In [1]:
!pip install torchmetrics plotly-express

Collecting torchmetrics
  Downloading torchmetrics-1.4.2-py3-none-any.whl.metadata (19 kB)
Collecting plotly-express
  Downloading plotly_express-0.4.1-py2.py3-none-any.whl.metadata (1.7 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.11.7-py3-none-any.whl.metadata (5.2 kB)
Downloading torchmetrics-1.4.2-py3-none-any.whl (869 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m869.2/869.2 kB[0m [31m32.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading plotly_express-0.4.1-py2.py3-none-any.whl (2.9 kB)
Downloading lightning_utilities-0.11.7-py3-none-any.whl (26 kB)
Installing collected packages: lightning-utilities, torchmetrics, plotly-express
Successfully installed lightning-utilities-0.11.7 plotly-express-0.4.1 torchmetrics-1.4.2


In [2]:
import os

import torch
from torch.utils.data import DataLoader

from torchmetrics import MetricCollection
from torchmetrics.classification import Accuracy, JaccardIndex

import numpy as np
from sklearn.metrics import roc_auc_score, log_loss, accuracy_score

from tqdm import tqdm

import PIL
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.colors as colors
import plotly
import plotly.io as pio
pio.renderers.default = 'colab'
import plotly.express as px
import plotly.graph_objects as go
from plotly.io import write_image

In [3]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [4]:
def load_npz_files(dir_path, convert_to_torch=True):
    """Helper function to load and concatenate data from all saved npz files in a directory."""
    # Ensure directory exists
    if not os.path.exists(dir_path):
        raise ValueError(f"Directory '{dir_path}' does not exist.")

    all_data = []
    for file_name in sorted(os.listdir(dir_path)):
        if file_name.endswith(".npz"):
            file_path = os.path.join(dir_path, file_name)
            data = np.load(file_path)["data"]
            all_data.append(data)

    # Ensure there is data to concatenate
    if len(all_data) == 0:
        raise ValueError(f"No .npz files found in directory '{dir_path}'.")

    # Concatenate and convert data
    if convert_to_torch:
        concatenated_data = torch.from_numpy(np.concatenate(all_data, axis=0))
    else:
        concatenated_data = np.concatenate(all_data, axis=0)

    return concatenated_data

In [5]:
class LoadMisclassificationDataset(torch.utils.data.Dataset):
    """
    Args:
            conf_file (str): Path to the npz files directory of max_pred_probs.
            label_file (str): Path to the npz file directory of labels.
    """
    def __init__(self, conf_file, label_file):
        print("Initializing Misclassification Dataset...")
        # Load all npz files from the ID directory
        self.conf_scores = load_npz_files(conf_file)

        print("[INFO]: Getting the preds for OOD dataset")
        self.preds = (self.conf_scores > 0.7).int()

        self.labels = load_npz_files(label_file)

        # Create a boolean mask where True indicates a misclassification
        misclassification_mask = self.preds != self.labels

        # Set to 1 where misclassification occurs, 0 for corrrect
        self.targets = misclassification_mask.int()

    def __len__(self):
        return len(self.conf_scores)

    def __getitem__(self, idx):
        conf_score = self.conf_scores[idx]
        target = self.targets[idx]
        return {"scores": conf_score, "label": target}

In [None]:
# from typing import List

# import torch
# from torch import Tensor
# from torchmetrics import Metric
# from torchmetrics.functional.classification import binary_auroc, binary_precision_recall_curve, binary_roc
# from torchmetrics.utilities import rank_zero_warn
# from torchmetrics.utilities.compute import auc
# from torchmetrics.utilities.data import dim_zero_cat

# import numpy as np
# import sklearn.metrics as sk
# import plotly.graph_objects as go


# class OODMetrics(Metric):
#     """
#     Class to calculate OOD Metrics - FPR@recalllevel, AUROC
#     """
#     is_differentiable: bool = False
#     higher_is_better: bool = False
#     full_state_update: bool = False

#     conf: List[Tensor]
#     targets: List[Tensor]

#     def __init__(self, recall_level: float, pos_label: int, ignore_index: int=None, **kwargs) -> None:
#         """The False Positive Rate at x% Recall or TPR metric.

#         Args:
#             recall_level (float): The recall level at which to compute the FPR. Usually 0.95 or 0.99.
#             pos_label (int): The positive label.
#             ignore_index (int, optional): Index to ignore in calculations. Defaults to None.
#             kwargs: Additional arguments to pass to the metric class.

#         Reference:
#             Ref Link:
#             - https://github.com/hendrycks/anomaly-seg
#             - https://github.com/ENSTA-U2IS-AI/torch-uncertainty/blob/1c906132748b5ea7fe2e1436de163397ebf4aa01/torch_uncertainty/metrics/classification/fpr95.py
#         """
#         super().__init__(**kwargs)

#         if recall_level < 0 or recall_level > 1:
#             raise ValueError(f"Recall level must be between 0 and 1. Got {recall_level}.")
#         self.recall_level = recall_level
#         self.pos_label = pos_label
#         self.ignore_index = ignore_index
#         self.add_state("conf", [], dist_reduce_fx="cat")
#         self.add_state("targets", [], dist_reduce_fx="cat")

#         rank_zero_warn(f"Metric `FPR{int(recall_level*100)}` will save all targets and predictions in buffer. For large datasets this may lead to large memory footprint.")

#     def fpr_at_tpr(self, scores, labels, recall_level=0.95):
#         """
#         Calculate the False Positive Rate at a certain True Positive Rate
#         """
#         # results will be sorted in reverse order
#         fpr, tpr, thresholds = binary_roc(scores, labels)  # thresholds=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
#         idx = torch.searchsorted(tpr, recall_level)
#         if idx == fpr.shape[0]:
#             return fpr[idx - 1], thresholds[idx - 1], fpr, tpr, thresholds
#         idx = np.searchsorted(tpr, recall_level, side='right')
#         return fpr[idx], thresholds[idx], fpr, tpr, thresholds

#     def update(self, conf: Tensor, target: Tensor) -> None:
#         """Update the metric state.

#         Args:
#             conf (Tensor): The confidence scores.
#             target (Tensor): The target labels.
#         """
#         self.conf.append(conf.contiguous().view(-1))
#         self.targets.append(target.contiguous().view(-1))

#     def compute(self) -> Tensor:
#         """Compute the actual False Positive Rate at x% Recall(TPR), AUROC, AUPR

#         Returns:
#             Tensor: The value of the FPRx.
#         """
#         print("[INFO]: Computing the OOD Metrics")
#         roc_curve_data, aupr_in_curve_data, aupr_out_curve_data = [], [], []
#         conf = dim_zero_cat(self.conf).cpu() # .numpy()
#         targets = dim_zero_cat(self.targets).cpu() # .numpy()
#         print("\nOOD Metrics")
#         print(f"conf.shape: {conf.shape}")
#         print(f"targets.shape: {targets.shape}")

#         if self.ignore_index is not None:
#             mask = targets != self.ignore_index
#             conf = conf * mask
#             targets = targets * mask

#         scores, idxs = torch.sort(conf, stable=True)
#         labels = targets[idxs]

#         print(f"scores.shape: {scores.device}, labels.device: {labels.device}")

#         auroc = binary_auroc(scores, labels)

#         fpr_at_recall_level, threshold_at_recall_level, fpr, tpr, thresholds_roc = self.fpr_at_tpr(scores, labels)  # (recall / TPR / Sensitivtiy) vs (FPR / 1-specificity) for different thresholds of classification scores.
#         roc_curve_data.append(fpr.tolist())
#         roc_curve_data.append(tpr.tolist())
#         roc_curve_data.append(thresholds_roc.tolist())

#         return fpr_at_recall_level.item(), threshold_at_recall_level.item(), auroc.item(), roc_curve_data

#     def reset(self) -> None:
#         """Reset the metric state."""
#         self.conf = []
#         self.targets = []


In [6]:
from typing import List

import torch
from torch import Tensor
from torchmetrics import Metric
from torchmetrics.functional.classification import binary_auroc, binary_precision_recall_curve, binary_roc
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.compute import auc
from torchmetrics.utilities.data import dim_zero_cat

import numpy as np
import sklearn.metrics as sk
import plotly.graph_objects as go

class OODMetrics(Metric):
    """
    Class to calculate OOD Metrics - FPR@recall_level, AUROC
    """
    is_differentiable: bool = False
    higher_is_better: bool = True
    full_state_update: bool = False

    conf: List[Tensor]
    targets: List[Tensor]

    def __init__(self, recall_level: float, pos_label: int, ignore_index: int = None, **kwargs) -> None:
        """The False Positive Rate at x% Recall (TPR) metric.

        Args:
            recall_level (float): The recall level (TPR) at which to compute the FPR. Usually 0.95 or 0.99.
            pos_label (int): The positive label.
            ignore_index (int, optional): Index to ignore in calculations. Defaults to None.
            kwargs: Additional arguments to pass to the metric class.
        """
        super().__init__(**kwargs)

        if not (0 <= recall_level <= 1):
            raise ValueError(f"Recall level must be between 0 and 1. Got {recall_level}.")
        self.recall_level = recall_level
        self.pos_label = pos_label
        self.ignore_index = ignore_index

        self.add_state("conf", [], dist_reduce_fx="cat")
        self.add_state("targets", [], dist_reduce_fx="cat")

        rank_zero_warn(
            f"Metric `FPR@{int(recall_level*100)}% Recall` will save all targets and predictions in memory. "
            "For large datasets, this may lead to large memory usage."
        )

    def fpr_at_tpr(self, scores: Tensor, labels: Tensor, recall_level: float = 0.95):
        """
        Calculate the False Positive Rate at a certain True Positive Rate (recall level).
        """
        fpr, tpr, thresholds = binary_roc(scores, labels)

        # Search for the index where TPR is equal to or just above the recall level
        idx = torch.searchsorted(tpr, torch.tensor(recall_level), right=False)

        # Handle edge case where TPR does not reach the recall level
        if idx >= len(fpr):
            idx = len(fpr) - 1
        return fpr[idx].item(), thresholds[idx].item(), fpr, tpr, thresholds

    def update(self, conf: Tensor, target: Tensor) -> None:
        """Update the metric state with new confidence scores and target labels.

        Args:
            conf (Tensor): The confidence scores.
            target (Tensor): The target labels (0 for correctly classified, 1 for misclassified).
        """
        self.conf.append(conf.contiguous().view(-1))
        self.targets.append(target.contiguous().view(-1))

    def compute(self) -> Tensor:
        """Compute the actual False Positive Rate at x% Recall (TPR), AUROC, and ROC curve data.

        Returns:
            Tuple: (fpr_at_recall_level, threshold_at_recall_level, auroc, roc_curve_data)
        """
        print("[INFO]: Computing the OOD Metrics")

        # Concatenate all the confidence scores and target labels across updates
        conf = dim_zero_cat(self.conf).cpu()
        targets = dim_zero_cat(self.targets).cpu()

        print(f"conf.shape: {conf.shape}, targets.shape: {targets.shape}")

        # Handle ignored indices, if applicable
        if self.ignore_index is not None:
            valid_mask = targets != self.ignore_index
            conf = conf[valid_mask]
            targets = targets[valid_mask]

        # Sort scores and corresponding labels for ROC calculation
        sorted_scores, idxs = torch.sort(conf, descending=True)
        sorted_labels = targets[idxs]

        # Calculate AUROC
        auroc = binary_auroc(sorted_scores, sorted_labels)

        # Calculate FPR at the specified recall level (TPR)
        fpr_at_recall_level, threshold_at_recall_level, fpr, tpr, thresholds_roc = self.fpr_at_tpr(
            sorted_scores, sorted_labels, recall_level=self.recall_level
        )

        # Gather ROC curve data (FPR, TPR, thresholds) for plotting
        roc_curve_data = {
            "fpr": fpr.tolist(),
            "tpr": tpr.tolist(),
            "thresholds": thresholds_roc.tolist(),
        }
        return fpr_at_recall_level, threshold_at_recall_level, auroc.item(), roc_curve_data

    def reset(self) -> None:
        """Reset the metric state."""
        self.conf = []
        self.targets = []

In [16]:
def plot_roc_curve(roc_curve_data, auroc, fpr_at_recall_level, recall_level):
    """
    Plot AUROC curve for OOD Test.
    """
    print("[INFO]: Plotting ROC curve...")
    fig = go.Figure()
    # ROC Curve data Unpacking
    fpr, tpr = roc_curve_data["fpr"], roc_curve_data["tpr"]
    # Plot ROC Curve
    fig.add_trace(go.Scatter(x=fpr, y=tpr, mode='lines', name=f'ROC Curve (area = {auroc:.2f})'))
    fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], mode='lines', name='Random Classifier', line=dict(dash='dash')))
    fig.add_trace(go.Scatter(x=[fpr_at_recall_level], y=[recall_level], mode='markers', marker=dict(color='red', size=10),
                             name=f'FPR@Recall ({recall_level}): {fpr_at_recall_level:.2f}'))
    # Annotation for AUROC
    fig.add_annotation(x=0.95, y=0.1, text=f'AUROC = {auroc:.3f}', showarrow=False, font=dict(size=12, color="black"))

    # Update layout
    fig.update_layout(title="ROC Curve", xaxis_title="False Positive Rate", yaxis_title="True Positive Rate",)
    return fig

In [8]:
class TestMetrics:
    def __init__(self, args, gpu_id):
        self.device = gpu_id
        self.dataset_name = args.eval_dataset_name
        self.ignore_index = args.ignore_index

        self.misclassification_dir = args.misclassification_dir
        self.labels_dir = args.labels_dir
        self.misclassification_name = args.misclassification_name.lower()
        self.is_miscclass_ood = True
        self.misclass_recall_level = args.misclass_recall_level

        self.misclassification_dataset = LoadMisclassificationDataset(conf_file=f"{self.misclassification_dir}/max_pred_probs", label_file=f"{self.labels_dir}")

        self.misclassification_metrics = OODMetrics(recall_level=self.misclass_recall_level, pos_label=1, ignore_index=self.ignore_index)                                      # AUROC, AUPR, FPR@thresh
        self.scalar_metrics_to_log = [f"fpr_at_tpr_{self.misclassification_name}_misclass", f"auroc_{self.misclassification_name}_misclass"]

    def misclassification_evaluation(self):
        print("[INFO]: Misclassification Evaluation...")
        self.misclassification_metrics.to(self.device)
        # Load the confidence scores and ood labels from the whole dataset with from the outputs of the given path
        self.misclassification_dataloader = DataLoader(self.misclassification_dataset, batch_size=1, pin_memory=True, shuffle=False)
        for i, data_dict in enumerate(tqdm(self.misclassification_dataloader)):
            conf_score = data_dict["scores"].squeeze(dim=1).to(self.device)
            target = data_dict["label"].squeeze(dim=1).to(self.device)
            # Update OOD Metrics
            self.misclassification_metrics.update(conf_score, target)
            print(f"updated: {i}")
            del data_dict
        print("computing")
        # Compute OOD Metrics
        fpr_at_recall_level, threshold_at_recall_level, auroc, roc_curve_data = self.misclassification_metrics.compute()
        # Plot each curve
        fig_roc = plot_roc_curve(roc_curve_data, auroc, fpr_at_recall_level, self.misclass_recall_level)
        self.misclassification_metrics.reset()
        return fpr_at_recall_level, auroc, fig_roc

    def main(self):
        summary_metrics = {}
        figs = {}

        # Calculate Misclassifications
        print("[INFO]: Calculating Misclassifications")
        fpr_at_recall_level, auroc, fig_roc = self.misclassification_evaluation()
        summary_metrics[f"fpr_at_tpr_{self.misclassification_name}_misclass"] = fpr_at_recall_level
        summary_metrics[f"auroc_{self.misclassification_name}_misclass"] = auroc

        print(f"[INFO]: Summary Metrics:\n {summary_metrics}")
        print(f"[INFO]: Plotting fig_roc_{self.misclassification_name}_misclass....")
        return fig_roc

### Labels

In [9]:
!cp -r /content/drive/MyDrive/Colab\ Notebooks/Master\ Thesis/runs/fishyscapes_labels .
!cp -r /content/drive/MyDrive/Colab\ Notebooks/Master\ Thesis/runs/lostandfound_labels .

## BASELINE

In [10]:
!cp -r /content/drive/MyDrive/Colab\ Notebooks/Master\ Thesis/runs/Baseline .

### Fishyscapes

In [11]:
from argparse import Namespace
import torch


# Create a Namespace instance
args = Namespace()

# Manually add arguments
args.eval_dataset_name = "fishyscapes"
args.ignore_index = 255
args.misclassification_dir = "/content/Baseline/fishyscapes"
args.labels_dir = "/content/fishyscapes_labels"
args.misclassification_name = "fishyscapes"
args.misclass_recall_level = 0.95

if torch.cuda.is_available():
    gpu_id = 0
else:
    gpu_id = "cpu"

In [17]:
calc_metrics = TestMetrics(args, gpu_id)
fig_roc = calc_metrics.main()


Initializing Misclassification Dataset...
[INFO]: Getting the preds for OOD dataset



Metric `FPR@95% Recall` will save all targets and predictions in memory. For large datasets, this may lead to large memory usage.



[INFO]: Calculating Misclassifications
[INFO]: Misclassification Evaluation...


100%|██████████| 30/30 [00:00<00:00, 201.94it/s]


updated: 0
updated: 1
updated: 2
updated: 3
updated: 4
updated: 5
updated: 6
updated: 7
updated: 8
updated: 9
updated: 10
updated: 11
updated: 12
updated: 13
updated: 14
updated: 15
updated: 16
updated: 17
updated: 18
updated: 19
updated: 20
updated: 21
updated: 22
updated: 23
updated: 24
updated: 25
updated: 26
updated: 27
updated: 28
updated: 29
computing
[INFO]: Computing the OOD Metrics
conf.shape: torch.Size([62914560]), targets.shape: torch.Size([62914560])
[INFO]: Plotting ROC curve...
[INFO]: Summary Metrics:
 {'fpr_at_tpr_fishyscapes_misclass': 0.11174800992012024, 'auroc_fishyscapes_misclass': 0.9728127717971802}
[INFO]: Plotting fig_roc_fishyscapes_misclass....


## EDL

In [18]:
!cp -r /content/drive/MyDrive/Colab\ Notebooks/Master\ Thesis/runs/EDL .

### Fishyscapes

In [19]:
from argparse import Namespace
import torch


# Create a Namespace instance
args = Namespace()

# Manually add arguments
args.eval_dataset_name = "fishyscapes"
args.ignore_index = 255
args.misclassification_dir = "/content/EDL/fishyscapes"
args.labels_dir = "/content/fishyscapes_labels"
args.misclassification_name = "fishyscapes"
args.misclass_recall_level = 0.95

if torch.cuda.is_available():
    gpu_id = 0
else:
    gpu_id = "cpu"

In [20]:
calc_metrics = TestMetrics(args, gpu_id)
fig_roc = calc_metrics.main()

Initializing Misclassification Dataset...
[INFO]: Getting the preds for OOD dataset
[INFO]: Calculating Misclassifications
[INFO]: Misclassification Evaluation...


100%|██████████| 30/30 [00:00<00:00, 219.65it/s]

updated: 0
updated: 1
updated: 2
updated: 3
updated: 4
updated: 5
updated: 6
updated: 7
updated: 8
updated: 9
updated: 10
updated: 11
updated: 12
updated: 13
updated: 14
updated: 15
updated: 16
updated: 17
updated: 18
updated: 19
updated: 20
updated: 21
updated: 22
updated: 23
updated: 24
updated: 25
updated: 26
updated: 27
updated: 28
updated: 29
computing
[INFO]: Computing the OOD Metrics





conf.shape: torch.Size([62914560]), targets.shape: torch.Size([62914560])
[INFO]: Plotting ROC curve...
[INFO]: Summary Metrics:
 {'fpr_at_tpr_fishyscapes_misclass': 0.7618088722229004, 'auroc_fishyscapes_misclass': 0.8926395177841187}
[INFO]: Plotting fig_roc_fishyscapes_misclass....


### Lostandfound

In [22]:
from argparse import Namespace


# Create a Namespace instance
args = Namespace()

# Manually add arguments
args.eval_dataset_name = "lostandfound"
args.ignore_index = 255
args.misclassification_dir = "/content/EDL/lostandfound"
args.labels_dir = "/content/lostandfound_labels"
args.misclassification_name = "lostandfound"
args.misclass_recall_level = 0.95

if torch.cuda.is_available():
    gpu_id = 0
else:
    gpu_id = "cpu"

In [23]:
calc_metrics = TestMetrics(args, gpu_id)
fig_roc = calc_metrics.main()
# fig_roc.show()

Initializing Misclassification Dataset...
[INFO]: Getting the preds for OOD dataset
[INFO]: Calculating Misclassifications
[INFO]: Misclassification Evaluation...


 23%|██▎       | 23/100 [00:00<00:00, 196.47it/s]

updated: 0
updated: 1
updated: 2
updated: 3
updated: 4
updated: 5
updated: 6
updated: 7
updated: 8
updated: 9
updated: 10
updated: 11
updated: 12
updated: 13
updated: 14
updated: 15
updated: 16
updated: 17
updated: 18
updated: 19
updated: 20
updated: 21
updated: 22
updated: 23
updated: 24
updated: 25
updated: 26
updated: 27
updated: 28
updated: 29
updated: 30
updated: 31
updated: 32
updated: 33
updated: 34
updated: 35
updated: 36
updated: 37
updated: 38
updated: 39
updated: 40


 65%|██████▌   | 65/100 [00:00<00:00, 203.16it/s]

updated: 41
updated: 42
updated: 43
updated: 44
updated: 45
updated: 46
updated: 47
updated: 48
updated: 49
updated: 50
updated: 51
updated: 52
updated: 53
updated: 54
updated: 55
updated: 56
updated: 57
updated: 58
updated: 59
updated: 60
updated: 61
updated: 62
updated: 63
updated: 64
updated: 65
updated: 66
updated: 67
updated: 68
updated: 69
updated: 70
updated: 71
updated: 72
updated: 73
updated: 74
updated: 75
updated: 76
updated: 77
updated: 78
updated: 79
updated: 80


100%|██████████| 100/100 [00:00<00:00, 195.77it/s]


updated: 81
updated: 82
updated: 83
updated: 84
updated: 85
updated: 86
updated: 87
updated: 88
updated: 89
updated: 90
updated: 91
updated: 92
updated: 93
updated: 94
updated: 95
updated: 96
updated: 97
updated: 98
updated: 99
computing
[INFO]: Computing the OOD Metrics
conf.shape: torch.Size([209715200]), targets.shape: torch.Size([209715200])
[INFO]: Plotting ROC curve...
[INFO]: Summary Metrics:
 {'fpr_at_tpr_lostandfound_misclass': 0.7984106540679932, 'auroc_lostandfound_misclass': 0.8848692774772644}
[INFO]: Plotting fig_roc_lostandfound_misclass....


## EDL_BASELINE_FINETUNE

In [24]:
!cp -r /content/drive/MyDrive/Colab\ Notebooks/Master\ Thesis/runs/EDL_BASELINE_FINETUNE .

### Fishyscapes

In [25]:
from argparse import Namespace
import torch


# Create a Namespace instance
args = Namespace()

# Manually add arguments
args.eval_dataset_name = "fishyscapes"
args.ignore_index = 255
args.misclassification_dir = "/content/EDL_BASELINE_FINETUNE/fishyscapes"
args.labels_dir = "/content/fishyscapes_labels"
args.misclassification_name = "fishyscapes"
args.misclass_recall_level = 0.95

if torch.cuda.is_available():
    gpu_id = 0
else:
    gpu_id = "cpu"

In [26]:
calc_metrics = TestMetrics(args, gpu_id)
fig_roc = calc_metrics.main()

Initializing Misclassification Dataset...
[INFO]: Getting the preds for OOD dataset
[INFO]: Calculating Misclassifications
[INFO]: Misclassification Evaluation...


100%|██████████| 30/30 [00:00<00:00, 241.11it/s]

updated: 0
updated: 1
updated: 2
updated: 3
updated: 4
updated: 5
updated: 6
updated: 7
updated: 8
updated: 9
updated: 10
updated: 11
updated: 12
updated: 13
updated: 14
updated: 15
updated: 16
updated: 17
updated: 18
updated: 19
updated: 20
updated: 21
updated: 22
updated: 23
updated: 24
updated: 25
updated: 26
updated: 27
updated: 28
updated: 29
computing
[INFO]: Computing the OOD Metrics





conf.shape: torch.Size([62914560]), targets.shape: torch.Size([62914560])
[INFO]: Plotting ROC curve...
[INFO]: Summary Metrics:
 {'fpr_at_tpr_fishyscapes_misclass': 0.9458062052726746, 'auroc_fishyscapes_misclass': 0.8877522945404053}
[INFO]: Plotting fig_roc_fishyscapes_misclass....


### Lostandfound

In [None]:
from argparse import Namespace


# Create a Namespace instance
args = Namespace()

# Manually add arguments
args.eval_dataset_name = "lostandfound"
args.ignore_index = 255
args.misclassification_dir = "/content/EDL_BASELINE_FINETUNE/lostandfound"
args.labels_dir = "/content/lostandfound_labels"
args.misclassification_name = "lostandfound"
args.misclass_recall_level = 0.95

if torch.cuda.is_available():
    gpu_id = 0
else:
    gpu_id = "cpu"

In [None]:
calc_metrics = TestMetrics(args, gpu_id)
fig_roc = calc_metrics.main()
# fig_roc.show()

Initializing Misclassification Dataset...
[INFO]: Getting the preds for OOD dataset



Metric `FPR95` will save all targets and predictions in buffer. For large datasets this may lead to large memory footprint.



[INFO]: Calculating Misclassifications
[INFO]: Misclassification Evaluation...


  0%|          | 0/100 [00:00<?, ?it/s]

updated: 0
updated: 1
updated: 2
updated: 3
updated: 4
updated: 5
updated: 6
updated: 7
updated: 8


 34%|███▍      | 34/100 [00:00<00:00, 164.41it/s]

updated: 9
updated: 10
updated: 11
updated: 12
updated: 13
updated: 14
updated: 15
updated: 16
updated: 17
updated: 18
updated: 19
updated: 20
updated: 21
updated: 22
updated: 23
updated: 24
updated: 25
updated: 26
updated: 27
updated: 28
updated: 29
updated: 30
updated: 31
updated: 32
updated: 33
updated: 34
updated: 35
updated: 36
updated: 37
updated: 38
updated: 39
updated: 40
updated: 41
updated: 42
updated: 43


 70%|███████   | 70/100 [00:00<00:00, 169.40it/s]

updated: 44
updated: 45
updated: 46
updated: 47
updated: 48
updated: 49
updated: 50
updated: 51
updated: 52
updated: 53
updated: 54
updated: 55
updated: 56
updated: 57
updated: 58
updated: 59
updated: 60
updated: 61
updated: 62
updated: 63
updated: 64
updated: 65
updated: 66
updated: 67
updated: 68
updated: 69
updated: 70
updated: 71
updated: 72
updated: 73
updated: 74
updated: 75
updated: 76
updated: 77
updated: 78


100%|██████████| 100/100 [00:00<00:00, 171.06it/s]


updated: 79
updated: 80
updated: 81
updated: 82
updated: 83
updated: 84
updated: 85
updated: 86
updated: 87
updated: 88
updated: 89
updated: 90
updated: 91
updated: 92
updated: 93
updated: 94
updated: 95
updated: 96
updated: 97
updated: 98
updated: 99
computing
[INFO]: Computing the OOD Metrics

OOD Metrics
conf.shape: torch.Size([209715200])
targets.shape: torch.Size([209715200])
scores.shape: cpu, labels.device: cpu
[INFO]: Plotting ROC curve...
[INFO]: Summary Metrics:
 {'fpr_at_tpr_lostandfound_misclass': 0.9620679020881653, 'auroc_lostandfound_misclass': 0.8787328004837036}
[INFO]: Plotting fig_roc_lostandfound_misclass....
