Mount your Google Drive to use datasets (For colab users only)

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Import Libraries

In [14]:
import torch
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score, roc_curve, auc
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
import os
from sklearn.metrics import roc_auc_score, roc_curve
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from tqdm import tqdm
from PIL import Image

# Model and Test Datasets

In [3]:
class iSUNDataset(Dataset):
    def __init__(self, directory, transform=None):
        self.directory = directory
        self.image_paths = [os.path.join(directory, fname) for fname in os.listdir(directory) if fname.endswith(('.png', '.jpg', '.jpeg'))]
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')  # Ensure images are in RGB mode
        if self.transform:
            image = self.transform(image)
        return image, 0  # Dummy label for OOD dataset


In [7]:
transform_cifar = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),  # ImageNet mean
                         (0.229, 0.224, 0.225))
])

In [8]:
batch_size = 64

isun_dataset_path = '/path/to/iSUN'
isun_dataset = iSUNDataset(isun_dataset_path, transform=transform_cifar)
isun_dataloader = DataLoader(isun_dataset, batch_size=batch_size, shuffle=False)

cifar10_test = datasets.CIFAR10(root='/path/to/cifar10', train=False, download=True, transform=transform_cifar)
cifar10_test_dataloader = DataLoader(cifar10_test, batch_size=batch_size, shuffle=False)

Load the pretrained model

In [10]:
device = torch.device("cuda")
model = models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(model.fc.in_features, 10) # number of classes if Cifar-10
model.load_state_dict(torch.load("/path/to/resnet18.pt"))
model.to(device)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 172MB/s]
  model.load_state_dict(torch.load("/content/drive/MyDrive/resnet18_cifar10.pt"))


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

# Evaluation

In [12]:
def evaluate_model(model, id_dataloader, ood_dataloader, device):
    """
    Evaluates the model's performance on ID and OOD datasets.

    Parameters:
    - model: Trained ResNet-18 model
    - id_dataloader: DataLoader for in-distribution data (Cifar-10)
    - ood_dataloader: DataLoader for out-of-distribution data (iSUN)
    - device: Device to perform computation on

    Returns:
    - metrics: Dictionary with AUROC and FPR95 metrics
    """
    model.eval()
    id_scores = []
    ood_scores = []

    # Compute ID scores with progress tracking
    print("Evaluating In-Distribution (ID) Dataset...")
    for id_inputs, _ in tqdm(id_dataloader, desc="ID Progress", leave=False):
        id_inputs = id_inputs.to(device)
        with torch.no_grad():
            id_outputs = model(id_inputs)
            id_outputs = id_outputs
            id_energy_scores = -torch.logsumexp(id_outputs, dim=1)
            id_scores.extend(id_energy_scores.cpu().numpy())

    # Compute OOD scores with progress tracking
    print("Evaluating Out-of-Distribution (OOD) Dataset...")
    for ood_inputs, _ in tqdm(ood_dataloader, desc="OOD Progress", leave=False):
        ood_inputs = ood_inputs.to(device)
        with torch.no_grad():
            ood_outputs = model(ood_inputs)
            ood_outputs = ood_outputs
            ood_energy_scores = -torch.logsumexp(ood_outputs, dim=1)
            ood_scores.extend(ood_energy_scores.cpu().numpy())

    # Invert energy scores so that higher scores correspond to ID
    id_scores = -np.array(id_scores)
    ood_scores = -np.array(ood_scores)

    # Concatenate scores and true labels
    y_true = np.concatenate([np.ones(len(id_scores)), np.zeros(len(ood_scores))])
    y_scores = np.concatenate([id_scores, ood_scores])

    # Calculate AUROC
    auroc = roc_auc_score(y_true, y_scores)

    # Calculate FPR95
    fpr, tpr, thresholds = roc_curve(y_true, y_scores)
    # Find the FPR where TPR >= 95%
    try:
        idx = np.where(tpr >= 0.95)[0][0]
        fpr95 = fpr[idx]
    except IndexError:
        fpr95 = 1.0  # If TPR never reaches 95%

    metrics = {
        "AUROC": auroc,
        "FPR95": fpr95
    }

    return metrics

In [15]:
metrics = evaluate_model(model, cifar10_test_dataloader, isun_dataloader, device)

Evaluating In-Distribution (ID) Dataset...




Evaluating Out-of-Distribution (OOD) Dataset...




In [16]:
print(f"AUROC: {metrics['AUROC']:.4f}, FPR95: {metrics['FPR95']:.4f}")

AUROC: 0.6600, FPR95: 0.9472
