In [1]:
%cd /data/gunsbrother/prjs/ltvu/ours/
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import lightning as L
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import (
    EarlyStopping,
    ModelCheckpoint,
    DeviceStatsMonitor,
    LearningRateMonitor,
    Callback,
)
from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger

seed_everything(42, workers=True)


/data/gunsbrother/prjs/ltvu/ours


Seed set to 42


42

In [None]:
class LitModule(L.LightningModule):
    def __init__(self, ):
        super().__init__()
        self.model = None

    def training_step(self, batch, batch_idx):
        loss = 0
        return loss

    def validation_step(self, batch, batch_idx):
        loss = 0
        val_acc = 0
        self.log_dict({'val_loss': loss, 'val_acc': val_acc})
        return loss

    def configure_optimizers(self):
        optim = None
        return optim

model = LitModule()

In [None]:

debug_options01 = dict(
    fast_dev_run=10,  # vs. limit_*: runs the first 10 batches without any side effects
)

debug_options03 = dict(
    overfit_batches=None,
)

# when an error occurred while overfitting sanity check
# or for rapid idea iteration
debug_options99 = dict(
    limit_train_batches=None,  # with max_epochs=1 when doing idea validation
    limit_val_batches=None,
    limit_test_batches=None,
    limit_predict_batches=None,
)

common_logger_options = dict(
    save_dir='results/',
    name='exp',
    version=None,  # will automatically assign the next version number or regarded as another sub_dir if str
)
trainer = Trainer(
    # important training args
    gradient_clip_val=1.,
    accumulate_grad_batches=4,
    precision='bf16-mixed',

    # stopping
    max_epochs=10,
    min_epochs=None,
    max_time=None,

    # for logging
    logger=[
        TensorBoardLogger(
            sub_dir='tb/',  # TensorBoard logs will be saved in /save_dir/name/version/sub_dir
            **common_logger_options),
        CSVLogger(**common_logger_options),
    ],
    check_val_every_n_epoch=1,
    val_check_interval=1.,  # every {int} batches or every {float in [0, 1]} epoch
    log_every_n_steps=50,
    enable_progress_bar=True,
    profiler=None,  # simple, advanced
    enable_checkpointing=False,  # False because ModelCheckpoint is used

    # for on-training sanity check
    num_sanity_val_steps=2,  # checks the first 2 batches of the val set at the start of training

    # when ddp is used
    num_nodes=1,
    accelerator='auto',
    devices=8,

    # for reproducibility
    deterministic=True,
    benchmark=False,

    # others
    callbacks=[
        # EarlyStopping(monitor='val_loss'),
        ckpt_callback:=ModelCheckpoint(
            # dirpath='checkpoints/',
            monitor='val_acc', mode='max',
        ),
        LearningRateMonitor("epoch"),
        DeviceStatsMonitor(),
    ],
)


# Training

In [None]:
trainer.fit(model=model, train_dataloaders=None, val_dataloaders=None)
# or
datamodule = None
trainer.fit(model, datamodule)

In [None]:
# after training
ckpt_callback.best_model_path

In [None]:
# load a model
p_ckpt = 'asdasdasdasd'
model_pretrained = LItEncoder.load_from_checkpoint(
    p_ckpt,
    #**overriding_args,
)
model_pretrained.freeze()  # eval mode + with no grad


# Non-essentials

In [None]:
import lightning as L
from torch.utils.data import random_split, DataLoader

# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST
from torchvision import transforms


class MNISTDataModule(L.LightningDataModule):
    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

    def prepare_data(self):  # guarantee that only one process in DDP will download the data
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)
        # or tokenize and save the data to disk

        # WARNING: DO NOT assign any state (i.e. self.x = y) here
        # because this function will be called from the main process

    def setup(self, stage: str):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit":
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(
                mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
            )

        # Assign test dataset for use in dataloader(s)
        if stage == "test":
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

        if stage == "predict":
            self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=32)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=32)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=32)

    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=32)

dm = MNISTDataModule()
model = Model()
trainer.fit(model, datamodule=dm)
trainer.test(datamodule=dm)
trainer.validate(datamodule=dm)
trainer.predict(datamodule=dm)

In [None]:
class MyPrintingCallback(Callback):
    def on_train_start(self, trainer, pl_module):
        print("Training is starting")

    def on_train_end(self, trainer, pl_module):
        print("Training is ending")