In [127]:
# Install datasets as it is not already installed on colab
!pip install datasets



In [128]:
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM, AutoModelForTokenClassification
from datasets import Dataset
from sklearn.metrics import accuracy_score
import torch.nn.functional as F



In [129]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [130]:
from google.colab import drive
drive.mount('/content/drive')

# Change working directory to Project folder, you may change this as needed
%cd "/content/drive/MyDrive/Machine Learning (CS-433)/Project 2/BP_LM"

from data_preprocessing import *

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/Machine Learning (CS-433)/Project 2/BP_LM


In [141]:
file_path = 'dataset/dataset1.txt'

# Load dataset
df = pd.read_csv(file_path, sep='\t')

trunc_df = df[df['IVS_SIZE'] < 1000]
df = trunc_df.sample(n=min(1000, len(trunc_df)))

# Calculate BP_POS_WITHIN_STRAND
df['BP_POS_WITHIN_STRAND'] = df['IVS_SIZE'] + df['BP_ACC_DIST']


In [142]:
# Create a split based on chromosome types (Alis idea)
train_chrs = ["chr1", "chr2", "chr3", "chr4",
              "chr5","chr10",
              "chr11", "chr12", "chr13", "chr14",
              "chr15", "chr16", "chr17", "chr18",
              "chr19", "chr22",
              "chrX", "chrY"]

# Keep chr6 and chr7 in train if we want a 90/10/10 train/val/test split
test_chrs = ["chr8", "chr20", "chr6"]
val_chrs = ["chr9", "chr21", "chr7"]

train_df, test_df, val_df = split_train_test_on_chr(df, train_chrs, val_chrs, test_chrs, shuffle=True)

Chromosomes in train set: {'chr3', 'chr15', 'chr4', 'chr22', 'chrX', 'chr2', 'chr5', 'chrY', 'chr19', 'chr16', 'chr18', 'chr17', 'chr14', 'chr11', 'chr13', 'chr10', 'chr12', 'chr1'}
Chromosomes in validation set: {'chr9', 'chr21', 'chr7'}
Chromosomes in test set: {'chr20', 'chr8', 'chr6'}

Total data points: 1000
Train set contains 826 data points (82.60%)
Validation set contains 80 data points (8.00%)
Test set contains 94 data points (9.40%)


In [143]:
train_seqs, train_labels = extract_intron_seq_and_labels(train_df, max_model_input_size=1024, truncate=True)
test_seqs, test_labels = extract_intron_seq_and_labels(test_df, max_model_input_size=1024, truncate=True)
val_seqs, val_labels = extract_intron_seq_and_labels(val_df, max_model_input_size=1024, truncate=True)

In [144]:
SPLICEBERT_PATH = "models/SpliceBERT.1024nt"  # set the path to the folder of pre-trained SpliceBERT

# load tokenizer
tokenizer = AutoTokenizer.from_pretrained(SPLICEBERT_PATH)

# finetuning SpliceBERT for token classification tasks
model = AutoModelForTokenClassification.from_pretrained(SPLICEBERT_PATH, num_labels = 2) # We want binary classification on tokens so num_labels = 2

Some weights of BertForTokenClassification were not initialized from the model checkpoint at models/SpliceBERT.1024nt and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Now we do it for our data

In [145]:
# Tokenize the input data
train_seqs = [' '.join(list(seq.upper().replace("U", "T"))) for seq in train_seqs] #There shouldn't be any "U"s in the training data, but I have kept the replacer line
test_seqs = [' '.join(list(seq.upper().replace("U", "T"))) for seq in test_seqs]
val_seqs = [' '.join(list(seq.upper().replace("U", "T"))) for seq in val_seqs]

In [146]:
def pad_labels(labels, max_length, pad_label=-100):
    """
    Pads labels with -100 which is apparenty standard in HuggingFace
    """
    padded_labels = []
    for label in labels:
        if len(label) < max_length:
            padded = label + [pad_label] * (max_length - len(label))
        else:
            padded = label[:max_length]
        padded_labels.append(padded)
    return padded_labels

max_length = 1024  # Ensure this matches the tokenizer's max_length

train_labels_padded = pad_labels(train_labels, max_length)
test_labels_padded = pad_labels(test_labels, max_length)
val_labels_padded = pad_labels(val_labels, max_length)

train_ids = tokenizer(train_seqs, padding='max_length', padding_side='left', max_length=max_length, truncation=True)
test_ids = tokenizer(test_seqs, padding='max_length', padding_side='left', max_length=max_length, truncation=True)
val_ids = tokenizer(val_seqs, padding='max_length', padding_side='left', max_length=max_length, truncation=True)

# Create Datasets
train_dataset = Dataset.from_dict(train_ids)
train_dataset = train_dataset.add_column("labels", train_labels_padded)

test_dataset = Dataset.from_dict(test_ids)
test_dataset = test_dataset.add_column("labels", test_labels_padded)

val_dataset = Dataset.from_dict(val_ids)
val_dataset = val_dataset.add_column("labels", val_labels_padded)

# Set up the collator (I think it does padding)
from transformers import DataCollatorForTokenClassification, TrainingArguments, Trainer
data_collator = DataCollatorForTokenClassification(tokenizer)




In [None]:
model_name = SPLICEBERT_PATH.split("/")[-1]
batch_size = 8

def compute_metrics(pred):
    predictions, labels = pred
    predictions = predictions[0]

    predictions = np.array(predictions)
    labels = np.array(labels)

    preds = np.argmax(predictions, axis=-1)

    sequence_matches = 0
    total_sequences = 0

    for label, prediction in zip(labels, preds):
        nonpadded_indices = label != -100 # Only consider non-padded tokens
        nonpadded_labels = label[nonpadded_indices]
        preds = prediction[nonpadded_indices]


        if np.array_equal(nonpadded_labels, preds): # If the entire label matches, count it as correct
            sequence_matches += 1

        total_sequences += 1

    acc = sequence_matches / total_sequences if total_sequences > 0 else 0

    return {"accuracy": acc}


args = TrainingArguments(
    f"{model_name}-finetuned-secondary-structure",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=3e-4,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.001,
    metric_for_best_model="accuracy",
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

trainer.train()

# Evaluate the model
evaluation_results = trainer.evaluate()
print(evaluation_results)


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.016152,0.0
2,No log,0.016158,0.0


In [None]:
trained_model = AutoModelForTokenClassification.from_pretrained(f"{model_name}-finetuned-secondary-structure/checkpoint-2991") #make sure you are loading the right checkpoint
trained_model = trained_model.to(device)

In [140]:
showcase_seq = test_seqs[20]
showcase_ids = tokenizer.encode(showcase_seq) # N -> 5, A -> 6, C -> 7, G -> 8, T(U) -> 9. NOTE: a [CLS] and a [SEP] token will be added to the start and the end of seq
print(showcase_ids)
showcase_ids = torch.as_tensor(showcase_ids)
test_id = showcase_ids.unsqueeze(0)
test_id = test_id.to(device)

test_logit = trained_model(test_id, output_hidden_states=False).logits
test_probs = torch.sigmoid(test_logit)
class1_probs = test_probs[..., 1]
max_indices = class1_probs.argmax(dim=-1)
predicted_classes = torch.zeros_like(class1_probs)
predicted_classes[torch.arange(test_logit.size(0)), max_indices] = 1

predicted_classes = predicted_classes.squeeze(0)

print(predicted_classes)

print(sum(predicted_classes))
print(predicted_classes.argmax(dim = -1))

print(torch.as_tensor(train_labels[20][predicted_classes.argmax(dim = -1)]))


[2, 9, 7, 6, 8, 6, 6, 7, 9, 6, 7, 6, 8, 8, 9, 8, 6, 6, 6, 9, 9, 6, 7, 6, 8, 6, 8, 9, 6, 9, 6, 9, 9, 9, 6, 9, 6, 6, 6, 9, 7, 6, 9, 9, 7, 6, 9, 9, 6, 6, 6, 9, 9, 8, 7, 6, 8, 6, 9, 7, 9, 6, 6, 9, 6, 9, 6, 7, 6, 8, 6, 9, 7, 6, 6, 6, 9, 9, 9, 7, 9, 6, 7, 6, 6, 8, 9, 8, 9, 7, 6, 9, 6, 8, 7, 6, 8, 9, 8, 6, 6, 9, 9, 9, 9, 8, 9, 6, 9, 9, 9, 9, 9, 9, 9, 9, 9, 7, 9, 7, 7, 6, 9, 7, 7, 7, 9, 7, 9, 9, 8, 6, 8, 8, 9, 9, 9, 9, 8, 9, 8, 6, 7, 7, 7, 6, 9, 9, 9, 9, 7, 6, 9, 7, 6, 8, 6, 6, 6, 9, 6, 8, 9, 8, 6, 8, 9, 9, 8, 7, 9, 9, 7, 6, 8, 7, 6, 8, 9, 8, 6, 9, 9, 6, 9, 7, 6, 8, 6, 9, 9, 8, 8, 6, 8, 6, 8, 8, 6, 8, 6, 6, 7, 9, 8, 7, 7, 7, 7, 7, 6, 7, 7, 7, 9, 7, 6, 9, 8, 7, 7, 9, 7, 6, 9, 9, 7, 7, 6, 8, 6, 8, 6, 6, 9, 8, 9, 6, 9, 7, 6, 9, 6, 9, 8, 9, 9, 7, 9, 8, 8, 8, 8, 8, 6, 8, 9, 8, 7, 6, 9, 8, 9, 6, 6, 9, 9, 9, 8, 6, 6, 6, 6, 6, 8, 6, 9, 8, 7, 7, 9, 7, 8, 8, 9, 9, 6, 9, 9, 7, 9, 8, 6, 9, 6, 7, 7, 6, 8, 8, 7, 6, 9, 9, 9, 7, 9, 8, 7, 7, 9, 7, 7, 6, 9, 7, 9, 7, 9, 7, 6, 6, 8, 8, 6, 6, 7, 7, 6, 9, 6, 9, 7, 

IndexError: list index out of range