In [3]:
# modified version of notebook from https://github.com/lyeoni/pytorch-mnist-GAN


# prerequisites
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
from torch.autograd import Variable
from torchvision.utils import save_image
import copy

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
# Define hyperparameters
batch_size = 100
z_dim = 100
learning_rate = 0.0002
n_epoch = 10  # Training epochs for each GAN
num_tasks = 5  # Number of class-incremental tasks
mnist_dim = 28 * 28  # Dimension of the MNIST data

In [5]:
# MNIST Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))])


train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transform, download=False)

# Split dataset into five tasks, each containing two classes
def split_dataset(dataset, num_tasks=5):
    task_indices = []
    targets = np.array(dataset.targets)
    classes = np.unique(targets)
    for i in range(num_tasks):
        idx = np.where(np.isin(targets, classes[i*2:i*2+2]))[0]
        task_indices.append(idx)
    return [Subset(dataset, idx) for idx in task_indices]

train_tasks = split_dataset(train_dataset, num_tasks)
test_tasks = split_dataset(test_dataset, num_tasks)

train_loaders = [DataLoader(task, batch_size=batch_size, shuffle=True) for task in train_tasks]
test_loaders = [DataLoader(task, batch_size=batch_size, shuffle=False) for task in test_tasks]

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 12567574.68it/s]


Extracting ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 368529.64it/s]


Extracting ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 3139947.51it/s]


Extracting ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3253164.07it/s]

Extracting ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw






In [6]:
class Generator(nn.Module):
    def __init__(self, g_input_dim, g_output_dim):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(g_input_dim, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)
        self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)

    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))

class Discriminator(nn.Module):
    def __init__(self, d_input_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
        self.fc4 = nn.Linear(self.fc3.out_features, 1)

    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        return torch.sigmoid(self.fc4(x))

In [7]:
# Function to train GAN (Generator and Discriminator)
def train_gan(generator, discriminator, data_loader, epochs=5):
    criterion = nn.BCELoss()
    G_optimizer = optim.Adam(generator.parameters(), lr=learning_rate)
    D_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate)

    for epoch in range(epochs):
        for data, _ in data_loader:
            batch_size = data.size(0)
            data = data.view(batch_size, -1).to(device)

            # Train Discriminator
            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)

            D_optimizer.zero_grad()
            outputs = discriminator(data)
            d_loss_real = criterion(outputs, real_labels)
            d_loss_real.backward()

            z = torch.randn(batch_size, z_dim).to(device)
            fake_data = generator(z)
            outputs = discriminator(fake_data.detach())
            d_loss_fake = criterion(outputs, fake_labels)
            d_loss_fake.backward()
            D_optimizer.step()

            # Train Generator
            G_optimizer.zero_grad()
            outputs = discriminator(fake_data)
            g_loss = criterion(outputs, real_labels)
            g_loss.backward()
            G_optimizer.step()

In [8]:
# Define Classifier model
class Classifier(nn.Module):
    def __init__(self, input_dim=784, num_classes=10):
        super(Classifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

In [9]:
# Function to train the Classifier
def train_classifier(model, optimizer, criterion, train_loader, epochs=5):
    model.train()
    for epoch in range(epochs):
        for data, targets in train_loader:
            data, targets = data.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

In [10]:
# Evaluation function to calculate accuracy
def evaluate(model, data_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, targets in data_loader:
            data, targets = data.to(device), targets.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    accuracy = 100 * correct / total
    return accuracy

In [12]:
# Class Incremental Learning function
def class_incremental_learning(train_loaders, test_loaders, num_tasks=5):
    classifier = Classifier(num_classes=10).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(classifier.parameters(), lr=learning_rate)

    all_acc = []
    avg_acc = []

    for i in range(num_tasks):
        print(f"\nTraining on task {i+1}/{num_tasks}...")

        # Train GAN for the current task
        generator = Generator(g_input_dim=z_dim, g_output_dim=mnist_dim).to(device)
        discriminator = Discriminator(d_input_dim=mnist_dim).to(device)
        train_gan(generator, discriminator, train_loaders[i], epochs=n_epoch)

        # Generate synthetic data for previous tasks
        for j in range(i):
            num_samples = len(train_tasks[j])
            z = torch.randn(num_samples, z_dim).to(device)
            synthetic_data = generator(z).view(num_samples, 1, 28, 28).detach().cpu()
            synthetic_targets = torch.tensor([j for _ in range(num_samples)])
            synthetic_dataset = torch.utils.data.TensorDataset(synthetic_data, synthetic_targets)
            synthetic_loader = DataLoader(synthetic_dataset, batch_size=batch_size, shuffle=True)

            # Train classifier on synthetic data
            train_classifier(classifier, optimizer, criterion, synthetic_loader)

        # Train classifier on new task's real data
        train_classifier(classifier, optimizer, criterion, train_loaders[i])

        # Evaluate on all seen tasks
        acc_per_task = []
        for j in range(i+1):
            acc = evaluate(classifier, test_loaders[j])
            acc_per_task.append(acc)
            print(f"Accuracy on task {j+1}: {acc:.2f}%")

        avg_acc.append(np.mean(acc_per_task))
        all_acc.append(acc_per_task)

    return all_acc, avg_acc

In [14]:
# Run the class-incremental learning
all_acc, avg_acc = class_incremental_learning(train_loaders, test_loaders)

print("\nFinal Average Accuracy per Task:")
for i, acc in enumerate(avg_acc):
    print(f"Task {i+1}: {acc:.2f}%")

# Compute Forward Transfer and Backward Transfer
fwt = (avg_acc[-1] - avg_acc[0]) / avg_acc[0] * 100
bwt = np.mean([avg_acc[-1] - acc for acc in avg_acc[:-1]])

print(f"\nForward Transfer: {fwt:.2f}%")
print(f"Backward Transfer: {bwt:.2f}%")


Training on task 1/5...
Accuracy on task 1: 99.91%

Training on task 2/5...
Accuracy on task 1: 0.00%
Accuracy on task 2: 97.89%

Training on task 3/5...
Accuracy on task 1: 0.00%
Accuracy on task 2: 0.00%
Accuracy on task 3: 99.36%

Training on task 4/5...
Accuracy on task 1: 0.00%
Accuracy on task 2: 0.00%
Accuracy on task 3: 0.00%
Accuracy on task 4: 99.65%

Training on task 5/5...
Accuracy on task 1: 0.00%
Accuracy on task 2: 0.00%
Accuracy on task 3: 0.00%
Accuracy on task 4: 0.00%
Accuracy on task 5: 98.08%

Final Average Accuracy per Task:
Task 1: 99.91%
Task 2: 48.95%
Task 3: 33.12%
Task 4: 24.91%
Task 5: 19.62%

Forward Transfer: -80.36%
Backward Transfer: -32.10%


In [41]:
# Set example dimensions
g_input_dim = 100  # Example input dimension for Generator
g_output_dim = 784  # Example output dimension for Generator

input_dim = 784  # Example input dimension for Classifier
output_dim = 10  # Example output dimension for Classifier

# Create an instance of the Generator and Classifier
generator = Generator(g_input_dim, g_output_dim)
classifier = Classifier(input_dim, output_dim)

# 1. Measure the memory and (estimate) computational requirements for generative and vanilla replay with raw samples
gan_memory = sum(p.numel() for p in generator.parameters())
cla_memory = sum(p.numel() for p in classifier.parameters())

# Print memory requirements
print(f"Memory requirements:")
print(f"Generator: {gan_memory:.2f}")
print(f"Classifier: {cla_memory:.2f}")
print()

# Estimate computational requirements (FLOPS)
# This is a simplified estimate and may vary based on actual operations and hardware
# Assume input size for estimation
# Example input sizes
input_size_generator = (1, g_input_dim)  # Example input size for Generator
input_size_classifier = (1, input_dim)    # Example input size for Classifier

# Estimate computational requirements (FLOPS) - This is a simplified estimation
# Generator FLOPS estimation
dummy_input_generator = torch.randn(input_size_generator)
num_operations_generator = g_input_dim * g_output_dim * 2  # Assume simple linear operation
generator_flops = num_operations_generator

# Classifier FLOPS estimation
dummy_input_classifier = torch.randn(input_size_classifier)
num_operations_classifier = input_dim * output_dim * 2  # Assume simple linear operation
classifier_flops = num_operations_classifier

# Print computational requirements
print(f"Computational requirements (FLOPS):")
print(f"Generator: {generator_flops} FLOPS")
print(f"Classifier: {classifier_flops} FLOPS")
print()

# Comment on memory requirements comparison
if gan_memory == cla_memory:
    print("Memory requirements are comparable.")
else:
    print("Memory requirements are not comparable due to differences in parameter counts.")


Memory requirements:
Generator: 1486352.00
Classifier: 235146.00

Computational requirements (FLOPS):
Generator: 156800 FLOPS
Classifier: 15680 FLOPS

Memory requirements are not comparable due to differences in parameter counts.


In [44]:
# Initialize cumulative memory usage
cumulative_memory = 0

# Loop through each split (five splits of two classes each)
for split in range(5):
    # Calculate memory requirements for Generator and Classifier for this split
    generator = Generator(g_input_dim, g_output_dim)
    classifier = Classifier(input_dim, output_dim)

    # Calculate memory for this split
    generator_memory = sum(p.numel() for p in generator.parameters())
    classifier_memory = sum(p.numel() for p in classifier.parameters())
    split_memory = generator_memory + classifier_memory

    # Accumulate cumulative memory usage
    cumulative_memory += split_memory

    # Print memory usage for this split
    print(f"Memory usage for Split {split+1}: Generator = {generator_memory} parameters, Classifier = {classifier_memory} parameters")

# Print cumulative memory usage
print(f"\nCumulative memory usage over 5 splits: {cumulative_memory} parameters")


Memory usage for Split 1: Generator = 1486352 parameters, Classifier = 235146 parameters
Memory usage for Split 2: Generator = 1486352 parameters, Classifier = 235146 parameters
Memory usage for Split 3: Generator = 1486352 parameters, Classifier = 235146 parameters
Memory usage for Split 4: Generator = 1486352 parameters, Classifier = 235146 parameters
Memory usage for Split 5: Generator = 1486352 parameters, Classifier = 235146 parameters

Cumulative memory usage over 5 splits: 8607490 parameters


In [48]:
# 2. How would the solution scale with n tasks? Is it linear, quadratic, or something else?
# Varying number of tasks (n)
n_tasks = [1, 5, 10, 20, 50, 100]  # Example values of n

# Function to calculate memory requirements for a given number of tasks (splits)
def calculate_memory_usage(n_tasks, g_input_dim, g_output_dim, input_dim, output_dim):
    cumulative_memory = 0

for n in n_tasks:
    print(f"Number of tasks: {n}")

    # Create instances of Generator and Classifier for each task
    generator = Generator(g_input_dim, g_output_dim)
    classifier = Classifier(input_dim, output_dim)

    # Calculate memory for this task
    gan_memory = sum(p.numel() for p in generator.parameters())
    cla_memory = sum(p.numel() for p in classifier.parameters())
    task_memory = generator_memory + classifier_memory

    # Accumulate cumulative memory usage
    cumulative_memory += task_memory

    # Print memory usage for this task
    print(f"Memory usage for Task {n+1}: Generator = {generator_memory} parameters, Classifier = {classifier_memory} parameters")

# Print cumulative memory usage
print(f"\nCumulative memory usage over {num_tasks} tasks: {cumulative_memory} parameters")

# Example dimensions and number of tasks
g_input_dim = 100  # Example input dimension for Generator
g_output_dim = 784  # Example output dimension for Generator
input_dim = 784  # Example input dimension for Classifier
output_dim = 10  # Example output dimension for Classifier
num_tasks = 5  # Number of tasks (splits), each split adds two new classes

# Calculate and print memory usage for the specified number of tasks
calculate_memory_usage(n_tasks, g_input_dim, g_output_dim, input_dim, output_dim)


Number of tasks: 1
Memory usage for Task 2: Generator = 1486352 parameters, Classifier = 235146 parameters
Number of tasks: 5
Memory usage for Task 6: Generator = 1486352 parameters, Classifier = 235146 parameters
Number of tasks: 10
Memory usage for Task 11: Generator = 1486352 parameters, Classifier = 235146 parameters
Number of tasks: 20
Memory usage for Task 21: Generator = 1486352 parameters, Classifier = 235146 parameters
Number of tasks: 50
Memory usage for Task 51: Generator = 1486352 parameters, Classifier = 235146 parameters
Number of tasks: 100
Memory usage for Task 101: Generator = 1486352 parameters, Classifier = 235146 parameters

Cumulative memory usage over 5 tasks: 20657976 parameters


In [49]:
# 3. What are the downsides of this approach?
# Function to simulate memory usage across multiple tasks (splits)
def simulate_memory_usage(num_tasks, g_input_dim, g_output_dim, input_dim, output_dim):
    cumulative_memory = 0

    for task in range(num_tasks):
        # Example: Creating Generator and Classifier for each task
        generator = Generator(g_input_dim, g_output_dim)
        classifier = Classifier(input_dim, output_dim)

        # Measure memory usage of generator and classifier
        generator_memory = calculate_model_memory(generator)
        classifier_memory = calculate_model_memory(classifier)

        # Calculate total memory usage for this task
        task_memory = generator_memory + classifier_memory

        # Accumulate cumulative memory usage
        cumulative_memory += task_memory

        # Print memory usage for this task
        print(f"Task {task+1}: Generator Memory = {generator_memory} bytes, Classifier Memory = {classifier_memory} bytes")

    # Print cumulative memory usage
    print(f"\nCumulative Memory Usage over {num_tasks} tasks: {cumulative_memory} bytes")

# Function to calculate memory usage of a model
def calculate_model_memory(model):
    # Calculate total number of parameters
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    # Assume each parameter is stored as a float32 (4 bytes)
    memory_bytes = total_params * 4

    return memory_bytes

# Example dimensions and number of tasks
g_input_dim = 100  # Example input dimension for Generator
g_output_dim = 784  # Example output dimension for Generator
input_dim = 784  # Example input dimension for Classifier
output_dim = 10  # Example output dimension for Classifier
num_tasks = 5  # Number of tasks (splits), each split adds two new classes

# Simulate memory usage across the specified number of tasks
simulate_memory_usage(num_tasks, g_input_dim, g_output_dim, input_dim, output_dim)

Task 1: Generator Memory = 5945408 bytes, Classifier Memory = 940584 bytes
Task 2: Generator Memory = 5945408 bytes, Classifier Memory = 940584 bytes
Task 3: Generator Memory = 5945408 bytes, Classifier Memory = 940584 bytes
Task 4: Generator Memory = 5945408 bytes, Classifier Memory = 940584 bytes
Task 5: Generator Memory = 5945408 bytes, Classifier Memory = 940584 bytes

Cumulative Memory Usage over 5 tasks: 34429960 bytes
