# Libraries & Dependencies

In [1]:
# ! pip uninstall -y numpy scipy scikit-learn pandas

In [2]:
# ! pip install "numpy<2.0" "scipy==1.11.4" "scikit-learn==1.2.2" pandas

In [4]:
#! pip install torch-geometric

In [5]:
# conda install -c conda-forge pytdc

In [51]:
import re
import json
import logging
from typing import List, Dict, Union, Tuple, Optional, Set

import numpy as np
import torch
from rdkit import Chem, RDLogger
from rdkit.Chem import AllChem, rdFingerprintGenerator
from torch_geometric.data import Data
from torch.utils.data import Dataset, DataLoader

import os
import yaml
import pandas as pd
from tdc.single_pred import ADME, Tox
from tdc.generation import MolGen

import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_add_pool

import torch.optim as optim
from sklearn.metrics import roc_auc_score, mean_absolute_error, average_precision_score
from scipy.stats import spearmanr
from tqdm import tqdm

from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader as PyGDataLoader
from torch_geometric.nn import GCNConv, global_add_pool

In [52]:
RDLogger.DisableLog('rdApp.*')

In [53]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("Data")

# Data Preparation

In [54]:
class SMILESTokenizer:
    """
    Regex-based tokenizer for chemical formulas.
    Essential for Sequence models (CNN, Mamba).
    """
    def __init__(self, vocab_file: str = None, max_len: int = 128):
        self.max_len = max_len
        self.pad_token = "<pad>"
        self.unk_token = "<unk>"
        self.sos_token = "<sos>"
        self.eos_token = "<eos>"
        self.token_pattern = re.compile(r"(\[[^\]]+]|Br?|Cl?|N|O|S|"
            "P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])")
        
        if vocab_file:
            self.load_vocab(vocab_file)
        else:
            self.vocab = {
                self.pad_token: 0,
                self.sos_token: 1,
                self.eos_token: 2,
                self.unk_token: 3
            }
            self.inverse_vocab = {v: k for k, v in self.vocab.items()}

    def train(self, smiles_list: List[str]):
        """Build vocabulary from a list of SMILES."""
        unique_tokens = set()
        for smi in smiles_list:
            tokens = self.token_pattern.findall(smi)
            unique_tokens.update(tokens)
        
        start_idx = len(self.vocab)
        for i, token in enumerate(sorted(unique_tokens)):
            self.vocab[token] = start_idx + i
            
        self.inverse_vocab = {v: k for k, v in self.vocab.items()}
        logger.info(f"Tokenizer trained. Vocab size: {len(self.vocab)}")

    def encode(self, smiles: str) -> torch.Tensor:
        """SMILES -> LongTensor [max_len]"""
        tokens = self.token_pattern.findall(smiles)
        ids = [self.vocab.get(t, self.vocab[self.unk_token]) for t in tokens]
        ids = [self.vocab[self.sos_token]] + ids + [self.vocab[self.eos_token]]
        
        if len(ids) < self.max_len:
            ids += [self.vocab[self.pad_token]] * (self.max_len - len(ids))
        else:
            ids = ids[:self.max_len-1] + [self.vocab[self.eos_token]]
            
        return torch.tensor(ids, dtype=torch.long)

    def save_vocab(self, path: str):
        with open(path, 'w') as f:
            json.dump(self.vocab, f)

    def load_vocab(self, path: str):
        with open(path, 'r') as f:
            self.vocab = json.load(f)
        self.inverse_vocab = {v: k for k, v in self.vocab.items()}

    def __len__(self):
        return len(self.vocab)

In [55]:
class MolFeaturizer:
    """Factory class to generate different molecular representations."""
    
    @staticmethod
    def smiles_to_morgan(smiles: str, radius: int = 2, n_bits: int = 2048) -> torch.Tensor:
        """Generates Morgan Fingerprint (ECFP)."""
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return torch.zeros(n_bits)
        fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
        arr = np.zeros((0,), dtype=np.int8)
        Chem.DataStructs.ConvertToNumpyArray(fp, arr)
        return torch.tensor(arr, dtype=torch.float32)

    @staticmethod
    def smiles_to_graph(smiles: str) -> Data:
        """
        Generates PyTorch Geometric Data object.
        Simple node features: AtomicNum (one-hot or integer).
        """
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return Data(x=torch.zeros((1, 1)), edge_index=torch.zeros((2, 0)))
        atom_features = []
        for atom in mol.GetAtoms():
            atom_features.append(atom.GetAtomicNum())
        
        x = torch.tensor(atom_features, dtype=torch.long).unsqueeze(1) # [N, 1]
        rows, cols = [], []
        for bond in mol.GetBonds():
            start = bond.GetBeginAtomIdx()
            end = bond.GetEndAtomIdx()
            rows.extend([start, end])
            cols.extend([end, start])
            
        edge_index = torch.tensor([rows, cols], dtype=torch.long)
        
        return Data(x=x, edge_index=edge_index)

In [56]:
class MolecularDataset(Dataset):
    def __init__(self, smiles_list, labels, featurizer):
        """
        Обновленный класс: принимает featurizer (функцию/объект), 
        а не строку modality.
        """
        self.smiles = smiles_list
        self.labels = labels
        self.featurizer = featurizer

    def __len__(self):
        return len(self.smiles)

    def __getitem__(self, idx):
        smi = self.smiles[idx]
        label = self.labels[idx]
        x = self.featurizer(smi)
        if label is None or np.isnan(label): 
            label = 0.0 
            
        y = torch.tensor([label], dtype=torch.float32)
        
        return x, y

In [57]:
DATA_DIR = "data"
RAW_DIR = os.path.join(DATA_DIR, "raw")
PROCESSED_DIR = os.path.join(DATA_DIR, "processed")
TOKENIZER_PATH = os.path.join(DATA_DIR, "tokenizer.json")
MAX_SEQ_LEN = 128  # Максимальная длина SMILES



def ensure_dirs():
    if not os.path.exists(RAW_DIR):
        os.makedirs(RAW_DIR)
    if not os.path.exists(PROCESSED_DIR):
        os.makedirs(PROCESSED_DIR)

def process_and_download_all():
    ensure_dirs()
    
    logger.info("Downloading ZINC for pretraining...")
    data = MolGen(name='ZINC', path=RAW_DIR) 
    split = data.get_split()
    
    train_smiles = split['train']['smiles'][:250000].tolist()
    
    logger.info(f"ZINC loaded. Using {len(train_smiles)} molecules for tokenizer training.")

    logger.info("Training Tokenizer...")
    tokenizer = SMILESTokenizer(max_len=MAX_SEQ_LEN)
    tokenizer.train(train_smiles)
    tokenizer.save_vocab(TOKENIZER_PATH)
    logger.info(f"Tokenizer saved to {TOKENIZER_PATH}")
    logger.info(f"Vocab size: {len(tokenizer.vocab)}")

    tasks = {
        'Caco2_Wang': 'Caco2',
        'HIA_Hou': 'HIA',
        'Pgp_Broccatelli': 'Pgp',
        'Bioavailability_Ma': 'Bioav',
        'Lipophilicity_AstraZeneca': 'Lipo',
        'Solubility_AqSolDB': 'AqSol'
    }
    
    for tdc_name, paper_name in tasks.items():
        logger.info(f"Downloading downstream task: {paper_name} ({tdc_name})...")
        try:
            data = ADME(name=tdc_name, path=RAW_DIR)
            split = data.get_split(method='scaffold', seed=42, frac=[0.8, 0.1, 0.1])
            
            save_task_dir = os.path.join(PROCESSED_DIR, paper_name)
            os.makedirs(save_task_dir, exist_ok=True)
            
            split['train'].to_csv(os.path.join(save_task_dir, 'train.csv'), index=False)
            split['val'].to_csv(os.path.join(save_task_dir, 'val.csv'), index=False)
            split['test'].to_csv(os.path.join(save_task_dir, 'test.csv'), index=False)
            
            logger.info(f"Task {paper_name} processed and saved to {save_task_dir}")
            
        except Exception as e:
            logger.error(f"Error downloading {paper_name}: {e}")



In [58]:
# process_and_download_all()
# logger.info("All data prepared successfully!")

# Feature Engineering

In [59]:
class MorganFeaturizer:
    def __init__(self, radius=2, n_bits=1024):
        self.generator = rdFingerprintGenerator.GetMorganGenerator(radius=radius, fpSize=n_bits)
        self.n_bits = n_bits
        
    def __call__(self, smiles):
        mol = Chem.MolFromSmiles(smiles)
        if mol is None: return torch.zeros(self.n_bits)
        return torch.tensor(self.generator.GetFingerprintAsNumPy(mol), dtype=torch.float32)

In [60]:
class GraphFeaturizer:
    def __call__(self, smiles: str):
        if not HAS_PYG: return None
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            # Заглушка тоже должна быть (1, 9)
            return Data(x=torch.zeros((1, 9)), edge_index=torch.empty((2, 0), dtype=torch.long))

        # Атомы
        atom_features = []
        for atom in mol.GetAtoms():
            # Базовые 5 фич
            feats = [
                atom.GetAtomicNum(),
                atom.GetDegree(),
                atom.GetExplicitValence(),
                int(atom.GetIsAromatic()),
                atom.GetFormalCharge()
            ]
            
            # !!! ИСПРАВЛЕНИЕ: Дополняем нулями до 9, чтобы совпадало с GCN(node_features=9) !!!
            padding = [0] * (9 - len(feats))
            feats += padding
            
            atom_features.append(feats)
            
        x = torch.tensor(atom_features, dtype=torch.float32)

        # Связи
        rows, cols = [], []
        for bond in mol.GetBonds():
            start = bond.GetBeginAtomIdx()
            end = bond.GetEndAtomIdx()
            rows.extend([start, end])
            cols.extend([end, start])

        if len(rows) == 0:
            edge_index = torch.empty((2, 0), dtype=torch.long)
        else:
            edge_index = torch.tensor([rows, cols], dtype=torch.long)

        return Data(x=x, edge_index=edge_index)

In [61]:
class CharTokenizer:
    """
    Modality 3: Sequence (for SMILES-CNN, Mamba)
    Character-level tokenization.
    """
    def __init__(self, max_len: int = 128, vocab_path: Optional[str] = None):
        self.max_len = max_len
        self.pad_token = "<pad>"
        self.sos_token = "<sos>"
        self.eos_token = "<eos>"
        self.unk_token = "<unk>"
        
        base_chars = [
            'C', 'c', 'O', 'o', 'N', 'n', 'S', 's',
            '=', '#', '(', ')', '[', ']', '1', '2', '3', 
            '+', '-', '.', 'F', 'Cl', 'Br', 'I', 'H'
        ]
        
        self.vocab = {
            self.pad_token: 0,
            self.sos_token: 1,
            self.eos_token: 2,
            self.unk_token: 3
        }
        
        for char in base_chars:
            if char not in self.vocab:
                self.vocab[char] = len(self.vocab)
                
        self.inverse_vocab = {v: k for k, v in self.vocab.items()}
        
        if vocab_path:
            self.load_vocab(vocab_path)

    def fit_on_zinc(self, smiles_list: List[str]):
        """
        Scans ZINC dataset to add missing characters (e.g., '4', '5', 'P', '/', '\')
        to the vocabulary to ensure coverage.
        """
        unique_chars = set()
        for smi in smiles_list:
            unique_chars.update(list(smi))
            
        for char in sorted(list(unique_chars)):
            if char not in self.vocab:
                self.vocab[char] = len(self.vocab)
        
        self.inverse_vocab = {v: k for k, v in self.vocab.items()}
        print(f"Tokenizer fitted. Vocab size: {len(self.vocab)}")

    def encode(self, smiles: str) -> torch.Tensor:
        """
        Process: String -> List[Int] -> Pad/Truncate -> Tensor
        """
        chars = list(smiles)
        
        ids = [self.vocab.get(c, self.vocab[self.unk_token]) for c in chars]
        
        ids = [self.vocab[self.sos_token]] + ids + [self.vocab[self.eos_token]]
        
        if len(ids) < self.max_len:
            padding = [self.vocab[self.pad_token]] * (self.max_len - len(ids))
            ids = ids + padding
        else:
            ids = ids[:self.max_len - 1] + [self.vocab[self.eos_token]]
            
        return torch.tensor(ids, dtype=torch.long)

    def save_vocab(self, path: str):
        with open(path, 'w') as f:
            json.dump(self.vocab, f)
            
    def load_vocab(self, path: str):
        with open(path, 'r') as f:
            self.vocab = json.load(f)
        self.inverse_vocab = {v: k for k, v in self.vocab.items()}

In [62]:
smi = "CN1C=NC2=C1C(=O)N(C(=O)N2C)C"

morgan = MorganFeaturizer(radius=2, n_bits=1024)
print(f"Morgan Shape: {morgan(smi).shape}")

graph_gen = GraphFeaturizer()
graph_data = graph_gen(smi)
print(f"Graph Nodes: {graph_data.x.shape}")
print(f"Graph Edges: {graph_data.edge_index.shape}")

tokenizer = CharTokenizer(max_len=128)
tokens = tokenizer.encode(smi)
print(f"Tokens Shape: {tokens.shape}")
print(f"Tokens: {tokens[:15]}")

Morgan Shape: torch.Size([1024])
Graph Nodes: torch.Size([14, 9])
Graph Edges: torch.Size([2, 30])
Tokens Shape: torch.Size([128])
Tokens: tensor([ 1,  4,  8, 18,  4, 12,  8,  4, 19, 12,  4, 18,  4, 14, 12])


# Models

In [82]:
class BaseModel(nn.Module):
    """Base class with utility to count parameters."""
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


In [83]:
# ==========================================
# Baseline 1: Morgan + MLP
# ==========================================
class MorganMLP(BaseModel):
    def __init__(self, input_dim=1024, output_dim=1, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.ReLU(),
            nn.Dropout(dropout),
            
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(dropout),
            
            nn.Linear(512, 128),
            nn.ReLU(),

            nn.Linear(128, output_dim)
        )

    def forward(self, x):
        return self.net(x)

In [84]:
MorganMLP()

MorganMLP(
  (net): Sequential(
    (0): Linear(in_features=1024, out_features=1024, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=1024, out_features=512, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.1, inplace=False)
    (6): Linear(in_features=512, out_features=128, bias=True)
    (7): ReLU()
    (8): Linear(in_features=128, out_features=1, bias=True)
  )
)

In [64]:
# ==========================================
# Baseline 2: SMILES + CNN
# ==========================================
class SmilesCNN(BaseModel):
    def __init__(self, vocab_size, embed_dim=64, output_dim=1, max_len=128):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        
        self.conv1 = nn.Conv1d(in_channels=embed_dim, out_channels=32, kernel_size=4)
        self.conv2 = nn.Conv1d(in_channels=32, out_channels=64, kernel_size=6)
        self.conv3 = nn.Conv1d(in_channels=64, out_channels=96, kernel_size=8)
        
        self.mlp = nn.Sequential(
            nn.Linear(96, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.Linear(32, output_dim)
        )

    def forward(self, x):
        x = self.embedding(x)
        x = x.transpose(1, 2)
        
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        
        x = torch.max(x, dim=2)[0]
        
        return self.mlp(x)

In [86]:
SmilesCNN(64)

SmilesCNN(
  (embedding): Embedding(64, 64, padding_idx=0)
  (conv1): Conv1d(64, 32, kernel_size=(4,), stride=(1,))
  (conv2): Conv1d(32, 64, kernel_size=(6,), stride=(1,))
  (conv3): Conv1d(64, 96, kernel_size=(8,), stride=(1,))
  (mlp): Sequential(
    (0): Linear(in_features=96, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=32, bias=True)
    (3): Linear(in_features=32, out_features=1, bias=True)
  )
)

In [65]:
# ==========================================
# Baseline 3: GCN
# ==========================================
class GCN(nn.Module):
    def __init__(self, node_features=9, hidden_dim=100):
        super().__init__()
        self.convs = nn.ModuleList([GCNConv(node_features, hidden_dim)])
        for _ in range(4): self.convs.append(GCNConv(hidden_dim, hidden_dim))
        self.head = nn.Linear(hidden_dim, 1)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
        x = global_add_pool(x, batch)
        return self.head(x)

In [87]:
GCN()

GCN(
  (convs): ModuleList(
    (0): GCNConv(9, 100)
    (1-4): 4 x GCNConv(100, 100)
  )
  (head): Linear(in_features=100, out_features=1, bias=True)
)

In [66]:
# ==========================================
# Baseline 4: NeuralFP (Frozen Encoder)
# ==========================================
class NeuralFP(BaseModel):
    def __init__(self, pretrained_gcn: GCN, output_dim=1):
        super().__init__()
        self.encoder = pretrained_gcn
        self.encoder.head = nn.Identity()
        
        for param in self.encoder.parameters():
            param.requires_grad = False
            
        self.decoder = nn.Sequential(
            nn.Linear(100, 200),
            nn.ReLU(),
            nn.Linear(200, 100),
            nn.ReLU(),
            nn.Linear(100, 50),
            nn.ReLU(),
            nn.Linear(50, output_dim)
        )

    def forward(self, data):
        with torch.no_grad():
            embedding = self.encoder(data)

        return self.decoder(embedding)
    
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

In [89]:
NeuralFP(GCN())

NeuralFP(
  (encoder): GCN(
    (convs): ModuleList(
      (0): GCNConv(9, 100)
      (1-4): 4 x GCNConv(100, 100)
    )
    (head): Identity()
  )
  (decoder): Sequential(
    (0): Linear(in_features=100, out_features=200, bias=True)
    (1): ReLU()
    (2): Linear(in_features=200, out_features=100, bias=True)
    (3): ReLU()
    (4): Linear(in_features=100, out_features=50, bias=True)
    (5): ReLU()
    (6): Linear(in_features=50, out_features=1, bias=True)
  )
)

# Pretraining

In [67]:
class UniversalTrainer:
    def __init__(self, model, params, task_config, device):
        self.model = model.to(device)
        self.task_type = task_config['type']
        self.metric_name = task_config['metric']
        self.device = device
        self.optimizer = optim.AdamW(model.parameters(), lr=params['learning_rate'])
        self.criterion = nn.BCEWithLogitsLoss() if self.task_type == 'classification' else nn.L1Loss()

    def run(self, train, val, test, epochs):
        train_l = get_dataloader(train, 32, True)
        val_l = get_dataloader(val, 32, False)
        test_l = get_dataloader(test, 32, False)
        
        higher_is_better = self.metric_name in ['roc-auc', 'pr-auc', 'spearman']
        best_score = -float('inf') if higher_is_better else float('inf')
        best_state = None

        for epoch in range(epochs):
            self.model.train()
            for batch in tqdm(train_l, desc=f"Ep {epoch+1}", leave=False):
                self.optimizer.zero_grad()
                if isinstance(batch, list): x, y = batch[0].to(self.device), batch[1].to(self.device)
                else: x, y = batch.to(self.device), batch.y.to(self.device)
                
                pred = self.model(x)
                if y.shape!=pred.shape: y=y.view_as(pred)
                loss = self.criterion(pred, y)
                loss.backward()
                self.optimizer.step()
            val_score = self.evaluate(val_l)
            improved = (val_score > best_score) if higher_is_better else (val_score < best_score)
            if improved:
                best_score = val_score
                best_state = self.model.state_dict()
        if best_state: self.model.load_state_dict(best_state)
        return self.evaluate(test_l)

    def evaluate(self, loader):
        self.model.eval()
        preds, targets = [], []
        with torch.no_grad():
            for batch in loader:
                if isinstance(batch, list): x, y = batch[0].to(self.device), batch[1].to(self.device)
                else: x, y = batch.to(self.device), batch.y.to(self.device)
                out = self.model(x)
                preds.extend(out.cpu().numpy())
                targets.extend(y.cpu().numpy())
        
        preds = np.array(preds).flatten()
        targets = np.array(targets).flatten()
        
        if np.isnan(preds).any(): preds = np.nan_to_num(preds)

        if self.metric_name == 'roc-auc':
            try:
                probs = 1 / (1 + np.exp(-preds))
                if len(np.unique(targets)) < 2: return 0.5
                return roc_auc_score(targets, probs)
            except: return 0.5
        elif self.metric_name == 'pr-auc':
            try:
                probs = 1 / (1 + np.exp(-preds))
                return average_precision_score(targets, probs)
            except: return 0.0
        elif self.metric_name == 'mae':
            return mean_absolute_error(targets, preds)
        elif self.metric_name == 'spearman':
            val, _ = spearmanr(targets, preds)
            return 0.0 if np.isnan(val) else val
        return 0.0


In [68]:
class ModelFactory:
    @staticmethod
    def create_model(model_name, config, device, tokenizer=None):
        """
        Args:
            model_name: One of ['morgan_mlp', 'cnn', 'gcn', 'neural_fp', 'mamba']
            config: The dict from config.yaml
            device: torch.device
            tokenizer: Instance of CharTokenizer/SMILESTokenizer (needed for seq models)
        Returns:
            model (nn.Module), modality_type (str)
        """
        
        if model_name == 'morgan_mlp':
            input_dim = config['featurization']['morgan_nbits']
            model = MorganMLP(input_dim=input_dim)
            return model, 'morgan'

        elif model_name == 'cnn':
            if tokenizer is None:
                raise ValueError("Tokenizer required for CNN initialization")
            
            vocab_size = len(tokenizer.vocab)
            model = SmilesCNN(
                vocab_size=vocab_size,
                embed_dim=64,
                max_len=config['featurization']['max_seq_len']
            )
            return model, 'seq'

        elif model_name == 'gcn':
            model = GCN(
                node_features=6, 
                hidden_dim=config['model_params']['gnn_hidden_dim']
            )
            return model, 'graph'

        elif model_name == 'neural_fp':
            base_gcn = GCN(node_features=6, hidden_dim=config['model_params']['gnn_hidden_dim'])
            
            model = NeuralFP(pretrained_gcn=base_gcn)
            return model, 'graph'
        else:
            raise ValueError(f"Unknown model name: {model_name}")

In [75]:
TASKS_CONFIG = {
    # --- Distribution (Table 2) ---
    'BBB':      {'tdc_name': 'BBB_Martins',       'type': 'classification', 'metric': 'roc-auc'},
    'PPBR':     {'tdc_name': 'PPBR_AZ',           'type': 'regression',     'metric': 'mae'},
    'VD':       {'tdc_name': 'VDss_Lombardo',     'type': 'regression',     'metric': 'mae'},

    # --- Metabolism (Table 3) ---
    'CYP2D6-I': {'tdc_name': 'CYP2D6_Veith',      'type': 'classification', 'metric': 'pr-auc'},
    'CYP3A4-I': {'tdc_name': 'CYP3A4_Veith',      'type': 'classification', 'metric': 'pr-auc'},
    'CYP2C9-I': {'tdc_name': 'CYP2C9_Veith',      'type': 'classification', 'metric': 'pr-auc'},
    
    'CYP2D6-S': {'tdc_name': 'CYP2D6_Substrate_CarbonMangels', 'type': 'classification', 'metric': 'pr-auc'},
    'CYP3A4-S': {'tdc_name': 'CYP3A4_Substrate_CarbonMangels', 'type': 'classification', 'metric': 'roc-auc'},
    'CYP2C9-S': {'tdc_name': 'CYP2C9_Substrate_CarbonMangels', 'type': 'classification', 'metric': 'pr-auc'},

    # --- Pharmacokinetics ---
    'Half-Life':{'tdc_name': 'Half_Life_Obach',        'type': 'regression', 'metric': 'spearman'},
    'CL-Micro': {'tdc_name': 'Clearance_Microsome_AZ', 'type': 'regression', 'metric': 'spearman'},
    'CL-Hepa':  {'tdc_name': 'Clearance_Hepatocyte_AZ','type': 'regression', 'metric': 'spearman'},

    # --- Toxicity (Table 4 & 1) ---
    'hERG':     {'tdc_name': 'hERG',         'type': 'classification', 'metric': 'roc-auc'},
    'AMES':     {'tdc_name': 'AMES',         'type': 'classification', 'metric': 'roc-auc'},
    'DILI':     {'tdc_name': 'DILI',         'type': 'classification', 'metric': 'roc-auc'},
    'LD50':     {'tdc_name': 'LD50_Zhu',     'type': 'regression',     'metric': 'mae'},
    
    # --- Absorption (Table 1) ---
    'Caco2':    {'tdc_name': 'Caco2_Wang',   'type': 'regression',     'metric': 'mae'},
    'HIA':      {'tdc_name': 'HIA_Hou',      'type': 'classification', 'metric': 'roc-auc'},
    'Pgp':      {'tdc_name': 'Pgp_Broccatelli','type': 'classification', 'metric': 'roc-auc'},
    'Bioav':    {'tdc_name': 'Bioavailability_Ma','type': 'classification', 'metric': 'roc-auc'},
    'Lipo':     {'tdc_name': 'Lipophilicity_AstraZeneca', 'type': 'regression', 'metric': 'mae'},
    'AqSol':    {'tdc_name': 'Solubility_AqSolDB', 'type': 'regression', 'metric': 'mae'},
}

CFG = {
    'device': "cuda" if torch.cuda.is_available() else "cpu",
    'data_path': './data/raw',
    'training': {'learning_rate': 1e-3, 'epochs': 10} 
}

In [76]:
def get_data(cfg_entry):
    name = cfg_entry['tdc_name']
    try: data = ADME(name=name, path=CFG['data_path'])
    except: data = Tox(name=name, path=CFG['data_path'])
    s = data.get_split(method='scaffold', seed=42, frac=[0.8, 0.1, 0.1])
    return s['train'], s['valid'], s['test']

def main():
    device = torch.device(CFG['device'])
    logger.info(f"Device: {device}")
    
    tok = CharTokenizer()
    morgan = MorganFeaturizer()
    graph = GraphFeaturizer()
    
    models = ['morgan_mlp', 'cnn', 'gcn', 'neural_fp']
    
    target_tasks = list(TASKS_CONFIG.keys())
    
    all_results = {}

    print(f"Tasks scheduled: {target_tasks}")

    for task in target_tasks:
        t_cfg = TASKS_CONFIG[task]
        metric_display = t_cfg['metric'].upper()
        logger.info(f"\n{'='*50}\nTask: {task} ({t_cfg['tdc_name']}) | Metric: {metric_display}\n{'='*50}")
        
        all_results[task] = {}

        try:
            train, val, test = get_data(t_cfg)
        except Exception as e:
            logger.error(f"Failed to load {task}: {e}"); continue

        for m_name in models:
            logger.info(f"Training {m_name}...")
            
            if m_name == 'morgan_mlp':
                feat = morgan; model = MorganMLP()
            elif m_name == 'cnn':
                feat = lambda x: tok.encode(x); model = SmilesCNN(len(tok.vocab))
            elif m_name in ['gcn', 'neural_fp']:
                if not HAS_PYG: 
                    all_results[task][m_name] = None
                    continue
                feat = graph; base = GCN()
                model = base if m_name=='gcn' else NeuralFP(base)
            
            tr_ds = MolecularDataset(train['Drug'].tolist(), train['Y'].tolist(), feat)
            va_ds = MolecularDataset(val['Drug'].tolist(), val['Y'].tolist(), feat)
            te_ds = MolecularDataset(test['Drug'].tolist(), test['Y'].tolist(), feat)
            
            trainer = UniversalTrainer(model, CFG['training'], t_cfg, device)
            final_metric = trainer.run(tr_ds, va_ds, te_ds, CFG['training']['epochs'])
            
            all_results[task][m_name] = final_metric
            logger.info(f">>> {task} | {m_name} Test {metric_display}: {final_metric:.4f}")


    print("\n" + "="*85)
    print("FINAL RESULTS SUMMARY (Best Test Metrics)")
    print("="*85)
    
    header = f"{'Task':<15} | {'Metric':<9} | " + " | ".join([f"{m:>10}" for m in models])
    print(header)
    print("-" * len(header))
    
    for task in target_tasks:
        if task not in all_results: continue
        
        metric_name = TASKS_CONFIG[task]['metric'].upper()
        row_str = f"{task:<15} | {metric_name:<9} | "
        
        for m in models:
            val = all_results[task].get(m, None)
            if val is None:
                val_str = "   N/A    "
            else:
                val_str = f"{val:10.4f}"
            row_str += f"{val_str} | "
        print(row_str)
        
    print("="*85)

In [77]:
main()

2025-12-09 21:41:13,081 - INFO - Device: cpu
2025-12-09 21:41:13,083 - INFO - 
Task: BBB (BBB_Martins) | Metric: ROC-AUC
Found local copy...
Loading...
Done!


Tasks scheduled: ['BBB', 'PPBR', 'VD', 'CYP2D6-I', 'CYP3A4-I', 'CYP2C9-I', 'CYP2D6-S', 'CYP3A4-S', 'CYP2C9-S', 'Half-Life', 'CL-Micro', 'CL-Hepa', 'hERG', 'AMES', 'DILI', 'LD50', 'Caco2', 'HIA', 'Pgp', 'Bioav', 'Lipo', 'AqSol']


100%|█████████████████████████████████████| 2030/2030 [00:01<00:00, 1862.99it/s]
2025-12-09 21:41:14,190 - INFO - Training morgan_mlp...
2025-12-09 21:41:31,091 - INFO - >>> BBB | morgan_mlp Test ROC-AUC: 0.8429      
2025-12-09 21:41:31,091 - INFO - Training cnn...
2025-12-09 21:41:44,230 - INFO - >>> BBB | cnn Test ROC-AUC: 0.9058             
2025-12-09 21:41:44,231 - INFO - Training gcn...
2025-12-09 21:42:06,338 - INFO - >>> BBB | gcn Test ROC-AUC: 0.4588             
2025-12-09 21:42:06,339 - INFO - Training neural_fp...
2025-12-09 21:42:24,894 - INFO - >>> BBB | neural_fp Test ROC-AUC: 0.7388       
2025-12-09 21:42:24,895 - INFO - 
Task: PPBR (PPBR_AZ) | Metric: MAE
Downloading...
100%|███████████████████████████████████████| 265k/265k [00:03<00:00, 85.6kiB/s]
Loading...
Done!
100%|█████████████████████████████████████| 1614/1614 [00:00<00:00, 1664.59it/s]
2025-12-09 21:42:32,609 - INFO - Training morgan_mlp...
2025-12-09 21:42:48,080 - INFO - >>> PPBR | morgan_mlp Test MAE: 11


FINAL RESULTS SUMMARY (Best Test Metrics)
Task            | Metric    | morgan_mlp |        cnn |        gcn |  neural_fp
-------------------------------------------------------------------------------
BBB             | ROC-AUC   |     0.8429 |     0.9058 |     0.4588 |     0.7388 | 
PPBR            | MAE       |    11.9905 |    10.4078 |    15.9136 |    16.9043 | 
VD              | MAE       |     2.4211 |     2.3123 |     2.5624 |     2.4784 | 
CYP2D6-I        | PR-AUC    |     0.5425 |     0.6008 |     0.1764 |     0.2612 | 
CYP3A4-I        | PR-AUC    |     0.7866 |     0.7958 |     0.5398 |     0.5812 | 
CYP2C9-I        | PR-AUC    |     0.7242 |     0.6923 |     0.5775 |     0.5417 | 
CYP2D6-S        | PR-AUC    |     0.7373 |     0.5490 |     0.3902 |     0.4028 | 
CYP3A4-S        | ROC-AUC   |     0.7239 |     0.6355 |     0.5769 |     0.5545 | 
CYP2C9-S        | PR-AUC    |     0.4223 |     0.3883 |     0.3713 |     0.3680 | 
Half-Life       | SPEARMAN  |     0.6276 |     0.0