In [1]:
import json
from pathlib import Path
import sys

import pandas as pd
from lightning import pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
import numpy as np
from sklearn.metrics import r2_score, mean_absolute_error
import torch
import wandb

from chemprop import data, featurizers, models, nn

In [2]:
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
NUM_WORKERS = 0 # number of workers for dataloader. 0 means using main process for data loading
smiles_column = 'smiles_std'
TARGET_COLUMNS = ['LogHLM', 'LogMLM', 'LogD', 'LogKSOL', 'LogMDR1-MDCKII']

In [4]:
def prepare_data(input_df):
    train_data, val_data = [], []
    for _, row in input_df.iterrows():
        dp = data.MoleculeDatapoint.from_smi(row[smiles_column], row[TARGET_COLUMNS].values)
        if row['split'] == 'train':
            train_data.append(dp)
        elif row['split'] == 'val':
            val_data.append(dp)

    pred_data = []
    for _, row in input_df.iterrows():
        dp = data.MoleculeDatapoint.from_smi(row[smiles_column], row[TARGET_COLUMNS].values)
        pred_data.append(dp)

    featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

    train_dset = data.MoleculeDataset(train_data, featurizer)
    # scaler = train_dset.normalize_targets()

    val_dset = data.MoleculeDataset(val_data, featurizer)
    # val_dset.normalize_targets(scaler)

    pred_dset = data.MoleculeDataset(pred_data, featurizer)

    return train_dset, val_dset, pred_dset

In [5]:
# def train_model(config, train_dset, val_dset, num_workers, scaler):
def train_model(config, train_dset, val_dset, num_workers, save_dir):
    # config is a dictionary containing hyperparameters used for the trial
    depth = int(config["depth"])
    ffn_hidden_dim = int(config["ffn_hidden_dim"])
    ffn_num_layers = int(config["ffn_num_layers"])
    message_hidden_dim = int(config["message_hidden_dim"])
    dropout = 0.2

    train_loader = data.build_dataloader(train_dset, num_workers=num_workers, shuffle=True)
    val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)

    mp = nn.BondMessagePassing(d_h=message_hidden_dim, depth=depth, dropout=dropout)
    agg = nn.MeanAggregation()
    # output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)
    # ffn = nn.RegressionFFN(
    #     n_tasks=len(target_columns),
    #     output_transform=output_transform, input_dim=message_hidden_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers,
    #     dropout=dropout
    # )
    ffn = nn.RegressionFFN(
        n_tasks=len(TARGET_COLUMNS),
        output_transform=None, input_dim=message_hidden_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers,
        dropout=dropout
    )
    batch_norm = True
    metric_list = [nn.metrics.MAE(), nn.metrics.R2Score()]
    model = models.MPNN(mp, agg, ffn, batch_norm, metric_list)

    ckpt_callback = ModelCheckpoint(
        save_top_k=0,
        save_last=True
    )

    exp_name = f"chemprop_run_{RUN_IDX}"
    logger = WandbLogger(
        project="admet-challenge",
        name=exp_name,
        prefix=f"{save_dir.stem}",
        save_dir=f"../wandb/{exp_name}"
    )

    trainer = pl.Trainer(
        accelerator="auto",
        devices=1,
        max_epochs=200, # number of epochs to train for
        # below are needed for Ray and Lightning integration
        enable_progress_bar=True,
        callbacks=[ckpt_callback],
        default_root_dir=save_dir,
        logger=logger
    )

    try:
        trainer.fit(model, train_loader, val_loader)
    except Exception as e:
        logger.finalize("failed")
        wandb.finish()
        raise e
    else:
        logger.finalize("success")

    return model

def predict(model, pred_dset, num_workers):
    pred_loader = data.build_dataloader(pred_dset, num_workers=num_workers, shuffle=False)
    
    trainer = pl.Trainer(
        accelerator="auto",
        devices=1,
        enable_progress_bar=True
    )

    model.eval()
    preds = trainer.predict(model, pred_loader, return_predictions=True)
    preds = torch.cat(preds)

    return preds

In [6]:
MODEL_CONFIG = {
    "depth": 5,
    "ffn_hidden_dim": 1500,
    "ffn_num_layers": 2,
    "message_hidden_dim": 300
}

In [7]:
# copypaste from https://github.com/asapdiscovery/asap-polaris-blind-challenge-examples/blob/a613051bac57060f686d9993e201ecaa15e51009/evaulation.py
# with a log-transform fix according to this issue https://github.com/asapdiscovery/asap-polaris-blind-challenge-examples/issues/14

from collections import defaultdict
from typing import Tuple

def mask_nan(y_true, y_pred):
    mask = ~np.isnan(y_true)
    y_true = np.array(y_true)[mask]
    y_pred = np.array(y_pred)[mask]
    return y_true, y_pred

def eval_admet(preds: dict[str, list], refs: dict[str, list]) -> Tuple[dict[str, float], np.ndarray]:
    """
    Eval ADMET targets with MAE for pre-log10 transformed targets (LogD) and MALE  (MAE on log10 transformed dataset) on non-transformed data

    This provides a "relative" error metric that will not be as sensitive to the large outliers with huge errors. This is sometimes known as MALE.

    Parameters
    ----------
    preds : dict[str, list]
        Dictionary of predicted ADMET values.
    refs : dict[str, list]
        Dictionary of reference ADMET values.

    Returns
    -------
    dict[str, float]
        Returns a dictonary of summary statistics
    """
    keys = {
        "MLM",
        "HLM",
        "KSOL",
        "LogD",
        "MDR1-MDCKII",
    }
    # will be treated as is
    logscale_endpts = {"LogD"}

    collect = defaultdict(dict)

    for k in keys:
        if k not in preds.keys() or k not in refs.keys():
            raise ValueError("required key not present")

        ref, pred = mask_nan(refs[k], preds[k])

        if k in logscale_endpts:
            # already log10scaled
            mae = mean_absolute_error(ref, pred)
            r2 = r2_score(ref, pred)
        else:
            # clip to a detection limit
            # epsilon = 1e-8
            # pred = np.clip(pred, a_min=epsilon, a_max=None)
            # ref = np.clip(ref, a_min=epsilon, a_max=None)

            # transform both log10scale
            pred_log10s = np.log10(pred + 1.)
            ref_log10s = np.log10(ref + 1.)

            # compute MALE and R2 in log space
            mae = mean_absolute_error(ref_log10s, pred_log10s)
            r2 = r2_score(ref_log10s, pred_log10s)

        collect[k]["mean_absolute_error"] = mae
        collect[k]["r2"] = r2

    # compute macro average MAE
    macro_mae = np.mean([collect[k]["mean_absolute_error"] for k in keys])
    collect["aggregated"]["macro_mean_absolute_error"] = macro_mae

    # compute macro average R2
    macro_r2 = np.mean([collect[k]["r2"] for k in keys])
    collect["aggregated"]["macro_r2"] = macro_r2

    return collect

In [8]:
def extract_preds(preds: pd.DataFrame):
    preds_dict = {}
    for t in TARGET_COLUMNS:
        if t in ["LogHLM", "LogMLM", "LogKSOL", "LogMDR1-MDCKII"]:
            # transform back to non-log scale
            preds_dict[t[3:]] = np.power(10, preds.iloc[:, preds.columns.get_loc(f"pred_{t}")].values) - 1.
        else:
            preds_dict[t] = preds.iloc[:, preds.columns.get_loc(f"pred_{t}")].values
    
    return preds_dict

def extract_refs(refs: pd.DataFrame):
    refs_dict = {}
    for t in TARGET_COLUMNS:
        if t in ["LogHLM", "LogMLM", "LogKSOL", "LogMDR1-MDCKII"]:
            refs_dict[t[3:]] = refs.iloc[:, refs.columns.get_loc(t[3:])].values
        else:
            refs_dict[t] = refs.iloc[:, refs.columns.get_loc(t)].values
    
    return refs_dict

In [9]:
def train_and_eval(input_paths, save_dirs):
    for input_path, save_dir in zip(input_paths, save_dirs):
        print(f"Training and predicting on {input_path}")
        input_df = pd.read_csv(input_path)
        train_dset, val_dset, pred_dset = prepare_data(input_df)
        model = train_model(MODEL_CONFIG, train_dset, val_dset, NUM_WORKERS, save_dir)
        preds = predict(model, pred_dset, NUM_WORKERS)

        output_df = input_df.copy()
        output_df[["pred_" + t for t in TARGET_COLUMNS]] = preds
        save_dir.mkdir(parents=True, exist_ok=True)
        output_df.to_csv(save_dir / "predictions.csv", index=False)

        train_preds = extract_preds(output_df[input_df["split"] == "train"])
        train_refs = extract_refs(input_df[input_df["split"] == "train"])
        val_preds = extract_preds(output_df[input_df["split"] == "val"])
        val_refs = extract_refs(input_df[input_df["split"] == "val"])

        metrics = eval_admet(train_preds, train_refs)
        print("Train metrics:")
        print(json.dumps(metrics, indent=2))

        metrics = eval_admet(val_preds, val_refs)
        print("\nVal metrics:")
        print(json.dumps(metrics, indent=2))
    
    wandb.finish()

## Run 0

In [10]:
input_paths = [Path(f'../data/asap/datasets/rnd_splits/split_{k}.csv') for k in range(5)]
save_dirs = [Path(f'../output/asap/rnd_splits/chemprop/run_0/split_{k}') for k in range(5)]
RUN_IDX = 0

In [None]:
train_and_eval(input_paths, save_dirs)

## Cleaning up + run 1

In [11]:
def clean_data(input_paths, save_dirs, output_dir, remove_worst_pct):
    smiles_to_remove = defaultdict(set)

    for input_path, save_dir in zip(input_paths, save_dirs):
        input_df = pd.read_csv(input_path)
        input_val_df = input_df[input_df["split"] == "val"]
        output_df = pd.read_csv(save_dir / "predictions.csv")
        output_val_df = output_df[input_df["split"] == "val"]

        for t in TARGET_COLUMNS:
            # Sort by absolute error
            notna_mask = input_val_df[t].notna()
            input_val_df = input_val_df[notna_mask]
            output_val_df = output_val_df[notna_mask]

            mae = np.abs(input_val_df[t] - output_val_df[f"pred_{t}"])
            sorted_idx = np.argsort(mae)[::-1]
            smiles_to_remove[t].update(
                input_val_df.iloc[sorted_idx[:int(remove_worst_pct * len(sorted_idx))]]["cxsmiles_std"].tolist()
            )

    for input_path in input_paths:
        input_df = pd.read_csv(input_path)
        for t in TARGET_COLUMNS:
            input_df.loc[input_df["cxsmiles_std"].isin(smiles_to_remove[t]) & (input_df["split"] == "train"), t] = np.nan

        input_df.to_csv(output_dir / input_path.name, index=False)

In [24]:
output_dir = Path("../output/asap/rnd_splits/chemprop/run_0/cleaned")
output_dir.mkdir(parents=True, exist_ok=True)

In [25]:
clean_data(input_paths, save_dirs, output_dir, remove_worst_pct = 0.2)

In [15]:
pd.read_csv(input_paths[0])[TARGET_COLUMNS].isna().sum()

LogHLM            133
LogMLM            131
LogD               90
LogKSOL            69
LogMDR1-MDCKII      9
dtype: int64

In [16]:
pd.read_csv(output_dir / input_paths[0].name)[TARGET_COLUMNS].isna().sum()

LogHLM            185
LogMLM            179
LogD              128
LogKSOL           102
LogMDR1-MDCKII     45
dtype: int64

In [33]:
input_paths = [Path(f'../output/asap/rnd_splits/chemprop/run_0/cleaned/split_{k}.csv') for k in range(5)]
save_dirs = [Path(f'../output/asap/rnd_splits/chemprop/run_1/split_{k}') for k in range(5)]
RUN_IDX = "1_clean_worst_pct_0.2"

In [None]:
train_and_eval(input_paths, save_dirs)

## Cleaning up + run 2

In [34]:
output_dir = Path("../output/asap/rnd_splits/chemprop/run_1/cleaned")
output_dir.mkdir(parents=True, exist_ok=True)

In [35]:
clean_data(input_paths, save_dirs, output_dir, remove_worst_pct = 0.2)

In [10]:
input_paths = [Path(f'../output/asap/rnd_splits/chemprop/run_1/cleaned/split_{k}.csv') for k in range(5)]
save_dirs = [Path(f'../output/asap/rnd_splits/chemprop/run_2/split_{k}') for k in range(5)]
RUN_IDX = "2_clean_worst_pct_0.2"

In [None]:
train_and_eval(input_paths, save_dirs)

## Cleaning up stero impure + run 1

In [12]:
def clean_data(input_paths, save_dirs, output_dir, remove_worst_pct):
    smiles_to_remove = defaultdict(set)

    for input_path, save_dir in zip(input_paths, save_dirs):
        input_df = pd.read_csv(input_path)
        input_val_df = input_df[input_df["split"] == "val"]
        output_df = pd.read_csv(save_dir / "predictions.csv")
        output_val_df = output_df[input_df["split"] == "val"]

        for t in TARGET_COLUMNS:
            # Sort by absolute error
            notna_mask = input_val_df[t].notna()
            input_val_df = input_val_df[notna_mask]
            output_val_df = output_val_df[notna_mask]

            mae = np.abs(input_val_df[t] - output_val_df[f"pred_{t}"])
            sorted_idx = np.argsort(mae)[::-1]
            smiles_to_remove[t].update(
                input_val_df.iloc[sorted_idx[:int(remove_worst_pct * len(sorted_idx))]]["cxsmiles_std"].tolist()
            )

    for input_path in input_paths:
        input_df = pd.read_csv(input_path)
        for t in TARGET_COLUMNS:
            input_df.loc[
                input_df["cxsmiles_std"].isin(smiles_to_remove[t]) & \
                    ~input_df["smiles_ext"].isna() & \
                    (input_df["split"] == "train"),
                t
            ] = np.nan

        input_df.to_csv(output_dir / input_path.name, index=False)

In [13]:
output_dir = Path("../output/asap/rnd_splits/chemprop/run_0/cleaned")
output_dir.mkdir(parents=True, exist_ok=True)

In [14]:
clean_data(input_paths, save_dirs, output_dir, remove_worst_pct = 0.2)

In [20]:
pd.read_csv(input_paths[0])[TARGET_COLUMNS].isna().sum()

LogHLM            133
LogMLM            131
LogD               90
LogKSOL            69
LogMDR1-MDCKII      9
dtype: int64

In [18]:
tmp = pd.read_csv(input_paths[0])
tmp = tmp[tmp["smiles_ext"].isna()]
tmp[TARGET_COLUMNS].isna().sum()

LogHLM            71
LogMLM            55
LogD              45
LogKSOL           35
LogMDR1-MDCKII     7
dtype: int64

In [16]:
pd.read_csv(output_dir / input_paths[0].name)[TARGET_COLUMNS].isna().sum()

LogHLM            166
LogMLM            158
LogD              105
LogKSOL            88
LogMDR1-MDCKII     27
dtype: int64

In [17]:
tmp = pd.read_csv(output_dir / input_paths[0].name)
tmp = tmp[tmp["smiles_ext"].isna()]
tmp[TARGET_COLUMNS].isna().sum()

LogHLM            71
LogMLM            55
LogD              45
LogKSOL           35
LogMDR1-MDCKII     7
dtype: int64

In [21]:
input_paths = [Path(f'../output/asap/rnd_splits/chemprop/run_0/cleaned/split_{k}.csv') for k in range(5)]
save_dirs = [Path(f'../output/asap/rnd_splits/chemprop/run_1/split_{k}') for k in range(5)]
RUN_IDX = "1_clean_worst_pct_0.2_stereo_impure"

In [None]:
train_and_eval(input_paths, save_dirs)

## Removing all stereo impure + run 0

In [24]:
input_paths = [Path(f'../data/asap/datasets/rnd_splits/split_{k}.csv') for k in range(5)]
output_dir = Path("../data/asap/datasets/rnd_splits/stereo_pure")
output_dir.mkdir(parents=True, exist_ok=True)

for input_path in input_paths:
    input_df = pd.read_csv(input_path)
    input_df = pd.concat([
        input_df[(input_df["split"] == "train") & input_df["smiles_ext"].isna()],
        input_df[input_df["split"] == "val"]
    ])
    input_df.to_csv(output_dir / input_path.name, index=False)

In [None]:
input_paths = [Path(f'../data/asap/datasets/rnd_splits/stereo_pure/split_{k}.csv') for k in range(5)]
save_dirs = [Path(f'../output/asap/rnd_splits/chemprop/run_0/stereo_pure/split_{k}') for k in range(5)]
RUN_IDX = "0_stereo_pure"
train_and_eval(input_paths, save_dirs)