In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist

class Image:
    def __init__(self):
        # Load the dataset
        (self.train_images, self.train_labels), (self.test_images, self.test_labels) = mnist.load_data()
    
    def display_statistics(self):
        # Calculate and display statistics
        fig, ax = plt.subplots()
        unique, counts = np.unique(self.train_labels, return_counts=True)
        ax.bar(unique, counts)
        ax.set_title('Distribution of Digits in Training Data')
        ax.set_xlabel('Digit')
        ax.set_ylabel('Counts')
        plt.show()

    def display_image(self, dataset, index):
        # Display an image
        if dataset == 'train':
            image = self.train_images[index]
        elif dataset == 'test':
            image = self.test_images[index]
        plt.imshow(image, cmap='gray')
        plt.title(f'Image from {dataset} set at index {index}')
        plt.show()

    def display_mean_images(self):
        # Display the mean image of each digit
        fig, axes = plt.subplots(1, 10, figsize=(15, 1.5))
        for digit in range(10):
            mean_image = np.mean(self.train_images[self.train_labels == digit], axis=0)
            axes[digit].imshow(mean_image, cmap='gray')
            axes[digit].set_title(f'Digit {digit}')
            axes[digit].axis('off')
        plt.show()

    def reshape_images(self):
        # Reshape the image datasets
        self.train_images = self.train_images.reshape((60000, 784))
        self.test_images = self.test_images.reshape((10000, 784))

# Usage
data_handler = Image()
data_handler.display_statistics()
data_handler.display_image('train', 123)  
data_handler.display_image('test', 123) 
data_handler.display_mean_images()
data_handler.reshape_images()