# OOD Detection

The idea of this set of experiments is to evaluate a `ResNet50` backbone pretrained on ImageNet with CF data on OOD benchmarks such as `ImageNet Sketch`. If findings are worthy, we can extend it to other datasets.

In [1]:
%load_ext autoreload
%autoreload 2

In [21]:
import re
from os.path import join, basename, dirname
from glob import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from natsort import natsorted
from tqdm import tqdm
from PIL import Image
import torch
import torchvision
from torchvision.io import read_image
from torchvision.utils import make_grid
from torchvision import transforms

from gradcam.utils import visualize_cam
from gradcam import GradCAM, GradCAMpp

In [3]:
from experiment_utils import set_env, REPO_PATH, seed_everything
set_env()

In [31]:
from experiments.image_utils import denormalize, show_single_image
from experiments.imagenet_utils import IMModel, AverageEnsembleModel
from experiments.ood_utils import validate as ood_validate

In [5]:
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Computer Modern Roman"],
})

### Set environment

In [6]:
seed_everything(0)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Load model

In [7]:
from imagenet.models.classifier_ensemble import InvariantEnsemble

In [8]:
model = InvariantEnsemble("resnet50", pretrained=True)

In [9]:
# load weights from a checkpoint
ckpt_path = "imagenet/experiments/classifier_2022_01_19_15_36_sample_run/model_best.pth"
ckpt = torch.load(ckpt_path, map_location="cpu")
ckpt_state_dict = ckpt["state_dict"]
ckpt_state_dict = {k.replace("module.", ""):v for k, v in ckpt_state_dict.items()}

In [10]:
model.load_state_dict(ckpt_state_dict)

<All keys matched successfully>

In [11]:
model = model.eval()

In [12]:
model = model.to(device)

In [27]:
shape_model = IMModel(base_model=model, mode="shape")
avg_model = AverageEnsembleModel(base_model=model)
pytorch_model = torchvision.models.resnet50(pretrained=True)

In [29]:
# check models
x = torch.randn((1, 3, 224, 224))

y = shape_model(x)
assert y.shape == torch.Size([1, 1000])

y = avg_model(x)
assert y.shape == torch.Size([1, 1000])

y = pytorch_model(x)
assert y.shape == torch.Size([1, 1000])

### Download data

Use `download_imagenet_sketch.sh` to download the data in folder `cgn_framework/imagenet/data/sketch/`

### Define the validation loader

In [18]:
valdir = join(join(REPO_PATH, "cgn_framework", "imagenet", "data/sketch"), 'val')

In [16]:
combined_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [22]:
val_loader = torch.utils.data.DataLoader(
    torchvision.datasets.ImageFolder(valdir, combined_transform),
    batch_size=36,
    shuffle=False,
    drop_last=False,
    num_workers=4,
    pin_memory=True,
)

In [23]:
len(val_loader)

1414

### Benchmark performance

In [33]:
acc1 = ood_validate(val_loader, pytorch_model)


Test: [   0/1414]	Time 14.957 (14.957)	Acc@1   5.56 (  5.56)
Test: [  10/1414]	Time  9.909 (10.761)	Acc@1   5.56 ( 13.64)
Test: [  20/1414]	Time 15.492 (11.546)	Acc@1   0.00 ( 15.48)
Test: [  30/1414]	Time 11.206 (11.750)	Acc@1  11.11 ( 18.28)
Test: [  40/1414]	Time  9.406 (11.604)	Acc@1  11.11 ( 16.73)
Test: [  50/1414]	Time 10.695 (11.611)	Acc@1  19.44 ( 17.16)



KeyboardInterrupt

