In [1]:
import os

import torch
from tqdm import tqdm
from ffcv.fields import BytesField, IntField, RGBImageField
from ffcv.writer import DatasetWriter

from data_utils.data_stats import *
from data_utils.dataloader import get_loader
from utils.metrics import topk_acc, real_acc, AverageMeter
from models.networks import get_model
from data_utils.dataset_to_beton import get_dataset

OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.
objc[86151]: Class CaptureDelegate is implemented in both /opt/homebrew/Cellar/opencv/4.10.0_12/lib/libopencv_videoio.4.10.0.dylib (0x11e2f4880) and /Users/fxx1047/src/scaling_mlps/.venv/lib/python3.10/site-packages/cv2/cv2.abi3.so (0x2895d65d8). One of the two will be used. Which one is undefined.
objc[86151]: Class CVWindow is implemented in both /opt/homebrew/Cellar/opencv/4.10.0_12/lib/libopencv_highgui.4.10.0.dylib (0x11de30b28) and /Users/fxx1047/src/scaling_mlps/.venv/lib/python3.10/site-packages/cv2/cv2.abi3.so (0x2895d6628). One of the two will be used. Which one is undefined.
objc[86151]: Class CVView is implemented in both /opt/homebrew/Cellar/opencv/4.10.0_12/lib/libopencv_highgui.4.10.0.dylib (0x11de30b50) and /Users/fxx1047/src/scaling_mlps/.venv/lib/python3.10/site-packages/cv2/cv2.abi3.so (0x2895d6650). One of the two will be used. Which one is undefined.
objc[86151]: Class 

In [2]:
dataset = 'cifar10'                 # One of cifar10, cifar100, stl10, imagenet or imagenet21
architecture = 'B_6-Wi_512'
data_resolution = 32                # Resolution of data as it is stored
crop_resolution = 64                # Resolution of fine-tuned model (64 for all models we provide)
num_classes = CLASS_DICT[dataset]
data_path = './beton/'
eval_batch_size = 1024
checkpoint = 'in21k_cifar10'        # This means you want the network pre-trained on ImageNet21k and finetuned on CIFAR10

In [3]:
torch.backends.cuda.matmul.allow_tf32 = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define the model and specify the pre-trained weights
model = get_model(architecture=architecture, resolution=crop_resolution, num_classes=CLASS_DICT[dataset],
                  checkpoint='in21k_cifar10')

Weights already downloaded
Load_state output <All keys matched successfully>


  for k, v in torch.load(weight_path, map_location=device).items()


In [4]:
# Get the test loader
loader = get_loader(
    dataset,
    bs=eval_batch_size,
    mode="test",
    augment=False,
    dev=device,
    mixup=0.0,
    data_path=data_path,
    data_resolution=data_resolution,
    crop_resolution=crop_resolution,
)

Loading ./beton/cifar10/test/test_32.beton


In [5]:
# Define a test function that evaluates test accuracy
@torch.no_grad()
def test(model, loader):
    model.eval()
    total_acc, total_top5 = AverageMeter(), AverageMeter()

    for ims, targs in tqdm(loader, desc="Evaluation"):
        ims = torch.reshape(ims, (ims.shape[0], -1))
        preds = model(ims)

        if dataset != 'imagenet_real':
            acc, top5 = topk_acc(preds, targs, k=5, avg=True)
        else:
            acc = real_acc(preds, targs, k=5, avg=True)
            top5 = 0

        total_acc.update(acc, ims.shape[0])
        total_top5.update(top5, ims.shape[0])


    return (
        total_acc.get_avg(percentage=True),
        total_top5.get_avg(percentage=True),
    )

In [6]:
test_acc, test_top5 = test(model, loader)

# Print all the stats
print("Test Accuracy        ", "{:.4f}".format(test_acc))
print("Top 5 Test Accuracy          ", "{:.4f}".format(test_top5))

Evaluation: 100%|██████████| 10/10 [00:01<00:00,  5.69it/s]

Test Accuracy         89.2200
Top 5 Test Accuracy           99.4400



