In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('..')
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
os.environ["MKL_NUM_THREADS"] = "8" 
os.environ["NUMEXPR_NUM_THREADS"] = "8" 
os.environ["OMP_NUM_THREADS"] = "8" 

In [2]:
import rnacomp
import rnacomp.models
from rnacomp.fastai_fit import *
from rnacomp.dataset import LenMatchBatchSampler, DeviceDataLoader
from rnacomp.utils import seed_everything, MAE, loss_laplace, LossDict, MAEDict
import gc
import wandb 
from fastai.callback.wandb import WandbCallback
from fastxtend.vision.all import EMACallback

In [3]:
class CFG:
    path = Path("../data/")
    pathbb = Path("../data/Ribonanza_bpp_files")
    pathss = Path("../eda/train_ss_vienna_rna.parquet")
    split_id = Path('../eda/fold_split.csv')
    path_extra = Path('../data/rmdb_data.v1.3.0.csv')
    bs = 32
    num_workers = 4
    device = 'cuda'
    seed = 2023
    out = 'exp_32_psd_v3_ex_ft'
    dataset_name = 'RNA_DatasetBaselineSplitssbppV6SAVEDwithFM'
    dataset_external = 'RNA_DatasetEXV0'
    sn_train = False

    
    model_name = 'RNA_ModelV25External'
    model_kwargs = dict(dim=192 * 2,
        depth=4,
        head_size=32,
        drop_pat_dropout=0.2,
        dropout=0.2,
        bpp_transfomer_depth = 4)

    epoch = 9
    lr = 4e-5
    wd = 0.05
    pct_start = 0.05
    
    md_wt = 'exp_32_psd_v3/models/model.pth'
    
seed_everything(CFG.seed)
os.makedirs(CFG.out, exist_ok=True)

# wandb.init(
#     # set the wandb project where this run will be logged
#     project="my-awesome-project",
    
#     # track hyperparameters and run metadata
#     config={
#     "learning_rate": 0.02,
#     "architecture": "CNN",
#     "dataset": "CIFAR-100",
#     "epochs": 10,
#     }
# )

In [4]:
def class_to_dict(cls):
    # Create a dictionary from the class attributes
    return {key: value for key, value in cls.__dict__.items() if not key.startswith("__") and not callable(value)}

In [5]:



split = pd.read_csv(CFG.split_id)
df = pd.read_parquet(CFG.path/'train_corrected.parquet')
df = pd.merge(df, split, on='sequence_id')
#df = df.query("SN_filter==1").reset_index(drop=True)
df_train = df.query('is_train==True').reset_index(drop=True)
df_valid = df.query('is_train==False').reset_index(drop=True)



ds_train = torch.utils.data.ConcatDataset([getattr(rnacomp.dataset, CFG.dataset_name)(df_train, mode='train',sn_train=CFG.sn_train, Lmax=433), 
                                            getattr(rnacomp.dataset, CFG.dataset_external)(pd.read_csv(CFG.path_extra),
                                                      mode='train',
                                                      sn_train=CFG.sn_train, 
                                                      repeat=4)])

dl_train = DeviceDataLoader(torch.utils.data.DataLoader(ds_train, 
            batch_size=CFG.bs,
            drop_last=True, 
            shuffle =True,
            num_workers=CFG.num_workers,
            persistent_workers=True), CFG.device)


ds_val = getattr(rnacomp.dataset, CFG.dataset_name)(df_valid, mode='eval')
dl_val= DeviceDataLoader(torch.utils.data.DataLoader(ds_val,  
                                                     batch_size=CFG.bs, 
                                                       drop_last=False,
                                                     num_workers=CFG.num_workers),
                         CFG.device)


data = DataLoaders(dl_train,dl_val)
del split
del df
gc.collect()



59

In [6]:
#the training was interepeted on epoch 3 
#so i had to restart from scratch but loading weights from epoch 3
learn = Learner(data,
                getattr(rnacomp.models, CFG.model_name)(**CFG.model_kwargs).cuda(), 
                path = CFG.out, 
                loss_func=LossDict(),
                cbs=[GradientClip(3.0),
                    #WandbCallback(log_preds=False),
                    CSVLogger(),
                     EMACallback(replace_weights=True),
                    SaveModelCallback(monitor='mae_dict',comp=np.less,at_end=True)],
                metrics=[MAEDict()]).to_fp16() 
learn.model.load_state_dict(torch.load(CFG.md_wt), strict=False)
learn.fit_one_cycle(CFG.epoch, lr_max=CFG.lr, wd=CFG.wd, pct_start=CFG.pct_start)
wandb.finish()




epoch,train_loss,valid_loss,mae_dict,time
0,0.242033,0.118288,0.125524,1:35:07
1,0.22873,0.118072,0.125382,1:35:29
2,0.22884,0.117948,0.12528,1:35:45
3,0.225624,0.117886,0.125236,1:35:56
4,0.225496,0.117798,0.125142,1:35:11
5,0.22241,0.117762,0.1251,1:35:19
6,0.216657,0.117732,0.125064,1:35:34
7,0.218843,0.117748,0.125074,1:35:19
8,0.21506,0.117749,0.125075,1:35:40


Better model found at epoch 0 with mae_dict value: 0.125523548437544.
Better model found at epoch 1 with mae_dict value: 0.12538198475967718.
Better model found at epoch 2 with mae_dict value: 0.12527960232412202.
Better model found at epoch 3 with mae_dict value: 0.1252358562368593.
Better model found at epoch 4 with mae_dict value: 0.12514241541387486.
Better model found at epoch 5 with mae_dict value: 0.1250995370265668.
Better model found at epoch 6 with mae_dict value: 0.1250642769168491.


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [8]:
# res = dict()
# for l in df_valid["L"].unique():
#     ds_val = getattr(rnacomp.dataset, CFG.dataset_name)(df_valid.query("L==@l").copy(), mode='eval')
#     ds_val_len = getattr(rnacomp.dataset, CFG.dataset_name)(df_valid.query("L==@l").copy(), mode='eval', mask_only=True)
#     sampler_val = torch.utils.data.SequentialSampler(ds_val_len)
#     len_sampler_val = LenMatchBatchSampler(sampler_val, batch_size=CFG.bs, 
#                    drop_last=False)
#     dl_val= DeviceDataLoader(torch.utils.data.DataLoader(ds_val, 
#                    batch_sampler=len_sampler_val, num_workers=CFG.num_workers), CFG.device)

#     data = DataLoaders(dl_val,dl_val)
#     learn = Learner(data,
#                     getattr(rnacomp.models, CFG.model_name)(**CFG.model_kwargs).cuda(),
#                     path = CFG.out, 
#                     loss_func=loss_laplace,
#                     metrics=[MAE()]).to_fp16() 
#     learn.load('model')
#     learn.eval()
#     loss_, score_ = learn.validate()
#     res[l]= score_

# res = pd.DataFrame(pd.Series(res)).reset_index()
# res.columns = ['L', 'mae']
# res.sort_values(by="L")

In [None]:
res

In [None]:
res