In [1]:
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torch.utils.data as data
from torch.optim import SGD
import torchvision.transforms as transforms
import torchvision.models as models
import numpy as np
import random
import torch.utils.data as torchdata
from torch.utils.data import SubsetRandomSampler
import matplotlib.pyplot as plt
import time
# from train_model import train_model
# from test_model import test_model
%matplotlib inline
import pickle

In [2]:
import tent

In [3]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]='1'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
filePath = '/local/rcs/ll3504/datasets/256_ObjectCategories/'
namelist = os.listdir(filePath)
nameDic_cal = {}
for name in namelist:
    splits = name.split(".")
    nameDic_cal[int(splits[0])-1] = splits[1]
print(nameDic_cal[1])

american-flag


In [5]:
def get_dataset(path='/database', dataset_name='caltech-256-common'):
    # No holdout testing data. train and test data are the same, but different transformation
    data_transforms = {
        'train': transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ]),
        'test': transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    }

    tr_dataset = datasets.ImageFolder(path + dataset_name + '/', data_transforms['train'])
    te_dataset = datasets.ImageFolder(path + dataset_name + '/', data_transforms['test'])

    return tr_dataset, te_dataset

In [6]:
def split_dataset(train_dataset, test_dataset, valid_size=0.02, batch_size=128, train_size = 128):
    '''
    This function splits dataset into train, val, and test sets, and return train, val, test dataloaders.
    Val and Test loaders are the same

    '''
    
    # what does the len function gives?
    num_train = len(train_dataset)
    indices = list(range(num_train))
    split = int(np.floor(valid_size * num_train))
    random.shuffle(indices)
    train_idx, valid_idx = indices[split:split+train_size], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)
    


    train_loader = torchdata.DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=48, pin_memory=True, drop_last=False, sampler = train_sampler)
    test_loader = torchdata.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=48, pin_memory=True, drop_last=False, sampler = valid_sampler)
    dataloaders = {'train': train_loader,
                   'val': test_loader,
                   'test': test_loader}
    dataset_sizes ={'train': train_size, #int(np.floor((1-valid_size) * num_train)),
                    'val': int(np.floor(valid_size * num_train)),
                    'test': int(np.floor(valid_size * num_train))}
    return dataloaders, dataset_sizes

In [7]:
imagebase = '/local/rcs/ll3504/datasets/'

In [8]:

corruption = ['fog', 'snow', 'spatter', 'gaussian_blur', 'gaussian_noise', 'brightness']
weather = ['fog', 'snow', 'spatter', 'frost']
digital = ['gaussian_noise', 'gaussian_blur', 'brightness', 'defocus_blur', 'contrast'] # NO gaussian noise in the files


In [9]:
def get_imagenetc(imagebase, batch_size=128, sample_size = 128, corruption=corruption):
    '''
    Returns:
        ref_dataloaders:          ImageNet original validation data, as a reference
        ref_dataset_sizes:        1000, not the sizes of the real dataset in the ref_loader, probs used downstream
        corrupted_dataloaders:    A list of corrupted dataloaders, each element in a list represetns the data loaders
                                  for one corruption type. Each element contains ['train']['val']['test'] loaders
        corrupted_dataset_sizes:  A list of dictionaries of the sizes of each loaders for each corruption
        corruption:               A list of corruption names, in the same order of the corrupted_dataloaders
    '''
    corrupted_dataloaders = {}
    
    imagenet_val = datasets.ImageNet(imagebase+'imagenetc/', split='val', transform=transforms.Compose([
            transforms.Resize([224, 224]),
            transforms.CenterCrop((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ]),
                                   target_transform=None)#, download=False)
    
    random_indices = random.sample(range(0, len(imagenet_val)), int(len(imagenet_val)*0.02))
    imagenet_val_subset = data.Subset(imagenet_val, random_indices)
    clean_val_loader = torch.utils.data.DataLoader(imagenet_val_subset,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=48)
    corrupted_dataloaders['clean'] = clean_val_loader
#     ref_dataloaders = { 'val': val_loader,
#                        'test': val_loader}
#     ref_dataset_sizes ={'val': int(len(val_loader.dataset)),
#                         'test': int(len(val_loader.dataset))}
    
    # for every type of corruption, go to the specified severity folder
    for corr in corruption:
        dataloader_all_sev = []

        for severity in range(1,6):
            
            dataset_name = 'imagenetc/' + corr + '/' + str(severity)
            # Get dataset from folder
            corr_trian_images, corr_test_images = get_dataset(imagebase, dataset_name)

            # Get corruption-specific train, val, test loader
                # train: training data, non-overlap with val/test
                # val: non-overlap with train, same as test
                # test: non-overlap with train, same as test

            corr_dataloaders, _ = split_dataset(corr_trian_images, corr_test_images, valid_size=0.02, batch_size=batch_size, train_size=sample_size)

            dataloader_all_sev.append(corr_dataloaders)
            corrupted_dataloaders[corr] = dataloader_all_sev
        
        
#     return ref_dataloaders, ref_dataset_sizes, corrupted_dataloaders, corrupted_dataset_sizes, corruption
    return clean_val_loader, corrupted_dataloaders




In [10]:
clean_loader, imc_loaders = get_imagenetc(imagebase, batch_size=64, sample_size = 49000, corruption=(weather+digital))
clean_loader, imc_loaders

(<torch.utils.data.dataloader.DataLoader at 0x7f978eab5820>,
 {'clean': <torch.utils.data.dataloader.DataLoader at 0x7f978eab5820>,
  'fog': [{'train': <torch.utils.data.dataloader.DataLoader at 0x7f9786b0fac0>,
    'val': <torch.utils.data.dataloader.DataLoader at 0x7f9786b0fb50>,
    'test': <torch.utils.data.dataloader.DataLoader at 0x7f9786b0fb50>},
   {'train': <torch.utils.data.dataloader.DataLoader at 0x7f9786b0ff40>,
    'val': <torch.utils.data.dataloader.DataLoader at 0x7f9786b0ffd0>,
    'test': <torch.utils.data.dataloader.DataLoader at 0x7f9786b0ffd0>},
   {'train': <torch.utils.data.dataloader.DataLoader at 0x7f97835c9400>,
    'val': <torch.utils.data.dataloader.DataLoader at 0x7f97835c9490>,
    'test': <torch.utils.data.dataloader.DataLoader at 0x7f97835c9490>},
   {'train': <torch.utils.data.dataloader.DataLoader at 0x7f97835c9880>,
    'val': <torch.utils.data.dataloader.DataLoader at 0x7f97835c9910>,
    'test': <torch.utils.data.dataloader.DataLoader at 0x7f97835c9

In [11]:
def get_imagenet(imagebase, batch_size=128, sample_size = 128):
    # this is the imageNet validation data
    imagenet_val = datasets.ImageNet(imagebase+'imagenetc/', split='val', transform=transforms.Compose([
            transforms.Resize([224, 224]),
            transforms.CenterCrop((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ]),
                                   target_transform=None)#, download=False)
    
    # TODO: subsample some size of ImageNet training data as source
        # Doesn't need this step
#     print("DEBUGGING: imagenet_val size is:", len(imagenet_val))
    
    random_indices = random.sample(range(0, len(imagenet_val)), int(len(imagenet_val)*0.02))
#     print("DEBUGGING: random indices are:", len(random_indices))
    imagenet_val_subset = data.Subset(imagenet_val, random_indices)
    val_loader = torch.utils.data.DataLoader(imagenet_val_subset,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=48)
    ref_dataloaders = { 'val': val_loader,
                       'test': val_loader}
    ref_dataset_sizes ={'val': int(len(val_loader.dataset)),
                        'test': int(len(val_loader.dataset))}
    
    return ref_dataloaders

In [12]:
def online_evaluate(corrupt_loaders, corrutpion, severity, lr):
    start = time.time()
    resnet50 = models.resnet50(pretrained=True)
    resnet50 = tent.configure_model(resnet50)
    params, param_names = tent.collect_params(resnet50)
    optimizer = SGD(params, lr=lr)
    tented_resnet50 = tent.Tent(resnet50, optimizer).to(device)

    num_correct, num_samples = 0., 0.
    
    trainloader = corrupt_loaders[corrutpion][severity-1]['train']

    for images, targets in trainloader:
        logits = tented_resnet50(images.to(device))
        predictions = logits.argmax(dim=1)
        num_correct += (predictions.detach().cpu() == targets).float().sum()
        num_samples += len(targets)

    adapt_time = time.time() - start
    accuracy = num_correct / num_samples
    print(f"Acc =: {num_correct:#5.0f}/{num_samples:#5.0f} ({100 * accuracy: .2f} %), \
            Error =: {100 * (1 - accuracy): .2f} %")
    print(f"Adaptation time for one epoch on {num_samples} images takes {adapt_time}s")
    return tented_resnet50, accuracy, adapt_time


In [22]:
def offline_validate(model, corrupt_loaders, corruption, severity, baseline=False, clean_data=False, bn_adapt=False):
    if baseline:
        if not bn_adapt:
            model.eval()
        else:
            model.train() # TODO: this is a very sloppy way of doing it, change it
    else:
        if not bn_adapt: # If we don't adapt the BN statistics at validation time, set to eval mode.
            model.model.eval() 
        model.offline_validation()
        
    if clean_data:
        valloader = corrupt_loaders['clean']
    else:
        valloader = corrupt_loaders[corruption][severity-1]['val']
    
    num_correct, num_samples = 0., 0.

#     with model.no_grad():
    for images, targets in valloader:
        logits = model(images.to(device))
        predictions = logits.argmax(dim=1)
        num_correct += (predictions.detach().cpu() == targets).float().sum()
        num_samples += len(targets)

    accuracy = num_correct / num_samples
    print(f"Validation Acc =: {num_correct:#5.0f}/{num_samples:#5.0f} ({100 * accuracy: .2f} %), \
            Validation Error =: {100 * (1 - accuracy): .2f} %")
    
    return accuracy

# Replicating TENT

In [14]:
lr = 0.00025
bs = 64

# ONLINE adaptation adaptation on the "training set"

for corr in digital:
    online_acc_sum = 0
    validate_acc_sum = 0

    for severity in range(1,6):
        print(f"Corruption: {corr}, Severity: {severity}")
        
        # online adaptation:
        start = time.time()
        adapted_model, accuracy, _ = online_evaluate(imc_loaders, corr, severity, lr)
        adapt_time = time.time() - start
        online_acc_sum += accuracy
        
        # offline validation:
        val_acc = offline_validate(adapted_model, imc_loaders, corr, severity)
        validate_acc_sum += val_acc
        
    print(f"Averag online accuracy for {corr}: {100 * (online_acc_sum / 5): .2f} %, \
            Averag online error for {corr}: {100 * (1 - online_acc_sum / 5): .2f} % ")
    print()
    print(f"Averag validation accuracy for {corr}: {100 * (validate_acc_sum / 5): .2f} %, \
            Averag validation error for {corr}: {100 * (1 - validate_acc_sum / 5): .2f} % ")
    print("======================================================================================")
    


Corruption: gaussian_noise, Severity: 1


Exception ignored in: <function _releaseLock at 0x7f99a5baa5e0>
Traceback (most recent call last):
  File "/home/zw2774/bin/anaconda3/lib/python3.9/logging/__init__.py", line 227, in _releaseLock
    def _releaseLock():
KeyboardInterrupt: 


RuntimeError: DataLoader worker (pid(s) 15231, 15279, 15327, 15375, 15423, 15471, 15522, 15570, 15625, 15697, 15746, 15794, 15843, 15891, 15948, 15997, 16057, 16105, 16153, 16201, 16249, 16301, 16372, 16420, 16468, 16516, 16564, 16613, 16661, 16709, 16757, 16805, 16853, 16901, 16949, 16997, 17045, 17093, 17141, 17190, 17241, 17289, 17337) exited unexpectedly

In [None]:
# a deleted cell of code, adaping on snow
# I leave the result here for the sake of reference, for now

# 7x7 Table - adapt + cross validate

In [15]:
weather_us = ['fog', 'snow', 'spatter']
digital_us = ['gaussian_noise', 'gaussian_blur', 'brightness']
weather_us + digital_us

['fog', 'snow', 'spatter', 'gaussian_noise', 'gaussian_blur', 'brightness']

In [16]:
cross_valid_acc = {}

for corr in weather_us + digital_us:
    online_acc_sum = 0
    cross_valid_acc[corr] = {}
    
    for severity in [3]:
        print(f"Corruption: {corr}, Severity: {severity}")

        # online adaptation:
        adapted_model, accuracy, _ = online_evaluate(imc_loaders, corr, severity, lr)
        online_acc_sum += accuracy
        
        # offline validation on clean set
        print(f"cross-validating on clean data")
        val_acc = offline_validate(adapted_model, imc_loaders, corruption=None, severity=None, clean_data=True)
        cross_valid_acc[corr]['clean'] = val_acc        
        
        # offline validation:
        for val_corr in weather_us + digital_us:
            print(f"cross-validating on corruption = {val_corr}")
            val_acc = offline_validate(adapted_model, imc_loaders, val_corr, severity)
            cross_valid_acc[corr][val_corr] = val_acc

#     print(f"Averag online accuracy for {corr}: {100 * (online_acc_sum / 5): .2f} %, \
#             Averag online error for {corr}: {100 * (1 - online_acc_sum / 5): .2f} % ")
#     print()
#     print(f"Averag validation accuracy for {corr}: {100 * (validate_acc_sum / 5): .2f} %, \
#             Averag validation error for {corr}: {100 * (1 - validate_acc_sum / 5): .2f} % ")
    print("======================================================================================")

Corruption: fog, Severity: 3
Acc =: 30823./49000. ( 62.90 %),             Error =:  37.10 %
Adaptation time for one epoch on 49000.0 images takes 133.26694416999817s
cross-validating on clean data
Validation Acc =:  741./1000. ( 74.10 %),             Validation Error =:  25.90 %
cross-validating on corruption = fog
Validation Acc =:  644./1000. ( 64.40 %),             Validation Error =:  35.60 %
cross-validating on corruption = snow
Validation Acc =:  494./1000. ( 49.40 %),             Validation Error =:  50.60 %
cross-validating on corruption = spatter
Validation Acc =:  578./1000. ( 57.80 %),             Validation Error =:  42.20 %
cross-validating on corruption = gaussian_noise
Validation Acc =:  472./1000. ( 47.20 %),             Validation Error =:  52.80 %
cross-validating on corruption = gaussian_blur
Validation Acc =:  456./1000. ( 45.60 %),             Validation Error =:  54.40 %
cross-validating on corruption = brightness
Validation Acc =:  737./1000. ( 73.70 %),         

In [19]:
severity = 3
baseline_valid_acc = {}

vanilla = models.resnet50(pretrained=True).to(device)



print(f"cross-validating on clean data")
val_acc = offline_validate(vanilla, imc_loaders, corruption=None, severity=None, clean_data=True, baseline=True)
baseline_valid_acc['clean'] = val_acc

for val_corr in weather_us + digital_us:
    print(f"cross-validating on corruption = {val_corr}")
    val_acc = offline_validate(vanilla, imc_loaders, val_corr, severity, baseline=True)
    baseline_valid_acc[val_corr] = val_acc

Batch BN adaptation at test time:
Validation Acc =:  739./1000. ( 73.90 %),             Validation Error =:  26.10 %
cross-validating on clean data
Validation Acc =:  753./1000. ( 75.30 %),             Validation Error =:  24.70 %
cross-validating on corruption = fog
Validation Acc =:  381./1000. ( 38.10 %),             Validation Error =:  61.90 %
cross-validating on corruption = snow
Validation Acc =:  348./1000. ( 34.80 %),             Validation Error =:  65.20 %
cross-validating on corruption = spatter
Validation Acc =:  482./1000. ( 48.20 %),             Validation Error =:  51.80 %
cross-validating on corruption = gaussian_noise
Validation Acc =:  259./1000. ( 25.90 %),             Validation Error =:  74.10 %
cross-validating on corruption = gaussian_blur
Validation Acc =:  368./1000. ( 36.80 %),             Validation Error =:  63.20 %
cross-validating on corruption = brightness
Validation Acc =:  715./1000. ( 71.50 %),             Validation Error =:  28.50 %


In [23]:
print(f'Batch BN adaptation at test time:')
bn_adapt_val_acc = offline_validate(vanilla, imc_loaders, corruption=None, severity=None, clean_data=True, baseline=True, bn_adapt=True)
baseline_valid_acc['BN_adapt'] = bn_adapt_val_acc

for val_corr in weather_us + digital_us:
    print(f"Batch BN adaptation on corruption = {val_corr}")
    val_acc = offline_validate(vanilla, imc_loaders, val_corr, severity, baseline=True, bn_adapt=True)
    baseline_valid_acc[val_corr] = val_acc

Batch BN adaptation at test time:
Validation Acc =:  743./1000. ( 74.30 %),             Validation Error =:  25.70 %
Batch BN adaptation on corruption = fog
Validation Acc =:  618./1000. ( 61.80 %),             Validation Error =:  38.20 %
Batch BN adaptation on corruption = snow
Validation Acc =:  478./1000. ( 47.80 %),             Validation Error =:  52.20 %
Batch BN adaptation on corruption = spatter
Validation Acc =:  566./1000. ( 56.60 %),             Validation Error =:  43.40 %
Batch BN adaptation on corruption = gaussian_noise
Validation Acc =:  455./1000. ( 45.50 %),             Validation Error =:  54.50 %
Batch BN adaptation on corruption = gaussian_blur
Validation Acc =:  430./1000. ( 43.00 %),             Validation Error =:  57.00 %
Batch BN adaptation on corruption = brightness
Validation Acc =:  738./1000. ( 73.80 %),             Validation Error =:  26.20 %


In [19]:
print(vanilla)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 