In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('..')
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
import rnacomp
import rnacomp.models
from rnacomp.fastai_fit import *
from rnacomp.dataset import LenMatchBatchSampler, DeviceDataLoader
from rnacomp.utils import seed_everything, MAE, loss_laplace
import gc
import wandb 
from fastai.callback.wandb import WandbCallback
from fastxtend.vision.all import EMACallback

from itertools import combinations



In [3]:
class CFG:
    path = Path("../data/")
    pathbb = Path("../data/Ribonanza_bpp_files")
    pathss = Path("../eda/train_ss_vienna_rna.parquet")
    split_id = Path('../eda/fold_split.csv')
    bs = 64
    num_workers = 4
    device = 'cuda'
    seed = 2023
    out = 'exp_32_ft'
    dataset_name = 'RNA_DatasetBaselineSplitssbppV5BTTA'

    
    model_name = 'RNA_ModelV25'
    model_kwargs = dict(dim=192 * 2,
        depth=4,
        head_size=32,
        drop_pat_dropout=0.2,
        dropout=0.2,
        bpp_transfomer_depth = 4)

    sn_train = True
    epoch = 16
    lr = 5e-6
    wd = 0.05
    pct_start = 0.01
    
    md_wt = 'exp_32_ft/models/model.pth'
    
seed_everything(CFG.seed)


# wandb.init(
#     # set the wandb project where this run will be logged
#     project="my-awesome-project",
    
#     # track hyperparameters and run metadata
#     config={
#     "learning_rate": 0.02,
#     "architecture": "CNN",
#     "dataset": "CIFAR-100",
#     "epochs": 10,
#     }
# )

In [4]:
def class_to_dict(cls):
    # Create a dictionary from the class attributes
    return {key: value for key, value in cls.__dict__.items() if not key.startswith("__") and not callable(value)}

In [5]:
fns = list(CFG.pathbb.rglob("*.txt"))
bpp_df = pd.DataFrame({"bpp": fns})
bpp_df['sequence_id'] = bpp_df['bpp'].apply(lambda x: x.stem)
bpp_df.drop_duplicates(subset = 'sequence_id', inplace=True)
ss = pd.read_parquet(CFG.pathss)[["sequence_id", "ss_full"]]
df = pd.read_parquet(CFG.path/'train_data.parquet')
split = pd.read_csv(CFG.split_id)
df = pd.merge(df, split, on='sequence_id')
df = pd.merge(df, bpp_df, on='sequence_id')
df = pd.merge(df, ss,  on='sequence_id')
#df = df.query("SN_filter==1").reset_index(drop=True)
df_train = df.query('is_train==True').reset_index(drop=True)
df_valid = df.query('is_train==False').reset_index(drop=True)


ds_train = getattr(rnacomp.dataset, CFG.dataset_name)(df_train, mode='train',sn_train=CFG.sn_train)
dl_train = DeviceDataLoader(torch.utils.data.DataLoader(ds_train, 
            num_workers=CFG.num_workers, batch_size = CFG.bs, 
            persistent_workers=True, shuffle=True), CFG.device)


ds_val = getattr(rnacomp.dataset, CFG.dataset_name)(df_valid, mode='eval')
dl_val= DeviceDataLoader(torch.utils.data.DataLoader(ds_val, 
               batch_size = CFG.bs, num_workers=CFG.num_workers), CFG.device)

data = DataLoaders(dl_train,dl_val)
del bpp_df
del ss
del split
del df
gc.collect()


0

In [8]:
def get_combinations(lst):
    all_combinations = []
    for r in range(1, len(lst) + 1):
        all_combinations.extend(combinations(lst, r))
    return all_combinations


class BppModelWrapper(nn.Module):
    def __init__(self, md=None,list_of_idx = [1, 2, 3]):
        super().__init__()
        self.combs = get_combinations(list_of_idx)
        self.md = md
        
    def forward(self, batch):
        all_bpps = batch['bb_matrix_full_prob_extra'].clone()
        all_bpps = [all_bpps[:, i].mean(1) for i in self.combs]
        res = []
        for bpp_combo in all_bpps:
            batch['bb_matrix_full_prob_extra'] = bpp_combo
            res.append(self.md(batch))
        return torch.stack(res).mean(0)
        
        

In [9]:

learn = Learner(data,
                BppModelWrapper(getattr(rnacomp.models, CFG.model_name)(**CFG.model_kwargs).cuda()), 
                path = CFG.out, 
                loss_func=loss_laplace,
                metrics=[MAE()]).to_fp16() 
learn.model.md.load_state_dict(torch.load(CFG.md_wt))
learn.model.eval();
learn.validate()


(#2) [0.11581036448478699,0.1249439363984313]

In [None]:
res = dict()
for l in df_valid["L"].unique():
    ds_val = getattr(rnacomp.dataset, CFG.dataset_name)(df_valid.query("L==@l").copy(), mode='eval')
    ds_val_len = getattr(rnacomp.dataset, CFG.dataset_name)(df_valid.query("L==@l").copy(), mode='eval', mask_only=True)
    sampler_val = torch.utils.data.SequentialSampler(ds_val_len)
    len_sampler_val = LenMatchBatchSampler(sampler_val, batch_size=CFG.bs, 
                   drop_last=False)
    dl_val= DeviceDataLoader(torch.utils.data.DataLoader(ds_val, 
                   batch_sampler=len_sampler_val, num_workers=CFG.num_workers), CFG.device)

    data = DataLoaders(dl_val,dl_val)
    learn = Learner(data,
                    BppModelWrapper(getattr(rnacomp.models, CFG.model_name)(**CFG.model_kwargs).cuda()),
                    path = CFG.out, 
                    loss_func=loss_laplace,
                    metrics=[MAE()]).to_fp16() 
    learn.model.md.load_state_dict(torch.load(CFG.md_wt))
    learn.eval()
    loss_, score_ = learn.validate()
    res[l]= score_

res = pd.DataFrame(pd.Series(res)).reset_index()
res.columns = ['L', 'mae']
res.sort_values(by="L")

In [None]:
res

In [None]:
res