In [1]:
import torchvision
import torch


# import NMIST data
def get_data(batch_size=128):
    train_data = torchvision.datasets.MNIST(
        root="../../data",
        train=True,
        download=True,
        transform=torchvision.transforms.ToTensor(),
    )
    test_data = torchvision.datasets.MNIST(
        root="../../data",
        train=False,
        download=True,
        transform=torchvision.transforms.ToTensor(),
    )
    # truncate the remaining data that doesn't make a full batch

    train_loader = torch.utils.data.DataLoader(
        train_data, batch_size=batch_size, shuffle=True, drop_last=True
    )
    test_loader = torch.utils.data.DataLoader(
        test_data, batch_size=batch_size, shuffle=False, drop_last=True
    )
    return train_loader, test_loader


# make a function that gives a mnist dataloader that gives a continous data of only classes 0-4 and after that 5-9
def get_data_separate(batch_size=128):
    train_data = torchvision.datasets.MNIST(
        root="../../data",
        train=True,
        download=True,
        transform=torchvision.transforms.ToTensor(),
    )
    test_data = torchvision.datasets.MNIST(
        root="../../data",
        train=False,
        download=True,
        transform=torchvision.transforms.ToTensor(),
    )

    train_data_1 = []
    train_data_2 = []
    test_data_1 = []
    test_data_2 = []
    for data, target in train_data:
        if target < 5:
            train_data_1.append((data, target))
        else:
            train_data_2.append((data, target))
    for data, target in test_data:
        if target < 5:
            test_data_1.append((data, target))
        else:
            test_data_2.append((data, target))

    train_loader_1 = torch.utils.data.DataLoader(
        train_data_1, batch_size=batch_size, shuffle=True, drop_last=True
    )
    train_loader_2 = torch.utils.data.DataLoader(
        train_data_2, batch_size=batch_size, shuffle=True, drop_last=True
    )
    test_loader_1 = torch.utils.data.DataLoader(
        test_data_1, batch_size=batch_size, shuffle=False, drop_last=True
    )
    test_loader_2 = torch.utils.data.DataLoader(
        test_data_2, batch_size=batch_size, shuffle=False, drop_last=True
    )

    test_loader = torch.utils.data.DataLoader(
        test_data, batch_size=batch_size, shuffle=False, drop_last=True
    )

    return train_loader_1, train_loader_2, test_loader_1, test_loader_2, test_loader

def get_domain_inc_data(batch_size=128):
    train_data = torchvision.datasets.MNIST(
        root="../../data",
        train=True,
        download=True,
        transform=torchvision.transforms.ToTensor(),
    )
    test_data = torchvision.datasets.MNIST(
        root="../../data",
        train=False,
        download=True,
        transform=torchvision.transforms.ToTensor(),
    )

    transforms = torchvision.transforms.Compose(
        [   torchvision.transforms.ToTensor(),
            torchvision.transforms.RandomRotation(10),
            torchvision.transforms.RandomAffine(0, translate=(0.1, 0.1)),
            torchvision.transforms.RandomAffine(0, shear=10),
            torchvision.transforms.RandomAffine(0, scale=(0.8, 1.2)),
            # add random noise
            torchvision.transforms.Lambda(lambda x: x + 0.01 * torch.randn_like(x)),
         ]
    )

    transformed_train_data = torchvision.datasets.MNIST(
        root="../../data",
        train=True,
        download=True,
        transform=transforms,
    )

    transformed_test_data = torchvision.datasets.MNIST(
        root="../../data",
        train=False,
        download=True,
        transform=transforms,
    )

    train_loader = torch.utils.data.DataLoader(
        train_data, batch_size=batch_size, shuffle=True, drop_last=True
    )
    test_loader = torch.utils.data.DataLoader(
        test_data, batch_size=batch_size, shuffle=False, drop_last=True
    )

    transformed_train_loader = torch.utils.data.DataLoader(
        transformed_train_data, batch_size=batch_size, shuffle=True, drop_last=True
    )
    transformed_test_loader = torch.utils.data.DataLoader(
        transformed_test_data, batch_size=batch_size, shuffle=False, drop_last=True
    )

    return train_loader, transformed_train_loader, test_loader, transformed_test_loader

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR


class NN(nn.Module):
    """
    Neural network class with Hebbian learning mechanisms.
    """

    def __init__(self, input_size, output_size, indexes, inhibition_strength=0.01):
        """
        Initializes the network layers, Hebbian parameters, and hooks for gradient freezing.

        Args:
            input_size (int): Size of the input layer.
            output_size (int): Size of the output layer.
            indexes (list): List of neuron indices to freeze during gradient updates.
        """

        super(NN, self).__init__()

        self.k = 5
        self.inhibition_strength = inhibition_strength
        self.percent_winner = 0.5

        self.linear = nn.ModuleList(
            [
                nn.Linear(input_size, 256),
                nn.Linear(256, 128),
                nn.Linear(128, 64),
                nn.Linear(64, output_size),
            ]
        )

        # Define the Hebbian parameters corresponding to each layer
        self.hebb_params = nn.ModuleList(
            [
                nn.Linear(input_size, 256, bias=False),
                nn.Linear(256, 128, bias=False),
                nn.Linear(128, 64, bias=False),
                nn.Linear(64, output_size, bias=False),
            ]
        )

        for i, heb_param in enumerate(self.hebb_params):
            nn.init.kaiming_normal_(heb_param.weight)
            heb_param.weight.requires_grad = False

        self.indexes = indexes
        self.hidden_size_array = [256, 128, 64, output_size]

        if indexes != [[], [], []]:
            self._register_gradient_hooks(self.indexes)

    def forward(self, x, scalers=None, indexes=None, masks=None, indices_old=None, target=None):
        """
        Defines the forward pass of the network.

        Args:
            x (torch.Tensor): Input tensor.
            indexes (list, optional): New indexes to update for freezing gradients.
            masks (list, optional): Masking values applied to specific layers during forward pass.

        Returns:
            torch.Tensor: Output tensor.
            list: Hebbian scores for each layer.
            list: Hebbian masks for each layer.
        """

        if scalers is not None:
            self.update_indexes(scalers)

        hebbian_scores = []
        hebbian_masks = []
        hebbian_indices = []

        for i, layer in enumerate(self.linear):
            x1 = layer(x)
            is_final_layer = (i == len(self.linear) - 1)

            if masks is not None: # check later why multiplying the mask of the last layer as well causes a drop is accuracy values.
                x1 = torch.mul(x1, masks[i])

            x1 = F.relu(x1) if not is_final_layer else x1

            hebbian_score, hebbian_index, hebbian_mask = self.hebbian_update(
                x, x1, i, indices_old=indices_old[i], target=target if is_final_layer else None
            )

            hebbian_scores.append(hebbian_score)
            hebbian_masks.append(hebbian_mask)
            hebbian_indices.append(hebbian_index)

            x = x1

        x = nn.Softmax(dim=1)(x)

        return x, hebbian_scores, hebbian_indices, hebbian_masks

    def hebb_forward(self, x, indexes=None):
        hebbian_scores = []
        hebbian_masks = []

        for i, layer in enumerate(self.hebb_params):
            x1 = layer(x)
            x1 = F.relu(x1)
            if i < len(self.linear) - 1:
                if indexes is not None:
                    hebbian_score, hebbian_mask = self.hebbian_update(
                        x, x1, i, indices_old=indexes[i]
                    )
                else:
                    hebbian_score, hebbian_mask = self.hebbian_update(x, x1, i)
                hebbian_scores.append(hebbian_score)
                hebbian_masks.append(hebbian_mask)
            x = x1

        return x, hebbian_scores, hebbian_masks

    def hebbian_update(self, x, y, layer_idx, lr=0.00005, threshold=0.5, indices_old=None, target=None):
        """
        Calculates Hebbian-derived scores, masks, and scales for a layer.
        Handles final layer differently using the one-hot target.
        """

        gd_layer = self.linear[layer_idx]
        x_size = self.hidden_size_array[layer_idx] # Size of the output dimension of the layer
        # input_size = gd_layer.weight.size(1) # Size of the input dimension of the layer (x's features)
        batch_size = x.size(0)

        # Check if this is the final layer (or equivalently, if target is provided)
        is_final_layer = (target is not None) # Assuming target is only non-None for the final layer

        # --- Calculate delta_w using appropriate rule (Unsupervised for hidden, Supervised for final) ---
        if not is_final_layer:
            # Using raw y for delta_w calculation as in your last version
            post_T = y.t() # Shape: (output_size, batch_size)
            pre = x        # Shape: (batch_size, input_size)

            y_x = torch.mm(post_T, pre) / batch_size # Shape: (output_size, input_size) - Pre-post correlation average

            y_y_T = torch.mm(post_T, y) / batch_size # Shape: (output_size, output_size) - Post-post correlation average
            # Applying lower triangle for Oja-like / competitive term
            heb_mask_tril = torch.tril(torch.ones(y_y_T.size(), device=y_y_T.device))
            y_y_T_lower = y_y_T * heb_mask_tril

            # Lateral term using current linear weights
            lateral_term = torch.mm(y_y_T_lower, gd_layer.weight.data) # Shape: (output_size, input_size)

            # Hebbian weight update delta
            delta_w = lr * (y_x - lateral_term)

            modified_weights = gd_layer.weight.data + delta_w
            with torch.no_grad():
                 # Normalize rows (incoming weights for each output neuron)
                 norm = torch.norm(modified_weights, p=2, dim=1, keepdim=True) # Shape: (output_size, 1)
                 norm = torch.clamp(norm, min=1e-8) # Avoid division by zero
                 normalized_modified_weights = modified_weights / norm


            # Score for each output neuron is the norm of its incoming weight vector
            hebbian_scores = torch.norm(normalized_modified_weights.detach(), p=2, dim=1) # Shape: (output_size)

            # Apply inhibition from old indices to Hebbian scores before selecting top K
            if indices_old is not None:
                 # scatter expects index to be long tensor
                 hebbian_scores = hebbian_scores.scatter(0, indices_old.long(), float('-inf'))

            # Select top K based on Hebbian scores
            num_winners = int(self.percent_winner * x_size)
            if num_winners == 0 and x_size > 0: num_winners = 1 # Ensure at least one winner if layer exists
            elif x_size == 0: num_winners = 0 # Handle empty layer gracefully

            if num_winners > 0:
                 # topk_indices_hebbian contains the indices (position in the layer's output dimension)
                 _, topk_indices_hebbian = torch.topk(hebbian_scores, num_winners) # Shape: (num_winners)
            else:
                 topk_indices_hebbian = torch.tensor([], dtype=torch.long, device=y.device)


            # Create the Hebbian-based winner mask (shape: 1, output_size) for activation masking
            # This mask will be applied in the *next* forward pass for this task
            hebbian_mask = torch.zeros(1, x_size, device=y.device)
            if num_winners > 0:
                 # Scatter needs index dimension to match self dimension (dim=1 here)
                 hebbian_mask.scatter_(1, topk_indices_hebbian.unsqueeze(0), 1.0) # Indices need shape (1, num_winners)

            # Get the indices of the non-selected neurons (for indices_old in next iter)
            all_indices = torch.arange(x_size, device=y.device)
            indices_non_winners = all_indices[hebbian_mask.squeeze(0) == 0] # Select indices where mask is 0

            scale = torch.zeros_like(gd_layer.weight.data) # Shape: (output_size, input_size)

            if num_winners > 0:
                 scale[topk_indices_hebbian] = normalized_modified_weights[topk_indices_hebbian]

            return scale, indices_non_winners, hebbian_mask


        else:
            x_size = self.hidden_size_array[layer_idx] # Size of the output dimension of the layer
            if target is None:
                 print("Warning: Target is None for final layer Hebbian update.")
                 scale_output = torch.zeros_like(gd_layer.weight.data)
                 hebbian_mask = torch.ones(1, x_size, device=y.device)
                 indices_non_winners = torch.tensor([], dtype=torch.long, device=y.device)
                 return scale_output, indices_non_winners, hebbian_mask

            target_onehot = target
            post_T_supervised = target_onehot.t() # Shape: (output_size, batch_size)
            pre_supervised = x                     # Shape: (batch_size, input_size)

            # Correlation term averaged over batch
            correlation_term = torch.mm(post_T_supervised, pre_supervised) / batch_size # Shape: (output_size, input_size)

            scale_output = correlation_term.detach() # Shape: (output_size, input_size)
            # Normalize scale for gradients between 0 and 1 (optional, but good practice)
            min_scale = torch.min(scale_output)
            max_scale = torch.max(scale_output)
            if max_scale - min_scale > 1e-8:
                 scale_output = (scale_output - min_scale) / (max_scale - min_scale)

            hebbian_scores = torch.norm(scale_output, p=2, dim=1) # Shape: (output_size)
            if indices_old is not None:
                hebbian_scores = hebbian_scores.scatter(0, indices_old.long(), float('-inf'))
            _, topk_indices_hebbian = torch.topk(hebbian_scores, int(self.percent_winner * x_size)) # Shape: (1)
            hebbian_mask = torch.zeros(1, x_size, device=y.device)
            hebbian_mask.scatter_(1, topk_indices_hebbian.unsqueeze(0), 1.0)
            indices_non_winners = torch.arange(x_size, device=y.device)[hebbian_mask.squeeze(0) == 0] # Select indices where mask is 0

            return scale_output, indices_non_winners, hebbian_mask

    def scale_grad(self, scalers):
        """
        Scales gradients for neurons specified by indexes.

        Args:
            indexes (list): List of neuron indices to scale during gradient updates.

        Returns:
            function: Hook function for modifying gradients during backpropagation.
        """

        def hook(grad):
            if len(scalers) > 0:
                grad *= scalers
            return grad

        return hook

    def freeze_grad(self, indexes):
        """
        Freezes gradients for neurons specified by indexes.

        Args:
            indexes (list): List of neuron indices to freeze during gradient updates.

        Returns:
            function: Hook function for modifying gradients during backpropagation.
        """

        def hook(grad):
            if len(indexes) > 0:
                indexes_arr = (
                    indexes.cpu().numpy()
                    if isinstance(indexes, torch.Tensor)
                    else indexes
                )
                grad[indexes_arr] = 0
            return grad

        return hook

    def _register_gradient_hooks(self, indexes):
        """
        Registers hooks for freezing gradients on specified neurons.

        Args:
            indexes (list): List of neuron indices to freeze during gradient updates.
        """
        for i, layer in enumerate(self.linear):
            # Check if the layer already has hooks registered and clear them if they exist
            if layer.weight._backward_hooks is not None:
                layer.weight._backward_hooks.clear()
            layer.weight.register_hook(self.scale_grad(indexes[i]))

    def update_indexes(self, new_indexes):
        """
        Updates the indexes of neurons for freezing and re-registers gradient hooks.

        Args:
            new_indexes (list): New list of neuron indexes to freeze.
        """

        self.indexes = new_indexes
        self._register_gradient_hooks(new_indexes)

    def reinitialize_hebbian_parameters(self, init_type="zero"):
        """
        Reinitializes the Hebbian parameters.

        Args:
            init_type (str, optional): Initialization type ('zero' or 'normal').
        """

        for param in self.hebb_params.parameters():
            if init_type == "zero":
                nn.init.constant_(param, 0)
            elif init_type == "normal":
                nn.init.kaiming_normal_(param)

In [3]:
import torch
from tqdm import tqdm
from torch import nn, optim


def get_excess_neurons(indices1, indices2, layer_sizes=[256, 128, 64]):
    """
    Identifies neurons in each layer that are present in indices2 but not in indices1.

    Args:
        indices1 (list of lists): Indices of neurons selected for a particular task or layer.
        indices2 (list of lists): Indices of neurons for comparison with indices1.
        layer_sizes (list, optional): List of neuron counts per layer.

    Returns:
        list of torch.Tensor: List of indices representing neurons not present in either indices1 or indices2.
    """

    # layer_sizes = [6,6,6]
    excess_neurons = []
    for i in range(len(indices1)):
        excess_neurons.append([j for j in indices2[i] if j not in indices1[i]])

    all_indices = [torch.arange(layer_sizes[i]) for i in range(len(layer_sizes))]

    if excess_neurons == [[], [], []]:
        for i in range(len(all_indices)):
            # delete the indicces present in indices2 from all_indices
            all_indices[i] = torch.tensor(
                [j for j in all_indices[i] if j not in indices2[i]]
            )
        return all_indices

    for i in range(len(indices1)):
        all_indices[i] = torch.tensor(
            [j for j in all_indices[i] if j not in excess_neurons[i]]
        )

    return all_indices


def get_merge_mask(mask1, mask2):
    """
    Merges two sets of binary masks using logical OR operation.

    Args:
        mask1 (list of torch.Tensor): First list of masks.
        mask2 (list of torch.Tensor): Second list of masks.

    Returns:
        list of torch.Tensor: List of merged masks, where each mask is the result of logical OR operation.
    """

    merge_mask = []
    for i in range(len(mask1)):
        merge_mask.append(torch.logical_or(mask1[i], mask2[i]).int())
    return merge_mask


def calc_percentage_of_zero_grad(masks):
    """
    Calculates the percentage of neurons with non-zero gradients in the given masks.

    Args:
        masks (list of torch.Tensor): List of masks for each layer.

    Returns:
        float: Percentage of neurons with non-zero gradients.
    """

    total = 0
    zero = 0
    for mask in masks:
        total += mask.numel()
        zero += torch.sum(mask == 0).item()
    return (1 - zero / total) * 100


def forwardprop_and_backprop(
    model,
    lr,
    data_loader,
    continual=False,
    list_of_indexes=None,
    masks=None,
    scheduler=None,
    optimizer=None,
    indices_old = None,
    task_id=None,
):
    """
    Performs forward and backward propagation over a dataset with optional continual learning.

    Args:
        model (nn.Module): Neural network model.
        lr (float): Learning rate for optimizer.
        data_loader (DataLoader): DataLoader for the training data.
        continual (bool, optional): Flag indicating whether continual learning is applied.
        list_of_indexes (list, optional): List of indexes for selective neuron training.
        masks (list, optional): List of masks for each layer.
        scheduler (torch.optim.lr_scheduler, optional): Learning rate scheduler.
        optimizer (torch.optim.Optimizer, optional): Optimizer for the model.

    Returns:
        tuple: Updated list of indexes, masks, model, and optimizer after training.
    """

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    criterion = nn.CrossEntropyLoss()
    if optimizer is None:
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    loss_total = 0
    model.train()

    for i, (data, target) in enumerate(tqdm(data_loader)):
        optimizer.zero_grad()
        data = data.view(-1, 784)
        data, target = data.to(device), target.to(device)
        scalers = None
        one_hot_target = torch.zeros(target.size(0), 10).to(device)
        one_hot_target.scatter_(1, target.view(-1, 1), 1)

        if not continual:
            indices_old = [None] * len(list_of_indexes)

            output, scalers, list_of_indexes, masks = model(
                data, scalers, indexes=list_of_indexes, masks=masks, indices_old = indices_old, target=one_hot_target
            )

        else:
            output, scalers, list_of_indexes, masks = model(
                data,
                scalers,
                indexes=list_of_indexes,
                masks=masks,
                indices_old=indices_old,
                target=one_hot_target,
            )

        # if task_id is not None:
        #     output = output[:, 5*(task_id-1):5*task_id]
        #     target = target % 5

        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        loss_total += loss.item()

    scheduler.step()

    print("Avg loss: ", loss_total / len(data_loader))
    return list_of_indexes, masks, model, optimizer

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR


seed = 924  # verified
print("Seed: ", seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)


data_loader_1, data_loader_2, test_loader_1, test_loader_2, test_loader = get_data_separate(
    batch_size=64
)
list_of_indexes = [[], [], [],[]]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
masks = [
    torch.ones(256).to(device),
    torch.ones(128).to(device),
    torch.ones(64).to(device),
    torch.ones(10).to(device),
]


original_model = NN(784, 10, indexes=list_of_indexes).to(device)
optimizer = optim.SGD(original_model.parameters(), lr=0.1, momentum=0.9)
scheduler = StepLR(optimizer, step_size=5, gamma=0.1)
for i in range(10):
    task1_indices, task1_masks, task1_model, optimizer = forwardprop_and_backprop(
        original_model,
        0.1,
        data_loader_1,
        list_of_indexes=list_of_indexes,
        masks=masks,
        optimizer=optimizer,
        scheduler=scheduler,
        task_id=1,
    )
    list_of_indexes = task1_indices
    # print("percentage of zero gradients: ",calc_percentage_of_zero_grad(original_model))

indices = []
new_masks = []
layer_sizes = [256, 128, 64, 10]
for i in range(len(layer_sizes)):
    indices.append(
        torch.tensor(
            [j for j in range(layer_sizes[i]) if j not in task1_indices[i]]
        ).to(device)
    )
    mask = torch.tensor(
        [1 if k in task1_indices[i] else 0 for k in range(layer_sizes[i])]
    ).to(device)
    new_masks.append(mask)

print("Task 1 indices: ", task1_indices)
print("Task 1 masks: ", task1_masks)
print("Percentage of frozen neurons: ", calc_percentage_of_zero_grad(task1_masks))

print("### Task 2 ###")
for i in range(10):
    task2_indices, task2_masks, task2_model, optimizer = forwardprop_and_backprop(
        task1_model,
        0.1,
        data_loader_2,
        list_of_indexes=task1_indices,
        masks=new_masks,
        continual=True,
        optimizer=None,
        scheduler=scheduler,
        indices_old=indices,
        task_id=2,
    )
# print("Percentage of frozen neurons: ", calc_percentage_of_zero_grad(task2_masks))
# print("percentage of zero gradients: ",calc_percentage_of_zero_grad(original_model))

print("Task 2 indices: ", task2_indices)
print("Task 2 masks: ", task2_masks)
print("Percentage of frozen neurons: ", calc_percentage_of_zero_grad(task2_masks))

all_masks = [task1_masks, task2_masks]
correct = 0
accuracies = []
original_model.eval()

print("### Testing both Tasks ###")



def mc_entropy_weighted_ensemble(model, data, masks, mc_runs=10, temperature=1.0, device='cuda'):
    """
    Do MC Dropout forward passes for each subnetwork, compute predictive entropy,
    then fuse outputs by softmax of negative entropy (lower entropy → higher weight).
    """
    model.train()
    batch_size = data.size(0)
    n_models = len(masks)
    all_probs = torch.zeros(n_models, mc_runs, batch_size, model.linear[-1].out_features, device=device)


    for m, mask in enumerate(masks):
        for t in range(mc_runs):
            with torch.no_grad():
                out, *_ = model(data, masks=mask, indices_old=[None]*len(mask))
            all_probs[m, t] = out

    # average to get predictive p̂(x)
    p_mc = all_probs.mean(dim=1)

    # compute entropy for each model & sample
    entropies = -torch.sum(p_mc * torch.log(p_mc + 1e-10), dim=2)  # shape: (n_models, batch)

    model.eval()

    # lower entropy => higher weight
    weights = F.softmax(-entropies / temperature, dim=0)
    weights = weights.unsqueeze(2)

    # 5) fuse the averaged probabilities
    fused = torch.sum(weights * p_mc, dim=0)
    return fused



correct = 0
original_model.eval()

for data, target in test_loader:
    data = data.view(-1, 784).to(device)
    target = target.to(device)

    fused_probs = mc_entropy_weighted_ensemble(
        original_model,
        data,
        masks=all_masks,
        mc_runs=20,
        temperature=0.5,
        device=device,
    )
    preds = fused_probs.argmax(dim=1)
    correct += preds.eq(target).sum().item()

acc = 100.0 * correct / len(test_loader.dataset)
print(f"MC‐Dropout + Entropy‐Weighted Ensemble accuracy: {acc:.2f}%")



print(f"Accuracy for both Tasks: {100 * correct / len(test_loader.dataset):.2f}%")

correct = 0
print("### Testing Task 1###")
task_id = 1
for data, target in test_loader_1:
    data = data.view(-1, 784)
    data, target = data.to(device), target.to(device)
    output, scalers, indices, masks = task1_model(data, masks=task1_masks, indices_old=[None]*len(indices))
    # check the accuracy
    predicted = output.argmax(dim=1, keepdim=True)
    correct += predicted.eq(target.view_as(predicted)).sum().item()

print(f"Accuracy for Task 1: {100* correct/len(test_loader_1.dataset)}%")
accuracies.append(100 * correct / len(test_loader_1.dataset))

# task2_masks = get_merge_mask(task1_masks, task2_masks)

correct = 0
print("### Testing Task 2###")
task_id = 2
for data, target in test_loader_2:
    data = data.view(-1, 784)
    data, target = data.to(device), target.to(device)
    output, scalers, indices, masks = task2_model(data, masks=task2_masks, indices_old=[None]*len(indices))
    # check the accuracy
    predicted = output.argmax(dim=1, keepdim=True)
    # target = target % 5
    correct += predicted.eq(target.view_as(predicted)).sum().item()

print(f"Accuracy for Task 2: {100* correct/len(test_loader_2.dataset)}%")
accuracies.append(100 * correct / len(test_loader_2.dataset))


# import matplotlib.pyplot as plt
# import numpy as np

# hebbian_weights = task1_model.hebb_params[0].weight.data.cpu().numpy()
# model_weights = task2_model.linear[0].weight.data.cpu().numpy()
# model_weights1 = task2_model.linear[1].weight.data.cpu().numpy()

# model_neurons = np.random.choice(256, 20)
# model_neurons1 = np.random.choice(128, 20)
# # select random 20 neurons
# neurons = np.random.choice(256, 20)


# plt.figure(figsize=(20, 10))
# for i, neuron in enumerate(neurons):
#     plt.subplot(4, 5, i + 1)
#     plt.imshow(hebbian_weights[neuron].reshape(28, 28), cmap="gray")
#     plt.axis("off")

# plt.show()

# plt.figure(figsize=(20, 10))
# for i, neuron in enumerate(model_neurons):
#     idx = neuron
#     # idx = task2_indices[0][neuron]
#     plt.subplot(4, 5, i + 1)
#     plt.imshow(model_weights[idx].reshape(28, 28), cmap="gray")
#     plt.axis("off")

# plt.show()

# plt.figure(figsize=(20, 10))
# for i, neuron in enumerate(model_neurons1):
#     idx = neuron
#     # idx = task2_indices[1][neuron]
#     plt.subplot(4, 5, i + 1)
#     plt.imshow(model_weights1[idx].reshape(16, 16), cmap="gray")
#     plt.axis("off")

# plt.show()

Seed:  924


100%|██████████| 9.91M/9.91M [00:00<00:00, 12.9MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 339kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.19MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.71MB/s]
100%|██████████| 478/478 [00:03<00:00, 133.90it/s]


Avg loss:  1.9166911259355903


100%|██████████| 478/478 [00:02<00:00, 219.45it/s]


Avg loss:  1.5095521478473393


100%|██████████| 478/478 [00:02<00:00, 220.38it/s]


Avg loss:  1.497603766838377


100%|██████████| 478/478 [00:02<00:00, 217.58it/s]


Avg loss:  1.490234434105861


100%|██████████| 478/478 [00:02<00:00, 224.47it/s]


Avg loss:  1.4888367388537738


100%|██████████| 478/478 [00:02<00:00, 184.02it/s]


Avg loss:  1.4823020477175213


100%|██████████| 478/478 [00:02<00:00, 224.53it/s]


Avg loss:  1.4802387399154726


100%|██████████| 478/478 [00:02<00:00, 226.43it/s]


Avg loss:  1.4796531128085308


100%|██████████| 478/478 [00:02<00:00, 218.79it/s]


Avg loss:  1.479382449114173


100%|██████████| 478/478 [00:02<00:00, 222.28it/s]


Avg loss:  1.478153832038576
Task 1 indices:  [tensor([  2,   4,   9,  14,  16,  18,  20,  21,  22,  23,  24,  26,  27,  28,
         31,  35,  36,  40,  42,  44,  47,  51,  57,  58,  63,  67,  69,  70,
         71,  76,  78,  85,  89,  92,  93,  97, 101, 102, 104, 107, 108, 109,
        111, 115, 117, 118, 121, 124, 126, 132, 133, 134, 135, 141, 142, 149,
        150, 152, 168, 170, 172, 177, 178, 179, 180, 182, 183, 188, 189, 195,
        196, 197, 198, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255], device='cuda:0'), tensor([  1,   4,   7,   9,  10,  11,  15,  16,  17,  21,  23,  25,  26,  27,
         28,  34,  47,  49,  50,  51,  58,  60,  61,  62,  64,  73,  75,  79,
         82,  90,  93,  95,  96,  97,  98,  99, 100, 101, 

100%|██████████| 459/459 [00:02<00:00, 186.18it/s]


Avg loss:  1.6643502107113275


100%|██████████| 459/459 [00:02<00:00, 202.50it/s]


Avg loss:  1.5130411430641457


100%|██████████| 459/459 [00:02<00:00, 217.41it/s]


Avg loss:  1.5052824493067456


100%|██████████| 459/459 [00:02<00:00, 218.45it/s]


Avg loss:  1.5036155990525788


100%|██████████| 459/459 [00:02<00:00, 160.06it/s]


Avg loss:  1.5011190970738728


100%|██████████| 459/459 [00:02<00:00, 194.42it/s]


Avg loss:  1.499105511667422


100%|██████████| 459/459 [00:02<00:00, 198.44it/s]


Avg loss:  1.5034695907875344


100%|██████████| 459/459 [00:02<00:00, 216.73it/s]


Avg loss:  1.5018870962990656


100%|██████████| 459/459 [00:02<00:00, 216.72it/s]


Avg loss:  1.49945105420738


100%|██████████| 459/459 [00:02<00:00, 218.53it/s]


Avg loss:  1.4993251021910856
Task 2 indices:  [tensor([  0,   1,   3,   5,   6,   7,   8,  10,  11,  12,  13,  15,  17,  19,
         25,  29,  30,  32,  33,  34,  37,  38,  39,  41,  43,  45,  46,  48,
         49,  50,  52,  53,  54,  55,  56,  59,  60,  61,  62,  64,  65,  66,
         68,  72,  73,  74,  75,  77,  79,  80,  81,  82,  83,  84,  86,  87,
         88,  90,  91,  94,  95,  96,  98,  99, 100, 103, 105, 106, 110, 112,
        113, 114, 116, 119, 120, 122, 123, 125, 127, 128, 129, 130, 131, 136,
        137, 138, 139, 140, 143, 144, 145, 146, 147, 148, 151, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 169, 171,
        173, 174, 175, 176, 181, 184, 185, 186, 187, 190, 191, 192, 193, 194,
        199, 200], device='cuda:0'), tensor([ 0,  2,  3,  5,  6,  8, 12, 13, 14, 18, 19, 20, 22, 24, 29, 30, 31, 32,
        33, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 48, 52, 53, 54, 55,
        56, 57, 59, 63, 65, 66, 67, 68, 69, 70, 71, 7

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
from tqdm import tqdm


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#  Router (Actor-Critic
class Router(nn.Module):
    def __init__(self, input_dim, hidden_dim, n_experts=2):
        super().__init__()
        self.fc1       = nn.Linear(input_dim, hidden_dim)
        self.actor_fc  = nn.Linear(hidden_dim, n_experts)
        self.critic_fc = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        # x: (batch, input_dim)
        h = F.relu(self.fc1(x))
        logits = self.actor_fc(h)     # (batch, n_experts)
        value  = self.critic_fc(h).squeeze(-1)
        return logits, value


def train_router(router,
                 expert_model,
                 masks,
                 data_loader,
                 n_epochs=5,
                 gamma=0.99,
                 lr=1e-3):
    optimizer = optim.Adam(router.parameters(), lr=lr)

    for epoch in range(n_epochs):
        router.train()
        epoch_reward = 0
        for data, target in tqdm(data_loader, desc=f"Epoch {epoch+1}/{n_epochs}"):
            data, target = data.view(-1,784).to(device), target.to(device)


            logits, values = router(data)
            dist    = Categorical(logits=logits)
            actions = dist.sample()           # (batch,)
            logp    = dist.log_prob(actions)  # (batch,)

            # route through chosen expert, get per-sample reward
            rewards = []
            for i in range(data.size(0)):
                x_i   = data[i:i+1]
                a     = actions[i].item()
                m     = masks[a]
                out, *_ = expert_model(
                    x_i,
                    masks=m,                       # pass it directly
                    indices_old=[None]*len(m)     # same length as layers
                )
                pred = out.argmax(dim=1)
                rewards.append((pred == target[i]).float().item())

            rewards = torch.tensor(rewards, device=device)  # (batch,)

            # compute advantages (one‐step)
            returns    = rewards
            advantages = returns - values.detach()

            # losses
            policy_loss = -(logp * advantages).mean()
            value_loss  = F.mse_loss(values, returns)
            loss        = policy_loss + 0.5 * value_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_reward += rewards.sum().item()

        avg_reward = epoch_reward / len(data_loader.dataset)
        print(f"Epoch {epoch+1}: Avg routing reward = {avg_reward:.4f}")


#  reuse  pretrained expert
original_model.eval()
for p in original_model.parameters():
    p.requires_grad = False


router = Router(input_dim=784, hidden_dim=128, n_experts=2).to(device)


full_train = torch.utils.data.ConcatDataset([data_loader_1.dataset,
                                              data_loader_2.dataset])
full_loader = torch.utils.data.DataLoader(full_train,
                                          batch_size=64,
                                          shuffle=True)
train_router(router, original_model, all_masks, full_loader,
              n_epochs=10, gamma=0.99, lr=1e-3)



# evalutaion

router.eval()
correct = 0
total   = 0
with torch.no_grad():
    for data, target in test_loader:
        data, target = data.view(-1,784).to(device), target.to(device)
        logits, _   = router(data)
        expert_ids  = logits.argmax(dim=1)  # pick the best expert
        for i in range(data.size(0)):
            x_i = data[i:i+1]
            a   = expert_ids[i].item()
            m   = all_masks[a]
            out, *_ = original_model(
                x_i,
                masks=m,
                indices_old=[None]*len(m)
            )
            pred = out.argmax(dim=1)
            correct += int(pred == target[i])
        total += data.size(0)

acc = 100.0 * correct / total
print(f"RL-router ensemble accuracy: {acc:.2f}%")


Epoch 1/10: 100%|██████████| 938/938 [02:59<00:00,  5.23it/s]


Epoch 1: Avg routing reward = 0.8415


Epoch 2/10: 100%|██████████| 938/938 [02:58<00:00,  5.27it/s]


Epoch 2: Avg routing reward = 0.9182


Epoch 3/10: 100%|██████████| 938/938 [02:57<00:00,  5.27it/s]


Epoch 3: Avg routing reward = 0.9310


Epoch 4/10: 100%|██████████| 938/938 [02:57<00:00,  5.27it/s]


Epoch 4: Avg routing reward = 0.9381


Epoch 5/10: 100%|██████████| 938/938 [02:57<00:00,  5.29it/s]


Epoch 5: Avg routing reward = 0.9409


Epoch 6/10: 100%|██████████| 938/938 [02:56<00:00,  5.31it/s]


Epoch 6: Avg routing reward = 0.9449


Epoch 7/10: 100%|██████████| 938/938 [02:57<00:00,  5.27it/s]


Epoch 7: Avg routing reward = 0.9464


Epoch 8/10: 100%|██████████| 938/938 [02:57<00:00,  5.29it/s]


Epoch 8: Avg routing reward = 0.9465


Epoch 9/10: 100%|██████████| 938/938 [02:56<00:00,  5.31it/s]


Epoch 9: Avg routing reward = 0.9494


Epoch 10/10: 100%|██████████| 938/938 [02:56<00:00,  5.33it/s]


Epoch 10: Avg routing reward = 0.9506
RL-router ensemble accuracy: 94.81%
