# Reproduce results on Cars & Deer

This notebook reproduces the results obtained on Cars & Deer for:
* Supervised model training
* Unsupervised model training

**Note: performing model training and inference on a GPU will lead to significant speedups.**

## Prerequisites

Before running the contents of this notebook, you need to have the Cars & Deer dataset images stored locally. To get the images, extract the contents of the file `./datasets/car_deer_noisy.tar.gz` in the Easy-ICD repository. Then, replace the value of `img_dir` below with the location you stored the dataset in.

## Required Imports

In [293]:
import torch
import torch.nn as nn
import numpy as np
import os
import json
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader, RandomSampler
from torchvision import transforms, utils
from torchvision.io import read_image
from sklearn.manifold import TSNE

from easy_icd.utils.datasets import create_dataset
from easy_icd.training.losses import SimCLRLoss
from easy_icd.utils.augmentation import RandomImageAugmenter, augment_minibatch
from easy_icd.utils.models import ResNet, LinearClassifier
from easy_icd.training.train_models import train_model
from easy_icd.outlier_detection.detect_outliers import analyze_data
from easy_icd.outlier_removal.remove_outliers import remove_outliers

from PIL import Image

from typing import Optional, List, Tuple, Dict, Callable

In [220]:
import gc

gc.collect()

torch.cuda.empty_cache()

## Create dataset

In [160]:
img_dir = './images/car_deer_noisy'

class_names = ['car', 'deer']

probs = 0.2 * torch.ones(8)
augmenter = RandomImageAugmenter((32, 32), probs, 2)
train_ds, test_ds = create_dataset(img_dir, class_names, False, True, 0.9)
train_dataloader = DataLoader(train_ds, batch_size=512, shuffle=True)
test_dataloader = DataLoader(test_ds, batch_size=200, shuffle=True)

## Create model

In [162]:
model = ResNet(num_blocks=[2, 2, 2, 2], out_channels=[64 * (2 ** i) for i in range(4)])

## Train model

In [None]:
train_model(model, train_dataloader, test_dataloader, './car_deer_noisy/', augmenter,
            'simclr', 100, None, 2, 0.1, lr=2e-1, min_lr=5e-3, num_warmup_epochs=50, losses_name='',
            gpu=True, epoch_offset=0, dataset_means_and_stds=[[0.4651, 0.4632, 0.4244], [0.2351, 0.2314, 0.2469]])

## View learned representations

In [167]:
from torchvision.transforms import Normalize

train_ds = create_dataset('./images/cifar_10_alt', ['car', 'deer', 'ship'], False, True)
train_dataloader = DataLoader(train_ds, batch_size=500, shuffle=False)

model.to('cuda')
model.use_projection_head(True)
model.eval()

normalizer = Normalize([0.4651, 0.4632, 0.4244], [0.2351, 0.2314, 0.2469])

feats_list = []
labels_list = []

num_batches = 20

for idx, (images, labels) in enumerate(train_dataloader):
    images = normalizer(images).to('cuda')
    labels_list.append(labels.detach().numpy())
    
    features = model(images).cpu().detach()
    features = torch.div(features, torch.linalg.norm(features, dim=1, keepdim=True))
    feats_list.append(features)
    
    if idx == (num_batches - 1):
        break
    
feats_list = np.concatenate(feats_list, 0)
labels = np.concatenate(labels_list, 0)

In [168]:
small_feats = TSNE(n_components=2, perplexity=50).fit_transform(feats_list)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(20, 20))

class_names = ['car', 'deer', 'ship']

for i in range(3):
    sel_inds = np.argwhere(labels == i).flatten()
    ax.scatter(small_feats[:, 0][sel_inds], small_feats[:, 1][sel_inds], label=class_names[i], s=100)
    
ax.legend(fontsize=36, ncol=1, loc=1)
ax.set_xticks([])
ax.set_yticks([])

## Train supervised model

In [None]:
img_dir = './images/car_deer_noisy'

class_names = ['car', 'deer']

probs = 0.2 * torch.ones(8)
augmenter = RandomImageAugmenter((32, 32), probs, 2)
train_ds, test_ds = create_dataset(img_dir, class_names, False, True, 0.9)
train_dataloader = DataLoader(train_ds, batch_size=128, shuffle=True)
test_dataloader = DataLoader(test_ds, batch_size=200, shuffle=True)

In [None]:
model = ResNet(num_layers=4, num_blocks=[2, 2, 2, 2], out_channels=[64 * (2 ** i) for i in range(4)], linear_sizes=[256, 64, 2], supervised=True)

In [None]:
model.train()

train_model(model, train_dataloader, test_dataloader, './models/supervised/car_deer_noisy_raw/', augmenter, 'ce', 150, None, 2, 0.1, lr=1e-1, min_lr=5e-3, num_warmup_epochs=30,
                      losses_name='', gpu=True, epoch_offset=0, dataset_means_and_stds=[[0.4651, 0.4632, 0.4244], [0.2351, 0.2314, 0.2469]])

## View learned representations

In [None]:
model = ResNet(num_layers=4, num_blocks=[2, 2, 2, 2], out_channels=[64 * (2 ** i) for i in range(4)], linear_sizes=[256, 64, 2], supervised=True)

old_dict = torch.load('./models/supervised/car_deer_noisy_raw/model_state_epoch_150.pt')

model.load_state_dict(old_dict, strict=False)

In [None]:
from torchvision.transforms import Normalize

train_ds = create_dataset('./images/cifar_10_alt', ['car', 'deer', 'ship'], False, True)
train_dataloader = DataLoader(train_ds, batch_size=500, shuffle=True)

model.to('cuda')
model.use_projection_head(False)
model.eval()

normalizer = Normalize([0.4651, 0.4632, 0.4244], [0.2351, 0.2314, 0.2469])

feats_list = []
labels_list = []

for i in range(10):
    images, labels = next(iter(train_dataloader))

    images = normalizer(images).to('cuda')
    labels_list.append(labels.detach().numpy())
    
    features = model(images).cpu().detach()
    features = torch.div(features, torch.linalg.norm(features, dim=1, keepdim=True))
    feats_list.append(features)
    
feats_list = np.concatenate(feats_list, 0)
labels = np.concatenate(labels_list, 0)

In [None]:
from sklearn.manifold import TSNE

small_feats = TSNE(n_components=2, perplexity=50).fit_transform(feats_list)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(20, 20))

class_names = ['car', 'deer', 'ship']

for i in range(3):
    sel_inds = np.argwhere(labels == i).flatten()
    ax.scatter(small_feats[:, 0][sel_inds], small_feats[:, 1][sel_inds], label=class_names[i], s=100)
    
ax.legend(fontsize=36, ncol=1, loc=1)
ax.set_xticks([])
ax.set_yticks([])