In [2]:
import numpy as np
from sklearn.metrics import pairwise_distances
from matplotlib import pyplot as plt
import pandas as pd
from chembl_webresource_client.new_client import new_client

In [181]:
# working with chembl to get data about molecules  - EXAMPLE
available_resources = [resource for resource in dir(new_client) if not resource.startswith('_')]
print(available_resources)
molecule = new_client.molecule
mols = molecule.filter(pref_name__iexact='ibuprofen')
mols[0]['atc_classifications']

# Here is how you read them:
#https://www.whocc.no/atc_ddd_index/?code=J05A&showdescription=no
# It is a hierarchy, with prefixes



['G02CC01', 'M01AE51', 'M02AA13', 'R02AX02', 'C01EB16', 'M01AE01']

In [182]:
# Read the list of molecules
molecules = pd.read_excel("data/gain-synonyms-list.xlsx")

In [232]:
molecules_list = list(molecules['Gain'].str.lower())
# selected_molecules = [m.split(" ")[-1] for m in molecules_list[0:7000] if (("+" not in m) and 

selected_molecules = [m.split(" ")[-1] for m in molecules_list[0:10000] if (("+" not in m) and 
                      (m.endswith("olol") or 
                      m.endswith('cillin') or 
                      m.endswith('sartan') or 
                      m.endswith('mycin') or 
                      m.endswith('vir') or 
                      m.endswith('parin') or
                      m.endswith('mab') or 
                      m.endswith('lamide') or
                      m.endswith('caine') or
                      m.endswith('bital') or
                      m.endswith('afil') or
                      m.endswith('asone') or
                      m.endswith('profen') or 
                      m.endswith('statin') or
                      m.endswith("fenac") or
                      m.endswith("floxacin") or 
                      m.endswith("nazole") or
                      m.endswith("parin") or 
                      m.endswith("setron") or 
                      m.endswith("tadine") or 
                      m.endswith("thiazide") or
                      m.endswith("vudine") or
                      m.endswith('tinib')))]

In [233]:
selected_molecules = list(set(selected_molecules))
len(selected_molecules)

246

In [234]:
ATC_codes = {}
for mol in selected_molecules:
    mols = molecule.filter(pref_name__iexact=mol)
    if len(mols)>0:
        ATC_codes[mol] = mols[0]['atc_classifications']
    else:
        print(mol + " not found")
ATC_codes

aciclovir not found
oxetacaine not found
flucloxacillin not found
dalteparin not found
glibenclamide not found
beclometasone not found
tinzaparin not found
lignocaine not found
clobetasone not found
valaciclovir not found
reviparin not found
phenoxymethylpenicillin not found


{'dasabuvir': ['J05AP09'],
 'ofloxacin': ['J01MA01', 'S02AA16', 'S01AE01'],
 'phenobarbital': ['N03AA02'],
 'telmisartan': ['C09CA07'],
 'betaxolol': ['S01ED52', 'C07AB05', 'S01ED02'],
 'ritonavir': ['J05AE03'],
 'pitavastatin': ['C10AA08'],
 'loratadine': ['R06AX13'],
 'axitinib': ['L01EK01'],
 'bleomycin': ['L01DC01'],
 'adalimumab': ['L04AB04'],
 'idrocilamide': ['M02AX05'],
 'stavudine': ['J05AF04'],
 'lamivudine': ['J05AF05'],
 'zanubrutinib': ['L01EL03'],
 'ropivacaine': ['N01BB09'],
 'cobimetinib': ['L01EE02'],
 'ondansetron': ['A04AA01'],
 'dolutegravir': ['J05AJ03'],
 'talinolol': ['C07AB13'],
 'moxifloxacin': ['J01MA14', 'S01AE07'],
 'nilotinib': ['L01EA03'],
 'cabozantinib': ['L01EX07'],
 'prilocaine': ['N01BB04', 'N01BB54'],
 'binimetinib': ['L01EE03'],
 'hydrochlorothiazide': ['C03AX01', 'C03AA03'],
 'sofosbuvir': ['J05AP08'],
 'ceritinib': ['L01ED02'],
 'fluticasone': [],
 'sunitinib': ['L01EX01'],
 'tadalafil': ['G04BE08'],
 'pindolol': ['C07AA03'],
 'baricitinib': ['L04

In [236]:
# Select highest class in the ATC hierarchy
ATC_codes_highest = {}
for mol in ATC_codes.keys():
    a = list(set([code[0:3] for code in ATC_codes[mol]]))
    if len(a) == 1:
        ATC_codes_highest[mol] = list(set([code[0] for code in ATC_codes[mol]]))[0]
classif = pd.DataFrame(ATC_codes_highest.keys())
classif.columns = ["Molecule"]
classif['ATC_CODE'] =  ATC_codes_highest.values()
classif

Unnamed: 0,Molecule,ATC_CODE
0,dasabuvir,J
1,phenobarbital,N
2,telmisartan,C
3,ritonavir,J
4,pitavastatin,C
...,...,...
170,nebivolol,C
171,voriconazole,J
172,propranolol,C
173,dicloxacillin,J


In [212]:
from transformers import AutoTokenizer, AutoModel
import torch as torch

tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [216]:
classif['ATC_CODE'].value_counts()

L    61
J    43
C    29
S     8
N     7
R     5
A     5
G     4
M     3
D     3
B     2
H     1
Name: ATC_CODE, dtype: int64

In [246]:
classif['ATC_CODE'].value_counts()
classif = classif[classif['ATC_CODE'].isin(["L", "J", "C", "N", "S", "G", "R", "A"])].reset_index(drop = True)
classif = classif.sample(frac = 1).reset_index(drop = True)
classif['ATC_CODE'].value_counts()

L    61
J    47
C    29
S     8
N     7
A     5
R     5
G     4
Name: ATC_CODE, dtype: int64

In [247]:
classif

Unnamed: 0,Molecule,ATC_CODE
0,simeprevir,J
1,tipranavir,J
2,tadalafil,G
3,brigatinib,L
4,chloroprocaine,N
...,...,...
161,voriconazole,J
162,lorlatinib,L
163,palonosetron,A
164,bleomycin,L


In [248]:
#embeddings and cosines
mat = np.zeros([classif.shape[0], 768])
i=0
print("Computing embeddings...")
for drug in classif['Molecule']:
    print(drug)
    inputs = tokenizer.encode(drug, return_tensors="pt")
    with torch.no_grad():
        last_hidden_states = model(inputs)[0] # Models outputs are now tuples
        last_hidden_states = last_hidden_states.mean(1)
        mat[i,:] = last_hidden_states
        i+=1
mat

Computing embeddings...
simeprevir
tipranavir
tadalafil
brigatinib
chloroprocaine
vardenafil
panitumumab
mitomycin
talinolol
pindolol
canakinumab
sildenafil
pentostatin
oseltamivir
tropisetron
cabozantinib
rosuvastatin
bromfenac
mepolizumab
abacavir
itraconazole
letermovir
prilocaine
golimumab
obinutuzumab
zanubrutinib
dolutegravir
amoxicillin
metipranolol
rituximab
bendroflumethiazide
selpercatinib
grazoprevir
indinavir
infliximab
clarithromycin
bosutinib
imatinib
afatinib
neratinib
maribavir
bupivacaine
tofacitinib
posaconazole
alectinib
granisetron
palivizumab
dasabuvir
ponatinib
daclatasvir
dasatinib
raltegravir
saquinavir
phenobarbital
cobimetinib
nilotinib
avanafil
valganciclovir
omalizumab
fosamprenavir
daptomycin
crizotinib
fedratinib
natalizumab
nelfinavir
hydrochlorothiazide
telmisartan
dicloxacillin
dorzolamide
pralsetinib
sunitinib
brinzolamide
methyclothiazide
filgotinib
nepafenac
piperacillin
esmolol
lenvatinib
rupatadine
amantadine
ceritinib
ropivacaine
candesartan
mezlo

array([[ 1.67311564e-01,  1.17157809e-01,  1.33802310e-01, ...,
        -6.23593573e-03, -2.66351905e-02, -2.99514621e-03],
       [ 1.06332973e-01,  1.19729951e-01,  1.70214608e-01, ...,
         4.05985832e-01, -1.28458440e-01,  2.20620468e-01],
       [ 6.71423459e-03, -1.71441659e-02,  2.21307024e-01, ...,
         1.90958425e-01,  2.10867822e-02,  7.22294450e-02],
       ...,
       [ 8.60813037e-02,  1.32632747e-04, -9.07927752e-02, ...,
         2.27935519e-02,  2.79648066e-01,  1.85039341e-01],
       [ 2.27924883e-01,  2.83039778e-01, -1.19182311e-01, ...,
         1.17901593e-01,  1.34647176e-01,  1.50358200e-01],
       [ 1.44957945e-01,  1.17119960e-01, -8.53406340e-02, ...,
         1.44371882e-01,  6.86201081e-02,  1.27034381e-01]])

In [249]:
data1 = pd.concat([classif,pd.DataFrame(mat)], axis = 1)
data1 = data1.drop(columns = ["Molecule"])
data1.columns = ["ATC_CODE"] + ["EMB"+str(c) for c in data1.columns[1:]]
data1

Unnamed: 0,ATC_CODE,EMB0,EMB1,EMB2,EMB3,EMB4,EMB5,EMB6,EMB7,EMB8,...,EMB758,EMB759,EMB760,EMB761,EMB762,EMB763,EMB764,EMB765,EMB766,EMB767
0,J,0.167312,0.117158,0.133802,0.349809,0.123612,-0.007661,0.122647,0.222882,0.454677,...,0.243016,-0.165767,-0.590992,-0.006415,0.265588,0.076795,-0.119888,-0.006236,-0.026635,-0.002995
1,J,0.106333,0.119730,0.170215,0.155003,-0.297473,0.059607,0.187581,0.181555,0.405078,...,0.017044,-0.004044,-0.414418,-0.218121,0.015648,0.171644,-0.101942,0.405986,-0.128458,0.220620
2,G,0.006714,-0.017144,0.221307,0.164459,-0.198167,0.046431,-0.021800,0.198044,0.595853,...,-0.259564,-0.070330,-0.240444,0.170309,0.471298,0.328822,0.094685,0.190958,0.021087,0.072229
3,L,0.453510,0.491506,-0.247160,-0.095819,-0.136647,-0.171095,0.168358,0.185787,0.252265,...,-0.294981,-0.013078,-0.282505,0.019374,-0.091646,0.174837,-0.085112,0.278661,-0.180826,-0.062410
4,N,-0.005502,-0.149610,-0.084269,0.112583,0.125535,-0.172069,0.046240,0.370068,0.231221,...,-0.115704,0.069219,-0.302195,-0.030950,0.096424,0.057980,0.307277,0.508175,-0.095182,0.235100
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
161,J,0.201963,0.080109,0.343997,0.052824,-0.155790,0.182558,-0.036449,0.006687,0.226137,...,-0.253253,0.051454,-0.178581,-0.213150,0.139666,0.184857,-0.179525,0.197767,0.415012,0.116435
162,L,0.308642,0.229553,0.009449,0.012115,-0.067055,-0.126821,-0.109676,0.056756,0.200215,...,-0.146335,-0.034794,-0.318155,-0.091889,0.218285,0.345893,-0.003602,0.047165,-0.058309,0.035905
163,A,0.086081,0.000133,-0.090793,-0.010786,-0.212796,-0.130711,0.139322,0.029819,-0.046446,...,0.158807,-0.093717,-0.170097,-0.070406,0.086640,0.214953,-0.007235,0.022794,0.279648,0.185039
164,L,0.227925,0.283040,-0.119182,0.080856,0.027410,0.041163,0.020383,0.126342,0.226626,...,0.009477,0.031909,-0.313182,0.023950,0.483168,0.194325,0.155545,0.117902,0.134647,0.150358


In [250]:
train_examples = 80
all_examples = data1.shape[0]

In [251]:
# Method 1: each embedding is a feature

from catboost import CatBoostClassifier
from sklearn.metrics import confusion_matrix

train_data1 = data1.drop(columns = ['ATC_CODE']).iloc[0:train_examples] 
train_labels_data1 = data1['ATC_CODE'].iloc[0:train_examples]
test_data1 = data1.drop(columns = ['ATC_CODE']).iloc[train_examples:all_examples] 
test_labels_data1 = data1['ATC_CODE'].iloc[train_examples:all_examples] 

clf = CatBoostClassifier(
    iterations=100, 
    learning_rate=0.1, 
    loss_function='MultiClass'
)

clf.fit(train_data1, train_labels_data1,  
        eval_set=(test_data1, test_labels_data1), 
        verbose=True
)

print('CatBoost model is fitted: ' + str(clf.is_fitted()))
print('CatBoost model parameters:')
print(clf.get_params())

preds_raw = clf.predict(test_data1)
confusion_matrix( test_labels_data1, preds_raw)

0:	learn: 2.0118559	test: 2.0464586	best: 2.0464586 (0)	total: 191ms	remaining: 18.9s
1:	learn: 1.9525276	test: 2.0304807	best: 2.0304807 (1)	total: 326ms	remaining: 16s
2:	learn: 1.8960607	test: 1.9959669	best: 1.9959669 (2)	total: 473ms	remaining: 15.3s
3:	learn: 1.8511744	test: 1.9659799	best: 1.9659799 (3)	total: 610ms	remaining: 14.6s
4:	learn: 1.8099187	test: 1.9416851	best: 1.9416851 (4)	total: 757ms	remaining: 14.4s
5:	learn: 1.7733592	test: 1.9193666	best: 1.9193666 (5)	total: 893ms	remaining: 14s
6:	learn: 1.7212131	test: 1.8941485	best: 1.8941485 (6)	total: 1.04s	remaining: 13.8s
7:	learn: 1.6797279	test: 1.8777251	best: 1.8777251 (7)	total: 1.19s	remaining: 13.7s
8:	learn: 1.6446453	test: 1.8594441	best: 1.8594441 (8)	total: 1.32s	remaining: 13.4s
9:	learn: 1.5975550	test: 1.8326730	best: 1.8326730 (9)	total: 1.47s	remaining: 13.2s
10:	learn: 1.5566877	test: 1.8144908	best: 1.8144908 (10)	total: 1.64s	remaining: 13.3s
11:	learn: 1.5170203	test: 1.7949789	best: 1.7949789 (11

94:	learn: 0.2627234	test: 1.1360035	best: 1.1360035 (94)	total: 14.5s	remaining: 763ms
95:	learn: 0.2578442	test: 1.1326268	best: 1.1326268 (95)	total: 14.7s	remaining: 611ms
96:	learn: 0.2534972	test: 1.1302799	best: 1.1302799 (96)	total: 14.8s	remaining: 458ms
97:	learn: 0.2488582	test: 1.1252829	best: 1.1252829 (97)	total: 15s	remaining: 305ms
98:	learn: 0.2446630	test: 1.1227259	best: 1.1227259 (98)	total: 15.1s	remaining: 153ms
99:	learn: 0.2407409	test: 1.1219146	best: 1.1219146 (99)	total: 15.3s	remaining: 0us

bestTest = 1.121914648
bestIteration = 99

CatBoost model is fitted: True
CatBoost model parameters:
{'iterations': 100, 'learning_rate': 0.1, 'loss_function': 'MultiClass'}


array([[ 0,  0,  0,  3,  0,  0,  0],
       [ 0,  4,  2, 15,  0,  0,  0],
       [ 0,  0, 21,  2,  0,  0,  0],
       [ 0,  0,  3, 29,  0,  0,  0],
       [ 0,  0,  0,  0,  2,  0,  0],
       [ 0,  0,  0,  2,  0,  0,  0],
       [ 0,  1,  1,  1,  0,  0,  0]], dtype=int64)

In [252]:
# Second option, the embeddings are one feature

from catboost import CatBoostClassifier, Pool

train_data = Pool(
    [
        [list(mat[i,].flatten()),"dummy"] for i in range(train_examples)
    ],
    label = list(classif['ATC_CODE'])[0:train_examples],
    cat_features=[1],
    embedding_features=[0]
)

test_data = Pool(
    [
        [list(mat[i,].flatten()),"dummy"] for i in range(train_examples, all_examples)
    ],
    label = list(classif['ATC_CODE'])[train_examples:all_examples],
    cat_features=[1],
    embedding_features=[0]
)

clf = CatBoostClassifier(
    iterations=70, 
    learning_rate=0.1, 
    loss_function='MultiClass'
)

clf.fit(train_data, eval_set=test_data)
preds_class = clf.predict(test_data)
#pd.DataFrame(zip([p[0] for p in preds_class], list(classif['ATC_CODE'])[100:133]))
confusion_matrix(list(classif['ATC_CODE'])[train_examples:all_examples], [p[0] for p in preds_class])

0:	learn: 2.0084648	test: 2.0200774	best: 2.0200774 (0)	total: 4.92ms	remaining: 340ms
1:	learn: 1.9267961	test: 1.9625076	best: 1.9625076 (1)	total: 8.31ms	remaining: 283ms
2:	learn: 1.8443941	test: 1.8900145	best: 1.8900145 (2)	total: 11.6ms	remaining: 259ms
3:	learn: 1.7342478	test: 1.8030286	best: 1.8030286 (3)	total: 14.8ms	remaining: 243ms
4:	learn: 1.6635941	test: 1.7627266	best: 1.7627266 (4)	total: 18.2ms	remaining: 237ms
5:	learn: 1.6239734	test: 1.7330612	best: 1.7330612 (5)	total: 21.7ms	remaining: 232ms
6:	learn: 1.5621182	test: 1.6840753	best: 1.6840753 (6)	total: 25.3ms	remaining: 228ms
7:	learn: 1.5202333	test: 1.6537695	best: 1.6537695 (7)	total: 27.6ms	remaining: 214ms
8:	learn: 1.4638323	test: 1.6116802	best: 1.6116802 (8)	total: 30.7ms	remaining: 208ms
9:	learn: 1.4310573	test: 1.5843552	best: 1.5843552 (9)	total: 33.9ms	remaining: 204ms
10:	learn: 1.3999901	test: 1.5699656	best: 1.5699656 (10)	total: 37.4ms	remaining: 201ms
11:	learn: 1.3661170	test: 1.5330278	best

array([[ 0,  0,  0,  1,  2,  0,  0],
       [ 0,  9,  0,  9,  0,  0,  3],
       [ 0,  0, 20,  3,  0,  0,  0],
       [ 0,  0,  2, 30,  0,  0,  0],
       [ 0,  1,  0,  0,  1,  0,  0],
       [ 0,  0,  0,  2,  0,  0,  0],
       [ 0,  0,  2,  1,  0,  0,  0]], dtype=int64)