In [None]:
from transformers import BertForSequenceClassification, BertTokenizer, BertForMaskedLM, AdamW
from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np
import matplotlib.pyplot as plt
import re
from scipy.spatial.distance import jensenshannon
import os

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
training_path = '/content/drive/MyDrive/IGI BERT/final_23S_LM_Train.fa'
test_path = '/content/drive/MyDrive/IGI BERT/final_23S_LM_Test.fa'

training_data = open(training_path)
test_data = open(test_path)

read_training_data = training_data.read()

read_training_data = read_training_data[:-1]
condensed_training_string = read_training_data.replace('\n', '') #removes the new line character
print(len(condensed_training_string))

print(condensed_training_string[:10000])

89205523
>RS_GCF_000217715.1UGAACGUGGCUACUGUGCCAGCUGGUGGAUCGCUCGGCUUGAGAGCUGAAGAAGGACGUGCCAAGCUGCGAUAAGCCUCAGGGACCCGCACGGAGGGGAAGAACUGAGGAUUUCCGAAUGGGAAUCCCCACCGCAAUUGCUUCGCGCAAUGGGGAACGCCGAGAAUUGAAACAUCUUAGUAUCGGCAGGAAGAGAAAACGUAACCGUGAUGUCGUUAGUAACGGCGAGUGAACGCGACACAGUCCAAACCGAAGCCUUCGGGCAAUGUGGUGUUCGGACUGACAACCACUCCGCGAAAGGCUGCAAGAAGUCUCUUGGAAUAGAGCACGAAACAGGGUGAUAGUCCCGUACUGCAGGCGAGUAUCGGACGCGUCAGCUCCAGAGUAUCGGGGGUUGGAUAUCCCUCGUGAAUAUUGCAGGCAUCGACUGCAAAGACUAAACACUCCUCAAGACCGAUAGCGAACAAGUAGCGUGAGCGAACGCUGAAAAGCACCCCACGAAGGGAGGUGCAAUAGGGCGUGAAAUCAGUUGGCGAUCGAACGACGGGGCAUACAAGGUCCUCUACACAAUGACCGUAGCGCGAGCUACCAGUAAGAAGUAGAGGAAGCCGAUGUUCCGUCGUACGUUUUGAAAAACGAACCAGGGAGUGUGUCUGAUUGGCGAGUCUAACCUGAUUAUCAGGGAAGGCGUAGGGAAACCGACAUGGCCGCAGCAUUGCGAGGGCCGCCGUGUUCAAGCGCGGGGAGUCAAUCGGACACGACCCGAAACCGGAUGAUCUAGACAUGGGCAAGACGAAGCGUGCCGAAAGGCACGUGGAGGUCUGCUAGCGUUGGUGUCCUACAAUACCCUCGCGUGACCUAUGUCUAGGGGUGAAAGGCCCAUCGAAUCCGGAAACAGCUGGUUCCGACCGAAACAUGUCGAAGCAUGACCUCUGCCGAGGUAGUCUGUGGGGUAGAGCGACGGAUUGGGGGA

In [None]:
#cleaning the data -> getting rid of the >RS_GCF_000217715.1 and its equivalents
cleaning_pattern = r">[A-Z]+_[A-Z]+_[0-9]+.[0-9](\w+)"
read_training_data_lst = re.findall(cleaning_pattern, condensed_training_string)
print(len(read_training_data_lst))

30486


In [None]:
read_training_data_lst = read_training_data_lst[0:6000]
print(len(read_training_data_lst))
# considering only a small sample for training purpose, lack of compute

6000


In [None]:
def m_tokenizer(string, size):
    '''
    Inputs:
    -> string: the string that you want tokenized
    -> size: this function only accepts one, two or three. Pass this is in as a string format and not integer
    '''
    tokenized_lst = []

    if size.lower() == 'one':
        single_token_lst = []
        for i in range(len(string)):
            single_token = string[i]
            single_token_lst.append(single_token)
        tokenized_lst = single_token_lst


    elif size.lower() == 'two':
        double_split_lst = []
        for i in range(len(string)):
            double_token = string[i: i+2]
            if len(double_token) == 2:
                double_split_lst.append(double_token)
        tokenized_lst = double_split_lst


    elif size.lower() == 'three':
        triple_split_lst = []
        for i in range(len(string)):
            triple_token = string[i: i+3]
            if len(triple_token) == 3:
                triple_split_lst.append(triple_token)
        tokenized_lst = triple_split_lst

    return np.array(tokenized_lst)

In [None]:
single_tokens = []
for i in read_training_data_lst:
    single_tokens.append(m_tokenizer(i, 'one'))

In [None]:
double_tokens = []
for i in read_training_data_lst:
    double_tokens.append(m_tokenizer(i, 'two'))

In [None]:
triple_tokens = []
for i in read_training_data_lst:
    triple_tokens.append(m_tokenizer(i, 'three'))

In [None]:
print(single_tokens)

[array(['U', 'G', 'A', ..., 'A', 'C', 'U'], dtype='<U1'), array(['U', 'G', 'A', ..., 'A', 'C', 'U'], dtype='<U1'), array(['G', 'G', 'A', ..., 'U', 'G', 'C'], dtype='<U1'), array(['G', 'A', 'G', ..., 'C', 'U', 'C'], dtype='<U1'), array(['A', 'A', 'U', ..., 'A', 'U', 'U'], dtype='<U1'), array(['U', 'G', 'A', ..., 'A', 'U', 'A'], dtype='<U1'), array(['U', 'G', 'G', ..., 'A', 'G', 'C'], dtype='<U1'), array(['G', 'G', 'A', ..., 'U', 'G', 'C'], dtype='<U1'), array(['U', 'C', 'A', ..., 'U', 'G', 'A'], dtype='<U1'), array(['G', 'C', 'C', ..., 'G', 'G', 'G'], dtype='<U1'), array(['C', 'C', 'A', ..., 'A', 'A', 'U'], dtype='<U1'), array(['C', 'C', 'A', ..., 'A', 'C', 'U'], dtype='<U1'), array(['G', 'G', 'G', ..., 'C', 'U', 'A'], dtype='<U1'), array(['G', 'G', 'G', ..., 'C', 'C', 'C'], dtype='<U1'), array(['C', 'U', 'A', ..., 'U', 'A', 'G'], dtype='<U1'), array(['C', 'A', 'C', ..., 'U', 'C', 'A'], dtype='<U1'), array(['C', 'G', 'A', ..., 'U', 'C', 'A'], dtype='<U1'), array(['G', 'C', 'C', ..., 'G'

In [None]:
print(double_tokens)

[array(['UG', 'GA', 'AA', ..., 'CA', 'AC', 'CU'], dtype='<U2'), array(['UG', 'GA', 'AA', ..., 'CA', 'AC', 'CU'], dtype='<U2'), array(['GG', 'GA', 'AG', ..., 'CU', 'UG', 'GC'], dtype='<U2'), array(['GA', 'AG', 'GC', ..., 'GC', 'CU', 'UC'], dtype='<U2'), array(['AA', 'AU', 'UC', ..., 'CA', 'AU', 'UU'], dtype='<U2'), array(['UG', 'GA', 'AA', ..., 'CA', 'AU', 'UA'], dtype='<U2'), array(['UG', 'GG', 'GC', ..., 'CA', 'AG', 'GC'], dtype='<U2'), array(['GG', 'GA', 'AC', ..., 'CU', 'UG', 'GC'], dtype='<U2'), array(['UC', 'CA', 'AA', ..., 'UU', 'UG', 'GA'], dtype='<U2'), array(['GC', 'CC', 'CC', ..., 'AG', 'GG', 'GG'], dtype='<U2'), array(['CC', 'CA', 'AA', ..., 'CA', 'AA', 'AU'], dtype='<U2'), array(['CC', 'CA', 'AA', ..., 'CA', 'AC', 'CU'], dtype='<U2'), array(['GG', 'GG', 'GC', ..., 'GC', 'CU', 'UA'], dtype='<U2'), array(['GG', 'GG', 'GG', ..., 'CC', 'CC', 'CC'], dtype='<U2'), array(['CU', 'UA', 'AC', ..., 'GU', 'UA', 'AG'], dtype='<U2'), array(['CA', 'AC', 'CU', ..., 'GU', 'UC', 'CA'], dtype

In [None]:
#sanity check
print(len(single_tokens) == len(read_training_data_lst))
print(len(double_tokens) == len(read_training_data_lst))
print(len(triple_tokens) == len(read_training_data_lst))

True
True
True


In [None]:
class NucleotideDataset(Dataset):
    def __init__(self, tokenized_data, tokenizer, max_len = 512):
        '''
        tokenized_data -> The tokenized RNA Sequence
        tokenizer -> BERT Tokenizer
        max_len -> max length of the sequence
        Note: any sequences longer than the max_len will be truncated, and longer ones will be padded
        '''
        self.tokenized_data = tokenized_data
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.tokenized_data)

    def __getitem__(self, idx):
        tokens = self.tokenized_data[idx] # get the tokenized sequence at a particular index
        sequence = ' '.join(tokens)  # joins tokens into a single string

        # Tokenizing the sequence using BERT tokenizer
        inputs = self.tokenizer(sequence, max_length=self.max_len, padding='max_length', truncation=True, return_tensors="pt")
        input_ids = inputs['input_ids'].squeeze(0)
        attention_mask = inputs['attention_mask'].squeeze(0)

        return input_ids, attention_mask

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')

dataset = NucleotideDataset(single_tokens, tokenizer)
train_loader = DataLoader(dataset, batch_size=10, shuffle=True)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.train()

optimizer = AdamW(model.parameters(), lr=5e-5)

In [None]:
import random

def mask_tokens(inputs, tokenizer, mlm_probability=0.15):
    """
    Prepare masked tokens inputs/labels for masked language modeling.
    """
    device = inputs.device  # Ensure everything is on the same device
    labels = inputs.clone().to(device)
    probability_matrix = torch.full(labels.shape, mlm_probability, device=device)

    # Create the special tokens mask on the same device
    special_tokens_mask = [
        tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.cpu().tolist()
    ]
    special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool, device=device)

    probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    labels[~masked_indices] = -100  # We only compute loss on masked tokens

    indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8, device=device)).bool() & masked_indices
    inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

    indices_random = torch.bernoulli(torch.full(labels.shape, 0.5, device=device)).bool() & masked_indices & ~indices_replaced
    random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long, device=device)
    inputs[indices_random] = random_words[indices_random]

    return inputs, labels

In [None]:
checkpoint_dir = '/content/drive/MyDrive/IGI BERT/Checkpoints'

def save_checkpoint(epoch, model, optimizer, loss, path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, path)


def load_checkpoint(path, model, optimizer):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    return epoch, loss

In [None]:
checkpoint_path = '/content/drive/MyDrive/checkpoints/latest_checkpoint.pt'
if os.path.exists(checkpoint_path):
    start_epoch, _ = load_checkpoint(checkpoint_path, model, optimizer)
    start_epoch += 1  # Continue from the next epoch
else:
    start_epoch = 0

epochs = 3
for epoch in range(start_epoch, epochs):
    model.train()
    total_loss = 0
    for step, batch in enumerate(train_loader):
        input_ids, attention_mask = batch
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)

        true_labels = input_ids.clone()  # Store the original input as true labels

        inputs, labels = mask_tokens(input_ids, tokenizer)  # Masked tokens and labels
        outputs = model(input_ids=inputs, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        total_loss += loss.item()
        if step % 10 == 0:
            print(f"Epoch: {epoch + 1}, Step: {step}, Loss: {loss.item()}")
            #clear_memory()  # Clear memory periodically

        # Save checkpoint periodically
        if step % 100 == 0:
            checkpoint_path = os.path.join(checkpoint_dir, 'latest_checkpoint.pt')
            save_checkpoint(epoch, model, optimizer, total_loss / (step + 1), checkpoint_path)

    print(f"Epoch: {epoch + 1}, Average Loss: {total_loss / len(train_loader)}")
    #clear_memory()

Epoch: 1, Step: 0, Loss: 3.098503351211548
Epoch: 1, Step: 10, Loss: 1.7019952535629272
Epoch: 1, Step: 20, Loss: 1.0696778297424316
Epoch: 1, Step: 30, Loss: 0.6610444784164429
Epoch: 1, Step: 40, Loss: 0.36047548055648804
Epoch: 1, Step: 50, Loss: 0.20060035586357117
Epoch: 1, Step: 60, Loss: 0.18791206181049347
Epoch: 1, Step: 70, Loss: 0.13093963265419006
Epoch: 1, Step: 80, Loss: 0.1496947556734085
Epoch: 1, Step: 90, Loss: 0.1448204070329666
Epoch: 1, Step: 100, Loss: 0.09531091153621674
Epoch: 1, Step: 110, Loss: 0.07837588340044022
Epoch: 1, Step: 120, Loss: 0.09920457750558853
Epoch: 1, Step: 130, Loss: 0.07166744768619537
Epoch: 1, Step: 140, Loss: 0.11320969462394714
Epoch: 1, Step: 150, Loss: 0.09628915786743164
Epoch: 1, Step: 160, Loss: 0.058947596698999405
Epoch: 1, Step: 170, Loss: 0.06366852670907974
Epoch: 1, Step: 180, Loss: 0.06312243640422821
Epoch: 1, Step: 190, Loss: 0.03142556920647621
Epoch: 1, Step: 200, Loss: 0.07776006311178207
Epoch: 1, Step: 210, Loss: 0.0

In [None]:
test_path = '/content/drive/MyDrive/IGI BERT/final_23S_LM_Test.fa'

training_data = open(training_path)
test_data = open(test_path)

read_test_data = test_data.read()

read_test_data = read_test_data[:-1]
condensed_test_string = read_test_data.replace('\n', '') #removes the new line character


#cleaning the data -> getting rid of the >RS_GCF_000217715.1 and its equivalents
cleaning_pattern = r">[A-Z]+_[A-Z]+_[0-9]+.[0-9](\w+)"
read_test_data_lst = re.findall(cleaning_pattern, condensed_test_string)
print(len(read_test_data_lst))

triple_test_tokens = []
for i in read_test_data_lst:
    triple_test_tokens.append(m_tokenizer(i, 'three'))


test_dataset = NucleotideDataset(triple_test_tokens, tokenizer)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)

1831


In [None]:
import torch
from scipy.spatial.distance import jensenshannon
import numpy as np

def calculate_jsd(predictions, true_labels, vocab_size):
    """Calculate the Jensen-Shannon Divergence between the predictions and the true labels."""
    predictions = torch.softmax(predictions, dim=-1).cpu().detach().numpy()

    # One-hot encode the true labels
    true_labels_one_hot = np.zeros((true_labels.size(0), true_labels.size(1), vocab_size))
    true_labels_one_hot[np.arange(true_labels.size(0))[:, None], np.arange(true_labels.size(1)), true_labels.cpu().numpy()] = 1

    jsd = np.mean([jensenshannon(pred, true, axis=-1) for pred, true in zip(predictions, true_labels_one_hot)])
    return jsd

model.eval()
with torch.no_grad():
    total_jsd = 0
    count = 0
    for step, batch in enumerate(test_loader):
        input_ids, attention_mask = batch
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits

        jsd = calculate_jsd(logits, input_ids, logits.size(-1))

        total_jsd += jsd
        count += 1

    average_jsd = total_jsd / count
    print(f"Average Jensen-Shannon Divergence on Test Data: {average_jsd}")

Average Jensen-Shannon Divergence on Test Data: 0.0050884016261092384
