In [None]:
import torch
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR100
from torch.utils.data import DataLoader, Subset
from collections import deque
import random

# Load CIFAR-100 and split it into tasks for incremental learning
def get_task_dataloaders(num_tasks=10, batch_size=32):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    dataset = CIFAR100(root='./data', train=True, download=True, transform=transform)
    test_dataset = CIFAR100(root='./data', train=False, download=True, transform=transform)

    num_classes_per_task = len(dataset.classes) // num_tasks
    train_loaders, test_loaders = [], []

    for task_id in range(num_tasks):
        task_classes = list(range(task_id * num_classes_per_task, (task_id + 1) * num_classes_per_task))
        train_indices = [i for i, label in enumerate(dataset.targets) if label in task_classes]
        test_indices = [i for i, label in enumerate(test_dataset.targets) if label in task_classes]

        train_subset = Subset(dataset, train_indices)
        test_subset = Subset(test_dataset, test_indices)

        train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
        test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False)

        train_loaders.append(train_loader)
        test_loaders.append(test_loader)

    print(f"Loaded {num_tasks} tasks with {num_classes_per_task} classes each.")
    return train_loaders, test_loaders

# Define replay buffer to hold samples from previous tasks
class ReplayBuffer:
    def __init__(self, max_size=2000, num_tasks=10):
        self.max_size_per_task = max_size // num_tasks
        self.buffers = {task_id: deque(maxlen=self.max_size_per_task) for task_id in range(num_tasks)}

    def add(self, task_id, data):
        self.buffers[task_id].append(data)

    def sample(self, batch_size):
        samples_per_task = batch_size // len(self.buffers)
        batch_data = []

        for task_id, buffer in self.buffers.items():
            if len(buffer) > 0:
                batch_data.extend(random.sample(list(buffer), min(len(buffer), samples_per_task)))

        if len(batch_data) > 0:
            batch_inputs, batch_labels = zip(*batch_data)
            return torch.stack(batch_inputs), torch.tensor(batch_labels)
        else:
            return None, None

# Initialize data loaders and replay buffer
num_tasks = 10
batch_size = 32
train_loaders, test_loaders = get_task_dataloaders(num_tasks=num_tasks, batch_size=batch_size)
replay_buffer = ReplayBuffer(max_size=2000, num_tasks=num_tasks)



Files already downloaded and verified
Files already downloaded and verified
Loaded 10 tasks with 10 classes each.


In [None]:
import torch.nn as nn

class FeatureExtractor(nn.Module):
    def __init__(self, input_dim=3*32*32, feature_dim=256):
        super(FeatureExtractor, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Linear(512, feature_dim)
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten input
        return self.fc(x)

class TaskHead(nn.Module):
    def __init__(self, feature_dim, output_dim):
        super(TaskHead, self).__init__()
        self.fc = nn.Linear(feature_dim, output_dim)

    def forward(self, x):
        return self.fc(x)

class ROWModel(nn.Module):
    def __init__(self, input_dim=3*32*32, feature_dim=256, num_classes_per_task=10, num_tasks=10):
        super(ROWModel, self).__init__()
        self.feature_extractor = FeatureExtractor(input_dim, feature_dim)
        # WP heads for classification
        self.wp_heads = nn.ModuleList([TaskHead(feature_dim, num_classes_per_task) for _ in range(num_tasks)])
        # OOD heads for OOD detection
        self.ood_heads = nn.ModuleList([TaskHead(feature_dim, num_classes_per_task + 1) for _ in range(num_tasks)])

    def forward(self, x, task_id, head_type='wp'):
        features = self.feature_extractor(x)
        if head_type == 'wp':
            return self.wp_heads[task_id](features)
        elif head_type == 'ood':
            return self.ood_heads[task_id](features)



In [None]:
import torch.optim as optim
import torch.nn.functional as F

def train_feature_and_ood_head(model, task_id, optimizer, train_loader, replay_buffer, epochs=10):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        correct_ind, correct_ood = 0, 0
        total_ind, total_ood = 0, 0

        for ind_data, ind_labels in train_loader:
            ind_data, ind_labels = ind_data.to(device), ind_labels.to(device)

            # Sample OOD data from replay buffer
            ood_data, ood_labels = replay_buffer.sample(len(ind_data))
            if ood_data is None or ood_labels is None:
                # If replay buffer is empty, skip OOD training for this batch
                combined_data, combined_labels = ind_data, ind_labels
            else:
                ood_data, ood_labels = ood_data.to(device), ood_labels.to(device)
                ood_labels[:] = model.ood_heads[task_id].fc.out_features - 1  # Set OOD labels to a unique class
                combined_data = torch.cat([ind_data, ood_data])
                combined_labels = torch.cat([ind_labels, ood_labels])

            # Forward pass and compute loss
            optimizer.zero_grad()
            ood_outputs = model(combined_data, task_id, head_type='ood')
            loss = F.cross_entropy(ood_outputs, combined_labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            # Calculate accuracy for IND and OOD
            _, predicted = torch.max(ood_outputs, 1)
            correct_ind += (predicted[:len(ind_labels)] == ind_labels).sum().item()
            if ood_data is not None:
                correct_ood += (predicted[len(ind_labels):] == ood_labels).sum().item()
                total_ood += len(ood_labels)
            total_ind += len(ind_labels)

        ind_accuracy = 100 * correct_ind / total_ind if total_ind > 0 else 0
        ood_accuracy = 100 * correct_ood / total_ood if total_ood > 0 else 0
        print(f"Epoch [{epoch+1}/{epochs}] - Loss: {total_loss / len(train_loader):.4f}, "
              f"IND Accuracy: {ind_accuracy:.2f}%, OOD Accuracy: {ood_accuracy:.2f}%")

# Initialize model, optimizer, and device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ROWModel().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Run training for task 0
train_feature_and_ood_head(
    model,
    task_id=0,
    optimizer=optimizer,
    train_loader=train_loaders[0],
    replay_buffer=replay_buffer,
    epochs=10
)



Epoch [1/10] - Loss: 1.8123, IND Accuracy: 38.10%, OOD Accuracy: 0.00%
Epoch [2/10] - Loss: 1.4943, IND Accuracy: 49.54%, OOD Accuracy: 0.00%
Epoch [3/10] - Loss: 1.3180, IND Accuracy: 55.34%, OOD Accuracy: 0.00%
Epoch [4/10] - Loss: 1.2134, IND Accuracy: 60.48%, OOD Accuracy: 0.00%
Epoch [5/10] - Loss: 1.1041, IND Accuracy: 64.08%, OOD Accuracy: 0.00%
Epoch [6/10] - Loss: 0.9851, IND Accuracy: 67.90%, OOD Accuracy: 0.00%
Epoch [7/10] - Loss: 0.8587, IND Accuracy: 72.80%, OOD Accuracy: 0.00%
Epoch [8/10] - Loss: 0.7603, IND Accuracy: 75.48%, OOD Accuracy: 0.00%
Epoch [9/10] - Loss: 0.6748, IND Accuracy: 78.42%, OOD Accuracy: 0.00%
Epoch [10/10] - Loss: 0.6448, IND Accuracy: 80.14%, OOD Accuracy: 0.00%


In [None]:
def fine_tune_wp_head(model, task_id, optimizer, train_loader, epochs=10):
    """
    Fine-tunes the WP head for a specific task using only in-distribution data.

    Args:
    - model: ROWModel instance.
    - task_id: ID of the current task.
    - optimizer: Optimizer for training the WP head.
    - train_loader: DataLoader for current task's IND data.
    - epochs: Number of epochs for fine-tuning.
    """
    # Freeze the feature extractor to prevent updates during WP head fine-tuning
    for param in model.feature_extractor.parameters():
        param.requires_grad = False

    model.train()
    for epoch in range(epochs):
        total_loss = 0
        correct, total = 0, 0

        for ind_data, ind_labels in train_loader:
            ind_data, ind_labels = ind_data.to(device), ind_labels.to(device)

            # Forward pass through the WP head
            optimizer.zero_grad()
            wp_outputs = model(ind_data, task_id, head_type='wp')

            # Compute cross-entropy loss for WP head on IND data
            loss = F.cross_entropy(wp_outputs, ind_labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            # Calculate accuracy
            _, predicted = torch.max(wp_outputs, 1)
            correct += (predicted == ind_labels).sum().item()
            total += ind_labels.size(0)

        accuracy = 100 * correct / total if total > 0 else 0
        print(f"Epoch [{epoch+1}/{epochs}] - WP Head Loss: {total_loss / len(train_loader):.4f}, "
              f"Accuracy: {accuracy:.2f}%")

    # Unfreeze the feature extractor after fine-tuning if needed for further training
    for param in model.feature_extractor.parameters():
        param.requires_grad = True

# Example usage for fine-tuning WP head of task 0
optimizer = optim.Adam(model.wp_heads[0].parameters(), lr=0.001)  # Optimizer for only WP head parameters

# Fine-tune WP head for task 0
fine_tune_wp_head(
    model,
    task_id=0,
    optimizer=optimizer,
    train_loader=train_loaders[0],
    epochs=10
)


Epoch [1/10] - WP Head Loss: 0.5163, Accuracy: 85.28%
Epoch [2/10] - WP Head Loss: 0.2826, Accuracy: 91.34%
Epoch [3/10] - WP Head Loss: 0.2484, Accuracy: 92.18%
Epoch [4/10] - WP Head Loss: 0.2296, Accuracy: 92.84%
Epoch [5/10] - WP Head Loss: 0.2131, Accuracy: 93.50%
Epoch [6/10] - WP Head Loss: 0.2014, Accuracy: 93.78%
Epoch [7/10] - WP Head Loss: 0.1945, Accuracy: 94.26%
Epoch [8/10] - WP Head Loss: 0.1861, Accuracy: 94.38%
Epoch [9/10] - WP Head Loss: 0.1789, Accuracy: 94.78%
Epoch [10/10] - WP Head Loss: 0.1703, Accuracy: 95.14%


In [None]:
import torch
import torch.nn.functional as F
import torch.optim as optim

def fine_tune_previous_ood_heads(model, current_task_id, optimizer, replay_buffer, epochs=3, batch_size=32, device='cpu'):
    """
    Fine-tunes the OOD heads for previous tasks using replay data.

    Args:
    - model: ROWModel instance.
    - current_task_id: The task ID that was most recently trained (all prior tasks will be fine-tuned).
    - optimizer: Optimizer for training.
    - replay_buffer: ReplayBuffer instance holding replay data.
    - epochs: Number of epochs to fine-tune.
    - batch_size: Size of the batch to sample from replay buffer.
    - device: The device to perform computations on (e.g., 'cpu' or 'cuda').
    """
    model.train()

    # Loop over previous tasks
    for task_id in range(current_task_id):
        for epoch in range(epochs):
            total_loss = 0

            # Sample a batch from the replay buffer
            batch = replay_buffer.sample(batch_size=batch_size)

            # Debugging output
            print(f"Task {task_id}, Epoch {epoch + 1}: Raw batch output:", batch)

            # Check if the batch is valid
            if batch is None or len(batch) < 2 or batch[0] is None or batch[1] is None:
                print(f"No valid replay data available for task {task_id}, skipping fine-tuning.")
                continue

            # Unpack the batch into data and labels
            ood_data, ood_labels = batch

            # Ensure ood_data and ood_labels are PyTorch tensors
            try:
                if not isinstance(ood_data, torch.Tensor):
                    ood_data = torch.tensor(ood_data, dtype=torch.float32)
                if not isinstance(ood_labels, torch.Tensor):
                    ood_labels = torch.tensor(ood_labels, dtype=torch.long)

                # Debugging types and shapes
                print("Type of ood_data:", type(ood_data), "Shape:", ood_data.shape)
                print("Type of ood_labels:", type(ood_labels), "Shape:", ood_labels.shape)

                # Move data to the appropriate device
                ood_data, ood_labels = ood_data.to(device), ood_labels.to(device)

                # Set OOD labels to a unique OOD class index for this task
                ood_labels.fill_(model.ood_heads[task_id].fc.out_features - 1)

                # Reset gradients
                optimizer.zero_grad()

                # Forward pass through the model
                ood_outputs = model(ood_data, task_id, head_type='ood')

                # Compute loss for OOD head fine-tuning
                loss = F.cross_entropy(ood_outputs, ood_labels)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

                print(f"Task {task_id}, Epoch [{epoch + 1}/{epochs}], Fine-Tuning OOD Loss: {total_loss:.4f}")

            except Exception as e:
                print(f"Error processing batch for task {task_id}, Epoch {epoch + 1}: {e}")

# Example usage for fine-tuning previous OOD heads after training a new task
optimizer = optim.Adam(model.parameters(), lr=0.001)
fine_tune_previous_ood_heads(
    model,
    current_task_id=1,
    optimizer=optimizer,
    replay_buffer=replay_buffer,
    epochs=10,
    batch_size=32,
    device='cuda' if torch.cuda.is_available() else 'cpu'
)






Task 0, Epoch 1: Raw batch output: (None, None)
No valid replay data available for task 0, skipping fine-tuning.
Task 0, Epoch 2: Raw batch output: (None, None)
No valid replay data available for task 0, skipping fine-tuning.
Task 0, Epoch 3: Raw batch output: (None, None)
No valid replay data available for task 0, skipping fine-tuning.
Task 0, Epoch 4: Raw batch output: (None, None)
No valid replay data available for task 0, skipping fine-tuning.
Task 0, Epoch 5: Raw batch output: (None, None)
No valid replay data available for task 0, skipping fine-tuning.
Task 0, Epoch 6: Raw batch output: (None, None)
No valid replay data available for task 0, skipping fine-tuning.
Task 0, Epoch 7: Raw batch output: (None, None)
No valid replay data available for task 0, skipping fine-tuning.
Task 0, Epoch 8: Raw batch output: (None, None)
No valid replay data available for task 0, skipping fine-tuning.
Task 0, Epoch 9: Raw batch output: (None, None)
No valid replay data available for task 0, skipp