In [None]:
import pandas as pd

import torch
from torch import optim
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import mean_squared_error
from math import sqrt

from common import *

#### Load data

In [None]:
train_data = pd.read_csv("../../data/ld50/train.csv")
test_data = pd.read_csv("../../data/ld50/test.csv")

y_train = train_data["ld50"]
y_test = test_data["ld50"]

x_train = train_data["smiles"]
x_test = test_data["smiles"]

x_train.describe(), y_train.describe()

#### Load model and extend layers

In [None]:
molformer, tokenizer = load_model()

model = torch.nn.Sequential(
            torch.nn.Linear(768, 1024),
            torch.nn.BatchNorm1d(1024),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(1024, 1024),
            torch.nn.BatchNorm1d(1024),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(1024, 1)
        )

#### Head-only training

In [None]:
# train_smiles = x_train.apply(canonicalize)
# test_smiles = x_test.apply(canonicalize)

# train_embeddings = embed(molformer, train_smiles, tokenizer)
# test_embeddings = embed(molformer, test_smiles, tokenizer)

In [None]:
# class EmbeddingsDataset(Dataset):
#     def __init__(self, x: torch.tensor, y: pd.Series):
#         self.X = x
#         self.Y = y

#     def __len__(self):
#         return len(self.Y)
    
#     def __getitem__(self, index: int):
#         x = self.X[index]
#         y = torch.tensor(self.Y.iloc[index])
#         return x, y.float()

# train_dataset = EmbeddingsDataset(train_embeddings, y_train)
# test_dataset = EmbeddingsDataset(test_embeddings, y_test)
# train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)

In [None]:
# optimizer = optim.Adam(model.parameters())
# criterion = torch.nn.MSELoss()
# num_epochs = 100

# model.train()
# for epoch_index in range(num_epochs):
#     running_loss = 0.
#     last_loss = 0.

#     # Here, we use enumerate(training_loader) instead of
#     # iter(training_loader) so that we can track the batch
#     # index and do some intra-epoch reporting
#     for i, data in enumerate(train_dataloader):
#         # Every data instance is an input + label pair
#         inputs, labels = data

#         # Zero your gradients for every batch!
#         optimizer.zero_grad()

#         # Make predictions for this batch
#         outputs = model(inputs)

#         # Compute the loss and its gradients
#         loss = criterion(outputs, labels.float().unsqueeze(-1))
#         loss.backward()

#         # Adjust learning weights
#         optimizer.step()

#         # Gather data and report
#         running_loss += loss.item()

#         if (i) % 20 == 19:
#             train_loss = running_loss / 20 # loss per batch
#             running_loss = 0.
#             print('Epoch {} batch {} train loss: {}'.format(epoch_index, i + 1, train_loss))

In [None]:
# import matplotlib.pyplot as plt
# from sklearn.metrics import mean_squared_error, r2_score
# from math import sqrt

# model.eval()
# inputs = test_embeddings
# with torch.no_grad():
#     y_pred = model(inputs)
#     y_pred = y_pred.squeeze(-1).detach().numpy()
#     plt.scatter(y_test.values,
#                 y_pred,
#                 color='r')

# sqrt(mean_squared_error(y_test, y_pred)), r2_score(y_test, y_pred)

#### Full model training

In [None]:
class SmilesDataset(Dataset):
    def __init__(self, x: pd.Series, y: pd.Series):
        self.X = x.apply(canonicalize)
        self.Y = y

    def __len__(self):
        return len(self.Y)
    
    def __getitem__(self, index: int):
        x = self.X.iloc[index]
        y = torch.tensor(self.Y.iloc[index])
        return x, y.float()

train_dataset = SmilesDataset(x_train, y_train)
test_dataset = SmilesDataset(x_test, y_test)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)

In [None]:
parameters = list(model.parameters()) + list(molformer.parameters())
optimizer = optim.Adam(parameters, lr=12e-5)
criterion = torch.nn.MSELoss()
num_epochs = 100

for epoch_index in range(num_epochs):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(train_dataloader):
        if i == 0:
            running_loss = 0.
        # Every data instance is an input + label pair
        inputs, labels = data

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        batch_enc = tokenizer.batch_encode_plus(inputs, padding=True, add_special_tokens=True)
        idx, mask = torch.tensor(batch_enc['input_ids']), torch.tensor(batch_enc['attention_mask'])
        token_embeddings = molformer.blocks(molformer.tok_emb(idx), length_mask=LM(mask.sum(-1)))
        input_mask_expanded = mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        embedding = sum_embeddings / sum_mask
        outputs = model(embedding)

        # Compute the loss and its gradients
        loss = criterion(outputs, labels.float().unsqueeze(-1))
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()

        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('Epoch {}  batch {} loss: {}'.format(epoch_index, i + 1, last_loss))
            running_loss = 0.

In [None]:
from datetime import datetime
torch.save(model.state_dict(), f"nn_model_{datetime.now().isoformat()}")