In [1]:
import os
if os.path.basename(os.getcwd()) == "notebooks":
    os.chdir("..")
%cd .
%load_ext autoreload
%autoreload 2

/home/jh/code/til/til-23-cv


## Training Suspect Recognition

In [2]:
import timm
import torch
from til_23_cv import LitImEncoder, ArcMarginProduct, LitImClsDataModule
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import (
    ModelCheckpoint,
    DeviceStatsMonitor,
    EarlyStopping,
    LearningRateMonitor,
)

### Settings

Picking the model backbone, batch size, image size, etc...

In [3]:
# NOTE:
# {model}_{size}_{patch size}_{im size}.{train method}_{dataset}
# m38m is Merged-38M, combines A LOT of datasets.
# ft refers to fine-tuning on a smaller dataset later.
# so m38m_ft_in22k_in1k means pretrained on Merged-38M, then finetuned on
# ImageNet-22k followed by ImageNet-1k.
# Clip models might be useful for their zero-shot capabilities.
display(timm.list_models(pretrained=True, filter="*dino*"))
# backbone = "eva02_large_patch14_448.mim_m38m_ft_in22k_in1k"
# backbone = "eva02_tiny_patch14_336.mim_in22k_ft_in1k"
# https://huggingface.co/timm/vit_small_patch14_dinov2.lvd142m
backbone = "vit_small_patch14_dinov2.lvd142m"

im_size = 224
batch_size = 144
num_workers = 12

['resmlp_12_224.fb_dino',
 'resmlp_24_224.fb_dino',
 'vit_base_patch8_224.dino',
 'vit_base_patch14_dinov2.lvd142m',
 'vit_base_patch16_224.dino',
 'vit_giant_patch14_dinov2.lvd142m',
 'vit_large_patch14_dinov2.lvd142m',
 'vit_small_patch8_224.dino',
 'vit_small_patch14_dinov2.lvd142m',
 'vit_small_patch16_224.dino']

### Data Module

In [4]:
data = LitImClsDataModule(
    "data/til23reid",
    im_size=im_size,
    batch_size=batch_size,
    num_workers=num_workers,
)
# `setup()` is automatically called; Calling here to check number of classes.
data.setup("fit")
nclasses = len(data.train_ds.classes)
print(nclasses)

200


### Preview Augmentations

In [5]:
import cv2
import numpy as np
while False:
    im = data.preview_transform(1)[0]
    im = im.resize((1024, 1024))
    im = np.array(im)[...,::-1]
    cv2.imshow("example", im)
    key = chr(cv2.waitKey(0))
    if key == "q":
        break
cv2.destroyAllWindows()

### Model

In [6]:
model = LitImEncoder(
    backbone,
    head=ArcMarginProduct(384, nclasses, s=32, m=0.5),
    pretrained=True,
    nclasses=nclasses,
    im_size=im_size,
)
# Set corresponding dataloader image normalization configs.
display(model.model.pretrained_cfg)

{'url': 'https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth',
 'hf_hub_id': 'timm/vit_small_patch14_dinov2.lvd142m',
 'architecture': 'vit_small_patch14_dinov2',
 'tag': 'lvd142m',
 'custom_load': False,
 'input_size': (3, 518, 518),
 'fixed_input_size': True,
 'interpolation': 'bicubic',
 'crop_pct': 1.0,
 'crop_mode': 'center',
 'mean': (0.485, 0.456, 0.406),
 'std': (0.229, 0.224, 0.225),
 'num_classes': 0,
 'pool_size': None,
 'first_conv': 'patch_embed.proj',
 'classifier': 'head',
 'license': 'cc-by-nc-4.0'}

### Lit Callbacks

In [13]:
# https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html
callbacks = [
    # DeviceStatsMonitor(),
    LearningRateMonitor(logging_interval="step"),
    EarlyStopping(monitor="val_sil_score", patience=20, mode="max", strict=True),
    ModelCheckpoint(
        filename="{epoch}-{val_sil_score:.2f}",
        monitor="val_sil_score",
        save_last=True,
        save_top_k=3,
        mode="max",
    )
]

### Training

In [14]:
torch.set_float32_matmul_precision("medium")
# torch.set_float32_matmul_precision("high")
seed_everything(42, workers=True)
trainer = Trainer(
    # TODO: This is the part that would love Hydra config system.
    # default_root_dir="temp",
    benchmark=True,
    # Glitchy, see https://github.com/Lightning-AI/lightning/issues/5558.
    # precision="16-mixed",
    callbacks=callbacks,
    log_every_n_steps=10,
    check_val_every_n_epoch=1, # 3
    max_steps=10000,
    # detect_anomaly=True,
    # fast_dev_run=True, # Debugging.
)

Global seed set to 42
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model, datamodule=data)

### Model Export

In [18]:
ckpt_path = "epoch=8-val_sil_score=0.43.pt"
save_path = "reid.torchscript"

In [19]:
# Load checkpoint.
ckpt = torch.load(ckpt_path)
model.load_state_dict(ckpt["state_dict"])

<All keys matched successfully>

In [20]:
# Trace & save model.
traced = torch.jit.trace(model.model, torch.rand(1, 3, im_size, im_size))
torch.jit.save(traced, save_path)

  assert condition, message


In [21]:
# Check equality.
traced = torch.jit.load(save_path)
x = torch.rand(1, 3, im_size, im_size)
print(torch.isclose(traced(x), model.model(x)).all())

tensor(True)
