In [None]:
import numpy as np
import torch
import torchvision.transforms as transforms
import torchvision
import os
import yaml
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset, Subset
from torchvision.datasets import ImageFolder
from utils import *
import matplotlib.pyplot as plt

from pytorch_lightning import Trainer
from baselines import Baseline
from linear_probe import Linear

import warnings
warnings.filterwarnings("ignore")


## Distortions

Here is a sample of the distortions used to train our models.

In [None]:
r50 = RandomMask(percent_missing=0.5, fixed = False)
r75 = RandomMask(percent_missing=0.75, fixed = False)
r90 = RandomMask(percent_missing=0.90, fixed = False)

blur21 = transforms.GaussianBlur(kernel_size=21, sigma=5)
blur37 = transforms.GaussianBlur(kernel_size=37, sigma=9)

noise01 = GaussianNoise(std=0.1, fixed=False)
noise03 = GaussianNoise(std=0.3, fixed=False)
noise05 = GaussianNoise(std=0.5, fixed=False)

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor()
])

val_dataset = ImageNet100(
    root='/home/sriram/Projects/Datasets/ImageNet100',
    split = 'train',
    transform = transform
)

val_dataloader = DataLoader(val_dataset, batch_size=1, num_workers=24,\
                            pin_memory=True, shuffle=True)

pic = next(iter(val_dataloader))[0][0]
plt.figure()
plt.axis('off')
plt.title("Original", fontsize=20)
plt.imshow(pic.permute(1, 2, 0))

In [None]:
rand50 = r50(pic)
rand75 = r75(pic)
rand90 = r90(pic)

b21 = blur21(pic)
b37 = blur37(pic)

n01 = noise01(pic)
n03 = noise03(pic)
n05 = noise05(pic)

imgs = [pic, b21, b37, rand50, rand75, rand90, n01, n03, n05]
labels = ["ORIGINAL", "BLUR N=21", "BLUR N=37", "RANDOM MASK 50%", \
          "RANDOM MASK 75%", "RANDOM MASK 90%", "GAUSSIAN NOISE σ=0.1", "GAUSSIAN NOISE σ=0.3", "GAUSSIAN NOISE σ=0.5"]

_, axs = plt.subplots(3, 3, figsize=(12, 13))
plt.subplots_adjust(wspace=0.05, hspace=0.15)
axs = axs.flatten()
i=0
for img, ax in zip(imgs, axs):
    ax.imshow(img.permute(1, 2, 0))
    ax.title.set_text(labels[i])
    ax.title.set_size(16)
    ax.axis('off')
    i = i +1
plt.show()

## Evaluation of Baseline and LinearProbe

Here is a small evaluation script to run testing on the Baseline and LinearProbe models. Make sure to import the checkpoints, as described in the README, as well as have CUDA enabled in your system.

In [None]:
from pytorch_lightning import Trainer
from baselines import Baseline
from linear_probe import LinearProbe

baseline_file = 'rand90_baseline.ckpt'
linear_file = 'rand90_linear.ckpt'

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    RandomMask(percent_missing=0.90, fixed = False)
])

val_dataset = ImageNet100(
    root='/home/sriram/Projects/Datasets/ImageNet100',
    split = 'train',
    transform = transform
)

test_dataloader = DataLoader(val_dataset, batch_size=64, num_workers=24,\
                            pin_memory=True, shuffle=False)

trainer = Trainer(gpus=1)

In [None]:
baseline_model = Baseline.load_from_checkpoint(baseline_file)
trainer.test(baseline_model, test_dataloaders=test_dataloader)

In [None]:
linear_model = Baseline.load_from_checkpoint(linear_file)
trainer.test(linear_model, test_dataloaders=test_dataloader)