In [1]:
import tqdm
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.figure_factory as ff

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForMaskedLM, AutoTokenizer
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

/kaggle/input/open-problems-single-cell-perturbations/multiome_train.parquet
/kaggle/input/open-problems-single-cell-perturbations/multiome_obs_meta.csv
/kaggle/input/open-problems-single-cell-perturbations/sample_submission.csv
/kaggle/input/open-problems-single-cell-perturbations/adata_train.parquet
/kaggle/input/open-problems-single-cell-perturbations/multiome_var_meta.csv
/kaggle/input/open-problems-single-cell-perturbations/adata_obs_meta.csv
/kaggle/input/open-problems-single-cell-perturbations/id_map.csv
/kaggle/input/open-problems-single-cell-perturbations/de_train.parquet
/kaggle/input/open-problems-single-cell-perturbations/adata_excluded_ids.csv


In [2]:
de_train = pd.read_parquet('../input/open-problems-single-cell-perturbations/de_train.parquet')
id_map = pd.read_csv('../input/open-problems-single-cell-perturbations/id_map.csv')
sample_submission = pd.read_csv('../input/open-problems-single-cell-perturbations/sample_submission.csv', index_col='id')

### Distributional Analysis

In [None]:
de_train.iloc[:,5:].describe()

## Data Creation


In [None]:
de_train=de_train.sample(frac=1)
thr=int(de_train.shape[0]*0.8)
train_data=de_train.iloc[:thr]
test_data=de_train.iloc[thr:]

In [None]:
# train_data.sm_name.unique()

In [None]:
chemberta = AutoModelForMaskedLM.from_pretrained("DeepChem/ChemBERTa-77M-MTR")
tokenizer = AutoTokenizer.from_pretrained("DeepChem/ChemBERTa-77M-MTR")
chemberta._modules["lm_head"] = nn.Identity()

In [None]:
def chemberta_features(smiles):
    encoded_input=tokenizer(smiles,return_tensors="pt", padding=False, truncation=True)
    model_output = chemberta(**encoded_input)
    return model_output.logits.mean(dim=1)

In [None]:
gene_keys=list(de_train.keys()[5:])
cell_gene_mean = train_data[['cell_type']+gene_keys].groupby('cell_type').mean().reset_index()
sm_gene_mean = train_data[['sm_name']+gene_keys].groupby('sm_name').mean().reset_index()
cell_gene_std = train_data[['cell_type']+gene_keys].groupby('cell_type').std().reset_index().replace({np.nan:0})
sm_gene_std = train_data[['sm_name']+gene_keys].groupby('sm_name').std().reset_index().replace({np.nan:0})

In [None]:
cell_type_len=len(train_data.cell_type.unique())
sm_name_len=len(train_data.sm_name.unique())

In [None]:
cell_type_dict={item:i for i,item in enumerate(train_data.cell_type.unique())}
sm_dict={item:i for i,item in enumerate(train_data.sm_name.unique())}

In [56]:
class Drug_Dataset(Dataset):
    def __init__(self, df,features,):
        self.df = df
        self.features=features
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row=self.df.iloc[idx]
        x=np.array([])
        y=torch.tensor(np.array(row[gene_keys]).astype(np.float32),dtype=torch.float32)
        
        for f in self.features:
            indicator=f.keys()[0]
            ind=row[indicator]
            extracted_f=f[f[indicator]==ind].iloc[0,list(range(1,f.shape[1]))]
            x=np.concatenate([x,np.array(extracted_f).astype(np.float32)],axis=0)
        x=np.concatenate([x,chemberta_features(row['SMILES']).detach().numpy()[0,:]],axis=0)
        x=torch.tensor(x,dtype=torch.float32)
        cell=torch.tensor(cell_type_dict[row['cell_type']])
        drug=torch.tensor(sm_dict[row['sm_name']])
        return {'x':x,'cell':cell,'drug':drug,'y':y}


In [57]:
train_p=Drug_Dataset(train_data,[cell_gene_mean,sm_gene_mean,cell_gene_std,sm_gene_std])
test_p=Drug_Dataset(test_data,[cell_gene_mean,sm_gene_mean,cell_gene_std,sm_gene_std])

In [59]:
train_loader=DataLoader(train_p,batch_size=32,shuffle=True)
test_loader=DataLoader(test_p,batch_size=8)

In [60]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [61]:
class Decoder(torch.nn.Module):
    def __init__(self, embed_dim=16,dim1=64,dim2=128):
        super(Decoder, self).__init__()
        self.cell_type_embeddings = nn.Embedding(cell_type_len, embed_dim)
        self.sm_name_embeddings = nn.Embedding(sm_name_len, embed_dim)
#         self.bn0=nn.BatchNorm1d(embed_dim*2+train_p[0]['x'].shape[0])
        self.dense1=nn.Linear(embed_dim*2+train_p[0]['x'].shape[0],dim1)
        self.bn1=nn.BatchNorm1d(dim1)
        self.dp1=nn.Dropout(0.2)
        self.dense2=nn.Linear(dim1,dim2)
        self.bn2=nn.BatchNorm1d(dim2)
        self.dp2=nn.Dropout(0.2)
        self.dense3=nn.Linear(dim2,18211)

    def forward(self, x_cell,x_sm,x):
        x1 = self.cell_type_embeddings(x_cell)
        x2 = self.sm_name_embeddings(x_sm)
        x=torch.concat([x1,x2,x],dim=-1)
#         x=self.bn0(x)
        x=self.dp1(F.relu((self.bn1(self.dense1(x)))))
        x=self.dp2(F.relu((self.bn2(self.dense2(x)))))
        x=self.dense3(x)
        return x


In [62]:
model=Decoder(embed_dim=16,dim1=64,dim2=128)
model.to(device)

Decoder(
  (cell_type_embeddings): Embedding(6, 16)
  (sm_name_embeddings): Embedding(146, 16)
  (dense1): Linear(in_features=73260, out_features=64, bias=True)
  (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dp1): Dropout(p=0.2, inplace=False)
  (dense2): Linear(in_features=64, out_features=128, bias=True)
  (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dp2): Dropout(p=0.2, inplace=False)
  (dense3): Linear(in_features=128, out_features=18211, bias=True)
)

In [63]:
def mrrmse(outputs,y):
    return (torch.sqrt(((outputs-y)**2).mean(dim=-1))).mean()

In [64]:
def train(num_epochs=1000):
    criterion= nn.MSELoss()
    criterion2= mrrmse
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    best_test_loss = float('inf')

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        train_loss2 = 0.0
        train_total=0.0
        test_loss=0.0
        test_loss2=0.0
        test_total=0.0
        

        for batch in tqdm.tqdm(train_loader):
            cell=batch['cell'].to(device)
            drug=batch['cell'].to(device)
            x=batch['x'].to(device)
            y=batch['y'].to(device)
            optimizer.zero_grad()
            outputs = model(cell,drug,x)
            loss = criterion(outputs,y)
            loss2 = criterion2(outputs,y)
            train_loss += loss.item()
            train_total+=outputs.size(0)
            train_loss2 += loss2.item()
            loss.backward()
            optimizer.step()
        avg_train_loss = train_loss / len(train_loader)
        avg_train_loss2 = train_loss2 / len(train_loader)


        model.eval()
        with torch.no_grad():
            for batch in test_loader:
                cell=batch['cell'].to(device)
                drug=batch['cell'].to(device)
                x=batch['x'].to(device)
                y=batch['y'].to(device)
                outputs = model(cell,drug,x)
                loss1 = criterion(outputs,y)
                loss2 = criterion2(outputs,y)
                test_total+=outputs.size(0)
                test_loss += loss.item()
                test_loss2 += loss2.item()
        avg_test_loss = test_loss / len(test_loader)
        avg_test_loss2 = test_loss2 / len(test_loader)

        if avg_test_loss < best_test_loss:
            best_test_loss = avg_test_loss
            torch.save(model.state_dict(), "best_model.pt")
        print(f"Epoch {epoch+1}/{num_epochs}: "
              f"Train MSE Loss: {avg_train_loss:.4f} , Train MRRMSE Loss: {avg_train_loss2:.4f}, Valid MSE Loss: {avg_test_loss:.4f}, Valid MRRMSE Loss: {avg_test_loss2:.4f}")

In [None]:
train()

100%|██████████| 16/16 [01:08<00:00,  4.31s/it]


Epoch 1/1000: Train MSE Loss: 5.4611 , Train MRRMSE Loss: 1.3457, Valid MSE Loss: 1.2459, Valid MRRMSE Loss: 1.2358


100%|██████████| 16/16 [01:09<00:00,  4.31s/it]


Epoch 2/1000: Train MSE Loss: 5.5421 , Train MRRMSE Loss: 1.3541, Valid MSE Loss: 5.5180, Valid MRRMSE Loss: 1.2240


100%|██████████| 16/16 [01:08<00:00,  4.31s/it]


Epoch 3/1000: Train MSE Loss: 5.8493 , Train MRRMSE Loss: 1.3621, Valid MSE Loss: 14.6035, Valid MRRMSE Loss: 1.2216


100%|██████████| 16/16 [01:09<00:00,  4.37s/it]


Epoch 4/1000: Train MSE Loss: 5.8354 , Train MRRMSE Loss: 1.3535, Valid MSE Loss: 16.7365, Valid MRRMSE Loss: 1.2132


100%|██████████| 16/16 [01:10<00:00,  4.38s/it]


Epoch 5/1000: Train MSE Loss: 5.3470 , Train MRRMSE Loss: 1.3140, Valid MSE Loss: 6.8601, Valid MRRMSE Loss: 1.2112


100%|██████████| 16/16 [01:09<00:00,  4.36s/it]


Epoch 6/1000: Train MSE Loss: 5.0441 , Train MRRMSE Loss: 1.2830, Valid MSE Loss: 1.4920, Valid MRRMSE Loss: 1.2054


100%|██████████| 16/16 [01:09<00:00,  4.37s/it]


Epoch 7/1000: Train MSE Loss: 4.9631 , Train MRRMSE Loss: 1.2701, Valid MSE Loss: 1.3266, Valid MRRMSE Loss: 1.2045


100%|██████████| 16/16 [01:09<00:00,  4.33s/it]


Epoch 8/1000: Train MSE Loss: 4.8074 , Train MRRMSE Loss: 1.2526, Valid MSE Loss: 0.9484, Valid MRRMSE Loss: 1.2017


100%|██████████| 16/16 [01:09<00:00,  4.34s/it]


Epoch 9/1000: Train MSE Loss: 4.7106 , Train MRRMSE Loss: 1.2440, Valid MSE Loss: 1.3301, Valid MRRMSE Loss: 1.1973


100%|██████████| 16/16 [01:09<00:00,  4.35s/it]


Epoch 10/1000: Train MSE Loss: 4.9376 , Train MRRMSE Loss: 1.2641, Valid MSE Loss: 7.5556, Valid MRRMSE Loss: 1.1961


100%|██████████| 16/16 [01:09<00:00,  4.36s/it]


Epoch 11/1000: Train MSE Loss: 4.5045 , Train MRRMSE Loss: 1.2226, Valid MSE Loss: 0.8257, Valid MRRMSE Loss: 1.1953


100%|██████████| 16/16 [01:09<00:00,  4.32s/it]


Epoch 12/1000: Train MSE Loss: 5.0784 , Train MRRMSE Loss: 1.2656, Valid MSE Loss: 16.4708, Valid MRRMSE Loss: 1.1891


100%|██████████| 16/16 [01:09<00:00,  4.33s/it]


Epoch 13/1000: Train MSE Loss: 4.4178 , Train MRRMSE Loss: 1.2185, Valid MSE Loss: 1.7133, Valid MRRMSE Loss: 1.1883


100%|██████████| 16/16 [01:09<00:00,  4.32s/it]


Epoch 14/1000: Train MSE Loss: 4.3468 , Train MRRMSE Loss: 1.2102, Valid MSE Loss: 0.5647, Valid MRRMSE Loss: 1.1952


100%|██████████| 16/16 [01:09<00:00,  4.34s/it]


Epoch 15/1000: Train MSE Loss: 4.5291 , Train MRRMSE Loss: 1.2322, Valid MSE Loss: 7.7375, Valid MRRMSE Loss: 1.1901


100%|██████████| 16/16 [01:09<00:00,  4.34s/it]


Epoch 16/1000: Train MSE Loss: 4.3974 , Train MRRMSE Loss: 1.2230, Valid MSE Loss: 8.1140, Valid MRRMSE Loss: 1.1940


100%|██████████| 16/16 [01:09<00:00,  4.31s/it]


Epoch 17/1000: Train MSE Loss: 4.1306 , Train MRRMSE Loss: 1.1911, Valid MSE Loss: 1.5196, Valid MRRMSE Loss: 1.2049


100%|██████████| 16/16 [01:08<00:00,  4.30s/it]


Epoch 18/1000: Train MSE Loss: 3.9560 , Train MRRMSE Loss: 1.1746, Valid MSE Loss: 0.7660, Valid MRRMSE Loss: 1.1999


100%|██████████| 16/16 [01:09<00:00,  4.35s/it]


Epoch 19/1000: Train MSE Loss: 4.1184 , Train MRRMSE Loss: 1.1993, Valid MSE Loss: 5.2893, Valid MRRMSE Loss: 1.1914


100%|██████████| 16/16 [01:09<00:00,  4.32s/it]


Epoch 20/1000: Train MSE Loss: 4.1928 , Train MRRMSE Loss: 1.2068, Valid MSE Loss: 4.5221, Valid MRRMSE Loss: 1.1978


100%|██████████| 16/16 [01:08<00:00,  4.30s/it]


Epoch 21/1000: Train MSE Loss: 4.1106 , Train MRRMSE Loss: 1.1939, Valid MSE Loss: 4.1201, Valid MRRMSE Loss: 1.2007


100%|██████████| 16/16 [01:08<00:00,  4.30s/it]


Epoch 22/1000: Train MSE Loss: 3.8828 , Train MRRMSE Loss: 1.1811, Valid MSE Loss: 2.5732, Valid MRRMSE Loss: 1.2111


100%|██████████| 16/16 [01:08<00:00,  4.31s/it]


Epoch 23/1000: Train MSE Loss: 3.9909 , Train MRRMSE Loss: 1.1801, Valid MSE Loss: 7.6411, Valid MRRMSE Loss: 1.2070


100%|██████████| 16/16 [01:09<00:00,  4.33s/it]


Epoch 24/1000: Train MSE Loss: 3.9029 , Train MRRMSE Loss: 1.1825, Valid MSE Loss: 4.2676, Valid MRRMSE Loss: 1.2107


100%|██████████| 16/16 [01:09<00:00,  4.34s/it]


Epoch 25/1000: Train MSE Loss: 3.7200 , Train MRRMSE Loss: 1.1602, Valid MSE Loss: 0.8388, Valid MRRMSE Loss: 1.2197


100%|██████████| 16/16 [01:09<00:00,  4.36s/it]


Epoch 26/1000: Train MSE Loss: 3.6407 , Train MRRMSE Loss: 1.1639, Valid MSE Loss: 1.2481, Valid MRRMSE Loss: 1.2245


100%|██████████| 16/16 [01:09<00:00,  4.33s/it]


Epoch 27/1000: Train MSE Loss: 3.6176 , Train MRRMSE Loss: 1.1546, Valid MSE Loss: 1.6393, Valid MRRMSE Loss: 1.2216


100%|██████████| 16/16 [01:09<00:00,  4.33s/it]


Epoch 28/1000: Train MSE Loss: 3.6868 , Train MRRMSE Loss: 1.1634, Valid MSE Loss: 2.9973, Valid MRRMSE Loss: 1.2169


100%|██████████| 16/16 [01:09<00:00,  4.32s/it]


Epoch 29/1000: Train MSE Loss: 3.6741 , Train MRRMSE Loss: 1.1661, Valid MSE Loss: 0.8942, Valid MRRMSE Loss: 1.2232


100%|██████████| 16/16 [01:09<00:00,  4.36s/it]


Epoch 30/1000: Train MSE Loss: 3.5735 , Train MRRMSE Loss: 1.1549, Valid MSE Loss: 1.2095, Valid MRRMSE Loss: 1.2189


100%|██████████| 16/16 [01:09<00:00,  4.35s/it]


Epoch 31/1000: Train MSE Loss: 3.6111 , Train MRRMSE Loss: 1.1733, Valid MSE Loss: 1.2994, Valid MRRMSE Loss: 1.2332


100%|██████████| 16/16 [01:09<00:00,  4.34s/it]


Epoch 32/1000: Train MSE Loss: 3.5665 , Train MRRMSE Loss: 1.1572, Valid MSE Loss: 4.3126, Valid MRRMSE Loss: 1.2316


100%|██████████| 16/16 [01:09<00:00,  4.34s/it]


Epoch 33/1000: Train MSE Loss: 3.5550 , Train MRRMSE Loss: 1.1697, Valid MSE Loss: 1.6954, Valid MRRMSE Loss: 1.2322


100%|██████████| 16/16 [01:09<00:00,  4.34s/it]


Epoch 34/1000: Train MSE Loss: 3.4735 , Train MRRMSE Loss: 1.1458, Valid MSE Loss: 0.6147, Valid MRRMSE Loss: 1.2322


100%|██████████| 16/16 [01:09<00:00,  4.35s/it]


Epoch 35/1000: Train MSE Loss: 3.3333 , Train MRRMSE Loss: 1.1266, Valid MSE Loss: 0.7201, Valid MRRMSE Loss: 1.2145


100%|██████████| 16/16 [01:09<00:00,  4.36s/it]


Epoch 36/1000: Train MSE Loss: 3.5452 , Train MRRMSE Loss: 1.1576, Valid MSE Loss: 4.3935, Valid MRRMSE Loss: 1.2200


100%|██████████| 16/16 [01:09<00:00,  4.37s/it]


Epoch 37/1000: Train MSE Loss: 3.3864 , Train MRRMSE Loss: 1.1522, Valid MSE Loss: 1.0912, Valid MRRMSE Loss: 1.2624


100%|██████████| 16/16 [01:09<00:00,  4.35s/it]


Epoch 38/1000: Train MSE Loss: 3.3514 , Train MRRMSE Loss: 1.1302, Valid MSE Loss: 0.7529, Valid MRRMSE Loss: 1.2342


100%|██████████| 16/16 [01:09<00:00,  4.37s/it]


Epoch 39/1000: Train MSE Loss: 3.2606 , Train MRRMSE Loss: 1.1166, Valid MSE Loss: 0.6880, Valid MRRMSE Loss: 1.2476


100%|██████████| 16/16 [01:10<00:00,  4.38s/it]


Epoch 40/1000: Train MSE Loss: 3.4789 , Train MRRMSE Loss: 1.1495, Valid MSE Loss: 6.4860, Valid MRRMSE Loss: 1.2358


100%|██████████| 16/16 [01:09<00:00,  4.36s/it]


Epoch 41/1000: Train MSE Loss: 3.5897 , Train MRRMSE Loss: 1.1791, Valid MSE Loss: 7.1397, Valid MRRMSE Loss: 1.2335


100%|██████████| 16/16 [01:09<00:00,  4.37s/it]


Epoch 42/1000: Train MSE Loss: 3.2643 , Train MRRMSE Loss: 1.1346, Valid MSE Loss: 0.9471, Valid MRRMSE Loss: 1.2405


100%|██████████| 16/16 [01:09<00:00,  4.36s/it]


Epoch 43/1000: Train MSE Loss: 3.3689 , Train MRRMSE Loss: 1.1413, Valid MSE Loss: 1.7061, Valid MRRMSE Loss: 1.2270


100%|██████████| 16/16 [01:09<00:00,  4.35s/it]


Epoch 44/1000: Train MSE Loss: 3.1663 , Train MRRMSE Loss: 1.1281, Valid MSE Loss: 1.4037, Valid MRRMSE Loss: 1.2435


100%|██████████| 16/16 [01:09<00:00,  4.35s/it]


Epoch 45/1000: Train MSE Loss: 3.3692 , Train MRRMSE Loss: 1.1578, Valid MSE Loss: 1.0458, Valid MRRMSE Loss: 1.2334


100%|██████████| 16/16 [01:09<00:00,  4.33s/it]


Epoch 46/1000: Train MSE Loss: 3.1668 , Train MRRMSE Loss: 1.1323, Valid MSE Loss: 0.9770, Valid MRRMSE Loss: 1.2372


100%|██████████| 16/16 [01:09<00:00,  4.34s/it]


Epoch 47/1000: Train MSE Loss: 3.4592 , Train MRRMSE Loss: 1.1440, Valid MSE Loss: 6.9670, Valid MRRMSE Loss: 1.2097


100%|██████████| 16/16 [01:09<00:00,  4.36s/it]


Epoch 48/1000: Train MSE Loss: 3.2808 , Train MRRMSE Loss: 1.1602, Valid MSE Loss: 0.9986, Valid MRRMSE Loss: 1.2412


100%|██████████| 16/16 [01:09<00:00,  4.35s/it]


Epoch 49/1000: Train MSE Loss: 3.1906 , Train MRRMSE Loss: 1.1331, Valid MSE Loss: 1.7239, Valid MRRMSE Loss: 1.2183


100%|██████████| 16/16 [01:10<00:00,  4.41s/it]


Epoch 50/1000: Train MSE Loss: 3.1530 , Train MRRMSE Loss: 1.1229, Valid MSE Loss: 0.8026, Valid MRRMSE Loss: 1.2370


100%|██████████| 16/16 [01:10<00:00,  4.39s/it]


Epoch 51/1000: Train MSE Loss: 3.1152 , Train MRRMSE Loss: 1.1197, Valid MSE Loss: 1.5949, Valid MRRMSE Loss: 1.2462


100%|██████████| 16/16 [01:13<00:00,  4.61s/it]


Epoch 52/1000: Train MSE Loss: 3.0965 , Train MRRMSE Loss: 1.1282, Valid MSE Loss: 0.9262, Valid MRRMSE Loss: 1.2365


100%|██████████| 16/16 [01:12<00:00,  4.56s/it]


Epoch 53/1000: Train MSE Loss: 2.9887 , Train MRRMSE Loss: 1.1097, Valid MSE Loss: 4.5675, Valid MRRMSE Loss: 1.2343


100%|██████████| 16/16 [01:12<00:00,  4.55s/it]


Epoch 54/1000: Train MSE Loss: 3.1556 , Train MRRMSE Loss: 1.1366, Valid MSE Loss: 4.0297, Valid MRRMSE Loss: 1.2153


100%|██████████| 16/16 [01:13<00:00,  4.56s/it]


Epoch 55/1000: Train MSE Loss: 2.9508 , Train MRRMSE Loss: 1.1158, Valid MSE Loss: 0.7552, Valid MRRMSE Loss: 1.2456


100%|██████████| 16/16 [01:12<00:00,  4.56s/it]


Epoch 56/1000: Train MSE Loss: 3.1755 , Train MRRMSE Loss: 1.1379, Valid MSE Loss: 1.0923, Valid MRRMSE Loss: 1.2165


100%|██████████| 16/16 [01:14<00:00,  4.64s/it]


Epoch 57/1000: Train MSE Loss: 2.9232 , Train MRRMSE Loss: 1.1091, Valid MSE Loss: 1.4928, Valid MRRMSE Loss: 1.2233


100%|██████████| 16/16 [01:14<00:00,  4.64s/it]


Epoch 58/1000: Train MSE Loss: 3.1800 , Train MRRMSE Loss: 1.1378, Valid MSE Loss: 0.8621, Valid MRRMSE Loss: 1.2260


100%|██████████| 16/16 [01:14<00:00,  4.67s/it]


Epoch 59/1000: Train MSE Loss: 2.8288 , Train MRRMSE Loss: 1.0840, Valid MSE Loss: 0.7735, Valid MRRMSE Loss: 1.2218


100%|██████████| 16/16 [01:14<00:00,  4.65s/it]


Epoch 60/1000: Train MSE Loss: 3.1180 , Train MRRMSE Loss: 1.1271, Valid MSE Loss: 2.1613, Valid MRRMSE Loss: 1.2396


100%|██████████| 16/16 [01:12<00:00,  4.52s/it]


Epoch 61/1000: Train MSE Loss: 3.0007 , Train MRRMSE Loss: 1.1233, Valid MSE Loss: 1.0067, Valid MRRMSE Loss: 1.2470


100%|██████████| 16/16 [01:11<00:00,  4.46s/it]


Epoch 62/1000: Train MSE Loss: 3.2438 , Train MRRMSE Loss: 1.1554, Valid MSE Loss: 1.5013, Valid MRRMSE Loss: 1.2179


100%|██████████| 16/16 [01:11<00:00,  4.46s/it]


Epoch 63/1000: Train MSE Loss: 3.0889 , Train MRRMSE Loss: 1.1299, Valid MSE Loss: 2.7326, Valid MRRMSE Loss: 1.2241


100%|██████████| 16/16 [01:13<00:00,  4.57s/it]


Epoch 64/1000: Train MSE Loss: 2.9065 , Train MRRMSE Loss: 1.1172, Valid MSE Loss: 0.9356, Valid MRRMSE Loss: 1.2480


100%|██████████| 16/16 [01:12<00:00,  4.51s/it]


Epoch 65/1000: Train MSE Loss: 2.8676 , Train MRRMSE Loss: 1.1145, Valid MSE Loss: 0.7612, Valid MRRMSE Loss: 1.2313


100%|██████████| 16/16 [01:10<00:00,  4.43s/it]


Epoch 66/1000: Train MSE Loss: 3.0749 , Train MRRMSE Loss: 1.1247, Valid MSE Loss: 2.0842, Valid MRRMSE Loss: 1.2429


100%|██████████| 16/16 [01:09<00:00,  4.36s/it]


Epoch 67/1000: Train MSE Loss: 2.9859 , Train MRRMSE Loss: 1.1297, Valid MSE Loss: 1.2251, Valid MRRMSE Loss: 1.2274


100%|██████████| 16/16 [01:09<00:00,  4.37s/it]


Epoch 68/1000: Train MSE Loss: 2.9477 , Train MRRMSE Loss: 1.1303, Valid MSE Loss: 2.0022, Valid MRRMSE Loss: 1.2268


100%|██████████| 16/16 [01:10<00:00,  4.40s/it]


Epoch 69/1000: Train MSE Loss: 2.8465 , Train MRRMSE Loss: 1.1000, Valid MSE Loss: 0.8576, Valid MRRMSE Loss: 1.2103


100%|██████████| 16/16 [01:10<00:00,  4.39s/it]


Epoch 70/1000: Train MSE Loss: 3.6091 , Train MRRMSE Loss: 1.1709, Valid MSE Loss: 14.5543, Valid MRRMSE Loss: 1.2112


100%|██████████| 16/16 [01:10<00:00,  4.40s/it]


Epoch 71/1000: Train MSE Loss: 2.9275 , Train MRRMSE Loss: 1.1192, Valid MSE Loss: 1.3324, Valid MRRMSE Loss: 1.2108


100%|██████████| 16/16 [01:10<00:00,  4.41s/it]


Epoch 72/1000: Train MSE Loss: 3.0551 , Train MRRMSE Loss: 1.1379, Valid MSE Loss: 2.2633, Valid MRRMSE Loss: 1.2101


100%|██████████| 16/16 [01:10<00:00,  4.42s/it]


Epoch 73/1000: Train MSE Loss: 2.7814 , Train MRRMSE Loss: 1.1029, Valid MSE Loss: 1.2227, Valid MRRMSE Loss: 1.2456


100%|██████████| 16/16 [01:10<00:00,  4.42s/it]


Epoch 74/1000: Train MSE Loss: 2.8024 , Train MRRMSE Loss: 1.1011, Valid MSE Loss: 1.4009, Valid MRRMSE Loss: 1.2378


100%|██████████| 16/16 [01:10<00:00,  4.39s/it]


Epoch 75/1000: Train MSE Loss: 2.8398 , Train MRRMSE Loss: 1.0953, Valid MSE Loss: 0.8995, Valid MRRMSE Loss: 1.2384


100%|██████████| 16/16 [01:10<00:00,  4.41s/it]


Epoch 76/1000: Train MSE Loss: 3.0222 , Train MRRMSE Loss: 1.1445, Valid MSE Loss: 1.9371, Valid MRRMSE Loss: 1.2266


100%|██████████| 16/16 [01:10<00:00,  4.40s/it]


Epoch 77/1000: Train MSE Loss: 2.7687 , Train MRRMSE Loss: 1.0969, Valid MSE Loss: 0.9688, Valid MRRMSE Loss: 1.2375


100%|██████████| 16/16 [01:10<00:00,  4.39s/it]


Epoch 78/1000: Train MSE Loss: 3.0254 , Train MRRMSE Loss: 1.1257, Valid MSE Loss: 4.8833, Valid MRRMSE Loss: 1.2208


100%|██████████| 16/16 [01:10<00:00,  4.43s/it]


Epoch 79/1000: Train MSE Loss: 2.7618 , Train MRRMSE Loss: 1.1030, Valid MSE Loss: 1.1182, Valid MRRMSE Loss: 1.2429


100%|██████████| 16/16 [01:10<00:00,  4.38s/it]


Epoch 80/1000: Train MSE Loss: 2.8393 , Train MRRMSE Loss: 1.1193, Valid MSE Loss: 1.5791, Valid MRRMSE Loss: 1.2201


100%|██████████| 16/16 [01:10<00:00,  4.40s/it]


Epoch 81/1000: Train MSE Loss: 2.8723 , Train MRRMSE Loss: 1.1167, Valid MSE Loss: 0.9071, Valid MRRMSE Loss: 1.2260


100%|██████████| 16/16 [01:10<00:00,  4.39s/it]


Epoch 82/1000: Train MSE Loss: 2.8545 , Train MRRMSE Loss: 1.1280, Valid MSE Loss: 2.2946, Valid MRRMSE Loss: 1.2205


100%|██████████| 16/16 [01:10<00:00,  4.41s/it]


Epoch 83/1000: Train MSE Loss: 2.8639 , Train MRRMSE Loss: 1.1166, Valid MSE Loss: 3.9665, Valid MRRMSE Loss: 1.2354


100%|██████████| 16/16 [01:09<00:00,  4.36s/it]


Epoch 84/1000: Train MSE Loss: 3.7386 , Train MRRMSE Loss: 1.1743, Valid MSE Loss: 22.2121, Valid MRRMSE Loss: 1.2156


100%|██████████| 16/16 [01:10<00:00,  4.40s/it]


Epoch 85/1000: Train MSE Loss: 2.8933 , Train MRRMSE Loss: 1.1035, Valid MSE Loss: 5.5330, Valid MRRMSE Loss: 1.2110


100%|██████████| 16/16 [01:10<00:00,  4.39s/it]


Epoch 86/1000: Train MSE Loss: 3.6567 , Train MRRMSE Loss: 1.1589, Valid MSE Loss: 21.2227, Valid MRRMSE Loss: 1.2050


100%|██████████| 16/16 [01:10<00:00,  4.40s/it]


Epoch 87/1000: Train MSE Loss: 2.7999 , Train MRRMSE Loss: 1.1046, Valid MSE Loss: 2.9250, Valid MRRMSE Loss: 1.2069


100%|██████████| 16/16 [01:10<00:00,  4.41s/it]


Epoch 88/1000: Train MSE Loss: 3.2722 , Train MRRMSE Loss: 1.1487, Valid MSE Loss: 12.1039, Valid MRRMSE Loss: 1.2011


100%|██████████| 16/16 [01:09<00:00,  4.37s/it]


Epoch 89/1000: Train MSE Loss: 2.7789 , Train MRRMSE Loss: 1.1079, Valid MSE Loss: 3.1198, Valid MRRMSE Loss: 1.2316


100%|██████████| 16/16 [01:09<00:00,  4.37s/it]


Epoch 90/1000: Train MSE Loss: 2.9670 , Train MRRMSE Loss: 1.1393, Valid MSE Loss: 1.9652, Valid MRRMSE Loss: 1.2212


100%|██████████| 16/16 [01:09<00:00,  4.36s/it]


Epoch 91/1000: Train MSE Loss: 2.8065 , Train MRRMSE Loss: 1.1221, Valid MSE Loss: 0.8201, Valid MRRMSE Loss: 1.2172


100%|██████████| 16/16 [01:10<00:00,  4.38s/it]


Epoch 92/1000: Train MSE Loss: 2.6301 , Train MRRMSE Loss: 1.0858, Valid MSE Loss: 1.0521, Valid MRRMSE Loss: 1.2656


100%|██████████| 16/16 [01:09<00:00,  4.35s/it]


Epoch 93/1000: Train MSE Loss: 2.8188 , Train MRRMSE Loss: 1.1114, Valid MSE Loss: 5.3731, Valid MRRMSE Loss: 1.2129


100%|██████████| 16/16 [01:09<00:00,  4.35s/it]


Epoch 94/1000: Train MSE Loss: 2.7260 , Train MRRMSE Loss: 1.0939, Valid MSE Loss: 3.7269, Valid MRRMSE Loss: 1.2400


100%|██████████| 16/16 [01:09<00:00,  4.36s/it]


Epoch 95/1000: Train MSE Loss: 2.6215 , Train MRRMSE Loss: 1.0957, Valid MSE Loss: 2.4353, Valid MRRMSE Loss: 1.2266


100%|██████████| 16/16 [01:09<00:00,  4.36s/it]


Epoch 96/1000: Train MSE Loss: 2.7065 , Train MRRMSE Loss: 1.1000, Valid MSE Loss: 3.0114, Valid MRRMSE Loss: 1.2183


100%|██████████| 16/16 [01:10<00:00,  4.39s/it]


Epoch 97/1000: Train MSE Loss: 2.9457 , Train MRRMSE Loss: 1.1257, Valid MSE Loss: 8.3505, Valid MRRMSE Loss: 1.2070


100%|██████████| 16/16 [01:10<00:00,  4.39s/it]


Epoch 98/1000: Train MSE Loss: 2.6486 , Train MRRMSE Loss: 1.0898, Valid MSE Loss: 2.8304, Valid MRRMSE Loss: 1.2217


100%|██████████| 16/16 [01:10<00:00,  4.39s/it]


Epoch 99/1000: Train MSE Loss: 2.7918 , Train MRRMSE Loss: 1.1231, Valid MSE Loss: 0.7994, Valid MRRMSE Loss: 1.2191


100%|██████████| 16/16 [01:09<00:00,  4.36s/it]


Epoch 100/1000: Train MSE Loss: 2.5500 , Train MRRMSE Loss: 1.0730, Valid MSE Loss: 2.1943, Valid MRRMSE Loss: 1.2227


100%|██████████| 16/16 [01:09<00:00,  4.37s/it]


Epoch 101/1000: Train MSE Loss: 2.7516 , Train MRRMSE Loss: 1.1229, Valid MSE Loss: 1.6438, Valid MRRMSE Loss: 1.2344


100%|██████████| 16/16 [01:10<00:00,  4.38s/it]


Epoch 102/1000: Train MSE Loss: 2.6374 , Train MRRMSE Loss: 1.0939, Valid MSE Loss: 1.1854, Valid MRRMSE Loss: 1.2313


100%|██████████| 16/16 [01:09<00:00,  4.35s/it]


Epoch 103/1000: Train MSE Loss: 3.0582 , Train MRRMSE Loss: 1.1602, Valid MSE Loss: 6.9971, Valid MRRMSE Loss: 1.2139


100%|██████████| 16/16 [01:10<00:00,  4.40s/it]


Epoch 104/1000: Train MSE Loss: 2.8708 , Train MRRMSE Loss: 1.1293, Valid MSE Loss: 6.8303, Valid MRRMSE Loss: 1.2060


100%|██████████| 16/16 [01:09<00:00,  4.36s/it]


Epoch 105/1000: Train MSE Loss: 2.5480 , Train MRRMSE Loss: 1.0844, Valid MSE Loss: 0.9699, Valid MRRMSE Loss: 1.2181


100%|██████████| 16/16 [01:09<00:00,  4.35s/it]


Epoch 106/1000: Train MSE Loss: 2.7224 , Train MRRMSE Loss: 1.0916, Valid MSE Loss: 4.8705, Valid MRRMSE Loss: 1.1997


100%|██████████| 16/16 [01:10<00:00,  4.41s/it]


Epoch 107/1000: Train MSE Loss: 2.7423 , Train MRRMSE Loss: 1.1024, Valid MSE Loss: 4.1968, Valid MRRMSE Loss: 1.2074


100%|██████████| 16/16 [01:10<00:00,  4.39s/it]


Epoch 108/1000: Train MSE Loss: 2.6205 , Train MRRMSE Loss: 1.0979, Valid MSE Loss: 0.9382, Valid MRRMSE Loss: 1.2443


100%|██████████| 16/16 [01:09<00:00,  4.36s/it]


Epoch 109/1000: Train MSE Loss: 2.8130 , Train MRRMSE Loss: 1.1051, Valid MSE Loss: 5.5822, Valid MRRMSE Loss: 1.2060


100%|██████████| 16/16 [01:09<00:00,  4.37s/it]


Epoch 110/1000: Train MSE Loss: 2.7351 , Train MRRMSE Loss: 1.0966, Valid MSE Loss: 8.3023, Valid MRRMSE Loss: 1.2124


100%|██████████| 16/16 [01:09<00:00,  4.37s/it]


Epoch 111/1000: Train MSE Loss: 2.6811 , Train MRRMSE Loss: 1.1052, Valid MSE Loss: 1.9757, Valid MRRMSE Loss: 1.2305


100%|██████████| 16/16 [01:10<00:00,  4.40s/it]


Epoch 112/1000: Train MSE Loss: 2.7015 , Train MRRMSE Loss: 1.1356, Valid MSE Loss: 3.3634, Valid MRRMSE Loss: 1.2204


100%|██████████| 16/16 [01:10<00:00,  4.38s/it]


Epoch 113/1000: Train MSE Loss: 2.6871 , Train MRRMSE Loss: 1.1108, Valid MSE Loss: 5.4284, Valid MRRMSE Loss: 1.2181


100%|██████████| 16/16 [01:10<00:00,  4.41s/it]


Epoch 114/1000: Train MSE Loss: 2.6157 , Train MRRMSE Loss: 1.1043, Valid MSE Loss: 0.9830, Valid MRRMSE Loss: 1.2234


100%|██████████| 16/16 [01:10<00:00,  4.40s/it]


Epoch 115/1000: Train MSE Loss: 2.4848 , Train MRRMSE Loss: 1.0783, Valid MSE Loss: 1.5531, Valid MRRMSE Loss: 1.2346


100%|██████████| 16/16 [01:10<00:00,  4.38s/it]


Epoch 116/1000: Train MSE Loss: 2.5733 , Train MRRMSE Loss: 1.0912, Valid MSE Loss: 5.1050, Valid MRRMSE Loss: 1.2313


100%|██████████| 16/16 [01:10<00:00,  4.39s/it]


Epoch 117/1000: Train MSE Loss: 2.4146 , Train MRRMSE Loss: 1.0700, Valid MSE Loss: 1.1104, Valid MRRMSE Loss: 1.2214


100%|██████████| 16/16 [01:10<00:00,  4.40s/it]


Epoch 118/1000: Train MSE Loss: 2.6178 , Train MRRMSE Loss: 1.1123, Valid MSE Loss: 0.9858, Valid MRRMSE Loss: 1.2502


100%|██████████| 16/16 [01:10<00:00,  4.42s/it]


Epoch 119/1000: Train MSE Loss: 2.4369 , Train MRRMSE Loss: 1.0859, Valid MSE Loss: 1.9994, Valid MRRMSE Loss: 1.2185


100%|██████████| 16/16 [01:10<00:00,  4.38s/it]


Epoch 120/1000: Train MSE Loss: 2.5373 , Train MRRMSE Loss: 1.0919, Valid MSE Loss: 1.0086, Valid MRRMSE Loss: 1.2408


100%|██████████| 16/16 [01:09<00:00,  4.37s/it]


Epoch 121/1000: Train MSE Loss: 2.4964 , Train MRRMSE Loss: 1.0745, Valid MSE Loss: 2.9217, Valid MRRMSE Loss: 1.2151


100%|██████████| 16/16 [01:10<00:00,  4.38s/it]


Epoch 122/1000: Train MSE Loss: 2.8590 , Train MRRMSE Loss: 1.1389, Valid MSE Loss: 5.2194, Valid MRRMSE Loss: 1.2253


100%|██████████| 16/16 [01:09<00:00,  4.36s/it]


Epoch 123/1000: Train MSE Loss: 2.4248 , Train MRRMSE Loss: 1.0765, Valid MSE Loss: 1.0043, Valid MRRMSE Loss: 1.2191


100%|██████████| 16/16 [01:10<00:00,  4.38s/it]


Epoch 124/1000: Train MSE Loss: 2.5577 , Train MRRMSE Loss: 1.0800, Valid MSE Loss: 1.6461, Valid MRRMSE Loss: 1.2294


100%|██████████| 16/16 [01:10<00:00,  4.40s/it]


Epoch 125/1000: Train MSE Loss: 2.4672 , Train MRRMSE Loss: 1.0776, Valid MSE Loss: 1.7760, Valid MRRMSE Loss: 1.2443


100%|██████████| 16/16 [01:10<00:00,  4.40s/it]


Epoch 126/1000: Train MSE Loss: 2.5605 , Train MRRMSE Loss: 1.0779, Valid MSE Loss: 2.1801, Valid MRRMSE Loss: 1.2506


100%|██████████| 16/16 [01:09<00:00,  4.37s/it]


Epoch 127/1000: Train MSE Loss: 2.5346 , Train MRRMSE Loss: 1.0941, Valid MSE Loss: 1.1311, Valid MRRMSE Loss: 1.2365


100%|██████████| 16/16 [01:09<00:00,  4.37s/it]


Epoch 128/1000: Train MSE Loss: 2.5082 , Train MRRMSE Loss: 1.0823, Valid MSE Loss: 5.6784, Valid MRRMSE Loss: 1.2189


100%|██████████| 16/16 [01:09<00:00,  4.36s/it]


Epoch 129/1000: Train MSE Loss: 2.3957 , Train MRRMSE Loss: 1.0708, Valid MSE Loss: 1.9792, Valid MRRMSE Loss: 1.2257


100%|██████████| 16/16 [01:10<00:00,  4.40s/it]


Epoch 130/1000: Train MSE Loss: 2.3308 , Train MRRMSE Loss: 1.0485, Valid MSE Loss: 0.9394, Valid MRRMSE Loss: 1.2269


100%|██████████| 16/16 [01:09<00:00,  4.36s/it]


Epoch 131/1000: Train MSE Loss: 2.4410 , Train MRRMSE Loss: 1.0696, Valid MSE Loss: 2.9555, Valid MRRMSE Loss: 1.2070


100%|██████████| 16/16 [01:09<00:00,  4.34s/it]


Epoch 132/1000: Train MSE Loss: 2.4329 , Train MRRMSE Loss: 1.1011, Valid MSE Loss: 0.9031, Valid MRRMSE Loss: 1.2195


100%|██████████| 16/16 [01:09<00:00,  4.35s/it]


Epoch 133/1000: Train MSE Loss: 2.4165 , Train MRRMSE Loss: 1.0792, Valid MSE Loss: 1.2237, Valid MRRMSE Loss: 1.2332


100%|██████████| 16/16 [01:09<00:00,  4.34s/it]


Epoch 134/1000: Train MSE Loss: 2.4808 , Train MRRMSE Loss: 1.0798, Valid MSE Loss: 1.6802, Valid MRRMSE Loss: 1.2352


100%|██████████| 16/16 [01:09<00:00,  4.33s/it]


Epoch 135/1000: Train MSE Loss: 2.4848 , Train MRRMSE Loss: 1.0914, Valid MSE Loss: 1.0154, Valid MRRMSE Loss: 1.2304


100%|██████████| 16/16 [01:09<00:00,  4.34s/it]


Epoch 136/1000: Train MSE Loss: 2.4062 , Train MRRMSE Loss: 1.0671, Valid MSE Loss: 0.9963, Valid MRRMSE Loss: 1.2116


100%|██████████| 16/16 [01:09<00:00,  4.36s/it]


Epoch 137/1000: Train MSE Loss: 2.5384 , Train MRRMSE Loss: 1.0988, Valid MSE Loss: 1.4527, Valid MRRMSE Loss: 1.2333


100%|██████████| 16/16 [01:09<00:00,  4.34s/it]


Epoch 138/1000: Train MSE Loss: 2.4981 , Train MRRMSE Loss: 1.0787, Valid MSE Loss: 1.8847, Valid MRRMSE Loss: 1.2294


100%|██████████| 16/16 [01:09<00:00,  4.34s/it]


Epoch 139/1000: Train MSE Loss: 2.3090 , Train MRRMSE Loss: 1.0540, Valid MSE Loss: 0.7441, Valid MRRMSE Loss: 1.2321


100%|██████████| 16/16 [01:09<00:00,  4.35s/it]


Epoch 140/1000: Train MSE Loss: 2.4939 , Train MRRMSE Loss: 1.0886, Valid MSE Loss: 0.9581, Valid MRRMSE Loss: 1.2264


100%|██████████| 16/16 [01:09<00:00,  4.34s/it]


Epoch 141/1000: Train MSE Loss: 2.4063 , Train MRRMSE Loss: 1.0686, Valid MSE Loss: 1.7909, Valid MRRMSE Loss: 1.2446


 62%|██████▎   | 10/16 [00:45<00:27,  4.51s/it]

In [55]:
for batch in tqdm.tqdm(train_loader):
    cell=batch['cell'].to(device)
    drug=batch['cell'].to(device)
    y=batch['y'].to(device)
    outputs = model(cell,drug)
    loss = nn.L1Loss()(outputs,y)
    break

  0%|          | 0/16 [00:02<?, ?it/s]


In [56]:
loss

tensor(1.1626, grad_fn=<MeanBackward0>)

In [57]:
y

tensor([[-0.7790, -0.3305, -1.1666,  ...,  0.7030,  0.8946,  0.2529],
        [ 0.4641,  0.1390,  0.6811,  ...,  0.5811,  0.4223, -0.7910],
        [ 0.0288, -0.7789, -3.6663,  ...,  1.1905, -3.5731,  0.8906],
        ...,
        [ 0.1066,  0.0069,  0.4245,  ..., -0.6688,  0.7983, -0.0600],
        [ 0.6422, -0.4220,  0.4593,  ...,  0.0229, -0.0209,  0.4794],
        [ 2.2251,  0.5643, -0.4376,  ...,  3.3962, -0.5671,  0.5014]])

In [58]:
outputs

tensor([[ 6.2332e-02, -2.1571e-01,  2.4501e-02,  ...,  6.1414e-02,
          8.4924e-02, -1.6814e-01],
        [ 3.8393e-01, -4.5674e-01, -5.3927e-01,  ..., -4.0812e-02,
         -2.8494e-01,  1.1883e-03],
        [ 7.3435e-02, -1.4179e-01, -2.9638e-01,  ...,  3.4764e-01,
         -1.1420e-01, -1.3031e-01],
        ...,
        [ 4.1656e-02,  4.7643e-01, -4.7299e-01,  ..., -6.7893e-01,
         -1.8814e-01, -1.7352e-01],
        [-3.7704e-01, -1.2595e-02, -1.6405e-01,  ..., -1.5800e-01,
         -2.3328e-01, -1.0147e-01],
        [ 1.2016e+00,  1.8736e-01,  4.0568e-01,  ...,  5.7204e-01,
         -4.1059e-01,  8.9473e-01]], grad_fn=<AddmmBackward0>)