# Train GNN-Siamese model

Uses Optuna hyper-parameter optmization for finding the best GNN-Siamese architecture. 
Train and save best model.

--------------------

### Imports

In [None]:
import os
import src.models.GNN_Siamese
from GNN_Siamese import DifferentialSpearmanCorrCoef, GNNSiamese
import optuna

---------------

### Define the hyper-parameters range for Optuna trial


In [None]:
num_layers=[1,3] #Number of layers suggest for optuna trial. num_layers[0]:min, num_layers[1]:max
batch_size=[10,20,30,40,50] #Sizes of batch sizes suggest for optuna trial
lr=[1e-4, 1e-2] #Learning rate suggest for optuna trial. num_layers[0]:min, num_layers[1]:max
activation_name=[torch.nn.Tanh(), torch.nn.ReLU(), torch.nn.Sigmoid()] #Activation functions suggest for optuna trial.
criterion_name=[torch.nn.MSELoss(), torch.nn.L1Loss()] #Loss functions suggest for optuna trial.
num_epochs=[5,50] # Number of epochs suggest for optuna trial
gat_heads=[1,2,3,4,5] #Number of gat heads suggest for optuna trial
readout_layer=[nn.global_mean_pool,nn.global_max_pool,nn.global_add_pool] #Pooling layers suggest for optuna trial.
num_trails=50 #Number of trials optuna should run

### Define auxiliar functions 

- Optuna objective trial: runs hyper-parameter optmization given the defined ranges above
- GNN-Siamese training function: train the full GNN-Siamese model with the parameters for the best model returned by optuna optimization
- Prepare training data: generator fited torch geometric DataLoader for training

In [None]:
def objective(trial,num_layers,batch_size,lr,activation_name,criterion_name,num_epochs,gat_heads,readout_layer,
             train_loader1, val_loader1,train_loader2, val_loader2):
    # Suggest hyperparameters
    num_node_features = 13
    num_layers = trial.suggest_int('num_conv_hidden_layers', num_layers[0], num_layers[1])
    conv_hidden_channels = []
    min_val = 9
    for i in range(num_layers):
        # Suggest a decreasing value for each subsequent layer
        max_val = 12 if i == 0 else conv_hidden_channels[-1] - 1
        if max_val < min_val:
            raise ValueError("Cannot generate strictly decreasing channels within the range.")
        conv_hidden_channels.append(trial.suggest_int(f'conv_hidden_channels_{i}', min_val, max_val))
        min_val = min_val - 2
        
    batch_size = trial.suggest_categorical('batch_size', batch_size)
    lr = trial.suggest_loguniform('lr', lr[0], lr[1])
    activation_name = trial.suggest_categorical('activation', activation_name)
    criterion = trial.suggest_categorical('criterion', criterion_name)
    num_epochs = trial.suggest_int('num_epochs', num_epochs[0], num_epochs[1])
    gat_heads = trial.suggest_categorical('gat_heads', gat_heads)
    readout_layer = trial.suggest_categorical('readout_layer',readout_layer)
    
    spearman_corrcoef = DifferentialSpearmanCorrCoef(num_outputs=batch_size)

    #Instanciate Model
    model = GNNSiamese(conv_hidden_channels=conv_hidden_channels,
                           activation=activation_name,
                           batch_size=batch_size,
                           readout_layer=readout_layer,
                           num_node_features=num_node_features,
                           n_graphs=n_graphs,
                           gat_heads=gat_heads).to(device)


    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # Training loop
    for data1,data2 in zip(train_loader1,train_loader2):
      optimizer.zero_grad()
      new_y1 = data1.y.reshape(batch_size,n_graphs).T
      new_y2 = data1.y.reshape(batch_size,n_graphs).T
      real_corr = spearman_corrcoef(new_y1, new_y2)

      #Siamese Network corr
      pred_corr = model(data1.x.to(device), data1.edge_index.to(device), data1.batch.to(device),
                        data2.x.to(device), data2.edge_index.to(device), data2.batch.to(device))
      loss = criterion(pred_corr, real_corr.to(device))
      loss = loss.to(device)
      loss.backward()
      optimizer.step()

    # Validation loss
    val_loss = 0
    for data1,data2 in zip(val_loader1,val_loader2):
        with torch.no_grad():
            new_y1 = data1.y.reshape(batch_size,n_graphs).T
            new_y2 = data1.y.reshape(batch_size,n_graphs).T
            real_corr = spearman_corrcoef(new_y1, new_y2)
            pred_corr = model(data1.x.to(device), data1.edge_index.to(device), data1.batch.to(device),
                              data2.x.to(device), data2.edge_index.to(device), data2.batch.to(device))
            real_corr = spearman_corrcoef(data1.y,data2.y)
            val_loss += criterion(pred_corr, real_corr.to(device)).item()

    return val_loss / len(val_loader1)


def train_best_GNN_Siamese(study_best_params,train_loader1, val_loader1,train_loader2, val_loader2):
    
    batch_size = study_best_params['batch_size']
    n_graphs = study_best_params['n_graphs']
    gat_heads = study_best_params['gat_heads']
    conv_hidden_channels=[study_best_params[f'conv_hidden_channels_{i}'] for i in range(study_best_params['num_conv_hidden_layers'])]
    activation=activations[study_best_params['activation']]
    readout_layer=readout[study_best_params['readout_layer']]
    criterion = criteria[study_best_params['criterion']]
    lr = study_best_params['lr']
    
    spearman_corrcoef = DifferentialSpearmanCorrCoef(num_outputs=batch_size*n_graphs)
    
    # Create the best model
    best_model = GCNSimpleSiamese(conv_hidden_channels,
                                  activation,
                                  batch_size,
                                  readout_layer,
                                  num_node_features,
                                  n_graphs,
                                  gat_heads).to(device)
    
    
    # Training loop
    optimizer = torch.optim.Adam(best_model.parameters(), lr=lr)
    
    train_losses = []
    val_losses = []
    spearman_corrcoef = NewSpearmanCorrCoef(num_outputs=batch_size)
    t_start_time = time.time()
    for epoch in range(study_best_params['num_epochs']):
      for data1,data2 in zip(train_loader1,train_loader2): #data1.x.shape = [n_graphs * batch_size, num_node_features]
          optimizer.zero_grad()
          #Real corr
          new_y1 = data1.y.reshape(batch_size,n_graphs).T
          new_y2 = data1.y.reshape(batch_size,n_graphs).T
          real_corr = spearman_corrcoef(new_y1, new_y2)
    
          #Siamese Network corr
          pred_corr = best_model(data1.x.to(device), data1.edge_index.to(device), data1.batch.to(device),
                                data2.x.to(device), data2.edge_index.to(device), data2.batch.to(device))
          loss = criterion(pred_corr, real_corr.to(device))
          loss.backward()
          optimizer.step()
      
      train_losses.append(loss.item())
    
      for data1,data2 in zip(val_loader1,val_loader2):
          with torch.no_grad():
              for data1,data2 in zip(train_loader1,train_loader2): #data1.x.shape = [n_graphs * batch_size, num_node_features]
                #Real corr
                new_y1 = data1.y.reshape(batch_size,n_graphs).T
                new_y2 = data2.y.reshape(batch_size,n_graphs).T
                real_corr = spearman_corrcoef(new_y1, new_y2)
    
                #Siamese Network corr
                pred_corr = best_model(data1.x.to(device), data1.edge_index.to(device), data1.batch.to(device),
                                       data2.x.to(device), data2.edge_index.to(device), data2.batch.to(device))
                val_loss = criterion(pred_corr, real_corr.to(device))
       val_losses.append(val_loss.item())

        
    parameters ={"batch_size":batch_size,"n_graphs":n_graphs,'conv_hidden_channels':conv_hidden_channels,'lin_hidden_channels':lin_hidden_channels, 'activation':activation, 'layer_type':layer_type,'readout_layer':readout_layer, 'num_node_features':num_node_features,'gat_heads':gat_heads}
    torch.save({"model_state_dict": best_model.state_dict(),"params":parameters},f'/models/GNN_siamese_best_model_state_dict.pth')

    return best_model

def prepare_training_data():
    dataset1 = torch.load(f'/data/train/train_set1.pkl')
    dataset2 = torch.load(f'/data/train/train_set2.pkl')
    scaler = joblib.load( f'/data/scaler.pkl')
    for data1,data2 in zip(dataset1,dataset2):
      data1.x = torch.tensor(scaler.transform(data1.x)).to(torch.float32)
      data2.x = torch.tensor(scaler.transform(data2.x)).to(torch.float32)
        
    train_size = int(0.8 * len(dataset1))
    val_size = len(dataset1) - train_size
    train_dataset1, val_dataset1 = random_split(dataset1, [train_size, val_size])
    train_dataset2, val_dataset2 = random_split(dataset2, [train_size, val_size])
    train_loader1 = DataLoader(train_dataset1, batch_size=batch_size*n_graphs, shuffle=True,drop_last=True)
    train_loader2 = DataLoader(train_dataset2, batch_size=batch_size*n_graphs, shuffle=True,drop_last=True)
    val_loader1 = DataLoader(val_dataset1, batch_size=batch_size*n_graphs, shuffle=True,drop_last=True)
    val_loader2 = DataLoader(val_dataset2, batch_size=batch_size*n_graphs, shuffle=True,drop_last=True)
    ß
    return train_loader1, val_loader1,train_loader2, val_loader2

-------

### Main

Runs auxiliar functions to train the GNN-Siamese model. Returns and saves best model

In [None]:
# Generate DataLoaders
train_loader1, val_loader1,train_loader2, val_loader2 = prepare_training_data()

# Run Optuna optimization
study = optuna.create_study(direction='minimize')
study.optimize(lambda trial: objective(trial,num_layers,batch_size,lr,activation_name,criterion_name,num_epochs,gat_heads,readout_layer,
                                       train_loader1, val_loader1,train_loader2, val_loader2), n_trials=num_trails)
study_best_params = study.best_params
GNN_Siamese = train_best_GNN_Siamese(study_best_params,train_loader1, val_loader1,train_loader2, val_loader2)

