# Master Training Notebook
To use these notebook, simply configure the global variables in the first cell to choose which model and hyperparameters to train on.

In [1]:
# Model and dataset variables
MULTIMOLECULE_MODEL = "splicebert" # Available models are: rnafm, rnamsm, ernierna, utrlm-te_el, splicebert, rnabert.
SAMPLE_N_DATAPOINTS = None # Set to None to use the full dataset
SEED = 32

# Training hyperparameters
BATCH_SIZE = 16
LEARNING_RATE = 3e-4
TRAIN_EPOCHS = 3
WEIGHT_DECAY = 0.001
OPTIMIZER = "adamw_torch"
MODEL_OUTPUT_DIRECTORY = f"multimolecule-{MULTIMOLECULE_MODEL}-finetuned-secondary-structure"
LOGGING_STEPS = 1000 # Number of steps between each logging of eval metrics (set big to reduce training time)

In [2]:
WORKING_DIRECTORY = '/content/drive/MyDrive/epfl_ml_project'
DATASET_PATH = 'data/fresh_dataset.txt'

In [3]:
%%capture
!pip install datasets evaluate multimolecule==0.0.5

In [4]:
import os
import pandas as pd
import torch
from transformers import (
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer
)
from datasets import Dataset
from google.colab import drive

import matplotlib.pyplot as plt

In [None]:
drive.mount('/content/drive')
%cd {WORKING_DIRECTORY}

In [None]:
from BP_LM.scripts.data_preprocessing import *
from BP_LM.scripts.trainer_datasets_creation import create_dataset
from BP_LM.scripts.compute_metrics import compute_metrics, precision_recall_data
from BP_LM.scripts.model_choice import set_multimolecule_model

os.environ["WANDB_MODE"] = "disabled"

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

In [None]:
# Initialize the selected multimolecule model
model, tokenizer, MODEL_MAX_INPUT_SIZE = set_multimolecule_model(MULTIMOLECULE_MODEL)

# Load data and create dataset
df = pd.read_csv(DATASET_PATH, sep='\t')
train_dataset, val_dataset, test_dataset = create_dataset(df, tokenizer, model, MODEL_MAX_INPUT_SIZE, SEED, SAMPLE_N_DATAPOINTS)

# Set up data collator
data_collator = DataCollatorForTokenClassification(tokenizer)

In [None]:
# Define model training parameters
training_args = TrainingArguments(
    output_dir=MODEL_OUTPUT_DIRECTORY,
    eval_strategy="steps",
    save_strategy="no",
    learning_rate=LEARNING_RATE,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=TRAIN_EPOCHS,
    optim=OPTIMIZER,
    weight_decay=WEIGHT_DECAY,
    logging_strategy="steps",
    logging_steps=LOGGING_STEPS,
    seed=SEED,
)

# Define metrics function
metrics_fn = lambda x: compute_metrics(x, "test_metrics")

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=metrics_fn,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

trainer.train()

# Save the final model manually after training
trainer.save_model(os.path.join(MODEL_OUTPUT_DIRECTORY, "final_model"))

## Plot Metrics Over Course of Training

In [None]:
log_history = trainer.state.log_history

eval_entries = [entry for entry in log_history if 'eval_loss' in entry]

data = {
    'step': [entry['step'] for entry in eval_entries],
    'eval_loss': [entry['eval_loss'] for entry in eval_entries],
    'eval_seq_accuracy': [entry['eval_seq_accuracy'] for entry in eval_entries],
    'eval_F1': [entry['eval_F1'] for entry in eval_entries]}

df = pd.DataFrame(data)
df.to_csv(f'{MODEL_OUTPUT_DIRECTORY}/eval_metrics.csv')

fig, ax1 = plt.subplots(figsize=(12, 6))

ax1.set_xlabel('Step')
ax1.set_ylabel('Eval Loss')
ax1.plot(df['step'], df['eval_loss'], color="tab:blue", linestyle='-', label='Eval Loss')
ax1.tick_params(axis='y')
ax1.set_xlim(left=0)

ax2 = ax1.twinx()
ax2.set_ylabel('Accuracy & F1')
ax2.plot(df['step'], df['eval_seq_accuracy'], color='tab:orange', linestyle='-', label='Eval Seq Accuracy')
ax2.plot(df['step'], df['eval_F1'], color="tab:green", linestyle='-', label='Eval F1')
ax2.tick_params(axis='y')
ax2.set_ylim(0, 1)

lines_1, labels_1 = ax1.get_legend_handles_labels()
lines_2, labels_2 = ax2.get_legend_handles_labels()
ax1.legend(lines_1 + lines_2, labels_1 + labels_2, loc='upper right')

plt.title(f'Evaluation Metrics per step for {MULTIMOLECULE_MODEL}')

fig.tight_layout()

plt.show()

In [None]:
last_pr_curve = precision_recall_data[-1]
precision, recall = zip(*last_pr_curve)

fig, ax = plt.subplots(dpi=100, figsize=(10, 6))
ax.plot(recall, precision, linestyle='--', color='b', label='Precision-Recall Curve')
ax.set_title(f"Precision-Recall Curve for {MULTIMOLECULE_MODEL}", fontsize=12)

ax.set_xlabel("Recall", fontsize=12)
ax.set_ylabel("Precision", fontsize=12)

ax.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
plt.show()