In [1]:
import torch, os
import torch.nn as nn
import torchvision.datasets as datasets
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.models as models
import numpy as np
import random
import torch.utils.data as torchdata
from torch.utils.data import SubsetRandomSampler
import matplotlib.pyplot as plt
import time
# from train_model import train_model
# from test_model import test_model
%matplotlib inline
import pickle

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]='0'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
filePath = '/local/rcs/ll3504/datasets/256_ObjectCategories/'
namelist = os.listdir(filePath)
nameDic_cal = {}
for name in namelist:
    splits = name.split(".")
    nameDic_cal[int(splits[0])-1] = splits[1]
print(nameDic_cal[1])

american-flag


In [4]:
def get_dataset(path='/database', dataset_name='caltech-256-common'):
    # No holdout testing data. train and test data are the same, but different transformation
    data_transforms = {
        'train': transforms.Compose([
#             transforms.Resize([256, 256]),
#             transforms.RandomCrop(224),
#             transforms.RandomRotation(20),
#             transforms.RandomHorizontalFlip(0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ]),
        'test': transforms.Compose([
#             transforms.Resize([224, 224]),
#             transforms.CenterCrop((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    }

    tr_dataset = datasets.ImageFolder(path + dataset_name + '/', data_transforms['train'])
    te_dataset = datasets.ImageFolder(path + dataset_name + '/', data_transforms['test'])
#     print('{} train set size: {}'.format(dataset_name, len(tr_dataset)))
#     print('{} test set size: {}'.format(dataset_name, len(te_dataset)))

    return tr_dataset, te_dataset

In [5]:
def split_dataset(train_dataset, test_dataset, valid_size=0.2, batch_size=128, train_size = 128):
    '''
    This function splits dataset into train, val, and test sets, and return train, val, test dataloaders.
    Val and Test loaders are the same

    '''
    
    # what does the len function gives?
    num_train = len(train_dataset)
    # print("DEBUGGING: overall training data size =", num_train)
    indices = list(range(num_train))
    split = int(np.floor(valid_size * num_train))
    random.shuffle(indices)
    train_idx, valid_idx = indices[split:split+train_size], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)
    
    # print("DEBUGGING: the train_ind are:", len(train_idx))


    train_loader = torchdata.DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=48, pin_memory=True, drop_last=True, sampler = train_sampler)
    test_loader = torchdata.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=48, pin_memory=True, drop_last=True, sampler = valid_sampler)
    dataloaders = {'train': train_loader,
                   'val': test_loader,
                   'test': test_loader}
    dataset_sizes ={'train': train_size, #int(np.floor((1-valid_size) * num_train)),
                    'val': int(np.floor(valid_size * num_train)),
                    'test': int(np.floor(valid_size * num_train))}
    return dataloaders, dataset_sizes

In [6]:
imagebase = '/local/rcs/ll3504/datasets/'

In [15]:
corruption = ['zoom_blur', 'speckle_noise', 'spatter',
                       'snow', 'glass_blur', 'motion_blur', 'saturate',
                       'gaussian_blur', 'frost', 'fog', 'brightness', 'contrast',
                       'elastic_transform', 'pixelate', 'jpeg_compression', 'defocus_blur']
weather  = ['snow', 'frost', 'fog', 'spatter']

In [8]:
def get_imagenetc(imagebase, severity=1, batch_size=128, sample_size = 128):
    '''
    Returns:
        ref_dataloaders:          ImageNet original validation data, as a reference
        ref_dataset_sizes:        1000, not the sizes of the real dataset in the ref_loader, probs used downstream
        corrupted_dataloaders:    A list of corrupted dataloaders, each element in a list represetns the data loaders
                                  for one corruption type. Each element contains ['train']['val']['test'] loaders
        corrupted_dataset_sizes:  A list of dictionaries of the sizes of each loaders for each corruption
        corruption:               A list of corruption names, in the same order of the corrupted_dataloaders
    '''
    corruption = ['zoom_blur', 'speckle_noise', 'spatter',
                       'snow', 'glass_blur', 'motion_blur', 'saturate',
                       'gaussian_blur', 'frost', 'fog', 'brightness', 'contrast',
                       'elastic_transform', 'pixelate', 'jpeg_compression', 'defocus_blur']
    corrupted_dataloaders = []
    corrupted_dataset_sizes = []
    
    # this is the imageNet validation data
    imagenet_val = datasets.ImageNet(imagebase+'imagenetc/', split='val', transform=transforms.Compose([
            transforms.Resize([224, 224]),
            transforms.CenterCrop((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ]),
                                   target_transform=None)#, download=False)
    
    # TODO: subsample some size of ImageNet training data as source
        # Doesn't need this step
#     print("DEBUGGING: imagenet_val size is:", len(imagenet_val))
    
    random_indices = random.sample(range(0, len(imagenet_val)), int(len(imagenet_val)*0.02))
#     print("DEBUGGING: random indices are:", len(random_indices))
    imagenet_val_subset = data.Subset(imagenet_val, random_indices)
    val_loader = torch.utils.data.DataLoader(imagenet_val_subset,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=48)
    ref_dataloaders = { 'val': val_loader,
                       'test': val_loader}
    ref_dataset_sizes ={'val': int(len(val_loader.dataset)),
                        'test': int(len(val_loader.dataset))}
    
    # for every type of corruption, go to the specified severity folder
    for corr in corruption:
        dataset_name = 'imagenetc/' + corr + '/' + str(severity)
        # Get dataset from folder
        corr_trian_images, corr_test_images = get_dataset(imagebase, dataset_name)
        # get corruption-specific train, val, test loader
        # train: training data, non-overlap with val/test
        # val: non-overlap with train, same as test
        # test: non-overlap with train, same as test
        corr_dataloaders, corr_dataset_sizes = split_dataset(corr_trian_images, corr_test_images, valid_size=0.02, batch_size=batch_size, train_size=sample_size)
        corrupted_dataloaders.append(corr_dataloaders)
        corrupted_dataset_sizes.append(corr_dataset_sizes)
    return ref_dataloaders, ref_dataset_sizes, corrupted_dataloaders, corrupted_dataset_sizes, corruption

In [13]:
import torch
from torchvision import models, datasets, transforms

def get_dataset_loader(valdir, batch_size, shuffle):
    val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        ),
    ]))
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=batch_size, shuffle=shuffle
    )
    return val_loader

def gce(logits, target, q = 0.8):
    """ Generalized cross entropy.
    
    Reference: https://arxiv.org/abs/1805.07836
    """
    probs = torch.nn.functional.softmax(logits, dim=1)
    probs_with_correct_idx = probs.index_select(-1, target).diag()
    loss = (1. - probs_with_correct_idx**q) / q
    return loss.mean()

def adapt_batchnorm(model):
    model.eval()
    parameters = []
    for module in model.modules():
        if isinstance(module, torch.nn.BatchNorm2d):
            parameters.extend(module.parameters())
            module.train()
    return parameters


# ---

def evaluate(
        datadir = '/data/imagenetc/gaussian_blur/3',
        num_epochs = 5,
        batch_size = 128,
        learning_rate = 1e-3,
        gce_q = 0.8
    ):
    
    model = models.resnet50(pretrained = True).cuda()
    parameters = adapt_batchnorm(model)
    val_loader = get_dataset_loader(
        datadir,
        batch_size = batch_size,
        shuffle = True
    )
    optimizer = torch.optim.SGD(
        model.parameters(), lr = learning_rate
    )
    
    num_correct, num_samples = 0., 0.
    for epoch in range(num_epochs):
        predictions = []
        for images, targets in val_loader:

            logits = model(images.cuda())
            predictions = logits.argmax(dim = 1)
            loss = gce(logits, predictions, q = gce_q)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            num_correct += (predictions.detach().cpu() == targets).float().sum()
            num_samples += len(targets)
        print(f"Correct: {num_correct:#5.0f}/{num_samples:#5.0f} ({100 * num_correct / num_samples:.2f} %)")
        
    return num_correct / num_samples


In [16]:
severity = [1,2,3,4,5]
results = np.zeros((3, 5))
for cor_ind in range(3):
    corr = weather[cor_ind]
    for sev in severity:
        print("CORRUPTION:", corr, "; SEVERITY:", sev)
        dataset_name = 'imagenetc/' + corr + '/' + str(sev)
        result = evaluate(datadir=imagebase+dataset_name+'/')
        results[cor_ind][sev-1] = result
        print("---------------------------------------------------")
    print("===============================================================")

CORRUPTION: snow ; SEVERITY: 1
Correct: 31750./50000. (63.50 %)
Correct: 64426./100000. (64.43 %)
Correct: 97530./150000. (65.02 %)
Correct: 130809./200000. (65.40 %)
Correct: 164196./250000. (65.68 %)
---------------------------------------------------
CORRUPTION: snow ; SEVERITY: 2
Correct: 25539./50000. (51.08 %)
Correct: 53112./100000. (53.11 %)
Correct: 81562./150000. (54.37 %)
Correct: 110457./200000. (55.23 %)
Correct: 139590./250000. (55.84 %)
---------------------------------------------------
CORRUPTION: snow ; SEVERITY: 3
Correct: 25887./50000. (51.77 %)
Correct: 53676./100000. (53.68 %)
Correct: 82304./150000. (54.87 %)
Correct: 111390./200000. (55.69 %)
Correct: 140780./250000. (56.31 %)
---------------------------------------------------
CORRUPTION: snow ; SEVERITY: 4
Correct: 21520./50000. (43.04 %)
Correct: 45658./100000. (45.66 %)
Correct: 70768./150000. (47.18 %)
Correct: 96521./200000. (48.26 %)
Correct: 122670./250000. (49.07 %)
-------------------------------------

In [20]:
results

array([[0.656784  , 0.55835998, 0.56312001, 0.49068001, 0.47912401],
       [0.66921198, 0.56897199, 0.478232  , 0.455392  , 0.38220799],
       [0.70635599, 0.68982798, 0.66353202, 0.64086002, 0.57903999]])

In [22]:
1-results.mean(axis=1)

array([0.4503864 , 0.48919681, 0.3440768 ])

In [None]:
print("hello")

In [10]:
## running the original experiment using the paper hyperparams

base_cum_acc = np.zeros((5, len(corruption)))
adapted_cum_acc = np.zeros((5, len(corruption)))


acc_record_sev = {}
batch_size = 96

for severity in range(1, 6):
    
    acc_record_sev[severity] = {}

In [13]:
for severity in range(1, 6):

    ref_dataloaders, ref_dataset_sizes, corrupted_dataloaders, corrupted_dataset_sizes, corruption \
            = get_imagenetc(imagebase, severity, batch_size, 49000)
    
    for cor in weather[1:]:
    
        print("Running experiment for severity =", severity, "corruption =", cor)
        
#         start = time.time()
        base_acc, base_acc_epoch, adapted_train_acc_overall, adapted_train_acc_per_epoch, \
            adapted_train_acc_per_iter, model, baseline, training_time = evaluate(corrupted_dataloaders[corruption.index(cor)]['train'])
#         adapt_time = time.time() - start

        base_cum_acc[severity-1][corruption.index(cor)] = base_acc_epoch[0]
        adapted_cum_acc[severity-1][corruption.index(cor)] = adapted_train_acc_overall[-1]

        ## evaluate on validation set
        model.eval()
        baseline.eval()
        with torch.no_grad():
            base_val_correct, crpt_val_correct, crpt_val_samples = 0, 0, 0
            for images, labels in corrupted_dataloaders[corruption.index(cor)]['val']:
                labels = labels.to(device)
                outputs = model(images.to(device))
                _, predicted = torch.max(outputs.data, 1)
                crpt_val_correct += (predicted == labels).sum().item()
                crpt_val_samples += len(labels)

                b_outputs = baseline(images.to(device))
                _, b_predicted = torch.max(b_outputs.data, 1)
                base_val_correct += (b_predicted == labels).sum().item()


        print(f"Corrupt Val Adapt Accuracy: {crpt_val_correct:#5.0f}/{crpt_val_samples:#5.0f} ({crpt_val_correct / crpt_val_samples})")
        print(f"Corrupt Val Base Accuracy: {base_val_correct:#5.0f}/{crpt_val_samples:#5.0f} ({base_val_correct / crpt_val_samples})")

        acc_record_sev[severity][cor] = [base_acc, base_acc_epoch, adapted_train_acc_overall, adapted_train_acc_per_epoch, \
                                    adapted_train_acc_per_iter, (base_val_correct / crpt_val_samples), \
                                    (crpt_val_correct / crpt_val_samples), training_time]

        print("===================================================================================================")
        
        with open('BASE_cum_acc_replicate.npy', 'wb') as f:
            np.save(f, base_cum_acc)
        with open('RPL_cum_acc_replicate.npy', 'wb') as f:
            np.save(f, adapted_cum_acc)
            
        with open('PERFORMANCE_dic_replicate.pkl', 'wb') as f:
            pickle.dump(acc_record_sev, f)
        




Running experiment for severity = 1 corruption = snow
Epoch 0
Baseline Correct: 26716./48960. (54.57 %)
Adapted Correct: 30224./48960. (61.73 %)
Baseline Epoch Correct: 26716./48960. (54.57 %)
Adapted Epoch Correct: 30224./48960. (61.73 %)
Epoch 1
Baseline Correct: 53429./97920. (54.56 %)
Adapted Correct: 61011./97920. (62.31 %)
Baseline Epoch Correct: 26713./48960. (54.56 %)
Adapted Epoch Correct: 30787./48960. (62.88 %)
Epoch 2
Baseline Correct: 80149./146880. (54.57 %)
Adapted Correct: 92202./146880. (62.77 %)
Baseline Epoch Correct: 26720./48960. (54.58 %)
Adapted Epoch Correct: 31191./48960. (63.71 %)
Epoch 3
Baseline Correct: 106873./195840. (54.57 %)
Adapted Correct: 123588./195840. (63.11 %)
Baseline Epoch Correct: 26724./48960. (54.58 %)
Adapted Epoch Correct: 31386./48960. (64.11 %)
Epoch 4
Baseline Correct: 133595./244800. (54.57 %)
Adapted Correct: 155255./244800. (63.42 %)
Baseline Epoch Correct: 26722./48960. (54.58 %)
Adapted Epoch Correct: 31667./48960. (64.68 %)
Corrup

Corrupt Val Adapt Accuracy:  535./ 960. (0.5572916666666666)
Corrupt Val Base Accuracy:  354./ 960. (0.36875)
Running experiment for severity = 3 corruption = frost
Epoch 0
Baseline Correct: 15719./48960. (32.11 %)
Adapted Correct: 20703./48960. (42.29 %)
Baseline Epoch Correct: 15719./48960. (32.11 %)
Adapted Epoch Correct: 20703./48960. (42.29 %)
Epoch 1
Baseline Correct: 31441./97920. (32.11 %)
Adapted Correct: 42305./97920. (43.20 %)
Baseline Epoch Correct: 15722./48960. (32.11 %)
Adapted Epoch Correct: 21602./48960. (44.12 %)
Epoch 2
Baseline Correct: 47162./146880. (32.11 %)
Adapted Correct: 64514./146880. (43.92 %)
Baseline Epoch Correct: 15721./48960. (32.11 %)
Adapted Epoch Correct: 22209./48960. (45.36 %)
Epoch 3
Baseline Correct: 62887./195840. (32.11 %)
Adapted Correct: 87258./195840. (44.56 %)
Baseline Epoch Correct: 15725./48960. (32.12 %)
Adapted Epoch Correct: 22744./48960. (46.45 %)
Epoch 4
Baseline Correct: 78607./244800. (32.11 %)
Adapted Correct: 110501./244800. (45

Baseline Correct: 57096./244800. (23.32 %)
Adapted Correct: 92579./244800. (37.82 %)
Baseline Epoch Correct: 11419./48960. (23.32 %)
Adapted Epoch Correct: 19697./48960. (40.23 %)
Corrupt Val Adapt Accuracy:  365./ 960. (0.3802083333333333)
Corrupt Val Base Accuracy:  217./ 960. (0.22604166666666667)
Running experiment for severity = 5 corruption = fog
Epoch 0
Baseline Correct: 11942./48960. (24.39 %)
Adapted Correct: 24478./48960. (50.00 %)
Baseline Epoch Correct: 11942./48960. (24.39 %)
Adapted Epoch Correct: 24478./48960. (50.00 %)
Epoch 1
Baseline Correct: 23892./97920. (24.40 %)
Adapted Correct: 50369./97920. (51.44 %)
Baseline Epoch Correct: 11950./48960. (24.41 %)
Adapted Epoch Correct: 25891./48960. (52.88 %)
Epoch 2
Baseline Correct: 35841./146880. (24.40 %)
Adapted Correct: 77110./146880. (52.50 %)
Baseline Epoch Correct: 11949./48960. (24.41 %)
Adapted Epoch Correct: 26741./48960. (54.62 %)
Epoch 3
Baseline Correct: 47786./195840. (24.40 %)
Adapted Correct: 104409./195840. (

In [8]:
with open('BASE_cum_acc_replicate.npy', 'rb') as f:
#     np.save(f, base_cum_acc)
    base_cum_acc = np.load(f)
with open('RPL_cum_acc_replicate.npy', 'rb') as f:
#     np.save(f, adapted_cum_acc)
    adapted_cum_acc = np.load(f)
    

In [None]:
## running the original experiment using the paper hyperparams

base_cum_acc = np.zeros((5, len(corruption)))
adapted_cum_acc = np.zeros((5, len(corruption)))


acc_record_sev = {}
batch_size = 96

for severity in range(1, 6):
    
    acc_record_sev[severity] = {}

for severity in range(1, 6):

    ref_dataloaders, ref_dataset_sizes, corrupted_dataloaders, corrupted_dataset_sizes, corruption \
            = get_imagenetc(imagebase, severity, batch_size, 49000)
    
    for cor in weather[1:]:
    
        print("Running experiment for severity =", severity, "corruption =", cor)
        
#         start = time.time()
        base_acc, base_acc_epoch, adapted_train_acc_overall, adapted_train_acc_per_epoch, \
            adapted_train_acc_per_iter, model, baseline, training_time = evaluate(corrupted_dataloaders[corruption.index(cor)]['train'])
#         adapt_time = time.time() - start

        base_cum_acc[severity-1][corruption.index(cor)] = base_acc_epoch[0]
        adapted_cum_acc[severity-1][corruption.index(cor)] = adapted_train_acc_overall[-1]

        ## evaluate on validation set
        model.eval()
        baseline.eval()
        with torch.no_grad():
            base_val_correct, crpt_val_correct, crpt_val_samples = 0, 0, 0
            for images, labels in corrupted_dataloaders[corruption.index(cor)]['val']:
                labels = labels.to(device)
                outputs = model(images.to(device))
                _, predicted = torch.max(outputs.data, 1)
                crpt_val_correct += (predicted == labels).sum().item()
                crpt_val_samples += len(labels)

                b_outputs = baseline(images.to(device))
                _, b_predicted = torch.max(b_outputs.data, 1)
                base_val_correct += (b_predicted == labels).sum().item()


        print(f"Corrupt Val Adapt Accuracy: {crpt_val_correct:#5.0f}/{crpt_val_samples:#5.0f} ({crpt_val_correct / crpt_val_samples})")
        print(f"Corrupt Val Base Accuracy: {base_val_correct:#5.0f}/{crpt_val_samples:#5.0f} ({base_val_correct / crpt_val_samples})")

        acc_record_sev[severity][cor] = [base_acc, base_acc_epoch, adapted_train_acc_overall, adapted_train_acc_per_epoch, \
                                    adapted_train_acc_per_iter, (base_val_correct / crpt_val_samples), \
                                    (crpt_val_correct / crpt_val_samples), training_time]

        print("===================================================================================================")
        
        with open('BASE_cum_acc_replicate.npy', 'wb') as f:
            np.save(f, base_cum_acc)
        with open('RPL_cum_acc_replicate.npy', 'wb') as f:
            np.save(f, adapted_cum_acc)
            
        with open('PERFORMANCE_dic_replicate.pkl', 'wb') as f:
            pickle.dump(acc_record_sev, f)
        




with open('BASE_cum_acc_replicate.npy', 'rb') as f:
#     np.save(f, base_cum_acc)
    base_cum_acc = np.load(f)
with open('RPL_cum_acc_replicate.npy', 'rb') as f:
#     np.save(f, adapted_cum_acc)
    adapted_cum_acc = np.load(f)
    