In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist

In [None]:
# ---------------------------
# 1. Load & preprocess MNIST
# ---------------------------
(X_train, _), (_, _) = mnist.load_data()

# Use smaller subset for speed
X_train = X_train[:5000]  

# Flatten images (28x28 -> 784)
X = X_train.reshape(X_train.shape[0], -1).astype(float)

# Normalize to [0,1]
X /= 255.0


In [None]:
# ---------------------------
# 2. Implement NMF (from scratch)
# ---------------------------
def nmf(X, n_components, max_iter=200, tol=1e-4):
    n_samples, n_features = X.shape
    
    # Random initialization
    np.random.seed(42)
    W = np.abs(np.random.randn(n_samples, n_components))
    H = np.abs(np.random.randn(n_components, n_features))
    
    eps = 1e-10
    errors = []
    
    for i in range(max_iter):
        # Update H
        H *= (W.T @ X) / (W.T @ W @ H + eps)
        # Update W
        W *= (X @ H.T) / (W @ (H @ H.T) + eps)
        
        # Reconstruction error
        X_approx = W @ H
        error = np.linalg.norm(X - X_approx, 'fro')
        errors.append(error)
        
        # Early stopping
        if i > 0 and abs(errors[-2] - error) < tol:
            break
    
    return W, H, errors[-1]

In [None]:
# ---------------------------
# 3. Hyperparameter tuning
# ---------------------------
errors_dict = {}
for k in range(2, 21):
    print(f"Training NMF with n_components={k}")
    W, H, error = nmf(X, n_components=k, max_iter=100)
    errors_dict[k] = error

# Best number of components
best_k = min(errors_dict, key=errors_dict.get)
print("Best n_components:", best_k, "with error:", errors_dict[best_k])

In [None]:

# ---------------------------
# 4. Visualize reconstructions
# ---------------------------
# Re-run NMF with best_k
W, H, _ = nmf(X, n_components=best_k, max_iter=200)
X_reconstructed = W @ H

# Plot few examples
n_show = 10
plt.figure(figsize=(20, 4))
for i in range(n_show):
    # Original
    ax = plt.subplot(2, n_show, i+1)
    plt.imshow(X[i].reshape(28, 28), cmap="gray")
    plt.axis("off")
    if i == 0: ax.set_title("Original")
    
    # Reconstructed
    ax = plt.subplot(2, n_show, i+1+n_show)
    plt.imshow(X_reconstructed[i].reshape(28, 28), cmap="gray")
    plt.axis("off")
    if i == 0: ax.set_title("Reconstructed")

plt.show()

# ---------------------------
# 5. Plot reconstruction error vs components
# ---------------------------
plt.figure(figsize=(8,5))
plt.plot(list(errors_dict.keys()), list(errors_dict.values()), marker='o')
plt.xlabel("Number of Components (k)")
plt.ylabel("Reconstruction Error (Frobenius Norm)")
plt.title("NMF Hyperparameter Tuning")
plt.grid()
plt.show()