In [1]:
!pip install rdkit torch torch-geometric pandas numpy scikit-learn tqdm



In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import AttentiveFP
from rdkit import Chem
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_absolute_error
import numpy as np
import pandas as pd
from tqdm import tqdm

# ---------------------- Data Processing ----------------------

def load_data(file_path, smiles_column="SMILES", target_column="MLM"):
    df = pd.read_csv(file_path).dropna(subset=[smiles_column, target_column])
    df = df[df[target_column] > 0]  # Remove zero/negative values
    df[target_column] = np.log1p(df[target_column])  # Log1p transformation
    return df, smiles_column, target_column

def smiles_to_graph(smiles, target):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None

    atom_features = []
    for atom in mol.GetAtoms():
        atomic_num = atom.GetAtomicNum()
        one_hot = [0] * 10
        one_hot[min(atomic_num, 9)] = 1
        features = one_hot + [
            atom.GetDegree() / 4.0,
            atom.GetFormalCharge() / 5.0,
            int(atom.GetHybridization()) / 6.0,
            float(atom.GetMass()) / 200.0,
            int(atom.GetIsAromatic()),
            atom.GetTotalNumHs() / 4.0,
            atom.GetExplicitValence() / 8.0
        ]
        atom_features.append(features)

    edge_index, edge_attr = [], []
    for bond in mol.GetBonds():
        start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        bond_features = [
            int(bond.GetBondType() == Chem.rdchem.BondType.SINGLE),
            int(bond.GetBondType() == Chem.rdchem.BondType.DOUBLE),
            int(bond.GetBondType() == Chem.rdchem.BondType.TRIPLE),
            int(bond.GetBondType() == Chem.rdchem.BondType.AROMATIC),
            int(bond.GetIsConjugated())
        ]
        edge_index.extend([[start, end], [end, start]])
        edge_attr.extend([bond_features, bond_features])

    if not edge_index:
        return None

    return Data(
        x=torch.tensor(atom_features, dtype=torch.float),
        edge_index=torch.tensor(edge_index, dtype=torch.long).t().contiguous(),
        edge_attr=torch.tensor(edge_attr, dtype=torch.float),
        y=torch.tensor([target], dtype=torch.float)
    )

# ---------------------- Model Definition ----------------------

class EnhancedAttentiveFP(nn.Module):
    def __init__(self, node_feat_size, edge_feat_size, num_layers=4, num_timesteps=6, hidden_channels=256, dropout=0.25):
        super(EnhancedAttentiveFP, self).__init__()
        self.gnn = AttentiveFP(
            in_channels=node_feat_size,
            hidden_channels=hidden_channels,
            out_channels=hidden_channels,
            edge_dim=edge_feat_size,
            num_layers=num_layers,
            num_timesteps=num_timesteps,
            dropout=dropout
        )
        self.linear = nn.Linear(hidden_channels, hidden_channels)
        self.norm = nn.LayerNorm(hidden_channels)
        self.relu = nn.ReLU()
        self.output_layer = nn.Linear(hidden_channels, 1)

    def forward(self, x, edge_index, edge_attr, batch):
        x = self.gnn(x, edge_index, edge_attr, batch)
        x = self.linear(x)
        x = self.norm(x)
        x = self.relu(x)
        return self.output_layer(x).squeeze(1)


# ---------------------- Training Function ----------------------

def train_model(data, epochs=200, batch_size=64, lr=3e-4, patience=15, weight_decay=1e-5, grad_accum_steps=2):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_data, val_data = train_test_split(data, test_size=0.2, random_state=42)
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)
    val_loader = DataLoader(val_data, batch_size=batch_size, drop_last=False)

    num_node_features = data[0].x.size(1)
    num_edge_features = data[0].edge_attr.size(1)

    model = EnhancedAttentiveFP(num_node_features, num_edge_features, num_layers=4, num_timesteps=6).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5, min_lr=1e-6)
    criterion = nn.L1Loss()

    best_r2 = -float("inf")
    best_mae = float("inf")
    early_stop_counter = 0
    val_metrics = []

    print("Starting validation phase...")
    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        optimizer.zero_grad()

        for step, batch in enumerate(train_loader):
            batch = batch.to(device)
            out = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
            loss = criterion(out, batch.y.view(-1))
            loss.backward()

            if (step + 1) % grad_accum_steps == 0 or step == len(train_loader) - 1:
                optimizer.step()
                optimizer.zero_grad()

            total_loss += loss.item()

        # Validation
        model.eval()
        preds, labels = [], []
        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(device)
                out = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
                preds.extend(out.cpu().numpy())
                labels.extend(batch.y.view(-1).cpu().numpy())


        # Convert back from log1p space for evaluation metrics
        all_labels_exp = np.expm1(all_labels)
        all_preds_exp = np.expm1(all_preds)

        r2 = r2_score(all_labels_exp, all_preds_exp)
        mae = mean_absolute_error(all_labels_exp, all_preds_exp)
        val_metrics.append({
            'epoch': epoch + 1,
            'train_loss': total_loss / len(train_loader),
            'val_r2': r2,
            'val_mae': mae
        })

        print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}, R²: {r2:.4f}, MAE: {mae:.4f}")

        scheduler.step(r2)

        if r2 > best_r2 or (r2 == best_r2 and mae < best_mae):
            best_r2 = r2
            best_mae = mae
            early_stop_counter = 0

        else:
            early_stop_counter += 1
            if early_stop_counter >= patience:
                print("Early stopping.")
                break

    pd.DataFrame(val_metrics).to_csv("validation_metrics.csv", index=False)

    # Full training
    print("\nTraining on full dataset...")
    full_loader = DataLoader(data, batch_size=batch_size, shuffle=True)
    best_loss = float("inf")
    early_stop_counter = 0
    full_metrics = []

    for epoch in range(epochs // 2):
        model.train()
        total_loss = 0.0
        optimizer.zero_grad()

        for step, batch in enumerate(full_loader):
            batch = batch.to(device)
            out = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
            loss = criterion(out, batch.y.view(-1))
            loss.backward()

            if (step + 1) % grad_accum_steps == 0 or step == len(full_loader) - 1:
                optimizer.step()
                optimizer.zero_grad()

            total_loss += loss.item()

        avg_loss = total_loss / len(full_loader)
        full_metrics.append({'epoch': epoch + 1, 'train_loss': avg_loss})
        print(f"Full Data Epoch {epoch+1}, Loss: {avg_loss:.4f}")

        if avg_loss < best_loss:
            best_loss = avg_loss
            early_stop_counter = 0
            torch.save(model.state_dict(), "best_model_mlmfull.pth")
        else:
            early_stop_counter += 1
            if early_stop_counter >= patience:
                print("Early stopping in full training.")
                break


    return model

# ---------------------- Run Training ----------------------

if __name__ == '__main__':
    file_path = "/content/MLM_MERGED.csv"
    df, smiles_col, target_col = load_data(file_path)

    print("Converting SMILES to graph...")
    graph_data = [smiles_to_graph(row[smiles_col], row[target_col]) for _, row in tqdm(df.iterrows(), total=len(df))]
    graph_data = [g for g in graph_data if g is not None]
    print(f"Loaded {len(graph_data)} valid molecules.")

    trained_model = train_model(graph_data)


Converting SMILES to graph...


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
100%|██████████| 2074/2074 [00:05<00:00, 375.66it/s]


Loaded 2074 valid molecules.
Starting validation phase...
Epoch 1, Loss: 143.6198, R²: -1.5035, MAE: 4.4873
Epoch 2, Loss: 101.3355, R²: -0.9313, MAE: 3.9247
Epoch 3, Loss: 91.0570, R²: -0.6542, MAE: 3.6068
Epoch 4, Loss: 84.8316, R²: -0.4490, MAE: 3.3546
Epoch 5, Loss: 79.1362, R²: -0.2922, MAE: 3.1445
Epoch 6, Loss: 75.2806, R²: -0.1815, MAE: 2.9873
Epoch 7, Loss: 72.7240, R²: -0.1033, MAE: 2.8687
Epoch 8, Loss: 70.2624, R²: -0.0532, MAE: 2.7841
Epoch 9, Loss: 68.9254, R²: -0.0241, MAE: 2.7273
Epoch 10, Loss: 68.0976, R²: -0.0083, MAE: 2.6844
Epoch 11, Loss: 67.8210, R²: -0.0016, MAE: 2.6545
Epoch 12, Loss: 67.1641, R²: -0.0001, MAE: 2.6365
Epoch 13, Loss: 67.3854, R²: -0.0005, MAE: 2.6292
Epoch 14, Loss: 67.2251, R²: -0.0014, MAE: 2.6245
Epoch 15, Loss: 67.3845, R²: -0.0026, MAE: 2.6222
Epoch 16, Loss: 67.0242, R²: -0.0007, MAE: 2.6229
Epoch 17, Loss: 66.7060, R²: 0.0006, MAE: 2.6208
Epoch 18, Loss: 66.8955, R²: 0.0172, MAE: 2.6271
Epoch 19, Loss: 66.1187, R²: 0.0418, MAE: 2.5771
Ep