In [15]:
## Training loop for transferring to CIFAR100 or SVHN

# clear memory
from IPython import get_ipython
get_ipython().magic('reset -sf') 

import numpy as np
import torch
import time
timer = 0

from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler

import torch.nn as nn
import torch.nn.functional as F

import torch.nn as nn

from torch.utils.tensorboard import SummaryWriter



import os, random

# import argparse

# argument_parser = argparse.ArgumentParser()

# argument_parser.add_argument("--lr_init", type=float, help="Initial learning rate value, default=0.01. CAREFUL: this will be divided by beta, since the ERM term is multiplied by beta in the objective.")

# parsed_args = argument_parser.parse_args()


# Make sure validation splits are the same at all time (e.g. even after loading)
seed = 0

def seed_init_fn(seed=seed):
   np.random.seed(seed)
   random.seed(seed)
   torch.manual_seed(seed)
   return

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

num_workers = 0
# Make sure test_data is a multiple of batch_size_test
batch_size_train_and_valid = 128
batch_size_test = 200

# proportion of full training set used for validation
valid_size = 0.2














class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
#         print(x.size(), out.size())
        return out


def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])


def prepare_model_for_finetuning(model, model_path, model_filename, dataset_name, num_output_classes, device=device):
    is_dataparallel = False
    if str(device) == "cuda" and torch.cuda.device_count() > 1:
        print("Using DataParallel")
        model = torch.nn.DataParallel(model)
        is_dataparallel = True
        
    checkpoint = torch.load(model_path)

    model_successfully_loaded = False
    try: 
        # Some models have been saved as DataParallel. For those we need to load to model, for others to model.module. Lessons
        # for the next project: always save module. :) Although this would still be a problem for for loop such as here as even
        # if I DataParallel in the for loop, the next iteration of the for loop tries to load non-DataParallel to a DataParalleled model
        try:
            model.load_state_dict(checkpoint['current_model'])
            print("Successfully loaded onto model.")
        except:
            print("Failed to load model onto model, attempting to load onto model.module...")
            model.module.load_state_dict(checkpoint['current_model'])
            print("Successfully loaded onto model.module.")
        model_successfully_loaded = True
        print("model_successfully_loaded:", model_successfully_loaded, flush=True)
    except:
        print("Model not stored on 'current_model' key")
        model_successfully_loaded = False
    if not model_successfully_loaded:
        try:
            # Loading the PAT model is slightly different
            checkpoint = torch.load(model_path)
            try:
                model.load_state_dict(checkpoint['model'])
            except:
                print("Failed to load model onto model, attempting to load onto model.module...")
                model.module.load_state_dict(checkpoint['model'])
                print("Successfully loaded onto model.module.")
        except:
            raise ValueError("Did not succeed in loading model -- check the key used to store model in the checkpoint loaded.")
        

    if is_dataparallel:
        for param in model.module.parameters():
            param.requires_grad = False
        model.module.linear = nn.Linear(512 * BasicBlock.expansion, num_output_classes)
    else:
        for param in model.parameters():
            param.requires_grad = False
        model.linear = nn.Linear(512 * BasicBlock.expansion, num_output_classes)        


    model.to(device)
    print("Loaded model ", model_filename, " on " + dataset_name)
    return model


def prepare_train_and_valid_dataloader(dataset_name, batch_size_train_and_valid=batch_size_train_and_valid, seed=seed):
    if dataset_name == "CIFAR100":
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))])

        train_and_valid_data = datasets.CIFAR100(root = "data", train = True, download = True, transform = transform_train)

        lr_init = 0.1
        num_output_classes = 100

    elif dataset_name == "SVHN":
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 

        train_and_valid_data = datasets.SVHN(root = "data", split="train", download = True, transform = transform_train)

        lr_init = 0.1
        num_output_classes = 10

    else:
        raise ValueError("Unsupported dataset name. Supported datasets are CIFAR100, SVHN. You entered:", dataset_name)






    num_valid_samples = int(np.floor(valid_size * len(train_and_valid_data)))
    num_train_samples = len(train_and_valid_data) - num_valid_samples
    train_data, valid_data = torch.utils.data.random_split(train_and_valid_data, [num_train_samples, num_valid_samples], generator=torch.Generator().manual_seed(seed))

    train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size_train_and_valid)
    valid_loader = torch.utils.data.DataLoader(valid_data, batch_size = batch_size_train_and_valid)
    return train_loader, valid_loader, lr_init, num_output_classes































    








# Careful: this is about loading pretrained models, not checkpoints of the model during CIFAR100 training !
WORKING_DIR = "results/CIFAR10/"
TRAINED_MODEL_PATH = WORKING_DIR + "models/"
for root, dirs, files in os.walk(TRAINED_MODEL_PATH):
    model_filenames = files
    model_paths = [TRAINED_MODEL_PATH + file for file in files]


    
num_epochs = 30    
# Controls how many time we repeat finetuning/training of models to get avg and std since it's not too expensive
num_training_loops = 3

dataset_names = ["SVHN", "CIFAR100"]

for dataset_name in dataset_names:
    TRAINING_OUTPUT_ROOT = "experiments/" + dataset_name + '/'
    for model_num, model_path in enumerate(model_paths):
        for training_loop in range(0, num_training_loops):
            TRAINING_OUTPUT_PATH = TRAINING_OUTPUT_ROOT + model_filenames[model_num] + '/'
            os.makedirs(TRAINING_OUTPUT_PATH, exist_ok=True)

            # Make sure validation splits are the same at all time (e.g. even after loading)
            seed = training_loop
            train_loader, valid_loader, lr_init, num_output_classes = prepare_train_and_valid_dataloader(dataset_name, batch_size_train_and_valid=batch_size_train_and_valid, seed=seed)

            model = ResNet18()
            model.to(device)
            model = prepare_model_for_finetuning(model, model_path, model_filenames[model_num], dataset_name, num_output_classes, device)


            writer = SummaryWriter(TRAINING_OUTPUT_PATH, comment="_loop_"+str(training_loop))

            optimizer = torch.optim.SGD(model.parameters(), lr = lr_init, momentum=0.9, weight_decay=5e-4)
            schedule = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20], gamma=0.1)

            valid_loss_min = np.Inf
            best_epoch = 0


            for epoch in range(0, num_epochs):
                train_loss = 0
                num_train_correct_preds = 0
                valid_loss = 0
                num_valid_correct_preds = 0

                ###################
                # Train the model #
                ###################
                model.train()
                for data, label in train_loader:
                    data, label = data.to(device), label.to(device)

                    optimizer.zero_grad()
                    preds = model(data)
                    loss = F.cross_entropy(preds, label)
                    loss.backward()
                    optimizer.step()
                    train_loss += loss.item() * data.size(0)
                    num_train_correct_preds += (torch.argmax(preds, dim=1) == label).sum().item()






                ######################    
                # Validate the model #
                ######################
                model.eval()
                with torch.no_grad():
                    for _, (data, label) in enumerate(valid_loader):
                        data, label = data.to(device), label.to(device)
                        preds = model(data)
                        loss = F.cross_entropy(preds, label)
                        valid_loss += loss.item() * data.size(0)
                        num_valid_correct_preds += (torch.argmax(preds, dim=1) == label).sum().item()


                # Average loss over epoch. Careful about computations where we average over batches; they will have a bias if dataset size not multiple of batch_size
                train_loss = train_loss / len(train_loader.sampler)
                # Handling validation terms differently based on stopping early or not (because when using the full set its size may not be divisible by batch_size) !
                valid_loss = valid_loss / len(valid_loader.sampler)
                training_acc = num_train_correct_preds / len(train_loader.sampler)
                valid_acc = num_valid_correct_preds / len(valid_loader.sampler)

                print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
                    epoch+1, 
                    train_loss,
                    valid_loss
                    ))

                print('Epoch: {} \tTraining accuracy: {:.6f} \tValidation accuracy: {:.6f}'.format(
                    epoch+1, 
                    training_acc,
                    valid_acc
                    ))

                epoch_lr = schedule.get_last_lr()[0]
                schedule.step()

                if valid_loss <= valid_loss_min:
                    print('Validation loss decreased ({:.6f} --> {:.6f}). '.format(
                    valid_loss_min,
                    valid_loss))
                    best_epoch = epoch
                    valid_loss_min = valid_loss


                    path_of_checkpoint = TRAINING_OUTPUT_PATH + dataset_name + '_loop_' + str(training_loop) + '.pt'

                    # lr to start from in checkpoint
                    lr_init = schedule.get_last_lr()[0]

                    checkpoint = {'current_model': model.module.state_dict(),
                                  'optimiser': optimizer.state_dict(),
                                  'schedule': schedule.state_dict(),
                                  'learning_rate': lr_init,
                                  'epoch': epoch + 1,
                                  'best_epoch': best_epoch,
                                  'seed': seed
                                 }

                    torch.save(checkpoint, path_of_checkpoint)

                writer.add_scalar('Learning_rate', epoch_lr, epoch+1)


                writer.add_scalar('Training_loss', train_loss, epoch+1)
                writer.add_scalar('Validation_loss', valid_loss, epoch+1)

                writer.add_scalar('Training_accuracy', training_acc, epoch+1)
                writer.add_scalar('Validation_accuracy', valid_acc, epoch+1)

            writer.close()


Using downloaded and verified file: data/train_32x32.mat
Using DataParallel
Failed to load model onto model, attempting to load onto model.module...
Successfully loaded onto model.module.
model_successfully_loaded: True
Starting training of model  model_MSD_ERM_51.pt  on SVHN
Epoch: 1 	Training Loss: 1.861225 	Validation Loss: 1.737258
Epoch: 1 	Training accuracy: 0.364570 	Validation accuracy: 0.415057
Validation loss decreased (inf --> 1.737258). 
Epoch: 2 	Training Loss: 1.711976 	Validation Loss: 1.675906
Epoch: 2 	Training accuracy: 0.425622 	Validation accuracy: 0.440448
Validation loss decreased (1.737258 --> 1.675906). 
Epoch: 3 	Training Loss: 1.671034 	Validation Loss: 1.648605
Epoch: 3 	Training accuracy: 0.440552 	Validation accuracy: 0.451915
Validation loss decreased (1.675906 --> 1.648605). 
Epoch: 4 	Training Loss: 1.645640 	Validation Loss: 1.636547
Epoch: 4 	Training accuracy: 0.450193 	Validation accuracy: 0.455464
Validation loss decreased (1.648605 --> 1.636547). 


KeyboardInterrupt: 

In [14]:
## Eval loop for fine tuned models on CIFAR100 or SVHN

# clear memory
from IPython import get_ipython
get_ipython().magic('reset -sf') 

import numpy as np
import torch
import time
timer = 0

from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler

import torch.nn as nn
import torch.nn.functional as F

import torch.nn as nn

from torch.utils.tensorboard import SummaryWriter



import os, random

# import argparse

# argument_parser = argparse.ArgumentParser()

# argument_parser.add_argument("--lr_init", type=float, help="Initial learning rate value, default=0.01. CAREFUL: this will be divided by beta, since the ERM term is multiplied by beta in the objective.")

# parsed_args = argument_parser.parse_args()


# Make sure validation splits are the same at all time (e.g. even after loading)
seed = 0

def seed_init_fn(seed=seed):
   np.random.seed(seed)
   random.seed(seed)
   torch.manual_seed(seed)
   return

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

num_workers = 0
# Make sure test_data is a multiple of batch_size_test
batch_size_train_and_valid = 128
batch_size_test = 200

# proportion of full training set used for validation
valid_size = 0.2














class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
#         print(x.size(), out.size())
        return out


def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])


def model_loading_helper(model, model_path, device=device):
    is_dataparallel = False
    if str(device) == "cuda" and torch.cuda.device_count() > 1:
        print("Using DataParallel")
        model = torch.nn.DataParallel(model)
        is_dataparallel = True
        
    checkpoint = torch.load(model_path)

    model_successfully_loaded = False
    try: 
        # Some models have been saved as DataParallel. For those we need to load to model, for others to model.module. Lessons
        # for the next project: always save module. :) Although this would still be a problem for for loop such as here as even
        # if I DataParallel in the for loop, the next iteration of the for loop tries to load non-DataParallel to a DataParalleled model
        try:
            model.load_state_dict(checkpoint['current_model'])
            print("Successfully loaded onto model.")
        except:
            print("Failed to load model onto model, attempting to load onto model.module...")
            model.module.load_state_dict(checkpoint['current_model'])
            print("Successfully loaded onto model.module.")
        model_successfully_loaded = True
        print("model_successfully_loaded:", model_successfully_loaded, flush=True)
    except:
        print("Model not stored on 'current_model' key")
        model_successfully_loaded = False
    if not model_successfully_loaded:
        try:
            # Loading the PAT model is slightly different
            checkpoint = torch.load(model_path)
            try:
                model.load_state_dict(checkpoint['model'])
            except:
                print("Failed to load model onto model, attempting to load onto model.module...")
                model.module.load_state_dict(checkpoint['model'])
                print("Successfully loaded onto model.module.")
        except:
            raise ValueError("Did not succeed in loading model -- check the key used to store model in the checkpoint loaded.")
    return model, is_dataparallel


def prepare_model_for_finetuning(model, model_path, model_filename, dataset_name, num_output_classes, device=device):
    model, is_dataparallel = model_loading_helper(model, model_path, device)
        

    if is_dataparallel:
        for param in model.module.parameters():
            param.requires_grad = False
        model.module.linear = nn.Linear(512 * BasicBlock.expansion, num_output_classes)
    else:
        for param in model.parameters():
            param.requires_grad = False
        model.linear = nn.Linear(512 * BasicBlock.expansion, num_output_classes)        


    model.to(device)
    print("Loaded model ", model_filename, " on " + dataset_name)
    return model




def prepare_test_dataloader(dataset_name, batch_size_test=batch_size_test, seed=seed):
    if dataset_name == "CIFAR100":
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))])


        test_data = datasets.CIFAR100(root = "data", train = False, download = True, transform = transform_test)
        num_output_classes = 100

    elif dataset_name == "SVHN":
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        test_data = datasets.SVHN(root = "data", split="test", download = True, transform = transform_test)
        num_output_classes = 10

    else:
        raise ValueError("Unsupported dataset name. Supported datasets are CIFAR100, SVHN. You entered:", dataset_name)


    test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size_test, worker_init_fn=seed_init_fn(seed))
    return test_loader, num_output_classes




























    












    
num_epochs = 30    
# Controls how many time we repeat finetuning/training of models to get avg and std since it's not too expensive
num_training_loops = 3
top_k = 3

dataset_names = ["SVHN", "CIFAR100"]


for dataset_name in dataset_names:
    SAVE_DIR = "results/" + dataset_name + '/'
    TRAINED_MODEL_PATH = "experiments/" + dataset_name + '/'
    _, base_model_list, _ = next(os.walk(TRAINED_MODEL_PATH))
    for base_model in base_model_list:
        topk_accuracies_test = []


        _, _, base_model_iterates = next(os.walk(TRAINED_MODEL_PATH+base_model))
        base_model_iterates = [base_model_iterate for base_model_iterate in base_model_iterates if base_model_iterate.endswith(".pt")]
        # Iterate over all repeats with different seeds of a given model to aggregate statistics.
        for eval_loop, base_model_iterate in enumerate(base_model_iterates):
            model_path = TRAINED_MODEL_PATH + base_model + '/' + base_model_iterate
            # Make sure validation splits are the same at all time (e.g. even after loading)

            seed = eval_loop
            test_loader, num_output_classes = prepare_test_dataloader(dataset_name, batch_size_test=batch_size_test, seed=seed)

            model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_output_classes)
            model.to(device)
            model, _ = model_loading_helper(model, model_path, device)


            test_loss = 0
            # confusion_matrix_test.append(torch.zeros([num_output_classes, num_output_classes]).to(device))
            topk_accuracies_test.append(torch.zeros(top_k))


            ######################    
            # Test the model #
            ######################
            model.eval()
            for _, (data, label) in enumerate(test_loader):
                data, label = data.to(device), label.to(device)

                with torch.no_grad():
                    preds = model(data)#, y=label)
                    loss = F.cross_entropy(preds, label)
                    test_loss += loss.item() * data.size(0)

                    # Update count of number of correct predictions per domain
                    # pred_probabilities = F.softmax(preds, dim=1)
                    predicted_topk = torch.topk(preds, top_k, dim=1).indices
                    for iter_samples, (pred, target) in enumerate(zip(predicted_topk, label)):
                        # # confusion_matrix_test[int(target)] += pred_probabilities[iter_samples]
                        # confusion_matrix_test[int(target), int(pred[0])] += 1
                        for iter_topk in range(0, top_k):
                            if target in pred[:iter_topk+1]:
                                topk_accuracies_test[eval_loop][iter_topk] += 1






            test_loss = test_loss / len(test_loader.sampler)
            topk_accuracies_test[eval_loop] /= len(test_loader.sampler)
            # confusion_matrix_test[eval_loop] /= len(test_loader.sampler)


            print("Model: {} \tTest Loss: {:.6f}".format(
                base_model + '/' + base_model_iterate, 
                test_loss
                ))


        # Compute statistics over the runs
        topk_accuracies_test_std, topk_accuracies_test_mean = torch.std_mean(torch.stack(topk_accuracies_test), dim=0)

        for iter_topk in range(0, top_k):
            print("Top {} test accuracy: {:.6f} std {:.6f}".format(
                iter_topk+1,
                topk_accuracies_test_mean[iter_topk], topk_accuracies_test_std[iter_topk]
                ), flush=True)

        for i in range(len(topk_accuracies_test)):
            topk_accuracies_test[i] = topk_accuracies_test[i].detach().numpy()
        # confusion_matrix_test = confusion_matrix_test.cpu().numpy()
        results = {}
        results["topk_accuracies"] = topk_accuracies_test
        results["topk_accuracies_mean"] = topk_accuracies_test_mean.detach().numpy()
        results["topk_accuracies_std"] = topk_accuracies_test_std.detach().numpy()
        results["base_model"] = base_model
        results["number_of_iterations"] = len(base_model_iterates)

        # results["confusion_matrix"] = confusion_matrix_test


        # df_confmat = pd.DataFrame(confusion_matrix_test, index = domains, columns=domains)
        # fig = plt.figure(figsize=(15,10))#, dpi=1200)
        # heatmap = sbn.heatmap(df_confmat, annot=True)
        # heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right')#, fontsize=15)
        # heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right')#, fontsize=15)
        # plt.ylabel('True label')
        # plt.xlabel('Predicted label')

        working_dir_of_save = SAVE_DIR + "test_accs/"
        os.makedirs(SAVE_DIR + "test_accs/", exist_ok=True)
        np.save(working_dir_of_save + base_model, results)
        # fig.savefig(working_dir_of_save + model_filenames[model_num] + "_confusion_matrix.pdf", bbox_inches='tight')
        # plt.show()


Using downloaded and verified file: data/test_32x32.mat
Using DataParallel
Failed to load model onto model, attempting to load onto model.module...
Successfully loaded onto model.module.
model_successfully_loaded: True
Model: model_ERM_clean_56.pt/SVHN_loop_2.pt 	Test Loss: 1.705784
Using downloaded and verified file: data/test_32x32.mat
Using DataParallel
Failed to load model onto model, attempting to load onto model.module...
Successfully loaded onto model.module.
model_successfully_loaded: True
Model: model_ERM_clean_56.pt/SVHN_loop_1.pt 	Test Loss: 1.702524
Using downloaded and verified file: data/test_32x32.mat
Using DataParallel
Failed to load model onto model, attempting to load onto model.module...
Successfully loaded onto model.module.
model_successfully_loaded: True
Model: model_ERM_clean_56.pt/SVHN_loop_0.pt 	Test Loss: 1.701157
Top 1 test accuracy: 0.424311 std 0.001562
Top 2 test accuracy: 0.595331 std 0.001715
Top 3 test accuracy: 0.711919 std 0.000835
Using downloaded an

In [64]:
### Eval base models on CIFAR-10-C


# clear memory
from IPython import get_ipython
get_ipython().magic('reset -sf') 

import numpy as np
import torch
import time
timer = 0

from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler

import torch.nn as nn
import torch.nn.functional as F

import torch.nn as nn

from torch.utils.tensorboard import SummaryWriter
from PIL import Image




import os, random

# import argparse

# argument_parser = argparse.ArgumentParser()

# argument_parser.add_argument("--lr_init", type=float, help="Initial learning rate value, default=0.01. CAREFUL: this will be divided by beta, since the ERM term is multiplied by beta in the objective.")

# parsed_args = argument_parser.parse_args()


# Make sure validation splits are the same at all time (e.g. even after loading)
seed = 0

def seed_init_fn(seed=seed):
   np.random.seed(seed)
   random.seed(seed)
   torch.manual_seed(seed)
   return

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

num_workers = 0
batch_size_test = 200





class CIFAR10C(datasets.VisionDataset):
    corruptions = ['brightness', 'contrast', 'defocus_blur', 'elastic_transform',
                        'fog', 'frost', 'gaussian_blur', 'gaussian_noise', 'glass_blur',
                        'impulse_noise', 'jpeg_compression', 'motion_blur', 'pixelate',
                        'saturate', 'shot_noise', 'snow', 'spatter', 'speckle_noise',
                        'zoom_blur']
    def __init__(self, root :str, name :str, transform=None, target_transform=None):
        assert name in self.corruptions

        # Download the dataset if needed
        if not os.path.exists(root):
            import urllib.request
            from tqdm import tqdm
            import tarfile
            os.mkdir(root)
            url = "https://zenodo.org/record/2535967/files/CIFAR-10-C.tar?download=1"
            file_name = 'cifar-c.tar'
            file_path = os.path.join(root, file_name)

            # Check if .tar of CIFAR-C already downloaded, and download it if necessary
            if not os.path.exists(file_path):
                print('Downloading CIFAR-C dataset...')
                with tqdm(unit='B', unit_scale=True, desc=file_name, leave=True) as progress_bar:
                    urllib.request.urlretrieve(url, file_path, reporthook=lambda blocknum, blocksize, totalsize: progress_bar.update(blocknum * blocksize - progress_bar.n))
            else:
                print(file_name, " already downloaded.")

            # Extract the file
            with tarfile.open(file_path, 'r') as tar:
                tar.extractall(root)
            

        super(CIFAR10C, self).__init__(root, transform=transform, target_transform=target_transform)
        data_path = os.path.join(root, name + '.npy')
        target_path = os.path.join(root, 'labels.npy')
        
        self.data = np.load(data_path)
        self.targets = np.load(target_path)
        
    def __getitem__(self, index):
        img, targets = self.data[index], self.targets[index]
        img = Image.fromarray(img)
        targets = torch.tensor(targets, dtype=torch.long)
        
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            targets = self.target_transform(targets)
            
        return img, targets
    
    def __len__(self):
        return len(self.data)
    












class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
#         print(x.size(), out.size())
        return out


def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])


def model_loading_helper(model, model_path, device=device):
    is_dataparallel = False
    if str(device) == "cuda" and torch.cuda.device_count() > 1:
        print("Using DataParallel")
        model = torch.nn.DataParallel(model)
        is_dataparallel = True
        
    checkpoint = torch.load(model_path)

    model_successfully_loaded = False
    try: 
        # Some models have been saved as DataParallel. For those we need to load to model, for others to model.module. Lessons
        # for the next project: always save module. :) Although this would still be a problem for for loop such as here as even
        # if I DataParallel in the for loop, the next iteration of the for loop tries to load non-DataParallel to a DataParalleled model
        try:
            model.load_state_dict(checkpoint['current_model'])
            print("Successfully loaded onto model.")
        except:
            print("Failed to load model onto model, attempting to load onto model.module...")
            model.module.load_state_dict(checkpoint['current_model'])
            print("Successfully loaded onto model.module.")
        model_successfully_loaded = True
        print("model_successfully_loaded:", model_successfully_loaded, flush=True)
    except:
        print("Model not stored on 'current_model' key")
        model_successfully_loaded = False
    if not model_successfully_loaded:
        try:
            # Loading the PAT model is slightly different
            checkpoint = torch.load(model_path)
            try:
                model.load_state_dict(checkpoint['model'])
            except:
                print("Failed to load model onto model, attempting to load onto model.module...")
                model.module.load_state_dict(checkpoint['model'])
                print("Successfully loaded onto model.module.")
        except:
            raise ValueError("Did not succeed in loading model -- check the key used to store model in the checkpoint loaded.")
    return model, is_dataparallel



def prepare_test_dataloader(dataset_name, batch_size_test=batch_size_test, seed=seed, corruption_name=None):
    if dataset_name == "CIFAR100":
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))])


        test_data = datasets.CIFAR100(root = "data", train = False, download = True, transform = transform_test)
        num_output_classes = 100

    elif dataset_name == "SVHN":
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        test_data = datasets.SVHN(root = "data", split="test", download = True, transform = transform_test)
        num_output_classes = 10

    elif dataset_name == "CIFAR-10-C":
        assert corruption_name in CIFAR10C.corruptions
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.49139968,  0.48215841,  0.44653091), (0.24703223,  0.24348513,  0.26158784))])

        test_data = CIFAR10C(root="data/CIFAR-10-C/", name=corruption_name, transform=transform_test)
        num_output_classes = 10

    else:
        raise ValueError("Unsupported dataset name. Supported datasets are CIFAR100, SVHN, CIFAR-10-C. You entered:", dataset_name)


    test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size_test, worker_init_fn=seed_init_fn(seed))
    return test_loader, num_output_classes




























    






WORKING_DIR = "results/CIFAR10/"
TRAINED_MODEL_PATH = WORKING_DIR + "models/"
for root, dirs, files in os.walk(TRAINED_MODEL_PATH):
    model_filenames = files
    model_paths = [TRAINED_MODEL_PATH + file for file in files]





    

top_k = 3

dataset_name = "CIFAR-10-C"
SAVE_DIR = "results/" + dataset_name + '/'

for model_num, model_path in enumerate(model_paths):
    topk_accuracies_test = {}
    for corruption_name in CIFAR10C.corruptions:



        seed = 0
        test_loader, num_output_classes = prepare_test_dataloader(dataset_name, batch_size_test=batch_size_test, corruption_name=corruption_name, seed=seed)

        model = ResNet18()
        model.to(device)
        model, _ = model_loading_helper(model, model_path, device)


        test_loss = 0
        # confusion_matrix_test.append(torch.zeros([num_output_classes, num_output_classes]).to(device))
        topk_accuracies_test[corruption_name] = torch.zeros(top_k)


        ######################    
        # Test the model #
        ######################
        model.eval()
        for _, (data, label) in enumerate(test_loader):
            data, label = data.to(device), label.to(device)

            with torch.no_grad():
                preds = model(data)#, y=label)
                loss = F.cross_entropy(preds, label)
                test_loss += loss.item() * data.size(0)

                # Update count of number of correct predictions per domain
                # pred_probabilities = F.softmax(preds, dim=1)
                predicted_topk = torch.topk(preds, top_k, dim=1).indices
                for iter_samples, (pred, target) in enumerate(zip(predicted_topk, label)):
                    # # confusion_matrix_test[int(target)] += pred_probabilities[iter_samples]
                    # confusion_matrix_test[int(target), int(pred[0])] += 1
                    for iter_topk in range(0, top_k):
                        if target in pred[:iter_topk+1]:
                            topk_accuracies_test[corruption_name][iter_topk] += 1






        test_loss = test_loss / len(test_loader.sampler)
        topk_accuracies_test[corruption_name] /= len(test_loader.sampler)
        # confusion_matrix_test[eval_loop] /= len(test_loader.sampler)
        topk_accuracies_test[corruption_name] = topk_accuracies_test[corruption_name].detach().numpy()

        print("Model: {} \tTest Loss: {:.6f}".format(
            model_filenames[model_num], 
            test_loss
            ))


        for iter_topk in range(0, top_k):
            print("Top {} test accuracy on corruption {}:".format(
                iter_topk+1,
                corruption_name
                ), 
                topk_accuracies_test[corruption_name][iter_topk], flush=True)


    # confusion_matrix_test = confusion_matrix_test.cpu().numpy()
    results = {}
    results["topk_accuracies"] = topk_accuracies_test
    results["model_name"] = model_filenames[model_num]
    results["top_k"] = top_k

    # results["confusion_matrix"] = confusion_matrix_test


    # df_confmat = pd.DataFrame(confusion_matrix_test, index = domains, columns=domains)
    # fig = plt.figure(figsize=(15,10))#, dpi=1200)
    # heatmap = sbn.heatmap(df_confmat, annot=True)
    # heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right')#, fontsize=15)
    # heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right')#, fontsize=15)
    # plt.ylabel('True label')
    # plt.xlabel('Predicted label')

    working_dir_of_save = SAVE_DIR + "test_accs/"
    os.makedirs(SAVE_DIR + "test_accs/", exist_ok=True)
    np.save(working_dir_of_save + model_filenames[model_num], results)
    # fig.savefig(working_dir_of_save + model_filenames[model_num] + "_confusion_matrix.pdf", bbox_inches='tight')
    # plt.show()


Using DataParallel
Failed to load model onto model, attempting to load onto model.module...
Successfully loaded onto model.module.
model_successfully_loaded: True
Model: model_ERM_clean_56.pt 	Test Loss: 2.403580
Top 1 test accuracy on corruption brightness: 0.31656
Top 2 test accuracy on corruption brightness: 0.4671
Top 3 test accuracy on corruption brightness: 0.58354
Using DataParallel
Failed to load model onto model, attempting to load onto model.module...
Successfully loaded onto model.module.
model_successfully_loaded: True
Model: model_ERM_clean_56.pt 	Test Loss: 2.119864
Top 1 test accuracy on corruption contrast: 0.35358
Top 2 test accuracy on corruption contrast: 0.47242
Top 3 test accuracy on corruption contrast: 0.56892
Using DataParallel
Failed to load model onto model, attempting to load onto model.module...
Successfully loaded onto model.module.
model_successfully_loaded: True
Model: model_ERM_clean_56.pt 	Test Loss: 2.551339
Top 1 test accuracy on corruption defocus_bl