# Imports

In [None]:
import os
import pickle

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
pd.set_option('display.max_columns', 1000, 'display.width', 2000, 'display.max_colwidth', 100)


import torch
import torch.nn as nn
import torch.optim as optim

import optuna

from protacdataset import ProtacLoader
from protacsplitter import PROTACSplitter
from data_curation_augmentation_splitting_functions import (compute_countMorgFP)
from pipeline_functions import (aggregate_output_all_epochs, 
                                get_best_epoch,
                                avg)

In [None]:
torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load data

Loads training and validation data

Specify the butina cutoff.

Specify the path by the number of protacs "n" in the training and validation set

In [None]:
#Train: 952, 2856, 9520 #used sizes in the thesis
#Val. : 238, 714,  2380


butina_cutoff = "0.33" #0.0, 0.33, 0.67

path_substr = {
    "Train":            {'id': "train",             'n': "952",    'butina': butina_cutoff},
    "Validation":       {'id': "val",               'n': "238",    'butina': butina_cutoff},
}

dataset_paths = {dataset_name: f"../data/augmented/{substr_dict['id']}_{substr_dict['n']}_ButinaClusterCutoff_{substr_dict['butina']}.csv" for dataset_name, substr_dict in path_substr.items()}




chosen_dataset_paths = dataset_paths

dataset_names = chosen_dataset_paths.keys()



model_type="link_pred" # node_pred link_pred boundary_pred
node_descriptors = "rdkit" #ones, empty, rdkit
graph_descriptors = [ "betweenness", "closeness"] #["betweenness", "eigenvector"] etc. "betweenness", "closeness", "local_eigenvectors_x"
model_crossfold=False
save_datasets = False 
load_datasets = False
precompute_splits=False


In [None]:
ProtacDataset_loader = ProtacLoader(dataset_paths=chosen_dataset_paths, model_type=model_type, node_descriptors=node_descriptors, graph_descriptors=graph_descriptors, model_crossfold=model_crossfold)
if not load_datasets:
    datasets_dict = ProtacDataset_loader.initialize_datasets(dataframes=ProtacDataset_loader.load_dataframes(),
                                                                num_crossfolds=1, precompute_splits=precompute_splits)

# Optuna

In [None]:
def objective(trial, max_epochs=15, model_type=model_type):

    # --------------------- Training parameters ---------------------

    #Libraries set to None, as they are not important during optimization
    e3_library = None #pd.read_csv('../../data/e3_trainval_substructures_without_attachment.csv')["E3 SMILES"].to_list()
    poi_library = None #pd.read_csv('../../data/poi_trainval_substructures_without_attachment.csv')["POI SMILES"].to_list()

    #Early stopping parameters
    val_early_stopping = True #model looks back 2*n epochs to decide if it shall continue or stop, depending on validation loss. IF median has increased => Stop.
    median_over_n_val_losses = None
    min_over_n_val_losses = None

    shift = 0.6 # vary this by hand for different models to set various early stopping criteria. Set to 1 to disable this feature
    stop_if_val_acc_below_x_at_n_list = [ #{'val_frac_criteria': 0.42, 'n_epochs': 1}, # good cutoff for node predictions
                                                {'val_frac_criteria': 0.5-shift, 'n_epochs': 1}, #{'val_frac_criteria': 0.7, 'n_epochs': 2},
                                                {'val_frac_criteria': 0.65-shift, 'n_epochs': 2}, # {'val_frac_criteria': 0.8, 'n_epochs': 3}, 
                                                {'val_frac_criteria': 0.7-shift, 'n_epochs': 3},# {'val_frac_criteria': 0.9, 'n_epochs': 4}, 
                                                {'val_frac_criteria': 0.75-shift, 'n_epochs': 4},#{'val_frac_criteria': 0.65, 'n_epochs': 8},
                                                {'val_frac_criteria': 0.8-shift, 'n_epochs': 5},
                                                {'val_frac_criteria': 0.825-shift, 'n_epochs': 6},
                                                {'val_frac_criteria': 0.85-shift, 'n_epochs': 7},
                                                {'val_frac_criteria': 0.875-shift, 'n_epochs': 8},
                                                {'val_frac_criteria': 0.9-shift, 'n_epochs': 9}] #[x,n]   #if below 60% accuracy at 5 epochs, abort: x% at n epochs
    
    # Model training parameters
    compute_pretrained_values = False 
    compute_rand_accuracy = False

    # --------------------- Hyperparameters ---------------------
    
    num_layers = 9 #trial.suggest_int('num_layers', 3, 9)  
    layer_sizes = trial.suggest_categorical('layer_dims', [100, 250, 500, 750, 1000])  # Allowed layer sizes
    layer_dims = [layer_sizes]*num_layers # [trial.suggest_categorical(f'layer_dim_{i}', layer_sizes) for i in range(num_layers)]
    dropout_rate = trial.suggest_float("dropout_rate", 0, 0.5) 

    gnn_layer_type = trial.suggest_categorical('gnn_layer_type', ['GraphConv', 'GCNConv', 'SAGEConv', 'GATConv', 'TransformerConv']) 

    use_batch_normalization = trial.suggest_categorical('use_batch_normalization', [True, False])
    use_skip_connections = trial.suggest_categorical('use_skip_connections', [True, False])
    use_graph_normalization = trial.suggest_categorical('use_graph_normalization', [True, False])
    use_edge_information = True if gnn_layer_type == 'TransformerConv' else False
    output_depth = trial.suggest_int('output_depth', 2, 4) # if model_type == 'link_pred' else 3

    lr = trial.suggest_float("lr", 1e-7, 1e-2, log=True) 
    batch_size = trial.suggest_categorical("batch_size", [1, 8, 64])

    weight_class_loss = 0.1 
    
    # --------------------- Model initialization and training ---------------------
    
    #Initialize model
    if 'model' in locals():
        del model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #"cpu") ## cuda makes the size of the network no longer matter in terms of epoch-time?! (64=1024 dimensions) On cpu it is 7 s vs 1min 30 seconds. GPU: 14 s vs 14 s
        
    train_key = "Train"
    val_key = "Validation"
    node_feature_dim = datasets_dict[train_key].num_node_features

    output_layer_config = [{
            'type': "linear_symmetric",  # Indicates symmetric processing but not global
            'in_features': -1,  # Input features to the first output layer
            'intermediate_features': layer_dims[0],  # Size for the middle layer in the sequence
            'out_features': 3,  # Size for the final classification/prediction layer
            'depth': output_depth,  # Total number of layers in this sequence
        }]


    boundary_layer_config = [{
            'type': "linear_symmetric",  # Indicates the use of symmetric and global adjustments
            'in_features': -1,  # Input features to the first boundary layer
            'intermediate_features': layer_dims[0],  # Intermediate layer size (if applicable)
            'out_features': 2,  # Final output size for boundary prediction
            'depth': output_depth,  # Total number of layers in this sequence
        }]


    init_params = {
        'model_type': model_type,
        'node_feature_dim': node_feature_dim, 
        'device': device,
        'num_predicted_classes': 3,
        
        'num_layers': num_layers,
        'layer_dims': layer_dims,
        'final_layer_dim': layer_dims[-1] if model_type=="link_pred" else 3,
        'gnn_layer_type': gnn_layer_type,
        
        #TransformerConv parameters. Currently set to default values.
        'TransformerConvHeads': 1,
        'TransformerConvBeta': False,
        'TransformerConvDropout': 0,

        #Possible to add the L1 and L2 norm of model parameters to the loss (it is normalized (divided) by the number of modelparameters before)
        'regularization_types': [], # ['L1_model_params', 'L2_model_params']
        'dropout_rate': dropout_rate,
        
        'use_skip_connections': use_skip_connections,
        'use_graph_normalization': use_graph_normalization,
        'use_edge_information': use_edge_information,
        'use_batch_normalization': use_batch_normalization,

        'output_layer_config': output_layer_config,
        'boundary_layer_config': boundary_layer_config
    }    

    model = PROTACSplitter(init_params=init_params)
     
    
    # --------------------- Assemble training parameters ---------------------


    
    train_params = {
        'datasets_dict': {train_key: datasets_dict[train_key], val_key: datasets_dict[val_key]},
        'optimizer': optim.Adam,
        'lr': lr,
        'batch_size': batch_size,
        'criterion': nn.CrossEntropyLoss(reduction='sum'),
        'epochs': max_epochs,
        'max_epochs': max_epochs,
        'val_early_stopping': val_early_stopping,
        'median_over_n_val_losses': median_over_n_val_losses,
        'min_over_n_val_losses': min_over_n_val_losses,
        'stop_if_val_acc_below_x_at_n_list': stop_if_val_acc_below_x_at_n_list,
        'print_every_n_epochs': 1,
        'compute_pretrained_values': compute_pretrained_values,
        'compute_rand_accuracy': compute_rand_accuracy,
        'e3_library': e3_library,
        'poi_library': poi_library,
        'fp_function': compute_countMorgFP,
        'param_to_opt': ['accuracy'], # only compute the accuracy when optimizing hyperparameters
        'weight_class_loss': weight_class_loss #weighting between row-loss (bondclass-loss) and column-loss.
    }


    # --------------------- Train model and get output ---------------------

    

    #Training
    model.to(device=device)
    output = model.train_model(train_params=train_params)
    

    #Get optimization metric
    aggregated_output = aggregate_output_all_epochs(output=output)
    best_epoch = get_best_epoch(output= aggregated_output, 
                       dataset= model.val_set_name, 
                       accuracy_origin="model", 
                       structure="PROTAC", 
                       metric_type="Accuracy",
                       model_crossfold = model_crossfold)
    
    best_val_acc = avg(output['metrics'][model.val_set_name]["model"]["PROTAC"]["Accuracy"][best_epoch]) #Optimize for validation accuracy
    
    return best_val_acc


create a folder named "pickle_optuna" to run this code

In [None]:
existing_study_paths = []
if True:
    
    n_trials = 1000
    pickle_str = f'{model_type}_study1.pickle'
    study_path = f'pickle_optuna/{pickle_str}'
    
    if study_path != '':
        for _ in range(n_trials):
            if not os.path.exists(study_path):
                study = optuna.create_study(direction='maximize') # Use 'minimize' if you are returning a loss
                study.optimize(objective, n_trials=1)
                with open(study_path, 'wb') as file:
                    pickle.dump(study, file)
            else:
                with open(study_path, 'rb') as file:
                    study = pickle.load(file)
                study.optimize(objective, n_trials=1)
                with open(study_path, 'wb') as file:
                    pickle.dump(study, file)
    else:
        study = optuna.create_study(direction='maximize') # Use 'minimize' if you are returning a loss
        study.optimize(objective, n_trials=n_trials)


    print("Number of finished trials: ", len(study.trials))
    print("Best trial:")
    trial = study.best_trial

    print("  Value: ", trial.value)
    print("  Params: ")
    for key, value in trial.params.items():
        print("    {}: {}".format(key, value))

In [None]:
# display parameters of best trials

load_pickled_study = True

if load_pickled_study:

    #load study
    if load_pickled_study:
        pickle_str = 'your_study_name.pickle' # ''
        study_path = f'pickle_optuna/{pickle_str}'
        if os.path.exists(study_path):
            with open(study_path, 'rb') as file:
                study = pickle.load(file)
        
    #print dataframe
    trials_df = study.trials_dataframe()
    trials_df_sorted = trials_df.sort_values(by=["value"], ascending=False)
    print("Best trials and their parameters:")
    print(len(trials_df))
    display(trials_df_sorted.head(20))
        

In [None]:
# display various visualizations of the trial from optuna

load_pickled_study = False

if load_pickled_study:



    #load study
    if load_pickled_study:
        pickle_str = 'your_study_name.pickle' # ''
        study_path = f'pickle_optuna/{pickle_str}'
        if os.path.exists(study_path):
            with open(study_path, 'rb') as file:
                study = pickle.load(file)
        
    #print dataframe
    trials_df = study.trials_dataframe()
    trials_df_sorted = trials_df.sort_values(by=["value"], ascending=False)
    print("Best trials and their parameters:")
    display(trials_df_sorted.head(10))
        
    #plot optuna plots
    best_params = study.best_params
    study_parameters = []
    for s in study.get_trials():
        for k, v in s.params.items():
            if k not in study_parameters:
                study_parameters.append(k)
    display(optuna.visualization.plot_parallel_coordinate(study))
    study_parameters_2 = study_parameters.copy()
    study_parameters_2.remove('lr')
    display(optuna.visualization.plot_param_importances(study, params=study_parameters))
    display(optuna.visualization.plot_slice(study, params=study_parameters))
    display(optuna.visualization.plot_contour(study))




    #plot acc vs lr, colored by batch size
    optuna_study_val_accs = []
    optuna_study_param_vals = []
    optuna_study_param = 'lr'
    optuna_study_param_vals_to_color_by = []
    optuna_study_param_to_color_by = 'batch_size'
    plt.figure(figsize=(10, 6))
    for s in study.get_trials():
        optuna_study_val_accs.append(s.values)
        optuna_study_param_vals.append(s.params[optuna_study_param])
        optuna_study_param_vals_to_color_by.append(s.params[optuna_study_param_to_color_by])

    optuna_study_param_vals_to_color_by_set = sorted(set(optuna_study_param_vals_to_color_by))
    colors = plt.cm.jet(np.linspace(0, 1, len(optuna_study_param_vals_to_color_by_set)))

    labels_in_legend =[]
    for trial_param, trial_val_acc, optuna_study_param_to_color_by in zip(optuna_study_param_vals, optuna_study_val_accs, optuna_study_param_vals_to_color_by):
        color = colors[optuna_study_param_vals_to_color_by_set.index(optuna_study_param_to_color_by)]
        label = f'Batch Size: {optuna_study_param_to_color_by}'
        if label in labels_in_legend:
            label=''
        plt.scatter([trial_param], [trial_val_acc], color=color, label=label)
        if label not in labels_in_legend and label != '':
            labels_in_legend.append(label)

    plt.xlabel('Learning Rate (lr)')
    plt.ylabel('Validation Accuracy')
    plt.title('Validation Accuracy vs Learning Rate for Different Batch Sizes')
    plt.xscale('log')  # Since lr values are usually in a log scale
    handles, labels = plt.gca().get_legend_handles_labels()
    plt.legend(handles,labels, loc='upper left')

    plt.legend(loc='upper left')
    plt.grid(True)
    plt.show()

