In [1]:
!mkdir focal focal/CIFAR10 focal/CIFAR100 focal/TinyImageNet
!git clone https://github.com/torrvision/focal_calibration

Cloning into 'focal_calibration'...
remote: Enumerating objects: 213, done.[K
remote: Total 213 (delta 0), reused 0 (delta 0), pack-reused 213[K
Receiving objects: 100% (213/213), 1.50 MiB | 4.04 MiB/s, done.
Resolving deltas: 100% (101/101), done.


In [2]:
import os
os.chdir('./focal_calibration')

In [3]:
import torch
import random
from collections import OrderedDict

# Import dataloaders
import Data.cifar10 as cifar10
import Data.cifar100 as cifar100
import Data.tiny_imagenet as tiny_imagenet

# Import network architectures
from Net.resnet_tiny_imagenet import resnet50 as resnet50_ti
from Net.resnet import resnet50, resnet110
from Net.wide_resnet import wide_resnet_cifar
from Net.densenet import densenet121

# Dataset params
dataset_num_classes = {
    'CIFAR10': 10,
    'CIFAR100': 100,
    'TinyImageNet': 200
}

dataset_loader = {
    'CIFAR10': cifar10,
    'CIFAR100': cifar100,
    'TinyImageNet': tiny_imagenet
}

# Mapping model name to model function
models = {
    'resnet50': resnet50,
    'resnet110': resnet110,
    'wide_resnet': wide_resnet_cifar,
    'densenet121': densenet121
}

loss_names = [
    'cross_entropy',
    'cross_entropy_smoothed_smoothing_0.05',
    'focal_loss_gamma_1.0',
    'focal_loss_gamma_2.0',
    'focal_loss_gamma_3.0',
]

In [4]:
from tqdm.auto import tqdm

def save_logits(model, dataloader, path):
    model.eval()
    device = model.device
    tqdm_dataloader = tqdm(dataloader, desc='eval')
    all_logits = []
    all_labels = []
    total_acc = 0
    with torch.no_grad():
        for batch_x, batch_y in tqdm_dataloader:
            logits = model(batch_x.to(model.device))
            all_logits.append(logits.cpu())
            all_labels.append(batch_y)
            acc = (logits.argmax(dim=1) == batch_y.to(model.device))
            total_acc += acc.sum().item()
            acc = acc.float().mean().item()
            tqdm_dataloader.set_postfix({'batch_acc': acc})
    total_acc *= 100 / len(dataloader.dataset)
    print(f'accuracy: {total_acc:.2f}')
    all_logits = torch.cat(all_logits)
    all_labels = torch.cat(all_labels)
    torch.save(all_logits, path)
    return all_logits, all_labels

In [5]:
import subprocess

def download_model(dataset_name, saved_model_name):
    url = 'https://www.robots.ox.ac.uk/~viveka/focal_calibration'
    command = f'wget -q -P model_weights/{dataset_name} {url}/{dataset_name}/{saved_model_name}'
    print(command)
    if subprocess.run(command.split()).returncode == 0:
        print(f'Successfully downloaded {dataset_name}/{saved_model_name}')
    return f'model_weights/{dataset_name}/{saved_model_name}'

def set_model(model, device, path):
    model.to(device)
    model.device = device
    state_dict = OrderedDict()
    for old_key, weights in torch.load(path, map_location=device).items():
        new_key = old_key[7:]
        assert old_key[:7] == 'module.'
        state_dict[new_key] = weights
    res = model.load_state_dict(state_dict)
    print(res)

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## CIFAR10

In [7]:
dataset_name = 'CIFAR10'
batch_size = 100
num_classes = dataset_num_classes[dataset_name]
test_loader = dataset_loader[dataset_name].get_test_loader(
    shuffle=False, batch_size=batch_size, pin_memory=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


0it [00:00, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data


In [8]:
with open(f'../focal/{dataset_name}/targets.txt', 'w') as fout:
    fout.write('\n'.join(str(target) for target in test_loader.dataset.targets))

In [9]:
for model_name in models:
    for loss_name in loss_names:
        saved_model_name = f'{model_name}_{loss_name}_350.model'
        state_path = download_model(dataset_name, saved_model_name)
        model_class = models[model_name]
        model = model_class(num_classes=num_classes, temp=1.0)
        set_model(model, device, state_path)
        logits_path = f'../focal/{dataset_name}/{model_name}_{loss_name}.pt'
        all_logits, all_labels = save_logits(model, test_loader, logits_path)
        print()

wget -q -P model_weights/CIFAR10 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR10/resnet50_cross_entropy_350.model
Successfully downloaded CIFAR10/resnet50_cross_entropy_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 95.05

wget -q -P model_weights/CIFAR10 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR10/resnet50_cross_entropy_smoothed_smoothing_0.05_350.model
Successfully downloaded CIFAR10/resnet50_cross_entropy_smoothed_smoothing_0.05_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 94.71

wget -q -P model_weights/CIFAR10 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR10/resnet50_focal_loss_gamma_1.0_350.model
Successfully downloaded CIFAR10/resnet50_focal_loss_gamma_1.0_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 95.07

wget -q -P model_weights/CIFAR10 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR10/resnet50_focal_loss_gamma_2.0_350.model
Successfully downloaded CIFAR10/resnet50_focal_loss_gamma_2.0_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 95.02

wget -q -P model_weights/CIFAR10 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR10/resnet50_focal_loss_gamma_3.0_350.model
Successfully downloaded CIFAR10/resnet50_focal_loss_gamma_3.0_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 94.75

wget -q -P model_weights/CIFAR10 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR10/resnet110_cross_entropy_350.model
Successfully downloaded CIFAR10/resnet110_cross_entropy_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 95.11

wget -q -P model_weights/CIFAR10 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR10/resnet110_cross_entropy_smoothed_smoothing_0.05_350.model
Successfully downloaded CIFAR10/resnet110_cross_entropy_smoothed_smoothing_0.05_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 94.48

wget -q -P model_weights/CIFAR10 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR10/resnet110_focal_loss_gamma_1.0_350.model
Successfully downloaded CIFAR10/resnet110_focal_loss_gamma_1.0_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 95.22

wget -q -P model_weights/CIFAR10 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR10/resnet110_focal_loss_gamma_2.0_350.model
Successfully downloaded CIFAR10/resnet110_focal_loss_gamma_2.0_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 94.94

wget -q -P model_weights/CIFAR10 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR10/resnet110_focal_loss_gamma_3.0_350.model
Successfully downloaded CIFAR10/resnet110_focal_loss_gamma_3.0_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 94.92

wget -q -P model_weights/CIFAR10 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR10/wide_resnet_cross_entropy_350.model
Successfully downloaded CIFAR10/wide_resnet_cross_entropy_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 96.14

wget -q -P model_weights/CIFAR10 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR10/wide_resnet_cross_entropy_smoothed_smoothing_0.05_350.model
Successfully downloaded CIFAR10/wide_resnet_cross_entropy_smoothed_smoothing_0.05_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 95.80

wget -q -P model_weights/CIFAR10 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR10/wide_resnet_focal_loss_gamma_1.0_350.model
Successfully downloaded CIFAR10/wide_resnet_focal_loss_gamma_1.0_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 95.73

wget -q -P model_weights/CIFAR10 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR10/wide_resnet_focal_loss_gamma_2.0_350.model
Successfully downloaded CIFAR10/wide_resnet_focal_loss_gamma_2.0_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 95.73

wget -q -P model_weights/CIFAR10 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR10/wide_resnet_focal_loss_gamma_3.0_350.model
Successfully downloaded CIFAR10/wide_resnet_focal_loss_gamma_3.0_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 95.87

wget -q -P model_weights/CIFAR10 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR10/densenet121_cross_entropy_350.model
Successfully downloaded CIFAR10/densenet121_cross_entropy_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 95.00

wget -q -P model_weights/CIFAR10 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR10/densenet121_cross_entropy_smoothed_smoothing_0.05_350.model
Successfully downloaded CIFAR10/densenet121_cross_entropy_smoothed_smoothing_0.05_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 94.91

wget -q -P model_weights/CIFAR10 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR10/densenet121_focal_loss_gamma_1.0_350.model
Successfully downloaded CIFAR10/densenet121_focal_loss_gamma_1.0_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 94.91

wget -q -P model_weights/CIFAR10 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR10/densenet121_focal_loss_gamma_2.0_350.model
Successfully downloaded CIFAR10/densenet121_focal_loss_gamma_2.0_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 95.16

wget -q -P model_weights/CIFAR10 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR10/densenet121_focal_loss_gamma_3.0_350.model
Successfully downloaded CIFAR10/densenet121_focal_loss_gamma_3.0_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 94.67



In [10]:
!du -h model_weights/CIFAR10/*

27M	model_weights/CIFAR10/densenet121_cross_entropy_350.model
27M	model_weights/CIFAR10/densenet121_cross_entropy_smoothed_smoothing_0.05_350.model
27M	model_weights/CIFAR10/densenet121_focal_loss_gamma_1.0_350.model
27M	model_weights/CIFAR10/densenet121_focal_loss_gamma_2.0_350.model
27M	model_weights/CIFAR10/densenet121_focal_loss_gamma_3.0_350.model
176M	model_weights/CIFAR10/resnet110_cross_entropy_350.model
176M	model_weights/CIFAR10/resnet110_cross_entropy_smoothed_smoothing_0.05_350.model
176M	model_weights/CIFAR10/resnet110_focal_loss_gamma_1.0_350.model
176M	model_weights/CIFAR10/resnet110_focal_loss_gamma_2.0_350.model
176M	model_weights/CIFAR10/resnet110_focal_loss_gamma_3.0_350.model
90M	model_weights/CIFAR10/resnet50_cross_entropy_350.model
90M	model_weights/CIFAR10/resnet50_cross_entropy_smoothed_smoothing_0.05_350.model
90M	model_weights/CIFAR10/resnet50_focal_loss_gamma_1.0_350.model
90M	model_weights/CIFAR10/resnet50_focal_loss_gamma_2.0_350.model
90M	mod

In [11]:
!rm -rf model_weights

## CIFAR100

In [12]:
dataset_name = 'CIFAR100'
batch_size = 100
num_classes = dataset_num_classes[dataset_name]
test_loader = dataset_loader[dataset_name].get_test_loader(
    shuffle=False, batch_size=batch_size, pin_memory=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


0it [00:00, ?it/s]

Extracting ./data/cifar-100-python.tar.gz to ./data


In [13]:
with open(f'../focal/{dataset_name}/targets.txt', 'w') as fout:
    fout.write('\n'.join(str(target) for target in test_loader.dataset.targets))

In [14]:
for model_name in models:
    for loss_name in loss_names:
        saved_model_name = f'{model_name}_{loss_name}_350.model'
        state_path = download_model(dataset_name, saved_model_name)
        model_class = models[model_name]
        model = model_class(num_classes=num_classes, temp=1.0)
        set_model(model, device, state_path)
        logits_path = f'../focal/{dataset_name}/{model_name}_{loss_name}.pt'
        all_logits, all_labels = save_logits(model, test_loader, logits_path)
        print()

wget -q -P model_weights/CIFAR100 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR100/resnet50_cross_entropy_350.model
Successfully downloaded CIFAR100/resnet50_cross_entropy_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 76.70

wget -q -P model_weights/CIFAR100 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR100/resnet50_cross_entropy_smoothed_smoothing_0.05_350.model
Successfully downloaded CIFAR100/resnet50_cross_entropy_smoothed_smoothing_0.05_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 76.57

wget -q -P model_weights/CIFAR100 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR100/resnet50_focal_loss_gamma_1.0_350.model
Successfully downloaded CIFAR100/resnet50_focal_loss_gamma_1.0_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 77.20

wget -q -P model_weights/CIFAR100 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR100/resnet50_focal_loss_gamma_2.0_350.model
Successfully downloaded CIFAR100/resnet50_focal_loss_gamma_2.0_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 76.85

wget -q -P model_weights/CIFAR100 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR100/resnet50_focal_loss_gamma_3.0_350.model
Successfully downloaded CIFAR100/resnet50_focal_loss_gamma_3.0_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 77.25

wget -q -P model_weights/CIFAR100 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR100/resnet110_cross_entropy_350.model
Successfully downloaded CIFAR100/resnet110_cross_entropy_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 77.27

wget -q -P model_weights/CIFAR100 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR100/resnet110_cross_entropy_smoothed_smoothing_0.05_350.model
Successfully downloaded CIFAR100/resnet110_cross_entropy_smoothed_smoothing_0.05_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 76.57

wget -q -P model_weights/CIFAR100 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR100/resnet110_focal_loss_gamma_1.0_350.model
Successfully downloaded CIFAR100/resnet110_focal_loss_gamma_1.0_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 77.64

wget -q -P model_weights/CIFAR100 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR100/resnet110_focal_loss_gamma_2.0_350.model
Successfully downloaded CIFAR100/resnet110_focal_loss_gamma_2.0_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 77.47

wget -q -P model_weights/CIFAR100 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR100/resnet110_focal_loss_gamma_3.0_350.model
Successfully downloaded CIFAR100/resnet110_focal_loss_gamma_3.0_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 77.08

wget -q -P model_weights/CIFAR100 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR100/wide_resnet_cross_entropy_350.model
Successfully downloaded CIFAR100/wide_resnet_cross_entropy_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 79.30

wget -q -P model_weights/CIFAR100 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR100/wide_resnet_cross_entropy_smoothed_smoothing_0.05_350.model
Successfully downloaded CIFAR100/wide_resnet_cross_entropy_smoothed_smoothing_0.05_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 78.81

wget -q -P model_weights/CIFAR100 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR100/wide_resnet_focal_loss_gamma_1.0_350.model
Successfully downloaded CIFAR100/wide_resnet_focal_loss_gamma_1.0_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 80.39

wget -q -P model_weights/CIFAR100 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR100/wide_resnet_focal_loss_gamma_2.0_350.model
Successfully downloaded CIFAR100/wide_resnet_focal_loss_gamma_2.0_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 79.99

wget -q -P model_weights/CIFAR100 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR100/wide_resnet_focal_loss_gamma_3.0_350.model
Successfully downloaded CIFAR100/wide_resnet_focal_loss_gamma_3.0_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 80.31

wget -q -P model_weights/CIFAR100 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR100/densenet121_cross_entropy_350.model
Successfully downloaded CIFAR100/densenet121_cross_entropy_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 75.48

wget -q -P model_weights/CIFAR100 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR100/densenet121_cross_entropy_smoothed_smoothing_0.05_350.model
Successfully downloaded CIFAR100/densenet121_cross_entropy_smoothed_smoothing_0.05_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 75.95

wget -q -P model_weights/CIFAR100 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR100/densenet121_focal_loss_gamma_1.0_350.model
Successfully downloaded CIFAR100/densenet121_focal_loss_gamma_1.0_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 76.18

wget -q -P model_weights/CIFAR100 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR100/densenet121_focal_loss_gamma_2.0_350.model
Successfully downloaded CIFAR100/densenet121_focal_loss_gamma_2.0_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 76.81

wget -q -P model_weights/CIFAR100 https://www.robots.ox.ac.uk/~viveka/focal_calibration/CIFAR100/densenet121_focal_loss_gamma_3.0_350.model
Successfully downloaded CIFAR100/densenet121_focal_loss_gamma_3.0_350.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 76.75



In [15]:
!rm -rf model_weights

## Tiny-Imagenet

In [16]:
import os
import glob
from torch.utils.data import Dataset
from PIL import Image

EXTENSION = 'JPEG'
NUM_IMAGES_PER_CLASS = 500
CLASS_LIST_FILE = 'wnids.txt'
VAL_ANNOTATION_FILE = 'val_annotations.txt'


class TinyImageNet(Dataset):
    """Tiny ImageNet data set available from `http://cs231n.stanford.edu/tiny-imagenet-200.zip`.
    Parameters
    ----------
    root: string
        Root directory including `train`, `test` and `val` subdirectories.
    split: string
        Indicating which split to return as a data set.
        Valid option: [`train`, `test`, `val`]
    transform: torchvision.transforms
        A (series) of valid transformation(s).
    in_memory: bool
        Set to True if there is enough memory (about 5G) and want to minimize disk IO overhead.
    """
    def __init__(self, root, split='train', transform=None, target_transform=None, in_memory=False):
        self.root = os.path.expanduser(root)
        self.split = split
        self.transform = transform
        self.target_transform = target_transform
        self.in_memory = in_memory
        self.split_dir = os.path.join(root, self.split)
        self.image_paths = sorted(glob.iglob(os.path.join(self.split_dir, '**', '*.%s' % EXTENSION), recursive=True))
        self.labels = {}  # fname - label number mapping
        self.images = []  # used for in-memory processing

        # build class label - number mapping
        with open(os.path.join(self.root, CLASS_LIST_FILE), 'r') as fp:
            self.label_texts = sorted([text.strip() for text in fp.readlines()])
        self.label_text_to_number = {text: i for i, text in enumerate(self.label_texts)}

        if self.split == 'train':
            for label_text, i in self.label_text_to_number.items():
                for cnt in range(NUM_IMAGES_PER_CLASS):
                    self.labels['%s_%d.%s' % (label_text, cnt, EXTENSION)] = i
        elif self.split == 'val':
            with open(os.path.join(self.split_dir, VAL_ANNOTATION_FILE), 'r') as fp:
                for line in fp.readlines():
                    terms = line.split('\t')
                    file_name, label_text = terms[0], terms[1]
                    self.labels[file_name] = self.label_text_to_number[label_text]

        # read all images into torch tensor in memory to minimize disk IO overhead
        if self.in_memory:
            self.images = [self.read_image(path) for path in self.image_paths]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        file_path = self.image_paths[index]

        if self.in_memory:
            img = self.images[index]
        else:
            img = self.read_image(file_path)

        if self.split == 'test':
            return img
        else:
            # file_name = file_path.split('/')[-1]
            return img, self.labels[os.path.basename(file_path)]

    def __repr__(self):
        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
        tmp = self.split
        fmt_str += '    Split: {}\n'.format(tmp)
        fmt_str += '    Root Location: {}\n'.format(self.root)
        tmp = '    Transforms (if any): '
        fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        tmp = '    Target Transforms (if any): '
        fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        return fmt_str

    def read_image(self, path):
        img = Image.open(path)
        if (img.mode == 'L'):
            img = img.convert('RGB')
        return self.transform(img) if self.transform else img

In [17]:
from torchvision import transforms
from torch.utils.data import DataLoader

dataset_name = 'TinyImageNet'
num_classes = dataset_num_classes[dataset_name]

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

dataset = TinyImageNet('/kaggle/input/tinyimagenet/tiny-imagenet-200',
                       split='val',
                       transform=transform,
                       in_memory=True)

test_loader = DataLoader(dataset, batch_size=batch_size,
                         num_workers=2, shuffle=False)

In [18]:
model_name = 'resnet50'
model_class = resnet50_ti

In [19]:
for loss_name in loss_names:
    saved_model_name = f'{model_name}_{loss_name}_100.model'
    state_path = download_model(dataset_name, saved_model_name)
    model = model_class(num_classes=num_classes, temp=1.0)
    set_model(model, device, state_path)
    logits_path = f'../focal/{dataset_name}/{model_name}_{loss_name}.pt'
    all_logits, all_labels = save_logits(model, test_loader, logits_path)
    print()

wget -q -P model_weights/TinyImageNet https://www.robots.ox.ac.uk/~viveka/focal_calibration/TinyImageNet/resnet50_cross_entropy_100.model
Successfully downloaded TinyImageNet/resnet50_cross_entropy_100.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 50.19

wget -q -P model_weights/TinyImageNet https://www.robots.ox.ac.uk/~viveka/focal_calibration/TinyImageNet/resnet50_cross_entropy_smoothed_smoothing_0.05_100.model
Successfully downloaded TinyImageNet/resnet50_cross_entropy_smoothed_smoothing_0.05_100.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 52.88

wget -q -P model_weights/TinyImageNet https://www.robots.ox.ac.uk/~viveka/focal_calibration/TinyImageNet/resnet50_focal_loss_gamma_1.0_100.model
Successfully downloaded TinyImageNet/resnet50_focal_loss_gamma_1.0_100.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 49.94

wget -q -P model_weights/TinyImageNet https://www.robots.ox.ac.uk/~viveka/focal_calibration/TinyImageNet/resnet50_focal_loss_gamma_2.0_100.model
Successfully downloaded TinyImageNet/resnet50_focal_loss_gamma_2.0_100.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 52.30

wget -q -P model_weights/TinyImageNet https://www.robots.ox.ac.uk/~viveka/focal_calibration/TinyImageNet/resnet50_focal_loss_gamma_3.0_100.model
Successfully downloaded TinyImageNet/resnet50_focal_loss_gamma_3.0_100.model
<All keys matched successfully>


eval:   0%|          | 0/100 [00:00<?, ?it/s]

accuracy: 50.31



In [20]:
with open(f'../focal/{dataset_name}/targets.txt', 'w') as fout:
    fout.write('\n'.join(all_labels.numpy().astype(str)))

In [21]:
os.chdir('..')
!zip -r logits_focal.zip focal
!rm -rf focal_calibration focal

  adding: focal/ (stored 0%)
  adding: focal/TinyImageNet/ (stored 0%)
  adding: focal/TinyImageNet/targets.txt (deflated 62%)
  adding: focal/TinyImageNet/resnet50_cross_entropy_smoothed_smoothing_0.05.pt (deflated 7%)
  adding: focal/TinyImageNet/resnet50_focal_loss_gamma_2.0.pt (deflated 7%)
  adding: focal/TinyImageNet/resnet50_focal_loss_gamma_3.0.pt (deflated 7%)
  adding: focal/TinyImageNet/resnet50_focal_loss_gamma_1.0.pt (deflated 7%)
  adding: focal/TinyImageNet/resnet50_cross_entropy.pt (deflated 7%)
  adding: focal/CIFAR100/ (stored 0%)
  adding: focal/CIFAR100/densenet121_focal_loss_gamma_3.0.pt (deflated 8%)
  adding: focal/CIFAR100/targets.txt (deflated 61%)
  adding: focal/CIFAR100/resnet50_cross_entropy_smoothed_smoothing_0.05.pt (deflated 8%)
  adding: focal/CIFAR100/resnet110_focal_loss_gamma_2.0.pt (deflated 8%)
  adding: focal/CIFAR100/wide_resnet_cross_entropy_smoothed_smoothing_0.05.pt (deflated 9%)
  adding: focal/CIFAR100/densenet121_cross_entropy