In [None]:
!pip install torch



In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class FeatureExtractor(nn.Module):
    def __init__(self, input_dim, feature_dim):
        super(FeatureExtractor, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, feature_dim)
        )

    def forward(self, x):
        return self.fc(x)

class TaskHead(nn.Module):
    def __init__(self, feature_dim, output_dim):
        super(TaskHead, self).__init__()
        self.fc = nn.Linear(feature_dim, output_dim)

    def forward(self, x):
        return self.fc(x)

class ROWModel(nn.Module):
    def __init__(self, input_dim, feature_dim, num_tasks, task_classes):
        super(ROWModel, self).__init__()
        self.feature_extractor = FeatureExtractor(input_dim, feature_dim)
        # Create WP and OOD heads for each task
        self.wp_heads = nn.ModuleList([TaskHead(feature_dim, task_classes[i]) for i in range(num_tasks)])
        self.ood_heads = nn.ModuleList([TaskHead(feature_dim, task_classes[i] + 1) for i in range(num_tasks)])

    def forward(self, x, task_id, head_type='wp'):
        features = self.feature_extractor(x)
        if head_type == 'wp':
            return self.wp_heads[task_id](features)
        elif head_type == 'ood':
            return self.ood_heads[task_id](features)

# Function to calculate accuracy
def evaluate_accuracy(model, data_loader, task_id, head_type='wp', device='cpu'):
    model.eval()  # Set model to evaluation mode
    correct, total = 0, 0

    with torch.no_grad():  # Disable gradient calculation for evaluation
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs, task_id, head_type=head_type)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    accuracy = 100 * correct / total if total > 0 else 0
    print(f"Accuracy for Task {task_id} ({head_type}): {accuracy:.2f}%")
    return accuracy

# Example parameters (adjust as needed)
input_dim = 784  # Example for a flattened 28x28 input
feature_dim = 256
num_tasks = 5
task_classes = [2, 2, 2, 2, 2]  # Adjust based on number of classes per task

# Initialize the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ROWModel(input_dim, feature_dim, num_tasks, task_classes).to(device)

# Example data loader for testing accuracy (replace with actual data loader)
from torch.utils.data import DataLoader, TensorDataset

# Dummy data for testing
x_test = torch.randn(100, input_dim)  # 100 samples, each of size `input_dim`
y_test = torch.randint(0, 2, (100,))  # 100 labels, binary classification for example

# Create a DataLoader
test_loader = DataLoader(TensorDataset(x_test, y_test), batch_size=32)

# Check accuracy for task 0 on WP head
evaluate_accuracy(model, test_loader, task_id=0, head_type='wp', device=device)



Accuracy for Task 0 (wp): 52.00%


52.0

In [None]:
import torch
import torch.nn.functional as F
import random
from torch.utils.data import DataLoader, TensorDataset

# Example replay buffer structure (for storing replay data)
class ReplayBuffer:
    def __init__(self, max_size):
        self.buffer = []
        self.max_size = max_size

    def add(self, data):
        self.buffer.append(data)
        if len(self.buffer) > self.max_size:
            self.buffer.pop(0)

    def sample(self, batch_size):
        if len(self.buffer) == 0:
            return None, None
        sampled_data = random.sample(self.buffer, min(len(self.buffer), batch_size))
        data, labels = zip(*sampled_data)
        return torch.stack(data), torch.tensor(labels)

# Training function for Step 1
def train_feature_and_ood(model, task_id, optimizer, current_task_data, replay_buffer, epochs=5, batch_size=32):
    """
    Trains the feature extractor and OOD head using current task data (IND) and replay buffer (OOD).

    Args:
    - model: ROWModel instance.
    - task_id: Current task identifier.
    - optimizer: Optimizer for training.
    - current_task_data: DataLoader containing current task's data.
    - replay_buffer: ReplayBuffer instance with data from previous tasks.
    - epochs: Number of epochs to train.
    - batch_size: Batch size for training.
    """
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        batch_count = 0

        for ind_data, ind_labels in current_task_data:
            ind_data, ind_labels = ind_data.to(device), ind_labels.to(device)

            # Sample OOD data from replay buffer
            ood_data, ood_labels = replay_buffer.sample(batch_size)
            if ood_data is None or ood_labels is None:
                # Skip training on OOD if replay buffer is empty
                print("Replay buffer is empty, skipping OOD training for this batch.")
                continue

            ood_data, ood_labels = ood_data.to(device), ood_labels.to(device)

            # Combine IND and OOD data
            combined_data = torch.cat([ind_data, ood_data])
            combined_labels = torch.cat([ind_labels, ood_labels])

            # Set OOD labels to a unique class index in the OOD head
            combined_labels[len(ind_labels):] = model.ood_heads[task_id].fc.out_features - 1

            # Train on OOD head
            optimizer.zero_grad()
            ood_outputs = model(combined_data, task_id, head_type='ood')

            # Compute cross-entropy loss for both IND and OOD data
            loss = F.cross_entropy(ood_outputs, combined_labels)
            loss.backward()
            optimizer.step()

            # Accumulate loss for epoch-level summary
            epoch_loss += loss.item()
            batch_count += 1

            # Print batch-level loss
            print(f"Epoch [{epoch+1}/{epochs}], Batch [{batch_count}], Batch Loss: {loss.item():.4f}")

        # Print average loss for the epoch
        avg_epoch_loss = epoch_loss / batch_count if batch_count > 0 else 0
        print(f"Epoch [{epoch+1}/{epochs}] completed. Average Loss: {avg_epoch_loss:.4f}\n")

# Example usage
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Assume `model` is already defined and moved to the device
model.to(device)

# Define optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Define batch size
batch_size = 32

# Dummy current task data and replay buffer
input_dim = 784  # Example input dimension, such as a flattened 28x28 image
x_current = torch.randn(100, input_dim)  # 100 samples
y_current = torch.randint(0, 2, (100,))  # Binary labels for example
current_task_data = DataLoader(TensorDataset(x_current, y_current), batch_size=batch_size)

# Initialize replay buffer and add some dummy data
replay_buffer = ReplayBuffer(max_size=200)
for i in range(50):
    x_replay = torch.randn(input_dim)
    y_replay = torch.tensor(2)  # Example OOD label
    replay_buffer.add((x_replay, y_replay))

# Train model on the task
train_feature_and_ood(model, task_id=0, optimizer=optimizer, current_task_data=current_task_data, replay_buffer=replay_buffer, epochs=5, batch_size=batch_size)




Epoch [1/5], Batch [1], Batch Loss: 1.0778
Epoch [1/5], Batch [2], Batch Loss: 1.0185
Epoch [1/5], Batch [3], Batch Loss: 0.9639
Epoch [1/5], Batch [4], Batch Loss: 0.5875
Epoch [1/5] completed. Average Loss: 0.9119

Epoch [2/5], Batch [1], Batch Loss: 0.7373
Epoch [2/5], Batch [2], Batch Loss: 0.6982
Epoch [2/5], Batch [3], Batch Loss: 0.7139
Epoch [2/5], Batch [4], Batch Loss: 0.2400
Epoch [2/5] completed. Average Loss: 0.5973

Epoch [3/5], Batch [1], Batch Loss: 0.4897
Epoch [3/5], Batch [2], Batch Loss: 0.4232
Epoch [3/5], Batch [3], Batch Loss: 0.4251
Epoch [3/5], Batch [4], Batch Loss: 0.0931
Epoch [3/5] completed. Average Loss: 0.3578

Epoch [4/5], Batch [1], Batch Loss: 0.2345
Epoch [4/5], Batch [2], Batch Loss: 0.1825
Epoch [4/5], Batch [3], Batch Loss: 0.1864
Epoch [4/5], Batch [4], Batch Loss: 0.0294
Epoch [4/5] completed. Average Loss: 0.1582

Epoch [5/5], Batch [1], Batch Loss: 0.0873
Epoch [5/5], Batch [2], Batch Loss: 0.0632
Epoch [5/5], Batch [3], Batch Loss: 0.0657
Epo

In [None]:
import torch
import torch.nn.functional as F

def fine_tune_wp_head(model, task_id, optimizer, current_task_data, device, epochs=5):
    """
    Fine-tunes the WP head for the current task using only the in-distribution data.

    Args:
    - model: ROWModel instance.
    - task_id: Current task identifier.
    - optimizer: Optimizer for training.
    - current_task_data: DataLoader containing current task's data.
    - device: Torch device (CPU or GPU).
    - epochs: Number of epochs to train.
    """
    # Freeze the feature extractor to keep it unchanged during WP head fine-tuning
    for param in model.feature_extractor.parameters():
        param.requires_grad = False

    # Check if DataLoader has data
    if len(current_task_data) == 0:
        print("DataLoader is empty.")
        return

    # Begin fine-tuning process
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        correct, total = 0, 0

        for batch_idx, (ind_data, ind_labels) in enumerate(current_task_data):
            ind_data, ind_labels = ind_data.to(device), ind_labels.to(device)

            # Forward pass for WP head
            optimizer.zero_grad()
            wp_outputs = model(ind_data, task_id, head_type='wp')

            # Compute cross-entropy loss for WP head on IND data
            loss = F.cross_entropy(wp_outputs, ind_labels)
            loss.backward()
            optimizer.step()

            # Accumulate loss for epoch
            epoch_loss += loss.item()

            # Calculate accuracy for the current batch
            _, predicted = torch.max(wp_outputs, 1)
            correct += (predicted == ind_labels).sum().item()
            total += ind_labels.size(0)

            # Print batch-level debug information
            print(f"Epoch [{epoch+1}/{epochs}], Batch [{batch_idx+1}], Batch Loss: {loss.item():.4f}")

        # Calculate and print average loss and accuracy for the epoch
        avg_loss = epoch_loss / len(current_task_data)
        accuracy = 100 * correct / total if total > 0 else 0
        print(f"Epoch [{epoch+1}/{epochs}] completed. WP Head Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")

    # Unfreeze the feature extractor for further training if needed
    for param in model.feature_extractor.parameters():
        param.requires_grad = True


In [None]:
def fine_tune_ood_heads(model, optimizer, replay_buffer, current_task_id, epochs=3, batch_size=32):
    """
    Fine-tunes the OOD heads of all previous tasks using the replay buffer.

    Args:
    - model: ROWModel instance.
    - optimizer: Optimizer for training.
    - replay_buffer: ReplayBuffer instance with data from previous tasks.
    - current_task_id: The current task identifier.
    - epochs: Number of epochs to train.
    - batch_size: Batch size for training.
    """
    model.train()
    for task_id in range(current_task_id):  # Only fine-tune up to the previous task heads
        for epoch in range(epochs):
            epoch_loss = 0
            for batch_data in replay_buffer.sample(batch_size):
                ood_data, ood_labels = batch_data  # OOD data and labels from replay buffer
                ood_data, ood_labels = ood_data.to(device), ood_labels.to(device)

                # Set OOD labels to the special OOD class index
                ood_labels[:] = model.ood_heads[task_id].out_features - 1

                optimizer.zero_grad()
                ood_outputs = model(ood_data, task_id, head_type='ood')

                # Calculate OOD loss
                loss = F.cross_entropy(ood_outputs, ood_labels)
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()

            avg_loss = epoch_loss / len(replay_buffer.buffer)
            print(f"Task {task_id}, Epoch [{epoch+1}/{epochs}], OOD Head Loss: {avg_loss:.4f}")

# Example usage for fine-tuning OOD heads after learning each new task
# fine_tune_ood_heads(model, optimizer, replay_buffer, current_task_id=current_task_id)


In [None]:
import torch.nn.functional as F

def inference(model, x, num_tasks):
    """
    Makes a prediction using both WP and OOD heads for all tasks.

    Args:
    - model: ROWModel instance.
    - x: Input tensor for inference.
    - num_tasks: Total number of tasks learned so far.

    Returns:
    - Predicted label for the input instance x.
    """
    x = x.to(device)
    model.eval()
    max_prob = 0
    predicted_label = None

    with torch.no_grad():
        for task_id in range(num_tasks):
            # Calculate WP probability for the task
            wp_output = model(x, task_id, head_type='wp')
            wp_prob = F.softmax(wp_output, dim=1)

            # Calculate OOD probability for the task
            ood_output = model(x, task_id, head_type='ood')
            ood_prob = F.softmax(ood_output, dim=1)

            # Multiply WP and OOD probabilities as per the theory in the paper
            combined_prob = wp_prob * ood_prob[:, :-1]  # Exclude the OOD class from wp_prob
            max_task_prob, task_pred = combined_prob.max(dim=1)

            # Update if this task has a higher probability than previous tasks
            if max_task_prob.item() > max_prob:
                max_prob = max_task_prob.item()
                predicted_label = task_pred.item()

    return predicted_label


In [None]:
def evaluate_model(model, test_data_loaders, num_tasks):
    """
    Evaluates the model on test data for each task.

    Args:
    - model: ROWModel instance.
    - test_data_loaders: A list of DataLoaders, one for each task.
    - num_tasks: Number of tasks the model has learned.

    Returns:
    - Accuracy for each task.
    """
    accuracies = []

    for task_id in range(num_tasks):
        correct = 0
        total = 0
        for x, y in test_data_loaders[task_id]:
            x, y = x.to(device), y.to(device)
            predicted_label = inference(model, x, num_tasks)
            correct += (predicted_label == y.item()).sum().item()
            total += y.size(0)

        accuracy = 100 * correct / total
        accuracies.append(accuracy)
        print(f"Accuracy for Task {task_id}: {accuracy:.2f}%")

    return accuracies


In [None]:
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR100
from torch.utils.data import DataLoader, Subset

def get_task_dataloaders(num_tasks, batch_size=32):
    """
    Splits CIFAR-100 into multiple tasks, each with a distinct set of classes.

    Args:
    - num_tasks: Number of tasks to split the dataset into.
    - batch_size: Batch size for data loaders.

    Returns:
    - A list of DataLoaders, one for each task.
    """
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # Load CIFAR-100 dataset
    dataset = CIFAR100(root='./data', train=True, download=True, transform=transform)
    test_dataset = CIFAR100(root='./data', train=False, download=True, transform=transform)

    # Split dataset into tasks
    num_classes_per_task = len(dataset.classes) // num_tasks
    task_dataloaders = []
    test_dataloaders = []

    for task_id in range(num_tasks):
        # Get class indices for this task
        task_classes = list(range(task_id * num_classes_per_task, (task_id + 1) * num_classes_per_task))

        # Filter dataset by these classes
        train_indices = [i for i, label in enumerate(dataset.targets) if label in task_classes]
        test_indices = [i for i, label in enumerate(test_dataset.targets) if label in task_classes]

        train_subset = Subset(dataset, train_indices)
        test_subset = Subset(test_dataset, test_indices)

        # Create DataLoader for this task
        train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
        test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False)

        task_dataloaders.append(train_loader)
        test_dataloaders.append(test_loader)

    return task_dataloaders, test_dataloaders

# Example usage
num_tasks = 5  # Define the number of tasks
batch_size = 32
train_loaders, test_loaders = get_task_dataloaders(num_tasks, batch_size)


Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169M/169M [00:02<00:00, 69.4MB/s]


Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified
