In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
from torchvision import datasets
from torchvision.transforms import Compose, Resize, Lambda, ToTensor, Grayscale, ToPILImage
from torch.utils.data import DataLoader

from imagenet_c import corrupt
from imagenet_c.corruptions import (
    gaussian_noise, shot_noise, impulse_noise, defocus_blur,
    glass_blur, motion_blur, zoom_blur, snow, frost, fog,
    brightness, contrast, elastic_transform, pixelate, jpeg_compression,
    speckle_noise, gaussian_blur, spatter, saturate)
from models import MLP, CNN_MNIST
from utils import test

torch.manual_seed(0)
np.random.seed(0)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using {} device".format(device))

In [None]:
# Load corruption functions
corruption_tuple = (gaussian_noise, shot_noise, impulse_noise, defocus_blur,
                    glass_blur, motion_blur, 
                    zoom_blur, snow, frost, fog,
                    brightness, contrast, elastic_transform, pixelate, jpeg_compression,
                    speckle_noise, gaussian_blur, spatter, saturate)

corruption_dict = {corr_func.__name__: corr_func for corr_func in corruption_tuple}

# Load model
# model = MLP().to(device)
# model.load_state_dict(torch.load('models/MLP_MNIST_weights_20211124_1035.pth'))
# model.eval()
model = CNN_MNIST().to(device)
model.load_state_dict(torch.load('models/CNN_MNIST_weights_20220411_0826.pth'))
model.eval()

# Define variables
BATCH_SIZE = 64


# Test classifier accuracy for different corruptions and severities

In [None]:
df_results = pd.DataFrame(index=['uncorrupted']+list(corruption_dict.keys()), columns=[f'severity={i}' for i in range(1, 6)])

# Uncorrupted baseline
test_data = datasets.MNIST(
                root='data',
                train=False,
                download=True,
                transform=Compose([
                    # Resize(244),
                    # Grayscale(3), # to RGB
                    # corruption_transform,
                    ToTensor(),
                    # Resize(28),
                    # Grayscale(),

                ])
            )
test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE)
_, accuracy = test(test_dataloader, model, torch.nn.CrossEntropyLoss(), device)
df_results.loc['uncorrupted', :] = accuracy

# Corrupted
for corruption in tqdm(corruption_dict.keys()):
    
    for severity in range(1, 6):

        try:

            corruption_transform = Lambda(lambda x: corrupt(np.uint8(x), severity=severity, corruption_name=corruption))

            test_data = datasets.MNIST(
                root='data',
                train=False,
                download=True,
                transform=Compose([
                    # Resize(244),
                    # Grayscale(3), # to RGB
                    corruption_transform,
                    ToTensor(),
                    # Resize(28),
                    # Grayscale(),

                ])
            )

            test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE)
            _, accuracy = test(test_dataloader, model, torch.nn.CrossEntropyLoss(), device)
            df_results.loc[corruption, f'severity={severity}'] = accuracy

        except:
            continue

df_results = df_results.dropna()
df_results

# Visualize corruptions

In [None]:
valid_corruptions = df_results.index.tolist()

test_data = datasets.MNIST(
                root='data',
                train=False,
            )
            
img = test_data[0][0]
img = np.uint8(img)


plt.figure(figsize=(30, 10))
for i, corruption in enumerate(valid_corruptions):

    if corruption == 'uncorrupted':
        img_c = img
    else:
        img_c = corrupt(img, severity=5, corruption_name=corruption)

    plt.subplot(1, len(valid_corruptions), i+1)
    plt.imshow(img_c, cmap='gray', vmin=0, vmax=255)
    plt.axis('off')
    plt.title(corruption)

In [None]:
import torchvision
torchvision.transforms.ToTensor()

test_data = datasets.MNIST(
                root='data',
                train=False,
            )
            
img = test_data[0][0]
img = np.uint8(img)

# img_c = corrupt(img, severity=5, corruption_name='gaussian_noise')
img_c = img

img_c = torch.from_numpy(img_c).float()
img_c = img_c.unsqueeze(0).unsqueeze(0).to(device)
model(img_c)

# Generate dataset including corruption with random severity

In [None]:
corruption_transform = Lambda(lambda x: corrupt(np.uint8(x), severity=np.random.randint(1, 6), corruption_name='contrast'))

train_data_corrupted = datasets.MNIST(
    root='data',
    train=True,
    download=True,
    transform=Compose([
        corruption_transform,
    ])
)

img = train_data_corrupted[0][0]
img = np.uint8(img)

plt.figure()
plt.imshow(img, cmap='gray', vmin=0, vmax=255)
plt.axis('off')