In [1]:
!pip install PyTDC

Collecting PyTDC
  Downloading PyTDC-0.4.1.tar.gz (107 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/107.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m107.7/107.7 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting rdkit-pypi (from PyTDC)
  Downloading rdkit_pypi-2022.9.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (29.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.4/29.4 MB[0m [31m39.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting fuzzywuzzy (from PyTDC)
  Downloading fuzzywuzzy-0.18.0-py2.py3-none-any.whl (18 kB)
Collecting dataclasses (from PyTDC)
  Downloading dataclasses-0.6-py3-none-any.whl (14 kB)
Building wheels for collected packages: PyTDC
  Building wheel for PyTDC (setup.py) ... [?25l[?25hdone
  Created wheel for PyTDC: filename=PyTDC-0.4.1-py3-none-any.whl size=140644 sha256=37af

In [2]:
import tqdm
import numpy as np
import pandas as pd
import plotly.express as px
from tdc.multi_pred import DTI
import plotly.figure_factory as ff

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

In [None]:
data = DTI(name = 'Davis')
split = data.get_split()

Downloading...
  8%|▊         | 1.60M/21.4M [00:00<00:03, 5.42MiB/s]

## Data Provision

In [None]:
new_data=data.get_data()
# new_data['weights']=1/new_data['Y']
# new_data=new_data.sample(frac=0.5,weights='weights')
y_max=new_data['Y'].max()
y_min=new_data['Y'].min()
new_data['Y']=(new_data['Y']-y_min)/(y_max-y_min)
# new_data['Y']=-np.log10(new_data['Y']/1e9)
new_data=new_data.sample(frac=1)
new_data=new_data.reset_index(drop=True)

## Some basic Analysis

In [None]:
new_data['Drug_l']=new_data.Drug.apply(len)
new_data['Target_l']=new_data.Target.apply(len)
new_data[['Drug_l','Target_l','Y']].describe()

Unnamed: 0,Drug_l,Target_l,Y
count,25772.0,25772.0,25772.0
mean,54.176471,744.849604,0.755811
std,10.962637,372.813592,0.399002
min,32.0,244.0,0.0
25%,45.0,479.0,0.377499
50%,53.0,632.0,1.0
75%,61.25,954.0,1.0
max,81.0,2549.0,1.0


In [None]:
new_data[['Drug',"Target"]].tail()

Unnamed: 0,Drug,Target
25767,Cc1cc2c(F)c(Oc3ncnn4cc(OCC(C)O)c(C)c34)ccc2[nH]1,MTAVYMNGGGLVNPHYARWDRRDSVESGCQTESSKEGEEGQPRQLT...
25768,Cc1[nH]c(C=C2C(=O)Nc3ccc(F)cc32)c(C)c1C(=O)NCC...,MHTGGETSACKPSSVRLAPSFSFHAAGLQMAGQMPHSHQYSDRRQP...
25769,O=C(NOCC1CC1)c1ccc(F)c(F)c1Nc1ccc(I)cc1Cl,MELQAARACFALLWGCALAAAAAAQGKEVVLLDFAAAGGELGWLTH...
25770,Nc1nc(N)c2nc(-c3cccc(O)c3)c(-c3cccc(O)c3)nc2n1,MFQASMRSPNMEPFKQQKVEDFYDIGEELGSGQFAIVKKCREKSTG...
25771,Cc1nc(Nc2ncc(C(=O)Nc3c(C)cccc3Cl)s2)cc(N2CCN(C...,MVDMGALDNLIANTAYLQARKPSDCDSKELQRRRRSLALPGLQGCA...


In [None]:
fig = px.histogram(new_data['Y'], nbins=200,marginal="box")

# Show the plot
fig.show()

## Let's create the tokenizer

In [None]:
def tokenize(input_string):
  return [ord(char) for char in input_string]
def encode(input_string,max_length=128,padding=True):
  tokens=tokenize(input_string)
  if len(tokens)>max_length:
    tokens=tokens[:max_length]
  if (len(tokens)<max_length) & padding:
    tokens.extend([0 for _ in range(max_length-len(tokens))])
  return tokens
def decode(input_tokens):
  return ''.join(list(map(lambda x:chr(x), input_tokens)))

In [None]:
l_tokenizer=encode('z',padding=False)[0]+1

## Now, let's create the dataset object

In [None]:
class DTIA_Dataset(Dataset):
    def __init__(self, df,drug_max_length,target_max_length):
        self.df = df
        self.dml=drug_max_length
        self.tml=target_max_length
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row=self.df.iloc[idx]
        input_drug=torch.tensor(encode(row['Drug'],max_length=self.dml))
        input_target=torch.tensor(encode(row['Target'],max_length=self.tml))
        y=torch.tensor(row['Y'],dtype=torch.float32)
        return {'input_drug':input_drug,
         'input_target': input_target,
         'y':y}

In [None]:
dml=45
tml=700

In [None]:
l=int(new_data.shape[0]*0.8)
train_p=DTIA_Dataset(new_data[:l],drug_max_length=45,target_max_length=700)
test_p=DTIA_Dataset(new_data[l:],drug_max_length=45,target_max_length=700)

In [None]:
train_loader=DataLoader(train_p,batch_size=32,shuffle=True)
test_loader=DataLoader(test_p,batch_size=32)

## Now a model

In [None]:
def create_mask(size,distance):
    indices = torch.arange(size).view(1, -1)
    matrix = torch.abs(indices - indices.t()) <= distance
    return matrix.int()
class SingleHeadAttention(torch.nn.Module):
    def __init__(self, qkv_dim,embed_dim, dropout_rate=0.1):
        super(SingleHeadAttention, self).__init__()
        self.qkv_dim=qkv_dim
        self.embed_dim = embed_dim
        self.W_q = torch.nn.Linear(qkv_dim, embed_dim, bias=False)
        self.W_k = torch.nn.Linear(qkv_dim, embed_dim, bias=False)
        self.W_v = torch.nn.Linear(qkv_dim, embed_dim, bias=False)
        self.dp = nn.Dropout(dropout_rate)
    def forward(self, query, key, value, mask=None):
        Q = self.W_q(query) #Q,KVe
        Q=self.dp(Q)
        K = self.W_k(key)   #KV,KVe
        K=self.dp(K)
        V = self.W_v(value) #KV,Ve
        V=self.dp(V)
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.qkv_dim, dtype=torch.float32)) #Q,KV
        if mask is not None:
            attention_scores = attention_scores.masked_fill(mask == 0, float('-inf')) #Q,KV
        attention_weights = F.softmax(attention_scores, dim=-1) #Q,KV
        attended_values = torch.matmul(attention_weights, V) #Q,Ve
        return attended_values, attention_weights


class MultiHeadAttention(torch.nn.Module):
    def __init__(self, embed_size, num_heads,output_size,dropout_rate=0.1):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.attention_heads = torch.nn.ModuleList([SingleHeadAttention(embed_size,embed_size,dropout_rate) for _ in range(num_heads)])
        self.fc_out = torch.nn.Linear(num_heads * embed_size, output_size)

    def forward(self, query, key, value, mask=None):
        head_outputs = [attention(query, key, value, mask)[0] for attention in self.attention_heads]
        concatenated_output = torch.cat(head_outputs, dim=-1)
        output = self.fc_out(concatenated_output)
        return output

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
drug_mask=create_mask(45,5).to(device)
target_mask=create_mask(700,10).to(device)
# target_mask1=create_mask(700,10).to(device)
class Drug_Model_att(torch.nn.Module):
    def __init__(self, embed_dim=16,dim1=32):
        super(Drug_Model_att, self).__init__()
        self.embeddings = nn.Embedding(l_tokenizer, embed_dim)
        self.pos_embedding = nn.Embedding(45, embed_dim)
        self.att1=MultiHeadAttention(embed_dim,2,embed_dim)
        self.att2=MultiHeadAttention(embed_dim,2,embed_dim)
        self.att3=MultiHeadAttention(embed_dim,2,embed_dim)

    def forward(self, input_ids):
        x = self.embeddings(input_ids)
        x = x + self.pos_embedding(torch.arange(input_ids.size(1)).to(device))
        x=self.att1(x,x,x,drug_mask)
        x=self.att2(x,x,x,drug_mask)
        x=self.att3(x,x,x,drug_mask)
        x=torch.transpose(x,1,2)
        x = torch.squeeze(F.adaptive_avg_pool1d(x, 1),2)
        return x


class Attention_Block(torch.nn.Module):
  def __init__(self,embed_dim,n_att,dropout_rate=0.2):
        super(Attention_Block,self).__init__()
        self.atts=[MultiHeadAttention(embed_dim,2,embed_dim,dropout_rate).to(device) for _ in range(n_att)]
  def forward(self,x,mask=None):
    for layer in self.atts:
          x=layer(x,x,x,mask)
    return x

class Target_Model_att(torch.nn.Module):
    def __init__(self, embed_dim=16,dim1=32):
        super(Target_Model_att, self).__init__()
        self.embeddings = nn.Embedding(l_tokenizer, embed_dim)
        self.pos_embedding = nn.Embedding(700, embed_dim)
        self.att_b1=Attention_Block(embed_dim,2)
        self.pool1=nn.MaxPool1d(10)
        self.att_b2=Attention_Block(embed_dim,2)

        # self.fpool=nn.AvgPool1d(700)

    def forward(self, input_ids):
        x = self.embeddings(input_ids)
        x = x + self.pos_embedding(torch.arange(input_ids.size(1)).to(device))
        x=self.att_b1(x,target_mask)
        x=torch.transpose(x,1,2)
        x=self.pool1(x)
        x=torch.transpose(x,1,2)
        x=self.att_b2(x)
        x=torch.transpose(x,1,2)
        x = torch.squeeze(F.adaptive_max_pool1d(x, 1),2)
        return x


class DTIA_Model(torch.nn.Module):
    def __init__(self, embed_dim=32,dim1=64):
        super(DTIA_Model, self).__init__()
        self.drug_encoder=Drug_Model_att(embed_dim=embed_dim,dim1=dim1)
        self.target_encoder=Target_Model_att(embed_dim=embed_dim,dim1=dim1)
        self.dense1 = nn.Linear(embed_dim*2, 1024)
        self.dp1=nn.Dropout(0.2)
        self.dense2 = nn.Linear(1024, 1024)
        self.dp2=nn.Dropout(0.2)
        self.dense3 = nn.Linear(1024, 512)
        self.dp3=nn.Dropout(0.2)
        self.fdense = nn.Linear(512, 1)


    def forward(self, drug_inputs,target_inputs):
        drug_encoded=self.drug_encoder(drug_inputs)
        target_encoded=self.target_encoder(target_inputs)
        x=torch.concat((drug_encoded,target_encoded),dim=-1)
        x = F.relu(self.dp1(self.dense1(x)))
        x = F.relu(self.dp2(self.dense2(x)))
        x = F.relu(self.dp3(self.dense3(x)))
        x=self.fdense(x)

        return x

In [None]:

model=DTIA_Model(embed_dim=64,dim1=32)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device='cpu'
model.to(device)

DTIA_Model(
  (drug_encoder): Drug_Model_att(
    (embeddings): Embedding(123, 64)
    (pos_embedding): Embedding(45, 64)
    (att1): MultiHeadAttention(
      (attention_heads): ModuleList(
        (0-1): 2 x SingleHeadAttention(
          (W_q): Linear(in_features=64, out_features=64, bias=False)
          (W_k): Linear(in_features=64, out_features=64, bias=False)
          (W_v): Linear(in_features=64, out_features=64, bias=False)
          (dp): Dropout(p=0.1, inplace=False)
        )
      )
      (fc_out): Linear(in_features=128, out_features=64, bias=True)
    )
    (att2): MultiHeadAttention(
      (attention_heads): ModuleList(
        (0-1): 2 x SingleHeadAttention(
          (W_q): Linear(in_features=64, out_features=64, bias=False)
          (W_k): Linear(in_features=64, out_features=64, bias=False)
          (W_v): Linear(in_features=64, out_features=64, bias=False)
          (dp): Dropout(p=0.1, inplace=False)
        )
      )
      (fc_out): Linear(in_features=128, out_

## Now, training.

In [None]:
def train(num_epochs=200):
    criterion= nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
    best_test_loss = float('inf')

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        train_total=0.0
        test_loss=0.0
        test_total=0.0

        for batch in tqdm.tqdm(train_loader):
            input_drug=batch['input_drug'].to(device)
            input_target=batch['input_target'].to(device)
            y=batch['y'].to(device)
            optimizer.zero_grad()
            outputs = model(input_drug,input_target).view(-1,)
            loss = criterion(outputs,y)
            train_loss += loss.item()
            train_total+=outputs.size(0)
            loss.backward()
            optimizer.step()
        avg_train_loss = train_loss / len(train_loader)



        model.eval()
        with torch.no_grad():
            for batch in test_loader:
                input_drug=batch['input_drug'].to(device)
                input_target=batch['input_target'].to(device)
                y=batch['y'].to(device)
                outputs = model(input_drug,input_target).view(-1,)
                loss = criterion(outputs,y)
                test_total+=outputs.size(0)
                test_loss += loss.item()
        avg_test_loss = test_loss / len(test_loader)

        if avg_test_loss < best_test_loss:
            best_test_loss = avg_test_loss
            torch.save(model.state_dict(), "best_model.pt")
        print(f"Epoch {epoch+1}/{num_epochs}: "
              f"Train Loss: {avg_train_loss:.4f}, Valid Loss: {avg_test_loss:.4f}")

In [None]:
train()

100%|██████████| 645/645 [00:45<00:00, 14.12it/s]


Epoch 1/200: Train Loss: 0.1343, Valid Loss: 0.1330


100%|██████████| 645/645 [00:45<00:00, 14.17it/s]


Epoch 2/200: Train Loss: 0.1224, Valid Loss: 0.1211


100%|██████████| 645/645 [00:45<00:00, 14.18it/s]


Epoch 3/200: Train Loss: 0.1209, Valid Loss: 0.1173


100%|██████████| 645/645 [00:44<00:00, 14.34it/s]


Epoch 4/200: Train Loss: 0.1201, Valid Loss: 0.1254


100%|██████████| 645/645 [00:45<00:00, 14.19it/s]


Epoch 5/200: Train Loss: 0.1198, Valid Loss: 0.1229


100%|██████████| 645/645 [00:45<00:00, 14.27it/s]


Epoch 6/200: Train Loss: 0.1186, Valid Loss: 0.1183


100%|██████████| 645/645 [00:45<00:00, 14.28it/s]


Epoch 7/200: Train Loss: 0.1186, Valid Loss: 0.1172


100%|██████████| 645/645 [00:45<00:00, 14.20it/s]


Epoch 8/200: Train Loss: 0.1183, Valid Loss: 0.1172


100%|██████████| 645/645 [00:45<00:00, 14.19it/s]


Epoch 9/200: Train Loss: 0.1179, Valid Loss: 0.1183


100%|██████████| 645/645 [00:45<00:00, 14.27it/s]


Epoch 10/200: Train Loss: 0.1180, Valid Loss: 0.1184


100%|██████████| 645/645 [00:45<00:00, 14.23it/s]


Epoch 11/200: Train Loss: 0.1175, Valid Loss: 0.1169


100%|██████████| 645/645 [00:45<00:00, 14.06it/s]


Epoch 12/200: Train Loss: 0.1178, Valid Loss: 0.1175


100%|██████████| 645/645 [00:45<00:00, 14.26it/s]


Epoch 13/200: Train Loss: 0.1172, Valid Loss: 0.1173


100%|██████████| 645/645 [00:44<00:00, 14.39it/s]


Epoch 14/200: Train Loss: 0.1170, Valid Loss: 0.1218


100%|██████████| 645/645 [00:45<00:00, 14.23it/s]


Epoch 15/200: Train Loss: 0.1169, Valid Loss: 0.1189


100%|██████████| 645/645 [00:45<00:00, 14.30it/s]


Epoch 16/200: Train Loss: 0.1165, Valid Loss: 0.1182


100%|██████████| 645/645 [00:44<00:00, 14.34it/s]


Epoch 17/200: Train Loss: 0.1163, Valid Loss: 0.1172


100%|██████████| 645/645 [00:44<00:00, 14.36it/s]


Epoch 18/200: Train Loss: 0.1160, Valid Loss: 0.1173


100%|██████████| 645/645 [00:44<00:00, 14.43it/s]


Epoch 19/200: Train Loss: 0.1160, Valid Loss: 0.1179


100%|██████████| 645/645 [00:44<00:00, 14.37it/s]


Epoch 20/200: Train Loss: 0.1159, Valid Loss: 0.1154


100%|██████████| 645/645 [00:44<00:00, 14.43it/s]


Epoch 21/200: Train Loss: 0.1153, Valid Loss: 0.1160


100%|██████████| 645/645 [00:44<00:00, 14.42it/s]


Epoch 22/200: Train Loss: 0.1150, Valid Loss: 0.1146


100%|██████████| 645/645 [00:44<00:00, 14.39it/s]


Epoch 23/200: Train Loss: 0.1140, Valid Loss: 0.1139


100%|██████████| 645/645 [00:44<00:00, 14.36it/s]


Epoch 24/200: Train Loss: 0.1126, Valid Loss: 0.1160


100%|██████████| 645/645 [00:44<00:00, 14.34it/s]


Epoch 25/200: Train Loss: 0.1118, Valid Loss: 0.1105


100%|██████████| 645/645 [00:44<00:00, 14.37it/s]


Epoch 26/200: Train Loss: 0.1107, Valid Loss: 0.1091


100%|██████████| 645/645 [00:44<00:00, 14.41it/s]


Epoch 27/200: Train Loss: 0.1085, Valid Loss: 0.1112


100%|██████████| 645/645 [00:44<00:00, 14.37it/s]


Epoch 28/200: Train Loss: 0.1070, Valid Loss: 0.1060


100%|██████████| 645/645 [00:44<00:00, 14.42it/s]


Epoch 29/200: Train Loss: 0.1061, Valid Loss: 0.1047


100%|██████████| 645/645 [00:44<00:00, 14.39it/s]


Epoch 30/200: Train Loss: 0.1050, Valid Loss: 0.1091


100%|██████████| 645/645 [00:44<00:00, 14.41it/s]


Epoch 31/200: Train Loss: 0.1048, Valid Loss: 0.1032


100%|██████████| 645/645 [00:45<00:00, 14.26it/s]


Epoch 32/200: Train Loss: 0.1034, Valid Loss: 0.1080


100%|██████████| 645/645 [00:45<00:00, 14.29it/s]


Epoch 33/200: Train Loss: 0.1023, Valid Loss: 0.1026


100%|██████████| 645/645 [00:45<00:00, 14.26it/s]


Epoch 34/200: Train Loss: 0.1015, Valid Loss: 0.1038


100%|██████████| 645/645 [00:45<00:00, 14.32it/s]


Epoch 35/200: Train Loss: 0.1007, Valid Loss: 0.1014


100%|██████████| 645/645 [00:45<00:00, 14.25it/s]


Epoch 36/200: Train Loss: 0.1002, Valid Loss: 0.1026


100%|██████████| 645/645 [00:45<00:00, 14.33it/s]


Epoch 37/200: Train Loss: 0.1003, Valid Loss: 0.1016


100%|██████████| 645/645 [00:44<00:00, 14.41it/s]


Epoch 38/200: Train Loss: 0.1001, Valid Loss: 0.1021


100%|██████████| 645/645 [00:44<00:00, 14.36it/s]


Epoch 39/200: Train Loss: 0.0991, Valid Loss: 0.1010


100%|██████████| 645/645 [00:44<00:00, 14.37it/s]


Epoch 40/200: Train Loss: 0.0992, Valid Loss: 0.1004


100%|██████████| 645/645 [00:44<00:00, 14.36it/s]


Epoch 41/200: Train Loss: 0.0986, Valid Loss: 0.1009


100%|██████████| 645/645 [00:44<00:00, 14.36it/s]


Epoch 42/200: Train Loss: 0.0979, Valid Loss: 0.1048


100%|██████████| 645/645 [00:44<00:00, 14.37it/s]


Epoch 43/200: Train Loss: 0.0983, Valid Loss: 0.1007


100%|██████████| 645/645 [00:44<00:00, 14.42it/s]


Epoch 44/200: Train Loss: 0.0979, Valid Loss: 0.0996


100%|██████████| 645/645 [00:44<00:00, 14.38it/s]


Epoch 45/200: Train Loss: 0.0980, Valid Loss: 0.0998


100%|██████████| 645/645 [00:44<00:00, 14.38it/s]


Epoch 46/200: Train Loss: 0.0975, Valid Loss: 0.1022


100%|██████████| 645/645 [00:44<00:00, 14.41it/s]


Epoch 47/200: Train Loss: 0.0975, Valid Loss: 0.1044


100%|██████████| 645/645 [00:44<00:00, 14.49it/s]


Epoch 48/200: Train Loss: 0.0966, Valid Loss: 0.0990


100%|██████████| 645/645 [00:44<00:00, 14.42it/s]


Epoch 49/200: Train Loss: 0.0969, Valid Loss: 0.0986


100%|██████████| 645/645 [00:44<00:00, 14.50it/s]


Epoch 50/200: Train Loss: 0.0964, Valid Loss: 0.0987


100%|██████████| 645/645 [00:44<00:00, 14.44it/s]


Epoch 51/200: Train Loss: 0.0960, Valid Loss: 0.0988


100%|██████████| 645/645 [00:44<00:00, 14.39it/s]


Epoch 52/200: Train Loss: 0.0961, Valid Loss: 0.1100


100%|██████████| 645/645 [00:45<00:00, 14.32it/s]


Epoch 53/200: Train Loss: 0.0962, Valid Loss: 0.0993


100%|██████████| 645/645 [00:45<00:00, 14.32it/s]


Epoch 54/200: Train Loss: 0.0952, Valid Loss: 0.0985


100%|██████████| 645/645 [00:45<00:00, 14.26it/s]


Epoch 55/200: Train Loss: 0.0957, Valid Loss: 0.0997


100%|██████████| 645/645 [00:45<00:00, 14.32it/s]


Epoch 56/200: Train Loss: 0.0957, Valid Loss: 0.0969


100%|██████████| 645/645 [00:45<00:00, 14.28it/s]


Epoch 57/200: Train Loss: 0.0948, Valid Loss: 0.1001


100%|██████████| 645/645 [00:45<00:00, 14.32it/s]


Epoch 58/200: Train Loss: 0.0948, Valid Loss: 0.0976


100%|██████████| 645/645 [00:44<00:00, 14.36it/s]


Epoch 59/200: Train Loss: 0.0946, Valid Loss: 0.1027


100%|██████████| 645/645 [00:44<00:00, 14.42it/s]


Epoch 60/200: Train Loss: 0.0952, Valid Loss: 0.1005


100%|██████████| 645/645 [00:44<00:00, 14.34it/s]


Epoch 61/200: Train Loss: 0.0941, Valid Loss: 0.0981


100%|██████████| 645/645 [00:45<00:00, 14.33it/s]


Epoch 62/200: Train Loss: 0.0947, Valid Loss: 0.0968


100%|██████████| 645/645 [00:44<00:00, 14.41it/s]


Epoch 63/200: Train Loss: 0.0945, Valid Loss: 0.0987


100%|██████████| 645/645 [00:44<00:00, 14.44it/s]


Epoch 64/200: Train Loss: 0.0943, Valid Loss: 0.0970


100%|██████████| 645/645 [00:44<00:00, 14.44it/s]


Epoch 65/200: Train Loss: 0.0941, Valid Loss: 0.0969


100%|██████████| 645/645 [00:44<00:00, 14.49it/s]


Epoch 66/200: Train Loss: 0.0940, Valid Loss: 0.0967


100%|██████████| 645/645 [00:44<00:00, 14.44it/s]


Epoch 67/200: Train Loss: 0.0942, Valid Loss: 0.0974


100%|██████████| 645/645 [00:44<00:00, 14.41it/s]


Epoch 68/200: Train Loss: 0.0938, Valid Loss: 0.0976


100%|██████████| 645/645 [00:44<00:00, 14.42it/s]


Epoch 69/200: Train Loss: 0.0939, Valid Loss: 0.1002


100%|██████████| 645/645 [00:44<00:00, 14.43it/s]


Epoch 70/200: Train Loss: 0.0930, Valid Loss: 0.0966


100%|██████████| 645/645 [00:44<00:00, 14.38it/s]


Epoch 71/200: Train Loss: 0.0930, Valid Loss: 0.0982


100%|██████████| 645/645 [00:44<00:00, 14.40it/s]


Epoch 72/200: Train Loss: 0.0939, Valid Loss: 0.1004


100%|██████████| 645/645 [00:44<00:00, 14.38it/s]


Epoch 73/200: Train Loss: 0.0929, Valid Loss: 0.0973


100%|██████████| 645/645 [00:44<00:00, 14.43it/s]


Epoch 74/200: Train Loss: 0.0930, Valid Loss: 0.0958


100%|██████████| 645/645 [00:44<00:00, 14.43it/s]


Epoch 75/200: Train Loss: 0.0926, Valid Loss: 0.0971


100%|██████████| 645/645 [00:44<00:00, 14.38it/s]


Epoch 76/200: Train Loss: 0.0924, Valid Loss: 0.0975


100%|██████████| 645/645 [00:44<00:00, 14.61it/s]


Epoch 77/200: Train Loss: 0.0925, Valid Loss: 0.0997


100%|██████████| 645/645 [00:44<00:00, 14.63it/s]


Epoch 78/200: Train Loss: 0.0931, Valid Loss: 0.1017


100%|██████████| 645/645 [00:43<00:00, 14.67it/s]


Epoch 79/200: Train Loss: 0.0924, Valid Loss: 0.0968


100%|██████████| 645/645 [00:44<00:00, 14.64it/s]


Epoch 80/200: Train Loss: 0.0930, Valid Loss: 0.0967


100%|██████████| 645/645 [00:44<00:00, 14.62it/s]


Epoch 81/200: Train Loss: 0.0930, Valid Loss: 0.0959


100%|██████████| 645/645 [00:43<00:00, 14.66it/s]


Epoch 82/200: Train Loss: 0.0927, Valid Loss: 0.0946


100%|██████████| 645/645 [00:44<00:00, 14.44it/s]


Epoch 83/200: Train Loss: 0.0927, Valid Loss: 0.0955


100%|██████████| 645/645 [00:44<00:00, 14.58it/s]


Epoch 84/200: Train Loss: 0.0924, Valid Loss: 0.0967


100%|██████████| 645/645 [00:44<00:00, 14.50it/s]


Epoch 85/200: Train Loss: 0.0924, Valid Loss: 0.0949


100%|██████████| 645/645 [00:44<00:00, 14.51it/s]


Epoch 86/200: Train Loss: 0.0921, Valid Loss: 0.0965


100%|██████████| 645/645 [00:44<00:00, 14.50it/s]


Epoch 87/200: Train Loss: 0.0920, Valid Loss: 0.0951


100%|██████████| 645/645 [00:44<00:00, 14.65it/s]


Epoch 88/200: Train Loss: 0.0924, Valid Loss: 0.0958


100%|██████████| 645/645 [00:43<00:00, 14.72it/s]


Epoch 89/200: Train Loss: 0.0917, Valid Loss: 0.0952


100%|██████████| 645/645 [00:44<00:00, 14.50it/s]


Epoch 90/200: Train Loss: 0.0917, Valid Loss: 0.0976


100%|██████████| 645/645 [00:44<00:00, 14.34it/s]


Epoch 91/200: Train Loss: 0.0914, Valid Loss: 0.0961


100%|██████████| 645/645 [00:44<00:00, 14.38it/s]


Epoch 92/200: Train Loss: 0.0921, Valid Loss: 0.0954


100%|██████████| 645/645 [00:44<00:00, 14.45it/s]


Epoch 93/200: Train Loss: 0.0910, Valid Loss: 0.0950


100%|██████████| 645/645 [00:44<00:00, 14.39it/s]


Epoch 94/200: Train Loss: 0.0906, Valid Loss: 0.0958


100%|██████████| 645/645 [00:44<00:00, 14.46it/s]


Epoch 95/200: Train Loss: 0.0910, Valid Loss: 0.0967


100%|██████████| 645/645 [00:44<00:00, 14.47it/s]


Epoch 96/200: Train Loss: 0.0911, Valid Loss: 0.0959


100%|██████████| 645/645 [00:44<00:00, 14.40it/s]


Epoch 97/200: Train Loss: 0.0911, Valid Loss: 0.0960


100%|██████████| 645/645 [00:44<00:00, 14.40it/s]


Epoch 98/200: Train Loss: 0.0903, Valid Loss: 0.0943


100%|██████████| 645/645 [00:44<00:00, 14.47it/s]


Epoch 99/200: Train Loss: 0.0908, Valid Loss: 0.0943


100%|██████████| 645/645 [00:44<00:00, 14.58it/s]


Epoch 100/200: Train Loss: 0.0908, Valid Loss: 0.0947


100%|██████████| 645/645 [00:44<00:00, 14.48it/s]


Epoch 101/200: Train Loss: 0.0911, Valid Loss: 0.0953


100%|██████████| 645/645 [00:44<00:00, 14.50it/s]


Epoch 102/200: Train Loss: 0.0910, Valid Loss: 0.0942


100%|██████████| 645/645 [00:44<00:00, 14.36it/s]


Epoch 103/200: Train Loss: 0.0907, Valid Loss: 0.0952


100%|██████████| 645/645 [00:44<00:00, 14.42it/s]


Epoch 104/200: Train Loss: 0.0905, Valid Loss: 0.0963


100%|██████████| 645/645 [00:45<00:00, 14.27it/s]


Epoch 105/200: Train Loss: 0.0903, Valid Loss: 0.0970


100%|██████████| 645/645 [00:44<00:00, 14.41it/s]


Epoch 106/200: Train Loss: 0.0900, Valid Loss: 0.0952


100%|██████████| 645/645 [00:44<00:00, 14.50it/s]


Epoch 107/200: Train Loss: 0.0899, Valid Loss: 0.0939


100%|██████████| 645/645 [00:44<00:00, 14.54it/s]


Epoch 108/200: Train Loss: 0.0897, Valid Loss: 0.0955


100%|██████████| 645/645 [00:44<00:00, 14.45it/s]


Epoch 109/200: Train Loss: 0.0902, Valid Loss: 0.0941


100%|██████████| 645/645 [00:44<00:00, 14.45it/s]


Epoch 110/200: Train Loss: 0.0891, Valid Loss: 0.0969


100%|██████████| 645/645 [00:44<00:00, 14.49it/s]


Epoch 111/200: Train Loss: 0.0899, Valid Loss: 0.0955


100%|██████████| 645/645 [00:44<00:00, 14.37it/s]


Epoch 112/200: Train Loss: 0.0899, Valid Loss: 0.0948


100%|██████████| 645/645 [00:44<00:00, 14.62it/s]


Epoch 113/200: Train Loss: 0.0896, Valid Loss: 0.0963


100%|██████████| 645/645 [00:44<00:00, 14.61it/s]


Epoch 114/200: Train Loss: 0.0895, Valid Loss: 0.0927


100%|██████████| 645/645 [00:44<00:00, 14.58it/s]


Epoch 115/200: Train Loss: 0.0891, Valid Loss: 0.0932


100%|██████████| 645/645 [00:44<00:00, 14.54it/s]


Epoch 116/200: Train Loss: 0.0894, Valid Loss: 0.0962


 48%|████▊     | 312/645 [00:21<00:22, 14.95it/s]

In [None]:
train()

## VISUAL EVALUATION

In [None]:
ys=[]
preds=[]
for batch in tqdm.tqdm(test_loader):
    input_drug=batch['input_drug'].to(device)
    input_target=batch['input_target'].to(device)
    y=batch['y'].to(device)
    outputs = model(input_drug,input_target).view(-1,)
    ys.extend(list(y.detach().cpu().numpy()))
    preds.extend(list(outputs.detach().cpu().numpy()))
temp_df=pd.DataFrame({'ys':ys,'preds':preds})
px.scatter(temp_df,x='ys',y='preds')

In [None]:
temp_df.corr()