In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('..')
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [2]:
import rnacomp
import rnacomp.models
from rnacomp.fastai_fit import *
from rnacomp.dataset import LenMatchBatchSampler, DeviceDataLoader
from rnacomp.utils import seed_everything, MAE, loss_laplace
import gc



In [3]:
class CFG:
    path = Path("../data/")
    pathbb = Path("../data/Ribonanza_bpp_files")
    pathss = Path("../eda/train_ss_vienna_rna.parquet")
    split_id = Path('../eda/fold_split.csv')
    bs = 64
    num_workers = 12
    device = 'cuda'
    seed = 2023
    out = 'exp_10'
    dataset_name = 'RNA_DatasetBaselineSplitssV1'
    sn_train = False
    
    model_name = 'RNA_ModelV2SS'
    dim = 192
    depth = 12
    dim_head = 32
   
    epoch = 64
    lr = 5e-4
    wd = 0.05
    pct_start = 0.02
    
seed_everything(CFG.seed)
os.makedirs(CFG.out, exist_ok=True)

In [4]:
fns = list(CFG.pathbb.rglob("*.txt"))
ss = pd.read_parquet(CFG.pathss)[["sequence_id", "ss_full"]]
df = pd.read_parquet(CFG.path/'train_data.parquet')
split = pd.read_csv(CFG.split_id)
df = pd.merge(df, split, on='sequence_id')
df = pd.merge(df, ss, on='sequence_id')
df_train = df.query('is_train==True').reset_index(drop=True)
df_valid = df.query('is_train==False').reset_index(drop=True)

ds_train = getattr(rnacomp.dataset, CFG.dataset_name)(df_train, mode='train',sn_train=CFG.sn_train)
ds_train_len = getattr(rnacomp.dataset, CFG.dataset_name)(df_train, mode='train', mask_only=True,sn_train=CFG.sn_train)
sampler_train = torch.utils.data.RandomSampler(ds_train_len)
len_sampler_train = LenMatchBatchSampler(sampler_train, batch_size=CFG.bs,
            drop_last=True)
dl_train = DeviceDataLoader(torch.utils.data.DataLoader(ds_train, 
            batch_sampler=len_sampler_train, num_workers=CFG.num_workers,
            persistent_workers=True), CFG.device)


ds_val = getattr(rnacomp.dataset, CFG.dataset_name)(df_valid, mode='eval')
ds_val_len = getattr(rnacomp.dataset, CFG.dataset_name)(df_valid, mode='eval', mask_only=True)
sampler_val = torch.utils.data.SequentialSampler(ds_val_len)
len_sampler_val = LenMatchBatchSampler(sampler_val, batch_size=CFG.bs, 
               drop_last=False)
dl_val= DeviceDataLoader(torch.utils.data.DataLoader(ds_val, 
               batch_sampler=len_sampler_val, num_workers=CFG.num_workers), CFG.device)

data = DataLoaders(dl_train,dl_val)
gc.collect()

16

In [None]:
learn = Learner(data,
                getattr(rnacomp.models, CFG.model_name)().cuda(), 
                path = CFG.out, 
                loss_func=loss_laplace,
                cbs=[GradientClip(3.0),
                    CSVLogger(),
                    SaveModelCallback(monitor='mae',comp=np.less,at_end=True)],
                metrics=[MAE()]).to_fp16() 
learn.fit_one_cycle(CFG.epoch, lr_max=CFG.lr, wd=CFG.wd, pct_start=CFG.pct_start)

epoch,train_loss,valid_loss,mae,time
0,0.129364,0.150852,0.162502,13:24
1,0.125005,0.146319,0.155701,13:27
2,0.121534,0.138519,0.1479,13:24
3,0.120992,0.136938,0.146753,13:29
4,0.120525,0.136677,0.146051,13:28
5,0.120109,0.134987,0.144529,13:30
6,0.118847,0.133615,0.143586,13:31
7,0.120376,0.134623,0.143011,13:44
8,0.118422,0.133353,0.143125,13:33
9,0.117844,0.131863,0.142052,13:35


Better model found at epoch 0 with mae value: 0.16250160996929314.
Better model found at epoch 1 with mae value: 0.15570108656365705.
Better model found at epoch 2 with mae value: 0.14789972458082956.
Better model found at epoch 3 with mae value: 0.14675329386964994.
Better model found at epoch 4 with mae value: 0.14605135974653943.
Better model found at epoch 5 with mae value: 0.14452896208019458.
Better model found at epoch 6 with mae value: 0.14358586932288508.
Better model found at epoch 7 with mae value: 0.14301126081108123.
Better model found at epoch 9 with mae value: 0.14205235057333346.
Better model found at epoch 10 with mae value: 0.14199062850046595.
Better model found at epoch 11 with mae value: 0.14186941386148985.
Better model found at epoch 12 with mae value: 0.14126443611326503.
Better model found at epoch 14 with mae value: 0.14063776745227743.
Better model found at epoch 16 with mae value: 0.14018476660744686.
Better model found at epoch 20 with mae value: 0.13925806

In [5]:
res = dict()
for l in df_valid["L"].unique():
    ds_val = getattr(rnacomp.dataset, CFG.dataset_name)(df_valid.query("L==@l").copy(), mode='eval')
    ds_val_len = getattr(rnacomp.dataset, CFG.dataset_name)(df_valid.query("L==@l").copy(), mode='eval', mask_only=True)
    sampler_val = torch.utils.data.SequentialSampler(ds_val_len)
    len_sampler_val = LenMatchBatchSampler(sampler_val, batch_size=CFG.bs, 
                   drop_last=False)
    dl_val= DeviceDataLoader(torch.utils.data.DataLoader(ds_val, 
                   batch_sampler=len_sampler_val, num_workers=CFG.num_workers), CFG.device)

    data = DataLoaders(dl_val,dl_val)
    learn = Learner(data,
                    getattr(rnacomp.models, CFG.model_name)().cuda(), 
                    path = CFG.out, 
                    loss_func=loss_laplace,
                    metrics=[MAE()]).to_fp16() 
    learn.load('model')
    learn.eval()
    loss_, score_ = learn.validate()
    res[l]= score_

res = pd.DataFrame(pd.Series(res)).reset_index()
res.columns = ['L', 'mae']
res.sort_values(by="L")

  elif with_opt: warn("Saved filed doesn't contain an optimizer state.")


  elif with_opt: warn("Saved filed doesn't contain an optimizer state.")


  elif with_opt: warn("Saved filed doesn't contain an optimizer state.")


  elif with_opt: warn("Saved filed doesn't contain an optimizer state.")


  elif with_opt: warn("Saved filed doesn't contain an optimizer state.")


Unnamed: 0,L,mae
2,115,0.09983
3,155,0.108777
0,170,0.161112
1,177,0.132897
4,206,0.106686


In [6]:
res

Unnamed: 0,L,mae
0,170,0.161112
1,177,0.132897
2,115,0.09983
3,155,0.108777
4,206,0.106686


In [None]:
res