# General Notebook on fine-tuning branch point prediction using any of MultiMolecule models

Any RNA model can be selected from the [MultiMolecule](https://multimolecule.danling.org/models/) website. Simple change the `MULTIMOLECULE_MODEL` variable in the cell below, and the the two cells under the tokenizer "Load the desired model and tokenizer" section.

In [70]:
# GLOBAL VARIABLES
WORKING_DIRECTORY = '/content/drive/MyDrive/epfl_ml_project'
DATASET_PATH = 'dataset/dataset.txt'

# For some reason, the splicebert model doesn't work when this is set to 1024
# (even tho max input length in docs say 1024 is max length)
MODEL_MAX_INPUT_SIZE = 512

MULTIMOLECULE_MODEL = "splicebert"
SAMPLE_N_DATAPOINTS = 100 # Sample a small subset of data for testing purposes. Set to None if training on full dataset
SEED = 66 # TODO: update code to seed everything

In [71]:
# %%capture
!pip install datasets evaluate multimolecule



In [72]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import average_precision_score, PrecisionRecallDisplay, precision_recall_curve
from scipy.special import softmax
import torch
from transformers import (
    AutoTokenizer,
    AutoModel,
    AutoModelForMaskedLM,
    AutoModelForTokenClassification,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer
)
from datasets import Dataset

In [73]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [74]:
from google.colab import drive
drive.mount('/content/drive')

# Change working directory to Project folder, you may change this as needed
%cd {WORKING_DIRECTORY}

from BP_LM.data_preprocessing import *

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/epfl_ml_project


## Load the desired model and tokenizer

In [76]:
# Change this import depending on the model
from multimolecule import RnaTokenizer, SpliceBertForTokenPrediction

In [77]:
tokenizer = RnaTokenizer.from_pretrained(f'multimolecule/{MULTIMOLECULE_MODEL}')
# Change line below depending on what model we want
model = SpliceBertForTokenPrediction.from_pretrained(f'multimolecule/{MULTIMOLECULE_MODEL}')

# TODO: REMOVE
MODEL_PATH = "models/SpliceBERT.1024nt"
# tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForTokenClassification.from_pretrained(MODEL_PATH, num_labels = 2) # We want binary classification on tokens so num_labels = 2

Some weights of SpliceBertForTokenPrediction were not initialized from the model checkpoint at multimolecule/splicebert and are newly initialized: ['splicebert.pooler.dense.bias', 'splicebert.pooler.dense.weight', 'token_head.decoder.bias', 'token_head.decoder.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertForTokenClassification were not initialized from the model checkpoint at models/SpliceBERT.1024nt and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Load the dataset and prepare the labels and sequences

In [78]:
# Load dataset
df = pd.read_csv(DATASET_PATH, sep='\t')

# Calculate BP_POS_WITHIN_STRAND
df['BP_POS_WITHIN_STRAND'] = df['IVS_SIZE'] + df['BP_ACC_DIST']

In [79]:
# Remove all data points where the BP is farther than
df = df[df['IVS_SIZE'] - df['BP_POS_WITHIN_STRAND'] <= MODEL_MAX_INPUT_SIZE]

if SAMPLE_N_DATAPOINTS:
    df = df.sample(n = SAMPLE_N_DATAPOINTS, random_state=SEED)

print(df.shape)

(100, 13)


In [80]:
# Create a split based on chromosome types
train_chrs = ["chr1", "chr2", "chr3", "chr4",
              "chr5", "chr6", "chr7",
              "chr12", "chr13", "chr14",
              "chr15", "chr16", "chr17", "chr18",
              "chr19", "chr20", "chr21", "chr22",
              "chrX", "chrY"]

# Keep chr6 and chr7 in train if we want a 90/10/10 train/val/test split
test_chrs = ["chr8", "chr11"]
val_chrs = ["chr9", "chr10"]

train_df, test_df, val_df = split_train_test_on_chr(df, train_chrs, val_chrs, test_chrs, shuffle=True)

Chromosomes in train set: {'chr5', 'chr16', 'chr1', 'chr13', 'chr15', 'chr19', 'chr3', 'chr12', 'chr22', 'chr14', 'chr4', 'chr17', 'chrX', 'chr18', 'chr2', 'chr7', 'chr6'}
Chromosomes in validation set: {'chr10', 'chr9'}
Chromosomes in test set: {'chr11', 'chr8'}

Total data points: 100
Train set contains 82 data points (82.00%)
Validation set contains 10 data points (10.00%)
Test set contains 8 data points (8.00%)


In [81]:
df

Unnamed: 0,CHR,START,END,STRAND,GENE,TRANSCRIPT,IVS,IVS_SIZE,BP_POS,BP_ACC_DIST,BP_ACC_SEQ,IVS_SEQ,BP_POS_WITHIN_STRAND
176262,chrX,130131800,130133312,-,AIFM1,ENST00000287295,IVS13,1513,130131826,-27,ACAGAACACCCTCTTCTCTCTTCTTAG,GTAATAGTGGTGTGCCTTGGAGACCCAACTGATTACAACTCAGTAT...,1486
100861,chr11,67519165,67519792,-,CABP2,ENST00000294288,IVS6,628,67519185,-21,ACCTTGGCTTCGTCCTGGCAG,GTACCCCCTGCCGCACATAGCAACACACGCCCTGGAAGGGTCCTAG...,607
73639,chr7,142867723,142867996,+,EPHB6,ENST00000652003,IVS12,274,142867980,-17,ACCGTTTTGTTCCTCAG,GTGAGTCCCCACCCCTGCCCAACTCTGCCCAGCACCATTAACTCCA...,257
121273,chr14,69320869,69324691,+,GALNT16,ENST00000448469,IVS2,3823,69324671,-21,ACCTCTGTGCTCCCTTCTCAG,GTACGGCCTCCATCGTGTCAGTGGAGGAAGAAAGACTGAGAGAAAG...,3802
130591,chr16,1198775,1200255,+,CACNA1H,ENST00000348261,IVS6,1481,1200220,-36,ATTGTACCTTTTGGCCCTGGCTGTGCCCATCCCCAG,GTGCCCAGGCCCCACCCCCGTGAGGCCCCTGCCCAGATGGCCCTGC...,1445
...,...,...,...,...,...,...,...,...,...,...,...,...,...
176696,chrX,139745858,139750024,-,ATP11C,ENST00000682941,IVS24,4167,139745885,-28,ATGGTTTTTTCCCCCCATTTCTCAATAG,GTAAGTTAAATATATAGTTAAAACCCCCTAAGACTTTCAGTGTTGT...,4139
54611,chr5,137880866,137881972,+,MYOT,ENST00000239926,IVS5,1107,137881942,-31,ACGCATAGTTGTTACCAAAATATTCTTGTAG,GTAAAAAATTTTAATTTTAAAGAAATGTATGTTTTCCTATCTAAAA...,1076
115989,chr13,35836142,35839091,-,DCLK1,ENST00000360631,IVS7,2950,35836165,-24,AAAAAGTATTAATGATTCTTTCAG,GTGATGAGCTTGGGACTTGGTATCTGAGGAGGCAGAGGATGGGTTG...,2926
68747,chr7,54752421,54755778,-,SEC61G,ENST00000352861,IVS3,3358,54752440,-20,AATTTTTCTTTCTTTTACAG,GTAAGTAAACTTTATGAAATAGACTAGGATTGAATGAAACTCAAAA...,3338


In [82]:
train_df

Unnamed: 0,CHR,START,END,STRAND,GENE,TRANSCRIPT,IVS,IVS_SIZE,BP_POS,BP_ACC_DIST,BP_ACC_SEQ,IVS_SEQ,BP_POS_WITHIN_STRAND
0,chrX,49076525,49076644,-,WDR45,ENST00000376372,IVS5,120,49076555,-31,ACGGCCATCCTGTGTTGTGTGATACCCACAG,GTGAGCCTGAGGAGGACCGGGGTGGGAGGTAGGAGGTCCCCACAGT...,89
1,chrX,130131800,130133312,-,AIFM1,ENST00000287295,IVS13,1513,130131826,-27,ACAGAACACCCTCTTCTCTCTTCTTAG,GTAATAGTGGTGTGCCTTGGAGACCCAACTGATTACAACTCAGTAT...,1486
2,chr6,73763686,73765929,+,CD109,ENST00000287097,IVS10,2244,73765907,-23,AGATGTTTTGTTTTGACCATTAG,GTAAGTTGGTATATTTATTTCCAGTCCATAGCAGTAAGTTCAGCTT...,2221
3,chr3,155140293,155141990,+,MME,ENST00000360490,IVS10,1698,155141965,-26,AATGTTTCATGCCTGCTTTTTTCCAG,GTAAGTGGTAAGTTTTTTGTGCTCTCTTATTGTGCCGTTTTCTAAA...,1672
4,chr14,23054867,23055058,-,CDH24,ENST00000487137,IVS3,192,23054891,-25,AATGGCGTGCTGGCATCCCCCACAG,GTGAGCACCCCAGCTCTGAATGCCCCAGTTCCCGTCTTCTGGAGCT...,167
...,...,...,...,...,...,...,...,...,...,...,...,...,...
77,chr7,90162079,90164476,+,STEAP1,ENST00000297205,IVS4,2398,90164453,-24,ACCAATTTTGTTTTTCTTTTGCAG,GTAAATAATATATAAAATAACCCTAAGAGGTAAATCTTCTTTTTGT...,2374
78,chr4,150850903,150851884,-,LRBA,ENST00000651943,IVS23,982,150850924,-22,ACTTTTTAAATTTCTGTATTAG,GTGATTTTATATATTGTCTTATTCAAGTACTTTTGCACTGAGTTAA...,960
79,chr18,21657153,21659169,-,ABHD3,ENST00000289119,IVS6,2017,21657176,-24,ACTTCCGAAATTCTTTTGTTCCAG,GTGAGTCATCTTTAAAATCTGTTGTCTCCAGGGCCGGGCCCGGTGG...,1993
80,chr16,57134827,57137148,+,CPNE2,ENST00000290776,IVS13,2322,57137125,-24,ATCCAGACTCTTCTCCCGAGGCAG,GTGAGTGTCAGACCCACCTGCAGCTGCCCTGTGTTTGCTACGGGCC...,2298


In [83]:
train_seqs, train_labels = extract_intron_seq_and_labels(train_df, max_model_input_size=MODEL_MAX_INPUT_SIZE, truncate=True)
test_seqs, test_labels = extract_intron_seq_and_labels(test_df, max_model_input_size=MODEL_MAX_INPUT_SIZE, truncate=True)
val_seqs, val_labels = extract_intron_seq_and_labels(val_df, max_model_input_size=MODEL_MAX_INPUT_SIZE, truncate=True)

In [84]:
# Show an example data pair
print(train_seqs[0])
print(train_labels[0])

GTGAGCCTGAGGAGGACCGGGGTGGGAGGTAGGAGGTCCCCACAGTAAGTGAGAGGGATAGTCCTCCCTGGGCATCCCGCCACCCCCTCACGGCCATCCTGTGTTGTGTGATACCCACAG
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


### Tokenize the training data and put it into the dataset format

In [85]:
# Tokenize the input data
# Shouldn't be any U's in the data anyways
train_seqs = [' '.join(list(seq.upper().replace("U", "T"))) for seq in train_seqs]
test_seqs = [' '.join(list(seq.upper().replace("U", "T"))) for seq in test_seqs]
val_seqs = [' '.join(list(seq.upper().replace("U", "T"))) for seq in val_seqs]

In [86]:
train_ids = tokenizer(train_seqs) #, padding = 'max_length', padding_side = 'left', max_length = 1024)
test_ids = tokenizer(test_seqs) #, padding = 'max_length', padding_side = 'left', max_length = 1024)
val_ids = tokenizer(val_seqs) #, padding = 'max_length', padding_side = 'left', max_length = 1024)

In [87]:
# Build the dataset structure that will be passed for training
train_dataset = Dataset.from_dict(train_ids)
train_dataset = train_dataset.add_column("labels", train_labels)

test_dataset = Dataset.from_dict(test_ids)
test_dataset = test_dataset.add_column("labels", test_labels)

val_dataset = Dataset.from_dict(val_ids)
val_dataset = val_dataset.add_column("labels", val_labels)

In [88]:
# Set up the collator (I think it does padding)
data_collator = DataCollatorForTokenClassification(tokenizer) #Unsure about how data_collator pads exactly so i specified padding in the tokenizer itself


In [89]:
def compute_accuracy(labels, categorical_predictions):
  sequence_matches = 0
  total_sequences = 0

  all_labels = []
  all_preds = []

  for label, cat_preds in zip(labels, categorical_predictions):
      # Ignore padded tokens
      cat_preds = cat_preds[label != -100]
      label = label[label != -100]

      # Sequence-level accuracy
      if np.array_equal(label, cat_preds):  # Entire sequence matches
          sequence_matches += 1
      total_sequences += 1

      # Token-level metrics
      all_labels.extend(label.tolist())
      all_preds.extend(cat_preds.tolist())

  # Sequence-level accuracy
  sequence_accuracy = sequence_matches / total_sequences if total_sequences > 0 else 0
  return sequence_accuracy


def compute_metrics(eval_pred):
    """
    Function to simultaneously evaluate accuracy, F1 and average precision (AP)

    The function does the evaluation per label for accuracy, and per token otherwise.

    average precision is the most interesting as it accounts for the fact that the
    ideal decision boundary may be something non trivial.

    F1 and accuracy are reported at the decision boundary which maximises the F1
    """
    raw_predictions, labels = eval_pred

    logits = raw_predictions # Discard hidden states and keep logits

    # Find predictions from logits
    predictions = softmax(logits, axis=2)[:,:,1] # Probability of a positive label

    # Reshape predictions and labels into long strings to compute metrics per token
    predictions_flat = predictions.reshape((-1,))
    labels_flat = labels.reshape((-1,))

    # Remove all the padded ones
    predictions_flat = predictions_flat[labels_flat!=-100]
    labels_flat = labels_flat[labels_flat!=-100]

    # Compute average precision
    AP = average_precision_score(labels_flat, predictions_flat)

    # Plot precision curves
    precision, recall, thresholds = precision_recall_curve(labels_flat, predictions_flat)
    fig, ax = plt.subplots(dpi = 300, figsize = (5,3))
    ax.set_ylabel("Precision")
    ax.set_xlabel("Recall")
    ax.set_title("Precision-Recall Curve at final epoch: splicebert", fontsize = 12)
    ax.plot(recall, precision)
    fig.savefig("Precision-Recall Curve at final epoch: splicebert.png")

    np.savetxt("pr_curve.txt", np.vstack((precision,recall)).T)

    # Compute ideal boundary and optimized F1
    ideal_threshold = thresholds[np.argmax(2 * (precision * recall) / (precision + recall))]
    F1 = np.max(2 * (precision * recall) / (precision + recall))

    # Calculate accuracy
    categorical_predictions = np.where(predictions>ideal_threshold, 1, 0)
    accuracy = compute_accuracy(labels, categorical_predictions)

    # Combine metrics
    dictionary = {"F1" : F1} | {"Accuracy" : accuracy} | {"AP" : AP} | {"ideal_threshold" : ideal_threshold}

    # Save the performance metrics to a text file
    with open('performance_metrics.txt', 'w') as f:
      print(dictionary, file=f)

    return dictionary

# Hidden states of all test samples is order 260GB
# Get rid of these during evaluation to not run out of RAM
def preprocess_logits_for_metrics(logits, labels):
  """
  The metric function needs the logits of all samples in the test set loaded in RAM
  simultaneously to evaluate performance. Model output contains both logits and
  Hidden states, hidden states take up much more space.

  We inject this function to the training loop to delete hidden states while
  evaluating the test samples. This frees tons of memory.
  """
  logits = logits[0] # Discard hidden states
  return logits

In [90]:
# Do not save to W&B
import os
os.environ["WANDB_MODE"] = "disabled"

In [91]:
# Define model training parameters
batch_size = 20

args = TrainingArguments(
    f"multimolecule-{MULTIMOLECULE_MODEL}-finetuned-secondary-structure",
    eval_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=3e-4,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=1,
    optim = "adamw_torch",
    weight_decay=0.001,
    load_best_model_at_end=True,
    metric_for_best_model="F1",
    #eval_accumulation_steps = 10,
    #push_to_hub=True,
)

In [None]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics,
    preprocess_logits_for_metrics = preprocess_logits_for_metrics,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

trainer.train()


  trainer = Trainer(


In [None]:
val_df