In [None]:
import os
import zipfile
import pandas as pd
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
import copy
import torch.nn.functional as F
from argparse import Namespace

# Specify the path to the 'mhist.zip' file
mhist_zip_path = '../ProjectA/mhist_dataset'

# Load annotations from the CSV file
annotations_path = "../ProjectA/mhist_dataset/annotations.csv"
annotations_df = pd.read_csv(annotations_path, delimiter=',')

# Filter and split data based on the 'Partition' column
train_annotations = annotations_df[annotations_df['Partition'] == 'train']
test_annotations = annotations_df[annotations_df['Partition'] == 'test']

# Path to the directory containing the images
images_dir = "../ProjectA/mhist_dataset/images"
class MHISTDataset(Dataset):
    def __init__(self, annotations, images_dir, transform=None):
        self.annotations = annotations
        self.images_dir = images_dir
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        img_name = os.path.join(self.images_dir, self.annotations.iloc[idx, 0])
        image = Image.open(img_name).convert('RGB')  # Assuming images are RGB
        label = int(self.annotations.iloc[idx, 2])  # Assuming the label column is at index 2

        if self.transform:
            image = self.transform(image)

        return image, label

# Define transformations for the dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Adjust the size as needed
    transforms.ToTensor(),
])

# Create MHIST datasets and dataloaders
train_dataset = MHISTDataset(train_annotations, images_dir, transform=transform)
test_dataset = MHISTDataset(test_annotations, images_dir, transform=transform)

batch_size = 128  # Adjust the batch size as needed

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# Use a DataLoader for the synthetic dataset
synthetic_dataset = MHISTDataset(train_annotations, images_dir, transform=transform)
synthetic_loader = DataLoader(dataset=synthetic_dataset, batch_size=batch_size, shuffle=True)

class SelectedModel(nn.Module):
    def __init__(self):
        super(SelectedModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(128 * 56 * 56, 256)
        self.fc2 = nn.Linear(256, 10)  # Assuming 10 classes for classification

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 128 * 56 * 56)  # Adjust the size based on your model architecture
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# (e) Train the selected network on the learned synthetic dataset
model_e = SelectedModel()
optimizer_e = optim.SGD(model_e.parameters(), lr=0.01)
scheduler_e = CosineAnnealingLR(optimizer_e, T_max=20)
criterion = nn.CrossEntropyLoss()
# Training loop for part (e)
for epoch in range(100):
    model_e.train()
    training_info = Namespace(loss=0)
    for images, labels in synthetic_loader:
        optimizer_e.zero_grad()
        outputs = model_e(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer_e.step()
        training_info.loss += loss.item()

    training_info.loss /= len(synthetic_loader)
    scheduler_e.step()
    # Print
    print(f"Epoch [{epoch+1:3d}/100] - Training Loss: {training_info.loss:.4f} ")

# Evaluation on the synthetic test set
model_e.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model_e(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy_synthetic = correct / total

# Print or visualize the comparison of test accuracy and training time
print(f"Accuracy on original test set: {accuracy_original}")
print(f"Accuracy on synthetic test set: {accuracy_synthetic}")