In [1]:
import pandas as pd
import numpy as np

from sklearn import metrics
from sklearn import preprocessing
from sklearn.manifold import TSNE

import matplotlib.pyplot as plt
import seaborn as sns

from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

In [40]:
def pred2cls_df(df):
    pred = np.array([np.array(i.split(','), dtype=float) for i in df.Pred_avg.tolist()])
    pred = np.argmax(pred, axis=1)
    return pred

def pred2prob_df(df): 
    pred = np.array([np.array(i.split(','), dtype=float) for i in df.Pred_avg.tolist()])
    return pred

def group_f1(df): 
    label = np.array(df.Class.tolist(), dtype=int)
    pred = pred2cls_df(df)
    score = metrics.f1_score(label, pred, average='micro')
    return score

def group_kappa(df): 
    label = np.array(df.Class.tolist(), dtype=int)
    pred = pred2cls_df(df)
    score = metrics.cohen_kappa_score(label, pred)
    return score

def group_acc(df): 
    label = np.array(df.Class.tolist(), dtype=int)
    pred = pred2cls_df(df)
    score = metrics.accuracy_score(label, pred)
    return score

def group_auc(df): 
    label = np.array(df.Class.tolist(), dtype=int)
    pred = pred2prob_df(df)
    # score = metrics.roc_auc_score(label, pred, multi_class='ovr', average='weighted', labels=[0, 1, 2, 3])
    score = metrics.roc_auc_score(label, pred, multi_class='ovr', average='macro', labels=[0, 1, 2, 3])
    return score

## 1. Independent training

In [3]:
dfs = []
for mb in range(18): 
    for i in range(5):
        dfs.append(pd.read_csv('../results1203/molnet_chirality_cls_etkdg_csp{}-5fold_{}.csv'.format(str(mb), str(i)), 
                               sep='\t', index_col=0))
df = pd.concat(dfs, ignore_index=True)

# print('AUC:', df.groupby('MB').apply(group_auc), '\n')
# print('ACC:', df.groupby('MB').apply(group_acc), '\n')
# print('KAPPA:', df.groupby('MB').apply(group_kappa), '\n')
# print('F1:', df.groupby('MB').apply(group_f1), '\n')

In [4]:
auc = df.groupby('MB').apply(group_auc)
acc = df.groupby('MB').apply(group_acc)
kappa = df.groupby('MB').apply(group_kappa)
f1 = df.groupby('MB').apply(group_f1)

In [5]:
print('AUC:', '\n'+'\n'.join(auc.astype(str).tolist()), '\n')
print('ACC:', '\n'+'\n'.join(acc.astype(str).tolist()), '\n')
print('KAPPA:', '\n'+'\n'.join(kappa.astype(str).tolist()), '\n')
print('F1:', '\n'+'\n'.join(f1.astype(str).tolist()), '\n')

AUC: 
0.9466460938141763
0.9632375087795964
0.9330071702351392
0.9615860706611726
0.8894157137118267
0.9597761785509201
0.9255694453206886
0.9209678627083075
0.9485127948586194
0.9263728018881191
0.9138234121358867
0.9037921878432107
0.9553984376445753
0.9478379457888514
0.8998460724951834
0.9047693404329458
0.9036204103858668
0.8997815170842871 

ACC: 
0.8725925925925926
0.8426229508196721
0.8268398268398268
0.8371428571428572
0.7828402366863906
0.8679245283018868
0.842925659472422
0.8112781954887218
0.9023644961804292
0.8510402219140083
0.8669322709163346
0.8773039889958735
0.8586956521739131
0.8660869565217392
0.8019607843137255
0.8834645669291339
0.7985130111524164
0.7837837837837838 

KAPPA: 
0.7781234711350293
0.7751808228013329
0.7456952592532249
0.7717600219659528
0.6429804909385209
0.7945577176525311
0.7118865887893202
0.717877663885417
0.8057184619076038
0.7083922056559607
0.7488574352386864
0.7514447267062125
0.7943815974968195
0.7658299483794534
0.664666448579422
0.67178649

In [6]:
# def process_prob(x): 
#     x = np.array(x.split(','), dtype=float)
#     return x

# def plot_roc_curve(mb_idx, save_fig=False, print_confusion_metrics=False):
#     lb = preprocessing.LabelBinarizer()
#     if mb_idx == 'all':
#         df_tmp = df
#     else: 
#         df_tmp = df[df['MB'] == mb_idx]
        
#     class_oh = lb.fit_transform(df_tmp['Class'])
#     pred_prob = df_tmp['Pred'].apply(process_prob)
#     pred_prob = np.array(pred_prob.tolist())

#     f, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 9), gridspec_kw={'height_ratios': [0.6, 2]})

#     sns.countplot(x=df_tmp["Class"], ax=ax1)

#     for c in range(4): 
#         fpr, tpr, thresh = metrics.roc_curve(class_oh[:, c], pred_prob[:, c])
#         auc = metrics.roc_auc_score(class_oh[:, c], pred_prob[:, c])
#         ax2.plot(fpr, tpr, label="class {} vs the rest (AUC={:.2f})".format(c, auc))

#     ax2.plot([0, 1], [0, 1], "k--", label="chance level (AUC = 0.5)")
#     ax2.axis("square")
#     ax2.set_xlabel("False Positive Rate")
#     ax2.set_ylabel("True Positive Rate")
#     ax2.set_title("One-vs-Rest ROC curves (encoded csp: {})".format(mb_idx))
#     ax2.legend()

#     plt.subplots_adjust(hspace=.3)
#     if save_fig: 
#         plt.savefig('./roc_curve_{}.png'.format(str(mb_idx)), dpi=300, bbox_inches='tight')
#         print('Save!')
#     plt.show()
    
#     # confusion metrics
#     if print_confusion_metrics: 
#         pred = np.argmax(pred_prob, axis=1)
#         print('confusion metrics: \n[[tn, fp], \n[fn, tp]]\n')
#         print(metrics.multilabel_confusion_matrix(df_tmp['Class'].to_numpy(), pred))

In [7]:
# plot_roc_curve('all', save_fig=True)

In [8]:
# plot_roc_curve(2, save_fig=True)

## 2. Transfer learning

In [37]:
import torch
import torch.nn.functional as F

def avg_res(preds): 
	avg_preds = []
	for pred in preds.Pred.values:
		avg_preds.append(pred.split(','))

	avg_preds = np.array(avg_preds, dtype=np.float32)
	if len(avg_preds.shape) == 1: # only one configuration
		return ','.join(avg_preds.astype('str'))
	else: # more than one configuration, so average the prediction
		avg_preds = np.average(avg_preds, axis=0)
		avg_preds = F.softmax(torch.from_numpy(avg_preds), dim=0).numpy()
		return ','.join(avg_preds.astype('str'))

def average_results_on_enantiomers(df): 
	g = df.groupby(['SMILES', 'MB'])
	avg_df = g.apply(avg_res).to_frame('Pred')
	avg_df = avg_df.merge(df, on=['SMILES', 'MB']).rename(columns={'Pred_x': 'Pred_avg', 'Pred_y': 'Pred'})
	avg_df = avg_df.reset_index()
	return avg_df

In [38]:
dfs = []
for mb in range(18): 
    for i in range(5):
        dfs.append(pd.read_csv('../results0804/molnet_chirality_cls_etkdg_csp{}-5fold_tl_{}.csv'.format(str(mb), str(i)), 
                               sep='\t', index_col=0))
df = pd.concat(dfs, ignore_index=True)

print('Average the results of enantiomers...')
df = average_results_on_enantiomers(df)

Average the results of enantiomers...


In [39]:
df

Unnamed: 0,index,SMILES,MB,Pred_avg,Class,Pred
0,0,B[C@@H](C=C)CO,14,"0.17491437,0.4752549,0.1749144,0.17491636",1,"0.0001083978932001628,0.9996631145477295,0.000..."
1,1,B[C@H](C=C)CO,14,"0.17491248,0.47526157,0.17491233,0.17491361",1,"0.00010372867109254003,0.9996832609176636,0.00..."
2,2,B[PH](C)(CCO)c1ccccc1,8,"0.17495552,0.17494585,0.4751624,0.17493619",2,"0.0002516253152862191,0.00021932336676400155,0..."
3,3,B[PH](C)(CCO)c1ccccc1,8,"0.17495552,0.17494585,0.4751624,0.17493619",2,"0.00026885821716859937,0.00019048835383728147,..."
4,4,B[PH](C)(N(C)P(c1ccccc1)c1ccccc1)C(C)(C)C,14,"0.17487806,0.17487794,0.47536623,0.17487781",0,"1.6194647969314246e-06,1.082708536159771e-06,0..."
...,...,...,...,...,...,...
87485,87485,c1csc([C@H]2COc3ccccc3N2)c1,8,"0.17901778,0.17992991,0.18496746,0.4560849",2,"0.006757269613444805,0.01183957327157259,0.039..."
87486,87486,c1ncn2c1CC1(Cc3cncn3C1)C2,0,"0.17488813,0.17488635,0.47533616,0.17488943",2,"3.1708415917819366e-05,2.1495376131497324e-05,..."
87487,87487,c1ncn2c1CC1(Cc3cncn3C1)C2,0,"0.17488813,0.17488635,0.47533616,0.17488943",2,"3.190322604496032e-05,2.159194991691038e-05,0...."
87488,87488,c1onc2c1CCCC21CCCc2conc21,11,"0.17488635,0.174887,0.47534245,0.17488424",2,"2.7015395971830003e-05,3.0077526389504783e-05,..."


In [41]:
auc = df.groupby('MB').apply(group_auc)
acc = df.groupby('MB').apply(group_acc)
kappa = df.groupby('MB').apply(group_kappa)
f1 = df.groupby('MB').apply(group_f1)

In [42]:
print('AUC:', '\n'+'\n'.join(auc.astype(str).tolist()), '\n')
print('ACC:', '\n'+'\n'.join(acc.astype(str).tolist()), '\n')
print('KAPPA:', '\n'+'\n'.join(kappa.astype(str).tolist()), '\n')
print('F1:', '\n'+'\n'.join(f1.astype(str).tolist()), '\n')

AUC: 
0.9115390431010146
0.9540741665218911
0.9304861965313138
0.976658798505573
0.7902068143536398
0.933946287500884
0.8340170750624439
0.9240428077650971
0.8434172323873177
0.8317083235412089
0.7764934478651258
0.7902334349911714
0.964933617047022
0.8353739245366469
0.7819361223278316
0.619008084882074
0.810060657914573
0.9202308784970906 

ACC: 
0.8781954887218045
0.8639344262295082
0.8578260869565217
0.8671428571428571
0.7940652818991097
0.8886792452830189
0.8663855421686747
0.8620300751879699
0.8697002923976608
0.8669456066945607
0.8087649402390438
0.85
0.8956521739130435
0.8504347826086956
0.7975369458128079
0.889763779527559
0.7731343283582089
0.8675675675675676 

KAPPA: 
0.777154452877258
0.8025012092558785
0.7899625501635115
0.8107007851119512
0.6351719236593685
0.8243204143948268
0.725852213414195
0.7926150196164072
0.7254854685063032
0.69984708401217
0.6245285570839383
0.671542199667243
0.845995570977698
0.7174503681144597
0.6151969375518864
0.6469826747515179
0.615715514769