# Module 1: Basics of Machine Learning
## Part 2: Classification

In this notebook, we illustrate the application of PyTorch to classify molecules in aqueous solubility categories

### 1. Install and load python libraries

In [None]:
!pip install torch numpy matplotlib scikit-learn pandas rdkit-pypi seaborn

In [None]:
# Required Libraries
import torch  # PyTorch main package
import torch.nn as nn  # Neural network modules
import torch.optim as optim  # Optimization algorithms
from torch.utils.data import TensorDataset, DataLoader
import numpy as np  # Numerical computations
import matplotlib.pyplot as plt  # Plotting library
from sklearn.datasets import make_regression  # To generate synthetic regression data
from sklearn.model_selection import train_test_split  # To split data into train/test sets
from sklearn.preprocessing import StandardScaler  # To standardize data
from sklearn.metrics import root_mean_squared_error, mean_absolute_error
import pandas as pd # Pandas for handling input data
from rdkit import Chem # Work with molecules
from rdkit.Chem import Draw # Draw molecules
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

### 2. Load and Prepare Solubility Data

In [None]:
!wget https://raw.githubusercontent.com/mcsorkun/AqSolDB/refs/heads/master/results/data_curated.csv
data = pd.read_csv('data_curated.csv')
# Choose some features and define X
descriptor_names=['MolWt', 'MolLogP', 'MolMR', 'HeavyAtomCount','NumHAcceptors', 'NumHDonors', 'NumHeteroatoms', 'NumRotatableBonds','NumValenceElectrons', 'NumAromaticRings', 'NumSaturatedRings',
       'NumAliphaticRings', 'RingCount', 'TPSA', 'LabuteASA', 'BalabanJ', 'BertzCT']
X=data[descriptor_names].to_numpy()
# Define y as the solubility (logS)
solub=data['Solubility'].to_numpy()

# Divide the solubility into categories
category_names={0: "Soluble", 1: "Somewhat soluble", 2: "Insoluble"} 
y = np.ones(solub.shape[0])*-1
y[solub>=-1]=0 # Soluble
y[np.logical_and(solub>=-3,solub<-1)]=1 # Somewhat soluble
y[solub<-3]=2 # Insoluble

# Prepare the data set

# Split into training and test sets (80/20)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Scale input features and target
scaler_X = StandardScaler()
X_train = scaler_X.fit_transform(X_train)
X_test = scaler_X.transform(X_test)

# Convert to PyTorch tensors
X_train = torch.FloatTensor(X_train)
X_test = torch.FloatTensor(X_test)
y_train = torch.LongTensor(y_train)
y_test = torch.LongTensor(y_test)

### 3. Define Machine Learning Setup (Neural Network Model, Loss, Optimizer)

In [None]:
# ------------------------------
# 3. Define the neural network with a switch for PReLU or Tanh
# ------------------------------
class ClassificationNN(nn.Module):
    def __init__(self, input_size, num_classes):
        """
        Initialize the model with the specified activation function.
        :param input_size: The size of the input features.
        :param num_classes: The number of output classes.
        :param activation_func: Choose "PReLU" or "Tanh" as the activation function.
        """
        super(ClassificationNN, self).__init__()
        self.fc1 = nn.Linear(input_size, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, num_classes)  # Output layer (no activation, use CrossEntropyLoss)
        self.dropout = nn.Dropout(0.0)  # Dropout for regularization
        self.activation = nn.ReLU()  # PReLU activation function

    def forward(self, x):
        x = self.activation(self.fc1(x))  # First layer with activation
        x = self.dropout(x)
        x = self.activation(self.fc2(x))  # Second layer with activation
        x = self.dropout(x)
        x = self.fc3(x)                   # Output layer (logits)
        return x

In [None]:
# ------------------------------
# 4. Initialize model, loss, optimizer
# ------------------------------
input_size = X_train.shape[1]
num_classes = len(torch.unique(y_train))

model = ClassificationNN(input_size, num_classes)

# Loss Function: CrossEntropyLoss includes LogSoftmax internally
criterion = nn.CrossEntropyLoss()

# Optimizer: Adam
optimizer = optim.Adam(model.parameters(), lr=0.001)

### 4. Train the Model

In [None]:
epochs = 200
train_losses, test_losses = [], []
train_accuracies, test_accuracies = [], []

for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(X_train)
    loss = criterion(outputs, y_train)
    loss.backward()
    optimizer.step()
    train_losses.append(loss.item())

    _, predicted = torch.max(outputs, 1)
    correct = (predicted == y_train).sum().item()
    train_accuracies.append(correct / len(y_train))

    model.eval()
    with torch.no_grad():
        test_outputs = model(X_test)
        test_loss = criterion(test_outputs, y_test)
        test_losses.append(test_loss.item())

        _, test_pred = torch.max(test_outputs, 1)
        correct_test = (test_pred == y_test).sum().item()
        test_accuracies.append(correct_test / len(y_test))

    if (epoch + 1) % 1 == 0:
        print(f"Epoch [{epoch+1}/{epochs}] "
              f"Train Loss: {loss.item():.4f}, Test Loss: {test_loss.item():.4f} "
              f"| Train Acc: {train_accuracies[-1]:.4f}, Test Acc: {test_accuracies[-1]:.4f}")

In [None]:
# ------------------------------
# 6. Plot Loss and Accuracy
# ------------------------------
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss over Epochs')

plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label='Train Accuracy')
plt.plot(test_accuracies, label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Accuracy over Epochs')
plt.tight_layout()
plt.show()

# 5. Classification Report and Confusion Matrix

In [None]:
model.eval()
with torch.no_grad():
    y_pred = model(X_test)
    _, y_pred = torch.max(y_pred, 1)

print("\nClassification Report:")
print(classification_report(y_test, y_pred))

cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()

# 6. Test predictions

In [None]:
# Try octane, benzene, ethanol, acetic acid
compound_name="benzene"
compound_index=np.where(data['Name']==compound_name)[0][0]

In [None]:
data[data['Name']==compound_name]

In [None]:
# Visualize molecules in the data set
s = data['SMILES'][compound_index]
mol = Chem.MolFromSmiles(s)
Draw.MolToImage(mol)

In [None]:
with torch.no_grad():
    X_example=scaler_X.transform(X[compound_index].reshape(1,-1))
    _, compound_class = torch.max(model(torch.FloatTensor(X_example)),1)
    print(category_names[compound_class.item()])