# __PROTACSplitter__


# Import packages

In [None]:
from protacdataset import ProtacLoader
from protacsplitter import PROTACSplitter

import torch
from torch_geometric.loader import DataLoader

import os
import random

import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from datetime import datetime
import copy
import torch.nn as nn
import torch.optim as optim
from data_curation_augmentation_splitting_functions import (compute_countMorgFP,
                                                            make_graph_with_pos)
from pipeline_functions import (aggregate_output_all_epochs,
                                aggregate_metrics_at_epoch, 
                                get_best_epoch,
                                avg,
                                aggregate_metrics_for_crossfold,
                                crossfolds_avg_std_at_epoch,
                                geo_mean)
import numpy as np
from rdkit import Chem
import networkx as nx
import statistics


import matplotlib.ticker as mticker

pd.set_option('display.max_columns', 1000, 'display.width', 2000, 'display.max_colwidth', 100)




# Load data

In [None]:
#Train: 952, 2856, 9520
#Val. : 238, 714,  2380

split_idx = 0

butina_cutoff = "0.33"

path_substr = {
    "Train":            {'id': "train",             'n': "952",    'butina': butina_cutoff},
    "Validation":       {'id': "val",               'n': "238",    'butina': butina_cutoff},
}

test_paths_substr_split_to_n = { 
    "test_protac":      {0: 108,    1: 108,     2: 108},
    "test_poi":         {0: 70,     1: 70,      2: 60},
    "test_linker":      {0: 135,    1: 135,     2: 135},
    "test_e3":          {0: 75,     1: 50,      2: 50}, 
    "test_poilinker":   {0: 135,    1: 135,     2: 135},
    "test_poie3":       {0: 70,     1: 70,      2: 60},
    "test_e3linker":    {0: 135,    1: 135,     2: 135}, 
}


test_paths_substr = {
    "Test PROTAC":      {'id': "test_protac",         'butina': butina_cutoff},
    "Test Warhead":         {'id': "test_poi",            'butina': butina_cutoff},
    "Test Linker":      {'id': "test_linker",         'butina': butina_cutoff},
    "Test E3":          {'id': "test_e3",             'butina': butina_cutoff},
    "Test W-Linker":   {'id': "test_poilinker",     'butina': butina_cutoff},
    "Test W-E3":       {'id': "test_poie3",         'butina': butina_cutoff},
    "Test E3Linker":    {'id': "test_e3linker",       'butina': butina_cutoff}

}

dataset_paths = {dataset_name: f"../data/augmented/{substr_dict['id']}_{substr_dict['n']}_ButinaClusterCutoff_{substr_dict['butina']}.csv" for dataset_name, substr_dict in path_substr.items()}
dataset_test_paths = {dataset_name: f"../data/augmented/{substr_dict['id']}_split{split_idx}_ButinaClusterCutoff_{substr_dict['butina']}.csv" for dataset_name, substr_dict in test_paths_substr.items()}
dataset_paths.update(dataset_test_paths)


chosen_dataset_paths = dataset_paths

dataset_names = chosen_dataset_paths.keys()



model_type="link_pred" # node_pred link_pred boundary_pred
node_descriptors = "rdkit" #ones, empty, rdkit
graph_descriptors = [ "betweenness", "closeness"] # "betweenness", "closeness", "local_eigenvectors_x"
model_crossfold=False
save_datasets = False 
load_datasets = False
precompute_splits=False


In [None]:
if model_crossfold:
    num_crossfolds = 5
else:
    num_crossfolds = 1
ProtacDataset_loader = ProtacLoader(dataset_paths=chosen_dataset_paths, model_type=model_type, node_descriptors=node_descriptors, graph_descriptors=graph_descriptors, model_crossfold=model_crossfold)
if not load_datasets:
    datasets_dict = ProtacDataset_loader.initialize_datasets(dataframes=ProtacDataset_loader.load_dataframes(),
                                                                num_crossfolds=num_crossfolds, precompute_splits=precompute_splits)

In [None]:
now = datetime.now()
now_str = now.strftime("%Y_%m_%d_%H%M")
if save_datasets:
    ProtacDataset_loader.save_datasets_to_file(dataset_path=f'data_{model_type}_{graph_descriptors}_split{split_idx}_{now_str}.pickle')
if load_datasets:
    datasets_dict = ProtacDataset_loader.load_datasets_from_file(dataset_path=f"link_pred_{graph_descriptors}_split{split_idx}.pickle")

# Define parameters

In [None]:
# --------------------- Hyperparameters ---------------------
    
num_layers = 9
layer_sizes = 250
layer_dims = [layer_sizes]*num_layers # [trial.suggest_categorical(f'layer_dim_{i}', layer_sizes) for i in range(num_layers)]
dropout_rate = 5.9e-3

gnn_layer_type = 'TransformerConv'

use_batch_normalization = False
use_skip_connections = False
use_graph_normalization = True
use_edge_information = True if gnn_layer_type == 'TransformerConv' else False
output_depth = 3

lr = 8e-5
batch_size = 1

weight_class_loss = 0.1 



# --------------------- Training parameters ---------------------

epochs = 2
max_epochs = 1000

#Libraries set to None, as they are not important during optimization
e3_library = pd.read_csv('../../data/e3_trainval_substructures_without_attachment.csv')["E3 SMILES"].to_list()
poi_library = pd.read_csv('../../data/poi_trainval_substructures_without_attachment.csv')["POI SMILES"].to_list()

#Early stopping parameters
val_early_stopping = True #model looks back 2*n epochs to decide if it shall continue or stop, depending on validation loss. IF median has increased => Stop.
median_over_n_val_losses = None
min_over_n_val_losses = 10

shift = 0.6 # vary this by hand for different models to set various early stopping criteria. Set to 1 to disable this feature
stop_if_val_acc_below_x_at_n_list = [ #{'val_frac_criteria': 0.42, 'n_epochs': 1}, # good cutoff for node predictions
                                                {'val_frac_criteria': 0.5-shift, 'n_epochs': 1}, #{'val_frac_criteria': 0.7, 'n_epochs': 2},
                                                {'val_frac_criteria': 0.65-shift, 'n_epochs': 2}, # {'val_frac_criteria': 0.8, 'n_epochs': 3}, 
                                                {'val_frac_criteria': 0.7-shift, 'n_epochs': 3},# {'val_frac_criteria': 0.9, 'n_epochs': 4}, 
                                                {'val_frac_criteria': 0.75-shift, 'n_epochs': 4},#{'val_frac_criteria': 0.65, 'n_epochs': 8},
                                                {'val_frac_criteria': 0.8-shift, 'n_epochs': 5},
                                                {'val_frac_criteria': 0.825-shift, 'n_epochs': 6},
                                                {'val_frac_criteria': 0.85-shift, 'n_epochs': 7},
                                                {'val_frac_criteria': 0.875-shift, 'n_epochs': 8},
                                                {'val_frac_criteria': 0.9-shift, 'n_epochs': 9}] #[x,n]   #if below 60% accuracy at 5 epochs, abort: x% at n epochs
    
# Model training parameters
compute_pretrained_values = False 
compute_rand_accuracy = False

param_to_opt = ["accuracy"], #[None],  "accuracy", "recall", "precision", "f1",                 if is set to [None], all evaluation metrics will be calculated

#------------

save_model_and_output = False

compute_datasets_dict=True
if 'datasets_dict' in locals():
    compute_datasets_dict=False

states = []
outputs = []

# Train model (validate, no testsets)

In [None]:
if True:
    if compute_datasets_dict:
        dataset_paths = {dataset_name: f"../data/augmented/{substr_dict['id']}_{substr_dict['n']}_ButinaClusterCutoff{substr_dict['butina']}.csv" for dataset_name, substr_dict in path_substr.items()}
        dataset_paths.update(dataset_test_paths)

        ProtacDataset_loader = ProtacLoader(dataset_paths=chosen_dataset_paths, model_type=model_type, node_descriptors=node_descriptors, graph_descriptors=graph_descriptors, model_crossfold=model_crossfold)
        datasets_dict = ProtacDataset_loader.initialize_datasets(dataframes=ProtacDataset_loader.load_dataframes(),
                                                                        num_crossfolds=1, precompute_splits=precompute_splits)
        compute_datasets_dict=True
        now = datetime.now()
        now_str = now.strftime("%Y_%m_%d_%H%M")


    # --------------------- Model initialization and training ---------------------
    
    #Initialize model
    if 'model' in locals():
        del model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #"cpu") ## cuda makes the size of the network no longer matter in terms of epoch-time?! (64=1024 dimensions) On cpu it is 7 s vs 1min 30 seconds. GPU: 14 s vs 14 s
        
    train_key = "Train"
    val_key = "Validation"
    node_feature_dim = datasets_dict[train_key].num_node_features

    output_layer_config = [{
            'type': "linear_symmetric",  # Indicates symmetric processing but not global
            'in_features': -1,  # Input features to the first output layer
            'intermediate_features': layer_dims[0],  # Size for the middle layer in the sequence
            'out_features': 3,  # Size for the final classification/prediction layer
            'depth': output_depth,  # Total number of layers in this sequence
        }]


    boundary_layer_config = [{
            'type': "linear_symmetric",  # Indicates the use of symmetric and global adjustments
            'in_features': -1,  # Input features to the first boundary layer
            'intermediate_features': layer_dims[0],  # Intermediate layer size (if applicable)
            'out_features': 2,  # Final output size for boundary prediction
            'depth': output_depth,  # Total number of layers in this sequence
        }]

    init_params = {
        'model_type': model_type,
        'node_feature_dim': node_feature_dim, 
        'device': device,
        'num_predicted_classes': 3,
        
        'num_layers': num_layers,
        'layer_dims': layer_dims,
        'final_layer_dim': layer_dims[-1] if model_type=="link_pred" else 3,
        'gnn_layer_type': gnn_layer_type,
        
        #TransformerConv parameters. Currently set to default values.
        'TransformerConvHeads': 1,
        'TransformerConvBeta': False,
        'TransformerConvDropout': 0,

        #Possible to add the L1 and L2 norm of model parameters to the loss (it is normalized (divided) by the number of modelparameters before)
        'regularization_types': [], # ['L1_model_params', 'L2_model_params']
        'dropout_rate': dropout_rate,
        
        'use_skip_connections': use_skip_connections,
        'use_graph_normalization': use_graph_normalization,
        'use_edge_information': use_edge_information,
        'use_batch_normalization': use_batch_normalization,

        'output_layer_config': output_layer_config,
        'boundary_layer_config': boundary_layer_config
    }    

    model = PROTACSplitter(init_params=init_params)


    train_params = {
        'datasets_dict': datasets_dict,
        'optimizer': optim.Adam,
        'lr': lr,
        'batch_size': batch_size,
        'criterion': nn.CrossEntropyLoss(reduction='sum'),
        'epochs': epochs,
        'max_epochs': max_epochs,
        'val_early_stopping': val_early_stopping,
        'median_over_n_val_losses': median_over_n_val_losses,
        'min_over_n_val_losses': min_over_n_val_losses,
        'stop_if_val_acc_below_x_at_n_list': stop_if_val_acc_below_x_at_n_list,
        'print_every_n_epochs': 1,
        'compute_pretrained_values': compute_pretrained_values,
        'compute_rand_accuracy': compute_rand_accuracy,
        'e3_library': e3_library,
        'poi_library': poi_library,
        'fp_function': compute_countMorgFP,
        'param_to_opt': param_to_opt, #[None],  accuracy, recall, precision, f1,                 if is set to [None], all evaluation metrics will be calculated
        'weight_class_loss': weight_class_loss #weighting between row-loss (bondclass-loss) and column-loss.
    }

    # Adjust use_batch_normalization based on batch_size
    if train_params['batch_size'] > 1:
        init_params['use_batch_normalization'] = True

    # Model factory function
    def create_model(init_params):
        model = PROTACSplitter(init_params=init_params)
        model.to(init_params['device'])
        return model


    
    start_time = datetime.now()
    model = create_model(init_params)
    output = model.train_model(train_params = train_params)
    outputs.append(output)
    finishing_time = datetime.now()
    timedelta_to_train_model = finishing_time - start_time
    minutes_to_train_model = timedelta_to_train_model.seconds/(60*num_crossfolds)
    minutes_per_epoch = minutes_to_train_model/(model.epoch+int(train_params['compute_pretrained_values']))
        

    print(f"Training started at {start_time} and finished at {finishing_time}")
    print(f"Minutes per epoch: {minutes_per_epoch}")


    epochs = model.epoch #new max-epoch. As to not run one model for more epochs than a previous model have stopped at before, as these epochs WILL be discard



    import pickle
    now = datetime.now()
    now_str = now.strftime("%Y_%m_%d_%H%M")
    output_comment = f"{model_type}"

    

    if save_model_and_output:
        data = {'model': model, 'output': output}
        with open(f"../data/model_outputs/{model.model_type}_{graph_descriptors}_{output_comment}_{now_str}", 'wb') as file:
            pickle.dump(data, file)


# Train model and test on the 3 splits

In [None]:

for split_idx in [0, 1, 2]:
    if compute_datasets_dict:
        dataset_paths = {dataset_name: f"../data/augmented/{substr_dict['id']}_{substr_dict['n']}_ButinaClusterCutoff{substr_dict['butina']}.csv" for dataset_name, substr_dict in path_substr.items()}
        dataset_test_paths = {dataset_name: f"../data/augmented/{substr_dict['id']}_split{split_idx}_ButinaClusterCutoff{substr_dict['butina']}.csv" for dataset_name, substr_dict in test_paths_substr.items()}
        dataset_paths.update(dataset_test_paths)

        ProtacDataset_loader = ProtacLoader(dataset_paths=chosen_dataset_paths, model_type=model_type, node_descriptors=node_descriptors, graph_descriptors=graph_descriptors, model_crossfold=model_crossfold)
        datasets_dict = ProtacDataset_loader.initialize_datasets(dataframes=ProtacDataset_loader.load_dataframes(),
                                                                        num_crossfolds=1, precompute_splits=precompute_splits)
        compute_datasets_dict=True
        now = datetime.now()
        now_str = now.strftime("%Y_%m_%d_%H%M")
        if save_datasets:
            ProtacDataset_loader.save_datasets_to_file(dataset_path=f'data_{model_type}_{graph_descriptors}_split{split_idx}_{now_str}.pickle')
   
    

    # --------------------- Model initialization and training ---------------------
    
    #Initialize model
    if 'model' in locals():
        del model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #"cpu") ## cuda makes the size of the network no longer matter in terms of epoch-time?! (64=1024 dimensions) On cpu it is 7 s vs 1min 30 seconds. GPU: 14 s vs 14 s
        
    train_key = "Train"
    val_key = "Validation"
    node_feature_dim = datasets_dict[train_key].num_node_features

    output_layer_config = [{
            'type': "linear_symmetric",  # Indicates symmetric processing but not global
            'in_features': -1,  # Input features to the first output layer
            'intermediate_features': layer_dims[0],  # Size for the middle layer in the sequence
            'out_features': 3,  # Size for the final classification/prediction layer
            'depth': output_depth,  # Total number of layers in this sequence
        }]


    boundary_layer_config = [{
            'type': "linear_symmetric",  # Indicates the use of symmetric and global adjustments
            'in_features': -1,  # Input features to the first boundary layer
            'intermediate_features': layer_dims[0],  # Intermediate layer size (if applicable)
            'out_features': 2,  # Final output size for boundary prediction
            'depth': output_depth,  # Total number of layers in this sequence
        }]

    init_params = {
        'model_type': model_type,
        'node_feature_dim': node_feature_dim, 
        'device': device,
        'num_predicted_classes': 3,
        
        'num_layers': num_layers,
        'layer_dims': layer_dims,
        'final_layer_dim': layer_dims[-1] if model_type=="link_pred" else 3,
        'gnn_layer_type': gnn_layer_type,
        
        #TransformerConv parameters. Currently set to default values.
        'TransformerConvHeads': 1,
        'TransformerConvBeta': False,
        'TransformerConvDropout': 0,

        #Possible to add the L1 and L2 norm of model parameters to the loss (it is normalized (divided) by the number of modelparameters before)
        'regularization_types': [], # ['L1_model_params', 'L2_model_params']
        'dropout_rate': dropout_rate,
        
        'use_skip_connections': use_skip_connections,
        'use_graph_normalization': use_graph_normalization,
        'use_edge_information': use_edge_information,
        'use_batch_normalization': use_batch_normalization,

        'output_layer_config': output_layer_config,
        'boundary_layer_config': boundary_layer_config
    }    

    model = PROTACSplitter(init_params=init_params)


    train_params = {
        'datasets_dict': datasets_dict,
        'optimizer': optim.Adam,
        'lr': lr,
        'batch_size': batch_size,
        'criterion': nn.CrossEntropyLoss(reduction='sum'),
        'epochs': epochs,
        'max_epochs': max_epochs,
        'val_early_stopping': val_early_stopping,
        'median_over_n_val_losses': median_over_n_val_losses,
        'min_over_n_val_losses': min_over_n_val_losses,
        'stop_if_val_acc_below_x_at_n_list': stop_if_val_acc_below_x_at_n_list,
        'print_every_n_epochs': 1,
        'compute_pretrained_values': compute_pretrained_values,
        'compute_rand_accuracy': compute_rand_accuracy,
        'e3_library': e3_library,
        'poi_library': poi_library,
        'fp_function': compute_countMorgFP,
        'param_to_opt': param_to_opt, 
        'weight_class_loss': weight_class_loss #weighting between row-loss (bondclass-loss) and column-loss.
    }

    # Adjust use_batch_normalization based on batch_size
    if train_params['batch_size'] > 1:
        init_params['use_batch_normalization'] = True

    # Model factory function
    def create_model(init_params):
        model = PROTACSplitter(init_params=init_params)
        model.to(init_params['device'])
        return model


    
    start_time = datetime.now()
    model = create_model(init_params)
    output = model.train_model(train_params = train_params)
    outputs.append(output)
    finishing_time = datetime.now()
    timedelta_to_train_model = finishing_time - start_time
    minutes_to_train_model = timedelta_to_train_model.seconds/(60*num_crossfolds)
    minutes_per_epoch = minutes_to_train_model/(model.epoch+int(train_params['compute_pretrained_values']))
        

    print(f"Training started at {start_time} and finished at {finishing_time}")
    print(f"Minutes per epoch: {minutes_per_epoch}")


    epochs = model.epoch #new max-epoch. As to not run one model for more epochs than a previous model have stopped at before, as these epochs WILL be discard



    """
    """
    import pickle
    now = datetime.now()
    now_str = now.strftime("%Y_%m_%d_%H%M")
    output_comment = f"{model_type}_split{split_idx}"

    

    if save_model_and_output:
        data = {'model': model, 'output': output}
        with open(f"../data/model_outputs/{model.model_type}_{graph_descriptors}_{output_comment}_{now_str}", 'wb') as file:
            pickle.dump(data, file)

    #"""
    model_crossfold = True

In [None]:
if save_model_and_output and model_crossfold: #load the saved 3 splits
    import pickle
    import gc

    identifying_comment = "besthypparams_woLocEigen"
    chosen_experiment = "link_pred_bestparams_woLocEigen_allmetrics"



    outputs = []
    studies = {}
    for split_idx, study_path in chosen_experiment.items():
        p = f"../../data/model_outputs/{study_path}"
        with open(p, 'rb') as file:
            torch.cuda.empty_cache()
            x = pickle.load(file)
            outputs.append(x["output"])
            if split_idx != 2:
                del x #sometimes all 3 models and outputs are too big for the GPU memory => Load and clean itterativly.
                gc.collect
        print(split_idx)

            
    model = x["model"]

# _---------------Best (average) epoch metrics-------------------_

# Aggregate metrics

In [None]:
identifying_comment = ""

In [None]:
if model_crossfold:
    measures_avg, measures_std, measures_concat = aggregate_metrics_for_crossfold(outputs=outputs, model=model)
else:
    aggregated_output = aggregate_output_all_epochs(output=output)

## Get best average epoch

In [None]:
best_epoch = get_best_epoch(output=measures_avg if model_crossfold else aggregated_output, 
                       dataset= model.val_set_name, 
                       accuracy_origin="model", 
                       structure="PROTAC", 
                       metric_type="Accuracy",
                       model_crossfold = model_crossfold)

print(f'Best epoch: {best_epoch}')

## Save model params

In [None]:
model.load_state_dict(model.states[best_epoch-1])
torch.save(model, 'protacsplitter_params.pt')

## Display evaluation metrics table

In [None]:
if model_crossfold:
    avg_std_metrics = crossfolds_avg_std_at_epoch(outputs=outputs, epoch=best_epoch, dataset_names=list(outputs[0]["metrics"].keys()), print_df = True)
else:
    aggregated_metrics = aggregate_metrics_at_epoch(output=output, epoch=best_epoch)
    aggregated_metrics_df = pd.DataFrame(aggregated_metrics)
    row_names = [dataset_name for dataset_name in output['metrics'].keys()]
    aggregated_metrics_df.index = row_names
    print(aggregated_metrics_df.round(1))

# Plots distributions at best epoch

## Plot PROTAC vs LIGANDS+LINKER accuracy distribution
investigate accuracy of model type prediction and split prediction

In [None]:
import matplotlib.pyplot as plt
import numpy as np

if model_crossfold:
    acc_to_plot = measures_concat
else:
    acc_to_plot = output

num_datasets = len(acc_to_plot['metrics'])

fig, axs = plt.subplots(1, num_datasets, figsize=(5*num_datasets,4), sharey=False)

for dataset_name, ax in zip(acc_to_plot['metrics'].keys(), axs):

    protac_accuracies = acc_to_plot['metrics'][dataset_name]["model"]["PROTAC"]["Accuracy"][best_epoch]  
    ligand_linker_accuracies = acc_to_plot['metrics'][dataset_name]["model"]["LIGANDS"]["Accuracy"][best_epoch] 

    # Plotting the histograms
    binwidth = 0.01
    bins = list(np.arange(0, 1 + binwidth, binwidth))
    ax.hist(protac_accuracies, range=(0,1), bins=bins, alpha=0.5, label='PROTAC', color='red')  # First histogram
    ax.hist(ligand_linker_accuracies, range=(0,1), bins=bins, alpha=0.5, label='LIGANDS+LINKER', color='blue')   # Second histogram

    ax.get_yaxis().set_visible(False)


    # Adding some plot details
    #ax.xlabel('Accuracy')
    #ax.ylabel('Frequency')
    ax.title.set_text(f'Accuracy for {dataset_name} dataset')
    #ax.legend()
#
plt.show()

## Violin plot atoms wrong (TP+FN)
Gives overview of the performance

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

if model_crossfold:
    acc_to_plot = measures_concat
else:
    acc_to_plot = output

dataset_names = ["Validation", "Test PROTAC"]

for dataset_name in dataset_names:

    #dataset_name = "Test PROTAC"
    accuracy_origin = "model"

    structures_to_plot = ["PROTAC", "LIGANDS"]
    renaming_dict = {"PROTAC": "PROTAC", "LIGANDS": "Ligands & Linker"}

    substructure_atoms_wrong = {}
    for substructure in acc_to_plot["metrics"][dataset_name][accuracy_origin].keys():
        if substructure not in structures_to_plot:
            continue
        if "Atoms_wrong" in acc_to_plot["metrics"][dataset_name][accuracy_origin][substructure]:
            substructure_atoms_wrong[substructure] = []
            for num_wrong, count in acc_to_plot["metrics"][dataset_name][accuracy_origin][substructure]["Atoms_wrong"][best_epoch].items():
                substructure_atoms_wrong[substructure].extend([num_wrong]*count)




    # Convert this mock data into a DataFrame
    substructure_atoms_wrong_df = pd.DataFrame({renaming_dict[k]: pd.Series(v) for k, v in substructure_atoms_wrong.items()})
    substructure_atoms_wrong_df_long = pd.melt(substructure_atoms_wrong_df, var_name='Atoms_wrong', value_name='Count wrong atoms')


    # Separate the data
    protac_data = substructure_atoms_wrong_df_long[substructure_atoms_wrong_df_long['Atoms_wrong'] == 'PROTAC']
    remaining_data = substructure_atoms_wrong_df_long[substructure_atoms_wrong_df_long['Atoms_wrong'] != 'PROTAC']

    # Create a figure and a set of subplots
    fig, ax = plt.subplots(1, 1, figsize=(10, 6))
    # Plot "PROTAC" with cut=0 on the first subplot
    sns.violinplot(x='Atoms_wrong', y='Count wrong atoms', data=substructure_atoms_wrong_df_long, cut=0, ax=ax)

    # Adjustments for both subplots
    ax.set_xlabel('')  # Remove the overall x-axis label
    ax.tick_params(axis='x', length=0)  # Remove the x-ticks but keep the labels
    # Keeping the tick labels unrotated
    for label in ax.get_xticklabels():
        label.set_rotation(0)

    # Remove spines for a cleaner look
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_visible(False)

    # Set a common title for the figure
    #fig.suptitle(f'Frequency & count of wrongly predicted atoms \n{dataset_name}', fontsize=16)

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()


    #fig.savefig(f'fig_results/{model.model_type}_{identifying_comment}_violin_{dataset_name}_{structures_to_plot}.svg', format='svg', dpi=1200, bbox_inches='tight')
    #fig.savefig(f'fig_results/{model.model_type}_{identifying_comment}_violin_{dataset_name}_{structures_to_plot}.png', format='png', dpi=1200, bbox_inches='tight')

## Cumulative barplot atoms wrong (normalized FP+FN)
Answers detailed questions on how wrong/good it is. 
E.g. What fraction of predictions for a given structure have at most x atoms wrongly predicted?

In [None]:
dataset_names = ["Validation", "Test PROTAC"] #list(acc_to_plot["metrics"].keys()) #
invert = True

if model_crossfold:
    acc_to_plot = measures_concat
else:
    acc_to_plot = output


for dataset_name in dataset_names:
    #dataset_name = "Test PROTAC"
    accuracy_origin = "model"
    substructure = "LIGANDS"



    if model_crossfold:
        atoms_wrong_to_plot = outputs
    else:
        atoms_wrong_to_plot = [output]

    fractions_across_splits = {}
    for split_idx, acc_to_plot in enumerate(atoms_wrong_to_plot):

        atoms_wrong_occuracnes = list(acc_to_plot["metrics"][dataset_name][accuracy_origin][substructure]["Atoms_wrong"][best_epoch].keys())

        total_number_of_protacs = sum(acc_to_plot["metrics"][dataset_name][accuracy_origin][substructure]["Atoms_wrong"][best_epoch].values())


        cumulative_fraction = 0
        cumulative_dict = {}
        lowest_occurance = min(atoms_wrong_occuracnes)
        highest_occurance = max(atoms_wrong_occuracnes)

        for i in range(lowest_occurance, highest_occurance+1):
            if i in acc_to_plot["metrics"][dataset_name][accuracy_origin][substructure]["Atoms_wrong"][best_epoch]:
                cumulative_fraction += acc_to_plot["metrics"][dataset_name][accuracy_origin][substructure]["Atoms_wrong"][best_epoch][i]/total_number_of_protacs
            cumulative_dict[i] = cumulative_fraction


        max_atoms_wrong_to_plot = 10
        fractions = list(cumulative_dict.values())[0:max_atoms_wrong_to_plot+1]
        num_atoms_wrong = list(cumulative_dict.keys())[0:max_atoms_wrong_to_plot+1] 

        fractions_across_splits[split_idx] = fractions


    average_fractions = []
    std_fractions = []
    for i in range(max_atoms_wrong_to_plot+1):
        tmp = []
        for split_idx, fractions in fractions_across_splits.items():
            if i >= len(fractions):
                correct_fraction = 1
            else:
                correct_fraction = fractions[i]
            tmp.append(correct_fraction)

        avg_tmp = 1-avg(tmp) if invert else avg(tmp)

        average_fractions.append(avg_tmp)
        std_fractions.append(statistics.stdev(tmp))
        

   # print(average_fractions)
    # print(f'Fraction of predictions with 6 or fewer mispredicted atoms for {dataset_name}: {(1-average_fractions[6])*100}')

    fig, ax = plt.subplots(figsize=(10,6))
    plt.bar(num_atoms_wrong, average_fractions, color="skyblue" )
    plt.errorbar(num_atoms_wrong, average_fractions, yerr=std_fractions, fmt=".", color="k" )
    plt.ylim(bottom=0)
    if average_fractions[0]> 0.5:
        plt.ylim(top=1)
    else:
        plt.ylim(top=(average_fractions[0]+std_fractions[0])*1.3)
    #plt.title("Cumulative fraction with at most x atoms wrong")

    plt.xlabel("Atoms wrong")
    plt.ylabel("Fraction of predictions" )
    plt.xticks(range(0, max_atoms_wrong_to_plot +1, 1))

    plt.show()

    invert_str = "_inverted" if invert else ""

    #fig.savefig(f'fig_results/{model.model_type}_{identifying_comment}_cumulative_{dataset_name}_{substructure}{invert_str}.svg', format='svg', dpi=1200, bbox_inches='tight')
    #fig.savefig(f'fig_results/{model.model_type}_{identifying_comment}_cumulative_{dataset_name}_{substructure}{invert_str}.png', format='png', dpi=1200, bbox_inches='tight')

## Plot distribution of Precision, Recall, F1 for substructures
Plot at epoch

In [None]:
# Assuming output, dataset_name, epoch, accuracy_origin, structure_type are defined
# For illustration, avg and median functions need to be defined or imported as well

dataset_name = model.train_set_name
accuracy_origin = "model"

if model_crossfold:
    acc_to_plot = measures_concat
else:
    acc_to_plot = output

# Calculate the number of metrics to determine the grid size
substructure_list = ['LIGANDS', 'POI', 'LINKER', 'E3']
metrics_list = ['Precision', 'Recall']
num_substructures = len(substructure_list)
num_metrics = len(metrics_list)
# Calculate grid size (for simplicity, creating a grid of 1 row)
nrows = num_substructures
ncols = num_metrics

# Create subplot grid
fig, axs = plt.subplots(nrows, ncols, figsize=(5*ncols, 4*num_substructures), sharey=True, sharex=True) # Adjust figsize as needed
if ncols == 1:  # If there's only one subplot, axs may not be an array
    axs = [axs]

binwidth = 0.01
bins = list(np.arange(0, 1 + binwidth, binwidth))

# Iterate over each metric and plot
for axs_structure, structure_type in zip(axs, substructure_list):
    ax_idx = 0
    if structure_type not in substructure_list:
        continue

    for metric_type, epoch_metric_dict in acc_to_plot["metrics"][dataset_name][accuracy_origin][structure_type].items():
        if metric_type not in metrics_list:
            continue
        metric_values = epoch_metric_dict[best_epoch]

        ax = axs_structure[ax_idx]
        ax_idx += 1
        ax.hist(metric_values, range=(0,1), bins=bins)

        ax.set_title(f"{metric_type} for {structure_type} in {dataset_name} from {accuracy_origin}.\nAvg: {round(np.mean(metric_values),2)}, Median: {round(np.median(metric_values),2)}")

plt.tight_layout()  
plt.show

## Plot barplot of Precision and recall for substructures

In [None]:
import matplotlib.pyplot as plt
import numpy as np

if model_crossfold:
    acc_to_plot = measures_avg
    std_to_plot = measures_std
else:
    acc_to_plot = output

group_labels = list(acc_to_plot["validity_fraction"].keys())
subcategories = ['Warhead', 'Linker', 'E3']  # For the legend
colors = ['red', 'gray', 'blue']
substructures_to_substructures = {'Warhead': "POI", 'Linker':"LINKER", 'E3': "E3"}
metrics = ["Precision", "Recall"]

# Set up the figure
fig, axs = plt.subplots(figsize=(15, 10), nrows = 2, ncols=1, sharex=True)

# Number of groups
n_groups = len(group_labels)

# Width of each bar
bar_width = 1/(1+len(subcategories))

# The x location for the groups
indices = np.arange(n_groups)

x_factor = 1.2

# Plotting each group
for metric, ax in zip(metrics, axs):



    for i, category in enumerate(subcategories):
        
        vals = []
        stds = []
        for dataset_name in acc_to_plot["metrics"].keys():
            vals.append(acc_to_plot["metrics"][dataset_name]["model"][substructures_to_substructures[category]][metric][best_epoch])
            stds.append(std_to_plot["metrics"][dataset_name]["model"][substructures_to_substructures[category]][metric][best_epoch])

    # data = [validity_dict[validity_type][best_epoch] for ]
        # The x location for the bars within each group
        bar_positions = x_factor*indices + i * bar_width +bar_width/2
        ax.bar(bar_positions, vals, width=bar_width, label=category, color=(colors[i],0.8))
        ax.errorbar(bar_positions, vals, yerr=stds, fmt=".", color=("k", 0.7))




    # Setting the x-ticks and labels for each group
    plt.xticks(x_factor*indices + bar_width * 1.5, group_labels)

    # Adding labels and title
    ax.set_ylim([0, 1])
    ax.title.set_text(metric)


    # Adding a legend
    ax.legend()

#fig.savefig(f'fig_results/{model.model_type}_{identifying_comment}_precision_recall.svg', format='svg', dpi=1200, bbox_inches='tight')
#fig.savefig(f'fig_results/{model.model_type}_{identifying_comment}_precision_recall.png', format='png', dpi=1200, bbox_inches='tight')



## Plot PROTAC accuracy and ligands linker accuracy

In [None]:
import matplotlib.pyplot as plt
import numpy as np

if model_crossfold:
    acc_to_plot = measures_avg
    std_to_plot = measures_std
else:
    acc_to_plot = output

group_labels = list(acc_to_plot["validity_fraction"].keys())
subcategories = ['PROTAC', 'LIGANDS']  # For the legend
colors = ['green', 'purple']
substructures_to_substructures = {'PROTAC': "PROTAC", 'LIGANDS':"Ligands&Linker"}
metrics = ["Accuracy"]

# Set up the figure
fig, ax = plt.subplots(figsize=(15, 5), nrows = 1, ncols=1, sharex=True)

# Number of groups
n_groups = len(group_labels)

# Width of each bar
bar_width = 1/(1+len(subcategories))

# The x location for the groups
indices = np.arange(n_groups)

x_factor = 1.4

# Plotting each group
for metric in metrics:



    for i, category in enumerate(subcategories):
        
        vals = []
        stds = []
        for dataset_name in acc_to_plot["metrics"].keys():
            vals.append(acc_to_plot["metrics"][dataset_name]["model"][category]["Accuracy"][best_epoch])
            stds.append(std_to_plot["metrics"][dataset_name]["model"][category]["Accuracy"][best_epoch])

    # data = [validity_dict[validity_type][best_epoch] for ]
        # The x location for the bars within each group
        bar_positions = x_factor*indices + i * bar_width +bar_width
        ax.bar(bar_positions, vals, width=bar_width, label=substructures_to_substructures[category], color=(colors[i],0.9))
        ax.errorbar(bar_positions, vals, yerr=stds, fmt=".", color=("k", 0.7))

    # Setting the x-ticks and labels for each group
    plt.xticks(x_factor*indices + bar_width * 1.5, group_labels)

    # Adding labels and title
    ax.set_ylim([0, 1])
    ax.title.set_text(metric)


    # Adding a legend
    ax.legend()

#fig.savefig(f'fig_results/{model.model_type}_{identifying_comment}_protac_ligandslinker_accuracy.svg', format='svg', dpi=1200, bbox_inches='tight')
#fig.savefig(f'fig_results/{model.model_type}_{identifying_comment}_protac_ligandslinker_accuracy.png', format='png', dpi=1200, bbox_inches='tight')



## Plot frequency of flips

In [None]:
import matplotlib.pyplot as plt

if model_crossfold:
    acc_to_plot = measures_avg
    std_to_plot = measures_std
else:
    acc_to_plot = output

dataset_names = list(acc_to_plot["flip_fraction"].keys())
flip_fractions = [acc_to_plot["flip_fraction"][dataset_name][best_epoch] for dataset_name in acc_to_plot["flip_fraction"].keys()]
flip_stds = [std_to_plot["flip_fraction"][dataset_name][best_epoch] for dataset_name in std_to_plot["flip_fraction"].keys()]
#flip_stds
# Creating the bar plot
plt.figure(figsize=(15, 6))  # Optional: Specifies the figure size
plt.bar(dataset_names, flip_fractions, color='skyblue')  # Creates the bar plot with names on the x-axis and values on the y-axis
plt.errorbar(dataset_names, flip_fractions, yerr=flip_stds, fmt="o", color="k")

# Adding labels and title
plt.xlabel('')  # X-axis label
plt.ylabel('')  # Y-axis label
plt.ylim(bottom=0)
#plt.title('Flipped fraction of PROTACs')  # Plot title

#plt.savefig(f'fig_results/{model.model_type}_{identifying_comment}_flipfraction.svg', format='svg', dpi=1200, bbox_inches='tight')
#plt.savefig(f'fig_results/{model.model_type}_{identifying_comment}_flipfraction.png', format='png', dpi=1200, bbox_inches='tight')


# Display the plot
plt.show()


## Plot valid SMILES & splits

In [None]:
import matplotlib.pyplot as plt
import numpy as np

if model_crossfold:
    acc_to_plot = measures_avg
else:
    acc_to_plot = output

group_labels = list(acc_to_plot["validity_fraction"].keys())
subcategories = ['VALID SPLIT', 'POI SMILES', 'LINKER SMILES', 'E3 SMILES']  # For the legend
colors = ['purple', 'red', 'green', 'blue']

# Random data for 4 groups of 4 bars each
data = np.random.rand(len(group_labels), 4)

# Set up the figure
plt.figure(figsize=(15, 6))

# Number of groups
n_groups = len(group_labels)

# Width of each bar
bar_width = 1/(1+len(subcategories))

# The x location for the groups
indices = np.arange(n_groups)

# Plotting each group
for i, validity_type in enumerate(subcategories):

    data = [validity_dict[validity_type][best_epoch] for dataset_name, validity_dict in acc_to_plot["validity_fraction"].items()]
    # The x location for the bars within each group
    bar_positions = indices + i * bar_width
    plt.bar(bar_positions, data, width=bar_width, label=subcategories[i], color=colors[i])

# Setting the x-ticks and labels for each group
plt.xticks(indices + bar_width * 1.5, group_labels)

# Adding labels and title
plt.xlabel('')
plt.ylabel('Fraction valid')
plt.title('Valid split & SMILES')

# Adding a legend
plt.legend()

# Display the plot
plt.show()

# _---------------Plots during training-------------------_

# Plot average metrics during training

Get the average over PROTACs for an epoch, dataset, and metric

## Plot each dataset

In [None]:
metrics_to_include_in_plot = ["Accuracy"] #Atoms_wrong, Accuracy, Precision, Recall, F1, 

structures_to_include_in_plot = ["PROTAC", "POI", "Linker", "E3", "LIGANDS"]

if model_crossfold:
    acc_to_plot = measures_avg
    std_to_plot = measures_std
else:
    acc_to_plot = aggregated_output
keys = list(acc_to_plot['loss'].keys())
train_dataset_name = keys[0]
val_dataset_name = keys[1]

structure_type_to_color = {"PROTAC": 'red', 'POI': 'red', 'LINKER': 'black', 'E3': 'blue', 'LIGANDS': 'blue', 'LIGANDS+LINKER': 'purple'}
dataset_to_color = {"Train": 'black', "train_CV": 'black', "Validation": 'green', 'val_CV': 'green', "Test PROTAC": 'purple', "Test POI": 'red', "Test Linker": 'gray', "Test E3": 'blue', 'Dummy': 'orange'}

structure_type_to_label = {"PROTAC": "PROTAC", "LIGANDS": "Ligands & Linker", "POI": "Warhead"}

done=False


for dataset_idx, dataset_name in enumerate(acc_to_plot['metrics'].keys()):
    for accuracy_origin in acc_to_plot['metrics'][dataset_name].keys():

        if accuracy_origin != "model":
            continue


        #Make a plot for all accuracies
        fig, ax1 = plt.subplots()
        ax1.set_xlabel('Epochs', color="black")
        ax1.set_ylabel('Node accuracy', color="black", rotation='vertical')

        #ax1.yaxis.set_label_coords(-.18, 0.5)

        #Fix x-axis and get adaptive tick step size
        x = list(range(1-int(train_params["compute_pretrained_values"]),len(acc_to_plot['loss'][train_dataset_name])+1-int(train_params["compute_pretrained_values"])))
        ax1.ticklabel_format(style='plain', axis='x', useOffset=False)
        allowed_tick_sizes = [1, 2, 5, 10, 25, 50, 100]
        desired_num_tick_steps = 10
        deviation_num_tick_steps = [abs(len(x)//tick_size-desired_num_tick_steps) for tick_size in allowed_tick_sizes]
        chosen_tick_step_size = allowed_tick_sizes[deviation_num_tick_steps.index(min(deviation_num_tick_steps))]
        fig.gca().xaxis.set_major_locator(mticker.MultipleLocator(chosen_tick_step_size))

        accuracy_plots_list = []
        loss_plot_list = []
        plotted_dummy_metrics = []
        for structure_type in acc_to_plot['metrics'][dataset_name][accuracy_origin].keys():
            if structure_type not in structures_to_include_in_plot:
                continue

            for metrics_type, metrics in acc_to_plot['metrics'][dataset_name][accuracy_origin][structure_type].items():
                if metrics_type not in metrics_to_include_in_plot:
                    continue
                
                

                #if isinstance(metrics, dict) or len(metrics) == 0:
                #    continue
                #print(metrics)
                accuracy_plots_list.append(ax1.plot(x, metrics.values(), color=structure_type_to_color[structure_type], label= structure_type_to_label[structure_type])) #  f'{accuracy_origin}_{structure_type}_{metrics_type}'
        
        if True:
            if True:
                if len(metrics_to_include_in_plot)>1:
                    dummy_label = f'Dummy_{metrics_type}'
                else:
                    dummy_label = "Dummy"
                if dummy_label not in plotted_dummy_metrics:
                    dummy_values = list(acc_to_plot["metrics"]["Dummy"][accuracy_origin][structure_type][metrics_type].values())
                    average_dummy_value_list = [avg(dummy_values)] * len(dummy_values)
                    accuracy_plots_list.append(ax1.plot(x, average_dummy_value_list, color="orange", label=dummy_label)) #  f'{accuracy_origin}_{structure_type}_{metrics_type}'
                    plotted_dummy_metrics.append(dummy_label)

        if model_crossfold:
            x_arr = np.array(x)
            for structure_type in acc_to_plot['metrics'][dataset_name][accuracy_origin].keys():
                if structure_type not in structures_to_include_in_plot:
                    continue

                avg_tmp_dict = acc_to_plot['metrics'][dataset_name][accuracy_origin][structure_type]
                std_tmp_dict = std_to_plot['metrics'][dataset_name][accuracy_origin][structure_type]
                for (metrics_type, avg_vals), (metrics_type, std_vals) in zip(avg_tmp_dict.items(), std_tmp_dict.items()):
                    if metrics_type not in metrics_to_include_in_plot:
                        continue
                    
                    acc_arr = np.array(list(avg_vals.values()))
                    std_arr = np.array(list(std_vals.values()))
                    #print(acc_arr)
                    #print(avg_tmp_dict)
                    ax1.fill_between(x_arr, acc_arr-std_arr, acc_arr+std_arr, color=structure_type_to_color[structure_type], alpha=0.3)
        

        ax1.tick_params(axis='y', labelcolor="black")
        ax1.set_ylim(bottom=0, top=1)
        ax1.set_xlim(left=min(x), right=max(x))
        #ax1.legend(loc='lower left')
        ax1.set_title(label=f"Node accuracy of {accuracy_origin} for {dataset_name} dataset\ndescriptors: {[node_descriptors] + model.graph_descriptor_list}")

        ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

        if dataset_idx<2:
            for _, dataset_name_for_loss in enumerate(acc_to_plot['metrics'].keys()): 
                loss = acc_to_plot['loss'][dataset_name_for_loss].values()
                if avg(loss) == 0:
                    continue

                loss_plot_list.append(ax2.plot(x, loss, color=dataset_to_color[dataset_name_for_loss], linestyle='dotted'))

        if len(loss_plot_list) > 0:
            #lns = ax2.plot(x, train_loss, color="black", linestyle='dotted', label='Loss')
            ax2.set_yscale(value="log")
            ax2.set_ylabel('Loss', color="black", rotation='vertical')  # we already handled the x-label with ax1
            #ax2.yaxis.set_label_coords(1.12, 0.5)
            ax2.tick_params(axis='y')

        ax3 = ax1.twinx()
        ax3.set_yticks([])
        ax4_dummy = ax1.twinx()
        ax4_dummy.set_yticks([])


        legend_positions = [1-((len(accuracy_plots_list)+1)*0.03)]
        available_axies = [ax1, ax2, ax3, ax4_dummy]
        legend_titles = metrics_to_include_in_plot #['Metrics']
        plot_list = [l for l in [accuracy_plots_list, loss_plot_list] if len(l)>0]
        if len(plot_list) == 1:
            del available_axies[1]


        plot_legend =True
        if plot_legend:
            lgnd_str = ""
        else:
            lgnd_str = "_noLegend"

        for chosen_plot_list, legend_pos, chosen_ax, leg_title in zip(plot_list, legend_positions, available_axies, legend_titles):
            lables = []
            for idx, plot in enumerate(chosen_plot_list):
                label = plot[0].get_label()
                lables.append(label)
                if idx == 0:
                    plots = plot
                else:
                    plots = plots+plot
            #lables = [plot.get_label() for plot in plots]
            if plot_legend:
                leg = chosen_ax.legend(plots, lables, loc='center left', bbox_to_anchor=(1.2, legend_pos), frameon=True, title=leg_title)
                leg.get_frame().set_linewidth(1.0)
                leg.get_frame().set_edgecolor('k')




        #fig.tight_layout()  # otherwise the right y-label is slightly clipped
        fig.show()
        #fig.savefig(f'fig_butinacutoff/NodePred_final{lgnd_str}.svg', format='svg', dpi=1200, bbox_inches='tight')
        #fig.savefig(f'fig_butinacutoff/NodePred_final{lgnd_str}.png', format='png', dpi=1200, bbox_inches='tight')
        done = False

        if done:
            break
    if done:
        break



## Plot accuracy for each structure

In [None]:
from scipy.ndimage import gaussian_filter1d

sigma = 1 #smoothing



dataset_to_color = {"Train": 'black', "train_CV": 'black', "Validation": 'green', 'val_CV': 'green', "Test PROTAC": 'purple', "Test POI": 'red', "Test Warhead": 'red', "Test Linker": 'gray', "Test E3": 'blue', 'Dummy': 'orange'}

metrics_to_include_in_plot = ["Accuracy"] #Atoms_wrong, Accuracy, Precision, Recall, F1, 
structures_to_include_in_plot = ["PROTAC"] #, "POI", "Linker", "E3", "LIGANDS"]


structure_type_to_label = {"PROTAC": "PROTAC", "LIGANDS": "Ligands & Linker", "POI": "Warhead"}


plot_check_library = False

if model_crossfold:
    acc_to_plot = measures_avg
    std_to_plot = measures_std
else:
    acc_to_plot = aggregated_output


if True:
    keys = list(acc_to_plot['loss'].keys())
    train_dataset_name = keys[0]
    val_dataset_name = keys[1]
    for i, structure_type in enumerate(acc_to_plot['metrics'][train_dataset_name]['model'].keys()):
        if structure_type not in structures_to_include_in_plot:
            continue

        #Make a plot for all accuracies
        fig, ax1 = plt.subplots()
        ax1.set_xlabel('Epochs', color="black")
        ax1.set_ylabel('Node accuracy', color="black", rotation='vertical')

        #ax1.yaxis.set_label_coords(-.18, 0.5)


        #get adaptive tick step size
        x = list(range(1-int(model.compute_pretrained_values),len(acc_to_plot['loss'][train_dataset_name])+1-int(model.compute_pretrained_values)))
        x = [val for i, val in enumerate(x) if i<best_epoch]
        ax1.ticklabel_format(style='plain', axis='x', useOffset=False)
        allowed_tick_sizes = [1, 2, 5, 10, 25, 50, 100]
        desired_num_tick_steps = 10
        deviation_num_tick_steps = [abs(len(x)//tick_size-desired_num_tick_steps) for tick_size in allowed_tick_sizes]
        chosen_tick_step_size = allowed_tick_sizes[deviation_num_tick_steps.index(min(deviation_num_tick_steps))]
        fig.gca().xaxis.set_major_locator(mticker.MultipleLocator(chosen_tick_step_size))


        ax1.tick_params(axis='y', labelcolor="black")
        ax1.set_ylim(bottom=0, top=1)
        ax1.set_xlim(left=min(x), right=max(x))
        #ax1.set_title(label=f"{metrics_to_include_in_plot} of {structure_type}\ndescriptors: {[node_descriptors] + model.graph_descriptor_list}")


        accuracy_plots_list = []
        accuracy_plots_check_library_list = []


        for dataset_name in acc_to_plot['metrics'].keys():
            for metric_type, acc_tmp in acc_to_plot['metrics'][dataset_name]['model'][structure_type].items():
                if metric_type not in metrics_to_include_in_plot:
                    continue
                elif dataset_name not in dataset_to_color:
                    continue

                if dataset_name == "Dummy":
                    dummy_values = list(acc_to_plot["metrics"]["Dummy"]["model"][structure_type][metric_type].values())
                    average_dummy_value_list = [avg(dummy_values)] * len(x)
                    accuracy_plots_list.append(ax1.plot(x, average_dummy_value_list, color="orange", label="Dummy"))
                    continue

                acc_tmp = list(acc_tmp.values())
                acc_tmp = gaussian_filter1d(acc_tmp, sigma=sigma)
                acc_tmp = [val for i, val in enumerate(acc_tmp) if i<best_epoch]
                accuracy_plots_list.append(ax1.plot(x, acc_tmp, color=dataset_to_color[dataset_name], label=f'{dataset_name}'))
                
                if plot_check_library:
                    acc_check_library_tmp = list(acc_to_plot['metrics'][dataset_name]['model+check_library'][structure_type][metric_type].values())
                    accuracy_plots_check_library_list.append(ax1.plot(x, acc_check_library_tmp, color=dataset_to_color[dataset_name], label=f'{dataset_name}', linestyle='dashed'))

                if model_crossfold:

                    std_tmp = list(std_to_plot['metrics'][dataset_name]['model'][structure_type][metric_type].values())
                    std_tmp = gaussian_filter1d(std_tmp, sigma=sigma)
                    std_tmp = [val for i, val in enumerate(std_tmp) if i<best_epoch]
                    std_check_library_tmp = list(std_to_plot['metrics'][dataset_name]['model+check_library'][structure_type][metric_type].values())
                    std_check_library_tmp = [val for i, val in enumerate(std_check_library_tmp) if i<best_epoch]

                    x_arr = np.array(x)
                    acc_tmp_np = np.array(acc_tmp)
                    std_tmp_np = np.array(std_tmp)
                    ax1.fill_between(x_arr, acc_tmp_np-std_tmp_np, acc_tmp_np+std_tmp_np, color=dataset_to_color[dataset_name], alpha=0.3)

                    if plot_check_library:
                        acc_check_library_tmp_np = np.array(acc_check_library_tmp)
                        std_check_library_tmp_np = np.array(std_check_library_tmp)
                        ax1.fill_between(x_arr, acc_check_library_tmp_np-std_check_library_tmp_np, acc_check_library_tmp_np+std_check_library_tmp_np, color=dataset_to_color[dataset_name], alpha=0.3)

        loss_plot_list = []
        ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis
        ax2.set_yscale(value="log")
        ax2.set_ylabel('Loss', color="black", rotation='vertical')  # we already handled the x-label with ax1
        #ax2.yaxis.set_label_coords(1.12, 0.5)
        ax2.tick_params(axis='y')
        train_loss = list(acc_to_plot['loss'][train_dataset_name].values())
        val_loss = list(acc_to_plot['loss'][val_dataset_name].values())
        train_loss = gaussian_filter1d(train_loss, sigma=sigma)
        val_loss = gaussian_filter1d(val_loss, sigma=sigma)
        train_loss = [val for i, val in enumerate(train_loss) if i<best_epoch]
        val_loss = [val for i, val in enumerate(val_loss) if i<best_epoch]
        loss_plot_list.append(ax2.plot(x, train_loss, color="black", linestyle='dotted', label='Train')) #marker='--'
        loss_plot_list.append(ax2.plot(x, val_loss, color="green", linestyle='dotted', label='Validation')) #marker='--'

            

        ax3 = ax1.twinx()
        ax3.set_yticks([])
        ax4_dummy = ax1.twinx()
        ax4_dummy.set_yticks([])


        legend_positions = [0.8, 0.35, -0.02]
        available_axies = [ax1, ax2, ax3, ax4_dummy]
        legend_titles = ['Accuracy', 'Model+Check Library accuracy']
        if plot_check_library is False:
            legend_titles.remove('Model+Check Library accuracy')
        plot_list = [accuracy_plots_list, accuracy_plots_check_library_list, loss_plot_list]
        plot_list = [l for l in plot_list if len(l)>0]
        if len(plot_list) == 1:
            del available_axies[1]

        plot_legend = False

        for chosen_plot_list, legend_pos, chosen_ax, leg_title in zip(plot_list, legend_positions, available_axies, legend_titles):
            lables = []
            for idx, plot in enumerate(chosen_plot_list):
                label = plot[0].get_label()
                lables.append(label)
                if idx == 0:
                    plots = plot
                else:
                    plots = plots+plot
            lables = [plot.get_label() for plot in plots]
            if plot_legend:
                leg = chosen_ax.legend(plots, lables, loc='center left', bbox_to_anchor=(1.2, legend_pos), frameon=True, title=leg_title)
                leg.get_frame().set_linewidth(1.0)
                leg.get_frame().set_edgecolor('k')
            #ax1.add_artist(leg)
            #chosen_ax.legend(frameon=True)
            figs = {structure_type: fig}
            fig.show()

        legend_str=""
        if not plot_legend:
            legend_str="_noLegend"
        #fig.savefig(f'fig_finaltraincurves/{model.model_type}_{identifying_comment}_{structure_type}_{metrics_to_include_in_plot}{legend_str}.svg', format='svg', dpi=1200, bbox_inches='tight')
        #fig.savefig(f'fig_finaltraincurves/{model.model_type}_{identifying_comment}_{structure_type}_{metrics_to_include_in_plot}{legend_str}.png', format='png', dpi=1200, bbox_inches='tight')
    

# ---- Plot predicted graphs ----

In [None]:
use_library = False

num_protacs_per_set = 5

chosen_dataset = "Validation"

fp_function = compute_countMorgFP

e3_database_fps = None #compute_countMorgFP(e3_library)
poi_database_fps = None #compute_countMorgFP(poi_library)

plot_postprocessed_graphs = False


protac_smi_to_plot = []#['COc1ccc(OC)c2c1c(OC)cc1c(=O)cc(-c3ccc(CCn4cc(-c5ccc6c(c5)C(c5ccc(Cl)cc5)=NC5(CC5)c5nnc(C)n5-6)cn4)cc3)oc12',
                     # 'COc1ccc(Cl)c(S(=O)(=O)Nc2ccc(-c3nc(OCC4CN(c5ccc(N=Nc6ccc(C(=O)c7ccc(-c8cc(=O)c9cc(OC)c%10c(OC)ccc(OC)c%10c9o8)cc7)cc6)cc5)CCO4)c4c(C)n[nH]c4n3)cc2)c1']#['COc1ccc(OC)c2c1c(OC)cc1c(=O)cc(-c3ccc(NC=NC4CC5(C)CC(Oc6ccc(C#N)c(Cl)c6)CC5(C)C4)cc3)oc12',
                     # 'Cc1ccc(F)c(S(=O)(=O)Nc2ccc(-c3nc(OCC4CN(C(=O)c5cnc(N6CCN(CCCCCN=COc7ccc8c(c7)CCCN8C(=O)CCl)CC6)nc5)CCO4)c4cn[nH]c4n3)cc2)c1']


for dataset_name, dataset  in datasets_dict.items():
    if dataset_name != chosen_dataset:
        continue

    # Create a figure with 'num_protacs_per_set' rows and 3 columns for each dataset
    num_columns = 2 + int(plot_postprocessed_graphs) + int(use_library)
    fig, axarr = plt.subplots(num_protacs_per_set, num_columns, figsize=(num_columns*5, num_protacs_per_set * 5))

    loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False)

    


    for i in range(num_protacs_per_set):
        protac_idx = random.randint(0, len(dataset) - 1)

        #if len(protac_idx_to_plot)>0:
        #    if protac_idx not in protac_idx_to_plot:
        #        continue

        j = 0
        for batch in loader:
            
            if protac_smi_to_plot != []:
                if batch.smiles[0] in protac_smi_to_plot:       #assumes batch_size == 1
                    protac_smi_to_plot.remove(batch.smiles[0])
                    print(f'SMILES was found from in the SMILES_to_plot_list: {batch.smiles[0]}')
                    protac_idx = j
                    break
            elif j==protac_idx:
                break
            else:
                j += 1

        
        batch = batch.to(model.device)


        ground_truth_colors = ['red' if label == 0 else 'gray' if label == 1 else 'blue' for label in batch.substructure_labels]


        G, pos = make_graph_with_pos(batch.smiles[0])
     

        # Plot ground truth on the left column
        nx.draw_networkx(G, pos=pos, ax=axarr[i, 0], node_color=ground_truth_colors, with_labels=False, node_size=50)
        #axarr[i, 0].set_title(f"Ground truth {dataset_name} {protac_idx}")
        axarr[i, 0].axis('equal')


        if dataset_name == "Dummy":
            randomize_y = True
        else:
            randomize_y = False
        y_type, y_location, predictions_and_prob, boundary_bond_probs_dataset = model.batch_to_output_and_classpredictions(batch=batch, e3_database_fps=e3_database_fps, poi_database_fps=poi_database_fps, fp_function=fp_function, randomize_y=randomize_y)


        class_predictions = predictions_and_prob["model"][0]


        predicted_colors = ['red' if label == 0 else 'gray' if label == 1 else 'blue' for label in class_predictions]
        
        nx.draw_networkx(G, pos=pos, ax=axarr[i, 1], node_color=predicted_colors, with_labels=False, node_size=50)
        #axarr[i, 1].set_title(f"Predicted {dataset_name} {protac_idx}")
        axarr[i, 1].axis('equal')

        if use_library:
            class_predictions_library = predictions_and_prob["model+check_library"][0]
            predicted_colors_library = ['red' if label == 0 else 'gray' if label == 1 else 'blue' for label in class_predictions_library]
        
            nx.draw_networkx(G, pos=pos, ax=axarr[i, 2], node_color=predicted_colors_library, with_labels=False, node_size=50)
            axarr[i, 2].set_title(f"Predicted & librarychecked {dataset_name} {protac_idx}")
            axarr[i, 2].axis('equal')
        

      

    plt.tight_layout()
    plt.show()


    #fig.savefig(f'fig_results/{model.model_type}_{identifying_comment}_{chosen_dataset}_{split_idx}.svg', format='svg', dpi=1200, bbox_inches='tight')

