In [2]:
from preproc_utils import *
from Get_PSSM import *
from Get_dataset import *
import os
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_auc_score,f1_score

chebi = pd.read_table('/Users/suhancho/data/Uniprot_metalbinding_challenge/POS_TRAIN_FULL.tsv')
inpath = '/Users/suhancho/data/Uniprot_metalbinding_challenge/chebi/'
bind_tsv_list = [inpath + f for f in os.listdir(inpath)]
pssm_path = '/Users/suhancho/data/Uniprot_metalbinding_challenge/PSSM/'
pssm_files = [pssm_path+f for f in os.listdir(pssm_path)]
bindlist = pd.concat([pd.read_table(f) for f in bind_tsv_list])
low_labels = [l.replace(' ','') for l in bindlist.Name.value_counts().index[bindlist.Name.value_counts()<1000]]

In [3]:
bindlist['Name'] = pd.Categorical(bindlist.Name)
bindlist['Target'] = bindlist.Name.cat.codes

In [4]:
def calculate_window(num_inspections,bs_df):
    train_dat=[]
    for i,pssm in tqdm(enumerate(pssm_files[0:num_inspections])):
        bs = get_binding_site_multi(bs_df,pssm.split('/')[-1].split('.')[0])
        gt = get_dataset_padded_multi(get_processed_pssm(pssm),bs)
        for g in gt : 
            train_dat.append([g[0].values.tolist(),g[1]])
            
    return(train_dat)

In [5]:
def preproc_data(windowdata):
    train_X = [dat[0] for dat in windowdata]
    train_Y = [dat[1] for dat in windowdata]
    print('Size of dataset : '+str(len(train_X)))
    return(train_X,train_Y)

In [6]:
import seaborn as sns
import matplotlib.pyplot as plt
def check_windowdata(traindata):
    sns.histplot([len(t) for t in traindata])
    plt.show()

In [7]:
def filter_traindata(Xdata,Ydata):
    train_X_filtered = [Xdata[i] for i in range(len(Xdata)) if len(Xdata[i])==9]
    train_Y_filtered = [Ydata[i] for i in range(len(Xdata)) if len(Xdata[i])==9]
    return(train_X_filtered,train_Y_filtered)  

In [8]:
from itertools import chain
def flatten_Xdata(filtered_X):
    flatten_trainX = [list(chain.from_iterable(lst)) for lst in filtered_X]
    return(flatten_trainX)

In [9]:
def get_MLmetrics(testset_y,testset_X,classifier,ion_name):
    auc = roc_auc_score(testset_y,list(classifier.predict(testset_X)))
    acc = accuracy_score(testset_y,list(classifier.predict(testset_X)))
    recall = recall_score(testset_y,list(classifier.predict(testset_X)))
    f1 = f1_score(testset_y,list(classifier.predict(testset_X)))
    prec = precision_score(testset_y,list(classifier.predict(testset_X)))
    print('ION = '+ion_name)
    print('\nAUC = '+str(round(auc,2)))
    print('\nAccuracy = '+str(round(acc,2)))
    print('\nRecall = '+str(round(recall,2)))
    print('\nF1 = '+str(round(f1,2)))
    print('Precision = '+str(round(prec,2)))
    return(auc,acc,recall,f1,prec)

In [10]:
def balance_classes(traindata,fold):
    label1 = [traindata[i] for i in range(len(traindata)) if traindata[i][1]==1]
    label0 = [traindata[i] for i in range(len(traindata)) if traindata[i][1]==0]
    balanced0 = random.sample(label0,len(label1)*fold)
    return(balanced0+label1)

In [11]:
from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler


In [15]:
iter = 10000
train_dat = calculate_window(iter,bindlist)
# train_dat = balance_classes(train_dat,3)
train_X,train_Y = preproc_data(train_dat)
check_windowdata(train_X)
train_X,train_Y = filter_traindata(train_X,train_Y)
flatten_trainX = flatten_Xdata(train_X)
trainX, testX, trainy, testy = train_test_split(flatten_trainX,train_Y,test_size=0.4,shuffle = True)#,stratify=train_Y)

X = trainX ; y = trainy

pipe_svc = make_pipeline(StandardScaler(),SVC(random_state=9510))

param_range = [0.0001, 0.001, 0.01, 0.1, 1.0, 10.0, 100.0, 1000.0]

param_grid = [{'svc__C': param_range, 
            'svc__kernel': ['linear']},
            {'svc__C': param_range, 
            'svc__gamma': param_range, 
            'svc__kernel': ['rbf']}]

gs = GridSearchCV(estimator=pipe_svc, 
                param_grid=param_grid, 
                scoring='accuracy', 
                cv=3,
                n_jobs=-1)
gs = gs.fit(X, y)
auc,acc,f1,recall,prec = get_MLmetrics(testy,testX,gs,'multi')