In [1]:
!pip install transformers



In [1]:
import pickle
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, TensorDataset
from transformers import BertForTokenClassification
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
from transformers import BertTokenizer
import torch
from torch.nn.utils.rnn import pad_sequence
from sklearn.metrics import accuracy_score, classification_report

In [2]:
# load dataset
with open("val_dataset.pkl", "rb") as f:
    val_ds = pickle.load(f)

In [3]:
dir(val_ds)

['__class__',
 '__class_getitem__',
 '__contains__',
 '__delattr__',
 '__delitem__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__ior__',
 '__iter__',
 '__le__',
 '__len__',
 '__lt__',
 '__ne__',
 '__new__',
 '__or__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__reversed__',
 '__ror__',
 '__setattr__',
 '__setitem__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 'clear',
 'copy',
 'fromkeys',
 'get',
 'items',
 'keys',
 'pop',
 'popitem',
 'setdefault',
 'update',
 'values']

In [4]:
print(type(val_ds))

# keys in the dictionary
print(val_ds.keys())

# inspect contents
for key in val_ds:
    print(f"Key: {key}")
    print(f"Value: {val_ds[key]}")

<class 'dict'>
dict_keys(['sequences', 'structures', 'nt_to_idx', 'struct_to_idx'])
Key: sequences
Value: ['GGGGCCAUAGCUCAGCUGGGAGAGCGCUUGCAUGGCAUGCAAGAGGUCGGCGGUUCGAUCCCGCCUGGCUCCACCA', 'AAAACCAAUAUUACUACCAUCGGUCCUCACCACUAGAUCGGUGUUAUGCUUUGUUGGGAUAGCAGGCCGUGCCAGUUGGACAGCCAAGGUCCACCUCUGGUUCGGCACACAU', 'AUAUAUCAUUGUAUUUUGCCUCGUAUAAUUGUGGAGAUAUGGUCCACGAGUUUCUACCGGACGGCCGUAAACUGUUCGACUACGGGGGAAACACCUGGGAUG', 'UAGAGUUUGAUUAUGGCUCAGAACGAACGCUGGCGGCAGGCUUAACACAUGCAAGUCGAACGCCCCGCAAGGGGAGUGGCAGACGGGUGAGUAACGCGUGGGAACAUACCCUUUCCUGCGGAAUAGSUCCGGGAAACUGGAAUUAAUACCGCAUACGCCCUACGGGGGAAAGAUUUAUCGGGGAAGGAUUGGCCCGCGUUGGAUUAGCUAGUUGGUGGGGUAAAGGCCUACCAAGGCGACGAUCCAUAGCUGGUCUGAGAGGAUGAUCAGCCACAUUGGGACUGAGACACGGCCCAAACUCCUACGGGAGGCAGCAGUGGGGAAUAUUGGACAAUGGGCGCAAGCCUGAUCCAGCCAUGCCGCGUGAGUGAUGAAGGCCUUAGGGUUGUAAAGCUCUUUCACCGGAGAAGAUAAUGACGGUAUCCGGAGAAGAAGCCCCGGCUAACUUCGUGCCAGCAGCCGCGGUAAUACGAAGGGGGCUAGCGUUGUUCGGAAUUACUGGGCGUAAAGCGCACGUAGGCGGAUAUUUAAGUCAGGGGUGAAAUCCCAGAGCUCAACUCUGGAACUGCCUUUGAUACUGGGUAUCUUGA

In [5]:
# get sequences and structures
sequences = val_ds['sequences']
structures = val_ds['structures']

print("first 3 sequences:", sequences[:3])
print("first 3 structures:", structures[:3])

first 3 sequences: ['GGGGCCAUAGCUCAGCUGGGAGAGCGCUUGCAUGGCAUGCAAGAGGUCGGCGGUUCGAUCCCGCCUGGCUCCACCA', 'AAAACCAAUAUUACUACCAUCGGUCCUCACCACUAGAUCGGUGUUAUGCUUUGUUGGGAUAGCAGGCCGUGCCAGUUGGACAGCCAAGGUCCACCUCUGGUUCGGCACACAU', 'AUAUAUCAUUGUAUUUUGCCUCGUAUAAUUGUGGAGAUAUGGUCCACGAGUUUCUACCGGACGGCCGUAAACUGUUCGACUACGGGGGAAACACCUGGGAUG']
first 3 structures: ['(((((((..((((........))))((((((.......))))))....(((((.......))))))))))))....', '....((((((....((.((((((((..........)))))))).)).....)))))).........(((.((((((.((((.......)))))....)))).))))......', '.................((((((((...(((((((.......)))))))........(((((((.......)))))))..))))))))..............']


In [6]:
# extract mappings
nt_to_idx = val_ds['nt_to_idx']
struct_to_idx = val_ds['struct_to_idx']

openings = list("({[<") + [chr(ord("A") + i) for i in range(26)]
closing = list(">]})") + [chr(ord("a") + i) for i in range(26)]
struct_to_idx = {".": 0}
# Add mappings for basic brackets
bracket_pairs = list(zip(openings, closing))
for i, (open_b, close_b) in enumerate(bracket_pairs, start=1):
    struct_to_idx[open_b] = 7    # Odd numbers for opening
    struct_to_idx[close_b] = 8       # Even numbers for closing

# hardcoding strcut mappings
struct_to_idx["("] = 1
struct_to_idx[")"] = 2
struct_to_idx["{"] = 3
struct_to_idx["}"] = 4
struct_to_idx["["] = 5
struct_to_idx["]"] = 6
struct_to_idx["<"] = 7
struct_to_idx[">"] = 8

# account for unknown tokens
nt_to_idx['UNK'] = len(nt_to_idx)
struct_to_idx['UNK'] = 9 #len(struct_to_idx)

# dot-bracket tokens
valid_structure_tokens = ['.', "UNK"] + list("({[<") + list(">]})")
valid_label_indices = [struct_to_idx[token] for token in valid_structure_tokens]
print(f"valid labels: {sorted(valid_label_indices)}")

# # convert sequences and structures to indices
# tokenized_sequences = [[nt_to_idx[nt] for nt in seq] for seq in sequences]
# tokenized_structures = [[struct_to_idx[char] for char in struct] for struct in structures]

# # print
# print("first tokenized sequence:", tokenized_sequences[0])
# print("first tokenized structure:", tokenized_structures[0])

valid labels: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


In [7]:
set(struct_to_idx.values())

{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}

In [8]:
# # get unique nucleotides in the sequences
# unique_nucleotides = set("".join(sequences))
# print("Unique nucleotides in the dataset:", unique_nucleotides)

# # get nucleotides missing in nt_to_idx
# missing_nucleotides = [nt for nt in unique_nucleotides if nt not in nt_to_idx]
# print("Nucleotides missing in nt_to_idx:", missing_nucleotides)

In [9]:
# # include more bases and characters to nt_to_idx
# valid_bases = ['A', 'C', 'G', 'U', 'N', 'R', 'Y', 'S', 'W', 'K', 'M', 'B', 'D', 'H', 'V']
# additional_tokens = {'.', '~', '_', 'x', 'X', 'a', 'u', 'g', 'c'}  # lowercase tokens

# # add to nt_to_idx
# for nt in valid_bases + list(additional_tokens):
#     if nt not in nt_to_idx:
#         nt_to_idx[nt] = len(nt_to_idx)

# # add token for unknown characters
# nt_to_idx['UNK'] = len(nt_to_idx)

# print("updated nt_to_idx mapping:", nt_to_idx)

In [10]:
# load bert tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# add RNA-specific tokens
nucleotide_tokens = list(nt_to_idx.keys())
tokenizer.add_tokens(nucleotide_tokens)

# convert sequences to token ids
def tokenize_and_truncate(sequences, labels, max_length=512):
    tokenized_sequences = []
    tokenized_labels = []

    for seq, label in zip(sequences, labels):
        tokenized_seq = [nt_to_idx.get(nt, nt_to_idx['UNK']) for nt in seq]
        tokenized_label = [struct_to_idx.get(struct, struct_to_idx['UNK']) for struct in label]

        # truncate to max length
        tokenized_sequences.append(tokenized_seq[:max_length])
        tokenized_labels.append(tokenized_label[:max_length])

    return tokenized_sequences, tokenized_labels

tokenized_sequences, tokenized_structures = tokenize_and_truncate(sequences, structures, max_length=512)

# convert to tensors and pad
sequence_tensors = pad_sequence([torch.tensor(seq) for seq in tokenized_sequences], batch_first=True, padding_value=nt_to_idx['UNK'])
structure_tensors = pad_sequence([torch.tensor(label) for label in tokenized_structures], batch_first=True, padding_value=struct_to_idx['UNK'])

# create attention masks
attention_masks = (sequence_tensors != nt_to_idx['UNK']).long()

print(f"sequence tensors shape: {sequence_tensors.shape}")
print(f"structure tensors shape: {structure_tensors.shape}")
print(f"attention masks shape: {attention_masks.shape}")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

sequence tensors shape: torch.Size([5116, 512])
structure tensors shape: torch.Size([5116, 512])
attention masks shape: torch.Size([5116, 512])


In [11]:
# tensor dataset and data loader
val_dataset = TensorDataset(sequence_tensors, attention_masks, structure_tensors)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [12]:
# bert for token classification
model = BertForTokenClassification.from_pretrained("bert-base-uncased", num_labels=len(struct_to_idx))

# add new tokens to the model
model.resize_token_embeddings(len(tokenizer))

# move model to gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# optimizer and loss function
optimizer = AdamW(model.parameters(), lr=5e-5)
loss_fn = CrossEntropyLoss()

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


In [13]:
# training loop
epochs = 3
model.train()

for epoch in range(epochs):
    epoch_loss = 0
    for batch in tqdm(val_loader, desc=f"Epoch {epoch + 1}/{epochs}"):
        input_ids, attention_mask, labels = [b.to(device) for b in batch]

        # zero gradients
        optimizer.zero_grad()

        # forward pass
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits

        # compute loss
        loss = loss_fn(logits.view(-1, logits.shape[-1]), labels.view(-1))
        epoch_loss += loss.item()

        # backward pass and optimization
        loss.backward()
        optimizer.step()

    print(f"epoch {epoch + 1}, loss: {epoch_loss / len(val_loader)}")

Epoch 1/3: 100%|██████████| 160/160 [08:07<00:00,  3.05s/it]


epoch 1, loss: 0.7373357443138957


Epoch 2/3: 100%|██████████| 160/160 [08:10<00:00,  3.07s/it]


epoch 2, loss: 0.3723185620270669


Epoch 3/3: 100%|██████████| 160/160 [08:10<00:00,  3.07s/it]

epoch 3, loss: 0.35110589656978847





In [18]:
# evaluation loop
model.eval()
all_predictions = []
all_labels = []

with torch.no_grad():
    for batch in val_loader:
        input_ids, attention_mask, labels = [b.to(device) for b in batch]

        # forward pass
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits

        # get predictions
        predictions = torch.argmax(logits, dim=-1)

        # flatten predictions and labels
        for pred, label in zip(predictions.view(-1), labels.view(-1)):
            if label != struct_to_idx['UNK']:  # ignore padding
                all_predictions.append(pred.item())
                all_labels.append(label.item())

# compute accuracy and classification report
accuracy = accuracy_score(all_labels, all_predictions)
print(f"Validation Accuracy: {accuracy:.4f}")

# Fix the classification report issue
# Get valid labels (keys that actually exist in the dataset)
valid_labels = sorted(struct_to_idx.values())

print(classification_report(
    all_labels,
    all_predictions,
    labels=valid_labels,
    target_names=[key for key, val in struct_to_idx.items() if val in valid_labels]
))

Validation Accuracy: 0.5442


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


              precision    recall  f1-score   support

           .       0.56      0.81      0.66    416710
           (       0.56      0.29      0.38    241393
           >       0.49      0.37      0.42    229229
           {       0.00      0.00      0.00        71
           ]       0.00      0.00      0.00        71
           [       0.00      0.00      0.00      6671
           }       0.00      0.00      0.00      5122
           <       0.00      0.00      0.00         5
           )       0.00      0.00      0.00         5
           A       0.00      0.00      0.00         5
           a       0.00      0.00      0.00         5
           B       0.00      0.00      0.00         5
           b       0.00      0.00      0.00         5
           C       0.00      0.00      0.00         5
           c       0.00      0.00      0.00         5
           D       0.00      0.00      0.00         5
           d       0.00      0.00      0.00         5
           E       0.00    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [None]:
# print(f"number of model classes: {len(struct_to_idx)}")
# print(f"keys in struct_to_idx: {struct_to_idx.keys()}")

In [None]:
# # save the model and tokenizer
# model.save_pretrained("./fine_tuned_bert_model")
# tokenizer.save_pretrained("./fine_tuned_bert_model")