<a href="https://colab.research.google.com/github/Swayamprakashpatel/Sol_ME/blob/main/GNN_Solubility.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
# Install necessary libraries
!pip install rdkit-pypi torch torch-geometric scikit-learn matplotlib

# Import required libraries
import pandas as pd
import numpy as np
import torch
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool
from sklearn.model_selection import train_test_split
from rdkit import Chem
import matplotlib.pyplot as plt

# Load the data
data = pd.read_csv('/content/GNN_Smiles.csv')

# Define a function to convert SMILES to graph data
def smiles_to_graph(smiles):
    mol = Chem.MolFromSmiles(smiles)
    edges = []
    for bond in mol.GetBonds():
        edges.append((bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()))
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    x = torch.tensor([atom.GetAtomicNum() for atom in mol.GetAtoms()], dtype=torch.float).view(-1, 1)
    return Data(x=x, edge_index=edge_index)

# Convert SMILES to graph data
drug_graphs = [smiles_to_graph(smiles) for smiles in data['Drug_Smile']]
solvent_graphs = [smiles_to_graph(smiles) for smiles in data['Solvent_Smile']]

# Combine drug and solvent graphs
def combine_graphs(drug_graph, solvent_graph):
    x = torch.cat([drug_graph.x, solvent_graph.x], dim=0)
    edge_index = torch.cat([drug_graph.edge_index, solvent_graph.edge_index + drug_graph.x.size(0)], dim=1)
    return Data(x=x, edge_index=edge_index)

graphs = [combine_graphs(drug, solvent) for drug, solvent in zip(drug_graphs, solvent_graphs)]
labels = torch.tensor(data['Solubility'].values, dtype=torch.float)

# Prepare the dataset
train_graphs, test_graphs, train_labels, test_labels = train_test_split(graphs, labels, test_size=0.2, random_state=42)

# Define the GNN model
class GNN(torch.nn.Module):
    def __init__(self):
        super(GNN, self).__init__()
        self.conv1 = GCNConv(1, 64)
        self.conv2 = GCNConv(64, 128)
        self.fc = torch.nn.Linear(128, 1)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = global_mean_pool(x, data.batch)
        x = self.fc(x)
        return x

# Define early stopping and model checkpointing based on RMSE
class EarlyStopping:
    def __init__(self, patience=15, delta=0):
        self.patience = patience
        self.delta = delta
        self.best_rmse = float('inf')
        self.wait = 0
        self.stop = False

    def __call__(self, val_rmse):
        if val_rmse < self.best_rmse - self.delta:
            self.best_rmse = val_rmse
            self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.stop = True

early_stopping = EarlyStopping(patience=15, delta=0.01)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GNN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.MSELoss()

# Initialize lists to store metrics and predictions
train_losses = []
val_rmse_list = []
train_predictions = []
train_actuals = []
test_predictions = []
test_actuals = []

# Define functions for training and validation
def train():
    model.train()
    total_loss = 0
    for graph, label in zip(train_graphs, train_labels):
        graph = graph.to(device)
        label = label.to(device)
        optimizer.zero_grad()
        out = model(graph)
        loss = criterion(out.squeeze(), label)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        # Collect training predictions
        train_predictions.append(out.cpu().detach().numpy())
        train_actuals.append(label.cpu().numpy())

    return total_loss / len(train_graphs)

def validate():
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for graph, label in zip(test_graphs, test_labels):
            graph = graph.to(device)
            label = label.to(device)
            out = model(graph)
            loss = criterion(out.squeeze(), label)
            val_loss += loss.item()

            # Collect test predictions
            test_predictions.append(out.cpu().detach().numpy())
            test_actuals.append(label.cpu().numpy())

    rmse = np.sqrt(val_loss / len(test_graphs))
    return rmse

# Training loop with early stopping and model checkpointing based on RMSE
num_epochs = 10
best_rmse = float('inf')

for epoch in range(num_epochs):
    train_loss = train()
    val_rmse = validate()

    train_losses.append(train_loss)
    val_rmse_list.append(val_rmse)

    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val RMSE: {val_rmse:.4f}')

    if val_rmse < best_rmse:
        best_rmse = val_rmse
        torch.save(model.state_dict(), 'best_model.pth')  # Save the model checkpoint

    early_stopping(val_rmse)
    if early_stopping.stop:
        print('Early stopping triggered.')
        break

# Load the best model for evaluation
model.load_state_dict(torch.load('best_model.pth'))

# Evaluate the model
model.eval()
predictions = []
with torch.no_grad():
    for graph in test_graphs:
        graph = graph.to(device)
        pred = model(graph)
        predictions.append(pred.cpu().numpy())
predictions = np.concatenate(predictions)

print("Predictions:", predictions)
print("True Labels:", test_labels.numpy())

# Plotting the metrics
epochs = range(1, len(train_losses) + 1)

plt.figure(figsize=(12, 12))

# Training Loss
plt.subplot(3, 1, 1)
plt.plot(epochs, train_losses, 'b', label='Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Epoch vs Train Loss')
plt.legend()

# Validation RMSE
plt.subplot(3, 1, 2)
plt.plot(epochs, val_rmse_list, 'r', label='Validation RMSE')
plt.xlabel('Epochs')
plt.ylabel('RMSE')
plt.title('Epoch vs Validation RMSE')
plt.legend()

# Predicted vs Actual (Train)
plt.subplot(3, 1, 3)
train_predictions = np.concatenate(train_predictions)
train_actuals = np.concatenate(train_actuals)
plt.scatter(train_actuals, train_predictions, alpha=0.5)
plt.plot([train_actuals.min(), train_actuals.max()], [train_actuals.min(), train_actuals.max()], 'r--', lw=2)
plt.xlabel('Actual Values')
plt.ylabel('Predicted Values')
plt.title('Train Predicted vs Actual')
plt.grid(True)

plt.tight_layout()
plt.show()

# Plot Predicted vs Actual (Test)
plt.figure(figsize=(6, 6))
test_predictions = np.concatenate(test_predictions)
test_actuals = np.concatenate(test_actuals)
plt.scatter(test_actuals, test_predictions, alpha=0.5)
plt.plot([test_actuals.min(), test_actuals.max()], [test_actuals.min(), test_actuals.max()], 'r--', lw=2)
plt.xlabel('Actual Values')
plt.ylabel('Predicted Values')
plt.title('Test Predicted vs Actual')
plt.grid(True)
plt.show()


Epoch 1/100, Train Loss: 6304.4975, Val RMSE: 83.4173
Epoch 2/100, Train Loss: 6218.9940, Val RMSE: 83.1607
Epoch 3/100, Train Loss: 6181.1014, Val RMSE: 82.9472
Epoch 4/100, Train Loss: 6161.1046, Val RMSE: 82.8024
Epoch 5/100, Train Loss: 6150.0767, Val RMSE: 82.6904
Epoch 6/100, Train Loss: 6143.7469, Val RMSE: 82.6122
Epoch 7/100, Train Loss: 6136.6788, Val RMSE: 82.5472
Epoch 8/100, Train Loss: 6131.9603, Val RMSE: 82.4777
Epoch 9/100, Train Loss: 6127.5112, Val RMSE: 82.4062
Epoch 10/100, Train Loss: 6122.4144, Val RMSE: 82.3441
Epoch 11/100, Train Loss: 6118.4768, Val RMSE: 82.2911
Epoch 12/100, Train Loss: 6115.3136, Val RMSE: 82.2416
Epoch 13/100, Train Loss: 6112.3584, Val RMSE: 82.2004
Epoch 14/100, Train Loss: 6110.2681, Val RMSE: 82.1740
Epoch 15/100, Train Loss: 6108.9904, Val RMSE: 82.1546
Epoch 16/100, Train Loss: 6107.8984, Val RMSE: 82.1348
Epoch 17/100, Train Loss: 6107.1908, Val RMSE: 82.1124
Epoch 18/100, Train Loss: 6105.5338, Val RMSE: 82.0940
Epoch 19/100, Train

KeyboardInterrupt: 