### Installs

In [None]:
# pip install seaborn
# pip install numpy==1.23.5
# pip install scikit-learn
# python version: 3.9.6
# pip install gymnasium

### Imports

In [None]:
import time
import warnings
import numpy as np
import matplotlib.pyplot as plt

from sklearn.exceptions import ConvergenceWarning
from sklearn.preprocessing import StandardScaler
from sklearn.neural_network import MLPClassifier

from sklearn.metrics import accuracy_score
from sklearn.datasets import load_digits
from sklearn.inspection import permutation_importance
from sklearn.model_selection import train_test_split

#### Helper Functions

In [None]:
"""
Shows the distribution of data in array data. Data must be an array of integers.
"""
def distributionOfData(data):
    # Define bins
    bin_edges = np.arange(0, 11, 1)  # Bin edges at integer intervals (-4 to 4)
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2  # Compute bin centers

    # Create the histogram
    plt.hist(data, bins=bin_edges, color='blue', edgecolor='black', alpha=0.7, label="Data Distribution")

    # Add labels and title
    plt.xlabel('Value Ranges (Standard Deviations)', fontsize=12)
    plt.ylabel('Frequency (Count of Data Points)', fontsize=12)
    plt.title('Histogram of Randomly Generated Data', fontsize=14)

    # Set x-axis ticks at bin centers
    plt.xticks(bin_centers, labels=[f"{x}" for x in bin_edges[:-1]])

    # Add grid for better readability
    plt.grid(axis='y', linestyle='--', alpha=0.6)

    # Add legend
    plt.legend()

    # Show the plot
    plt.show()

"""
Shows the images of digits used in the data set.
"""
def showVisualDigitsData(X_data, y_data):
    # Print data
    print("Example of Data:")

    plt.imshow(X_data[0].reshape(8, 8), cmap="gray")
    plt.title(f"Example Image (Label: {y_data[0]})")
    plt.show()

"""
Shows numerical representation of each digit. Each digit is represented as a len-64 array with each value ranging between 0 an 16 where 0
correlates to black and 16 correlates to white.
"""
def showNumericalDigitsData(X_data, y_data):
    for i in range(3):
        print(f"Y-value: {X_data[i]}")
        print(f"X-Value: {y_data.data[i]}\n")

"""
Filters the dataset to include digits < max_val
"""
def filteredDigits(max_val=5):
        # Load the digits dataset
    digits = load_digits()

    # Extract the labels (target) from the dataset
    y = digits.target

    # Use np.where to find the indices of labels that are less than 5
    indices_less_than_max = np.where(y < max_val)
    
    return digits.data[indices_less_than_max], digits.target[indices_less_than_max]

"""
Gets all the digits.
"""
def allDigits():
    digits = load_digits()
    return digits.data, digits.target

"""
Creates the supervised model. no-print means no statistics for each epoch will be printed
"""
def supervisedModel(X_train, X_test, y_test, y_train, epochs=10, batch_size=64, no_print=False):
    warnings.filterwarnings("ignore", category=ConvergenceWarning)
    # Scale data
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)

    # Define MLP model (warm_start=True allows incremental training)
    model = MLPClassifier(hidden_layer_sizes=(128, 64), max_iter=1, warm_start=True, verbose=False, batch_size=batch_size)

    # Lists to store accuracy values
    train_accuracies = []
    test_accuracies = []
    epoch_times = []  # List to store epoch times

    # Training loop
    for epoch in range(epochs):
        start_time = time.time()  # Start time of the epoch

        model.fit(X_train, y_train)
        
        train_acc = accuracy_score(y_train, model.predict(X_train))
        test_acc = accuracy_score(y_test, model.predict(X_test))
        
        train_accuracies.append(train_acc)
        test_accuracies.append(test_acc)
        if not (no_print):
            print(f"Epoch {epoch+1}/{epochs} - Training Accuracy: {train_acc:.4f}, Testing Accuracy: {test_acc:.4f}")
        end_time = time.time()  # End time of the epoch
        epoch_time = end_time - start_time  # Time taken for the epoch
        epoch_times.append(epoch_time)  # Store epoch time
    return model, train_accuracies, test_accuracies, epoch_times

"""
Graphs test vs training loss.
"""
def graphAccuracies(train_accuracies, test_accuracies):
    epochs = range(1, len(train_accuracies) + 1)  # Create an epoch range
    
    plt.figure(figsize=(8, 5))
    plt.plot(epochs, train_accuracies, label="Training Accuracy", marker='o')
    plt.plot(epochs, test_accuracies, label="Testing Accuracy", marker='s')
    
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.title("Training vs Testing Accuracy Over Epochs")
    plt.legend()
    plt.grid(True)
    plt.show()

"""
Shows images of the correctly predicted digits.
"""
def showCorrectPredictions(model, X_test, y_test):
    model_preds = model.predict(X_test)
    correct_indices = np.where(model_preds == y_test)[0][:10]  # Select 10 correctly classified digits

    fig, axes = plt.subplots(2, 5, figsize=(10, 5))
    for i, ax in enumerate(axes.flat):
        ax.imshow(X_test[correct_indices[i]].reshape(8, 8), cmap='gray')
        ax.set_title(f"Pred: {model_preds[correct_indices[i]]}, True: {y_test[correct_indices[i]]}")
        ax.axis('off')
    plt.show()

"""
Shows images of the incorrectly predicted digits.
"""
def showIncorrectPredictions(model, X_test, y_test):
    model_preds = model.predict(X_test)
    incorrect_indices = np.where(model_preds != y_test)[0][:10]  # Select 10 correctly classified digits

    fig, axes = plt.subplots(2, 5, figsize=(10, 5))
    for i, ax in enumerate(axes.flat):
        ax.imshow(X_test[incorrect_indices[i]].reshape(8, 8), cmap='gray')
        ax.set_title(f"Pred: {model_preds[incorrect_indices[i]]}, True: {y_test[incorrect_indices[i]]}")
        ax.axis('off')
    plt.show()

"""
Creates feature importance graph.
"""
def showFeatureImportance(model, X_test, y_test):
    
    # Assuming 'model' is your trained MLPClassifier model
    result = permutation_importance(model, X_test, y_test, n_repeats=10, random_state=42)

    # Get the feature importances
    importances = result.importances_mean

    norm_importances = (importances - np.min(importances)) / (np.max(importances) - np.min(importances))

    importance_matrix = norm_importances.reshape(8, 8)
    fig, ax = plt.subplots(figsize=(8, 8))
    cax = ax.matshow(importance_matrix, cmap='Blues')  # Using the Blues colormap

    # Add colorbar to the plot
    fig.colorbar(cax)

    # Set axis labels (just for clarity, with indices 1 to 8 for both X and Y)
    ax.set_xticks(np.arange(8))
    ax.set_yticks(np.arange(8))
    ax.set_xticklabels(np.arange(1, 9))
    ax.set_yticklabels(np.arange(1, 9))

    # Set title
    ax.set_title("Feature Importance Heatmap (Permutation Importance)")

    # Show the plot
    plt.show()

"""
Compares the average time per epoch, the max train accuracy, and the max test accuracy across many batch sizes.
"""
def compareBatchSizes(X_train, X_test, y_test, y_train, batch_sizes):
    epoch_times_list = []
    total_times_list = []
    train_accuracies_list = []
    test_accuracies_list = []
    idx = []
    
    for batch_size in batch_sizes:
        _, train_accuracies, test_accuracies, epoch_times = supervisedModel(X_train, X_test, y_test, y_train, epochs=10, batch_size=batch_size, no_print=True)
        epoch_times_list.append(sum(epoch_times)/len(epoch_times))
        total_times_list.append(sum(epoch_times))
        test_accuracies_list.append(max(test_accuracies))
        train_accuracies_list.append(max(train_accuracies))
        idx.append(batch_size)

    createPlot(idx, epoch_times_list, "# Batches", "Time (sec)", "Average Training Time per Batch #")
    createPlot(idx, train_accuracies_list, "# Batches", "Max Train Accuracy", "Max Train Accuracy per Batch #")
    createPlot(idx, test_accuracies_list, "# Batches", "Max Test Accuracy", "Max Test Accuracy per Batch #")
    createPlot(idx, total_times_list, "# Batches", "Total Training Time (sec)", "Total Training Time per Batch #")

"""
Helper function to create plots.
"""
def createPlot(x, y, xlabel, ylabel, title):
    plt.plot(x, y)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    plt.show()
        

### Preparing Data

In [None]:
X_data, y_data = filteredDigits()
X_train, _, y_train, _ = train_test_split(X_data, y_data, test_size=0.2, random_state=42)
X_test, y_test = allDigits()

print(f"Length of training data: {len(X_train)}")
print(f"Length of testing data: {len(X_test)}")


#### Your turn: Show one of the digits using above functions.

In [None]:
### Your code

#### Your turn: Show the numerical representations of the data. How is the image converted to a numerical representation?

In [None]:
### Your code

### Training Supervised Model

In [None]:
model, train_accuracies, test_accuracies, epoch_times= supervisedModel(X_train, X_test, y_test, y_train, epochs=10, batch_size=500)


#### Peek into the results

In [None]:
graphAccuracies(train_accuracies, test_accuracies)

#### Why is the model performing so badly? What can we fix?

In [None]:
### Your code

#### Deep dive into the results

In [None]:
# Fill in functions appropriately 
showCorrectPredictions()

In [None]:
# Fill in functions appropriately 
showIncorrectPredictions()

In [None]:
# Fill in functions appropriately 
showFeatureImportance()

#### Your turn: Try different batch sizes. Use the above functions to see how batch size affects model performance (latency and accuracy)

In [None]:
### Your code