In [None]:
import torch

In [None]:
import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers

In [None]:
from torchrbpnet import layers

l = layers.StemConv1D(4)
next(l.children()).out_channels

In [None]:
import os
import torch
print(os.getcwd())

from torchrbpnet.data import tfrecord_to_dataloader, dummy_dataloader
from torchrbpnet.data import datasets

dataloader = torch.utils.data.DataLoader(datasets.TFIterableDataset('../example/data.matrix/head.tfrecord', batch_size=2, shuffle=100), batch_size=None)

In [None]:
# %%
import torch
import torch.nn as nn
import pytorch_lightning as pl

import torchmetrics
from torchrbpnet.losses import MultinomialNLLLossFromLogits
from torchrbpnet.metrics import MultinomialNLLFromLogits #BatchedPCC
from torchrbpnet.networks import MultiRBPNet

class Model(pl.LightningModule):
    def __init__(self, network, metrics=None, optimizer=torch.optim.Adam):
        super().__init__()
        self.network = network
        self.loss_fn = MultinomialNLLLossFromLogits()
        
        # metrics
        if metrics is None:
            self.metrics = nn.ModuleDict({})
        else:
            self.metrics = nn.ModuleDict(metrics)
        
        # optimizer
        self.optimizer_cls = optimizer
    
    def forward(self, *args, **kwargs):
        return self.network(*args, **kwargs)

    def configure_optimizers(self):
        optimizer = self.optimizer_cls(self.parameters())
        return optimizer

    def training_step(self, batch, batch_idx, **kwargs):
        inputs, y = batch
        y = y['total']
        y_pred = self.forward(inputs)
        loss = self.loss_fn(y, y_pred, dim=-2)
        self.compute_and_log_metics(y_pred, y, partition='train')
        return loss
    
    def training_epoch_end(self, *args, **kwargs):
        self._reset_metrics()

    def validation_epoch_end(self, *args, **kwargs):
        self._reset_metrics()

    def validation_step(self, batch, batch_idx):
        inputs, y = batch
        y = y['total']
        y_pred = self.forward(inputs)
        self.compute_and_log_metics(y_pred, y, partition='val')
    
    def compute_and_log_metics(self, y_pred, y, partition=None):
        on_step = False
        if partition == 'train':
            on_step = True

        for name, metric in self.metrics.items():
            metric(y_pred, y)
            self.log(f'{partition}/{name}', metric.compute(), on_step=on_step, on_epoch=True, prog_bar=False)
    
    def _reset_metrics(self):
        for metric in self.metrics.values():
            metric.reset()

model = Model(network=MultiRBPNet(n_tasks=223))

In [None]:
model

In [None]:
mods = [x for x in model.modules()]

In [None]:
mods[-5].weight.shape

In [None]:
import datetime
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, RichProgressBar, LearningRateMonitor

root_log_dir = f'logs/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}'
loggers = [
    pl_loggers.TensorBoardLogger(root_log_dir+'/tensorboard', name='', version='', log_graph=True),
    # pl_loggers.CSVLogger(root_log_dir+'/tensorboard', name='', version=''),
]

checkpoint_callback = ModelCheckpoint(dirpath=f'{root_log_dir}/checkpoints', every_n_epochs=1, save_last=True)

early_stop_callback = EarlyStopping(monitor="val/loss", min_delta=0.00, patience=3, verbose=False, mode="min")

bar = RichProgressBar()

trainer = pl.Trainer(default_root_dir=root_log_dir, max_epochs=2, logger=loggers, callbacks=[checkpoint_callback, early_stop_callback, LearningRateMonitor('step', log_momentum=True)])
trainer.fit(model=model, train_dataloaders=dataloader, val_dataloaders=dataloader)
torch.save(model.network, 'test.pt')

In [None]:
model.optimizers()