In [1]:
import lightning as L
import torch
import timm
import torch.nn.functional as F
import torchmetrics
from typing import Union
from pathlib import Path
from torchvision.datasets import Food101
from torch.utils.data import random_split, DataLoader
torch.set_float32_matmul_precision('high')

class Food101DataModule(L.LightningDataModule):
    def __init__(self, transform, data_dir: Union[str, Path] = "data", batch_size: int = 128) -> None:
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transform

    def prepare_data(self):
        Food101(self.data_dir, split='train', download=True) # type: ignore
        Food101(self.data_dir, split='test', download=True) # type: ignore

    def setup(self, stage: str = 'fit'):
        if stage == 'fit':
            food101_full = Food101(self.data_dir, split='train', download=True, transform=self.transform) # type: ignore
            self.food101_train, self.food101_val = random_split(food101_full, [0.8, 0.2]) # type: ignore

        if stage == 'test':
            self.food101_test = Food101(self.data_dir, split='test', download=True, transform=self.transform) # type: ignore

        if stage == "predict":
            self.food101_predict = Food101(self.data_dir, split='test', download=True, transform=self.transform) # type: ignore

    def train_dataloader(self):
        return DataLoader(self.food101_train, batch_size=self.batch_size, num_workers=4, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.food101_val, batch_size=self.batch_size, num_workers=4, pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.food101_test, batch_size=self.batch_size, num_workers=4, pin_memory=True)

    def predict_dataloader(self):
        return DataLoader(self.food101_predict, batch_size=self.batch_size, num_workers=4, pin_memory=True)

class Food101Classifier(L.LightningModule):
    def __init__(self, model_name: str = "hf_hub:timm/levit_256.fb_dist_in1k") -> None:
        super().__init__()
        self.save_hyperparameters()
        self.num_classes = 101
        self.model = timm.create_model(model_name, pretrained=True, num_classes=101)
        self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=101)
        self.valid_acc = torchmetrics.Accuracy(task="multiclass", num_classes=101)
        self.f1_metric = torchmetrics.F1Score(task="multiclass", num_classes=101)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self.forward(inputs)
        preds = torch.argmax(outputs, 1)
        loss = F.cross_entropy(outputs, labels)
        self.log("train_loss", loss, on_epoch=True, prog_bar=True, sync_dist=True)
        self.train_acc(preds, labels)
        self.log('train_acc', self.train_acc, on_epoch=True, prog_bar=True, sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        self.model.eval()
        outputs = self.forward(inputs)
        preds = torch.argmax(outputs, 1)
        loss = F.cross_entropy(outputs, labels)
        self.log("val_loss", loss, prog_bar=True, sync_dist=True)
        self.valid_acc(preds, labels)
        self.log('val_acc', self.valid_acc, prog_bar=True, sync_dist=True)
        self.f1_metric(preds, labels)
        self.log("val_f1", self.f1_metric, prog_bar=True, sync_dist=True)

    def test_step(self, batch, batch_idx):
        inputs, labels = batch
        self.model.eval()
        outputs = self.forward(inputs)
        preds = torch.argmax(outputs, 1)
        loss = F.cross_entropy(outputs, labels)
        self.log("test_loss", loss, prog_bar=True, sync_dist=True)
        self.valid_acc(preds, labels)
        self.log('test_acc', self.valid_acc, prog_bar=True, sync_dist=True)
        self.f1_metric(preds, labels)
        self.log("test_f1", self.f1_metric, prog_bar=True, sync_dist=True)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=0.001, foreach=True)

In [2]:
model = Food101Classifier.load_from_checkpoint("~/SeeFood102/models/levit_256.fb_dist_in1k/checkpoints.ckpt").to('cpu')

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [3]:
model

Food101Classifier(
  (model): LevitDistilled(
    (stem): Stem16(
      (conv1): ConvNorm(
        (linear): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (act1): Hardswish()
      (conv2): ConvNorm(
        (linear): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (act2): Hardswish()
      (conv3): ConvNorm(
        (linear): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (act3): Hardswish()
      (conv4): ConvNorm(
        (linear): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=

In [4]:
from torchvision import transforms
preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

In [5]:
test_data = Food101('/root/SeeFood102/data', split='test', transform=preprocess)
test_loader = DataLoader(test_data, batch_size=128, num_workers=4, pin_memory=True)

In [6]:
from tqdm import tqdm


def test_step(model, data_loader, accuracy_fn: torchmetrics.Metric = torchmetrics.Accuracy('multiclass', num_classes=101)):
    test_acc = 0
    model.eval()
    model.freeze()
    # Turn on inference context manager
    with torch.inference_mode():
        for images, labels in tqdm(data_loader,
                                    total=len(data_loader),
                                    desc='Making predictions:'):
            # 1. Forward pass
            preds = model(images)

            # 2. Calculate accuracy
            test_acc += accuracy_fn(preds.argmax(dim=1), labels)

        # Adjust metrics and print out
        test_acc /= len(data_loader)
        print(f"Test accuracy: {test_acc:.2f}")
    return test_acc.cpu()

In [7]:
test_step(model, test_loader)

Making predictions:: 100%|██████████| 198/198 [02:40<00:00,  1.23it/s]


Test accuracy: 0.66


tensor(0.6628)