In [1]:
# import libaries
from pyspark.sql.session import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import StructType, StructField, IntegerType, StringType,FloatType
import pyspark.sql.functions as f
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os


from sklearn.preprocessing import OneHotEncoder
from sklearn.compose import make_column_transformer
from sklearn.feature_extraction.text import CountVectorizer



In [4]:
class PSTool:
    def __init__(self):
        print('Creating output folder')
        if os.path.exists('output'):
            pass
        else:
            os.makedirs('output')

    def pyspark_session(self, host_location):
        """
        Creates and returns spark session object
        """
        print('Starting session')
        sc = SparkContext(host_location)  # Create spark context
        spark = SparkSession(sc)  # Create session
        return spark

    def file_loader(self, path, delim, spark_obj, schema):
        print('Loading in file')
        data = spark_obj.read.options(delimiter=delim).option("header","False").csv(path, schema=schema)
        
        print('File loaded')
        return data

    def get_questions(self, df):
        pass

if __name__ == "__main__":
    pstool = PSTool()  # Instanciate object
    spk = pstool.pyspark_session('local[16]')  # start session
    # load data
    path = '/data/dataprocessing/interproscan/all_bacilli.tsv'
    # path = 'all_bacilli_subset.tsv'
    schema = StructType([
        StructField("Protein_accession", StringType(), True),
        StructField("Sequence_MD5_digest", StringType(), True),
        StructField("Sequence_length", IntegerType(), True),
        StructField("Analysis", StringType(), True),
        StructField("Signature_accession", StringType(), True),
        StructField("Signature_description", StringType(), True),
        StructField("Start_location", IntegerType(), True),
        StructField("Stop_location", IntegerType(), True),
        StructField("Score", FloatType(), True),
        StructField("Status", StringType(), True),
        StructField("Date", StringType(), True),
        StructField("InterPro_annotations_accession", StringType(), True),
        StructField("InterPro_annotations_description", StringType(), True),
        StructField("GO_annotations", StringType(), True),
        StructField("Pathways_annotations", StringType(), True)])
    
    df = pstool.file_loader(path, '\t', spk, schema)
#     pstool.get_questions(df)
#     print('Closing spark session')
#     spk.sparkContext.stop()
    df.printSchema()  # Shows column names and some info

Creating output folder
Starting session


22/07/08 13:30:02 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
22/07/08 13:30:04 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
22/07/08 13:30:04 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.


Loading in file
File loaded
root
 |-- Protein_accession: string (nullable = true)
 |-- Sequence_MD5_digest: string (nullable = true)
 |-- Sequence_length: integer (nullable = true)
 |-- Analysis: string (nullable = true)
 |-- Signature_accession: string (nullable = true)
 |-- Signature_description: string (nullable = true)
 |-- Start_location: integer (nullable = true)
 |-- Stop_location: integer (nullable = true)
 |-- Score: float (nullable = true)
 |-- Status: string (nullable = true)
 |-- Date: string (nullable = true)
 |-- InterPro_annotations_accession: string (nullable = true)
 |-- InterPro_annotations_description: string (nullable = true)
 |-- GO_annotations: string (nullable = true)
 |-- Pathways_annotations: string (nullable = true)



### Data cleaning
The "function" of the protein which is the "class" your model should predict is defined as the InterPRO number which covers:
- ">" 90% of the protein's sequenc
- Covers the largest length of the sequence

As we are to train on the interpro number:
- Remove any rows without interpro number


In [6]:
# remove rows that do not have Interpro number
print(df.select('InterPro_annotations_accession').distinct().count())
IPRO_filt = df.filter(df["InterPro_annotations_accession"] != '-')
print(IPRO_filt.select('InterPro_annotations_accession').distinct().count())

# check amount of rows left. 
print('len:', len(df.columns))
print('count:' , df.count())
# check if the columns are propperly loaded in. 
df_sizes = IPRO_filt.withColumn('perc', abs(df.Start_location - df.Stop_location) / df.Sequence_length).sort('perc')
for i in df_sizes.columns:
    print(df_sizes.select(i).show(5))



                                                                                

9704


                                                                                

9703
len: 15


                                                                                

count: 4200591


                                                                                

+--------------------+
|   Protein_accession|
+--------------------+
|gi|510143242|gb|A...|
|gi|510143242|gb|A...|
|gi|510143242|gb|A...|
|gi|510143242|gb|A...|
|gi|510143242|gb|A...|
+--------------------+
only showing top 5 rows

None


                                                                                

+--------------------+
| Sequence_MD5_digest|
+--------------------+
|d6f8e49a4de47c68d...|
|d6f8e49a4de47c68d...|
|d6f8e49a4de47c68d...|
|d6f8e49a4de47c68d...|
|d6f8e49a4de47c68d...|
+--------------------+
only showing top 5 rows

None


                                                                                

+---------------+
|Sequence_length|
+---------------+
|           6359|
|           6359|
|           6359|
|           6359|
|           6359|
+---------------+
only showing top 5 rows

None


                                                                                

+---------------+
|       Analysis|
+---------------+
|ProSitePatterns|
|ProSitePatterns|
|ProSitePatterns|
|ProSitePatterns|
|ProSitePatterns|
+---------------+
only showing top 5 rows

None


                                                                                

+-------------------+
|Signature_accession|
+-------------------+
|            PS00455|
|            PS00455|
|            PS00455|
|            PS00455|
|            PS00455|
+-------------------+
only showing top 5 rows

None


                                                                                

+---------------------+
|Signature_description|
+---------------------+
| Putative AMP-bind...|
| Putative AMP-bind...|
| Putative AMP-bind...|
| Putative AMP-bind...|
| Putative AMP-bind...|
+---------------------+
only showing top 5 rows

None


                                                                                

+--------------+
|Start_location|
+--------------+
|           604|
|          5690|
|          4177|
|          1645|
|          3144|
+--------------+
only showing top 5 rows

None


                                                                                

+-------------+
|Stop_location|
+-------------+
|          615|
|         1656|
|         4188|
|         5701|
|         3155|
+-------------+
only showing top 5 rows

None


                                                                                

+-----+
|Score|
+-----+
| null|
| null|
| null|
| null|
| null|
+-----+
only showing top 5 rows

None


                                                                                

+------+
|Status|
+------+
|     T|
|     T|
|     T|
|     T|
|     T|
+------+
only showing top 5 rows

None


                                                                                

+----------+
|      Date|
+----------+
|25-04-2022|
|25-04-2022|
|25-04-2022|
|25-04-2022|
|25-04-2022|
+----------+
only showing top 5 rows

None


                                                                                

+------------------------------+
|InterPro_annotations_accession|
+------------------------------+
|                     IPR020845|
|                     IPR020845|
|                     IPR020845|
|                     IPR020845|
|                     IPR020845|
+------------------------------+
only showing top 5 rows

None


                                                                                

+--------------------------------+
|InterPro_annotations_description|
+--------------------------------+
|            AMP-binding, cons...|
|            AMP-binding, cons...|
|            AMP-binding, cons...|
|            AMP-binding, cons...|
|            AMP-binding, cons...|
+--------------------------------+
only showing top 5 rows

None


                                                                                

+--------------+
|GO_annotations|
+--------------+
|             -|
|             -|
|             -|
|             -|
|             -|
+--------------+
only showing top 5 rows

None


                                                                                

+--------------------+
|Pathways_annotations|
+--------------------+
|MetaCyc: PWY-1061...|
|MetaCyc: PWY-1061...|
|MetaCyc: PWY-1061...|
|MetaCyc: PWY-1061...|
|MetaCyc: PWY-1061...|
+--------------------+
only showing top 5 rows

None




+--------------------+
|                perc|
+--------------------+
|0.001729831734549...|
|0.001729831734549...|
|0.001729831734549...|
|0.001729831734549...|
|0.001729831734549...|
+--------------------+
only showing top 5 rows

None


                                                                                

### modeling

In [5]:
# Organize data into labels for easier use
label_names = ["InterPro_annotations_accession"]
labels = IPRO_filt.select(label_names)
feature_names = [i for i in IPRO_filt.columns if i not in label_names]
features = IPRO_filt.select(feature_names)

print(label_names)
print(labels[:5])
print(feature_names[:5])
print(features[:5])

train, test, train_labels, test_labels = train_test_split(features, labels, test_size=0.2, random_state=42)

['InterPro_annotations_accession']


TypeError: unexpected item type: <class 'slice'>

In [None]:
# get a list of models to evaluate

def get_models():
	models = {'lr': LogisticRegression()}
	models['knn'] = KNeighborsClassifier()
	models['cart'] = DecisionTreeClassifier()
	models['svm'] = SVC()
	models['bayes'] = GaussianNB()
	return models
 
# evaluate a given model using cross-validation
def evaluate_model(model, X, y):
	cv = RepeatedStratifiedKFold(n_splits=10, n_repeats=3, random_state=42)
	return cross_val_score(
	    model, X, y, scoring='accuracy', cv=cv, n_jobs=-1, error_score='raise')
 
# define dataset
X = train
y = train_labels

# get the models to evaluate
models = get_models()
# evaluate the models and store results
results, names = list(), list()
for name, model in models.items():
	scores = evaluate_model(model, X, y)
	results.append(scores)
	names.append(name)
	print('model %s accuracy: %.3f (%.3f)' % (name, np.mean(scores), np.std(scores)))
# plot model performance for comparison
pyplot.boxplot(results, labels=names, showmeans=True)
pyplot.show()

In [None]:
# model evaluation
bayes = GaussianNB()
model = bayes.fit(train, train_labels)
preds = bayes.predict(test)

print(f"accuracy: {accuracy_score(test_labels, preds)}")
cf_matrix = confusion_matrix(test_labels, preds)
# print(confusion_matrix)
# turn confusion matrix into percentages
df_cm = cf_matrix.astype('float') / cf_matrix.sum(axis=1)[:, np.newaxis] 
plt.figure()
heatmap = sns.heatmap(df_cm, cmap="Blues", annot=True)
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.title('Naive base Model Results in confusion matrix')
plt.show()  

print(classification_report(test_labels, preds))

In [None]:
# closing the spark sessison

print('Closing spark session')
spk.sparkContext.stop()