In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ----------------------------
# 1. Feature Extractor (ResNet-18)
# ----------------------------
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        resnet = models.resnet18(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(resnet.children())[:-1])

    def forward(self, x):
        x = self.feature_extractor(x)
        return x.view(x.size(0), -1)

# ----------------------------
# 2. Gating Network (Expert Selection)
# ----------------------------
class GatingNetwork(nn.Module):
    def __init__(self, input_dim, num_experts):
        super(GatingNetwork, self).__init__()
        self.fc = nn.Linear(input_dim, num_experts)

    def forward(self, x):
        logits = self.fc(x)
        
        # Gumbel-Softmax for hard expert selection
        gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-10) + 1e-10)
        expert_scores = F.softmax((logits + gumbel_noise) / 0.5, dim=1)

        # Select the highest scoring expert
        chosen_expert = torch.argmax(expert_scores, dim=1)

        # One-hot encoding of the chosen expert
        hard_expert = F.one_hot(chosen_expert, num_classes=expert_scores.shape[1]).float()

        return hard_expert, expert_scores

# ----------------------------
# 3. Masked Classifier (Neuron Masking)
# ----------------------------
class MaskedFeatureExtractor(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(MaskedFeatureExtractor, self).__init__()
        
        # Feature extractor layers
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        
        # Store mask per task
        self.masks = []

    def apply_mask(self, x, task_id):
        """
        Apply the mask corresponding to the current task.
        """
        mask = self.masks[task_id] if task_id < len(self.masks) else torch.ones_like(x)
        return x * mask

    def forward(self, x, task_id):
        x = F.relu(self.fc1(x))
        x = self.apply_mask(x, task_id)  # Apply task-specific mask
        return x

class TaskSpecificClassifier(nn.Module):
    def __init__(self, hidden_dim, output_dim):
        super(TaskSpecificClassifier, self).__init__()
        
        # Final classification layer
        self.fc2 = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        return F.softmax(self.fc2(x), dim=1)


In [2]:
def prune_neurons(model, prune_percentage=0.2):
    """
    Prune neurons based on magnitude and create a mask for the task.
    
    Parameters:
    - model: The classifier being pruned
    - prune_percentage: Percentage of neurons to prune
    """
    with torch.no_grad():
        for name, param in model.named_parameters():
            # if 'fc1.weight' in name:
                # Get the absolute values of the weights
                weights = torch.abs(param.data)

                # Calculate threshold
                threshold = torch.quantile(weights, prune_percentage)

                # Create mask
                mask = (weights > threshold).float()

                # Store the mask
                model.masks.append(mask)
                
                # Zero out pruned connections
                param.data *= mask


In [3]:
class MixtureOfExperts(nn.Module):
    def __init__(self, feature_dim, hidden_dim, num_experts=3):
        super(MixtureOfExperts, self).__init__()
        
        self.feature_extractor = FeatureExtractor()
        self.gating_network = GatingNetwork(feature_dim, num_experts)
        self.masked_feature_extractor = MaskedFeatureExtractor(feature_dim, hidden_dim)
        
        self.num_experts = num_experts
        self.classifiers = nn.ModuleDict()  # Store different classifiers per task

    def add_task(self, task_id, output_dim):
        """
        Add a new task with a task-specific classifier.
        """
        self.classifiers[str(task_id)] = TaskSpecificClassifier(hidden_dim=self.masked_feature_extractor.fc1.out_features, output_dim=output_dim)
        self.classifiers[str(task_id)] = self.classifiers[str(task_id)].to(device)

    def forward(self, x, task_id):
        # 1. Extract features
        features = self.feature_extractor(x)

        # 2. Select expert via gating network
        hard_expert, expert_scores = self.gating_network(features)

        # 3. Apply mask to the extracted features
        masked_features = self.masked_feature_extractor(features, task_id)

        # 4. Classify using task-specific classifier
        if str(task_id) in self.classifiers:
            outputs = self.classifiers[str(task_id)](masked_features)
        else:
            raise ValueError(f"Classifier for task {task_id} not found!")

        return outputs, expert_scores, hard_expert

In [4]:
def full_training(model, dataloader, criterion, optimizer, device, num_epochs=10, task_id=0, log_interval=100):
    """
    Train the classifier fully on the current task before pruning.
    """
    model.train()

    for epoch in range(num_epochs):
        running_loss = 0.0
        correct, total = 0, 0  # For tracking accuracy

        for batch_idx, (images, labels) in enumerate(dataloader):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass
            outputs, _, _ = model(images, task_id)

            # Compute loss
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            # Track accuracy
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            # Log periodically
            if batch_idx % log_interval == 0:
                print(f"Batch [{batch_idx}/{len(dataloader)}] (Task {task_id + 1}) - Loss: {loss.item():.4f}")

        # Epoch-level logging
        epoch_loss = running_loss / len(dataloader)
        epoch_accuracy = 100. * correct / total
        print(f"Epoch [{epoch+1}/{num_epochs}] (Task {task_id + 1}) - Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%")


In [5]:
def prune_and_freeze(model, prune_percentage=0.2):
    """
    Prune the model and freeze the mask for Task 1.
    """
    print("Pruning and freezing neurons for Task 1...")
    prune_neurons(model.classifier, prune_percentage)

def retrain_with_mask(model, dataloader, criterion, optimizer, device, task_id=0, num_epochs=5, log_interval=10):
    """
    Retrain the model on the given task using the frozen mask.
    """
    model.train()

    for epoch in range(num_epochs):
        running_loss = 0.0
        correct, total = 0, 0  # For accuracy tracking

        for batch_idx, (images, labels) in enumerate(dataloader):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass with mask applied
            outputs, _, _ = model(images, task_id)

            # Compute loss
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            # Track accuracy
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            # Log periodically
            if batch_idx % log_interval == 0:
                print(f"Batch [{batch_idx}/{len(dataloader)}] (Task {task_id + 1}) - Loss: {loss.item():.4f}")

        # Epoch-level logging
        epoch_loss = running_loss / len(dataloader)
        epoch_accuracy = 100. * correct / total
        print(f"Epoch [{epoch+1}/{num_epochs}] (Task {task_id + 1}) - Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%")



In [6]:
# CIFAR-10 dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Device

# Initialize model
model = MixtureOfExperts(feature_dim=512, hidden_dim=256, num_experts=3).to(device)

# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)



In [7]:
# Load training dataset
train_dataset = datasets.ImageFolder(root='cubs_cropped/train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# Load test dataset
test_dataset = datasets.ImageFolder(root='cubs_cropped/test', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

class_names = train_dataset.classes
print(class_names)

['001.Black_footed_Albatross', '002.Laysan_Albatross', '003.Sooty_Albatross', '004.Groove_billed_Ani', '005.Crested_Auklet', '006.Least_Auklet', '007.Parakeet_Auklet', '008.Rhinoceros_Auklet', '009.Brewer_Blackbird', '010.Red_winged_Blackbird', '011.Rusty_Blackbird', '012.Yellow_headed_Blackbird', '013.Bobolink', '014.Indigo_Bunting', '015.Lazuli_Bunting', '016.Painted_Bunting', '017.Cardinal', '018.Spotted_Catbird', '019.Gray_Catbird', '020.Yellow_breasted_Chat', '021.Eastern_Towhee', '022.Chuck_will_Widow', '023.Brandt_Cormorant', '024.Red_faced_Cormorant', '025.Pelagic_Cormorant', '026.Bronzed_Cowbird', '027.Shiny_Cowbird', '028.Brown_Creeper', '029.American_Crow', '030.Fish_Crow', '031.Black_billed_Cuckoo', '032.Mangrove_Cuckoo', '033.Yellow_billed_Cuckoo', '034.Gray_crowned_Rosy_Finch', '035.Purple_Finch', '036.Northern_Flicker', '037.Acadian_Flycatcher', '038.Great_Crested_Flycatcher', '039.Least_Flycatcher', '040.Olive_sided_Flycatcher', '041.Scissor_tailed_Flycatcher', '042.Ver

In [8]:
# Task 1: Full training
task_1_output_dim = len(class_names)  # Number of classes for task 1

model.add_task(task_id=0, output_dim=task_1_output_dim)  # Add task-specific classifier
full_training(model, train_loader, criterion, optimizer, device, num_epochs=10, task_id=0)

# Task 1: Pruning and freezing
prune_and_freeze(model, prune_percentage=0.2, task_id=0)

# Task 1: Retraining with frozen mask
retrain_with_mask(model, train_loader, criterion, optimizer, device, task_id=0, num_epochs=5)

Batch [0/188] (Task 1) - Loss: 5.2985


KeyboardInterrupt: 

In [None]:
# Load training dataset
train_dataset = datasets.ImageFolder(root='flowers/train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# Load test dataset
test_dataset = datasets.ImageFolder(root='flowers/test', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

class_names = train_dataset.classes
print(class_names)

# Task 2: Repeat the same pipeline (full training, pruning, and retraining)
task_2_output_dim = len(class_names)  # Number of classes for task 2 (change if needed)
model.add_task(task_id=1, output_dim=task_2_output_dim)  # Add classifier for Task 2
full_training(model, train_loader, criterion, optimizer, device, num_epochs=10, task_id=1)
prune_and_freeze(model, prune_percentage=0.2, task_id=1)
retrain_with_mask(model, train_loader, criterion, optimizer, device, task_id=1, num_epochs=5)

KeyboardInterrupt: 

In [None]:
# # Task 2: Repeat the same pipeline (full training, pruning, and retraining)
# full_training(model, train_loader, criterion, optimizer, device, num_epochs=10)
# prune_and_freeze(model, prune_percentage=0.75)
# retrain_with_mask(model, train_loader, criterion, optimizer, device, task_id=1, num_epochs=5)