In [131]:
# Imports

import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D, Dropout
from tensorflow.keras.optimizers import Adam
from sklearn.decomposition import PCA
import numpy as np
import plotly_express as px
import plotly.graph_objects as go
import pandas as pd

In [115]:
# Main

# Load the MNIST dataset
(X_train, y_train), (X_test, y_test) = mnist.load_data()
n_train = X_train.shape[0]
n_test = X_test.shape[0]

# Normalise images to be in the range [-1, 1]
X_train = X_train / 127.5 - 1
X_test = X_test / 127.5 - 1

# Convert each 28x28 image into a 784 dimensional vector
features_count = np.prod(X_train.shape[1:])
X_train_flatened = X_train.reshape(n_train, features_count)
X_test_flatened = X_test.reshape(n_test, features_count)


#### --- Task 1 --- ####

In [None]:
# PCA and Centroids 
# The centroid is calculated by averaging the coordinates of all the points in a cluster. This average gives you a single point that best represents the center of that cluster. 

# Reduce the dimensionality of the data to 2 dimensions
pca = PCA(n_components=2)
X_train_pca = pca.fit_transform(X_train_flatened)

# Create a scatter plot of the PCA data, colored by digit
fig = px.scatter(X_train_pca, x=0, y=1, color=y_train, title='PCA plot of the MNIST Dataset', width=1000, height=600)
fig.update_layout(xaxis_title='Principal Component 1', yaxis_title='Principal Component 2')

# Create a DataFrame with the PCA data and digit labels
df_pca = pd.DataFrame(X_train_pca, columns=['PC1', 'PC2'])
df_pca['digit'] = y_train

# Compute centroids for each class by taking the mean of PC1 and PC2
centroids = df_pca.groupby('digit')[['PC1', 'PC2']].mean()

# Build a color mapping from the default Plotly Express palette
color_sequence = px.colors.qualitative.Plotly
unique_digits = sorted(df_pca['digit'].unique())
color_map = {digit: color_sequence[i % len(color_sequence)] for i, digit in enumerate(unique_digits)}

# Add centroids as larger markers, each colored according to its class
for digit, row in centroids.iterrows():
    fig.add_trace(
        go.Scatter(
            x=[row['PC1']],
            y=[row['PC2']],
            mode='markers',
            marker=dict(color=color_map[digit], size=15, symbol='diamond'),
            name=f'Centroid {digit}'
        )
    )

fig.show()

In [None]:
# Scree Plot - Shows the percentage of variance explained by each principal component

pca_full = PCA(n_components=50)
pca_full.fit(X_train_flatened)
variance_ratios = pca_full.explained_variance_ratio_
components = np.arange(1, len(variance_ratios) + 1)

df = pd.DataFrame({'Principal Component': components, 'Explained Variance': variance_ratios * 100})
df['Cumulative Variance'] = df['Explained Variance'].cumsum()

fig = px.bar(df, x='Principal Component', y='Explained Variance', title='Scree Plot & Cumulative Variance', labels={'Explained Variance': 'Percentage of Variance Explained'}, width=1000, height=500)
fig.add_scatter(x=df['Principal Component'], y=df['Cumulative Variance'], mode='lines+markers', name='Cumulative Variance', line=dict(color='red'))
fig.show()

In [None]:
# Questions / Notes

# Why is PCA a good option to visualise data?
# PCA is a good option to visualise data because it reduces the dimensionality of the data to 2 dimensions, which makes it easier to plot and understand.

# Observations
# Clustering of Classes - 784D to 2D space and visually see the differences between the different classes, which is not possible in the original 784D space.
# Separation of Certain Classes - Some digits form more isolated clusters. Eg digit 1 has a tight cluster. - Relatively simple and unique shape compared to other digits.
# Overlap Among Other Classes - Digit classes, such as 3, 5, and 8, have clusters that overlap considerably. This suggests that their differences may not be well captured by a linear projection onto the first two principal components.

# Q - Which classes can be linearly separated?
# A - 1 and 0 can be linearly separated. Where 3, 5 and 8 cannot be linearly separated.

# PCA Notes
# When you have high dimensional data, there are many directions in which the data can vary. The first principal component is the direction along which the data varies the most.
# By projecting your data onto the space defined by the top few principal components (often just two for visualisation), you reduce the dimensionality while retaining most of the information (variance) in the original data. 