In [1]:
from src.layers import *
from src.utils import process_edges, auprc_auroc_ap
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score
import matplotlib.pyplot as plt
from scipy imporinterp
from collections import OrderedDict
import numpy as np
from umap import UMAP

In [2]:
pose = Pose(data_dir, device='cpu')

model = pose.model
model.eval()


go-basic.obo: fmt(1.2) rel(2021-05-01) 47,284 GO Terms
  EXISTS: gene2go
HMS:0:00:06.084110 335,858 annotations, 20,671 genes, 18,441 GOs, 1 taxids READ: gene2go 

Load BP Gene Ontology Analysis ...
fisher module not installed.  Falling back on scipy.stats.fisher_exact
 81% 16,936 of 20,913 population items found in association

Load CC Gene Ontology Analysis ...
fisher module not installed.  Falling back on scipy.stats.fisher_exact
 85% 17,872 of 20,913 population items found in association

Load MF Gene Ontology Analysis ...
fisher module not installed.  Falling back on scipy.stats.fisher_exact
 83% 17,408 of 20,913 population items found in association


Model(
  (pp): PP(
    (conv_list): ModuleList(
      (0): myGCN(64, 32)
      (1): myGCN(32, 32)
    )
  )
  (pd): PD(
    (conv): myGCN(128, 128)
  )
  (mip): MultiInnerProductDecoder()
)

In [6]:
def evaluate(mod: str, pos_score: torch.Tensor, neg_score: torch.Tensor) -> np.array:
    assert mod in {'train', 'test'}, "'idx' should in {'train', 'test'}"
    
    record = np.zeros((3, gdata.n_et))
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    precision = dict()
    recall = dict()
    average_precision = dict()
    
    for i in range(getattr(gdata, f'{mod}_range').shape[0]):
        [start, end] = getattr(gdata, f'{mod}_range')[i]
        p_s = pos_score[start: end]
        n_s = neg_score[start: end]

        pos_target = torch.ones(p_s.shape[0])
        neg_target = torch.zeros(n_s.shape[0])

        score = torch.cat([p_s, n_s])
        target = torch.cat([pos_target, neg_target])

        # auprc_auroc_ap
        record[0, i], record[1, i], record[2, i] = auprc_auroc_ap(target, score)
        
        # auroc
        fpr[i], tpr[i], _ = roc_curve(target.detach().numpy(), score.detach().numpy())
        roc_auc[i] = auc(fpr[i], tpr[i])
        # auprc
        precision[i], recall[i], _ = precision_recall_curve(target.detach().numpy(), score.detach().numpy())

    return record, fpr, tpr, roc_auc, precision, recall, average_precision

In [7]:
pp_static_edge_weights = torch.ones((gdata.pp_index.shape[1]))
pd_static_edge_weights = torch.ones((gdata.pd_index.shape[1]))
test_neg_index = typed_negative_sampling(gdata.test_idx, gdata.n_drug, gdata.test_range)

z = model.pp(gdata.p_feat, gdata.pp_index, pp_static_edge_weights)
z = model.pd(z, gdata.pd_index, pd_static_edge_weights)

pos_score = model.mip(z, gdata.test_idx, gdata.test_et)
neg_score = model.mip(z, test_neg_index, gdata.test_et)

record, fpr, tpr, roc_auc, precision, recall, average_precision = evaluate('test', pos_score, neg_score)
y_score = torch.cat((pos_score, neg_score)).detach().numpy()
y_test = torch.cat((torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0]))).detach().numpy()

# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(y_test, y_score)
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

# # Compute macro-average ROC curve and ROC area

# # First aggregate all false positive rates
# n_classes = record.shape[1]
# all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))

# # Then interpolate all ROC curves at this points
# mean_tpr = np.zeros_like(all_fpr)
# for i in range(n_classes):
#     mean_tpr += interp(all_fpr, fpr[i], tpr[i])

# # Finally average it and compute AUC
# mean_tpr /= n_classes

# fpr["macro"] = all_fpr
# tpr["macro"] = mean_tpr
# roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

In [10]:
pp_static_edge_weights = torch.ones((gdata.pp_index.shape[1]))
pd_static_edge_weights = torch.ones((gdata.pd_index.shape[1]))
train_neg_index = typed_negative_sampling(gdata.train_idx, gdata.n_drug, gdata.train_range)

z = model.pp(gdata.p_feat, gdata.pp_index, pp_static_edge_weights)
z = model.pd(z, gdata.pd_index, pd_static_edge_weights)

pos_score_train = model.mip(z, gdata.train_idx, gdata.train_et)
neg_score_train = model.mip(z, train_neg_index, gdata.train_et)

y_score_train = torch.cat((pos_score_train, neg_score_train)).detach().numpy()
y_test_train = torch.cat((torch.ones(pos_score_train.shape[0]), torch.zeros(neg_score_train.shape[0]))).detach().numpy()

record_train, fpr_train, tpr_train, roc_auc_train, _, _, _ = evaluate('train', pos_score_train, neg_score_train)

# Compute micro-average ROC curve and ROC area_train
fpr_train["micro"], tpr_train["micro"], _ = roc_curve(y_test_train, y_score_train)
roc_auc_train["micro"] = auc(fpr_train["micro"], tpr_train["micro"])

In [None]:
plt.figure()
plt.plot(fpr_train["micro"], tpr_train["micro"],
         label='Train - Mi-ROC curve (area = {0:0.4f})'
               ''.format(roc_auc_train["micro"]),
         color='darkblue', linestyle='-', linewidth=2, alpha=0.5)

# plt.plot(fpr_train["macro"], tpr_train["macro"],
#          label='Test - Ma-ROC curve for test (area = {0:0.2f})'
#                ''.format(roc_auc_train["macro"]),
#          color='darkorange', linestyle='dashdot', linewidth=2)

plt.plot(fpr["micro"], tpr["micro"],
         label='Test - Mi-ROC curve (area = {0:0.4f})'
               ''.format(roc_auc["micro"]),
         color='darkorange', linestyle='-', linewidth=2, alpha=0.5)

plt.plot(fpr["macro"], tpr["macro"],
         label='Test - Ma-ROC curve for test (area = {0:0.4f})'
               ''.format(roc_auc["macro"]),
         color='darkorange', linestyle='dashdot', linewidth=2, alpha=0.5)

plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Some extension of Receiver operating characteristic to multi-class')
plt.legend(loc="lower right")
plt.show()


In [86]:
record_train.mean(axis=1), record.mean(axis=1)

(array([0.85682845, 0.88356708, 0.85718549]),
 array([0.7946841 , 0.83275441, 0.79720226]))

In [95]:
y, xx, _ = precision_recall_curve(y_test_train, y_score_train)
auprc = auc(xx, y)
print(f"micro-prc - train - {auprc}")
y, xx, _ = precision_recall_curve(y_test, y_score)
auprc = auc(xx, y)
print(f"micro-prc - test - {auprc}")

micro-prc - train - 0.8345149626588475
micro-prc - test - 0.8347589294847769


### side effect embedding learned from pose-pred



In [8]:
# get trined embedding of side effects
side_effect_embedding = model.mip.weight.data.numpy()
side_effect_embedding.shape

(861, 160)

In [3]:
# load side effect catagory
se_c = pandas.read_csv('index-map/side_effect+catagory.tsv', sep='\t')
se_c

Unnamed: 0,side effect,catagory,catagory name
0,atelectasis,C08,respiratory tract diseases
1,back ache,C23,"pathological conditions, signs and symptoms"
2,lung edema,C08,respiratory tract diseases
3,acidosis,C18,nutritional and metabolic diseases
4,peliosis,C06,digestive system diseases
...,...,...,...
838,corneal abrasion,C10,nervous system diseases
839,corneal abrasion,C11,eye diseases
840,corneal abrasion,C26,wounds and injuries
841,bunion,C05,musculoskeletal diseases


In [4]:
# side effect mapping
se_name2idx, se_idx2name = {}, {}
for se, idx in gdata.side_effect_name_to_idx.items():
    se = se.lower()
    se_name2idx[se] = idx
    se_idx2name[idx] = se

In [5]:
cata2se, cata2name = OrderedDict(), OrderedDict()
for se, ca, na in se_c.values:
    if ca == 'C23':
        ca = 'S'
    if cata2se.get(ca) is None:
        cata2se[ca] = []
    if cata2name.get(ca) is None:
        cata2name[ca] = na
    cata2se[ca].append(se_name2idx[se.lower()])
    se_set = set()

# update symptom cata
for ca, ses in cata2se.items():
    if ca == 'S':
        continue
    se_set.update(set(ses))
se_all = set(range(gdata.n_et))
se_tmp = (se_all - se_set) | set(cata2se['S'])
cata2se['S'] = list(se_tmp)


In [6]:
v = list(cata2se.values())[0]

In [9]:
se_ca_embed = np.concatenate([side_effect_embedding[v].mean(axis=0) for v in list(cata2se.values())]).reshape((-1, side_effect_embedding.shape[1]))
se_ca_embed = np.concatenate((se_ca_embed, side_effect_embedding[cata2se['F']]), axis=0)

In [13]:
ca_embed = se_ca_embed[:22]

In [14]:
umap = UMAP(n_components=2, init='random', random_state=8)
se_umap_np = umap.fit_transform(se_ca_embed)
se_umap = pandas.DataFrame(se_umap_np[:22, :], columns=['x', 'y'])
se_umap['category'] = [f"{i}" for i in cata2se.keys()]
se_umap['n_se'] = [len(v) for v in cata2se.values()]
se_umap['cat_all'] = [f"{i}-{cata2name[i]}" for i in cata2se.keys()]
se_umap['color'] = [f"{i[0]}" for i in cata2se.keys()]
se_umap.sort_values(by='category')



Unnamed: 0,x,y,category,n_se,cat_all,color
8,4.73751,8.858061,C01,50,C01-infections,C
10,6.22428,8.271191,C04,32,C04-neoplasms,C
17,6.210504,8.924448,C05,38,C05-musculoskeletal diseases,C
3,4.216232,9.131621,C06,55,C06-digestive system diseases,C
18,6.037347,8.604804,C07,18,C07-stomatognathic diseases,C
0,4.183841,9.74394,C08,39,C08-respiratory tract diseases,C
16,4.940152,11.09022,C09,21,C09-otorhinolaryngologic diseases,C
9,4.26257,11.085861,C10,96,C10-nervous system diseases,C
11,4.915984,10.317763,C11,24,C11-eye diseases,C
12,4.90381,9.444127,C12,30,C12-male urogenital diseases,C


(22, 6)

In [75]:
se_umap.to_csv('se_embed_2d_pred.csv')

In [15]:
import plotly.graph_objects as go
import plotly.express as px

In [16]:
# plot mental disorder
se_umap_md = se_umap_np[22:]

In [156]:
col_dict = {'C': 'rgb(93, 164, 214)', 'S': 'rgb(44, 160, 101)', 'F': 'rgb(255, 65, 54)'}
fig = go.Figure(data=[go.Scatter(
    x = se_umap['x'],
    y = se_umap['y'],
    mode = 'markers+text',
    text = se_umap['category'],
    textposition='middle center',
    textfont=dict(
        family="sans serif",
        size=15,
        # color="LightSeaGreen"
    ),
    marker = dict(color=[col_dict[i] for i in se_umap['color']], opacity=0.3, size=np.log(se_umap['n_se'])*15))])
# fig = px.scatter(se_umap, x='x', y='y', color=[col_dict[i] for i in se_umap['color']], text='category', size='n_se', size_max=100, opacity=0.3)

fig.add_trace(go.Scatter(
    x = se_umap_md[:, 0],
    y = se_umap_md[:, 1],
    mode='markers'
))

fig.update_xaxes(gridcolor='#e0e0e0', showline=True, linecolor='black', mirror=True)
fig.update_yaxes(gridcolor='#e0e0e0', showline=True, linecolor='black', mirror=True)
fig.update_layout(plot_bgcolor='#FFF')
fig.show()

In [17]:
# average after umap
umap = UMAP(n_components=2, init='random', random_state=8)
se_umap_np = umap.fit_transform(side_effect_embedding)

In [18]:
se_umap_np = (se_umap_np - se_umap_np.mean(axis=0))*10+10

In [19]:
se_ca_embed = np.concatenate([se_umap_np[v].mean(axis=0) for v in list(cata2se.values())]).reshape((-1, 2))
se_umap = pandas.DataFrame(se_ca_embed, columns=['x', 'y'])
se_umap['category'] = [f"{i}" for i in cata2se.keys()]
se_umap['n_se'] = [len(v) for v in cata2se.values()]
se_umap['cat_all'] = [f"{i} - {cata2name[i]}" for i in cata2se.keys()]
se_umap['color'] = [f"{i[0]}" for i in cata2se.keys()]
se_umap = se_umap.sort_values(by='category')
se_umap

Unnamed: 0,x,y,category,n_se,cat_all,color
8,9.086743,7.471842,C01,50,C01 - infections,C
10,13.716795,10.056763,C04,32,C04 - neoplasms,C
17,17.519848,13.967044,C05,38,C05 - musculoskeletal diseases,C
3,8.308552,5.27629,C06,55,C06 - digestive system diseases,C
18,23.783897,12.842505,C07,18,C07 - stomatognathic diseases,C
0,9.170995,8.973595,C08,39,C08 - respiratory tract diseases,C
16,12.150425,16.570591,C09,21,C09 - otorhinolaryngologic diseases,C
9,5.474772,13.598949,C10,96,C10 - nervous system diseases,C
11,8.93924,16.282234,C11,24,C11 - eye diseases,C
12,10.514725,6.916021,C12,30,C12 - male urogenital diseases,C


In [22]:
col_dict = {'C': 'rgb(93, 164, 214)', 'S': 'rgb(44, 160, 101)', 'F': 'rgb(255, 65, 54)'}
fig = go.Figure(data=[go.Scatter(
    x = se_umap['y'].values+5,
    y = se_umap['x'],
    mode = 'markers+text',
    text = se_umap['category'],
    textposition='middle center',
    textfont=dict(
        family="sans serif",
        size=15,
        # color="LightSeaGreen"
    ),
    marker = dict(color=[col_dict[i] for i in se_umap['color']], opacity=0.3, size=np.log(se_umap['n_se'])*15))])
# fig = px.scatter(se_umap, x='x', y='y', color=[col_dict[i] for i in se_umap['color']], text='category', size='n_se', size_max=100, opacity=0.3)

# se_umap_md = se_umap_np[cata2se['F']]
# fig.add_trace(go.Scatter(
#     x = se_umap_md[:, 1],
#     y = se_umap_md[:, 0],
#     mode='markers+text',

# ))

# se_umap_md = se_umap_np[cata2se['C07']]
# fig.add_trace(go.Scatter(
#     x = se_umap_md[:, 1],
#     y = se_umap_md[:, 0],
#     mode='markers+text',

# ))

# se_umap_md = se_umap_np[cata2se['S']]
# fig.add_trace(go.Scatter(
#     x = se_umap_md[:, 1],
#     y = se_umap_md[:, 0],
#     mode='markers+text',

# ))

fig.update_xaxes(range=(0, 40), gridcolor='#e0e0e0', showline=True, linecolor='black', mirror=True)
fig.update_yaxes(range=(2.5, 25), tickvals=[2.5, 5,7.5, 10,12.5,15, 17.5, 20,22.5, 25], gridcolor='#e0e0e0', showline=True, linecolor='black', mirror=True)
fig.update_layout(plot_bgcolor='#FFF', width=800, height=1000)
fig.show()

In [236]:
for i in se_umap['cat_all'].values:
    print(i)

C01 - infections
C04 - neoplasms
C05 - musculoskeletal diseases
C06 - digestive system diseases
C07 - stomatognathic diseases
C08 - respiratory tract diseases
C09 - otorhinolaryngologic diseases
C10 - nervous system diseases
C11 - eye diseases
C12 - male urogenital diseases
C13 - female urogenital diseases and pregnancy complications
C14 - cardiovascular diseases
C15 - hemic and lymphatic diseases
C16 - congenital, hereditary, and neonatal diseases and abnormalities
C17 - skin and connective tissue diseases
C18 - nutritional and metabolic diseases
C19 - endocrine system diseases
C20 - immune system diseases
C25 - chemically-induced disorders
C26 - wounds and injuries
F - mental discorders
S - pathological conditions, signs and symptoms


## overlapping between side effect categories


In [25]:
se_umap.head()

Unnamed: 0,x,y,category,n_se,cat_all,color
8,9.086743,7.471842,C01,50,C01 - infections,C
10,13.716795,10.056763,C04,32,C04 - neoplasms,C
17,17.519848,13.967044,C05,38,C05 - musculoskeletal diseases,C
3,8.308552,5.27629,C06,55,C06 - digestive system diseases,C
18,23.783897,12.842505,C07,18,C07 - stomatognathic diseases,C


In [84]:
cata2se.keys()

odict_keys(['C08', 'S', 'C18', 'C06', 'C14', 'C20', 'C25', 'C15', 'C01', 'C10', 'C04', 'C11', 'C12', 'C13', 'C26', 'C17', 'C09', 'C05', 'C07', 'C19', 'F', 'C16'])

In [68]:
n = se_umap.shape[0]
cate_pair_dis_js = []
for i in range(n):
    for j in range(i+1, n):
        # get category and coordinates
        x1, y1, c1, _, _, _ = se_umap.iloc[i, :]
        x2, y2, c2, _, _, _ = se_umap.iloc[j, :]

        # get side effets
        s1 = set(cata2se[c1])
        s2 = set(cata2se[c2])

        # compute distance
        dis = np.sqrt((x1-x2)**2 + (y1-y2)**2)
        
        # compute jaccard similarity coefficient score
        js = len(s1&s2) / len(s1) / len(s2)

        if js < 1e-10:
            continue

        cate_pair_dis_js.append([c1, c2, dis, js])
cate_pair_dis_js = pandas.DataFrame(cate_pair_dis_js)


In [69]:
cate_pair_dis_js

Unnamed: 0,0,1,2,3
0,C01,C05,10.644478,0.001053
1,C01,C06,2.329384,0.001818
2,C01,C07,15.647695,0.003333
3,C01,C08,1.504114,0.007179
4,C01,C09,9.600697,0.003810
...,...,...,...,...
95,C17,S,4.743208,0.000281
96,C18,C19,9.214382,0.003676
97,C18,S,1.425412,0.000129
98,C20,C25,11.519726,0.022556


In [82]:
fig = go.Figure(data=[go.Scatter(
    x = cate_pair_dis_js.iloc[:, 2],
    y = cate_pair_dis_js.iloc[:, 3],
    mode = 'markers+text',
    text = [f'{i}-{j}' if js > 0.006 else '' for i, j, _, js in cate_pair_dis_js.values],
    textposition='top center',
    textfont=dict(
        family="sans serif",
        size=12,
        # color="LightSeaGreen"
    ),
    marker = dict(opacity=0.3))])

fig.update_xaxes(gridcolor='#e0e0e0', zerolinecolor='#e0e0e0', showline=True, linecolor='black', mirror=True, title='embedding distance')
fig.update_yaxes(gridcolor='#e0e0e0', zerolinecolor='#e0e0e0', showline=True, linecolor='black', mirror=True, title='jaccard distance')
fig.update_layout(plot_bgcolor='#FFF')
fig.show()