# Getting inference working on a timm-pretrained model

In [1]:
MODEL_NAME = "hf_hub:SamAdamDay/resnet18_cifar10"

ROOT_DIR = "~/Code/Projects/PVG Experiments/data/image_classification/cifar10/raw/"

BATCH_SIZE = 256

FORCE_CPU = True

## Setup

In [2]:
from contextlib import suppress

import torch
from torch.utils.data import DataLoader

from torchvision.datasets import CIFAR10

import timm
from timm import create_model
from timm.models import ResNet
from timm.data import (
    resolve_data_config,
    create_transform,
    create_dataset,
    create_loader,
)
from timm.utils import accuracy

In [3]:
if torch.cuda.is_available() and not FORCE_CPU:
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
device

device(type='cuda')

## Load

In [4]:
model = create_model(MODEL_NAME, pretrained=True)
model.to(device)
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, m

In [5]:
num_classes = model.num_classes
num_classes

1000

In [6]:
data_config = resolve_data_config(
    model=model,
    use_test_size=True,
    verbose=True,
)
data_config

{'input_size': [3, 288, 288],
 'interpolation': 'bicubic',
 'mean': [0.485, 0.456, 0.406],
 'std': [0.229, 0.224, 0.225],
 'crop_pct': 1.0,
 'crop_mode': 'center'}

In [7]:
transform = create_transform(
    input_size=data_config["input_size"],
    is_training=False,
    no_aug=False,
    train_crop_mode=None,
    scale=None,
    ratio=None,
    hflip=0.5,
    vflip=0.,
    color_jitter=0.4,
    color_jitter_prob=None,
    grayscale_prob=0.,
    gaussian_blur_prob=0.,
    auto_augment=None,
    interpolation=data_config['interpolation'],
    mean=data_config['mean'],
    std=data_config['std'],
    crop_pct=data_config['crop_pct'],
    crop_mode=data_config['crop_mode'],
    crop_border_pixels=None,
    re_prob=0.,
    re_mode="const",
    re_count=1,
    re_num_splits=0,
    tf_preprocessing=False,
    use_prefetcher=False,
    separate=False,
)
transform

Compose(
    Resize(size=288, interpolation=bicubic, max_size=None, antialias=True)
    CenterCrop(size=[288, 288])
    ToTensor()
    Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)

In [8]:
# dataset = create_dataset(
#     root=ROOT_DIR,
#     name="torch/cifar10",
#     split="validation",
#     download=True,
#     load_bytes=False,
#     class_map="",
#     num_samples=None,
#     input_key=None,
#     input_img_mode="RGB",
#     target_key=None,
# )
dataset = CIFAR10(
    root=ROOT_DIR,
    train=False,
    download=True,
    transform=transform,
)
dataset

Files already downloaded and verified


Dataset CIFAR10
    Number of datapoints: 10000
    Root location: /home/sam/Code/Projects/PVG Experiments/data/image_classification/cifar10/raw/
    Split: Test
    StandardTransform
Transform: Compose(
               Resize(size=288, interpolation=bicubic, max_size=None, antialias=True)
               CenterCrop(size=[288, 288])
               ToTensor()
               Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
           )

In [9]:
# loader = create_loader(
#     dataset,
#     input_size=data_config["input_size"],
#     batch_size=BATCH_SIZE,
#     use_prefetcher=False,
#     interpolation=data_config["interpolation"],
#     mean=data_config["mean"],
#     std=data_config["std"],
#     num_workers=4,
#     crop_pct=data_config["crop_pct"],
#     crop_mode=data_config["crop_mode"],
#     crop_border_pixels=None,
#     pin_memory=False,
#     device=device,
#     tf_preprocessing=False,
# )
loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=1,
    drop_last=False,
)
loader

<torch.utils.data.dataloader.DataLoader at 0x7f7aa012fa10>

In [10]:
dataset.transform

Compose(
    Resize(size=288, interpolation=bicubic, max_size=None, antialias=True)
    CenterCrop(size=[288, 288])
    ToTensor()
    Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)

## Eval

In [11]:
model.eval()
with torch.no_grad():
    for batch_idx, (input, target) in enumerate(loader):
        target = target.to(device)
        input = input.to(device)
        output = model(input)
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        print( # noqa: T201
            f"Batch [{batch_idx+1}/{len(loader)}]: "
            f"Top-1 accuracy: {acc1.item()}, "
            f"Top-5 accuracy: {acc5.item()}"
        )

Batch [1/40]: Top-1 accuracy: 93.75, Top-5 accuracy: 99.609375
Batch [2/40]: Top-1 accuracy: 92.96875, Top-5 accuracy: 100.0


KeyboardInterrupt: 