In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve

batch_size = 128
epochs = 10
lr = 1e-3
mc_samples = 20
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)
np.random.seed(0)



In [2]:
class CNN(nn.Module):
    def __init__(self, dropout_p=0.3, num_classes=10):
        super().__init__()

        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)

        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(p=dropout_p)

        self.fc1 = nn.Linear(128 * 4 * 4, 256)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.dropout(x)

        x = self.pool(F.relu(self.conv2(x)))
        x = self.dropout(x)

        x = self.pool(F.relu(self.conv3(x)))
        x = self.dropout(x)

        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)

        return x

In [3]:
transform_cifar = transforms.Compose([
    transforms.ToTensor()
])

transform_mnist = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.expand(3, -1, -1))
])

train_id = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_cifar)
test_id = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_cifar)
test_ood = datasets.MNIST(root='./data', train=False, download=True, transform=transform_mnist)

train_id_loader = DataLoader(train_id, batch_size=batch_size, shuffle=True)
test_id_loader = DataLoader(test_id, batch_size=batch_size, shuffle=False)
test_ood_loader = DataLoader(test_ood, batch_size=batch_size, shuffle=False)

100%|██████████| 170M/170M [00:03<00:00, 47.6MB/s]
100%|██████████| 9.91M/9.91M [00:00<00:00, 17.7MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 481kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.49MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 10.7MB/s]


In [4]:
def train(model, train_loader, epochs, lr):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_function = nn.CrossEntropyLoss()

    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0.0
        correct_samples = 0

        for x, y in train_loader:
            x, y = x.to(device), y.to(device)

            optimizer.zero_grad()

            logits = model(x)
            loss = loss_function(logits, y)

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            preds = logits.argmax(dim=1)

            correct_samples += (preds == y).sum().item()

        print(f'Epoch {epoch}; Train loss {total_loss / len(train_loader)}; Accuracy {correct_samples / len(train_loader.dataset) * 100}')

In [5]:
def compute_ood_metrics(id_scores, ood_scores):
    y_true = np.concatenate([
        np.zeros_like(id_scores),
        np.ones_like(ood_scores)
    ])
    scores = np.concatenate([id_scores, ood_scores])

    auroc = roc_auc_score(y_true, scores)
    aupr = average_precision_score(y_true, scores)

    fpr, tpr, _ = roc_curve(y_true, scores)
    target_tpr = 0.95
    idxs = np.where(tpr >= target_tpr)[0]
    if len(idxs) > 0:
        fpr95 = fpr[idxs[0]]
    else:
        fpr95 = 1.0

    print(f'AUROC {auroc}')
    print(f'AUPR {aupr}')
    print(f'FPR@95%TPR {fpr95}')

    return auroc, aupr, fpr95

In [6]:
def get_softmax_ood_scores(model, id_loader, ood_loader):
    model.to(device)
    model.eval()

    id_scores = []
    ood_scores = []

    with torch.no_grad():
        for x, _ in id_loader:
            x = x.to(device)
            logits = model(x)
            probs = F.softmax(logits, dim=1)
            max_probs, _ = probs.max(dim=1)
            scores = 1.0 - max_probs

            id_scores.append(scores.cpu().numpy())

    with torch.no_grad():
        for x, _ in ood_loader:
            x = x.to(device)
            logits = model(x)
            probs = F.softmax(logits, dim=1)
            max_probs, _ = probs.max(dim=1)
            scores = 1.0 - max_probs

            ood_scores.append(scores.cpu().numpy())

    id_scores = np.concatenate(id_scores)
    ood_scores = np.concatenate(ood_scores)

    return id_scores, ood_scores


def get_mcd_ood_entropy(model, x, T=20):
    model.to(device)
    model.train()

    with torch.no_grad():
        probs_T = []
        for _ in range(T):
            logits = model(x)
            probs = F.softmax(logits, dim=1)

            probs_T.append(probs.unsqueeze(0))

        probs_T = torch.cat(probs_T, dim=0)

    p_mean = probs_T.mean(dim=0)

    eps = 1e-8
    entropy = -torch.sum(p_mean * torch.log(p_mean + eps), dim=1)

    return entropy


def get_mcd_ood_scores(model, id_loader, ood_loader, T=20):
    model.to(device)

    id_scores = []
    ood_scores = []

    for x, _ in id_loader:
        x = x.to(device)
        entropy = get_mcd_ood_entropy(model, x, T=T)

        id_scores.append(entropy.cpu().numpy())

    for x, _ in ood_loader:
        x = x.to(device)
        entropy = get_mcd_ood_entropy(model, x, T=T)

        ood_scores.append(entropy.cpu().numpy())

    id_scores = np.concatenate(id_scores)
    ood_scores = np.concatenate(ood_scores)

    return id_scores, ood_scores

In [7]:
model = CNN(dropout_p=0.3, num_classes=10)
train(model, train_id_loader, epochs=epochs, lr=lr)

Epoch 1; Train loss 1.7258383511277415; Accuracy 36.298
Epoch 2; Train loss 1.369394351454342; Accuracy 50.404
Epoch 3; Train loss 1.2248715022030998; Accuracy 56.257999999999996
Epoch 4; Train loss 1.128293522331111; Accuracy 59.858
Epoch 5; Train loss 1.0697537838955364; Accuracy 61.946
Epoch 6; Train loss 1.007762765945376; Accuracy 64.336
Epoch 7; Train loss 0.9637599073712478; Accuracy 65.994
Epoch 8; Train loss 0.9319869461266891; Accuracy 67.078
Epoch 9; Train loss 0.8966699090150311; Accuracy 68.396
Epoch 10; Train loss 0.8633091169245103; Accuracy 69.80199999999999


In [8]:
softmax_id_scores, softmax_ood_scores = get_softmax_ood_scores(model, test_id_loader, test_ood_loader)


In [9]:
auroc, aupr, fpr95 = compute_ood_metrics(softmax_id_scores, softmax_ood_scores)

AUROC 0.64310357
AUPR 0.5806254718909908
FPR@95%TPR 0.7788


In [10]:
mcd_id_scores, mcd_ood_scores = get_mcd_ood_scores(model, test_id_loader, test_ood_loader)

In [11]:
mcd_auroc, mcd_aupr, mcd_fpr95 = compute_ood_metrics(mcd_id_scores, mcd_ood_scores)

AUROC 0.717056355
AUPR 0.6321278668951666
FPR@95%TPR 0.6845


In [12]:
class CNN_ReAct(CNN):
    def forward(self, x, return_features=False):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.dropout(x)

        x = self.pool(F.relu(self.conv2(x)))
        x = self.dropout(x)

        x = self.pool(F.relu(self.conv3(x)))
        x = self.dropout(x)

        x = x.view(x.size(0), -1)
        features = F.relu(self.fc1(x))
        features = self.dropout(features)
        logits = self.fc2(features)

        if return_features:
            return logits, features
        return logits

In [18]:
def compute_react_threshold(model, train_loader, quantile=0.9):
    model.to(device)
    model.eval()

    activations = []

    with torch.no_grad():
        for x, _ in train_loader:
            x = x.to(device)
            _, feats = model(x, return_features=True)
            activations.append(feats.cpu().numpy())

    activations = np.concatenate(activations, axis=0)
    threshold = np.quantile(activations, quantile)

    print(f"ReAct threshold (q={quantile}): {threshold}")
    return threshold

In [28]:
def get_react_energy_ood_scores(model, id_loader, ood_loader, threshold, T=1.0):
    """
    ReAct + Energy:
    1) получаем features (после fc1+ReLU)
    2) клиппим features по threshold
    3) пересчитываем logits = fc2(features)
    4) energy: E(x) = -T * logsumexp(logits/T)
       OOD-score = E(x)  (чем меньше скор, тем более верим, что это OOD)
    """
    model.to(device)
    model.eval()

    id_scores = []
    ood_scores = []

    def collect(loader, storage):
        with torch.no_grad():
            for x, _ in loader:
                x = x.to(device)

                # logits, feats из модели, которая умеет return_features=True (CNN_ReAct)
                logits, feats = model(x, return_features=True)

                # ReAct clip
                feats = torch.clamp(feats, max=threshold)

                # пересчёт логитов после клиппинга
                logits = model.fc2(feats)

                # Energy
                energy = -T * torch.logsumexp(logits / T, dim=1)

                # OOD-score
                scores = energy
                storage.append(scores.cpu().numpy())

    collect(id_loader, id_scores)
    collect(ood_loader, ood_scores)

    return np.concatenate(id_scores), np.concatenate(ood_scores)

In [29]:
# модель, которая возвращает фичи
react_model = CNN_ReAct(dropout_p=0.3, num_classes=10)
react_model.load_state_dict(model.state_dict())

# порог ReAct (квантиль)
react_threshold = compute_react_threshold(react_model, train_id_loader, quantile=0.9)

# ReAct + Energy scores
react_energy_id_scores, react_energy_ood_scores = get_react_energy_ood_scores(
    react_model,
    test_id_loader,
    test_ood_loader,
    threshold=react_threshold,
    T=1.0
)

# метрики
react_energy_auroc, react_energy_aupr, react_energy_fpr95 = compute_ood_metrics(
    react_energy_id_scores,
    react_energy_ood_scores
)

ReAct threshold (q=0.9): 0.9684548377990723
AUROC 0.89320782
AUPR 0.8177956025178723
FPR@95%TPR 0.2679


ReAct работает лучше Softmax и не такой тяжёлый, как MC Dropout, а в связке с Energy он даёт самые уверенные результаты на OOD. В итоге это самый простой и эффективный вариант из всех протестированных.