# Laboratory #4: Adversarial Learning and OOD Detection

In this laboratory session we will develop a methodology for detecting OOD samples and measuring the quality of OOD detection. We will also experiment with incorporating adversarial examples during training to render models more robust to adversarial attacks.

---
## Exercise 1: OOD Detection and Performance Evaluation
In this first exercise you will build a simple OOD detection pipeline and implement some performance metrics to evaluate its performance.

### Exercise 1.1: Build a simple OOD detection pipeline

Implement an OOD detection pipeline (like in the Flipped Activity notebook) using an ID and an OOD dataset of your choice. Some options:

+ CIFAR-10 (ID), Subset of CIFAR-100 (OOD). You will need to wrap CIFAR-100 in some way to select a subset of classes that are *not* in CIFAR-10 (see `torch.utils.data.Subset`).
+ Labeled Faces in the Wild (ID), CIFAR-10 or FakeData (OOD). The LfW dataset is available in Scikit-learn (see `sklearn.datasets.fetch_lfw_people`).
+ Something else, but if using images keep the images reasonably small!

In this exercise your *OOD Detector* should produce a score representing how "out of distribution" a test sample is. We will implement some metrics in the next exercise, but for now use the techniques from the flipped activity notebook to judge how well OOD scoring is working (i.e. histograms).


**Note**: Make sure you make a validation split of your ID dataset for testing.

### NOTE: For this lab, some parts are already done, so I will not repeat them. I will only comment on my changes and additions.

In [1]:
import torch
import torchvision
from torchvision.datasets import FakeData
from torchvision import transforms
import matplotlib.pyplot as plt
from torch import nn
from torch import optim
import numpy as np
import random
import torch.nn.functional as F
from tqdm.notebook import tqdm
from torch.utils.data import Subset
from sklearn import metrics as skm
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from sklearn.metrics import RocCurveDisplay, PrecisionRecallDisplay
import matplotlib.pyplot as plt

In [2]:
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
torch.cuda.is_available()

True

In [4]:

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 128

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=8, persistent_workers= True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=8, persistent_workers= True)

fakeset = FakeData(size=len(testset), image_size=(3, 32, 32), transform=transform)
fakeloader = torch.utils.data.DataLoader(fakeset, batch_size=batch_size, shuffle=False, num_workers=8, persistent_workers= True)



 52%|█████▏    | 88.4M/170M [00:11<00:10, 7.55MB/s]


KeyboardInterrupt: 

#### Goal
To split a single, existing dataset into two separate, non-overlapping sets(one for training (90%) and one for validation (10%)) and then to create data loaders for both.

#### Method
1.  **Calculate Split Sizes**: The code first determines the exact number of samples that will go into the training set (90% of the total) and the validation set (the remaining 10%).

2.  **Define Indices**: It then creates two simple lists of indices. `train_indices` contains the indices for the first 90% of the data, and `val_indices` contains the indices for the last 10%.

 3. **Create Subset objects**: The code creates `Subset` objects by passing the `trainset` and the selected `indices`. A `Subset` is a lightweight wrapper that keeps a reference to the original `trainset` and stores only the `indices`, so the data is never duplicated. This makes it highly memory-efficient
4.  **Create `DataLoader`s**: Finally, two `DataLoader` objects are created from these new `Subset`s. These loaders will handle the process of fetching data in batches for the training loop.
    *   **`trainloader`**: Has `shuffle=True`. This is important for training, as it ensures the model sees the data in a different random order each epoch, which helps it generalize better.
    *   **`valloader`**: Has `shuffle=False`. This is a best practice for evaluation. The validation data should always be presented in the same order so that you can get consistent, comparable performance metrics from one epoch to the next.
    *   `num_workers` and `persistent_workers` are performance optimizations that use multiple CPU processes to load data in the background, preventing data loading from becoming a bottleneck during training.

#### Result
The code produces two `DataLoader` objects:
*   `trainloader`: Ready to be used in the training loop, it will serve shuffled batches of data drawn from the first 90% of the original dataset.
*   `valloiloader`: Ready to be used in the evaluation loop, it will serve unshuffled batches of data drawn from the last 10% of the original dataset.

In [None]:


train_size = int(0.9 * len(trainset))  # 90% train
val_size = len(trainset) - train_size  # 10% validation

train_indices = list(range(0, train_size))
val_indices = list(range(train_size, len(trainset)))
train_subset = Subset(trainset, train_indices)
val_subset = Subset(trainset, val_indices)

trainloader = torch.utils.data.DataLoader(train_subset, batch_size=batch_size,
                                          shuffle=True, num_workers=8, persistent_workers=True)

valloader = torch.utils.data.DataLoader(val_subset, batch_size=batch_size,
                                        shuffle=False, num_workers=8, persistent_workers=True)


In [None]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1)  # downsample
        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)  # downsample
        
        self.flatten_dim = 256 * 8 * 8  # assuming input is 32x32
        self.fc1 = nn.Linear(self.flatten_dim, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))   # -> 32x32x32
        x = F.relu(self.conv2(x))   # -> 32x32x64
        x = F.relu(self.conv3(x))   # -> 32x32x128
        x = F.relu(self.conv4(x))   # -> 16x16x128
        x = F.relu(self.conv5(x))   # -> 8x8x256
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

#### Goal
To train the CNN model while monitoring its performance on a validation set, and to save the model weights from the epoch that achieves the lowest validation loss. This ensures we end up with the best possible version of our model, preventing overfitting.

#### Method

1.  **Initialization**: Before we call the function, several key components are set up:
    *   **Hyperparameters**: The number of `epochs`, the `loss_fn` (Cross-Entropy Loss for classification), and the `optimizer` (AdamW) are defined.
    *   **Inside the train**: Variables like `best_val_loss` (initialized to infinity), `best_model_wts`, and `path` are created to keep track of the best model found so far and where to save it.

2.  **Outer Loop (Epochs)**: The main `for` loop iterates through the specified number of epochs, with a `tqdm` progress bar to provide a visual of the overall training progress.

3.  **Training Phase (per epoch)**:
    *   The code loops through the `trainloader`. For each batch of data, it performs the standard training steps:
        1.  Move data to the correct `device` (GPU).
        2.  `optimizer.zero_grad()`: Clear any gradients from the previous step.
        3.  `yp = model(x)`: Perform a forward pass to get the model's predictions.
        4.  `l = loss_fn(yp, y)`: Calculate the loss between the predictions and the true labels.
        5.  `l.backward()`: Perform backpropagation to compute the gradients.
        6.  `optimizer.step()`: Update the model's weights using the computed gradients.
    *   The `train_loss` for the epoch is calculated as the average of the batch losses.

4.  **Validation Phase (per epoch)**:
    *   `with torch.no_grad()`: This context manager disables gradient calculations. It reduces memory consumption and speeds up the process, as we are not learning in this phase.
    *   The code loops through the `valloader`, calculating the `val_loss` and the `val_acc` (validation accuracy) over all batches. No `optimizer.step()` is called here.

5.  **Best Model Checkpointing**:
    *   After each epoch, the current `val_loss` is compared to the `best_val_loss` found so far.
    *   If the current loss is lower, it means we've found a new best model. The script then:
        1.  Updates `best_val_loss` and records the current `best_epoch`.
        2.  Saves a copy of the model's current weights (`model.state_dict()`) to the specified `path`. This is the **checkpointing** step.

6.  **Logging and Final Loading**:
    *   The `tqdm` progress bar is updated with all the relevant metrics (`train_loss`, `val_loss`, `val_acc`, and the best scores so far).
    *   After the entire training loop finishes, the script loads the saved best weights back into the `model` object. This ensures that the model available for use after training is the best-performing one, not just the one from the final epoch.
#### Result
*   **A Saved Checkpoint**: A file named `cifar10_best_P.pth` is created on disk. This file contains the weights of the model at the epoch where it achieved the lowest validation loss.
*   **A Trained Model**: The `model` object in the Python session is loaded with these best weights, making it immediately ready for inference on a test set.
*   **Log**: The console output provides a clear, epoch-by-epoch log of the training process, showing how the model's performance on both the training and validation sets evolved over time.

In [None]:

#1e-5 001
def train(model, trainloader, valloader, loss_fn, optimizer, epochs, path):
        
    best_val_loss = float("inf") 
    best_model_wts = None
    best_epoch = 0  
    acc_of_best = 0
    
    epoch_pbar = tqdm(range(epochs), desc="Training Progress")
    
    for e in epoch_pbar:
        model.train()
        running_loss = 0
        for data in trainloader:
            x, y = data[0].to(device), data[1].to(device)
            optimizer.zero_grad()
            yp = model(x)
            l = loss_fn(yp, y)
    
            l.backward()
            optimizer.step()
            running_loss += l.item()
        train_loss = running_loss / len(trainloader)
        
        
        model.eval()
        val_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for data in valloader:
                x, y = data[0].to(device), data[1].to(device)
                yp = model(x)
                l = loss_fn(yp, y)
                _, predicted = yp.max(1)
                val_loss += l.item()
                
                total += y.size(0)
                correct += predicted.eq(y).sum().item()
        
        #    scheduler.step()
        val_loss /= len(valloader)
        val_acc = 100. * correct / total
        
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_epoch = e + 1 
            best_model_wts = model.state_dict().copy()
            torch.save(best_model_wts, path)
            acc_of_best = val_acc
        
        metrics = {
            'train_loss': f"{train_loss:.4f}", 
            'val_loss': f"{val_loss:.4f}", 
            'val_acc': f"{val_acc:.2f}%",
        }
        
        if best_epoch > 0:
            metrics['best_loss'] = f"{best_val_loss:.4f}"
            metrics['best_epoch'] = best_epoch
            metrics['best_acc'] = f"{acc_of_best:.2f}%" 
        epoch_pbar.set_postfix(metrics)

   
    if best_model_wts is not None:
            model.load_state_dict(best_model_wts)
            print(f"\nBest model from epoch {best_epoch} loaded (Val Loss: {best_val_loss:.4f}).")
    return model
        


model = CNN().to(device)
loss_fn = nn.CrossEntropyLoss()
#epochs = 50
#optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-5)
epochs = 100
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay = 5e-4)
#scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=0)
path =  './cifar10_best_P.pth'
base_loss = nn.CrossEntropyLoss()
model = train(model, trainloader, valloader, loss_fn, optimizer, epochs, path)

In [None]:
model = CNN().to(device)
model.load_state_dict(torch.load(path))
model.eval()
y_gt, y_pred = [], []
for it, data in enumerate (testloader):
    x, y = data
    x, y = x.to(device), y.to(device)
    yp = model(x)
    y_pred.append(yp.argmax(1))
    y_gt.append(y)        

In [None]:
y_pred_t = torch.cat(y_pred)
y_gt_t = torch.cat(y_gt)

accuracy = sum(y_pred_t == y_gt_t)/len(y_gt_t)
print(f'Accuracy: {accuracy}')
cmn = skm.confusion_matrix(y_gt_t.cpu(), y_pred_t.cpu(), normalize='true')
cmn = (100*cmn).astype(np.int32)
disp = metrics.ConfusionMatrixDisplay(cmn, display_labels=testset.classes)
disp.plot()
plt.show()


#### Goal
To test the effectiveness of using the maximum logit score as a simple confidence metric to distinguish between in-distribution (`testloader`) and out-of-distribution (`fakeloader`) data. This will serve as the initial benchmark to improve upon.

#### Method
1.  **The Scoring Function (`compute_scores`)**: A function is defined to calculate a confidence score for each image in a dataset.
    *   It takes a trained `model` and a `data_loader` as input.
    *   Inside the function, it iterates through the data and passes each batch through the model to get the raw output logits.
    *   For each image, the score is calculated as `logit.max(dim=1)[0]`. This takes the **maximum value from the logit vector**. This maximum logit corresponds to the class the model thinks is most likely. In the literature, this is known as the Maximum Softmax Probability (MSP) method, since the `softmax` function preserves the order of the logits, and a higher maximum logit directly translates to a higher confidence probability for the predicted class.

2.  **Applying the Score Function**:
    *   The best-performing model (saved previously at `path`) is loaded.
    *   The `compute_scores` function is called twice:
        *   Once on the `testloader` to get the confidence scores for **in-distribution (ID)** data.
        *   Once on the `fakeloader` to get the confidence scores for **out-of-distribution (OOD)** data.

3.  **Visualization**:
    *   The two sets of scores (`scores_test` and `scores_fake`) are sorted independently.
    *   They are then plotted on the same graph. This allows for a direct visual comparison of the entire distribution of confidence scores for ID vs. OOD data.

#### Result
*   The output is a plot containing two curves. The x-axis can be seen as the percentile of the data, and the y-axis is the confidence score.
*   **Ideal Outcome**: If this method works well, the blue curve (`test`) should be consistently **above** the orange curve (`fake`). This would visually confirm that the model is indeed more confident on in-distribution data than on out-of-distribution data.
*   **The Baseline Reality**: In practice, while there will likely be a separation, the two curves will also have a significant **overlap**. This overlap represents the failure region where this simple detector cannot distinguish between ID and OOD samples based on confidence alone. This imperfect separation is why this method is considered a **baseline**. The rest of the lab will introduce more advanced techniques designed to increase the separation between these two curves and create a more reliable OOD detector.

**NOTE**: From now on, I’ll use my trained CNN. If a different one is needed, change the path.

In [None]:
def compute_scores(model, data_loader,device, **kwargs):
    scores = []
    with torch.no_grad():
        for data in data_loader:
            x, y = data
            logit = model(x.to(device))
            score = logit.max(dim=1)[0]
            scores.append(score)
        scores = torch.cat(scores)
        return scores
path= './cifar10_best_P.pth'
model.load_state_dict(torch.load(path)) 
scores_test = compute_scores(model, testloader, device)
scores_fake = compute_scores(model, fakeloader, device)


fig, axes = plt.subplots(1, 2, figsize=(12, 5))


axes[0].plot(sorted(scores_test.cpu()), label='test', color='C0')
axes[0].plot(sorted(scores_fake.cpu()), label='fake', color='C1')
axes[0].set_title("Sorted Score Curves")
axes[0].set_xlabel("Samples (sorted index)")
axes[0].set_ylabel("Score")
axes[0].legend()

axes[1].hist(scores_test.cpu(), density=True, alpha=0.5, bins=25, label='test', color='C0')
axes[1].hist(scores_fake.cpu(), density=True, alpha=0.5, bins=25, label='fake', color='C1')
axes[1].set_title("Score Distributions (Histogram)")
axes[1].set_xlabel("Score")
axes[1].set_ylabel("Density")
axes[1].legend()

plt.tight_layout()
plt.show()

In [None]:
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        # Input size: [batch, 3, 32, 32]
        # Output size: [batch, 3, 32, 32]
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 12, 4, stride=2, padding=1),            # [batch, 12, 16, 16]
            nn.ReLU(),
            nn.Conv2d(12, 24, 4, stride=2, padding=1),           # [batch, 24, 8, 8]
            nn.ReLU(),
			nn.Conv2d(24, 48, 4, stride=2, padding=1),           # [batch, 48, 4, 4]
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
			nn.ConvTranspose2d(48, 24, 4, stride=2, padding=1),  # [batch, 24, 8, 8]
            nn.ReLU(),
			nn.ConvTranspose2d(24, 12, 4, stride=2, padding=1),  # [batch, 12, 16, 16]
            nn.ReLU(),
            nn.ConvTranspose2d(12, 3, 4, stride=2, padding=1),   # [batch, 3, 32, 32]
            nn.Sigmoid(),
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded

#### Goal
To train an Autoencoder to accurately reconstruct images from the in-distribution training set (CIFAR-10). The objective is to save the model that achieves the lowest reconstruction error on the validation set, creating a specialist model whose reconstruction quality can be used to detect anomalies.

#### Method
1.  **Model and Loss Function**: An `Autoencoder` model is initialized. The chosen loss function is `nn.MSELoss` (Mean Squared Error). This is the critical difference from a classification task. Instead of comparing model outputs to class labels, MSE directly measures the average squared difference between the pixel values of the original input image (`x`) and the reconstructed output image (`x_rec`). The optimizer's goal is to minimize this pixel-wise error.

2.  **Training Loop**: The script iterates for a set number of epochs. In each epoch:
    *   **Training Phase**: The model processes batches from the `trainloader`. For each batch, it performs a forward pass to get the reconstructed image (`x_rec`), calculates the `MSELoss` between the original and the reconstruction, and uses backpropagation to update the model weights.
    *   **Validation Phase**: The model then evaluates its performance on the `valloader`. It calculates the average reconstruction loss on this unseen portion of the in-distribution data.

3.  **Checkpointing**: Just like in the classifier training, the script keeps track of the `best_loss` on the validation set. If the model achieves a new lowest validation loss in the current epoch, it saves the model's weights (`state_dict`) to disk. This ensures the final model is the one that generalizes best to unseen ID data.
#### Result
*   The script produces a trained `Autoencoder` model, with the best-performing weights saved to `cifar10_ae.pth`.
*   This model is now a specialist, highly optimized to compress and decompress images from the CIFAR-10 dataset with minimal reconstruction error.
*   It's interesting to note that the autoencoder did not overfit. This is likely due to the fact that the **information bottleneck** successfully regularizes the model, and the task itself is more complex, changing from **classification to pixel-wise reconstruction**

In [None]:
model_ae = Autoencoder().to(device)
mse_loss = nn.MSELoss()
optimizer = optim.Adam(model_ae.parameters(), lr=0.001)
epochs = 100
best_loss = float('inf')
path = './cifar10_ae.pth'
val_loss = 0
epoch = 0 
ae_bar = tqdm(range(epochs), desc="AE Training Progress")

for e in ae_bar:
    running_loss = 0
    for data in trainloader:
        x, y = data
        x, y = x.to(device), y.to(device)
        
        z, x_rec = model_ae(x)
        l = mse_loss(x, x_rec)
        
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
        running_loss += l.item()
    val_loss = 0
    
    for data in valloader:
        x, y = data
        x, y = x.to(device), y.to(device)
        z, x_rec = model_ae(x)
        l = mse_loss(x, x_rec)
        val_loss += l.item()
    val_loss /= len(valloader)
    if val_loss < best_loss:
        best_loss = val_loss
        torch.save(model_ae.state_dict(), path)
        epoch = e 
    ae_bar.set_postfix({'train_loss': f"{running_loss/len(trainloader):.4f}", 'val_loss': f"{val_loss:.4f}", 'best_loss': f"{best_loss:.4f}", 'best_epoch': f"{epoch}"}, )
model_ae.load_state_dict(torch.load(path))


In [None]:
path = './cifar10_ae.pth'
loss = nn.MSELoss(reduction='none') 
model_ae = Autoencoder().to(device)
model_ae.load_state_dict(torch.load(path))

def compute_scores_ae(model_ae, data_loader, device, **kwargs):
    model_ae.eval()
    scores_fake_ae = []
    with torch.no_grad():
        for data in data_loader:
            x, y = data
            x=x.to(device)
            z, xr = model_ae(x)
            l = loss(x, xr)
            score = l.mean([1,2,3])
            scores_fake_ae.append(-score)
    return torch.cat(scores_fake_ae)

scores_fake_ae = compute_scores_ae(model_ae, fakeloader, device)
scores_test_ae = compute_scores_ae(model_ae, testloader, device)



In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5))


axes[0].plot(sorted(scores_test_ae.cpu()), label='test', color='C0')
axes[0].plot(sorted(scores_fake_ae.cpu()), label='fake', color='C1')
axes[0].set_title("Sorted Score Curves")
axes[0].set_xlabel("Samples (sorted index)")
axes[0].set_ylabel("Score")
axes[0].legend()

axes[1].hist(scores_test_ae.cpu(), density=True, alpha=0.5, bins=25, label='test', color='C0')
axes[1].hist(scores_fake_ae.cpu(), density=True, alpha=0.5, bins=25, label='fake', color='C1')
axes[1].set_title("Score Distributions (Histogram)")
axes[1].set_xlabel("Score")
axes[1].set_ylabel("Density")
axes[1].legend()

plt.tight_layout()
plt.show()

### Exercise 1.2: Measure your OOD detection performance

There are several metrics used to evaluate OOD detection performance, we will concentrate on two threshold-free approaches: the area under the Receiver Operator Characteristic (ROC) curve for ID classification, and the area under the Precision-Recall curve for *both* ID and OOD scoring. See [the ODIN paper](https://arxiv.org/pdf/1706.02690.pdf) section 4.3 for a description of OOD metrics.

Use the functions in `sklearn.metrics` to produce ROC and PR curves for your OOD detector. Some useful functions:

+ [`sklearn.metric.RocCurveDisplay.from_predictions`](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.RocCurveDisplay.html)
+ [`sklearn.metrics.PrecisionRecallDisplay`](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.PrecisionRecallDisplay.html)



#### Goal
To create a single, reusable function that calculates a suite of standard OOD performance metrics: **AUROC**, **AUPR**, and **FPR at 95% TPR**. These metrics are taken from **[the ODIN paper](https://arxiv.org/pdf/1706.02690.pdf)** and provide a comprehensive way to measure how well a given scoring method can separate ID and OOD data.

#### Method

The function's methodology is based on reframing the OOD detection problem as a **binary classification task**:

1.  **Score Calculation**: It first calls the provided `score_fn` (e.g., `compute_scores`) to get a single numerical score for every sample in both the ID and OOD data loaders.

2.  **Ground Truth Setup**: It then creates the ground truth labels (`y_true`). By convention, in-distribution samples are the "positive" class, so they are assigned a label of `1`. Out-of-distribution samples are the "negative" class, assigned a label of `0`.

3.  **Metric Calculation**: With the scores (`y_score`) and true labels (`y_true`), it uses `sklearn.metrics` to compute several key metrics:
    *   **AUROC (Area Under the Receiver Operating Characteristic Curve)**:  It measures the overall ability of the score to separate the two distributions. An AUROC of 1.0 means perfect separation (all ID scores are higher than all OOD scores), while 0.5 means the score is no better than random chance.
    *   **AUPR (Area Under the Precision-Recall Curve)**: Similar to AUROC, this metric also summarizes the detector's performance across all thresholds. It is particularly informative when there is a large imbalance between the number of ID and OOD samples. A score of 1.0 is perfect.
    *   **FPR@95TPR (False Positive Rate at 95% True Positive Rate)**: measures how many out-of-distribution samples are mistakenly accepted when the detector is tuned to correctly identify 95% of in-distribution samples. A lower value means better performance.
    *   **Detection Error**: This is another point-based metric, calculated at the same 95% TPR threshold. It's the average of the two types of errors at that threshold: the false positive rate (`fpr_at_95`) and the false negative rate (`1 - tpr`). A lower value is better.

#### Result
The function returns a **dictionary** containing all the calculated metrics for a given OOD detection experiment. This dictionary can be easily logged, compared across different scoring methods, or used for further analysis.


In [None]:


def eval_OOD(model, idloader, oddloader, device,  score_fn=compute_scores, T=100, eps= 0.01):
    id_scores = score_fn(model, idloader, device, T=T, eps= eps, filter_correct_only= True ).cpu()
    ood_scores = score_fn(model, oddloader,  device, T=T, eps= eps, filter_correct_only= False ).cpu().numpy()
    num_samples = min(len(id_scores), len(ood_scores))
    id_scores = id_scores[:num_samples]
    
    
    ood_scores = ood_scores[:num_samples]
    y_true = np.concatenate([np.ones_like(id_scores), np.zeros_like(ood_scores)], axis=0)
    y_score = np.concatenate([id_scores, ood_scores], axis=0)
    label = "FakeDataset-OOD"
    
    fpr, tpr, thresholds = skm.roc_curve(y_true, y_score)
    auroc = skm.auc(fpr, tpr)
    precision, recall, _ = skm.precision_recall_curve(y_true, y_score)
    aupr = skm.auc(recall, precision)
    
    
    target_tpr = 0.95
    idx = (torch.tensor(tpr) - target_tpr).abs().argmin().item()
    fpr_at_95 = float(fpr[idx])
    detection_error = 0.5 * (1 - target_tpr) + 0.5 * fpr_at_95
    
   
    
    return {
    "AUROC": auroc,
    "AUPR": aupr,
    "FPR@95TPR": fpr_at_95,
    "DetectionError@95TPR": detection_error,
    "roc_curve": (fpr, tpr),
    "pr_curve": (recall, precision),
    "label": label
    }
    

In [None]:
def plot_curve(metrics_dict, title_prefix="",):
    auc_digits = 4
    fpr, tpr = metrics_dict["roc_curve"]
    recall, precision = metrics_dict["pr_curve"]

    plt.figure(figsize=(12, 5))
    

    plt.subplot(1, 2, 1)
    disp = RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=metrics_dict["AUROC"])
    disp.plot(ax=plt.gca())
    plt.title(f"{title_prefix} ROC Curve")
    
    new_label = [ f"AUROC ={metrics_dict['AUROC']:.{auc_digits}f}" ]
    plt.legend(new_label, loc="lower right")
    

    plt.subplot(1, 2, 2)
    disp_pr = PrecisionRecallDisplay(precision=precision, recall=recall)
    disp_pr.plot(ax=plt.gca())
    plt.title(f"{title_prefix} Precision-Recall Curve")
    
    new_label = [ f"AUPR ={metrics_dict['AUPR']:.{auc_digits}f}" ]
    plt.legend( new_label, loc="lower left")
    
    plt.tight_layout()
    plt.show()


In [None]:
model.load_state_dict(torch.load(f'./cifar10_best_P.pth'))
#model.load_state_dict(torch.load(path))
metric = eval_OOD(model, testloader, fakeloader, device, compute_scores)
print(f"  FPR at 95 TPR: {metric['FPR@95TPR']:.4f} | Detection Error at 95 tpr: {metric['DetectionError@95TPR']:.4f}")
plot_curve(metric, title_prefix="CNN")

In [None]:


metric = eval_OOD(model_ae, testloader, fakeloader,device, compute_scores_ae)
print(f"  FPR at 95 TPR: {metric['FPR@95TPR']:.4f} | Detection Error at 95 tpr: {metric['DetectionError@95TPR']:.4f}")
plot_curve(metric, title_prefix="Autoencoder")



When comparing the CNN and the Autoencoder for out-of-distribution detection, their behaviors differ in important ways. The CNN delivers more stable performance, with a smoother and more consistent detection at 95 TPR. The Autoencoder, on the other hand, achieves higher overall discriminative power but concentrates its errors in a narrow region. This means it performs extremely well when operating in a conservative regime, but its performance drops sharply if the system requires very high TPR.


## Exercise 2: Enhancing Robustness to Adversarial Attack

In this second exercise we will experiment with enhancing our base model to be (more) robust to adversarial attacks. 

### Exercise 2.1: Implement FGSM and generate adversarial examples

Recall that the Fast Gradient Sign Method (FGSM) perturbs samples in the direction of the gradient with respect to the input $\mathbf{x}$:
$$ \boldsymbol{\eta}(\mathbf{x}) = \varepsilon \mathrm{sign}(\nabla_{\mathbf{x}} \mathcal{L}(\boldsymbol{\theta}, \mathbf{x}, y)) ) $$
Implement FGSM and generate some *adversarial examples* using your trained ID model. Evaluate these samples qualitatively and quantitatively. Evaluate how dependent on $\varepsilon$ the quality of these samples are. 

In [None]:
class NormalizeInverse(torchvision.transforms.Normalize):

    def __init__(self, mean, std):
        mean = torch.as_tensor(mean)
        std = torch.as_tensor(std)
        std_inv = 1 / (std + 1e-7)
        mean_inv = -mean * std_inv
        super().__init__(mean=mean_inv, std=std_inv)

    def __call__(self, tensor):
        return super().__call__(tensor.clone())

inv = NormalizeInverse((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

In [None]:
for i,c in enumerate(testset.classes):
    print(i, c)

In [None]:
import torch
import torch.nn as nn

def fgsm_attack(model, testloader, device, sample_id=0, eps=1/255, targeted_attack=True, target_label=None):
    loss_fn = nn.CrossEntropyLoss()
    model.train()

    for data in testloader:
         x, y = data
        

         x, y = x[sample_id].to(device), y[sample_id].to(device)
         x = x.unsqueeze(0)
         y = y.unsqueeze(0)
         x.requires_grad = True
         original = x.clone()
         output = model(x)
         if output.argmax().item() == y.item() or (target_label is not None and y.item() != target_label):
            break
         else:
            print('Classifier is already wrong or target label matches ground truth.')
             
    print('Attack!')
    n = 0
    target = torch.tensor(target_label).unsqueeze(0).to(device) if targeted_attack and target_label is not None else y

    while True:
        x.retain_grad()
        output = model(x)
        model.zero_grad()

        loss = loss_fn(output, target if targeted_attack else y)
        loss.backward()

        if targeted_attack:
            x = x - eps * torch.sign(x.grad)
        else:
            x = x + eps * torch.sign(x.grad)

        n += 1
        pred = output.argmax().item()

        if not targeted_attack and pred != y.item():
            print(f'Untargeted attack success! budget: {int(255 * n * eps)}/255')
            break

        if targeted_attack and pred == target.item():
            print(f'Targeted attack ({pred}) success! budget: {int(255 * n * eps)}/255')
            break

    return original, x, y, output


# MAIN
if __name__ == "__main__":
    import matplotlib.pyplot as plt
    print(testset.classes)
    original, adv_x, y, output = fgsm_attack(model, testloader, device, sample_id=0, eps=1/255, targeted_attack=True, target_label=testset.classes.index('deer'))

    img = inv(adv_x.squeeze())
    plt.imshow(img.permute(1, 2, 0).detach().cpu())
    plt.title(testset.classes[output.argmax()])
    plt.show()

    diff = adv_x - original
    diff_img = inv(diff[0])
    plt.imshow(diff_img.permute(1, 2, 0).detach().cpu())
    plt.title('Difference')
    plt.show()

    diff_flat = diff.flatten()
    plt.hist(diff_flat.detach().cpu())
    plt.show()

In [None]:
x.shape

In [None]:
diff.squeeze().mean(0).shape
plt.imshow(255*diff.cpu().detach().squeeze().mean(0))
plt.colorbar()

### Exercise 2.2: Augment training with adversarial examples

Use your implementation of FGSM to augment your training dataset with adversarial samples. Ideally, you should implement this data augmentation *on the fly* so that the adversarial samples are always generated using the current model. Evaluate whether the model is more (or less) robust to ID samples using your OOD detection pipeline and metrics you implemented in Exercise 1.

#### Goal
To create a batch of adversarial examples by performing a **single-step gradient ascent** in the input space. The goal is to find the direction that will **maximize the model's loss** as quickly as possible and then take a small step in that direction, hopefully  fooling the classifier.

#### Method
The function executes the FGSM algorithm, which can be broken down into these key steps:

1.  **Prepare for Gradient Calculation**:
    *   A detached clone of the input tensor x is created, and its requires_grad attribute is set to True. This operation is critical as it designates the input image's pixels, rather than the model's weights, as the parameters with respect to which the gradients will be computed.

2.  **Calculate the Loss**:
    *   A standard forward pass is performed (`logits = model(x_adv)`), and the loss is calculated against the *true* labels `y`. 

3.  **Find the Direction of Attack**:
    *   `torch.autograd.grad(loss, x_adv, ...)`: This is the core of the attack. It computes the gradient of the `loss` with respect to the input image `x_adv`. This gradient, `∇_x L(θ, x, y)`, is a tensor that has the same shape as the image and points in the direction in the pixel space that will cause the **greatest increase** in the loss.
    *   `torch.sign(...)`: This is the "Sign" part of FGSM. Instead of using the full gradient, the algorithm takes only its sign (`-1` if negative, `+1` if positive). This identifies the general direction of the steepest ascent for each pixel, simplifying the attack and creating a uniform perturbation.

4.  **Apply the Perturbation**:
    *   `x_adv = x_adv + eps * grad_sign`: The original image is modified by adding the signed gradient, scaled by a small constant `eps` (epsilon). Epsilon controls the magnitude of the attack; a larger `eps` creates a more noticeable change in the image but makes the attack more potent. 

5.  **Ensure Image Validity**:
    *   `torch.clamp(x_adv, -1, 1)`: The perturbation might push some pixel values outside their valid range (e.g., below -1 or above 1). This line clips the values to ensure that the resulting `x_adv` is still a valid image tensor.

#### Result
*   The function returns `x_adv`, a tensor representing a batch of **adversarial images**.
*   These images are visually almost indistinguishable from the original images in `x`, as the changes are subtle (controlled by `eps`).
*   However, when this `x_adv` tensor is fed back into the `model`, the model is highly likely to misclassify it.


In [None]:
def fgsm_attack_batch(model, x, y, eps=1/255, loss_fn=None):
    if loss_fn is None:
        loss_fn = nn.CrossEntropyLoss()
    model.eval()
    x_adv = x.detach().clone()
    x_adv.requires_grad_(True)
    logits = model(x_adv)
    loss = loss_fn(logits, y)
    grad_sign = torch.sign(torch.autograd.grad(loss, x_adv, retain_graph=False, create_graph=False)[0])
    x_adv = x_adv + eps * grad_sign
    x_adv = torch.clamp(x_adv, -1, 1)
    return x_adv.detach()

### Goal  
The purpose of this code is to evaluate how well a model performs under normal conditions and when exposed to adversarial attacks. It measures three main things:  
- The model’s *accuracy on clean data*.  
- The model’s *accuracy on adversarially perturbed data*.  
- The *average loss* the model produces during evaluation.  

### Method  
1. **Preparation**  
   - Tracking variables are initialized for counting correctly classified samples, total samples, adversarial results, and losses.  
   - If no custom loss function is provided, the code uses `CrossEntropyLoss` with reduction set to `"sum"`.  

2. **Main Loop (per batch of data)**  
   - Input images `x` and labels `y` are moved onto the appropriate device (GPU).  
   - Predictions are generated on **clean inputs** without gradient calculations (`torch.no_grad()`).  
   - The number of correct predictions and the total loss are updated.  

3. **Adversarial Evaluation**  
   - A mask (`correct_mask`) selects only the samples the model classified correctly in the clean setting.  
   - If adversarial attacks are enabled (`eps > 0`) and there are correctly classified samples, these images are perturbed using the **Fast Gradient Sign Method (FGSM)** through the function `fgsm_attack_batch`.  
   - The perturbed images `x_adv` are passed through the model, and predictions are compared with the true labels again to measure adversarial robustness.  

4. **Final Aggregation**  
   - Clean accuracy is computed as the ratio of correct clean predictions to total samples.  
   - Adversarial accuracy is computed as the ratio of correct adversarial predictions to total samples.  
   - Average loss is computed by dividing the accumulated loss by the total number of samples.  


### Results  
At the end, the function returns three key metrics:  
- **Clean Accuracy (`clean_acc`)**: the percentage of samples correctly classified without modification.  
- **Adversarial Accuracy (`adv_acc`)**: the percentage of samples correctly classified after adversarial perturbation. This gives insight into model robustness.  
- **Average Loss (`avg_loss`)**: the mean loss across all clean samples, representing overall prediction quality.  


In [None]:
def evaluate_model(model, loader, device, eps, loss_fn=None):
    model.eval()
    
    total_clean_correct = 0
    total_adv_correct = 0
    total_samples = 0
    total_loss = 0.0

    if loss_fn is None:
        loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")
    
    for x, y in loader:
        x, y = x.to(device), y.to(device)

     
        with torch.no_grad():
            logits_clean = model(x)
            preds_clean = torch.argmax(logits_clean, dim=1)
        total_clean_correct += (preds_clean == y).sum().item()
        total_loss += loss_fn(logits_clean, y).item()
        total_samples += y.size(0)

        
        correct_mask = preds_clean.eq(y)  
        
        if correct_mask.any() and eps > 0:
            x_to_attack = x[correct_mask]
            y_to_attack = y[correct_mask]

            x_adv = fgsm_attack_batch(model, x_to_attack, y_to_attack, eps=eps)

            with torch.no_grad():
                logits_adv = model(x_adv)
                preds_adv = torch.argmax(logits_adv, dim=1)
            total_adv_correct += (preds_adv == y_to_attack).sum().item()
    
    clean_acc = total_clean_correct / max(total_samples, 1)
    adv_acc = total_adv_correct / max(total_samples, 1)
    avg_loss = total_loss / max(total_samples, 1)
    
    return clean_acc, adv_acc, avg_loss


### Goal  
This function aims to train a model with adversarial training using the Fast Gradient Sign Method (FGSM). It improves the robustness of the model by using a combined loss on both clean inputs and their adversarially perturbed counterparts. The training also monitors performance on a validation set, including clean accuracy, adversarial accuracy, and out-of-distribution (OOD) detection metrics, while saving the best model based on validation loss.  

### Method   
- AdamW optimizer is initialized with weight decay, and a cosine annealing learning rate scheduler is set to decay the learning rate smoothly across the total epochs. Cross-entropy loss serves as the criterion.  
- For each training epoch:  
  - The model trains on batches from the training loader. For each batch, adversarial examples are generated using FGSM (`fgsm_attack_batch`) by perturbing inputs slightly in the direction of the gradient to maximize loss.  
  - The loss is calculated as a weighted sum of clean input loss (`loss_clean`) and adversarial input loss (`loss_adv`), controlled by `alpha`.
  - Backpropagation and optimizer step update model weights. Running metrics for loss and accuracy on clean inputs are tracked.  
- After each epoch, the scheduler updates learning rate according to a cosine annealing schedule.  
- Validation is performed using the provided `evaluate_model` function to get clean accuracy, adversarial accuracy, and average loss metrics. Additionally, an OOD evaluation function returns metrics such as AUROC and FPR@95TPR.  
- Metrics are logged to TensorBoard via `SummaryWriter` for visualization.  
- The model with the lowest validation loss is saved for future use.  
- After all epochs, the best saved model is loaded and the training history dictionary with various tracked metric lists is returned.  
 
### Results  
- The training history contains epoch-wise loss and accuracy on both training and validation sets (clean and adversarial), plus validation loss.  
- Per epoch console output provides detailed information including OOD evaluation metrics (AUROC and FPR@95TPR), clean accuracy, and validation loss to help analyze robustness improvements.  
- The learning rate follows a smooth cosine decay over epochs, improving convergence stability.  
- The final model loaded is the best version based on validation loss, optimized for both clean and adversarial robustness.  


In [None]:
import math
import os

def train_with_fgsm(model, trainloader, valloader, device,
                    epochs=10, eps=1/255, alpha=0.5, lr=0.1, wd =5e-4):

    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    criterion = nn.CrossEntropyLoss()
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    history = {"train_loss": [], "train_acc": [], "val_acc": [], "adv_val_acc": [], "val_loss": []}

    best_val =math.inf
    path = './cifar10_best_advtrained.pth'
    best_epoch = 0
    best_auroc = 0 
    logdir = os.path.join("fgsm", f"FGSM_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
    writer = SummaryWriter(log_dir=logdir)

    epoch_pbar = tqdm(range(1, epochs + 1), desc="Training", unit="epoch")

    for epoch in epoch_pbar:
        model.train()
        running_loss, running_acc, n_samples = 0.0, 0.0, 0

        for x, y in trainloader:
            x, y = x.to(device), y.to(device)
            x_adv = fgsm_attack_batch(model, x, y, eps=eps)

            optimizer.zero_grad()
            
            logits_clean = model(x)
            loss_clean = criterion(logits_clean, y)
            
            logits_adv = model(x_adv)
            loss_adv = criterion(logits_adv, y)
            
            loss = alpha *  loss_clean  + (1.0 - alpha) * loss_adv
            loss.backward()
            optimizer.step()

            preds = torch.argmax(logits_clean, dim=1)
            acc = (preds == y).float().mean().item()
            running_loss += loss.item() * x.size(0)
            running_acc += acc * x.size(0)
            n_samples += x.size(0)

        scheduler.step()
        
        train_loss = running_loss / n_samples
        train_acc = running_acc / n_samples
        val_acc_clean, val_acc_adv, avg_loss = evaluate_model(model, valloader, device, eps, criterion)
        metric =  eval_OOD(model,valloader, fakeloader, device)
        print(f"{epoch}| AUROC: {metric['AUROC']:.4f} |  FPR@95TPR: {metric['FPR@95TPR']:.4f} | Accuarcy: {val_acc_clean:.4f} | Val Loss: {avg_loss:.4f}")
        auroc = metric['AUROC']
        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_acc)
        history["val_acc"].append(val_acc_clean)
        history["adv_val_acc"].append(val_acc_adv)
        history["val_loss"].append(avg_loss)
        
        writer.add_scalar("Loss/train", train_loss, epoch)
        writer.add_scalar("Accuracy/train", train_acc, epoch)
        writer.add_scalar("Accuracy/val_clean", val_acc_clean, epoch)
        writer.add_scalar("Accuracy/val_adv", val_acc_adv, epoch)
        writer.add_scalar("Loss/val", avg_loss, epoch)
        writer.add_scalar("AUROC FakeDataset-OOD",auroc, epoch)
        if avg_loss < best_val:
            best_val = avg_loss
            torch.save(model.state_dict(), path)
            best_epoch = epoch
            best_auroc = auroc 
        epoch_pbar.set_postfix([
           ("train_loss", f"{train_loss:.4f}"),
           ("val_acc", f"{val_acc_clean:.4f}"),
           ("adv_val_acc", f"{val_acc_adv:.4f}"),
            ("val_loss", f"{avg_loss:.4f}"),
           ("best_val", f"{best_val:.4f}"),
           ("best_epoch", f"{best_epoch}"),
           ("best_auroc", f"{best_auroc:.4f}")
        ])

    writer.close()
    print(f"\nAddestramento completato.")
    print(f"Caricamento del modello migliore con accuratezza avversaria di validazione: {best_val:.4f}")
    model.load_state_dict(torch.load(path))
    
    return history

In [None]:





if __name__ == "__main__":


    device ="cuda" if torch.cuda.is_available() else "cpu"
    #location = './cifar10_best_my.pth'
    model = CNN().to(device)
    #model.load_state_dict(torch.load(location))
    acc_clean,accuracy_adv ,_  = evaluate_model(model, testloader, device,eps=8/255)
    print(f'Accuracy on adversarial samples before FGSM training: {accuracy_adv:.4f}')
    metric= eval_OOD(model,testloader, fakeloader, device)
    print(f"AUROC: {metric['AUROC']:.4f} |  FPR@95TPR: {metric['FPR@95TPR']:.4f} | Accuarcy: {acc_clean:.4f}")
   # plot_robustness(metric, title_prefix="FakeDataset-OOD device,")

    history = train_with_fgsm(model, trainloader, valloader, device,
                               epochs=30, eps=32/255,
                              alpha=0.5 , lr=0.003, 
                              wd =5e-4)

    acc_clean, accuracy_adv, _ = evaluate_model(model, testloader, device,eps=8/255)
    print(f'Accuracy on adversarial samples after FGSM training: {accuracy_adv:.4f}')
    metric = eval_OOD(model, testloader, fakeloader, device)
    print(f"AUROC: {metric['AUROC']:.4f} |  FPR@95TPR: {metric['FPR@95TPR']:.4f} | Accuarcy: {acc_clean:.4f}")
    plot_curve(metric, title_prefix="FakeDataset-OOD")


In [None]:
    a , accuracy_adv, b = evaluate_model(model, testloader, device,eps=8/255)
    print(f'Accuracy on adversarial samples after FGSM training: {accuracy_adv:.4f}')
    print(a, b)
    metrics_fake = eval_robustness(model, testloader, fakeloader)
    plot_robustness(metrics_fake, title_prefix="FakeDataset-OOD") 

In [None]:
scores_id = compute_scores(model, testloader, max_logit).cpu().numpy()
scores_ood  = compute_scores(model, fakeloader, max_logit).cpu().numpy()
plt.hist(scores_id, density=True, alpha=0.5, bins=25)
plt.hist(scores_ood, density=True, alpha=0.5, bins=25)
plt.show()

---
## Exercise 3: Wildcard

You know the drill. Pick *ONE* of the following exercises to complete.

### Exercise 3.1: Implement ODIN for OOD detection
ODIN is a very simple approach, and you can already start experimenting by implementing a temperature hyperparameter in your base model and doing a grid search on $T$ and $\varepsilon$.


In [None]:

device ="cuda" if torch.cuda.is_available() else "cpu"
model = CNN().to(device)
model.load_state_dict(torch.load(f'./cifar10_best_P.pth')) # modello dato dal professore  


#### Goal 
To compute the ODIN OOD detection score by actively amplifying the confidence scores of in-distribution (ID) samples more than out-of-distribution (OOD) samples. The hypothesis is that a small, gradient-based perturbation that aims to increase the model's confidence will have a much larger effect on an ID sample than on an OOD sample, thus creating a wider gap between their final scores. 
#### Method 
The function calculates the ODIN score for each image in a batch by following a multi-step process: 
1. **Gradient Calculation Setup**: Just like in an adversarial attack, the function enables gradient computation with respect to the input image `x` by setting `x.requires_grad = True`.
2. **Part 1: Temperature Scaling (`T`)**:
    * The model's output logits are divided by a temperature value `T` (typically `T > 1`).
    *  **Effect**: A high temperature "softens" the softmax distribution, making the probabilities less peaky. This might seem counterintuitive, but it has the effect of scaling up the magnitudes of the gradients for all classes, which is crucial for the next step.
    * The loss is then calculated between these scaled logits and the model's *own prediction* (`y_pred_correct`).
4. **Part 2: Input Pre-processing (`eps`)**:
    * **Gradient Computation**: The function computes the gradient of the loss with respect to the input image `x`. This gradient points in the direction that would *increase* the loss.
    * **Perturbation**: The key step of ODIN is `perturbed_x = x_correct - eps * gradient_sign`. This looks like FGSM but with a critical difference: the gradient is **subtracted**, not added
    * **Reasoning**: We are performing **gradient descent** on the loss with respect to the input. This means we are subtly modifying the image to push it *even closer* to the decision boundary of its predicted class, with the goal of **maximizing the softmax score**. The core hypothesis of ODIN is that this small push has a much larger effect on ID samples  than on OOD samples.
6. **Final Score Calculation**:
   * A final forward pass is performed on the newly created `perturbed_x`. * The logits from this pass are again scaled by the temperature `T`
   * The final ODIN score is the maximum value of the softmax output of this **perturbed, temperature-scaled** result.
#### Result 
* The function returns a tensor of ODIN scores, one for each image.
* These scores are not just a passive measure of the model's confidence; they are the result of an **active process** designed to separate ID and OOD samples.
* The expected outcome is that the distribution of these ODIN scores for ID data will be much higher and more clearly separated from the distribution of scores for OOD data when compared to the simple MSP baseline. This should translate directly to superior quantitative metrics (higher AUROC, lower FPR@95TPR).

In [None]:
def compute_scores_ODIN(model, dataloader, device, T, eps, filter_correct_only=False):
    model.eval()
    all_scores = []
    for x, y_true in dataloader:
        x, y_true = x.to(device), y_true.to(device)
        x.requires_grad = True
        logits = model(x)
        y_pred = logits.argmax(dim=1)
        if filter_correct_only:
            correct_mask = y_pred.eq(y_true)
        else: 
            correct_mask =  torch.ones(len(x), dtype=torch.bool)
            
        if correct_mask.sum() == 0:
            continue
            
        logits_correct = logits[correct_mask]
        y_pred_correct = y_pred[correct_mask]

        loss = F.cross_entropy(logits_correct / T, y_pred_correct, reduction='sum')
        grad_full, = torch.autograd.grad(loss, x)
        x_correct = x[correct_mask]
        grad_correct = grad_full[correct_mask]

        with torch.no_grad():
            gradient_sign = grad_correct.data.sign()
            perturbed_x = x_correct - eps * gradient_sign/0.5
           

            final_logits = model(perturbed_x)
            prob = F.softmax(final_logits / T, dim=1)
            score = prob.max(dim=1)[0]
            
            all_scores.append(score.cpu())
            
    return torch.cat(all_scores)



#### Goal
To create a single, standardized figure that provides a comprehensive visual comparison of two OOD detection methods. The function should display both the **ROC curves** and the **Precision-Recall curves** for each method on shared axes, with clear legends that include the quantitative summary metrics (AUROC and AUPR).

#### Method

The function uses `matplotlib` and the convenient display objects from `sklearn.metrics` to build the visualization step-by-step:

1.  **Figure Setup**: It begins by creating a figure with one row and two columns of subplots (`plt.subplots(1, 2, ...)`). This sets up the side-by-side layout for the ROC and Precision-Recall comparisons.

2.  **Data Extraction**: It unpacks the raw data points (`fpr`, `tpr`, `recall`, `precision`) from the two input metric dictionaries. These raw points are what will be used to draw the curves.

3.  **ROC Curve Plotting (Left Subplot)**:
    *   It uses `sklearn.metrics.RocCurveDisplay`. This is a helper object that takes the `fpr`, `tpr`, and the pre-computed `roc_auc` score.
    *   It plots the curve for the first method on the left subplot (`ax=ax1`) with a solid line.
    *   It then plots the curve for the second method on the **same subplot**, but with a dashed line (`linestyle='--'`) to visually distinguish it.
    *   **Informative Legend**: Instead of a simple legend, it creates a custom one that includes not just the label of each method but also its corresponding AUROC score, formatted to four decimal places. This adds crucial quantitative information directly to the plot.

4.  **Precision-Recall Curve Plotting (Right Subplot)**:
    *   It follows the exact same logic as the ROC plotting, but this time using `sklearn.metrics.PrecisionRecallDisplay` and targeting the right subplot (`ax=ax2`).
    *   It overlays the two curves using solid and dashed lines and creates a custom legend that includes the AUPR score for each method.

5.  **Final Touches**: `plt.tight_layout()` is called to automatically adjust the spacing and prevent titles and labels from overlapping, and `plt.show()` displays the final, completed figure.

#### Result
The function does not return any data; its result is the **visualization itself**. The output is a single, clear figure containing two plots:

*   **Left Plot (ROC Comparison)**: This plot allows for an immediate visual assessment of which method is better at separating the two classes. The curve that is "higher and to the left" represents the superior detector. The legend provides the precise AUROC value to confirm the visual impression.
*   **Right Plot (Precision-Recall Comparison)**: This plot provides a complementary view of performance, which is especially useful when datasets are imbalanced. The curve that is "higher and to the right" is better. The legend provides the AUPR value.



In [None]:
def plot_curve_comparison(metrics_dict1, metrics_dict2, label1="Method 1", label2="Method 2", title_prefix=""):

    auc_digits = 4
    
    fpr1, tpr1 = metrics_dict1["roc_curve"]
    recall1, precision1 = metrics_dict1["pr_curve"]
    
    fpr2, tpr2 = metrics_dict2["roc_curve"]
    recall2, precision2 = metrics_dict2["pr_curve"]
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
 
    disp1 = RocCurveDisplay(fpr=fpr1, tpr=tpr1, roc_auc=metrics_dict1["AUROC"])
    disp1.plot(ax=ax1, name=f'{label1}', curve_kwargs={'linestyle': '-'})
    
    disp2 = RocCurveDisplay(fpr=fpr2, tpr=tpr2, roc_auc=metrics_dict2["AUROC"])
    disp2.plot(ax=ax1, name=f'{label2}', curve_kwargs={'linestyle': '--'})
    
    ax1.set_title(f"{title_prefix} ROC Curve Comparison")
    ax1.legend([
        f"{label1} (AUROC = {metrics_dict1['AUROC']:.{auc_digits}f})",
        f"{label2} (AUROC = {metrics_dict2['AUROC']:.{auc_digits}f})"
    ], loc="lower right")
    
    disp_pr1 = PrecisionRecallDisplay(precision=precision1, recall=recall1)
    disp_pr1.plot(ax=ax2, name=f'{label1}', linestyle='-')
    
    disp_pr2 = PrecisionRecallDisplay(precision=precision2, recall=recall2)
    disp_pr2.plot(ax=ax2, name=f'{label2}', linestyle='--')
    
    ax2.set_title(f"{title_prefix} Precision-Recall Curve Comparison")
    ax2.legend([
        f"{label1} (AUPR = {metrics_dict1['AUPR']:.{auc_digits}f})",
        f"{label2} (AUPR = {metrics_dict2['AUPR']:.{auc_digits}f})"
    ], loc="lower left")
    
    plt.tight_layout()
    plt.show()



In [None]:
model.load_state_dict(torch.load(f'./cifar10_best_P.pth'))
metric1 = eval_OOD(model, testloader, fakeloader, device, compute_scores_ODIN)
print(f"  FPR at 95 TPR: {metric1['FPR@95TPR']:.4f} | Detection Error at 95 tpr: {metric1['DetectionError@95TPR']:.4f}")

metric2 = eval_OOD(model, testloader, fakeloader, device,compute_scores)
print(f"  FPR at 95 TPR: {metric2['FPR@95TPR']:.4f} | Detection Error at 95 tpr: {metric2['DetectionError@95TPR']:.4f}")
plot_curve_comparison(metric1,metric2,"ODIN", "Baseline", title_prefix="CNN")

#### Goal
To systematically test a wide range of `temperature` and `eps` values to find the single best pair that minimizes the False Positive Rate when the True Positive Rate is fixed at 95% (FPR@95TPR). This process tunes the ODIN algorithm to its peak performance for the specific model and dataset being used.

#### Method
The script implements a classic **grid search** algorithm:

1.  **Define the Search Space**: Two lists of potential hyperparameter values are created: `temperature` and `all_eps`. These lists define the "grid" of all possible combinations that will be tested.

2.  **Iterate Through the Grid**: The code uses nested `for` loops to iterate through every single `(T, eps)` pair in the search space.

3.  **Evaluate Each Combination**: For each pair of hyperparameters, the full OOD evaluation is performed:
    *   The `compute_scores_ODIN` function is called for both the in-distribution (`testloader`) and out-of-distribution (`fakeloader`) data using the current `T` and `eps`.
    *   The ground truth labels and predicted scores are assembled.
    *   The ROC curve is calculated, and the specific metric of interest, `fpr_at_95_ODIN`, is extracted. This metric represents the performance of the current `(T, eps)` configuration.

4.  **Track the Best Performance**: A variable, `best_fpr`, is initialized to a very high value. After each evaluation, the script checks:
    *   `if fpr_at_95_ODIN < best_fpr:`
    *   If the current combination of hyperparameters yields a lower (better) FPR than the best one found so far, the script updates `best_fpr` and stores the current `T` and `eps` as the new best values (`bestT`, `bestEps`).
    *   The print statements provide a real-time log of the search, indicating when a new best combination has been found.

#### Result
The primary result of this script is not a trained model, but a set of **optimal hyperparameters**.

*   **`bestT` and `bestEps`**: The final output provides the single `Temperature` and `Epsilon` values that produced the best OOD detection performance out of all the combinations tested.
*   **`best_fpr`**: The script also outputs the best achievable performance score (the minimum FPR@95TPR). This value represents the peak performance of the ODIN method after it has been properly tuned.
*   **Scores for Plotting (`best_ypred`, `best_ytrue`)**: The script conveniently saves the raw scores from the single best run. These can be passed directly to the `eval_OOD` function to generate a full metrics dictionary and then to the `plot_curve_comparison` function to visualize the performance of the **optimized** ODIN method against other techniques.

In [None]:
temperature = [1, 2, 5, 10, 20, 50, 100, 200, 500, 1000]
all_eps =  np.linspace(0, 0.004, 21)

std_dev = 0.5
best_fpr = 100
n=0
tot = len (temperature) * len(all_eps) 
for T in temperature:
    for eps in all_eps:

        
        scores_test_ODIN = []
        model.eval()
        
        scores_test_ODIN = compute_scores_ODIN(model, testloader,device, T, eps, filter_correct_only=True)
        scores_fake_ODIN = compute_scores_ODIN(model, fakeloader,device, T, eps, filter_correct_only=False)    
        

        y_fake_ODIN = torch.zeros_like(scores_fake_ODIN)
        y_test_ODIN = torch.ones_like(scores_test_ODIN)
        ypred_ODIN = torch.cat((scores_test_ODIN, scores_fake_ODIN))
        ytrue_ODIN = torch.cat((y_test_ODIN, y_fake_ODIN))
        fpr_ODIN, tpr_ODIN, thresholds_ODIN = skm.roc_curve(ytrue_ODIN, ypred_ODIN.detach().numpy())
        idx_ODIN = np.argmin(np.abs(tpr_ODIN - 0.95))
        fpr_at_95_ODIN = fpr_ODIN[idx_ODIN]
        n += 1
       
        if fpr_at_95_ODIN < best_fpr:
            best_ypred = ypred_ODIN
            best_ytrue = ytrue_ODIN
            bestT = T
            bestEps = eps
            best_fpr = fpr_at_95_ODIN
            print(f'Tempearture: {T}, Epsilon: {eps:.4f}, FPR at 95% TPR: {fpr_at_95_ODIN:.6f}  {n}/{tot}  new Best') 
        else :
            print(f'Tempearture: {T}, Epsilon: {eps:.4f}, FPR at 95% TPR: {fpr_at_95_ODIN:.6f}  {n}/{tot}') 
            
    print()
    print()

print( f'Best performance Temperature:{bestT}  Epsilon: {bestEps:.4f}, FPR at 95% TPR:{best_fpr}')


In [None]:

fpr_ODIN, tpr_ODIN, _ = metrics.roc_curve(best_ytrue, best_ypred.detach().cpu().numpy())
roc_auc_ODIN = metrics.auc(fpr_ODIN, tpr_ODIN)

# Prepare true labels and predictions for regular scores
y_fake = torch.zeros_like(scores_fake)
y_test = torch.ones_like(scores_test)
ypred = torch.cat((scores_test, scores_fake))
ytrue = torch.cat((y_test, y_fake))
print( ytrue, ypred) 
ytrue = ytrue.cpu()
ypred = ypred.cpu()
# Compute ROC curve and AUC for regular scores
fpr, tpr, _ = metrics.roc_curve(ytrue, ypred)
roc_auc = metrics.auc(fpr, tpr)

# Plot both ROC curves on the same figure
plt.figure(figsize=(8, 6))
plt.plot(fpr_ODIN, tpr_ODIN, label=f'ODIN ROC (AUC = {roc_auc_ODIN:.4f})')
plt.plot(fpr, tpr, label=f'Baseline ROC (AUC = {roc_auc:.4f})')

# Plot diagonal line
plt.plot([0, 1], [0, 1], 'k--')

# Plot formatting
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve Comparison')
plt.legend(loc='lower right')
plt.grid(True)
plt.tight_layout()
plt.show()  