In [None]:
import torch
import pickle
from XAI import *
from matplotlib import pyplot as plt
from tqdm import tqdm
from tdc.benchmark_group import admet_group
import os
from torch import stack, tensor, Generator, cat, float32, nonzero, set_float32_matmul_precision
from tdc.metadata import admet_metrics, bm_metric_names
from tdc.evaluator import Evaluator
from sklearn.metrics import matthews_corrcoef as mcc
from sklearn.metrics import confusion_matrix
from Utilities import *
import math

regression_keys = [
    'caco2_wang',
    'lipophilicity_astrazeneca',
    'solubility_aqsoldb',
    'ppbr_az',
    'ld50_zhu'
]

group = admet_group(path = '../data_tdc/')
names = group.dataset_names
data = {}
seeds = [1,2,3,4,5]
models = [
          'qm_all_20L_wide_def_2e-5_16p',
          'masking_20L_wide_def_2e-5_16p', 
          'homo-lumo_20L_wide_def_2e-5_16p',
          'scratch_20L_wide_def_2e-5_16p', 
          'charges_20L_wide_def_2e-5_16p', 
          'nmr_20L_wide_def_2e-5_16p', 
          'fukui_n_20L_wide_def_2e-5_16p', 
          'fukui_e_20L_wide_def_2e-5_16p'
         ]

for model_ in models:
    results_dict = {}
    for name in names:
    
        predictions_list = []
        
        for seed in seeds:
            
            model = GT(
                        task_name='tdc', 
                        shared_weights = False, 
                        d_model = 256, 
                        nhead = 32, 
                        num_layers = 20, 
                        dim_feedforward = 512, 
                        norm_first = True, 
                        post_norm = True, 
                        zero_bias = True, 
                        reg_loss = 'l1'
                        )
            
            predictions = {}
            
            best_checkpoint_dir = f'TDC_checkpoints_old/{model_}/{name}_{seed}/'
            ckpts = os.listdir(best_checkpoint_dir)
            best = [c for c in ckpts if c.startswith('epoch')][0]
            print(best_checkpoint_dir + best)
            model = model.load_from_checkpoint(best_checkpoint_dir + best, strict=False)
            
            model.eval()
            model.freeze()
            
            ev = Evaluator(admet_metrics[name])
            ev_func = ev.evaluator_func
            benchmark = group.get(name) 

            train_val, test = benchmark['train_val'], benchmark['test']
            train, valid = group.get_train_valid_split(benchmark = name, split_type = 'default', seed = seed)
            train['is_valid'] = train['Drug'].apply(lambda x: check_valid(x))
            valid['is_valid'] = valid['Drug'].apply(lambda x: check_valid(x))
            test['is_valid'] = test['Drug'].apply(lambda x: check_valid(x))
            train = train[train['is_valid']]
            valid = valid[valid['is_valid']]
            test = test[test['is_valid']]
            
            train_str = [chython.smiles(sm).pack() for sm in train['Drug']]
            val_str = [chython.smiles(sm).pack() for sm in valid['Drug']]
            test_str = [chython.smiles(sm).pack() for sm in test['Drug']]

            train_lbls = train['Y'].values.tolist()
            val_lbls = valid['Y'].values.tolist()
            test_lbls = test['Y'].values.tolist()
            
            y_pred_test = get_test_predictions(model, test_str)

            if len(train_val['Y'].unique()) == 2:
                task_name = 'classification'
                print("classification")
                scaler = None

            else:
                task_name = 'regression'
                print("regression")
                scaler = StandardScaler()
                scaler.fit(train[['Y']])
                
            if scaler is not None:
                print('scaling back...')
                y_pred_test = scaler.inverse_transform(np.array([y_pred_test]).T)
                scaler = None

            metrica = ev_func(test['Y'].values.reshape(-1,1), y_pred_test)
            print(metrica)
            predictions[name] = y_pred_test   
            with open(f'TDC_checkpoints_old/{model_}/{name}_{seed}/predictions_2.pkl', 'wb') as f:
                pickle.dump(y_pred_test, f)
            predictions_list.append(predictions)
            
        results = group.evaluate_many(predictions_list)
        
        for k,v in results.items():
            results_dict[k] = v
            
    with open(f'./results_dicts_new/{model_}.pkl', 'wb') as f:
            pickle.dump(results_dict, f)

In [None]:
import pickle as pk

files = [ 
         'scratch_20L_wide_def_2e-5_16p', 
         'qm_all_20L_wide_def_2e-5_16p', 
         'charges_20L_wide_def_2e-5_16p', 
         'nmr_20L_wide_def_2e-5_16p', 
         'fukui_n_20L_wide_def_2e-5_16p', 
         'fukui_e_20L_wide_def_2e-5_16p',
         'masking_20L_wide_def_2e-5_16p',
         'homo-lumo_20L_wide_def_2e-5_16p',
        ]

collected_results = {}

for file in files:
    
    with open(f'./results_dicts_new/{file}.pkl', 'rb') as f:
        res = pk.load(f)
    
    collected_results[file] = res

In [None]:
# ADMET data extracted from the website
admet_data = {
    "TDC.Caco2_Wang": {
        "metric": "MAE",
        "leaderboard": {
            1: {"MapLight": "0.276 ± 0.005"},
            2: {"BaseBoosting": "0.285 ± 0.005"},
            17: {"Morgan + MLP (DeepPurpose)": "0.908 ± 0.060"}
        }
    },
    "TDC.Bioavailability_Ma": {
        "metric": "AUROC",
        "leaderboard": {
            1: {"SimGCN": "0.748 ± 0.033"},
            2: {"MapLight + GNN": "0.742 ± 0.010"},
            16: {"Basic ML": "0.523 ± 0.011"}
        }
    },
    "TDC.Lipophilicity_AstraZeneca": {
        "metric": "MAE",
        "leaderboard": {
            1: {"Chemprop-RDKit": "0.467 ± 0.006"},
            2: {"Chemprop": "0.470 ± 0.009"},
            15: {"CNN (DeepPurpose)": "0.743 ± 0.020"}
        }
    },
    "TDC.Solubility_AqSolDB": {
        "metric": "MAE",
        "leaderboard": {
            1: {"Chemprop-RDKit": "0.761 ± 0.025"},
            2: {"AttentiveFP": "0.776 ± 0.008"},
            14: {"Morgan + MLP (DeepPurpose)": "1.203 ± 0.019"}
        }
    },
    "TDC.HIA_Hou": {
        "metric": "AUROC",
        "leaderboard": {
            1: {"MapLight + GNN": "0.989 ± 0.001"},
            2: {"RFStacker": "0.988 ± 0.002"},
            16: {"Morgan + MLP (DeepPurpose)": "0.807 ± 0.072"}
        }
    },
    "TDC.Pgp_Broccatelli": {
        "metric": "AUROC",
        "leaderboard": {
            1: {"MapLight + GNN": "0.938 ± 0.002"},
            2: {"ZairaChem": "0.935 ± 0.006"},
            16: {"Basic ML": "0.818 ± 0.000"}
        }
    },
    "TDC.BBB_Martins": {
        "metric": "AUROC",
        "leaderboard": {
            1: {"MapLight": "0.916 ± 0.001"},
            2: {"Lantern RADR Ensemble": "0.915 ± 0.002"},
            21: {"Euclia ML model": "0.725 ± 0.019"}
        }
    },
    "TDC.PPBR_AZ": {
        "metric": "MAE",
        "leaderboard": {
            1: {"MapLight + GNN": "7.526 ± 0.106"},
            2: {"MapLight": "7.660 ± 0.058"},
            15: {"Morgan + MLP (DeepPurpose)": "12.848 ± 0.362"}
        }
    },
    "TDC.VDss_Lombardo": {
        "metric": "Spearman",
        "leaderboard": {
            1: {"MapLight + GNN": "0.713 ± 0.007"},
            2: {"MapLight": "0.707 ± 0.009"},
            15: {"CNN (DeepPurpose)": "0.226 ± 0.114"}
        }
    },
    "TDC.CYP2C9_Veith": {
        "metric": "AUPRC",
        "leaderboard": {
            1: {"MapLight + GNN": "0.859 ± 0.001"},
            2: {"ContextPred": "0.839 ± 0.003"},
            15: {"Basic ML": "0.556 ± 0.000"}
        }
    },
    "TDC.CYP2D6_Veith": {
        "metric": "AUPRC",
        "leaderboard": {
            1: {"MapLight + GNN": "0.790 ± 0.001"},
            2: {"ContextPred": "0.739 ± 0.005"},
            15: {"Euclia ML model": "0.348 ± 0.004"}
        }
    },
    "TDC.CYP3A4_Veith": {
        "metric": "AUPRC",
        "leaderboard": {
            1: {"MapLight + GNN": "0.916 ± 0.000"},
            2: {"ContextPred": "0.904 ± 0.002"},
            15: {"Basic ML": "0.654 ± 0.000"}
        }
    },
    "TDC.CYP2C9_Substrate_CarbonMangels": {
        "metric": "AUPRC",
        "leaderboard": {
            1: {"ZairaChem": "0.441 ± 0.033"},
            1: {"ZairaChem": "0.441 ± 0.033"},
            15: {"Euclia ML model": "0.347 ± 0.018"}
        }
    },
    "TDC.CYP2D6_Substrate_CarbonMangels": {
        "metric": "AUPRC",
        "leaderboard": {
            1: {"ContextPred": "0.736 ± 0.024"},
            2: {"MapLight + GNN": "0.720 ± 0.002"},
            15: {"Basic ML": "0.478 ± 0.018"}
        }
    },
    "TDC.CYP3A4_Substrate_CarbonMangels": {
        "metric": "AUROC",
        "leaderboard": {
            1: {"CNN (DeepPurpose)": "0.662 ± 0.031"},
            1: {"CNN (DeepPurpose)": "0.662 ± 0.031"},
            15: {"NeuralFP": "0.578 ± 0.020"}
        }
    },
    "TDC.Half_Life_Obach": {
        "metric": "Spearman",
        "leaderboard": {
            1: {"CAF": "0.576 ± 0.025"},
            2: {"MapLight": "0.562 ± 0.008"},
            15: {"AttentiveFP": "0.085 ± 0.068"}
        }
    },
    "TDC.Clearance_Hepatocyte_AZ": {
        "metric": "Spearman",
        "leaderboard": {
            1: {"MapLight + GNN": "0.498 ± 0.009"},
            2: {"MapLight": "0.466 ± 0.012"},
            13: {"Morgan + MLP (DeepPurpose)": "0.272 ± 0.068"}
        }
    },
    "TDC.Clearance_Microsome_AZ": {
        "metric": "Spearman",
        "leaderboard": {
            1: {"MapLight + GNN": "0.630 ± 0.010"},
            2: {"MapLight": "0.626 ± 0.008"},
            16: {"CNN (DeepPurpose)": "0.252 ± 0.116"}
        }
    },
    "TDC.LD50_Zhu": {
        "metric": "MAE",
        "leaderboard": {
            1: {"BaseBoosting": "0.552 ± 0.009"},
            2: {"MACCS keys + autoML": "0.588 ± 0.005"},
            14: {"CNN (DeepPurpose)": "0.675 ± 0.011"}
        }
    },
    "TDC.hERG": {
        "metric": "AUROC",
        "leaderboard": {
            1: {"MapLight + GNN": "0.880 ± 0.002"},
            2: {"SimGCN": "0.874 ± 0.014"},
            10: {"CNN (DeepPurpose)": "0.754 ± 0.037"}
        }
    },
    "TDC.AMES": {
        "metric": "AUROC",
        "leaderboard": {
            1: {"ZairaChem": "0.871 ± 0.002"},
            2: {"MapLight + GNN": "0.869 ± 0.002"},
            13: {"CNN (DeepPurpose)": "0.776 ± 0.015"}
        }
    },
    "TDC.DILI": {
        "metric": "AUROC",
        "leaderboard": {
            1: {"ZairaChem": "0.925 ± 0.005"},
            2: {"AttrMasking": "0.919 ± 0.008"},
            15: {"CNN (DeepPurpose)": "0.792 ± 0.016"}
        }
    }
}

def parse_values(value_str):
    mean, variance = value_str.split("±")
    return [float(mean.strip()), float(variance.strip())]

final_dict = {"Rank1": {}, "Rank2": {}, "RankLast": {}, "Metrics": {}}

for dataset, data in admet_data.items():
    final_dict["Metrics"][dataset] = data["metric"]
    for rank, models in data["leaderboard"].items():
        for model, value in models.items():
            parsed_values = parse_values(value)
            if rank == 1:
                final_dict["Rank1"][dataset] = parsed_values
            elif rank == 2:
                final_dict["Rank2"][dataset] = parsed_values
            else:
                final_dict["RankLast"][dataset] = parsed_values

keyss = list(final_dict['Rank1'].keys())

for j in final_dict.keys():
    for k in keyss:
        try:
            newk = k[4:].lower()

            final_dict[j][newk] = final_dict[j].pop(k) 
        except: 
            print('missing data nan filling')
            final_dict[j][newk] = [float('nan'), float('nan')]

In [None]:
metric_dict = bm_metric_names['admet_group']
regression_metrics = ['spearman', 'mae']

datasets = list(collected_results[next(iter(collected_results))].keys())

num_datasets = len(datasets)
grid_size = math.ceil(math.sqrt(num_datasets))

fig, axes = plt.subplots(grid_size, grid_size, figsize=(15, 20))
fig.tight_layout(pad=5.0)

axes_flat = axes.flatten()

for i, dataset in enumerate(datasets):
    ax = axes_flat[i]

    for main_key in collected_results:
        
        mean, std_dev = collected_results[main_key][dataset]

        ax.bar(main_key, mean, yerr=std_dev, label=f'{main_key} (std: {std_dev:.3f})')
        ax.axhline(final_dict['Rank1'][dataset][0], c = 'r', linestyle = '-.', linewidth = 1.5)
        ax.axhline(final_dict['Rank2'][dataset][0], c = 'b', linestyle = '-.', linewidth = 1.5)
        ax.axhline(final_dict['RankLast'][dataset][0], c = 'g', linestyle = '-.', linewidth = 1.5)

    ax.set_ylabel('Mean Value')
    metr = metric_dict[dataset]
    ax.set_title(f'{dataset} ({metr})')
    
    xticks = [k for k in collected_results.keys()]
    ax.set_xticklabels(xticks, rotation=45, ha='right')

for j in range(i + 1, len(axes_flat)):
    axes_flat[j].set_visible(False)


plt.tight_layout()

plt.show()

In [None]:
import pandas as pd

def reorganize_dictionary(original_dict):
    new_dict = {}
    for model, tasks in original_dict.items():
        for task, result in tasks.items():
            if task not in new_dict:
                new_dict[task] = {}
            new_dict[task][model] = result
    return new_dict

performance_dict = reorganize_dictionary(collected_results)

df = pd.DataFrame()

for task, models in performance_dict.items():
    for model, values in models.items():
        mean, std = values
        df.loc[task, model] = f'{mean:.3f} ± {std:.3f}'

print(df.to_latex())