In [1]:
import os

import torch
from tqdm import tqdm
from ffcv.fields import BytesField, IntField, RGBImageField
from ffcv.writer import DatasetWriter

from data_utils.data_stats import *
from data_utils.dataloader import get_loader
from utils.metrics import topk_acc, real_acc, AverageMeter
from models.networks import get_model
from data_utils.dataset_to_beton import get_dataset

from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms

import ast
import pickle
import timm
import matplotlib.pyplot as plt
import tarfile
from torchvision.io import read_file
from PIL import Image
from imagenet1k import Imagenet1k
from classes import IMAGENET2012_CLASSES

  from .autonotebook import tqdm as notebook_tqdm


In [18]:
dataset = 'imagenet'                 # One of cifar10, cifar100, stl10, imagenet or imagenet21
architecture = 'B_12-Wi_1024'
data_resolution = 224                # Resolution of data as it is stored
crop_resolution = 224               # Resolution of fine-tuned model (64 for all models we provide)
num_classes = CLASS_DICT[dataset]
data_path = './beton/'
eval_batch_size = 400
checkpoint = 'in21k_imagenet'  #'in21k_cifar100'        # This means you want the network pre-trained on ImageNet21k and finetuned on CIFAR10

#set these if you dont use mlp
ttimmModel = True #set this to true if you use a tim model
timmodelName = "resnet18.a3_in1k" #resnet18.a3_in1k , vit_small_patch16_224.augreg_in1k

In [4]:
# If you did not yet, produce .beton file for CIFAR10 (check README for how to do that for ImageNet)
def create_beton(dataset, mode, data_path, res):
    dataset = get_dataset(dataset, mode, data_path)

    write_path = os.path.join(
        write_path, dataset, mode, f"{mode}_{res}.beton"
    )

    os.makedirs(os.path.dirname(write_path), exist_ok=True)

    writer = DatasetWriter(
        write_path,
        {
            "image": RGBImageField(write_mode="smart", max_resolution=res),
            "label": IntField(),
        },
        num_workers=0,
    )

    writer.from_indexed_dataset(dataset, chunksize=100)

path = "C:/mlp/scaling_mlps/beton/imagenetOriginal/val"
create_beton(dataset, 'test', path, data_resolution)

FileNotFoundError: Couldn't find any class folder in C:/mlp/scaling_mlps/beton/imagenetOriginal/val.

In [19]:
torch.backends.cuda.matmul.allow_tf32 = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


if ttimmModel:
    model = timm.create_model(timmodelName, pretrained=True)
    data_cfg = timm.data.resolve_data_config(model.pretrained_cfg)

    #set the transform to the transform of the model
    transform = timm.data.create_transform(**data_cfg)

else:
    model = get_model(architecture=architecture, resolution=crop_resolution, num_classes=CLASS_DICT[dataset],
                  checkpoint=checkpoint)
model.cuda()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, m

In [67]:
# Get the test loader
from torchvision.transforms import ToTensor
data_path = "C:\\mlp\\scaling_mlps\\beton"
loader = get_loader(
    dataset,
    bs=eval_batch_size,
    mode="test",
    augment=False,
    dev=device,
    mixup=0.0,
    data_path=data_path,
    data_resolution=data_resolution,
    crop_resolution=crop_resolution,
)

len(loader)


Loading C:\mlp\scaling_mlps\beton\imagenet\ffcv\val\val_64.beton


625

In [20]:
#NEED TO RUN FOR IMAGENET DATASETS

#create wordnet to label dict
newdict = {}
currentind = 0
for x in IMAGENET2012_CLASSES:
    newdict[x] = currentind
    currentind +=1

In [21]:
#create custom dataset for Imagenet A
#loads ImagenetA as loader

mean = MEAN_DICT["imagenet"]
std = STD_DICT["imagenet"]


if ttimmModel == False:
    transform =transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
                transforms.Resize((data_resolution, data_resolution))
            ])


pathToImagenet = 'C:/mlp/scaling_mlps/beton/imageneta/indexed/imagenet-a'
dataset = ImageFolder(root = pathToImagenet, transform=transform)

d,v = dataset[100]
dataset.class_to_idx = newdict
loader = DataLoader(dataset,batch_size=eval_batch_size)



In [70]:
#create imagenet dataset downloaded from huggingface
#loads imagenet as loader. 
#The getitem method is really slow but that way the pictures dont get read into the ram 

mean = MEAN_DICT["imagenet"]
std = STD_DICT["imagenet"]

class Imnet(Dataset):

    def __init__(self, path):
        self.imnet = Imagenet1k()

        with open('wordnetToLabel.txt', 'r') as file:
            data = file.read()

        wordnetToLabel = ast.literal_eval(data) 
        self.class_to_idx = wordnetToLabel


        if ttimmModel == False:
            self.transform = transforms.Compose([
                transforms.Normalize(mean, std),
                transforms.Resize((data_resolution, data_resolution))
                
            ])
        else:
            self.transform = transform

        self.files = os.listdir(path)
        
        self.path = path
    
    def __len__(self):
        return 50000
    
    def __getitem__(self, idx):
        image_path = os.path.join(self.path, self.files[idx])
        image = Image.open(image_path)
        if image.mode == 'L':
            image = image.convert('RGB')
        im = self.transform(image)
        synset_id = self.files[idx][-14:-5]
        label = newdict[synset_id]
        sample = [ im.to("cuda"), torch.tensor(label).to("cuda")]

        return sample
    

path = "C:/mlp/scaling_mlps/beton/imagenetOriginal/val"
ImDataset = Imnet(path)
loader = DataLoader(ImDataset,batch_size=eval_batch_size)
print(len(loader))


50000
625


In [16]:
# Define a test function that evaluates test accuracy
@torch.no_grad()
def test(model, loader):
    debug = True

    model.eval()
    total_acc, total_top5 = AverageMeter(), AverageMeter()

    for ims, targs in tqdm(loader, desc="Evaluation"):

        #added to("cuda") to add data to gpu
        ims = ims.cuda()
        if ttimmModel == False:
            ims = torch.reshape(ims, (ims.shape[0], -1))


        
        preds = model(ims).cuda()
        
   
      
        targs = targs.to("cuda")
        if dataset != 'imagenet_real':
            acc, top5 = topk_acc(preds, targs, k=5, avg=True)
        else:
            acc = real_acc(preds, targs, k=5, avg=True)
            top5 = 0

        total_acc.update(acc, ims.shape[0])
        total_top5.update(top5, ims.shape[0])


    return (
        total_acc.get_avg(percentage=True),
        total_top5.get_avg(percentage=True),
    )





In [22]:
test_acc, test_top5 = test(model, loader)

# Print all the stats
print("Test Accuracy        ", "{:.4f}".format(test_acc))
print("Top 5 Test Accuracy          ", "{:.4f}".format(test_top5))

Evaluation: 100%|██████████| 19/19 [00:39<00:00,  2.09s/it]

Test Accuracy         0.1200
Top 5 Test Accuracy           0.6133



