In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt

In [7]:
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
MNIST_INPUT_DIM = 784
MNIST_IMG_SIZE = int(np.sqrt(MNIST_INPUT_DIM))

In [8]:
class ExperimentConfig:
    """Configuration for the VCL experiment"""
    def __init__(self):
        # Model hyperparameters
        self.prior_type = 'gaussian'  # 'gaussian' or 'exponential'
        self.task_type = 'classification'
        self.init_prior_mu = 0.0
        self.init_prior_scale = 0.01
        self.input_dim = 784
        self.hidden_dim = 256
        self.num_samples = 10
        
        # Training parameters
        self.num_epochs = 100
        self.batch_size = 256
        self.learning_rate = 0.001
        self.coreset_method = "random"
        self.coreset_size = 200
        self.patience = 5
        self.early_stop_threshold = 1e-4
        
        # Task configuration
        self.tasks = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
        
    def validate(self):
        assert self.prior_type in ['gaussian', 'exponential']
        assert len(self.tasks) > 0

    def get_coreset_fn(self):
        """Get coreset attachment function based on config"""
        if self.coreset_method == "kcenter":
            return attach_kcenter_coreset
        return attach_random_coreset

In [9]:
def kl_gaussian(q, p):
    """KL divergence between two Gaussian distributions"""
    q_mu, q_sigma = q['mu'], q['sigma']
    p_mu, p_sigma = p['mu'], p['sigma']
    
    ratio = (q_sigma / p_sigma) ** 2
    log_ratio = torch.log(ratio)
    mean_term = ratio + ((q_mu - p_mu) / p_sigma) ** 2
    return 0.5 * (log_ratio + mean_term - 1)

def kl_exponential(q, p):
    """KL divergence between two Exponential distributions"""
    q_lambda = 1 / q['sigma']  # Convert scale to rate
    p_lambda = 1 / p['sigma']
    
    log_ratio = torch.log(p_lambda / q_lambda)
    ratio_term = q_lambda / p_lambda
    return log_ratio + ratio_term - 1

In [10]:
class VariationalLayer(nn.Module):
    """Variational continual learning layer with configurable prior"""
    def __init__(self, input_dim, output_dim, config):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.config = config
        
        # Weight parameters
        self.W_mu = nn.Parameter(torch.Tensor(output_dim, input_dim))
        self.W_rho = nn.Parameter(torch.Tensor(output_dim, input_dim))
        
        # Bias parameters
        self.b_mu = nn.Parameter(torch.Tensor(output_dim))
        self.b_rho = nn.Parameter(torch.Tensor(output_dim))
        
        # Initialize priors
        self.W_prior = {
            'mu': torch.tensor(config.init_prior_mu),
            'sigma': torch.tensor(config.init_prior_scale)
        }
        self.b_prior = {
            'mu': torch.tensor(config.init_prior_mu),
            'sigma': torch.tensor(config.init_prior_scale)
        }
        
        self.reset_parameters()
    
    def reset_parameters(self):
        # Initialize means
        nn.init.normal_(self.W_mu, mean=0.0, std=0.1)
        nn.init.normal_(self.b_mu, mean=0.0, std=0.1)
        
        # Initialize rho for variance
        self.W_rho.data.fill_(-3.0)
        self.b_rho.data.fill_(-3.0)
    
    @property
    def W_sigma(self):
        """Convert rho to sigma using softplus"""
        return torch.log1p(torch.exp(self.W_rho))
    
    @property
    def b_sigma(self):
        """Convert rho to sigma using softplus"""
        return torch.log1p(torch.exp(self.b_rho))
    
    def forward(self, x, sample=True):
        """Forward pass with local reparameterization trick"""
        # Calculate activations
        act_mu = F.linear(x, self.W_mu, self.b_mu)
        
        if self.training or sample:
            act_var = F.linear(x**2, self.W_sigma**2, self.b_sigma**2)
            act_std = torch.sqrt(act_var + 1e-16)
            noise = torch.randn_like(act_mu)
            return act_mu + act_std * noise
        return act_mu
    
    def kl_loss(self):
        """Compute KL divergence based on prior type"""
        W_params = {'mu': self.W_mu, 'sigma': self.W_sigma}
        b_params = {'mu': self.b_mu, 'sigma': self.b_sigma}
        
        if self.config.prior_type == 'exponential':
            kl_func = kl_exponential
        else:
            kl_func = kl_gaussian
            
        return (torch.sum(kl_func(W_params, self.W_prior))) + \
               (torch.sum(kl_func(b_params, self.b_prior)))

In [11]:
class BaseModel(nn.Module):
    """Base model for SplitMNIST experiments"""
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.tasks = config.tasks
        
    def forward(self, x, task_id):
        x = x.view(-1, self.config.input_dim)
        for layer in self.hidden_layers:
            x = F.relu(layer(x))
        return self.task_heads[task_id](x)

class VanillaModel(BaseModel):
    """Standard neural network without VCL"""
    def __init__(self, config):
        super().__init__(config)
        
        # Hidden layers
        self.hidden_layers = nn.ModuleList([
            nn.Linear(config.input_dim, config.hidden_dim),
            nn.Linear(config.hidden_dim, config.hidden_dim)
        ])
        
        # Task-specific heads
        self.task_heads = nn.ModuleList([
            nn.Linear(config.hidden_dim, len(task)) for task in self.tasks
        ])

In [12]:
class VCLModel(BaseModel):
    """Variational Continual Learning model"""
    def __init__(self, config):
        super().__init__(config)
        
        # Hidden layers
        self.hidden_layers = nn.ModuleList([
            VariationalLayer(config.input_dim, config.hidden_dim, config),
            VariationalLayer(config.hidden_dim, config.hidden_dim, config)
        ])
        
        # Task-specific heads
        self.task_heads = nn.ModuleList([
            VariationalLayer(config.hidden_dim, len(task), config) for task in self.tasks
        ])
    
    def update_priors(self):
        """Update priors to current posteriors after learning a task"""
        for layer in self.hidden_layers + list(self.task_heads):
            layer.W_prior = {
                'mu': layer.W_mu.detach().clone(),
                'sigma': layer.W_sigma.detach().clone()
            }
            layer.b_prior = {
                'mu': layer.b_mu.detach().clone(),
                'sigma': layer.b_sigma.detach().clone()
            }
    
    def total_kl_loss(self, task_id):
        """Compute total KL loss for current task"""
        kl_loss = 0.0
        for layer in self.hidden_layers:
            kl_loss += layer.kl_loss()
        kl_loss += self.task_heads[task_id].kl_loss()
        return kl_loss

In [19]:
def compute_loss(outputs, targets, config):
    if config.task_type == 'regression': # Gaussian likelihood - MSE loss
        loss = F.mse_loss(outputs.mean(-1), targets)
    else: # Categorical likelihood - NLL loss
        log_output = torch.logsumexp(outputs, dim=-1) - np.log(config.num_samples)
        loss = F.nll_loss(log_output, targets)
    return loss

In [20]:
def train_model(model, dataloader, task_id, config):
    """Train model on a specific task"""
    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
    model.train()
    
    best_loss = float('inf')
    num_consec_worse_epochs = 0
    
    for epoch in range(config.num_epochs):
        epoch_loss = 0.0
        
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
            targets -= model.tasks[task_id][0]  # Reindex targets
            
            optimizer.zero_grad()
            
            # Monte Carlo sampling
            outputs = torch.zeros(inputs.size(0), len(model.tasks[task_id]), 
                                config.num_samples, device=DEVICE)
            for i in range(config.num_samples):
                net_out = model(inputs, task_id)
                # if config.task_type == 'classification':
                #     net_out = F.log_softmax(net_out, dim=-1)  # Only for classification
                outputs[..., i] = net_out

            loss = compute_loss(outputs, targets, config)

            # Add KL loss for VCL
            if isinstance(model, VCLModel):
                num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
                loss += model.total_kl_loss(task_id) / num_params
            
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        
        # Early stopping
        if epoch_loss < best_loss - config.early_stop_threshold:
            best_loss = epoch_loss
            num_consec_worse_epochs = 0
        else:
            num_consec_worse_epochs += 1
            if num_consec_worse_epochs >= config.patience:
                break

In [21]:
def evaluate_model(model, dataloader, task_id, config):
    """Evaluate model supporting both classification and regression"""
    model.eval()
    metrics = []
    task = model.tasks[task_id]
    
    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
            
            # Prepare targets based on task type
            if config.task_type == 'regression':
                targets = F.one_hot(targets, num_classes=10).float()
            else:
                targets -= task[0]
            
            outputs = torch.zeros(inputs.size(0), len(task), 
                                config.num_samples, device=DEVICE)
            for i in range(config.num_samples):
                net_out = model(inputs, task_id)
                outputs[..., i] = net_out
            
            # Calculate metric based on task type
            if config.task_type == 'regression':
                pred = outputs.mean(-1)
                rmse = torch.sqrt(F.mse_loss(pred, targets))
                metrics.append(rmse.item())
            else:
                log_output = torch.logsumexp(outputs, dim=-1) - np.log(config.num_samples)
                acc = (log_output.argmax(-1) == targets).float().mean()
                metrics.append(acc.item())
    
    if config.task_type == 'regression':
        return np.mean(metrics)  # Return average RMSE
    else:
        return np.mean(metrics)  # Return average accuracy

In [22]:
from torch.utils.data import DataLoader, SubsetRandomSampler

def get_class_indices(dataset, target_classes):
    """Get indices for specified class targets"""
    idx = torch.zeros_like(dataset.targets, dtype=torch.bool)
    for target in target_classes:
        idx |= (dataset.targets == target)
    return idx

def create_split_dataloaders(class_distribution, batch_size=256):
    """Create train/test dataloaders for each task"""
    transform = transforms.Compose([
        transforms.Resize((MNIST_IMG_SIZE, MNIST_IMG_SIZE)),
        transforms.ToTensor(),
    ])

    # Load MNIST datasets
    train_set = torchvision.datasets.MNIST(
        root="./data", train=True, download=True, transform=transform)
    test_set = torchvision.datasets.MNIST(
        root="./data", train=False, download=True, transform=transform)

    dataloaders = []
    
    for classes in class_distribution:
        # Train loader
        train_idx = torch.where(get_class_indices(train_set, classes))[0]
        train_loader = DataLoader(
            train_set,
            batch_size=batch_size,
            sampler=SubsetRandomSampler(train_idx)
        )
        
        # Test loader
        test_idx = torch.where(get_class_indices(test_set, classes))[0]
        test_loader = DataLoader(
            test_set,
            batch_size=batch_size,
            sampler=SubsetRandomSampler(test_idx)
        )
        
        dataloaders.append((train_loader, test_loader))
    
    return dataloaders

In [25]:
def run_experiment(config):
    """Run complete continual learning experiment"""
    # Prepare data
    dataloaders = create_split_dataloaders(config.tasks, config.batch_size)
    
    # Initialize models
    vanilla_model = VanillaModel(config).to(DEVICE)
    vcl_model = VCLModel(config).to(DEVICE)
    
    # Storage for results
    results = {
        'vanilla': np.zeros((len(config.tasks), len(config.tasks))),
        'vcl': np.zeros((len(config.tasks), len(config.tasks)))
    }
    
    # Train and evaluate
    for task_id in range(len(config.tasks)):
        print(f"\n*** Training on Task {task_id+1} ***")
        
        # Train vanilla model
        train_model(vanilla_model, dataloaders[task_id][0], task_id, config)
        
        # Train VCL model
        train_model(vcl_model, dataloaders[task_id][0], task_id, config)
        if isinstance(vcl_model, VCLModel):
            vcl_model.update_priors()
        
        # Evaluate on all tasks
        for eval_task in range(task_id + 1):
            _, test_loader = dataloaders[eval_task]
            
            # Vanilla evaluation
            acc = evaluate(vanilla_model, test_loader, eval_task, config)
            results['vanilla'][task_id, eval_task] = acc
            
            # VCL evaluation
            acc = evaluate(vcl_model, test_loader, eval_task, config)
            results['vcl'][task_id, eval_task] = acc
            
            print(f"Task {eval_task+1} Accuracy - Vanilla: {acc:.4f}, VCL: {acc:.4f}")
    
    return results

def plot_results(results):
    """Plot comparison of results"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Average accuracy plot
    for model in results:
        avg_acc = np.mean(results[model], axis=1)
        ax1.plot(np.arange(len(avg_acc)) + 1, avg_acc, label=model)
    ax1.set_title("Average Accuracy")
    ax1.set_xlabel("Number of Tasks")
    ax1.legend()
    
    # Final accuracy plot
    for model in results:
        final_acc = results[model][-1]
        ax2.plot(np.arange(len(final_acc)) + 1, final_acc, label=model)
    ax2.set_title("Final Accuracy")
    ax2.set_xlabel("Task Number")
    ax2.legend()
    
    plt.tight_layout()
    plt.show()

In [None]:
if __name__ == "__main__":
    # Configuration
    config = ExperimentConfig()
    config.prior_type = 'gaussian'  # Change to 'gaussian' for original VCL
    config.validate()
    
    # Set random seeds
    torch.manual_seed(0)
    np.random.seed(0)
    
    # Run experiment
    results = run_experiment(config)
    plot_results(results)


*** Training on Task 1 ***
