In [1]:
import numpy as np
from sklearn.metrics import average_precision_score,roc_auc_score
import rna_model
import torch
import h5py
from torch.utils.data import DataLoader
from tqdm import tqdm
import glob
import pandas as pd
import data_preprocess
import h5py

In [2]:
h5f = h5py.File('./data/MT_Splice/psi_data.h5', 'w')
for dataset in ['test','train','val']:
    data_dir = '/home/amber/multitask_RNA/data/MT_Splice/gtex_'+ dataset+'_psi.csv'
    data_class = data_preprocess.Ascot(ascot = data_dir,
                                  fasta_file = '/home/amber/ref/hg19/hg19.fa',
                                  pad_trim_same_l=False, mean_inpute=False,
                                  region_anno=False, length=400, flanking=300,
                                  seq_align='both', encode=True, flanking_exons=False,use_logit=False)
    fa,x,mean,y = [],[],[]
    for i in tqdm(range(data_class.__len__())):
        item = data_class.__getitem__(i)
        fasta = item[0]['fasta']
        seq = np.concatenate([item[0]['seql'],item[0]['seqr']])
        miu = item[0]['mean']
        target = item[1]
        
        fa.append(fasta)
        x.append(seq)
        mean.append(miu)
        y.append(target)
    
    if dataset == 'val':
        dataset = 'valid'
    h5f.create_dataset('x_'+dataset, data=np.asarray(x))
    h5f.create_dataset('mean_'+dataset, data=np.asarray(mean))
    h5f.create_dataset('y_'+dataset, data=np.asarray(y))
    
h5f.close()
    

100%|█████████████████████████████| 11840/11840 [00:05<00:00, 2015.15it/s]
100%|█████████████████████████████| 38028/38028 [00:18<00:00, 2024.83it/s]
100%|███████████████████████████████| 1088/1088 [00:00<00:00, 2023.87it/s]


In [2]:
file = h5py.File('/home/amber/multitask_RNA/data/MT_Splice/psi_data.h5','r')

In [3]:
x_train = file['x_train']
x_test = file['x_test']
x_valid = file['x_valid']

In [10]:
alphabet = np.array(['A','C','G','T'])
fa_train = []
fa_test = []
fa_valid = []

with open('./data/MT_Splice/valid.fa', 'w') as f:
    for i in range(len(x_valid)):
        f.write('>exon '+ str(i)+'\n')
        f.write(''.join(alphabet[np.argmax(x_valid[i],axis = 1)])+'\n')
    f.close()

In [84]:
file = h5py.File('/home/amber/multitask_RNA/data/MT_Splice/psi_data.h5','a')
for dataset in ['test','train','valid']:
    structure = []
    for run in ['M','I','H','E']:
        score = open('./data/MT_Splice/'+run+'_'+dataset+'.txt','r').readlines() 
        score_list = np.char.split(np.array(score), sep ='\t')
        score_array = np.vstack(score_list)[:,:-1]
        
        structure.append(score_array)
    
    fold_array = np.stack(structure,axis=1)
    file.create_dataset('fold_'+dataset, data=fold_array.astype('float'))

file.close()

NameError: name 'f' is not defined

In [98]:
h5f['y_valid']

<HDF5 dataset "y_valid": shape (1088, 56), type "<f8">

## Transfer train

In [1]:
import rna_model
import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
from pytorch_lightning import Trainer
import numpy as np
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
import wandb

In [2]:
dataset = './data/MT_Splice/psi_data.h5'
train_loader = DataLoader(rna_model.mt_splice_data(dataset,'train')
                    ,num_workers=4,pin_memory=True,batch_size = 32)
valid_loader = DataLoader(rna_model.mt_splice_data(dataset,'valid')
                    ,num_workers=4,pin_memory=True,batch_size = 32)

In [3]:
import pytorch_lightning as pl
import rna_model

def nan_mask(pred, target):
    # Missing data are nan's
    mask = torch.isnan(target)
    pred = pred[~mask]
    target = target[~mask]
    return pred,target

class transfer_mt_model(pl.LightningModule):
    def __init__(self,exp_num,lr):
        super().__init__()
        self.lr = lr
        
        #RBP model
        rbp_module = rna_model.rbp_cnn(120,0.001)
        rbp_module.load_state_dict(torch.load('/home/amber/multitask_RNA/wandb/run-20220617_144340-m9rrwfw8/files/best_model.ckpt')['state_dict'])
        rbp_list = list(rbp_module.children())[2:-4]
        self.rbp = torch.nn.Sequential(*[rbp_list[0],rbp_list[1][:-5]])
#         for param in self.rbp.parameters():
#             param.requires_grad = False
        #Annotation model
        annot_model = rna_model.Splice_AI_2K((4,800),(3,800),32,0.001)
        annot_model.load_state_dict(torch.load('/home/amber/multitask_RNA/wandb/run-20220612_213207-zai2tsbz/files/best_model.ckpt')['state_dict'])
        self.annot = torch.nn.Sequential(*list(annot_model.children())[2:-3])
#         for param in self.annot.parameters():
#             param.requires_grad = False
        #outputhead
        #(N,37,800)
        self.outblock = torch.nn.Sequential(*[
            nn.Conv1d(37,128,8,padding = 'same'),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.MaxPool1d(4),
            nn.Dropout(0.2),
            nn.Conv1d(128,256,3,padding = 'same'),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.MaxPool1d(4),
            nn.Dropout(0.2),
            nn.Conv1d(256,256,3,padding = 'same'),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.MaxPool1d(4),
            nn.Dropout(0.2),
            nn.Flatten(),
            nn.Linear(3072,256),
            nn.Linear(256,256),
            nn.Linear(256,exp_num),
            nn.Sigmoid()
        ])
        self.loss_func = torch.nn.KLDivLoss()
        
    def forward(self,x):
        seq = x[0]
        seq_fold = x[1]
        rbp_pres = self.rbp(seq)
        rbp_pres = torch.nn.functional.pad(rbp_pres,(0,0,144,144,0,0))
        rbp_pres = rbp_pres.permute((0,2,1))
        annot_pres = self.annot(seq)
        pres = torch.cat((annot_pres,rbp_pres,seq_fold),dim=1)
        output = self.outblock(pres)
        return output
        
    def training_step(self,batch,batch_idx):
        x,y = batch
        y_hat = self(x)
        y_hat,y = nan_mask(y_hat,y)
        y0 = torch.log(y_hat)
        y1 = torch.log(1-y_hat)
        loss0 = self.loss_func(y0,y)
        loss1 = self.loss_func(y1,1-y)
        loss = loss0+loss1
        self.log("train_loss", loss,on_step = False, on_epoch = True)
        return loss
        
    def validation_step(self,batch,batch_idx):
        x,y = batch
        y_hat = self(x)
        y_hat,y = nan_mask(y_hat,y)
        y0 = torch.log(y_hat)
        y1 = torch.log(1-y_hat)
        loss0 = self.loss_func(y0,y)
        loss1 = self.loss_func(y1,1-y)
        loss = loss0+loss1
        self.log("val_loss", loss,on_step = False, on_epoch = True)
        return loss
        
    def configure_optimizers(self):
        self.opt=torch.optim.Adam([{"params":self.rbp.parameters(),"lr": 1e-5},
                                   {"params":self.annot.parameters(),"lr": 1e-5},
                                   {"params":self.outblock.parameters(),"lr": 1e-3}])

        self.reduce_lr = torch.optim.lr_scheduler.ReduceLROnPlateau(self.opt,
                                                                     mode = 'min',
                                                                     factor = 0.2,
                                                                    patience = 3,
                                                                    min_lr = 1e-7,
                                                                    verbose = True)
        schedulers =  {'scheduler':self.reduce_lr,'monitor':"val_loss",}
        return [self.opt],schedulers


In [4]:
model = transfer_mt_model(56,0.001)
config={'model':'transfer_mt_model','lr':1e-3,'t_lr':1e-5}
wandb_logger = WandbLogger(project="mt_splice",config=config,log_model=True)
checkpoint_callback = ModelCheckpoint(save_top_k=1,
                                        monitor="val_loss",
                                        mode="min",
                                        dirpath=wandb.run.dir,
                                        filename="best_model")
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch')
earlystop = EarlyStopping(monitor="val_loss",
                            mode="min",patience=10)

[34m[1mwandb[0m: Currently logged in as: [33mambert[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
trainer = pl.Trainer(gpus=1,detect_anomaly=True,max_epochs=100,logger = wandb_logger,
                    callbacks=[checkpoint_callback,earlystop,lr_monitor])


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [6]:
trainer.fit(model=model,train_dataloaders=train_loader,val_dataloaders = valid_loader)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [7]

  | Name      | Type       | Params
-----------------------------------------
0 | rbp       | Sequential | 2.2 M 
1 | annot     | Sequential | 358 K 
2 | outblock  | Sequential | 1.2 M 
3 | loss_func | KLDivLoss  | 0     
-----------------------------------------
3.7 M     Trainable params
0         Non-trainable params
3.7 M     Total params
14.880    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  return F.conv1d(input, weight, bias, self.stride,
                not been set for this class (_ResultMetric). The property determines if `update` by
                default needs access to the full metric state. If this is not the case, significant speedups can be
                achieved and we recommend setting this to `False`.
                We provide an checking function
                `from torchmetrics.utilities import check_forward_no_full_state`
                that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
                default for now) or if `full_state_update=False` can be used safely.
                


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 00005: reducing learning rate of group 0 to 2.0000e-06.
Epoch 00005: reducing learning rate of group 1 to 2.0000e-06.
Epoch 00005: reducing learning rate of group 2 to 2.0000e-04.


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 00009: reducing learning rate of group 0 to 4.0000e-07.
Epoch 00009: reducing learning rate of group 1 to 4.0000e-07.
Epoch 00009: reducing learning rate of group 2 to 4.0000e-05.


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]