In [1]:
# necessary imports
import torch
import os
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
import torch.optim as optim
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader
from openood.evaluation_api import Evaluator
from openood.networks import ResNet18_32x32

In [2]:
import faiss
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm

In [3]:
imagesize = 32

transform_test = transforms.Compose([
    transforms.Resize((imagesize, imagesize)),
    transforms.CenterCrop(imagesize),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
    # transforms.Normalize([x/255.0 for x in [125.3, 123.0, 113.9]],
    #                     [x/255.0 for x in [63.0, 62.1, 66.7]]),
])

transform_train = transforms.Compose([
    # transforms.RandomCrop(imagesize, padding=4),
    transforms.RandomResizedCrop(size=imagesize, scale=(0.2, 1.)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
    # transforms.Normalize([x / 255.0 for x in [125.3, 123.0, 113.9]],
    #                      [x / 255.0 for x in [63.0, 62.1, 66.7]]),
])

In [4]:
train_dir4 = './data/images_classic/cifar4/cifar4/train'
test_dir4 = './data/images_classic/cifar4/cifar4/test'
train_dir6 = './data/images_classic/cifar6/cifar6/train'
test_dir6 = './data/images_classic/cifar6/cifar6/test'
train_dir10 = './data/images_classic/cifar10/cifar10/train'
test_dir10 = './data/images_classic/cifar10/cifar10/test'

In [5]:
train_dir10 = './data/images_classic/cifar10/cifar10/train'
test_dir10 = './data/images_classic/cifar10/cifar10/test'
train_dir100 = './data/images_classic/cifar100/cifar100/train'
test_dir100 = './data/images_classic/cifar100/cifar100/test'

In [14]:
net = ResNet18_32x32(num_classes=10)
net.load_state_dict(torch.load('resnet_cifar4.pth'))
net.cuda()
net.eval()

  net.load_state_dict(torch.load('resnet_cifar4.pth'))


ResNet18_32x32(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1,

In [7]:
transform_cifar10 = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]),
])
transform_cifar100 = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761]),
])

In [8]:
def evaluate(model,dataloaders):
    # Set the model to evaluation mode
    model.eval()

    # Initialize variables to track accuracy per class
    correct_preds = {classname: 0 for classname in class_names}
    total_preds = {classname: 0 for classname in class_names}

    # Evaluation loop
    with torch.no_grad():
        for inputs, labels in dataloaders['test']:
            inputs = inputs.to('cuda')
            labels = labels.to('cuda')

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            # Track accuracy for each class
            for label, pred in zip(labels, preds):
                if pred == label:
                    correct_preds[class_names[label]] += 1
                total_preds[class_names[label]] += 1

    # Calculate and print accuracy for each class
    for classname, correct_count in correct_preds.items():
        accuracy = 100 * float(correct_count) / total_preds[classname]
        print(f'Accuracy for class {classname}: {accuracy:.2f}%')

In [9]:
train_dir4 = './data/images_classic/cifar4/cifar4/train'
test_dir4 = './data/images_classic/cifar4/cifar4/test'
train_dir6b = './data/images_classic/cifar6b/cifar6b/train'
test_dir6b = './data/images_classic/cifar6b/cifar6b/test'
train_dir10 = './data/images_classic/cifar10/cifar10/train'
test_dir10 = './data/images_classic/cifar10/cifar10/test'
train_dir100 = './data/images_classic/cifar100/cifar100/train'
test_dir100 = './data/images_classic/cifar100/cifar100/test'

In [11]:
# Load train and test datasets
image_datasets4 = {
    'train': datasets.ImageFolder(train_dir4, transform=transform_cifar10),
    'test': datasets.ImageFolder(test_dir4, transform=transform_cifar10)
}

# Create DataLoaders
dataloaders4 = {
    'train': DataLoader(image_datasets4['train'], batch_size=64, shuffle=True, num_workers=4),
    'test': DataLoader(image_datasets4['test'], batch_size=64, shuffle=False, num_workers=4)
}
# Load train and test datasets
image_datasets6b = {
    'train': datasets.ImageFolder(train_dir6b, transform=transform_cifar10),
    'test': datasets.ImageFolder(test_dir6b, transform=transform_cifar10)
}

# Create DataLoaders
dataloaders6b = {
    'train': DataLoader(image_datasets6b['train'], batch_size=64, shuffle=True, num_workers=4),
    'test': DataLoader(image_datasets6b['test'], batch_size=64, shuffle=False, num_workers=4)
}
# Load train and test datasets
image_datasets10 = {
    'train': datasets.ImageFolder(train_dir10, transform=transform_cifar10),
    'test': datasets.ImageFolder(test_dir10, transform=transform_cifar10)
}

# Create DataLoaders
dataloaders10 = {
    'train': DataLoader(image_datasets10['train'], batch_size=64, shuffle=True, num_workers=4),
    'test': DataLoader(image_datasets10['test'], batch_size=64, shuffle=False, num_workers=4)
}

In [12]:
# Class names in CIFAR-10
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

In [15]:
a=evaluate(net,dataloaders10)

Accuracy for class airplane: 90.40%
Accuracy for class automobile: 99.50%
Accuracy for class bird: 85.30%
Accuracy for class cat: 80.40%
Accuracy for class deer: 0.00%
Accuracy for class dog: 0.00%
Accuracy for class frog: 0.00%
Accuracy for class horse: 0.00%
Accuracy for class ship: 0.00%
Accuracy for class truck: 0.00%


## feat_extract

In [16]:
torch.manual_seed(1)
torch.cuda.manual_seed(1)
np.random.seed(1)
device = 'cuda' 

In [17]:
batch_size = 64

In [18]:
FORCE_RUN = False

In [19]:
# Normalizer function to ensure unit norm
normalizer = lambda x: x / np.linalg.norm(x, axis=-1, keepdims=True) + 1e-10

In [20]:
class KNNPostprocessor():
    def __init__(self, K):
        self.K = K
        self.activation_log = None

    def setup(self, net: nn.Module, id_loader_dict, ood_loader_dict):
            activation_log = []
            net.eval()
            with torch.no_grad():
                for batch in tqdm(id_loader_dict['train'],
                                  desc='Setup: ',
                                  position=0,
                                  leave=True):
                    data = batch[0].cuda()
                    data = data.float()

                    _, feature = net(data, return_feature=True)
                    activation_log.append(
                        normalizer(feature.data.cpu().numpy()))

            self.activation_log = np.concatenate(activation_log, axis=0)
            self.index = faiss.IndexFlatL2(feature.shape[1])
            self.index.add(self.activation_log)
    '''    def postprocess(self, net: nn.Module, data):
        output, feature = net(data, return_feature=True)
        feature_normed = normalizer(feature.data.cpu().numpy())
        D, _ = self.index.search(
            feature_normed,
            self.K,
        )
        kth_dist = -D[:, -1]
        _, pred = torch.max(torch.softmax(output, dim=1), dim=1)
        return pred, torch.from_numpy(kth_dist)'''
    def detect_ood(self, net: nn.Module, ood_loader):
        """
        Perform KNN-based OOD detection by computing distances to the nearest neighbors.
        
        Args:
        - net (nn.Module): The neural network model used for feature extraction.
        - ood_loader (DataLoader): DataLoader for the out-of-distribution (OOD) dataset (test data).

        Returns:
        - scores (torch.Tensor): OOD scores based on K-th nearest neighbor distances.
        """
        ood_scores = []
        net.eval()

        # Extract features for OOD samples and compute distances to nearest neighbors
        for batch in tqdm(ood_loader, desc='Processing OOD data', position=0, leave=True):
            data = batch[0].cuda()  # Get the input data and move to GPU
            _, feature = net(data, return_feature=True)  # Extract features

            # Normalize the features
            feature_normed = normalizer(feature.data.cpu().numpy())

            # Search K nearest neighbors and get distances
            D, _ = self.index.search(feature_normed, self.K)
            
            # Use the K-th distance as the OOD score
            kth_dist = -D[:, -1]  # Negative distance for consistency with OOD scoring
            ood_scores.append(torch.from_numpy(kth_dist))

        # Concatenate all OOD scores
        ood_scores = torch.cat(ood_scores, dim=0)
        return ood_scores
    def evaluate(self, id_loader, ood_loader, net):
        """
        Evaluate OOD detection performance by comparing ID and OOD samples.
        
        Args:
        - id_loader (DataLoader): DataLoader for the in-distribution dataset (test data).
        - ood_loader (DataLoader): DataLoader for the out-of-distribution dataset (test data).
        - net (nn.Module): The neural network model used for feature extraction.

        Returns:
        - fpr_at_95_tpr (float): False positive rate at 95% true positive rate.
        """
        # Extract ID and OOD features
        self.id_scores = self.detect_ood(net, id_loader)
        self.ood_scores = self.detect_ood(net, ood_loader)

        # Calculate FPR at 95% TPR
        fpr_at_95_tpr = self.calculate_fpr_at_95_tpr(self.id_scores, self.ood_scores)
        return fpr_at_95_tpr
    def calculate_fpr_at_95_tpr(self,id_scores, ood_scores):
        """
        Calculate FPR at 95% TPR for OOD detection.

        Args:
        - id_scores (torch.Tensor): Scores for in-distribution samples.
        - ood_scores (torch.Tensor): Scores for out-of-distribution samples.

        Returns:
        - fpr_at_95_tpr (float): False positive rate at 95% true positive rate.
        """
        labels = np.concatenate([np.zeros_like(id_scores), np.ones_like(ood_scores)])
        scores = np.concatenate([id_scores, ood_scores])

        # Sort scores and labels based on the score threshold
        sorted_indices = np.argsort(scores)
        sorted_labels = labels[sorted_indices]

        # Calculate TPR and FPR
        tpr = np.cumsum(sorted_labels) / np.sum(sorted_labels)
        fpr = np.cumsum(1 - sorted_labels) / np.sum(1 - sorted_labels)

        # Find FPR where TPR is closest to 95%
        idx = np.searchsorted(tpr, 0.95)
        fpr_at_95_tpr = fpr[idx]
        return fpr_at_95_tpr

In [22]:
k50 = KNNPostprocessor(50)
# Setup Faiss index with in-distribution training data
k50.setup(net, dataloaders4,dataloaders6b )
# Evaluate OOD detection performance
fpr_at_95_tpr = k50.evaluate(dataloaders4['test'], dataloaders6b['test'], net)
print(f"FPR at 95% TPR: {fpr_at_95_tpr:.4f}")

Setup: 100%|█████████████████████████████████████████████████████████████████████████| 313/313 [00:09<00:00, 34.18it/s]
Processing OOD data: 100%|█████████████████████████████████████████████████████████████| 63/63 [00:36<00:00,  1.72it/s]
Processing OOD data: 100%|█████████████████████████████████████████████████████████████| 94/94 [00:50<00:00,  1.85it/s]

FPR at 95% TPR: 0.7535





In [24]:
k10 = KNNPostprocessor(10)
# Setup Faiss index with in-distribution training data
k10.setup(net, dataloaders4,dataloaders6b )
# Evaluate OOD detection performance
fpr_at_95_tpr = k10.evaluate(dataloaders4['test'], dataloaders6b['test'], net)
print(f"FPR at 95% TPR: {fpr_at_95_tpr:.4f}")

Setup: 100%|█████████████████████████████████████████████████████████████████████████| 313/313 [00:09<00:00, 33.81it/s]
Processing OOD data: 100%|█████████████████████████████████████████████████████████████| 63/63 [00:36<00:00,  1.72it/s]
Processing OOD data: 100%|█████████████████████████████████████████████████████████████| 94/94 [00:51<00:00,  1.81it/s]

FPR at 95% TPR: 0.7538





In [25]:
k100 = KNNPostprocessor(100)
# Setup Faiss index with in-distribution training data
k100.setup(net, dataloaders4,dataloaders6b )
# Evaluate OOD detection performance
fpr_at_95_tpr = k10.evaluate(dataloaders4['test'], dataloaders6b['test'], net)
print(f"FPR at 95% TPR: {fpr_at_95_tpr:.4f}")

Setup: 100%|█████████████████████████████████████████████████████████████████████████| 313/313 [00:09<00:00, 34.29it/s]
Processing OOD data: 100%|█████████████████████████████████████████████████████████████| 63/63 [00:37<00:00,  1.67it/s]
Processing OOD data: 100%|█████████████████████████████████████████████████████████████| 94/94 [00:50<00:00,  1.85it/s]

FPR at 95% TPR: 0.7538





In [28]:
k50 = KNNPostprocessor(50)
k50.setup(net, dataloaders4,dataloaders6b )

Setup: 100%|█████████████████████████████████████████████████████████████████████████| 313/313 [00:09<00:00, 34.54it/s]


In [29]:
classes = ('deer', 'dog', 'frog', 'horse', 'ship', 'truck') # Adjust based on actual classes in dataset6b

# Loop through each class in dataset6b
for i, cls_name in enumerate(classes):
    # Filter test data to include only the current class for OOD evaluation
    class_indices = [idx for idx, (_, label) in enumerate(image_datasets6b['test'].imgs) if label == i]
    class_subset = torch.utils.data.Subset(image_datasets6b['test'], class_indices)
    class_loader = DataLoader(class_subset, batch_size=64, shuffle=False, num_workers=4)
    
    # Evaluate OOD detection performance for the current class
    fpr_at_95_tpr = k50.evaluate(dataloaders4['test'], class_loader, net)
    print(f"FPR at 95% TPR for class '{cls_name}': {fpr_at_95_tpr:.4f}")

Processing OOD data: 100%|█████████████████████████████████████████████████████████████| 63/63 [00:38<00:00,  1.66it/s]
Processing OOD data: 100%|█████████████████████████████████████████████████████████████| 16/16 [00:14<00:00,  1.08it/s]


FPR at 95% TPR for class 'deer': 0.6143


Processing OOD data: 100%|█████████████████████████████████████████████████████████████| 63/63 [00:35<00:00,  1.77it/s]
Processing OOD data: 100%|█████████████████████████████████████████████████████████████| 16/16 [00:14<00:00,  1.08it/s]


FPR at 95% TPR for class 'dog': 0.6127


Processing OOD data: 100%|█████████████████████████████████████████████████████████████| 63/63 [00:36<00:00,  1.73it/s]
Processing OOD data: 100%|█████████████████████████████████████████████████████████████| 16/16 [00:15<00:00,  1.06it/s]


FPR at 95% TPR for class 'frog': 0.6058


Processing OOD data: 100%|█████████████████████████████████████████████████████████████| 63/63 [00:36<00:00,  1.72it/s]
Processing OOD data: 100%|█████████████████████████████████████████████████████████████| 16/16 [00:14<00:00,  1.07it/s]


FPR at 95% TPR for class 'horse': 0.4608


Processing OOD data: 100%|█████████████████████████████████████████████████████████████| 63/63 [00:35<00:00,  1.77it/s]
Processing OOD data: 100%|█████████████████████████████████████████████████████████████| 16/16 [00:15<00:00,  1.04it/s]


FPR at 95% TPR for class 'ship': 0.9575


Processing OOD data: 100%|█████████████████████████████████████████████████████████████| 63/63 [00:36<00:00,  1.72it/s]
Processing OOD data: 100%|█████████████████████████████████████████████████████████████| 16/16 [00:14<00:00,  1.11it/s]

FPR at 95% TPR for class 'truck': 0.8045



