# MNIST Dataset Analysis Demo

This notebook demonstrates how to load the MNIST dataset, visualize sample images from both the training and test sets, and analyze the class distribution with counts for each digit.

In [None]:
from loading_dataset import load_mnist

# Load the MNIST dataset
X_train, y_train, X_test, y_test = load_mnist()

print(f"Training samples shape: {X_train.shape}")
print(f"Training labels shape: {y_train.shape}")
print(f"Test samples shape: {X_test.shape}")
print(f"Test labels shape: {y_test.shape}")

In [None]:
import matplotlib.pyplot as plt

def visualize_samples(X, y, dataset_name="Dataset", num_samples=16):
    """
    Visualizes a grid of sample images from the dataset.
    
    Parameters:
      - X: Numpy array of images.
      - y: Corresponding labels.
      - dataset_name: Name of the dataset (used in the title).
      - num_samples: Number of images to display (default 16).
    """
    plt.figure(figsize=(8, 8))
    for i in range(num_samples):
        plt.subplot(4, 4, i + 1)
        plt.imshow(X[i], cmap='gray')
        plt.title(f"Label: {y[i]}")
        plt.axis('off')
    plt.suptitle(f"{dataset_name} Sample Images", fontsize=16)
    plt.tight_layout()
    plt.show()

# Visualize samples from the training dataset
visualize_samples(X_train, y_train, dataset_name="Training")

# Visualize samples from the test dataset
visualize_samples(X_test, y_test, dataset_name="Test")

In [None]:
import numpy as np
import seaborn as sns

def plot_class_distribution(y, dataset_name="Dataset"):
    """
    Plots the distribution of class labels with count annotations for each class.
    
    Parameters:
      - y: Numpy array of labels.
      - dataset_name: Name of the dataset (used in the title).
    """
    unique, counts = np.unique(y, return_counts=True)
    plt.figure(figsize=(8, 6))
    ax = sns.barplot(x=unique, y=counts, palette="viridis")
    plt.title(f"{dataset_name} Class Distribution")
    plt.xlabel("Digit Label")
    plt.ylabel("Frequency")
    
    # Annotate each bar with its count
    for p in ax.patches:
        height = p.get_height()
        ax.annotate(f'{int(height)}',
                    (p.get_x() + p.get_width() / 2., height),
                    ha='center', va='bottom',
                    xytext=(0, 5), textcoords='offset points')
    plt.show()

# Plot class distribution for training and test datasets
plot_class_distribution(y_train, dataset_name="Training")
plot_class_distribution(y_test, dataset_name="Test")

## Conclusion

In this notebook, we have loaded the MNIST dataset, visualized sample images from both the training and test sets, and plotted the class distribution with counts for each digit. This analysis helps us understand the dataset before proceeding with training any classification models.