In [1]:
import h5py
import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from torchinfo import summary
import sys
sys.path.append('/home/amber/multitask_RNA/evaluation/')
import mpra_model
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
torch.set_default_dtype(torch.float32)

In [2]:
file = '/home/amber/multitask_RNA/data/lenti_MPRA/K562_onehot.h5'
model = mpra_model.MPRA_model(filter_num=192,kernel_size=7,residual_num=2,exp_num=1)

In [3]:
dataset = mpra_model.onehot_dataset(file)
train,valid = torch.utils.data.random_split(dataset,[int(len(dataset)*0.9),len(dataset)-int(len(dataset)*0.9)])
train_loader = DataLoader(train, num_workers=4,batch_size = 256)
valid_loader = DataLoader(valid, num_workers=4,batch_size = 256)

In [5]:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint,LearningRateMonitor

checkpoint_callback = ModelCheckpoint(dirpath ='/home/amber/multitask_RNA/model/lenti_MPRA_onehot/K562' ,
                                        save_top_k=1,
                                        monitor="val_loss", 
                                        mode="min")
earlystop = EarlyStopping(monitor="val_loss",
                            mode="min",patience=7)
lr_monitor = LearningRateMonitor(logging_interval='epoch')


trainer = pl.Trainer(gpus=0,detect_anomaly=True,max_epochs=100,
                    callbacks=[earlystop,lr_monitor])
history = trainer.fit(model=model,train_dataloaders=train_loader,val_dataloaders = valid_loader)

  rank_zero_deprecation(
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(

   | Name         | Type        | Params
----------------------------------------------
0  | loss_func    | MSELoss     | 0     
1  | first_conv   | Conv1d      | 5.6 K 
2  | batch1       | BatchNorm1d | 384   
3  | maxpool      | MaxPool1d   | 0     
4  | residual     | Sequential  | 775 K 
5  | flat         | Flatten     | 0     
6  | linear       | Linear      | 258 K 
7  | batch2       | BatchNorm1d | 384   
8  | last_act     | ReLU        | 0     
9  | last_dropout | Dropout     | 0     
10 | dense        | Linear      | 24.7 K
11 | dense2       | Linear      | 129   
----------------------------------------------
1.1 M     Trainable params
0         Non-trainable params
1.1 M     Total params
4.261     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [6]:
summary(model, input_size=(128, 4, 230), device='cuda')

torch.Size([128, 192, 7])
torch.Size([128, 1344])


Layer (type:depth-idx)                   Output Shape              Param #
MPRA_model                               [128, 1]                  --
├─MSELoss: 1-1                           --                        --
├─Conv1d: 1-2                            [128, 192, 230]           5,568
├─BatchNorm1d: 1-3                       [128, 192, 230]           384
├─MaxPool1d: 1-4                         [128, 192, 28]            --
├─Sequential: 1-5                        [128, 192, 7]             --
│    └─dilated_residual: 2-1             [128, 192, 28]            --
│    │    └─Sequential: 3-1              [128, 192, 28]            775,872
│    │    └─ReLU: 3-2                    [128, 192, 28]            --
│    └─MaxPool1d: 2-2                    [128, 192, 7]             --
├─Flatten: 1-6                           [128, 1344]               --
├─Linear: 1-7                            [128, 192]                258,240
├─BatchNorm1d: 1-8                       [128, 192]                384
