In [1]:
import lightning as L
import torch
import timm
import torch.nn.functional as F
import torchmetrics
import gc

from typing import Union
from pathlib import Path
from torchvision.datasets import Food101
from torch.utils.data import random_split, DataLoader
from pytorch_lightning.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint
torch.set_float32_matmul_precision('high')

In [2]:
class Food101DataModule(L.LightningDataModule):
    def __init__(self, transform, data_dir: Union[str, Path] = "data", batch_size: int = 128) -> None:
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transform

    def prepare_data(self):
        Food101(self.data_dir, split='train', download=True)
        Food101(self.data_dir, split='test', download=True)

    def setup(self, stage: str = 'fit'):
        if stage == 'fit':
            food101_full = Food101(self.data_dir, split='train', download=True, transform=self.transform)
            self.food101_train, self.food101_val = random_split(food101_full, [0.8, 0.2])

        if stage == 'test':
            self.food101_test = Food101(self.data_dir, split='test', download=True, transform=self.transform)

        if stage == "predict":
            self.food101_predict = Food101(self.data_dir, split='test', download=True, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.food101_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.food101_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.food101_test, batch_size=self.batch_size)

    def predict_dataloader(self):
        return DataLoader(self.food101_predict, batch_size=self.batch_size)

In [3]:
class Food101Classifier(L.LightningModule):
    def __init__(self, model_name: str, epochs: int) -> None:
        super().__init__()
        self.num_classes = 101
        self.epochs = epochs
        self.model = timm.create_model(model_name, pretrained=True, num_classes=101)
        self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=101)
        self.valid_acc = torchmetrics.Accuracy(task="multiclass", num_classes=101)
        self.f1_metric = torchmetrics.F1Score(task="multiclass", num_classes=101)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self.forward(inputs)
        preds = torch.argmax(outputs, 1)
        loss = F.cross_entropy(outputs, labels)
        self.log("train_loss", loss, on_epoch=True, prog_bar=True)
        self.train_acc(preds, labels)
        self.log('train_acc', self.train_acc, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        self.model.eval()
        outputs = self.forward(inputs)
        preds = torch.argmax(outputs, 1)
        loss = F.cross_entropy(outputs, labels)
        self.log("val_loss", loss, prog_bar=True)
        self.valid_acc(preds, labels)
        self.log('val_acc', self.valid_acc, prog_bar=True)
        self.f1_metric(preds, labels)
        self.log("val_f1", self.f1_metric, prog_bar=True)

    def test_step(self, batch, batch_idx):
        inputs, labels = batch
        self.model.eval()
        outputs = self.forward(inputs)
        preds = torch.argmax(outputs, 1)
        loss = F.cross_entropy(outputs, labels)
        self.log("test_loss", loss, prog_bar=True)
        self.valid_acc(preds, labels)
        self.log('test_acc', self.valid_acc, prog_bar=True)
        self.f1_metric(preds, labels)
        self.log("test_f1", self.f1_metric, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=0.001, foreach=True)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, 0.01, steps_per_epoch=947, epochs=self.epochs)
        scheduler = {"scheduler": scheduler, "interval" : "step" }
        return [optimizer], [scheduler]

In [4]:
models = ["levit_128s.fb_dist_in1k", "levit_192.fb_dist_in1k", "levit_256.fb_dist_in1k", "levit_384.fb_dist_in1k",
          "convnextv2_nano.fcmae_ft_in22k_in1k_384", "convnextv2_tiny.fcmae_ft_in22k_in1k_384", "convnextv2_base.fcmae_ft_in22k_in1k_384",
          "convnextv2_large.fcmae_ft_in22k_in1k_384", "tf_efficientnetv2_s.in21k_ft_in1k", "tf_efficientnetv2_m.in21k_ft_in1k",
          "tf_efficientnetv2_l.in21k_ft_in1k", "tf_efficientnetv2_b3.in21k_ft_in1k", "tf_efficientnet_b2.ns_jft_in1k",
          "beitv2_large_patch16_224.in1k_ft_in22k_in1k", "beitv2_base_patch16_224.in1k_ft_in22k_in1k", "vit_base_patch14_dinov2.lvd142m",
          "vit_large_patch14_dinov2.lvd142m", "vit_small_patch14_dinov2.lvd142m", "vit_large_patch14_clip_336.laion2b_ft_in12k_in1k_inat21",
          "vit_large_patch14_clip_336.datacompxl_ft_inat21", "eva02_large_patch14_clip_336.merged2b_ft_inat21", "vit_relpos_medium_patch16_rpn_224.sw_in1k",
          "swinv2_tiny_window8_256.ms_in1k", "swinv2_small_window8_256.ms_in1k", "swinv2_base_window8_256.ms_in1k", "timm/swinv2_large_window12to16_192to256.ms_in22k_ft_in1k"]

In [5]:
for model in models[2:3]:
    gc.collect()
    torch.cuda.empty_cache()
    print(model)
    logger = TensorBoardLogger("runs", version=1, name=f"{model}/logs")
    food_model = Food101Classifier("hf_hub:timm/"+model, 3)
    data_cfg = timm.data.resolve_data_config(food_model.model.pretrained_cfg)
    transform = timm.data.create_transform(**data_cfg)
    food_data = Food101DataModule(transform)
    checkpoint_callback = ModelCheckpoint(monitor="val_acc", dirpath="models", filename=f"{model}/checkpoints")
    trainer = L.Trainer(
        logger=logger,
        accelerator='gpu',
        devices=1,
        precision="16-mixed",
        accumulate_grad_batches=1,
        enable_checkpointing=True,
        callbacks=[checkpoint_callback],
        max_epochs=3,
        fast_dev_run=False,
        profiler="advanced",
    )
    trainer.fit(food_model, food_data)

levit_256.fb_dist_in1k


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name      | Type               | Params
-------------------------------------------------
0 | model     | LevitDistilled     | 18.0 M
1 | train_acc | MulticlassAccuracy | 0     
2 | valid_acc | MulticlassAccuracy | 0     
3 | f1_metric | MulticlassF1Score  | 0     
-------------------------------------------------
18.0 M    Trainable params
0         Non-trainable params
18.0 M    Total params
71.886    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

