<a href="https://colab.research.google.com/github/WanxinBao/Bao-s-dataset/blob/main/RNAFM_tutorial_code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.
import kagglehub
team93_rnafm_tutorial_path = kagglehub.dataset_download('team93/rnafm-tutorial')

print('Data source import complete.')


# RNA-FM Tutorial

### RNA-FM

An RNA foundation model (RNA-FM) was pre-trained on all the 23 million non-coding RNA sequences through self-supervised learning. With this approach, the pre-trained RNA-FM could infer sequential and evolutionary information of non-coding RNAs without using any labels.

Reference:
Chen, Jiayang, et al. "Interpretable RNA foundation model from unannotated data for highly accurate RNA structure and function predictions." arXiv preprint arXiv:2204.00300 (2022).

Link:

Paper: https://arxiv.org/abs/2204.00300

GitHub: https://github.com/ml4bio/RNA-FM

### Workflow of our tutorial

**Preparation**
1. install the RNA-FM package
2. load the necessary libraries

**Task 1. RNA family clustering**

Goal: to demonstrate that RNA-FM embeddings contain biologically meaningful information

1. read RNA sequences for each family from FASTA files
2. generate the RNA-FM embeddings for each sequence
3. t-SNE dimension reduction on the generated embeddings
4. plot the embeddings in the 2D space

**Task 2. RNA type classification**

Goal: to demonstrate how to use RNA-FM for downstream applications

1. read RNA sequences for each type from a FASTA file
2. generate RNA-FM embeddings for each sequence
3. build the dataset and model
4. train and validate the model
5. test the model on a dataset excluded from training

## Install RNA-FM

In [None]:
!python --version

Note: If the following pip installation does not work, it is most likely due to the Internet restrictions imposed by Kaggle. To resolve this, you would have to verify your account with a phone number so that Internet connection can be enabled.

To summarize:
verfiy, refresh, then open this notebook and try again.

In [None]:
!pip install rna-fm

!pip install biopython

If pip install fails to install the required packages, we can also uncomment the following cell to install it from source.

In [None]:
# !git clone https://github.com/ml4bio/RNA-FM.git

# !pwd
# !ls
# %cd ./RNA-FM

# !python setup.py install

In [None]:
import fm  # for development with RNA-FM

from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import numpy as np

from Bio import SeqIO  # for file parsing

from sklearn.manifold import TSNE  # for dimension reduction

from sklearn.model_selection import train_test_split  # for splitting train/val/test

from tqdm.notebook import tqdm  # for showing progress

import matplotlib.pyplot as plt

In [None]:
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

print(f'using {device} device')

data_dir = '../input/rnafm-tutorial/'
!ls '../input/rnafm-tutorial/'

## Task 1. RNA Family Clustering

### Load data

In [None]:
# load fasta data
fasta_paths = list(Path(data_dir).glob('RF*.fasta'))
fasta_paths.sort()
print(fasta_paths)

In [None]:
rfam_list = []  # list of RNA families

seqs = []  # list of two-element tuples [(sequence ID, sequence),]
labels = []  # list of labels correspond to each entry in the seqs list, the labels are the RNA families

for fasta_path in fasta_paths:
    rfam = fasta_path.stem
    rfam_list.append(rfam)
    print(rfam)

    records = list(SeqIO.parse(fasta_path, 'fasta'))

    fasta_seqs = [str(record.seq) for record in records]
    fasta_seq_names = [record.id for record in records]

    seqs += [(seq_name, seq) for seq_name, seq in zip(fasta_seq_names, fasta_seqs)]

    labels += [rfam] * len(fasta_seq_names)

    print(len(seqs), len(labels))

In [None]:
# examine the data
print(seqs[:2])

### Load the pretrained model

In [None]:
# !gdown 1zflX5hHTxuwqcZm6A1npq7ubP8m7LdNX

In [None]:
# Load RNA-FM model
fm_model, alphabet = fm.pretrained.rna_fm_t12(Path(data_dir, 'RNA-FM_pretrained.pth'))
batch_converter = alphabet.get_batch_converter()

fm_model.to(device)  # use GPU if available

fm_model.eval()  # disables dropout for deterministic results

### Retrieve RNA-FM embeddings

In [None]:
chunk_size = 20

# pre-allocate the space to save memory
token_embeddings = np.zeros((len(labels), 1024, 640))

# divide all the sequences into chunks for processing due to the GPU memory limit
for i in tqdm(range(0, len(seqs), chunk_size)):
    data = seqs[i:i+chunk_size]

    batch_labels, batch_strs, batch_tokens = batch_converter(data)

    # use GPU
    with torch.no_grad():
        results = fm_model(batch_tokens.to(device), repr_layers=[12])

    emb = results['representations'][12].cpu().numpy()

    token_embeddings[i:i+chunk_size, :emb.shape[1], :] = emb


print(token_embeddings.shape)

### Dimension reduction

In [None]:
token_embeddings = np.mean(token_embeddings, axis=1)

print(token_embeddings.shape)

In [None]:
# t-SNE
tsne = TSNE(n_components=2, random_state=42)  # n_components is the dimension of the reduced data

embeddings = tsne.fit_transform(token_embeddings)

print(embeddings.shape)

### Visualization

In [None]:
colors = ['tab:blue', 'tab:orange', 'tab:red']
rfam_dict = {'RF00001': '5S_rRNA', 'RF00005': 'tRNA', 'RF00010': 'RNaseP_bact_a'}  # map the Rfam index to name

plt.figure(figsize=(8, 8))

for i, label in enumerate(sorted(list(set(labels)))):
    # find the data points corresponding to the current label
    indices = [j for j, l in enumerate(labels) if l == label]
    plt.scatter(embeddings[indices, 0], embeddings[indices, 1], color=colors[i], s=5, alpha=0.5, label=rfam_dict[label])


plt.legend()
plt.xticks([])
plt.yticks([])

plt.show()

## Task 2. RNA Type Classification

### Load data

In [None]:
fasta_path = Path(data_dir, 'format_rnacentral_active.100.sample-Max50.fasta')

records = list(SeqIO.parse(fasta_path, 'fasta'))

fasta_seqs = [str(record.seq) for record in records]
fasta_seq_names = [record.id for record in records]

print(len(fasta_seqs), len(fasta_seq_names))

labels = [record.description.split()[1] for record in records]

seqs = [(seq_name, seq) for seq_name, seq in zip(fasta_seq_names, fasta_seqs)]

num_class = len(set(labels))

print(len(seqs), len(labels))
print(f'number of classes: {len(set(labels))}')

label_to_num = {'miRNA': 0, 'snRNA': 1, 'other': 2, 'hammerhead_ribozyme': 3,
                'telomerase_RNA': 4, 'antisense_RNA': 5, 'precursor_RNA': 6,
                'tRNA': 7, 'snoRNA': 8, 'RNase_P_RNA': 9, 'pre_miRNA': 10,
                'misc_RNA': 11, 'rRNA': 12, 'siRNA': 13, 'vault_RNA': 14,
                'autocatalytically_spliced_intron': 15, 'guide_RNA': 16,
                'Y_RNA': 17, 'scRNA': 18, 'sRNA': 19, 'scaRNA': 20,
                'RNase_MRP_RNA': 21, 'tmRNA': 22, 'lncRNA': 23, 'ncRNA': 24,
                'piRNA': 25, 'ribozyme': 26, 'SRP_RNA': 27}

labels = [label_to_num[label] for label in labels]

### Retrieve RNA-FM embeddings

You don't need to download it again if you have already done so for the previous task.

In [None]:
# !gdown 1zflX5hHTxuwqcZm6A1npq7ubP8m7LdNX  # for Colab only

In [None]:
# Load RNA-FM model
fm_model, alphabet = fm.pretrained.rna_fm_t12(Path(data_dir, 'RNA-FM_pretrained.pth'))
batch_converter = alphabet.get_batch_converter()

fm_model.to(device)

fm_model.eval()  # disables dropout for deterministic results

In [None]:
chunk_size = 50

# pre-allocate the space to save memory
token_embeddings = np.zeros((len(labels), 1024, 640))

# divide all the sequences into chunks for processing due to the GPU memory limit
for i in tqdm(range(0, len(seqs), chunk_size)):
    data = seqs[i:i+chunk_size]

    batch_labels, batch_strs, batch_tokens = batch_converter(data)

    # use GPU
    with torch.no_grad():
        results = fm_model(batch_tokens.to(device), repr_layers=[12])

    emb = results["representations"][12].cpu().numpy()

    token_embeddings[i:i+chunk_size, :emb.shape[1], :] = emb


print(token_embeddings.shape)

### Construct the dataset and classifier

In [None]:
class RNATypeDataset(Dataset):
    def __init__(self, embeddings, labels):
        self.embeddings = embeddings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        # use the mean of the RNA-FM embedding along the sequence dimension
        # so that all the embeddings are converted from (L, 640) -> (640,)
        return np.mean(self.embeddings[idx], axis=0), self.labels[idx]

In [None]:
class RNATypeClassifier(nn.Module):
    def __init__(self, num_class):
        super().__init__()
        self.fc = nn.Linear(640, num_class)

    def forward(self, x):
        x = self.fc(x)

        return x

In [None]:
# each class have 50 samples, use 30 for train, 10 for val, and 10 for test (achieved by using stratify)
x_train_val, x_test, y_train_val, y_test = train_test_split(token_embeddings, labels,
                                                            test_size=0.2, random_state=42, stratify=labels)

x_train, x_val, y_train, y_val = train_test_split(x_train_val, y_train_val,
                                                  test_size=0.2, random_state=42, stratify=y_train_val)

print(x_train.shape, x_val.shape, x_test.shape)

In [None]:
def class_distribution(labels):
    classes, counts = np.unique(labels, return_counts=True)
    distribution = counts / counts.sum()
    return dict(zip(classes, distribution))


train_dist = class_distribution(y_train)
val_dist = class_distribution(y_val)
test_dist = class_distribution(y_test)

print(train_dist)
print(val_dist)
print(test_dist)

In [None]:
# hyper-parameters

batch_size = 8
lr = 1e-3
epochs = 100

In [None]:
train_dataset = RNATypeDataset(x_train, y_train)
val_dataset = RNATypeDataset(x_val, y_val)
test_dataset = RNATypeDataset(x_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
model = RNATypeClassifier(num_class).to(device)
print(model)

criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

### Train the model

In [None]:
max_val_acc = -1
best_epoch = -1

train_loss_history = []
val_loss_history = []

train_acc_history = []
val_acc_history = []

for epoch in tqdm(range(epochs)):

    # train the model
    train_losses = []
    train_preds = []
    train_targets = []

    model.train()

    for batch in train_loader:
        x, y = batch
        x, y = x.to(device).float(), y.to(device).long()

        # no need to apply the softmax function since it has been included in the loss function
        y_pred = model(x)

        # y_pred: (B, C) with class probabilities, y shape: (B,) with class indices
        loss = criterion(y_pred, y)

        train_losses.append(loss.item())
        train_preds.append(torch.max(y_pred.detach(),1)[1])
        train_targets.append(y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # validate the model
    val_losses = []
    val_preds = []
    val_targets = []

    model.eval()

    for batch in val_loader:
        x, y = batch
        x, y = x.to(device).float(), y.to(device).long()

        y_pred = model(x)

        # y_pred: (B, C) with class probabilities, y shape: (B,) with class indices
        loss = criterion(y_pred, y)

        val_losses.append(loss.item())
        val_preds.append(torch.max(y_pred.detach(),1)[1])
        val_targets.append(y)

    # calculate the accuracy
    train_preds = torch.cat(train_preds, dim=0)
    train_targets = torch.cat(train_targets, dim=0)
    train_acc = (train_preds == train_targets).float().mean().cpu()

    val_preds = torch.cat(val_preds, dim=0)
    val_targets = torch.cat(val_targets, dim=0)
    val_acc = (val_preds == val_targets).float().mean().cpu()

    train_acc_history.append(train_acc)
    val_acc_history.append(val_acc)

    # save the model checkpoint for the best validation accuracy
    if val_acc > max_val_acc:
        torch.save({'model_state_dict': model.state_dict()}, 'rna_type_checkpoint.pt')
        best_epoch = epoch
        max_val_acc = val_acc

    # show intermediate steps
    if epoch % 20 == 0:
        tqdm.write(f'epoch {epoch}/{epochs}: train loss={np.mean(train_loss_history):.6f}, '
                   f'train acc={train_acc:.6f}, '
                   f'val loss={np.mean(val_loss_history):.6f}, '
                   f'val acc={val_acc:.6f}')

    train_loss_history.append(np.mean(train_losses))
    val_loss_history.append(np.mean(val_losses))

### Visualize training results

In [None]:
plt.figure(figsize=(8, 6))

plt.plot(train_loss_history, label='train loss')
plt.plot(val_loss_history, label='val loss')

# the epoch with best validation loss
plt.axvline(x=best_epoch, color='r', linestyle='--', alpha=0.8)

plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Loss History')

plt.legend()

plt.show()

In [None]:
plt.figure(figsize=(8, 6))

plt.plot(train_acc_history, label='train accuracy')
plt.plot(val_acc_history, label='val accuracy')

# the epoch with best validation accuracy
plt.axvline(x=best_epoch, color='r', linestyle='--', alpha=0.8)

plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Accuracy History')

plt.legend()

plt.show()

### Test the model

In [None]:
# test the model
test_preds = []

model.load_state_dict(torch.load('rna_type_checkpoint.pt')['model_state_dict'])

model.eval()

for batch in test_loader:
    x, y = batch
    x, y = x.to(device).float(), y.to(device).long()

    output = model(x)

    _, y_pred = torch.max(output.data, 1)  # argmax in y_pred
    # print(y_pred.shape)

    test_preds.append(y_pred.cpu().numpy())


test_preds = np.concatenate(test_preds)

total = len(y_test)
correct = np.sum(test_preds == y_test)

print(f'total number of test data: {total}, correct={correct}, test acc={correct/total:.4f}')

### Applying RNA-FM for other intersting downstream tasks?

The general workflow is the same as above:
- get your sequence,
- generate the FM embeddings, and
- use the FM embeddings as inputs to your own downstream model.

In the above example, we have only used one fully-connected layer in our downstream model. However, as the problem becomes more complex and the dataset's size grows, we can also adopt more sophisticated models or combine the embeddings with other data to further improve the performance.