# MNIST EDA

<a href="https://colab.research.google.com/github/BU-Spark/ml-549-course/blob/main/phase3_EDA/mnist-eda.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Here we'll do some basic data exploration on the MNIST data set. MNIST is fairly structured. Other image datasets
may not be all the same size, in which case it is good to show some statistics on the image sizes.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist

In [None]:
# Load the CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()

In [None]:
print('Training data shape : ', x_train.shape, y_train.shape)
print('Testing data shape : ', x_test.shape, y_test.shape)

In [None]:

# Visualize some images from the dataset
def visualize_images(images, labels, classes, num_images=10):

    # Shuffle the indices every time this function is called
    # create a permutation of indices [0, 1, 2, ..., len(images)-1]
    indices = np.random.permutation(len(images))
    images = images[indices]
    labels = labels[indices]

    fig, axes = plt.subplots(1, num_images, figsize=(15, 3))
    for i, ax in enumerate(axes):
        ax.imshow(images[i], cmap='gray')
        ax.set_title(classes[labels[i]])
        ax.axis('off')
    plt.show()

# Class names in MNIST dataset
classes = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']

visualize_images(x_train, y_train, classes)


In [None]:

# Display the distribution of classes
def plot_class_distribution(labels, classes):
    unique, counts = np.unique(labels, return_counts=True)
    plt.bar(classes, counts)
    plt.xlabel('Classes')
    plt.ylabel('Number of examples')
    plt.title('Distribution of classes in CIFAR-10')
    plt.xticks(rotation=45)
    plt.show()

plot_class_distribution(y_train.flatten(), classes)


In [None]:

# Basic statistics about the images
def image_statistics(images):
    print(f"Mean: {np.mean(images)}")
    print(f"Standard Deviation: {np.std(images)}")
    print(f"Min Pixel Value: {np.min(images)}")
    print(f"Max Pixel Value: {np.max(images)}")

image_statistics(x_train)
