# Assignment 8: Open-Set Classification

In this assignment, we develop a network that is capable of correctly classifying known classes while at the same time rejecting unknown samples that occur during inference time.
To showcase the capability, we make use of the MNIST dataset that we artificially split into known, negative and unknown classes; this allows us to train a network on the data without requiring too expensive hardware.
Known and negative classes are used during training, and unknown classes appear only in the testing set.

## Dataset
We split the MNIST dataset into 4 known classes, 4 negative classes (used for training) and 2 unknown classes (used only for testing).
While several splits might be possible, here we restrict to the following:
* Known class indexes: (1, 4, 7, 9)
* Negative class indexes: (0, 2, 3, 5)
* Unknown class indexes: (6,8)

Please note that, in PyTorch, class indexing starts at 0 (other than in the lecture where class indexing starts at 1).

We rely on the `torchvision.datasets.MNIST` implementation of the MNIST dataset, which we adapt to our needs.
The constructor of our Dataset class takes one parameter that defines the purpose of this dataset (`"train", "validation", "test"`).
The `"train"` partition uses the training samples of the *known* and the *negative* classes.
The `"validation"` partition uses the test samples of the *known* and the *negative* classes.
Finally, the `"test"` partition uses the test samples of the *known* and the *unknown* classes.

In our implementation of the Dataset class, we need to implement two functions.
* First, the constructor `__init__(self, purpose)` selects the data based on our purpose. 
* Second, the index function `__getitem__(self, n)` returns a pair $(X^n, \vec t{\,}^{n})$ for the sample with the index $n$, where $X \in \mathbb R^{1\times28\times28}$ with values in range $[0,1]$ and $\vec t \in \mathbb R^{O}$, see below.

Since our loss function (cf. Task 5) requires our target vectors to be in vector format, we need to convert the target index $\tau^n$ into its vector representation $\vec t{\,}^n$.
Particularly, we need to provide the following target vectors: 

<center> 

 $\tau^n = 1 : \vec t{\,}^n = (1,0,0,0)$ 

 $\tau^n = 4 : \vec t{\,}^n = (0,1,0,0)$ 
 
 $\tau^n = 7 : \vec t{\,}^n = (0,0,1,0)$
 
 $\tau^n = 9 : \vec t{\,}^n = (0,0,0,1)$

 else: $\vec t{\,}^n = (\frac14,\frac14,\frac14,\frac14)$

</center>


### Task 1: Target Vectors

Implement a function that generates a target vector for any of the ten different classes according to above description. The return value should be a `torch.tensor` of type float.



In [None]:
import torch
import torchvision

In [None]:

# define the three types of classes
known_classes = (1,4,7,9)
negative_classes = (0,2,3,5)
unknown_classes = (6, 8)
O = len(known_classes)

# define one-hot vectors
labels_known = torch.nn.functional.one_hot(torch.tensor([0,1,2,3]), num_classes=4)
label_unknown = torch.tensor([1/O,1/O,1/O,1/O])

def target_vector(index):
  # select correct one-hot vector for known classes, and the 1/O-vectors for unknown classes
    if index in known_classes:
        return labels_known[known_classes.index(index)]
    else:
        return label_unknown
    
for i in range(0,10,1):
    print(i, target_vector(i))

### Test 1: Check your Target Vectors

Test that your target vectors are correct, for all tpyes of known and unknown samples.


In [None]:
# check that the target vectors for known classes are correct
for index in known_classes:
  t = target_vector(index)
  print(index, t) 
  assert max(t) == 1
  assert sum(t) == 1

# check that the target vectors for negative and unknown classes are correct
for index in negative_classes + unknown_classes:
  t = target_vector(index)
  print(index, t)
  assert max(t) == 0.25
  assert sum(t) == 1

### Tasks 2 and 3: Dataset Construction and Dataset Item Selection

Write a dataset class that derives from `torchvision.datasets.MNIST` in `PyTorch` and adapts some parts of it. 
In the constructor, make sure that you let `PyTorch` load the dataset by calling the base class constructor `super` with the desired parameters. Afterward, the `self.data` and `self.targets` are populated with all samples and target indexes.
From these, we need to sub-select the samples that fit our current `purpose` and store them back to `self.data` and `self.targets`.

Second, we need to implement the index function of our dataset, where we need to return both the image and the target vector.
The images in `self.data` were originally stored as `uint8` values in the dimension $\mathbb N^{N\times28\times28}$ with values in $[0, 255]$.
The targets in `self.targets` were originally stored as class indexes in the dimension $\mathbb N^N$. Make sure that you return both in the desired format.

Notes:

* Since Jupyter Notebook does not allow splitting classes over several code boxes, the two tasks are required to be solved in the same code box.
* **The definition below is just one possibility.** There are many ways to implement this dataset interface. 
* With a clever implementation of the constructor, there is no need to overwrite the `__getitem__(self,index)` function.
* Depending on your implementation, you might also need to overwrite the `__len__(self)` function.

In [None]:
import torch.nn.functional as F

class DataSet(torchvision.datasets.MNIST):
  def __init__(self, purpose="train"):
    """
    Initializes a custom MNIST dataset.
    
    Args:
      purpose (str): One of "train", "validation", or "test".
                     Determines which subset of classes to load.
    """
    # Define which classes and which MNIST split (train/test) to use
    if purpose == "train":
        self.valid_classes = list(range(8)) # Digits 0-7
        use_train_split = True
    elif purpose == "validation":
        self.valid_classes = [8] # Digit 8
        use_train_split = True # Validation data comes from the original training set
    elif purpose == "test":
        self.valid_classes = [9] # Digit 9
        use_train_split = False # Test data comes from the original test set
    else:
        raise ValueError("`purpose` must be one of 'train', 'validation', or 'test'")

    # Call the base class constructor to load the appropriate MNIST split
    super().__init__(root="./data/MNIST",
                     train=use_train_split,
                     download=True)

    # Filter the data and targets to include only the valid classes
    # We create a boolean mask to select the samples.
    mask = torch.isin(self.targets, torch.tensor(self.valid_classes))
    
    # Apply the mask to keep only the desired samples
    self.data = self.data[mask]
    self.targets = self.targets[mask]

  def __getitem__(self, index):
    """
    Retrieves and processes a single data sample.

    Args:
      index (int): The index of the sample to retrieve.

    Returns:
      tuple: A tuple containing the processed image tensor and the one-hot encoded target tensor.
    """
    # Retrieve the raw image and target index using the provided index
    # The parent class __getitem__ would work, but we need to apply transformations.
    img, target_idx = self.data[index], self.targets[index]

    # Process the image:
    # 1. Convert from uint8 [0, 255] to float32 [0.0, 1.0]
    # 2. Add a channel dimension: (28, 28) -> (1, 28, 28)
    input_tensor = (img.float() / 255.0).unsqueeze(0)

    # Process the target:
    # Convert the class index into a one-hot encoded vector.
    # MNIST has 10 total classes (0-9).
    target_tensor = F.one_hot(target_idx, num_classes=10).float()
    
    return input_tensor, target_tensor


### Test 2: Data Sets


Instantiate the training dataset.
Implement a data loader for the training dataset with a batch size of 64.
Assure that all inputs are of the desired type and shape.
Assert that the target values are in the correct format, and the sum of the target values per sample is one.


In [None]:
# instantiate the training dataset
train_set = DataSet(purpose="train")
train_loader = torch.utils.data.DataLoader(train_set, 64, shuffle=True)
actual_len = len(train_set)
print(f"Length of the filtered training set: {actual_len}")

# assert that we have not filtered out all samples
assert 60000 > actual_len > 45000

# check the batch and assert valid data and sizes
for x,t in train_loader:
  assert len(x) <= 64
  assert len(t) == len(x)
  assert torch.all(torch.sum(t, axis = 1) == 1)
  assert x.shape == torch.Size([x.shape[0], 1, 28, 28])
  assert x.dtype == torch.float32
  assert torch.max(x) <= 1

### Task 4: Utility Function

Implement a function that splits a batch of samples into known and negative/unknown parts. For the known parts, also provide the target vectors.
How can we know which of the data samples are known samples, and which are negative/unknown?

This function needs to return three elements:
* First, the samples from the batch that belong to known classes.
* Second, the target vectors that belong to the known classes.
* Finally, the samples from the batch that belong to negative/unknown classes.

In [None]:
def split_known_unknown(batch, targets):
  """
  Splits a batch into known and unknown samples.

  Args:
    batch (torch.Tensor): The input data samples.
    targets (torch.Tensor): The one-hot encoded target vectors.

  Returns:
    tuple: A tuple containing (known_samples, known_targets, unknown_samples).
  """
  # A target is "known" if it's a one-hot vector, meaning its elements sum to 1.
  # An "unknown" target is a zero vector, summing to 0.
  target_sums = torch.sum(targets, dim=1)
  
  # Create boolean masks to select the indexes
  known = (target_sums == 1)
  unknown = (target_sums == 0)
  
  # Return the sliced tensors based on the masks
  return batch[known], targets[known], batch[unknown]


## Loss Function and Confidence

We write our own PyTorch implementation of our loss function.
Particularly, we implement a manual way to define the derivative of our loss function via `torch.autograd.Function`, which allows us to define the forward and backward pass on our own.
For this purpose, we need to implement two `static` functions in our loss.
The function `forward(ctx, logits, targets)` is required to compute the loss value and allows us to store some variables in the context of the backward pass.
The `backward(ctx, result)` provides us with the result of the forward function (the loss value) as well as the context with our stored variables.
Here, we need to compute the derivative of the loss with respect to both of the inputs to the forward function (which might look a bit confusing), i.e.,$\frac{\partial \mathbf{J}^{CCE}}{\partial \mathbf{Z}}$ and $\frac{\partial \mathbf{J}^{CCE}}{\partial \mathbf{T}}$.
Since the latter is not required, we can also simply return `None` for the second derivative.

<font color=#FF000>Hint: if you think the implementation of loss function is too hard, you can also cross-entropy as your loss function (**since PyTorch version 1.11**).</font>

### Task 5: Loss Function Implementation

Implement a `torch.autograd.Function` class for the adapted SoftMax function according to the equations provided in the lecture.
You might want to compute the log of the network output $\ln y_o$ from the logits $z_o$ via `torch.nn.functional.log_softmax`.
Store all the data required for the backward pass in the context during `forward`, and extract these from the context during `backward`.

In [None]:
class AdaptedSoftMax(torch.autograd.Function):

  # implement the forward propagation
  @staticmethod
  def forward(ctx, logits, targets):
    """
    Computes the forward pass for the cross-entropy loss.
    """
    # compute the log probabilities via log_softmax
    log_probs = F.log_softmax(logits, dim=1)
    
    # We need the softmax probabilities (y) and targets (t) for the backward pass,
    # as the gradient is (y - t).
    probs = torch.exp(log_probs)
    ctx.save_for_backward(probs, targets)
    
    # Compute the cross-entropy loss, averaged over the batch.
    # J = - (1/N) * sum(targets * log_probs)
    loss = -torch.sum(targets * log_probs) / logits.shape[0]
    return loss

  # implement Jacobian (backward pass)
  @staticmethod
  def backward(ctx, grad_output):
    """
    Computes the backward pass, correctly applying the chain rule.
    """
    probs, targets = ctx.saved_tensors
    batch_size = probs.shape[0]
    
    # This is the local gradient of the loss w.r.t. the logits
    local_grad = (probs - targets) / batch_size
    
    # The final gradient is the local gradient multiplied by the upstream gradient
    # (grad_output). This correctly implements the chain rule.
    dJ_dy = local_grad * grad_output
    
    return dJ_dy, None

# DO NOT REMOVE!
# here we set the adapted softmax function to be used later
adapted_softmax = AdaptedSoftMax.apply

### Task 5a: Alternative Loss Function

In case the loss function is too difficult to implement, you can also choose to rely on PyTorch's automatic gradient computation and simply define your loss function without the backward pass.

In this case, we only need to define the forward pass. A simple function `adapted_softmax(logits, targets)` is sufficient.

You can implement any variant of the categorical cross-entropy loss function on top of SoftMax activations as defined in the lecture.


In [None]:
def adapted_softmax_alt(logits, targets):
  """
  Computes cross-entropy loss using PyTorch's autograd.
  """
  # Compute the log probabilities (numerically stable)
  log_probs = F.log_softmax(logits, dim=1)
  
  # Compute the cross-entropy loss, averaged over the batch.
  # PyTorch will automatically compute the gradient for this computation.
  loss = -torch.sum(targets * log_probs) / logits.shape[0]
  
  return loss

### Task 6: Confidence Evaluation


Implement a function to compute the confidence value for a given batch of samples. 
Compute Softmax confidence and split these confidences between known and negative/unknown classes. 
For samples from known classes, sum up the SoftMax confidences of the correct class. 
For negative/unknown samples, sum 1 minus the maximum confidence for any of the known classes; also apply the $\frac1O$ correction for the minimum possible SoftMax confidence.

In [None]:
import torch
import torch.nn.functional as F

def confidence(logits, targets):
    """
    Compute confidence values for a batch of samples.
    - Known samples: sum of softmax confidence at the correct class.
    - Unknown samples: sum of (1 - max softmax confidence) + 1/O correction.
    """
    probs = F.softmax(logits, dim=1)
    O = logits.size(1)

    # Identify known vs unknown
    is_known = (targets.max(dim=1).values == 1.0)
    is_unknown = ~is_known

    # --- Known confidence ---
    conf_known = torch.sum(torch.sum(probs * targets, dim=1)[is_known])

    # --- Unknown confidence ---
    conf_unknown = torch.tensor(0.0, device=logits.device)
    if is_unknown.any():
        max_conf, _ = torch.max(probs[is_unknown], dim=1)
        conf_unknown = torch.sum(1.0 - max_conf + 1.0 / O)

    return conf_known + conf_unknown


### Test 3: Check Confidence Implementation

Test that your confidence implementation does what it is supposed to do.

Note that confidence values should always be between 0 and 1, other values indicate an issue in the implementation.

In [None]:
# select good logit vectors for known and unknown classes
logits = torch.tensor([[10., 0., 0., 0.], [-10., 0, -10., -10.], [0.,0.,0.,0.]])
# select the according target vectors for these classes
targets = torch.stack([target_vector(known_classes[0]), target_vector(known_classes[1]), target_vector(negative_classes[0])])


# the confidence should be close to 1 for all cases
assert 3 - confidence(logits, targets) < 1e-3

## Network and Training

We make use of the same convolutional network as utilized in Assignment 6, to which we append a final fully-connected layer with $K$ inputs and $O$ outputs.
Additionally, we replace the $\tanh$ activation function with the better-performing ReLU function.

The topology can be found in the following:
1. 2D convolutional layer with $Q_1$ channels, kernel size $5\times5$, stride 1 and padding 2
2. 2D maximum pooling layer with kernel size $2\times2$ and stride 2
3. Activation function ReLU
4. 2D convolutional layer with $Q_2$ channels, kernel size $5\times5$, stride 1 and padding 2
5. 2D maximum pooling layer with kernel size $2\times2$ and stride 2
6. Activation function ReLU
7. Flatten layer to convert the convolution output into a vector
8. Fully-connected layer with the correct number of inputs and $K$ outputs
9. Fully-connected layer with $K$ inputs and $O$ outputs

However, instead of relying on the `torch.nn.Sequential` class, we need to define our own network class, which we need to derive from `torch.nn.Module` -- since our network has two outputs.
We basically need to implement two methods in our network.
The constructor `__init__(self, Q1, Q2, K)` needs to call the base class constructor and initialize all required layers of our network.
The `forward(self, x)` function then passes the input through all of our layers and returns both the deep features (extracted at the first fully-connected layer) and the logits (extracted from the second fully-connected layer).


### Task 7: Network Definition

We define our own small-scale network to classify known and unknown samples for MNIST.
We basically use the same convolutional network as in Assignment 6, with some small adaptations.
However, this time we need to implement our own network model since we need to modify our network output.

Implement a network class, including the layers as provided above.
Implement both the constructor and the forward function.
Instantiate the network with $Q_1=16, Q_2=32, K=20, O=4$.


In [None]:
class Network (torch.nn.Module):
  def __init__(self, Q1, Q2, K, O):
    # call base class constructor
    super(Network, self).__init__()

    # --- Layer Definitions ---
    # Input for MNIST is 1 channel. Output is Q1 channels.
    # Padding=2 with Kernel=5 and Stride=1 preserves the image dimensions (28x28).
    self.conv1 = nn.Conv2d(in_channels=1, out_channels=Q1, kernel_size=5, stride=1, padding=2)
    # Input is Q1 channels from conv1. Output is Q2 channels.
    self.conv2 = nn.Conv2d(in_channels=Q1, out_channels=Q2, kernel_size=5, stride=1, padding=2)
    
    # Pooling and activation functions can be defined once and re-used
    self.pool = nn.MaxPool2d(kernel_size=2, stride=2) # Halves the dimensions
    self.act = nn.ReLU()
    self.flatten = nn.Flatten()
    
    # After two 2x2 pooling layers, a 28x28 image becomes 7x7.
    # The number of input features to the first FC layer is channels * height * width.
    fc1_in_features = Q2 * 7 * 7
    self.fc1 = nn.Linear(in_features=fc1_in_features, out_features=K)
    self.fc2 = nn.Linear(in_features=K, out_features=O)

  def forward(self, x):
    # --- Forward Pass ---
    # compute first block of convolution, pooling and activation
    # (N, 1, 28, 28) -> (N, Q1, 28, 28) -> (N, Q1, 14, 14) -> (N, Q1, 14, 14)
    a = self.act(self.pool(self.conv1(x)))
    
    # compute second block of convolution, pooling and activation
    # (N, Q1, 14, 14) -> (N, Q2, 14, 14) -> (N, Q2, 7, 7) -> (N, Q2, 7, 7)
    a = self.act(self.pool(self.conv2(a)))
    
    # Flatten the output for the fully-connected layers
    a_flat = self.flatten(a)
    
    # get the deep features as the output of the first fully-connected layer
    deep_features = self.fc1(a_flat)
    
    # get the logits as the output of the second fully-connected layer
    logits = self.fc2(deep_features)
    
    # return both the logits and the deep features
    return logits, deep_features

# --- Network Instantiation ---
# Check for CUDA availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# create network with specified dimensions
network = Network(Q1=16, Q2=32, K=20, O=4).to(device)

# You can print the network to verify its structure
print(network)

### Task 8: Training Loop

Instantiate the training and validation set and according data loaders.
Instantiate an SGD optimizer with an appropriate learning rate (the optimal learning rate might depend on your loss function implementation and can vary between 0.1 and 0.00001).
Implement the training and validation loop for 10 epochs (you can also train for 100 epochs if you want).
Compute the training set confidence during the epoch.
At the end of each epoch, also compute the validation set confidence measure.
Print both the training set and validation set confidence scores to the console.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# --- Hyperparameters ---
batch_size = 64
learning_rate = 0.01   # adjust if loss diverges
epochs = 10

# --- Datasets & Dataloaders ---
train_set = DataSet(purpose="train")
validation_set = DataSet(purpose="validation")

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=batch_size, shuffle=False)

# --- Network & Optimizer ---
network = Network(Q1=16, Q2=32, K=20, O=4).to(device)
optimizer = torch.optim.SGD(network.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()  # use softmax-crossentropy for classification

# --- Training Loop ---
for epoch in range(epochs):
    network.train()
    train_conf, validation_conf = 0.0, 0.0

    # Training phase
    for x, t in train_loader:
        x, t = x.to(device), t.to(device)

        # forward
        logits, features = network(x)

        # compute loss
        # targets may be one-hot, so convert to class indices for CrossEntropyLoss
        if t.ndim > 1:
            targets_idx = torch.argmax(t, dim=1)
        else:
            targets_idx = t
        loss = loss_fn(logits, targets_idx)

        # backward + update
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # accumulate confidence
        train_conf += confidence(logits.detach(), t).item()

    # Validation phase
    network.eval()
    with torch.no_grad():
        for x, t in validation_loader:
            x, t = x.to(device), t.to(device)
            logits, features = network(x)
            validation_conf += confidence(logits, t).item()

    # normalize by dataset sizes
    train_conf /= len(train_set)
    validation_conf /= len(validation_set)

    print(f"Epoch {epoch+1}; "
          f"train confidence: {train_conf:.5f}, "
          f"val confidence: {validation_conf:.5f}")


## Evaluation

For evaluation, we test two different things.
First, we check whether our intuition is correct, and the training helps reduce the deep feature magnitudes of unknown samples while maintaining magnitudes for known samples.
It is also interesting to see whether there is a difference between samples of the negative classes that were seen during training, and unknown classes that were not.
For this purpose, we extract the deep features for the validation and test sets, compute their magnitudes, and plot them in a histogram.

The second evaluation computes Correct Classification Rates (CCR) and False Positive Rates (FPR) for a given confidence threshold $\zeta=0.98$ (based on your training results, you might want to vary this threshold).
For the known samples, we compute how often the correct class was classified with a confidence over threshold.
For unknown samples, we assess how often one of the known classes was predicted with a confidence larger than $\zeta$.



### Task 9: Feature Magnitude Plot

Extract deep features for validation and test set samples and compute their magnitudes. Split them into known, negative (validation set), and unknown (test set). Plot a histogram for each of the three types of samples.
Note that the minimum magnitude is 0, and the maximum magnitude can depend on your network training success.

In [None]:
import torch
import matplotlib.pyplot as plt

# --- Instantiate test set and dataloader ---
test_set = DataSet(purpose="test")
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=False)

# --- Collect feature magnitudes ---
known, negative, unknown = [], [], []

network.eval()
with torch.no_grad():
    # Validation set → contains known + negative
    for x, t in validation_loader:
        x, t = x.to(device), t.to(device)

        logits, deep_features = network(x)
        norms = torch.norm(deep_features, p=2, dim=1)

        # Split into known vs negative
        is_known = (t.max(dim=1).values == 1.0)
        is_negative = ~is_known   # validation negatives

        known.extend(norms[is_known].cpu().tolist())
        negative.extend(norms[is_negative].cpu().tolist())

    # Test set → contains known + unknown
    for x, t in test_loader:
        x, t = x.to(device), t.to(device)

        logits, deep_features = network(x)
        norms = torch.norm(deep_features, p=2, dim=1)

        # Split into known vs unknown
        is_known = (t.max(dim=1).values == 1.0)
        is_unknown = ~is_known

        known.extend(norms[is_known].cpu().tolist())
        unknown.extend(norms[is_unknown].cpu().tolist())

# --- Plot histograms ---
plt.figure(figsize=(6,4))
max_mag = 20  # adjust depending on training results

plt.hist(known, bins=100, range=(0,max_mag), density=True,
         color="g", histtype="step", label="Known")
plt.hist(negative, bins=100, range=(0,max_mag), density=True,
         color="b", histtype="step", label="Negative")
plt.hist(unknown, bins=100, range=(0,max_mag), density=True,
         color="r", histtype="step", label="Unknown")

plt.legend()
plt.xlabel("Deep Feature Magnitude")
plt.ylabel("Density")
plt.title("Feature Magnitudes: Known vs Negative vs Unknown")
plt.show()


### Task 10: Classification Evaluation

For a fixed threshold of $\zeta=0.98$, compute CCR and FPR for the test set.
A well-trained network can achieve a CCR of > 90% for an FPR < 10%.
You might need to vary the threshold.

In [None]:
import torch
import torch.nn.functional as F

zeta = 0.98

correct, known = 0, 0
false, unknown = 0, 0

network.eval()
with torch.no_grad():
    for x, t in test_loader:
        x, t = x.to(device), t.to(device)

        # forward pass
        logits, deep_features = network(x)
        probs = F.softmax(logits, dim=1)

        # predicted class and confidence
        max_conf, pred_class = torch.max(probs, dim=1)

        # split known vs unknown (same trick as before)
        is_known = (t.max(dim=1).values == 1.0)
        is_unknown = ~is_known

        # --- Known samples ---
        if is_known.any():
            targets_idx = torch.argmax(t[is_known], dim=1)
            preds = pred_class[is_known]
            confs = max_conf[is_known]

            # correct if predicted class = target and confidence >= zeta
            correct += torch.sum((preds == targets_idx) & (confs >= zeta)).item()
            known += len(targets_idx)

        # --- Unknown samples ---
        if is_unknown.any():
            confs = max_conf[is_unknown]

            # false positives = accepted as known if confidence >= zeta
            false += torch.sum(confs >= zeta).item()
            unknown += len(confs)

# --- Report results ---
print(f"CCR: {correct} of {known} = {correct/known*100:.2f}%")
print(f"FPR: {false} of {unknown} = {false/unknown*100:.2f}%")
