# Generate and launch parameter and hyperparameter SWEEP on multiple GPUs

**Purpose:** This script performes 5-fold cross-validation (hyper)parameter sweeps used to run the following sweeps:
1. ChemBERTa version and loss function
    - variable parameters: [learning rate, loss function, ChemBERTa version]
    - config file: `BPEtok10M.yaml` and `SMILEStok1M.yaml`
2. Batch size
    - variable parameters: [learning rate, batch size]
    - config file: `batch_size.yaml`
3. M50 (EC50) Model configuration
    - variable parameters: [learning rate, number of reinitialized encoders, number of hidden layers, hidden layer sizes]
    - config file: `2_hidden_layers_hpsweep.yaml`, `3_hidden_layers_hpsweep.yaml` and `4_hidden_layers_hpsweep.yaml`
4. M10 (EC10) Model configuration
    - variable parameters: [learning rate, number of reinitialized encoders, number of hidden layers, hidden layer sizes]
    - config file: `2_hidden_layers_hpsweep.yaml`, `3_hidden_layers_hpsweep.yaml` and `4_hidden_layers_hpsweep.yaml`
5. M5010 (EC50 and EC10 combo) Model configuration
    - variable parameters: [learning rate, number of reinitialized encoders, number of hidden layers, hidden layer sizes]
    - config file: `2_hidden_layers_hpsweep.yaml`, `3_hidden_layers_hpsweep.yaml` and `4_hidden_layers_hpsweep.yaml`

**Dependency:** None (apart from initial model testing)

**Consecutive scripts:** After running this script the following scripts may be executed. `Kfold_crossvalidation_sweep.ipynb`

## Imports

In [None]:
from transformers import AutoModel, AutoTokenizer
from transformers import get_linear_schedule_with_warmup

import torch
import torch.nn as nn

import wandb

from tqdm.notebook import tqdm
import random
import pandas as pd
import numpy as np

from development_utils.preprocessing.Get_data_for_model import PreprocessData
from development_utils.training.Build_Pytorch_Dataset_and_DataLoader import BuildDataLoader_KFold, Make_KFolds
from development_utils.training.Build_Pytorch_model import ecoCAIT, DNN_module, GPUinfo, Modify_architecture
from development_utils.training.PerformanceCalculations import CalculateWeightedAverage

In [None]:
device = torch.device("cuda:0") if torch.cuda.is_available() else "cpu"

if torch.cuda.is_available():
    GPUinfo(device)

GPUs on node: NVIDIA A100-SXM4-40GB
Number of GPUs available: 1
Using cuda:0 device
42.35 Gb free on CUDA


## wandb configuration

In [None]:
ENTITYNAME = 'ecotoxformer'
PROJECTNAME = 'testing_github'
SWEEPID = 'b9s38isu'

In [None]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mstyrbjornkall[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

## Define helper functions

In [None]:
def GetData(data, config):
    # Preprocesses data for training
    processor = PreprocessData(dataframe=data)

    processor.FilterData(
        concentration_thresh=config.conc_thresh,
        endpoint=config.endpoints,
        effect=config.effects,
        species_groups=config.species_groups,
        log_data=True,
        concentration_sign=config.concentration_sign)

    processor.GetPubchemCID()
    processor.GetMetadata(list_of_metadata=['cmpdname'])
    processor.GetCanonicalSMILES()
    processor.ConcatenateOneHotEnc(list_of_endpoints=config.endpoints, list_of_effects=config.effects)

    data = processor.dataframe
    # Get the number of neurons needed for one hot encoding
    fc1 = len(data.OneHotEnc_concatenated.iloc[0])
    
    return data, fc1

In [6]:
def SetSeed(seed):
    # Sets random seed for deterministic training
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True

In [7]:
def GetLayers(config):
    # Function to buil list of layer sizes
    if config.n_hidden_layers == 1:
        return [config.layer_1]
    elif config.n_hidden_layers == 2:
        return [config.layer_1, config.layer_2]
    elif config.n_hidden_layers == 3:
        return [config.layer_1, config.layer_2, config.layer_3]
    elif config.n_hidden_layers == 4:
        return [config.layer_1, config.layer_2, config.layer_3, config.layer_4]

## Training config

### Train and validation functions

In [8]:
def RunTrainingEpochs(data, folds, fold_id, config, fc1, global_step):
    
    # Load ChemBERTa
    chemberta = AutoModel.from_pretrained(config.base_model)
    tokenizer = AutoTokenizer.from_pretrained(config.base_model)

    # Build Pytorch train and validation dataloader based on fold
    DataLoaders = BuildDataLoader_KFold(
                                    df = data,
                                    folds = folds,
                                    fold_id=fold_id, 
                                    wandb_config = config,
                                    label = config.label, 
                                    batch_size = config.batch_size, 
                                    max_length = config.max_token_length, 
                                    seed = config.seed, 
                                    tokenizer = tokenizer)
    
    train_dataloader = DataLoaders.BuildTrainingLoader(sampler_choice=config.sampling_procedure, num_workers=2, weight_args=['SMILES_Canonical_RDKit','effect','endpoint'])
    val_dataloader = DataLoaders.BuildValidationLoader(sampler_choice='SequentialSampler', num_workers=2)
    print('Successfully built dataloader')
    print(f'SMILES overlap train/validation: {len(set(DataLoaders.train.SMILES_Canonical_RDKit.tolist())&set(DataLoaders.val.SMILES_Canonical_RDKit.tolist()))}')

######## MODEL ##################################################################################
    # Build the model (consisting of a DNN(-module) and ChemBERTa)
    dnn_module = DNN_module(
                        one_hot_enc_len=fc1,
                        n_hidden_layers=config.n_hidden_layers,
                        layer_sizes=GetLayers(config),
                        dropout=config.dropout)

    model = ecoCAIT(roberta=chemberta, dnn=dnn_module)

    model = Modify_architecture(model).FreezeModel(model, config.n_frozen_layers, config.freeze_embedding)
    model = Modify_architecture(model).ReinitializeEncoderLayers(model, reinit_n_layers=config.reinit_n_layers)
    model = model.to(device)

######## TRAINING CONFIG ##################################################################################   
    # Apply Layer wise Learning Rate Decay
    model_parameters = Modify_architecture(model).LLRD(model, init_lr = config.lr)

    optimizer = torch.optim.AdamW(model_parameters, lr=config.lr)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0.1*config.epochs*len(train_dataloader), num_training_steps=config.epochs*len(train_dataloader))
    print('Successfully built optimizer')

    if config.loss_fun == 'MSELoss':
        loss_fun = nn.MSELoss()
    else:
        loss_fun = nn.L1Loss()

    best_val_loss = np.inf
    best_val_loss_norm = np.inf
    
    batch_num = [0,0]
    
######## RUN TRAINING ##################################################################################
    # Log initial validation loss
    avg_loss, avg_loss_norm, median_loss, median_loss_norm, _, batch_num, val_results = evaluate(model, val_dataloader, DataLoaders.val, loss_fun, batch_num, -1, global_step-1)

    if median_loss < best_val_loss:
        best_val_loss = median_loss
    
    # Initilize arrays in which to store this fold's performance
    Best_Validation_Median_Loss = []
    Best_Validation_Median_Loss_Normalized = []
    Best_Validation_Mean_Loss_Normalized = []

    # Run epochs
    print("\nRunning epochs...")
    for epoch in tqdm(range(config.epochs)):

        avg_loss, median_loss, total_preds, total_labels, batch_num = train(config, model, train_dataloader, optimizer, scheduler, loss_fun, batch_num, epoch, global_step)
        
        avg_loss, avg_loss_norm, median_loss, median_loss_norm, _, batch_num, val_results = evaluate(model, val_dataloader, DataLoaders.val, loss_fun, batch_num, epoch, global_step)
        
        # Update and log epoch results
        if median_loss_norm < best_val_loss_norm:
            best_val_loss = median_loss
            best_val_loss_norm = median_loss_norm
            best_validation_results = val_results
            best_validation_mean_norm_loss = avg_loss_norm

        wandb.log({'Best Validation Median Loss': best_val_loss,
                    'Best Validation Median Loss Normalized': best_val_loss_norm,
                    'Best Validation Mean Loss Normalized': best_validation_mean_norm_loss,
                    'global_step': global_step})
        
        Best_Validation_Median_Loss.append(best_val_loss)
        Best_Validation_Median_Loss_Normalized.append(best_val_loss_norm)
        Best_Validation_Mean_Loss_Normalized.append(best_validation_mean_norm_loss)
        
        global_step += 1
        
######## DELETE FOLD PARAMETERS ##################################################################################
    del model
    del optimizer
    del loss_fun
    del chemberta
    del tokenizer

    return Best_Validation_Median_Loss, Best_Validation_Median_Loss_Normalized, Best_Validation_Mean_Loss_Normalized, global_step, best_validation_results

In [9]:
# function to train the model on epoch
def train(args, model, dataloader, optimizer, scheduler, loss_fun, batch_num, epoch, global_step):
    from tqdm.notebook import tqdm
    model.train()
    
    print("\nTraining...")
    total_loss = 0
    total_preds=[]
    total_labels=[]
    # iterate over batches
    for step, batch in enumerate(tqdm(dataloader)):
        # Extract batch samples
        batch = [r.to(device) for r in batch.values()]
        sent_id, mask, duration, onehot, labels = batch
        
        # Zero gradients
        optimizer.zero_grad()

        # Predict batch
        preds, _ = model(sent_id, mask, duration, onehot)

        # Calculate batch loss
        loss = loss_fun(preds, labels)
        total_loss += loss.item()
        loss.backward()

        # Clip gradient to prevent exploding gradients and update weights
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

        # Log batch results
        preds = preds.detach().cpu().numpy()
        labels = labels.detach().cpu().numpy()
        total_preds.append(preds)
        total_labels.append(labels)
        
        wandb.log({
            "Training Batch Loss": loss.item(),
            "Learning Rate": optimizer.param_groups[0]["lr"], 
            'training batch': batch_num[0]
        })
        batch_num[0] += 1

    # compute the training loss of the epoch
    avg_loss = total_loss / len(dataloader)
    total_preds = np.concatenate(total_preds, axis=0)
    total_labels  = np.concatenate(total_labels, axis=0)
    median_loss = np.median(abs(total_preds - total_labels))

    wandb.log({
        "Training Loss function": avg_loss,
        "Training Mean Loss": np.mean(abs(total_preds - total_labels)), 
        'training epoch': epoch,
        "Training Median Loss": np.median(abs(total_preds - total_labels)),
        "Training RMSE Loss": np.sqrt(np.mean((total_labels-total_preds)**2)),
        'global_step': global_step})
    
    return avg_loss, median_loss, total_preds, total_labels, batch_num

In [10]:
# function to validate the model on epoch
def evaluate(model, dataloader, dataset, loss_fun, batch_num, epoch, global_step):
    from tqdm.notebook import tqdm
    
    print("\nEvaluating...")
    model.eval()
    total_preds = []
    total_labels = []
    total_loss = 0

    # Initialize validation array in which to log results
    val_results = dataset.copy()
    cls_embeddings = []

    # iterate over batches
    for step, batch in enumerate(tqdm(dataloader)):
        # Extract batch samples
        batch = [t.to(device) for t in batch.values()]
        sent_id, mask, duration, onehot, labels = batch
        
        with torch.no_grad():
            # Predict batch
            preds, roberta_output = model(sent_id, mask, duration, onehot)

            # Calculate batch loss
            loss = loss_fun(preds, labels)
            total_loss += loss.item()

            # Log batch results
            cls_embeddings.append(roberta_output.detach().cpu().numpy())
            preds = preds.detach().cpu().numpy()
            labels = labels.detach().cpu().numpy()
            total_preds.append(preds)
            total_labels.append(labels)
        batch_num[1] += 1

    # compute the validation loss of the epoch
    avg_loss = total_loss/len(dataloader)
    total_preds  = np.concatenate(total_preds, axis=0)
    total_labels  = np.concatenate(total_labels, axis=0)
    val_results['CLS_embeddings']  = np.concatenate(cls_embeddings, axis=0).tolist()
    val_results['labels'] = total_labels
    val_results['preds'] = total_preds
    val_results['residuals'] = val_results.labels-val_results.preds
    val_results['L1Error'] = abs(total_labels - total_preds)
    median_loss = val_results.L1Error.median()
    val_results_normalized = CalculateWeightedAverage(val_results)
    median_loss_norm = abs(val_results_normalized.residuals).median()
    avg_loss_norm = abs(val_results_normalized.residuals).mean()
    wandb.log({
        "Validation Loss function": avg_loss,
        "Validation Mean Loss": val_results.L1Error.mean(),
        "Validation Median Loss": median_loss,
        "Validation Loss Normalized": median_loss_norm,
        "Validation Mean Loss Normalized": avg_loss_norm,
        "Validation RMSE Loss Normalized": np.sqrt(((val_results_normalized.labels - val_results_normalized.preds)**2).mean()),
        'validation epoch': epoch,
        'global_step': global_step
        })
        
    return avg_loss, avg_loss_norm, median_loss, median_loss_norm, total_preds, batch_num, val_results

### Define Trainer function

In [13]:
def trainer(config=None):
    from tqdm.notebook import tqdm
    # Set random seeds and deterministic pytorch for reproducibility
    SetSeed(42)
    
    # Initialize a new wandb run
    with wandb.init(config=config):

        # If called by wandb.agent, as below, this config will be set by Sweep Controller
        sweepconfig = wandb.config

    ######## DATA ##################################################################################
        # Load dataframe
        datadir = '../data/development/'
        data = pd.read_excel(datadir+'Preprocessed_complete_data.xlsx', sheet_name='dataset')
        data, fc1 = GetData(data, sweepconfig)
        print('Successfully loaded data')

        # Build K-folds based on SMILES (each fold has a unique set of SMILES)
        folds = Make_KFolds().Split(data[sweepconfig.smiles_col_name], k_folds=sweepconfig.k_folds, seed=sweepconfig.seed)
        print('Successfully built folds')
        name = wandb.run.name
        
        # Initilize arrays in which to store each fold's performance
        Avg_Best_Validation_Median_Loss = np.zeros((sweepconfig.k_folds, sweepconfig.epochs))
        Avg_Best_Validation_Median_Loss_Normalized = np.zeros((sweepconfig.k_folds, sweepconfig.epochs))
        Avg_Best_Validation_Mean_Loss_Normalized = np.zeros((sweepconfig.k_folds, sweepconfig.epochs))

        # Run K-fold Cross validation
        print(f'\n Running {sweepconfig.k_folds} folds')
        global_step = 0

        for fold_id in tqdm(range(1, sweepconfig.k_folds+1, 1)):
            # Run one fold
            Best_Validation_Median_Loss, Best_Validation_Median_Loss_Normalized, Best_Validation_Mean_Loss_Normalized, global_step, best_validation_results = RunTrainingEpochs(data, folds, fold_id, sweepconfig, fc1, global_step)

            global_step += 1

            # Log results
            if sweepconfig.save_results == True:
                wandb.log({f"Best Validation Results {fold_id}": wandb.Table(dataframe=best_validation_results)})
            Avg_Best_Validation_Median_Loss[fold_id-1,:] = Best_Validation_Median_Loss
            Avg_Best_Validation_Median_Loss_Normalized[fold_id-1,:] = Best_Validation_Median_Loss_Normalized
            Avg_Best_Validation_Mean_Loss_Normalized[fold_id-1,:] = Best_Validation_Mean_Loss_Normalized

        Avg_Best_Validation_Median_Loss.mean(axis=0)
        Avg_Best_Validation_Median_Loss_Normalized.mean(axis=0)

        # Log the sweep metric as the average performance
        for i in range(sweepconfig.epochs):
            wandb.log({
                'Avg Best Validation Median Loss Normalized': Avg_Best_Validation_Median_Loss_Normalized.mean(axis=0)[i],
                'Avg Best Validation Median Loss Normalized std': np.std(Avg_Best_Validation_Median_Loss_Normalized, axis=0)[i],
                'Avg Best Validation Median Loss': Avg_Best_Validation_Median_Loss.mean(axis=0)[i],
                'Avg Best Validation Median Loss std': np.std(Avg_Best_Validation_Median_Loss, axis=0)[i],
                'Avg Best Validation Mean Loss Normalized': Avg_Best_Validation_Mean_Loss_Normalized.mean(axis=0)[i],
                'epoch': i})

## Train the model

In [None]:
# Run wandb agent (runs the script)
wandb.agent(f'{ENITITYNAME}/{PROJECTNAME}/{SWEEPID}', trainer)