In [1]:
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.models as models
import numpy as np
import random
import torch.utils.data as torchdata
from torch.utils.data import SubsetRandomSampler
import matplotlib.pyplot as plt
import time
# from train_model import train_model
# from test_model import test_model
%matplotlib inline
import pickle

In [2]:
import tent

In [3]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]='1'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
filePath = '/local/rcs/ll3504/datasets/256_ObjectCategories/'
namelist = os.listdir(filePath)
nameDic_cal = {}
for name in namelist:
    splits = name.split(".")
    nameDic_cal[int(splits[0])-1] = splits[1]
print(nameDic_cal[1])

american-flag


In [5]:
def get_dataset(path='/database', dataset_name='caltech-256-common'):
    # No holdout testing data. train and test data are the same, but different transformation
    data_transforms = {
        'train': transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ]),
        'test': transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    }

    tr_dataset = datasets.ImageFolder(path + dataset_name + '/', data_transforms['train'])
    te_dataset = datasets.ImageFolder(path + dataset_name + '/', data_transforms['test'])

    return tr_dataset, te_dataset

In [6]:
def split_dataset(train_dataset, test_dataset, valid_size=0.2, batch_size=128, train_size = 128):
    '''
    This function splits dataset into train, val, and test sets, and return train, val, test dataloaders.
    Val and Test loaders are the same

    '''
    
    # what does the len function gives?
    num_train = len(train_dataset)
    # print("DEBUGGING: overall training data size =", num_train)
    indices = list(range(num_train))
    split = int(np.floor(valid_size * num_train))
    random.shuffle(indices)
    train_idx, valid_idx = indices[split:split+train_size], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)
    
    # print("DEBUGGING: the train_ind are:", len(train_idx))


    train_loader = torchdata.DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=48, pin_memory=True, drop_last=True, sampler = train_sampler)
    test_loader = torchdata.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=48, pin_memory=True, drop_last=True, sampler = valid_sampler)
    dataloaders = {'train': train_loader,
                   'val': test_loader,
                   'test': test_loader}
    dataset_sizes ={'train': train_size, #int(np.floor((1-valid_size) * num_train)),
                    'val': int(np.floor(valid_size * num_train)),
                    'test': int(np.floor(valid_size * num_train))}
    return dataloaders, dataset_sizes

In [7]:
imagebase = '/local/rcs/ll3504/datasets/'

In [13]:

corruption = ['fog', 'snow', 'spatter', 'gaussian_blur', 'gaussian_noise', 'brightness']
weather = ['fog', 'frost', 'snow', 'spatter']
digit=['gaussian_blur', 'gaussian_noise', 'brightness']


In [15]:
def get_imagenetc(imagebase, batch_size=128, sample_size = 128, corruption=corruption):
    '''
    Returns:
        ref_dataloaders:          ImageNet original validation data, as a reference
        ref_dataset_sizes:        1000, not the sizes of the real dataset in the ref_loader, probs used downstream
        corrupted_dataloaders:    A list of corrupted dataloaders, each element in a list represetns the data loaders
                                  for one corruption type. Each element contains ['train']['val']['test'] loaders
        corrupted_dataset_sizes:  A list of dictionaries of the sizes of each loaders for each corruption
        corruption:               A list of corruption names, in the same order of the corrupted_dataloaders
    '''

    corrupted_dataloaders = {}
    
    # for every type of corruption, go to the specified severity folder
    for corr in corruption:
        dataloader_all_sev = []

        for severity in range(1,6):
            
            dataset_name = 'imagenetc/' + corr + '/' + str(severity)
            # Get dataset from folder
            corr_trian_images, corr_test_images = get_dataset(imagebase, dataset_name)

            # Get corruption-specific train, val, test loader
                # train: training data, non-overlap with val/test
                # val: non-overlap with train, same as test
                # test: non-overlap with train, same as test

            corr_dataloaders, _ = split_dataset(corr_trian_images, corr_test_images, valid_size=0.02, batch_size=batch_size, train_size=sample_size)

            dataloader_all_sev.append(corr_dataloaders)
            corrupted_dataloaders[corr] = dataloader_all_sev
        
        
#     return ref_dataloaders, ref_dataset_sizes, corrupted_dataloaders, corrupted_dataset_sizes, corruption
    return corrupted_dataloaders




In [16]:
imc_loaders = get_imagenetc(imagebase, batch_size=64, sample_size = 49000, corruption=weather)
imc_loaders['frost']

[{'train': <torch.utils.data.dataloader.DataLoader at 0x7f37234a91c0>,
  'val': <torch.utils.data.dataloader.DataLoader at 0x7f37234a9130>,
  'test': <torch.utils.data.dataloader.DataLoader at 0x7f37234a9130>},
 {'train': <torch.utils.data.dataloader.DataLoader at 0x7f37234a9610>,
  'val': <torch.utils.data.dataloader.DataLoader at 0x7f37234a9280>,
  'test': <torch.utils.data.dataloader.DataLoader at 0x7f37234a9280>},
 {'train': <torch.utils.data.dataloader.DataLoader at 0x7f37234a9c10>,
  'val': <torch.utils.data.dataloader.DataLoader at 0x7f37234a90d0>,
  'test': <torch.utils.data.dataloader.DataLoader at 0x7f37234a90d0>},
 {'train': <torch.utils.data.dataloader.DataLoader at 0x7f37045a7160>,
  'val': <torch.utils.data.dataloader.DataLoader at 0x7f37045a71f0>,
  'test': <torch.utils.data.dataloader.DataLoader at 0x7f37045a71f0>},
 {'train': <torch.utils.data.dataloader.DataLoader at 0x7f37045a75e0>,
  'val': <torch.utils.data.dataloader.DataLoader at 0x7f37045a7670>,
  'test': <torch

In [22]:
def online_evaluate(corrupt_loaders, corrutpion, severity, lr):
    resnet50 = models.resnet50(pretrained=True)
    resnet50 = tent.configure_model(resnet50)
    params, param_names = tent.collect_params(resnet50)
    optimizer = SGD(params, lr=lr)
    tented_resnet50 = tent.Tent(resnet50, optimizer).to(device)

    num_correct, num_samples = 0., 0.
    
    trainloader = corrupt_loaders[corrutpion][severity-1]['train']

    for images, targets in trainloader:
        logits = tented_resnet50(images.to(device))
        predictions = logits.argmax(dim=1)
        num_correct += (predictions.detach().cpu() == targets).float().sum()
        num_samples += len(targets)

    accuracy = num_correct / num_samples
    print(f"Acc =: {num_correct:#5.0f}/{num_samples:#5.0f} ({100 * accuracy: .2f} %), \
            Error =: {100 * (1 - accuracy): .2f} %")
    
    return tented_resnet50, accuracy

In [32]:
def offline_validate(model, corrupt_loaders, corruption, severity, baseline=False):
    if baseline:
        model.eval()
    else:
        model.offline_validation()
    valloader = corrupt_loaders[corruption][severity-1]['val']
    
    num_correct, num_samples = 0., 0.

#     with model.no_grad():
    for images, targets in valloader:
        logits = model(images.to(device))
        predictions = logits.argmax(dim=1)
        num_correct += (predictions.detach().cpu() == targets).float().sum()
        num_samples += len(targets)

    accuracy = num_correct / num_samples
    print(f"Validation Acc =: {num_correct:#5.0f}/{num_samples:#5.0f} ({100 * accuracy: .2f} %), \
            Validation Error =: {100 * (1 - accuracy): .2f} %")
    
    return accuracy

In [24]:
lr = 0.00025
bs = 64

# ONLINE adaptation adaptation on the "training set"

for corr in weather:
    online_acc_sum = 0
    validate_acc_sum = 0

    for severity in range(1,6):
        print(f"Corruption: {corr}, Severity: {severity}")
        
        # online adaptation:
        adapted_model, accuracy = online_evaluate(imc_loaders, corr, severity, lr)
        online_acc_sum += accuracy
        
        # offline validation:
        val_acc = offline_validate(adapted_model, imc_loaders, corr, severity)
        validate_acc_sum += val_acc
        
    print(f"Averag online accuracy for {corr}: {100 * (online_acc_sum / 5): .2f} %, \
            Averag online error for {corr}: {100 * (1 - online_acc_sum / 5): .2f} % ")
    print()
    print(f"Averag validation accuracy for {corr}: {100 * (validate_acc_sum / 5): .2f} %, \
            Averag validation error for {corr}: {100 * (1 - validate_acc_sum / 5): .2f} % ")
    print("======================================================================================")
    


Corruption: fog, Severity: 1
Acc =: 33715./48960. ( 68.86 %),             Error =:  31.14 %
Validation Acc =:  656./ 960. ( 68.33 %),             Validation Error =:  31.67 %
Corruption: fog, Severity: 2
Acc =: 32627./48960. ( 66.64 %),             Error =:  33.36 %
Validation Acc =:  601./ 960. ( 62.60 %),             Validation Error =:  37.40 %
Corruption: fog, Severity: 3
Acc =: 30822./48960. ( 62.95 %),             Error =:  37.05 %
Validation Acc =:  621./ 960. ( 64.69 %),             Validation Error =:  35.31 %
Corruption: fog, Severity: 4
Acc =: 29398./48960. ( 60.04 %),             Error =:  39.96 %
Validation Acc =:  592./ 960. ( 61.67 %),             Validation Error =:  38.33 %
Corruption: fog, Severity: 5
Acc =: 25396./48960. ( 51.87 %),             Error =:  48.13 %
Validation Acc =:  526./ 960. ( 54.79 %),             Validation Error =:  45.21 %
Averag online accuracy for fog:  62.07 %,             Averag online error for fog:  37.93 % 

Averag validation accuracy for 

In [25]:
# a deleted cell of code, adaping on snow
# I leave the result here for the sake of reference, for now

Corruption: snow, Severity: 1
Acc =: 30466./48960. ( 62.23 %),             Error =:  37.77 %
Validation Acc =:  589./ 960. ( 61.35 %),             Validation Error =:  38.65 %
Corruption: snow, Severity: 2
Acc =: 24128./48960. ( 49.28 %),             Error =:  50.72 %
Validation Acc =:  519./ 960. ( 54.06 %),             Validation Error =:  45.94 %
Corruption: snow, Severity: 3
Acc =: 24433./48960. ( 49.90 %),             Error =:  50.10 %
Validation Acc =:  500./ 960. ( 52.08 %),             Validation Error =:  47.92 %
Corruption: snow, Severity: 4
Acc =: 20096./48960. ( 41.05 %),             Error =:  58.95 %
Validation Acc =:  406./ 960. ( 42.29 %),             Validation Error =:  57.71 %
Corruption: snow, Severity: 5
Acc =: 18757./48960. ( 38.31 %),             Error =:  61.69 %
Validation Acc =:  384./ 960. ( 40.00 %),             Validation Error =:  60.00 %
Averag online accuracy for snow:  48.15 %,             Averag online error for snow:  51.85 % 

Averag validation accura

In [26]:
weather_us = ['fog', 'snow', 'spatter']


In [29]:
cross_valid_acc = {}

for corr in weather_us:
    online_acc_sum = 0
    cross_valid_acc[corr] = {}
    
    for severity in [3]:
        print(f"Corruption: {corr}, Severity: {severity}")

        # online adaptation:
        adapted_model, accuracy = online_evaluate(imc_loaders, corr, severity, lr)
        online_acc_sum += accuracy
        
        # offline validation:
        for val_corr in weather_us:
            print(f"cross-validating on corruption = {val_corr}")
            val_acc = offline_validate(adapted_model, imc_loaders, val_corr, severity)
            cross_valid_acc[corr][val_corr] = val_acc

#     print(f"Averag online accuracy for {corr}: {100 * (online_acc_sum / 5): .2f} %, \
#             Averag online error for {corr}: {100 * (1 - online_acc_sum / 5): .2f} % ")
#     print()
#     print(f"Averag validation accuracy for {corr}: {100 * (validate_acc_sum / 5): .2f} %, \
#             Averag validation error for {corr}: {100 * (1 - validate_acc_sum / 5): .2f} % ")
    print("======================================================================================")

Corruption: fog, Severity: 3
Acc =: 30799./48960. ( 62.91 %),             Error =:  37.09 %
cross-validating on corruption = fog
Validation Acc =:  632./ 960. ( 65.83 %),             Validation Error =:  34.17 %
cross-validating on corruption = snow
Validation Acc =:  465./ 960. ( 48.44 %),             Validation Error =:  51.56 %
cross-validating on corruption = spatter
Validation Acc =:  568./ 960. ( 59.17 %),             Validation Error =:  40.83 %
Corruption: snow, Severity: 3
Acc =: 24380./48960. ( 49.80 %),             Error =:  50.20 %
cross-validating on corruption = fog
Validation Acc =:  616./ 960. ( 64.17 %),             Validation Error =:  35.83 %
cross-validating on corruption = snow
Validation Acc =:  493./ 960. ( 51.35 %),             Validation Error =:  48.65 %
cross-validating on corruption = spatter
Validation Acc =:  583./ 960. ( 60.73 %),             Validation Error =:  39.27 %
Corruption: spatter, Severity: 3
Acc =: 29558./48960. ( 60.37 %),             Error =

In [35]:
vanilla = models.resnet50(pretrained=True).to(device)
for val_corr in weather_us:
    print(f"cross-validating on corruption = {val_corr}")
    val_acc = offline_validate(vanilla, imc_loaders, val_corr, severity, baseline=True)
    cross_valid_acc[corr][val_corr] = val_acc

cross-validating on corruption = fog
Validation Acc =:  449./ 960. ( 46.77 %),             Validation Error =:  53.23 %
cross-validating on corruption = snow
Validation Acc =:  343./ 960. ( 35.73 %),             Validation Error =:  64.27 %
cross-validating on corruption = spatter
Validation Acc =:  483./ 960. ( 50.31 %),             Validation Error =:  49.69 %
