In [1]:
# Question 6: Custom Autoencoder Model for Complex Dataset
# Description: Build a custom autoencoder for a more complex dataset, like CIFAR-10.
import numpy as np
from sklearn.decomposition import PCA
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt

# Load CIFAR-10 dataset
from sklearn import datasets
import numpy as np

# Load CIFAR-10 data from sklearn datasets
cifar10 = datasets.fetch_openml('CIFAR_10_small')

# Preprocess the data
X = cifar10.data.astype(np.float32) / 255.0  # Normalize pixel values to [0, 1]

# Split into train and test
X_train, X_test = train_test_split(X, test_size=0.2, random_state=42)

# Scale the data for PCA (autoencoders often work better with normalized data)
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# Apply PCA for dimensionality reduction
n_components = 50  # This is a hyperparameter that you can adjust
pca = PCA(n_components=n_components)
X_train_pca = pca.fit_transform(X_train_scaled)
X_test_pca = pca.transform(X_test_scaled)

# Reconstruct data from PCA
X_train_reconstructed = pca.inverse_transform(X_train_pca)
X_test_reconstructed = pca.inverse_transform(X_test_pca)

# Calculate mean squared error (MSE) as a reconstruction error
train_mse = mean_squared_error(X_train_scaled, X_train_reconstructed)
test_mse = mean_squared_error(X_test_scaled, X_test_reconstructed)

print(f"Train Reconstruction MSE: {train_mse}")
print(f"Test Reconstruction MSE: {test_mse}")

# Visualize the original and reconstructed data (for a few samples)
n_samples = 5
fig, axes = plt.subplots(2, n_samples, figsize=(10, 4))

for i in range(n_samples):
    # Original images
    axes[0, i].imshow(X_test.iloc[i].reshape(32, 32, 3))  # Assuming images are 32x32 with 3 channels (RGB)
    axes[0, i].axis('off')
    
    # Reconstructed images
    axes[1, i].imshow(X_test_reconstructed[i].reshape(32, 32, 3))
    axes[1, i].axis('off')

plt.show()



: 