# Import Libraries

In [None]:
import torch
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score, roc_curve, auc
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
import os
from sklearn.metrics import roc_auc_score, roc_curve
import torch.nn.functional as F
import torch.nn as nn
from torchvision.datasets import ImageFolder
from tqdm import tqdm
from PIL import Image

# Common Functions

In [None]:
class ImageDataset(Dataset):
    def __init__(self, directory, transform=None):
        self.directory = directory
        self.image_paths = [os.path.join(directory, fname) for fname in os.listdir(directory) if fname.endswith(('.png', '.jpg', '.jpeg'))]
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, 0

class RecursiveDataset(Dataset):
    def __init__(self, root_dir, transform=None):

        self.image_paths = []
        for root, _, files in os.walk(root_dir):
            for fname in files:
                if fname.endswith(('.png', '.jpg', '.jpeg')):
                    self.image_paths.append(os.path.join(root, fname))
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, 0

In [None]:
def evaluate_model(model, id_dataloader, ood_dataloader, device):
    """
    Evaluates the model's performance on ID and OOD datasets.

    Parameters:
    - model: Trained model
    - id_dataloader: DataLoader for ID data
    - ood_dataloader: DataLoader for OOD data

    Returns:
    - metrics: Dictionary with AUROC and FPR95 metrics
    """
    model.eval()
    id_scores = []
    ood_scores = []

    # Compute ID scores with progress tracking
    print("Evaluating In-Distribution (ID) Dataset...")
    for id_inputs, _ in tqdm(id_dataloader, desc="ID Progress", leave=False):
        id_inputs = id_inputs.to(device)
        with torch.no_grad():
            id_outputs = model(id_inputs)
            id_outputs = id_outputs
            id_energy_scores = -torch.logsumexp(id_outputs, dim=1)
            id_scores.extend(id_energy_scores.cpu().numpy())

    # Compute OOD scores with progress tracking
    print("Evaluating Out-of-Distribution (OOD) Dataset...")
    for ood_inputs, _ in tqdm(ood_dataloader, desc="OOD Progress", leave=False):
        ood_inputs = ood_inputs.to(device)
        with torch.no_grad():
            ood_outputs = model(ood_inputs)
            ood_outputs = ood_outputs
            ood_energy_scores = -torch.logsumexp(ood_outputs, dim=1)
            ood_scores.extend(ood_energy_scores.cpu().numpy())

    # Invert energy scores so that higher scores correspond to ID
    id_scores = -np.array(id_scores)
    ood_scores = -np.array(ood_scores)

    # Concatenate scores and true labels
    y_true = np.concatenate([np.ones(len(id_scores)), np.zeros(len(ood_scores))])
    y_scores = np.concatenate([id_scores, ood_scores])

    # Calculate AUROC
    auroc = roc_auc_score(y_true, y_scores)

    # Calculate FPR95
    fpr, tpr, thresholds = roc_curve(y_true, y_scores)
    # Find the FPR where TPR >= 95%
    try:
        idx = np.where(tpr >= 0.95)[0][0]
        fpr95 = fpr[idx]
    except IndexError:
        fpr95 = 1.0

    metrics = {
        "AUROC": auroc,
        "FPR95": fpr95
    }

    return metrics

# Create Dataloaders

In [None]:
transform_cifar = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),  # ImageNet mean
                         (0.229, 0.224, 0.225))
])

In [None]:
batch_size = 64

ood_dataset_path = '/path/to/dataset'
ood_dataset = ImageDataset(ood_dataset_path, transform=transform_cifar)
ood_dataloader = DataLoader(ood_dataset, batch_size=batch_size, shuffle=False, num_workers=16)

cifar10_dataset_path = '/path/to/dataset'
cifar10_test = datasets.CIFAR10(root=cifar10_dataset_path, train=False, download=True, transform=transform_cifar)
cifar10_test_dataloader = DataLoader(cifar10_test, batch_size=batch_size, shuffle=False, num_workers=16)

cifar100_dataset_path = '/path/to/dataset'
cifar100_test = datasets.CIFAR100(root=cifar100_dataset_path, train=False, download=True, transform=transform_cifar)
cifar100_test_dataloader = DataLoader(cifar100_test, batch_size=batch_size, shuffle=False, num_workers=16)



Files already downloaded and verified


# Create the Model and Load Weights

In [None]:
class ResNet18(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        base_model = models.resnet18(pretrained=False)
        self.features = nn.Sequential(*list(base_model.children())[:-1])
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        features = self.features(x)
        features = features.view(features.size(0), -1)
        out = self.fc(features)
        return out

    def get_features(self, x):
        features = self.features(x)
        return features.view(features.size(0), -1)

In [None]:
device = torch.device("cuda")
model = ResNet18(num_classes=10) # adjust for number of classes
state_dict = torch.load("/path/to/model/weights")

# create a new state dict if trained with multiple GPUs (optional)
new_state_dict = {}
for k, v in state_dict.items():
    new_key = k.replace('module.', '') if k.startswith('module.') else k
    new_state_dict[new_key] = v

model.load_state_dict(new_state_dict)
model.to(device)

  state_dict = torch.load("/content/drive/MyDrive/Greg Greg+ Models/resnet_greg+_cifar10.pt")


ResNet18(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=Tru

# Evaluation Statistics

In [None]:
metrics = evaluate_model(model, cifar10_test_dataloader, ood_dataloader, device)

Evaluating In-Distribution (ID) Dataset...




Evaluating Out-of-Distribution (OOD) Dataset...




In [None]:
print(f"AUROC: {metrics['AUROC']:.4f}, FPR95: {metrics['FPR95']:.4f}")

AUROC: 0.3375, FPR95: 0.9919
