In [5]:
import torch
import torch.nn as nn
import pytorch_lightning as pl

import parnet

In [2]:
import gin
gin.enter_interactive_mode()

@gin.configurable()
def train(**kwargs):
    pass

gin.parse_config_file('../configs/config.gin')

ParsedConfigFileIncludesAndImports(filename='../configs/config.gin', imports=['torch', 'pytorch_lightning', 'parnet.networks', 'parnet.layers', 'parnet.metrics'], includes=[])

In [3]:
network = parnet.networks.PanRBPNet()
network

PanRBPNet(
  (stem): StemConv1D(
    (conv1d): Conv1d(4, 256, kernel_size=(7,), stride=(1,), padding=same)
    (batch_norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): ReLU()
  )
  (body): Conv1DTower(
    (tower): Sequential(
      (0): Sequential(
        (0): ResConv1DBlock(
          (conv1d): Conv1d(256, 256, kernel_size=(4,), stride=(1,), padding=same)
          (batch_norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): ReLU()
          (dropout): Dropout1d(p=0.25, inplace=False)
        )
        (1): ResConv1DBlock(
          (conv1d): Conv1d(256, 256, kernel_size=(4,), stride=(1,), padding=same)
          (batch_norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): ReLU()
          (dropout): Dropout1d(p=0.25, inplace=False)
        )
        (2): ResConv1DBlock(
          (conv1d): Conv1d(256, 256, kernel_size=(4,), stride=(1

In [4]:
list(network.children())[-1]

IndexEmbeddingOutput(
  (conv1d): Conv1d(128, 223, kernel_size=(1,), stride=(1,))
)

In [5]:
from parnet.losses import MultinomialNLLLossFromLogits
from parnet.metrics import PearsonCorrCoeff, FilteredPearsonCorrCoeff

class Model(pl.LightningModule):
    def __init__(self, network, _example_input=None, metrics=None, optimizer=torch.optim.Adam):
        super().__init__()
        self.network = network

        # loss
        self.loss_fn = nn.ModuleDict({
            'TRAIN': MultinomialNLLLossFromLogits(dim=-1),
            'VAL': MultinomialNLLLossFromLogits(dim=-1),
        })
        
        # metrics
        if metrics is None:
            metrics = {}
        self.metrics = nn.ModuleDict({
            'TRAIN': nn.ModuleDict({name: metric() for name, metric in metrics.items()}),
            'VAL': nn.ModuleDict({name: metric() for name, metric in metrics.items()}),
        })
        
        # optimizer
        self.optimizer_cls = optimizer
    
    def forward(self, *args, **kwargs):
        return self.network(*args, **kwargs)

    def configure_optimizers(self):
        optimizer = self.optimizer_cls(self.parameters())
        return optimizer

    def on_train_start(self) -> None:
        self.logger.experiment.add_graph(self.network, {'sequence': torch.rand(2, 4, 1000)})
        return super().on_train_start()

    def training_step(self, batch, batch_idx=None, **kwargs):
        inputs, y = batch
        y = y['total']
        y_pred = self.forward(inputs)
        # loss = self.loss_fn(y, y_pred)
        loss = self.compute_and_log_loss(y, y_pred, partition='TRAIN')
        self.compute_and_log_metics(y, y_pred, partition='TRAIN')
        return loss

    def validation_step(self, batch, batch_idx=None, **kwargs):
        inputs, y = batch
        y = y['total']
        y_pred = self.forward(inputs)
        self.compute_and_log_loss(y, y_pred, partition='VAL')
        self.compute_and_log_metics(y, y_pred, partition='VAL')
    
    def compute_and_log_loss(self, y, y_pred, partition=None):
        # on_step = False
        # if partition == 'TRAIN':
        #     on_step = True

        loss = self.loss_fn[partition](y, y_pred)
        self.log(f'{partition}/loss', loss, on_step=True, on_epoch=True, prog_bar=False)
        return loss

    def compute_and_log_metics(self, y, y_pred, partition=None):
        # on_step = False
        # if partition == 'TRAIN':
        #     on_step = True

        for name, metric in self.metrics[partition].items():
            metric(y, y_pred)
            self.log(f'{partition}/{name}', metric, on_step=True, on_epoch=True, prog_bar=False)

model = Model(network, metrics={'pcc': PearsonCorrCoeff, 'filtered_pcc': FilteredPearsonCorrCoeff})
model

Model(
  (network): PanRBPNet(
    (stem): StemConv1D(
      (conv1d): Conv1d(4, 256, kernel_size=(7,), stride=(1,), padding=same)
      (batch_norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU()
    )
    (body): Conv1DTower(
      (tower): Sequential(
        (0): Sequential(
          (0): ResConv1DBlock(
            (conv1d): Conv1d(256, 256, kernel_size=(4,), stride=(1,), padding=same)
            (batch_norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (act): ReLU()
            (dropout): Dropout1d(p=0.25, inplace=False)
          )
          (1): ResConv1DBlock(
            (conv1d): Conv1d(256, 256, kernel_size=(4,), stride=(1,), padding=same)
            (batch_norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (act): ReLU()
            (dropout): Dropout1d(p=0.25, inplace=False)
          )
          (2): ResConv1DBlock(
  

In [6]:
dataloader = torch.utils.data.DataLoader(parnet.data.datasets.TFIterableDataset('../example/head.20.tfrecord', batch_size=4, shuffle=1_000_000), batch_size=None)
dataloader

<torch.utils.data.dataloader.DataLoader at 0x7f3fa51b1870>

In [7]:
next(iter(dataloader))[1]['total'].shape

torch.Size([4, 223, 1000])

In [8]:
import datetime
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, RichProgressBar, LearningRateMonitor

root_log_dir = f'logs/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}'
loggers = [
    pl_loggers.TensorBoardLogger(root_log_dir+'/tensorboard', name='', version='', log_graph=True),
    # pl_loggers.CSVLogger(root_log_dir+'/tensorboard', name='', version=''),
]

checkpoint_callback = ModelCheckpoint(dirpath=f'{root_log_dir}/checkpoints', every_n_epochs=1, save_last=True)

early_stop_callback = EarlyStopping(monitor="val/loss", min_delta=0.00, patience=3, verbose=False, mode="min")

bar = RichProgressBar()

trainer = pl.Trainer(default_root_dir=root_log_dir, max_epochs=10, logger=loggers, callbacks=[checkpoint_callback, LearningRateMonitor('step', log_momentum=True)], log_every_n_steps=1)
trainer.fit(model=model, train_dataloaders=dataloader, val_dataloaders=dataloader)
torch.save(model.network, 'test.pt')

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(

  | Name    | Type       | Params
---------------------------------------
0 | network | PanRBPNet  | 3.3 M 
1 | loss_fn | ModuleDict | 0     
2 | metrics | ModuleDict | 0     
---------------------------------------
3.3 M     Trainable params
0         Non-trainable params
3.3 M     Total params
13.352    Total estimated model params size (MB)
  rank_zero_warn(


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(
  return F.conv1d(input, weight, bias, self.stride,


                                                                           

  rank_zero_warn(


Epoch 9: : 10it [00:07,  1.35it/s, loss=1.17, v_num=]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: : 10it [00:07,  1.31it/s, loss=1.17, v_num=]


In [9]:
net = torch.load('test.pt')
net

PanRBPNet(
  (stem): StemConv1D(
    (conv1d): Conv1d(4, 256, kernel_size=(7,), stride=(1,), padding=same)
    (batch_norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): ReLU()
  )
  (body): Conv1DTower(
    (tower): Sequential(
      (0): Sequential(
        (0): ResConv1DBlock(
          (conv1d): Conv1d(256, 256, kernel_size=(4,), stride=(1,), padding=same)
          (batch_norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): ReLU()
          (dropout): Dropout1d(p=0.25, inplace=False)
        )
        (1): ResConv1DBlock(
          (conv1d): Conv1d(256, 256, kernel_size=(4,), stride=(1,), padding=same)
          (batch_norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): ReLU()
          (dropout): Dropout1d(p=0.25, inplace=False)
        )
        (2): ResConv1DBlock(
          (conv1d): Conv1d(256, 256, kernel_size=(4,), stride=(1

In [10]:
import torch
from parnet.losses import MultinomialNLLLossFromLogits, multinomial_nll_loss

l = MultinomialNLLLossFromLogits()

y, y_pred = torch.randint(0, 10, size=(2, 7, 101), dtype=torch.float32), torch.rand(2, 7, 101, dtype=torch.float32)
print(l(y, y_pred))
print(multinomial_nll_loss(y, y_pred))

tensor(276.3673)
tensor(276.3673)


In [11]:
from parnet.data.datasets import TFIterableDataset, MaskedTFIterableDataset

In [12]:
d = MaskedTFIterableDataset(mask_filepaths=['../example/experiment.masks/experiment-mask.K562.pt'], filepath='../example/data.matrix/head.tfrecord', batch_size=4, shuffle=1_000_000)
next(iter(d))[1]['total']

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 

In [2]:
import gin
gin.enter_interactive_mode()

@gin.configurable()
def train(**kwargs):
    pass

gin.parse_config_file('../configs/config.gin')

  from .autonotebook import tqdm as notebook_tqdm


ParsedConfigFileIncludesAndImports(filename='../configs/config.gin', imports=['torch', 'pytorch_lightning', 'pytorch_lightning.loggers', 'parnet.networks', 'parnet.layers', 'parnet.metrics'], includes=[])

In [6]:
network = parnet.networks.PanRBPNet()
network

PanRBPNet(
  (stem): StemConv1D(
    (conv1d): Conv1d(4, 256, kernel_size=(7,), stride=(1,), padding=same)
    (batch_norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): ReLU()
  )
  (body): Conv1DTower(
    (tower): Sequential(
      (0): Sequential(
        (0): ResConv1DBlock(
          (conv1d): Conv1d(256, 256, kernel_size=(4,), stride=(1,), padding=same)
          (batch_norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): ReLU()
          (dropout): Dropout1d(p=0.25, inplace=False)
        )
        (1): ResConv1DBlock(
          (conv1d): Conv1d(256, 256, kernel_size=(4,), stride=(1,), padding=same)
          (batch_norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): ReLU()
          (dropout): Dropout1d(p=0.25, inplace=False)
        )
        (2): ResConv1DBlock(
          (conv1d): Conv1d(256, 256, kernel_size=(4,), stride=(1

In [11]:
list(network.children())[-3]

Conv1DTower(
  (tower): Sequential(
    (0): Sequential(
      (0): ResConv1DBlock(
        (conv1d): Conv1d(256, 256, kernel_size=(4,), stride=(1,), padding=same)
        (batch_norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): ReLU()
        (dropout): Dropout1d(p=0.25, inplace=False)
      )
      (1): ResConv1DBlock(
        (conv1d): Conv1d(256, 256, kernel_size=(4,), stride=(1,), padding=same)
        (batch_norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): ReLU()
        (dropout): Dropout1d(p=0.25, inplace=False)
      )
      (2): ResConv1DBlock(
        (conv1d): Conv1d(256, 256, kernel_size=(4,), stride=(1,), padding=same)
        (batch_norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): ReLU()
        (dropout): Dropout1d(p=0.25, inplace=False)
      )
      (3): ResConv1DBlock(
        (conv1d): Conv1d(256, 256, kerne