In [1]:
import torch, os
import torch.nn as nn
import torchvision.datasets as datasets
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.models as models
import numpy as np
import random
import torch.utils.data as torchdata
from torch.utils.data import SubsetRandomSampler
# from train_model import train_model
# from test_model import test_model
%matplotlib inline

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]='3'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
filePath = '/local/rcs/ll3504/datasets/256_ObjectCategories/'
namelist = os.listdir(filePath)
nameDic_cal = {}
for name in namelist:
    splits = name.split(".")
    nameDic_cal[int(splits[0])-1] = splits[1]
print(nameDic_cal[1])

american-flag


In [4]:
def get_dataset(path='/database', dataset_name='caltech-256-common'):
    # No holdout testing data. train and test data are the same, but different transformation
    data_transforms = {
        'train': transforms.Compose([
#             transforms.Resize([256, 256]),
#             transforms.RandomCrop(224),
#             transforms.RandomRotation(20),
#             transforms.RandomHorizontalFlip(0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ]),
        'test': transforms.Compose([
#             transforms.Resize([224, 224]),
#             transforms.CenterCrop((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    }

    tr_dataset = datasets.ImageFolder(path + dataset_name + '/', data_transforms['train'])
    te_dataset = datasets.ImageFolder(path + dataset_name + '/', data_transforms['test'])
#     print('{} train set size: {}'.format(dataset_name, len(tr_dataset)))
#     print('{} test set size: {}'.format(dataset_name, len(te_dataset)))

    return tr_dataset, te_dataset

In [5]:
def split_dataset(train_dataset, test_dataset, valid_size=0.2, batch_size=128, train_size = 128):
    '''
    This function splits dataset into train, val, and test sets, and return train, val, test dataloaders.
    Val and Test loaders are the same

    '''
    
    # what does the len function gives?
    num_train = len(train_dataset)
    indices = list(range(num_train))
    split = int(np.floor(valid_size * num_train))
    random.shuffle(indices)
    train_idx, valid_idx = indices[split:split+train_size], indices[:split]
#     print("DEBUGGING: train_idx =", train_idx, "valid_idx =", valid_idx)
    
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)
    
#     print("DEBUGGING: the train_ind are:", len(train_idx))


    train_loader = torchdata.DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=48, pin_memory=True, drop_last=True, sampler = train_sampler)
    test_loader = torchdata.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=48, pin_memory=True, drop_last=True, sampler = valid_sampler)
    dataloaders = {'train': train_loader,
                   'val': test_loader,
                   'test': test_loader}
    dataset_sizes ={'train': train_size, #int(np.floor((1-valid_size) * num_train)),
                    'val': int(np.floor(valid_size * num_train)),
                    'test': int(np.floor(valid_size * num_train))}
    return dataloaders, dataset_sizes

In [6]:
imagebase = '/local/rcs/ll3504/datasets/'

In [7]:
corruption = ['zoom_blur', 'speckle_noise', 'spatter',
                       'snow', 'glass_blur', 'motion_blur', 'saturate',
                       'gaussian_blur', 'frost', 'fog', 'brightness', 'contrast',
                       'elastic_transform', 'pixelate', 'jpeg_compression', 'defocus_blur']

# ImageNetC

In [8]:
def get_imagenetc(imagebase, severity=1, batch_size=128, sample_size = 128):
    '''
    Returns:
        ref_dataloaders:          ImageNet original validation data, as a reference
        ref_dataset_sizes:        1000, not the sizes of the real dataset in the ref_loader, probs used downstream
        corrupted_dataloaders:    A list of corrupted dataloaders, each element in a list represetns the data loaders
                                  for one corruption type. Each element contains ['train']['val']['test'] loaders
        corrupted_dataset_sizes:  A list of dictionaries of the sizes of each loaders for each corruption
        corruption:               A list of corruption names, in the same order of the corrupted_dataloaders
    '''
    corruption = ['zoom_blur', 'speckle_noise', 'spatter',
                       'snow', 'glass_blur', 'motion_blur', 'saturate',
                       'gaussian_blur', 'frost', 'fog', 'brightness', 'contrast',
                       'elastic_transform', 'pixelate', 'jpeg_compression', 'defocus_blur']
    corrupted_dataloaders = []
    corrupted_dataset_sizes = []
    
    # this is the imageNet validation data
    imagenet_val = datasets.ImageNet(imagebase+'imagenetc/', split='val', transform=transforms.Compose([
            transforms.Resize([224, 224]),
            transforms.CenterCrop((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ]),
                                   target_transform=None)#, download=False)
    
    # TODO: subsample some size of ImageNet training data as source
        # Doesn't need this step
#     print("DEBUGGING: imagenet_val size is:", len(imagenet_val))
    
    random_indices = random.sample(range(0, len(imagenet_val)), int(len(imagenet_val)*0.02))
#     print("DEBUGGING: random indices are:", len(random_indices))
    imagenet_val_subset = data.Subset(imagenet_val, random_indices)
    val_loader = torch.utils.data.DataLoader(imagenet_val_subset,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=48)
    ref_dataloaders = { 'val': val_loader,
                       'test': val_loader}
    ref_dataset_sizes ={'val': int(len(val_loader.dataset)),
                        'test': int(len(val_loader.dataset))}
    
    # for every type of corruption, go to the specified severity folder
    for corr in corruption:
        dataset_name = 'imagenetc/' + corr + '/' + str(severity)
        # Get dataset from folder
        corr_trian_images, corr_test_images = get_dataset(imagebase, dataset_name)
        # get corruption-specific train, val, test loader
        # train: training data, non-overlap with val/test
        # val: non-overlap with train, same as test
        # test: non-overlap with train, same as test
        corr_dataloaders, corr_dataset_sizes = split_dataset(corr_trian_images, corr_test_images, valid_size=0.02, batch_size=batch_size, train_size=sample_size)
        corrupted_dataloaders.append(corr_dataloaders)
        corrupted_dataset_sizes.append(corr_dataset_sizes)
    return ref_dataloaders, ref_dataset_sizes, corrupted_dataloaders, corrupted_dataset_sizes, corruption

In [None]:
ref_dataloaders, ref_dataset_sizes, corrupted_dataloaders, corrupted_dataset_sizes, corruption \
= get_imagenetc(imagebase, 5, 32, 64)



In [None]:
# in total 16 corruptions, each have 3 dataloaders ['train']['val']['test']
len(corrupted_dataloaders[15]['val']) # why is this 31??

In [None]:
corrupted_dataset_sizes[0]

In [None]:
ref_dataset_sizes['val']

In [None]:
len(ref_dataloaders['val'])

# Vanila Resnet-50 Baseline

Test top-1 accuracy for all corruptions, for different batch sizes, average accuracy over batches, (inference time)
on reference data (IM-val) and corruption test data


In [None]:
import time

In [None]:
batchsizes = [32, 64, 128, 256, 512, 1000]
severity = [1,2,3,4,5]
# a dictionary that maps 'coruption_type' to a table of 5x6, 
# recording the performance of all corruption for all batchsizes
# baseline_performances = {}

# # in addition to all the corruptions, save performance for reference (IM-val) too
# baseline_performances['ref'] = np.zeros(shape=(len(severity), len(batchsizes)))
# for c in corruption:
#     baseline_performances[c] = np.zeros(shape=(len(severity), len(batchsizes)))

baseline_performances = np.zeros(shape=(len(corruption), len(severity), len(batchsizes)))
baseline_inf_times = np.zeros(shape=(len(corruption), len(severity), len(batchsizes)))



In [None]:
model = models.resnet50(pretrained=True)
model.eval()
model.to(device)

In [None]:
for severity_ind in range(len(severity)):
    for bs_ind in range(len(batchsizes)):
        ref_dataloaders, ref_dataset_sizes, corrupted_dataloaders, corrupted_dataset_sizes, corruption \
        = get_imagenetc(imagebase, severity[severity_ind], batchsizes[bs_ind])
        for cor_ind in range(len(corrupted_dataloaders)):
            with torch.no_grad():
                # TOOD: also record average batch inference time.
                batch_accs = []
                batch_inf_times = []
                for data in corrupted_dataloaders[cor_ind]['test']:
                    start = time.time()
                    images, labels = data[0].to(device), data[1].to(device)
                    outputs = model(images)
                    _, predicted = torch.max(outputs.data, 1)
                    inference_time = time.time() - start
                    batch_inf_times.append(inference_time)
                    batch_accuracy = ((predicted == labels).sum().item()) / labels.size(0)
                    batch_accs.append(batch_accuracy)
            print(f'Vanilla resnet50 on 1000 test images of corruption {corruption[cor_ind]} with severity {severity[severity_ind]} and batch_size {batchsizes[bs_ind]}, average batch_size = {np.mean(batch_accs)}')
            baseline_performances[cor_ind, severity_ind, bs_ind] = np.mean(batch_accs)
            print(f'Average Inference time = {np.mean(batch_inf_times)}')
            baseline_inf_times[cor_ind, severity_ind, bs_ind] = np.mean(batch_inf_times)
            
        with open('ResNet50_performances.npy', 'wb') as f:
            np.save(f, baseline_performances)
        with open('ResNet50_times.npy', 'wb') as f:
            np.save(f, baseline_inf_times)
        
            
            
            

In [None]:
# baseline_performances
baseline_inf_times

In [None]:
baseline_inf_times

In [None]:
with open('ResNet50_performances.npy', 'wb') as f:
    np.save(f, baseline_performances)
with open('ResNet50_times.npy', 'wb') as f:
    np.save(f, baseline_inf_times)

In [6]:
with open('ResNet50_performances.npy', 'rb') as f:
    baseline_performances = np.load(f)
with open('ResNet50_times.npy', 'rb') as f:
    baseline_inf_times = np.load(f)


In [15]:
1-baseline_performances[0,:,0].mean()

0.6334677419354839

## Graphing the performance

# Robust Pseudo label

In [11]:
import time

In [12]:
import torch
from torchvision import models, datasets, transforms

def gce(logits, target, q = 0.8):
    """ Generalized cross entropy.
    
    Reference: https://arxiv.org/abs/1805.07836
    """
    probs = torch.nn.functional.softmax(logits, dim=1)
    probs_with_correct_idx = probs.index_select(-1, target).diag()
    loss = (1. - probs_with_correct_idx**q) / q
    return loss.mean()

def adapt_batchnorm(model):
    model.eval()
    parameters = []
    for module in model.modules():
        if isinstance(module, torch.nn.BatchNorm2d):
            parameters.extend(module.parameters())
            module.train()
    return parameters


# ---



In [None]:
def adapt(
#         datadir = '/data/imagenetc/gaussian_blur/3',
        baseline,
        model,
        dataloader,
        num_epochs = 1, # followed their findings in the paper
        batch_size = 32,
        learning_rate = 1e-3,
        gce_q = 0.8,
    ):
    
    # model = models.resnet50(pretrained = True).to(device)
    parameters = adapt_batchnorm(model)
    
    # TODO change this loader
    # val_loader = get_dataset_loader(
    #    datadir,
    #    batch_size = batch_size,
    #    shuffle = True
    # )
    
    optimizer = torch.optim.SGD(
        parameters, lr = learning_rate
    )
    
    b_correct, num_correct, num_samples = 0., 0., 0.
    for epoch in range(num_epochs):
        predictions = []
        batch_accs = []
        for images, labels in dataloader:
            # start = time.time()
            
            outputs = model(images.to(device))
            predictions = outputs.argmax(dim = 1)
            # _, predicted = torch.max(outputs.data, 1)

            # inference_time = time.time() - start
            
            # TODO: in our scenario, do we want to revert back to original model after adapting in each step?
            loss = gce(outputs, predictions, q = gce_q)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            b_outputs = baseline(images.to(device))
            b_predictions = b_outputs.argmax(dim = 1)

            num_correct += (predictions.detach().cpu() == labels).float().sum()
            b_correct += (b_predictions.detach().cpu() == labels).float().sum()

            num_samples += len(labels)

            print(f"Baseline Correct: {b_correct:#5.0f}/{num_samples:#5.0f} ({100 * b_correct / num_samples:.2f} %)")
            print(f"Adapt Correct: {num_correct:#5.0f}/{num_samples:#5.0f} ({100 * num_correct / num_samples:.2f} %)")
    
    return num_correct/num_samples, b_correct/num_samples
                    


In [None]:
target_sizes = [32, 64, 128, 256, 512]
severity = [1,2,3,4,5]

rpl_corrupt_train_acc = np.zeros(shape=(len(corruption), len(severity), len(target_sizes)))
rpl_corrupt_val_acc = np.zeros(shape=(len(corruption), len(severity), len(target_sizes)))
rpl_ref_acc = np.zeros(shape=(len(corruption), len(severity), len(target_sizes)))


# rpl_adpt_times = np.zeros(shape=(len(corruption), len(severity), len(target_sizes)))
baseline_train_acc = np.zeros(shape=(len(corruption), len(severity), len(target_sizes)))
baseline_val_acc = np.zeros(shape=(len(corruption), len(severity), len(target_sizes)))
baseline_ref_acc = np.zeros(shape=(len(corruption), len(severity), len(target_sizes)))





In [None]:
baseline = models.resnet50(pretrained=True)
baseline.to(device)
baseline.eval()
for severity_ind in range(len(severity)):
    for ts_ind in range(len(target_sizes)):
        ref_dataloaders, ref_dataset_sizes, corrupted_dataloaders, corrupted_dataset_sizes, corruption \
        = get_imagenetc(imagebase, severity[severity_ind], 32, target_sizes[ts_ind])
        
        # adapt the model using the rpl methods
        # evaluate the adapted model on the test corrupted (target) dataset & the reference (source) dataset
        for cor_ind in range(len(corrupted_dataloaders)):
            
            print(f'## Experiment: Severity = {severity[severity_ind]}, target_size = {target_sizes[ts_ind]}, corruption = {corruption[cor_ind]}')
            model = models.resnet50(pretrained=True)
            model.to(device)
            
            # adapt the model
            # start = time.time()
            train_acc, base_acc = adapt(baseline, model, corrupted_dataloaders[cor_ind]['train'])
            # adapt_time = time.time() - start
            # print(f"Adaptation time: {adapt_time}")
            
            # rpl_adpt_times[cor_ind, severity_ind, ts_ind] = adapt_time
            rpl_corrupt_train_acc[cor_ind, severity_ind, ts_ind] = train_acc
            baseline_train_acc[cor_ind, severity_ind, ts_ind] = base_acc
            
            # evaluate both the adapted model & baseline model on the, corrupted val data, and reference (soure) data
            # In validation step I treat the method as offline adaptation, then set BN layers to eval()
            model.eval()
            with torch.no_grad():
                base_val_correct, crpt_val_correct, crpt_val_samples, base_ref_correct, ref_correct, ref_samples = 0, 0, 0, 0, 0, 0
                for images, labels in corrupted_dataloaders[cor_ind]['val']:
                    labels = labels.to(device)
                    outputs = model(images.to(device))
                    _, predicted = torch.max(outputs.data, 1)
                    crpt_val_correct += (predicted == labels).sum().item()
                    crpt_val_samples += len(labels)
                    
                    b_outputs = baseline(images.to(device))
                    _, b_predicted = torch.max(b_outputs.data, 1)
                    base_val_correct += (b_predicted == labels).sum().item()
                    
                    
                print(f"Corrupt Val Adapt Accuracy: {crpt_val_correct:#5.0f}/{crpt_val_samples:#5.0f} ({crpt_val_correct / crpt_val_samples})")
                print(f"Corrupt Val Base Accuracy: {base_val_correct:#5.0f}/{crpt_val_samples:#5.0f} ({base_val_correct / crpt_val_samples})")
                
                rpl_corrupt_val_acc[cor_ind, severity_ind, ts_ind] = crpt_val_correct / crpt_val_samples
                baseline_val_acc[cor_ind, severity_ind, ts_ind] = base_val_correct / crpt_val_samples
                

                for images, labels in ref_dataloaders['val']:
                    labels = labels.to(device)
                    outputs = model(images.to(device))
                    _, predicted = torch.max(outputs.data, 1)
                    ref_correct += (predicted == labels).sum().item()
                    ref_samples += len(labels)
                    
                    b_outputs = baseline(images.to(device))
                    _, b_predicted = torch.max(b_outputs.data, 1)
                    base_ref_correct += (b_predicted == labels).sum().item()
                    
                print(f"Ref Adapt Accuracy: {ref_correct:#5.0f}/{ref_samples:#5.0f} ({ref_correct / ref_samples})")
                print(f"Ref Base Accuracy: {base_ref_correct:#5.0f}/{ref_samples:#5.0f} ({base_ref_correct / ref_samples})")
                
                rpl_ref_acc[cor_ind, severity_ind, ts_ind] = ref_correct / ref_samples
                baseline_ref_acc[cor_ind, severity_ind, ts_ind] = base_ref_correct / ref_samples


        with open('RPL_corrupt_train_acc.npy', 'wb') as f:
            np.save(f, rpl_corrupt_train_acc)
        with open('RPL_corrupt_validation_acc.npy', 'wb') as f:
            np.save(f, rpl_corrupt_val_acc)
        with open('RPL_source_acc.npy', 'wb') as f:
            np.save(f, rpl_ref_acc)
            
        with open('Base_corrupt_train_acc.npy', 'wb') as f:
            np.save(f, baseline_train_acc)
        with open('Base_corrupt_validation_acc.npy', 'wb') as f:
            np.save(f, baseline_val_acc)
        with open('Base_source_acc.npy', 'wb') as f:
            np.save(f, baseline_ref_acc)
#         with open('RPL_adapt_time.npy', 'wb') as f:
#             np.save(f, rpl_adpt_times)
        
        
            
            
            

In [None]:
with open('RPL_corrupt_train_acc.npy', 'wb') as f:
    np.save(f, rpl_corrupt_train_acc)
with open('RPL_corrupt_validation_acc.npy', 'wb') as f:
    np.save(f, rpl_corrupt_val_acc)
with open('RPL_source_acc.npy', 'wb') as f:
    np.save(f, rpl_ref_acc)

with open('Base_corrupt_train_acc.npy', 'wb') as f:
    np.save(f, baseline_train_acc)
with open('Base_corrupt_validation_acc.npy', 'wb') as f:
    np.save(f, baseline_val_acc)
with open('Base_source_acc.npy', 'wb') as f:
    np.save(f, baseline_ref_acc)


In [None]:
with open('RPL_corrupt_train_acc.npy', 'rb') as f:
    rpl_corrupt_train_acc = np.load(f)
with open('RPL_corrupt_validation_acc.npy', 'rb') as f:
    rpl_corrupt_val_acc = np.load(f)
with open('RPL_target_acc.npy', 'rb') as f:
    rpl_ref_acc = np.load(f)
# with open('RPL_adapt_time.npy', 'rb') as f:
#     rpl_adpt_times = np.load(f)
    
with open('Base_corrupt_train_acc.npy', 'rb') as f:
    baseline_train_acc = np.load(f)
with open('Base_corrupt_validation_acc.npy', 'rb') as f:
    baseline_val_acc = np.load(f)
with open('Base_source_acc.npy', 'rb') as f:
    baseline_ref_acc = np.load(f)



In [None]:
rpl_corrupt_val_acc

In [None]:
baseline_val_acc

In [None]:
rpl_corrupt_train_acc

In [None]:
baseline_train_acc

## Comparison with TENT

In [None]:
# validation accuracy averaged over all corruptions and all batch sizes
# sev1, 2, 3, 4, 5
rpl_corrupt_val_acc.transpose(2, 1, 0).mean(axis=0).mean(axis=1)

In [None]:
# validation accuracy (batch size = 32) averaged over all corruptions
# sev1, 2, 3, 4, 5
rpl_corrupt_val_acc.transpose(2, 1, 0)[0].mean(axis=1)

In [None]:
# validation error rate averaged over all corruptions and all batch sizes, shuffled
# sev5, 4, 3, 2, 1
(1 - rpl_corrupt_val_acc.transpose(2, 1, 0).mean(axis=0).mean(axis=1))[::-1]

In [None]:
# validation error rate (batch size = 32) averaged over all corruptions, shuffled
# sev5, 4, 3, 2, 1
1 - rpl_corrupt_val_acc.transpose(2, 1, 0)[0].mean(axis=1)[::-1]

In [None]:
# training / online error rate averaged over all corruptions and all batch sizes, shuffled
# sev5, 4, 3, 2, 1
(1 - rpl_corrupt_train_acc.transpose(2, 1, 0).mean(axis=0).mean(axis=1))[::-1]

In [None]:
# baseline randomed error rate
# sev5, 4, 3, 2, 1
(1 - baseline_performances.transpose(2, 1, 0).mean(axis=0).mean(axis=1))[::-1]

## Comparison with RPL Paper

In [None]:
print(corruption)
1-rpl_corrupt_val_acc.mean(axis=1).mean(axis=1)

In [None]:
1 - baseline_performances.mean(axis=1).mean(axis=1)

# Graphs

In [None]:
import matplotlib.pyplot as plt

In [None]:
with open('ResNet50_performances.npy', 'rb') as f:
    baseline_performances = np.load(f)
with open('ResNet50_times.npy', 'rb') as f:
    baseline_inf_times = np.load(f)
# with open('RPL_corrupt_train_acc.npy', 'rb') as f:
#     rpl_corrupt_train_acc = np.load(f)
# with open('RPL_corrupt_validation_acc.npy', 'rb') as f:
#     rpl_corrupt_val_acc = np.load(f)
# with open('RPL_target_acc.npy', 'rb') as f:
#     rpl_ref_acc = np.load(f)
# with open('RPL_adapt_time.npy', 'rb') as f:
#     rpl_adpt_times = np.load(f)


In [None]:
# severity 5, 16 graphs for 16 corruptions, 
# 6 lines for 6 batch sizes, an extra line for baseline corruption, an extra line for ref

In [None]:
target_sizes = [32, 64, 128, 256, 512]
severity = [1,2,3,4,5]

In [None]:
fig, ax = plt.subplots(4,4, figsize=(18, 18))
fig.suptitle('Accuracy on corrupted (severity = 3) training set')
for cor_ind in range(len(corruption)):
    ax[cor_ind//4][cor_ind%4].plot(target_sizes, baseline_train_acc[cor_ind][2], label = 'Resnet50 acc train set')
    ax[cor_ind//4][cor_ind%4].plot(target_sizes, rpl_corrupt_train_acc[cor_ind][2], label = 'Adapted train acc')
    
    ax[cor_ind//4][cor_ind%4].legend()
    ax[cor_ind//4][cor_ind%4].set_title('Corruption: '+corruption[cor_ind])
    ax[cor_ind//4][cor_ind%4].set_xticks(target_sizes)
    ax[cor_ind//4][cor_ind%4].set_ylim([0, 0.7])


    
for a in ax.flat:
    a.set(xlabel='Target adaptation batch size', ylabel='Accuracy')
    
# Hide x labels and tick labels for top plots and y ticks for right plots.
for a in ax.flat:
    a.label_outer()

In [None]:
fig, ax = plt.subplots(4,4, figsize=(18, 18))
fig.suptitle('Accuracy on corrupted (severity = 3) validation set')
for cor_ind in range(len(corruption)):
    ax[cor_ind//4][cor_ind%4].plot(target_sizes, baseline_val_acc[cor_ind][2], label = 'Resnet50 val acc')
    ax[cor_ind//4][cor_ind%4].plot(target_sizes, rpl_corrupt_val_acc[cor_ind][2], label='Adapted val acc')
    
    ax[cor_ind//4][cor_ind%4].legend()
    ax[cor_ind//4][cor_ind%4].set_title('Corruption: '+corruption[cor_ind])
    ax[cor_ind//4][cor_ind%4].set_xticks(target_sizes)
    ax[cor_ind//4][cor_ind%4].set_ylim([0, 0.9])


    
for a in ax.flat:
    a.set(xlabel='Target adaptation batch size', ylabel='Accuracy')
    
# Hide x labels and tick labels for top plots and y ticks for right plots.
for a in ax.flat:
    a.label_outer()

In [None]:
fig, ax = plt.subplots(4,4, figsize=(18, 18))
fig.suptitle('Accuracy on corrupted (severity = 3) train & validation set')
for cor_ind in range(len(corruption)):
    ax[cor_ind//4][cor_ind%4].plot(target_sizes, baseline_val_acc[cor_ind][2], label = 'Resnet50 val acc')
    ax[cor_ind//4][cor_ind%4].plot(target_sizes, rpl_corrupt_val_acc[cor_ind][2], label='Adapted val acc')
    ax[cor_ind//4][cor_ind%4].plot(target_sizes, baseline_performances[cor_ind][2][:-1], label = 'Resnet50_random')
    
    ax[cor_ind//4][cor_ind%4].plot(target_sizes, baseline_train_acc[cor_ind][2], label = 'Resnet50 train acc')
    ax[cor_ind//4][cor_ind%4].plot(target_sizes, rpl_corrupt_train_acc[cor_ind][2], label = 'Adapted train acc')
    
    ax[cor_ind//4][cor_ind%4].legend()
    ax[cor_ind//4][cor_ind%4].set_title('Corruption: '+corruption[cor_ind])
    ax[cor_ind//4][cor_ind%4].set_xticks(target_sizes)
    ax[cor_ind//4][cor_ind%4].set_ylim([0, 0.9])


    
for a in ax.flat:
    a.set(xlabel='Target adaptation batch size', ylabel='Accuracy')
    
# Hide x labels and tick labels for top plots and y ticks for right plots.
for a in ax.flat:
    a.label_outer()