In [2]:
import os
if os.path.basename(os.getcwd()) == "notebooks":
    os.chdir("..")
%cd .
# %load_ext autoreload
# %autoreload 2

/home/jh/code/til/til-23-cv


In [3]:
%%sh
./setup.sh

Not on competition platform, exiting...


## Training Suspect Recognition

In [4]:
import timm
import torch
from til_23_cv.reid import cli_main

In [4]:
# https://lightning.ai/docs/pytorch/stable/cli/lightning_cli_advanced_3.html#run-from-python
cli = cli_main(dict(config="cfg/reid.yaml"))
trainer = cli.trainer
model = cli.model
data = cli.datamodule

  rank_zero_warn(
Global seed set to 42
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


### `timm` models

In [5]:
# NOTE:
# {model}_{size}_{patch size}_{im size}.{train method}_{dataset}
# m38m is Merged-38M, combines A LOT of datasets.
# ft refers to fine-tuning on a smaller dataset later.
# so m38m_ft_in22k_in1k means pretrained on Merged-38M, then finetuned on
# ImageNet-22k followed by ImageNet-1k.
# Clip models might be useful for their zero-shot capabilities.
display(timm.list_models(pretrained=True, filter="*dino*"))
# backbone = "eva02_large_patch14_448.mim_m38m_ft_in22k_in1k"
# backbone = "eva02_tiny_patch14_336.mim_in22k_ft_in1k"
# https://huggingface.co/timm/vit_small_patch14_dinov2.lvd142m
# backbone = "vit_small_patch14_dinov2.lvd142m"

['resmlp_12_224.fb_dino',
 'resmlp_24_224.fb_dino',
 'vit_base_patch8_224.dino',
 'vit_base_patch14_dinov2.lvd142m',
 'vit_base_patch16_224.dino',
 'vit_giant_patch14_dinov2.lvd142m',
 'vit_large_patch14_dinov2.lvd142m',
 'vit_small_patch8_224.dino',
 'vit_small_patch14_dinov2.lvd142m',
 'vit_small_patch16_224.dino']

### Preview Augmentations

See `notebooks/data.ipynb` for how to convert `til23plush` dataset to `til23reid`.

In [6]:
print(data.nclasses)

200


In [7]:
import cv2
import numpy as np
while False:
    im = data.preview_transform(1)[0]
    im = im.resize((1024, 1024))
    im = np.array(im)[...,::-1]
    cv2.imshow("example", im)
    key = chr(cv2.waitKey(0))
    if key == "q":
        break
cv2.destroyAllWindows()

### Training

In [5]:
trainer.fit(model, datamodule=data)

**NOTE: Above saves checkpoints to `runs/lightning_logs/version_N/checkpoints`, pick the best for export!**

### Hyperparameter Search

This is only possible because the model converges and overfits quickly.

In [None]:
# Need master branch for lightning 2.0 support.
%pip install git+https://git@github.com/optuna/optuna.git@master

In [5]:
import optuna
import warnings
from IPython.display import clear_output
from til_23_cv.utils import OptunaCallback

In [7]:
# Define search objective.
def objective(trial):
    clear_output(wait=True)

    arc_s = trial.suggest_float("arc_s", 0.1, 15.0)
    arc_m = trial.suggest_float("arc_m", 0.15, 0.45)
    lr = trial.suggest_float("lr", 5e-6, 5e-5, log=True)
    # batch_size = trial.suggest_categorical("batch_size", [64, 96, 128])

    cli = cli_main(
        dict(
            config="cfg/reid.yaml",
            # Set sched_steps to -1 to disable OneCycle for better read on optimal LR.
            model=dict(arc_s=arc_s, arc_m=arc_m, lr=lr, sched_steps=10000),
            # data=dict(batch_size=batch_size),
            trainer=dict(
                callbacks=[OptunaCallback(trial=trial, monitor="val_sil_score")],
                default_root_dir=f"runs/optuna/trial_{trial.number}",
                max_epochs=80,
            ),
        )
    ) 

    print("Hyperparameters:")
    print(
        "arc_s:", arc_s,
        "arc_m:", arc_m,
        "lr:", lr,
        # "batch_size:", batch_size
    )

    trainer = cli.trainer
    model = cli.model
    data = cli.datamodule
    trainer.fit(model, datamodule=data)

    return model.best_score  

In [8]:
# Search for hyperparameters.
pruner = optuna.pruners.PatientPruner(
    optuna.pruners.MedianPruner(
        n_startup_trials=8, n_warmup_steps=8, n_min_trials=2
    ),
    patience=8,
)
study = optuna.create_study(direction="maximize", pruner=pruner)
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    study.optimize(objective, n_trials=500, timeout=int(60*60*12))

Global seed set to 42
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: runs/optuna/trial_74/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type              | Params
--------------------------------------------
0 | model | VisionTransformer | 21.6 M
1 | head  | ArcMarginProduct  | 76.8 K
2 | loss  | CrossEntropyLoss  | 0     
--------------------------------------------
21.7 M    Trainable params
0         Non-trainable params
21.7 M    Total params
86.822    Total estimated model params size (MB)


Hyperparameters:
arc_s: 12.076299801184774 arc_m: 0.28574758664587907 lr: 2.0280211776804368e-05


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[I 2023-05-31 09:47:52,848] Trial 74 finished with value: 0.4753156304359436 and parameters: {'arc_s': 12.076299801184774, 'arc_m': 0.28574758664587907, 'lr': 2.0280211776804368e-05}. Best is trial 64 with value: 0.5075626969337463.


In [10]:
print(f"Number of finished trials: {len(study.trials)}")
print("Best trial:")
trial = study.best_trial
print(f"  ID: {trial.number}".format())
print(f"  Value: {trial.value}".format())
print(f"  Params: ")
for k, v in trial.params.items():
    print(f"    {k}: {v}")

Number of finished trials: 75
Best trial:
  ID: 64
  Value: 0.5075626969337463
  Params: 
    arc_s: 10.200518950277152
    arc_m: 0.2956878768033506
    lr: 1.735784634541476e-05


### Model Export

In [6]:
ckpt_path = "runs/lightning_logs/version_4/checkpoints/epoch=11-val_sil_score=0.434.ckpt"
save_path = "models/reid.pt"

In [8]:
# Load checkpoint.
ckpt = torch.load(ckpt_path)
model.load_state_dict(ckpt["state_dict"])

<All keys matched successfully>

In [16]:
# Trace & save model.
encoder = model.model
sz = model.hparams.im_size
traced = torch.jit.trace(encoder, torch.rand(1, 3, sz, sz))
torch.jit.save(traced, save_path)

  assert condition, message


In [18]:
# Check equality.
traced = torch.jit.load(save_path)
x = torch.rand(1, 3, sz, sz)
print(torch.isclose(traced(x), model.model(x)).all())

tensor(True)
