In [None]:
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras.datasets import cifar10
import seaborn as sns
import pandas as pd


# Load the CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

# Define class names (for labeling)
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

# --- Basic Dataset Information ---
print("Shape of x_train:", x_train.shape)  # (50000, 32, 32, 3)
print("Shape of y_train:", y_train.shape)  # (50000, 1)  (Important: It's a 2D array)
print("Shape of x_test:", x_test.shape)    # (10000, 32, 32, 3)
print("Shape of y_test:", y_test.shape)    # (10000, 1)
print("Data type of x_train:", x_train.dtype)  # uint8 (values between 0 and 255)
print("Data type of y_train:", y_train.dtype)  # uint8

# --- Class Distribution ---
def plot_class_distribution(labels, class_names, dataset_name="Training Set"):
    """Plots the distribution of classes in the CIFAR-10 dataset."""
    class_counts = np.bincount(labels.flatten(), minlength=len(class_names)) # Flatten needed
    plt.figure(figsize=(10, 6))
    plt.bar(class_names, class_counts)
    plt.xlabel("Class")
    plt.ylabel("Number of Images")
    plt.title(f"Class Distribution in the {dataset_name}")
    plt.xticks(rotation=45, ha="right")  # Rotate x-axis labels for readability
    plt.tight_layout()
    plt.show()


plot_class_distribution(y_train, class_names, dataset_name="Training Set")
plot_class_distribution(y_test, class_names, dataset_name="Test Set")


# --- Pixel Value Analysis ---
def plot_pixel_value_distribution(images, channel='all'):
    """Plots the distribution of pixel values in a given image set.

    Args:
        images:  A NumPy array of shape (N, height, width, channels) containing the images.
        channel: 'red', 'green', 'blue', or 'all' (default: 'all') to specify the channel
                 to analyze. If 'all', all channels are combined.
    """

    if channel == 'all':
        pixel_values = images.flatten()  # Combine all channels
        title = "Distribution of Pixel Values (All Channels)"
    else:
        channel_index = {'red': 0, 'green': 1, 'blue': 2}[channel]
        pixel_values = images[:, :, :, channel_index].flatten()  # Select the specified channel
        title = f"Distribution of Pixel Values (Channel: {channel})"

    plt.figure(figsize=(10, 6))
    sns.histplot(pixel_values, bins=256, kde=True)  # Use seaborn for a smoother histogram
    plt.xlabel("Pixel Value")
    plt.ylabel("Frequency")
    plt.title(title)
    plt.tight_layout()
    plt.show()


plot_pixel_value_distribution(x_train, channel='all')  # All channels
plot_pixel_value_distribution(x_train, channel='red')  # Red channel
plot_pixel_value_distribution(x_train, channel='green')# Green channel
plot_pixel_value_distribution(x_train, channel='blue') # Blue channel

# --- Analyze mean and standard deviation of pixel values ---
def analyze_pixel_stats(images):
    """Calculates and prints mean and standard deviation of pixel values
       for each channel in a set of images.
    """
    means = np.mean(images, axis=(0, 1, 2))  # Mean across all images, rows, and columns
    stds = np.std(images, axis=(0, 1, 2))    # Std across all images, rows, and columns
    print(f"Mean pixel values (R, G, B): {means}")
    print(f"Standard deviation of pixel values (R, G, B): {stds}")

analyze_pixel_stats(x_train)
analyze_pixel_stats(x_test)



# --- Image Brightness Analysis ---
def plot_image_brightness_distribution(images):
    """Plots the distribution of average image brightness."""
    brightness_values = np.mean(images, axis=(1, 2, 3))  # Mean across rows, columns, and channels
    plt.figure(figsize=(10, 6))
    sns.histplot(brightness_values, bins=50, kde=True)
    plt.xlabel("Average Image Brightness")
    plt.ylabel("Frequency")
    plt.title("Distribution of Average Image Brightness")
    plt.tight_layout()
    plt.show()

plot_image_brightness_distribution(x_train)


# --- Sample Images Visualization (moved from previous answer) ---
def visualize_cifar10(images, labels, class_names, num_images=25):
    """
    Visualizes a grid of images from the CIFAR-10 dataset.

    Args:
        images:  A NumPy array of shape (N, height, width, channels) containing the images.
        labels:  A NumPy array of shape (N,) containing the corresponding labels.
        class_names: A list of class names corresponding to the labels.
        num_images: The number of images to display in the grid (default: 25).
    """

    num_rows = int(np.sqrt(num_images))  # Calculate rows and columns for the grid
    num_cols = int(np.ceil(num_images / num_rows))
    plt.figure(figsize=(2 * num_cols, 2 * num_rows))  # Adjust figure size for better display

    for i in range(num_images):
        plt.subplot(num_rows, num_cols, i + 1)
        plt.xticks([])  # Remove x-axis ticks
        plt.yticks([])  # Remove y-axis ticks
        plt.grid(False) # Remove grid lines
        plt.imshow(images[i], cmap=plt.cm.binary)  # Display the image (you can change cmap if needed)
        label_index = int(labels[i])  # Ensure label is an integer
        plt.xlabel(class_names[label_index])  # Set the label as the x-axis label
    plt.tight_layout()  # Adjust subplot parameters for a tight layout.
    plt.show()


# Visualize the first 25 images from the training set
visualize_cifar10(x_train, y_train, class_names, num_images=25)


# Example of visualizing images of a specific class

def visualize_class(images, labels, class_names, class_index, num_images=10):
    """
    Visualizes images belonging to a specific class.

    Args:
        images:  A NumPy array of shape (N, height, width, channels) containing the images.
        labels:  A NumPy array of shape (N,) containing the corresponding labels.
        class_names: A list of class names corresponding to the labels.
        class_index: The index of the class to visualize.
        num_images: The number of images to display.
    """

    # Find the indices of images belonging to the specified class
    indices = np.where(labels.flatten() == class_index)[0] #flatten here too

    # Select the first num_images indices
    selected_indices = indices[:num_images]

    # Get the images and labels for those indices
    selected_images = images[selected_indices]
    selected_labels = labels[selected_indices]

    # Visualize the selected images
    visualize_cifar10(selected_images, selected_labels, class_names, num_images=num_images)

# Visualize images of class 'dog' (index 5)
visualize_class(x_train, y_train, class_names, class_index=5, num_images=10) # Shows 10 dog images