In [None]:
import pandas as pd
import numpy as np
import json
import pprint
import matplotlib.pyplot as plt

import deepchem as dc
from deepchem.models import GCNModel
from deepchem.feat.graph_data import GraphData

from rdkit import Chem
from rdkit.Chem import rdmolops, Draw

import networkx as nx
from torch_geometric.utils.convert import from_networkx
from torch_geometric.data import Data
from torch_geometric.loader import NeighborLoader, ImbalancedSampler
from torch_geometric.nn import GCNConv, BatchNorm
from torch_geometric.nn import GATConv

import dgl
import dgllife
from dgllife.model import GCN
from dgllife.model import GAT
from dgllife.model.readout import WeightedSumAndMax

import torch
import torch.nn as nn
import torch.nn.functional as F
import optuna

import os
import datetime

# see torch version and cuda version
print(torch.__version__)
print(torch.version.cuda)

# run inputs
model_type = input("Enter the model type GCN, GNN, GAT: ")
if model_type not in ["GCN", "GNN", "GAT"]:
    raise ValueError("Model type not supported")

motif_used = input("Enter if motif is used, TRUE or FALSE: ").strip().upper() == "TRUE"
test_used = input("Enter if test is used, TRUE or FALSE: ").strip().upper() == "TRUE"
print("Model type: ", model_type)
print("Motif used: ", motif_used)
print("Test used: ", test_used)

# failure detection
assert type(model_type) == str
assert type(motif_used) == bool
assert type(test_used) == bool
print("Model type: ", type(model_type))
print("Motif used: ", type(motif_used))
print("Test used: ", type(test_used))


In [None]:
# set seed
torch.manual_seed(42)
np.random.seed(0)
import random
random.seed(0)


In [2]:
# load data
path = r"Færdig_data_med_clin_data.csv"
molecules = pd.read_csv(path)





In [None]:
# create molecule graph for each molecule.
featurizer = dc.feat.MolGraphConvFeaturizer(use_edges=True)
molecules["atom_graph"] = [featurizer.featurize(smiles)[0] for smiles in molecules["SMILES"]]

# remove molecules that could not be converted to graphs
molecules = molecules[molecules["atom_graph"].apply(lambda x: isinstance(x, GraphData))]


In [None]:
# create a dataframe to store the molecule graphs
molecules_graph_data = pd.DataFrame()

molecules_graph_data["atom_graph"] = molecules["atom_graph"]
molecules["atom_graph"] = molecules.index


#placeholder graph for motifs(not used)
smiles = "C(c1ccc(cc1)N)(=O)OCC"
motif_atom_graph_placeholder = featurizer.featurize([(smiles)])[0]

motif_atom = pd.DataFrame({"atom_graph": [motif_atom_graph_placeholder]})

motif_atom_index = 50000

#set index to 50000 to avoid overlap with the molecules
motif_atom.index = (motif_atom_index,)

# add the motif to the dataframe of graphs
molecules_graph_data = pd.concat([molecules_graph_data, motif_atom])

# convert to dgl graph which are used in the model later
dgl_g = [dgl.add_self_loop(graph.to_dgl_graph()) for graph in molecules_graph_data["atom_graph"]]
molecules_graph_data["dgl_graph"] = dgl_g



In [None]:



## gets the motifs from the molecules
def get_motifs(mol, motif_dict, ignore_dict=False):
    motifs = motif_dict

    sssr = rdmolops.GetSymmSSSR(mol)
    for ring in sssr:
        atom_symbols = tuple(sorted(mol.GetAtomWithIdx(idx).GetSymbol() for idx in ring))
        bond_types = set()
        for i in range(len(ring)):
            bond = mol.GetBondBetweenAtoms(ring[i], ring[(i+1) % len(ring)])
            bond_types.add(str(bond.GetBondType()))
        bond_type = "AROMATIC" if "AROMATIC" in bond_types else "MIXED" if len(bond_types) > 1 else bond_types.pop()
        motif_= str((atom_symbols, bond_type))

        if ignore_dict != True:
            if motif_ not in motifs.keys():
                print("value not added")
                continue
                
            else:
                print("value addedd")
                motifs[motif_] += 1
        else:
            motifs[motif_] = 1
        
            
        
        
        
    # Add bonds to the motifs
    for bond in mol.GetBonds():
        if not bond.IsInRing():
            start_atom_symbol = mol.GetAtomWithIdx(bond.GetBeginAtomIdx()).GetSymbol()
            end_atom_symbol = mol.GetAtomWithIdx(bond.GetEndAtomIdx()).GetSymbol()
            bond_type = str(bond.GetBondType())
            atom_symbols = tuple(sorted([start_atom_symbol, end_atom_symbol]))
            motif_ = str((atom_symbols, bond_type))
            if ignore_dict != True:
                if motif_ not in motifs:
                    continue
                else:
                    motifs[motif_] += 1
            else:
                motifs[motif_] = 1
        
    return motifs

## process motifs for a single molecule
def process_single_molecule(smiles, motif_dict, ignore_dict=False):
    molecule = Chem.MolFromSmiles(smiles)
    if molecule is None:
        print(f"Failed to process molecule with SMILES: {smiles}")
        return {}
    
    motifs_dict = get_motifs(molecule, motif_dict, ignore_dict)
    return motifs_dict



In [None]:
# motif count across all molecules
motif_count_dict = {}

for idx, molecule in molecules.iterrows():
    single_dict = process_single_molecule(molecule["SMILES"], {}, ignore_dict=True)
    
    for key, value in single_dict.items():
        if key not in motif_count_dict:
            motif_count_dict[key] = 1
        else:
            motif_count_dict[key] += 1

# filter motifs with count less than 5
motif_count_dict_with_more_than_5 = {}

for key, value in motif_count_dict.items():
    if value > 5:
        motif_count_dict_with_more_than_5[key] = value


In [None]:



#one-hot encode the motifs from motif.columns
motif_bag = pd.get_dummies(motif_count_dict_with_more_than_5.keys(), dtype=float)

# create dict with motif as keys and 0 as values
motif_dict = dict.fromkeys(motif_count_dict_with_more_than_5.keys(), 0)

# sort the dict keys for consistency when using the dict across all molecules again
motif_dict = dict(sorted(motif_dict.items()))

G = nx.Graph()

# add motif nodes to the motif graph
for idx, row in motif_bag.iterrows():
    # Convert the row to a dictionary of features
    features = row.to_dict()
    # Add the node to the graph with its features
    # find which column is 1 and set the name of the node to that column name
    for key, value in features.items():
        if value == 1.0:
            G.add_node(key, **features)
            # tf_weight = motif_count_dict_with_more_than_5[key]
            # extraft the value of the first row in the motif dataframe, in the column that is column_name
            # G.nodes[key]['tf_weight'] = tf_weight
            G.nodes[key]["atom_graph"] = motif_atom_index
            G.nodes[key]['node_type'] = 'motif'
            break

In [None]:
for idx, molecule in molecules.iterrows():
    motif = process_single_molecule(molecule['SMILES'],motif_dict.copy())
    if motif == {}:
        continue
    # add the motif for the molecule to the graph as a node 
    # add the motif dictionary to the graph
    G.add_node(molecule["SMILES"], **motif)
    G.nodes[molecule["SMILES"]]['atom_graph'] = molecule["atom_graph"]
    G.nodes[molecule["SMILES"]]['node_type'] = 'molecule'
    G.nodes[molecule["SMILES"]]['Approval'] = molecule["Vores_approval"]



def tf_idf(molecule,motif, molecules_count):
    term_frequency_all_mol = motif_count_dict_with_more_than_5[motif]
    # [:-3] is to ignore the last entries which is the non-motif features
    print(G.nodes[molecule].values())
    tf = G.nodes[molecule][motif]/sum(list(G.nodes[molecule].values())[:-3])
    idf = np.log(molecules_count / term_frequency_all_mol)
    print(G.nodes[molecule][motif])
    print(f"TF: {tf}")
    print(f"IDF: {idf}")

    return tf * idf









In [None]:
# for each molecule, add an edge between the molecule and the motif if the motif is present in the molecule(ie. the value of the motif in the molecule is greater than 0)
for molecule in G.nodes():
    if G.nodes[molecule]['node_type'] == 'molecule':
        print(f"Processing molecule {molecule}")
        for motif in G.nodes():
            if G.nodes[motif]['node_type'] == 'motif':
                print(f"Processing motif {motif}")
                print("get_value: ", G.nodes[molecule].get(motif, 0))
                if G.nodes[molecule].get(motif, 0) > 0:
                    print("here")
                    G.add_edge(motif, molecule , weight=tf_idf(molecule, motif, len(molecules)))
                
                
                

In [None]:
# create a co-occurence matrix for the pmi values
cooccurence_matrix = pd.DataFrame(index=motif_bag.columns, columns=motif_bag.columns, dtype=int)

cooccurence_matrix = cooccurence_matrix.fillna(0)

molecules = [node for node in G.nodes() if G.nodes[node]['node_type'] == 'molecule']
motifs = [node for node in G.nodes() if G.nodes[node]['node_type'] == 'motif']

# Iterate over molecules
for molecule in molecules:
    connected_motifs = [motif for motif in motifs if G.has_edge(motif, molecule)]
    for i, motif in enumerate(connected_motifs):
        for motif_2 in connected_motifs[i+1:]:
            cooccurence_matrix.loc[motif, motif_2] += 1
            cooccurence_matrix.loc[motif_2, motif] += 1 

# values here the raw count of each motif
for i, count in motif_count_dict_with_more_than_5.items():
    cooccurence_matrix.loc[i,i] = count

# divide each entry in the matrix with number of molecules
cooccurence_matrix = cooccurence_matrix.div(len(molecules))

# pmi value function
def pmi_value(motif_1, motif_2, cooccurence_matrix):
    pmi = np.log(cooccurence_matrix.loc[motif_1, motif_2] / (cooccurence_matrix.loc[motif_1, motif_1] * cooccurence_matrix.loc[motif_2, motif_2]))
    return pmi


# add edges between motifs and add the pmi value as the weight, if the pmi value is 0, the edge is not added
for i, motif in enumerate(motifs):
    for motif_2 in motifs[i+1:]:
        if pmi_value(motif, motif_2, cooccurence_matrix) <= 0:
            continue
        G.add_edge(motif, motif_2, weight=max(pmi_value(motif, motif_2, cooccurence_matrix), 0))




In [None]:
# need to a approval value for the motifs for the model to work(not used)
for i in G.nodes:
    if G.nodes[i]["node_type"] == "motif":
        G.nodes[i]["Approval"] = -1


In [None]:
# convert features to tensors and store them as x and y in the graph
for i, attr_ in G.nodes(data=True):
    attr = list(attr_.values())
    if attr_["node_type"] == "molecule":
        G.nodes[i]["x"] = torch.tensor(attr[:-3])
        G.nodes[i]["y"] = torch.tensor(attr[-1])
    else:
        G.nodes[i]["x"] = torch.tensor(attr[:-3])
        G.nodes[i]["y"] = torch.tensor(attr[-1])

In [None]:
# create Data object for training 
node_features = []
node_labels = []
edge_index = [[],[]]
mask = []
atom_graphs = []


for i, attr in G.nodes(data=True):
    
    node_features.append(attr["x"])
    node_labels.append(attr["y"])
    if attr["node_type"] == "molecule":
        mask.append(True)
        atom_graphs.append(attr["atom_graph"])
        
    else: 
        mask.append(False)
        atom_graphs.append(attr["atom_graph"])




motif_indexes = {name: i for i, name in enumerate(G.nodes)}



for edge in G.edges:
    edge_index[0].append(motif_indexes[edge[0]])
    edge_index[1].append(motif_indexes[edge[1]])


edge_weights = [G[u][v]["weight"] for u, v in G.edges()]

node_features = torch.stack(node_features)
node_labels = torch.tensor(node_labels)
edge_index = torch.tensor(edge_index, dtype=torch.long)
edge_weights = torch.tensor(edge_weights, dtype=torch.float)
mask = torch.tensor(mask, dtype=torch.bool)
atom_graphs = torch.tensor(atom_graphs, dtype=torch.long)


data = Data(x=node_features,edge_index=edge_index,y=node_labels, edge_weight=edge_weights, atom_graphs=atom_graphs)

split_sizes = (0.7,0.2,0.1)

# only use molecules for training, validation and testing, motifs are not used except for the motif graph
molecule_nodes = torch.nonzero(mask, as_tuple=False).view(-1)
num_mol_nodes = molecule_nodes.size(0)

train_size = int(split_sizes[0]*num_mol_nodes)
val_size = int(split_sizes[1]*num_mol_nodes)
test_size = int(split_sizes[2]*num_mol_nodes)

shuffled_mol_indices = molecule_nodes[torch.randperm(num_mol_nodes)]
train_idx = shuffled_mol_indices[:train_size]
val_idx = shuffled_mol_indices[train_size: train_size+val_size]
test_idx = shuffled_mol_indices[train_size+val_size:]


data.train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
data.val_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
data.test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)


data.train_mask[train_idx] = True
data.val_mask[val_idx] = True
data.test_mask[test_idx] = True

In [None]:

class SimpleGNN(torch.nn.Module):
    def __init__(self, in_feats, hidden_feats, gnn_norm, activation, dropout):
        super(SimpleGNN, self).__init__()

        if gnn_norm == "Both":
            gnn_norm = True
        if gnn_norm == "None":
            gnn_norm = False

        self.activation = activation
        self.dropout = dropout
        self.res_feats = []
        self.convs = []   

        input_dim = in_feats
        

        for i, hidden_dim in enumerate(hidden_feats):
            if i == len(hidden_feats) - 1:
                continue
            self.convs.append(GCNConv(input_dim, hidden_dim, normalize = gnn_norm))
            input_dim = hidden_dim
        self.convs.append(GCNConv(input_dim, hidden_feats[-1], normalize = gnn_norm))


    

    def forward(self, x, edge_index, edge_weight):

        for i, conv in enumerate(self.convs[:-1]):  
            x = conv(x, edge_index, edge_weight=edge_weight)
            if self.dropout:
                x = self.dropout[i](x)
            x = self.activation[i](x)
        x = self.convs[-1](x, edge_index, edge_weight=edge_weight)  
        x = self.activation[-1](x)
        return x


In [None]:

class SimpleGAT(torch.nn.Module):
    def __init__(self, in_feats, hidden_feats, heads, activation,agg_modes, dropout):
        super(SimpleGAT, self).__init__()


        if agg_modes[0] == 'flatten':
            agg_modes = [True]*len(hidden_feats)
        elif agg_modes[0] == 'mean':
            agg_modes = [False]*len(hidden_feats)

        self.activation = activation
        self.dropout = dropout

        self.convs = []
        
        input_dim = in_feats

        for i, hidden_dim in enumerate(hidden_feats):
            if i == len(hidden_feats) - 1:
                continue
            self.convs.append(GATConv(input_dim, hidden_dim, heads=heads[i], concat=agg_modes[i], dropout=dropout[i]))
            if agg_modes[0]:
                input_dim = hidden_dim*heads[i]
            else:
                input_dim = hidden_dim
            
        self.convs.append(GATConv(input_dim, hidden_feats[-1], heads=heads[-1], concat=agg_modes[-1], dropout=dropout[-1]))

    

    def forward(self, x, edge_index):

        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            x = self.activation[i](x)
        x = self.convs[-1](x, edge_index)
        x = self.activation[-1](x)

        return x.squeeze()



In [6]:
# function to batch the dgl graphs during training
def indexes_to_batch_dgl_graphs(indexes, batch):
    indexes = batch.atom_graphs.tolist()

    atom_graphs = [molecules_graph_data.loc[index, "dgl_graph"] for index in indexes]
    
    return dgl.batch(atom_graphs)

# function to save the results of the trials to a json file during hyperparameter optimization
def save_trial_results(stopped_early, folder_path, trial, best_val_loss, epoch_train_predict_values, epoch_train_targets_values, epoch_val_predict_values, epoch_val_targets_values, all_train_loss, all_val_loss, epochs, test_used=False, epoch_test_predict_values=None, epoch_test_targets_values=None, test_loss=None):
    json_path = os.path.join(folder_path, f"trial_{trial.number}.json")
    json_object = {
        "trail_number": trial.number,
        "trial_params": trial.params,
        "best_val_loss": best_val_loss[0],
        "best_epoch": best_val_loss[1],
        "epoch_train_predict_values": epoch_train_predict_values,
        "epoch_train_targets_values": epoch_train_targets_values,
        "epoch_val_predict_values": epoch_val_predict_values,
        "epoch_val_targets_values": epoch_val_targets_values,
        "all_train_loss": all_train_loss,
        "all_val_loss": all_val_loss,
        "stoped_early": stopped_early,
        "epochs": epochs
    }
    if test_used:
        json_object["test_predict_values"] = epoch_test_predict_values
        json_object["test_targets_values"] = epoch_test_targets_values
        json_object["test_loss"] = test_loss
    with open(json_path, "w") as f:
        json.dump(json_object, f)



In [None]:

class Combined_model(torch.nn.Module):
    def __init__(self, motif_param_dict, atom_param_dict, concat_size, MLPsize, include_motif):
        super(Combined_model, self).__init__()
        self.nfeat_name = "x"
        self.include_motif = include_motif
        in_feats_motif =  data.num_node_features
        in_feats_atom = 30
        

        if model_type == "GNN":
            if include_motif:
                self.motif_model = SimpleGNN(in_feats=in_feats_motif,
                                        hidden_feats=motif_param_dict["hidden_channels"],
                                        gnn_norm=motif_param_dict["gnn_norm"],
                                        activation=motif_param_dict["activation_function"],
                                        dropout=motif_param_dict["dropout"],)
                
            self.atom_lvl_model = GCN(in_feats=in_feats_atom,
                                hidden_feats=atom_param_dict["hidden_channels"],
                                gnn_norm=atom_param_dict["gnn_norm"],
                                activation=atom_param_dict["activation_function"],
                                dropout=atom_param_dict["dropout"],
                                allow_zero_in_degree=True)
            
        elif model_type == "GCN":
            if include_motif:
                self.motif_model = SimpleGNN(in_feats=in_feats_motif,
                                        hidden_feats=motif_param_dict["hidden_channels"],
                                        gnn_norm=motif_param_dict["gnn_norm"],
                                        activation=motif_param_dict["activation_function"],
                                        dropout=motif_param_dict["dropout"],)
            
            self.atom_lvl_model = GCN(in_feats=in_feats_atom,
                                hidden_feats=atom_param_dict["hidden_channels"],
                                gnn_norm=atom_param_dict["gnn_norm"],
                                activation=atom_param_dict["activation_function"],
                                dropout=atom_param_dict["dropout"],
                                allow_zero_in_degree=True)
            

        elif model_type == "GAT":
            if include_motif:
                self.motif_model = SimpleGAT(in_feats=in_feats_motif,
                                    hidden_feats=motif_param_dict["hidden_channels"],
                                    heads=motif_param_dict["heads"],
                                    activation=motif_param_dict["activation_function"],
                                    agg_modes = motif_param_dict["GAT_agg_modes"],
                                    dropout=motif_param_dict["dropout"],)
            
            self.atom_lvl_model = GAT(in_feats=in_feats_atom,
                                hidden_feats=atom_param_dict["hidden_channels"],
                                num_heads=atom_param_dict["heads"],
                                activations=atom_param_dict["activation_function"],
                                agg_modes = atom_param_dict["GAT_agg_modes"],
                                feat_drops=atom_param_dict["dropout"])
        
        
        # MaS_size is the size of the output of the last layer of the atom model
        if model_type == "GAT":
            if atom_param_dict["GAT_agg_modes"][-1] == "flatten":
                MaS_size = atom_param_dict["hidden_channels"][-1]*atom_param_dict["heads"][-1]
            else:
                MaS_size = atom_param_dict["hidden_channels"][-1]
            self.aggr = WeightedSumAndMax(MaS_size)
        else:
            self.aggr = WeightedSumAndMax(atom_param_dict["hidden_channels"][-1])

        self.linear1 = torch.nn.Linear(concat_size, MLPsize)
        self.linear2 = torch.nn.Linear(MLPsize, 1)

    def forward(self, x, edge_index, edge_weight=None, batch_graph=None):
        node_feats = batch_graph.ndata[self.nfeat_name]
    
        x1 = self.atom_lvl_model(batch_graph, node_feats)
        x1 = self.aggr(batch_graph, x1)
    
        if model_type == "GNN" or model_type == "GCN":
            if self.include_motif:
                x2 = self.motif_model(x, edge_index, edge_weight)
                x = torch.cat((x1, x2), dim=1)
            else:
                x = x1

        elif model_type == "GAT":
            if self.include_motif:
                x2 = self.motif_model(x, edge_index)
                x = torch.cat((x1, x2), dim=1)
            else:
                x = x1
    
        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        x = F.sigmoid(x)
        return x.squeeze()


def objective(trial):

    # function to create the hyperparameters for the model
    def create_params():

        # optimizer and sampler
        if motif_used:
            amount_neighbors = trial.suggest_categorical("amount_neighbors", [5, 10, 50, 100, 200])
        epochs = trial.suggest_int("epochs", 10, 100)
        lr = trial.suggest_float("lr", 0.0001, 0.1)
        weight_decay = trial.suggest_float("weight_decay", 0.0, 0.1)
        batch_size = trial.suggest_categorical("batch_size", [8, 16, 32, 64, 128, 256, 512, 1024])


        # last MLP
        MLP_size = trial.suggest_categorical(f"MLP_size", [8, 16, 32, 64, 128, 256, 512])
        

        if model_type == "GNN":
            gnn_norm_motif = [False]
            gnn_norm_atom = ["none"]
        elif model_type == "GCN":
            gnn_norm_motif = [True]
            gnn_norm_atom = ["both"]


        if motif_used:

            motif_n_layers = trial.suggest_int("motif_n_layers", 1, 8)
            motif_hidden_channels = []
            for i in range(motif_n_layers):
                motif_hidden_channels.append(trial.suggest_categorical(f"motif_hidden_channels_{i+1}", [8, 16, 32, 64, 128, 256, 512]))

            motif_activation_function = "relu"
            if motif_activation_function == "relu":
                motif_activations = [F.relu] * motif_n_layers

            if model_type == "GNN" or model_type == "GCN":
                motif_param_dict = {
                        "n_layers": motif_n_layers,
                        "activation_function": motif_activations,
                        "gnn_norm": gnn_norm_motif*motif_n_layers,
                        "hidden_channels": motif_hidden_channels,
                        "dropout": [nn.Dropout(trial.suggest_float("motif_dropout", 0.0, 0.5))] * motif_n_layers,
                    }
            elif model_type == "GAT":
                motif_param_dict = {
                    "n_layers": motif_n_layers,
                    "activation_function": motif_activations,
                    "hidden_channels": motif_hidden_channels,
                    "dropout": [trial.suggest_float("motif_dropout", 0.0, 0.5)] * motif_n_layers,
                    "residual": [trial.suggest_categorical("motif_residual", [True, False])] * motif_n_layers,
                    "heads": [trial.suggest_int("motif_heads", 1, 3)]*motif_n_layers,
                    "GAT_agg_modes": [trial.suggest_categorical("motif_GAT_agg_modes", ["flatten", 'mean'])]*motif_n_layers # flatten is concat for this model
                }
        else:
            motif_param_dict = {}
         
        
        #atom model params

        atom_n_layers = trial.suggest_int("atom_n_layers", 1, 8)
        atom_hidden_channels = []
        for i in range(atom_n_layers):
            atom_hidden_channels.append(trial.suggest_categorical(f"atom_hidden_channels_{i+1}", [8, 16, 32, 64, 128, 256, 512]))

        atom_activation_function = "relu"
        if atom_activation_function == "relu":
            atom_activations = [F.relu] * atom_n_layers
        
        if model_type == "GNN" or model_type == "GCN":
            atom_param_dict = {
                    "n_layers": atom_n_layers,
                    "activation_function": atom_activations,
                    "gnn_norm": gnn_norm_atom*atom_n_layers,
                    "hidden_channels": atom_hidden_channels,
                    "dropout": [trial.suggest_float("atom_dropout", 0.0, 0.5)] * atom_n_layers,
                }
        elif model_type == "GAT":
            atom_param_dict = {
                "n_layers": atom_n_layers,
                "activation_function": atom_activations,
                "hidden_channels": atom_hidden_channels,
                "dropout": [trial.suggest_float("atom_dropout", 0.0, 0.5)] * atom_n_layers,
                "residual": [trial.suggest_categorical("atom_residual", [True, False])] * atom_n_layers,
                "heads": [trial.suggest_int("atom_heads", 1, 3)]*atom_n_layers,
                "GAT_agg_modes": [trial.suggest_categorical("atom_GAT_agg_modes", ["flatten", 'mean'])]*atom_n_layers # flatten is concat for this model
            }
        
        atom_out_channels = atom_hidden_channels[-1]
        if motif_used:
            motif_out_channels = motif_hidden_channels[-1]


        #final MLP input size(concat size), depends on if motif is used and the model type
        if model_type == "GNN" or model_type == "GCN":
            if motif_used:
                concat_size = atom_out_channels*2 + motif_out_channels
            else:
                concat_size = atom_out_channels*2
       
        elif model_type == "GAT":
            if motif_used: 
            
                if atom_param_dict["GAT_agg_modes"][0] == "flatten" and motif_param_dict["GAT_agg_modes"][0] == 'mean':
                    concat_size = atom_out_channels*2*atom_param_dict["heads"][0] + motif_out_channels
                elif atom_param_dict["GAT_agg_modes"][0] == "flatten" and motif_param_dict["GAT_agg_modes"][0] == "flatten":
                    concat_size = atom_out_channels*2*atom_param_dict["heads"][0] + motif_out_channels*motif_param_dict["heads"][0]
                elif atom_param_dict["GAT_agg_modes"][0] == 'mean' and motif_param_dict["GAT_agg_modes"][0] == 'mean':
                    concat_size = atom_out_channels*2 + motif_out_channels
                elif atom_param_dict["GAT_agg_modes"][0] == 'mean' and motif_param_dict["GAT_agg_modes"][0] == "flatten":
                    concat_size = atom_out_channels*2 + motif_out_channels*motif_param_dict["heads"][0] 

            else:
                if atom_param_dict["GAT_agg_modes"][0] == "flatten":
                    concat_size = atom_out_channels*2*atom_param_dict["heads"][0] 
                else:
                    concat_size = atom_out_channels*2
        
        return motif_param_dict, atom_param_dict, concat_size, lr, weight_decay, epochs, batch_size, MLP_size, amount_neighbors
    


    # get the hyperparameters
    motif_param_dict, atom_param_dict, concat_size, lr, decay, epochs, batch_size, MLP_size, amount_neighbors = create_params()


    if motif_used:
        max_range_for_neighbor_sampler = [-1] + motif_param_dict["n_layers"]*[amount_neighbors]
    else:
        # if motif is not used, the motif graph is not used.
        max_range_for_neighbor_sampler = [0]


    # sampler
    sampler1 = ImbalancedSampler(data, data.train_mask)

    # neighbor loader for the training, validation and test data
    train_loader = NeighborLoader(data =data, sampler=sampler1 ,batch_size=batch_size,num_neighbors=max_range_for_neighbor_sampler, input_nodes=data.train_mask,weight_attr="edge_weight")
    val_loader = NeighborLoader(data = data, batch_size=batch_size,num_neighbors=max_range_for_neighbor_sampler, input_nodes=data.val_mask,weight_attr="edge_weight")
    test_loader = NeighborLoader(data =data, batch_size=batch_size,num_neighbors=max_range_for_neighbor_sampler, input_nodes=data.test_mask,weight_attr="edge_weight")


    # create the model
    model_combined = Combined_model(motif_param_dict,atom_param_dict, concat_size,MLP_size, include_motif=motif_used)


    # set the optimizer and loss function
    optimizer = torch.optim.Adam(model_combined.parameters(), lr=lr, weight_decay=decay)
    criterion = torch.nn.BCELoss()


    
    # all the lists to store the results        
    all_train_loss = []
    all_val_loss = []
    all_test_loss = []
    epoch_train_predict_values = [] # all predictions for each epoch
    epoch_train_targets_values = [] # all targets for each epoch
    epoch_val_predict_values = [] # all predictions for each epoch
    epoch_val_targets_values = [] # all targets for each epoch
    epoch_test_predict_values = [] # all predictions for each epoch
    epoch_test_targets_values = [] # all targets for each epoch
    best_val_loss = (100000, 0)
    val_callback_counter=0

    # start training
    for epoch in range(epochs):
        # reset if epoch based values for each epoch
        epoch_train_loss = 0
        epoch_val_loss = 0
        epoch_test_loss = 0
        epoch_val_p_values = []
        epoch_val_t_values = []
        epoch_train_p_values = []
        epoch_train_t_values = []
        epoch_test_p_values = []
        epoch_test_t_values = []

        # iterate over the training data
        for batch in iter(train_loader):
            model_combined.train()
            optimizer.zero_grad()

            batch_atom = indexes_to_batch_dgl_graphs(batch.atom_graphs, batch)

            if model_type == "GNN" or model_type == "GCN":
                out = model_combined(batch.x, batch.edge_index, batch.edge_weight, batch_atom)
            elif model_type == "GAT":
                # if model is GAT, the edge_weight is not used
                out = model_combined(batch.x, batch.edge_index, batch_graph = batch_atom)

            train_loss = criterion(out[batch.train_mask], batch.y[batch.train_mask].float())

            train_loss.backward()

            optimizer.step()

            # append predictions and targets to lists
            epoch_train_p_values.extend(out[batch.train_mask].tolist())
            epoch_train_t_values.extend(batch.y[batch.train_mask].tolist())

            epoch_train_loss += train_loss.item()
    

        model_combined.eval()
        with torch.no_grad():
            # iterate over the test data, if test is used
            if test_used:
                for test_batch in iter(test_loader):
                    batch_graph_test = indexes_to_batch_dgl_graphs(test_batch.atom_graphs, test_batch)
                    if model_type == "GNN" or model_type == "GCN":
                        test_out = model_combined(test_batch.x, test_batch.edge_index, test_batch.edge_weight, batch_graph_test)
                    elif model_type == "GAT":
                        test_out = model_combined(test_batch.x, test_batch.edge_index, batch_graph =batch_graph_test)
                    test_loss = criterion(test_out[test_batch.test_mask], test_batch.y[test_batch.test_mask].float())

                    # append predictions and targets to lists
                    epoch_test_p_values.extend(test_out[test_batch.test_mask].tolist())
                    epoch_test_t_values.extend(test_batch.y[test_batch.test_mask].tolist())
                    epoch_test_loss += test_loss.item()
                
                
            # iterate over the validation data
            for val_batch in iter(val_loader):
                batch_graph_val = indexes_to_batch_dgl_graphs(val_batch.atom_graphs, val_batch)

                if model_type == "GNN" or model_type == "GCN":
                    val_out = model_combined(val_batch.x, val_batch.edge_index, val_batch.edge_weight, batch_graph_val)
                elif model_type == "GAT":
                    val_out = model_combined(val_batch.x, val_batch.edge_index, batch_graph = batch_graph_val)
                val_loss = criterion(val_out[val_batch.val_mask], val_batch.y[val_batch.val_mask].float())

                # append predictions and targets to lists
                epoch_val_loss += val_loss.item()
                epoch_val_p_values.extend(val_out[val_batch.val_mask].tolist())
                epoch_val_t_values.extend(val_batch.y[val_batch.val_mask].tolist())
            
            # send feedback to optuna
            trial.report(epoch_val_loss, epoch)
            if trial.should_prune(): 
                # save the results if the trial is pruned
                save_trial_results("Pruned", folder_path, trial, best_val_loss, epoch_train_predict_values, epoch_train_targets_values, epoch_val_predict_values, epoch_val_targets_values, all_train_loss, all_val_loss, epochs)
                raise optuna.TrialPruned()
            
            # update the best validation loss during training
            if epoch_val_loss < best_val_loss[0]:
                print(f"New best val_loss: {epoch_val_loss}, epoch: {epoch}, old best val_loss: {best_val_loss[0]}, epoch: {best_val_loss[1]}")
                best_val_loss = (epoch_val_loss, epoch)
                

            
        
        # append the loss values to the lists for each epoch        
        all_train_loss.append(epoch_train_loss)
        all_val_loss.append(epoch_val_loss)
        epoch_train_predict_values.append(epoch_train_p_values)
        epoch_train_targets_values.append(epoch_train_t_values)
        epoch_val_predict_values.append(epoch_val_p_values)
        epoch_val_targets_values.append(epoch_val_t_values)
        if test_used:
            all_test_loss.append(epoch_test_loss)
            epoch_test_predict_values.append(epoch_test_p_values)
            epoch_test_targets_values.append(epoch_test_t_values)


    save_trial_results("Finished", folder_path, trial, best_val_loss, epoch_train_predict_values, epoch_train_targets_values, epoch_val_predict_values, epoch_val_targets_values, all_train_loss, all_val_loss, epochs, test_used=True, epoch_test_predict_values=epoch_test_predict_values, epoch_test_targets_values=epoch_test_targets_values, test_loss=all_test_loss)
        
    return best_val_loss[0]
        




# Set the current time for naming the study
now = datetime.datetime.now().strftime("%H-%M-%S_%d-%m-%Y")

# Creates a unique study name
study_name = f"study_{now}"

# Defines the folder path for saving study results
folder_path = os.path.join(f"{model_type}_motif={motif_used}", study_name)
os.makedirs(folder_path, exist_ok=True)  


# Define the study (optuna object)
study = optuna.create_study(
    storage="sqlite:///db.sqlite3",  # Specify the database storage URL
    study_name=study_name,
    direction="minimize",
    pruner=optuna.pruners.PercentilePruner(
        percentile=25.0,  # Prune the bottom 75% of trials
        n_startup_trials=30,  # Start pruning after 30 trials
        n_warmup_steps=10,    # Allow 10 warmup epochs before pruning
        interval_steps=10     # Check pruning every 10 epochs
    )
)

# Optimize the study
study.optimize(objective, n_trials=500, show_progress_bar=True)

# Print the best trial results
print(f"Best value: {study.best_value} (params: {study.best_params})")

# Save the study trials to a CSV file
output_csv_path = os.path.join(folder_path, f"trials_{study_name}_{model_type}_{motif_used}.csv")
study.trials_dataframe().to_csv(output_csv_path)
print(f"Study results saved to: {output_csv_path}")




