In [1]:
import timm
import torch
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.utils import accuracy, AverageMeter
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Resize, ToTensor, InterpolationMode, CenterCrop, Normalize
from torchvision.datasets import ImageFolder
from kaggle.api.kaggle_api_extended import KaggleApi
import os
from config import CACHE_DIR

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
BATCH_SIZE = 128
DATA_PATH = os.path.join(CACHE_DIR, 'imagenet1k_val')
os.makedirs(DATA_PATH, exist_ok=True)
IMG_SIZE = 224
INTERPOLATION = InterpolationMode.BICUBIC
PIN_MEMORY = True
NUM_WORKERS = 8
SHUFFLE = False
AMP_ENABLE = True

In [None]:
api = KaggleApi()
api.authenticate()
api.dataset_download_files(
    dataset='sautkin/imagenet1kvalid',
    path=DATA_PATH,
    unzip=True,            
    quiet=False            
)

In [4]:
size = int((256 / 224) * IMG_SIZE)
transform = Compose([
    Resize(size, interpolation=INTERPOLATION),
    CenterCrop(IMG_SIZE),
    ToTensor(),
    Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
])

In [5]:
dataset = ImageFolder(root=DATA_PATH, transform=transform)
data_loader_128 = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=SHUFFLE, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, drop_last=False)
data_loader_32 = DataLoader(dataset, batch_size=32, shuffle=SHUFFLE, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, drop_last=False)

In [6]:
@torch.no_grad()
def validate(model_name):
    model = timm.create_model(model_name, pretrained=True)
    model.to(device)
    model.eval()
    
    acc1_meter = AverageMeter()
    acc5_meter = AverageMeter()
    
    if model_name == 'swin_large_patch4_window7_224':
        data_loader = data_loader_32
    else:
        data_loader = data_loader_128
    
    print_freq = len(data_loader) // 10

    for idx, (images, target) in enumerate(data_loader):
        images = images.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        # compute output
        with torch.cuda.amp.autocast(enabled=AMP_ENABLE):
            output = model(images)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))

        acc1_meter.update(acc1.item(), target.size(0))
        acc5_meter.update(acc5.item(), target.size(0))

        if idx % print_freq == 0:
            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            print(
                f'Evaluated [{idx}/{len(data_loader)}]\t'
                f'Top 1 Accuracy: {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
                f'Top 5 Accuracy: {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
                f'Memory Used: {memory_used:.0f} MB')
            
    del model
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    
    print(f'*** Top 1 Accuracy {acc1_meter.avg:.3f}, Top 5 Accuracy {acc5_meter.avg:.3f} ***')
    return acc1_meter.avg, acc5_meter.avg

In [7]:
swin_b_acc1, swin_b_acc5 = validate('swin_base_patch4_window7_224')

  with torch.cuda.amp.autocast(enabled=AMP_ENABLE):


Evaluated [0/391]	Top 1 Accuracy: 98.438 (98.438)	Top 5 Accuracy: 98.438 (98.438)	Memory Used: 2441 MB
Evaluated [39/391]	Top 1 Accuracy: 88.281 (88.457)	Top 5 Accuracy: 99.219 (98.496)	Memory Used: 2441 MB
Evaluated [78/391]	Top 1 Accuracy: 84.375 (88.370)	Top 5 Accuracy: 99.219 (98.161)	Memory Used: 2441 MB
Evaluated [117/391]	Top 1 Accuracy: 95.312 (87.606)	Top 5 Accuracy: 100.000 (98.073)	Memory Used: 2441 MB
Evaluated [156/391]	Top 1 Accuracy: 77.344 (87.694)	Top 5 Accuracy: 96.875 (98.184)	Memory Used: 2441 MB
Evaluated [195/391]	Top 1 Accuracy: 79.688 (86.539)	Top 5 Accuracy: 94.531 (97.824)	Memory Used: 2441 MB
Evaluated [234/391]	Top 1 Accuracy: 78.125 (86.240)	Top 5 Accuracy: 92.188 (97.716)	Memory Used: 2441 MB
Evaluated [273/391]	Top 1 Accuracy: 92.969 (85.795)	Top 5 Accuracy: 98.438 (97.602)	Memory Used: 2441 MB
Evaluated [312/391]	Top 1 Accuracy: 82.812 (85.663)	Top 5 Accuracy: 93.750 (97.526)	Memory Used: 2441 MB
Evaluated [351/391]	Top 1 Accuracy: 72.656 (85.254)	Top 5 

In [8]:
swin_s_acc1, swin_s_acc5 = validate('swin_small_patch4_window7_224')

  with torch.cuda.amp.autocast(enabled=AMP_ENABLE):


Evaluated [0/391]	Top 1 Accuracy: 96.094 (96.094)	Top 5 Accuracy: 97.656 (97.656)	Memory Used: 2441 MB
Evaluated [39/391]	Top 1 Accuracy: 87.500 (87.793)	Top 5 Accuracy: 100.000 (98.086)	Memory Used: 2441 MB
Evaluated [78/391]	Top 1 Accuracy: 78.125 (86.748)	Top 5 Accuracy: 98.438 (97.765)	Memory Used: 2441 MB
Evaluated [117/391]	Top 1 Accuracy: 94.531 (85.765)	Top 5 Accuracy: 100.000 (97.623)	Memory Used: 2441 MB
Evaluated [156/391]	Top 1 Accuracy: 75.000 (86.062)	Top 5 Accuracy: 96.875 (97.741)	Memory Used: 2441 MB
Evaluated [195/391]	Top 1 Accuracy: 75.000 (84.774)	Top 5 Accuracy: 89.844 (97.373)	Memory Used: 2441 MB
Evaluated [234/391]	Top 1 Accuracy: 78.906 (84.468)	Top 5 Accuracy: 90.625 (97.234)	Memory Used: 2441 MB
Evaluated [273/391]	Top 1 Accuracy: 87.500 (83.873)	Top 5 Accuracy: 97.656 (97.129)	Memory Used: 2441 MB
Evaluated [312/391]	Top 1 Accuracy: 77.344 (83.766)	Top 5 Accuracy: 92.188 (97.005)	Memory Used: 2441 MB
Evaluated [351/391]	Top 1 Accuracy: 64.062 (83.290)	Top 5

In [9]:
swin_t_acc1, swin_t_acc5 = validate('swin_tiny_patch4_window7_224')

  with torch.cuda.amp.autocast(enabled=AMP_ENABLE):


Evaluated [0/391]	Top 1 Accuracy: 92.969 (92.969)	Top 5 Accuracy: 98.438 (98.438)	Memory Used: 2441 MB
Evaluated [39/391]	Top 1 Accuracy: 89.062 (86.680)	Top 5 Accuracy: 99.219 (97.539)	Memory Used: 2441 MB
Evaluated [78/391]	Top 1 Accuracy: 75.000 (86.155)	Top 5 Accuracy: 97.656 (97.280)	Memory Used: 2441 MB
Evaluated [117/391]	Top 1 Accuracy: 94.531 (85.818)	Top 5 Accuracy: 99.219 (97.338)	Memory Used: 2441 MB
Evaluated [156/391]	Top 1 Accuracy: 72.656 (85.783)	Top 5 Accuracy: 93.750 (97.358)	Memory Used: 2441 MB
Evaluated [195/391]	Top 1 Accuracy: 64.844 (83.917)	Top 5 Accuracy: 87.500 (96.596)	Memory Used: 2441 MB
Evaluated [234/391]	Top 1 Accuracy: 63.281 (83.235)	Top 5 Accuracy: 85.938 (96.280)	Memory Used: 2441 MB
Evaluated [273/391]	Top 1 Accuracy: 85.156 (82.376)	Top 5 Accuracy: 96.094 (95.980)	Memory Used: 2441 MB
Evaluated [312/391]	Top 1 Accuracy: 81.250 (81.916)	Top 5 Accuracy: 89.062 (95.699)	Memory Used: 2441 MB
Evaluated [351/391]	Top 1 Accuracy: 65.625 (81.292)	Top 5 A

In [10]:
swin_l_acc1, swin_l_acc5 = validate('swin_large_patch4_window7_224')

  with torch.cuda.amp.autocast(enabled=AMP_ENABLE):


Evaluated [0/1563]	Top 1 Accuracy: 96.875 (96.875)	Top 5 Accuracy: 96.875 (96.875)	Memory Used: 2441 MB
Evaluated [156/1563]	Top 1 Accuracy: 96.875 (89.172)	Top 5 Accuracy: 100.000 (98.786)	Memory Used: 2441 MB
Evaluated [312/1563]	Top 1 Accuracy: 90.625 (89.237)	Top 5 Accuracy: 100.000 (98.472)	Memory Used: 2441 MB
Evaluated [468/1563]	Top 1 Accuracy: 96.875 (88.373)	Top 5 Accuracy: 100.000 (98.274)	Memory Used: 2441 MB
Evaluated [624/1563]	Top 1 Accuracy: 100.000 (88.565)	Top 5 Accuracy: 100.000 (98.370)	Memory Used: 2441 MB
Evaluated [780/1563]	Top 1 Accuracy: 81.250 (87.624)	Top 5 Accuracy: 93.750 (98.155)	Memory Used: 2441 MB
Evaluated [936/1563]	Top 1 Accuracy: 87.500 (87.250)	Top 5 Accuracy: 93.750 (98.006)	Memory Used: 2441 MB
Evaluated [1092/1563]	Top 1 Accuracy: 100.000 (86.739)	Top 5 Accuracy: 100.000 (97.919)	Memory Used: 2441 MB
Evaluated [1248/1563]	Top 1 Accuracy: 75.000 (86.684)	Top 5 Accuracy: 96.875 (97.888)	Memory Used: 2441 MB
Evaluated [1404/1563]	Top 1 Accuracy: 6

| Model  | Input Size | Params | Our Top-1 Acc (%) | Paper Top-1 Acc (%) | Top-1 Acc Δ (%) | Our Top-5 Acc (%) | Paper Top-5 Acc (%) | Top-5 Acc Δ (%) |
|--------|------------|--------|-------------------|---------------------|-----------------|-------------------|---------------------|-----------------|
| Swin-T | 224²       | 29M    | 81.2              | 81.2                | +0.0            | 95.5              | 95.5                | +0.0            |
| Swin-S | 224²       | 50M    | 83.3              | 83.2                | +0.1            | 96.9              | 96.2                | +0.7            |
| Swin-B | 224²       | 88M    | 85.2              | 85.5                | -0.3            | 97.5              | 96.5                | +1.0            |
| Swin-L | 224²       | 197M   | 86.2              | 86.3                | -0.1            | 97.9              | 97.9                | +0.0            |