In [None]:
RETRAIN_XGBOOST = False
RETRAIN_FP_MODEL = False
RETRAIN_GNN_MODEL = False
RETRAIN_SSL_MODEL = False
RETRAIN_BERT_MODEL = False

# Machine Learning for Predicting Targeted Protein Degradation

## Notes

### Machine Learning Model

The model will try to predict whether a given PROTAC is active or not, effectively making it a binary classification task.

### Biochemistry Notes

Some notes about the biochemistry behind the PROTACs (it might contain some errors/silly statements, as I'm not a biochemist):

* A gene is a portion the DNA in the chomosome
* A gene starts and ends with a specific sequence
* A gene is "copied" to an mRNA, the mRNA (or something else?) then converts it to the protein
* 3 gene bases encode one aminoacid in the protein. An aminoacid can be encoded by several triplets (side note: the more triplets encode the same aminoacid, the less likely is that, in case of mutations, a different aminoacid is encoded)
* Genes can be slightly different in different organisms, that's why we have different uniprot ID, despite the gene reported in the entry is the same
* The cell type _might_ refer to the different cell type used for conducting the experiments. In fact, different cells might be difficult to handle/grow in lab. Also,  despite different cell types might use/have internally the same protein, the protein can be slightly different in different cells. Finally, the cell itself can influence the PROTAC response and in turn result in different DC50 values
* Intuitively, $IC_{50}$ in general measures how well two molecules bind togheter. That's why it is reported for different pairs, like E3-e3_ligase, Warhead-POI, et cetera.

### Technical Stuff

* Is Optuna really popular to be used in larger projects?
    * Yes, AZ even has its in-house development of Optuna (Eva can point me to the persons working on it to discuss things)
* Is there a framework for automatically conducting ablation studies?
* Shall I share information/work via jupyter/colab/github?
* Almost everything is stored in categorical fashion, maybe Pytorch won't like it for technical reasons...

### TODOs on Technical Stuff

* Is there a framework for automatically conducting ablation studies?
* Shall I share information/work via jupyter/colab/github?
    * Yes, I'm working on public data, so it's fine (but do it at the end of the project)
* ~~Better organize the checkpoints, the evaluation results and the plots. Maybe a common CSV file?~~
* ~~Add single mapping argument for wrapper model inside `train_model`~~
* ~~Add `get_smiles_embedding_size` method to the sub-models, instead of a wrapper-model argument~~
* ~~Is Optuna really popular to be used in larger projects?~~
    * Yes, AZ even has its in-house development of Optuna (Eva can point me to the persons working on it to discuss things)
* ~~Implement a Optuna callback for deleting all but the best model checkpoints~~

### TODOs

* What loss function shall we use? The Huber loss seems to be a good candidate, but how do we set its $\delta$ parameter? How many "outliers" are there?
* The final model shall include some representation of the SMILES (either fingerprints or BERT-based or GNN-based), together with other features like E3 ligase, cell type, et cetera

* When encoding SMILES to graphs, what about using the binding affinity as node features?
* Encoding/pass to the model the **binding pockets**, _i.e._, amino acids binding to the PROTAC
* Encode/pass to the model the gene ID
    * The POI sequence itself is not "useful": if we were to extract the embeddings from AlphaFold, which is _trained_ to generate the 3D structure, then we might gain something from the POI sequence. Also, it would be useful to leverage information from the 3D structure of the _complex_, which we currently don't have 
    * How do we represent the POI amino acid sequence anyway?
        * Count vectorizer?
        * Specific tokenizer?
        * MSA?
        * Custom Enbedding?
        * AlphaFold?
* Web scraping degradation percentages from the Western Blot figures, which are only available online
* Instead of predicting a value, what about returning the function parameters of the $DC_{50}$ response?
    * Interesting idea, but we don't have that much information from the dataset unfortunately
        * We might have it actually, but only for a restricted number of entries...
* AZ we have some PROTAC patents which might have some extra data we can use 
* ~~Include mutations for different genes for the same Uniprot ID~~
* ~~Reproduce Rocío's student model~~
* ~~Explore additional fingerprints~~:
    * ~~Morgan (already included)~~
    * ~~MACCS~~
    * ~~Path-based (at different lenghts to eventually capture how things are connected to each other and the long linker atoms)~~
* ~~Predict the concentration value (**in log base 10!**)~~
* ~~Check why number of Uniprot ID is different from number of gene entries~~
    * I need to update the entries to include some _mutations_: they are not captured by the Uniprot ID, but should be easy to include
* ~~Normalize the concentration:~~
    1. ~~Convert nM to M~~
    2. ~~Take the negative logarithm~~
* ~~The cell type might be case insensitive, double check it with Eva~~
    * It is case sensitive
* ~~Check if DS biased towards a certain E3~~
    * Yes, it is
* ~~Double check if the current DB is the "full one"~~
    * Yes, it is
* ~~Get finer details like canonical SMILES representation. (RDKit can get the canonal one)~~

## Future Work Ideas

* Data augmentations:
    * Apply upsampling w.r.t. features like E3 ligase, cell type, et cetera
    * Scrumble the SMILES
* SSL:
    * Contrastive learning 
    * Apply active learning and [semi-supervised learning](https://lilianweng.github.io/posts/2021-12-05-semi-supervised/) techniques
* Explainability:
    * Detailed analysis of feature importance from the XGBoost model and SSL Transformer models
    * Chamical space analysis of PROTAC-DB vs. PROTAC-pedia (UMAP, PCA, t-SNE, et cetera, of fingerprints and SMILES (show the variance in the plots!))
    * Plotting on the $D_{max}$ vs. $pDC_{50}$ graph the predictions of the model. I suspect and hope that the model will struggle to predict point on the border between active and inactive PROTACs. (Add two dotted lines to the plot to show the activity thresholds)
* Model-related ideas:
    * Try residual connections
    * Trying even more types of fingerprints, like [CDDD](https://github.com/jrwnter/cddd)
    * Combining/adding multiple fingerprints, like Morgan and path-based
    * Advanced embeddings for other features, like the POI sequence
        * How to deal with the different cell types?
            * One-hot encoding?
            * Embedding?
            * Other?
* Ensemble Methods: Ensemble methods involve training multiple models independently and then combining their predictions to make final decisions. Techniques such as bagging, boosting, and stacking are commonly used to aggregate the predictions of multiple models.
* Using PROTAC-pedia entries to train on their definition of active/inactive and then finetune on PROTAC-DB and a more stringent definition of active/inactive
* Predict the regression task

activity cliffiness (prediction)?

## Setup and Imports

In [None]:
import optuna

optuna.logging.set_verbosity(optuna.logging.WARN) #INFO, WARN

In [None]:
# import ray
# from ray import air, tune
# from ray.air import session
# from ray.tune import CLIReporter
# from ray.tune.schedulers import (ASHAScheduler,
#                                  PopulationBasedTraining,
#                                  HyperBandScheduler)
# from ray.tune.integration.pytorch_lightning import (TuneReportCallback,
#                                                     TuneReportCheckpointCallback)

# from ray.tune.search import ConcurrencyLimiter
# from ray.tune.search.optuna import OptunaSearch
# from ray.air import CheckpointConfig

In [None]:
from IPython.display import display_html
# from IPython.core.interactiveshell import InteractiveShell
# InteractiveShell.ast_node_interactivity = 'all'

import collections
import itertools
import re
import gc
import math
import numpy as np
import pandas as pd
import pickle
import requests as r
import matplotlib.pyplot as plt
import seaborn as sns
import shutil
import random
import copy
import os

import typing
from typing import Mapping, Literal, Callable, List, ClassVar, Any, Tuple, Type

from uuid import uuid4
from rdkit import Chem
from rdkit.Chem import AllChem, DataStructs, MACCSkeys, Draw
from rdkit.Chem.Draw import IPythonConsole
from datetime import date
from scipy.sparse import csr_matrix, vstack
from tqdm import tqdm

import sklearn
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.model_selection import train_test_split, GroupShuffleSplit
from sklearn import preprocessing
from sklearn import metrics
from sklearn.metrics import classification_report, f1_score, roc_auc_score
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.utils import resample, class_weight

pd.set_option('display.max_columns', 1000, 'display.width', 2000, 'display.max_colwidth', 100)

In [None]:
import xgboost as xgb

In [None]:
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger

# import lightning as pl
# from lightning import LightningModule, Trainer, seed_everything
# from pytorch_lightning.callbacks import ModelCheckpoint
# from pytorch_lightning.callbacks.progress import TQDMProgressBar
# from pytorch_lightning.callbacks.early_stopping import EarlyStopping
# from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger


import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader, random_split

from torchvision.ops import MLP

import torch_geometric
import torch_geometric.nn as geom_nn
import torch_geometric.data as geom_data
from torch_geometric.utils.smiles import from_smiles

from torchmetrics import (Accuracy,
                          AUROC,
                          ROC,
                          Precision,
                          Recall,
                          F1Score,
                          MeanAbsoluteError,
                          MeanSquaredError)
from torchmetrics.functional import (mean_absolute_error,
                                     mean_squared_error,
                                     mean_squared_log_error,
                                     pearson_corrcoef,
                                     r2_score)
from torchmetrics.functional.classification import (binary_accuracy,
                                                    binary_auroc,
                                                    binary_precision,
                                                    binary_recall,
                                                    binary_f1_score)

# Sets seeds for numpy, torch and python.random.
torch.manual_seed(42)
np.random.seed(42)
pl.seed_everything(42, workers=True)
# torch.use_deterministic_algorithms(True) # TODO: This is a GPU-related thing..

Reduce logging information:

In [None]:
import logging
import warnings
import re

def set_global_logging_level(level=logging.ERROR, prefices=[""]):
    """
    Override logging levels of different modules based on their name as a prefix.
    It needs to be invoked after the modules have been loaded so that their loggers have been initialized.

    Args:
        - level: desired level. e.g. logging.INFO. Optional. Default is logging.ERROR
        - prefices: list of one or more str prefices to match (e.g. ["transformers", "torch"]). Optional.
          Default is `[""]` to match all active loggers.
          The match is a case-sensitive `module_name.startswith(prefix)`
    """
    prefix_re = re.compile(fr'^(?:{ "|".join(prefices) })')
    for name in logging.root.manager.loggerDict:
        if re.match(prefix_re, name):
            logging.getLogger(name).setLevel(level)

# Filter out annoying Pytorch Lightning printouts
warnings.filterwarnings('ignore', '.*does not have many workers.*')
warnings.filterwarnings('ignore', '.*Checkpoint directory.*')
warnings.filterwarnings('ignore', '.*The number of training batches.*')
warnings.filterwarnings('ignore', '.*is an instance of.*')
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
logging.getLogger("pytorch_lightning.utilities.rank_zero_warn").setLevel(logging.ERROR)
set_global_logging_level(logging.ERROR, ['transformers', 'nlp', 'torch', 'tensorflow', 'tensorboard', 'wandb', 'xgboost'])

Setup directories:

In [None]:
data_dir = os.path.join(os.getcwd(), '..', 'data')
src_dir = os.path.join(os.getcwd(), '..', 'src')
fig_dir = os.path.join(data_dir, 'figures')
checkpoint_dir = os.path.join(os.getcwd(), '..', 'checkpoints')
if not os.path.exists(data_dir):
    os.makedirs(data_dir)
if not os.path.exists(src_dir):
    os.makedirs(src_dir)
if not os.path.exists(fig_dir):
    os.makedirs(fig_dir)
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

In [None]:
import networkx as nx



# edge_index = torch.tensor([[0, 1, 1, 2],
#                            [1, 0, 2, 1]], dtype=torch.long)
# x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

# data = torch_geometric.data.Data(x=x, edge_index=edge_index)

data = from_smiles('CN1C=NC2=C1C(=O)N(C(=O)N2C)C')
# g = torch_geometric.utils.to_networkx(data, to_undirected=True)
# nx.draw(g)


import networkx as nx
from matplotlib import pyplot as plt
from torch_geometric.nn import to_hetero

g = torch_geometric.utils.to_networkx(data)
# Networkx seems to create extra nodes from our heterogeneous graph, so I remove them
isolated_nodes = [node for node in g.nodes() if g.out_degree(node) == 0]
[g.remove_node(i_n) for i_n in isolated_nodes]
# Plot the graph
nx.draw(g, with_labels=True)
plt.show()

plt.figure(figsize=(10, 10))
plt.matshow(data.x.numpy())
plt.title('Node feature matrix')
plt.ylabel('Node index')
plt.xlabel('Node features')
plt.colorbar()
plt.show()

plt.figure(figsize=(10, 10))
plt.matshow(data.edge_index.numpy())
plt.title('Edge matrix')
plt.xlabel('Edge index')
plt.ylabel('Edge')
plt.show()


In [None]:
from torch_geometric.utils.convert import to_scipy_sparse_matrix, from_scipy_sparse_matrix

adj = to_scipy_sparse_matrix(data.edge_index)
plt.spy(adj)
plt.title('Adjacency matrix')
plt.show()

## Load Datasets

Load PROTAC-DB, used for training and validation:

In [None]:
df_file = os.path.join(data_dir, 'processed', 'protac-db_dc50_dmax.csv')
protac_df = pd.read_csv(df_file)
protac_db_df = pd.concat([
    protac_df['Smiles'],
    protac_df['Smiles_nostereo'],
    protac_df['DC50'].astype(float),
    protac_df['pDC50'].astype(float),
    protac_df['Dmax'].astype(float),
    protac_df['poi_gene_id'],
    protac_df['poi_seq'],
    protac_df['cell_type'],
    protac_df['e3_ligase'],
    # protac_df['treatment_hours'], # NOTE: Not used in this analysis...
    protac_df['active'],
], axis=1).reset_index(drop=True)
print('protac_db_df: {:,} x {:,}'.format(*protac_db_df.shape))

dataframes = {}
files = [
    'protac-db_DC50_Dmax',
    'protac-db_ssl',
    'protac-db_no_degradation',
    'protac-db_interpolated',
]
relevant_cols = [
    'DC50',
    'pDC50',
    'Dmax',
    'poi_gene_id',
    'poi_seq',
    'cell_type',
    'e3_ligase',
    'Smiles',
    'Smiles_nostereo',
    # 'treatment_hours', # NOTE: Not used in our work...
    'active',
]
for f in files:
    df_file = os.path.join(data_dir, 'processed', f + '.csv')
    print(f'Loading "{f}"...')
    dataframes[f] = pd.read_csv(df_file, usecols=relevant_cols).reset_index(drop=True)
    display(dataframes[f])

Load PROTAC-Pedia, used for testing:

In [None]:
df_file = os.path.join(data_dir, 'processed', 'protac-pedia_dc50_dmax.csv')
protac_df = pd.read_csv(df_file)
protac_pedia_df = pd.concat([
    protac_df['DC50'].astype(float),
    protac_df['pDC50'].astype(float),
    protac_df['Dmax'].astype(float),
    protac_df['poi_seq'],
    protac_df['cell_type'],
    protac_df['active'],
    protac_df['Active/Inactive'],
], axis=1).reset_index(drop=True)
protac_pedia_df['e3_ligase'] = protac_df['E3 Ligase']
protac_pedia_df['poi_gene_id'] = 'Unknown'
protac_pedia_df['Smiles'] = protac_df['PROTAC SMILES']
protac_pedia_df['Smiles_nostereo'] = protac_df['PROTAC SMILES_nostereo']
print('protac_pedia_df: {:,} x {:,}'.format(*protac_pedia_df.shape))

dataframes['protac-pedia'] = protac_pedia_df
dataframes['protac-pedia_DC50_Dmax'] = protac_pedia_df[~protac_pedia_df['active'].isna()]
dataframes['protac-pedia_ssl'] = protac_pedia_df[protac_pedia_df['active'].isna()]

print('protac-pedia_DC50_Dmax:')
display(dataframes['protac-pedia_DC50_Dmax'])
print('protac-pedia_ssl:')
display(dataframes['protac-pedia_ssl'])

## Assemble Train/Val/Test Sets

### Train/Test Split (TODO)

There are two approaches or strategies we might want to pursue:

* Given a PROTAC structure, we want to make the best DC prediction
* Given a POI, we want the best PROTAC that targets it

In our case, the end goal is to design efficient PROTACs, so we follow the first paradigm.
Because of that, we now make sure that the same SMILES/PROTAC structure ends up in a single dataset, either in the training or test set.
<!-- Since we want to predict the degradation percentage PROTAC-wise, we split the train-test sets according to the PROTACs' SMILES representations. -->

The following is inspired from this [Stackoverflow question](https://stackoverflow.com/questions/54797508/how-to-generate-a-train-test-split-based-on-a-group-id).

TODO: Make sure that both PROTAC and POI are "separate".

In [None]:
TRAIN_SPLIT_PERC = 0.9
TEST_SPLIT_PERC = 1.0 - TRAIN_SPLIT_PERC

def split_df(df, test_perc=TEST_SPLIT_PERC):
    tmp = df
    if isinstance(df, list):
        tmp = pd.concat(df, axis=0, ignore_index=True)
    splitter = GroupShuffleSplit(test_size=test_perc,
                                 n_splits=2,
                                 random_state=42)
    split = splitter.split(tmp, groups=tmp['Smiles_nostereo'])
    train_inds, val_inds = next(split)
    train_df, val_df = tmp.iloc[train_inds], tmp.iloc[val_inds]
    return train_df, val_df

Assemble the train/val/test sets for binary classification task:

NOTE: PROTAC-Pedia is used for testing, so we remove all its entries which are also present in PROTAC-DB. We also do NOT include them in the SSL set either, so they are "lost forever", unfortunately.

In [None]:
# Assemble PROTAC-DB and PROTAC-Pedia for binary classification task
protac_db = pd.concat([dataframes['protac-db_DC50_Dmax'],
                 dataframes['protac-db_no_degradation']],
                axis=0, ignore_index=True)
protac_pedia = dataframes['protac-pedia']
# Split PROTAC-DB into train, val sets
# NOTE: We specify 15% split w/out interpolated values, so the final amount will
# be lower than 15% of the total dataset
train, val = split_df(protac_db, test_perc=0.15)
# Get test set from PROTAC-Pedia entries NOT in PROTAC-DB and with active labels
test = protac_pedia[~protac_pedia['Smiles_nostereo'].isin(protac_db['Smiles_nostereo'])]
test = test[~test['active'].isna()]

# Remove interpolated entries which are in validation set
interpolated = dataframes['protac-db_interpolated']
not_in_val = interpolated[~interpolated['Smiles_nostereo'].isin(val['Smiles_nostereo'])]
in_val = interpolated[interpolated['Smiles_nostereo'].isin(val['Smiles_nostereo'])]

# Add interpolated entries to training set if not in validation, else to
# the validation set
print(f'train len before adding interpolated: {len(train)}')
train = pd.concat([train, not_in_val], axis=0, ignore_index=True)
print(f'train len after adding interpolated: {len(train)}')

print(f'val len before adding interpolated: {len(val)}')
val = pd.concat([val, in_val], axis=0, ignore_index=True)
print(f'val len after adding interpolated: {len(val)}')

dataframes['train_bin'] = train
dataframes['val_bin'] = val
dataframes['test_bin'] = test
           
print(f'train len: {len(train)} ({len(train) / (len(train) + len(val)) * 100:.1f}%)')
print(f'val len: {len(val)} ({len(val) / (len(val) + len(train)) * 100:.1f}%)')
print(f'test len: {len(test)}')

Assemble the train/val/test sets for regression task:

NOTE: PROTAC-Pedia is used for testing, so we remove all its entries which are also present in PROTAC-DB. We also do NOT include them in the SSL set either, so they are "lost forever", unfortunately.

In [None]:
# Assemble PROTAC-DB and PROTAC-Pedia for regression task
protac_db = pd.concat([dataframes['protac-db_DC50_Dmax'],
                       dataframes['protac-db_interpolated']
                       ], axis=0, ignore_index=True)
protac_pedia = dataframes['protac-pedia']
# Split PROTAC-DB into train, val sets
train, val = split_df(protac_db)
# Get test set from PROTAC-Pedia entries NOT in PROTAC-DB and with active labels
test = protac_pedia[~protac_pedia['Smiles_nostereo'].isin(protac_db['Smiles_nostereo'])]
test = test[~test['active'].isna()]

dataframes['train_regr'] = train
dataframes['val_regr'] = val
dataframes['test_regr'] = test
           
display(train)
display(val)
display(test)

Assemble dataset for SSL binary classification task:

In [None]:
ssl_df = pd.concat([dataframes['protac-db_ssl'],
                    dataframes['protac-db_DC50_Dmax'],
                    dataframes['protac-db_interpolated'],
                    dataframes['protac-db_no_degradation']],
                axis=0, ignore_index=True)
val_df = dataframes['val_bin']
test_df = dataframes['test_bin']
# Remove entries in val and test sets from ssl_df
print(f'SSL len before removal: {len(ssl_df)}')
ssl_df = ssl_df[~ssl_df['Smiles_nostereo'].isin(val_df['Smiles_nostereo'])]
ssl_df = ssl_df[~ssl_df['Smiles_nostereo'].isin(test_df['Smiles_nostereo'])]
print(f'SSL len after removal: {len(ssl_df)}')
# Store and display
dataframes['ssl_bin'] = ssl_df
display(ssl_df)

Assemble dataset for SSL regression task:

In [None]:
ssl_df = pd.concat([dataframes['protac-db_ssl'],
                    dataframes['protac-db_DC50_Dmax'],
                    dataframes['protac-db_interpolated'],
                    dataframes['protac-db_no_degradation']],
                axis=0, ignore_index=True)
val_df = dataframes['val_regr']
test_df = dataframes['test_regr']
# Remove entries in val and test sets from ssl_df
ssl_df = ssl_df[~ssl_df['Smiles_nostereo'].isin(val_df['Smiles_nostereo'])]
ssl_df = ssl_df[~ssl_df['Smiles_nostereo'].isin(test_df['Smiles_nostereo'])]
# Store and display
dataframes['ssl_regr'] = ssl_df
display(ssl_df)

Check for data leakage:

In [None]:
def check_data_leakage(train_df, val_df):
    train_smiles = train_df['Smiles_nostereo'].tolist()
    drop_indices = []
    for index, row in list(val_df.iterrows()):
        if row['Smiles_nostereo'] in train_smiles:
            drop_indices.append(index)
    drop_indices = list(set(drop_indices))
    if len(drop_indices) == 0:
        print('No data leakage detected.')
    else:
        print(f'Detected {len(drop_indices)} leaking entries.')

checks = [('train', 'val'), ('train', 'test'), ('ssl', 'val'), ('ssl', 'test')]
for train, test in checks:
    for task in ['_bin', '_regr']:
        print(f'Checking leakage between {train + task} and {test + task}...', end=' ')
        check_data_leakage(dataframes[train + task], dataframes[test + task])

### Removing Class Imbalance

If we aim at predicting the percentage degradation, the dataset as it is right now is heavily unbalanced towards 50% degradation. Here, we try to compensate the issue via class weighting and up-/down-sampling.

* [Motivations and intuition behind using class weights versus up- or down-sampling](https://datascience.stackexchange.com/questions/44755/why-doesnt-class-weight-resolve-the-imbalanced-classification-problem).
* [Getting the class weights](https://datascience.stackexchange.com/questions/13490/how-to-set-class-weights-for-imbalanced-classes-in-keras).

In [None]:
def plot_active_inactive(val_bin_df, train_bin_df, test_bin_df, descr='countplot_active_entries_train_val_test'):
    val_bin_df['Dataset'] = 'Validation'
    train_bin_df['Dataset'] = 'Train'
    test_bin_df['Dataset'] = 'Test'

    val_bin_df['active'] = val_bin_df['active'].astype(bool)
    train_bin_df['active'] = train_bin_df['active'].astype(bool)
    test_bin_df['active'] = test_bin_df['active'].astype(bool)

    tmp = pd.concat([train_bin_df, val_bin_df, test_bin_df], axis=0)
    tmp['Active/Inactive'] = tmp['active'].apply(lambda x: 'Active' if x else 'Inactive')

    top_n = tmp['Active/Inactive'].value_counts().index
    ax = sns.countplot(data=tmp, hue='Active/Inactive', x='Dataset', hue_order=['Inactive', 'Active'])

    for bars_group in ax.containers:
        ax.bar_label(bars_group, padding=1) # fontsize=12

    plt.grid(axis='y', alpha=0.7)
    plt.title('Active and inactive entries in train/val/test datasets')
    f = os.path.join(fig_dir, descr)
    plt.savefig(f + '.pdf', bbox_inches='tight')
    plt.savefig(f + '.png', bbox_inches='tight')
    plt.show()
    plt.close()

val_bin_df = dataframes['val_bin'].copy()
train_bin_df = dataframes['train_bin'].copy()
test_bin_df = dataframes['test_bin'].copy()
plot_active_inactive(val_bin_df, train_bin_df, test_bin_df)
del val_bin_df
del train_bin_df
del test_bin_df

#### Upsampling/Downsampling

From this [blogpost](https://towardsdatascience.com/heres-what-i-ve-learnt-about-sklearn-resample-ab735ae1abc4).

TODO: While resampling, introducing variations to the data, for example by giving a different SMILES representation (for the same molecule ofc) for each new sample.

In [None]:
def shuffle_smiles(smiles):
    rand_config = {
        'isomericSmiles': False, # random.choice([True, False]),
        'kekuleSmiles': random.choice([True, False]),
        # 'rootedAtAtom': (optional) if non-negative, this forces the SMILES to start at a particular atom. Defaults to -1.
        'canonical': random.choice([True, False]),
        'allBondsExplicit': random.choice([True, False]),
        'allHsExplicit': random.choice([True, False]),
    }
    mol = Chem.MolFromSmiles(smiles)
    return Chem.MolToSmiles(mol, **rand_config)

def scramble_smiles(smiles, plot_mol=False):
    # Convert SMILES string to RDKit molecule object
    mol = Chem.MolFromSmiles(smiles)
    new_mol = copy.deepcopy(mol)
    # Round 1: randomize order of double bonds
    Chem.Kekulize(new_mol)
    # Round 2: shuffle atom indices
    atom_indices = list(range(new_mol.GetNumAtoms()))
    random.shuffle(atom_indices)
    new_mol = Chem.RenumberAtoms(new_mol, atom_indices)
    # Round 3: randomize order of double bonds again
    Chem.Kekulize(new_mol)
    # Generate a new SMILES string from the new molecule object
    new_smiles = Chem.MolToSmiles(new_mol, isomericSmiles=False, canonical=False)
    # Check if the scrambled molecule is the same as the original one
    canon_new_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles), canonical=True)
    canon_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(new_smiles), canonical=True)
    if plot_mol:
        print('Original molecule:')
        display(mol)
        print('Transformed molecule:')
        display(new_mol)
    if canon_smiles != canon_new_smiles:
        print(f'original/scrambled:\n{smiles}\n{new_smiles}')
        if smiles != new_smiles:
            pass
            # display(mol)
            # display(new_mol)
        print('-' * 80)
        return smiles
    else:
        return new_smiles

examples = dataframes['train_bin'].at[0, 'Smiles_nostereo']
print(f'Original SMILES: {examples}')
print(f'Transformed SMILES: {scramble_smiles(examples, plot_mol=True)}')

In [None]:
train_df = dataframes['train_bin']
active_df = train_df[train_df['active'] == True]
inactive_df = train_df[train_df['active'] == False]
# Set majority and minority classes datasets
if len(active_df) > len(inactive_df):
    majority_df = active_df
    minority_df = inactive_df
else:
    majority_df = inactive_df
    minority_df = active_df
# Upsample the minority class
n_samples = abs(len(active_df) - len(inactive_df))
minority_upsampled_df = resample(minority_df, random_state=42,
                                 n_samples=n_samples, replace=True)
# Transform SMILES strings of the upsampled class
minority_upsampled_df['Smiles_nostereo'] = minority_upsampled_df['Smiles_nostereo'].apply(scramble_smiles)
# Concatenate the upsampled dataframe
train_upsampled_bin_df = pd.concat([minority_upsampled_df, train_df], axis=0)
dataframes['train_upsampled_bin'] = train_upsampled_bin_df

val_bin_df = dataframes['val_bin'].copy()
test_bin_df = dataframes['test_bin'].copy()
plot_active_inactive(val_bin_df, train_upsampled_bin_df, test_bin_df)
del val_bin_df
del test_bin_df

Legacy code for the remaining section...

Assemble train/val/test sets and SSL in deprecating old way:

In [None]:
# protac_db_df['poi_gene_id'] = input_df['poi_gene_id']
splitter = GroupShuffleSplit(test_size=TEST_SPLIT_PERC, n_splits=2, random_state=42)
# Split "entire" dataset
split = splitter.split(protac_db_df, groups=protac_db_df['Smiles'])
train_inds, val_inds = next(split)
train_df, val_df = protac_db_df.iloc[train_inds], protac_db_df.iloc[val_inds]
# Split datasets for binary classification
bin_df = protac_db_df.dropna(subset=['active'])
split = splitter.split(bin_df, groups=bin_df['Smiles'])
train_inds, val_inds = next(split)
train_bin_df, val_bin_df = bin_df.iloc[train_inds], bin_df.iloc[val_inds]
# Get test dataset from PROTAC-Pedia entries which are NOT in PROTAC-DB
tmp = protac_pedia_df.dropna(subset=['active'])
test_df = tmp[~tmp['Smiles_nostereo'].isin(protac_db_df['Smiles_nostereo'])]
test_bin_df = test_df
# Reporting
print(f'Len(PROTAC-Pedia): {len(protac_pedia_df)}')
print(f'Len(PROTAC-Pedia) with active/inactive: {len(tmp)}')
print(f'Train data len.: {len(train_df)}')
print(f'Val data len.: {len(val_df)}')
print(f'Test data len.: {len(test_df)}')
print(f'Train data len.: {len(train_bin_df)} (binary classification)')
print(f'Val data len.: {len(val_bin_df)} (binary classification)')
print(f'Test data len.: {len(test_bin_df)} (binary classification)')

not_in_val = ~protac_pedia_df['Smiles_nostereo'].isin(val_bin_df['Smiles_nostereo'])
not_in_test = ~protac_pedia_df['Smiles_nostereo'].isin(protac_db_df['Smiles_nostereo'])
tmp = protac_pedia_df[not_in_val & not_in_test].copy()
tmp['active'] = tmp['Active/Inactive']
print(f'Len(PROTAC-Pedia) with active/inactive: {len(tmp)}')
splitter = GroupShuffleSplit(test_size=0.1, n_splits=2, random_state=42)
split = splitter.split(tmp, groups=tmp['Smiles'])
train_inds, val_inds = next(split)
train_bin_protac_pedia_df = tmp.iloc[train_inds]
val_bin_protac_pedia_df = tmp.iloc[val_inds]
print(f'Train data len.: {len(train_bin_protac_pedia_df)} (binary classification PROTAC-Pedia)')
print(f'Val data len.: {len(val_bin_protac_pedia_df)} (binary classification PROTAC-Pedia)')

Checking for data leakage.

In [None]:
print(f'[Before data leaking check] Train data len.: {len(train_df)}')
print(f'[Before data leaking check] Val data len.: {len(val_df)}')

train_smiles = train_df['Smiles'].tolist()
# train_genes = train_df['poi_gene_id'].tolist()
drop_indices = []
for index, row in list(val_df.iterrows()) + list(test_df.iterrows()):
    if row['Smiles'] in train_smiles:
        # print(f'Index n.{index} is leaking SMILES')
        drop_indices.append(index)
    # if row['poi_gene_id'] in train_genes:
    #     # print(f'Index n.{index} is leaking genes')
    #     drop_indices.append(index)
drop_indices = list(set(drop_indices))
if len(drop_indices) == 0:
    print('No data leakage detected.')
else:
    print(f'Detected {len(drop_indices)} leaking entries.')
    train_df = pd.concat([train_df, val_df.loc[drop_indices]], axis=0)
    val_df = val_df.drop(drop_indices)
    print(f'[After data leaking check] Train data len.: {len(train_df)}')
    print(f'[After data leaking check] Val data len.: {len(val_df)}')

In [None]:
print(f'Binary Classification')
print(f'[Before data leaking check] Train data len.: {len(train_bin_df)}')
print(f'[Before data leaking check] Test data len.: {len(val_bin_df)}')

train_smiles = train_bin_df['Smiles'].tolist()
# train_genes = train_bin_df['poi_gene_id'].tolist()
drop_indices = []
for index, row in list(val_df.iterrows()) + list(test_df.iterrows()):
    if row['Smiles'] in train_smiles:
        # print(f'Index n.{index} is leaking SMILES')
        drop_indices.append(index)
    # if row['poi_gene_id'] in train_genes:
    #     print(f'Index n.{index} is leaking genes')
    #     drop_indices.append(index)
drop_indices = list(set(drop_indices))
if len(drop_indices) == 0:
    print('No data leakage detected.')
else:
    print(f'Detected {len(drop_indices)} leaking entries.')
#     train_bin_df = pd.concat([train_bin_df, val_bin_df.loc[drop_indices]], axis=1)
#     val_bin_df = val_bin_df.drop(drop_indices)
#     print(f'[After data leaking check] Train data len.: {len(train_bin_df)}')
#     print(f'[After data leaking check] Test data len.: {len(val_bin_df)}')

In [None]:
# Set other classes to another dataframe
active_df = train_bin_df[train_bin_df['active']]
# Set the minority class to a seperate dataframe
inactive_df = train_bin_df[train_bin_df['active'] == False]
# Upsample the minority class
n_samples = len(active_df) # - len(inactive_df)
inactive_df_upsampled = resample(inactive_df, random_state=42,
                                 n_samples=n_samples, replace=True)
inactive_df_upsampled['Smiles_nostereo'] = inactive_df_upsampled.apply(lambda row: scramble_smiles(row['Smiles_nostereo']), axis=1)
# Concatenate the upsampled dataframe
train_upsampled_bin_df = pd.concat([inactive_df_upsampled, active_df])
print(f'inactive_df_upsampled len: {len(inactive_df_upsampled)}')
print(f'active_df len: {len(active_df)}')
print(f'train_upsampled len: {len(train_upsampled_bin_df)}')

In [None]:
# sns.histplot(data=val_bin_df['active'].astype(float))
val_bin_df['Dataset'] = 'Validation'
train_bin_df['Dataset'] = 'Train'
test_bin_df['Dataset'] = 'Test'

tmp = pd.concat([train_bin_df, val_bin_df, test_bin_df], axis=0)
tmp['Active/Inactive'] = tmp['active'].apply(lambda x: 'Active' if x else 'Inactive')

top_n = tmp['Active/Inactive'].value_counts().index
ax = sns.countplot(data=tmp, hue='Active/Inactive', x='Dataset')

# plt.xticks(rotation=90)
for bars_group in ax.containers:
    ax.bar_label(bars_group, padding=1) # fontsize=12

plt.grid(axis='y', alpha=0.7)
plt.title('Active and inactive entries in train/val/test datasets')
f = os.path.join(fig_dir, 'countplot_active_entries_train_val_test')
plt.savefig(f + '.pdf', bbox_inches='tight')
plt.savefig(f + '.png', bbox_inches='tight')
plt.show()
plt.close()

del val_bin_df['Dataset']
del train_bin_df['Dataset']
del test_bin_df['Dataset']

In [None]:
# sns.histplot(data=val_bin_df['active'].astype(float))
val_bin_protac_pedia_df['Type'] = 'Val'
train_bin_protac_pedia_df['Type'] = 'Train'
test_bin_df['Type'] = 'Test'

tmp = pd.concat([train_bin_protac_pedia_df, val_bin_protac_pedia_df, test_bin_df], axis=0)
tmp['active'] = tmp['active'].astype(int).copy()
sns.histplot(data=tmp, x='active', hue='Type', multiple='dodge')
# plt.xticks(rotation=90)
plt.grid(axis='y', alpha=0.9)
plt.title('Active entries in datasets')
# plt.show()
plt.close()

del val_bin_protac_pedia_df['Type']
del train_bin_protac_pedia_df['Type']
del test_bin_df['Type']

In [None]:
print('-' * 80)
print('PROTAC-DB statistics:')
print('-' * 80)
for x in zip([train_bin_df, val_bin_df, test_bin_df], ['Train', 'Val', 'Test']):
    df, name = x
    n_active = len(df[df['active'] == True])
    n_inactive = len(df[df['active'] == False])
    n_active_perc = n_active / len(df) * 100
    n_inactive_perc = n_inactive / len(df) * 100
    print(f'{name} dataset num. active entries:\t{n_active:4d} ({n_active_perc:2.1f}%)')
    print(f'{name} dataset num. inactive entries:\t{n_inactive:4d} ({n_inactive_perc:2.1f}%)')
print('-' * 80)
print('PROTAC-Pedia statistics:')
print('-' * 80)
for x in zip([train_bin_protac_pedia_df, val_bin_protac_pedia_df], ['Train', 'Val']):
    df, name = x
    n_active = len(df[df['active'] == True])
    n_inactive = len(df[df['active'] == False])
    n_active_perc = n_active / len(df) * 100
    n_inactive_perc = n_inactive / len(df) * 100
    print(f'{name} dataset num. active entries:\t{n_active:4d} ({n_active_perc:2.1f}%)')
    print(f'{name} dataset num. inactive entries:\t{n_inactive:4d} ({n_inactive_perc:2.1f}%)')
print('-' * 80)

In [None]:
sns.histplot(data=train_bin_df['active'].astype(float))
# plt.xticks(rotation=90)
plt.grid(axis='y', alpha=0.9)
plt.title('Active entries in train dataset (unbalanced)')
# plt.savefig(os.path.join(fig_dir, 'active_entries_hist.pdf'), bbox_inches='tight')
# plt.show()
plt.close()

In [None]:
sns.histplot(data=train_upsampled_bin_df['active'].astype(float))
# plt.xticks(rotation=90)
plt.grid(axis='y', alpha=0.9)
plt.title('Active entries in train dataset (balanced, i.e., upsampled)')
# plt.savefig(os.path.join(fig_dir, 'active_entries_hist.pdf'), bbox_inches='tight')
# plt.show()
plt.close()

#### Class Weights (TODO)

In [None]:
# from sklearn.utils import class_weight

# class_weights = class_weight.compute_class_weight(class_weight='balanced',
#                                                   classes=train_df['Dmax'].astype(float).unique(),
#                                                   y=train_df['Dmax'].astype(float))

## POI Sequence Encoding

#### POI Sequence to $N_{grams}$

Count-vectorize the POI amino acid sequence.

(Not ideal and very simple, but it's a start)

In [None]:
import joblib

ngram_min_range = 2 # Orginal: 3
ngram_max_range = 2 # Orginal: 3
poi_vectorizer = CountVectorizer(analyzer='char', ngram_range=(ngram_min_range, ngram_max_range))
X = poi_vectorizer.fit_transform(protac_db_df['poi_seq'].tolist())
rec_n_grams_df = pd.DataFrame(X.toarray(), columns=list(s.replace(' ', '') for s in poi_vectorizer.get_feature_names_out()))
print(f'POI embedding size: {rec_n_grams_df.shape[-1]}')

# Load the pre-trained countvectorizer if it exists, otherwise train it
poi_encoder_filepath = os.path.join(checkpoint_dir, 'poi_encoder.joblib')
if os.path.exists(poi_encoder_filepath):
    print('Loading pre-trained POI vectorizer...')
    poi_encoder = joblib.load(poi_encoder_filepath)
else:
    print('Training POI vectorizer...')
    poi_encoder = CountVectorizer(analyzer='char', ngram_range=(ngram_min_range, ngram_max_range))
    X = poi_encoder.fit_transform(dataframes['ssl_bin']['poi_seq'].tolist())
    rec_n_grams_df = pd.DataFrame(X.toarray(), columns=list(s.replace(' ', '') for s in poi_encoder.get_feature_names_out()))
    print(f'POI embedding size: {rec_n_grams_df.shape[-1]}')
    joblib.dump(poi_encoder, poi_encoder_filepath)

#### POI Gene Ordinal Encoding

Add the "Unknown" class to the POI genes.

Since genes ultimately encode proteins, we can use the gene ID as a categorical feature to include information about the POIs.

(The information loss is considerable, since the gene ID is not that informative compared to the entire amino acid sequence)

In [None]:
poi_gene_enc = preprocessing.OrdinalEncoder(handle_unknown='use_encoded_value',
                                            unknown_value=-1,
                                            encoded_missing_value=-1)
poi_gene_id = protac_db_df['poi_gene_id'].to_numpy().reshape(-1, 1)
poi_gene_enc.fit(poi_gene_id)

## E3 Ligase and Cell Type Ordinal Encoding

Notice that the "other E3" have been dropped during the previous steps, leading to only 5 possibilities left.

In [None]:
e3_ligase_enc = preprocessing.OrdinalEncoder(handle_unknown='use_encoded_value',
                                             unknown_value=-1,
                                             encoded_missing_value=-1)
e3_ligase = protac_db_df['e3_ligase'].to_numpy().reshape(-1, 1)
e3_ligase_enc.fit(e3_ligase)

# Load the pre-trained ordinal encoder if it exists, otherwise train it
e3_encoder_filepath = os.path.join(checkpoint_dir, 'e3_encoder.joblib')
if os.path.exists(e3_encoder_filepath):
    print('Loading pre-trained POI vectorizer...')
    e3_encoder = joblib.load(e3_encoder_filepath)
else:
    print('Training E3 encoder...')
    e3_encoder = preprocessing.OrdinalEncoder(handle_unknown='use_encoded_value',
                                              unknown_value=-1,
                                              encoded_missing_value=-1)
    e3_ligase = dataframes['ssl_bin']['e3_ligase'].to_numpy().reshape(-1, 1)
    e3_encoder.fit(e3_ligase)
    joblib.dump(e3_encoder, e3_encoder_filepath)
print('Done!')

In [None]:
cell_type_enc = preprocessing.OrdinalEncoder(handle_unknown='use_encoded_value',
                                             unknown_value=-1,
                                             encoded_missing_value=-1)
cell_type = protac_db_df['cell_type'].to_numpy().reshape(-1, 1)
cell_type_enc.fit(cell_type)

# Load the pre-trained ordinal encoder if it exists, otherwise train it
cell_encoder_filepath = os.path.join(checkpoint_dir, 'cell_encoder.joblib')
if os.path.exists(cell_encoder_filepath):
    print('Loading pre-trained POI vectorizer...')
    cell_encoder = joblib.load(cell_encoder_filepath)
else:
    print('Training E3 encoder...')
    cell_encoder = preprocessing.OrdinalEncoder(handle_unknown='use_encoded_value',
                                                unknown_value=-1,
                                                encoded_missing_value=-1)
    e3_ligase = dataframes['ssl_bin']['cell_type'].to_numpy().reshape(-1, 1)
    cell_encoder.fit(e3_ligase)
    joblib.dump(cell_encoder, cell_encoder_filepath)
print('Done!')

## Molecular Fingerprints

SMILES $\rightarrow$ molecule is unique.

molecule $\not\to$ SMILES is _not_ unique.

By construction, a SMILES encodes a unique molecule. However, one molecule can be encoded by multiple SMILES representations.

On the other end, Morgan fingerprints are designed to create a one-to-one mapping of molecules, given there are enough bits given to the representation.

Fingerprints appear to be very informative and descriptive of the molecules. The Tanimoto similarity score b/w molecular fingerprints can somewhat give a measurements of similarity b/w two molecules.

Refer to this [tutorial](https://chem.libretexts.org/Courses/Intercollegiate_Courses/Cheminformatics/06%3A_Molecular_Similarity/6.04%3A_Python_Assignment).

[Pytorch, convert to binary tensor](https://stackoverflow.com/questions/55918468/convert-integer-to-pytorch-tensor-of-binary-bits).

In [None]:
def get_fingerprint(smiles: str, n_bits: int = 1024, fp_type: Literal['morgan', 'maccs', 'path'] = 'morgan',
                    min_path: int = 1, max_path: int = 2,
                    atomic_radius: int = 2) -> np.ndarray:
    """Returns molecular fingerprint of a given molecule SMILES.

    Args:
        smiles (str): SMILES string to convert.
        n_bits (int, optional): Number of bits of the generated fingerprint. Defaults to 1024.
        fp_type (Literal[&#39;morgan&#39;, &#39;maccs&#39;, &#39;path&#39;], optional): Fingerprint type to generate. Defaults to 'morgan'.
        min_path (int, optional): Minimum path lenght for path-based fingerprints. Defaults to 1.
        max_path (int, optional): Maximum path lenght for path-based fingerprints. Defaults to 2.
        atomic_radius (int, optional): Atomic radius for MORGAN fingerprints. Defaults to 2.

    Raises:
        ValueError: When wrong fingerprint type is requested.

    Returns:
        np.ndarray: The generated fingerprint.
    """ 
    mol = Chem.MolFromSmiles(smiles)
    if fp_type == 'morgan':
        fingerprint = AllChem.GetMorganFingerprintAsBitVect(mol, atomic_radius,
                                                            nBits=n_bits)
    elif fp_type == 'maccs':
        fingerprint = MACCSkeys.GenMACCSKeys(mol)
    elif fp_type == 'path':
        fingerprint = Chem.rdmolops.RDKFingerprint(mol, fpSize=n_bits,
                                                   minPath=min_path,
                                                   maxPath=max_path)
    else:
        raise ValueError(f'Wrong type of fingerprint requested. Received "{fp_type}", expected one in: [morgan|maccs|path]')
    array = np.zeros((0,), dtype=np.int8)
    DataStructs.ConvertToNumpyArray(fingerprint, array)
    return array

fp2str = lambda fp: ''.join([str(x) for x in fp])

#### Compress and Store Fingerprints

Code inspired from this [StackOverflow question](https://stackoverflow.com/questions/71621513/how-do-i-compress-a-rather-long-binary-string-in-python-so-that-i-will-be-able-t).

In [None]:
import zlib

# Numpy array -> Binary string
num2str = lambda fp: ''.join([str(x) for x in fp])
# Binary string -> Hexadecimal (compressed) string
str2hex = lambda fp: zlib.compress(fp.encode()).hex()
# Hexadecimal (compressed) string -> Binary string
hex2str = lambda fp: zlib.decompress(bytearray.fromhex(fp)).decode()
# Binary string -> Numpy array
str2num = lambda fp: np.array([x for x in fp], dtype=np.int8)

fp = get_fingerprint(protac_db_df.iloc[17]['Smiles'], n_bits=1024, fp_type='maccs')
fp_dec = str2num(hex2str(str2hex(num2str(fp))))
np.allclose(fp, fp_dec)

#### Morgan Fingerprints

The notebook now adds $n$ columns, each containing the $i$-th bit of Morgan fingerprint, with $i=1,...,n$. In our case we have $n = 1024$.

**We obtain the fingerprint from the "removed stereochemistry" SMILES, in order to further avoid overlaps.**

In [None]:
n_bits = 32
compress = lambda s: str2hex(num2str(get_fingerprint(s, n_bits=n_bits)))
protac_db_df[f'morgan_{n_bits}bits'] = protac_db_df['Smiles_nostereo'].apply(compress)

#### MACCS Fingerprints

Inspired by this [tutorial](https://projects.volkamerlab.org/teachopencadd/talktorials/T004_compound_similarity.html).

In [None]:
smiles_example = protac_db_df.iloc[0]['Smiles']

MACCS_BITWIDTH = len(fp2str(get_fingerprint(protac_db_df.iloc[31]['Smiles'], n_bits=128, fp_type='maccs')))
MACCS_BITWIDTH

In [None]:
n_bits = 167
compress = lambda s: str2hex(num2str(get_fingerprint(s, n_bits=n_bits, fp_type='maccs')))
# input_df[f'maccs_{n_bits}bits'] = input_df['Smiles_nostereo'].apply(compress)

#### Path-Based Fingerprints

In [None]:
n_bits = 32
compress = lambda s: str2hex(num2str(get_fingerprint(s, n_bits=n_bits, fp_type='path')))
# input_df[f'path_{n_bits}bits'] = input_df['Smiles_nostereo'].apply(compress)
# input_df[f'path_{n_bits}bits']

## PyTorch Dataset and DataLoader

In [None]:
# TODO: Cast everything to np.float16??? Maybe not, CPUs might not support it...

class ProtacDataset(Dataset):

    def __init__(self,
                 dataframe,
                 task: Literal['predict_active_inactive', 'predict_pDC50_and_Dmax'] = 'predict_active_inactive',
                 scale_concentration: bool = False,
                 scale_degradation: bool = False,
                 include_smiles_as_str: bool = False,
                 include_smiles_as_graphs: bool = False,
                 smiles_tokenizer: Any = None,
                 include_poi_seq: bool = True,
                 poi_vectorizer: Any = None,
                 ngram_range: Tuple[int, int] = (2, 2),      
                 tokenize_poi_seq: bool = False,
                 poi_tokenizer: Any = None,
                 include_poi_gene: bool = False,
                 precompute_smiles_as_graphs: bool = False,
                 precompute_fingerprints: bool = False,
                 precompute_poi_seq: bool = False,
                 use_for_ssl: bool = False,
                 return_tensors: str | None = None,
                 use_morgan_fp: bool = False,
                 morgan_bits: int = 1024,
                 morgan_atomic_radius: int = 2,
                 use_maccs_fp: bool = False,
                 use_path_fp: bool = False,
                 path_bits: int = 1024,
                 fp_min_path: int = 1,
                 fp_max_path: int = 16,
                 poi_gene_enc: sklearn.preprocessing.OrdinalEncoder | sklearn.preprocessing.OneHotEncoder | None = None,
                 e3_ligase_enc: sklearn.preprocessing.OrdinalEncoder | sklearn.preprocessing.OneHotEncoder | None = None,
                 cell_type_enc: sklearn.preprocessing.OrdinalEncoder | sklearn.preprocessing.OneHotEncoder | None = None,
                 normalize_extra_features: bool = False):
        """Pytorch Dataset for PROTAC data. Each element will consist of a dictionary of different processed features.
        When processed by a DataLoader, the dictionary structure will remain, but each value will be converted to a batch of tensors.

        Parameters
        ----------
        dataframe : pandas.DataFrame
            Pandas dataframe containing the data.
        task : Literal['predict_active_inactive', 'predict_pDC50_and_Dmax'], default='predict_pDC50_and_Dmax'
            The task for which to load the dataset.
        scale_concentration : bool, default=False
            Whether to scale the concentrations to the range [0, 1].
        scale_degradation : bool, default=False
            Whether to scale the Dmax to the range [0, 1].
        include_smiles_as_str : bool, default=False
            Whether to include the SMILES as a string.
        include_smiles_as_graphs : bool, default=False
            Whether to include the SMILES as graphs.
        smiles_tokenizer : Any, default=None
            The SMILES tokenizer to use. If None, will use a default one.
        poi_vectorizer : Any, default=None
            The POI vectorizer to use. If None, will use a default one.
        poi_tokenizer : Any, default=None
            The POI tokenizer to use. If None, will use a default one.
        include_poi_seq : bool, default=False
            Whether to include the POI sequence.
        precompute_smiles_as_graphs : bool, default=False
            Whether to precompute the SMILES graphs.
        precompute_fingerprints : bool, default=False
            Whether to precompute the fingerprints.
        use_for_ssl : bool, default=False
            Whether to use the dataset for self-supervised learning.
        use_morgan_fp : bool, default=False
            Whether to use Morgan fingerprints.
        morgan_bits : int, default=1024
            The number of bits to use for the Morgan fingerprints.
        use_maccs_fp : bool, default=False
            Whether to use MACCS fingerprints.
        use_path_fp : bool, default=False
            Whether to use path fingerprints.
        """
        self.__dict__.update(locals()) # Add arguments as attributes
        self.hparams = {k: v for k, v in locals().items() if k != 'dataframe' and k != 'self'} # Store hyperparameters
        self.maccs_bits = 167 # Hardcoded, see RDKit documentation
        self.dataset_len = len(self.dataframe)
        if task == 'predict_pDC50' or task == 'predict_active_inactive':
            if not use_for_ssl:
                self.dataframe = dataframe.dropna(subset=['active'])
        self.smiles = self.dataframe['Smiles_nostereo']
        # if include_selfies:
        #     self.selfies = [sf.encoder(s) for s in self.smiles]
        
        # TODO: The SSL dataframe has the same columns, so there is no point in
        # having a separate if-else clause for it. Just use the same code for both.
        if not use_for_ssl:
            # Get POI gene as integer classes and normalize them
            if poi_gene_enc is not None:
                self.gene = self.dataframe['poi_gene_id'].to_numpy().reshape(-1, 1)
                self.gene = poi_gene_enc.transform(self.gene)
                self.gene = self.gene.flatten().astype(np.float32)
                if normalize_extra_features:
                    self.gene /= len(poi_gene_enc.categories_)
            else:
                self.poi_gene_enc = preprocessing.OrdinalEncoder(
                    handle_unknown='use_encoded_value',
                    unknown_value=-1,
                    encoded_missing_value=-1
                )
                tmp = self.dataframe['poi_gene_id'].to_numpy().reshape(-1, 1)
                tmp = self.poi_gene_enc.fit_transform(tmp)
                self.gene = tmp.astype(np.float32).flatten()
            # Get E3 ligase as integer classes and normalize them
            if e3_ligase_enc is not None:
                self.e3_ligase = self.dataframe['e3_ligase'].to_numpy().reshape(-1, 1)
                self.e3_ligase = e3_ligase_enc.transform(self.e3_ligase)
                self.e3_ligase = self.e3_ligase.flatten().astype(np.float32)
                if normalize_extra_features:
                    self.e3_ligase /= len(e3_ligase_enc.categories_)
            else:
                self.e3_ligase_enc = preprocessing.OrdinalEncoder(
                    handle_unknown='use_encoded_value',
                    unknown_value=-1,
                    encoded_missing_value=-1
                )
                tmp = self.dataframe['e3_ligase'].to_numpy().reshape(-1, 1)
                tmp = self.e3_ligase_enc.fit_transform(tmp)
                self.e3_ligase = tmp.astype(np.float32).flatten()
            # Get cell type as integer classes and normalize them
            if cell_type_enc is not None:
                self.cell_type = self.dataframe['cell_type'].to_numpy().reshape(-1, 1)
                self.cell_type = cell_type_enc.transform(self.cell_type)
                self.cell_type = self.cell_type.flatten().astype(np.float32)
                if normalize_extra_features:
                    self.cell_type /= len(cell_type_enc.categories_)
            else:
                self.cell_type_enc = preprocessing.OrdinalEncoder(
                    handle_unknown='use_encoded_value',
                    unknown_value=-1,
                    encoded_missing_value=-1
                )
                tmp = self.dataframe['cell_type'].to_numpy().reshape(-1, 1)
                tmp = self.cell_type_enc.fit_transform(tmp)
                self.cell_type = tmp.astype(np.float32).flatten()
            # Get the POI sequence
            if include_poi_seq:
                self.poi_seq = self.dataframe['poi_seq'].to_list()
                if poi_vectorizer is None:
                    self.poi_vectorizer = CountVectorizer(analyzer='char',
                                                          ngram_range=ngram_range)
                    self.poi_vectorizer.fit(self.poi_seq)
                if precompute_poi_seq:
                    self.poi_seq = self.poi_vectorizer.transform(self.poi_seq)
                    self.poi_seq = self.poi_seq.toarray().astype(np.float32)
            # Get the concentration and degradation values
            self.active = self.dataframe['active'].astype(np.float32)
            # TODO: Scaling the concentrations and degradations???
            if scale_concentration:
                self.pDC50 = (self.dataframe['pDC50'] * 0.1).astype(np.float32)
            else:
                self.pDC50 = self.dataframe['pDC50'].astype(np.float32)
            if scale_degradation:
                self.Dmax = (self.dataframe['Dmax']).astype(np.float32)
            else:
                self.Dmax = self.dataframe['Dmax'].astype(np.float32)
            # Tokenize the POI sequence (for example for BERT-based models)
            if tokenize_poi_seq:
                if poi_tokenizer is None:
                    self.poi_seq = self.dataframe['poi_seq']
                else:
                    self.poi_seq = [self.poi_tokenizer(seq, padding='max_length', truncation=True, return_tensors='pt') for seq in self.dataframe['poi_seq']]
                
        if precompute_fingerprints:
            if self.use_morgan_fp:
                self.morgan_fp = np.array([get_fingerprint(s, n_bits=self.morgan_bits, fp_type='morgan', atomic_radius=morgan_atomic_radius).astype(np.float32) for s in self.smiles])
            if self.use_maccs_fp:
                self.maccs_fp = np.array([get_fingerprint(s, fp_type='maccs').astype(np.float32) for s in self.smiles])
            if self.use_path_fp:
                self.path_fp = np.array([get_fingerprint(s, n_bits=self.path_bits, fp_type='path', min_path=self.fp_min_path, max_path=self.fp_max_path).astype(np.float32) for s in self.smiles])
        if include_smiles_as_graphs or precompute_smiles_as_graphs:
            # NOTE: self.graph_smiles is a list of PytorchGeometric Data objects
            self.graph_smiles = [from_smiles(s) for s in self.smiles]
        if smiles_tokenizer:
            # NOTE: Do NOT return tensors when doing SSL, i.e., MLM, as reported
            # in this conversation: https://discuss.huggingface.co/t/extra-dimension-with-datacollatorfor-languagemodeling-into-bertformaskedlm/6400/6
            if use_for_ssl:
                self.smiles_tokenized = [
                    smiles_tokenizer(s, padding='max_length', truncation=True) for s in self.smiles
                ]
                assert len(self.smiles_tokenized) == len(self.smiles), (
                    f'ERROR. Len tokenized {len(self.smiles_tokenized)} /= len SMILES {len(self.smiles)}'
                )
            else:
                self.smiles_tokenized = [
                    smiles_tokenizer(s, padding='max_length', truncation=True, return_tensors='pt') for s in self.smiles
                ]

    @staticmethod
    def load(pt_file):
        # TODO: Work in progress
        return torch.load(pt_file)

    def __len__(self):
        return len(self.smiles)

    def __getitem__(self, idx):
        smiles = self.smiles.iloc[idx]
        if self.use_for_ssl:
            elem = {}
            if self.smiles_tokenizer:
                smiles_tokenized = self.smiles_tokenized[idx]
                elem['input_ids'] = smiles_tokenized['input_ids']
                elem['attention_mask'] = smiles_tokenized['attention_mask']
                elem['labels'] = smiles_tokenized['input_ids'].copy()
            else:
                elem['smiles'] = smiles
            return elem
        elem = {
            'e3_ligase': self.e3_ligase[idx][..., None],
            'cell_type': self.cell_type[idx][..., None],
        }
        if self.include_poi_gene:
            elem['poi_gene_id'] = self.gene[idx][..., None]
        if self.task == 'predict_active_inactive':
            elem['labels'] = self.active.iloc[idx][..., None]
        elif self.task == 'predict_pDC50':
            elem['labels'] = self.pDC50.iloc[idx][..., None]
        elif self.task == 'predict_pDC50_and_Dmax':
            Dmax = self.Dmax.iloc[idx]
            pDC50 = self.pDC50.iloc[idx]
            elem['labels'] = np.array([Dmax, pDC50])
        else:
            raise ValueError(f'Task "{self.task}" not recognized. Available: "predict_pDC50" \| "predict_active_inactive" \| "predict_pDC50_and_Dmax"')
        if self.include_smiles_as_graphs or self.precompute_smiles_as_graphs:
            if self.precompute_smiles_as_graphs:
                elem['smiles_graph'] = self.graph_smiles[idx]
            else:
                elem['smiles_graph'] = from_smiles(smiles)
        if self.smiles_tokenizer:
            elem['smiles_tokenized'] = self.smiles_tokenized[idx]
        if self.include_poi_seq or self.tokenize_poi_seq:
            if self.precompute_poi_seq:
                elem['poi_seq'] = self.poi_seq[idx]
            else:
                poi_seq = self.poi_vectorizer.transform([self.poi_seq[idx]])
                poi_seq = poi_seq.toarray().flatten().astype(np.float32)
                elem['poi_seq'] = poi_seq
        if self.include_smiles_as_str:
            elem['smiles'] = smiles
        if self.use_morgan_fp:
            if self.precompute_fingerprints:
                fp = self.morgan_fp[idx].copy()
            else:
                fp = get_fingerprint(smiles, n_bits=self.morgan_bits).astype(np.float32)
            elem['morgan_fp'] = fp
        if self.use_maccs_fp:
            if self.precompute_fingerprints:
                fp = self.maccs_fp[idx].copy()
            else:
                fp = get_fingerprint(smiles, fp_type='maccs').astype(np.float32)
            elem['maccs_fp'] = fp
        if self.use_path_fp:
            if self.precompute_fingerprints:
                fp = self.path_fp[idx].copy()
            else:
                fp = get_fingerprint(smiles, n_bits=self.path_bits,
                                     fp_type='path',
                                     min_path=self.fp_min_path,
                                     max_path=self.fp_max_path).astype(np.float32)
            elem['path_fp'] = fp
        return elem

    def get_fingerprint(self, fp_type: Literal['morgan_fp', 'maccs_fp', 'path_fp'] = 'morgan_fp'):
        # TODO: Add the proper checks if fingerprints are used
        if self.precompute_fingerprints:
            if fp_type == 'morgan_fp':
                return self.morgan_fp
            elif fp_type == 'maccs_fp':
                return self.maccs_fp
            elif fp_type == 'path_fp':
                return self.path_fp
            else:
                raise ValueError(f'Fingerprint type "{fp_type}" not recognized. Available: "morgan_fp" \| "maccs_fp" \| "path_fp"')
        else:
            smiles = self.smiles
            if fp_type == 'morgan_fp':
                return np.array([get_fingerprint(s, n_bits=self.morgan_bits).astype(np.float32) for s in smiles])
            elif fp_type == 'maccs_fp':
                return np.array([get_fingerprint(s, fp_type='maccs_fp').astype(np.float32) for s in smiles])
            elif fp_type == 'path_fp':
                return np.array([get_fingerprint(s, n_bits=self.path_bits, fp_type='path_fp', min_path=self.fp_min_path, max_path=self.fp_max_path).astype(np.float32) for s in smiles])
            else:
                raise ValueError(f'Fingerprint type "{fp_type}" not recognized. Available: "morgan_fp" \| "maccs_fp" \| "path_fp"')
    
    def get_poi_seq_emb_size(self):
        if self.include_poi_seq:
            return len(self.poi_vectorizer.get_feature_names_out())
        else:
            return 0
    
    def __str__(self) -> str:
        return f'ProtacDataset for {self.task} task with {len(self)} samples.'

In order to have batches of graphs _along side with the other features_, we need to _extend_ the default collate function to include data of Pytorch Geometric type `Data`.

Refer to [this documentation](https://pytorch.org/docs/stable/data.html#torch.utils.data.default_collate).

In [None]:
from torch.utils.data._utils.collate import collate
from torch.utils.data._utils.collate import default_collate_fn_map

def graph_collate(batch, *, collate_fn_map=None):
    # Handle graph data separately: graph representation and computation can be
    # greatly optimized due to their sparse nature. In fact, multiple graphs in
    # a batch can be seen as a "big" graph of unconnected sub-graphs. Hence,
    # their adjecency matrices can be combined together to form a single one.
    return torch_geometric.data.Batch.from_data_list(batch)

def custom_collate(batch):
    collate_map = default_collate_fn_map.copy()
    collate_map.update({torch_geometric.data.Data: graph_collate})
    return collate(batch, collate_fn_map=collate_map)

In [None]:
import inspect

protac_dataset_src = '''
import torch
from torch.utils.data import Dataset, DataLoader

import torch_geometric
import torch_geometric.nn as geom_nn
import torch_geometric.data as geom_data
from torch_geometric.utils.smiles import from_smiles
from torch.utils.data._utils.collate import collate
from torch.utils.data._utils.collate import default_collate_fn_map

from sklearn.feature_extraction.text import CountVectorizer
from rdkit import Chem
from rdkit.Chem import AllChem, DataStructs, MACCSkeys

import torchvision
import typing
import re
import numpy as np
import pandas as pd

from typing import Literal, Callable, List, ClassVar, Any, Mapping, Tuple

MACCS_BITWIDTH = 167

'''

protac_dataset_src += inspect.getsource(get_fingerprint) + '\n'
protac_dataset_src += inspect.getsource(graph_collate) + '\n'
protac_dataset_src += inspect.getsource(custom_collate) + '\n'
protac_dataset_src += 'class ProtacDataset(Dataset): \n'
protac_dataset_src += inspect.getsource(ProtacDataset.__init__) + '\n'
protac_dataset_src += inspect.getsource(ProtacDataset.load) + '\n'
protac_dataset_src += inspect.getsource(ProtacDataset.__len__) + '\n'
protac_dataset_src += inspect.getsource(ProtacDataset.__getitem__)
with open(os.path.join(src_dir, 'protac_dataset.py'), 'w') as f:
    f.write(protac_dataset_src)

**TODO: The number of node features might not be correct...**

In [None]:
test_dataset = ProtacDataset(val_df, include_smiles_as_graphs=True)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=True, collate_fn=custom_collate)
batch = next(iter(test_dataloader))
NUM_NODE_FEATURES = num_node_features = batch['smiles_graph'].x.size()[-1]
NODE_EDGE_DIM = node_edge_dim = batch['smiles_graph'].edge_attr.size()[-1]

print(batch['e3_ligase'].size())
print(batch['poi_seq'].size())
print(batch)
print(f'Number of node features: {NUM_NODE_FEATURES}')

In [None]:
test_dataset = ProtacDataset(val_bin_protac_pedia_df, include_smiles_as_graphs=True)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=True, collate_fn=custom_collate)
batch = next(iter(test_dataloader))

print(test_dataset.hparams)
print(batch['e3_ligase'].size())
print(batch['poi_seq'].size())
print(batch['labels'].size())
print(batch)

Retrieve train and test PROTAC datasets:

In [None]:
def get_datasets(task: Literal['predict_active_inactive', 'predict_pDC50_and_Dmax'] = 'predict_active_inactive',
                 use_upsampled: bool = False,
                 regenerate_datasets: bool = False,
                 dataset_name: str = '', 
                 get_DMatrix: bool = False,
                 on_missing: Literal['error', 'ignore', 'regenerate'] = 'regenerate',
                 save_datasets: bool = True,
                 **protac_ds_kwargs):
    '''Get train and test datasets for the given task.
    Args:
        task (str): Task to perform. One of 'predict_active_inactive', 'predict_pDC50', 'predict_pDC50_and_Dmax'.
        use_upsampled (bool): Whether to use the upsampled training dataset for 'predict_active_inactive' task.
        regenerate_datasets (bool): Whether to regenerate the datasets if already present on disk.
        dataset_name (str): Trailing name of the dataset to generate. Useful for differentiating between different models. Example: f'train_pDC50_Dmax_dataset{dataset_name}.pt'
        on_missing (bool): Raise an error if the dataset is not found on disk. If False, generate the dataset instead.
        get_DMatrix (bool): TODO: Add support for XGBoost.
        protac_ds_kwargs (dict): Keyword arguments to pass to the ProtacDataset class. Passed to all train/val/test datasets to be generated.
    '''
    # Function to rename task names
    # TODO: I shall simply convert 'active_inactive' to 'bin' and the rest to 'regr'...
    rename = lambda x: x.replace('predict_', '').replace('active_inactive', 'bin').replace('pDC50_and_Dmax', 'pDC50_Dmax')
    # Setup dataframes
    datasets = {
        'train': dataframes['train_upsampled_bin'], # train_bin_df if not use_upsampled else train_upsampled_bin_df,
        'val': dataframes['val_bin'], # val_bin_df,
        'test': dataframes['test_bin'], # test_bin_df,
        'train_protac_pedia': train_bin_protac_pedia_df,
        'val_protac_pedia': val_bin_protac_pedia_df,
    }
    # Get regression datasets if required
    if task == 'predict_pDC50_and_Dmax':
        datasets['train'] = dataframes['train_regr'], # train_df
        datasets['val'] = dataframes['val_regr'], # val_df
        datasets['test'] = dataframes['test_regr'], # test_df
    # Start generation/retrieval of datasets by "overwriting" dataframe keys
    ret = {}
    for k, df in datasets.items():
        upsampled = '_upsampled' if use_upsampled else ''
        filename = f'{k}{upsampled}_{rename(task)}_dataset{dataset_name}.pt'
        filename = os.path.join(data_dir, 'protac', filename)
        if os.path.exists(filename):
            if regenerate_datasets:
                ret[k] = ProtacDataset(df, task=task, **protac_ds_kwargs)
                if save_datasets:
                    torch.save(ret[k], filename)
            else:
                ret[k] = torch.load(filename)
        else:
            if on_missing == 'error':
                raise FileNotFoundError(f'{k} dataset not found at: {filename}')
            elif on_missing == 'ignore':
                pass
            else:
                ret[k] = ProtacDataset(df, task=task, **protac_ds_kwargs)
                # TODO: Add support for XGBoost datasets here
                if save_datasets:
                    torch.save(ret[k], filename)
    return ret


    # # Get namings
    # if task == 'predict_pDC50_and_Dmax':
    #     train_ds = os.path.join(data_dir, 'protac', f'train_pDC50_Dmax_dataset{dataset_name}.pt')
    #     val_ds = os.path.join(data_dir, 'protac', f'val_pDC50_Dmax_dataset{dataset_name}.pt')
    #     test_ds = os.path.join(data_dir, 'protac', f'test_pDC50_Dmax_dataset{dataset_name}.pt')
    #     train_tmp_df = train_df
    #     val_tmp_df = val_df
    #     test_tmp_df = test_df
    # if task == 'predict_pDC50':
    #     # TODO
    #     raise NotImplementedError
    # elif task == 'predict_active_inactive':
    #     val_ds = os.path.join(data_dir, 'protac', f'val_bin_dataset{dataset_name}.pt')
    #     test_ds = os.path.join(data_dir, 'protac', f'test_bin_dataset{dataset_name}.pt')
    #     val_tmp_df = val_bin_df
    #     test_tmp_df = test_bin_df
    #     if use_upsampled:
    #         train_ds = os.path.join(data_dir, 'protac', f'train_upsampled_bin_dataset{dataset_name}.pt')
    #         train_tmp_df = train_upsampled_bin_df
    #     else:
    #         train_ds = os.path.join(data_dir, 'protac', f'train_bin_dataset{dataset_name}.pt')
    #         train_tmp_df = train_bin_df
    # # Generate datasets if required or not already on disk
    # if not os.path.exists(train_ds) and on_missing == 'error' and not regenerate_datasets:
    #     raise FileNotFoundError(f'Train dataset not found at: {train_ds}')
    # else:
    #     if regenerate_datasets or not on_missing == 'error':
    #         train_dataset = ProtacDataset(train_tmp_df, task=task,
    #                                     **protac_ds_kwargs)
    #         torch.save(train_dataset, train_ds)
    #     # TODO: Add support for XGBoost
    #     # if get_DMatrix:
    #     #     dtrain = xgb.DMatrix(train_dataset.morgan_fp, label=train_dataset.active)
    #     else:
    #         train_dataset = torch.load(train_ds)
    # if not os.path.exists(val_ds) and on_missing == 'error' and not regenerate_datasets:
    #     raise FileNotFoundError(f'Val dataset not found at: {val_ds}')
    # else:
    #     if regenerate_datasets or not on_missing == 'error':
    #         val_dataset = ProtacDataset(val_tmp_df, task=task,
    #                                     **protac_ds_kwargs)
    #         torch.save(val_dataset, val_ds)
    #     else:
    #         val_dataset = torch.load(val_ds)
    # if not os.path.exists(test_ds) and on_missing == 'error' and not regenerate_datasets:
    #     raise FileNotFoundError(f'Test dataset not found at: {test_ds}')
    # else:
    #     if regenerate_datasets or not on_missing == 'error':
    #         test_dataset = ProtacDataset(test_tmp_df, task=task,
    #                                     **protac_ds_kwargs)
    #         torch.save(test_dataset, test_ds)
    #     else:
    #         test_dataset = torch.load(test_ds)
    # return {
    #     'train': train_dataset,
    #     'val': val_dataset,
    #     'test': test_dataset,
    # }

## Plotting Datasets

In [None]:
tmp = train_df[['pDC50', 'Dmax', 'cell_type', 'poi_gene_id', 'active']]
# tmp.loc[tmp['active'].isna(), 'active'] = -1
cols = {
    'pDC50': '$pDC_{50}$',
    'Dmax': '$D_{max}$ (%)',
    'cell_type': 'Cell type',
    'poi_gene_id': 'POI/Target',
    'active': 'Active/Inactive',
}
tmp = tmp.rename(columns=cols)
ax = sns.pairplot(data=tmp, diag_kind='hist', hue='Active/Inactive', palette='Set2', corner=False)
# plt.show()
plt.close()

In [None]:
dfs = [
    (protac_db_df.copy(), 'PROTAC-DB'),
    (protac_pedia_df.copy(), 'PROTAC-Pedia'),
    (test_df.copy(), 'PROTAC-Pedia-Test'),
]
for tmp, df_name in dfs:
    tmp['pDC50'] = tmp['pDC50'].astype(float)
    tmp['Dmax'] = tmp['Dmax'].astype(float) * 100
    tmp['active'] = tmp['active'].replace({True: 'Active', False: 'Inactive'})
    if 'PROTAC-Pedia' in df_name:
        del tmp['Active/Inactive']
    old2new_names = {
        'pDC50': '$pDC_{50}$ ($-Log_{10}(M)$)',
        'Dmax': '$D_{max}$ (%)',
        'active': 'Active/Inactive',
    }
    tmp = tmp.rename(columns=old2new_names)

    sns.scatterplot(data=tmp, x='$pDC_{50}$ ($-Log_{10}(M)$)', y='$D_{max}$ (%)',
                    hue='Active/Inactive', hue_order=['Inactive', 'Active'])
    plt.grid(axis='both', alpha=0.8)
    plt.title(f'{df_name}' + ': $D_{max}$ vs. $pDC_{50}$')
    # plt.xticks(rotation=90)
    # plt.legend(loc='upper left', bbox_to_anchor=(1, 1), ncol=1) #, fancybox=True, shadow=True)

    f = os.path.join(fig_dir, f'scatter_{df_name.lower()}_active_entries')
    plt.savefig(f + '.pdf', bbox_inches='tight')
    plt.savefig(f + '.png', bbox_inches='tight')
    plt.show()
    plt.close()
del dfs

TODO:

* try different binning size
* try two-sided ANNOVA test in order to compare the distributions of the different bins. In this way, we might get some more information on how to split up the data
* try bswarm plot

In [None]:
dfs = [
    (protac_db_df.copy(), 'PROTAC-DB'),
    (protac_pedia_df.copy(), 'PROTAC-Pedia'),
    (test_df.copy(), 'PROTAC-Pedia-Test'),
]
for tmp, df_name in dfs:
    tmp['pDC50'] = tmp['pDC50'].astype(float)
    tmp['Dmax'] = tmp['Dmax'].astype(float) * 100
    # Change the bin size by rounding
    tmp = (tmp[['pDC50', 'Dmax']]).round(1)
    old2new_names = {
        'pDC50': '$pDC_{50}$ ($-Log_{10}(M)$)',
        'Dmax': '$D_{max}$ (%)',
    }
    tmp = tmp.rename(columns=old2new_names)
    # Change plot/figure size
    plt.figure(figsize=(9, 6))
    ax = sns.boxplot(data=tmp, x='$pDC_{50}$ ($-Log_{10}(M)$)', y='$D_{max}$ (%)')
    plt.xticks(rotation=90)
    # plt.xlim(0.0, 50.0)
    plt.ylim(0., 105.)
    plt.grid(axis='y', alpha=0.7)
    plt.title(f'{df_name}' + ': $D_{max}$ vs. $pDC_{50}$')
    f = os.path.join(fig_dir, f'boxplot_{df_name.lower()}_pDC50_vs_Dmax')
    plt.savefig(f + '.pdf', bbox_inches='tight')
    plt.savefig(f + '.png', bbox_inches='tight')
    plt.show()
    plt.close()
del dfs

In [None]:
# For plotting the distribution of POI sequences, we employ an Ordinal encoder
# trainer on all the POI sequences in all datasets. In this way, when we
# transform the sequences for each dataset, we will get the same encoding
# and so the classes will be consistent.
tmp = pd.concat([protac_db_df, protac_pedia_df], axis=0)
poi_seq_enc = sklearn.preprocessing.OrdinalEncoder()
poi_seq_enc.fit(tmp['poi_seq'].to_numpy().reshape(-1, 1))

dfs = [
    (protac_db_df.copy(), 'PROTAC-DB'),
    (protac_pedia_df.copy(), 'PROTAC-Pedia'),
    (test_df.copy(), 'PROTAC-Pedia-Test'),
]
common_poi = []
for col in ['POI Class', 'E3 Ligase', 'Cell Type']:
    for tmp, df_name in dfs:
        tmp['cell_type'] = tmp['cell_type'].str.replace('esophagealcancercellline', '') # it's just too long to print...
        tmp['POI Class'] = poi_seq_enc.transform(tmp['poi_seq'].to_numpy().reshape(-1, 1))
        old2new_names = {
            'e3_ligase': 'E3 Ligase',
            'cell_type': 'Cell Type',
        }
        tmp = tmp.rename(columns=old2new_names)
        print(f'Number or unique {col} in {df_name}: {len(tmp[col].unique())}')
        # # Change plot/figure size
        # if col == 'Cell Type':
        #     plt.figure(figsize=(15, 6))
        n = 20
        top_n = tmp[col].value_counts().index[:n]
        # print([poi_seq_enc.inverse_transform([[tmp.at[i, 'POI Class']]])[0][0] for i in top_n])
        
        if col == 'POI Class' and df_name != 'PROTAC-Pedia-Test':
            pois = []
            for i in top_n:
                # poi_name = poi_seq_enc.inverse_transform([[int(i)]])[0][0]
                # print(i, poi_name)
                pois.append(i)
            if common_poi == []:
                common_poi = pois
            common_poi = list(set(common_poi) & set(pois))
        
        ax = sns.countplot(data=tmp, x=col, order=top_n)
        fmt = '{:0.0f}'
        rot = 0 # 90 if col == 'Cell Type' else 0
        for bars_group in ax.containers:
            ax.bar_label(bars_group, padding=2, fmt=fmt, rotation=rot) # fontsize=12
        plt.xticks(rotation=0 if col == 'E3 Ligase' else 90)
        plt.grid(axis='y', alpha=0.7)
        top_n = f' (Top-{n} most frequent)' if col != 'E3 Ligase' else ''
        plt.title(f'{df_name}: {col} Distribution{top_n}')
        t = col.lower().replace(' ', '_')
        f = os.path.join(fig_dir, f'countplot_{df_name.lower()}_{t}_distribution')
        plt.savefig(f + '.pdf', bbox_inches='tight')
        plt.savefig(f + '.png', bbox_inches='tight')
        plt.show()
        plt.close()
del dfs

for i in common_poi:
    poi_name = poi_seq_enc.inverse_transform([[int(i)]])[0][0]
    print(i, poi_name)

### Dimensionality Reduction

In [None]:
# import pandas as pd
# from rdkit import Chem
# from rdkit.Chem import Descriptors
# from rdkit.ML.Descriptors import MoleculeDescriptors
# from sklearn.decomposition import PCA
# import matplotlib.pyplot as plt

# # List of SMILES
# smiles_list = ["CCO", "CCN(C)C=O", "C1=CC=CC=C1"]

# smiles_list = protac_db_df['Smiles_nostereo'].to_list()

# # Convert SMILES to RDKit molecules
# mols = [Chem.MolFromSmiles(smiles) for smiles in smiles_list]

# # Calculate molecular descriptors
# descriptor_names = [desc[0] for desc in Descriptors._descList]
# descriptor_calculator = MoleculeDescriptors.MolecularDescriptorCalculator(descriptor_names)
# descriptors = [descriptor_calculator.CalcDescriptors(mol) for mol in mols]

# # Perform Principal Component Analysis (PCA)
# pca = PCA(n_components=2)
# pca.fit(descriptors)
# transformed = pca.transform(descriptors)

# # Create a DataFrame for plotting
# df = pd.DataFrame(transformed, columns=['PC1', 'PC2'])
# df['SMILES'] = smiles_list

In [None]:
# %matplotlib
# # Plot the chemical space
# fig,ax = plt.subplots()

# sc = plt.scatter(df['PC1'], df['PC2'])

# annot = ax.annotate("", xy=(0,0), xytext=(20,20), textcoords="offset points",
#                     bbox=dict(boxstyle="round", fc="w"),
#                     arrowprops=dict(arrowstyle="->"))
# annot.set_visible(False)


# def update_annot(ind, prev_ind=-1):
#     if prev_ind == -1:
#         prev_ind = ind["ind"][0]
#     if ind["ind"][0] == prev_ind:
#         prev_ind = ind["ind"][0]
#         pos = sc.get_offsets()[ind["ind"][0]]
#         annot.xy = pos
#         # text = "{}, {}".format(" ".join(list(map(str,ind["ind"]))), 
#         #                     " ".join([df.at[n, 'SMILES'] for n in ind["ind"]]))
#         text = " ".join([df.at[n, 'SMILES'] for n in ind["ind"]])
#         annot.set_text(text)
#         # annot.get_bbox_patch().set_facecolor(cmap(norm(c[ind["ind"][0]])))
#         annot.get_bbox_patch().set_alpha(0.4)
    

# prev_ind = -1
# def hover(event):
#     vis = annot.get_visible()
#     if event.inaxes == ax:
#         cont, ind = sc.contains(event)
#         if cont:
#             update_annot(ind, prev_ind)
#             annot.set_visible(True)
#             fig.canvas.draw_idle()
#         else:
#             if vis:
#                 annot.set_visible(False)
#                 fig.canvas.draw_idle()

# fig.canvas.mpl_connect("motion_notify_event", hover)

# plt.xlabel('PC1')
# plt.ylabel('PC2')
# plt.title('Chemical Space')
# plt.show()

RDKit molecular descriptors:

In [None]:
import pandas as pd
from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit.ML.Descriptors import MoleculeDescriptors
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

def get_mol_descriptors(smiles_list):
    # Convert SMILES to RDKit molecules
    mols = [Chem.MolFromSmiles(smiles) for smiles in smiles_list]
    # Calculate molecular descriptors
    descriptor_names = [desc[0] for desc in Descriptors._descList]
    descriptor_calculator = MoleculeDescriptors.MolecularDescriptorCalculator(descriptor_names)
    descriptors = [descriptor_calculator.CalcDescriptors(mol) for mol in mols]
    return np.array(descriptors)

print('Calculating molecular descriptors...')

smiles_db_list = protac_db_df['Smiles_nostereo'].to_list()
f = os.path.join(data_dir, 'protac', 'protac_db_descr.npy')
if os.path.exists(f):
    protac_db_descr = np.load(f)
else:
    protac_db_descr = get_mol_descriptors(smiles_db_list)
    np.save(os.path.join(f, protac_db_descr))

smiles_pedia_list = protac_pedia_df['Smiles_nostereo'].to_list()
f = os.path.join(data_dir, 'protac', 'protac_pedia_descr.npy')
if os.path.exists(f):
    protac_pedia_descr = np.load(f)
else:
    protac_pedia_descr = get_mol_descriptors(smiles_pedia_list)
    np.save(os.path.join(f, protac_pedia_descr))

print('Done.')

In [None]:
# plt.ioff() # Turn off interactive mode like %matplotlib magic...


def plot_tSNE(df, perplexity=0, descr='PROTAC-DB'):
    fig, ax = plt.subplots()
    # Get sub-plot and scatter plot
    sc = plt.scatter(df['Dimension 1'], df['Dimension 2'], s=5, alpha=0.6)
    # Create annotation
    annot = ax.annotate('', xy=(0,0), xytext=(20,20),
                        textcoords='offset points',
                        bbox=dict(boxstyle='round', fc='w'),
                        arrowprops=dict(arrowstyle='->'))
    annot.set_visible(False)
    # Function to update annotation text
    def update_annot(ind, prev_ind=-1):
        if prev_ind == -1:
            prev_ind = ind['ind'][0]
        if ind['ind'][0] == prev_ind:
            prev_ind = ind['ind'][0]
            pos = sc.get_offsets()[ind['ind'][0]]
            annot.xy = pos
            # text = '{}, {}'.format(' '.join(list(map(str,ind['ind']))), 
            #                     ' '.join([df.at[n, 'SMILES'] for n in ind['ind']]))
            text = ' '.join([df.at[n, 'SMILES'] for n in ind['ind']])
            annot.set_text(text)
            # annot.get_bbox_patch().set_facecolor(cmap(norm(c[ind['ind'][0]])))
            annot.get_bbox_patch().set_alpha(0.4)
    # Function to show annotation on hover
    prev_ind = -1
    def hover(event):
        vis = annot.get_visible()
        if event.inaxes == ax:
            cont, ind = sc.contains(event)
            if cont:
                update_annot(ind, prev_ind)
                annot.set_visible(True)
                fig.canvas.draw_idle()
            else:
                if vis:
                    annot.set_visible(False)
                    fig.canvas.draw_idle()
    # fig.canvas.mpl_connect('motion_notify_event', hover)
    plt.xlabel('Dimension 1')
    plt.ylabel('Dimension 2')
    plt.title(f'Chemical Space for {descr} (t-SNE), Perplexity={perplexity}')
    plt.grid(alpha=0.8)
    plt.show()

In [None]:
from bhtsne import tsne

# Perplexity values to try
perplexities = [5, 10, 20, 30, 50]

# Plot chemical space for each perplexity
for perplexity in perplexities:
    transformed = tsne(protac_db_descr,
                        perplexity=perplexity,
                        dimensions=2,
                        theta=0.5, # Original: 0.5
                        rand_seed=42)
    # Create a DataFrame for plotting
    df = pd.DataFrame(transformed, columns=['Dimension 1', 'Dimension 2'])
    plt.scatter(df['Dimension 1'], df['Dimension 2'], s=5, alpha=0.6, label='PROTAC-DB')
    
    transformed = tsne(protac_pedia_descr,
                        perplexity=perplexity,
                        dimensions=2,
                        theta=0.5, # Original: 0.5
                        rand_seed=42)
    # Create a DataFrame for plotting
    df = pd.DataFrame(transformed, columns=['Dimension 1', 'Dimension 2'])
    plt.scatter(df['Dimension 1'], df['Dimension 2'], s=5, alpha=0.6, label='PROTAC-Pedia')
    
    plt.xlabel('Dimension 1')
    plt.ylabel('Dimension 2')
    plt.title(f'Chemical Space (t-SNE), Perplexity={perplexity}')
    plt.legend()
    plt.grid(alpha=0.8)
    plt.show()

    # for descr, descriptors, smiles_list in zip(['PROTAC-DB', 'PROTAC-PEDIA'], 
    #                                            [protac_db_descr, protac_pedia_descr], 
    #                                            [smiles_db_list, smiles_pedia_list]):
    #     print(f'Plotting chemical space for {descr} (t-SNE), Perplexity={perplexity}...')
    #     # Perform t-SNE
    #     # tsne = TSNE(n_components=2,
    #     #             n_iter=500,
    #     #             n_iter_without_progress=20,
    #     #             perplexity=perplexity,
    #     #             random_state=42,
    #     #             learning_rate='auto',
    #     #             init='pca',
    #     #             n_jobs=-1)
    #     # transformed = tsne.fit_transform(np.array(descriptors))
        
    #     plot_tSNE(df, perplexity=perplexity, descr=descr)  

Fingerprints via SKlearn:

In [None]:
# get_fingerprint(smiles: str, n_bits: int = 1024, fp_type: Literal['morgan', 'maccs', 'path'] = 'morgan',
#                     min_path: int = 1, max_path: int = 2,
#                     atomic_radius: int = 2)

for fp_type in ['morgan', 'maccs']:
    for n_bits in [1024, 2048, 4096]:
        if fp_type == 'maccs' and n_bits != 1024:
            # MACCS keys only have 167 bits
            continue
        # Get fingerprints
        protac_db_descr = np.array([get_fingerprint(smiles, n_bits=n_bits, fp_type=fp_type) for smiles in smiles_db_list]).astype(np.float64)
        protac_pedia_descr = np.array([get_fingerprint(smiles, n_bits=n_bits, fp_type=fp_type) for smiles in smiles_pedia_list]).astype(np.float64)
        # Perplexity values to try
        perplexities = [5, 10, 20, 30, 50]
        # Plot chemical space for each perplexity
        for perplexity in perplexities:
            tsne = TSNE(n_components=2,
                        n_iter=1000,
                        perplexity=perplexity,
                        random_state=42,
                        learning_rate='auto',
                        init='pca',
                        n_jobs=-1)
            transformed = tsne.fit_transform(protac_db_descr)
            # Create a DataFrame for plotting
            df = pd.DataFrame(transformed, columns=['Dimension 1', 'Dimension 2'])
            plt.scatter(df['Dimension 1'], df['Dimension 2'], s=5, alpha=0.6, label='PROTAC-DB')
            
            transformed = tsne.fit_transform(protac_pedia_descr)
            # Create a DataFrame for plotting
            df = pd.DataFrame(transformed, columns=['Dimension 1', 'Dimension 2'])
            plt.scatter(df['Dimension 1'], df['Dimension 2'], s=5, alpha=0.6, label='PROTAC-Pedia')
            
            descr = 'MACCS' if fp_type == 'maccs' else f'Morgan {n_bits} bits'
            plt.title(f'Chemical Space from {descr} Fingerprints (t-SNE), Perplexity={perplexity}')
            plt.xlabel('Dimension 1')
            plt.ylabel('Dimension 2')
            plt.legend()
            plt.grid(alpha=0.8)
            plt.show()

[1] K. M. Sakamoto, K. B. Kim, A. Kumagai, F. Mercurio, C. M. Crews, and R. J. Deshaies, “Protacs: chimeric molecules that target proteins to the Skp1-Cullin-F box complex for ubiquitination and degradation,” Proc Natl Acad Sci U S A, vol. 98, no. 15, pp. 8554–8559, Jul. 2001, doi: 10.1073/pnas.141230798.

TruncatedSVD:

In [None]:
from sklearn.decomposition import TruncatedSVD

# get_fingerprint(smiles: str, n_bits: int = 1024, fp_type: Literal['morgan', 'maccs', 'path'] = 'morgan',
#                     min_path: int = 1, max_path: int = 2,
#                     atomic_radius: int = 2)

for fp_type in ['morgan', 'maccs']:
    for n_bits in [1024, 2048, 4096]:
        if fp_type == 'maccs' and n_bits != 1024:
            # MACCS keys only have 167 bits
            continue
        # Get fingerprints
        protac_db_descr = np.array([get_fingerprint(smiles, n_bits=n_bits, fp_type=fp_type) for smiles in smiles_db_list]).astype(np.float64)
        protac_pedia_descr = np.array([get_fingerprint(smiles, n_bits=n_bits, fp_type=fp_type) for smiles in smiles_pedia_list]).astype(np.float64)
        # Perplexity values to try
        n_components = [2, 3]
        # Plot chemical space for each number of components
        for n in n_components:
            # Perform TruncatedSVD
            svd = TruncatedSVD(n_components=n, random_state=42, n_iter=100)
            # Create a DataFrame for plotting
            transformed = svd.fit_transform(protac_db_descr)
            column_names = [f'Component {i+1}' for i in range(n)]
            df_db = pd.DataFrame(transformed, columns=column_names)

            transformed = svd.fit_transform(protac_pedia_descr)
            column_names = [f'Component {i+1}' for i in range(n)]
            df_pedia = pd.DataFrame(transformed, columns=column_names)

            # Plot the chemical space
            if n == 2:
                plt.scatter(df_db['Component 1'], df_db['Component 2'], label='PROTAC-DB', s=5, alpha=0.6)
                plt.scatter(df_pedia['Component 1'], df_pedia['Component 2'], label='PROTAC-Pedia', s=5, alpha=0.6)
                # for i, row in df.iterrows():
                #     plt.annotate(row['SMILES'], (row['Component 1'], row['Component 2']))
                plt.xlabel('Component 1')
                plt.ylabel('Component 2')
                plt.legend()
                plt.grid(alpha=0.8)
                plt.title('Chemical Space (TruncatedSVD)')
                descr = 'MACCS' if fp_type == 'maccs' else f'Morgan {n_bits} bits'
                plt.title(f'Chemical Space from {descr} Fingerprints (TruncatedSVD), n.components={n}')
            elif n == 3:
                fig = plt.figure()
                ax = fig.add_subplot(111, projection='3d')
                ax.scatter(df_db['Component 1'], df_db['Component 2'], df_db['Component 3'], label='PROTAC-DB', s=5, alpha=0.6)
                ax.scatter(df_pedia['Component 1'], df_pedia['Component 2'], df_pedia['Component 3'], label='PROTAC-Pedia', s=5, alpha=0.6)
                # for i, row in df_db.iterrows():
                #     ax.text(row['Component 1'], row['Component 2'], row['Component 3'], row['SMILES'])
                ax.set_xlabel('Component 1')
                ax.set_ylabel('Component 2')
                ax.set_zlabel('Component 3')
                plt.legend()
                descr = 'MACCS' if fp_type == 'maccs' else f'Morgan {n_bits} bits'
                plt.title(f'Chemical Space from {descr} Fingerprints (TruncatedSVD), n.components={n}')
            else:
                print(f"Cannot plot chemical space for {n} components.")
            plt.show()

## ML Models

Define global dictionary for the results:

In [None]:
experiments_results = {}

In [None]:
def save_results(result, result_name):
    with open(os.path.join(checkpoint_dir, result_name + '.pkl'), 'wb') as fp:
        pickle.dump(result, fp)
        print(f'Results {result_name} saved successfully to file.')

def load_result(result_name):
    if not os.path.exists(os.path.join(checkpoint_dir, result_name + '.pkl')):
        print(f'WARNING: File {result_name} not found.')
        return {}
    with open(os.path.join(checkpoint_dir, result_name + '.pkl'), 'rb') as fp:
        return pickle.load(fp)

def print_dict(title, d, filter_keys=True):
    print(f'{title}')
    filters = ['prediction', 'labels', 'logits', 'fpr', 'tpr', 'confusion']
    for k, v in d.items():
        if filter_keys:
            if not any([f in k for f in filters]):
                print(f'\t* {k}: {v}')
        else:
            print(f'\t* {k}: {v}')

### Setup/Shared Functions

Evaluation functions:

In [None]:
def get_eval_results(model: nn.Module,
                     task: str = 'predict_active_inactive',
                     num_gpus: int = 0,
                     return_logits: bool = True,
                     return_logits_only: bool = False,
                     phase: Literal['train', 'val', 'test'] = ['val', 'test'],
                     run_lightning_eval: bool = True,
                     trainer: pl.Trainer | None = None) -> dict:
    """Get predictions from a model.

    Args:
        model (nn.Module): Model to evaluate.
        task (str, optional): Task to perform. Defaults to 'predict_pDC50_and_Dmax'.
        num_gpus (int, optional): Number of available GPUs. Defaults to 0.
        run_lightning_eval (bool, optional): Do not run Pytorch Lightning evaluation, useful for retrieving predictions only by setting it to False. Defaults to True.
        return_logits (bool, optional): Return predictions, e.g., TPR, FPR, et cetera. Defaults to True.
        trainer (pl.Trainer | None, optional): Pytorch Trainer to handle the evaluation. If not supplied it will automatically instantiated. Defaults to None.

    Returns:
        dict: Collected evaluation results.
    """
    device = torch.device('cuda:0' if torch.cuda.is_available() and num_gpus else 'cpu')
    model.eval()
    if num_gpus > 0:
        model = model.to(device)
    if trainer is None:
        if torch.cuda.is_available() and num_gpus > 0:
            pl_devices = math.ceil(num_gpus)
            accelerator = 'gpu'
            precision = '16-mixed'
        else:
            pl_devices = 4
            accelerator = 'auto'
            precision = '32'
        trainer = pl.Trainer(accelerator=accelerator,
                             devices=pl_devices,
                             precision=precision,
                             enable_checkpointing=False,
                             enable_progress_bar=False,
                             enable_model_summary=False)
    # NOTE: "The length of the list corresponds to the number of test
    # dataloaders used."
    eval_results = {}
    if run_lightning_eval:
        if 'val' in phase:
            eval_results.update(trainer.validate(model, verbose=0)[0])
        if 'test' in phase:
            eval_results.update(trainer.test(model, verbose=0)[0])
    # trainer.test(model, model.test_dataloader())
    if task == 'predict_pDC50_and_Dmax':
        # TODO: Regression task. Example:
        # preds = torch.concat(trainer.predict(model, model.test_dataloader()))
        # degradation_predictions = np.array(preds[:, 0])
        # concentration_predictions = np.array(preds[:, 1])
        # degradation_labels = np.array(model.test_dataset.Dmax)
        # concentration_labels = np.array(model.test_dataset.pDC50)
        pass
    elif task == 'predict_active_inactive':
        if return_logits:
            def get_logits(ds_name):
                ds = model.val_dataset if ds_name == 'val' else model.test_dataset
                dl = DataLoader(ds, batch_size=model.batch_size, shuffle=False,
                                collate_fn=custom_collate)
                preds = torch.concat(trainer.predict(model, dl)).to(device)
                if return_logits_only:
                    ret = {f'{ds_name}_logits': preds.cpu().numpy()}
                    return ret
                print(f'Getting additional information for {ds_name} set')
                y = torch.concat([batch['labels'] for batch in dl])
                y = y.to(device)
                # Obtain binary predictions and ROC curve
                sigmoid = nn.Sigmoid().to(device)
                roc = ROC(task='binary').to(device)
                if torch.cuda.is_available() and num_gpus > 0:
                    y = y.to(torch.half)
                    preds = preds.to(torch.half)
                    sigmoid = sigmoid.to(torch.half)
                    roc = roc.to(torch.half)
                bin_preds = sigmoid(preds) >= 0.5
                fpr, tpr, _ = roc(preds, y.to(torch.long))
                cm = confusion_matrix(y.cpu().numpy().astype(int),
                                    bin_preds.cpu().numpy().astype(int))
                ret = {
                    f'{ds_name}_prediction': bin_preds.cpu().numpy().astype(int),
                    f'{ds_name}_logits': preds.cpu().numpy(),
                    f'{ds_name}_labels': y.cpu().numpy(),
                    f'{ds_name}_fpr': fpr.cpu().numpy(),
                    f'{ds_name}_tpr': tpr.cpu().numpy(),
                    f'{ds_name}_confusion_matrix': cm,
                }
                return ret
            if 'val' in phase:
                eval_results.update(get_logits('val'))
            if 'test' in phase:
                eval_results.update(get_logits('test'))
    return eval_results

Functions to plot the models predictions against the true values:

In [None]:
def plot_eval_degradation(degr_pred, degr_labels, score, descr='', filename=None):
    sorted_idx = np.argsort(degr_labels)
    line_descr = f'{descr}Predicted degr.(%) - Loss: {score:.5f}'
    plt.plot(degr_pred[sorted_idx], label=line_descr)
    plt.plot(degr_labels[sorted_idx], label='Reference degradation (%)')
    plt.legend(loc='upper left', bbox_to_anchor=(1, 1), ncol=1) #, fancybox=True, shadow=True)
    plt.grid(alpha=0.8)
    plt.xlabel('Test ID (sorted by degradation)')
    plt.ylabel('Degradation (%)')
    if filename is not None:
        plt.savefig(filename)
        plt.close()
    else:
        plt.show()

def plot_eval_concentration(conc_pred, conc_labels, score, descr='', filename=None):
    sorted_idx = np.argsort(conc_labels)
    line_descr = f'{descr}Predicted conc. (-log10(M)) - Loss: {score:.5f}'
    plt.plot(conc_pred[sorted_idx], label=line_descr)
    plt.plot(conc_labels[sorted_idx], label='Reference conc. (-log10(M))')
    plt.legend(loc='upper left', bbox_to_anchor=(1, 1), ncol=1) #, fancybox=True, shadow=True)
    plt.grid(alpha=0.8)
    plt.xlabel('Test ID (sorted by concentration)')
    plt.ylabel('Concentration (-log10(M))')
    if filename is not None:
        plt.savefig(filename)
        plt.close()
    else:
        plt.show()

def plot_eval_classification(pred, labels, score, descr='', filename=None):
    sorted_idx = np.argsort(labels)
    line_descr = f'{descr}Predicted - Accuracy: {score * 100:.2f}%'
    plt.plot(pred[sorted_idx], label=line_descr)
    plt.plot(labels[sorted_idx], label='Reference')
    # dummy_acc = labels[labels > 0].sum() / len(labels)
    # line_descr = f'Dummy Accuracy: {dummy_acc * 100:.2f}%'
    # plt.plot([dummy_acc] * len(labels), '--', label=line_descr, color='black')
    plt.legend(loc='upper left', bbox_to_anchor=(1, 1), ncol=1) #, fancybox=True, shadow=True)
    plt.grid(alpha=0.8)
    plt.xlabel('Test ID (sorted by activity)')
    plt.ylabel('Active/Inactive')
    if filename is not None:
        plt.savefig(filename)
        plt.close()
    else:
        plt.show()

Remove non-optimal models:

In [None]:
from os import listdir
from os.path import isfile, isdir, join

def del_non_optimal_ckpt(checkpoint_root_dir: str, best_trial_names: List,
                         model_name: str, logs_root_dir: str | None = None,
                         verbose: int = 0):
    """Delete all checkpoints that are not the best ones.

    Args:
        checkpoint_root_dir (str): Checkpoint root directory.
        best_trial_names (List): List of best checkpoints.
        model_name (str): Remove all checkpoints that contain this string.
    """
    all_ckpt = [f for f in listdir(checkpoint_root_dir) if isfile(join(checkpoint_root_dir, f)) or isdir(join(checkpoint_root_dir, f))]
    for ckpt in all_ckpt:
        if not any([trial in ckpt for trial in best_trial_names]):
            if model_name in ckpt:
                filepath = os.path.join(checkpoint_root_dir, ckpt)
                if verbose:
                    print(f'Removing {filepath}')
                if os.path.isdir(filepath):
                    shutil.rmtree(filepath)
                elif os.path.isfile(filepath):
                    os.remove(filepath)
                else:
                    print(f'WARNING. "{filepath}" is a special file (socket, FIFO, device file)')
    if logs_root_dir is not None:
        all_logs = [f for f in listdir(logs_root_dir) if isfile(join(logs_root_dir, f)) or isdir(join(logs_root_dir, f))]
        for logf in all_logs:
            if not any([f'logs_{trial}' in logf for trial in best_trial_names]):
                if model_name in logf:
                    filepath = os.path.join(logs_root_dir, logf)
                    if verbose:
                        print(f'Removing {filepath}')
                    if os.path.isdir(filepath):
                        shutil.rmtree(filepath)
                    elif os.path.isfile(filepath):
                        os.remove(filepath)
                    else:  
                        print(f'WARNING. "{filepath}" is a special file (socket, FIFO, device file)')
        

Ray Tune generic function to tune hyperparameters, given a trainable function:

In [None]:
def tune_model(train_fn_with_parameters: Callable,
               config: dict,
               num_epochs: int = 10,
               num_samples: int = 10,
               task: Literal['predict_active_inactive', 'predict_pDC50_and_Dmax'] = 'predict_pDC50_and_Dmax',
               ray_local_dir: str = os.path.join(checkpoint_dir, 'ray'),
               ray_run_name: str = 'tune_model',
               params2report: List[str] | None = None,
               gpus_per_trial: int = 0):
    """Tune hyperparameters of a model using Ray Tune.

    Args:
        train_fn_with_parameters (Callable): Trainable function that takes a dictionary of hyperparameters as input and returns a dictionary of metrics.
        config (dict): Dictionary of hyperparameters to tune.
        num_epochs (int, optional): Number of epochs to train. Defaults to 10.
        num_samples (int, optional): Number of samples to explore/run in Ray Tune. Defaults to 10.
        task (Literal[&#39;predict_active_inactive&#39;, &#39;predict_pDC50&#39;, &#39;predict_pDC50_and_Dmax&#39;], optional): Task to train the model for. Defaults to 'predict_pDC50_and_Dmax'.
        ray_local_dir (str, optional): Ray checkpoint directory. Defaults to os.path.join(checkpoint_dir, 'ray').
        ray_run_name (str, optional): Run-specific name. Defaults to 'tune_model'.
        params2report (List[str] | None, optional): Parameters to report in logging. Defaults to None.
        gpus_per_trial (int, optional): GPUs per single trial. Defaults to 0.

    Returns:
        ResultGrid: Ray Tune result grid wrapper.
    """
    optim_metric = 'opt_score' if task == 'predict_active_inactive' else 'val_loss'
    # optim_metric = 'acc' if task == 'predict_active_inactive' else 'loss'
    optim_mode = 'max' if task == 'predict_active_inactive' else 'min'
    # Setup reporting and logging
    if task == 'predict_active_inactive':
        metric_columns = [
            'val_loss',
            'val_acc',
            'roc_auc',
            'precision',
            'recall',
            'f1_score',
            'training_iteration'
        ]
    else:
        metric_columns = ['val_loss', 'training_iteration']
    params2report = config.keys() if params2report is None else params2report
    reporter = CLIReporter(
        parameter_columns=params2report,
        metric_columns=metric_columns,
        # metric=optim_metric,
        # mode=optim_mode,
        # sort_by_metric=True,
        )
    # Setup scheduler
    # scheduler = HyperBandScheduler(time_attr='training_iteration', max_t=200)
    scheduler = ASHAScheduler(
        max_t=num_epochs,
        grace_period=1,
        reduction_factor=2)
    # Instantiate Ray Tune Tuner
    tuner = tune.Tuner(
        tune.with_resources(
            train_fn_with_parameters,
            resources={'cpu': 1, 'gpu': gpus_per_trial}
        ),
        tune_config=tune.TuneConfig(
            metric=optim_metric,
            mode=optim_mode,
            scheduler=scheduler,
            search_alg=ConcurrencyLimiter(OptunaSearch(), max_concurrent=8),
            num_samples=num_samples
        ),
        run_config=air.RunConfig(
            local_dir=ray_local_dir,
            name=ray_run_name,
            progress_reporter=reporter,
            verbose=0, # Default: (Comment this line)
            checkpoint_config=air.CheckpointConfig(
                num_to_keep=1,
                checkpoint_score_attribute=optim_metric,
                checkpoint_score_order=optim_mode)
        ),
        param_space=config,
    )
    results = tuner.fit()
    print(f'Best hyperparameters found were:')
    for k, v in results.get_best_result().config.items():
        print(f'\t* {k}: {v}')
    print(f'\nBest metrics achieved:')
    for k, v in results.get_best_result().metrics.items():
        print(f'\t* {k}: {v}')
    
    return results


In [None]:
# import logging

# !export RAY_RUNTIME_ENV_WORKING_DIR_CACHE_SIZE_GB=0

# ray.shutdown()
# ray.init(runtime_env={'working_dir': '.', 'excludes': ['data/protac', 'data/protac/*.pt']}, logging_level=logging.ERROR)

### Custom PytorchLightning Callbacks

Optuna pruning callback:

(It needed to be adapted to the current version of PytorchLightning. Luckily, only one line was changed, hopefully maintaining the same functionality)

In [None]:
import warnings

from packaging import version

import optuna
from optuna.storages._cached_storage import _CachedStorage
from optuna.storages._rdb.storage import RDBStorage

from pytorch_lightning.callbacks import Callback

# Define key names of `Trial.system_attrs`.
_PRUNED_KEY = "ddp_pl:pruned"
_EPOCH_KEY = "ddp_pl:epoch"

class CustomPyTorchLightningPruningCallback(Callback):
    """PyTorch Lightning callback to prune unpromising trials.

    See `the example <https://github.com/optuna/optuna-examples/blob/
    main/pytorch/pytorch_lightning_simple.py>`__
    if you want to add a pruning callback which observes accuracy.

    Args:
        trial:
            A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the
            objective function.
        monitor:
            An evaluation metric for pruning, e.g., ``val_loss`` or
            ``val_acc``. The metrics are obtained from the returned dictionaries from e.g.
            ``pytorch_lightning.LightningModule.training_step`` or
            ``pytorch_lightning.LightningModule.validation_epoch_end`` and the names thus depend on
            how this dictionary is formatted.

    .. note::
        For the distributed data parallel training, the version of PyTorchLightning needs to be
        higher than or equal to v1.5.0. In addition, :class:`~optuna.study.Study` should be
        instantiated with RDB storage.
    """

    def __init__(self, trial: optuna.trial.Trial, monitor: str) -> None:
        super().__init__()

        self._trial = trial
        self.monitor = monitor
        self.is_ddp_backend = False

    def on_init_start(self, trainer: Trainer) -> None:
        self.is_ddp_backend = (
            trainer._accelerator_connector.distributed_backend is not None  # type: ignore
        )
        if self.is_ddp_backend:
            if version.parse(pl.__version__) < version.parse("1.5.0"):  # type: ignore
                raise ValueError("PyTorch Lightning>=1.5.0 is required in DDP.")
            if not (
                isinstance(self._trial.study._storage, _CachedStorage)
                and isinstance(self._trial.study._storage._backend, RDBStorage)
            ):
                raise ValueError(
                    "optuna.integration.PyTorchLightningPruningCallback"
                    " supports only optuna.storages.RDBStorage in DDP."
                )

    def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:

        # When the trainer calls `on_validation_end` for sanity check,
        # do not call `trial.report` to avoid calling `trial.report` multiple times
        # at epoch 0. The related page is
        # https://github.com/PyTorchLightning/pytorch-lightning/issues/1391.
        if trainer.sanity_checking:
            return

        epoch = pl_module.current_epoch

        current_score = trainer.callback_metrics.get(self.monitor)
        if current_score is None:
            message = (
                "The metric '{}' is not in the evaluation logs for pruning. "
                "Please make sure you set the correct metric name.".format(self.monitor)
            )
            warnings.warn(message)
            return

        should_stop = False
        if trainer.is_global_zero:
            self._trial.report(current_score.item(), step=epoch)
            should_stop = self._trial.should_prune()
        # TODO: The following line breaks the current version of Pytorch
        # Lightning. But I suspect it's necessary in a distributed training
        # environment... so it shouldn't matter for us...
        # should_stop = trainer.training_type_plugin.broadcast(should_stop)
        trainer.should_stop = should_stop
        if not should_stop:
            return

        if not self.is_ddp_backend:
            message = "Trial was pruned at epoch {}.".format(epoch)
            raise optuna.TrialPruned(message)
        else:
            # Stop every DDP process if global rank 0 process decides to stop.
            trainer.should_stop = True
            if trainer.is_global_zero:
                self._trial.storage.set_trial_system_attr(self._trial._trial_id, _PRUNED_KEY, True)
                self._trial.storage.set_trial_system_attr(self._trial._trial_id, _EPOCH_KEY, epoch)

    def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
        if not self.is_ddp_backend:
            return

        # Because on_validation_end is executed in spawned processes,
        # _trial.report is necessary to update the memory in main process, not to update the RDB.
        _trial_id = self._trial._trial_id
        _study = self._trial.study
        _trial = _study._storage._backend.get_trial(_trial_id)  # type: ignore
        _trial_system_attrs = _study._storage.get_trial_system_attrs(_trial_id)
        is_pruned = _trial_system_attrs.get(_PRUNED_KEY)
        epoch = _trial_system_attrs.get(_EPOCH_KEY)
        intermediate_values = _trial.intermediate_values
        for step, value in intermediate_values.items():
            self._trial.report(value, step=step)

        if is_pruned:
            message = "Trial was pruned at epoch {}.".format(epoch)
            raise optuna.TrialPruned(message)

Thresholded early stopping callback:

In [None]:
from pytorch_lightning.callbacks import Callback

class ThresholdEarlyStopping(Callback):
    """PyTorch Lightning callback to stop training when a metric decreases after
    having reached a threshold.
    """
    
    def __init__(self, monitor: str = 'val_acc', threshold: float = 0.9):
        super().__init__()
        self.monitor = monitor
        self.threshold = threshold
        self.reached_threshold = False

    def on_validation_end(self, trainer, pl_module):
        metrics = trainer.callback_metrics
        score = metrics.get(self.monitor)

        if score is not None:
            if not self.reached_threshold and score >= self.threshold:
                self.reached_threshold = True
                # print(f"Metric {self.monitor} reached the threshold ({self.threshold}).")
            elif self.reached_threshold and score < self.threshold:
                trainer.should_stop = True
                # print(f"Metric {self.monitor} dropped below the threshold ({self.threshold}). Training will be stopped.")


### XGBoost

Generate specific datasets:

In [None]:
tasks = [
    'predict_active_inactive',
    # 'predict_pDC50_and_Dmax',
    # 'predict_pDC50',
    ]
fp_bits_options = [1024, 2048, 4096]
upsampled = [False] # [True, False]
fp_max_paths = list(range(8, 11))
radii = list(range(2, 11))
use_extra_features = [True, False]
experiments = (tasks, fp_bits_options, fp_max_paths, radii, upsampled, use_extra_features)

for subset in tqdm(list(itertools.product(*experiments))):
    task, fp_bits, fp_max_path, radius, use_upsampled, use_extra_features = subset
    protac_ds_kwargs = {
        'precompute_fingerprints': True,
        'use_morgan_fp': True,
        'use_maccs_fp': True,
        'use_path_fp': True,
        'morgan_atomic_radius': radius,
        'morgan_bits': fp_bits,
        'path_bits': fp_bits,
        'fp_min_path': 1,
        'fp_max_path': fp_max_path,
        'poi_gene_enc': poi_gene_enc,
        'poi_vectorizer': poi_encoder, # poi_vectorizer,
        'e3_ligase_enc': e3_encoder, # e3_ligase_enc,
        'cell_type_enc': cell_encoder, # cell_type_enc,
    }
    dataset_name = f'_fp{fp_bits}_radius{radius}_path1-{fp_max_path}'
    get_datasets(task,
                 use_upsampled,
                 dataset_name=dataset_name,
                 regenerate_datasets=False,
                 **protac_ds_kwargs)
    # print(f'Finished generating "protac-db{dataset_name}" datasets.')
print('Datasets are ready to use.')

Define Optuna objective:

In [None]:
from optuna.integration import XGBoostPruningCallback

def xgb_objective(trial,
                  fp_bits: int,
                  num_round: int = 10,
                  use_extra_features: bool | None = None,
                  task: Literal['predict_active_inactive', 'predict_pDC50_and_Dmax'] = 'predict_active_inactive'):
    # Inspired by: https://github.com/optuna/optuna-examples/blob/main/xgboost/xgboost_integration.py
    config = {
        # Fingerprint-specific
        'fp_type': trial.suggest_categorical('fp_type', ['morgan_fp', 'maccs_fp', 'path_fp']),
        'fp_radius': trial.suggest_int('fp_radius', 2, 10),
        'fp_max_path': trial.suggest_int('fp_max_path', 8, 10),
        # XGBoost-specific
        'booster': trial.suggest_categorical('booster', ['gbtree', 'gblinear', 'dart']),
        'lambda': trial.suggest_float('lambda', 1e-8, 1.0, log=True),
        'alpha': trial.suggest_float('alpha', 1e-8, 1.0, log=True),
    }
    if config['booster'] == 'gbtree' or config['booster'] == 'dart':
        config['max_depth'] = trial.suggest_int('max_depth', 1, 16)
        config['eta'] = trial.suggest_float('eta', 1e-8, 1.0, log=True)
        config['gamma'] = trial.suggest_float('gamma', 1e-8, 1.0, log=True)
        config['grow_policy'] = trial.suggest_categorical('grow_policy', ['depthwise', 'lossguide'])
    if config['booster'] == 'dart':
        config['sample_type'] = trial.suggest_categorical('sample_type', ['uniform', 'weighted'])
        config['normalize_type'] = trial.suggest_categorical('normalize_type', ['tree', 'forest'])
        config['rate_drop'] = trial.suggest_float('rate_drop', 1e-8, 1.0, log=True)
        config['skip_drop'] = trial.suggest_float('skip_drop', 1e-8, 1.0, log=True)
    if use_extra_features is None:
        config['fp_use_extra_features'] = trial.suggest_categorical('fp_use_extra_features', [True, False])
    else:
        config['fp_use_extra_features'] = use_extra_features
    # Reporting
    model_name = 'xgb_model'
    fp_bits = MACCS_BITWIDTH if config['fp_type'] == 'maccs_fp' else fp_bits
    eventid = f'{trial.datetime_start.strftime("%Y%m-%d%H-%M%S-")}{uuid4()}'
    trial_name = f'{config["fp_type"]}-{fp_bits}-{trial.number}-{eventid}'    
    model_checkpoint_dir = os.path.join(checkpoint_dir, 'xgboost')
    model_checkpoint = f'{model_name}-{trial_name}.bin'
    trial.set_user_attr('trial_name', trial_name)
    trial.set_user_attr('model_checkpoint', model_checkpoint)
    trial.set_user_attr('model_checkpoint_dir', model_checkpoint_dir)
    trial.set_user_attr('model_name', model_name)
    # Retrieve specific datasets
    tmp = 1024 if config['fp_type'] == 'maccs_fp' else fp_bits
    dataset_name = f'_fp{tmp}_radius{config["fp_radius"]}_path1-{config["fp_max_path"]}'
    trial.set_user_attr('dataset_name', dataset_name)
    ds = get_datasets(task, use_upsampled, dataset_name=dataset_name)
    train_ds = ds['train']
    val_ds = ds['val']
    test_ds = ds['test']
    # Setup XGBoost-specific datasets
    train_fp_data = train_ds.get_fingerprint(config['fp_type'])
    val_fp_data = val_ds.get_fingerprint(config['fp_type'])
    test_fp_data = test_ds.get_fingerprint(config['fp_type'])
    if config['fp_use_extra_features']:
        # Get POI sequence
        poi_seq_train = train_ds.poi_vectorizer.transform(train_ds.poi_seq)
        poi_seq_val = val_ds.poi_vectorizer.transform(val_ds.poi_seq)
        poi_seq_test = test_ds.poi_vectorizer.transform(test_ds.poi_seq)
        poi_seq_train = poi_seq_train.toarray().astype(np.int32)
        poi_seq_val = poi_seq_val.toarray().astype(np.int32)
        poi_seq_test = poi_seq_test.toarray().astype(np.int32)
        # Concatenate extra features
        train_fp_data = np.concatenate([
                                        train_fp_data,
                                        np.array(train_ds.e3_ligase)[:, np.newaxis],
                                        np.array(train_ds.cell_type)[:, np.newaxis],
                                        poi_seq_train,
                                        ], axis=-1)
        test_fp_data = np.concatenate([
                                        test_fp_data,
                                        np.array(test_ds.e3_ligase)[:, np.newaxis],
                                        np.array(test_ds.cell_type)[:, np.newaxis],
                                        poi_seq_test,
                                        ], axis=-1)
        val_fp_data = np.concatenate([
                                        val_fp_data,
                                        np.array(val_ds.e3_ligase)[:, np.newaxis],
                                        np.array(val_ds.cell_type)[:, np.newaxis],
                                        poi_seq_val,
                                        ], axis=-1)
    dtrain = xgb.DMatrix(train_fp_data, label=train_ds.active)
    dtest = xgb.DMatrix(test_fp_data, label=test_ds.active)
    dval = xgb.DMatrix(val_fp_data, label=val_ds.active)
    # Setup XGBoost-specific parameters
    xgb_params = {
        'objective': 'binary:logistic',
        'eval_metric': ['logloss', 'auc'],
        'seed': 42,
        'verbosity': 0,
        'silent': True,
    }
    xgb_params.update({k: config[k] for k in config.keys() if 'fp_' not in k})
    evallist = [(dtrain, 'train'), (dval, 'eval')]
    # Setup Optuna callback
    pruning_callback = XGBoostPruningCallback(trial, 'eval-auc')
    # Train XGBoost model via its training function
    bst = xgb.train(params=xgb_params,
                    dtrain=dtrain,
                    num_boost_round=num_round,
                    evals=evallist,
                    early_stopping_rounds=5,
                    callbacks=[pruning_callback],
                    verbose_eval=False)
    bst.save_model(os.path.join(model_checkpoint_dir, model_checkpoint))
    # Get and report score metrics
    def get_scores(phase):
        ds = val_ds if phase == 'val' else test_ds
        y = torch.tensor(ds.active.to_numpy()).flatten()
        dxgb = dval if phase == 'val' else dtest
        preds = torch.tensor(bst.predict(dxgb)).flatten()
        acc = binary_accuracy(preds, y, threshold=0.5).cpu().numpy()
        roc_auc = binary_auroc(preds, y.to(torch.long)).cpu().numpy()
        precision = binary_precision(preds, y.to(torch.long)).cpu().numpy()
        recall = binary_recall(preds, y.to(torch.long)).cpu().numpy()
        f1_score = binary_f1_score(preds, y).cpu().numpy()
        trial.set_user_attr(f'{phase}_acc', acc)
        trial.set_user_attr(f'{phase}_roc_auc', roc_auc)
        trial.set_user_attr(f'{phase}_precision', precision)
        trial.set_user_attr(f'{phase}_recall', recall)
        trial.set_user_attr(f'{phase}_f1_score', f1_score)
        opt_score = acc + roc_auc
        return opt_score
    get_scores('test')
    return get_scores('val')

Run experiments:

In [None]:
if not os.path.exists(os.path.join(checkpoint_dir, 'xgboost')):
    os.makedirs(os.path.join(checkpoint_dir, 'xgboost'))

# Define experiments design points
tasks = [
    'predict_active_inactive',
    # 'predict_pDC50_and_Dmax',
    # 'predict_pDC50',
    ]
fp_bits_options = [1024, 2048, 4096]
upsampled = [False] # [True, False]
use_extra_features = [True, False]
experiments = (tasks, fp_bits_options, upsampled, use_extra_features)
# Get all experiments combinations
n_experiments = 0
for subset in itertools.product(*experiments):
    task, fp_bits, use_upsampled, use_extra_features = subset
    if use_upsampled and task != 'predict_active_inactive':
        continue
    n_experiments += 1
# Set fixed parameters
num_round = 20
num_samples = 50
n_gpus = 1 if torch.cuda.is_available() else 0
# Define specific results dictionary in the global one
experiments_results['results_xgb'] = load_result('results_xgb')
if RETRAIN_XGBOOST or not experiments_results['results_xgb']:
    # Run experiments
    pl.utilities.memory.garbage_collection_cuda()
    i = 0
    best_ckpt = []
    for experiment_id in itertools.product(*experiments):
        task, fp_bits, use_upsampled, use_extra_features = experiment_id
        if use_upsampled and task != 'predict_active_inactive':
            continue
        print(f'-' * 80)
        print(f'Experiment n.{i + 1} ({i / n_experiments * 100.0:.2f}% complete):')
        print(f'\ttask: {task}')
        print(f'\tfp_bits: {fp_bits}')
        print(f'\tuse_upsampled: {use_upsampled}')
        print(f'-' * 80)
        # Run Optuna study
        direction = 'maximize' if task == 'predict_active_inactive' else 'minimize'
        # optuna_pruner = optuna.pruners.MedianPruner(n_warmup_steps=10)
        optuna_pruner = optuna.pruners.HyperbandPruner(min_resource=5,
                                                       max_resource=num_round,
                                                       reduction_factor=3)
        optuna_sampler = optuna.samplers.TPESampler(seed=42)
        study = optuna.create_study(direction=direction,
                                    pruner=optuna_pruner,
                                    sampler=optuna_sampler)
        study.optimize(lambda trial: xgb_objective(trial,
                                                task=task,
                                                fp_bits=fp_bits,
                                                use_extra_features=use_extra_features,
                                                num_round=num_round),
                    n_trials=num_samples,
                    timeout=600)
        trial = study.best_trial
        experiments_results['results_xgb'][experiment_id] = {}
        experiments_results['results_xgb'][experiment_id]['trial'] = trial
        experiments_results['results_xgb'][experiment_id]['fp_bits'] = fp_bits
        experiments_results['results_xgb'][experiment_id]['task'] = task
        experiments_results['results_xgb'][experiment_id]['use_upsampled'] = use_upsampled
        experiments_results['results_xgb'][experiment_id]['use_extra_features'] = use_extra_features
        # Reporting
        print('-' * 80)
        print(f'Experiment n.{i + 1} done ({(i + 1) / n_experiments * 100.0:.2f}% complete)')
        print('Number of finished trials: {}'.format(len(study.trials)))
        print(f'Best trial score: {trial.value}:')
        print_dict('Experiment:', experiments_results['results_xgb'][experiment_id])
        print_dict('Params:', trial.params)
        print_dict('Attributes:', trial.user_attrs)
        i += 1
        # Remove non-optimal checkpoints
        model_name = trial.user_attrs['model_name']
        checkpoint_root_dir = trial.user_attrs['model_checkpoint_dir']
        best_ckpt.append(trial.user_attrs['trial_name'])
        del_non_optimal_ckpt(checkpoint_root_dir, best_ckpt, model_name)
    # Save results
    save_results(experiments_results['results_xgb'], result_name='results_xgb')

In [None]:
for experiment_id, design_points in experiments_results['results_xgb'].items():
    task = design_points['task']
    trial = design_points['trial']
    print(f'Experiment: {experiment_id}')
    print_dict('Hyperparams:', trial.params)
    print_dict('Attributes:', trial.user_attrs)
    model = xgb.Booster()
    model_checkpoint_dir = trial.user_attrs['model_checkpoint_dir']
    model_checkpoint = trial.user_attrs['model_checkpoint']
    model.load_model(os.path.join(model_checkpoint_dir, model_checkpoint))
    print('-' * 80)

In [None]:
for experiment_id, design_points in experiments_results['results_xgb'].items():
    task = design_points['task']
    trial = design_points['trial']
    print(f'Experiment: {experiment_id}')
    print_dict('Hyperparams:', trial.params)
    print_dict('Attributes:', trial.user_attrs)
    model = xgb.Booster()
    model_checkpoint_dir = trial.user_attrs['model_checkpoint_dir']
    model_checkpoint = trial.user_attrs['model_checkpoint']
    model.load_model(os.path.join(model_checkpoint_dir, model_checkpoint))
    print('-' * 80)

Evaluation:

In [None]:
auc_results = {}
confusion_matrices = {}

for experiment_id, design_points in experiments_results['results_xgb'].items():
    task = design_points['task']
    trial = design_points['trial']
    fp_bits = design_points['fp_bits']
    use_upsampled = design_points['use_upsampled']
    # Retrieve specific datasets
    ds = get_datasets(task,
                      use_upsampled,
                      dataset_name=trial.user_attrs['dataset_name'],
                      regenerate_datasets=False,
                      **protac_ds_kwargs)
    train_ds = ds['train']
    val_ds = ds['val']
    test_ds = ds['test']
    # Load model
    model_checkpoint_dir = trial.user_attrs['model_checkpoint_dir']
    model_checkpoint = trial.user_attrs['model_checkpoint']
    best_model = xgb.Booster()
    best_model.load_model(os.path.join(model_checkpoint_dir, model_checkpoint))
    # Setup XGBoost-specific datasets
    test_fp_data = test_ds.get_fingerprint(trial.params['fp_type'])
    val_fp_data = val_ds.get_fingerprint(trial.params['fp_type'])
    if trial.params.get('fp_use_extra_features', False) or design_points.get('use_extra_features', False):
        use_extra_features = True
        poi_seq_train = train_ds.poi_vectorizer.transform(train_ds.poi_seq)
        poi_seq_val = val_ds.poi_vectorizer.transform(val_ds.poi_seq)
        poi_seq_test = test_ds.poi_vectorizer.transform(test_ds.poi_seq)
        poi_seq_train = poi_seq_train.toarray().astype(np.int32)
        poi_seq_val = poi_seq_val.toarray().astype(np.int32)
        poi_seq_test = poi_seq_test.toarray().astype(np.int32)
        test_fp_data = np.concatenate([
            test_fp_data,
            np.array(test_ds.e3_ligase)[:, np.newaxis],
            np.array(test_ds.cell_type)[:, np.newaxis],
            poi_seq_test,
        ], axis=-1)
        val_fp_data = np.concatenate([
            val_fp_data,
            np.array(val_ds.e3_ligase)[:, np.newaxis],
            np.array(val_ds.cell_type)[:, np.newaxis],
            poi_seq_val,
        ], axis=-1)
    else:
        use_extra_features = False
    auc_results[experiment_id] = {}
    for phase in ['val', 'test']:
        if phase == 'val':
            bin_labels = val_ds.active.to_numpy().flatten()
            dval = xgb.DMatrix(val_fp_data, label=val_ds.active)
            # Get predictions
            logits = best_model.predict(dval, iteration_range=(0, best_model.best_iteration + 1)).flatten()
        if phase == 'test':
            bin_labels = test_ds.active.to_numpy().flatten()
            dtest = xgb.DMatrix(test_fp_data, label=test_ds.active)
            # Get predictions
            logits = best_model.predict(dtest, iteration_range=(0, best_model.best_iteration + 1)).flatten()
        # Plot results
        fp_type = trial.params['fp_type']
        if task == 'predict_pDC50_and_Dmax':
            pass
        elif task == 'predict_pDC50':
            pass
        elif task == 'predict_active_inactive':
            bin_pred = logits >= 0.5
            descr = f'XGBoost [{fp_type.replace("_fp", "").upper()} {fp_bits}bits'
            descr += f'{" + Extra Feat.]" if use_extra_features else "]"}'
            # Get confusion matrix data
            cm = confusion_matrix(bin_labels.flatten().astype(int),
                                bin_pred.flatten().astype(int))
            confusion_matrices[experiment_id] = (ConfusionMatrixDisplay(cm), descr)
            # Get ROC AUC data
            fpr, tpr, _ = sklearn.metrics.roc_curve(bin_labels.flatten(), logits)
            auc = sklearn.metrics.auc(fpr, tpr)
            auc_results[experiment_id][phase] = (fpr, tpr, auc)
print('All evaluation results gathered.')

Plot ROC-AUC curve and confusion matrixes:

In [None]:
for phase in ['val', 'test']:
    # Dummy classifier
    dummy_pred = torch.tensor([1.0] * len(bin_labels))
    fpr, tpr, _ = sklearn.metrics.roc_curve(bin_labels.flatten(), dummy_pred)
    auc = sklearn.metrics.roc_auc_score(bin_labels.flatten(), dummy_pred)
    acc = binary_accuracy(dummy_pred, torch.tensor(bin_labels), threshold=0.5)
    plt.plot(fpr, tpr, label=f'Dummy (AUC = {auc:.2f}, Accuracy = {acc * 100:.2f}%)', color='black', lw=0.8, linestyle='--')

    # Plot other models results
    for experiment_id, design_points in experiments_results['results_xgb'].items():
        task = design_points['task']
        trial = design_points['trial']
        fp_bits = design_points['fp_bits']
        use_upsampled = design_points['use_upsampled']
        if task != 'predict_active_inactive':
            continue
        fp_type = trial.params['fp_type']
        acc = trial.user_attrs[f'{phase}_acc']
        fpr, tpr, auc = auc_results[experiment_id][phase]

        if trial.params.get('fp_use_extra_features', False) or design_points.get('use_extra_features', False):
            use_extra_features = True
        else:
            use_extra_features = False
        descr = f'XGBoost [{fp_type.replace("_fp", "").upper()} {fp_bits}bits'
        descr += f'{" + Extra Feat.]" if use_extra_features else "]"}'
        plt.plot(fpr, tpr, label=f'{descr} (AUC: {auc:.2f}, Accuracy: {acc * 100:.2f}%)')
    plt.grid('both', alpha=0.7)
    plt.legend(loc='upper left', bbox_to_anchor=(1, 1), ncol=1, fancybox=True) #, shadow=True)
    plt.title(f'XGBoost {"Validation" if phase == "val" else "Test"} Set ROC Curve')

    filename = os.path.join(fig_dir, f'roc_curve_{phase}_xgb')
    plt.savefig(filename + '.pdf', bbox_inches='tight')
    plt.savefig(filename + '.png', bbox_inches='tight')
    plt.show()
    plt.close()

    # Plot confusion matrixes
    for i, (_, cm) in enumerate(confusion_matrices.items()):
        disp, descr = cm
        disp.plot(cmap=plt.cm.Blues)
        plt.title(f'{descr}')
        plt.savefig(os.path.join(fig_dir, f'confusion_matrix_{phase}_xgb_n{i}.pdf'), bbox_inches='tight')
        plt.savefig(os.path.join(fig_dir, f'confusion_matrix_{phase}_xgb_n{i}.png'), bbox_inches='tight')
        # plt.show()
        plt.close()

Extract sub-molecule from fingerprint:

In [None]:
# TODO: Code generated with Google Bard, untested.
def get_smiles_parts_from_morgan_fingerprint(smiles, fingerprint):
    """
    Given a SMILES and its Morgan molecular fingerprint, returns the SMILES parts of the bits set to 1.

    Args:
        smiles: The SMILES of the molecule.
        fingerprint: The Morgan molecular fingerprint of the molecule.

    Returns:
        A list of SMILES strings, each of which corresponds to a substructure of the molecule whose bit is set to 1 in the fingerprint.
    """
    # Convert the SMILES to a molecule object.
    mol = Chem.MolFromSmiles(smiles)
    # Get the bit info map for the fingerprint.
    bit_info_map = fingerprint.GetBitInfoMap()
    # Create a list to store the SMILES parts.
    smiles_parts = []
    # Iterate over the bits in the fingerprint.
    for bit_id, atoms in bit_info_map.items():
        # Get the submolecule corresponding to the bit.
        submol = Chem.PathToSubmol(mol, atoms, atomMap={})
        # Get the SMILES of the submolecule.
        subsmiles = Chem.MolToSmiles(submol)
        # Add the SMILES of the submolecule to the list.
        smiles_parts.append(subsmiles)
    # Return the list of SMILES parts.
    return smiles_parts

### Generic Wrapper Model

[Instructions](https://github.com/Lightning-AI/lightning/discussions/7249#discussioncomment-677516) on how to save combined model hyperparameters (particularly useful when dealing with Transformers):

1. Generate sub-model within the wrapper model via a generator function
2. Save the sub-model as a separate model checkpoint
3. Supply the sub-model checkpoint in order to retrieve it via the generator function 

In [None]:
from torchmetrics import MetricCollection
from typing import Type

class WrapperModel(pl.LightningModule):

    def __init__(self,
                 smiles_encoder: Type[nn.Module],
                 smiles_encoder_args: Mapping | None = {},
                 train_dataset: ProtacDataset = None,
                 val_dataset: ProtacDataset = None,
                 test_dataset: ProtacDataset = None,
                 use_extra_features: bool = False,
                 poi_seq_embedding_size: int = 0,
                 hidden_channels_extra_features: List[int] = [32, 32],
                 norm_layer: object = nn.BatchNorm1d,
                 task: Literal['predict_active_inactive', 'predict_pDC50_and_Dmax'] = 'predict_active_inactive',
                 freeze_smiles_encoder: bool = False,
                 dropout: float = 0.5,
                 batch_size: int = 64,
                 learning_rate: float = 1e-3,
                 loss_function: Callable | object = nn.HuberLoss(),
                 **model_kwargs):
        """Wrapper class to make prediction on PROTAC data.

        Args:
            smiles_encoder_generator (Callable | None, optional): Function to generate and retrieve the SMILES encoder nn.Module. Defaults to None.
            smiles_encoder_generator_args (Mapping | None, optional): Arguments to the SMILES encoder generator function. NOTE: The arguments will be saved as hyperparameters. Defaults to None.
            smiles_encoder_checkpoint_path (str | None, optional): Additional argument suppied to the SMILES encoder generator function. NOTE: It should allow the generator function to retrieve the pretrained nn.Module object. Defaults to None.
            smiles_embedding_size (int, optional): SMILES embedding size, i.e., output dimension of the SMILES encoder. Defaults to 1.
            train_dataset (ProtacDataset, optional): Train dataset. Defaults to None.
            test_dataset (ProtacDataset, optional): Test dataset. Defaults to None.
            use_extra_features (bool, optional): Whether to include an additional MLP branch to process extra features. Defaults to False.
            hidden_channels_extra_features (List[int], optional): MLP hidden channels sizes of the extra features branch. Defaults to [32, 32].
            norm_layer (object, optional): Normalization layer to use in the extra features branch. Defaults to nn.BatchNorm1d.
            task (Literal[&#39;predict_active_inactive&#39;, &#39;predict_pDC50&#39;, &#39;predict_pDC50_and_Dmax&#39;], optional): Task to perform. Defaults to 'predict_active_inactive'.
            freeze_smiles_encoder (bool, optional): Whether to train the SMILES encoder parameters. Defaults to False.
            dropout (float, optional): Dropout factor for the extra features branch. Defaults to 0.5.
            batch_size (int, optional): Batch size. Defaults to 64.
            learning_rate (float, optional): Learning rate. Defaults to 1e-3.
            loss_function (Callable | object, optional): Loss function to be used for regression tasks. Defaults to nn.HuberLoss().
        """
        super().__init__()
        # Set our init args as class attributes
        self.__dict__.update(locals()) # Add arguments as attributes
        # Save the arguments passed to init
        ignore_args_as_hyperparams = [
            'train_dataset',
            'test_dataset',
            'val_dataset',
            'loss_function',
        ]
        self.save_hyperparameters(ignore=ignore_args_as_hyperparams) 
        # Define or load SMILES encoder sub-model if not supplied
        self.smiles_encoder = smiles_encoder(**smiles_encoder_args)
        if freeze_smiles_encoder:
            self.smiles_encoder.freeze()
        # Define sub-model branch for processing "extra features"
        if use_extra_features:
            extra_features_size = 2 # Cell type and E3 ligase
            if self.poi_seq_embedding_size <= 0:
                if train_dataset is not None:
                    self.poi_seq_embedding_size = train_dataset.get_poi_seq_emb_size()
                if test_dataset is not None:
                    assert train_dataset.get_poi_seq_emb_size() == test_dataset.get_poi_seq_emb_size(), 'POI sequence embedding size mismatch between train and test datasets.'
                    self.poi_seq_embedding_size = test_dataset.get_poi_seq_emb_size()
                if val_dataset is not None:
                    assert train_dataset.get_poi_seq_emb_size() == val_dataset.get_poi_seq_emb_size(), 'POI sequence embedding size mismatch between train and val datasets.'
                    self.poi_seq_embedding_size = val_dataset.get_poi_seq_emb_size()
            extra_features_size += self.poi_seq_embedding_size
            self.extra_features_encoder = MLP(in_channels=extra_features_size,
                                              hidden_channels=hidden_channels_extra_features,
                                              norm_layer=norm_layer,
                                              inplace=False,
                                              dropout=dropout)
        # Define prediction head
        head_inputs = self.smiles_encoder.get_smiles_embedding_size()
        if use_extra_features:
            head_inputs += hidden_channels_extra_features[-1]
        num_outputs = 2 if task == 'predict_pDC50_and_Dmax' else 1
        self.head = nn.Linear(head_inputs, num_outputs)
        self.sigmoid = nn.Sigmoid()
        # Losses
        self.regr_loss = loss_function
        self.bin_loss = nn.BCEWithLogitsLoss()
        # Metrics, a separate metrics collection is defined for each phase
        # NOTE: According to the PyTorch Lightning docs, "similar" metrics,
        # i.e., requiring the same computation, should be optimized w/in a
        # metrics collection.
        phases = ['train_metrics', 'val_metrics', 'test_metrics']
        self.metrics = nn.ModuleDict({p: MetricCollection({
            'acc': Accuracy(task='binary'),
            'roc_auc': AUROC(task='binary'),
            'precision': Precision(task='binary'),
            'recall': Recall(task='binary'),
            'f1_score': F1Score(task='binary'),
            'opt_score': Accuracy(task='binary') + F1Score(task='binary'),
            'hp_metric': Accuracy(task='binary'),
        }, prefix=p.replace('metrics', '')) for p in phases})
        # Misc
        self.missing_dataset_error = \
            '''Class variable `{0}` is None. If the model was loaded from a checkpoint, the dataset must be set manually:
            
            model = {1}.load_from_checkpoint('checkpoint.ckpt')
            model.{0} = my_{0}
            '''

    def forward(self, x_in):
        if self.use_extra_features:
            e3_ligase = x_in['e3_ligase']
            cell_type = x_in['cell_type']
            if self.poi_seq_embedding_size > 0:
                x = torch.cat((e3_ligase, cell_type, x_in['poi_seq']), dim=-1)
            else:
                x = torch.cat((e3_ligase, cell_type), dim=-1)
            extra_emb = self.extra_features_encoder(x)
            smiles_emb = self.smiles_encoder(x_in)
            x = torch.cat((extra_emb, smiles_emb), dim=-1)
        else:
            x = self.smiles_encoder(x_in)
        return self.head(x)

    def step(self, batch, phase='train'):
        y = batch['labels']
        preds = self.forward(batch)
        if self.task == 'predict_active_inactive':
            loss = self.bin_loss(preds, y)
            self.metrics[f'{phase}_metrics'].update(preds, y)
            self.log(f'{phase}_loss', loss, on_epoch=True, prog_bar=True)
            self.log_dict(self.metrics[f'{phase}_metrics'], on_epoch=True)
        else:
            loss = self.regr_loss(preds, y)
            self.log(f'{phase}_loss', loss, on_epoch=True, prog_bar=True)
            if phase == 'val':
                self.log('hp_metric', loss)
        return loss

    def training_step(self, batch, batch_idx):
        return self.step(batch, phase='train')

    def validation_step(self, batch, batch_idx):
        return self.step(batch, phase='val')

    def test_step(self, batch, batch_idx):
        return self.step(batch, phase='test')

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    def load_smiles_encoder(self, checkpoint_path):
        ckpt = torch.load(checkpoint_path, map_location=self.device)
        self.smiles_encoder.load_state_dict(ckpt, strict=False)

    # def prepare_data(self):
    #     train_ds = os.path.join(data_dir, 'protac', f'train_dataset_fp{self.fp_bits}.pt')
    #     test_ds = os.path.join(data_dir, 'protac', f'test_dataset_fp{self.fp_bits}.pt')
    #     self.train_dataset = torch.load(train_ds)
    #     self.train_dataset = torch.load(train_ds)
    #     self.test_dataset = torch.load(test_ds)

    def train_dataloader(self):
        if self.train_dataset is None:
            format = 'train_dataset', self.__class__.__name__
            raise ValueError(self.missing_dataset_error.format(*format))
        return DataLoader(self.train_dataset, batch_size=self.batch_size,
                          shuffle=True, collate_fn=custom_collate,
                          drop_last=True)

    def val_dataloader(self):
        if self.val_dataset is None:
            format = 'val_dataset', self.__class__.__name__
            raise ValueError(self.missing_dataset_error.format(*format))
        return DataLoader(self.val_dataset, batch_size=self.batch_size,
                          shuffle=False, collate_fn=custom_collate)

    def test_dataloader(self):
        if self.test_dataset is None:
            format = 'test_dataset', self.__class__.__name__
            raise ValueError(self.missing_dataset_error.format(*format))
        return DataLoader(self.test_dataset, batch_size=self.batch_size,
                          shuffle=False, collate_fn=custom_collate)

Training function for generic wrapper model to be used either standalone for bare Pytorch Lightning training or for hyperparameter tuning via Optuna (or Ray Tune):

In [None]:
def trial_set_dict(trial: optuna.trial.Trial, d: dict) -> None:
    """Set a dictionary of user parameters on a Optuna trial object."""
    for k, v in d.items():
        trial.set_user_attr(k, v)

In [None]:
def train_model(reporting_config: dict,
                smiles_encoder: Type[nn.Module],
                smiles_encoder_args: Mapping,
                train_dataset: ProtacDataset,
                val_dataset: ProtacDataset,
                test_dataset: ProtacDataset,
                num_epochs: int = 10,
                task: Literal['predict_active_inactive', 'predict_pDC50_and_Dmax'] = 'predict_active_inactive',
                loss_func: Callable | object = nn.HuberLoss(),
                num_gpus: int = 0,
                use_raytune: bool = False,
                trial: optuna.trial.Trial | None = None,
                accumulate_grad_batches: int = 1,
                enable_checkpointing: bool = True,
                enable_tensorboard_logging: bool = False,
                model: nn.Module | None = None,
                **model_kwargs: Mapping | None) -> float | None:
    """Training function for generic wrapper model to be used either standalone for bare Pytorch Lightning training or for hyperparameter tuning via Optuna (or Ray Tune).
    TODO: Add kwargs for Pytorch Lightning Trainer class.

    Args:
        reporting_config (dict): Dictionary for checkpoint naming and reporting.
        smiles_encoder (Callable): Generator function for the sub-model, to be passed to the wrapper model.
        smiles_encoder_args (Mapping): Arguments to the sub-model generator function.
        smiles_embedding_size (int): Size of the sub-model's output embedding.
        train_dataset (ProtacDataset): The training dataset.
        test_dataset (ProtacDataset): The test dataset.
        num_epochs (int, optional): Number of epochs to train to model to. Defaults to 10.
        task (Literal[&#39;predict_active_inactive&#39;, &#39;predict_pDC50&#39;, &#39;predict_pDC50_and_Dmax&#39;], optional): Task to configure the model to. Defaults to 'predict_active_inactive'.
        loss_func (Callable | object, optional): Loss function for regression tasks. Defaults to nn.HuberLoss().
        num_gpus (int, optional): Number of GPUs to use for training. Defaults to 0.
        use_raytune (bool, optional): Use Ray Tune for training. DEPRECATED. Defaults to True.
        trial (optuna.trial.Trial | None, optional): Optuna trial object. Defaults to None.
        smiles_encoder_save_function (Callable | None, optional): Custom function to save SMILES encoder sub-model. Defaults to None.
        smiles_encoder_save_function_kwargs (Mapping, optional): Arguments to the saving functions of the SMILES encoder. Defaults to {}.
    """
    # TODO: The dataset is currently not loaded from disk but passed as
    # argument. The problem lies in the definition of ProtacDataset, which is
    # not recognized by ray.tune. This should be fixed in the future.

    # Namings
    # if use_raytune:
    #     trial_name = f'{config["fp_type"]}-{fp_bits}-{tune.get_trial_id()}'
    lightning_dir = reporting_config['lightning_dir']
    model_checkpoint_dir = reporting_config['model_checkpoint_dir']
    model_checkpoint = reporting_config['model_checkpoint']
    tensorboard_dir = reporting_config['tensorboard_dir']
    trial_name = reporting_config['trial_name']
    model_name = reporting_config['model_name']
    # Setup Pytorch Lightning wrapper model
    poi_seq_embedding_size = 0
    if model_kwargs is not None:
        if model_kwargs.get('use_extra_features', False):
            assert train_dataset.get_poi_seq_emb_size() == val_dataset.get_poi_seq_emb_size() == test_dataset.get_poi_seq_emb_size(), f'POI sequence embedding sizes of train, val and test dataset must match.'
            poi_seq_embedding_size = train_dataset.get_poi_seq_emb_size()
    if model is None:
        model = WrapperModel(task=task,
                            smiles_encoder=smiles_encoder,
                            smiles_encoder_args=smiles_encoder_args,
                            poi_seq_embedding_size=poi_seq_embedding_size,
                            train_dataset=train_dataset,
                            val_dataset=val_dataset,
                            test_dataset=test_dataset,
                            loss_function=loss_func,
                            norm_layer=nn.BatchNorm1d,
                            **model_kwargs)
        # TODO: Pytorch compile not yet supported on Windows
        if os.name != 'nt':
            model = torch.compile(model)
    # Configure Pytorch Lightning loggers
    loggers = []
    if enable_tensorboard_logging:
        loggers.append(TensorBoardLogger(save_dir=tensorboard_dir,
                                         name=trial_name,
                                         version='.',
                                         default_hp_metric=True))
    if not use_raytune and trial is not None:
        save_dir = os.path.join(lightning_dir, 'logs', f'logs_{model_name}-{trial_name}')
        csv_logger = CSVLogger(save_dir=save_dir)
        loggers.append(csv_logger)
    # Setup Pytorch Lightning callbacks
    if task == 'predict_active_inactive':
        # Keep track of additional metrics when predicting active/inactive
        # NOTE: Key == Ray Tune name, Value == Pytorch Lightning name
        metrics = {
            'train_loss': 'train_loss',
            'train_acc': 'train_acc',
            'val_loss': 'val_loss',
            'val_acc': 'val_acc',
            'val_opt_score': 'val_opt_score',
            'val_roc_auc': 'val_roc_auc',
            'val_precision': 'val_precision',
            'val_recall': 'val_recall',
            'val_f1_score': 'val_f1_score',
        }
        monitor = 'val_acc'
        mode = 'max'
    else:
        metrics = {'val_loss': 'val_loss', 'train_loss': 'train_loss'}
        monitor = 'val_loss'
        mode = 'min'
    callbacks = []
    if use_raytune:
        callbacks.append(
            TuneReportCheckpointCallback(metrics,
                                         filename='checkpoint.ckpt',
                                         on='validation_end'),
            # TuneReportCallback(metrics, on='validation_end'),
        )
    else:
        save_top_k = 1 if enable_checkpointing else 0
        callbacks.append(ModelCheckpoint(dirpath=model_checkpoint_dir,
                                         filename=model_checkpoint,
                                         save_top_k=save_top_k,
                                         monitor=monitor,
                                         mode=mode))
        if trial is not None:
            callbacks.append(
                CustomPyTorchLightningPruningCallback(trial, monitor=monitor)
            )
    callbacks.extend([
        EarlyStopping(monitor='val_loss', mode='min', patience=5, check_finite=True),
        EarlyStopping(monitor='val_acc', mode='max', patience=5),
    ])
    # Instantiate Pytorch Lightning Trainer
    if torch.cuda.is_available() and num_gpus > 0:
        pl_devices = math.ceil(num_gpus)
        accelerator = 'gpu'
        precision = '16-mixed'
    else:
        pl_devices = 4
        accelerator = 'auto'
        precision = '32'
    trainer = pl.Trainer(max_epochs=num_epochs,
                         gradient_clip_val=1.0,
                         gradient_clip_algorithm='norm',
                         accumulate_grad_batches=accumulate_grad_batches,
                         log_every_n_steps=8,
                         callbacks=callbacks,
                         accelerator=accelerator,
                         devices=pl_devices,
                         precision=precision,
                         enable_checkpointing=True, # not use_raytune,
                         logger=loggers,
                         enable_progress_bar=False,
                         enable_model_summary=False)
    # Finally, start training
    trainer.fit(model)

    # # Reload best model checkpoint
    # # NOTE: We need to reload the best model because the checkpointing callback
    # # automatically saves the model with best accuracy achieved DURING training,
    # # not necessarily the model AT THE END of training.
    # if enable_checkpointing:
    #     best_model_path = trainer.checkpoint_callback.best_model_path
    #     model = WrapperModel.load_from_checkpoint(best_model_path)
    #     model.train_dataset = train_dataset
    #     model.val_dataset = val_dataset
    #     model.test_dataset = test_dataset

    # Report metrics and, if using Optuna, return the trained model score
    if not use_raytune and trial is not None:
        trainer_log_dir = trainer.loggers[-1].log_dir
        trial.set_user_attr('trainer_log_dir', trainer_log_dir)
        for metric in metrics.values():
            trial.set_user_attr(metric, trainer.callback_metrics[metric].item())
        if enable_checkpointing:
            best_model_path = trainer.checkpoint_callback.best_model_path
            trial.set_user_attr('model_checkpoint', best_model_path)
        # Return in case the function is used as an optimization func. in Optuna
        if task == 'predict_active_inactive':
            if enable_checkpointing:
                eval_results = get_eval_results(model, task, num_gpus,
                                            trainer=trainer, return_logits=False)
                trial_set_dict(trial, eval_results)
                return eval_results['val_acc']
            else:
                metrics_path = os.path.join(trainer_log_dir, 'metrics.csv')
                best_score = pd.read_csv(metrics_path)[monitor].astype(float).max()
                return best_score
        else:
            return trainer.callback_metrics['val_loss'].item()

Generic objective body:

In [None]:
def objective_body(trial: optuna.trial.Trial,
                   model_name: str,
                   trial_name: str,
                   dataset_name: str,
                   smiles_encoder: Type[nn.Module],
                   smiles_encoder_gen_args: Mapping,
                   model_kwargs: Mapping,
                   num_epochs: int = 10,
                   task: Literal['predict_active_inactive', 'predict_pDC50_and_Dmax'] = 'predict_active_inactive',
                   enable_checkpointing: bool = True,
                   loss_func: Callable | object = nn.HuberLoss(),
                   use_upsampled: bool = False,
                   num_gpus: int = 0,
                   ) -> float:
    ds = get_datasets(task, use_upsampled, dataset_name=dataset_name)
    train_dataset = ds['train']
    val_dataset = ds['val']
    test_dataset = ds['test']
    trial.set_user_attr('dataset_name', dataset_name)
    # Standard namings for reporting
    lightning_dir = os.path.join(checkpoint_dir, 'lightning')
    model_checkpoint_dir = os.path.join(lightning_dir, 'models')
    model_checkpoint = f'{model_name}-{trial_name}'
    tensorboard_dir = os.path.join(lightning_dir, 'tensorboard', f'{model_name}_{task}')
    reporting_config = {
        'trial_name': trial_name,
        'model_name': model_name,
        'lightning_dir': lightning_dir,
        'model_checkpoint_dir': model_checkpoint_dir,
        'model_checkpoint': model_checkpoint,
        'tensorboard_dir': tensorboard_dir,
    }
    trial_set_dict(trial, reporting_config)
    # Train model via its generic function
    return train_model(reporting_config=reporting_config,
                       smiles_encoder=smiles_encoder,
                       smiles_encoder_args=smiles_encoder_gen_args,
                       train_dataset=train_dataset,
                       val_dataset=val_dataset,
                       test_dataset=test_dataset,
                       # Optional arguments
                       num_epochs=num_epochs,
                       task=task,
                       loss_func=loss_func,
                       num_gpus=num_gpus,
                       use_raytune=False,
                       enable_checkpointing=enable_checkpointing,
                       trial=trial,
                       **model_kwargs)

Plot training curves:

In [None]:
def plot_training_curves(trainer_logs: str, experiment: str = '', figpath: str | None = None):
    metrics = pd.read_csv(os.path.join(trainer_logs, 'metrics.csv'))
    del metrics['step']
    del metrics['val_hp_metric']
    metrics = metrics.set_index('epoch').groupby('epoch').max()
    train_cols = [
        'train_loss_epoch',
        'train_acc_epoch',
    ]
    val_cols = [
        'val_loss',
        'val_acc',
    ]
    test_cols = [
        'test_loss',
        'test_acc',
    ]
    # display(metrics[train_cols + val_cols + test_cols].dropna(axis=1, how='all'))
    print(metrics[train_cols + val_cols + test_cols].dropna(axis=1, how='all'))
    ax = sns.relplot(data=metrics[train_cols + val_cols], kind='line')
    # Plot legend in a nicer way...
    ax._legend.remove()
    plt.legend(loc='upper left', bbox_to_anchor=(1, 1), ncol=1) #, fancybox=True, shadow=True)
    plt.title(f'Training curves{experiment}')
    plt.grid(alpha=0.7)
    if figpath is not None:
        plt.savefig(figpath + '.pdf', bbox_inches='tight')
        plt.savefig(figpath + '.png', bbox_inches='tight')
        plt.close()
    else:
        plt.show()

Evaluation function:

In [None]:
n_gpus = 1 if torch.cuda.is_available() else 0

def evaluate_experiment(dataset_name: str,
                        task: str,
                        descr: str = '',
                        plot_dummy: bool = False,
                        model_checkpoint: str | None = None,
                        plot_auc: bool = True,
                        phase: Literal['train', 'val', 'test'] = ['val', 'test'],
                        best_model: nn.Module | None = None) -> Tuple | None:
    """Evaluate a trained model on a specific dataset.

    Args:
        descr (str): Model description. Used for plotting.
        dataset_name (str): Dataset to retrieve via `get_datasets`.
        task (str): Task to evaluate.
        plot_dummy (bool, optional): Whether to plot the Dummy classifier. Defaults to False.
        model_checkpoint (str | None, optional): Model checkpoint to load. Defaults to None.
        plot_auc (bool, optional): Whether to plot the AUC curve. Defaults to True.
        best_model (nn.Module | None, optional): Model to evaluate. Defaults to None.

    Raises:
        ValueError: In case no model nor checkpoint is provided.

    Returns:
        Tuple | None: If binary classification: dictionary with evaluation results and Confusion matrix display.
    """
    # Load model
    if best_model is None:
        if model_checkpoint is not None:
            best_model = WrapperModel.load_from_checkpoint(model_checkpoint)
        else:
            raise ValueError('Either best_model or model_checkpoint must be provided.')
    # Retrieve specific datasets and setup train/test datasets to run evaluation
    ds = get_datasets(task, use_upsampled, dataset_name=dataset_name)    
    best_model.train_dataset = ds['train']
    best_model.val_dataset = ds['val']
    best_model.test_dataset = ds['test']
    # Get predictions    
    preds = get_eval_results(best_model, task, phase=phase, num_gpus=n_gpus)
    # Plot results
    if task == 'predict_pDC50_and_Dmax':
        # TODO: Regression task
        pass
    elif task == 'predict_active_inactive':
        best_guess = 1.0 if phase == 'test' else 0.0
        bin_labels = preds[f'{phase}_labels'].flatten().astype(int)
        if plot_auc:
            # Plot Dummy classifier
            if plot_dummy:
                dummy_pred = torch.tensor([best_guess] * len(bin_labels))
                fpr, tpr, _ = sklearn.metrics.roc_curve(bin_labels, dummy_pred)
                auc = sklearn.metrics.roc_auc_score(bin_labels, dummy_pred)
                acc = binary_accuracy(dummy_pred, torch.tensor(bin_labels))
                dummy_descr = f'Dummy (AUC = {auc:.2f}, Accuracy = {acc * 100:.2f}%)'
                plt.plot(fpr, tpr, label=dummy_descr, color='black', lw=0.8, linestyle='--')
            # ROC-AUC curve plotting
            l = f'{descr} (AUC: {preds[f"{phase}_roc_auc"]:.2f}, Accuracy: {preds[f"{phase}_acc"] * 100:.2f}%)'
            plt.plot(preds[f"{phase}_fpr"], preds[f"{phase}_tpr"], label=l)
        return preds, ConfusionMatrixDisplay(preds[f'{phase}_confusion_matrix'])

### SMILES as Fingerprints - MLPs

Define the MLP-based fingerprint sub-model:

In [None]:
class FingerprintSubModel(pl.LightningModule):

    def __init__(self,
                 fp_type: Literal['morgan_fp', 'maccs_fp', 'path_fp'] = 'morgan_fp',
                 fp_bits: int = 1024,
                 hidden_channels: List[int] = [128, 128],
                 norm_layer: object = nn.BatchNorm1d,
                 dropout: float = 0.5):
        super().__init__()
        # Set our init args as class attributes
        self.__dict__.update(locals()) # Add arguments as attributes
        self.save_hyperparameters()
        self.fp_bits = MACCS_BITWIDTH if fp_type == 'maccs_fp' else fp_bits
        # Define PyTorch model
        self.fp_encoder = MLP(in_channels=self.fp_bits,
                              hidden_channels=hidden_channels,
                              norm_layer=norm_layer,
                              inplace=False,
                              dropout=dropout)

    def forward(self, x_in):
        return self.fp_encoder(x_in[self.fp_type])

    def get_smiles_embedding_size(self):
        return self.hidden_channels[-1]

Generate experiment-specific datasets:

In [None]:
tasks = [
    'predict_active_inactive',
    # 'predict_pDC50_and_Dmax',
    # 'predict_pDC50',
    ]
fp_bits_options = [1024, 2048, 4096]
upsampled = [False] # [True, False]
fp_max_paths = list(range(8, 11))
radii = list(range(2, 11))
experiments = (tasks, fp_bits_options, fp_max_paths, radii, upsampled)

for subset in tqdm(list(itertools.product(*experiments))):
    task, fp_bits, fp_max_path, radius, use_upsampled = subset
    protac_ds_kwargs = {
        'precompute_fingerprints': True,
        'use_morgan_fp': True,
        'use_maccs_fp': True,
        'use_path_fp': True,
        'morgan_atomic_radius': radius,
        'morgan_bits': fp_bits,
        'path_bits': fp_bits,
        'fp_min_path': 1,
        'fp_max_path': fp_max_path,
        'poi_gene_enc': poi_gene_enc,
        'poi_vectorizer': poi_encoder, # poi_vectorizer,
        'e3_ligase_enc': e3_encoder, # e3_ligase_enc,
        'cell_type_enc': cell_encoder, # cell_type_enc,
    }
    dataset_name = f'_fp{fp_bits}_radius{radius}_path1-{fp_max_path}'
    get_datasets(task,
                 use_upsampled,
                 dataset_name=dataset_name,
                 regenerate_datasets=False,
                 **protac_ds_kwargs)
    # print(f'Finished generating "protac-db{dataset_name}" datasets.')
print('Datasets are ready to use.')

Define Optuna objective:

In [None]:
def fp_objective(trial,
                 fp_bits: int,
                 num_epochs:int = 10,
                 task: Literal['predict_active_inactive', 'predict_pDC50_and_Dmax'] = 'predict_active_inactive',
                 loss_func: Callable | object = nn.HuberLoss(),
                 enable_checkpointing: bool = True,
                 use_upsampled: bool = False,
                 num_gpus: int = 0):
    # ==========================================================================
    # Model-specific objective code
    # ==========================================================================
    fp_radius = trial.suggest_int('radius', 2, 10)
    fp_max_path = trial.suggest_int('fp_max_path', 8, 10)
    # Setup SMILES Encoder arguments
    num_layers = trial.suggest_int('num_layers', 2, 8)
    hidden_channels = [
        trial.suggest_int(f'smiles_enc_kwargs_layer_{i}_size', 32, 1024, step=32) for i in range(num_layers)
    ]
    smiles_encoder_gen_args = {
        'fp_type': trial.suggest_categorical('smiles_enc_kwargs_fp_type', ['morgan_fp', 'maccs_fp', 'path_fp']),
        'fp_bits': fp_bits,
        'hidden_channels': hidden_channels,
        'norm_layer': nn.BatchNorm1d,
        'dropout': trial.suggest_float('smiles_enc_kwargs_dropout', 0.1, 0.8),
    }
    smiles_encoder = FingerprintSubModel
    # Setup Wrapper Model arguments
    num_layers_extra = trial.suggest_int('num_layers_extra', 2, 8)
    hidden_channels_extra_features = [
        trial.suggest_int(f'model_kwargs_layer_{i}_size', 32, 1024, step=32) for i in range(num_layers_extra)
    ]
    model_kwargs = {
        'use_extra_features': True, # trial.suggest_categorical('use_extra_features', [True, False]),
        'hidden_channels_extra_features': hidden_channels_extra_features,
        'dropout': trial.suggest_float('model_kwargs_dropout', 0.1, 0.8),
        'learning_rate': trial.suggest_float('model_kwargs_learning_rate', 1e-5, 1e-2, log=True),
        'batch_size': trial.suggest_categorical('model_kwargs_batch_size', [16, 32, 64, 128]),
    }
    # Retrieve specific datasets
    tmp = 1024 if smiles_encoder_gen_args['fp_type'] == 'maccs_fp' else fp_bits
    dataset_name = f'_fp{tmp}_radius{fp_radius}_path1-{fp_max_path}'
    # Model-specific namings for reporting
    model_name = 'fp_model'
    fp_bits = MACCS_BITWIDTH if smiles_encoder_gen_args['fp_type'] == 'maccs_fp' else fp_bits
    eventid = f'{trial.datetime_start.strftime("%Y%m%d-%H-%M-%S-")}{uuid4()}'
    trial_name = f'{smiles_encoder_gen_args["fp_type"]}-{fp_bits}-{trial.number}-{eventid}'
    
    
    # ==========================================================================
    # Standard and common code for all objectives
    # ==========================================================================
    ds = get_datasets(task, use_upsampled, dataset_name=dataset_name)
    train_dataset = ds['train']
    val_dataset = ds['val']
    test_dataset = ds['test']
    trial.set_user_attr('dataset_name', dataset_name)
    # Standard namings for reporting
    lightning_dir = os.path.join(checkpoint_dir, 'lightning')
    model_checkpoint_dir = os.path.join(lightning_dir, 'models')
    model_checkpoint = f'{model_name}-{trial_name}'
    tensorboard_dir = os.path.join(lightning_dir, 'tensorboard', f'{model_name}_{task}')
    reporting_config = {
        'trial_name': trial_name,
        'model_name': model_name,
        'lightning_dir': lightning_dir,
        'model_checkpoint_dir': model_checkpoint_dir,
        'model_checkpoint': model_checkpoint,
        'tensorboard_dir': tensorboard_dir,
    }
    trial_set_dict(trial, reporting_config)
    # Train model via its generic function
    return train_model(reporting_config=reporting_config,
                       smiles_encoder=smiles_encoder,
                       smiles_encoder_args=smiles_encoder_gen_args,
                       train_dataset=train_dataset,
                       val_dataset=val_dataset,
                       test_dataset=test_dataset,
                       # Optional arguments
                       num_epochs=num_epochs,
                       task=task,
                       loss_func=loss_func,
                       num_gpus=num_gpus,
                       use_raytune=False,
                       enable_checkpointing=enable_checkpointing,
                       trial=trial,
                       **model_kwargs)

Run experiments:

* NOTE: Somehow, automagically, Raytune takes care of the PytorchLighting Dataloaders, eventually generating mess
* NOTE: Batch size of 32 creates a last batch of size 1, thus breaking the batch norm layer. Setting `drop_last=True` should help, but somehow RayTune fails anyway...
* NOTE: Also having a too large batch size breaks RayTune... Maybe because it's unable to set it to a proper size which doesn't exceed the actual dataset size (like Pythorch does by default)

That's why, I abandoned RayTune for now, and I'm using just Optuna instead.

In [None]:
# Define experiments design points
tasks = [
    'predict_active_inactive',
    # 'predict_pDC50_and_Dmax',
    # 'predict_pDC50',
    ]
fp_bits_options = [1024, 2048, 4096]
upsampled = [False] # [True, False]
experiments = (tasks, fp_bits_options, upsampled)
# Get all experiments combinations
n_experiments = 0
for subset in itertools.product(*experiments):
    task, fp_bits, use_upsampled = subset  
    if use_upsampled and task != 'predict_active_inactive':
        continue
    n_experiments += 1
# Set fixed parameters
num_epochs = 15
num_samples = 20 # 1000
n_gpus = 1 if torch.cuda.is_available() else 0
# loss_func = mean_absolute_error
# loss_func = mean_squared_error
loss_func = nn.HuberLoss(reduction='mean', delta=0.8) # Default 1.0
# Define specific results dictionary in the global one
experiments_results['results_fp'] = load_result('results_fp')
if RETRAIN_FP_MODEL or not experiments_results['results_fp']:
    # Run experiments
    i = 0
    best_ckpt = [] # Used for removing non-optimal checkpoints
    for experiment_id in itertools.product(*experiments):
        task, fp_bits, use_upsampled = experiment_id
        if use_upsampled and task != 'predict_active_inactive':
            continue
        print(f'-' * 80)
        print(f'Experiment n.{i + 1} ({i / n_experiments * 100.0:.2f}% complete):')
        print(f'\ttask: {task}')
        print(f'\tfp_bits: {fp_bits}')
        print(f'\tuse_upsampled: {use_upsampled}')
        print(f'-' * 80)
        # Run Optuna study
        direction = 'maximize' if task == 'predict_active_inactive' else 'minimize'
        # optuna_pruner = optuna.pruners.MedianPruner(n_warmup_steps=10)
        optuna_pruner = optuna.pruners.HyperbandPruner(min_resource=2,
                                                       max_resource=num_epochs,
                                                       reduction_factor=3)
        optuna_sampler = optuna.samplers.TPESampler(seed=42)
        study = optuna.create_study(direction=direction,
                                    pruner=optuna_pruner,
                                    sampler=optuna_sampler)    
        study.optimize(lambda trial: fp_objective(trial,
                                                task=task,
                                                fp_bits=fp_bits,
                                                num_epochs=num_epochs,
                                                loss_func=loss_func,
                                                enable_checkpointing=False,
                                                num_gpus=n_gpus),
                    n_trials=num_samples,
                    timeout=600 * 2)
        trial = study.best_trial
        experiments_results['results_fp'][experiment_id] = {}
        experiments_results['results_fp'][experiment_id]['trial'] = trial
        experiments_results['results_fp'][experiment_id]['fp_bits'] = fp_bits
        experiments_results['results_fp'][experiment_id]['task'] = task
        experiments_results['results_fp'][experiment_id]['use_upsampled'] = use_upsampled
        # ======================================================================
        # Retrain best model and store its checkpoint
        # ======================================================================
        # Setup SMILES Encoder arguments
        num_layers = trial.params['num_layers']
        hidden_channels = [
            trial.params[f'smiles_enc_kwargs_layer_{i}_size'] for i in range(num_layers)
        ]
        smiles_encoder_gen_args = {
            'fp_type': trial.params['smiles_enc_kwargs_fp_type'],
            'fp_bits': fp_bits,
            'hidden_channels': hidden_channels,
            'norm_layer': nn.BatchNorm1d,
            'dropout': trial.params['smiles_enc_kwargs_dropout'],
        }
        smiles_encoder = FingerprintSubModel
        smiles_embedding_size = hidden_channels[-1]
        # Setup Wrapper Model arguments
        num_layers_extra = trial.params['num_layers_extra']
        hidden_channels_extra_features = [
            trial.params[f'model_kwargs_layer_{i}_size'] for i in range(num_layers_extra)
        ]
        model_kwargs = {
            'use_extra_features': True,
            'hidden_channels_extra_features': hidden_channels_extra_features,
            'dropout': trial.params['model_kwargs_dropout'],
            'learning_rate': trial.params['model_kwargs_learning_rate'],
            'batch_size': trial.params['model_kwargs_batch_size'],
        }
        ds = get_datasets(task, use_upsampled, dataset_name=trial.user_attrs['dataset_name'])
        reporting_config = {
            'trial_name': trial.user_attrs['trial_name'],
            'model_name': trial.user_attrs['model_name'],
            'lightning_dir': trial.user_attrs['lightning_dir'],
            'model_checkpoint_dir': trial.user_attrs['model_checkpoint_dir'],
            'model_checkpoint': trial.user_attrs['model_checkpoint'],
            'tensorboard_dir': trial.user_attrs['tensorboard_dir'],
        }
        train_model(reporting_config=reporting_config,
                    smiles_encoder=smiles_encoder,
                    smiles_encoder_args=smiles_encoder_gen_args,
                    train_dataset=ds['train'],
                    val_dataset=ds['val'],
                    test_dataset=ds['test'],
                    # Optional arguments
                    num_epochs=num_epochs,
                    task=task,
                    loss_func=loss_func,
                    num_gpus=n_gpus,
                    use_raytune=False,
                    enable_checkpointing=True,
                    trial=trial,
                    **model_kwargs)
        # Reporting
        print('-' * 80)
        print(f'Experiment n.{i + 1} done ({(i + 1) / n_experiments * 100.0:.2f}% complete)')
        print('Number of finished trials: {}'.format(len(study.trials)))
        print(f'Best trial score: {trial.value}:')
        print_dict('Experiment:', experiments_results['results_fp'][experiment_id])
        print_dict('Params:', trial.params)
        print_dict('Attributes:', trial.user_attrs)
        i += 1
        # Remove non-optimal checkpoints
        model_name = trial.user_attrs['model_name']
        checkpoint_root_dir = trial.user_attrs['model_checkpoint_dir']
        best_ckpt.append(trial.user_attrs['trial_name'])
        del_non_optimal_ckpt(checkpoint_root_dir, best_ckpt, model_name)
        # Plotting training curves
        trainer_logs = trial.user_attrs['trainer_log_dir']
        descr=f' for {fp_bits}bit'
        figpath = os.path.join(fig_dir, f'training_curves_{task}_fp{fp_bits}')
        plot_training_curves(trainer_logs, experiment=descr, figpath=figpath)
    save_results(experiments_results['results_fp'], result_name='results_fp')

Evaluation:

In [None]:
for phase in ['val', 'test']:
    confusion_matrices = {}
    plot_dummy = True

    for experiment_id, design_points in experiments_results['results_fp'].items():
        trial = design_points['trial']
        task = design_points['task']
        # Model-specific description
        fp_bits = design_points['fp_bits']
        use_upsampled = design_points['use_upsampled']
        fp_type = trial.params['smiles_enc_kwargs_fp_type']
        descr = f'MLP [{fp_type.replace("_fp", "").upper()} {fp_bits}bits]'
        preds, cm = evaluate_experiment(task=task,
                                        descr=descr,
                                        dataset_name=trial.user_attrs['dataset_name'],
                                        model_checkpoint=trial.user_attrs['model_checkpoint'],
                                        plot_dummy=plot_dummy,
                                        phase=phase,
                                        plot_auc=True)
        experiments_results['results_fp'][experiment_id]['trial'].user_attrs.update(preds)
        save_results(experiments_results['results_fp'],
                     result_name='results_fp')
        confusion_matrices[experiment_id] = (cm, descr)
        plot_dummy = False
        print_dict(f'Evaluation results for {descr}:', preds)
        print('-' * 80)
    plt.grid('both', alpha=0.7)
    plt.legend(loc='upper left', bbox_to_anchor=(1, 1), ncol=1, fancybox=True) #, shadow=True)
    plt.title(f'MLP {"Validation" if phase == "val" else "Test"} Set ROC Curve')

    filename = os.path.join(fig_dir, f'roc_curve_{phase}_fp')
    plt.savefig(filename + '.pdf', bbox_inches='tight')
    plt.savefig(filename + '.png', bbox_inches='tight')
    plt.show()
    plt.close()

    # Plot confusion matrixes:
    for i, (_, (disp, descr)) in enumerate(confusion_matrices.items()):
        disp.plot(cmap=plt.cm.Blues)
        plt.title(f'{descr}')
        filename = os.path.join(fig_dir, f'confusion_matrix_{phase}_fp_n{i}')
        plt.savefig(filename + '.pdf', bbox_inches='tight')
        plt.savefig(filename + '.png', bbox_inches='tight')
        # plt.show()
        plt.close()

In [None]:
# %tensorboard --logdir {checkpoint_dir}/lightning/tensorboard/

### SMILES as Graphs - GNNs

#### Pre-made Models

Generate sub-model:

In [None]:
class GnnSubModel(pl.LightningModule):

    def __init__(self,
                 model_type: Literal['gin', 'gat', 'gcn', 'attentivefp'] = 'gin',
                 hidden_channels: int = 32,
                 num_layers: int = 3,
                 out_channels: int = 8,
                 dropout: float = 0.1,
                 act: Literal['relu', 'elu'] = 'relu',
                 jk: Literal['max', 'last', 'cat', 'lstm'] = 'max',
                 norm: Literal['batch', 'layer'] = 'batch',
                 num_timesteps: int = 16):
        super().__init__()
        # Set our init args as class attributes
        self.__dict__.update(locals()) # Add arguments as attributes
        self.save_hyperparameters()
        self.smiles_embedding_size = out_channels
        if model_type == 'gin':
            self.smiles_embedding_size = hidden_channels
            self.gnn = geom_nn.models.GIN(in_channels=num_node_features,
                                          hidden_channels=hidden_channels,
                                          num_layers=num_layers,
                                          dropout=dropout,
                                          act=act,
                                          norm=norm,
                                          jk=jk)
        elif model_type == 'gat':
            self.gnn = geom_nn.models.GAT(in_channels=num_node_features,
                                          hidden_channels=hidden_channels,
                                          num_layers=num_layers,
                                          out_channels=out_channels,
                                          dropout=dropout,
                                          act=act,
                                          norm=norm,
                                          jk=jk)
        elif model_type == 'gcn':
            self.gnn = geom_nn.models.GCN(in_channels=num_node_features,
                                          hidden_channels=hidden_channels,
                                          num_layers=num_layers,
                                          out_channels=out_channels,
                                          dropout=dropout,
                                          act=act,
                                          norm=norm,
                                          jk=jk)
        elif model_type == 'attentivefp':
            self.gnn = geom_nn.models.AttentiveFP(in_channels=num_node_features,
                                                  hidden_channels=hidden_channels,
                                                  out_channels=out_channels,
                                                  edge_dim=node_edge_dim,
                                                  num_layers=num_layers,
                                                  num_timesteps=num_timesteps,
                                                  dropout=dropout)
        else:
            raise ValueError(f'Unknown model type: {model_type}. Available: gin, gat, gcn, attentivefp')
        
        
    def forward(self, batch):
        if self.model_type == 'gin':
            x = self.gnn(batch['smiles_graph'].x,
                         batch['smiles_graph'].edge_index)
            smiles_emb = geom_nn.global_add_pool(x, batch['smiles_graph'].batch)
        elif self.model_type == 'gat':
            x = self.gnn(x=batch['smiles_graph'].x.to(torch.float),
                         edge_index=batch['smiles_graph'].edge_index,
                         edge_attr=batch['smiles_graph'].edge_attr)
            smiles_emb = geom_nn.global_add_pool(x, batch['smiles_graph'].batch)
        elif self.model_type == 'gcn':
            x = self.gnn(x=batch['smiles_graph'].x.to(torch.float),
                         edge_index=batch['smiles_graph'].edge_index,
                         edge_attr=batch['smiles_graph'].edge_attr)
            smiles_emb = geom_nn.global_add_pool(x, batch['smiles_graph'].batch)
        elif self.model_type == 'attentivefp':
            smiles_emb = self.gnn(batch['smiles_graph'].x.to(torch.float),
                                  batch['smiles_graph'].edge_index,
                                  batch['smiles_graph'].edge_attr,
                                  batch['smiles_graph'].batch)
        return smiles_emb
    
    def get_smiles_embedding_size(self):
        return self.smiles_embedding_size

Generate experiment-specific datasets:

In [None]:
tasks = [
    'predict_active_inactive',
    # 'predict_pDC50_and_Dmax',
    # 'predict_pDC50',
    ]
upsampled = [False] # [True, False]
experiments = (tasks, upsampled)

for subset in itertools.product(*experiments):
    task, use_upsampled = subset
    protac_ds_kwargs = {
        'precompute_smiles_as_graphs': True,
        'poi_vectorizer': poi_vectorizer,
        'e3_ligase_enc': e3_ligase_enc,
        'poi_gene_enc': poi_gene_enc,
        'cell_type_enc': cell_type_enc,
    }
    dataset_name = f'_{task}{"_upsampled" if use_upsampled else ""}_graph'
    get_datasets(task,
                 use_upsampled,
                 dataset_name=dataset_name,
                 regenerate_datasets=False,
                 **protac_ds_kwargs)
print('Datasets are ready to use.')

Define Optuna objective:

In [None]:
def gnn_objective(trial,
                  gnn_type: Literal['gin', 'gat', 'gcn', 'attentivefp'] = 'gin',
                  num_epochs: int = 10,
                  task: Literal['predict_active_inactive', 'predict_pDC50_and_Dmax'] = 'predict_active_inactive',
                  loss_func: Callable | object = nn.HuberLoss(),
                  enable_checkpointing: bool = True,
                  num_gpus: int = 0) -> float:
    # ==========================================================================
    # Model-specific objective code
    # ==========================================================================
    # Setup SMILES Encoder arguments
    trial.set_user_attr('gnn_type', gnn_type)
    layer_sizes = [64, 128, 256, 512, 768]
    smiles_encoder_gen_args = {
        'hidden_channels': trial.suggest_int('hidden_channels', 64, 768, step=64), # trial.suggest_categorical('hidden_channels', layer_sizes),
        'num_layers': trial.suggest_int('num_layers', 2, 8),
        'dropout': trial.suggest_float('dropout', 0.01, 0.8),
    }
    if gnn_type != 'gin':
        smiles_encoder_gen_args['out_channels'] = trial.suggest_categorical('out_channels', layer_sizes)    
    if gnn_type == 'attentivefp':
        smiles_encoder_gen_args['num_timesteps'] = trial.suggest_categorical('num_timesteps', [8, 16, 32] + layer_sizes)
    else:
        smiles_encoder_gen_args['jk'] = trial.suggest_categorical('jk', ['max', 'last', 'cat', 'lstm'])
    smiles_encoder_gen_args['model_type'] = gnn_type
    smiles_encoder = GnnSubModel
    # Setup Wrapper Model arguments
    num_layers_extra = trial.suggest_int('num_layers_extra', 2, 8)
    hidden_channels_extra_features = [
        trial.suggest_int(f'model_kwargs_layer_{i}_size', 64, 512, step=32) for i in range(num_layers_extra)
    ]
    model_kwargs = {
        'use_extra_features': True, # trial.suggest_categorical('use_extra_features', [True, False]),
        'hidden_channels_extra_features': hidden_channels_extra_features,
        'dropout': trial.suggest_float('model_kwargs_dropout', 0.01, 0.8),
        'learning_rate': trial.suggest_float('model_kwargs_learning_rate', 1e-5, 1e-2, log=True),
        'batch_size': trial.suggest_categorical('model_kwargs_batch_size', [4, 8]),
    }
    accumulate_grad_batches = trial.suggest_categorical('accumulate_grad_batches', [1, 2, 4, 8])
    # Retrieve specific datasets
    use_upsampled = False
    dataset_name = f'_{task}{"_upsampled" if use_upsampled else ""}_graph'
    trial.set_user_attr('dataset_name', dataset_name)
    ds = get_datasets(task, use_upsampled, dataset_name=dataset_name)
    train_dataset = ds['train']
    val_dataset = ds['val']
    test_dataset = ds['test']
    # Model-specific namings for reporting
    model_name = 'gnn_model'
    eventid = f'{trial.datetime_start.strftime("%Y%m%d-%H-%M-%S-")}{uuid4()}'
    trial_name = f'{gnn_type}-{trial.number}-{eventid}'
    # ==========================================================================
    # Standard and common code for all objectives
    # ==========================================================================
    # Standard namings for reporting
    lightning_dir = os.path.join(checkpoint_dir, 'lightning')
    model_checkpoint_dir = os.path.join(lightning_dir, 'models')
    model_checkpoint = f'{model_name}-{trial_name}'
    tensorboard_dir = os.path.join(lightning_dir, 'tensorboard', f'{model_name}_{task}')
    reporting_config = {
        'trial_name': trial_name,
        'model_name': model_name,
        'lightning_dir': lightning_dir,
        'model_checkpoint_dir': model_checkpoint_dir,
        'model_checkpoint': model_checkpoint,
        'tensorboard_dir': tensorboard_dir,
    }
    trial_set_dict(trial, reporting_config)
    # Train model via its generic function
    return train_model(reporting_config=reporting_config,
                       smiles_encoder=smiles_encoder,
                       smiles_encoder_args=smiles_encoder_gen_args,
                       train_dataset=train_dataset,
                       val_dataset=val_dataset,
                       test_dataset=test_dataset,
                       # Optional arguments
                       num_epochs=num_epochs,
                       task=task,
                       loss_func=loss_func,
                       num_gpus=num_gpus,
                       use_raytune=False,
                       enable_checkpointing=enable_checkpointing,
                       accumulate_grad_batches=accumulate_grad_batches,
                       trial=trial,
                       **model_kwargs)

Run experiments:

In [None]:
# Define experiments design points
tasks = [
    'predict_active_inactive',
    # 'predict_pDC50_and_Dmax',
    # 'predict_pDC50',
    ]
upsampled = [False] # [True, False]
gnn_types = [
    'attentivefp',
    'gat',
    'gcn',
    'gin',
]
experiments = (tasks, upsampled, gnn_types)
# Get all experiments combinations
n_experiments = 0
for subset in itertools.product(*experiments):
    task, use_upsampled, gnn_type = subset
    if use_upsampled and task != 'predict_active_inactive':
        continue
    n_experiments += 1
# Set fixed parameters
num_epochs = 50
num_samples = 1000
n_gpus = 1 if torch.cuda.is_available() else 0
# loss_func = mean_absolute_error
# loss_func = mean_squared_error
loss_func = nn.HuberLoss(reduction='mean', delta=0.8) # Default 1.0
# Define specific results dictionary in the global one
experiments_results['results_gnn'] = load_result('results_gnn')
if RETRAIN_GNN_MODEL or not experiments_results['results_gnn']:
    # Run experiments
    pl.utilities.memory.garbage_collection_cuda()
    i = 0
    best_ckpt = []
    for experiment_id in itertools.product(*experiments):
        task, use_upsampled, gnn_type = experiment_id
        if use_upsampled and task != 'predict_active_inactive':
            continue
        print(f'-' * 80)
        print(f'Experiment n.{i + 1} ({i / n_experiments * 100.0:.2f}% complete):')
        experiments_results['results_gnn'][experiment_id] = {}
        experiments_results['results_gnn'][experiment_id]['task'] = task
        experiments_results['results_gnn'][experiment_id]['use_upsampled'] = use_upsampled
        experiments_results['results_gnn'][experiment_id]['gnn_type'] = gnn_type
        print_dict('Experiment:', experiments_results['results_gnn'][experiment_id])
        print(f'-' * 80)
        # Run Optuna study
        direction = 'maximize' if task == 'predict_active_inactive' else 'minimize'
        # optuna_pruner = optuna.pruners.MedianPruner(n_warmup_steps=10)
        optuna_pruner = optuna.pruners.HyperbandPruner(min_resource=2,
                                                       max_resource=num_epochs,
                                                       reduction_factor=3)
        optuna_sampler = optuna.samplers.TPESampler(seed=42)
        study = optuna.create_study(direction=direction,
                                    pruner=optuna_pruner,
                                    sampler=optuna_sampler)
        study.optimize(lambda trial: gnn_objective(trial,
                                                   task=task,
                                                   gnn_type=gnn_type,
                                                   num_epochs=num_epochs,
                                                   loss_func=loss_func,
                                                   enable_checkpointing=True,
                                                   num_gpus=n_gpus),
                       n_trials=num_samples,
                       timeout=600)
        trial = study.best_trial
        experiments_results['results_gnn'][experiment_id]['trial'] = trial
        # Reporting
        print('-' * 80)
        print(f'Experiment n.{i + 1} done ({(i + 1) / n_experiments * 100.0:.2f}% complete)')
        print('Number of finished trials: {}'.format(len(study.trials)))
        print(f'Best trial score: {trial.value}:')
        print_dict('Experiment:', experiments_results['results_gnn'][experiment_id])
        print_dict('Params:', trial.params)
        print_dict('Attributes:', trial.user_attrs)
        # Remove non-optimal checkpoints
        model_name = trial.user_attrs['model_name']
        checkpoint_root_dir = trial.user_attrs['model_checkpoint_dir']
        best_ckpt.append(trial.user_attrs['trial_name'])
        del_non_optimal_ckpt(checkpoint_root_dir, best_ckpt, model_name)
        # Plotting training curves
        trainer_logs = trial.user_attrs['trainer_log_dir']
        descr = f' for GNN ({gnn_type.upper()})'
        figpath = os.path.join(fig_dir, f'training_curves_{task}_{gnn_type}')
        plot_training_curves(trainer_logs, experiment=descr, figpath=figpath)
        i += 1
    save_results(experiments_results['results_gnn'], result_name='results_gnn')

Evaluation:

In [None]:
for phase in ['val', 'test']:
    confusion_matrices = {}
    plot_dummy = True

    for experiment_id, design_points in experiments_results['results_gnn'].items():
        trial = design_points['trial']
        task = design_points['task']
        # Model-specific description
        gnn_type = trial.user_attrs["gnn_type"]
        descr = f'GNN ({gnn_type[0].upper()}{gnn_type[1:]})'
        preds, cm = evaluate_experiment(task=task,
                                        descr=descr,
                                        dataset_name=trial.user_attrs['dataset_name'],
                                        model_checkpoint=trial.user_attrs['model_checkpoint'],
                                        plot_dummy=plot_dummy,
                                        phase=phase,
                                        plot_auc=True)
        experiments_results['results_gnn'][experiment_id]['trial'].user_attrs.update(preds)
        save_results(experiments_results['results_gnn'],
                     result_name='results_gnn')
        confusion_matrices[experiment_id] = (cm, descr)
        plot_dummy = False
        print_dict(f'Evaluation results for {descr}:', preds)
        print('-' * 80)
    plt.grid('both', alpha=0.7)
    plt.legend(loc='upper left', bbox_to_anchor=(1, 1), ncol=1, fancybox=True) #, shadow=True)
    plt.title(f'GNN {"Validation" if phase == "val" else "Test"} Set ROC Curve')

    filename = os.path.join(fig_dir, f'roc_curve_{phase}_gnn')
    plt.savefig(filename + '.pdf', bbox_inches='tight')
    plt.savefig(filename + '.png', bbox_inches='tight')
    plt.show()
    plt.close()

    # Plot confusion matrixes:
    for i, (_, (disp, descr)) in enumerate(confusion_matrices.items()):
        disp.plot(cmap=plt.cm.Blues)
        plt.title(f'{descr}')
        filename = os.path.join(fig_dir, f'confusion_matrix_{phase}_gnn_n{i}')
        plt.savefig(filename + '.pdf', bbox_inches='tight')
        plt.savefig(filename + '.png', bbox_inches='tight')
        # plt.show()
        plt.close()

### SMILES as Sentences - Transformers

#### SSL via Finetuning MLM

In [None]:
ssl_df = pd.read_csv(os.path.join(data_dir, 'protac', 'protac-db_ssl.csv'))
ssl_df = ssl_df.dropna(subset=['Smiles_nostereo'])
print(f'Length of SSL dataframe before removing SMILES duplicates: {len(ssl_df)}')
print(ssl_df.shape)
print(len(train_bin_df))
ssl_df = pd.concat([ssl_df, train_bin_df], axis=0)
print(ssl_df.shape)
ssl_df = ssl_df.drop_duplicates(subset=['Smiles_nostereo'])
print(f'Length of SSL dataframe after removing SMILES duplicates: {len(ssl_df)}')
print(ssl_df.shape)

Load Tokenizer and setup PROTAC dataset:

In [None]:
from transformers import (
    AutoTokenizer,
    AutoModelForMaskedLM,
    TrainingArguments, Trainer,
    DataCollatorForLanguageModeling,
    RobertaTokenizerFast,
    RobertaForMaskedLM,    
)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

PRETRAINED_NLP_MODEL = 'seyonec/ChemBERTa-zinc-base-v1'
# PRETRAINED_NLP_MODEL = 'DeepChem/ChemBERTa-10M-MTR'
CHEMBERT_MLM_FOR_PROTACS = os.path.join(checkpoint_dir, 'chembert_mlm_for_protacs_' + PRETRAINED_NLP_MODEL.split('/')[-1])

pretrained_bert_models = [
    'entropy/roberta_zinc_480m',
    'seyonec/ChemBERTa-zinc-base-v1',
    'DeepChem/ChemBERTa-10M-MTR',
]
ssl_bert_models = ['SSL_' + b.split('/')[-1] for b in pretrained_bert_models]

##### Perplexity Score

Refer to this [StackOverflow question](https://stackoverflow.com/questions/70464428/how-to-calculate-perplexity-of-a-sentence-using-huggingface-masked-language-mode).

> From the huggingface documentation [here](https://huggingface.co/docs/transformers/perplexity) they mentioned that perplexity "is not well defined for masked language models like BERT".
>
> There is a paper [Masked Language Model Scoring](https://arxiv.org/abs/1910.14659) that explores pseudo-perplexity from masked language models and shows that pseudo-perplexity, while not being theoretically well justified, still performs well for comparing "naturalness" of texts.

In [None]:
# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# def perplexity_score(model, tokenizer, sentence):
#     tensor_input = tokenizer.encode(sentence, return_tensors='pt')
#     print(f'tensor_input: {tensor_input}')
#     repeat_input = tensor_input.repeat(tensor_input.size(-1) - 2, 1)
#     mask = torch.ones(tensor_input.size(-1) - 1).diag(1)[:-2]
#     masked_input = repeat_input.masked_fill(mask == 1, tokenizer.mask_token_id)
#     labels = repeat_input.masked_fill(masked_input != tokenizer.mask_token_id, -100)
    
#     masked_input = masked_input.to(device)
#     labels = labels.to(device)
#     with torch.inference_mode():
#         loss = model(masked_input, labels=labels).loss
#     return np.exp(loss.item())

Scaled perplexity score over the entire dataset.

> Typically, averaging occurs before exponentiation (which corresponds to the geometric average of exponentiated losses). The rationale is that we consider individual sentences as statistically independent, and so their joint probability is the product of their individual probability. Thus, by computing the geometric average of individual perplexities, we in some sense spread this joint probability evenly across sentences.

In [None]:
def perplexity_score(model, tokenizer, dataset):
    loss = torch.zeros(1)
    for elem in dataset:
        tensor_input = tokenizer.encode(elem['smiles'], return_tensors='pt')
        repeat_input = tensor_input.repeat(tensor_input.size(-1) - 2, 1)
        mask = torch.ones(tensor_input.size(-1) - 1).diag(1)[:-2]
        masked_input = repeat_input.masked_fill(mask == 1, tokenizer.mask_token_id)
        labels = repeat_input.masked_fill(masked_input != tokenizer.mask_token_id, -100)
        masked_input = masked_input.to(device)
        labels = labels.to(device)
        with torch.inference_mode():
            loss += model(masked_input, labels=labels).loss.item()
    loss /= len(dataset)
    return np.exp(loss.item())

In [None]:
# from datasets import load_metric
# import evaluate

# # metric = evaluate.load('perplexity')
# metric = load_metric('perplexity')
# metric.compute(model_id=PRETRAINED_NLP_MODEL, input_texts=test_dataset)['mean_perplexity']

##### SSL Training

In [None]:
experiments_results['results_transformer_ssl'] = load_result('results_transformer_ssl')
# experiments_results['results_transformer_ssl'] = None

if RETRAIN_SSL_MODEL or experiments_results['results_transformer_ssl'] is None:
    experiments_results['results_transformer_ssl'] = {}
    for bert_model, ssl_bert_model in zip(pretrained_bert_models, ssl_bert_models):
        chembert_mlm_for_protacs = os.path.join(checkpoint_dir, ssl_bert_model)
        # if os.path.exists(chembert_mlm_for_protacs):
        #     continue

        # Load pretrained model and tokenizer
        if bert_model == 'entropy/roberta_zinc_480m':
            tokenizer = RobertaTokenizerFast.from_pretrained(bert_model,
                                                             max_len=128)
            model = RobertaForMaskedLM.from_pretrained(bert_model)
        else:
            tokenizer = AutoTokenizer.from_pretrained(bert_model)
            model = AutoModelForMaskedLM.from_pretrained(bert_model,
                                                         output_hidden_states=True)
        # Create a data collator (used in Trainer)
        tokenizer.pad_token = tokenizer.eos_token
        data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer,
                                                        mlm_probability=0.15)
        # Generate PROTAC-Datasets
        ssl_df['active'] = None
        ssl_dataset = ProtacDataset(ssl_df,
                                    use_for_ssl=True,
                                    smiles_tokenizer=tokenizer)
        val_dataset = ProtacDataset(val_df,
                                    use_for_ssl=True,
                                    smiles_tokenizer=tokenizer)
        # Get perplexity score BEFORE training
        model.to(device)
        # Setup the Trainer
        training_args = TrainingArguments(
            output_dir=chembert_mlm_for_protacs,
            evaluation_strategy='epoch',
            learning_rate=2e-5,
            num_train_epochs=5,
            weight_decay=0.01,
            optim='adamw_torch',
            gradient_accumulation_steps=4,
            per_device_train_batch_size=32,
            per_device_eval_batch_size=16,
            log_level='info',
            logging_strategy='steps',
            logging_steps=20,
            push_to_hub=False,
            fp16=torch.cuda.is_available(),
            report_to='all',
            seed=42,
        )
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=ssl_dataset,
            eval_dataset=val_dataset,
            data_collator=data_collator,
        )
        # Evaluate model BEFORE training
        model.eval()
        eval_results = trainer.evaluate()
        train_perplexity = perplexity_score(model, tokenizer, ProtacDataset(train_df, use_for_ssl=True))
        val_perplexity = perplexity_score(model, tokenizer, ProtacDataset(val_df, use_for_ssl=True))
        print(f"{bert_model} MLM Perplexity BEFORE training (Huggingface): {math.exp(eval_results['eval_loss']):.2f}")
        print(f'{bert_model} MLM perplexity on train BEFORE training: {train_perplexity:.2f}')
        print(f'{bert_model} MLM perplexity on test BEFORE training: {val_perplexity:.2f}')
        experiments_results['results_transformer_ssl'][ssl_bert_model] = {
            'perplexity_huggingface_before': math.exp(eval_results['eval_loss']),
            'train_perplexity_before': train_perplexity,
            'val_perplexity_before': val_perplexity,
        }
        # Train the model and save it
        pl.utilities.memory.garbage_collection_cuda()
        model.train()
        trainer.train()
        trainer.save_model()
        tokenizer.save_pretrained(chembert_mlm_for_protacs)
        # Evaluate model AFTER training
        model.eval()
        eval_results = trainer.evaluate()
        train_perplexity = perplexity_score(model, tokenizer, ProtacDataset(train_df, use_for_ssl=True))
        val_perplexity = perplexity_score(model, tokenizer, ProtacDataset(val_df, use_for_ssl=True))
        tmp = {
            'perplexity_huggingface_after': math.exp(eval_results['eval_loss']),
            'train_perplexity_after': train_perplexity,
            'val_perplexity_after': val_perplexity,
        }
        experiments_results['results_transformer_ssl'][ssl_bert_model].update(tmp)
        print_dict('Training scores:', eval_results)
        print(f"{bert_model} MLM Perplexity AFTER training (Huggingface): {math.exp(eval_results['eval_loss']):.2f}")
        print(f'{bert_model} MLM perplexity on train dataset AFTER training: {train_perplexity:.2f}')
        print(f'{bert_model} MLM perplexity on test dataset AFTER training: {val_perplexity:.2f}')
        print_dict('Perplexity:', experiments_results['results_transformer_ssl'][ssl_bert_model])
        print('-' * 80)
    save_results(experiments_results['results_transformer_ssl'], result_name='results_transformer_ssl')
else:
    print_dict('SSL Results:', experiments_results['results_transformer_ssl'])

In [None]:
for k, v in load_result('results_transformer_ssl').items():
    print_dict(k, v)

#### Train BERT-based Model

Generate experiment-specific datasets:

In [None]:
for task in ['predict_active_inactive']: # 'predict_pDC50_and_Dmax',
    for bert_model, ssl_bert_model in zip(pretrained_bert_models, ssl_bert_models):
        # Generate PROTAC-Datasets for pretrained BERT model
        if bert_model == 'entropy/roberta_zinc_480m':
            tokenizer = RobertaTokenizerFast.from_pretrained(bert_model,
                                                             max_len=128)
        else:
            tokenizer = AutoTokenizer.from_pretrained(bert_model)
        print(f'Tokenizer: {tokenizer}')
        dataset_name = f'_tokenized_{bert_model.split("/")[-1]}'
        protac_ds_kwargs = {
            'poi_vectorizer': poi_vectorizer,
            'e3_ligase_enc': e3_ligase_enc,
            'poi_gene_enc': poi_gene_enc,
            'cell_type_enc': cell_type_enc,
        }
        if task == 'predict_active_inactive':
            for upsampled in [False]: # [False, True]
                protac_ds_kwargs['smiles_tokenizer'] = tokenizer
                ds = get_datasets(task,
                             upsampled,
                             dataset_name=dataset_name,
                             regenerate_datasets=True,
                             **protac_ds_kwargs)
                dl = DataLoader(ds['train'], batch_size=8, collate_fn=custom_collate, drop_last=True)
                # batch = next(iter(dl))
                # for k, v in batch.items():
                #     if k == 'smiles_tokenized':
                #         # print_dict(k, v)
                #         for k, v in v.items():
                #             print(k, v.size())
                #     else:
                #         print(k, v.size())
                # print('')
        else:
            protac_ds_kwargs['smiles_tokenizer'] = tokenizer
            ds = get_datasets(task,
                         dataset_name=dataset_name,
                         regenerate_datasets=True,
                         **protac_ds_kwargs)
        # Generate PROTAC-Datasets for SSL-ed BERT model
        chembert_mlm_for_protacs = os.path.join(checkpoint_dir, ssl_bert_model)
        tokenizer = AutoTokenizer.from_pretrained(chembert_mlm_for_protacs)
        print(f'Tokenizer (from SSL): {tokenizer}')
        
        dataset_name = f'_tokenized_{ssl_bert_model}'
        protac_ds_kwargs['smiles_tokenizer'] = tokenizer
        if task == 'predict_active_inactive':
            for upsampled in [False]: # [False, True]
                ds = get_datasets(task,
                             upsampled,
                             dataset_name=dataset_name,
                             regenerate_datasets=True,
                             **protac_ds_kwargs)
                dl = DataLoader(ds['train'], batch_size=1, collate_fn=custom_collate, drop_last=True)
                # batch = next(iter(dl))
                # for k, v in batch.items():
                #     if k == 'smiles_tokenized':
                #         # print_dict(k, v)
                #         for k, v in v.items():
                #             print(k, v.size())
                #     else:
                #         print(k, v.size())
                # print('')
        else:
            get_datasets(task,
                         dataset_name=dataset_name,
                         regenerate_datasets=True,
                         **protac_ds_kwargs)
print('Datasets ready to use.')    

Generate and restore ChemBERT model functions:

(Check the this [question](https://datascience.stackexchange.com/questions/107212/get-sentence-embeddings-of-transformer-based-models) for the intuition and implementation behind `mean_pooling()`, which obtains a single SMILES embedding out of the ones generated by the RoBERTa model.)

In [None]:
from transformers import AutoConfig, AutoModelForSequenceClassification

def mean_pooling(model_output, attention_mask):
    # First element of model_output contains all token embeddings
    token_embeddings = model_output['last_hidden_state']
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask


class TransformerSubModel(pl.LightningModule):

    def __init__(self, checkpoint_path: str = 'seyonec/ChemBERTa-zinc-base-v1'):
        super().__init__()
        # Save the arguments passed to init
        self.save_hyperparameters()
        self.__dict__.update(locals()) # Add arguments as attributes
        # ChemBERT for SMILES
        self.config = AutoConfig.from_pretrained(checkpoint_path,
                                                 output_hidden_states=True,
                                                 num_labels=1)
        self.chembert = AutoModelForSequenceClassification.from_pretrained(
            checkpoint_path,
            config=self.config
        ).roberta

    def forward(self, x_in):
        # Run ChemBert over the toeknized SMILES
        input_ids = x_in['smiles_tokenized']['input_ids'].squeeze(dim=1)
        attention_mask = x_in['smiles_tokenized']['attention_mask'].squeeze(dim=1)
        smiles_embedding = self.chembert(input_ids, attention_mask)
        # NOTE: Due to multi-head attention, the output of the Transformer is a
        # sequence of hidden states, one for each input token. The following
        # takes the mean of all token embeddings to get a single embedding.
        smiles_embedding = mean_pooling(smiles_embedding, attention_mask)
        return smiles_embedding
    
    def get_smiles_embedding_size(self):
        return self.config.to_dict()['hidden_size']

Define Optuna objective:

In [None]:
def train_bert_model(config,
                     trial_name,
                     trial = None,
                     enable_checkpointing: bool = False,
                     bert_model: str = 'seyonec/ChemBERTa-zinc-base-v1',
                     num_epochs: int = 5,
                     task: Literal['predict_active_inactive', 'predict_pDC50_and_Dmax'] = 'predict_active_inactive',
                     loss_func: Callable | object = nn.HuberLoss(),
                     num_gpus: int = 0):
    # Namings for reporting
    model_name = 'transformer_model'
    lightning_dir = os.path.join(checkpoint_dir, 'lightning')
    model_checkpoint_dir = os.path.join(lightning_dir, 'models')
    model_checkpoint = f'{model_name}-{trial_name}'
    tensorboard_dir = os.path.join(lightning_dir, 'tensorboard', f'{model_name}_{task}')
    reporting_config = {
        'trial_name': trial_name,
        'lightning_dir': lightning_dir,
        'model_name': model_name,
        'model_checkpoint_dir': model_checkpoint_dir,
        'model_checkpoint': model_checkpoint,
        'tensorboard_dir': tensorboard_dir,
    }
    if trial is not None:
        trial_set_dict(trial, reporting_config)
    # Setup arguments for Wrapper model
    model_kwargs = {
        'freeze_smiles_encoder': config['freeze_smiles_encoder'],
        'use_extra_features': config['use_extra_features'],
        'freeze_smiles_encoder': config['freeze_smiles_encoder'],
        'hidden_channels_extra_features': config['hidden_channels_extra_features'],
        'dropout': config['dropout'],
        'learning_rate': config['learning_rate'],
        'batch_size': config['batch_size'],
    }
    # Setup arguments for SMILES Encoder generator function
    bert_model_path = os.path.join(checkpoint_dir, bert_model)
    if os.path.exists(bert_model_path):
        generator_args = {'checkpoint_path': bert_model_path}
    else:
        generator_args = {'checkpoint_path': bert_model}
    # Get specific tokenized datasets for current BERT model
    bert_model_stripped = bert_model.split('/')[-1]
    dataset_name = f'_tokenized_{bert_model_stripped}'
    ds = get_datasets(task, use_upsampled, dataset_name=dataset_name)
    # Train model via its generic function
    return train_model(reporting_config=reporting_config,
                       smiles_encoder=TransformerSubModel,
                       smiles_encoder_args=generator_args,
                       train_dataset=ds['train'],
                       val_dataset=ds['val'],
                       test_dataset=ds['test'],
                       num_epochs=num_epochs,
                       task=task,
                       loss_func=loss_func,
                       num_gpus=num_gpus,
                       use_raytune=False,
                       trial=trial,
                       accumulate_grad_batches=config['accumulate_grad_batches'],
                       enable_checkpointing=enable_checkpointing,
                       **model_kwargs)
    

def bert_objective(trial,
                   bert_model: str = 'seyonec/ChemBERTa-zinc-base-v1',
                   num_epochs: int = 5,
                   task: Literal['predict_active_inactive', 'predict_pDC50_and_Dmax'] = 'predict_active_inactive',
                   loss_func: Callable | object = nn.HuberLoss(),
                   num_gpus: int = 0):
    # Setup Wrapper Model arguments
    num_layers_extra = trial.suggest_int('num_layers_extra', 2, 8)
    hidden_channels_extra_features = [
        trial.suggest_int(f'layer_{i}_size', 64, 512, step=32) for i in range(num_layers_extra)
    ]
    config = {
        'use_extra_features': True, # trial.suggest_categorical('use_extra_features', [True, False]),
        'freeze_smiles_encoder': False, # trial.suggest_categorical('freeze_smiles_encoder', [True, False]),
        'hidden_channels_extra_features': hidden_channels_extra_features,
        'dropout': trial.suggest_float('dropout', 0.01, 0.8),
        'learning_rate': trial.suggest_float('learning_rate', 1e-5, 1e-2, log=True),
        'batch_size': trial.suggest_categorical('batch_size', [4, 8]),
        'accumulate_grad_batches': trial.suggest_categorical('accumulate_grad_batches', [1, 2, 4, 8]),
    }
    # Namings for reporting
    eventid = f'{trial.datetime_start.strftime("%Y%m-%d%H-%M%S-")}{uuid4()}'
    trial_name = f'{"freezed-" if config["freeze_smiles_encoder"] else ""}{trial.number}-{eventid}'
    bert_model_stripped = bert_model.split('/')[-1]
    dataset_name = f'_tokenized_{bert_model_stripped}'
    trial.set_user_attr('bert_type', bert_model_stripped)
    trial.set_user_attr('dataset_name', dataset_name)
    # Finally train the BERT-based model
    return train_bert_model(config=config,
                            trial=trial,
                            trial_name=trial_name,
                            enable_checkpointing=False,
                            bert_model=bert_model,
                            num_epochs=num_epochs,
                            task=task,
                            loss_func=loss_func,
                            num_gpus=num_gpus)

Run experiments:

In [None]:
# Define experiments design points
tasks = [
    'predict_active_inactive',
    # 'predict_pDC50_and_Dmax',
    # 'predict_pDC50',
    ]
upsampled = [False] # [True, False]
bert_models_list = ssl_bert_models + pretrained_bert_models
experiments = (tasks, upsampled, bert_models_list)
# Get all experiments combinations
n_experiments = 0
for subset in itertools.product(*experiments):
    task, use_upsampled, _ = subset  
    if use_upsampled and task != 'predict_active_inactive':
        continue
    n_experiments += 1
# Set fixed parameters
num_epochs = 15
num_samples = 1000
n_gpus = 1 if torch.cuda.is_available() else 0
# loss_func = mean_absolute_error
# loss_func = mean_squared_error
loss_func = nn.HuberLoss(reduction='mean', delta=0.8) # Default 1.0
# Define specific results dictionary in the global one
experiments_results['results_transformer'] = load_result('results_transformer')
if RETRAIN_BERT_MODEL or experiments_results['results_transformer'] is None:
    # Run experiments
    i = 0
    best_ckpt = []
    for experiment_id in itertools.product(*experiments):
        task, use_upsampled, bert_model = experiment_id
        if use_upsampled and task != 'predict_active_inactive':
            continue
        print(f'-' * 80)
        print(f'Experiment n.{i + 1}/{n_experiments} ({i / n_experiments * 100.0:.2f}% complete):')
        print(f'\ttask: {task}')
        print(f'\tuse_upsampled: {use_upsampled}')
        print(f'\tbert_model: {bert_model}')
        print(f'-' * 80)
        # Run Optuna study
        direction = 'maximize' if task == 'predict_active_inactive' else 'minimize'
        # optuna_pruner = optuna.pruners.MedianPruner(n_warmup_steps=10)
        optuna_pruner = optuna.pruners.HyperbandPruner(min_resource=2,
                                                       max_resource=num_epochs,
                                                       reduction_factor=3)
        optuna_sampler = optuna.samplers.TPESampler(seed=42)
        study = optuna.create_study(direction=direction,
                                    pruner=optuna_pruner,
                                    sampler=optuna_sampler)
        study.optimize(lambda trial: bert_objective(trial,
                                                    task=task,
                                                    bert_model=bert_model,
                                                    num_epochs=num_epochs,
                                                    loss_func=loss_func,
                                                    num_gpus=n_gpus),
                    n_trials=num_samples,
                    timeout=600 * 2)
        best_trial = study.best_trial
        # Reporting
        print('-' * 80)
        print(f'Experiment n.{i + 1}/{n_experiments} done ({(i + 1) / n_experiments * 100.0:.2f}% complete)')
        print(f'Number of finished trials: {len(study.trials)}')
        print(f'Best trial value: {best_trial.value}')
        # Retrain model with best hyperparameters
        # NOTE: We are training at the end in order to not pollute the disk with
        # non-optimal checkpoints
        # NOTE: Set the non-optimized parameters to the fixed ones
        print_dict('Hyperparams:', best_trial.params)
        print('Retraining model with best hyperparameters...', end='')
        config = best_trial.params.copy()
        num_layers_extra = config['num_layers_extra']
        hidden_channels_extra_features = [
            config[f'layer_{i}_size'] for i in range(num_layers_extra)
        ]
        config['use_extra_features'] = True
        config['freeze_smiles_encoder'] = False
        config['hidden_channels_extra_features'] = hidden_channels_extra_features
        train_bert_model(config=config,
                         trial_name=best_trial.user_attrs['trial_name'],
                         trial=best_trial,
                         enable_checkpointing=True, # This time enable checkpointing
                         bert_model=bert_model,
                         num_epochs=num_epochs,
                         task=task,
                         loss_func=loss_func,
                         num_gpus=n_gpus)
        print('done')
        # Update experiment results
        # NOTE: The function `train_bert_model` updates the `best_trial` object
        experiments_results['results_transformer'][experiment_id] = {}
        experiments_results['results_transformer'][experiment_id]['task'] = task
        experiments_results['results_transformer'][experiment_id]['trial'] = best_trial
        experiments_results['results_transformer'][experiment_id]['use_upsampled'] = use_upsampled
        print_dict('Experiment:', experiments_results['results_transformer'][experiment_id])
        print_dict('Hyperparams:', best_trial.params)
        print_dict('Attributes:', best_trial.user_attrs)
        # Plotting training curves
        trainer_logs = best_trial.user_attrs['trainer_log_dir']
        descr = f' for Transformer ({best_trial.user_attrs["bert_type"]})'
        figpath = os.path.join(fig_dir, f'training_curves_{task}_{best_trial.user_attrs["bert_type"]}')
        plot_training_curves(trainer_logs, experiment=descr, figpath=figpath)
        # Remove non-optimal checkpoints
        model_name = best_trial.user_attrs['model_name']
        checkpoint_root_dir = best_trial.user_attrs['model_checkpoint_dir']
        best_ckpt.append(best_trial.user_attrs['trial_name'])
        del_non_optimal_ckpt(checkpoint_root_dir, best_ckpt, model_name)
        i += 1
    # Save results
    save_results(experiments_results['results_transformer'],
                result_name='results_transformer')

In [None]:
experiments_results['results_transformer'] = load_result('results_transformer')
if experiments_results['results_transformer'] is not None:
    for experiment_id, design_points in experiments_results['results_transformer'].items():
        task = design_points['task']
        trial = design_points['trial']
        use_upsampled = design_points['use_upsampled']
        print(f'Experiment: {experiment_id}')
        print_dict('Hyperparams:', trial.params)
        print_dict('Attributes:', trial.user_attrs)
        model_checkpoint = trial.user_attrs['model_checkpoint']
        model = WrapperModel.load_from_checkpoint(model_checkpoint)
        print('-' * 80)

Evaluation:

In [None]:
for phase in ['val', 'test']:
    confusion_matrices = {}
    plot_dummy = True

    for experiment_id, design_points in experiments_results['results_transformer'].items():
        trial = design_points['trial']
        task = design_points['task']
        # Model-specific description
        descr = f'Transformer [{trial.user_attrs["bert_type"]}]'
        preds, cm = evaluate_experiment(task=task,
                                        descr=descr,
                                        dataset_name=trial.user_attrs['dataset_name'],
                                        model_checkpoint=trial.user_attrs['model_checkpoint'],
                                        plot_dummy=plot_dummy,
                                        phase=phase,
                                        plot_auc=True)
        experiments_results['results_transformer'][experiment_id]['trial'].user_attrs.update(preds)
        save_results(experiments_results['results_transformer'],
                     result_name='results_transformer')
        confusion_matrices[experiment_id] = (cm, descr)
        plot_dummy = False
        print_dict(f'Evaluation results for {descr}:', preds)
        print('-' * 80)
    plt.grid('both', alpha=0.7)
    plt.legend(loc='upper left', bbox_to_anchor=(1, 1), ncol=1, fancybox=True) #, shadow=True)
    plt.title(f'Transformers {"Validation" if phase == "val" else "Test"} Set ROC Curve')

    filename = os.path.join(fig_dir, f'roc_curve_{phase}_transformers')
    plt.savefig(filename + '.pdf', bbox_inches='tight')
    plt.savefig(filename + '.png', bbox_inches='tight')
    plt.show()
    plt.close()

    # Plot confusion matrixes:
    for i, (_, (disp, descr)) in enumerate(confusion_matrices.items()):
        disp.plot(cmap=plt.cm.Blues)
        plt.title(f'{descr}')
        filename = os.path.join(fig_dir, f'confusion_matrix_{phase}_transformers_n{i}')
        plt.savefig(filename + '.pdf', bbox_inches='tight')
        plt.savefig(filename + '.png', bbox_inches='tight')
        # plt.show()
        plt.close()

### Active Learning and Semi-Supervised Learning

[Interesting thesis about the subject](https://odr.chalmers.se/server/api/core/bitstreams/356f3738-b743-4c5a-ab5e-233503f69024/content)

TODOs:

* The SSL data are missing `poi_seq` and `cell_type`: maybe move the parsing at the end of the data cleaning process?
* Implement the training loop suggested by ChatGPT

In [None]:
ssl_df = pd.read_csv(os.path.join(data_dir, 'protac', 'protac-db_ssl.csv'))
ssl_df = ssl_df.dropna(subset=['Smiles_nostereo'])
ssl_df = ssl_df.rename(columns={'E3ligase': 'e3_ligase'})
ssl_df['active'] = False
protac_ds_kwargs = {
        'precompute_fingerprints': False,
        'use_morgan_fp': True,
        'morgan_bits': 4096,
        'morgan_atomic_radius': 2,
        'poi_vectorizer': poi_vectorizer,
        'e3_ligase_enc': e3_ligase_enc,
        'poi_gene_enc': poi_gene_enc,
        'cell_type_enc': cell_type_enc,
    }
# train_dataset = ProtacDataset(train_bin_df, **protac_ds_kwargs)
# val_dataset = ProtacDataset(ssl_df, **protac_ds_kwargs)
# test_dataset = ProtacDataset(test_bin_df, **protac_ds_kwargs)

In [None]:
model_name = 'fp_model_protac_pedia'
trial_name = 'v0'
dataset_name = '_fp2048_radius2_path1-8'
num_epochs = 15
num_samples = 20
# Setup SMILES Encoder arguments
smiles_encoder_gen_args = {
    'fp_type': 'morgan_fp',
    'fp_bits': 2048,
    'hidden_channels': [128, 128, 128],
    'norm_layer': nn.BatchNorm1d,
    'dropout': 0.3,
}
smiles_encoder = FingerprintSubModel
# Setup Wrapper Model arguments
model_kwargs = {
    'use_extra_features': True, # trial.suggest_categorical('use_extra_features', [True, False]),
    'hidden_channels_extra_features': [128, 64],
    'dropout': 0.3,
    'learning_rate': 1e-4,
    'batch_size': 128,
}
protac_ds_kwargs = {
    'precompute_fingerprints': True,
    'use_morgan_fp': True,
    'use_maccs_fp': True,
    'use_path_fp': True,
    'morgan_atomic_radius': 2,
    'morgan_bits': 2048,
    'path_bits': 2048,
    'fp_min_path': 1,
    'fp_max_path': 8,
    'poi_vectorizer': poi_vectorizer,
    'e3_ligase_enc': e3_ligase_enc,
    'poi_gene_enc': poi_gene_enc,
    'cell_type_enc': cell_type_enc,
}
ds = get_datasets(task, use_upsampled, dataset_name=dataset_name, regenerate_datasets=False, **protac_ds_kwargs)
# Standard namings for reporting
lightning_dir = os.path.join(checkpoint_dir, 'lightning')
model_checkpoint_dir = os.path.join(lightning_dir, 'models')
model_checkpoint = f'{model_name}-{trial_name}'
tensorboard_dir = os.path.join(lightning_dir, 'tensorboard', f'{model_name}_{task}')
reporting_config = {
    'trial_name': trial_name,
    'model_name': model_name,
    'lightning_dir': lightning_dir,
    'model_checkpoint_dir': model_checkpoint_dir,
    'model_checkpoint': model_checkpoint,
    'tensorboard_dir': tensorboard_dir,
}
trial = optuna.trial.FixedTrial({})
# Train model via its generic function
train_model(reporting_config=reporting_config,
            smiles_encoder=smiles_encoder,
            smiles_encoder_args=smiles_encoder_gen_args,
            train_dataset=ds['train'],
            val_dataset=ds['val'],
            test_dataset=ds['test'],
            # Optional arguments
            num_epochs=num_epochs,
            num_gpus=1,
            use_raytune=False,
            enable_checkpointing=True,
            trial=trial,
            **model_kwargs)
descr = ' MLP'
plot_training_curves(trial.user_attrs['trainer_log_dir'], experiment=descr)

model = WrapperModel.load_from_checkpoint(trial.user_attrs['model_checkpoint'])
model.train_dataset = ds['train']
model.val_dataset = ds['val']
model.test_dataset = ds['test']
# model.learning_rate = 1e-3
# num_epochs = 10

# train_model(reporting_config=reporting_config,
#             smiles_encoder=smiles_encoder,
#             smiles_encoder_args=smiles_encoder_gen_args,
#             train_dataset=ds['train'],
#             val_dataset=ds['val'],
#             test_dataset=ds['test'],
#             model=model,
#             # Optional arguments
#             num_epochs=num_epochs,
#             num_gpus=1,
#             enable_checkpointing=True,
#             trial=trial,
#             **model_kwargs)
# descr = ' for MLP'
# plot_training_curves(trial.user_attrs['trainer_log_dir'], experiment=descr)

In [None]:
# Example implementation of calculate_uncertainty_scores using entropy
def calculate_uncertainty_scores(predictions):
    probabilities = torch.sigmoid(predictions)
    entropy = -torch.mean((probabilities * torch.log(probabilities + 1e-8)) + ((1 - probabilities) * torch.log(1 - probabilities + 1e-8)), dim=1)
    return entropy

model.train_dataset = ds['train']
model.val_dataset = ds['val_protac_pedia']
model.test_dataset = ds['test']
preds = get_eval_results(model, num_gpus=1, run_lightning_eval=False, return_logits_only=True)

In [None]:
from torch.utils.data import ConcatDataset

len(ConcatDataset([ds['train'], ds['train_protac_pedia']]))

[Maybe a relevant paper...](https://ieeexplore.ieee.org/document/9533839)

In [None]:
res = load_result('results_fp')[('predict_active_inactive', 4096, False)]
trial = res['trial']
oracle = WrapperModel.load_from_checkpoint(trial.user_attrs['model_checkpoint'])
print(oracle.smiles_encoder.hparams)
protac_ds_kwargs = {
    'precompute_fingerprints': True,
    'use_morgan_fp': True,
    'use_maccs_fp': True,
    'use_path_fp': True,
    'morgan_atomic_radius': trial.params['radius'],
    'morgan_bits': oracle.smiles_encoder.hparams['fp_bits'],
    'path_bits': oracle.smiles_encoder.hparams['fp_bits'],
    'fp_min_path': 1,
    'fp_max_path': trial.params['fp_max_path'],
    'poi_vectorizer': poi_vectorizer,
    'e3_ligase_enc': e3_ligase_enc,
    'poi_gene_enc': poi_gene_enc,
    'cell_type_enc': cell_type_enc,
}
ds = get_datasets(dataset_name=trial.user_attrs['dataset_name'], regenerate_datasets=False, **protac_ds_kwargs)
# oracle.train_dataset = ds['train']
# oracle.val_dataset = ds['val']
# oracle.test_dataset = ds['test']

# preds = get_eval_results(oracle, num_gpus=1, run_lightning_eval=False, return_logits_only=True)
# predictions = torch.sigmoid(torch.Tensor(preds['val_logits']))
# # predictions

In [None]:
tmp = ProtacDataset(ds['train_protac_pedia'].dataframe[:25], **ds['train_protac_pedia'].hparams)
print(tmp)

In [None]:
from torch.utils.data import DataLoader, SubsetRandomSampler

def active_learning(num_iterations=1, confidence_threshold=0.7, warmup_steps=5):
    # Initialize the oracle model (pretrained model)
    res = load_result('results_fp')[('predict_active_inactive', 4096, False)]
    trial = res['trial']
    oracle = WrapperModel.load_from_checkpoint(trial.user_attrs['model_checkpoint'])
    # oracle.smiles_encoder.hparams
    ds = get_datasets(dataset_name=trial.user_attrs['dataset_name'])
    train_dataset = ds['train']
    val_dataset = ds['val']
    test_dataset = ds['test']
    train_dataset_al = ds['train_protac_pedia']
    val_dataset_al = ds['val_protac_pedia']
    
    num_epochs = 10
    smiles_encoder_gen_args = {
        'fp_type': 'morgan_fp',
        'fp_bits': 4096,
        'hidden_channels': [128, 128, 128],
        'norm_layer': nn.BatchNorm1d,
        'dropout': 0.3,
    }
    smiles_encoder = FingerprintSubModel
    # Setup Wrapper Model arguments
    model_kwargs = {
        'use_extra_features': True, # trial.suggest_categorical('use_extra_features', [True, False]),
        'hidden_channels_extra_features': [128, 64],
        'dropout': 0.3,
        'learning_rate': 1e-4,
        'batch_size': 128,
    }
    train_model(reporting_config=reporting_config,
            smiles_encoder=smiles_encoder,
            smiles_encoder_args=smiles_encoder_gen_args,
            train_dataset=train_dataset,
            val_dataset=val_dataset,
            test_dataset=test_dataset,
            # Optional arguments
            num_epochs=num_epochs,
            num_gpus=1,
            use_raytune=False,
            enable_checkpointing=True,
            trial=trial,
            **model_kwargs)
    descr = ' MLP (starting)'
    plot_training_curves(trial.user_attrs['trainer_log_dir'], experiment=descr)
    initial_model = WrapperModel.load_from_checkpoint(trial.user_attrs['model_checkpoint'])
    
    for i in range(num_iterations):
        # Setup Oracle val dataset as AL dataset to get predictions
        oracle.val_dataset = train_dataset_al
        oracle.test_dataset = test_dataset
        # Get predictions from oracle
        preds = get_eval_results(oracle, num_gpus=1, run_lightning_eval=False, return_logits_only=True)
        predictions = torch.sigmoid(torch.Tensor(preds['val_logits'])).flatten()
        # Get pseudo-labels from the oracle on AL dataset:
        pseudo_labels = torch.zeros_like(predictions) - 1.0
        pseudo_labels[predictions > confidence_threshold] = 1.0
        pseudo_labels[(1 - predictions) > confidence_threshold] = 0.0
        print(f'Iter.N.{i}) Active:   {len(pseudo_labels[pseudo_labels == 1.0])}')
        print(f'Iter.N.{i}) Inactive: {len(pseudo_labels[pseudo_labels == 0.0])}')

        high_confidence_idx = (pseudo_labels != -1).flatten().numpy()
        # Update labels accordingly
        train_dataset_al.dataframe.iloc[pseudo_labels == 1.0]['active'] = True
        train_dataset_al.dataframe.iloc[pseudo_labels == 0.0]['active'] = False
        high_confidence_df = train_dataset_al.dataframe[high_confidence_idx]
        low_confidence_df = train_dataset_al.dataframe[~high_confidence_idx]
        
        new_labeled_ds = ProtacDataset(high_confidence_df, **train_dataset_al.hparams)
        train_dataset_al = ProtacDataset(low_confidence_df, **train_dataset_al.hparams)
        
        train_df = pd.concat([train_dataset.dataframe, new_labeled_ds.dataframe])
        train_dataset = ProtacDataset(train_df, **train_dataset.hparams)
        print(f'Iter.N.{i}) Len Train dataset: {len(train_dataset)}')
        
        if i == 0:
            oracle = initial_model
        oracle.train_dataset = train_dataset
        oracle.val_dataset = val_dataset
        oracle.test_dataset = test_dataset
        
        train_model(reporting_config=reporting_config,
                    smiles_encoder=smiles_encoder,
                    smiles_encoder_args=smiles_encoder_gen_args,
                    train_dataset=train_dataset,
                    val_dataset=val_dataset,
                    test_dataset=test_dataset,
                    model=oracle,
                    # Optional arguments
                    num_epochs=num_epochs,
                    num_gpus=1,
                    enable_checkpointing=True,
                    trial=trial,
                    **model_kwargs)
        descr = ' for MLP'
        plot_training_curves(trial.user_attrs['trainer_log_dir'], experiment=descr)
        oracle = WrapperModel.load_from_checkpoint(trial.user_attrs['model_checkpoint'])
        # Update train dataset with confident pseudo-labels
        # Remove confident pseudo-labels from AL train dataset
        # Train AL model on updated train dataset
        # if AL model accuracy > Oracle accuracy:
        #   break or oracle = AL model?
        
        
    
    
#     oracle.train_dataset = ds['train']
#     oracle.val_dataset = ds['val']
#     oracle.test_dataset = ds['test']
    
    
#     # Randomly select initial samples
#     initial_indices = torch.randperm(len(train_dataset))[:num_initial_samples]
#     train_loader = DataLoader(train_dataset, batch_size=64, sampler=SubsetRandomSampler(initial_indices))
    
#     # Train the oracle model on initial samples
#     trainer = pl.Trainer(max_epochs=num_epochs)
#     trainer.fit(oracle, train_loader)
    
#     # Extend the dataset using pseudo-labeling from the oracle
#     unlabeled_indices = torch.tensor(list(set(range(len(train_dataset))) - set(initial_indices.tolist()))))
#     unlabeled_loader = DataLoader(train_dataset, batch_size=64, sampler=SubsetRandomSampler(unlabeled_indices))
    
#     pseudo_labels = []
#     confident_indices = []
    
#     for images, _ in unlabeled_loader:
#         with torch.no_grad():
#             outputs = oracle(images)
#             probabilities = torch.softmax(outputs, dim=1)
        
#         confident_mask = (probabilities[:, 1] > confidence_threshold)  # Select confident positive examples
#         confident_indices.extend(unlabeled_indices[confident_mask].tolist())
#         pseudo_labels.extend((probabilities[:, 1] >= 0.5).long()[confident_mask].tolist())
        
#     train_dataset.targets = torch.cat((train_dataset.targets, torch.tensor(pseudo_labels)))
#     train_dataset.data = torch.cat((train_dataset.data, unlabeled_loader.dataset.data[confident_indices]))
    
#     # Train the model on the extended dataset
#     model = Classifier()  # Replace with your binary classification model
#     train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
#     trainer.fit(model, train_loader)
    
#     return model

# # Run the active learning loop
# confidence_threshold = 0.9  # Adjust the threshold as needed
# model = active_learning(num_initial_samples=100, num_queries=10, num_epochs=5, confidence_threshold=confidence_threshold)

active_learning(num_iterations=3, confidence_threshold=0.75)

In [None]:
# Define your active learning function
def active_learning(datasets, model, num_iterations, batch_size):
    # Create a DataLoader for the unlabeled dataset
    unlabeled_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    # Initialize an empty list to store the selected samples
    selected_samples = []
    
    
    train_dataset = datasets['train']
    val_dataset = datasets['val']
    test_dataset = datasets['test']
    train_dataset_active = datasets['train_protac_pedia']
    val_dataset_active = datasets['val_protac_pedia']
    

    for _ in range(num_iterations):
        model.train_dataset = ds['train']
        model.val_dataset = ds['val_protac_pedia']
        model.test_dataset = ds['test']
        # Get predictions for all unlabeled samples
        preds = get_eval_results(model, num_gpus=1, run_lightning_eval=False,
                                 return_logits_only=True)
        predictions = torch.Tensor(preds['val_logits'])

        # Calculate uncertainty scores (e.g., entropy) for each prediction
        uncertainty_scores = calculate_uncertainty_scores(predictions)

        # Select the samples with the highest uncertainty scores
        num_samples_to_select = min(batch_size, len(dataset))
        selected_indices = torch.topk(uncertainty_scores, num_samples_to_select).indices

        # Add the selected samples to the training set
        selected_samples.extend(dataset[selected_indices])
        # Remove the selected samples from the unlabeled set
        dataset.remove_samples(selected_indices)

        # Train the model on the updated training set
        trainer = pl.Trainer(max_epochs=10)  # Modify the parameters as needed
        model = Classifier()  # Create a new instance of the model
        trainer.fit(model, DataLoader(dataset, batch_size=batch_size))
        
        
        

        trial = optuna.trial.FixedTrial({})
        # Train model via its generic function
        train_model(reporting_config=reporting_config,
                    model=model,
                    train_dataset=ds['train'],
                    val_dataset=ds['val'],
                    test_dataset=ds['test'],
                    # Optional arguments
                    num_epochs=num_epochs,
                    num_gpus=1,
                    use_raytune=False,
                    enable_checkpointing=True,
                    trial=trial,
                    **model_kwargs)
        descr = ' MLP (active learning)'
        plot_training_curves(trial.user_attrs['trainer_log_dir'], experiment=descr)

        model = WrapperModel.load_from_checkpoint(trial.user_attrs['model_checkpoint'])
        model.train_dataset = ds['train']
        model.val_dataset = ds['val']
        model.test_dataset = ds['test']


    return selected_samples

### SMILES as Graphs - GNNs

#### Via GNN Layers

Following this [tutorial](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial7/GNN_overview.html).

In [None]:
gnn_layer_by_name = {
    'GCN': geom_nn.GCNConv,
    'GAT': geom_nn.GATConv,
    'GraphConv': geom_nn.GraphConv
}

In [None]:
class GNNModel(nn.Module):

    def __init__(self,
                 c_in: int,
                 c_hidden: int,
                 c_out: int,
                 num_layers: int = 2,
                 layer_name: Literal['GCN', 'GAT', 'GraphConv'] = 'GCN',
                 dp_rate: float = 0.1,
                 **kwargs):
        """
        Inputs:
            c_in - Dimension of input features
            c_hidden - Dimension of hidden features
            c_out - Dimension of the output features. Usually number of classes in classification
            num_layers - Number of "hidden" graph layers
            layer_name - String of the graph layer to use
            dp_rate - Dropout rate to apply throughout the network
            kwargs - Additional arguments for the graph layer (e.g. number of heads for GAT)
        """
        super().__init__()
        gnn_layer = gnn_layer_by_name[layer_name]

        layers = []
        in_channels, out_channels = c_in, c_hidden
        # Interleave graph layers with ReLU and Dropout
        for l_idx in range(num_layers - 1):
            layers += [
                gnn_layer(in_channels=in_channels,
                          out_channels=out_channels,
                          **kwargs),
                nn.Dropout(dp_rate),
                nn.ReLU(inplace=True),
            ]
            # if 'heads' in kwargs and 'GAT' in layer_name:
            #     # For GAT, we need to update the input dimensionality
            #     if kwargs.get('concat', True):
            #         in_channels = kwargs['heads'] * c_hidden
            # else:
            #     in_channels = c_hidden
            in_channels = c_hidden
        layers += [gnn_layer(in_channels=in_channels,
                             out_channels=c_out,
                             **kwargs)]
        self.layers = nn.ModuleList(layers)

    def forward(self, x, edge_index):
        """
        Inputs:
            x - Input features per node
            edge_index - List of vertex index pairs representing the edges in the graph (PyTorch geometric notation)
        """
        for l in self.layers:
            # NOTE: For graph layers, we need to add the "edge_index" tensor as
            # additional input. All PyTorch Geometric graph layer inherit the
            # class "MessagePassing", hence we can simply check the class type.
            if isinstance(l, geom_nn.MessagePassing):
                x = l(x, edge_index)
            else:
                x = l(x)
        return x

Define GNN sub-model:

(TODO: Maybe use a predefined model? [Models available in Pytorch Geometric.](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#models))

* GatedGraphConv
* GIN
* Aggregation functions in general (AttentionalAggregation seems to be the most promising one)
* DepthSumPooling

In [None]:
class GnnSubModel(pl.LightningModule):

    def __init__(self,
                 c_hidden: int,
                 c_out: int,
                 num_layers: int = 2,
                 layer_name: Literal['GCN', 'GAT', 'GraphConv'] = 'GCN',
                 dp_rate: float = 0.1,
                 **kwargs):
        super().__init__()
        # Set our init args as class attributes
        self.__dict__.update(locals()) # Add arguments as attributes
        self.save_hyperparameters()
        # Define PyTorch model
        # NOTE: `num_node_features` definition is near ProtacDataset definition
        self.graph_encoder = GNNModel(num_node_features, c_hidden, c_out,
                                      num_layers, layer_name, dp_rate, **kwargs)

    def forward(self, x_in):
        # Get the graph input
        x = x_in['smiles_graph'].x.float()
        edge_index = x_in['smiles_graph'].edge_index
        batch_idx = x_in['smiles_graph'].batch
        # Run the GNN sub-model
        x = self.graph_encoder(x, edge_index)
        smiles_emb = geom_nn.global_add_pool(x, batch_idx)
        return smiles_emb


def generate_gnn_submodel(c_hidden: int,
                          c_out: int,
                          num_layers: int = 2,
                          layer_name: Literal['GCN', 'GAT', 'GraphConv'] = 'GCN',
                          dp_rate: float = 0.1,
                          dp_rate_linear: float = 0.5,
                          checkpoint_path: str | None = None,
                          **kwargs):
    model = GnnSubModel(c_hidden, c_out, num_layers, layer_name, dp_rate,
                        dp_rate_linear, **kwargs)
    if checkpoint_path is not None:
        model.load_state_dict(torch.load(checkpoint_path))
    return model

Define Optuna objective:

In [None]:
def gnn_objective(trial,
                  num_epochs:int = 10,
                  task: Literal['predict_active_inactive', 'predict_pDC50_and_Dmax'] = 'predict_active_inactive',
                  loss_func: Callable | object = nn.HuberLoss(),
                  num_gpus: int = 0):
    config = {
        # SMILES encoder
        'c_hidden': trial.suggest_categorical('c_hidden', [64, 128, 256, 512, 768]),
        'c_out': trial.suggest_categorical('c_out', [64, 128, 256, 512, 768]),
        'num_layers': trial.suggest_int('num_layers', 3, 11),
        'layer_name': trial.suggest_categorical('layer_name', ['GCN', 'GAT', 'GraphConv']),
        'dp_rate': trial.suggest_float('dp_rate', 0.1, 0.8),
        'dp_rate_linear': trial.suggest_float('dp_rate_linear', 0.1, 0.8),
        # Extra features branch
        'use_extra_features': True, # trial.suggest_categorical('use_extra_features', [True, False]),
        'layer_1_size_extra': trial.suggest_categorical('layer_1_size_extra', [8, 16, 32, 64, 128, 256, 512]),
        'layer_2_size_extra': trial.suggest_categorical('layer_2_size_extra', [8, 16, 32, 64, 128, 256, 512]),
        'dropout': trial.suggest_float('dropout', 0.1, 0.8),
        'learning_rate': trial.suggest_float('learning_rate', 1e-5, 1e-3, log=True),
        # TODO: add code to divide the batch size in the training function according to the accumulate grad config
        'batch_size': trial.suggest_categorical('batch_size', [16, 32, 64]),
    }
    # Namings for reporting
    model_name = 'gnn_model'
    eventid = f'{trial.datetime_start.strftime("%Y%m%d-%H-%M-%S-")}{uuid4()}'
    trial_name = f'{config["layer_name"]}-{trial.number}-{eventid}'    
    lightning_dir = os.path.join(checkpoint_dir, 'lightning')
    model_checkpoint_dir = os.path.join(lightning_dir, 'models')
    model_checkpoint = f'{model_name}-{trial_name}'
    tensorboard_dir = os.path.join(lightning_dir, 'tensorboard', f'{model_name}_{task}')
    reporting_config = {
        'trial_name': trial_name,
        'lightning_dir': lightning_dir,
        'model_name': model_name,
        'model_checkpoint_dir': model_checkpoint_dir,
        'model_checkpoint': model_checkpoint,
        'tensorboard_dir': tensorboard_dir,
    }
    # Setup sub-model arguments
    generator_args = {
        'c_hidden': config['c_hidden'],
        'c_out': config['c_out'],
        'num_layers': config['num_layers'],
        'layer_name': config['layer_name'],
        'dp_rate': config['dp_rate'],
        'dp_rate_linear': config['dp_rate_linear'],
    }
    if config['layer_name'] == 'GAT':
        generator_args.update({
            'heads': 8,
            'concat': False,
        })
    smiles_embedding_size = config['c_out']
    # Retrieve specific datasets
    dataset_name = f'_graph'
    trial.set_user_attr('dataset_name', dataset_name)
    protac_ds_kwargs = {
        'precompute_smiles_as_graphs': True,
    }
    ds = get_datasets(task,
                                 use_upsampled,
                                 dataset_name=dataset_name,
                                 regenerate_datasets=False,
                                 **protac_ds_kwargs)
    train_dataset, test_dataset = ds
    # Train model via its generic function
    return train_model(config=config,
                       reporting_config=reporting_config,
                       smiles_encoder_generator=generate_gnn_submodel,
                       smiles_encoder_generator_args=generator_args,
                       smiles_embedding_size=smiles_embedding_size,
                       train_dataset=train_dataset,
                       test_dataset=test_dataset,
                       num_epochs=num_epochs,
                       task=task,
                       loss_func=loss_func,
                       num_gpus=num_gpus,
                       use_raytune=False,
                       trial=trial)

Run experiments:

In [None]:
# Define experiments design points
tasks = [
    'predict_active_inactive',
    # 'predict_pDC50_and_Dmax',
    # 'predict_pDC50',
    ]
upsampled = [False] # [True, False]
experiments = (tasks, upsampled)
# Get all experiments combinations
n_experiments = 0
for subset in itertools.product(*experiments):
    task, use_upsampled = subset  
    if use_upsampled and task != 'predict_active_inactive':
        continue
    n_experiments += 1
# Set fixed parameters
num_epochs = 10
num_samples = 20
n_gpus = 1 if torch.cuda.is_available() else 0
# loss_func = mean_absolute_error
# loss_func = mean_squared_error
loss_func = nn.HuberLoss(reduction='mean', delta=0.8) # Default 1.0
# Define specific results dictionary in the global one
experiments_results['results_gnn'] = {}
# Run experiments
pl.utilities.memory.garbage_collection_cuda()
i = 0
best_ckpt = []
for experiment_id in itertools.product(*experiments):
    task, use_upsampled = experiment_id
    if use_upsampled and task != 'predict_active_inactive':
        continue
    print(f'-' * 80)
    print(f'Experiment n.{i + 1} ({i / n_experiments * 100.0:.2f}% complete):')
    print(f'\ttask: {task}')
    print(f'\tuse_upsampled: {use_upsampled}')
    print(f'-' * 80)
    # Run Optuna study
    direction = 'maximize' if task == 'predict_active_inactive' else 'minimize'
    optuna_pruner = optuna.pruners.MedianPruner(n_warmup_steps=10)
    optuna_sampler = optuna.samplers.TPESampler(seed=42)
    study = optuna.create_study(direction=direction,
                                pruner=optuna_pruner,
                                sampler=optuna_sampler)
    study.optimize(lambda trial: gnn_objective(trial,
                                               task=task,
                                               num_epochs=num_epochs,
                                               loss_func=loss_func,
                                               num_gpus=n_gpus),
                   n_trials=num_samples,
                   timeout=600)
    trial = study.best_trial
    experiments_results['results_gnn'][experiment_id] = {}
    experiments_results['results_gnn'][experiment_id]['trial'] = trial
    experiments_results['results_gnn'][experiment_id]['task'] = task
    experiments_results['results_gnn'][experiment_id]['use_upsampled'] = use_upsampled
    # Reporting
    print('-' * 80)
    print(f'Experiment n.{i + 1} done ({(i + 1) / n_experiments * 100.0:.2f}% complete)')
    print('Number of finished trials: {}'.format(len(study.trials)))
    print(f'Best trial score: {trial.value}:')
    print_dict('Experiment:', experiments_results['results_gnn'][experiment_id])
    print_dict('Params:', trial.params)
    print_dict('Attributes:', trial.user_attrs)
    # Remove non-optimal checkpoints
    model_name = trial.user_attrs['model_name']
    checkpoint_root_dir = trial.user_attrs['model_checkpoint_dir']
    best_ckpt.append(trial.user_attrs['trial_name'])
    del_non_optimal_ckpt(checkpoint_root_dir, best_ckpt, model_name)
    # Plotting training curves
    trainer_logs = trial.user_attrs['trainer_log_dir']
    descr = f' for GNN ({trial.params["layer_name"]})'
    plot_training_curves(trainer_logs, experiment=descr)
    i += 1
save_results(experiments_results['results_gnn'], result_name='results_gnn')

Evaluation:

In [None]:
confusion_matrices = {}
plot_dummy = True

for experiment_id, design_points in experiments_results['results_gnn'].items():
    trial = design_points['trial']
    task = design_points['task']
    # Model-specific description
    layer_name = trial.params['layer_name']
    descr = f'GNN ({layer_name})'
    evaluate_experiment(descr, experiment_id, trial, task, confusion_matrices, plot_dummy)
    plot_dummy = False
plt.grid('both', alpha=0.7)
plt.legend(loc='upper left', bbox_to_anchor=(1, 1), ncol=1, fancybox=True) #, shadow=True)
plt.show()

# Plot confusion matrixes:
for experiment, (disp, descr) in confusion_matrices.items():
    disp.plot(cmap=plt.cm.Blues)
    plt.title(f'{descr}')
    plt.show()

In [None]:
# %tensorboard --logdir {checkpoint_dir}/lightning/tensorboard/

#### GNN Model as standalone (DEPRECATED)

In [None]:
class GraphGNNModel(nn.Module):

    def __init__(self,
                 c_in: int,
                 c_hidden: int,
                 c_out: int,
                 dp_rate_linear: float = 0.5,
                 **kwargs):
        """
        Inputs:
            c_in - Dimension of input features
            c_hidden - Dimension of hidden features
            c_out - Dimension of output features (usually number of classes)
            dp_rate_linear - Dropout rate before the linear layer (usually much higher than inside the GNN)
            kwargs - Additional arguments for the GNNModel object
        """
        super().__init__()
        self.GNN = GNNModel(c_in=c_in, c_hidden=c_hidden, c_out=c_hidden,
                            **kwargs)
        self.head = nn.Sequential(
            nn.Dropout(dp_rate_linear),
            nn.Linear(c_hidden, c_out)
        )

    def forward(self, x, edge_index, batch_idx):
        """
        Inputs:
            x - Input features per node
            edge_index - List of vertex index pairs representing the edges in the graph (PyTorch geometric notation)
            batch_idx - Index of batch element for each node
        """
        x = self.GNN(x, edge_index)
        x = geom_nn.global_mean_pool(x, batch_idx) # Average pooling
        x = self.head(x)
        return x

In [None]:
class GraphLevelGNN(pl.LightningModule):

    def __init__(self, binary_classification=True, **model_kwargs):
        super().__init__()
        # Saving hyperparameters
        self.save_hyperparameters()

        self.model = GraphGNNModel(**model_kwargs)
        if not binary_classification:
            self.loss_module = F.mse_loss
        else:
            self.loss_module = nn.BCEWithLogitsLoss() if self.hparams.c_out == 1 else nn.CrossEntropyLoss()
        self.__dict__.update(locals()) # Add arguments as attributes


    def forward(self, data, y, mode='train'):
        x, edge_index, batch_idx = data.x.float(), data.edge_index, data.batch
        x = self.model(x, edge_index, batch_idx)
        x = x.squeeze(dim=-1)
        y = y.squeeze(dim=-1)

        if self.hparams.c_out == 1:
            preds = (x > 0).float()
            y = y.float()
        else:
            preds = x.argmax(dim=-1)
        loss = self.loss_module(x, y)
        
        if self.binary_classification:
            acc = (preds == y).sum().float() / preds.shape[0]
        else:
            acc = ((preds >= 50) == (y >= 50)).sum().float() / preds.shape[0]
        return loss, preds, acc

    def configure_optimizers(self):
        # optimizer = torch.optim.AdamW(self.parameters(), lr=1e-2, weight_decay=0.0) # High lr because of small dataset and small model
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3) # High lr because of small dataset and small model
        return optimizer

    def training_step(self, batch, batch_idx):
        loss, _, acc = self.forward(batch['smiles_graph'], batch['labels'], mode='train')
        self.log('train_loss', loss)
        self.log('train_acc', acc)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, _, acc = self.forward(batch['smiles_graph'], batch['labels'], mode='val')
        self.log('val_loss', loss)
        self.log('val_acc', acc)

    def test_step(self, batch, batch_idx):
        loss, _, acc = self.forward(batch['smiles_graph'], batch['labels'], mode='test')
        self.log('test_loss', loss)
        self.log('test_acc', acc)

In [None]:
BINARY_CLASSIFICATION = False

train_dataset = ProtacDataset(train_df,
                              include_smiles_as_graphs=True,
                              task='predict_pDC50')
test_dataset = ProtacDataset(val_df,
                             include_smiles_as_graphs=True,
                             task='predict_pDC50')

In [None]:
graph_train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, collate_fn=custom_collate)
graph_val_loader = DataLoader(test_dataset, batch_size=128, collate_fn=custom_collate)
graph_test_loader = DataLoader(test_dataset, batch_size=128, collate_fn=custom_collate)

In [None]:
CHECKPOINT_PATH = '.'
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

def train_graph_classifier(model_name, **model_kwargs):
    pl.seed_everything(42)

    # Create a PyTorch Lightning trainer with the generation callback
    root_dir = os.path.join(checkpoint_dir, 'GraphLevel' + model_name)
    os.makedirs(root_dir, exist_ok=True)
    trainer = pl.Trainer(default_root_dir=root_dir,
                         callbacks=[ModelCheckpoint(save_weights_only=True, mode='max', monitor='val_acc')],
                         accelerator='gpu' if str(device).startswith('cuda') else 'cpu',
                         devices=1,
                         max_epochs=3, # 500,
                         log_every_n_steps=8,
                         logger=CSVLogger(save_dir='logs/'),
                         enable_progress_bar=True)
    # trainer.logger._default_hp_metric = None # Optional logging argument that we don't need

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(checkpoint_dir, f'GraphLevel{model_name}.ckpt')
    if os.path.isfile(pretrained_filename):
        print('Found pretrained model, loading...')
        model = GraphLevelGNN.load_from_checkpoint(pretrained_filename)
    else:
        pl.seed_everything(42)
        model = GraphLevelGNN(c_in=num_node_features,
                              c_out=1, # if tu_dataset.num_classes==2 else tu_dataset.num_classes,
                              binary_classification=BINARY_CLASSIFICATION,
                              **model_kwargs)
        trainer.fit(model, graph_train_loader, graph_val_loader)
        model = GraphLevelGNN.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
    # Test best model on validation and test set
    train_result = trainer.test(model, graph_train_loader, verbose=False)
    test_result = trainer.test(model, graph_test_loader, verbose=False)
    result = {'test': test_result[0]['test_acc'], 'train': train_result[0]['test_acc']}
    return model, result, trainer

In [None]:
!PYTORCH_NO_CUDA_MEMORY_CACHING=1
!CUDA_LAUNCH_BLOCKING=1
model, result, trainer = train_graph_classifier(model_name='GraphConv',
                                                c_hidden=128,
                                                layer_name='GraphConv',
                                                num_layers=3,
                                                dp_rate_linear=0.5,
                                                dp_rate=0.0)

In [None]:
print(f"Train performance: {100.0*result['train']:4.2f}%")
print(f"Test performance:  {100.0*result['test']:4.2f}%")

In [None]:
metrics = pd.read_csv(f'{trainer.logger.log_dir}/metrics.csv')
del metrics['step']
metrics.set_index('epoch', inplace=True)
display(metrics.dropna(axis=1, how='all').head())
g = sns.relplot(data=metrics, kind='line')
g = plt.grid(alpha=0.7)
plt.show()

In [None]:
predictions = []
y = []
# Make predictions and plot
with torch.no_grad():
    _ = model.eval()
    for batch in graph_test_loader:
        _, preds, _ = model(batch['smiles'], batch['labels'])
        predictions.extend(preds.detach().tolist())
        y.extend(batch['labels'].detach().tolist())
predictions = np.array(predictions).flatten()
y = np.array(y).flatten()
sorted_idx = np.argsort(y)

# g = plt.scatter(np.arange(len(y)), predictions[sorted_idx], label='predictions', marker='d')
# g = plt.scatter(np.arange(len(y)), y[sorted_idx], label='labels', marker='x')
g = plt.plot(np.arange(len(y)), predictions[sorted_idx], label='predictions')
g = plt.plot(np.arange(len(y)), y[sorted_idx], label='labels')
g = plt.legend()
g = plt.grid(alpha=0.8)
g = plt.xlabel('Test ID (sorted by degradation perc.)')
g = plt.ylabel('Degradation (%)')
plt.show()

### SSL via Dive-Into-Graphs (DIG)

The intuition behind Graph Contrastive Learning (CL) (from this [poster](https://yyou1996.github.io/files/neurips2020_graphcl_poster.pdf))

Following [this guide](https://pytorch-geometric.readthedocs.io/en/latest/tutorial/create_dataset.html) for constructing a Pytorch Geometric Dataset.





In [None]:
import torch
from torch_geometric.data import InMemoryDataset, download_url
import shutil

class ProtacGeomDataset(InMemoryDataset):
    def __init__(self, root=None, transform=None, pre_transform=None,
                 pre_filter=None,
                 protac_db_csv='protac-db_leftovers.csv',
                 use_for_ssl=True,
                 binary_classification=False):
        self.__dict__.update(locals()) # Add arguments as attributes
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return [self.protac_db_csv]

    @property
    def processed_file_names(self):
        protac_db_file = os.path.splitext(os.path.join(self.raw_dir, self.protac_db_csv))[0]
        protac_pt_file = protac_db_file + ('_bin' if self.binary_classification else '') + '.pt'
        return [protac_pt_file]

    def download(self):
        # Download to `self.raw_dir`.
        src = os.path.join(self.root, self.protac_db_csv)
        dst = os.path.join(self.raw_dir, self.protac_db_csv)
        shutil.copy(src, dst)
        # TODO: Implement this method when the project goes opensource
        # download_url(url, self.raw_dir)

    def process(self):
        # Read data into huge `Data` list.
        df_file = os.path.join(self.raw_dir, self.protac_db_csv)
        dataframe = pd.read_csv(df_file)
        dataframe = dataframe.dropna(subset=['Smiles_nostereo'])
        if self.use_for_ssl:
            dataframe = dataframe.drop_duplicates(subset=['Smiles_nostereo'])
        smiles = dataframe['Smiles_nostereo']
        data_list = [from_smiles(s) for s in smiles]
        # Convert x features to float
        for d in data_list:
            d.x = d.x.to(torch.float32)
        # Store labels for each graph
        if not self.use_for_ssl:
            if self.binary_classification:
                y_data = dataframe['active'].astype(np.compat.long)
            else:
                y_data = (dataframe['degradation'].astype(np.float32) * 0.01).to_numpy()[..., None]
            for d, y in zip(data_list, y_data):
                d.y = y
        # Filter entries
        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]
        # Apply pre-transformation
        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]
        # Finally save the Pytorch Geometric Data entries
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

In [None]:
# Save pandas to csv
train_bin_df.to_csv(os.path.join(data_dir, 'protac', 'protac-db_curated_train_bin.csv'), index=False)
train_df.to_csv(os.path.join(data_dir, 'protac', 'protac-db_curated_train.csv'), index=False)
train_upsampled_bin_df.to_csv(os.path.join(data_dir, 'protac', 'protac-db_curated_train_upsampled_bin.csv'), index=False)
val_bin_df.to_csv(os.path.join(data_dir, 'protac', 'protac-db_curated_test_bin.csv'), index=False)
val_df.to_csv(os.path.join(data_dir, 'protac', 'protac-db_curated_test.csv'), index=False)

In [None]:
train_dataset_bin = ProtacGeomDataset(os.path.join(data_dir, 'protac'),
                                      protac_db_csv='protac-db_curated_train_bin.csv',
                                      use_for_ssl=False,
                                      binary_classification=True)
test_dataset_bin = ProtacGeomDataset(os.path.join(data_dir, 'protac'),
                                     protac_db_csv='protac-db_curated_test_bin.csv',
                                     use_for_ssl=False,
                                     binary_classification=True)
ssl_dataset_bin = ProtacGeomDataset(os.path.join(data_dir, 'protac'),
                                    protac_db_csv='protac-db_ssl.csv',
                                    use_for_ssl=True,
                                    binary_classification=True)
ssl_dataset = ProtacGeomDataset(os.path.join(data_dir, 'protac'),
                                protac_db_csv='protac-db_ssl.csv',
                                use_for_ssl=True,
                                binary_classification=False)
# TODO: Only work on binary classification for now
# train_dataset = ProtacGeomDataset(os.path.join(data_dir, 'protac'),
#                                   protac_db_csv='protac-db_curated_train.csv',
#                                   use_for_ssl=False,
#                                   binary_classification=False)
# test_dataset = ProtacGeomDataset(os.path.join(data_dir, 'protac'),
#                                  protac_db_csv='protac-db_curated_test.csv',
#                                  use_for_ssl=False,
#                                  binary_classification=False)

The following SSL code comes from [this guide](https://diveintographs.readthedocs.io/en/latest/tutorials/sslgraph.html#id10).

In [None]:
from dig.sslgraph.utils import Encoder
from dig.sslgraph.method import GraphCL, GRACE
from dig.sslgraph.evaluation import GraphSemisupervised, GraphUnsupervised

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

feat_dim = ssl_dataset[0].x.shape[1]
gnn_type = 'resgcn' # Possible values: resgcn | gcn | gin
n_layers = 4
embed_dim = 128
# Define SMILES encoder
encoder = Encoder(feat_dim, embed_dim, n_layers=n_layers, gnn=gnn_type)
# Define prediction head as a linear layer
# NOTE: Unless using 'resgcn', the other predefined GNNs return an output of
# shape: (batch_size, embedded_size * n_layers)
embed_dim = embed_dim if gnn_type == 'resgcn' else embed_dim * n_layers
prediction_head = nn.Linear(embed_dim, 2)

[List](https://diveintographs.readthedocs.io/en/latest/sslgraph/method.html#dig.sslgraph.method.GraphCL) of Graph Constrastive Learning (CL) techniques.

A [poster](https://yyou1996.github.io/files/neurips2020_graphcl_poster.pdf) with the intuitions behind Graph CL.

In [None]:
# graphcl = GraphCL(embed_dim, aug_1='subgraph', aug_2='dropN')
graphcl = GRACE(embed_dim, dropE_rate_1=0.5, dropE_rate_2=0.5, maskN_rate_1=0.5, maskN_rate_2=0.5)

In [None]:
# # Define the evaluator, i.e., trainer, and its configuration parameters
# evaluator = GraphUnsupervised(train_dataset, log_interval=10, device=device,
#                               p_optim='Adam',
#                               p_lr=1e-3,
#                               p_weight_decay=0.9,
#                               p_epoch=10)
# evaluator.evaluate(learning_model=graphcl,
#                    encoder=encoder)

[List](https://diveintographs.readthedocs.io/en/latest/sslgraph/evaluation.html#dig-sslgraph-evaluation) of Graph CL trainers.

In [None]:
evaluator = GraphSemisupervised(train_dataset_bin, ssl_dataset_bin, label_rate=0.3,
                                n_folds=10, device=device)
evaluator.setup_train_config(batch_size=256,
                             p_lr=1e-5,
                             p_epoch=5,
                             f_lr=1e-5,
                             f_epoch=5)
test_acc, _ = evaluator.evaluate(learning_model=graphcl,
                                           encoder=encoder,
                                           pred_head=prediction_head,
                                           fold_seed=42)
print(f'Test accuracy: {test_acc:.2f}')

In [None]:
model = nn.Sequential(encoder, prediction_head, nn.Softmax(dim=1))

predictions = []
targets = []
# Make predictions
with torch.no_grad():
    model.eval()
    for batch in test_dataset_bin:
        model_pred = model(batch.to(device)).argmax(dim=1).detach().tolist()
        predictions.extend(model_pred)
        targets.extend(batch.y.detach().tolist())
predictions = np.array(predictions).flatten()
targets = np.array(targets).flatten()
sorted_idx = np.argsort(targets)
# Get accuracy
accuracy = Accuracy(task='binary')
acc = accuracy(torch.Tensor(predictions), torch.Tensor(targets))
# Plot predicitons (sorted)
plt.plot(predictions[sorted_idx], label='Predicted degradation activity')
plt.plot(targets[sorted_idx], label='Reference degradation activity')
plt.legend()
plt.grid(alpha=0.8)
plt.xlabel('Test ID (sorted by degradation perc.)')
plt.ylabel('Degradation')
plt.title(f'Predicting PROTAC activation (Accuracy: {acc:.2f})')
plt.show()

Define Optuna objective:

In [None]:
feat_dim = ssl_dataset[0].x.shape[1]
accuracy = Accuracy(task='binary')

def run_objective_body(**kwargs):
    batch_size = kwargs['batch_size']
    label_rate = kwargs['label_rate']
    p_lr = kwargs['p_lr']
    p_epoch = kwargs['p_epoch']
    f_lr = kwargs['f_lr']
    f_epoch = kwargs['f_epoch']
    gnn_type = kwargs['gnn_type']
    n_layers = kwargs['n_layers']
    embed_dim = kwargs['embed_dim']
    # Define SMILES encoder
    encoder = Encoder(feat_dim, embed_dim, n_layers=n_layers, gnn=gnn_type)
    # Define prediction head as a linear layer
    # NOTE: Unless using 'resgcn', the other predefined GNNs return an output of
    # shape: (batch_size, embedded_size * n_layers)
    embed_dim = embed_dim if gnn_type == 'resgcn' else embed_dim * n_layers
    prediction_head = nn.Linear(embed_dim, 2)
    # Define Contrastive Learning framework
    if kwargs['cl_type'] == 'graphcl':
        graphcl = GraphCL(embed_dim, aug_1='subgraph', aug_2='dropN')
    elif kwargs['cl_type'] == 'grace':
        graphcl = GRACE(embed_dim, dropE_rate_1=0.5, dropE_rate_2=0.5, maskN_rate_1=0.5, maskN_rate_2=0.5)
    # Define trainer
    evaluator = GraphSemisupervised(train_dataset_bin, ssl_dataset_bin,
                                    label_rate=label_rate,
                                    device=device)
    evaluator.setup_train_config(batch_size=batch_size,
                                 p_lr=p_lr,
                                 p_epoch=p_epoch,
                                 f_lr=f_lr,
                                 f_epoch=f_epoch)
    train_acc, _ = evaluator.evaluate(learning_model=graphcl,
                                      encoder=encoder,
                                      pred_head=prediction_head)
    # Test model
    model = nn.Sequential(encoder, prediction_head, nn.Softmax(dim=1))
    predictions = []
    targets = []
    # Make predictions
    with torch.no_grad():
        model.eval()
        for batch in test_dataset_bin:
            model_pred = model(batch.to(device)).argmax(dim=1).detach().tolist()
            predictions.extend(model_pred)
            targets.extend(batch.y.detach().tolist())
    predictions = np.array(predictions).flatten()
    targets = np.array(targets).flatten()
    # Get test accuracy
    test_acc = accuracy(torch.Tensor(predictions), torch.Tensor(targets))
    return train_acc, test_acc, predictions, targets

def objective(trial):
    params = {
        'label_rate': trial.suggest_float('label_rate', 0.3, 1.0),
        'p_lr': trial.suggest_float('p_lr', 1e-6, 1e-3, log=True),
        'p_epoch': trial.suggest_int('p_epoch', 10, 31),
        'f_lr': trial.suggest_float('f_lr', 1e-6, 1e-3, log=True),
        'f_epoch': trial.suggest_int('f_epoch', 10, 31),
        'gnn_type': trial.suggest_categorical('gnn_type', ['resgcn', 'gcn', 'gin']),
        'n_layers': trial.suggest_int('n_layers', 3, 10),
        'embed_dim': trial.suggest_categorical('embed_dim', [64 + 128 * i for i in range(6)]),
        'batch_size': trial.suggest_categorical('batch_size', [64, 128, 256]),
        'cl_framework': trial.suggest_categorical('cl_framework', ['graphcl', 'grace']),
    }
    train_acc, test_acc, predictions, targets = run_objective_body(**params)
    # Log parameters
    trial.set_user_attr('train_acc', train_acc)
    trial.set_user_attr('test_acc', test_acc)
    print(f'Training accuracy: {train_acc:.2f}')
    print(f'Test accuracy: {test_acc:.2f}')
    return test_acc

In [None]:
study = optuna.create_study(direction='maximize',
                            pruner=optuna.pruners.MedianPruner(n_warmup_steps=5),
                            sampler=optuna.samplers.TPESampler(seed=42))
study.optimize(objective, n_trials=10, timeout=600)
print('-' * 80)
print('Number of finished trials: {}'.format(len(study.trials)))
print('Best trial:')
print(f'\tValue: {study.best_trial.value}')
print('\tParams: ')
for key, value in study.best_trial.params.items():
    print(f'\t\t* {key}: {value}')

In [None]:
ret = run_objective_body(**study.best_trial.params)
train_acc, test_acc, predictions, targets = ret
sorted_idx = np.argsort(targets)
# Get accuracy
accuracy = Accuracy(task='binary')
acc = accuracy(torch.Tensor(predictions), torch.Tensor(targets))
# Plot predicitons (sorted)
plt.plot(predictions[sorted_idx], label='Predicted degradation activity')
plt.plot(targets[sorted_idx], label='Reference degradation activity')
plt.legend()
plt.grid(alpha=0.8)
plt.xlabel('Test ID (sorted by degradation perc.)')
plt.ylabel('Degradation')
plt.title(f'Predicting PROTAC activation (Accuracy: {test_acc:.2f})')
plt.show()

#### GraphCL

Define a custom Pytorch Geometric Dataset for SSL.

Following [this guide](https://pytorch-geometric.readthedocs.io/en/latest/tutorial/create_dataset.html) for constructing a Pytorch Geometric Dataset.

In [None]:
from torch_geometric.data import InMemoryDataset, download_url

class ProtacGeomDataset(InMemoryDataset):

    def __init__(self,
                 root: str = None,
                 transform: Callable | None = None,
                 pre_transform: Callable | None = None,
                 pre_filter: Callable | None = None,
                 task: Literal['predict_active_inactive', 'predict_pDC50_and_Dmax'] = 'predict_active_inactive',
                 protac_df: pd.DataFrame | None = None,
                 protac_db_csv: str = 'protac-db_ssl.csv',
                 use_for_ssl: bool = True):
        """Protac Pytorch Geometric dataset for SSL.
        
            Args:
            root (string): Root directory where the dataset should be saved.
            transform (callable, optional): A function/transform that takes in
                an :obj:`torch_geometric.data.Data` object and returns a
                transformed version. The data object will be transformed before
                every access. (default: :obj:`None`)
            pre_transform (callable, optional): A function/transform that takes
                in an :obj:`torch_geometric.data.Data` object and returns a
                transformed version. The data object will be transformed before
                being saved to disk. (default: :obj:`None`)
            pre_filter (callable, optional): A function that takes in an
                :obj:`torch_geometric.data.Data` object and returns a boolean
                value, indicating whether the data object should be included in
                the final dataset. (default: :obj:`None`)
            protac_db_csv (str): Path to the protac database file in csv format.
                The file must contain at least the following columns: 'SMILES',
                'active'.
            use_for_ssl (bool): Whether to use the dataset for self-supervised
                learning. If False, each entry will have its `y` element
                assigned. (default: :obj:`True`)
        """
        self.__dict__.update(locals()) # Add arguments as attributes
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return [self.protac_db_csv]

    @property
    def processed_file_names(self):
        protac_db_file = os.path.splitext(os.path.join(self.raw_dir, self.protac_db_csv))[0]
        protac_pt_file = protac_db_file + ('_bin' if self.task == 'predict_active_inactive' else '') + '.pt'
        return [protac_pt_file]

    def download(self):
        # Download to `self.raw_dir`.
        src = os.path.join(self.root, self.protac_db_csv)
        dst = os.path.join(self.raw_dir, self.protac_db_csv)
        shutil.copy(src, dst)
        # TODO: Implement this method when the project goes opensource
        # download_url(url, self.raw_dir)

    def process(self):
        if self.protac_df is not None:
            df_file = os.path.join(self.raw_dir, self.protac_db_csv)
            dataframe = pd.read_csv(df_file)
        else:
            dataframe = self.protac_df.copy()
            dataframe = dataframe.dropna(subset=['Smiles_nostereo'])
        if self.use_for_ssl:
            dataframe = dataframe.drop_duplicates(subset=['Smiles_nostereo'])
        # Read data into huge list of `Data` type elements.
        smiles = dataframe['Smiles_nostereo']
        data_list = [from_smiles(s) for s in smiles]
        # Convert x features to float
        for d in data_list:
            d.x = d.x.to(torch.float32)
        # Store labels for each graph
        if not self.use_for_ssl:
            if self.task == 'predict_active_inactive':
                y_data = dataframe['active'].astype(np.compat.long)
            elif self.task == 'predict_pDC50':
                y_data = dataframe.pDC50.to_numpy()[..., None]
            elif self.task == 'predict_pDC50_and_Dmax':
                Dmax = dataframe.Dmax.to_numpy()
                pDC50 = dataframe.pDC50.to_numpy()
                y_data = np.array([Dmax, pDC50])
            else:
                raise ValueError(f'Task "{self.task}" not recognized. Available: "predict_pDC50" \| "predict_active_inactive" \| "predict_pDC50_and_Dmax"')
            for d, y in zip(data_list, y_data):
                d.y = y
        # Filter entries
        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]
        # Apply pre-transformation
        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]
        # Finally save the Pytorch Geometric Data entries
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

In [None]:
def get_train_test_geom_datasets(task, use_for_ssl=False, dataset_name='', use_upsampled=False):
    # Get naming convention for the datasets
    if task == 'predict_pDC50_and_Dmax':
        train_ds = os.path.join(data_dir, 'protac', f'train_pDC50_Dmax_geom_dataset{dataset_name}.pt')
        test_ds = os.path.join(data_dir, 'protac', f'test_pDC50_Dmax_geom_dataset{dataset_name}.pt')
    if task == 'predict_pDC50':
        train_ds = os.path.join(data_dir, 'protac', f'train_pDC50_geom_dataset{dataset_name}.pt')
        test_ds = os.path.join(data_dir, 'protac', f'test_pDC50_geom_dataset{dataset_name}.pt')
    elif task == 'predict_active_inactive':
        test_ds = os.path.join(data_dir, 'protac', f'test_bin_geom_dataset{dataset_name}.pt')
        if use_upsampled:
            train_ds = os.path.join(data_dir, 'protac', f'train_upsampled_bin_geom_dataset{dataset_name}.pt')
        else:
            train_ds = os.path.join(data_dir, 'protac', f'train_bin_geom_dataset{dataset_name}.pt')
    # Get specific dataframes acconding to the task
    if task == 'predict_active_inactive':
        train_df_tmp = train_upsampled_bin_df if use_upsampled else train_bin_df
        val_df_tmp = val_bin_df
    else:
        train_df_tmp = train_df
        val_df_tmp = val_df
    # Generate and save Pytorch Geometric specific datasets
    train_geom_dataset = ProtacGeomDataset(os.path.join(data_dir, 'protac'),
                                           protac_df=train_df_tmp,
                                           task=task,
                                           use_for_ssl=use_for_ssl)
    test_geom_dataset = ProtacGeomDataset(os.path.join(data_dir, 'protac'),
                                          protac_df=val_df_tmp,
                                          task=task,
                                          use_for_ssl=use_for_ssl)
    torch.save(train_geom_dataset, train_ds)
    torch.save(test_geom_dataset, test_ds)
    return train_geom_dataset, test_geom_dataset

In [None]:
from dig.sslgraph.utils import Encoder
from dig.sslgraph.method import GraphCL, GRACE
from dig.sslgraph.evaluation import GraphSemisupervised, GraphUnsupervised
from dig.sslgraph.method import Contrastive

class SSLTrainer(pl.LightningModule):

    def __init__(self,
                 feat_dim:int=datasets['ssl'][0].x.shape[1],
                 gnn_type:str='resgcn', # Possible values: resgcn | gcn | gin
                 n_layers:int=4,
                 embed_dim:int=768,
                 binary_classification:bool=True,
                 batch_size:int=64,
                 learning_rate:float=1e-3,
                 dropE_rate_1=0.5,
                 dropE_rate_2=0.5,
                 maskN_rate_1=0.5,
                 maskN_rate_2=0.5,
                 train_dataset=datasets['ssl'],
                 test_dataset=datasets['test_ssl'],
                 **kwargs):
        super().__init__()
        # Save the arguments passed to init
        self.save_hyperparameters()
        # Set our init args as class attributes
        self.__dict__.update(locals()) # Add arguments as attributes
        # Define PyTorch models
        self.encoder = Encoder(feat_dim, embed_dim, n_layers=n_layers,
                               gnn=gnn_type)
        self.out_dim = 2 if binary_classification else 2
        if gnn_type == 'resgcn':
            self.embed_dim = embed_dim
        else:
            self.embed_dim = embed_dim * n_layers
        self.head = nn.Linear(embed_dim, self.out_dim)
        # Loss and evaluation metrics
        self.val_acc = Accuracy('binary')
        self.test_acc = Accuracy('binary')
        self.val_mse = MeanSquaredError()
        self.test_mse = MeanSquaredError()

        # # Constrastive Learning model
        # self.graphcl = GRACE(embed_dim, dropE_rate_1=0.5, dropE_rate_2=0.5,
        #                      maskN_rate_1=0.5, maskN_rate_2=0.5)
        self.graphcl = GraphCL(embed_dim, aug_1='subgraph', aug_2='dropN')
        # Override the train() method in graphcl and allow Lightning to take
        # care of the entire training
        def void(self):
            pass
        self.graphcl.train = void
        # Get projection head dimensions
        if self.graphcl.z_n_dim is None:
            self.graphcl.proj_out_dim = self.graphcl.z_dim
        else:
            self.graphcl.proj_out_dim = self.graphcl.z_n_dim
        # Get graph-level project head
        if self.graphcl.graph_level and self.graphcl.proj is not None:
            self.graphcl.proj_head_g = self.graphcl._get_proj(self.graphcl.proj,
                                                              self.graphcl.z_dim)
        elif self.graphcl.graph_level:
            self.graphcl.proj_head_g = lambda x: x
        else:
            self.graphcl.proj_head_g = None
        # Get node-level project head
        if self.graphcl.node_level and self.graphcl.proj_n is not None:
            self.graphcl.proj_head_n = self.graphcl._get_proj(self.graphcl.proj_n,
                                                              self.graphcl.z_n_dim)
        elif self.graphcl.node_level:
            self.graphcl.proj_head_n = lambda x: x
        else:
            self.graphcl.proj_head_n = None

    def forward(self, x_in):
        x = self.encoder(x_in)
        return self.head(x)

    def training_step(self, data, batch_idx):
        # output of each encoder should be Tensor for graph-level embedding
        if isinstance(self.encoder, list):
            assert len(self.encoder) == len(self.graphcl.views_fn)
            encoders = self.encoder
            [enc.train() for enc in encoders]
        else:
            self.encoder.train()
            encoders = [self.encoder] * len(self.graphcl.views_fn)
        # Update projection heads, if possible
        try:
            if self.graphcl.node_level and self.graphcl.graph_level:
                self.graphcl.proj_head_g.train()
                self.graphcl.proj_head_n.train()
            elif self.graphcl.graph_level:
                self.graphcl.proj_head_g.train()
            else:
                self.graphcl.proj_head_n.train()
        except:
            pass
        # Assemble graph views
        if None in self.graphcl.views_fn:
            views = []
            for v_fn in self.graphcl.views_fn:
                # For view fn that returns multiple views
                if v_fn is not None:
                    views += [*v_fn(data)]
            assert len(views) == len(encoders)
        else:
            views = [v_fn(data) for v_fn in self.graphcl.views_fn]
        # Get embeddings per views
        zs_n, zs_g = [], []
        for view, enc in zip(views, encoders):
            # Run encoder
            if self.graphcl.node_level and self.graphcl.graph_level:
                z_g, z_n = self.graphcl._get_embed(enc, view)
                zs_n.append(self.graphcl.proj_head_n(z_n))
                zs_g.append(self.graphcl.proj_head_g(z_g))
            elif self.graphcl.graph_level:
                z_g = self.graphcl._get_embed(enc, view)
                zs_g.append(self.graphcl.proj_head_g(z_g))
            else:
                z_n = self.graphcl._get_embed(enc, view)
                zs_n.append(self.graphcl.proj_head_n(z_n))
        # Get loss
        if self.graphcl.node_level and self.graphcl.graph_level:
            loss = self.graphcl.loss_fn(zs_g, zs_n=zs_n, batch=data.batch,
                                        neg_by_crpt=self.graphcl.neg_by_crpt,
                                        tau=self.graphcl.tau)
        elif self.graphcl.graph_level:
            loss = self.graphcl.loss_fn(zs_g,
                                        neg_by_crpt=self.graphcl.neg_by_crpt,
                                        tau=self.graphcl.tau)
        else:
            loss = self.graphcl.loss_fn(zs_g=None, zs_n=zs_n, batch=data.batch,
                                        neg_by_crpt=self.graphcl.neg_by_crpt,
                                        tau=self.graphcl.tau)
        # Reporting and return
        self.log(f'train_loss', loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x = self.encoder(batch)
        y_hat = self.head(x)
        if self.binary_classification:
            y_hat = nn.functional.softmax(y_hat, dim=1).argmax(dim=1)
            acc = self.val_acc(y_hat, batch.y)
            self.log(f'val_acc', acc, prog_bar=True)
            return acc
        else:
            y_hat = torch.Tensor(y_hat)
            y = torch.Tensor(batch.y)
            mse = self.val_mse(y_hat, y)
            self.log(f'val_mse', mse, prog_bar=True)
            return mse

    def test_step(self, batch, batch_idx):
        x = self.encoder(batch)
        y_hat = self.head(x)
        if self.binary_classification:
            y_hat = nn.functional.softmax(y_hat, dim=1).argmax(dim=1)
            acc = self.test_acc(y_hat, batch.y)
            self.log(f'test_acc', acc, prog_bar=True)
            return acc
        else:
            y_hat = torch.Tensor(y_hat)
            y = torch.Tensor(batch.y)
            mse = self.test_mse(y_hat, y)
            self.log(f'test_mse', mse, prog_bar=True)
            return mse

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=custom_collate)

    def val_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, collate_fn=custom_collate)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, collate_fn=custom_collate)

    def predict_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, collate_fn=custom_collate)

In [None]:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from dig.sslgraph.method.contrastive.views_fn import NodeAttrMask, EdgePerturbation, Sequential

dropE_rate_1 = 0.5
dropE_rate_2 = 0.5
maskN_rate_1 = 0.5
maskN_rate_2 = 0.5
model = SSLTrainer(embed_dim=32, batch_size=128, binary_classification=False)

callbacks = [
    TQDMProgressBar(refresh_rate=20),
    # EarlyStopping(monitor='val_loss', mode='min'),
    # ModelCheckpoint(save_weights_only=True, mode='min', monitor='val_loss'),
]

trainer = pl.Trainer(max_epochs=2,
                     gradient_clip_val=1.0,
                     gradient_clip_algorithm='norm',
                     accelerator='auto',
                     devices=1 if torch.cuda.is_available() else 1,
                     log_every_n_steps=8,
                     callbacks=callbacks,
                     logger=CSVLogger(save_dir='logs/'),
                     deterministic=True)
trainer.fit(model)#### Pythorch Lightning
from dig.sslgraph.utils import Encoder
from dig.sslgraph.method import GraphCL, GRACE
from dig.sslgraph.evaluation import GraphSemisupervised, GraphUnsupervised
from dig.sslgraph.method import Contrastive

class SSLTrainer(pl.LightningModule):

    def __init__(self,
                 feat_dim:int=ssl_dataset[0].x.shape[1],
                 gnn_type:str='resgcn', # Possible values: resgcn | gcn | gin
                 n_layers:int=4,
                 embed_dim:int=768,
                 binary_classification:bool=True,
                 batch_size:int=64,
                 learning_rate:float=1e-3,
                 dropE_rate_1=0.5,
                 dropE_rate_2=0.5,
                 maskN_rate_1=0.5,
                 maskN_rate_2=0.5,
                 **kwargs):
        super().__init__()
        # Save the arguments passed to init
        self.save_hyperparameters()
        # Set our init args as class attributes
        self.__dict__.update(locals()) # Add arguments as attributes
        # Define PyTorch models
        self.encoder = Encoder(feat_dim, embed_dim, n_layers=n_layers,
                               gnn=gnn_type)
        self.out_dim = 2 if binary_classification else 1
        if gnn_type == 'resgcn':
            self.embed_dim = embed_dim
        else:
            self.embed_dim = embed_dim * n_layers
        self.head = nn.Linear(embed_dim, self.out_dim)
        # Loss and evaluation metrics
        self.val_acc = Accuracy('binary')
        self.test_acc = Accuracy('binary')
        self.val_mse = MeanSquaredError()
        self.test_mse = MeanSquaredError()

        # # Constrastive Learning model
        # self.graphcl = GRACE(embed_dim, dropE_rate_1=0.5, dropE_rate_2=0.5,
        #                      maskN_rate_1=0.5, maskN_rate_2=0.5)
        self.graphcl = GraphCL(embed_dim, aug_1='subgraph', aug_2='dropN')
        # Override the train() method in graphcl and allow Lightning to take
        # care of the entire training
        def void(self):
            pass
        self.graphcl.train = void
        # Get projection head dimensions
        if self.graphcl.z_n_dim is None:
            self.graphcl.proj_out_dim = self.graphcl.z_dim
        else:
            self.graphcl.proj_out_dim = self.graphcl.z_n_dim
        # Get graph-level project head
        if self.graphcl.graph_level and self.graphcl.proj is not None:
            self.graphcl.proj_head_g = self.graphcl._get_proj(self.graphcl.proj,
                                                              self.graphcl.z_dim)
        elif self.graphcl.graph_level:
            self.graphcl.proj_head_g = lambda x: x
        else:
            self.graphcl.proj_head_g = None
        # Get node-level project head
        if self.graphcl.node_level and self.graphcl.proj_n is not None:
            self.graphcl.proj_head_n = self.graphcl._get_proj(self.graphcl.proj_n,
                                                              self.graphcl.z_n_dim)
        elif self.graphcl.node_level:
            self.graphcl.proj_head_n = lambda x: x
        else:
            self.graphcl.proj_head_n = None

    def forward(self, x_in):
        x = self.encoder(x_in)
        return self.head(x)

    def training_step(self, data, batch_idx):
        # output of each encoder should be Tensor for graph-level embedding
        if isinstance(self.encoder, list):
            assert len(self.encoder) == len(self.graphcl.views_fn)
            encoders = self.encoder
            [enc.train() for enc in encoders]
        else:
            self.encoder.train()
            encoders = [self.encoder] * len(self.graphcl.views_fn)
        # Update projection heads, if possible
        try:
            if self.graphcl.node_level and self.graphcl.graph_level:
                self.graphcl.proj_head_g.train()
                self.graphcl.proj_head_n.train()
            elif self.graphcl.graph_level:
                self.graphcl.proj_head_g.train()
            else:
                self.graphcl.proj_head_n.train()
        except:
            pass
        # Assemble graph views
        if None in self.graphcl.views_fn:
            views = []
            for v_fn in self.graphcl.views_fn:
                # For view fn that returns multiple views
                if v_fn is not None:
                    views += [*v_fn(data)]
            assert len(views) == len(encoders)
        else:
            views = [v_fn(data) for v_fn in self.graphcl.views_fn]
        # Get embeddings per views
        zs_n, zs_g = [], []
        for view, enc in zip(views, encoders):
            # Run encoder
            if self.graphcl.node_level and self.graphcl.graph_level:
                z_g, z_n = self.graphcl._get_embed(enc, view)
                zs_n.append(self.graphcl.proj_head_n(z_n))
                zs_g.append(self.graphcl.proj_head_g(z_g))
            elif self.graphcl.graph_level:
                z_g = self.graphcl._get_embed(enc, view)
                zs_g.append(self.graphcl.proj_head_g(z_g))
            else:
                z_n = self.graphcl._get_embed(enc, view)
                zs_n.append(self.graphcl.proj_head_n(z_n))
        # Get loss
        if self.graphcl.node_level and self.graphcl.graph_level:
            loss = self.graphcl.loss_fn(zs_g, zs_n=zs_n, batch=data.batch,
                                        neg_by_crpt=self.graphcl.neg_by_crpt,
                                        tau=self.graphcl.tau)
        elif self.graphcl.graph_level:
            loss = self.graphcl.loss_fn(zs_g,
                                        neg_by_crpt=self.graphcl.neg_by_crpt,
                                        tau=self.graphcl.tau)
        else:
            loss = self.graphcl.loss_fn(zs_g=None, zs_n=zs_n, batch=data.batch,
                                        neg_by_crpt=self.graphcl.neg_by_crpt,
                                        tau=self.graphcl.tau)
        # Reporting and return
        self.log(f'train_loss', loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x = self.encoder(batch)
        y_hat = self.head(x)
        # print(f'[validation] y_hat/y: {y_hat} / {batch.y}')
        if self.binary_classification:
            y_hat = nn.functional.softmax(y_hat, dim=1).argmax(dim=1)
            acc = self.val_acc(y_hat, batch.y)
            self.log(f'val_acc', acc, prog_bar=True)
            return acc
        else:
            y_hat = torch.Tensor(y_hat)
            y = torch.Tensor(batch.y)
            mse = self.val_mse(y_hat, y)
            self.log(f'val_mse', mse, prog_bar=True)
            return mse

    def test_step(self, batch, batch_idx):
        x = self.encoder(batch)
        y_hat = self.head(x)
        # print(f'[test] y_hat/y: {y_hat} / {batch.y}')
        if self.binary_classification:
            y_hat = nn.functional.softmax(y_hat, dim=1).argmax(dim=1)
            acc = self.test_acc(y_hat, batch.y)
            self.log(f'test_acc', acc, prog_bar=True)
            return acc
        else:
            y_hat = torch.Tensor(y_hat)
            y = torch.Tensor(batch.y)
            mse = self.test_mse(y_hat, y)
            self.log(f'test_mse', mse, prog_bar=True)
            return mse

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    ####################
    # DATA RELATED HOOKS
    ####################

    # def prepare_data(self):
    #     # download
    #     MNIST(self.data_dir, train=True, download=True)
    #     MNIST(self.data_dir, train=False, download=True)

    def train_dataloader(self):
        if self.binary_classification:
            return torch_geometric.loader.DataLoader(ssl_dataset_bin, batch_size=self.batch_size, shuffle=True)
        else:
            return torch_geometric.loader.DataLoader(ssl_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        if self.binary_classification:
            return torch_geometric.loader.DataLoader(train_dataset_bin, batch_size=self.batch_size)
        else:
            return torch_geometric.loader.DataLoader(train_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        if self.binary_classification:
            return torch_geometric.loader.DataLoader(test_dataset_bin, batch_size=self.batch_size)
        else:
            return torch_geometric.loader.DataLoader(test_dataset, batch_size=self.batch_size)
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from dig.sslgraph.method.contrastive.views_fn import NodeAttrMask, EdgePerturbation, Sequential

dropE_rate_1 = 0.5
dropE_rate_2 = 0.5
maskN_rate_1 = 0.5
maskN_rate_2 = 0.5
model = SSLTrainer(embed_dim=256, batch_size=128, binary_classification=True)

callbacks = [
    TQDMProgressBar(refresh_rate=20),
    # EarlyStopping(monitor='val_loss', mode='min'),
    # ModelCheckpoint(save_weights_only=True, mode='min', monitor='val_loss'),
]
torch.use_deterministic_algorithms(False)

# NOTE: The DIG library is not suitable to be automated by Pytorch Lightning
trainer = pl.Trainer(max_epochs=5,
                     gradient_clip_val=1.0,
                     gradient_clip_algorithm='norm',
                     enable_progress_bar=True,
                     accelerator='cpu', # 'gpu' if torch.cuda.is_available() > 0 else 'auto',
                     precision='32', # '16-mixed' if torch.cuda.is_available() > 0 else '32',
                     log_every_n_steps=8,
                     callbacks=callbacks,
                     logger=CSVLogger(save_dir='logs/'))
trainer.fit(model)
metrics = pd.read_csv(f'{trainer.logger.log_dir}/metrics.csv')
del metrics['step']
metrics.set_index('epoch', inplace=True)
display(metrics.dropna(axis=1, how='all').head())
sns.relplot(data=metrics, kind='line')
plt.grid(alpha=0.7)
plt.show()