In [1]:
import numpy as np
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer

from riken.word2vec import classification_tools

K = 4
RANDOM_STATE = 42

data_path = '/home/pierre/riken/data/riken_data/complete_from_xlsx.tsv'

In [2]:
df = pd.read_csv(data_path, sep='\t').dropna()
df.loc[:, 'seq_len'] = df.sequences.apply(len)
df = df.loc[df.seq_len >= 50, :]

sequences, y = df['sequences'].values, df['is_allergenic'].values
groups = df.species.values

# Feature extraction

In [4]:
#TODO: Test this
def overlapping_tokenizer(seq, k=K):
    """
    'ABCDE' ==> ['ABC', 'BCE', 'CDE'] 
    """
    return [seq[idx:idx+K] for idx in range(len(seq)-k+1)]


vectorizer = TfidfVectorizer(tokenizer=overlapping_tokenizer, max_features=32000, 
                             lowercase=False, ngram_range=(1,4),) 
X = vectorizer.fit_transform(sequences)

# Prediction using SVM

In [5]:
from protein_io import data_op

train_inds, test_inds = data_op.group_shuffle_indices(X, y, groups)
Xtrain, Xtest, ytrain, ytest = X[train_inds], X[test_inds], y[train_inds], y[test_inds]

In [6]:
ytest.sum() / len(ytest)

0.3192452830188679

In [8]:
from sklearn.svm import LinearSVC

clf = LinearSVC(tol=1e-6, max_iter=50000, penalty='l1', C=15.0, 
                loss='squared_hinge', class_weight='balanced', dual=False, random_state=RANDOM_STATE)
clf.fit(Xtrain, ytrain)
ypred = clf.predict(Xtest)
yscore = clf.decision_function(Xtest)

from sklearn.metrics import classification_report, roc_auc_score

print(classification_report(ytest, ypred >= 0.5))
print('ROC AUC SCORE: ', roc_auc_score(ytest, yscore))

             precision    recall  f1-score   support

      False       0.78      0.93      0.85       902
       True       0.74      0.45      0.56       423

avg / total       0.77      0.78      0.76      1325

ROC AUC SCORE:  0.8052135260230745


In [36]:
study_df = (pd.DataFrame({'coef': clf.coef_.reshape(-1), 'name': [pat for pat in vectorizer.vocabulary_.keys()]})
.assign(absolute_val=lambda x: x.coef.abs()))

study_df.sort_values('absolute_val', ascending=False)[['coef', 'name']]

Unnamed: 0,coef,name
6096,9.470084,QAIH
9607,9.283566,RSSP
13096,9.117230,GLKI
7809,8.603098,TVLD
9008,8.244090,SFGA
19579,8.085166,AGGS
10697,7.685303,SEED
14709,7.505808,ELDD
27460,-7.493306,QTAT
27379,7.105758,ETYK


In [44]:
study_df[study_df.absolute_val!=0].sample(15)

Unnamed: 0,coef,name,absolute_val
17733,0.923072,RSVY,0.923072
17869,-0.96876,KTHL,0.96876
1343,-0.190381,SPVR,0.190381
21684,-1.69756,EDVI,1.69756
8553,0.923886,EREL,0.923886
15977,-1.377065,VDSN,1.377065
3273,0.927878,SLFA,0.927878
19807,-2.3861,DAEG,2.3861
1635,4.592474,LFES,4.592474
1477,-0.006801,IEIK,0.006801


# Results study

In [46]:
df.loc[:, 'prediction'] = clf.predict(X)
mask_test = np.zeros(len(df), dtype=bool)
mask_test[test_inds] = True
df.loc[:, 'is_test'] = mask_test

gped = df.loc[df.is_test].groupby(['genre', 'species'])
counts = gped.is_allergenic.count()
specific_accuracy = gped.apply(lambda x: (x.is_allergenic==x.prediction).sum()/len(x))
specific_precision = gped.apply(lambda x: ((x.prediction==True) & x.is_allergenic).sum()/len(x[x.prediction==True]))

specific_df = pd.DataFrame({'count': counts, 'accuracy': specific_accuracy, 'precision': specific_precision})
specific_df  # [specific_df['count']>=10]

  # This is added back by InteractiveShellApp.init_path()


Unnamed: 0_level_0,Unnamed: 1_level_0,accuracy,count,precision
genre,species,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
actinidia,deliciosa,0.742857,35,0.681818
alnus,glutinosa,1.000000,2,1.000000
alternaria,alternata,0.368421,19,1.000000
ambrosia,trifida,0.000000,1,
ananas,comosus,0.500000,2,1.000000
anisakis,pegreffii,1.000000,4,1.000000
aquareovirus,a,1.000000,3,
argas,reflexus,0.000000,2,
artemia,franciscana,0.960000,25,0.000000
artemia,salina,1.000000,19,


In [48]:
precision = lambda x: ((x.prediction==True) & x.is_allergenic).sum()/len(x[x.prediction])
precision(df.loc[df.is_test])

0.7431906614785992

# Dimension reduction

In [66]:
from sklearn.decomposition import TruncatedSVD

svd = TruncatedSVD(n_components=1000, random_state=RANDOM_STATE, n_iter=10)
svd.fit(Xtrain)

TruncatedSVD(algorithm='randomized', n_components=1000, n_iter=10,
       random_state=42, tol=0.0)

In [67]:
svd.explained_variance_ratio_.cumsum()

array([0.00096352, 0.00651107, 0.0115898 , 0.01534664, 0.01812787,
       0.02047562, 0.02281182, 0.02494843, 0.02702831, 0.02905577,
       0.0310529 , 0.03300765, 0.03482276, 0.03656517, 0.03824042,
       0.03968834, 0.04110445, 0.04248367, 0.04379914, 0.04510358,
       0.04639908, 0.04764358, 0.04884808, 0.05003206, 0.0511772 ,
       0.05230598, 0.05341584, 0.05449862, 0.05555849, 0.05657269,
       0.05756687, 0.0585453 , 0.05948794, 0.06040185, 0.06130921,
       0.06219831, 0.06307711, 0.06395219, 0.06481176, 0.06566332,
       0.06650143, 0.06732711, 0.06815176, 0.06896465, 0.06976225,
       0.07055236, 0.07133892, 0.07212233, 0.07289678, 0.07365496,
       0.0744065 , 0.07514708, 0.07587935, 0.07659837, 0.07731393,
       0.07801962, 0.07872187, 0.07942279, 0.08011282, 0.08079607,
       0.08147837, 0.08215131, 0.08281846, 0.08347398, 0.08412777,
       0.08477802, 0.08541688, 0.08605448, 0.08668483, 0.08730615,
       0.08792503, 0.08853798, 0.08914865, 0.08974738, 0.09034

In [68]:
Xtrain_red = svd.transform(Xtrain)
Xtest_red = svd.transform(Xtest)

In [69]:
import catboost

clf = catboost.CatBoostClassifier()
clf.fit(Xtrain_red, ytrain, eval_set=(Xtest_red, ytest))

0:	learn: 0.6685673	test: 0.6817719	best: 0.6817719 (0)	total: 56.8ms	remaining: 56.8s
1:	learn: 0.6451478	test: 0.6726415	best: 0.6726415 (1)	total: 117ms	remaining: 58.4s
2:	learn: 0.6248550	test: 0.6641524	best: 0.6641524 (2)	total: 178ms	remaining: 59.1s
3:	learn: 0.6038690	test: 0.6546043	best: 0.6546043 (3)	total: 242ms	remaining: 1m
4:	learn: 0.5848593	test: 0.6459055	best: 0.6459055 (4)	total: 303ms	remaining: 1m
5:	learn: 0.5663223	test: 0.6383831	best: 0.6383831 (5)	total: 364ms	remaining: 1m
6:	learn: 0.5498369	test: 0.6317474	best: 0.6317474 (6)	total: 424ms	remaining: 1m
7:	learn: 0.5337085	test: 0.6264910	best: 0.6264910 (7)	total: 483ms	remaining: 60s
8:	learn: 0.5184445	test: 0.6217632	best: 0.6217632 (8)	total: 547ms	remaining: 1m
9:	learn: 0.5028386	test: 0.6156711	best: 0.6156711 (9)	total: 607ms	remaining: 1m
10:	learn: 0.4887266	test: 0.6108401	best: 0.6108401 (10)	total: 658ms	remaining: 59.2s
11:	learn: 0.4751364	test: 0.6070544	best: 0.6070544 (11)	total: 717ms	

96:	learn: 0.1990982	test: 0.6059427	best: 0.5779421 (29)	total: 5.8s	remaining: 54s
97:	learn: 0.1983578	test: 0.6061671	best: 0.5779421 (29)	total: 5.86s	remaining: 53.9s
98:	learn: 0.1976419	test: 0.6065969	best: 0.5779421 (29)	total: 5.91s	remaining: 53.8s
99:	learn: 0.1968416	test: 0.6070717	best: 0.5779421 (29)	total: 5.95s	remaining: 53.6s
100:	learn: 0.1959215	test: 0.6081388	best: 0.5779421 (29)	total: 6.01s	remaining: 53.5s
101:	learn: 0.1948561	test: 0.6075503	best: 0.5779421 (29)	total: 6.07s	remaining: 53.4s
102:	learn: 0.1940400	test: 0.6078237	best: 0.5779421 (29)	total: 6.12s	remaining: 53.3s
103:	learn: 0.1935175	test: 0.6080995	best: 0.5779421 (29)	total: 6.18s	remaining: 53.2s
104:	learn: 0.1925923	test: 0.6074743	best: 0.5779421 (29)	total: 6.24s	remaining: 53.2s
105:	learn: 0.1916781	test: 0.6078705	best: 0.5779421 (29)	total: 6.3s	remaining: 53.1s
106:	learn: 0.1911440	test: 0.6082450	best: 0.5779421 (29)	total: 6.36s	remaining: 53.1s
107:	learn: 0.1905217	test: 0

192:	learn: 0.1494630	test: 0.6058283	best: 0.5779421 (29)	total: 11.5s	remaining: 48.2s
193:	learn: 0.1491503	test: 0.6058904	best: 0.5779421 (29)	total: 11.6s	remaining: 48.2s
194:	learn: 0.1489389	test: 0.6058753	best: 0.5779421 (29)	total: 11.7s	remaining: 48.1s
195:	learn: 0.1485614	test: 0.6054124	best: 0.5779421 (29)	total: 11.7s	remaining: 48.1s
196:	learn: 0.1481885	test: 0.6053729	best: 0.5779421 (29)	total: 11.8s	remaining: 48s
197:	learn: 0.1478524	test: 0.6053777	best: 0.5779421 (29)	total: 11.8s	remaining: 48s
198:	learn: 0.1474850	test: 0.6057945	best: 0.5779421 (29)	total: 11.9s	remaining: 47.9s
199:	learn: 0.1472953	test: 0.6060344	best: 0.5779421 (29)	total: 12s	remaining: 47.9s
200:	learn: 0.1469447	test: 0.6055275	best: 0.5779421 (29)	total: 12s	remaining: 47.9s
201:	learn: 0.1466458	test: 0.6054995	best: 0.5779421 (29)	total: 12.1s	remaining: 47.9s
202:	learn: 0.1464078	test: 0.6054764	best: 0.5779421 (29)	total: 12.2s	remaining: 47.8s
203:	learn: 0.1460260	test: 0

286:	learn: 0.1249767	test: 0.6022937	best: 0.5779421 (29)	total: 17s	remaining: 42.2s
287:	learn: 0.1247511	test: 0.6023273	best: 0.5779421 (29)	total: 17.1s	remaining: 42.2s
288:	learn: 0.1245091	test: 0.6023823	best: 0.5779421 (29)	total: 17.1s	remaining: 42.1s
289:	learn: 0.1243197	test: 0.6021333	best: 0.5779421 (29)	total: 17.2s	remaining: 42s
290:	learn: 0.1241246	test: 0.6019091	best: 0.5779421 (29)	total: 17.2s	remaining: 42s
291:	learn: 0.1239202	test: 0.6016729	best: 0.5779421 (29)	total: 17.3s	remaining: 41.9s
292:	learn: 0.1236294	test: 0.6019311	best: 0.5779421 (29)	total: 17.3s	remaining: 41.8s
293:	learn: 0.1234670	test: 0.6018695	best: 0.5779421 (29)	total: 17.4s	remaining: 41.8s
294:	learn: 0.1233320	test: 0.6021941	best: 0.5779421 (29)	total: 17.4s	remaining: 41.7s
295:	learn: 0.1231704	test: 0.6021963	best: 0.5779421 (29)	total: 17.5s	remaining: 41.6s
296:	learn: 0.1230152	test: 0.6021923	best: 0.5779421 (29)	total: 17.5s	remaining: 41.5s
297:	learn: 0.1228142	test:

382:	learn: 0.1081663	test: 0.5997331	best: 0.5779421 (29)	total: 22.4s	remaining: 36.1s
383:	learn: 0.1080144	test: 0.6001544	best: 0.5779421 (29)	total: 22.5s	remaining: 36s
384:	learn: 0.1078884	test: 0.5999713	best: 0.5779421 (29)	total: 22.5s	remaining: 36s
385:	learn: 0.1077874	test: 0.5999647	best: 0.5779421 (29)	total: 22.6s	remaining: 35.9s
386:	learn: 0.1076014	test: 0.5995839	best: 0.5779421 (29)	total: 22.6s	remaining: 35.8s
387:	learn: 0.1074729	test: 0.5996697	best: 0.5779421 (29)	total: 22.7s	remaining: 35.8s
388:	learn: 0.1073572	test: 0.5995958	best: 0.5779421 (29)	total: 22.7s	remaining: 35.7s
389:	learn: 0.1071989	test: 0.5992481	best: 0.5779421 (29)	total: 22.8s	remaining: 35.6s
390:	learn: 0.1071239	test: 0.5992291	best: 0.5779421 (29)	total: 22.8s	remaining: 35.6s
391:	learn: 0.1069914	test: 0.5990272	best: 0.5779421 (29)	total: 22.9s	remaining: 35.5s
392:	learn: 0.1067996	test: 0.5991912	best: 0.5779421 (29)	total: 23s	remaining: 35.5s
393:	learn: 0.1066683	test:

478:	learn: 0.0953660	test: 0.5980265	best: 0.5779421 (29)	total: 27.7s	remaining: 30.1s
479:	learn: 0.0952328	test: 0.5981613	best: 0.5779421 (29)	total: 27.8s	remaining: 30.1s
480:	learn: 0.0951336	test: 0.5980361	best: 0.5779421 (29)	total: 27.8s	remaining: 30s
481:	learn: 0.0950328	test: 0.5978486	best: 0.5779421 (29)	total: 27.9s	remaining: 30s
482:	learn: 0.0948169	test: 0.5980914	best: 0.5779421 (29)	total: 27.9s	remaining: 29.9s
483:	learn: 0.0946948	test: 0.5980632	best: 0.5779421 (29)	total: 28s	remaining: 29.8s
484:	learn: 0.0945901	test: 0.5980678	best: 0.5779421 (29)	total: 28s	remaining: 29.8s
485:	learn: 0.0945193	test: 0.5980842	best: 0.5779421 (29)	total: 28.1s	remaining: 29.7s
486:	learn: 0.0943742	test: 0.5980279	best: 0.5779421 (29)	total: 28.2s	remaining: 29.7s
487:	learn: 0.0942533	test: 0.5979429	best: 0.5779421 (29)	total: 28.2s	remaining: 29.6s
488:	learn: 0.0941535	test: 0.5977972	best: 0.5779421 (29)	total: 28.3s	remaining: 29.5s
489:	learn: 0.0939937	test: 0

574:	learn: 0.0849887	test: 0.5944156	best: 0.5779421 (29)	total: 33.1s	remaining: 24.4s
575:	learn: 0.0849120	test: 0.5945901	best: 0.5779421 (29)	total: 33.1s	remaining: 24.4s
576:	learn: 0.0848439	test: 0.5945422	best: 0.5779421 (29)	total: 33.2s	remaining: 24.3s
577:	learn: 0.0847057	test: 0.5945324	best: 0.5779421 (29)	total: 33.2s	remaining: 24.3s
578:	learn: 0.0846669	test: 0.5943594	best: 0.5779421 (29)	total: 33.3s	remaining: 24.2s
579:	learn: 0.0845505	test: 0.5945301	best: 0.5779421 (29)	total: 33.3s	remaining: 24.1s
580:	learn: 0.0844341	test: 0.5946264	best: 0.5779421 (29)	total: 33.4s	remaining: 24.1s
581:	learn: 0.0843605	test: 0.5946172	best: 0.5779421 (29)	total: 33.4s	remaining: 24s
582:	learn: 0.0843105	test: 0.5946821	best: 0.5779421 (29)	total: 33.5s	remaining: 24s
583:	learn: 0.0841843	test: 0.5945001	best: 0.5779421 (29)	total: 33.6s	remaining: 23.9s
584:	learn: 0.0840361	test: 0.5944701	best: 0.5779421 (29)	total: 33.6s	remaining: 23.8s
585:	learn: 0.0838763	tes

669:	learn: 0.0766115	test: 0.5936449	best: 0.5779421 (29)	total: 38.4s	remaining: 18.9s
670:	learn: 0.0765765	test: 0.5935814	best: 0.5779421 (29)	total: 38.4s	remaining: 18.8s
671:	learn: 0.0765196	test: 0.5937060	best: 0.5779421 (29)	total: 38.5s	remaining: 18.8s
672:	learn: 0.0764534	test: 0.5938824	best: 0.5779421 (29)	total: 38.5s	remaining: 18.7s
673:	learn: 0.0763595	test: 0.5938512	best: 0.5779421 (29)	total: 38.6s	remaining: 18.7s
674:	learn: 0.0762793	test: 0.5937480	best: 0.5779421 (29)	total: 38.7s	remaining: 18.6s
675:	learn: 0.0762021	test: 0.5938785	best: 0.5779421 (29)	total: 38.7s	remaining: 18.6s
676:	learn: 0.0761431	test: 0.5939938	best: 0.5779421 (29)	total: 38.8s	remaining: 18.5s
677:	learn: 0.0760690	test: 0.5939888	best: 0.5779421 (29)	total: 38.8s	remaining: 18.4s
678:	learn: 0.0759571	test: 0.5942186	best: 0.5779421 (29)	total: 38.9s	remaining: 18.4s
679:	learn: 0.0759059	test: 0.5941815	best: 0.5779421 (29)	total: 39s	remaining: 18.3s
680:	learn: 0.0758222	t

765:	learn: 0.0698460	test: 0.5935436	best: 0.5779421 (29)	total: 43.8s	remaining: 13.4s
766:	learn: 0.0697942	test: 0.5931887	best: 0.5779421 (29)	total: 43.9s	remaining: 13.3s
767:	learn: 0.0697253	test: 0.5930522	best: 0.5779421 (29)	total: 43.9s	remaining: 13.3s
768:	learn: 0.0696227	test: 0.5931760	best: 0.5779421 (29)	total: 44s	remaining: 13.2s
769:	learn: 0.0695518	test: 0.5932757	best: 0.5779421 (29)	total: 44.1s	remaining: 13.2s
770:	learn: 0.0695064	test: 0.5932340	best: 0.5779421 (29)	total: 44.1s	remaining: 13.1s
771:	learn: 0.0694276	test: 0.5932130	best: 0.5779421 (29)	total: 44.2s	remaining: 13s
772:	learn: 0.0693880	test: 0.5933724	best: 0.5779421 (29)	total: 44.2s	remaining: 13s
773:	learn: 0.0693372	test: 0.5934081	best: 0.5779421 (29)	total: 44.3s	remaining: 12.9s
774:	learn: 0.0692875	test: 0.5933174	best: 0.5779421 (29)	total: 44.3s	remaining: 12.9s
775:	learn: 0.0692361	test: 0.5930996	best: 0.5779421 (29)	total: 44.4s	remaining: 12.8s
776:	learn: 0.0691838	test:

861:	learn: 0.0643789	test: 0.5925470	best: 0.5779421 (29)	total: 49.2s	remaining: 7.87s
862:	learn: 0.0643118	test: 0.5924619	best: 0.5779421 (29)	total: 49.2s	remaining: 7.81s
863:	learn: 0.0642631	test: 0.5922779	best: 0.5779421 (29)	total: 49.3s	remaining: 7.75s
864:	learn: 0.0641915	test: 0.5921111	best: 0.5779421 (29)	total: 49.3s	remaining: 7.7s
865:	learn: 0.0641580	test: 0.5922778	best: 0.5779421 (29)	total: 49.4s	remaining: 7.64s
866:	learn: 0.0641272	test: 0.5921882	best: 0.5779421 (29)	total: 49.4s	remaining: 7.58s
867:	learn: 0.0640708	test: 0.5921759	best: 0.5779421 (29)	total: 49.5s	remaining: 7.53s
868:	learn: 0.0640064	test: 0.5922882	best: 0.5779421 (29)	total: 49.6s	remaining: 7.47s
869:	learn: 0.0639729	test: 0.5923003	best: 0.5779421 (29)	total: 49.6s	remaining: 7.41s
870:	learn: 0.0639262	test: 0.5921853	best: 0.5779421 (29)	total: 49.7s	remaining: 7.36s
871:	learn: 0.0638041	test: 0.5920322	best: 0.5779421 (29)	total: 49.7s	remaining: 7.3s
872:	learn: 0.0637892	t

957:	learn: 0.0598960	test: 0.5910787	best: 0.5779421 (29)	total: 54.5s	remaining: 2.39s
958:	learn: 0.0598334	test: 0.5912349	best: 0.5779421 (29)	total: 54.5s	remaining: 2.33s
959:	learn: 0.0598076	test: 0.5911891	best: 0.5779421 (29)	total: 54.6s	remaining: 2.27s
960:	learn: 0.0597267	test: 0.5908282	best: 0.5779421 (29)	total: 54.6s	remaining: 2.22s
961:	learn: 0.0596662	test: 0.5910073	best: 0.5779421 (29)	total: 54.7s	remaining: 2.16s
962:	learn: 0.0596139	test: 0.5909747	best: 0.5779421 (29)	total: 54.8s	remaining: 2.1s
963:	learn: 0.0595785	test: 0.5910391	best: 0.5779421 (29)	total: 54.8s	remaining: 2.05s
964:	learn: 0.0595108	test: 0.5910701	best: 0.5779421 (29)	total: 54.9s	remaining: 1.99s
965:	learn: 0.0594615	test: 0.5910672	best: 0.5779421 (29)	total: 54.9s	remaining: 1.93s
966:	learn: 0.0594215	test: 0.5909776	best: 0.5779421 (29)	total: 55s	remaining: 1.88s
967:	learn: 0.0593882	test: 0.5909718	best: 0.5779421 (29)	total: 55s	remaining: 1.82s
968:	learn: 0.0593267	test

<catboost.core.CatBoostClassifier at 0x7f0284f598d0>

In [70]:
ypred = clf.predict_proba(Xtest_red)[:, 1]
print(classification_report(ytest, ypred >= 0.5))
print('ROC AUC SCORE: ', roc_auc_score(ytest, ypred))

             precision    recall  f1-score   support

      False       0.70      1.00      0.82       902
       True       0.90      0.08      0.15       423

avg / total       0.76      0.70      0.61      1325

ROC AUC SCORE:  0.742239467849224
