In [None]:
import os
import zipfile
import pandas as pd
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
import copy
import torch.nn.functional as F

# Specify the path to the 'mhist.zip' file
mhist_zip_path = '../ProjectA/mhist_dataset'

# Load annotations from the CSV file
annotations_path = "../ProjectA/mhist_dataset/annotations.csv"
annotations_df = pd.read_csv(annotations_path, delimiter=',')

# Filter and split data based on the 'Partition' column
train_annotations = annotations_df[annotations_df['Partition'] == 'train']
test_annotations = annotations_df[annotations_df['Partition'] == 'test']

# Path to the directory containing the images
images_dir = "../ProjectA/mhist_dataset/images"

# Unzip the images if not already done
if not os.path.exists(images_dir):
    with zipfile.ZipFile(mhist_zip_path + '.zip', 'r') as zip_ref:
        zip_ref.extractall(images_dir)

# Define a custom dataset class for MHIST
class MHISTDataset(Dataset):
    def __init__(self, annotations, images_dir, transform=None, synthetic_images=None):
        self.annotations = annotations
        self.images_dir = images_dir
        self.transform = transform
        self.synthetic_images = synthetic_images

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        if self.synthetic_images is not None:
            synthetic_image = self.synthetic_images[idx]
            label = int(self.annotations.iloc[idx, 2])  # Assuming the label column is at index 2
            return synthetic_image, label
        else:
            img_name = os.path.join(self.images_dir, self.annotations.iloc[idx, 0])
            image = Image.open(img_name).convert('RGB')  # Assuming images are RGB
            label = int(self.annotations.iloc[idx, 2])  # Assuming the label column is at index 2

            if self.transform:
                image = self.transform(image)

            return image, label

# Define transformations for the dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Adjust the size as needed
    transforms.ToTensor(),
])

# Create MHIST datasets and dataloaders
train_dataset = MHISTDataset(train_annotations, images_dir, transform=transform)
test_dataset = MHISTDataset(test_annotations, images_dir, transform=transform)

batch_size = 128  # Adjust the batch size as needed

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# Specify the selected model architecture
class SelectedModel(nn.Module):
    def __init__(self):
        super(SelectedModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(128 * 56 * 56, 256)
        self.fc2 = nn.Linear(256, 10)  # Assuming 10 classes for classification

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 128 * 56 * 56)  # Adjust the size based on your model architecture
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# (a) Train the selected model on the original dataset
model_a = SelectedModel()
criterion = nn.CrossEntropyLoss()
optimizer_a = optim.SGD(model_a.parameters(), lr=0.01)
scheduler_a = CosineAnnealingLR(optimizer_a, T_max=20)

# Training loop for part (a)
for epoch in range(20):
    model_a.train()
    for images, labels in train_loader:
        optimizer_a.zero_grad()
        outputs = model_a(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer_a.step()

    scheduler_a.step()

# Evaluation on the original test set
model_a.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model_a(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy_original = correct / total
print(f"Accuracy on original test set: {accuracy_original}")

# (b) Learn the synthetic dataset S using the Gradient Matching algorithm
def gradient_matching(model, synthetic_loader, eta_S, zeta_S, eta_theta, zeta_theta, K, T):
    synthetic_optimizer = optim.SGD(model.parameters(), lr=eta_S)

    synthetic_images_list = []
    for k in range(K):
        synthetic_images = torch.randn_like(next(iter(synthetic_loader))[0])  # Random initialization

        for t in range(T):
            model.train()
            synthetic_optimizer.zero_grad()
            synthetic_outputs = model(synthetic_images)
            synthetic_loss = synthetic_outputs.sum()  # Replace with your specific loss
            synthetic_loss.backward()
            synthetic_optimizer.step()

        model_copy = copy.deepcopy(model)

        inner_optimizer = optim.SGD(model.parameters(), lr=eta_theta)

        for inner_step in range(zeta_theta):
            real_images, real_labels = next(iter(train_loader))
            real_outputs = model(real_images)
            synthetic_outputs = model_copy(synthetic_images)
            gm_loss = F.mse_loss(real_outputs, synthetic_outputs)
            inner_optimizer.zero_grad()
            gm_loss.backward()
            inner_optimizer.step()

        synthetic_images_list.append(synthetic_images)

    return model, synthetic_images_list

# Apply gradient matching to learn synthetic dataset S
model_b, synthetic_images_b = gradient_matching(model_a, train_loader, eta_S=0.1, zeta_S=1, eta_theta=0.01, zeta_theta=50, K=200, T=10)

# Use a DataLoader for the synthetic dataset
synthetic_dataset_b = MHISTDataset(train_annotations, images_dir, transform=transform, synthetic_images=synthetic_images_b)
synthetic_loader_b = DataLoader(dataset=synthetic_dataset_b, batch_size=batch_size, shuffle=True)

# Visualize condensed images
def visualize_condensed_images(model, synthetic_loader):
    model.eval()
    with torch.no_grad():
        for images, _ in synthetic_loader:
            synthetic_outputs = model(images)
            # Visualize or save the condensed images as needed
            # ...

# Visualize condensed images
visualize_condensed_images(model_b, synthetic_loader_b)

# (d) Repeat (b) and (c) with condensed images initialized with Gaussian noise
model_c, synthetic_images_c = gradient_matching(model_a, train_loader, eta_S=0.1, zeta_S=1, eta_theta=0.01, zeta_theta=50, K=200, T=10)

# Use a DataLoader for the synthetic dataset
synthetic_dataset_c = MHISTDataset(train_annotations, images_dir, transform=transform, synthetic_images=synthetic_images_c)
synthetic_loader_c = DataLoader(dataset=synthetic_dataset_c, batch_size=batch_size, shuffle=True)

# Visualize condensed images
visualize_condensed_images(model_c, synthetic_loader_c)

# (e) Train the selected network on the learned synthetic dataset
model_e = SelectedModel()
optimizer_e = optim.SGD(model_e.parameters(), lr=0.01)
scheduler_e = CosineAnnealingLR(optimizer_e, T_max=20)

# Training loop for part (e)
for epoch in range(20):
    model_e.train()
    for synthetic_images, labels in synthetic_loader_b:
        optimizer_e.zero_grad()
        outputs = model_e(synthetic_images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer_e.step()

    scheduler_e.step()

# Evaluation on the synthetic test set
model_e.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model_e(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy_synthetic = correct / total
print(f"Accuracy on synthetic test set: {accuracy_synthetic}")

# Print or visualize the comparison of test accuracy and training time
print(f"Accuracy on original test set: {accuracy_original}")
print(f"Accuracy on synthetic test set: {accuracy_synthetic}")

