In [1]:
import torch
import numpy as np
import pandas as pd
from torch import Tensor, nn
from tqdm.notebook import tqdm
from torchaudio.models import Conformer
from torch.utils.data import Dataset, DataLoader
from transformers import Trainer, TrainingArguments
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'



In [2]:
df = pd.read_parquet('/kaggle/input/stanford-ribonanza-rna-folding-converted/train_data.parquet')
df = df.replace(np.nan, 0.0)
df = df.drop([c for c in df.columns if 'reactivity_error' in c], axis='columns')
df['L'] = df.sequence.apply(len)
df = df[df.SN_filter == 1]
df = df.sort_values(by='signal_to_noise', ignore_index=True, ascending=False)
df = df.drop_duplicates(['sequence_id', 'experiment_type'])
df = df[df.sequence_id.duplicated(keep=False) == True].reset_index(drop=True)
df = df.drop(['dataset_name', 'reads', 'SN_filter', 'signal_to_noise'], axis='columns')

In [3]:
class RNA_Dataset(Dataset):
    def __init__(self, df, size=206):
        self.seq_map = {'A':1,'C':2,'G':3,'U':4}
        self.size = size
        
        df_2A3 = df.loc[df.experiment_type=='2A3_MaP'].reset_index(drop=True)
        df_DMS = df.loc[df.experiment_type=='DMS_MaP'].reset_index(drop=True)
        
        self.seq = df_2A3['sequence'].values
        self.react_2A3 = torch.from_numpy(df_2A3[[c for c in df_2A3.columns if 'reactivity_0' in c]].values)
        self.react_DMS = torch.from_numpy(df_DMS[[c for c in df_DMS.columns if 'reactivity_0' in c]].values)

        df = None
        df_2A3 = None
        df_DMS = None

    def __len__(self):
        return len(self.seq)
    
    def __getitem__(self, idx):
        seq = self.seq[idx]
        length = len(seq)
        seq = [self.seq_map[s] for s in seq]
        seq = torch.Tensor(seq).to(torch.int32)
        mask = torch.zeros(self.size, dtype=torch.bool)
        mask[:len(seq)] = 1
        
        output = dict()
        output['input_ids'] = nn.functional.pad(seq,(0,self.size-length))
        output['length'] = length
        output['mask'] = mask
        output['labels'] = nn.functional.pad(torch.stack([self.react_DMS[idx], self.react_2A3[idx]],-1), (0, 0, 0, 2))
        
        return output

In [4]:
class RNA_testset(Dataset):
    def __init__(self, df, size=457):
        self.size = size
        self.df = df
        self.seq_map = {'A':1,'C':2,'G':3,'U':4}

        self.seq = df['sequence'].values
        self.id_min = df.id_min

    def __len__(self):
        return len(self.seq)
    
    def __getitem__(self, idx):
        seq = self.seq[idx]
        length = len(seq)
        seq = [self.seq_map[s] for s in seq]
        seq = torch.Tensor(seq).to(torch.int32)
        mask = torch.zeros(self.size, dtype=torch.bool)
        mask[:len(seq)] = 1
        seq = nn.functional.pad(seq,(0,self.size-length))

        return seq, length, mask, self.id_min.iloc[idx]

In [5]:
class conformer(nn.Module) :
    def __init__(self, kernel_size: int, num_channels: int, num_layers: int, feed_forward = 1024, num_heads = 16) :
        super().__init__()

        self.postional_embedding = nn.Sequential(nn.Embedding(457, num_channels//4), nn.Sigmoid(), nn.Linear(num_channels//4, num_channels//2), nn.ReLU())        
        self.base_embedding = nn.Sequential(nn.Embedding(5, num_channels//4), nn.Sigmoid(), nn.Linear(num_channels//4, num_channels//2), nn.ReLU())
        
        self.feed_forward = nn.Sequential(nn.Linear(num_channels, num_channels*2), nn.Sigmoid(), nn.Linear(num_channels*2, num_channels), nn.ReLU())
        
        self.encoder =  Conformer(num_channels, num_heads, feed_forward, num_layers, kernel_size)     
        self.result = nn.Sequential(nn.Sigmoid(), nn.Linear(num_channels, num_channels//2), nn.Linear(num_channels//2, num_channels//4), 
                                    nn.ReLU(), nn.Linear(num_channels//4, num_channels//8), nn.ReLU(), nn.Linear(num_channels//8, 2), nn.ReLU())
        self.loss = nn.L1Loss()
        
    def forward(self, input_ids, length, mask, labels=None) :

        mask = torch.unsqueeze(mask, dim=-1)
        max_len = torch.max(length)
        mask = mask[:, :max_len]
        input_ids = input_ids[:, :max_len]
        
        positional_embedding = self.postional_embedding(input_ids)*mask
        base_embedding = self.base_embedding(input_ids)*mask
        embedding = torch.concat((positional_embedding, base_embedding), -1)
        feed_forward = self.feed_forward(embedding) + embedding
        
        encoded, _ = self.encoder(feed_forward, length)
        output = self.result(encoded)*mask
        
        if labels is not None :

            y = labels[:, :max_len]
            cover = y != 0
            output *= cover
            loss = torch.unsqueeze( self.loss(output, y), dim=0)
            return loss
        
        return output

In [6]:
torch.cuda.empty_cache()
training_args = TrainingArguments(
    output_dir="/kaggle/working/",
    report_to = 'none',
    lr_scheduler_type='cosine',
    per_device_train_batch_size=32,
    gradient_accumulation_steps=1,
    learning_rate=5e-5,
    weight_decay = 1e-5,
    warmup_steps=50,
    num_train_epochs=20,
    save_strategy='epoch',
    fp16=True,
    logging_steps=10,
    save_steps = 1.0,
    torch_compile=True
)
dataset = RNA_Dataset(df, size=206)
model = conformer(kernel_size = 7, num_channels = 256, num_layers =20, num_heads = 16, feed_forward = 1024)
trainer = Trainer(model = model, args=training_args, train_dataset = dataset)

The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here.


In [7]:
trainer.train()

Step,Training Loss
10,0.2168
20,0.2114
30,0.2043
40,0.1962
50,0.2007
60,0.1821
70,0.183


KeyboardInterrupt: 