In [None]:
# Import necessary libraries
import tqdm
import pandas as pd
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem, Draw, rdDepictor
from rdkit.Chem import PandasTools, Descriptors
import py3Dmol


from IPython.display import display, HTML
from rdkit.Chem.Draw import rdMolDraw2D
from IPython.display import SVG

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

import matplotlib
import matplotlib.pyplot as plt

# Classification of molecules using CNN #

**Goals**
1. Introduction to SMILES as molecular representation for ML models.
2. Use ML models, more specifically **Convolutional NeuralNetworks** to classify molecules. 





## Data loading and analysis ##

In [None]:
data_url = "https://github.com/RodrigoAVargasHdz/CHEM-4PB3/raw/w2024/Course_Notes/data/HIV.csv"
data = pd.read_csv(data_url)
print('Total data:', data.count())
data.head()

data.hist(column='HIV_active')
print('Possible values of HIV_active:', data['HIV_active'].unique())
print('Possible values of activity:', data['activity'].unique())

Plot some molecules that are not **HIV active**

In [None]:
PandasTools.AddMoleculeColumnToFrame(data, 'smiles')
HIV_active_0 = data[data['HIV_active'] == 0]
HIV_active_0_16 = PandasTools.FrameToGridImage(HIV_active_0[:9], column='ROMol', legendsCol='smiles',
                                               molsPerRow=3, subImgSize=(300, 300))
HIV_active_0_16

Plot some molecules that are **HIV active**

In [None]:
PandasTools.AddMoleculeColumnToFrame(data, 'smiles')
HIV_active_0 = data[data['HIV_active'] == 1]
HIV_active_0_16 = PandasTools.FrameToGridImage(HIV_active_0[:9], column='ROMol', legendsCol='smiles',
                                               molsPerRow=3, subImgSize=(300, 300))
HIV_active_0_16

From the previous class, we saw that a molecule written in the SMILES notation can be transformed into a *"figure"* using a dictionary of characters and the one-hot encoding transformation. <br>

To create this dictionary, we first need to defined the maximum number of characters in a SMILE, meaning the length of the text. 

In [None]:
SMILES_CHARS = ["7", "6", "o", "]", "3", "s", "(", "-", "S", "/", "B", "4", "[", ")", "#", "I",
                "l", "O", "H", "c", "1", "@", "=", "n", "P", "8", "C", "2", "F", "5", "r", "N", 
                "+", "\\", " ", "Cu", ".", "Si", "Se", "Na", "Li", "Ge", "K", "Zn", "Mo", "Rh",
                "9", "p", "se", "Fe", "W", "Te", "Pd", "Ni", "As", "Pt", "Mg", "%","U","0",
                "Tl", "Ga", "Au", "Ti", "Mn", "Bi", "Br", "Hg", "b", "Ca", "Ag"]
# Index
smi2index = dict((c, i) for i, c in enumerate(SMILES_CHARS))


def smiles_to_one_hot_and_list(smile, maxlen):
    # Initialize a matrix filled with zeros up to maxlen
    X = np.zeros((maxlen, len(SMILES_CHARS)))  # (maxlen, dictionary size)
    smile_list = []  # List to store the split SMILES string
    smile = smile.replace('\n', '')
    i = 0  # Position in the smile string
    j = 0  # Position in the one-hot matrix
    while i < len(smile):
        # Check for two-character element
        if i + 1 < len(smile) and smile[i:i+2] in smi2index:
            X[j, smi2index[smile[i:i+2]]] = 1
            smile_list.append(smile[i:i+2])
            i += 2
        # Single character element or symbol
        elif smile[i] in smi2index:
            X[j, smi2index[smile[i]]] = 1
            smile_list.append(smile[i])
            i += 1
        else:
            print(smile)
            print(smile[i])
            assert 0
            # raise ValueError(f"Unrecognized SMILES character: {smile[i]}")
        j += 1
        if j >= maxlen:
            break  # Prevents exceeding the maximum length

    return X.T, smile_list

# Example usage
smile = 'CN1C=NC2=C1C(=O)N(C(=O)N2C)C'  # Example SMILES 
max_length_test = len(smile) + 5  # Example maximum length for SMILES representation
one_hot_encoded_smile, smile_list = smiles_to_one_hot_and_list(
    smile, max_length_test)
print(smile_list)

plt.figure(figsize=(10, 10))
plt.imshow(one_hot_encoded_smile, cmap='binary')
plt.ylabel('Tokens')
plt.xlabel('SMILES')

plt.title('One-hot encoding for %s' % smile)
plt.xticks(np.arange(len(list(smile_list))),
           smile_list, fontsize=8)
plt.yticks(np.arange(len(list(SMILES_CHARS))),
           list(SMILES_CHARS), fontsize=8)

In [None]:
mol = AllChem.MolFromSmiles(smile)
mol

In classification tasks, having a balanced dataset is crucial for ensuring the model learns to accurately identify each class. When a dataset is imbalanced, with some classes significantly overrepresented compared to others, models tend to become biased towards the majority classes. This bias can lead to poorer performance on the minority classes, as the model might not learn their characteristics adequately. Essentially, the model may simply "learn" to predict the majority class most of the time, because doing so minimizes its error on the training data. However, such a model is not genuinely understanding or distinguishing between the classes effectively. <br>

Balancing the dataset, either through undersampling the majority classes, oversampling the minority classes, or employing synthetic data generation techniques like SMOTE, helps in creating a more equitable training environment. This balanced approach encourages the model to pay equal attention to learning the distinctive features of each class, leading to better generalization and a more robust performance across all classes, not just the predominant ones.

In [None]:
data_negative = data[data['HIV_active'] == 0]
data_positive = data[data['HIV_active'] == 1]
print(data_negative.count())
print(data_positive.count())

# balanced dataset
n_positive = len(data_positive)
data_negative_red = data_negative.sample(n_positive)

balanced_data = pd.concat([data_negative_red, data_positive], axis=0)
print(balanced_data.head())



In [None]:
max_length = 0
for si in data['smiles']:
    if len(si) > max_length:
        max_length = len(si)

print('Max Molecule length', max_length)

Let's create a Data loader for this dataset.
1. We will transform each smile into its "figure" representation
2. Because we are working with two-classes, 'active' and 'inactive'. We also need to transform the label/class to one-hot encoding.

In [None]:
class CustomDataset(Dataset):
    def __init__(self, smiles_all, labels_all, max_length):
        self.labels = labels_all
        self.smiles = smiles_all
        self.max_length = max_length

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        si = self.smiles[idx]
        labels = self.labels[idx]
    
        mi,si_list = smiles_to_one_hot_and_list(si,self.max_length)
        
        molecules_torch = torch.from_numpy(mi).float()
        labels_one_hot = F.one_hot(torch.from_numpy(np.array(labels)),num_classes=2)
        return molecules_torch.unsqueeze(0), labels_one_hot

In [None]:
data_full = balanced_data
train_size = int(0.8 * len(data_full))  # 80% for training
validation_size = len(data_full) - train_size  # 20% for validation
print('Training data', train_size)
print('Test data', validation_size)
# train_dataset, validation_dataset = random_split(
#     data_full, [train_size, validation_size])
tr_dataset = data_full.sample(train_size)
val_dataset = data_full.sample(validation_size)
                            


training_data = CustomDataset(
    tr_dataset['smiles'].to_list(), tr_dataset['HIV_active'].to_list(), max_length)
train_dataloader = DataLoader(training_data, batch_size=128, shuffle=True)
train_molecules, train_labels = next(iter(train_dataloader))

print('Size of the training data')
print(train_molecules.shape)
print(train_labels.shape)

## CNN ##

In [None]:
# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 6, 5,padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(6, 16, 5,padding=0),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )        # fully connected layer, output 10 classes
        self.fc1 = nn.Sequential(
            nn.Linear(16 * 15 * 143, 512),
            nn.ReLU(),
            )
        self.fc2 = nn.Sequential(
            nn.Linear(512, 128),
            nn.ReLU(),
            )
        self.fc3 = nn.Linear(128, 2)

    def forward(self, x):
        x = self.conv1(x)
        # print(x.shape)
        # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        x = self.conv2(x)
        # print(x.shape)
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        # print(x.shape)
        x = self.fc1(x)
        # print(x.shape)
        x = self.fc2(x)
        # print(x.shape)
        output = self.fc3(x)
        return output   # return x for visualization

In [None]:
def train(model, training_data, training_epochs=60):
    # Define the loss function and optimizer
    loss_function = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

    trainloader = torch.utils.data.DataLoader(
        training_data, batch_size=64, shuffle=True)

    iterator = tqdm.notebook.tqdm(range(training_epochs))

    # Run the training loop (epochs)
    loss_trajectory = []
    for epoch in iterator:

        # Set current loss value
        current_loss = []
        for i, data in enumerate(trainloader, 0):
            inputs, targets = data

            outputs = model(inputs)

            optimizer.zero_grad()

            loss = loss_function(outputs, targets.float())
            loss.backward()
            optimizer.step()

            # Print statistics
            # current_loss += loss.item()
            current_loss.append(loss.item())
        # print('Epoch %s: %.4f +- %.4f'%(epoch,np.array(current_loss).mean(),np.array(current_loss).std()))
        iterator.set_postfix(loss=torch.tensor(current_loss).mean())
        loss_trajectory.append(current_loss)
        # Process is complete.
    return loss_trajectory

In [None]:
cnn = CNN()
print(train_molecules.shape)
labels = cnn(train_molecules)
print(labels)

loss_trj = train(cnn, training_data,1)


## Confusion Matrix ##
A Confusion Matrix is a powerful tool used in machine learning to evaluate the performance of classification models. It is a table that visualizes the accuracy of a model by comparing the actual versus predicted classifications. The matrix is divided into four parts: True Positives (TP), where the model correctly predicts the positive class; True Negatives (TN), where the model correctly predicts the negative class; False Positives (FP), where the model incorrectly predicts the positive class; and False Negatives (FN), where the model incorrectly predicts the negative class. This breakdown allows not only for the calculation of overall accuracy but also for more nuanced performance metrics such as precision, recall, and the F1 score. By providing a detailed view of how a model is performing across different classes, the Confusion Matrix helps in identifying specific areas where the model may be struggling, making it invaluable for diagnosing and improving classification models.

In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sns

true_labels = np.random.binomial(1, 0.5, 100)
print(true_labels)
assert 0
# predictions = np.random.binomial(1, 0.5, 100)
cnn.eval()
cnn()


val_data = CustomDataset(
    val_dataset['smiles'].to_list(), val_dataset['HIV_active'].to_list(), max_length)
train_dataloader = DataLoader(val_data, batch_size=validation_size, shuffle=True)
val_molecules, val_labels = next(iter(train_dataloader))


true_labels = val_labels
# Compute confusion matrix
cm = confusion_matrix(true_labels, predictions)

# Plotting
fig, ax = plt.subplots()
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax)
ax.set_xlabel('Predicted labels')
ax.set_ylabel('True labels')
ax.set_title('Confusion Matrix')
ax.xaxis.set_ticklabels(['Class 0', 'Class 1'])
ax.yaxis.set_ticklabels(['Class 0', 'Class 1'])

plt.show()