In [1]:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cuda:0')
# gpustat -cuFi 1
parser.add_argument('--seed', type=int, default=42)

# learning params
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--hdim', type=float, default=64)
parser.add_argument('--batch_size', type=float, default=128)

args = parser.parse_args([])

In [2]:
import pandas as pd
import numpy as np
import sys

import torch
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
sys.path.append('../')
from utils_dm import EarlyStopper, set_seed

In [4]:
model_name = f'Solubility_ECFP_MLP_h{args.hdim}b{args.batch_size}_lr{args.lr}'

In [5]:
set_seed(args.seed)

random seed with 42


In [6]:
dataset = 'Solubility_AqSolDB'

traindf = pd.read_csv(f'../../../2023-2/processed_data/ECFP/{dataset}_train_ECFP_R2B1024.csv', index_col=0)
validdf = pd.read_csv(f'../../../2023-2/processed_data/ECFP/{dataset}_valid_ECFP_R2B1024.csv', index_col=0)
testdf = pd.read_csv(f'../../../2023-2/processed_data/ECFP/{dataset}_test_ECFP_R2B1024.csv', index_col=0)

In [7]:
traindf

Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,1015,1016,1017,1018,1019,1020,1021,1022,1023,label
Drug_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Benzo[cd]indol-2(1H)-one,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,-3.254767
4-chlorobenzaldehyde,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,-2.177078
"4-({4-[bis(oxiran-2-ylmethyl)amino]phenyl}methyl)-N,N-bis(oxiran-2-ylmethyl)aniline",0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,1,0,0,0,0,-4.662065
vinyltoluene,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,-3.123150
3-(3-ethylcyclopentyl)propanoic acid,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,1,0,0,0,0,-3.286116
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
sarafloxacin,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,-3.130000
sparfloxacin,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,1,0,0,0,0,-3.370000
sulindac_form_II,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,-4.500000
tetracaine,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,-3.010000


In [8]:
from torch.utils.data import Dataset, DataLoader

In [9]:
class MyDataset(Dataset):
    def __init__(self, dataset, labels):
        self.dataset = torch.tensor(dataset).float()
        self.labels = torch.tensor(labels, dtype=torch.float32)
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx], self.labels[idx]

In [10]:
train_data, train_labels = traindf.values[:, :-1],  traindf.values[:, -1]
valid_data, valid_labels = validdf.values[:, :-1],  validdf.values[:, -1]
test_data, test_labels = testdf.values[:, :-1],  testdf.values[:, -1]
train_labels.shape, (train_labels==1).sum() # check number of positives

((6988,), 0)

In [11]:
trainset = MyDataset(train_data, train_labels)
validset = MyDataset(valid_data, valid_labels)
testset = MyDataset(test_data, test_labels)

In [12]:
# build dataloader
trainloader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=False,
                        generator=torch.Generator().manual_seed(42))
validloader = DataLoader(validset, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=False)
testloader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=False)

# Build model

In [13]:
import torch
import torch.nn as nn

class NeuralNetwork(nn.Module):
    def __init__(self, in_dim, hdim, out_dim=1, dropout=0.1):
        super(NeuralNetwork, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(in_dim,hdim),
            nn.LayerNorm(hdim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hdim,hdim)
        )
        self.prediction_head = nn.Sequential(
            nn.LayerNorm(hdim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hdim,out_dim)
        )

        
    def forward(self, x):
        x = self.encoder(x)
        x = self.prediction_head(x)
        return x

In [14]:
in_dim = trainset[0][0].shape[0]
model = NeuralNetwork(in_dim, args.hdim).to(args.device)

# Train

In [15]:
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

# criterion = nn.BCEWithLogitsLoss()
criterion = nn.MSELoss()

# binary cross entropy nn.BCEWithLogitsLoss()
# Mean squared error  nn.MSELoss()

early_stopper = EarlyStopper(patience=20,printfunc=print,verbose=True,path=f'ckpts/{model_name}.pt')

In [16]:
def train(model, trainloader, args, optimizer=optimizer, criterion=criterion):
    model.train()
    train_loss = 0
    for batch, label in trainloader:
        batch = batch.to(args.device)
        label = label.to(args.device)

        optimizer.zero_grad()
        pred = model(batch).squeeze()
        
        loss = criterion(pred, label)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    return train_loss/len(trainloader)

In [17]:
def eval(model, loader, args, return_output=False, criterion=criterion):
    model.eval()
    preds = []
    labels = []
    with torch.no_grad():
        for batch, label in loader:
            batch = batch.to(args.device)
            label = label.to(args.device)
            pred = model(batch)
            preds.append(pred)
            labels.append(label)
    preds = torch.cat(preds, dim=0)
    labels = torch.cat(labels, dim=0)
    
    loss = criterion(preds.squeeze(), labels.squeeze())

    if return_output:
        return loss.item(), preds, labels
    else:
        return loss.item()

In [18]:
epoch = 0
while True:
    epoch+=1
    train_loss = train(model,trainloader,args)
    valid_loss = eval(model,validloader,args)
    print(f'[Epoch{epoch}] train_loss: {train_loss:.4f}, valid_loss: {valid_loss:.4f}')
    early_stopper(valid_loss,model)
    if early_stopper.early_stop:
        print('early stopping')
        break

[Epoch1] train_loss: 11.7360, valid_loss: 7.9027
[Epoch2] train_loss: 7.0950, valid_loss: 5.9699
[Epoch3] train_loss: 6.0525, valid_loss: 5.3284
[Epoch4] train_loss: 5.5032, valid_loss: 4.8736
[Epoch5] train_loss: 4.9770, valid_loss: 4.3919
[Epoch6] train_loss: 4.4753, valid_loss: 3.9833
[Epoch7] train_loss: 4.0494, valid_loss: 3.6527
[Epoch8] train_loss: 3.6790, valid_loss: 3.3883
[Epoch9] train_loss: 3.3574, valid_loss: 3.1803
[Epoch10] train_loss: 3.0562, valid_loss: 3.0140
[Epoch11] train_loss: 2.8225, valid_loss: 2.8839
[Epoch12] train_loss: 2.6373, valid_loss: 2.7772
[Epoch13] train_loss: 2.4294, valid_loss: 2.6970
[Epoch14] train_loss: 2.2373, valid_loss: 2.6110
[Epoch15] train_loss: 2.1012, valid_loss: 2.5676
[Epoch16] train_loss: 1.9815, valid_loss: 2.5069
[Epoch17] train_loss: 1.8445, valid_loss: 2.4782
[Epoch18] train_loss: 1.7323, valid_loss: 2.4353
[Epoch19] train_loss: 1.6117, valid_loss: 2.3877
[Epoch20] train_loss: 1.5218, valid_loss: 2.3759
[Epoch21] train_loss: 1.4123

### Validate

In [19]:
model.load_state_dict(torch.load(early_stopper.path, map_location=args.device))
model.eval()
print(f'loaded best model "{early_stopper.path}", valid loss: {early_stopper.val_loss_min:.4f}')

loaded best model "ckpts/Solubility_ECFP_MLP_h64b128_lr0.0001.pt", valid loss: 2.0056


In [20]:
test_loss = eval(model,testloader,args)
print(f'{dataset}: Final test loss: {test_loss:.4f}')

# [A] Solubility_AqSolDB: Final test loss: 1.8527 - MSE
# [D] BBB_Martins: Final test loss: 0.3858 - BCE
# [M] CYP3A4_Veith: Final test loss: 0.4666 - BCE
# [E] Clearance_Hepatocyte_AZ: Final test loss: 3122.5181 - MSE

Solubility_AqSolDB: Final test loss: 1.8527
