In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models

class Expert(nn.Module):
    def __init__(self, model_type="resnet18", num_classes=10):
        super(Expert, self).__init__()
        if model_type == "resnet18":
            self.model = models.resnet18(pretrained=True)
            self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
        elif model_type == "vgg16":
            self.model = models.vgg16(pretrained=True)
            self.model.classifier[-1] = nn.Linear(self.model.classifier[-1].in_features, num_classes)
        else:
            raise ValueError("Unsupported model type. Choose 'resnet18' or 'vgg16'")
    
    def forward(self, x):
        return F.softmax(self.model(x), dim=-1)

class Gate(nn.Module):
    def __init__(self, n_experts, input_dim=512):
        super(Gate, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, n_experts)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        return F.softmax(self.fc2(x), dim=-1)

class MOEModel(nn.Module):
    def __init__(self, n_experts=3, model_type="resnet18", num_classes=10, lr=0.001):
        super(MOEModel, self).__init__()
        self.experts = nn.ModuleList([Expert(model_type, num_classes) for _ in range(n_experts)])
        self.gate = Gate(n_experts, 512 if model_type == "resnet18" else 4096)
        self.optimizer = optim.Adam(self.parameters(), lr=lr)
    
    def forward(self, x):
        gate_outputs = self.gate(x).unsqueeze(-1)  # (batch_size, n_experts, 1)
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=-1)  # (batch_size, num_classes, n_experts)
        return torch.matmul(expert_outputs, gate_outputs).squeeze(-1)  # (batch_size, num_classes)
    
    def probabilities(self, x, y):
        expert_outputs = self.forward(x)
        return torch.sum(y * expert_outputs, dim=1)
    
    def calculate_loss(self, x, y):
        probs = self.probabilities(x, y)
        return -torch.log(probs + 1e-9).mean()
    
    def grad(self, x, y):
        self.optimizer.zero_grad()
        loss = self.calculate_loss(x, y)
        loss.backward()
        return loss
    
    def step(self, x, y):
        loss = self.grad(x, y)
        self.optimizer.step()
        return loss


In [None]:
import torch
import torch.nn as nn

class SparsePruner:
    """Performs pruning on the experts in the MOE model."""

    def __init__(self, model, prune_perc, previous_masks, train_bias, train_bn):
        self.model = model
        self.prune_perc = prune_perc
        self.train_bias = train_bias
        self.train_bn = train_bn
        self.current_masks = None
        self.previous_masks = previous_masks
        
        # Identify dataset index from the previous masks
        valid_key = list(previous_masks.keys())[0]
        self.current_dataset_idx = previous_masks[valid_key].max()
    
    def pruning_mask(self, weights, previous_mask, layer_idx):
        """Computes a pruning mask based on weight magnitudes."""
        previous_mask = previous_mask.cuda()
        tensor = weights[previous_mask.eq(self.current_dataset_idx)]
        abs_tensor = tensor.abs()
        cutoff_rank = round(self.prune_perc * tensor.numel())
        cutoff_value = abs_tensor.view(-1).cpu().kthvalue(cutoff_rank)[0].item()

        remove_mask = weights.abs().le(cutoff_value) * previous_mask.eq(self.current_dataset_idx)
        previous_mask[remove_mask.eq(1)] = 0
        mask = previous_mask
        
        print(f'Layer #{layer_idx}, pruned {mask.eq(0).sum()}/{tensor.numel()} ({100 * mask.eq(0).sum() / tensor.numel():.2f}%)')
        return mask
    
    def prune(self):
        """Prunes only the expert networks while keeping the gating network untouched."""
        print(f'Pruning experts for dataset idx: {self.current_dataset_idx}')
        assert self.current_masks is None, 'Pruning twice?'
        self.current_masks = {}
        
        print(f'Pruning each expert layer by removing {100 * self.prune_perc:.2f}% of values')
        
        for expert_idx, expert in enumerate(self.model.experts):
            for layer_idx, module in enumerate(expert.modules()):
                if isinstance(module, (nn.Conv2d, nn.Linear)):
                    mask = self.pruning_mask(module.weight.data, self.previous_masks[(expert_idx, layer_idx)], layer_idx)
                    self.current_masks[(expert_idx, layer_idx)] = mask.cuda()
                    module.weight.data[self.current_masks[(expert_idx, layer_idx)].eq(0)] = 0.0
    
    def make_grads_zero(self):
        """Sets gradients of pruned weights to zero for the expert networks."""
        assert self.current_masks is not None
        
        for expert_idx, expert in enumerate(self.model.experts):
            for layer_idx, module in enumerate(expert.modules()):
                if isinstance(module, (nn.Conv2d, nn.Linear)):
                    layer_mask = self.current_masks[(expert_idx, layer_idx)]
                    
                    if module.weight.grad is not None:
                        module.weight.grad.data[layer_mask.ne(self.current_dataset_idx)] = 0
                        if not self.train_bias and module.bias is not None:
                            module.bias.grad.data.fill_(0)

    def make_pruned_zero(self):
        """Forces pruned weights to remain zero in the expert networks."""
        assert self.current_masks is not None
        
        for expert_idx, expert in enumerate(self.model.experts):
            for layer_idx, module in enumerate(expert.modules()):
                if isinstance(module, (nn.Conv2d, nn.Linear)):
                    module.weight.data[self.current_masks[(expert_idx, layer_idx)].eq(0)] = 0.0

    def apply_mask(self, dataset_idx):
        """Applies the pruning mask for a specific dataset."""
        for expert_idx, expert in enumerate(self.model.experts):
            for layer_idx, module in enumerate(expert.modules()):
                if isinstance(module, (nn.Conv2d, nn.Linear)):
                    weight = module.weight.data
                    mask = self.previous_masks[(expert_idx, layer_idx)].cuda()
                    weight[mask.eq(0)] = 0.0
                    weight[mask.gt(dataset_idx)] = 0.0

    def make_finetuning_mask(self):
        """Allows previously pruned weights to be trainable for the new dataset."""
        assert self.previous_masks is not None
        self.current_dataset_idx += 1
        
        for expert_idx, expert in enumerate(self.model.experts):
            for layer_idx, module in enumerate(expert.modules()):
                if isinstance(module, (nn.Conv2d, nn.Linear)):
                    mask = self.previous_masks[(expert_idx, layer_idx)]
                    mask[mask.eq(0)] = self.current_dataset_idx
        
        self.current_masks = self.previous_masks
