In [None]:
import numpy as np
import matplotlib.pyplot as plt
from minisom import MiniSom
from skimage.metrics import mean_squared_error, structural_similarity
import tensorflow as tf

# Step 1: Load and Preprocess CIFAR-10 Dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# Normalize the pixel values between 0 and 1
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

# Step 2: Define patch size and flatten images into patches for SOM training
patch_size = 8
x_train_patches = []
for img in x_train:
    h, w, c = img.shape
    for i in range(0, h, patch_size):
        for j in range(0, w, patch_size):
            patch = img[i:i+patch_size, j:j+patch_size].flatten()
            x_train_patches.append(patch)

x_train_patches = np.array(x_train_patches)

# Step 3: Initialize SOM with patch-sized input dimension
som_grid_size = (40, 40)  # Increased grid size to capture more variations
initial_learning_rate = 0.5
initial_sigma = 2.0  # Slightly larger neighborhood radius to start
num_iterations = 5000  # More iterations to allow the SOM to learn more

# Create SOM instance for 8x8x3 (patch) input
som = MiniSom(som_grid_size[0], som_grid_size[1], patch_size*patch_size*3, sigma=initial_sigma, learning_rate=initial_learning_rate)
som.random_weights_init(x_train_patches)

# Step 4: Train the SOM on patches
def train_som(som, data, num_iterations):
    for i in range(num_iterations):
        # Dynamic learning rate and sigma (neighborhood size)
        lr = initial_learning_rate * (1 - i / num_iterations)
        sigma = initial_sigma * (1 - i / num_iterations)

        # Select a random sample from the patches
        random_sample = data[np.random.randint(0, len(data))]

        # Train SOM with updated learning rate and neighborhood
        som.update(random_sample, som.winner(random_sample), lr, sigma)

# Train SOM on patch data
train_som(som, x_train_patches, num_iterations)

# Step 5: Compress image using SOM on patches instead of the entire image
def compress_image_by_patches(image, som, patch_size=8):
    """
    Compress an image using SOM by dividing it into smaller patches (e.g., 8x8) and mapping each patch to the closest SOM weight.
    """
    h, w, c = image.shape
    compressed_image = np.zeros_like(image)

    for i in range(0, h, patch_size):
        for j in range(0, w, patch_size):
            # Extract the patch and flatten it
            patch = image[i:i+patch_size, j:j+patch_size].flatten()

            # Find the BMU for the patch
            bmu = som.winner(patch)

            # Replace the patch with the BMU weight
            compressed_patch = som.get_weights()[bmu].reshape(patch_size, patch_size, c)
            compressed_image[i:i+patch_size, j:j+patch_size] = compressed_patch

    return compressed_image

# Step 6: Compress a sample image from the test set using the patch-based method
image_to_compress = x_test[0]  # Use the first image from the test set
compressed_image = compress_image_by_patches(image_to_compress, som, patch_size=8)

# Step 7: Display the original and compressed images
plt.figure(figsize=(10, 5))

# Original image
plt.subplot(1, 2, 1)
plt.title('Original Image')
plt.imshow(image_to_compress)

# Compressed image
plt.subplot(1, 2, 2)
plt.title('Compressed Image (with patches)')
plt.imshow(compressed_image)
plt.show()

# Step 8: Calculate Performance Metrics (MSE, SSIM)
mse = mean_squared_error(image_to_compress, compressed_image)
print(f'Mean Squared Error (MSE): {mse}')

ssim = structural_similarity(image_to_compress, compressed_image, win_size=5, channel_axis=-1, data_range=1)
print(f'Structural Similarity Index (SSIM): {ssim}')

# Step 9: File size comparison (in-memory size)
original_size = image_to_compress.nbytes
compressed_size = compressed_image.nbytes

print(f"Original image size: {original_size} bytes")
print(f"Compressed image size: {compressed_size} bytes")

# Step 10: Plot the MSE and SSIM values
plt.figure(figsize=(6, 4))

# MSE bar plot
plt.subplot(1, 2, 1)
plt.bar(['MSE'], [mse])
plt.title('Mean Squared Error')

# SSIM bar plot
plt.subplot(1, 2, 2)
plt.bar(['SSIM'], [ssim])
plt.title('Structural Similarity Index')

plt.tight_layout()
plt.show()