In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from sklearn.metrics import roc_auc_score, roc_curve
import numpy as np
import matplotlib.pyplot as plt
from openood.networks import ResNet18_32x32

In [2]:
train_dir10 = './data/images_classic/cifar10/cifar10/train'
test_dir10 = './data/images_classic/cifar10/cifar10/test'

In [3]:
# Data transformations (augmentation and normalization)
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616])
    ]),
    'test': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616])
    ]),
}

In [4]:
# Define the four classes and the remaining six classes
classes_subset1 = ['airplane', 'automobile', 'ship', 'truck', 'cat', 'dog'] 
classes_subset2 =  ['frog', 'horse', 'bird', 'deer'] # Remaining classes

# Load full CIFAR-10 dataset (assuming data_transforms is already defined)
train_dir10 = './data/images_classic/cifar10/cifar10/train'
test_dir10 = './data/images_classic/cifar10/cifar10/test'
image_datasets10 = {
    'train': datasets.ImageFolder(train_dir10, transform=data_transforms['train']),
    'test': datasets.ImageFolder(test_dir10, transform=data_transforms['test'])
}

# Function to filter dataset by classes
def filter_by_classes(dataset, classes_to_include):
    class_indices = [dataset.class_to_idx[cls] for cls in classes_to_include]
    indices = [i for i, (_, label) in enumerate(dataset.samples) if label in class_indices]
    return Subset(dataset, indices)

# Create subsets for the four classes and the other six classes
subset1 = {'train':filter_by_classes(image_datasets10['train'], classes_subset1),
              'test':filter_by_classes(image_datasets10['test'], classes_subset1)}
subset2 = {'train':filter_by_classes(image_datasets10['train'], classes_subset2),
            'test':filter_by_classes(image_datasets10['test'], classes_subset2)}

# Create DataLoaders for each subset
dataloaders_subset1 = {
    'train': DataLoader(subset1['train'], batch_size=64, shuffle=True, num_workers=4),
    'test': DataLoader(subset1['test'], batch_size=64, shuffle=False, num_workers=4)
}
dataloaders_subset2 = {
    'train': DataLoader(subset2['train'], batch_size=64, shuffle=True, num_workers=4),
    'test': DataLoader(subset2['test'], batch_size=64, shuffle=False, num_workers=4)
}

In [24]:
class ARPL(nn.Module):
    def __init__(self, num_classes, feature_dim=512):
        super(ARPL, self).__init__()
        self.feature_extractor = ResNet18_32x32(num_classes=10)
        self.prototype_layer = nn.Parameter(torch.randn(num_classes, feature_dim))

    def forward(self, x,rf=False):
        x,features = self.feature_extractor(x,return_feature=True)
        if rf:
            return x, features
        else:
            return x
        

    def compute_loss(self, features, labels):
        # Contrastive loss for prototype learning
        prototypes = self.prototype_layer
        distances = torch.cdist(features, prototypes)  # Pairwise distances
        logits = -distances

        ce_loss = nn.CrossEntropyLoss()(logits, labels)
        return ce_loss

In [50]:
# Training Loop
def train_model(model, train_loader):
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    model.train()

    for epoch in range(20):
        total_loss = 0
        for images, labels in train_loader:
            images, labels = images.to("cuda"), labels.to("cuda")

            optimizer.zero_grad()
            _,features = model(images,rf=True)
            loss = model.compute_loss(features, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch [{epoch + 1}/{20}], Loss: {total_loss / len(train_loader):.4f}")

In [26]:
model = ARPL(num_classes=10)
model

ARPL(
  (feature_extractor): ResNet18_32x32(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=Fal

In [27]:
model.to('cuda')

ARPL(
  (feature_extractor): ResNet18_32x32(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=Fal

In [51]:
train_model(model, dataloaders_subset1['train'])

Epoch [1/20], Loss: 0.1453
Epoch [2/20], Loss: 0.1333
Epoch [3/20], Loss: 0.1293
Epoch [4/20], Loss: 0.1191
Epoch [5/20], Loss: 0.1106
Epoch [6/20], Loss: 0.1085
Epoch [7/20], Loss: 0.1006
Epoch [8/20], Loss: 0.1012
Epoch [9/20], Loss: 0.0897
Epoch [10/20], Loss: 0.0875
Epoch [11/20], Loss: 0.0821
Epoch [12/20], Loss: 0.0768
Epoch [13/20], Loss: 0.0731
Epoch [14/20], Loss: 0.0678
Epoch [15/20], Loss: 0.0669
Epoch [16/20], Loss: 0.0627
Epoch [17/20], Loss: 0.0617
Epoch [18/20], Loss: 0.0571
Epoch [19/20], Loss: 0.0551
Epoch [20/20], Loss: 0.0516


In [34]:
import os
import sys
import numpy as np

def get_curve_online(known, novel, stypes=['Bas']):
    tp, fp = dict(), dict()
    tnr_at_tpr95 = dict()
    for stype in stypes:
        known.sort()
        novel.sort()
        end = np.max([np.max(known), np.max(novel)])
        start = np.min([np.min(known), np.min(novel)])
        num_k = known.shape[0]
        num_n = novel.shape[0]
        tp[stype] = -np.ones([num_k+num_n+1], dtype=int)
        fp[stype] = -np.ones([num_k+num_n+1], dtype=int)
        tp[stype][0], fp[stype][0] = num_k, num_n
        k, n = 0, 0
        for l in range(num_k+num_n):
            if k == num_k:
                tp[stype][l+1:] = tp[stype][l]
                fp[stype][l+1:] = np.arange(fp[stype][l]-1, -1, -1)
                break
            elif n == num_n:
                tp[stype][l+1:] = np.arange(tp[stype][l]-1, -1, -1)
                fp[stype][l+1:] = fp[stype][l]
                break
            else:
                if novel[n] < known[k]:
                    n += 1
                    tp[stype][l+1] = tp[stype][l]
                    fp[stype][l+1] = fp[stype][l] - 1
                else:
                    k += 1
                    tp[stype][l+1] = tp[stype][l] - 1
                    fp[stype][l+1] = fp[stype][l]
        tpr95_pos = np.abs(tp[stype] / num_k - .95).argmin()
        tnr_at_tpr95[stype] = 1. - fp[stype][tpr95_pos] / num_n
    return tp, fp, tnr_at_tpr95

def metric_ood(x1, x2, stypes=['Bas'], verbose=True):
    tp, fp, tnr_at_tpr95 = get_curve_online(x1, x2, stypes)
    results = dict()
    mtypes = ['TNR', 'AUROC', 'DTACC', 'AUIN', 'AUOUT']
    if verbose:
        print('      ', end='')
        for mtype in mtypes:
            print(' {mtype:6s}'.format(mtype=mtype), end='')
        print('')
        
    for stype in stypes:
        if verbose:
            print('{stype:5s} '.format(stype=stype), end='')
        results[stype] = dict()
        
        # TNR
        mtype = 'TNR'
        results[stype][mtype] = 100.*tnr_at_tpr95[stype]
        if verbose:
            print(' {val:6.3f}'.format(val=results[stype][mtype]), end='')
        
        # AUROC
        mtype = 'AUROC'
        tpr = np.concatenate([[1.], tp[stype]/tp[stype][0], [0.]])
        fpr = np.concatenate([[1.], fp[stype]/fp[stype][0], [0.]])
        results[stype][mtype] = 100.*(-np.trapz(1.-fpr, tpr))
        if verbose:
            print(' {val:6.3f}'.format(val=results[stype][mtype]), end='')
        
        # DTACC
        mtype = 'DTACC'
        results[stype][mtype] = 100.*(.5 * (tp[stype]/tp[stype][0] + 1.-fp[stype]/fp[stype][0]).max())
        if verbose:
            print(' {val:6.3f}'.format(val=results[stype][mtype]), end='')
        
        # AUIN
        mtype = 'AUIN'
        denom = tp[stype]+fp[stype]
        denom[denom == 0.] = -1.
        pin_ind = np.concatenate([[True], denom > 0., [True]])
        pin = np.concatenate([[.5], tp[stype]/denom, [0.]])
        results[stype][mtype] = 100.*(-np.trapz(pin[pin_ind], tpr[pin_ind]))
        if verbose:
            print(' {val:6.3f}'.format(val=results[stype][mtype]), end='')
        
        # AUOUT
        mtype = 'AUOUT'
        denom = tp[stype][0]-tp[stype]+fp[stype][0]-fp[stype]
        denom[denom == 0.] = -1.
        pout_ind = np.concatenate([[True], denom > 0., [True]])
        pout = np.concatenate([[0.], (fp[stype][0]-fp[stype])/denom, [.5]])
        results[stype][mtype] = 100.*(np.trapz(pout[pout_ind], 1.-fpr[pout_ind]))
        if verbose:
            print(' {val:6.3f}'.format(val=results[stype][mtype]), end='')
            print('')
    
    return results

def compute_oscr(pred_k, pred_u, labels):
    x1, x2 = np.max(pred_k, axis=1), np.max(pred_u, axis=1)
    pred = np.argmax(pred_k, axis=1)
    correct = (pred == labels)
    m_x1 = np.zeros(len(x1))
    m_x1[pred == labels] = 1
    k_target = np.concatenate((m_x1, np.zeros(len(x2))), axis=0)
    u_target = np.concatenate((np.zeros(len(x1)), np.ones(len(x2))), axis=0)
    predict = np.concatenate((x1, x2), axis=0)
    n = len(predict)

    # Cutoffs are of prediction values
    
    CCR = [0 for x in range(n+2)]
    FPR = [0 for x in range(n+2)] 

    idx = predict.argsort()

    s_k_target = k_target[idx]
    s_u_target = u_target[idx]

    for k in range(n-1):
        CC = s_k_target[k+1:].sum()
        FP = s_u_target[k:].sum()

        # True Positive Rate
        CCR[k] = float(CC) / float(len(x1))
        # False Positive Rate
        FPR[k] = float(FP) / float(len(x2))

    CCR[n] = 0.0
    FPR[n] = 0.0
    CCR[n+1] = 1.0
    FPR[n+1] = 1.0

    # Positions of ROC curve (FPR, TPR)
    ROC = sorted(zip(FPR, CCR), reverse=True)

    OSCR = 0

    # Compute AUROC Using Trapezoidal Rule
    for j in range(n+1):
        h = ROC[j][0] - ROC[j+1][0]
        w = (ROC[j][1] + ROC[j+1][1]) / 2.0

        OSCR = OSCR + h*w

    return OSCR

In [36]:
class ARPLoss(nn.Module):
    def __init__(self, feat_dim=512, num_classes=10, **kwargs):
        super(ARPLoss, self).__init__()
        self.prototype_layer = nn.Parameter(torch.randn(num_classes, feat_dim))

    def forward(self, features, labels=None):
        prototypes = self.prototype_layer.to(features.device)  
        distances = torch.cdist(features, prototypes)  # Pairwise distances
        logits = -distances

        if labels is None:
            return logits, 0

        ce_loss = nn.CrossEntropyLoss()(logits, labels)
        return logits, ce_loss

In [41]:
criterion = ARPLoss()

# Define options
options = {'use_gpu': True}

def test(net, criterion, testloader, outloader, epoch=None, **options):
    net.eval()
    correct, total = 0, 0

    torch.cuda.empty_cache()

    _pred_k, _pred_u, _labels = [], [], []

    with torch.no_grad():
        for data, labels in testloader:
            if options['use_gpu']:
                data, labels = data.cuda(), labels.cuda()

            with torch.set_grad_enabled(False):
                y, x = net(data, True)
                logits, _ = criterion(x, y)
                predictions = logits.data.max(1)[1]
                total += labels.size(0)
                correct += (predictions == labels.data).sum()

                _pred_k.append(logits.data.cpu().numpy())
                _labels.append(labels.data.cpu().numpy())

        for batch_idx, (data, labels) in enumerate(outloader):
            if options['use_gpu']:
                data, labels = data.cuda(), labels.cuda()

            with torch.set_grad_enabled(False):
                y, x = net(data, True)
                logits, _ = criterion(x, y)
                _pred_u.append(logits.data.cpu().numpy())

    # Accuracy
    acc = float(correct) * 100. / float(total)
    print('Acc: {:.5f}'.format(acc))

    _pred_k = np.concatenate(_pred_k, 0)
    _pred_u = np.concatenate(_pred_u, 0)
    _labels = np.concatenate(_labels, 0)

    # Out-of-Distribution detection evaluation
    x1, x2 = np.max(_pred_k, axis=1), np.max(_pred_u, axis=1)
    results = metric_ood(x1, x2)['Bas']

    # OSCR
    _oscr_score = compute_oscr(_pred_k, _pred_u, _labels)

    results['ACC'] = acc
    results['OSCR'] = _oscr_score * 100.

    return results

In [42]:
results = test(model, criterion, dataloaders_subset1['test'], dataloaders_subset2['test'], **options)

Acc: 0.70000
       TNR    AUROC  DTACC  AUIN   AUOUT 
Bas     0.025 20.417 50.000 43.519 26.247


In [52]:
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
model.eval()

# Initialize variables to track accuracy per class
correct_preds = {classname: 0 for classname in class_names}
total_preds = {classname: 0 for classname in class_names}

# Evaluation loop
with torch.no_grad():
    for inputs, labels in dataloaders_subset1['test']:
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')

        outputs= model(inputs)
        _, preds = torch.max(outputs, 1)

        # Track accuracy for each class
        for label, pred in zip(labels, preds):
            if pred == label:
                correct_preds[class_names[label]] += 1
            total_preds[class_names[label]] += 1

for classname, correct_count in correct_preds.items():
    if total_preds[classname] > 0:
        accuracy = 100 * float(correct_count) / total_preds[classname]
        print(f'Accuracy for class {classname}: {accuracy:.2f}%')
    else:
        print(f'Accuracy for class {classname}: No samples')

Accuracy for class airplane: 0.10%
Accuracy for class automobile: 0.00%
Accuracy for class bird: No samples
Accuracy for class cat: 0.10%
Accuracy for class deer: No samples
Accuracy for class dog: 0.10%
Accuracy for class frog: No samples
Accuracy for class horse: No samples
Accuracy for class ship: 0.00%
Accuracy for class truck: 0.00%


In [49]:
results['OSCR']

0.09941666666666202