# Final Model Training
**Purpose:** This script is used to train the final models. The script may also be used to try out other model configurations if desired.

**Dependency:** `hyperparameter_sweep.ipynb`, `Kfold_crossvalidation_sweep.ipynb`. The model configurations specified in this script are depermined by the model development sweeps (`hyperparameter_sweep.ipynb`). The number of epochs used when training the final models are determined by examining the 10x10 K-fold cross-validation runs (`Kfold_crossvalidation_sweep.ipynb`) and vary depending on model version due to overfitting being more or less prone to happen for the different datasets.

**Consecutive scripts:** After running this script the following scripts may be executed. `push_model_to_huggingface.ipynb`, `download_wandb_artifacts.ipynb`.

## Imports

In [None]:
from transformers import AutoModel, AutoTokenizer
from transformers import get_linear_schedule_with_warmup

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, SequentialSampler

import wandb

from tqdm.notebook import tqdm
import random
import pandas as pd
import numpy as np
import time

import sys
sys.path.insert(1, '/cephyr/users/skall/Alvis/TRIDENT/development/development_utils/')
import os
os.chdir('/cephyr/users/skall/Alvis/TRIDENT/development/')

from development_utils.preprocessing.Get_data_for_model import PreprocessData
from development_utils.training.Build_Pytorch_Dataset_and_DataLoader import BuildDataLoader_with_trainval_ratio
from development_utils.training.Build_Pytorch_model import TRIDENT, DNN_module, GPUinfo, Modify_architecture
from development_utils.training.PerformanceCalculations import CalculateWeightedAverage

In [None]:
#%matplotlib inline

import plotly.express as px
import plotly.graph_objects as go

In [None]:
device = torch.device("cuda:0") if torch.cuda.is_available() else "cpu"

GPUinfo(device)

## wandb configuration

In [None]:
ENTITYNAME = 'ecotoxformer'
PROJECTNAME = 'Final_model'

In [None]:
wandb.login()

True

In [None]:
wandb.init(entity=ENTITYNAME, project=PROJECTNAME, notes='', dir = '/mimer/NOBACKUP/groups/snic2022-22-552/skall/wandb/')

**PAY ATTENTION**
The following parameters depend on results from other scripts and can assume different values depending on which model version shopuld be fine-tuned. 

The values used in the publication are the following:

F-M50:
| **config variable**   | **value**  |
|-------------------|---|
| species_groups    | ['fish']  |
| endpoints         | ['EC50']  |
| effects           | ['MOR']  |
| epochs            | 35  |
| lr                | 0.00015  |

F-M10:
| **config variable**   | **value**  |
|-------------------|---|
| species_groups    | ['fish']  |
| endpoints         | ['EC10', 'NOEC']  |
| effects           | ['MOR','DVP','ITX','REP','MPH','POP','GRO']  |
| epochs            | 35  |
| lr                | 0.0005  |

F-M5010:
| **config variable**   | **value**  |
|-------------------|---|
| species_groups    | ['fish']  |
| endpoints         | ['EC50', 'EC10', 'NOEC']  |
| effects           | ['MOR','DVP','ITX','REP','MPH','POP','GRO']  |
| epochs            | 25  |
| lr                | 0.0002  |

A-M50:
| **config variable**   | **value**  |
|-------------------|---|
| species_groups    | ['algae']  |
| endpoints         | ['EC50']  |
| effects           | ['POP']  |
| epochs            | 25  |
| lr                | 0.00015  |

A-M10:
| **config variable**   | **value**  |
|-------------------|---|
| species_groups    | ['algae']  |
| endpoints         | ['EC10','NOEC']  |
| effects           | ['POP']  |
| epochs            | 30  |
| lr                | 0.0005  |

A-M5010:
| **config variable**   | **value**  |
|-------------------|---|
| species_groups    | ['algae']  |
| endpoints         | ['EC50','EC10','NOEC']  |
| effects           | ['POP']  |
| epochs            | 35  |
| lr                | 0.0002  |

I-M50:
| **config variable**   | **value**  |
|-------------------|---|
| species_groups    | ['crustaceans']  |
| endpoints         | ['EC50']  |
| effects           | ['MOR','ITX']  |
| epochs            | 35  |
| lr                | 0.00015  |

I-M10:
| **config variable**   | **value**  |
|-------------------|---|
| species_groups    | ['crustaceans']  |
| endpoints         | ['EC10','NOEC']  |
| effects           | ['MOR','DVP','ITX','REP','MPH','POP']  |
| epochs            | 35  |
| lr                | 0.0005  |

I-M5010:
| **config variable**   | **value**  |
|-------------------|---|
| species_groups    | ['crustaceans']  |
| endpoints         | ['EC50','EC10','NOEC']  |
| effects           | ['MOR','DVP','ITX','REP','MPH','POP']  |
| epochs            | 35  |
| lr                | 0.0002  |


Change the config below according to the specifications above. Values that should change are marked by a #* comment.

In [None]:
config = {}

# TRAINING ######################################
config['batch_size'] = 512     
config['epochs'] = 35 #*      
config['lr'] = 0.0002 #*
config['seed'] = 42            
config['max_token_length'] = 100
config['sampling_procedure'] = 'WRS_sqrt'
config['sampler_weight_args'] = ['SMILES_Canonical_RDKit','effect','endpoint']
config['optimizer'] = 'AdamW'
config['loss_fun'] = 'L1Loss'

# MODEL ############################################
config['pretrained_model'] = "seyonec/PubChem10M_SMILES_BPE_450k"
config['n_hidden_layers'] = 3
config['hidden_layer_size'] = [700, 500, 300]
config['dropout'] = 0.2
config['inputs']=['SMILES_Canonical_RDKit', 'Duration_Value', 'OneHotEnc_concatenated']
config['label'] = 'mgperL'
config['species_classes'] = []
config['reinit_n_layers'] = 0

# MODIFICATIONS ###########################################
config['n_frozen_layers'] = 0 
config['freeze_embedding'] = False
config['add_roberta_layer'] = False
config['use_cls'] = True

# DATA #######################################################
config['conc_thresh'] = 500
config['species_groups'] = ['algae'] #*
config['endpoints'] = ['EC50','EC10','NOEC'] #*
config['effects'] = ['POP'] #*
config['dataset'] = 'large'
config['concentration_sign'] = '='
config['log_data'] = True

if config['n_hidden_layers'] != len(config['hidden_layer_size']):
    print('You are not using all layers!')

In [None]:
class Dict2Class(object):
    def __init__(self, my_dict):
        for key in my_dict:
            setattr(self, key, my_dict[key])

In [None]:
config = Dict2Class(config)

# Config is a variable that holds and saves hyperparameters and inputs
wandb.config.update(config)

## Set seed

In [None]:
def SetSeed(seed):
    torch.manual_seed(seed) # pytorch random seed
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True

In [None]:
SetSeed(config.seed)

## Load ChemBERTa

In [None]:
tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)

chemberta = AutoModel.from_pretrained(config.pretrained_model)

print(f'Trainable parameters: {chemberta.num_parameters()}')

## Data

### Loading and pre-process

In [None]:
def GetData(data, config):
    # Preprocesses data for training
    processor = PreprocessData(dataframe=data)

    processor.FilterData(
        concentration_thresh=config.conc_thresh,
        endpoint=config.endpoints,
        effect=config.effects,
        species_groups=config.species_groups,
        log_data=True,
        concentration_sign=config.concentration_sign)

    processor.GetPubchemCID(drop_missing_entries=False)
    processor.GetMetadata(list_of_metadata=['cmpdname'])
    processor.GetCanonicalSMILES()
    processor.ConcatenateOneHotEnc(list_of_endpoints=config.endpoints, list_of_effects=config.effects)

    data = processor.dataframe
    # Get the number of neurons needed for one hot encoding
    fc1 = len(data.OneHotEnc_concatenated.iloc[0])
    
    return data, fc1

In [None]:
datadir = '../data/development/'
data = pd.read_excel(datadir+'Preprocessed_complete_data_2023.xlsx', sheet_name='dataset')
data, fc1 = GetData(data, config)

In [None]:
config.fc1 = fc1
wandb.config.update(config)

## Define dataloader

In [None]:
# Build Pytorch train dataloader
# test_size = 0 ensures entire dataset used as training set
DataLoaders = BuildDataLoader_with_trainval_ratio(
                                    df = data, 
                                    wandb_config = config,
                                    label = config.label, 
                                    batch_size = config.batch_size, 
                                    max_length = config.max_token_length, 
                                    seed = config.seed,
                                    test_size = 0,
                                    tokenizer = tokenizer)
        
train_dataloader = DataLoaders.BuildTrainingLoader(sampler_choice=config.sampling_procedure, num_workers=2, weight_args=config.sampler_weight_args)

## Architecture

In [None]:
# Build the model (consisting of a DNN(-module) and ChemBERTa)
dnn_module = DNN_module(
                        one_hot_enc_len=fc1,
                        n_hidden_layers=config.n_hidden_layers,
                        layer_sizes=config.hidden_layer_size,
                        dropout=config.dropout)

model = TRIDENT(roberta=chemberta, dnn=dnn_module)

model = Modify_architecture(model).FreezeModel(model, config.n_frozen_layers, config.freeze_embedding)
model = Modify_architecture(model).ReinitializeEncoderLayers(model, reinit_n_layers=config.reinit_n_layers)
model = model.to(device)

In [None]:
print(f'Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}')

## Define train and validation functions

In [None]:
# function to train the model on epoch
def train(args, model, dataloader, optimizer, scheduler, loss_fun, batch_num, epoch, global_step):
    from tqdm.notebook import tqdm
    model.train()
    
    print("\nTraining...")
    total_loss = 0
    total_preds=[]
    total_labels=[]
    # iterate over batches
    for step, batch in enumerate(tqdm(dataloader)):
        # Extract batch samples
        batch = [r.to(device) for r in batch.values()]
        sent_id, mask, duration, onehot, labels = batch
        
        # Zero gradients
        optimizer.zero_grad()

        # Predict batch
        preds, _ = model(sent_id, mask, duration, onehot)

        # Calculate batch loss
        loss = loss_fun(preds, labels)
        total_loss += loss.item()
        loss.backward()

        # Clip gradient to prevent exploding gradients and update weights
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

        # Log batch results
        preds = preds.detach().cpu().numpy()
        labels = labels.detach().cpu().numpy()
        total_preds.append(preds)
        total_labels.append(labels)
        
        wandb.log({
            "Training Batch Loss": loss.item(),
            "Learning Rate": optimizer.param_groups[0]["lr"], 
            'training batch': batch_num[0]
        })
        batch_num[0] += 1

    # compute the training loss of the epoch
    avg_loss = total_loss / len(dataloader)
    total_preds = np.concatenate(total_preds, axis=0)
    total_labels  = np.concatenate(total_labels, axis=0)
    median_loss = np.median(abs(total_preds - total_labels))

    wandb.log({
        "Training Loss function": avg_loss,
        "Training Mean Loss": np.mean(abs(total_preds - total_labels)), 
        'training epoch': epoch,
        "Training Median Loss": np.median(abs(total_preds - total_labels)),
        "Training RMSE Loss": np.sqrt(np.mean((total_labels-total_preds)**2)),
        'global_step': global_step})
    
    return avg_loss, median_loss, total_preds, total_labels, batch_num

## Train the model

In [None]:
# Apply Layer wise Learning Rate Decay
model_parameters = Modify_architecture(model).LLRD(model, init_lr = config.lr)

optimizer = torch.optim.AdamW(model_parameters, lr=config.lr)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0.1*config.epochs*len(train_dataloader), num_training_steps=config.epochs*len(train_dataloader))

In [None]:
if config.loss_fun == 'L1Loss':
    loss_fun = nn.L1Loss()
elif config.loss_fun == 'MSELoss':
    loss_fun = nn.MSELoss()
    
wandb.watch(model, log="all")

In [None]:
# Set name of files (combination of endpoints and species_group)
if len(config.endpoints) == 1:
    name=f'EC50_{config.species_groups[0]}'
elif len(config.endpoints) == 2:
    name=f'EC10_{config.species_groups[0]}'
elif len(config.endpoints) == 3:
    name=f'EC50EC10_{config.species_groups[0]}_withoverlap' 

In [None]:
save_name = f'../TRIDENT/final_model_2023_{name}'

In [None]:
# Function to save fine-tuned ChemBERTa and DNN-module
def save_ckp(model, checkpoint_dir):
    torch.save(model.dnn.state_dict(), checkpoint_dir+'_dnn_saved_weights.pt')
    torch.save(model.roberta.state_dict(), checkpoint_dir+'_roberta_saved_weights.pt')
    #wandb.save(checkpoint_dir+'_dnn_saved_weights.pt')
    #wandb.save(checkpoint_dir+'_roberta_saved_weights.pt')

### Run training epochs

In [None]:
best_val_loss = float('inf')
best_val_loss_norm = float('inf')
batch_num = [0,0]
global_step = 0

# Time training
start_time = time.time()

# Run training epochs
for epoch in tqdm(range(config.epochs)):
    print('\n Epoch {:} / {:}'.format(epoch + 1, config.epochs))
    avg_loss, median_loss, total_preds, total_labels, batch_num = train(config, model, train_dataloader, optimizer, scheduler, loss_fun, batch_num, epoch, global_step)
    # No validation epochs since this is the final model
    print(f'\nTraining Loss: {median_loss:.3f}')

train_time = (time.time() - start_time)/60

wandb.log({'Total train time (min)': train_time,
            'epoch time (s)': train_time/config.epochs*60})

In [None]:
save_ckp(model, save_name)

## Evaluate resulting model
NOTE: Will be partially overfitted to training data (not a problem, new chemicals will still have the accuracy presented in publication)

In [None]:
# Build new pytorch dataset and dataloader manually
dataset = DataLoaders.BuildDataset(DataLoaders.train)
sampler = SequentialSampler(dataset)
dataloader = DataLoader(dataset, sampler=sampler, batch_size=512, collate_fn=DataLoaders.collator, num_workers=2)

In [None]:
# Deactivate dropout
model.eval()

# predict the entire training set again and save CLS-embeddings
results = DataLoaders.train.copy()
predictions = []
cls_embeddings = []
for step, batch in enumerate(tqdm(dataloader)):
    batch = [r.to(device) for r in batch.values()]
    sent_id, mask, duration, onehot, labels = batch

    with torch.no_grad():
        # Predict 
        preds, cls = model(sent_id, mask, duration, onehot)
    predictions.append(preds.detach().cpu().numpy())
    cls_embeddings.append(cls.detach().cpu().numpy())

In [None]:
results['preds'] = np.concatenate(predictions, axis=0)
results['CLS_embeddings'] = np.concatenate(cls_embeddings, axis=0).tolist()
results['residuals'] = results.mgperL-results.preds
results['absolute_error'] = abs(results.mgperL-results.preds)

In [None]:
# Save results locally
results.to_pickle(f'../data/results/{name}_final_model_training_data_RDkit.zip', compression='zip')

In [None]:
# Save results to weights and biases
art = wandb.Artifact(
            f"Training_data_final_model_{name}", type="results_dataset",
            description=f"{name}",
            metadata={"source": "Preprocessed_complete_data.xlsx",
                      "sizes": len(results)})

art.add_file(local_path=f'../data/results/{name}_final_model_training_data_RDkit.zip')

wandb.log_artifact(art)

In [None]:
# Also save normalized results locally
results['labels'] = results.mgperL
results_normalized = CalculateWeightedAverage(results)
results_normalized.to_pickle(f'../data/results/{name}_weighted_Avg_Training_data_final_model.zip', compression='zip')

In [None]:
# Also save normalized results to weights and biases
art = wandb.Artifact(
            f"Weighted_Avg_Training_data_final_model_{name}", type="weighted_results_dataset",
            description=f"{name}",
            metadata={"source": "Training_data_final_model_{name}",
                      "sizes": len(results_normalized)})

art.add_file(local_path=f'../data/results/{name}_weighted_Avg_Training_data_final_model.zip')

wandb.log_artifact(art)

### Plot results

In [None]:
fig = px.scatter(results_normalized, x='mgperL', y='preds', hover_data=['SMILES_Canonical_RDKit'], trendline='ols', trendline_color_override="black")

fig.update_traces(marker=dict(line_width=0.5, line_color='Black'))
fig.update_yaxes(title_text="Predicted Concentration [Log10(mg/L)]", range=[-4,4])
fig.update_xaxes(title_text="Actual Concentration [Log10(mg/L)]", range=[-4,4])
fig.update_xaxes(showline=True, linewidth=2, linecolor='grey')
fig.update_yaxes(showline=True, linewidth=2, linecolor='grey')
fig.update_layout(
    width=700,
    height=700,
    title=f'Predictions vs. actual labels, one per SMILES n={len(results_normalized)}')

fig.update_layout(
        font_family='Serif',
        font=dict(size=16), 
        plot_bgcolor='rgba(0, 0, 0, 0)')

fig.show()

wandb.log({'Prediction vs target (one per chemical)': fig}, commit=True)

In [None]:
wandb.finish()