In [1]:
from itertools import chain
import timm
import pandas as pd
import torchvision.transforms as T
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR

from wildlife_tools.data import WildlifeDataset, SplitMetadata
from wildlife_tools.train import ArcFaceLoss, BasicTrainer

In [3]:
metadata = pd.read_csv('ExampleDataset/metadata.csv')
image_root = 'ExampleDataset'

transform = T.Compose([
    T.Resize(size=256),
    T.CenterCrop(size=(224, 224)),
    T.ToTensor(),
])

dataset = WildlifeDataset(
    metadata = metadata, 
    root = image_root,
    split = SplitMetadata('split', 'train'),
    transform=transform
)

# MegaDescriptor-T backbone from HuggingFace Hub
backbone = timm.create_model('hf-hub:BVRA/wildlife-mega-tiny', num_classes=0, pretrained=True)

# Arcface loss - needs backbone output size and number of classes.
objective = ArcFaceLoss(num_classes=dataset.num_classes, embedding_size=768, margin=0.5, scale=64)

# Optimize parameters in backbone and in objective using single optimizer.
params = chain(backbone.parameters(), objective.parameters())
optimizer = SGD(params=params, lr=0.001, momentum=0.9)


trainer = BasicTrainer(
    dataset=dataset,
    model=backbone,
    objective=objective,
    optimizer=optimizer,
    epochs=20,
    device='cpu',
)

trainer.train()

Epoch 0: : 1it [00:01,  1.47s/it]
Epoch 1: : 1it [00:02,  2.04s/it]
Epoch 2: : 1it [00:02,  2.57s/it]
Epoch 3: : 1it [00:02,  2.15s/it]
Epoch 4: : 1it [00:02,  2.35s/it]
Epoch 5: : 1it [00:02,  2.05s/it]
Epoch 6: : 1it [00:02,  2.79s/it]
Epoch 7: : 1it [00:02,  2.54s/it]
Epoch 8: : 1it [00:02,  2.96s/it]
Epoch 9: : 1it [00:02,  2.06s/it]
Epoch 10: : 1it [00:02,  2.32s/it]
Epoch 11: : 1it [00:02,  2.12s/it]
Epoch 12: : 1it [00:02,  2.07s/it]
Epoch 13: : 1it [00:02,  2.09s/it]
Epoch 14: : 1it [00:02,  2.22s/it]
Epoch 15: : 1it [00:02,  2.08s/it]
Epoch 16: : 1it [00:02,  2.09s/it]
Epoch 17: : 1it [00:02,  2.23s/it]
Epoch 18: : 1it [00:02,  2.30s/it]
Epoch 19: : 1it [00:02,  2.06s/it]


# Yaml config
Equivalent cofiguration can be achieved using following Yaml configuration:

In [15]:
from wildlife_tools.tools import parse_yaml, realize

yaml_config = """
trainer:
  method: EmbeddingTrainer
  device: cpu
  epochs: 20

  dataset:
    method: WildlifeDataset
    metadata: ExampleDataset/metadata.csv
    root: ExampleDataset
    split:
      method: SplitMetadata
      col: split
      value: train
    transform:
      method: TransformTorchvision
      compose:
        - Resize(size=256)
        - CenterCrop(size=(224, 224))
        - ToTensor()

  objective:
    method: ArcFaceLoss
    margin: 0.5
    scale: 64

  optimizer:
    method: OptimizerSGD
    lr: 0.001
    momentum: 0.9

  backbone:
    method: TimmBackbone
    model_name: swin_tiny_patch4_window7_224
    pretrained: true
"""


config = parse_yaml(yaml_config)
trainer = realize(config['trainer'])
trainer.train()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
