In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, ConcatDataset
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn.init as init
import torch.utils.data as data
import matplotlib.pyplot as plt

# MLP model definition
class MLP(nn.Module):
    def __init__(self, channel, num_classes):
        super(MLP, self).__init__()
        self.fc_1 = nn.Linear(28 * 28 * 1 if channel == 1 else 32 * 32 * 3, 128)
        self.fc_2 = nn.Linear(128, 128)
        self.fc_3 = nn.Linear(128, num_classes)

    def forward(self, x):
        out = x.view(x.size(0), -1)
        out = F.relu(self.fc_1(out))
        out = F.relu(self.fc_2(out))
        out = self.fc_3(out)
        return out

# Gradient matching algorithm
def gradient_matching_algorithm(model, dataset, lr_condensed=0.1, num_iterations=10, num_opt_steps=1):
    optimizer_condensed = optim.SGD(model.parameters(), lr=lr_condensed)

    for iteration in range(num_iterations):
        model.apply(random_weights_initialization)

        for step in range(num_opt_steps):
            inputs, targets = sample_mini_batch(dataset, minibatch_size=256)
            outputs = model(inputs)
            loss = F.cross_entropy(outputs, targets)

            optimizer_condensed.zero_grad()
            loss.backward()
            optimizer_condensed.step()

    return model

# Training the model
def train_model(model, train_loader, lr_model=0.01, num_opt_steps_model=50):
    optimizer_model = optim.SGD(model.parameters(), lr=lr_model)

    for epoch in range(num_opt_steps_model):
        model.apply(random_weights_initialization)
        for data, target in train_loader:
            optimizer_model.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)

            loss.backward()
            optimizer_model.step()

    return model

# Condense dataset function
def condense_dataset(dataset, num_images_per_class):
    condensed_indices = []
    classes = dataset.targets.unique()

    for class_label in classes:
        class_indices = (dataset.targets == class_label).nonzero(as_tuple=True)[0]
        selected_indices = class_indices[:num_images_per_class]
        condensed_indices.extend(selected_indices.tolist())

    condensed_dataset = Subset(dataset, condensed_indices)
    return condensed_dataset

# Other utility functions
def random_weights_initialization(module):
    if isinstance(module, nn.Linear):
        init.normal_(module.weight)
        init.constant_(module.bias, 0)

def sample_mini_batch(dataset, minibatch_size):
    data_loader = data.DataLoader(dataset, batch_size=minibatch_size, shuffle=True)
    data_iter = iter(data_loader)
    inputs, targets = next(data_iter)

    return inputs, targets

def visualize_condensed_images_per_class(dataset, num_images_per_class):
    classes = dataset.dataset.targets.unique()
    fig, axes = plt.subplots(len(classes), num_images_per_class, figsize=(12, 12))

    for i, class_label in enumerate(classes):
        class_indices = (dataset.dataset.targets == class_label).nonzero(as_tuple=True)[0]
        selected_indices = class_indices[:num_images_per_class]

        for j, idx in enumerate(selected_indices):
            img, label = dataset.dataset[idx]
            img = img.numpy().squeeze()
            axes[i, j].imshow(img, cmap='gray')
            axes[i, j].axis('off')
            axes[i, j].set_title(f'Class {label}')

    plt.tight_layout()
    plt.show()

# Create synthetic dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
mnist_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
condensed_dataset = condense_dataset(mnist_dataset, num_images_per_class=10)

# Initialize the model
model = MLP(channel=1, num_classes=10)

# Use the gradient matching algorithm to generate the synthetic dataset
model = gradient_matching_algorithm(model, condensed_dataset, lr_condensed=0.1, num_iterations=10, num_opt_steps=1)

# Define data loader for the synthetic dataset
synthetic_loader = DataLoader(condensed_dataset, batch_size=256, shuffle=True)

# Train the model on the synthetic dataset
trained_model = train_model(model, synthetic_loader, lr_model=0.01, num_opt_steps_model=50)

# Visualize condensed images per class
visualize_condensed_images_per_class(condensed_dataset, num_images_per_class=10)

# Evaluate the trained model on the original MNIST test dataset
mnist_test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
test_loader_original = DataLoader(mnist_test_dataset, batch_size=64, shuffle=False)

trained_model.eval()

correct_original = 0
total_original = 0

with torch.no_grad():
    for inputs, labels in test_loader_original:
        outputs = trained_model(inputs)
        _, predicted = torch.max(outputs, 1)

        total_original += labels.size(0)
        correct_original += (predicted == labels).sum().item()

accuracy_original = correct_original / total_original
print(f'Original Test Accuracy: {100 * accuracy_original:.2f}%')

# Continual learning with the new task dataset
# Suppose you have new_task_dataset_1 and new_task_dataset_2 as your new task datasets
num_classes_new_task = 2
num_images_per_class_new_task = 5

new_task_dataset = condense_dataset(mnist_dataset, num_images_per_class=num_images_per_class_new_task)

# Concatenate the synthetic dataset with new task datasets
continual_dataset = ConcatDataset([synthetic_loader.dataset, new_task_dataset])

# Update the synthetic loader with the combined dataset
synthetic_loader = DataLoader(continual_dataset, batch_size=256, shuffle=True)

# Train the model on the combined dataset
trained_model_continual = train_model(trained_model, synthetic_loader, lr_model=0.01, num_opt_steps_model=50)

# Evaluate the trained model on the MNIST test dataset
test_loader_continual = DataLoader(mnist_test_dataset, batch_size=64, shuffle=False)

trained_model_continual.eval()

correct_continual = 0
total_continual = 0

with torch.no_grad():
    for inputs, labels in test_loader_continual:
        outputs = trained_model_continual(inputs)
        _, predicted = torch.max(outputs, 1)

        total_continual += labels.size(0)
        correct_continual += (predicted == labels).sum().item()

accuracy_continual = correct_continual / total_continual
print(f'Continual Test Accuracy: {100 * accuracy_continual:.2f}%')

