### Inports

In [1]:
import pandas as pd
import os, gc
import numpy as np
from sklearn.model_selection import KFold
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [2]:
fname = 'example0'
PATH = 'stanford-ribonanza-rna-folding-converted/'
OUT = './'
bs = 2
num_workers = 4
SEED = 2023
nfolds = 4
device = 'cuda' if torch.cuda.is_available() else 'cpu'

### Data

In [3]:
class RNA_Dataset(Dataset):
    def __init__(self, df, mode='train', seed=2023, fold=0, nfolds=4, 
                 mask_only=False, **kwargs):
        self.seq_map = {'A':0,'C':1,'G':2,'U':3}
        self.Lmax = 206
        df['L'] = df.sequence.apply(len)
        df_2A3 = df.loc[df.experiment_type=='2A3_MaP']
        df_DMS = df.loc[df.experiment_type=='DMS_MaP']
        
        split = list(KFold(n_splits=nfolds, random_state=seed, 
                shuffle=True).split(df_2A3))[fold][0 if mode=='train' else 1]
                
        df_2A3 = df_2A3.iloc[split].reset_index(drop=True)
        df_DMS = df_DMS.iloc[split].reset_index(drop=True)
        
        m = (df_2A3['SN_filter'].values > 0) & (df_DMS['SN_filter'].values > 0)
        df_2A3 = df_2A3.loc[m].reset_index(drop=True)
        df_DMS = df_DMS.loc[m].reset_index(drop=True)
        
        self.seq = df_2A3['sequence'].values
        self.L = df_2A3['L'].values
        
        self.react_2A3 = df_2A3[[c for c in df_2A3.columns if \
                                 'reactivity_0' in c]].values
        self.react_DMS = df_DMS[[c for c in df_DMS.columns if \
                                 'reactivity_0' in c]].values
        self.react_err_2A3 = df_2A3[[c for c in df_2A3.columns if \
                                 'reactivity_error_0' in c]].values
        self.react_err_DMS = df_DMS[[c for c in df_DMS.columns if \
                                'reactivity_error_0' in c]].values
        self.sn_2A3 = df_2A3['signal_to_noise'].values
        self.sn_DMS = df_DMS['signal_to_noise'].values
        self.mask_only = mask_only
        
    def __len__(self):
        return len(self.seq)  
    
    def __getitem__(self, idx):
        seq = self.seq[idx]
        if self.mask_only:
            mask = torch.zeros(self.Lmax, dtype=torch.bool)
            mask[:len(seq)] = True
            return {'mask':mask},{'mask':mask}
        seq = [self.seq_map[s] for s in seq]
        seq = np.array(seq)
        mask = torch.zeros(self.Lmax, dtype=torch.bool)
        mask[:len(seq)] = True
        seq = np.pad(seq,(0,self.Lmax-len(seq)))
        
        react = torch.from_numpy(np.stack([self.react_2A3[idx],
                                           self.react_DMS[idx]],-1))
        react_err = torch.from_numpy(np.stack([self.react_err_2A3[idx],
                                               self.react_err_DMS[idx]],-1))
        sn = torch.FloatTensor([self.sn_2A3[idx],self.sn_DMS[idx]])
        
        return {'seq':torch.from_numpy(seq), 'mask':mask}, \
               {'react':react, 'react_err':react_err,
                'sn':sn, 'mask':mask}
    
class LenMatchBatchSampler(torch.utils.data.BatchSampler):
    def __iter__(self):
        buckets = [[]] * 100
        yielded = 0

        for idx in self.sampler:
            s = self.sampler.data_source[idx]
            if isinstance(s,tuple): L = s[0]["mask"].sum()
            else: L = s["mask"].sum()
            L = max(1,L // 16) 
            if len(buckets[L]) == 0:  buckets[L] = []
            buckets[L].append(idx)
            
            if len(buckets[L]) == self.batch_size:
                batch = list(buckets[L])
                yield batch
                yielded += 1
                buckets[L] = []
                
        batch = []
        leftover = [idx for bucket in buckets for idx in bucket]

        for idx in leftover:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yielded += 1
                yield batch
                batch = []

        if len(batch) > 0 and not self.drop_last:
            yielded += 1
            yield batch
            
def dict_to(x, device='cuda'):
    return {k:x[k].to(device) for k in x}

def to_device(x, device='cuda'):
    return tuple(dict_to(e,device) for e in x)

class DeviceDataLoader:
    def __init__(self, dataloader, device='cuda'):
        self.dataloader = dataloader
        self.device = device
    
    def __len__(self):
        return len(self.dataloader)
    
    def __iter__(self):
        for batch in self.dataloader:
            yield tuple(dict_to(x, self.device) for x in batch)

In [4]:
df = pd.read_parquet(os.path.join(PATH,'train_data.parquet'))
fold = 0

In [5]:
dataset = RNA_Dataset(df, mode='train', fold=fold, nfolds=nfolds)

In [6]:
for data in dataset:
    print(data)
    break

({'seq': tensor([2, 2, 2, 0, 0, 1, 2, 0, 1, 3, 1, 2, 0, 2, 3, 0, 2, 0, 2, 3, 1, 2, 0, 0,
        0, 0, 0, 1, 3, 3, 3, 2, 0, 3, 0, 3, 2, 2, 0, 3, 3, 3, 0, 1, 3, 1, 1, 2,
        0, 2, 2, 0, 2, 0, 1, 2, 0, 0, 1, 3, 0, 1, 1, 0, 1, 2, 0, 0, 1, 0, 2, 2,
        2, 2, 0, 0, 0, 1, 3, 1, 3, 0, 1, 1, 1, 2, 3, 2, 2, 1, 2, 3, 1, 3, 1, 1,
        2, 3, 3, 3, 2, 0, 1, 2, 0, 2, 3, 0, 0, 2, 3, 1, 1, 3, 0, 0, 2, 3, 1, 0,
        0, 1, 0, 3, 2, 1, 1, 0, 2, 2, 3, 0, 3, 3, 2, 0, 1, 3, 3, 1, 2, 2, 3, 1,
        0, 0, 3, 0, 1, 0, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0,
        0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 'mask': tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True, 

In [7]:
data[0]['seq'].shape

torch.Size([206])

In [8]:
data[1]['react']

tensor([[        nan,         nan],
        [        nan,         nan],
        [        nan,         nan],
        [        nan,         nan],
        [        nan,         nan],
        [        nan,         nan],
        [        nan,         nan],
        [        nan,         nan],
        [        nan,         nan],
        [        nan,         nan],
        [        nan,         nan],
        [        nan,         nan],
        [        nan,         nan],
        [        nan,         nan],
        [        nan,         nan],
        [        nan,         nan],
        [        nan,         nan],
        [        nan,         nan],
        [        nan,         nan],
        [        nan,         nan],
        [        nan,         nan],
        [        nan,         nan],
        [        nan,         nan],
        [        nan,         nan],
        [        nan,         nan],
        [        nan,         nan],
        [ 1.5600e-01,  1.8000e-02],
        [ 4.3500e-01,  1.452

### Model

In [15]:
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim=16, M=10000):
        super().__init__()
        self.dim = dim
        self.M = M

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(self.M) / half_dim
        emb = torch.exp(torch.arange(half_dim, device=device) * (-emb))
        emb = x[...,None] * emb[None,...]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

class RNA_Model(nn.Module):
    def __init__(self, dim=192, depth=12, head_size=32, **kwargs):
        super().__init__()
        self.emb = nn.Embedding(4,dim)
        self.pos_enc = SinusoidalPosEmb(dim)
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=dim, nhead=dim//head_size, dim_feedforward=4*dim,
                dropout=0.1, activation=nn.GELU(), batch_first=True, norm_first=True), depth)
        self.proj_out = nn.Linear(dim,2)
    
    def forward(self, x0):
        mask = x0['mask']
        Lmax = mask.sum(-1).max()
        mask = mask[:,:Lmax]
        x = x0['seq'][:,:Lmax]
        
        pos = torch.arange(Lmax, device=x.device).unsqueeze(0)
        pos = self.pos_enc(pos)
        x = self.emb(x)
        x = x + pos
        
        x = self.transformer(x, src_key_padding_mask=~mask)
        x = self.transformer(x, src_key_padding_mask=~mask)
        x = self.transformer(x, src_key_padding_mask=~mask)

        x = x + pos

        x = self.proj_out(x)
        
        return x

In [16]:
fold = 0

In [17]:
ds_train = RNA_Dataset(df, mode='train', fold=fold, nfolds=nfolds)
ds_train_len = RNA_Dataset(df, mode='train', fold=fold, 
            nfolds=nfolds, mask_only=True)
sampler_train = torch.utils.data.RandomSampler(ds_train_len)
len_sampler_train = LenMatchBatchSampler(sampler_train, batch_size=bs,
            drop_last=True)
dl_train = DeviceDataLoader(torch.utils.data.DataLoader(ds_train, 
            batch_sampler=len_sampler_train, num_workers=num_workers,
            persistent_workers=True), device)

ds_val = RNA_Dataset(df, mode='eval', fold=fold, nfolds=nfolds)
ds_val_len = RNA_Dataset(df, mode='eval', fold=fold, nfolds=nfolds, 
           mask_only=True)
sampler_val = torch.utils.data.SequentialSampler(ds_val_len)
len_sampler_val = LenMatchBatchSampler(sampler_val, batch_size=bs, 
           drop_last=False)
dl_val= DeviceDataLoader(torch.utils.data.DataLoader(ds_val, 
           batch_sampler=len_sampler_val, num_workers=num_workers), device)
gc.collect()



1733

In [18]:
model = RNA_Model()   
model = model.to(device)

In [19]:
for data in dl_train:
    break

In [20]:
model.forward(data[0]).shape

torch.Size([2, 177, 2])