In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as torchF
import torch.jit
import torch.optim as optim

import torchvision.transforms.functional as torchvisionF
from torchvision.transforms import ColorJitter, Compose, Lambda
from numpy import random

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
#from torchvision.models import resnet50, ResNet50_Weights

import matplotlib.pyplot as plt
import numpy as np

from PIL import Image

import os
from tqdm import tqdm
from copy import deepcopy
from time import time
import logging

# ResNet18 on ImageNet-C

In [3]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet34', pretrained=True)
# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet101', pretrained=True)
# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet152', pretrained=True)
model.eval()

Using cache found in C:\Users\duchu/.cache\torch\hub\pytorch_vision_v0.10.0


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [17]:
'''
ResNet model implementation
Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Deep Residual Learning for Image Recognition. arXiv:1512.03385
'''


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.fc = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out


def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])


def ResNet34():
    return ResNet(BasicBlock, [3, 4, 6, 3])


def ResNet50():
    return ResNet(Bottleneck, [3, 4, 6, 3])


def ResNet101():
    return ResNet(Bottleneck, [3, 4, 23, 3])


def ResNet152():
    return ResNet(Bottleneck, [3, 8, 36, 3])


def test():
    net = ResNet18()
    y = net(torch.randn(1, 3, 32, 32))
    print(y.size())

### Imbalanced data experiment

In [4]:
dir_data = "./data/Tiny-ImageNet-C"
imbalanced_data_folders = ["brightness", "contrast", "defocus_blur", "elastic_transform", "fog", "frost", "gaussian_noise", "glass_blur", "impulse_noise", "jpeg_compression", "motion_blur", "pixelate", "shot_noise", "snow", "zoom_blur"]
dir_mapping = "mapping.txt"

In [5]:
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [6]:
mapping = {}
with open(os.path.join(dir_data, dir_mapping)) as mapping_file:
    num = 0
    for data in mapping_file:
        file_name = data.split(' ')[0]
        mapping[file_name] = num
        num += 1

In [7]:
def show(dir_image):
    image = Image.open(dir_image)
    image.show()

def get_accuracy(outputs, ground_truth):
    num = len(outputs)
    sum = torch.sum(torch.tensor(outputs) == torch.tensor(ground_truth))
    return sum / num

In [8]:
imagenet_dir_data = "./data/Tiny-ImageNet-C"
imagenet_imbalanced_data_folders = ["brightness", "contrast", "defocus_blur", "elastic_transform", "fog", "frost", "gaussian_noise", "glass_blur", "impulse_noise", "jpeg_compression", "motion_blur", "pixelate", "shot_noise", "snow", "zoom_blur"]
imagenet_dir_mapping = "mapping.txt"

all_imagenet = []
for imbalanced_data_folder in imbalanced_data_folders:
    for index in range(1, 6):
        dir_corrupt = os.path.join(dir_data, imbalanced_data_folder, str(index))
        dir_classes = os.listdir(dir_corrupt)
        for dir_class in tqdm(dir_classes):
            data_folder = os.path.join(dir_corrupt, dir_class)
            label = mapping[dir_class]
            images = os.listdir(data_folder)
            for image in images:
                dir_image = os.path.join(data_folder, image)
                all_imagenet.append([dir_image, label])
random.shuffle(all_imagenet)


  0%|          | 0/200 [00:00<?, ?it/s]

100%|██████████| 200/200 [00:00<00:00, 4184.18it/s]
100%|██████████| 200/200 [00:00<00:00, 3733.31it/s]
100%|██████████| 200/200 [00:00<00:00, 1682.52it/s]
100%|██████████| 200/200 [00:00<00:00, 3540.03it/s]
100%|██████████| 200/200 [00:00<00:00, 4425.16it/s]
100%|██████████| 200/200 [00:00<00:00, 4898.57it/s]
100%|██████████| 200/200 [00:00<00:00, 4085.13it/s]
100%|██████████| 200/200 [00:00<00:00, 4234.85it/s]
100%|██████████| 200/200 [00:00<00:00, 2584.39it/s]
100%|██████████| 200/200 [00:00<00:00, 4074.85it/s]
100%|██████████| 200/200 [00:00<00:00, 3982.34it/s]
100%|██████████| 200/200 [00:00<00:00, 4447.64it/s]
100%|██████████| 200/200 [00:00<00:00, 1557.05it/s]
100%|██████████| 200/200 [00:00<00:00, 3361.44it/s]
100%|██████████| 200/200 [00:00<00:00, 4052.88it/s]
100%|██████████| 200/200 [00:00<00:00, 5137.97it/s]
100%|██████████| 200/200 [00:00<00:00, 4776.05it/s]
100%|██████████| 200/200 [00:00<00:00, 3787.64it/s]
100%|██████████| 200/200 [00:00<00:00, 4221.17it/s]
100%|███████

In [31]:
"""
1. Evaluate the model on the corrupted data without using test-time adaptation technique
"""

outputs = []
ground_truth = []
for imbalanced_data_folder in imbalanced_data_folders:
    for index in range(1, 6):
        dir_corrupt = os.path.join(dir_data, imbalanced_data_folder)
        dir_corrupt = os.path.join(dir_corrupt, str(index))
        dir_classes = os.listdir(dir_corrupt)
        for dir_class in tqdm.tqdm(dir_classes):
            data_folder = os.path.join(dir_corrupt, dir_class)
            label = mapping[dir_class]
            images = os.listdir(data_folder)
            for image in images:
                dir_image = os.path.join(data_folder, image)
                image = Image.open(dir_image)

                input = preprocess(image).unsqueeze(0)

                with torch.no_grad():
                    output = model(input)

                    ground_truth.append(label)
                    outputs.append(output)

100%|██████████| 200/200 [03:29<00:00,  1.05s/it]
100%|██████████| 200/200 [03:34<00:00,  1.07s/it]
100%|██████████| 200/200 [03:35<00:00,  1.08s/it]
100%|██████████| 200/200 [03:34<00:00,  1.07s/it]
100%|██████████| 200/200 [03:35<00:00,  1.08s/it]
100%|██████████| 200/200 [03:36<00:00,  1.08s/it]
100%|██████████| 200/200 [03:40<00:00,  1.10s/it]
100%|██████████| 200/200 [03:39<00:00,  1.10s/it]
100%|██████████| 200/200 [03:40<00:00,  1.10s/it]
100%|██████████| 200/200 [03:38<00:00,  1.09s/it]
100%|██████████| 200/200 [03:39<00:00,  1.10s/it]
100%|██████████| 200/200 [03:40<00:00,  1.10s/it]
100%|██████████| 200/200 [03:41<00:00,  1.11s/it]
100%|██████████| 200/200 [03:39<00:00,  1.10s/it]
100%|██████████| 200/200 [03:41<00:00,  1.11s/it]
100%|██████████| 200/200 [03:41<00:00,  1.11s/it]
100%|██████████| 200/200 [03:40<00:00,  1.10s/it]
100%|██████████| 200/200 [03:41<00:00,  1.11s/it]
100%|██████████| 200/200 [03:41<00:00,  1.11s/it]
100%|██████████| 200/200 [03:40<00:00,  1.10s/it]


In [37]:
for i in range(len(outputs)):
    outputs[i] = np.argmax(outputs[i])

In [39]:
accuracy = get_accuracy(outputs, ground_truth)
print(accuracy)

tensor(0.0988)


### CoTTA (Continual Test-time Adaptation)

### TENT (Test-time Adaptation by Entropy Minimization)

In [9]:
"""
Builds upon: https://github.com/qinenergy/cotta
Corresponding paper: https://arxiv.org/abs/2006.10726
"""

class Tent(nn.Module):
    """
    Tent adapts a model by entropy minimization during testing.
    Once tented, a model adapts itself by updating on every forward.
    """

    def __init__(self, model, optimizer, steps = 1):
        super().__init__()
        self.model = model
        self.optimizer = optimizer
        self.steps = steps
    
    def forward(self, x):
        for _ in range(self.steps):
            outputs = tent_forward_and_adapt(x, self.model, self.optimizer)
        return outputs

@torch.jit.script
def tent_softmax_entropy(x: torch.Tensor) -> torch.Tensor:
    """Entropy of softmax distribution from logits."""
    return -(x.softmax(1) * x.log_softmax(1)).sum(1)


@torch.enable_grad()  # ensure grads in possible no grad context for testing
def tent_forward_and_adapt(x, model, optimizer):
    """Forward and adapt model on batch of data.

    Measure entropy of the model prediction, take gradients, and update params.
    """
    # forward
    outputs = model(x)
    # adapt
    loss = tent_softmax_entropy(outputs).mean(0)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    return outputs

def tent_configure_model(model):
    """Configure model for use with tent."""
    # train mode, because tent optimizes the model to minimize entropy
    model.train()
    # disable grad, to (re-)enable only what tent updates
    model.requires_grad_(False)
    # configure norm for tent updates: enable grad + force batch statisics
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.requires_grad_(True)
            # force use of batch stats in train and eval modes
            m.track_running_stats = False
            m.running_mean = None
            m.running_var = None
    return model

def tent_collect_params(model):
    """Collect the affine scale + shift parameters from batch norms.

    Walk the model's modules and collect all batch normalization parameters.
    Return the parameters and their names.

    Note: other choices of parameterization are possible!
    """
    params = []
    names = []
    for nm, m in model.named_modules():
        if isinstance(m, nn.BatchNorm2d):
            for np, p in m.named_parameters():
                if np in ['weight', 'bias']:  # weight is scale, bias is shift
                    params.append(p)
                    names.append(f"{nm}.{np}")
    return params, names

def tent_check_model(model):
    """Check model for compatability with tent."""
    is_training = model.training
    assert is_training, "tent needs train mode: call model.train()"
    param_grads = [p.requires_grad for p in model.parameters()]
    has_any_params = any(param_grads)
    has_all_params = all(param_grads)
    assert has_any_params, "tent needs params to update: " \
                           "check which require grad"
    assert not has_all_params, "tent should not update all params: " \
                               "check which require grad"
    has_bn = any([isinstance(m, nn.BatchNorm2d) for m in model.modules()])
    assert has_bn, "tent needs normalization for its optimization"

In [14]:
def setup_raw(model):
    model.eval()
    return model

def setup_tent(model, steps = 1):
    """
    Set up tent adaptation.
    """
    model = tent_configure_model(model)
    params, param_names = tent_collect_params(model)
    optimizer = optim.Adam(params = params)
    tent_model = Tent(model = model, optimizer = optimizer, steps = steps)
    return tent_model

def setup_cotta(model, steps = 1):
    """
    Set up CoTTA adaptation.
    """
    model = cotta_configure_model(model)
    params, param_names = cotta_collect_params(model)
    optimizer = optim.Adam(params = params, lr = 0.01)
    cotta_model = CoTTA(model = model, optimizer = optimizer, steps = steps)
    return cotta_model

In [15]:
def evaluate_tiny_imagenet_C(model, batch_size = 100):
    # evaluate on each severity and type of corruption in turn
    outputs = torch.tensor([])
    ground_truth = []
    images = torch.tensor([])

    while(batch_size):
        dir_image, label = all_imagenet[random.randint(len(all_imagenet))]
        image = Image.open(dir_image)
        input = preprocess(image).unsqueeze(0)
        if input.numel() == 0:
            images = input
        else:
            images = torch.cat((images, input))
        
        ground_truth.append(label)
        batch_size -= 1
    
    images = torch.tensor(images)
    with torch.no_grad():
        output = model.forward(images)
        outputs = np.argmax(output, axis = 1)
    
    #print(outputs)
    #print(ground_truth)
    accuracy = get_accuracy(outputs, ground_truth)
    return accuracy

In [12]:
tent_model = setup_tent(model)
tent_check_model(tent_model)
accuracy = evaluate_tiny_imagenet_C(tent_model)
print(accuracy)

  images = torch.tensor(images)


tensor(0.0800)


  sum = torch.sum(torch.tensor(outputs) == torch.tensor(ground_truth))


In [16]:
tent_model = setup_tent(model, steps = 1)
tent_check_model(tent_model)
for epoch in range(100):
    accuracy = evaluate_tiny_imagenet_C(tent_model)
    print(accuracy)
print(accuracy)

  images = torch.tensor(images)
  sum = torch.sum(torch.tensor(outputs) == torch.tensor(ground_truth))


tensor(0.0600)
tensor(0.0800)
tensor(0.1100)
tensor(0.1000)
tensor(0.0500)
tensor(0.0600)
tensor(0.0700)
tensor(0.1500)
tensor(0.0700)
tensor(0.0200)
tensor(0.0900)
tensor(0.0400)
tensor(0.1000)
tensor(0.0900)
tensor(0.0700)
tensor(0.0700)
tensor(0.0900)
tensor(0.1000)
tensor(0.0700)
tensor(0.1200)
tensor(0.0900)
tensor(0.0600)
tensor(0.0600)
tensor(0.1200)
tensor(0.0700)
tensor(0.0800)
tensor(0.0800)
tensor(0.1300)
tensor(0.1000)
tensor(0.0500)
tensor(0.0700)
tensor(0.0500)
tensor(0.1400)
tensor(0.0800)
tensor(0.0300)
tensor(0.0700)
tensor(0.0700)
tensor(0.0600)
tensor(0.0500)
tensor(0.0900)
tensor(0.0800)
tensor(0.0500)
tensor(0.0200)
tensor(0.0500)
tensor(0.0600)
tensor(0.0200)
tensor(0.0700)
tensor(0.0600)
tensor(0.0500)
tensor(0.0700)
tensor(0.0500)
tensor(0.0500)
tensor(0.0500)
tensor(0.0700)
tensor(0.0900)
tensor(0.0600)
tensor(0.0800)
tensor(0.0700)
tensor(0.0700)
tensor(0.0900)
tensor(0.0200)
tensor(0.0200)
tensor(0.0300)
tensor(0.0700)
tensor(0.0600)
tensor(0.0500)
tensor(0.0