In [None]:
import os
import math
from IPython.display import clear_output
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pandas.core.dtypes.cast import maybe_box_datetimelike
from copy import deepcopy
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.nn.modules.container import Sequential
from torch.utils.data import random_split, TensorDataset, Dataset, DataLoader

import torchvision
from torchvision import datasets


from Model_and_trainer import CustomStructureDataset, First_CNN, ProgressPlotter

import random


def set_random_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda


In [None]:
#Loading train, val, test datasets
train_dataset = CustomStructureDataset('Data/train_data_files.csv', str_dir = 'Data/Dataset')
val_dataset = CustomStructureDataset('Data/Val_data.csv', str_dir = 'Data/Dataset_val')
test_dataset = CustomStructureDataset('Data/test_1.csv', str_dir = 'Data/Dataset_t1')

In [None]:
# Calculate mean and variance to normalize target values
train_labels = torch.Tensor(list(train_dataset.str_labels[2].astype(float)))
mean = torch.mean(train_labels).item()
std = torch.std(train_labels).item()
print(mean, std)

7.316920280456543 2.1500697135925293


In [None]:

train_dataset.normalize = True # normalization of target values
train_dataset.mean = mean
train_dataset.std = std
train_dataset.train = True # if train = True, with probability 0.5, -1 is replaced by 1 and vice versa
train_dataset.transform = True # if transform = True, with probability 0.5 a mirror reflection occurs along one of the axes

val_dataset.train = False
val_dataset.normalize = True
val_dataset.mean = mean
val_dataset.std = std
val_dataset.transform = None

test_dataset.train = False
test_dataset.normalize = True
test_dataset.mean = mean
test_dataset.std = std
test_dataset.transform = None


In [None]:
from torch.utils.data import WeightedRandomSampler
batch_size = 32

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, shuffle=False, batch_size=1, num_workers=2)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=1, num_workers=2)

In [None]:
import torch.nn as nn
import torch
from torch.nn.modules.container import Sequential

class First_CNN(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv_stack = nn.Sequential(
            nn.Conv3d(10, 32, 7, padding=3, bias=False), # in channel=10, out=32
            nn.MaxPool3d(2), # size [32, length//2, width//2, high//2]
            nn.ReLU(),
            nn.BatchNorm3d(32),

            nn.Conv3d(32, 64, 5, padding=2, bias=False), # in channel=32, out=64
            nn.MaxPool3d(2), # size [64,length//2//2,width//2//2, high//2//2]
            nn.ReLU(),
            nn.BatchNorm3d(64),

            nn.Conv3d(64, 128, 3, padding=1, bias=False), # in channel=64, out=128
            nn.MaxPool3d(2), # size [128,length//2//2//2, width//2//2//2, high//2//2//2]
            nn.ReLU(),
            nn.BatchNorm3d(128),

            nn.Conv3d(128, 256, 3, padding=1, bias=False), # in channel=128, out=256
            nn.MaxPool3d(2), # size [256,length//2//2//2//2, width//2//2//2//2, high//2//2//2//2]
            nn.ReLU(),
            nn.BatchNorm3d(256),

            nn.Flatten(),
            nn.Linear(256*(81//2//2//2//2)*(81//2//2//2//2)*(41//2//2//2//2), 1000),
            nn.Dropout(0.3),
            nn.ReLU(),
            nn.Linear(1000, 200),
            nn.ReLU(),
            nn.Linear(200, 1))

    def forward(self, x):
      x = self.conv_stack(x)
      return x


In [None]:
from torchmetrics import PearsonCorrCoef, R2Score, MeanSquaredError

def train_func(model, criterion, optimizer, num_epochs, best_metrics):
  pp = ProgressPlotter(title="baseline", groups=["loss"])
  loss_hist = [] # for plotting
  acc_val_list = []
  val_loss_hist = [] # for plotting
  for epoch in range(num_epochs):
      model.train()
      hist_loss = 0
      for _, batch in tqdm(enumerate(train_loader, 0)): # get batch
          # parse batch
          structure, labels = batch
          structure, labels = structure.to(device=device, dtype=torch.float), labels.to(device=device, dtype=torch.float)
          # sets the gradients of all optimized tensors to zero.
          optimizer.zero_grad()
          # get outputs
          y_pred = model(structure)
          y_pred = torch.reshape(y_pred, (-1,))
          # calculate loss
          loss = criterion(y_pred, labels)
          # calculate gradients
          loss.backward()
          # performs a single optimization step (parameter update)
          optimizer.step()
          hist_loss += loss.item()
      loss_hist.append(hist_loss / len(train_loader))

      pred_list_val = []
      labels_list_val = []
      pred_list_norm = []
      labels_list_norm = []
      pearson = PearsonCorrCoef().to(device)
      model.eval()
      with torch.no_grad():
        for struct_val, labels_val in val_loader:
          struct_val, labels_val = struct_val.to(device=device, dtype=torch.float), labels_val.to(device=device, dtype=torch.float)
          pred_val = model(struct_val.to(device))
          labels_list_val.append(labels_val[0].item()*std+mean)
          pred_list_val.append(torch.reshape(pred_val, (-1,))[0].item()*std+mean)
          pred_list_norm.append(torch.reshape(pred_val, (-1,))[0].item())
          labels_list_norm.append(labels_val[0].item())
      val_loss_hist = criterion(torch.Tensor(pred_list_norm).to(device), torch.Tensor(labels_list_norm).to(device))
      corr_coef= pearson(torch.Tensor(pred_list_val).to(device), torch.Tensor(labels_list_val).to(device))
      if corr_coef > best_metrics:
        best_metrics = corr_coef
        best_model = deepcopy(model)

      y_pred_train = torch.tensor([])
      y_true_train = torch.tensor([])
      pearson = PearsonCorrCoef() # class for drawing training over epochs
      for i, data in tqdm(enumerate(train_loader)):
        struct_train, labels_train = data
        struct_train, labels_train = struct_train.to(device=device, dtype=torch.float), labels_train.to(device=device, dtype=torch.float)
        pred_train = model(struct_train.to(device)).cpu().detach()
        y_true_train = torch.cat((y_true_train, (labels_train*std+mean).cpu().detach()))
        y_pred_train = torch.cat((y_pred_train, pred_train*std+mean))
      corr_coef_train= pearson(torch.squeeze(y_pred_train), y_true_train)

      print(f"Epoch={epoch} loss={loss_hist[epoch]:.4f} val_loss: {round(val_loss_hist.item(),3)} val_corr: {round(corr_coef.item(),3)}")
      pp.add_scalar(group="loss", value=loss_hist[epoch], tag="train")
      pp.add_scalar(group="loss", value=val_loss_hist.item(), tag="val")
      pp.add_scalar(
      group= 'corr', value=corr_coef_train.item(), tag="train")
      pp.add_scalar(
      group= 'corr', value=corr_coef.item(), tag="val")
      pp.display()
  return best_model

In [None]:
set_random_seed(42)
model = First_CNN()
model.to(device)

criterion = nn.MSELoss()

optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay=1e-3)

num_epochs = 20
best_model = train_func(model, criterion, optimizer, num_epochs, 0)

In [None]:
set_random_seed(42)
best_model.train()
optimizer = torch.optim.AdamW(best_model.parameters(), lr=0.00001, weight_decay=1e-5)
num_epochs = 10
best_model2 = train_func(best_model, criterion, optimizer, num_epochs, 0.5)

In [None]:

def validate(model, test_loader, device = "cpu"):
  model.eval()
  pred_list = []
  labels_list = []
  for struct, labels in test_loader:
      struct, labels = struct.to(device=device, dtype=torch.float), labels.to(device=device, dtype=torch.float)
      pred = model(struct.to(device))
      labels_list.append(labels[0].item()*std+mean)
      pred_list.append(round(torch.reshape(pred, (-1,))[0].item()*std+mean, 2))

  pearson = PearsonCorrCoef().to(device)
  corr_coef= pearson(torch.Tensor(pred_list).to(device), torch.Tensor(labels_list).to(device))

  mse = MeanSquaredError().to(device)
  rmse = math.sqrt(mse(torch.Tensor(pred_list).to(device), torch.Tensor(labels_list).to(device)).item())

  print(f'PearsonCorr: {round(corr_coef.item(),2)}')
  print(f'RMSE: {round(rmse,3)}')
  return labels_list, pred_list

In [None]:
a, b = validate(best_model2, test_loader, device=device)

In [None]:
import scipy
scipy.stats.pearsonr(b, a)

In [None]:
import seaborn as sns
corr_df = pd.DataFrame(columns=['Pred_pKD', 'True_pKD'])
corr_df['Pred_pKD'] = b
corr_df['True_pKD'] = a
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', None)

sns.set_context(rc={'figure.dpi': 500, 'font.size': 12})
fig = sns.jointplot(data=corr_df, x='True_pKD', y='Pred_pKD', palette='Set2', ylim=(2, 14), xlim=(2, 14))

fig.figure.savefig("Result Model/model_test1.png")

In [None]:
corr_df.to_csv('Result Model/best_model_test_1.csv')

In [None]:
# 0: Hydrogen bond acceptors pr1 + Hydrogen bond donors pr2
# 1: Hydrogen bond donors pr1 + Hydrogen bond acceptors pr2
# 2: Hydrogen bond acceptors pr1 + Weak hydrogen bond donors pr2
# 3: Weak hydrogen bond donors pr1 + Hydrogen bond acceptors pr2
# 4: Positive charge atoms pr1 + Negative charge atoms pr2
# 5: Negative charge atoms pr1 + Positive charge atoms pr2
# 6: Hydrophobic atoms pr1 + Hydrophobic atoms pr2
# 7: Carbonyl carbons pr1 + Carbonyl carbons pr2
# 8: Carbonyl oxygens pr1 + Carbonyl oxygens pr2
# 9: Aromatic atoms pr1 + Aromatic atoms pr2

channels_names = ['HB_Ac1+Don2', 'HB_Ac2+Don1', 'HB_Ac1+Weak_Don2', 'HB_Ac2+Weak_Don1', 'Pos1+Neg2', 'Pos2+Neg1', 'Hph1+Hph2', 'Carboxy_C1+Carboxy_C2', 'Carboxy_O1+Carboxy_O2', 'Arom1+Arom2']

In [None]:
# output of weights for each input channel
w0 = best_model2.conv_stack[0].weight
w0 = pd.DataFrame(np.transpose(w0.cpu().detach().numpy(), [0, 4, 2, 3, 1]).reshape((-1, 10)),
                  columns=channels_names)

In [None]:
# calculating the average value for significant neurons
diff = (w0.abs() > 0.001).mean()
diff.sort_values(ascending=False)

In [None]:
# range between 25th and 75th percentiles
perc_diff = ((w0.apply(lambda x: np.percentile(x, 75))
             - w0.apply(lambda x: np.percentile(x, 25)))
             .sort_values(ascending=False))

# plotting spreads of weights for each of 10 channels
fig, ax = plt.subplots(figsize=(7, 6), dpi=300)

sns.boxplot(data=w0, fliersize=0, orient='h', ax=ax)
ax.set_xlim(-0.055, 0.055)
ax.set_xticks(np.arange(-0.04, 0.05, 0.02))
ax.set_ylim(10, -1)

fig.tight_layout()
fig.figure.savefig("Result Model/model_feat_imp.png")

In [None]:
test_dataset2 = CustomStructureDataset('Data/test_2.csv', str_dir = 'Data/Dataset_t2')
test_dataset2.train = False
test_dataset2.normalize = True
test_dataset2.mean = mean
test_dataset2.std = std
test_dataset2.transform = None

test_loader2 = DataLoader(test_dataset2, shuffle=False, batch_size=1, num_workers=2)

a, b = validate(best_model2, test_loader2, device=device)

In [None]:
scipy.stats.pearsonr(b, a)

In [None]:
corr_df = pd.DataFrame(columns=['Pred_pKD', 'True_pKD'])
corr_df['Pred_pKD'] = b
corr_df['True_pKD'] = a
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', None)

sns.set_context(rc={'figure.dpi': 500, 'font.size': 12})
fig = sns.jointplot(data=corr_df, x='True_pKD', y='Pred_pKD', palette='Set2', ylim=(2, 14), xlim=(2, 14))

fig.figure.savefig("Result Model/model_test_2.png")

In [None]:
corr_df.to_csv('Result Model/best_model_test_2.csv')

In [None]:
#torch.save(best_model2, 'Models/best_model.pt')