In [29]:
import os

import torch
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.utils.data import DataLoader, TensorDataset
from avalanche.benchmarks.datasets import MNIST
from avalanche.benchmarks.datasets.dataset_utils import default_dataset_location
from avalanche.benchmarks.utils import as_classification_dataset, AvalancheDataset


from avalanche.benchmarks.classic import SplitMNIST
from avalanche.benchmarks.generators import nc_benchmark
from avalanche.models import SimpleMLP
from avalanche.training import Naive
from avalanche.training.plugins import (
    ReplayPlugin,
    EWCPlugin,
    AGEMPlugin,
    EvaluationPlugin,
)
from avalanche.evaluation.metrics import (
    forgetting_metrics,
    accuracy_metrics,
    loss_metrics,
    timing_metrics,
    cpu_usage_metrics,
    confusion_matrix_metrics,
    disk_usage_metrics,
)
from avalanche.logging import InteractiveLogger

from pytorch_ood.detector import OpenMax, EnergyBased, Entropy
from pytorch_ood.utils import OODMetrics
from torch.utils.data import DataLoader


import numpy as np
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
from copy import deepcopy
from pytorch_ood.model import WideResNet

import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from torchvision.transforms import ToTensor
from avalanche.benchmarks.datasets import MNIST
from avalanche.benchmarks.datasets.dataset_utils import default_dataset_location
import torch
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset, ConcatDataset
from pytorch_ood.utils import OODMetrics
from pytorch_ood.dataset.img import Textures
from pytorch_ood.utils import ToUnknown
from pytorch_ood.model import WideResNet
from avalanche.models import SimpleMLP
import torch.nn as nn
import torch.optim as optim

In [121]:
import torch
import numpy as np
from torchvision.datasets import MNIST
from torch.utils.data import TensorDataset, DataLoader
from torchvision.transforms.functional import to_tensor


def remove_channel_dimension(dataset):
    modified_data = []
    for data in dataset:
        modified_data.append(data.squeeze(-1))  # Remove the last channel dimension
    return torch.stack(modified_data)  # Stack the modified data and return as a tensor


def load_data(batch_size=64, seed=0):
    # Location to save/load the MNIST dataset
    datadir = default_dataset_location("mnist")

    # Load the non-corrupted MNIST dataset
    train_MNIST = MNIST(datadir, train=True, download=True)
    test_MNIST = MNIST(datadir, train=False, download=True)

    # Extract train and test data/labels
    train_data = train_MNIST.data.float() / 255  # Normalize data to [0, 1]
    train_labels = train_MNIST.targets
    test_data = test_MNIST.data.float() / 255  # Normalize data to [0, 1]
    test_labels = test_MNIST.targets

    # Load corrupted data and labels
    c_test_images = (
        np.load("./brightness/test/test_images.npy").astype(np.float32) / 255
    )  # Normalize
    c_test_labels = np.load("./brightness/test/test_labels.npy")
    # c_train_images = (
    #     np.load("./brightness/train/train_images.npy").astype(np.float32) / 255
    # )  # Normalize
    # c_train_labels = np.load("./brightness/train/train_labels.npy")

    # Convert NumPy arrays to tensors and remove channel dimension for images
    c_test_images_tensor = remove_channel_dimension(torch.tensor(c_test_images))
    # c_train_images_tensor = remove_channel_dimension(torch.tensor(c_train_images))

    # Apply the specified mapping to the corrupted labels
    def map_labels(labels):
        return torch.tensor([10 if label == 0 else label+10 for label in labels])

    c_test_labels_tensor = map_labels(c_test_labels)
    # c_train_labels_tensor = map_labels(c_train_labels)

    # Combine non-corrupted and corrupted data
    combined_test_data = torch.cat(
        [test_data, c_test_images_tensor], dim=0
    )  # Add channel dimension
    combined_test_labels = torch.cat([test_labels, c_test_labels_tensor], dim=0)
    
    # combined_train_data = torch.cat(
    #     [train_data, c_train_images_tensor], dim=0
    # )  # Add channel dimension
    # combined_train_labels = torch.cat([train_labels, c_train_labels_tensor], dim=0)

    # Create TensorDataset objects
    # train_dataset = TensorDataset(combined_train_data, combined_train_labels)
    
    combined_test_dataset = TensorDataset(combined_test_data, combined_test_labels)

    # Create a train Dataset
    train_dataset = TensorDataset(train_data, train_labels)
    # Create DataLoader objects
    train_dataLoader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    
    
    test_dataLoader = DataLoader(
        combined_test_dataset, batch_size=batch_size, shuffle=False
    )

    # desired_order = [0, -10, 1, -1, 2, -2, 3, -3, 4, -4, 5, -5, 6, -6, 7, -7, 8, -8, 9, -9]
    
    desired_order = [0, 10, 1, 11, 2, 12, 3, 13, 4, 14, 5, 15, 6, 16, 7, 17, 8, 18, 9, 19]

    scenario = nc_benchmark(
        combined_test_dataset,
        combined_test_dataset,
        n_experiences=10,
        shuffle=False,
        seed=seed,
        fixed_class_order=desired_order,
        task_labels=True,
    )

    return train_dataLoader, test_dataLoader, scenario 


# Example usage
batch_size = 64
train_dataLoader, test_dataLoader, scenario= load_data(batch_size=batch_size)

# labels_tensor = test_dataLoader.dataset.tensors[1]

# # Count the number of occurrences of each class
# unique_classes, counts = torch.unique(labels_tensor, return_counts=True)

# # Print the counts for each class
# for class_index, count in zip(unique_classes, counts):
#     print(f"Class {class_index.item()}: {count.item()} samples")
print(scenario.original_classes_in_exp)
print(scenario.task_labels)
print("Train DataLoader and Test DataLoader have been created successfully")

[{0, 10}, {1, 11}, {2, 12}, {3, 13}, {4, 14}, {5, 15}, {16, 6}, {17, 7}, {8, 18}, {9, 19}]
[[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]
Train DataLoader and Test DataLoader have been created successfully


In [127]:
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(
            64, 10
        )  # Output size matches the number of desired classes

    def forward(self, x):
        x = x.view(-1, 28 * 28)  # Flatten the input images
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def train_model(train_loader):
    # model = SimpleNN()
    model = SimpleMLP()

    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Training loop
    num_epochs = 10
    for epoch in range(num_epochs):
        running_loss = 0.0
        for images, labels in train_loader:
            optimizer.zero_grad()  # Zero the gradients
            outputs = model(images)  # Forward pass
            loss = criterion(outputs, labels)  # Calculate the loss
            loss.backward()  # Backward pass
            optimizer.step()  # Update weights
            running_loss += loss.item()

        # # Print average loss per epoch
        # print(
        #     f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(train_loader)}"
        # )
    return model

In [129]:
def map_test(labels):
    return torch.tensor([-1 if label > 9 else label for label in labels])


test_stream = scenario.test_stream

large_metrics = OODMetrics()
model = train_model(train_dataLoader)
detector = OpenMax(model, tailsize=25, alpha=5, euclid_weight=0.5)
detector.fit(train_dataLoader)

print(scenario.original_classes_in_exp)

for exp in scenario.train_stream:
    print(exp.classes_in_this_experience)
    test_loader = DataLoader(exp.dataset, batch_size=128, shuffle=True)
    narrow_metrics = OODMetrics()
    for batch in test_loader:
        if len(batch) == 2:
            x, y = batch
        else:
            x, y, *_ = batch
        y = map_test(y)
        large_metrics.update(detector(x),y)
        narrow_metrics.update(detector(x),y)
    print(narrow_metrics.compute())
print(large_metrics.compute())

[{0, 10}, {1, 11}, {2, 12}, {3, 13}, {4, 14}, {5, 15}, {16, 6}, {17, 7}, {8, 18}, {9, 19}]
[0, 10]
{'AUROC': 0.9558725357055664, 'AUPR-IN': 0.9455128908157349, 'AUPR-OUT': 0.9594336748123169, 'FPR95TPR': 0.19081632792949677}
[1, 11]
{'AUROC': 0.7978178858757019, 'AUPR-IN': 0.848194420337677, 'AUPR-OUT': 0.7332760691642761, 'FPR95TPR': 0.8440528512001038}
[2, 12]
{'AUROC': 0.8337783813476562, 'AUPR-IN': 0.8581138849258423, 'AUPR-OUT': 0.8255239725112915, 'FPR95TPR': 0.6346899271011353}
[3, 13]
{'AUROC': 0.8885226249694824, 'AUPR-IN': 0.9102345705032349, 'AUPR-OUT': 0.8757404088973999, 'FPR95TPR': 0.5267326831817627}
[4, 14]
{'AUROC': 0.7842042446136475, 'AUPR-IN': 0.7521378397941589, 'AUPR-OUT': 0.8108094334602356, 'FPR95TPR': 0.5997963547706604}
[5, 15]
{'AUROC': 0.8753204941749573, 'AUPR-IN': 0.8914667963981628, 'AUPR-OUT': 0.8674235343933105, 'FPR95TPR': 0.5728699564933777}
[16, 6]
{'AUROC': 0.8973678350448608, 'AUPR-IN': 0.8991154432296753, 'AUPR-OUT': 0.8994457721710205, 'FPR95TPR'