In [None]:
import os
import numpy as np
import pandas as pd
from PIL import Image
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from tensorflow.keras import layers, models, optimizers
from tqdm import tqdm
import matplotlib.pyplot as plt

# Specify the path to the 'mhist.zip' file
mhist_zip_path = '../ProjectA/mhist_dataset'

# Load annotations from the CSV file
annotations_path = "../ProjectA/mhist_dataset/annotations.csv"
annotations_df = pd.read_csv(annotations_path, delimiter=',')

# Filter and split data based on the 'Partition' column
train_annotations = annotations_df[annotations_df['Partition'] == 'train']
test_annotations = annotations_df[annotations_df['Partition'] == 'test']

# Path to the directory containing the images
images_dir = "../ProjectA/mhist_dataset/images"



# Hyperparameters from Table 1
K = 200  # Number of random weight initializations
T = 10   # Number of iterations
ηS = 0.1  # Learning rate for the condensed samples
ζS = 1    # Number of optimization steps for the condensed samples
ηθ = 0.01  # Learning rate for the model
ζθ = 50   # Number of optimization steps for the model
batch_size = 128
num_classes = 50

# Load images and labels
image_paths = [os.path.join(images_dir, img_name) for img_name in train_annotations['Image Name']]
labels = train_annotations['Partition'].values

# Convert string labels to integers
label_encoder = LabelEncoder()
labels = label_encoder.fit_transform(labels)

# Split data into train and validation sets
train_image_paths, val_image_paths, train_labels, val_labels = train_test_split(
    image_paths, labels, test_size=0.2, random_state=42
)

# Create a simple CNN model
model = models.Sequential([
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(64, 64, 3)),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(128, (3, 3), activation='relu'),
    layers.Flatten(),
    layers.Dense(256, activation='relu'),
    layers.Dense(num_classes, activation='softmax')
])

# Compile the model
model.compile(optimizer=optimizers.SGD(lr=ηθ), loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Repeat the process with Gaussian noise initialization
num_repeats = 1  # You can adjust the number of repeats as needed

for repeat in range(num_repeats):
    print(f"Repeat {repeat + 1} with Gaussian Noise Initialization")

    # Initialize condensed images with Gaussian noise
    condensed_images = np.random.normal(loc=0, scale=1, size=(K, 64, 64, 3))

    # Convert to TensorFlow Tensor
    condensed_images_tensor = tf.constant(condensed_images, dtype=tf.float32)

    # Gradient Matching algorithm
    for iteration in tqdm(range(T)):
        # Update condensed samples
        for _ in range(ζS):
            with tf.GradientTape() as tape:
                tape.watch(condensed_images_tensor)
                loss_S = tf.reduce_sum(model(condensed_images_tensor))
            grads_S = tape.gradient(loss_S, condensed_images_tensor)
            condensed_images_tensor -= ηS * grads_S.numpy()

        # Convert back to NumPy array
        condensed_images = condensed_images_tensor.numpy()

        # Update model
        for _ in range(ζθ):
            indices = np.random.choice(len(train_image_paths), batch_size, replace=False)
            batch_images = np.array([load_and_preprocess_image(train_image_paths[i]) for i in indices])
            batch_labels = labels[indices]

            with tf.GradientTape() as tape:
                predictions = model(batch_images)
                loss = tf.keras.losses.sparse_categorical_crossentropy(batch_labels, predictions)

            gradients = tape.gradient(loss, model.trainable_variables)
            optimizer = tf.keras.optimizers.SGD(learning_rate=ηθ)
            optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    # Visualize condensed images per class
    fig, axes = plt.subplots(num_classes, K // num_classes, figsize=(15, 15))

    for i in range(num_classes):
        class_images = condensed_images[i * (K // num_classes): (i + 1) * (K // num_classes)]
        for j, img in enumerate(class_images):
            axes[i, j].imshow(img)
            axes[i, j].axis('off')

    plt.show()

