In [1]:
import os
if os.path.basename(os.getcwd()) == "notebooks":
    os.chdir("..")
%cd .
# %load_ext autoreload
# %autoreload 2

/home/jh/code/til/til-23-cv


In [2]:
%%sh
./setup.sh

Not on competition platform, exiting...


## Training Suspect Recognition

In [3]:
import timm
import torch
from til_23_cv.reid import cli_main

In [4]:
# https://lightning.ai/docs/pytorch/stable/cli/lightning_cli_advanced_3.html#run-from-python
cli = cli_main(config="cfg/reid.yaml")
trainer = cli.trainer
model = cli.model
data = cli.datamodule

  rank_zero_warn(
Global seed set to 42
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [5]:
cli.config.as_dict()

{'config': None,
 'seed_everything': 42,
 'trainer': {'accelerator': 'auto',
  'strategy': 'auto',
  'devices': 'auto',
  'num_nodes': 1,
  'logger': None,
  'callbacks': None,
  'fast_dev_run': False,
  'max_epochs': None,
  'min_epochs': None,
  'min_steps': None,
  'max_time': None,
  'limit_train_batches': None,
  'limit_val_batches': None,
  'limit_test_batches': None,
  'limit_predict_batches': None,
  'overfit_batches': 0.0,
  'val_check_interval': None,
  'check_val_every_n_epoch': 1,
  'num_sanity_val_steps': None,
  'log_every_n_steps': 1,
  'enable_checkpointing': None,
  'enable_progress_bar': None,
  'enable_model_summary': None,
  'accumulate_grad_batches': 1,
  'gradient_clip_val': None,
  'gradient_clip_algorithm': None,
  'deterministic': False,
  'benchmark': True,
  'inference_mode': True,
  'use_distributed_sampler': True,
  'profiler': None,
  'detect_anomaly': False,
  'barebones': False,
  'plugins': [<lightning.pytorch.plugins.precision.amp.MixedPrecisionPlugin 

### `timm` models

In [5]:
# NOTE:
# {model}_{size}_{patch size}_{im size}.{train method}_{dataset}
# m38m is Merged-38M, combines A LOT of datasets.
# ft refers to fine-tuning on a smaller dataset later.
# so m38m_ft_in22k_in1k means pretrained on Merged-38M, then finetuned on
# ImageNet-22k followed by ImageNet-1k.
# Clip models might be useful for their zero-shot capabilities.
display(timm.list_models(pretrained=True, filter="*dino*"))
# backbone = "eva02_large_patch14_448.mim_m38m_ft_in22k_in1k"
# backbone = "eva02_tiny_patch14_336.mim_in22k_ft_in1k"
# https://huggingface.co/timm/vit_small_patch14_dinov2.lvd142m
# backbone = "vit_small_patch14_dinov2.lvd142m"

['resmlp_12_224.fb_dino',
 'resmlp_24_224.fb_dino',
 'vit_base_patch8_224.dino',
 'vit_base_patch14_dinov2.lvd142m',
 'vit_base_patch16_224.dino',
 'vit_giant_patch14_dinov2.lvd142m',
 'vit_large_patch14_dinov2.lvd142m',
 'vit_small_patch8_224.dino',
 'vit_small_patch14_dinov2.lvd142m',
 'vit_small_patch16_224.dino']

### Preview Augmentations

See `notebooks/data.ipynb` for how to convert `til23plush` dataset to `til23reid`.

In [6]:
print(data.nclasses)

200


In [7]:
import cv2
import numpy as np
while False:
    im = data.preview_transform(1)[0]
    im = im.resize((1024, 1024))
    im = np.array(im)[...,::-1]
    cv2.imshow("example", im)
    key = chr(cv2.waitKey(0))
    if key == "q":
        break
cv2.destroyAllWindows()

### Training

In [6]:
trainer.fit(model, datamodule=data)

**NOTE: Above saves checkpoints to `runs/lightning_logs/version_N/checkpoints`, pick the best for export!**

### Hyperparameter Search

This is only possible because the model converges and overfits quickly.

In [None]:
# Need master branch for lightning 2.0 support.
%pip install git+https://git@github.com/optuna/optuna.git@master

In [3]:
import optuna
import warnings
from IPython.display import clear_output
from til_23_cv.utils import OptunaCallback

In [4]:
STEPS_PER_EPOCH = 97 # 97 is batch size 128 for given dataset only.
MAX_EPOCHS = 32

In [None]:
# Define search objective.
def objective(trial):
    clear_output(wait=True)

    arc_s = trial.suggest_float("arc_s", 0.1, 10.0)
    arc_m = trial.suggest_float("arc_m", 0.1, 0.8)
    lr = trial.suggest_float("lr", 1e-6, 3e-5, log=True)
    # batch_size = trial.suggest_categorical("batch_size", [64, 96, 128])

    cli = cli_main(
        dict(
            model=dict(
                model_name="vit_base_patch14_dinov2.lvd142m",
                arc_s=arc_s,
                arc_m=arc_m,
                lr=lr,
                # Set sched_steps to -1 to disable OneCycle for better read on optimal LR.
                sched_steps=MAX_EPOCHS*STEPS_PER_EPOCH,
            ),
            # data=dict(batch_size=batch_size),
            trainer=dict(
                callbacks=[OptunaCallback(trial=trial, monitor="val_sil_score")],
                default_root_dir=f"runs/optuna/trial_{trial.number}",
                max_epochs=MAX_EPOCHS,
            ),
        )
    ) 

    print("Hyperparameters:")
    print(
        "arc_s:", arc_s,
        "arc_m:", arc_m,
        "lr:", lr,
        # "batch_size:", batch_size
    )

    trainer = cli.trainer
    model = cli.model
    data = cli.datamodule
    trainer.fit(model, datamodule=data)

    return model.best_score

In [None]:
# Search for hyperparameters.
pruner = optuna.pruners.PatientPruner(
    optuna.pruners.MedianPruner(
        # 0.3 is default pct for OneCycle scheduler.
        n_startup_trials=4, n_warmup_steps=int(MAX_EPOCHS*0.3), n_min_trials=2
    ),
    patience=4,
)
study = optuna.create_study(direction="maximize", pruner=pruner)
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    study.optimize(objective, timeout=int(60*60*6))

In [None]:
print(f"Number of finished trials: {len(study.trials)}")
print("Best trial:")
trial = study.best_trial
print(f"  ID: {trial.number}".format())
print(f"  Value: {trial.value}".format())
print(f"  Params: ")
for k, v in trial.params.items():
    print(f"    {k}: {v}")

Number of finished trials: 75
Best trial:
  ID: 64
  Value: 0.5075626969337463
  Params: 
    arc_s: 10.200518950277152
    arc_m: 0.2956878768033506
    lr: 1.735784634541476e-05


### Model Export

In [37]:
# ckpt_path = trainer.checkpoint_callback.best_model_path
ckpt_path = "runs/lightning_logs/version_4/checkpoints/epoch=11-val_sil_score=0.434.ckpt"
save_path = "models/reid.pt"

In [38]:
torch.set_float32_matmul_precision("highest")
torch.use_deterministic_algorithms(True, warn_only=True)

In [39]:
# Load checkpoint.
ckpt = torch.load(ckpt_path)
model.load_state_dict(ckpt["state_dict"])

<All keys matched successfully>

In [40]:
# Trace & save model.
encoder = model.model
sz = model.hparams.im_size

x = torch.rand(1, 3, sz, sz).cuda()
encoder.cuda().eval()
traced = torch.jit.trace(encoder, x)
torch.jit.save(traced, save_path)

  assert condition, message


In [41]:
# Check equality.
with torch.inference_mode():
    traced = torch.jit.load(save_path).eval()
    x = torch.rand(1, 3, sz, sz).cuda()
    print(torch.isclose(traced(x), encoder(x)).all())
    print(abs(traced(x) - encoder(x)).max())

tensor(False, device='cuda:0')
tensor(7.9870e-06, device='cuda:0')
