In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from tqdm import tqdm
import numpy as np
import torch.nn.functional as F
import random
import os
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# Define Bayesian Convolutional Layer (unchanged)

# Define Bayesian Convolutional Layer to output both mu and log_sigma
import torch
import torch.nn as nn
import torch.nn.functional as F

# Define Bayesian Convolutional Layer to output both mu and log_sigma
import torch
import torch.nn as nn
import torch.nn.functional as F

# Define Bayesian Convolutional Layer to output both mu and log_sigma
class BayesianConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True):
        super(BayesianConv2d, self).__init__()
        self.stride = stride
        self.padding = padding
        
        # Parameters for weight mean and log standard deviation
        self.weight_mu = nn.Parameter(torch.Tensor(out_channels, in_channels, kernel_size, kernel_size).normal_(0, 0.1))
        self.weight_log_sigma = nn.Parameter(torch.Tensor(out_channels, in_channels, kernel_size, kernel_size).normal_(-3, 0.1))
        
        if bias:
            self.bias_mu = nn.Parameter(torch.Tensor(out_channels).normal_(0, 0.1))
            self.bias_log_sigma = nn.Parameter(torch.Tensor(out_channels).normal_(-3, 0.1))
        else:
            self.bias_mu = None
            self.bias_log_sigma = None

    def forward(self, x):
        # Sample weights and biases
        weight = self.weight_mu + torch.exp(self.weight_log_sigma) * torch.randn_like(self.weight_log_sigma)
        bias = self.bias_mu + torch.exp(self.bias_log_sigma) * torch.randn_like(self.bias_log_sigma) if self.bias_mu is not None else None
        
        # Convolution operation
        predicted_mu = F.conv2d(x, weight, bias, stride=self.stride, padding=self.padding)
        predicted_log_sigma = self.weight_log_sigma.mean()  # Simplify to single log_sigma value for this layer
        
        # Return both mu and log_sigma
        return predicted_mu, predicted_log_sigma

# Define Bayesian Fully Connected Layer to output mu and log_sigma
class BayesianLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super(BayesianLinear, self).__init__()
        
        # Parameters for weight mean and log standard deviation
        self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features).normal_(0, 0.1))
        self.weight_log_sigma = nn.Parameter(torch.Tensor(out_features, in_features).normal_(-3, 0.1))
        
        self.bias_mu = nn.Parameter(torch.Tensor(out_features).normal_(0, 0.1))
        self.bias_log_sigma = nn.Parameter(torch.Tensor(out_features).normal_(-3, 0.1))

    def forward(self, x):
        # Sample weights and biases
        weight = self.weight_mu + torch.exp(self.weight_log_sigma) * torch.randn_like(self.weight_log_sigma)
        bias = self.bias_mu + torch.exp(self.bias_log_sigma) * torch.randn_like(self.bias_log_sigma)
        
        # Compute the predicted mean and log_sigma
        predicted_mu = F.linear(x, weight, bias)
        predicted_log_sigma = self.weight_log_sigma.mean()  # Simplify to a single log sigma value
        
        return predicted_mu, predicted_log_sigma

# Define Bayesian ResNet Block
class BayesianBasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BayesianBasicBlock, self).__init__()
        self.conv1 = BayesianConv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = BayesianConv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=True)
        self.bn2 = nn.BatchNorm2d(planes)
        
        self.shortcut = None
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                BayesianConv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x_tuple):
        x, prev_log_sigma = x_tuple
        
        # First convolutional layer
        mu1, log_sigma1 = self.conv1(x)
        out = F.relu(self.bn1(mu1))
        
        # Second convolutional layer
        mu2, log_sigma2 = self.conv2(out)
        out = self.bn2(mu2)
        
        # Handle shortcut connection
        if self.shortcut:
            identity = x
            for layer in self.shortcut:
                if isinstance(layer, BayesianConv2d):
                    identity, _ = layer(identity)
                else:
                    identity = layer(identity)
        else:
            identity = x
            
        out += identity
        out = F.relu(out)
        
        # Average the log_sigma values
        current_log_sigma = (log_sigma1 + log_sigma2) / 2
        if prev_log_sigma is not None:
            current_log_sigma = (current_log_sigma + prev_log_sigma) / 2
            
        return out, current_log_sigma

class BayesianResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(BayesianResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = BayesianConv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=True)
        self.bn1 = nn.BatchNorm2d(64)
        
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        
        # Calculate output size
        temp_input = torch.randn(1, 3, 32, 32)
        temp_output, _ = self._forward_conv_layers(temp_input)
        self.flattened_size = temp_output.view(temp_output.size(0), -1).size(1)
        
        self.linear = BayesianLinear(self.flattened_size, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.ModuleList(layers)

    def _forward_conv_layers(self, x):
        mu1, log_sigma1 = self.conv1(x)
        x = F.relu(self.bn1(mu1))
        
        log_sigmas = [log_sigma1]
        current_log_sigma = log_sigma1
        
        for layer_blocks in [self.layer1, self.layer2, self.layer3, self.layer4]:
            for block in layer_blocks:
                x, current_log_sigma = block((x, current_log_sigma))
                log_sigmas.append(current_log_sigma)
        
        x = F.avg_pool2d(x, 4)
        avg_log_sigma = sum(log_sigmas) / len(log_sigmas)
        
        return x, avg_log_sigma

    def forward(self, x):
        x, log_sigma_conv = self._forward_conv_layers(x)
        x = x.view(x.size(0), -1)
        
        mu, log_sigma_fc = self.linear(x)
        combined_log_sigma = (log_sigma_conv + log_sigma_fc) / 2
        
        return mu, combined_log_sigma

# Define BayesianResNet18 for CIFAR-10 with 10 classes
def BayesianResNet18(num_classes=10):
    return BayesianResNet(BayesianBasicBlock, [2, 2, 2, 2], num_classes=num_classes)



# Experience Replay Buffer (unchanged)
class ExperienceReplayBuffer:
    def __init__(self, buffer_size=500):
        self.buffer = []
        self.buffer_size = buffer_size

    def add_samples(self, samples):
        self.buffer.extend(samples)
        if len(self.buffer) > self.buffer_size:
            self.buffer = self.buffer[-self.buffer_size:]

    def get_samples(self, batch_size):
        return random.sample(self.buffer, min(len(self.buffer), batch_size))
    

def nll_gaussian_loss(predicted_logits, predicted_log_sigma, targets):
    """
    Cross entropy loss with uncertainty
    predicted_logits: predicted class logits (batch_size x num_classes)
    predicted_log_sigma: predicted log standard deviation
    targets: ground truth labels (batch_size)
    """
    # Convert targets to one-hot encoding
    num_classes = predicted_logits.size(1)
    targets_one_hot = F.one_hot(targets, num_classes).float()
    
    # Apply softmax to get probabilities
    probs = F.softmax(predicted_logits, dim=1)
    
    # Calculate cross entropy loss
    ce_loss = F.cross_entropy(predicted_logits, targets)
    
    # Add uncertainty penalty
    uncertainty_penalty = torch.mean(torch.exp(-predicted_log_sigma))
    
    return ce_loss + 0.1 * uncertainty_penalty

def elbo_loss(predicted_logits, predicted_log_sigma, targets, model, kl_weight=1e-6):
    """
    ELBO loss function for Bayesian classification
    """
    # Likelihood loss (cross entropy)
    nll = nll_gaussian_loss(predicted_logits, predicted_log_sigma, targets)
    
    # KL divergence regularization term
    kl = 0.0
    for module in model.modules():
        if isinstance(module, (BayesianLinear, BayesianConv2d)):
            kl += torch.sum(-0.5 * torch.sum(1 + module.weight_log_sigma - 
                                           module.weight_mu**2 - 
                                           module.weight_log_sigma.exp()))
    
    return nll + kl_weight * kl

# Training and testing adjustments for class-incremental learning
def train_model(model, train_loader, buffer, optimizer, epochs=5, batch_size=64):
    model.train()
    for epoch in range(epochs):
        total_loss = 0.0
        for images, labels in tqdm(train_loader):
            images, labels = images.cuda(), labels.cuda()
            
            # Handle experience replay
            if len(buffer.buffer) > batch_size:
                replay_samples = buffer.get_samples(batch_size // 2)
                replay_images, replay_labels = zip(*replay_samples)
                replay_images = torch.stack(replay_images).cuda()
                replay_labels = torch.tensor(replay_labels, dtype=torch.long).cuda()
                images = torch.cat([images, replay_images], dim=0)
                labels = torch.cat([labels, replay_labels], dim=0)

            optimizer.zero_grad()
            
            # Get model predictions
            predicted_logits, predicted_log_sigma = model(images)
            
            # Calculate losses
            loss_nll = nll_gaussian_loss(predicted_logits, predicted_log_sigma, labels)
            loss = elbo_loss(predicted_logits, predicted_log_sigma, labels, model)
            
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(train_loader):.4f}")

def test_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.cuda(), labels.cuda()
            predicted_logits, _ = model(images)
            _, predicted = torch.max(predicted_logits, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

# Prepare CIFAR-10 dataset for class-incremental learning
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
cifar10_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Split CIFAR-10 into 5 class-incremental tasks (each task has 2 new classes)
tasks = []
test_tasks = []
for i in range(5):
    indices = [j for j in range(len(cifar10_dataset)) if cifar10_dataset.targets[j] in range(i * 2, (i + 1) * 2)]
    tasks.append(Subset(cifar10_dataset, indices))
    
    test_indices = [j for j in range(len(test_dataset)) if test_dataset.targets[j] in range(i * 2, (i + 1) * 2)]
    test_tasks.append(Subset(test_dataset, test_indices))

# Train and test across class-incremental tasks
model = BayesianResNet18(num_classes=10).cuda()
buffer = ExperienceReplayBuffer(buffer_size=500)

for task_id, task_data in enumerate(tasks):
    print(f"\nTraining on Class Increment {task_id + 1}")

    # Prepare data loaders for training and testing
    train_loader = DataLoader(task_data, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_tasks[task_id], batch_size=64, shuffle=False)

    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # Train the model on the current task
    train_model(model, train_loader, buffer, optimizer, epochs=20, batch_size=64)

    # Update experience replay buffer
    for images, labels in train_loader:
        buffer.add_samples([(image, label) for image, label in zip(images, labels)])
    
    # Test on all test sets after each task
    print(f"\nTesting after Task {task_id + 1}")
    for i in range(5):
        test_loader = DataLoader(test_tasks[i], batch_size=64, shuffle=False)
        accuracy = test_model(model, test_loader)
        print(f"Accuracy on Test Set {i + 1}: {accuracy:.2f}%")

print("Training and testing across all class increments completed.")

Files already downloaded and verified
Files already downloaded and verified

Training on Class Increment 1


100%|██████████| 157/157 [00:06<00:00, 23.50it/s]


Epoch [1/1], Loss: 14.1313

Testing after Task 1
Accuracy on Test Set 1: 66.30%
Accuracy on Test Set 2: 0.00%
Accuracy on Test Set 3: 0.00%
Accuracy on Test Set 4: 0.00%
Accuracy on Test Set 5: 0.00%

Training on Class Increment 2


100%|██████████| 157/157 [00:07<00:00, 20.11it/s]


Epoch [1/1], Loss: 14.5749

Testing after Task 2
Accuracy on Test Set 1: 48.70%
Accuracy on Test Set 2: 62.60%
Accuracy on Test Set 3: 0.00%
Accuracy on Test Set 4: 0.00%
Accuracy on Test Set 5: 0.00%

Training on Class Increment 3


100%|██████████| 157/157 [00:07<00:00, 21.76it/s]


Epoch [1/1], Loss: 14.7957

Testing after Task 3
Accuracy on Test Set 1: 0.00%
Accuracy on Test Set 2: 8.20%
Accuracy on Test Set 3: 64.20%
Accuracy on Test Set 4: 0.00%
Accuracy on Test Set 5: 0.00%

Training on Class Increment 4


100%|██████████| 157/157 [00:06<00:00, 23.31it/s]


Epoch [1/1], Loss: 14.9370

Testing after Task 4
Accuracy on Test Set 1: 0.00%
Accuracy on Test Set 2: 0.00%
Accuracy on Test Set 3: 19.10%
Accuracy on Test Set 4: 72.30%
Accuracy on Test Set 5: 0.00%

Training on Class Increment 5


100%|██████████| 157/157 [00:07<00:00, 20.83it/s]


Epoch [1/1], Loss: 14.1569

Testing after Task 5
Accuracy on Test Set 1: 0.00%
Accuracy on Test Set 2: 0.00%
Accuracy on Test Set 3: 0.00%
Accuracy on Test Set 4: 62.25%
Accuracy on Test Set 5: 64.75%
Training and testing across all class increments completed.


: 