In [1]:
pip install pytdc

Collecting networkx>=2.7.1 (from scanpy>=1.9.2->tiledbsoma<2.0.0,>=1.7.2->pytdc)
  Downloading networkx-3.5-py3-none-any.whl.metadata (6.3 kB)
Downloading networkx-3.5-py3-none-any.whl (2.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m55.1 MB/s[0m eta [36m0:00:00[0m
[0mInstalling collected packages: networkx
Successfully installed networkx-3.5


In [2]:
pip install rdkit-pypi --upgrade




In [3]:
pip install mordred

Collecting mordred
  Using cached mordred-1.2.0-py3-none-any.whl
Collecting networkx==2.* (from mordred)
  Using cached networkx-2.8.8-py3-none-any.whl.metadata (5.1 kB)
Using cached networkx-2.8.8-py3-none-any.whl (2.0 MB)
Installing collected packages: networkx, mordred
  Attempting uninstall: networkx
    Found existing installation: networkx 3.5
    Uninstalling networkx-3.5:
      Successfully uninstalled networkx-3.5
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
scikit-image 0.25.2 requires networkx>=3.0, but you have networkx 2.8.8 which is incompatible.
nx-cugraph-cu12 25.6.0 requires networkx>=3.2, but you have networkx 2.8.8 which is incompatible.[0m[31m
[0mSuccessfully installed mordred-1.2.0 networkx-2.8.8


In [4]:
pip install torch



In [5]:
pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m63.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


# **DATA PREPARATION**

In [6]:
from tdc.single_pred import ADME
from tdc import utils
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem, rdmolops

# Step 1: Identify valid ADME tasks
exclude = {'caco2_wang', 'solubility_aqsoldb', 'half_life_obach'}
valid_adme = []

for task in utils.retrieve_benchmark_names('admet_group'):
    if task in exclude:
        continue
    try:
        ADME(name=task)  # Will throw if invalid
        valid_adme.append(task)
    except Exception as e:
        print(f"Skipping {task}: {e}")

# Step 2: Load data and detect task type
all_adme_data = {}
for task in valid_adme:
    df = ADME(name=task).get_data()
    if 'Y' not in df.columns:
        print(f"Skipping {task} — no 'Y' column")
        continue

    unique_vals = df['Y'].dropna().unique()
    if df['Y'].dtype in ['float64', 'int64']:
        if len(unique_vals) <= 10 and set(unique_vals).issubset({0, 1}):
            task_type = 'classification'
        else:
            task_type = 'regression'
    else:
        task_type = 'classification'

    all_adme_data[task] = {'data': df, 'type': task_type}

for task, content in all_adme_data.items():
    print(f"{task}: {content['type']}, n = {len(content['data'])}")

# Step 3: Find overlapping molecules from fine-tune tasks
fine_tune_tasks = ['solubility_aqsoldb', 'caco2_wang', 'half_life_obach']
exclude_smiles = set()

for task in fine_tune_tasks:
    df = ADME(name=task).get_data()
    exclude_smiles.update(df['Drug'].unique())

print(f"Excluding {len(exclude_smiles)} molecules from pretraining due to overlap")

# Step 4: Merge pretraining datasets
pretrain_tasks = [
    'hia_hou', 'pgp_broccatelli', 'bioavailability_ma', 'lipophilicity_astrazeneca', 'bbb_martins',
    'ppbr_az', 'vdss_lombardo', 'cyp2d6_veith', 'cyp3a4_veith', 'cyp2c9_veith',
    'cyp2d6_substrate_carbonmangels', 'cyp3a4_substrate_carbonmangels', 'cyp2c9_substrate_carbonmangels',
    'clearance_microsome_az', 'clearance_hepatocyte_az'
]

merged_df = []
for task in pretrain_tasks:
    df = ADME(name=task).get_data()
    df = df[~df['Drug'].isin(exclude_smiles)].copy()
    df['task'] = task
    df['task_type'] = all_adme_data[task]['type']
    merged_df.append(df[['Drug', 'Y', 'task', 'task_type']])

merged_df = pd.concat(merged_df, ignore_index=True)
print(f"Merged multitask pretraining dataset size: {merged_df.shape}")

# Step 5: Molecule preprocessing utilities
def neutralize_charges(mol):
    if mol is None:
        return None
    patterns = [
        ('[n+;H]', 'n'), ('[N+;!H0]', 'N'), ('[$([O-]);!$([O-][#7])]', 'O'),
        ('[O-;r6]', 'O'), ('[O-;R1]=[C;R1]', 'O'), ('[n-]', 'n'),
        ('[S-;X1]', 'S'), ('[$([N-;X2]S(=O)=O)]', 'N'), ('[C-](=O)', 'C'),
    ]
    for pattern, replacement in patterns:
        patt = Chem.MolFromSmarts(pattern)
        repl = Chem.MolFromSmiles(replacement)
        if patt and repl:
            mol = AllChem.ReplaceSubstructs(mol, patt, repl, replaceAll=True)[0]
    try:
        Chem.SanitizeMol(mol)
        return mol
    except:
        return None

def standardize_smiles(smiles):
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return None
        mol = Chem.RemoveHs(mol)
        rdmolops.RemoveStereochemistry(mol)
        frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=True)
        mol = max(frags, key=lambda m: m.GetNumAtoms()) if frags else mol
        mol = neutralize_charges(mol)
        return Chem.MolToSmiles(mol, canonical=True) if mol else None
    except:
        return None

def preprocess_df(df, smiles_col='Drug', target_col='Y', task_col='task'):
    df = df.copy()
    df[smiles_col] = df[smiles_col].apply(standardize_smiles)
    n_invalid = df[smiles_col].isna().sum()
    print(f"Removed {n_invalid} invalid SMILES")
    df = df.dropna(subset=[smiles_col])
    before = len(df)
    df = df.drop_duplicates(subset=[smiles_col, task_col])
    print(f"Removed {before - len(df)} (smiles, task) duplicates")
    before = len(df)
    df = df.dropna(subset=[target_col])
    print(f"Removed {before - len(df)} rows with missing target Y")
    return df.reset_index(drop=True)

# Step 6: Preprocess merged dataset
processed_df = preprocess_df(merged_df, smiles_col='Drug', target_col='Y')


Downloading...
100%|██████████| 40.1k/40.1k [00:00<00:00, 183kiB/s] 
Loading...
Done!
Downloading...
100%|██████████| 126k/126k [00:00<00:00, 288kiB/s] 
Loading...
Done!
Downloading...
100%|██████████| 43.7k/43.7k [00:00<00:00, 204kiB/s] 
Loading...
Done!
Downloading...
100%|██████████| 298k/298k [00:00<00:00, 345kiB/s]
Loading...
Done!
Downloading...
100%|██████████| 138k/138k [00:00<00:00, 213kiB/s]
Loading...
Done!
Downloading...
100%|██████████| 265k/265k [00:00<00:00, 409kiB/s]
Loading...
Done!
Downloading...
100%|██████████| 89.9k/89.9k [00:00<00:00, 207kiB/s]
Loading...
Done!
Downloading...
100%|██████████| 800k/800k [00:01<00:00, 742kiB/s]
Loading...
Done!
Downloading...
100%|██████████| 746k/746k [00:01<00:00, 689kiB/s]
Loading...
Done!
Downloading...
100%|██████████| 740k/740k [00:01<00:00, 679kiB/s]
Loading...
Done!
Downloading...
100%|██████████| 45.4k/45.4k [00:00<00:00, 212kiB/s] 
Loading...
Done!
Downloading...
100%|██████████| 46.0k/46.0k [00:00<00:00, 214kiB/s] 
Loadin

Skipping herg: ('herg', 'does not match to available values. Please double check.')
Skipping ames: ('ames', 'does not match to available values. Please double check.')
Skipping dili: ('dili', 'does not match to available values. Please double check.')
Skipping ld50_zhu: ('ld50_zhu', 'does not match to available values. Please double check.')
hia_hou: classification, n = 578
pgp_broccatelli: classification, n = 1218
bioavailability_ma: classification, n = 640
lipophilicity_astrazeneca: regression, n = 4200
bbb_martins: classification, n = 2030
ppbr_az: regression, n = 1614
vdss_lombardo: regression, n = 1130
cyp2d6_veith: classification, n = 13130
cyp3a4_veith: classification, n = 12328
cyp2c9_veith: classification, n = 12092
cyp2d6_substrate_carbonmangels: classification, n = 667
cyp3a4_substrate_carbonmangels: classification, n = 670
cyp2c9_substrate_carbonmangels: classification, n = 669
clearance_microsome_az: regression, n = 1102
clearance_hepatocyte_az: regression, n = 1213


100%|██████████| 853k/853k [00:01<00:00, 656kiB/s]
Loading...
Done!
Downloading...
100%|██████████| 82.5k/82.5k [00:00<00:00, 192kiB/s]
Loading...
Done!
Downloading...
100%|██████████| 53.6k/53.6k [00:00<00:00, 245kiB/s] 
Loading...
Done!
Found local copy...
Loading...
Done!
Found local copy...
Loading...
Done!
Found local copy...
Loading...
Done!
Found local copy...
Loading...
Done!
Found local copy...
Loading...
Done!
Found local copy...
Loading...
Done!
Found local copy...
Loading...
Done!
Found local copy...
Loading...
Done!
Found local copy...
Loading...
Done!
Found local copy...
Loading...
Done!
Found local copy...
Loading...
Done!
Found local copy...


Excluding 11242 molecules from pretraining due to overlap


Loading...
Done!
Found local copy...
Loading...
Done!
Found local copy...
Loading...
Done!
Found local copy...
Loading...
Done!
[08:24:26] non-ring atom 0 marked aromatic
[08:24:26] non-ring atom 0 marked aromatic
[08:24:26] non-ring atom 0 marked aromatic
[08:24:26] non-ring atom 0 marked aromatic
[08:24:26] non-ring atom 0 marked aromatic
[08:24:26] non-ring atom 0 marked aromatic
[08:24:26] non-ring atom 0 marked aromatic
[08:24:26] non-ring atom 0 marked aromatic
[08:24:26] non-ring atom 0 marked aromatic
[08:24:26] non-ring atom 0 marked aromatic
[08:24:26] non-ring atom 0 marked aromatic
[08:24:26] non-ring atom 0 marked aromatic
[08:24:26] non-ring atom 0 marked aromatic
[08:24:26] non-ring atom 0 marked aromatic
[08:24:26] non-ring atom 0 marked aromatic
[08:24:26] non-ring atom 0 marked aromatic
[08:24:26] non-ring atom 0 marked aromatic
[08:24:26] non-ring atom 0 marked aromatic
[08:24:26] non-ring atom 0 marked aromatic
[08:24:26] non-ring atom 0 marked aromatic
[08:24:26] n

Merged multitask pretraining dataset size: (49758, 4)


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
[08:25:26] non-ring atom 0 marked aromatic
[08:25:26] non-ring atom 0 marked aromatic
[08:25:26] non-ring atom 0 marked aromatic
[08:25:26] non-ring atom 0 marked aromatic
[08:25:26] non-ring atom 0 marked aromatic
[08:25:26] non-ring atom 0 marked aromatic
[08:25:26] non-ring atom 0 marked aromatic
[08:25:26] non-ring atom 0 marked aromatic
[08:25:26] non-ring atom 0 marked aromatic
[08:25:26] non-ring atom 0 marked aromatic
[08:25:26] non-ring atom 0 marked aromatic
[08:25:26] non-ring atom 0 marked aromatic
[08:25:26] non-ring atom 0 marked aromatic
[08:25:26] non-ring atom 0 marked aromatic
[08:25:26] non-ring atom 0 marked aromatic
[08:25:26] non-ring atom 0 marked aromatic
[08:25:26] non-ring atom 0 marked aromatic
[08:25:26] non-ring atom 0 marked aromatic
[08:25:26] non-ring atom 0 marked aromatic
[08:25:26] non-ring atom 0 marked aromatic
[08:25:26] non-ring atom 0 marked aromatic
[08:25:26] non-ring atom 0 marke

Removed 24 invalid SMILES
Removed 1148 (smiles, task) duplicates
Removed 0 rows with missing target Y


[08:25:30] non-ring atom 0 marked aromatic
[08:25:30] non-ring atom 0 marked aromatic
[08:25:30] non-ring atom 0 marked aromatic
[08:25:30] non-ring atom 0 marked aromatic
[08:25:30] non-ring atom 0 marked aromatic
[08:25:30] non-ring atom 0 marked aromatic
[08:25:30] non-ring atom 0 marked aromatic
[08:25:30] non-ring atom 0 marked aromatic
[08:25:30] non-ring atom 0 marked aromatic
[08:25:30] non-ring atom 0 marked aromatic
[08:25:30] non-ring atom 0 marked aromatic
[08:25:30] non-ring atom 0 marked aromatic
[08:25:30] non-ring atom 0 marked aromatic
[08:25:30] non-ring atom 0 marked aromatic
[08:25:30] non-ring atom 0 marked aromatic
[08:25:30] non-ring atom 0 marked aromatic
[08:25:30] non-ring atom 0 marked aromatic
[08:25:30] non-ring atom 0 marked aromatic
[08:25:30] non-ring atom 0 marked aromatic
[08:25:30] non-ring atom 0 marked aromatic
[08:25:30] non-ring atom 0 marked aromatic
[08:25:30] non-ring atom 0 marked aromatic
[08:25:30] non-ring atom 0 marked aromatic
[08:25:30] 

In [7]:
processed_df.head()

Unnamed: 0,Drug,Y,task,task_type
0,C=C1C(=CC=C2CCCC3(C)C2CCC3C(C)CCCC(C)C)CC(O)CC1O,1.0,hia_hou,classification
1,Cc1c(N(C)C)c(=N)n(-c2ccccc2)n1C,1.0,hia_hou,classification
2,COc1ccccc1OCCNCC(O)c1ccc(C)cc1S(N)(=O)=O,1.0,hia_hou,classification
3,CC(C)(C#N)c1cc(Cn2cncn2)cc(C(C)(C)C#N)c1,1.0,hia_hou,classification
4,NS(=O)(=O)c1cc2c(cc1C(F)(F)F)NC(Cc1ccccc1)NS2(...,1.0,hia_hou,classification


In [8]:
from rdkit.Chem.Scaffolds import MurckoScaffold
import numpy as np

def scaffold_split(df, smiles_col='Drug', frac_train=0.8, frac_val=0.1, frac_test=0.1, seed=42):
    np.random.seed(seed)

    def get_scaffold(smiles):
        mol = Chem.MolFromSmiles(smiles)
        return MurckoScaffold.MurckoScaffoldSmiles(mol=mol) if mol else None

    df = df.copy()
    df['scaffold'] = df[smiles_col].apply(get_scaffold)
    scaffold_groups = df.groupby('scaffold').indices.values()
    scaffold_groups = sorted(scaffold_groups, key=len, reverse=True)

    n_total = len(df)
    n_train = int(frac_train * n_total)
    n_val = int(frac_val * n_total)
    n_test = n_total - n_train - n_val

    train_idx, val_idx, test_idx = [], [], []
    counts = [0, 0, 0]  # train, val, test

    for group in scaffold_groups:
        sizes = [len(train_idx), len(val_idx), len(test_idx)]
        target = np.argmin([
            sizes[0] / n_train,
            sizes[1] / n_val,
            sizes[2] / n_test
        ])
        if target == 0:
            train_idx.extend(group)
        elif target == 1:
            val_idx.extend(group)
        else:
            test_idx.extend(group)

    return (
        df.iloc[train_idx].drop(columns=['scaffold']).reset_index(drop=True),
        df.iloc[val_idx].drop(columns=['scaffold']).reset_index(drop=True),
        df.iloc[test_idx].drop(columns=['scaffold']).reset_index(drop=True)
    )


In [9]:
train_df, val_df, test_df = scaffold_split(processed_df, smiles_col='Drug')


In [10]:
train_df["task"].value_counts(), val_df["task"].value_counts(), test_df["task"].value_counts()

(task
 cyp2d6_veith                      10114
 cyp3a4_veith                       9493
 cyp2c9_veith                       9304
 lipophilicity_astrazeneca          3017
 bbb_martins                        1186
 ppbr_az                            1127
 clearance_microsome_az              813
 pgp_broccatelli                     805
 vdss_lombardo                       782
 clearance_hepatocyte_az             718
 hia_hou                             355
 cyp3a4_substrate_carbonmangels      331
 cyp2c9_substrate_carbonmangels      329
 cyp2d6_substrate_carbonmangels      325
 bioavailability_ma                  169
 Name: count, dtype: int64,
 task
 cyp2d6_veith                      1300
 cyp3a4_veith                      1212
 cyp2c9_veith                      1157
 lipophilicity_astrazeneca          392
 ppbr_az                            163
 bbb_martins                        139
 pgp_broccatelli                     98
 vdss_lombardo                       93
 clearance_hepatocyte_az 

In [11]:
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors, rdMolDescriptors, Crippen, QED
from sklearn.preprocessing import StandardScaler
from sklearn.feature_selection import VarianceThreshold
import numpy as np
import pandas as pd

# 1) Essential 2D descriptors
ESSENTIAL_DESCRIPTOR_FUNCS = {
    "MolWt": Descriptors.MolWt,
    "HeavyAtomCount": Descriptors.HeavyAtomCount,
    "ExactMolWt": Descriptors.ExactMolWt,
    "MolLogP": Descriptors.MolLogP,
    "MolMR": Descriptors.MolMR,
    "TPSA": rdMolDescriptors.CalcTPSA,
    "LabuteASA": rdMolDescriptors.CalcLabuteASA,
    "BalabanJ": Descriptors.BalabanJ,
    "BertzCT": Descriptors.BertzCT,
    "Chi0": Descriptors.Chi0, "Chi1": Descriptors.Chi1,
    "Chi2n": Descriptors.Chi2n, "Chi3n": Descriptors.Chi3n, "Chi4n": Descriptors.Chi4n,
    "Chi2v": Descriptors.Chi2v, "Chi3v": Descriptors.Chi3v, "Chi4v": Descriptors.Chi4v,
    "HallKierAlpha": Descriptors.HallKierAlpha,
    "Kappa1": Descriptors.Kappa1, "Kappa2": Descriptors.Kappa2, "Kappa3": Descriptors.Kappa3,
    "NumHAcceptors": rdMolDescriptors.CalcNumLipinskiHBA,
    "NumHDonors": rdMolDescriptors.CalcNumLipinskiHBD,
    "NumRotatableBonds": Descriptors.NumRotatableBonds,
    "NumRadicalElectrons": Descriptors.NumRadicalElectrons,
    "RingCount": Descriptors.RingCount,
    "NumAromaticRings": rdMolDescriptors.CalcNumAromaticRings,
    "NumAliphaticRings": rdMolDescriptors.CalcNumAliphaticRings,
    "NumSaturatedRings": rdMolDescriptors.CalcNumSaturatedRings,
    "NumHeterocycles": rdMolDescriptors.CalcNumHeterocycles,
    "FractionCSP3": rdMolDescriptors.CalcFractionCSP3,
    "NumSpiroAtoms": rdMolDescriptors.CalcNumSpiroAtoms,
    "NumBridgeheadAtoms": rdMolDescriptors.CalcNumBridgeheadAtoms,
    "FormalCharge": lambda m: sum(a.GetFormalCharge() for a in m.GetAtoms()) if m else np.nan,
    "NumValenceElectrons": Descriptors.NumValenceElectrons,
    "NumHeteroatoms": Descriptors.NumHeteroatoms,
    "NHOHCount": Descriptors.NHOHCount,
    "NOCount": Descriptors.NOCount,
    "NumAmideBonds": rdMolDescriptors.CalcNumAmideBonds,
    "QED": QED.qed,
    "CrippenLogP": Crippen.MolLogP,
    "CrippenMR": Crippen.MolMR,
}

# 2) Optional VSA descriptors
def compute_vsa_descriptors(mol):
    if mol is None:
        return {}
    out = {}
    try:
        for i, v in enumerate(rdMolDescriptors.PEOE_VSA_(mol), 1):
            out[f"PEOE_VSA_{i}"] = v
        for i, v in enumerate(rdMolDescriptors.SMR_VSA_(mol), 1):
            out[f"SMR_VSA_{i}"] = v
        for i, v in enumerate(rdMolDescriptors.SlogP_VSA_(mol), 1):
            out[f"SlogP_VSA_{i}"] = v
    except:
        pass
    return out

# 3) Fast 3D shape descriptors
def compute_fast3d_descriptors(mol):
    out = {k: np.nan for k in [
        "PMI1", "PMI2", "PMI3", "Asphericity", "Eccentricity",
        "RadiusOfGyration", "SpherocityIndex", "InertialShapeFactor"
    ]}
    if mol is None:
        return out
    try:
        if mol.GetNumConformers() == 0:
            AllChem.EmbedMolecule(mol, AllChem.ETKDGv3(), randomSeed=42)
        out["PMI1"], out["PMI2"], out["PMI3"] = rdMolDescriptors.CalcPrincipalMomentsOfInertia(mol)
        out["Asphericity"] = rdMolDescriptors.CalcAsphericity(mol)
        out["Eccentricity"] = rdMolDescriptors.CalcEccentricity(mol)
        out["RadiusOfGyration"] = rdMolDescriptors.CalcRadiusOfGyration(mol)
        out["SpherocityIndex"] = rdMolDescriptors.CalcSpherocityIndex(mol)
        out["InertialShapeFactor"] = rdMolDescriptors.CalcInertialShapeFactor(mol)
    except:
        pass
    return out

# 4) GNN atom/bond features
def get_atom_features(a):
    return [
        a.GetAtomicNum(), a.GetDegree(),
        int(a.GetHybridization()), int(a.GetIsAromatic()),
        a.GetFormalCharge(), int(a.IsInRing()),
    ]

def get_bond_features(b):
    return [
        b.GetBondTypeAsDouble(), int(b.GetIsConjugated()),
        int(b.IsInRing()), int(b.GetStereo() != Chem.rdchem.BondStereo.STEREONONE),
    ]

# 5) Compute descriptors for SMILES
def compute_descriptor_df(smiles_series, use_vsa=True, use_3d=True):
    rows = []
    for smi in smiles_series:
        mol = Chem.MolFromSmiles(smi) if pd.notnull(smi) else None
        rec = {}
        for name, fn in ESSENTIAL_DESCRIPTOR_FUNCS.items():
            try:
                rec[name] = fn(mol) if mol else np.nan
                if mol:
                    try:
                        mol_charges = Chem.AddHs(Chem.Mol(mol))
                        AllChem.ComputeGasteigerCharges(mol_charges)
                        charges = [float(a.GetProp('_GasteigerCharge'))
                                   for a in mol_charges.GetAtoms()
                                   if a.HasProp('_GasteigerCharge')]
                        if charges:
                            rec.update({
                                'GasteigerCharge_mean': np.mean(charges),
                                'GasteigerCharge_min': np.min(charges),
                                'GasteigerCharge_max': np.max(charges),
                            })
                        else:
                            rec.update({k: np.nan for k in [
                                'GasteigerCharge_mean', 'GasteigerCharge_min', 'GasteigerCharge_max'
                            ]})
                    except:
                        rec.update({k: np.nan for k in [
                            'GasteigerCharge_mean', 'GasteigerCharge_min', 'GasteigerCharge_max'
                        ]})
            except:
                rec[name] = np.nan
        if use_vsa:
            rec.update(compute_vsa_descriptors(mol))
        if use_3d and mol:
            rec.update(compute_fast3d_descriptors(Chem.AddHs(Chem.Mol(mol))))
        rows.append(rec)
    return pd.DataFrame(rows, index=getattr(smiles_series, "index", None))

# 6) Descriptor filtering (train only)
def filter_descriptors(df, missing_thresh=0.3, var_thresh=0.01, corr_thresh=0.95):
    keep = df.columns[df.isnull().mean() <= missing_thresh]
    df = df[keep].fillna(df.median(numeric_only=True))
    mask = VarianceThreshold(var_thresh).fit(df).get_support()
    df = df.loc[:, mask]
    if df.shape[1] > 1:
        corr = df.corr().abs()
        upper = corr.where(np.triu(np.ones(corr.shape), k=1).astype(bool))
        drop = [c for c in upper.columns if any(upper[c] > corr_thresh)]
        if drop:
            df = df.drop(columns=drop)
    return df, list(df.columns)

# 7) Prepare global scaler
all_train_smiles = pd.Index(sorted(train_df['Drug'].unique()))
raw_tr_all = compute_descriptor_df(all_train_smiles, use_vsa=True, use_3d=True)
tr_desc_filtered_all, kept_global = filter_descriptors(raw_tr_all)
train_medians_global = tr_desc_filtered_all.median(numeric_only=True)
global_scaler = StandardScaler().fit(tr_desc_filtered_all.values)

def align_and_scale(smiles_series):
    raw = compute_descriptor_df(smiles_series, use_vsa=True, use_3d=True)
    aligned = raw[kept_global].fillna(train_medians_global)
    return global_scaler.transform(aligned.values)

# 8) Process multitask splits
def process_multitask_splits_global(train_df, val_df, test_df):
    processed, targ_scalers = {}, {}
    tasks = sorted(train_df['task'].unique())

    def mol_graph_feats(series):
        A, B = [], []
        for smi in series:
            mol = Chem.MolFromSmiles(smi) if pd.notnull(smi) else None
            if mol:
                A.append([get_atom_features(a) for a in mol.GetAtoms()])
                B.append([get_bond_features(b) for b in mol.GetBonds()])
            else:
                A.append([]); B.append([])
        return A, B

    for task in tasks:
        tr, vl, te = (
            train_df[train_df['task'] == task].reset_index(drop=True),
            val_df[val_df['task'] == task].reset_index(drop=True),
            test_df[test_df['task'] == task].reset_index(drop=True),
        )

        tr_X, vl_X, te_X = map(align_and_scale, [tr['Drug'], vl['Drug'], te['Drug']])

        if tr.loc[0, 'task_type'] == 'regression':
            ts = StandardScaler().fit(tr[['Y']])
            targ_scalers[task] = ts
            tr_Y, vl_Y, te_Y = (ts.transform(tr[['Y']]).ravel(),
                                ts.transform(vl[['Y']]).ravel(),
                                ts.transform(te[['Y']]).ravel())
        else:
            tr_Y, vl_Y, te_Y = tr['Y'].values, vl['Y'].values, te['Y'].values

        tr_A, tr_B = mol_graph_feats(tr['Drug'])
        vl_A, vl_B = mol_graph_feats(vl['Drug'])
        te_A, te_B = mol_graph_feats(te['Drug'])

        processed[task] = {
            'train': {'desc': tr_X, 'Y': tr_Y, 'atoms': tr_A, 'bonds': tr_B},
            'val':   {'desc': vl_X, 'Y': vl_Y, 'atoms': vl_A, 'bonds': vl_B},
            'test':  {'desc': te_X, 'Y': te_Y, 'atoms': te_A, 'bonds': te_B},
        }

    return processed, targ_scalers, kept_global, global_scaler

processed, targ_scalers, kept_global, global_scaler = process_multitask_splits_global(
    train_df, val_df, test_df
)
descriptor_dim = len(kept_global)


# **Model Architecture**

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINEConv, GATConv, JumpingKnowledge, Set2Set
from torch_geometric.data import Data

class ResidualGINEGATBlock(nn.Module):
    def __init__(self, dim, drop):
        super().__init__()
        self.gine = GINEConv(
            nn.Sequential(nn.Linear(dim, dim), nn.ReLU(), nn.Linear(dim, dim)),
            edge_dim=dim, train_eps=True
        )
        self.norm1 = nn.LayerNorm(dim)

        self.gat = GATConv(dim, dim // 4, heads=4, concat=True)
        self.norm2 = nn.LayerNorm(dim)

        self.drop = drop

    def forward(self, x, edge_index, edge_attr):
        h = self.gine(x, edge_index, edge_attr)
        h = self.norm1(F.relu(h))
        h = F.dropout(h, p=self.drop, training=self.training) + x

        g = self.gat(h, edge_index)
        g = self.norm2(F.relu(g))
        return F.dropout(g, p=self.drop, training=self.training) + h

class GNNEncoder(nn.Module):
    def __init__(self, atom_dim, bond_dim, hidden_dim=128, num_layers=3, drop=0.1):
        super().__init__()
        self.atom_emb = nn.Linear(atom_dim, hidden_dim)
        self.bond_emb = nn.Linear(bond_dim, hidden_dim)

        self.blocks = nn.ModuleList([
            ResidualGINEGATBlock(hidden_dim, drop) for _ in range(num_layers)
        ])
        self.jk = JumpingKnowledge(mode='cat')
        self.pool = Set2Set(hidden_dim * num_layers, processing_steps=3)

        self.out_dim = hidden_dim * num_layers * 2

    def forward(self, data: Data):
        x, ei, ea, batch = data.x, data.edge_index, data.edge_attr, data.batch
        x = F.relu(self.atom_emb(x))
        ea = F.relu(self.bond_emb(ea))

        xs = []
        for blk in self.blocks:
            x = blk(x, ei, ea)
            xs.append(x)

        x_jk = self.jk(xs)
        return self.pool(x_jk, batch)

class ADMEModel(nn.Module):
    def __init__(
        self,
        descriptor_dim,
        atom_dim=6,
        bond_dim=4,
        graph_hidden=128,
        desc_hidden=128,
        num_gnn_layers=4,
        fused_dim=256,
        drop=0.1,
        use_film=True
    ):
        super().__init__()
        self.use_film = use_film

        self.graph_encoder = GNNEncoder(atom_dim, bond_dim, graph_hidden, num_gnn_layers, drop)
        self.desc_encoder = nn.Sequential(
            nn.LayerNorm(descriptor_dim),
            nn.Linear(descriptor_dim, desc_hidden), nn.ReLU(), nn.Dropout(drop),
            nn.Linear(desc_hidden, desc_hidden), nn.ReLU(), nn.Dropout(drop)
        )

        if use_film:
            self.film = nn.Linear(desc_hidden, self.graph_encoder.out_dim * 2)

        self.gate = nn.Sequential(
            nn.Linear(self.graph_encoder.out_dim + desc_hidden, fused_dim),
            nn.ReLU(),
            nn.Linear(fused_dim, fused_dim),
            nn.Sigmoid()
        )
        self.fusion_proj = nn.Sequential(
            nn.Linear(self.graph_encoder.out_dim + desc_hidden, fused_dim),
            nn.ReLU(), nn.Dropout(drop),
            nn.Linear(fused_dim, fused_dim), nn.ReLU(), nn.Dropout(drop)
        )

        self.adapters = nn.ModuleDict()
        self.heads = nn.ModuleDict()
        self.fused_dim = fused_dim

        self.task_losses = {}

    def add_head(self, task_id, output_size=1):
        self.adapters[task_id] = nn.Sequential(
            nn.Linear(self.fused_dim, self.fused_dim // 2),
            nn.ReLU(),
            nn.Linear(self.fused_dim // 2, self.fused_dim),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
        self.heads[task_id] = nn.Linear(self.fused_dim, output_size)

    def forward(self, graph_batch, desc, task_id):
        g = self.graph_encoder(graph_batch)
        d = self.desc_encoder(desc)

        if self.use_film:
            gamma, beta = self.film(d).chunk(2, dim=1)
            g = gamma * g + beta

        fused_in = torch.cat([g, d], dim=1)
        gate_val = self.gate(fused_in)
        z = self.fusion_proj(fused_in) * gate_val

        h = self.adapters[task_id](z)
        return self.heads[task_id](h)

    def freeze_backbone(self):
        for name, param in self.named_parameters():
            if not name.startswith(('adapters', 'heads')):
                param.requires_grad = False

    def unfreeze_backbone(self):
        for param in self.parameters():
            param.requires_grad = True

def count_params(module: nn.Module):
    total = sum(p.numel() for p in module.parameters())
    trainable = sum(p.numel() for p in module.parameters() if p.requires_grad)
    return total, trainable

def print_param_table(model: nn.Module):
    rows = []

    def add(name, m):
        tot, tr = count_params(m)
        rows.append((name, tot, tr))

    add('graph_encoder', model.graph_encoder)
    add('desc_encoder', model.desc_encoder)
    add('gate', model.gate)
    add('fusion_proj', model.fusion_proj)
    for name, head in model.heads.items():
        add(f'head:{name}', head)

    w = max(len(r[0]) for r in rows)
    print(f'{"module".ljust(w)}  total     trainable')
    for name, tot, tr in rows:
        print(f'{name.ljust(w)}  {tot:8d}  {tr:9d}')
    tot_all, tr_all = count_params(model)
    print('-' * w)
    print(f'TOTAL: {tot_all:,}  TRAINABLE: {tr_all:,}')


# **Training Loop**

In [13]:
import os, time, json, random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import defaultdict, Counter
from torch.utils.data import DataLoader, WeightedRandomSampler
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader as GeoDataLoader
from rdkit import Chem
from scipy.stats import spearmanr

SPEARMAN_TASKS = set()


class TaskDataset(torch.utils.data.Dataset):
    def __init__(self, df_task, split_dict, augment=True):
        self.df = df_task.reset_index(drop=True)
        self.descs = split_dict['desc']
        self.targets = split_dict['Y']
        self.atoms = split_dict['atoms']
        self.bonds = split_dict['bonds']
        self.augment = augment

    def __len__(self): return len(self.df)

    def __getitem__(self, i):
        smi = self.df.loc[i, 'Drug']
        mol = Chem.MolFromSmiles(smi)
        if self.augment:
            try:
                smi = Chem.MolToSmiles(mol, doRandom=True)
                mol = Chem.MolFromSmiles(smi)
            except Exception:
                pass

        x = torch.tensor(self.atoms[i], dtype=torch.float)

        edge_index, edge_attr = [], []
        bonds_list = list(mol.GetBonds()) if mol is not None else []
        for feat, bond in zip(self.bonds[i], bonds_list):
            a, b = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            edge_index += [[a,b],[b,a]]
            edge_attr  += [feat, feat]
        if len(edge_index)==0:
            edge_index = torch.zeros((2,0), dtype=torch.long)
            edge_attr = torch.zeros((0, len(self.bonds[i][0]) if len(self.bonds[i])>0 else 4), dtype=torch.float)
        else:
            edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
            edge_attr  = torch.tensor(edge_attr, dtype=torch.float)

        y = torch.tensor(self.targets[i], dtype=torch.float).view(-1)
        data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
        desc = torch.tensor(self.descs[i], dtype=torch.float)
        return data, desc


task_types = {t: train_df[train_df['task']==t]['task_type'].iloc[0] for t in processed.keys()}


model = ADMEModel(
    descriptor_dim=descriptor_dim,
    atom_dim= len(get_atom_features(Chem.MolFromSmiles("CC").GetAtomWithIdx(0))),
    bond_dim=len(get_bond_features(Chem.MolFromSmiles("CC").GetBondWithIdx(0))),
    graph_hidden=128,
    desc_hidden=128,
    num_gnn_layers=4,
    fused_dim=256,
    drop=0.1
)

for task in processed.keys():
    if task not in model.heads:
        model.add_head(task, output_size=1)

FOCAL_TASKS = {
    'cyp2c9_substrate_carbonmangels': {'alpha': 0.6, 'gamma': 2.0},
    'cyp3a4_substrate_carbonmangels': {'alpha': 0.6, 'gamma': 2.0},
}
for name in list(model.heads.keys()):
    if name in FOCAL_TASKS:
        head = model.heads[name]
        in_dim = None
        out_dim = 1
        if isinstance(head, nn.Linear):
            in_dim = head.in_features; out_dim = head.out_features
        elif isinstance(head, nn.Sequential):
            for m in head:
                if isinstance(m, nn.Linear):
                    in_dim = m.in_features
                    break
            for m in reversed(head):
                if isinstance(m, nn.Linear):
                    out_dim = m.out_features
                    break
        if in_dim is None:
            in_dim = getattr(model, 'fused_dim', 256)
        model.heads[name] = nn.Sequential(
            nn.Linear(in_dim, max(in_dim//2, 4)),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(max(in_dim//2, 4), out_dim)
        )

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(DEVICE)


BATCH_SIZE = 32
VAL_BATCH = 64
EPOCHS = 80
PATIENCE = 20
BASE_LR = 1e-3
WEIGHT_DECAY = 1e-5
WARMUP_EPOCHS = 3
SAVE_TO_DRIVE = True
DRIVE_DIR = "/content/drive/MyDrive/molecular_pretraining"
IMBALANCE_EDGE = 0.35
ALPHA_GRADNORM = 1.5
WLR = 1e-3
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

def collate_fn(batch):
    datas, descs = zip(*batch)
    batch_graph = Batch.from_data_list(list(datas))
    descs_t = torch.stack([d if isinstance(d, torch.Tensor) else torch.tensor(d, dtype=torch.float) for d in descs], dim=0)
    return batch_graph, descs_t

def make_balanced_loader(ds, df_task, batch_size=32, task_name=None):
    y = df_task['Y'].astype(int).values
    pos_rate = max((y == 1).mean(), 1e-6)
    w_pos = 0.5 / pos_rate
    w_neg = 0.5 / (1.0 - pos_rate)
    if task_name == 'cyp2c9_substrate_carbonmangels':
        w_pos *= 3.0
    weights = torch.tensor([w_pos if yi==1 else w_neg for yi in y], dtype=torch.float)
    sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)
    return DataLoader(ds, batch_size=batch_size, sampler=sampler, collate_fn=collate_fn)

def build_task_loaders(processed, df_split, split_name, batch_size=32, shuffle=False):
    loaders = {}
    for task in processed.keys():
        df_task = df_split[df_split['task'] == task].reset_index(drop=True)
        if len(df_task) == 0:
            continue
        ds = TaskDataset(df_task, processed[task][split_name], augment=(split_name=='train'))
        if split_name=='train' and task_types[task]=='classification':
            pos = df_task['Y'].mean()
            if (pos < IMBALANCE_EDGE) or (pos > 1 - IMBALANCE_EDGE):
                loaders[task] = make_balanced_loader(ds, df_task, batch_size=batch_size, task_name=task)
                continue
        loaders[task] = DataLoader(ds, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)
    return loaders

train_loaders = build_task_loaders(processed, train_df, 'train', batch_size=BATCH_SIZE, shuffle=True)
val_loaders   = build_task_loaders(processed, val_df,   'val',   batch_size=VAL_BATCH, shuffle=False)
test_loaders  = build_task_loaders(processed, test_df,  'test',  batch_size=VAL_BATCH, shuffle=False)


class BCEWithLogitsLossSmoothed(nn.Module):
    def __init__(self, pos_weight=None, smoothing=0.05):
        super().__init__()
        self.smoothing = smoothing
        self.pos_weight = pos_weight
        self.bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='mean')
    def forward(self, logits, targets):
        s = self.smoothing
        targets = targets * (1 - 2*s) + s
        return self.bce(logits, targets)

class FocalBCEWithLogitsLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha, self.gamma, self.reduction = alpha, gamma, reduction
    def forward(self, logits, targets):
        p = torch.sigmoid(logits)
        ce = torch.nn.functional.binary_cross_entropy_with_logits(logits, targets, reduction='none')
        p_t = p*targets + (1-p)*(1-targets)
        w = (self.alpha*targets + (1-self.alpha)*(1-targets))
        loss = w * (1 - p_t).pow(self.gamma) * ce
        return loss.mean() if self.reduction=='mean' else loss.sum()

REGRESSION_CRIT = nn.SmoothL1Loss(reduction='mean')

criteria = {}
LABEL_SMOOTHING = 0.05
for t in processed.keys():
    if task_types[t] == 'classification':
        ytr = train_df[train_df['task']==t]['Y'].astype(int).values
        c = Counter(ytr); n0, n1 = c.get(0,0), c.get(1,0)
        pos_weight = None if n1==0 else torch.tensor([max(n0,1)/max(n1,1)], dtype=torch.float, device=DEVICE)
        if t in FOCAL_TASKS:
            cfg = FOCAL_TASKS[t]
            criteria[t] = FocalBCEWithLogitsLoss(alpha=cfg['alpha'], gamma=cfg['gamma'])
        else:
            criteria[t] = BCEWithLogitsLossSmoothed(pos_weight=pos_weight, smoothing=LABEL_SMOOTHING)
    else:
        criteria[t] = REGRESSION_CRIT

task_list = sorted([t for t in processed.keys()])
T = len(task_list)
task_to_idx = {t:i for i,t in enumerate(task_list)}
shared_module = getattr(model, 'graph_encoder', model)
shared_params = [p for p in shared_module.parameters() if p.requires_grad]
w_vec = torch.nn.Parameter(torch.ones(T, device=DEVICE), requires_grad=True)
L0 = None
first_epoch_done = False

model_params = [p for p in model.parameters() if p.requires_grad]
optimizer_model = torch.optim.AdamW(model_params, lr=BASE_LR, weight_decay=WEIGHT_DECAY)
optimizer_w = torch.optim.Adam([w_vec], lr=WLR)
scheduler = CosineAnnealingWarmRestarts(optimizer_model, T_0=10, T_mult=2)
scaler = GradScaler()

def weighted_loss_for_training(task_name, loss_tensor):
    idx = task_to_idx[task_name]
    w = w_vec[idx].detach()
    return w * loss_tensor

def mean(val_list):
    return float(np.mean(val_list)) if len(val_list) else 0.0

def regression_metrics(y_true, y_pred, task):
    mse = float(np.mean((y_true - y_pred)**2))
    rmse = float(np.sqrt(mse))
    mae = float(np.mean(np.abs(y_true - y_pred)))
    out = {"mse": mse, "rmse": rmse, "mae": mae}
    if task in SPEARMAN_TASKS:
        sr = spearmanr(y_true, y_pred).correlation
        out['spearman'] = float(sr) if np.isfinite(sr) else np.nan
    return out

def classification_metrics(y_true, logits):
    p = 1.0/(1.0+np.exp(-logits))
    preds = (p >= 0.5).astype(np.float32)
    acc = float((preds == y_true).mean())
    y = y_true.astype(int)
    if y.min()==0 and y.max()==1 and (y==1).sum()>0 and (y==0).sum()>0:
        order = np.argsort(p)
        ranks = np.empty_like(order, dtype=float); ranks[order] = np.arange(1, len(p)+1)
        n1 = float((y==1).sum()); n0 = float((y==0).sum())
        auroc = (ranks[y==1].sum() - n1*(n1+1)/2.0) / (n0*n1)
    else:
        auroc = float('nan')
    order = np.argsort(-p)
    y_sorted = y_true[order].astype(int)
    tp = np.cumsum(y_sorted==1).astype(float)
    fp = np.cumsum(y_sorted==0).astype(float)
    prec = tp / np.maximum(tp+fp, 1.0)
    rec  = tp / max((y_sorted==1).sum(), 1.0)
    auprc = float(np.trapz(prec, rec))
    return {"acc": acc, "auroc": float(auroc), "auprc": float(auprc)}

def prevalence_matched_threshold(y_true, logits, match_rate=None):
    p = 1.0/(1.0+np.exp(-logits))
    if match_rate is None:
        match_rate = (y_true == 1).mean()
    return float(np.quantile(p, 1.0 - match_rate))

def val_selection_score(val_metrics, task_types):
    rmses, one_minus_auroc, one_minus_auprc = [], [], []
    for t, m in val_metrics.items():
        if task_types[t] == 'regression':
            rmses.append(m['rmse'])
        else:
            if np.isfinite(m.get('auroc', np.nan)):
                one_minus_auroc.append(1.0 - m['auroc'])
            if np.isfinite(m.get('auprc', np.nan)):
                one_minus_auprc.append(1.0 - m['auprc'])
    a = np.mean(rmses) if rmses else 0.0
    b = np.mean(one_minus_auroc) if one_minus_auroc else 0.0
    c = np.mean(one_minus_auprc) if one_minus_auprc else 0.0
    return a + 0.5 * b + 1.0 * c


def run_epoch(loaders, train=True):
    if train:
        model.train()
    else:
        model.eval()

    reg_loss_sum=reg_loss_cnt=0
    clf_loss_sum=clf_loss_cnt=0
    all_preds = defaultdict(list); all_targets = defaultdict(list)
    # track per-task training loss accumulation (only used during training)
    epoch_task_losses = defaultdict(list)

    with torch.set_grad_enabled(train):
        for task, loader in loaders.items():
            for graph_batch, desc in loader:
                graph_batch = graph_batch.to(DEVICE)
                desc = desc.to(DEVICE)
                y = graph_batch.y.view(-1).to(DEVICE)

                with autocast(enabled=(DEVICE=='cuda')):
                    out = model(graph_batch, desc, task).view(-1)
                    if task_types[task]=='regression':
                        L_raw = REGRESSION_CRIT(out, y)
                    else:
                        L_raw = criteria[task](out, y)
                    if train:
                        weighted = weighted_loss_for_training(task, L_raw)
                    else:
                        weighted = L_raw

                if train:
                    optimizer_model.zero_grad()
                    scaler.scale(weighted).backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    scaler.step(optimizer_model)
                    scaler.update()
                    # store raw loss value (unpushed by w)
                    epoch_task_losses[task].append(float(L_raw.detach().cpu().item()))
                else:
                    all_preds[task].append(out.detach().cpu().numpy())
                    all_targets[task].append(y.detach().cpu().numpy())

                if task_types[task]=='regression':
                    reg_loss_sum += float(L_raw.item()); reg_loss_cnt += 1
                else:
                    clf_loss_sum += float(L_raw.item()); clf_loss_cnt += 1

    reg_loss = (reg_loss_sum/reg_loss_cnt) if reg_loss_cnt>0 else None
    clf_loss = (clf_loss_sum/clf_loss_cnt) if clf_loss_cnt>0 else None

    metrics = {}; raws = {}
    if not train:
        for t in all_preds:
            y_true = np.concatenate(all_targets[t])
            logits = np.concatenate(all_preds[t])
            raws[t] = (y_true, logits)
            if task_types[t]=='regression':
                metrics[t] = regression_metrics(y_true, logits, t)
            else:
                metrics[t] = classification_metrics(y_true, logits)
    return reg_loss, clf_loss, metrics, raws, epoch_task_losses

def sample_one_batch_per_task(train_loaders):
    samples = {}
    for t, loader in train_loaders.items():
        try:
            it = iter(loader)
            batch = next(it)
        except Exception:
            ds = loader.dataset
            idx = random.randrange(len(ds))
            batch = ds[idx:idx+1] if hasattr(ds, '__getitem__') else None
        samples[t] = batch
    return samples


def init_L0(epoch_task_losses):
    """Compute initial L0 mean losses for each task after first epoch."""
    L0_mean = {t: mean(vals) if len(vals) > 0 else 1.0 for t, vals in epoch_task_losses.items()}
    return torch.tensor([L0_mean.get(t, 1.0) for t in task_list], dtype=torch.float, device=DEVICE)

def gradnorm_step(train_loaders, epoch_task_losses, L0_vec):
    """
    Robust GradNorm: temporarily disable cuDNN for the forward passes used
    to build the differentiable losses (so double-backward is supported).
    """
    L0_vec = L0_vec.to(DEVICE)
    sample_batches = sample_one_batch_per_task(train_loaders)

    was_training = model.training
    model.train()

    try:
        optimizer_model.zero_grad()

        L_t_tensors = []
        with torch.backends.cudnn.flags(enabled=False):
            for t in task_list:
                batch = sample_batches.get(t, None)
                if batch is None:
                    L_t_tensors.append(torch.tensor(1.0, device=DEVICE, requires_grad=True))
                    continue

                graph_batch, desc = batch
                graph_batch, desc = graph_batch.to(DEVICE), desc.to(DEVICE)
                y = graph_batch.y.view(-1).to(DEVICE)

                with autocast(enabled=(DEVICE == 'cuda')):
                    out = model(graph_batch, desc, t).view(-1)
                    L_i = REGRESSION_CRIT(out, y) if task_types[t] == 'regression' else criteria[t](out, y)

                if not isinstance(L_i, torch.Tensor):
                    L_i = torch.tensor(float(L_i), device=DEVICE, requires_grad=True)
                elif not L_i.requires_grad:
                    L_i = L_i.clone().requires_grad_(True)

                L_t_tensors.append(L_i)

        L_mean_epoch = torch.tensor(
            [ mean(epoch_task_losses.get(t, [])) if len(epoch_task_losses.get(t, []))>0
              else float(L_t_tensors[i].detach().cpu().item())
              for i,t in enumerate(task_list) ],
            device=DEVICE, dtype=torch.float
        ) + 1e-9

        L0_vec = (L0_vec + 1e-9)
        r_i = L_mean_epoch / L0_vec
        r_bar = torch.mean(r_i)

        G_i = []
        for i, L_i in enumerate(L_t_tensors):
            weighted = w_vec[i] * L_i
            grads = torch.autograd.grad(weighted, shared_params, retain_graph=True, create_graph=True)
            g_norm = torch.sqrt(sum([(g.norm())**2 for g in grads if g is not None]) + 1e-12)
            G_i.append(g_norm)

        G = torch.stack(G_i)
        G_avg = torch.mean(G)

        G_star = G_avg * ((r_i / r_bar) ** ALPHA_GRADNORM)
        loss_gradnorm = torch.sum(torch.abs(G - G_star))

        optimizer_w.zero_grad()
        loss_gradnorm.backward()
        optimizer_w.step()

        with torch.no_grad():
            w_vec.data = w_vec.data * (float(T) / float(w_vec.data.sum() + 1e-12))

    finally:
        if not was_training:
            model.eval()

    return L0_vec


best_score = float('inf')
best_state, stale = None, 0
first_epoch_done = False

for epoch in range(1, EPOCHS + 1):
    if epoch <= WARMUP_EPOCHS:
        factor = epoch / float(WARMUP_EPOCHS)
        for g in optimizer_model.param_groups:
            g['lr'] = BASE_LR * factor

    tr_reg, tr_clf, _, _, epoch_task_losses = run_epoch(train_loaders, train=True)

    if not first_epoch_done:
        L0_vec = init_L0(epoch_task_losses)
        first_epoch_done = True

    val_reg, val_clf, val_metrics, val_raws, _ = run_epoch(val_loaders, train=False)
    val_score = val_selection_score(val_metrics, task_types)

    if val_score < best_score:
        best_score, best_state, stale = val_score, {k: v.detach().cpu() for k, v in model.state_dict().items()}, 0
    else:
        stale += 1

    print(f"Epoch {epoch:02d} | train_reg={tr_reg:.4f} train_clf={tr_clf:.4f} | val_score={val_score:.4f} | stale={stale}/{PATIENCE}")

    L0_vec = gradnorm_step(train_loaders, epoch_task_losses, L0_vec)


    scheduler.step(epoch)
    if stale >= PATIENCE:
        print("Early stopping.")
        break

if best_state is not None:
    model.load_state_dict(best_state)




Epoch 01 | train_reg=0.3256 train_clf=0.8215 | val_score=1.4873 | stale=0/12
Epoch 02 | train_reg=0.3077 train_clf=0.7781 | val_score=1.4544 | stale=0/12
Epoch 03 | train_reg=0.2992 train_clf=0.7345 | val_score=1.3773 | stale=0/12
Epoch 04 | train_reg=0.2832 train_clf=0.7168 | val_score=1.3485 | stale=0/12
Epoch 05 | train_reg=0.2739 train_clf=0.7042 | val_score=1.3282 | stale=0/12
Epoch 06 | train_reg=0.2609 train_clf=0.6840 | val_score=1.2702 | stale=0/12
Epoch 07 | train_reg=0.2463 train_clf=0.6645 | val_score=1.2591 | stale=0/12
Epoch 08 | train_reg=0.2375 train_clf=0.6502 | val_score=1.2544 | stale=0/12
Epoch 09 | train_reg=0.2368 train_clf=0.6421 | val_score=1.2481 | stale=0/12
Epoch 10 | train_reg=0.2254 train_clf=0.6484 | val_score=1.2439 | stale=0/12
Epoch 11 | train_reg=0.2497 train_clf=0.6776 | val_score=1.2693 | stale=1/12
Epoch 12 | train_reg=0.2397 train_clf=0.6617 | val_score=1.2484 | stale=2/12
Epoch 13 | train_reg=0.2371 train_clf=0.6546 | val_score=1.2689 | stale=3/12

In [14]:
import os
save_path = "checkpoints/best_gradnorm_pretrained.pth"
os.makedirs(os.path.dirname(save_path), exist_ok=True)  # Create directory if missing

torch.save({
    'model_state_dict': model.state_dict(),
    'gradnorm_w': w_vec.detach().cpu().tolist(),
    'task_list': task_list,
    'val_score': best_score
}, save_path)

print("Training complete.")
print_param_table(model)


Training complete.
module                               total     trainable
graph_encoder                         3418628    3418628
desc_encoder                            24310      24310
gate                                   360960     360960
fusion_proj                            360960     360960
head:bbb_martins                          257        257
head:bioavailability_ma                   257        257
head:clearance_hepatocyte_az              257        257
head:clearance_microsome_az               257        257
head:cyp2c9_substrate_carbonmangels     33025      33025
head:cyp2c9_veith                         257        257
head:cyp2d6_substrate_carbonmangels       257        257
head:cyp2d6_veith                         257        257
head:cyp3a4_substrate_carbonmangels     33025      33025
head:cyp3a4_veith                         257        257
head:hia_hou                              257        257
head:lipophilicity_astrazeneca            257        257
head:pgp_bro

In [15]:
def class_balance(df, task):
    y = df[df['task']==task]['Y'].astype(int).values
    p = (y==1).mean()
    print(f"{task}: n={len(y)}  pos%={100*p:.1f}")

for t in ['cyp2c9_substrate_carbonmangels','cyp2d6_substrate_carbonmangels']:
    for name, df in [('train',train_df),('val',val_df),('test',test_df)]:
        print(name, end=' | '); class_balance(df, t)


train | cyp2c9_substrate_carbonmangels: n=329  pos%=17.3
val | cyp2c9_substrate_carbonmangels: n=35  pos%=25.7
test | cyp2c9_substrate_carbonmangels: n=31  pos%=12.9
train | cyp2d6_substrate_carbonmangels: n=325  pos%=28.0
val | cyp2d6_substrate_carbonmangels: n=35  pos%=31.4
test | cyp2d6_substrate_carbonmangels: n=31  pos%=32.3


In [17]:
import os, time, json
import torch, numpy as np, pandas as pd
from torch_geometric.data import Batch
from torch.cuda.amp import autocast

CKPT_PATH = "checkpoints/best_gradnorm_pretrained.pth"
SAVE_TO_DRIVE = True
DRIVE_DIR = "/content/drive/MyDrive/molecular_pretraining"
OUT_DIR = "test_results"
os.makedirs(OUT_DIR, exist_ok=True)

def regression_metrics(y_true, y_pred, task):
    mse = float(np.mean((y_true - y_pred)**2))
    rmse = float(np.sqrt(mse))
    mae = float(np.mean(np.abs(y_true - y_pred)))
    out = {"mse": mse, "rmse": rmse, "mae": mae}
    return out

def classification_metrics(y_true, logits):
    p = 1.0/(1.0+np.exp(-logits))
    preds = (p >= 0.5).astype(np.int64)
    acc = float((preds == y_true).mean())
    y = y_true.astype(int)
    # AUROC
    if y.min()==0 and y.max()==1 and (y==1).sum()>0 and (y==0).sum()>0:
        order = np.argsort(p)
        ranks = np.empty_like(order, dtype=float); ranks[order] = np.arange(1, len(p)+1)
        n1 = float((y==1).sum()); n0 = float((y==0).sum())
        auroc = (ranks[y==1].sum() - n1*(n1+1)/2.0) / (n0*n1)
    else:
        auroc = float('nan')
    order = np.argsort(-p)
    y_sorted = y_true[order].astype(int)
    tp = np.cumsum(y_sorted==1).astype(float)
    fp = np.cumsum(y_sorted==0).astype(float)
    prec = tp / np.maximum(tp+fp, 1.0)
    rec  = tp / max((y_sorted==1).sum(), 1.0)
    auprc = float(np.trapz(prec, rec)) if len(rec)>1 else float('nan')
    return {"acc": acc, "auroc": float(auroc), "auprc": float(auprc)}

def best_threshold_from_val(y_true, logits, mode="mcc"):
    p = 1.0/(1.0+np.exp(-logits))
    ts = np.linspace(0,1,101)
    best_t, best_score = 0.5, -1.0
    for t in ts:
        pr = (p>=t).astype(int)
        tp = ((pr==1)&(y_true==1)).sum(); tn = ((pr==0)&(y_true==0)).sum()
        fp = ((pr==1)&(y_true==0)).sum(); fn = ((pr==0)&(y_true==1)).sum()
        if mode=="mcc":
            num = (tp*tn - fp*fn)
            den = np.sqrt((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)) + 1e-9
            score = num/den
        else:
            prec = tp/(tp+fp+1e-9); rec = tp/(tp+fn+1e-9)
            score = 2*prec*rec/(prec+rec+1e-9)
        if score > best_score:
            best_score, best_t = score, t
    return best_t

if os.path.exists(CKPT_PATH):
    ckpt = torch.load(CKPT_PATH, map_location="cpu", weights_only=False)

    try:
        model.load_state_dict(ckpt['model_state_dict'])
        print("Loaded model weights from", CKPT_PATH)
    except Exception as e:
        print("Checkpoint load error:", e)
else:
    print("No checkpoint at", CKPT_PATH, "— using current model in memory")

model.to(DEVICE)
model.eval()

calib_T = globals().get('calib_T', {})
best_thresh = globals().get('best_thresh', {})
prev_thresh = globals().get('prev_thresh', {})
if 'calib_T' in ckpt if 'ckpt' in globals() else False:
    calib_T = ckpt.get('calib_T', calib_T)
    best_thresh = ckpt.get('best_thresh', best_thresh)
    prev_thresh = ckpt.get('prev_thresh', prev_thresh)

all_preds = {}
all_trues = {}
with torch.no_grad():
    for task, loader in test_loaders.items():
        preds_list, trues_list = [], []
        for graph_batch, desc in loader:
            graph_batch = graph_batch.to(DEVICE); desc = desc.to(DEVICE)
            with autocast(enabled=(DEVICE=='cuda')):
                out = model(graph_batch, desc, task).view(-1)
            preds_list.append(out.detach().cpu().numpy())
            trues_list.append(graph_batch.y.view(-1).cpu().numpy())
        if len(preds_list)==0:
            all_preds[task] = np.array([])
            all_trues[task] = np.array([])
        else:
            all_preds[task] = np.concatenate(preds_list)
            all_trues[task] = np.concatenate(trues_list)

summary_rows = []
diag_dir = os.path.join(OUT_DIR, "test_misdiagnosed")
os.makedirs(diag_dir, exist_ok=True)

for t in sorted(list(processed.keys())):
    logits = all_preds.get(t, np.array([]))
    y_true = all_trues.get(t, np.array([]))
    if logits.size == 0 or y_true.size == 0:
        print(f"[TEST][{t}] No samples")
        continue

    if task_types[t] == 'regression':
        y_pred = logits
        m = regression_metrics(y_true, y_pred, t)
        print(f"[TEST][{t}] RMSE={m['rmse']:.4f} MAE={m['mae']:.4f}")
        df = pd.DataFrame({
            'Drug': test_df[test_df['task']==t]['Drug'].values,
            'y_true': y_true,
            'y_pred': y_pred
        })
        df['error'] = np.abs(df['y_true'] - df['y_pred'])
        df_sorted = df.sort_values('error', ascending=False)
        path = os.path.join(diag_dir, f"{t}_regression_diag.csv")
        df_sorted.to_csv(path, index=False)
    else:
        Tval = float(calib_T.get(t, 1.0))
        logits_cal = logits / max(Tval, 1e-6)
        probs = 1.0/(1.0+np.exp(-logits_cal))
        thr_mcc = best_thresh.get(t, None)
        thr_prev = prev_thresh.get(t, None)
        if thr_mcc is None:
            thr_mcc = best_threshold_from_val(y_true.astype(int), logits_cal, mode='mcc')
        if thr_prev is None:
            thr_prev = np.quantile(probs, 1.0 - max(y_true.mean(), 1e-6))
        preds_mcc = (probs >= thr_mcc).astype(int)
        preds_prev = (probs >= thr_prev).astype(int)

        m = classification_metrics(y_true, logits_cal)
        auroc = m.get('auroc', np.nan); auprc = m.get('auprc', np.nan)
        acc_mcc = float((preds_mcc == y_true).mean())
        acc_prev = float((preds_prev == y_true).mean())
        print(f"[TEST][{t}] AUROC={auroc:.3f} AUPRC={auprc:.3f} ACC@MCC={acc_mcc:.3f} ACC@prev={acc_prev:.3f}")

        df = pd.DataFrame({
            'Drug': test_df[test_df['task']==t]['Drug'].values,
            'y_true': y_true,
            'logit': logits_cal,
            'prob': probs,
            'pred_mcc': preds_mcc,
            'pred_prev': preds_prev
        })
        mis = df[df['pred_mcc'] != df['y_true']]
        path = os.path.join(diag_dir, f"{t}_classification_diag.csv")
        df.to_csv(path, index=False)

    summary_rows.append({'task': t, 'type': task_types[t], 'n': int(len(y_true))})

summary_df = pd.DataFrame(summary_rows)
summary_df.to_csv(os.path.join(OUT_DIR, "test_summary_tasks.csv"), index=False)
print("Per-task diagnostics saved to:", diag_dir)

if SAVE_TO_DRIVE:
    try:
        from google.colab import drive
        drive.mount('/content/drive', force_remount=False)
        os.makedirs(DRIVE_DIR, exist_ok=True)
        ts = int(time.time())
        import shutil
        dest_dir = os.path.join(DRIVE_DIR, f"test_results_{ts}")
        shutil.copytree(OUT_DIR, dest_dir, dirs_exist_ok=True)
        ckpt_dest = os.path.join(DRIVE_DIR, f"best_model_test_snapshot_{ts}.pth")
        torch.save({'model_state_dict': model.state_dict(),
                    'calib_T': calib_T,
                    'best_thresh': best_thresh,
                    'prev_thresh': prev_thresh}, ckpt_dest)
        print("Pushed diagnostics + checkpoint to Drive:", dest_dir, ckpt_dest)
    except Exception as e:
        print("Drive upload failed:", e)



Loaded model weights from checkpoints/best_gradnorm_pretrained.pth
[TEST][bbb_martins] AUROC=0.862 AUPRC=0.869 ACC@MCC=0.895 ACC@prev=0.865
[TEST][bioavailability_ma] AUROC=0.667 AUPRC=0.790 ACC@MCC=0.842 ACC@prev=0.789
[TEST][clearance_hepatocyte_az] RMSE=0.7819 MAE=0.5229
[TEST][clearance_microsome_az] RMSE=0.6668 MAE=0.4169
[TEST][cyp2c9_substrate_carbonmangels] AUROC=0.528 AUPRC=0.136 ACC@MCC=0.903 ACC@prev=0.806
[TEST][cyp2c9_veith] AUROC=0.842 AUPRC=0.757 ACC@MCC=0.783 ACC@prev=0.783
[TEST][cyp2d6_substrate_carbonmangels] AUROC=0.795 AUPRC=0.618 ACC@MCC=0.806 ACC@prev=0.742
[TEST][cyp2d6_veith] AUROC=0.847 AUPRC=0.612 ACC@MCC=0.844 ACC@prev=0.847
[TEST][cyp3a4_substrate_carbonmangels] AUROC=0.746 AUPRC=0.816 ACC@MCC=0.700 ACC@prev=0.733
[TEST][cyp3a4_veith] AUROC=0.849 AUPRC=0.818 ACC@MCC=0.759 ACC@prev=0.763
[TEST][hia_hou] AUROC=0.974 AUPRC=0.972 ACC@MCC=0.975 ACC@prev=0.950
[TEST][lipophilicity_astrazeneca] RMSE=0.7582 MAE=0.5677
[TEST][pgp_broccatelli] AUROC=0.931 AUPRC=0.951