In [20]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
import numpy as np
import pandas as pd
from tqdm import tqdm


In [4]:

# Define the feature extractor (f)
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(64 * 7 * 7, 128)  # Assuming 28x28 -> 7x7 after pooling
    
    def forward(self, x):
        x = self.conv(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

# Define the classification head (θ)
class ClassificationHead(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(ClassificationHead, self).__init__()
        self.fc = nn.Linear(input_dim, num_classes)
    
    def forward(self, x):
        return self.fc(x)

# Synthesizer: A simple generator for model inversion
class Synthesizer(nn.Module):
    def __init__(self, feature_dim, image_shape):
        super(Synthesizer, self).__init__()
        self.fc = nn.Linear(feature_dim, np.prod(image_shape))
        self.image_shape = image_shape

    def forward(self, features):
        images = self.fc(features).view(-1, *self.image_shape)
        return torch.sigmoid(images)  # Generate synthetic images


In [13]:
def create_mnist_data(num_classes, task_index,csv_file="mnist.csv", image_shape=(1, 28, 28)):
    mnist_file = open(csv_file)
    df = pd.read_csv(mnist_file)

    # Determine the classes for this task
    all_classes = sorted(df['label'].unique())  # Get all unique classes
    start_class = task_index * num_classes      # Starting class for this task
    end_class = start_class + num_classes       # Ending class for this task
    valid_classes = all_classes[start_class:end_class]

    # Filter the dataset for the valid classes
    filtered_df = df[df['label'].isin(valid_classes)]

    # Separate features and labels
    X = filtered_df.drop(columns=['label']).values  # Features (flattened images)
    y = filtered_df['label'].values                # Labels

    # Normalize and reshape the images
    X = torch.tensor(X / 255.0, dtype=torch.float32)  # Normalize to [0, 1]
    X = X.view(-1, *image_shape)  # Reshape to (batch, 1, 28, 28)
    y = torch.tensor(y, dtype=torch.long)            # Convert labels to tensors

    return TensorDataset(X, y)

In [22]:

# Knowledge distillation loss
def knowledge_distillation_loss(pred, target, temperature=2):
    soft_pred = F.log_softmax(pred / temperature, dim=1)
    soft_target = F.softmax(target / temperature, dim=1)
    return F.kl_div(soft_pred, soft_target, reduction="batchmean") * (temperature ** 2)

from tqdm import tqdm

# Training loop with progress bars
def train_r_dfcil(task_data, num_classes_per_task, feature_dim=128, image_shape=(1, 28, 28), epochs=5):
    feature_extractor = FeatureExtractor()
    classification_head = ClassificationHead(feature_dim, num_classes_per_task[0])
    synthesizer = Synthesizer(feature_dim, image_shape)
    
    old_model = None
    optimizer = optim.Adam(list(feature_extractor.parameters()) + list(classification_head.parameters()), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    
    for task_idx, (train_data, test_data) in enumerate(task_data):
        print(f"\nTraining on Task {task_idx + 1}/{len(task_data)}")
        train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
        num_classes = sum(num_classes_per_task[:task_idx + 1])
        
        if task_idx > 0:
            # Expand the classification head for new classes
            old_classification_head = classification_head
            classification_head = ClassificationHead(feature_dim, num_classes)
            classification_head.fc.weight.data[:old_classification_head.fc.out_features] = \
                old_classification_head.fc.weight.data
            classification_head.fc.bias.data[:old_classification_head.fc.out_features] = \
                old_classification_head.fc.bias.data
        
        # Optimizer for expanded model
        optimizer = optim.Adam(list(feature_extractor.parameters()) + list(classification_head.parameters()), lr=0.001)
        
        # Train synthesizer using the old model
        if old_model is not None:
            print("Training synthesizer...")
            synth_optimizer = optim.Adam(synthesizer.parameters(), lr=0.001)
            for epoch in tqdm(range(epochs), desc="Synthesizer Training", leave=False):
                features = torch.randn(32, feature_dim)  # Random latent features
                synthetic_images = synthesizer(features)
                synth_optimizer.zero_grad()
                features = old_model[0](synthetic_images)
                preds = old_model[1](features)
                loss = criterion(preds, torch.randint(0, num_classes - num_classes_per_task[task_idx], (32,)))
                loss.backward()
                synth_optimizer.step()
        
        # Train new task
        print("Training model...")
        for epoch in tqdm(range(epochs), desc=f"Task {task_idx + 1} Training", leave=False):
            feature_extractor.train()
            classification_head.train()
            epoch_loss = 0
            for images, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}", leave=False):
                images, labels = images, labels
                optimizer.zero_grad()
                
                # Forward pass
                features = feature_extractor(images)
                outputs = classification_head(features)
                
                # Loss calculation
                loss = criterion(outputs, labels)
                
                # Add knowledge distillation loss if old model exists
                if old_model is not None:
                    with torch.no_grad():
                        old_features = old_model[0](images)
                        old_outputs = old_model[1](old_features)
                    loss += knowledge_distillation_loss(outputs[:, :num_classes - num_classes_per_task[task_idx]],
                                                        old_outputs)
                
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()
            
            tqdm.write(f"Epoch {epoch + 1}/{epochs} Loss: {epoch_loss / len(train_loader):.4f}")
        
        # Freeze old model
        old_model = (feature_extractor, classification_head)
        print(f"Finished Task {task_idx + 1}/{len(task_data)}")



In [23]:
if __name__ == "__main__":
    num_classes_per_task = [5, 5]  # Task 1: 5 classes, Task 2: 5 classes
    image_shape = (1, 28, 28)
    
    # Load MNIST datasets for tasks
    task1_data = create_mnist_data(num_classes_per_task[0], task_index=0, image_shape=image_shape)
    task2_data = create_mnist_data(num_classes_per_task[1], task_index=1, image_shape=image_shape)
    
    # Prepare task data
    task_data = [(task1_data, task1_data), (task2_data, task2_data)]  # (train, test)
    
    # Train R-DFCIL
    train_r_dfcil(task_data, num_classes_per_task, image_shape=image_shape)


Training on Task 1/2
Training model...


                                                      

KeyboardInterrupt: 