In [1]:
!mkdir logits

import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [2]:
import os
from tqdm.notebook import tqdm

def save_logits(model, dataloader):
    model.eval()
    device = model.device
    tqdm_dataloader = tqdm(dataloader, desc=model.name,
                           leave=True)
    all_logits = []
    total_acc = 0
    with torch.no_grad():
        for batch_x, batch_y in tqdm_dataloader:
            logits = model(batch_x.to(model.device))
            all_logits.append(logits.cpu())
            acc = (logits.argmax(dim=1) == batch_y.to(model.device))
            total_acc += acc.sum().item()
            acc = acc.float().mean().item()
            tqdm_dataloader.set_postfix({'batch_acc': acc})
    total_acc *= 100 / len(dataloader.dataset)
    print(f'{model.name} accuracy: {total_acc:.2f}')
    all_logits = torch.cat(all_logits)
    torch.save(all_logits, os.path.join('logits', dataloader.name, model.name + '.pt'))

# CIFAR-10 v1 °˖✧◝(⁰▿⁰)◜✧˖°

In [7]:
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2471, 0.2435, 0.2616]),
])

datadir = './datasets/cifar10/'
dataset = CIFAR10(train=False, download=True,
                  root=datadir, transform=transform)
dataloader = DataLoader(dataset, batch_size=256, num_workers=2, shuffle=False)
dataloader.name = 'cifar10/v1'

Files already downloaded and verified


In [4]:
!mkdir logits/cifar10
with open('logits/cifar10/targets.txt', 'w') as fout:
    fout.write('\n'.join(str(target) for target in dataset.targets))

In [5]:
# https://github.com/huyvnphan/PyTorch_CIFAR10
!git clone https://github.com/huyvnphan/PyTorch_CIFAR10.git
!wget -q https://rutgers.box.com/shared/static/gkw08ecs797j2et1ksmbg1w5t3idf5r5.zip -O huyvnphan_models.zip
!unzip huyvnphan_models.zip -d PyTorch_CIFAR10/cifar10_models

Cloning into 'PyTorch_CIFAR10'...
remote: Enumerating objects: 640, done.[K
remote: Counting objects: 100% (88/88), done.[K
remote: Compressing objects: 100% (68/68), done.[K
remote: Total 640 (delta 41), reused 49 (delta 20), pack-reused 552[K
Receiving objects: 100% (640/640), 6.59 MiB | 12.38 MiB/s, done.
Resolving deltas: 100% (224/224), done.
Archive:  huyvnphan_models.zip
   creating: PyTorch_CIFAR10/cifar10_models/state_dicts/
  inflating: PyTorch_CIFAR10/cifar10_models/state_dicts/googlenet.pt  
  inflating: PyTorch_CIFAR10/cifar10_models/state_dicts/vgg11_bn.pt  
  inflating: PyTorch_CIFAR10/cifar10_models/state_dicts/vgg13_bn.pt  
  inflating: PyTorch_CIFAR10/cifar10_models/state_dicts/resnet18.pt  
  inflating: PyTorch_CIFAR10/cifar10_models/state_dicts/vgg19_bn.pt  
  inflating: PyTorch_CIFAR10/cifar10_models/state_dicts/vgg16_bn.pt  
  inflating: PyTorch_CIFAR10/cifar10_models/state_dicts/mobilenet_v2.pt  
  inflating: PyTorch_CIFAR10/cifar10_models/state_dicts/incepti

In [6]:
!mkdir logits/cifar10/v1

In [8]:
from PyTorch_CIFAR10.cifar10_models.densenet import densenet121, densenet161, densenet169
from PyTorch_CIFAR10.cifar10_models.googlenet import googlenet
from PyTorch_CIFAR10.cifar10_models.inception import inception_v3
from PyTorch_CIFAR10.cifar10_models.mobilenetv2 import mobilenet_v2
from PyTorch_CIFAR10.cifar10_models.resnet import resnet18, resnet34, resnet50
from PyTorch_CIFAR10.cifar10_models.vgg import vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn

model_names = ['densenet121', 'densenet161', 'densenet169',
               'googlenet',
               'inception_v3',
               'mobilenet_v2',
               'resnet18', 'resnet34', 'resnet50',
               'vgg11_bn', 'vgg13_bn', 'vgg16_bn', 'vgg19_bn']

for name in model_names:
    Model = globals()[name]
    model = Model(pretrained=True).to(device)
    model.name = name
    model.device = device
    save_logits(model, dataloader)

HBox(children=(FloatProgress(value=0.0, description='densenet121', max=40.0, style=ProgressStyle(description_w…


densenet121 accuracy: 94.06


HBox(children=(FloatProgress(value=0.0, description='densenet161', max=40.0, style=ProgressStyle(description_w…


densenet161 accuracy: 94.07


HBox(children=(FloatProgress(value=0.0, description='densenet169', max=40.0, style=ProgressStyle(description_w…


densenet169 accuracy: 94.05


HBox(children=(FloatProgress(value=0.0, description='googlenet', max=40.0, style=ProgressStyle(description_wid…


googlenet accuracy: 92.85


HBox(children=(FloatProgress(value=0.0, description='inception_v3', max=40.0, style=ProgressStyle(description_…


inception_v3 accuracy: 93.74


HBox(children=(FloatProgress(value=0.0, description='mobilenet_v2', max=40.0, style=ProgressStyle(description_…


mobilenet_v2 accuracy: 93.91


HBox(children=(FloatProgress(value=0.0, description='resnet18', max=40.0, style=ProgressStyle(description_widt…


resnet18 accuracy: 93.07


HBox(children=(FloatProgress(value=0.0, description='resnet34', max=40.0, style=ProgressStyle(description_widt…


resnet34 accuracy: 93.33


HBox(children=(FloatProgress(value=0.0, description='resnet50', max=40.0, style=ProgressStyle(description_widt…


resnet50 accuracy: 93.65


HBox(children=(FloatProgress(value=0.0, description='vgg11_bn', max=40.0, style=ProgressStyle(description_widt…


vgg11_bn accuracy: 92.39


HBox(children=(FloatProgress(value=0.0, description='vgg13_bn', max=40.0, style=ProgressStyle(description_widt…


vgg13_bn accuracy: 94.21


HBox(children=(FloatProgress(value=0.0, description='vgg16_bn', max=40.0, style=ProgressStyle(description_widt…


vgg16_bn accuracy: 94.00


HBox(children=(FloatProgress(value=0.0, description='vgg19_bn', max=40.0, style=ProgressStyle(description_widt…


vgg19_bn accuracy: 93.95


In [9]:
del model
torch.cuda.empty_cache()
!rm -rf PyTorch_CIFAR10 huyvnphan_models.zip

# CIFAR-10 v2

In [10]:
# https://github.com/chenyaofo/pytorch-cifar-models

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
])

datadir = './datasets/cifar10/'
dataset = CIFAR10(train=False, download=True,
                  root=datadir, transform=transform)
dataloader = DataLoader(dataset, batch_size=256, num_workers=2, shuffle=False)
dataloader.name = 'cifar10/v2'

Files already downloaded and verified


In [11]:
!mkdir logits/cifar10/v2

In [12]:
model_names = ['mobilenetv2_x0_5', 'mobilenetv2_x1_0', 'mobilenetv2_x1_4',
            #    'repvgg_a0', 'repvgg_a1', 'repvgg_a2',
               'resnet20', 'resnet32', 'resnet44', 'resnet56',
               'shufflenetv2_x0_5', 'shufflenetv2_x1_0', 'shufflenetv2_x1_5', 'shufflenetv2_x2_0',
               'vgg11_bn', 'vgg13_bn', 'vgg16_bn', 'vgg19_bn']
            #    'vit_b16', 'vit_b32', 'vit_h14', 'vit_l16', 'vit_l32']

for name in model_names:
    model = torch.hub.load('chenyaofo/pytorch-cifar-models', f'cifar10_{name}',
                           pretrained=True).to(device)
    model.name = name
    model.device = device
    save_logits(model, dataloader)

Downloading: "https://github.com/chenyaofo/pytorch-cifar-models/archive/master.zip" to /root/.cache/torch/hub/master.zip
Downloading: "https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar10_mobilenetv2_x0_5-ca14ced9.pt" to /root/.cache/torch/hub/checkpoints/cifar10_mobilenetv2_x0_5-ca14ced9.pt


HBox(children=(FloatProgress(value=0.0, max=2986233.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='mobilenetv2_x0_5', max=40.0, style=ProgressStyle(descript…


mobilenetv2_x0_5 accuracy: 93.12


Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master
Downloading: "https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar10_mobilenetv2_x1_0-fe6a5b48.pt" to /root/.cache/torch/hub/checkpoints/cifar10_mobilenetv2_x1_0-fe6a5b48.pt


HBox(children=(FloatProgress(value=0.0, max=9193273.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='mobilenetv2_x1_0', max=40.0, style=ProgressStyle(descript…


mobilenetv2_x1_0 accuracy: 94.05


Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master
Downloading: "https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar10_mobilenetv2_x1_4-3bbbd6e2.pt" to /root/.cache/torch/hub/checkpoints/cifar10_mobilenetv2_x1_4-3bbbd6e2.pt


HBox(children=(FloatProgress(value=0.0, max=17635961.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='mobilenetv2_x1_4', max=40.0, style=ProgressStyle(descript…


mobilenetv2_x1_4 accuracy: 94.21


Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master
Downloading: "https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar10_resnet20-4118986f.pt" to /root/.cache/torch/hub/checkpoints/cifar10_resnet20-4118986f.pt


HBox(children=(FloatProgress(value=0.0, max=1139055.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='resnet20', max=40.0, style=ProgressStyle(description_widt…


resnet20 accuracy: 92.60


Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master
Downloading: "https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar10_resnet32-ef93fc4d.pt" to /root/.cache/torch/hub/checkpoints/cifar10_resnet32-ef93fc4d.pt


HBox(children=(FloatProgress(value=0.0, max=1944567.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='resnet32', max=40.0, style=ProgressStyle(description_widt…


resnet32 accuracy: 93.53


Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master
Downloading: "https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar10_resnet44-2a3cabcb.pt" to /root/.cache/torch/hub/checkpoints/cifar10_resnet44-2a3cabcb.pt


HBox(children=(FloatProgress(value=0.0, max=2750015.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='resnet44', max=40.0, style=ProgressStyle(description_widt…


resnet44 accuracy: 94.01


Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master
Downloading: "https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar10_resnet56-187c023a.pt" to /root/.cache/torch/hub/checkpoints/cifar10_resnet56-187c023a.pt


HBox(children=(FloatProgress(value=0.0, max=3555463.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='resnet56', max=40.0, style=ProgressStyle(description_widt…


resnet56 accuracy: 94.37


Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master
Downloading: "https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar10_shufflenetv2_x0_5-1308b4e9.pt" to /root/.cache/torch/hub/checkpoints/cifar10_shufflenetv2_x0_5-1308b4e9.pt


HBox(children=(FloatProgress(value=0.0, max=1554833.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='shufflenetv2_x0_5', max=40.0, style=ProgressStyle(descrip…


shufflenetv2_x0_5 accuracy: 90.65


Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master
Downloading: "https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar10_shufflenetv2_x1_0-98807be3.pt" to /root/.cache/torch/hub/checkpoints/cifar10_shufflenetv2_x1_0-98807be3.pt


HBox(children=(FloatProgress(value=0.0, max=5230673.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='shufflenetv2_x1_0', max=40.0, style=ProgressStyle(descrip…


shufflenetv2_x1_0 accuracy: 93.30


Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master
Downloading: "https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar10_shufflenetv2_x1_5-296694dd.pt" to /root/.cache/torch/hub/checkpoints/cifar10_shufflenetv2_x1_5-296694dd.pt


HBox(children=(FloatProgress(value=0.0, max=10164241.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='shufflenetv2_x1_5', max=40.0, style=ProgressStyle(descrip…


shufflenetv2_x1_5 accuracy: 93.57


Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master
Downloading: "https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar10_shufflenetv2_x2_0-ec31611c.pt" to /root/.cache/torch/hub/checkpoints/cifar10_shufflenetv2_x2_0-ec31611c.pt


HBox(children=(FloatProgress(value=0.0, max=21707473.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='shufflenetv2_x2_0', max=40.0, style=ProgressStyle(descrip…


shufflenetv2_x2_0 accuracy: 93.98


Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master
Downloading: "https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar10_vgg11_bn-eaeebf42.pt" to /root/.cache/torch/hub/checkpoints/cifar10_vgg11_bn-eaeebf42.pt


HBox(children=(FloatProgress(value=0.0, max=39068509.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='vgg11_bn', max=40.0, style=ProgressStyle(description_widt…


vgg11_bn accuracy: 92.79


Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master
Downloading: "https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar10_vgg13_bn-c01e4a43.pt" to /root/.cache/torch/hub/checkpoints/cifar10_vgg13_bn-c01e4a43.pt


HBox(children=(FloatProgress(value=0.0, max=39814235.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='vgg13_bn', max=40.0, style=ProgressStyle(description_widt…


vgg13_bn accuracy: 94.00


Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master
Downloading: "https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar10_vgg16_bn-6ee7ea24.pt" to /root/.cache/torch/hub/checkpoints/cifar10_vgg16_bn-6ee7ea24.pt


HBox(children=(FloatProgress(value=0.0, max=61080472.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='vgg16_bn', max=40.0, style=ProgressStyle(description_widt…


vgg16_bn accuracy: 94.16


Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master
Downloading: "https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar10_vgg19_bn-57191229.pt" to /root/.cache/torch/hub/checkpoints/cifar10_vgg19_bn-57191229.pt


HBox(children=(FloatProgress(value=0.0, max=82346709.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='vgg19_bn', max=40.0, style=ProgressStyle(description_widt…


vgg19_bn accuracy: 93.91


# CIFAR-100

In [17]:
from torchvision.datasets import CIFAR100

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.507, 0.4865, 0.4409], [0.2673, 0.2564, 0.2761]),
])

datadir = './datasets/cifar100/'
dataset = CIFAR100(train=False, download=True,
                   root=datadir, transform=transform)
dataloader = DataLoader(dataset, batch_size=256, num_workers=2, shuffle=False)
dataloader.name = 'cifar100'

Files already downloaded and verified


In [14]:
!mkdir logits/cifar100
with open('logits/cifar100/targets.txt', 'w') as fout:
    fout.write('\n'.join(str(target) for target in dataset.targets))

In [15]:
model_names = ['mobilenetv2_x0_5', 'mobilenetv2_x1_0', 'mobilenetv2_x1_4',
            #    'repvgg_a0', 'repvgg_a1', 'repvgg_a2',
               'resnet20', 'resnet32', 'resnet44', 'resnet56',
               'shufflenetv2_x0_5', 'shufflenetv2_x1_0', 'shufflenetv2_x1_5', 'shufflenetv2_x2_0',
               'vgg11_bn', 'vgg13_bn', 'vgg16_bn', 'vgg19_bn']
            #    'vit_b16', 'vit_b32', 'vit_h14', 'vit_l16', 'vit_l32']

In [18]:
for name in model_names:
    model = torch.hub.load('chenyaofo/pytorch-cifar-models', f'cifar100_{name}',
                           pretrained=True).to(device)
    model.name = name
    model.device = device
    save_logits(model, dataloader)

Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master


HBox(children=(FloatProgress(value=0.0, description='mobilenetv2_x0_5', max=40.0, style=ProgressStyle(descript…


mobilenetv2_x0_5 accuracy: 71.17


Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master


HBox(children=(FloatProgress(value=0.0, description='mobilenetv2_x1_0', max=40.0, style=ProgressStyle(descript…


mobilenetv2_x1_0 accuracy: 74.29


Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master


HBox(children=(FloatProgress(value=0.0, description='mobilenetv2_x1_4', max=40.0, style=ProgressStyle(descript…


mobilenetv2_x1_4 accuracy: 76.29


Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master


HBox(children=(FloatProgress(value=0.0, description='resnet20', max=40.0, style=ProgressStyle(description_widt…


resnet20 accuracy: 68.83


Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master


HBox(children=(FloatProgress(value=0.0, description='resnet32', max=40.0, style=ProgressStyle(description_widt…


resnet32 accuracy: 70.16


Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master


HBox(children=(FloatProgress(value=0.0, description='resnet44', max=40.0, style=ProgressStyle(description_widt…


resnet44 accuracy: 71.63


Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master


HBox(children=(FloatProgress(value=0.0, description='resnet56', max=40.0, style=ProgressStyle(description_widt…


resnet56 accuracy: 72.63


Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master


HBox(children=(FloatProgress(value=0.0, description='shufflenetv2_x0_5', max=40.0, style=ProgressStyle(descrip…


shufflenetv2_x0_5 accuracy: 67.82


Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master


HBox(children=(FloatProgress(value=0.0, description='shufflenetv2_x1_0', max=40.0, style=ProgressStyle(descrip…


shufflenetv2_x1_0 accuracy: 72.58


Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master


HBox(children=(FloatProgress(value=0.0, description='shufflenetv2_x1_5', max=40.0, style=ProgressStyle(descrip…


shufflenetv2_x1_5 accuracy: 74.23


Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master


HBox(children=(FloatProgress(value=0.0, description='shufflenetv2_x2_0', max=40.0, style=ProgressStyle(descrip…


shufflenetv2_x2_0 accuracy: 75.48


Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master


HBox(children=(FloatProgress(value=0.0, description='vgg11_bn', max=40.0, style=ProgressStyle(description_widt…


vgg11_bn accuracy: 70.78


Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master


HBox(children=(FloatProgress(value=0.0, description='vgg13_bn', max=40.0, style=ProgressStyle(description_widt…


vgg13_bn accuracy: 74.63


Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master


HBox(children=(FloatProgress(value=0.0, description='vgg16_bn', max=40.0, style=ProgressStyle(description_widt…


vgg16_bn accuracy: 74.00


Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master


HBox(children=(FloatProgress(value=0.0, description='vgg19_bn', max=40.0, style=ProgressStyle(description_widt…


vgg19_bn accuracy: 73.87


In [20]:
!zip -r logits.zip logits

  adding: logits/ (stored 0%)
  adding: logits/cifar10/ (stored 0%)
  adding: logits/cifar10/v2/ (stored 0%)
  adding: logits/cifar10/v2/mobilenetv2_x1_0.pt (deflated 8%)
  adding: logits/cifar10/v2/vgg11_bn.pt (deflated 8%)
  adding: logits/cifar10/v2/shufflenetv2_x1_0.pt (deflated 8%)
  adding: logits/cifar10/v2/resnet32.pt (deflated 8%)
  adding: logits/cifar10/v2/vgg16_bn.pt (deflated 7%)
  adding: logits/cifar10/v2/resnet44.pt (deflated 9%)
  adding: logits/cifar10/v2/vgg19_bn.pt (deflated 7%)
  adding: logits/cifar10/v2/mobilenetv2_x0_5.pt (deflated 8%)
  adding: logits/cifar10/v2/shufflenetv2_x2_0.pt (deflated 9%)
  adding: logits/cifar10/v2/shufflenetv2_x0_5.pt (deflated 8%)
  adding: logits/cifar10/v2/resnet20.pt (deflated 8%)
  adding: logits/cifar10/v2/resnet56.pt (deflated 9%)
  adding: logits/cifar10/v2/mobilenetv2_x1_4.pt (deflated 9%)
  adding: logits/cifar10/v2/shufflenetv2_x1_5.pt (deflated 8%)
  adding: logits/cifar10/v2/vgg13_bn.pt (deflated 8%)
  adding: logits/cifa

# ImageNet 〜(＞＜)〜

In [None]:
# from torchvision.datasets import ImageFolder
# from torchvision import transforms
# from torch.utils.data import DataLoader

# transform = transforms.Compose([
#     transforms.Resize(256),
#     transforms.CenterCrop(224),
#     transforms.ToTensor(),
#     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
# ])

# datadir = '../input/imagenetpt/val'
# dataset = ImageFolder(root=datadir, transform=transform)
# dataloader = DataLoader(dataset, batch_size=100, num_workers=2, shuffle=False)
# dataloader.name = 'imagenet'
# from torchvision.models import alexnet
# model = alexnet(pretrained=True)
# model.name = 'alexnet'
# model.device = device
# batch_x, batch_y = next(iter(dataloader))
# (model(batch_x) == batch_y)

In [None]:
!git clone https://github.com/rwightman/pytorch-image-models.git
!pip install git+https://github.com/rwightman/pytorch-image-models.git

Cloning into 'pytorch-image-models'...
remote: Enumerating objects: 6369, done.[K
remote: Counting objects: 100% (546/546), done.[K
remote: Compressing objects: 100% (224/224), done.[K
remote: Total 6369 (delta 345), reused 460 (delta 316), pack-reused 5823[K
Receiving objects: 100% (6369/6369), 17.13 MiB | 27.36 MiB/s, done.
Resolving deltas: 100% (4614/4614), done.
Collecting git+https://github.com/rwightman/pytorch-image-models.git
  Cloning https://github.com/rwightman/pytorch-image-models.git to /tmp/pip-req-build-suerldda
  Running command git clone -q https://github.com/rwightman/pytorch-image-models.git /tmp/pip-req-build-suerldda
Building wheels for collected packages: timm
  Building wheel for timm (setup.py) ... [?25ldone
[?25h  Created wheel for timm: filename=timm-0.4.8-py3-none-any.whl size=338238 sha256=a7ed44757ce92d796e463b8aa0bb856fcddf6d42937973b31442fcb53e157ec8
  Stored in directory: /tmp/pip-ephem-wheel-cache-mwagxw5c/wheels/a0/ec/5f/289118b747739bb1e02e36cf

In [None]:
# timm.list_models(pretrained=True)

In [None]:
%%writefile pytorch-image-models/val_logits.py
#!/usr/bin/env python3
""" ImageNet Validation Script

This is intended to be a lean and easily modifiable ImageNet validation script for evaluating pretrained
models or training checkpoints against ImageNet or similarly organized image datasets. It prioritizes
canonical PyTorch, standard Python style, and good performance. Repurpose as you see fit.

Hacked together by Ross Wightman (https://github.com/rwightman)

https://github.com/rwightman/pytorch-image-models/blob/master/validate.py
"""
import argparse
import os
import csv
import glob
import time
import logging
import torch
import torch.nn as nn
import torch.nn.parallel
from collections import OrderedDict
from contextlib import suppress

from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy

has_apex = False
try:
    from apex import amp
    has_apex = True
except ImportError:
    pass

has_native_amp = False
try:
    if getattr(torch.cuda.amp, 'autocast') is not None:
        has_native_amp = True
except AttributeError:
    pass

torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('validate')


parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
parser.add_argument('data', metavar='DIR',
                    help='path to dataset')
parser.add_argument('--dataset', '-d', metavar='NAME', default='',
                    help='dataset type (default: ImageFolder/ImageTar if empty)')
parser.add_argument('--split', metavar='NAME', default='validation',
                    help='dataset split (default: validation)')
parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
                    help='model architecture (default: dpn92)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 2)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
                    metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--img-size', default=None, type=int,
                    metavar='N', help='Input image dimension, uses model default if empty')
parser.add_argument('--input-size', default=None, nargs=3, type=int,
                    metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
parser.add_argument('--crop-pct', default=None, type=float,
                    metavar='N', help='Input image center crop pct')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
                    help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float,  nargs='+', default=None, metavar='STD',
                    help='Override std deviation of of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
                    help='Image resize interpolation type (overrides model)')
parser.add_argument('--num-classes', type=int, default=None,
                    help='Number classes in dataset')
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
                    help='path to class to idx mapping file (default: "")')
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
                    help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
parser.add_argument('--log-freq', default=10, type=int,
                    metavar='N', help='batch logging frequency (default: 10)')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                    help='use pre-trained model')
parser.add_argument('--num-gpu', type=int, default=1,
                    help='Number of GPUS to use')
parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
                    help='disable test time pool')
parser.add_argument('--no-prefetcher', action='store_true', default=False,
                    help='disable fast prefetcher')
parser.add_argument('--pin-mem', action='store_true', default=False,
                    help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--channels-last', action='store_true', default=False,
                    help='Use channels_last memory layout')
parser.add_argument('--amp', action='store_true', default=False,
                    help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.')
parser.add_argument('--apex-amp', action='store_true', default=False,
                    help='Use NVIDIA Apex AMP mixed precision')
parser.add_argument('--native-amp', action='store_true', default=False,
                    help='Use Native Torch AMP mixed precision')
parser.add_argument('--tf-preprocessing', action='store_true', default=False,
                    help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
                    help='use ema version of weights if present')
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
                    help='convert model torchscript for inference')
parser.add_argument('--legacy-jit', dest='legacy_jit', action='store_true',
                    help='use legacy jit mode for pytorch 1.5/1.5.1/1.6 to get back fusion performance')
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
                    help='Output csv file for validation results (summary)')
parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
                    help='Real labels JSON file for imagenet evaluation')
parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME',
                    help='Valid label indices txt file for validation of partial label space')


def validate(args):
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher
    amp_autocast = suppress  # do nothing
    if args.amp:
        if has_native_amp:
            args.native_amp = True
        elif has_apex:
            args.apex_amp = True
        else:
            _logger.warning("Neither APEX or Native Torch AMP is available.")
    assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
    if args.native_amp:
        amp_autocast = torch.cuda.amp.autocast
        _logger.info('Validating in mixed precision with native PyTorch AMP.')
    elif args.apex_amp:
        _logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
    else:
        _logger.info('Validating in float32. AMP not enabled.')

    if args.legacy_jit:
        set_jit_legacy()

    # create model
    model = create_model(
        args.model,
        pretrained=args.pretrained,
        num_classes=args.num_classes,
        in_chans=3,
        global_pool=args.gp,
        scriptable=args.torchscript)
    if args.num_classes is None:
        assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
        args.num_classes = model.num_classes

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])
    _logger.info('Model %s created, param count: %d' % (args.model, param_count))

    data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True)
    test_time_pool = False
    if not args.no_test_pool:
        model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True)

    if args.torchscript:
        torch.jit.optimized_execution(True)
        model = torch.jit.script(model)

    model = model.cuda()
    if args.apex_amp:
        model = amp.initialize(model, opt_level='O1')

    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))

    criterion = nn.CrossEntropyLoss().cuda()

    dataset = create_dataset(
        root=args.data, name=args.dataset, split=args.split,
        load_bytes=args.tf_preprocessing, class_map=args.class_map)

    if args.valid_labels:
        with open(args.valid_labels, 'r') as f:
            valid_labels = {int(line.rstrip()) for line in f}
            valid_labels = [i in valid_labels for i in range(args.num_classes)]
    else:
        valid_labels = None

    if args.real_labels:
        real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels)
    else:
        real_labels = None

    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
    loader = create_loader(
        dataset,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        crop_pct=crop_pct,
        pin_memory=args.pin_mem,
        tf_preprocessing=args.tf_preprocessing)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()
    all_logits = []  # !!!
    with torch.no_grad():
        # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
        input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda()
        if args.channels_last:
            input = input.contiguous(memory_format=torch.channels_last)
        model(input)
        end = time.time()
        for batch_idx, (input, target) in enumerate(loader):
            if args.no_prefetcher:
                target = target.cuda()
                input = input.cuda()
            if args.channels_last:
                input = input.contiguous(memory_format=torch.channels_last)

            # compute output
            with amp_autocast():
                output = model(input)

            if valid_labels is not None:
                output = output[:, valid_labels]
            loss = criterion(output, target)

            if real_labels is not None:
                real_labels.add_result(output)

            # measure accuracy and record loss
#             print(target)  # !!! — yes, they are [0] * 50, [1] * 50, ...
            acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(acc1.item(), input.size(0))
            top5.update(acc5.item(), input.size(0))
            
            # append logits
            all_logits.append(output.cpu())
            
            
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if batch_idx % args.log_freq == 0:
                _logger.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                        batch_idx, len(loader), batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                        loss=losses, top1=top1, top5=top5))
    # save logits
    all_logits = torch.cat(all_logits)  # !!!
    torch.save(all_logits, f'/kaggle/working/logits/imagenet/{args.model}.pt')  # !!!
    
    if real_labels is not None:
        # real labels mode replaces topk values at the end
        top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5)
    else:
        top1a, top5a = top1.avg, top5.avg
    results = OrderedDict(
        top1=round(top1a, 4), top1_err=round(100 - top1a, 4),
        top5=round(top5a, 4), top5_err=round(100 - top5a, 4),
        param_count=round(param_count / 1e6, 2),
        img_size=data_config['input_size'][-1],
        cropt_pct=crop_pct,
        interpolation=data_config['interpolation'])

    _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
       results['top1'], results['top1_err'], results['top5'], results['top5_err']))

    return results


def main():
    setup_default_logging()
    args = parser.parse_args()
    model_cfgs = []
    model_names = []
    if os.path.isdir(args.checkpoint):
        # validate all checkpoints in a path with same model
        checkpoints = glob.glob(args.checkpoint + '/*.pth.tar')
        checkpoints += glob.glob(args.checkpoint + '/*.pth')
        model_names = list_models(args.model)
        model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)]
    else:
        if args.model == 'all':
            # validate all models in a list of names with pretrained checkpoints
            args.pretrained = True
            model_names = list_models(pretrained=True, exclude_filters=['*_in21k', '*_in22k'])
            model_cfgs = [(n, '') for n in model_names]
        elif not is_model(args.model):
            # model name doesn't exist, try as wildcard filter
            model_names = list_models(args.model)
            model_cfgs = [(n, '') for n in model_names]

    if len(model_cfgs):
        results_file = args.results_file or './results-all.csv'
        _logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
        results = []
        try:
            start_batch_size = args.batch_size
            for m, c in model_cfgs:
                batch_size = start_batch_size
                args.model = m
                args.checkpoint = c
                result = OrderedDict(model=args.model)
                r = {}
                while not r and batch_size >= args.num_gpu:
                    torch.cuda.empty_cache()
                    try:
                        args.batch_size = batch_size
                        print('Validating with batch size: %d' % args.batch_size)
                        r = validate(args)
                    except RuntimeError as e:
                        if batch_size <= args.num_gpu:
                            print("Validation failed with no ability to reduce batch size. Exiting.")
                            raise e
                        batch_size = max(batch_size // 2, args.num_gpu)
                        print("Validation failed, reducing batch size by 50%")
                result.update(r)
                if args.checkpoint:
                    result['checkpoint'] = args.checkpoint
                results.append(result)
        except KeyboardInterrupt as e:
            pass
        results = sorted(results, key=lambda x: x['top1'], reverse=True)
        if len(results):
            write_results(results_file, results)
    else:
        validate(args)


def write_results(results_file, results):
    with open(results_file, mode='w') as cf:
        dw = csv.DictWriter(cf, fieldnames=results[0].keys())
        dw.writeheader()
        for r in results:
            dw.writerow(r)
        cf.flush()


if __name__ == '__main__':
    main()

Writing pytorch-image-models/val_logits.py


### tf_efficientnet_b8, vgg19_bn, repvgg_b3, mobilenetv2_120d

In [None]:
!mkdir logits/imagenet

In [None]:
!nvidia-smi --gpu-reset

Error occurred during reset of GPU 00000000:00:04.0: Unknown Error

1 device did not complete reset successfully, and may be in an unstable state. Please reboot your system.


In [None]:
!python pytorch-image-models/val_logits.py ../input/imagenetpt/val --model tf_efficientnet_b8 -b 50 --pretrained

Validating in float32. AMP not enabled.
Loading pretrained weights from url (https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth)
Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth" to /root/.cache/torch/hub/checkpoints/tf_efficientnet_b8_ra-572d5dd9.pth
Model tf_efficientnet_b8 created, param count: 87413142
Data processing configuration for current model + dataset:
	input_size: (3, 672, 672)
	interpolation: bicubic
	mean: (0.485, 0.456, 0.406)
	std: (0.229, 0.224, 0.225)
	crop_pct: 0.954
Test: [   0/1000]  Time: 7.374s (7.374s,    6.78/s)  Loss:  0.2992 (0.2992)  Acc@1:  98.000 ( 98.000)  Acc@5:  98.000 ( 98.000)
Test: [  10/1000]  Time: 3.027s (3.456s,   14.47/s)  Loss:  0.1843 (0.3876)  Acc@1: 100.000 ( 94.182)  Acc@5: 100.000 ( 98.727)
Test: [  20/1000]  Time: 3.026s (3.250s,   15.38/s)  Loss:  0.6154 (0.3099)  Acc@1:  90.000 ( 95.714)  Acc

In [None]:
!python pytorch-image-models/val_logits.py ../input/imagenetpt/val --model vgg19_bn --pretrained

Validating in float32. AMP not enabled.
Loading pretrained weights from url (https://download.pytorch.org/models/vgg19_bn-c79401a0.pth)
Downloading: "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth" to /root/.cache/torch/hub/checkpoints/vgg19_bn-c79401a0.pth
Model vgg19_bn created, param count: 143678248
Data processing configuration for current model + dataset:
	input_size: (3, 224, 224)
	interpolation: bilinear
	mean: (0.485, 0.456, 0.406)
	std: (0.229, 0.224, 0.225)
	crop_pct: 0.875
Test: [   0/196]  Time: 8.288s (8.288s,   30.89/s)  Loss:  0.5989 (0.5989)  Acc@1:  82.422 ( 82.422)  Acc@5:  96.484 ( 96.484)
Test: [  10/196]  Time: 0.778s (2.414s,  106.03/s)  Loss:  0.9983 (0.7153)  Acc@1:  74.219 ( 81.463)  Acc@5:  94.922 ( 95.348)
Test: [  20/196]  Time: 0.804s (2.215s,  115.55/s)  Loss:  0.8656 (0.7308)  Acc@1:  82.031 ( 81.473)  Acc@5:  91.406 ( 95.089)
Test: [  30/196]  Time: 0.762s (2.038s,  125.62/s)  Loss:  0.7318 (0.7009)  Acc@1:  80.859 ( 82.157)  Acc@5:  96.094 (

In [None]:
!python pytorch-image-models/val_logits.py ../input/imagenetpt/val --model repvgg_b3 --pretrained

Validating in float32. AMP not enabled.
Loading pretrained weights from url (https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-repvgg-weights/repvgg_b3-199bc50d.pth)
Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-repvgg-weights/repvgg_b3-199bc50d.pth" to /root/.cache/torch/hub/checkpoints/repvgg_b3-199bc50d.pth
Model repvgg_b3 created, param count: 123085288
Data processing configuration for current model + dataset:
	input_size: (3, 224, 224)
	interpolation: bilinear
	mean: (0.485, 0.456, 0.406)
	std: (0.229, 0.224, 0.225)
	crop_pct: 0.875
Test: [   0/196]  Time: 8.449s (8.449s,   30.30/s)  Loss:  0.4018 (0.4018)  Acc@1:  91.406 ( 91.406)  Acc@5:  98.047 ( 98.047)
Test: [  10/196]  Time: 1.103s (2.409s,  106.27/s)  Loss:  1.0000 (0.5926)  Acc@1:  78.516 ( 85.440)  Acc@5:  95.703 ( 97.266)
Test: [  20/196]  Time: 1.108s (2.161s,  118.47/s)  Loss:  0.5559 (0.5951)  Acc@1:  90.234 ( 85.324)  Acc@5:  95.703 ( 97.080)
Test: [  3

In [None]:
!python pytorch-image-models/val_logits.py ../input/imagenetpt/val --model mobilenetv2_120d --pretrained

Validating in float32. AMP not enabled.
Loading pretrained weights from url (https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_120d_ra-5987e2ed.pth)
Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_120d_ra-5987e2ed.pth" to /root/.cache/torch/hub/checkpoints/mobilenetv2_120d_ra-5987e2ed.pth
Model mobilenetv2_120d created, param count: 5831144
Data processing configuration for current model + dataset:
	input_size: (3, 224, 224)
	interpolation: bicubic
	mean: (0.485, 0.456, 0.406)
	std: (0.229, 0.224, 0.225)
	crop_pct: 0.875
Test: [   0/196]  Time: 9.087s (9.087s,   28.17/s)  Loss:  0.5475 (0.5475)  Acc@1:  87.109 ( 87.109)  Acc@5:  97.266 ( 97.266)
Test: [  10/196]  Time: 0.380s (2.290s,  111.77/s)  Loss:  0.9457 (0.6748)  Acc@1:  78.516 ( 83.629)  Acc@5:  93.750 ( 96.520)
Test: [  20/196]  Time: 0.310s (2.218s,  115.40/s)  Loss:  0.7070 (0.6962)  Acc@1:  87.109 ( 83.612)  Acc@5:  93.750 

In [None]:
!zip -r logits.zip logits

  adding: logits/ (stored 0%)
  adding: logits/imagenet/ (stored 0%)
  adding: logits/imagenet/tf_efficientnet_b8.pt (deflated 7%)
  adding: logits/imagenet/vgg19_bn.pt (deflated 7%)
  adding: logits/imagenet/mobilenetv2_120d.pt (deflated 7%)
  adding: logits/imagenet/repvgg_b3.pt (deflated 7%)
  adding: logits/cifar10/ (stored 0%)
  adding: logits/cifar10/targets.txt (deflated 70%)
  adding: logits/cifar10/v1/ (stored 0%)
  adding: logits/cifar10/v1/mobilenet_v2.pt (deflated 11%)
  adding: logits/cifar10/v1/resnet18.pt (deflated 10%)
  adding: logits/cifar10/v1/vgg11_bn.pt (deflated 7%)
  adding: logits/cifar10/v1/densenet121.pt (deflated 11%)
  adding: logits/cifar10/v1/resnet50.pt (deflated 10%)
  adding: logits/cifar10/v1/vgg16_bn.pt (deflated 7%)
  adding: logits/cifar10/v1/googlenet.pt (deflated 10%)
  adding: logits/cifar10/v1/vgg19_bn.pt (deflated 6%)
  adding: logits/cifar10/v1/inception_v3.pt (deflated 11%)
  adding: logits/cifar10/v1/vgg13_bn.pt (deflated 7%)
  adding: logit

In [None]:
!rm -rf logits pytorch-image-models datasets

In [None]:
!du -h logits.zip

713M	logits.zip
