In [None]:
#default_exp lit_model

# LitModel
> Lit model API

In [None]:
%load_ext autoreload
%autoreload 2

## Imports

In [None]:
#export
from loguru import logger
from pytorch_lightning.core.lightning import LightningModule
import torch
from datetime import datetime, timedelta
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import TensorBoardLogger
import os
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from lit_classifier.loss import FocalLoss, BinaryFocalLoss
import os.path as osp
from torch.optim.lr_scheduler import LambdaLR

In [None]:
# isinstance(torch.optim.Adam(model.parameters()), torch.optim.Optimizer)

## Get optim cfg

In [None]:

#export
def get_optim_cfg(epochs, steps_per_ep, lr=1e-3, init_lr=0.5, min_lr=0.2, interval='step', optim='Adam'):
    steps = epochs*steps_per_ep
    return dict(lr=lr, init_lr=init_lr, min_lr=min_lr, steps=steps, epochs=epochs, interval=interval, optim=optim)


## Lit

## Optim

In [None]:
#export

def get_linear_scheduler(optimizer, optim_cfg):
    def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, init_lr, min_lr, num_epochs, interval):
        def lr_lambda(current_step: int):
            if current_step < num_warmup_steps:
                x = (1-init_lr)*(current_step / num_warmup_steps)+init_lr
                return x
            if interval=='epoch':
                steps_per_ep = num_training_steps / num_epochs
                current_ep = current_step // steps_per_ep
                current_step = steps_per_ep*current_ep

            total_step = (num_training_steps-num_warmup_steps)
            current_step = current_step-num_warmup_steps
            rt = min_lr+(1-min_lr)*(1-current_step/total_step)
            return rt
        return LambdaLR(optimizer, lr_lambda, -1)
    
    scheduler = {
        "scheduler": get_linear_schedule_with_warmup(optimizer,
            optim_cfg["steps"] * 0.15,
            optim_cfg["steps"],
            optim_cfg["init_lr"],
            optim_cfg["min_lr"],
            optim_cfg["epochs"],
            optim_cfg['interval']
        ),
        "interval": 'step',  # or 'epoch'
        "frequency": 1,
    }
    return scheduler

In [None]:
#export
class LitModel(LightningModule):
    def __init__(self, model, optim_cfg=None, loss=FocalLoss(), scheduler_fn=get_linear_scheduler):
        """
            Params:
                scheduler_fn: function to create schedualer, (self.scheduler_fn(optimizer))
            
        """
        super().__init__()
        self.model = model
        self.loss_fn = loss
        self.optim_cfg = optim_cfg
        self.lr = 10e-3
        self.scheduler_fn = scheduler_fn

    def configure_optimizers(self):
        """
            Setup optimizer and scheduler
        """
        if self.optim_cfg is None:
            logger.warning('Please add optim cfg and re-init this object')
            return
        if self.optim_cfg['optim'] == 'Adam':
            optimizer = torch.optim.Adam(self.parameters(), lr=self.optim_cfg["lr"])
        elif self.optim_cfg['optim'] == 'AdamW':
            optimizer = torch.optim.AdamW(self.parameters(), lr=self.optim_cfg["lr"], 
                                          betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, amsgrad=False)
        else:
            assert isinstance(self.optim_cfg['optim'], torch.optim.Optimizer) 
            optimizer = self.optim_cfg['optim']
        # if self.optim_cfg['']
        # if self.scheduler_fn is None:
        #     scheduler = get_linear_scheduler(optimizer, self.optim_cfg)
        # else:
        scheduler = dict(
            scheduler=self.scheduler_fn(optimizer, optim_cfg),
            interval="step",
            frequency=1,
        )
        return [optimizer], [scheduler]


    def forward(self, x):
        return self.model(x)

    def predict_step(self, batch, batch_idx):
        x = batch
        logits = self(x)
        scores = logits.sigmoid()
        # return dict(scores=scores)
        return scores

    def validation_step(self, batch, batch_idx):
        x, y = batch[:2]
        logits = self(x)
        loss = self.loss_fn(logits, y)

        preds = logits.softmax(1).argmax(1)
        accs = (y == preds).float().mean()


        self.log("val_loss", loss, rank_zero_only=True, prog_bar=True,
                    on_step=False, on_epoch=True)
        self.log("val_acc", accs, rank_zero_only=True, prog_bar=True,
                    on_step=False, on_epoch=True)

        return loss

    def training_step(self, batch, batch_idx):
        x, y = batch[:2]
        logits = self(x)
        loss = self.loss_fn(logits, y)

        preds = logits.softmax(1).argmax(1)
        accs = (y == preds).float().mean()
        
        self.log("training_loss", loss, prog_bar=True, rank_zero_only=True, on_epoch=True)
        self.log("training_accuracy", accs, prog_bar=True, rank_zero_only=True, on_epoch=True)
        return loss


class BinLitModel(LitModel):
    def validation_step(self, batch, batch_idx):
        x, y = batch[:2]
        logits = self(x).reshape(-1)
        y = y.reshape(logits.shape)
        loss = self.loss_fn(logits, y)

        preds = logits.sigmoid() > 0.5
        accs = (y == preds).float().mean()


        self.log("val_loss", loss, rank_zero_only=True,
                    on_step=False, on_epoch=True)
        self.log("val_acc", accs, rank_zero_only=True,
                    on_step=False, on_epoch=True)

        return loss

    def training_step(self, batch, batch_idx):
        x, y = batch[:2]
        logits = self(x).reshape(-1)
        y = y.reshape(logits.shape)
        loss = self.loss_fn(logits, y)

        preds = logits.sigmoid() > 0.5
        accs = (y == preds).float().mean()
        
        self.log("training_loss", loss, prog_bar=True, rank_zero_only=True)
        self.log("training_accuracy", accs, prog_bar=True, rank_zero_only=True)
        return loss
    
# from lit_classifier.lit_model import LitModel
# PLitModel = persistent_class(LitModel)

In [None]:
#export
def load_lit_state_dict(ckpt_path):
    st = torch.load(ckpt_path)['state_dict']
    out_st = {}
    for k, v in st.items():
        out_st[k.replace('model.', '')] = v
    return out_st

In [None]:
#hide
import timm, matplotlib.pyplot as plt


optim_cfg = get_optim_cfg(10, 1000, interval='step', lr=0.01)
model = timm.create_model('resnet18')
model = LitModel(model)
# [optim], [sche] = model.configure_optimizers()
# sche = sche['scheduler']
# lrs = []
# for i in range(10000):
#     lrs.append(sche.get_lr())
#     sche.step()
# plt.plot(lrs)

## Get trainer

In [None]:
#export
def get_trainer(exp_name, gpus=1, max_epochs=None, distributed=False,
        monitor=dict(metric="val_acc", mode="max"), save_every_n_epochs=1, save_top_k=1, use_version=True,
    trainer_kwargs=dict(), optim_cfg=None):
    if max_epochs is None:
        assert optim_cfg is not None, f'optim_cfg and max_epoch cannot be both None'
        max_epochs = optim_cfg['max_epoch']
        
    now = datetime.now() + timedelta(hours=7)
    
    root_log_dir = osp.join(
            "lightning_logs", exp_name)
    cur_num_exps = len(os.listdir(root_log_dir)) if osp.exists(root_log_dir) else 0
    version = now.strftime(f"{cur_num_exps:02d}_%b%d_%H_%M")
    if use_version:
        root_log_dir = osp.join(root_log_dir, version)
        logger.info('Root log directory: {}'.format(root_log_dir))
    filename="{epoch}-{"+monitor["metric"]+":.2f}"

    callback_ckpt = ModelCheckpoint(
        dirpath=osp.join(root_log_dir, "ckpts"),
        monitor=monitor['metric'],mode=monitor['mode'],
        filename=filename,
        save_last=True,
        every_n_epochs=save_every_n_epochs,
        save_top_k=save_top_k,
    )

    callback_tqdm = TQDMProgressBar(refresh_rate=5)
    callback_lrmornitor = LearningRateMonitor(logging_interval="step")
    plt_logger = TensorBoardLogger(
        osp.join(root_log_dir, "tb_logs"), version=version
    )
    
    trainer = Trainer(
        gpus=gpus,
        max_epochs=max_epochs,
        strategy= "dp" if not distributed else "ddp",
        callbacks=[callback_ckpt, callback_tqdm, callback_lrmornitor],
        logger=plt_logger,**trainer_kwargs,
    )
    return trainer

In [None]:
!nbdev_build_lib

zsh:1: command not found: nbdev_build_lib


# Mnist example 

In [None]:
#hide
import os

import torch
from pytorch_lightning import LightningModule, Trainer
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.datasets import MNIST

BATCH_SIZE = 256 #if AVAIL_GPUS else 64

In [None]:
from timm.data.transforms_factory import transforms_imagenet_train, transforms_imagenet_eval

In [None]:
import pytorch_lightning as pl
def to_rgb(x):
    return x.convert('RGB')
train_transform = transforms.Compose(
    [
        transforms.Lambda(to_rgb),
        *transforms_imagenet_train(32).transforms,
    ]
)
class ClassifierDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "./", batch_size: int = 32):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

    def setup(self, stage):
        self.mnist_test = MNIST(self.data_dir, train=False, transform=train_transform)
        self.mnist_predict = MNIST(self.data_dir, train=False, transform=train_transform)
        mnist_full = MNIST(self.data_dir, train=True, transform=train_transform)
        self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=self.batch_size)

    # def teardown(self, stage: Optional[str] = None):
        # Used to clean-up when the run is finished
# mnist = MNISTDataModule()

In [None]:
from lit_classifier.persistance import persistent_class

In [None]:
# Init DataLoader from MNIST Dataset
from timm.data.transforms_factory import transforms_imagenet_train, transforms_imagenet_eval

T = transforms.Compose(
[    
    transforms.Lambda(to_rgb),
    *transforms_imagenet_eval(32).transforms,
]
)
PMNIST = persistent_class(MNIST)
train_ds = PMNIST(root='./', train=True, download=True, transform=T)
test_ds = PMNIST(root='./', train=False, download=True, transform=T)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, num_workers=4)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE*2, num_workers=4)

In [None]:
# next(iter(train_loader))

In [None]:
from lit_classifier import base_model
from lit_classifier.lit_model import *
from lit_classifier.loss import FocalLoss
import timm
# model = base_model.model_factory('mobilenetv2_035', 10, pretrained=False)
model = timm.create_model('resnet18', True, num_classes=10)
optim_cfg = get_optim_cfg(10, len(train_loader), lr=1e-4, optim='AdamW')
lit_model = LitModel(model, optim_cfg, loss=FocalLoss())
trainer = get_trainer('test', gpus=[0], distributed=False)

2022-07-30 15:24:08.150 | INFO     | lit_classifier.lit_model:get_trainer:176 - Root log directory: lightning_logs/test/01_Jul30_22_24
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer

In [None]:
trainer

<pytorch_lightning.trainer.trainer.Trainer at 0x7f6ec613d940>

In [None]:
trainer.fit(lit_model, train_loader, test_loader)

# BUILD LIB

In [None]:
!nbdev_build_lib