In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import torch.nn.init as init
import torch.utils.data as data
import matplotlib.pyplot as plt
import numpy as np

# Define MLP architecture
class MLP(nn.Module):
    def __init__(self, channel, num_classes):
        super(MLP, self).__init__()
        self.fc_1 = nn.Linear(28 * 28 * 1 if channel == 1 else 32 * 32 * 3, 128)
        self.fc_2 = nn.Linear(128, 128)
        self.fc_3 = nn.Linear(128, num_classes)

    def forward(self, x):
        out = x.view(x.size(0), -1)
        out = F.relu(self.fc_1(out))
        out = F.relu(self.fc_2(out))
        out = self.fc_3(out)
        return out

# Define LeNet architecture
class LeNet(nn.Module):
    def __init__(self, channel, num_classes):
        super(LeNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(channel, 6, kernel_size=5, padding=2 if channel==1 else 0),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(6, 16, kernel_size=5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.fc_1 = nn.Linear(16 * 5 * 5, 120)
        self.fc_2 = nn.Linear(120, 84)
        self.fc_3 = nn.Linear(84, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc_1(x))
        x = F.relu(self.fc_2(x))
        x = self.fc_3(x)
        return x

# Gradient Matching Algorithm
def gradient_matching_algorithm(model, dataset, initialization_method='gaussian', num_iterations=10, num_opt_steps=1):
    optimizer_condensed = optim.SGD(model.parameters(), lr=lr_condensed)

    for iteration in range(num_iterations):
        if initialization_method == 'gaussian':
            model.apply(gaussian_weights_initialization)
        else:
            model.apply(random_weights_initialization)

        for step in range(num_opt_steps):
            inputs, targets = sample_mini_batch(dataset, minibatch_size=256)

            outputs = model(inputs)
            loss = F.cross_entropy(outputs, targets)

            optimizer_condensed.zero_grad()
            loss.backward()
            optimizer_condensed.step()

    return model

# Gaussian weights initialization
def gaussian_weights_initialization(module):
    if isinstance(module, nn.Linear):
        init.normal_(module.weight)
        init.constant_(module.bias, 0)

# Train the model
def train_model(model, train_loader, lr_model=0.01, num_opt_steps_model=50):
    optimizer_model = optim.SGD(model.parameters(), lr=lr_model)

    for epoch in range(num_opt_steps_model):
        model.apply(random_weights_initialization)
        for data, target in train_loader:
            optimizer_model.zero_grad()

            output = model(data)
            loss = F.cross_entropy(output, target)

            loss.backward()
            optimizer_model.step()

    return model

# Dataset Condensation
def condense_dataset(dataset, num_images_per_class):
    condensed_indices = []
    classes = dataset.targets.unique()

    for class_label in classes:
        class_indices = (dataset.targets == class_label).nonzero(as_tuple=True)[0]
        selected_indices = class_indices[:num_images_per_class]
        condensed_indices.extend(selected_indices.tolist())

    condensed_dataset = Subset(dataset, condensed_indices)
    return condensed_dataset

# Random weights initialization
def random_weights_initialization(module):
    if isinstance(module, nn.Linear):
        init.normal_(module.weight)
        init.constant_(module.bias, 0)

# Sample a mini-batch
def sample_mini_batch(dataset, minibatch_size):
    data_loader = data.DataLoader(dataset, batch_size=minibatch_size, shuffle=True)
    data_iter = iter(data_loader)
    inputs, targets = next(data_iter)
    return inputs, targets

# Visualize condensed images per class
def visualize_condensed_images_per_class(dataset, num_images_per_class):
    classes = dataset.dataset.targets.unique()
    fig, axes = plt.subplots(len(classes), num_images_per_class, figsize=(12, 12))

    for i, class_label in enumerate(classes):
        class_indices = (dataset.dataset.targets == class_label).nonzero(as_tuple=True)[0]
        selected_indices = class_indices[:num_images_per_class]

        for j, idx in enumerate(selected_indices):
            img, label = dataset.dataset[idx]
            img = img.numpy().squeeze()
            axes[i, j].imshow(img, cmap='gray')
            axes[i, j].axis('off')
            axes[i, j].set_title(f'Class {label}')

    plt.tight_layout()
    plt.show()

# Set parameters
num_images_per_class = 10
minibatch_size = 256
lr_condensed = 0.1
num_iterations = 10
num_opt_steps_condensed = 1
lr_model = 0.01
num_opt_steps_model = 50

# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
mnist_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)

# Condense the dataset
condensed_dataset = condense_dataset(mnist_dataset, num_images_per_class=num_images_per_class)

# Initialize the MLP model
mlp_model = MLP(channel=1, num_classes=10)

# Use the gradient matching algorithm to generate the synthetic dataset with random weights initialization
mlp_model = gradient_matching_algorithm(mlp_model, condensed_dataset, initialization_method='random', num_iterations=num_iterations, num_opt_steps=num_opt_steps_condensed)

# Define data loader for the synthetic dataset
synthetic_loader = DataLoader(condensed_dataset, batch_size=minibatch_size, shuffle=True)

# Train the MLP model on the synthetic dataset
trained_mlp_model = train_model(mlp_model, synthetic_loader, lr_model=lr_model, num_opt_steps_model=num_opt_steps_model)

# Define LeNet model
lenet_model = LeNet(channel=1, num_classes=10)

# Use the gradient matching algorithm to generate the synthetic dataset with random weights initialization
lenet_model = gradient_matching_algorithm(lenet_model, condensed_dataset, initialization_method='random', num_iterations=num_iterations, num_opt_steps=num_opt_steps_condensed)

# Define data loader for the synthetic dataset
synthetic_loader_lenet = DataLoader(condensed_dataset, batch_size=minibatch_size, shuffle=True)

# Train the LeNet model on the synthetic dataset
trained_lenet_model = train_model(lenet_model, synthetic_loader_lenet, lr_model=lr_model, num_opt_steps_model=num_opt_steps_model)

# Evaluate the performance of the trained LeNet model on the test set
transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
mnist_test_dataset = datasets.MNIST(root='./data', train=False, transform=transform_test, download=True)
test_loader_lenet = DataLoader(mnist_test_dataset, batch_size=minibatch_size, shuffle=False)

def test_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    accuracy = correct / total
    print(f'Test Accuracy of the LeNet model on the test set: {accuracy}')

# Test the LeNet model on the test set
test_model(trained_lenet_model, test_loader_lenet)

