In [2]:
# modified version of notebook from https://github.com/lyeoni/pytorch-mnist-GAN
# CL_course-phd-data-science assignment solution code by student: Angela Pomaro


# Prerequisites
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
from torch.autograd import Variable
from torchvision.utils import save_image
import copy

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define hyperparameters
batch_size = 100
z_dim = 100
learning_rate = 0.0002
n_epoch = 10  # Training epochs for each GAN
num_tasks = 5  # Number of class-incremental tasks
mnist_dim = 28 * 28  # Dimension of the MNIST data

# MNIST Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.1307,), std=(0.3081,))])

train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transform, download=False)

# Split dataset into five tasks, each containing two classes
def split_dataset(dataset, num_tasks=5):
    task_indices = []
    targets = np.array(dataset.targets)
    classes = np.unique(targets)
    for i in range(num_tasks):
        idx = np.where(np.isin(targets, classes[i*2:i*2+2]))[0]
        task_indices.append(idx)
    return [Subset(dataset, idx) for idx in task_indices]

class Generator(nn.Module):
    def __init__(self, g_input_dim, g_output_dim):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(g_input_dim, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)
        self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)

    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))

class Discriminator(nn.Module):
    def __init__(self, d_input_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features // 2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features // 2)
        self.fc4 = nn.Linear(self.fc3.out_features, 1)

    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        return torch.sigmoid(self.fc4(x))

def train_gan(generator, discriminator, data_loader, epochs=5):
    criterion = nn.BCELoss()
    G_optimizer = optim.Adam(generator.parameters(), lr=learning_rate)
    D_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate)

    for epoch in range(epochs):
        for data, _ in data_loader:
            batch_size = data.size(0)
            data = data.view(batch_size, -1).to(device)

            # Train Discriminator
            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)

            D_optimizer.zero_grad()
            outputs = discriminator(data)
            d_loss_real = criterion(outputs, real_labels)
            d_loss_real.backward()

            z = torch.randn(batch_size, z_dim).to(device)
            fake_data = generator(z)
            outputs = discriminator(fake_data.detach())
            d_loss_fake = criterion(outputs, fake_labels)
            d_loss_fake.backward()
            D_optimizer.step()

            # Train Generator
            G_optimizer.zero_grad()
            outputs = discriminator(fake_data)
            g_loss = criterion(outputs, real_labels)
            g_loss.backward()
            G_optimizer.step()

class Classifier(nn.Module):
    def __init__(self, input_dim=784, num_classes=10):
        super(Classifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

def train_classifier(model, optimizer, criterion, train_loader, epochs=5):
    model.train()
    for epoch in range(epochs):
        for data, targets in train_loader:
            data, targets = data.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

def evaluate(model, data_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, targets in data_loader:
            data, targets = data.to(device), targets.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    accuracy = 100 * correct / total
    return accuracy

def class_incremental_learning(train_dataset, test_dataset, num_tasks=5):
    classifier = Classifier(num_classes=10).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(classifier.parameters(), lr=learning_rate)

    all_acc = []
    avg_acc = []

    train_tasks = split_dataset(train_dataset, num_tasks)
    test_tasks = split_dataset(test_dataset, num_tasks)

    train_loaders = [DataLoader(task, batch_size=batch_size, shuffle=True) for task in train_tasks]
    test_loaders = [DataLoader(task, batch_size=batch_size, shuffle=False) for task in test_tasks]

    stored_generators = []
    synthetic_datasets = []

    for i in range(num_tasks):
        print(f"\nTraining on task {i+1}/{num_tasks}...")

        # Train GAN for the current task
        generator = Generator(g_input_dim=z_dim, g_output_dim=mnist_dim).to(device)
        discriminator = Discriminator(d_input_dim=mnist_dim).to(device)
        train_gan(generator, discriminator, train_loaders[i], epochs=n_epoch)

        # Store the current generator
        stored_generators.append(copy.deepcopy(generator))

        # Generate synthetic data for previous tasks
        if i > 0:
            for j in range(i):
                num_samples = len(train_tasks[j])
                z = torch.randn(num_samples, z_dim).to(device)
                synthetic_data = stored_generators[j](z).view(num_samples, 1, 28, 28).detach().cpu()
                synthetic_targets = torch.tensor([train_tasks[j][k][1] for k in range(num_samples)])
                synthetic_dataset = torch.utils.data.TensorDataset(synthetic_data, synthetic_targets)
                synthetic_datasets.append(synthetic_dataset)

        # Combine real data from current task with synthetic data from previous tasks
        combined_data = []
        combined_targets = []

        # Add real data from current task
        for data, targets in train_loaders[i]:
            combined_data.append(data)
            combined_targets.append(targets)

        # Add synthetic data from previous tasks
        for synthetic_dataset in synthetic_datasets:
            for data, targets in DataLoader(synthetic_dataset, batch_size=batch_size, shuffle=True):
                combined_data.append(data)
                combined_targets.append(targets)

        combined_data = torch.cat(combined_data)
        combined_targets = torch.cat(combined_targets)
        combined_dataset = torch.utils.data.TensorDataset(combined_data, combined_targets)
        combined_loader = DataLoader(combined_dataset, batch_size=batch_size, shuffle=True)

        # Train classifier on combined data
        train_classifier(classifier, optimizer, criterion, combined_loader, epochs=n_epoch)

        # Evaluate on all seen tasks
        acc_per_task = []
        for j in range(i+1):
            acc = evaluate(classifier, test_loaders[j])
            acc_per_task.append(acc)
            print(f"Accuracy on task {j+1}: {acc:.2f}%")

        avg_acc.append(np.mean(acc_per_task))
        all_acc.append(acc_per_task)

    return all_acc, avg_acc, classifier

# Run the class-incremental learning
all_acc, avg_acc, classifier = class_incremental_learning(train_dataset, test_dataset)

print("\nFinal Average Accuracy per Task:")
for i, acc in enumerate(avg_acc):
    print(f"Task {i+1}: {acc:.2f}%")

# Compute Forward Transfer and Backward Transfer
fwt = (avg_acc[-1] - avg_acc[0]) / avg_acc[0] * 100
bwt = np.mean([avg_acc[-1] - acc for acc in avg_acc[:-1]])

print(f"\nForward Transfer: {fwt:.2f}%")
print(f"Backward Transfer: {bwt:.2f}%")
print()
print()

# Memory and computational requirements
generator = Generator(g_input_dim=z_dim, g_output_dim=mnist_dim)
classifier = Classifier(input_dim=mnist_dim, num_classes=10)

gan_memory = sum(p.numel() for p in generator.parameters())
cla_memory = sum(p.numel() for p in classifier.parameters())

# Print memory requirements
print(f"Memory requirements:")
print(f"Generator: {gan_memory} parameters")
print(f"Classifier: {cla_memory} parameters")
print()

# Estimate computational requirements (FLOPS)
num_operations_generator = z_dim * mnist_dim * 2  # Assume simple linear operation
generator_flops = num_operations_generator

# Classifier FLOPS estimation
num_operations_classifier = mnist_dim * 10 * 2  # Assume simple linear operation
classifier_flops = num_operations_classifier

# Print computational requirements
print(f"Computational requirements (FLOPS):")
print(f"Generator: {generator_flops} FLOPS")
print(f"Classifier: {classifier_flops} FLOPS")
print()

# Cumulative memory usage over 5 tasks

# Initialize cumulative memory usage
cumulative_memory = 0

# Loop through each task (of two classes each)
for task in range(5):
    # Calculate memory requirements for Generator and Classifier for this task
    generator = Generator(g_input_dim=z_dim, g_output_dim=mnist_dim)
    classifier = Classifier(input_dim=mnist_dim, num_classes=10)

    # Calculate memory for this task
    generator_memory = sum(p.numel() for p in generator.parameters())
    classifier_memory = sum(p.numel() for p in classifier.parameters())
    task_memory = generator_memory + classifier_memory

    # Accumulate cumulative memory usage
    cumulative_memory += task_memory

    # Print memory usage for this task
    print(f"Memory usage for task {task+1}: Generator = {generator_memory} parameters, Classifier = {classifier_memory} parameters")

# Print cumulative memory usage
print(f"\nCumulative memory usage over 5 tasks: {cumulative_memory} parameters")
print()



Training on task 1/5...
Accuracy on task 1: 99.95%

Training on task 2/5...
Accuracy on task 1: 7.14%
Accuracy on task 2: 99.66%

Training on task 3/5...
Accuracy on task 1: 0.14%
Accuracy on task 2: 22.92%
Accuracy on task 3: 99.89%

Training on task 4/5...
Accuracy on task 1: 0.28%
Accuracy on task 2: 15.52%
Accuracy on task 3: 23.69%
Accuracy on task 4: 99.85%

Training on task 5/5...
Accuracy on task 1: 0.00%
Accuracy on task 2: 8.77%
Accuracy on task 3: 0.05%
Accuracy on task 4: 2.72%
Accuracy on task 5: 99.34%

Final Average Accuracy per Task:
Task 1: 99.95%
Task 2: 53.40%
Task 3: 40.98%
Task 4: 34.84%
Task 5: 22.18%

Forward Transfer: -77.81%
Backward Transfer: -35.12%


Memory requirements:
Generator: 1486352 parameters
Classifier: 235146 parameters

Computational requirements (FLOPS):
Generator: 156800 FLOPS
Classifier: 15680 FLOPS

Memory usage for task 1: Generator = 1486352 parameters, Classifier = 235146 parameters
Memory usage for task 2: Generator = 1486352 parameters, 