# Introduction
- Task: Multi-class Classification    
- Method: Convolutional Neural Network
- Library: PyTorch
- Dataset: Mnist
    - https://yann.lecun.com/exdb/mnist/ 
    - https://pytorch.org/vision/main/generated/torchvision.datasets.MNIST.html
    - The MNIST database of handwritten digits, has a training set of 60,000 examples, and a test set of 10,000 examples.
    - grayscale 28*28
    - Classes: {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}

# Prepare Data

In [None]:
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torchvision.datasets import MNIST

class CustomDataset(Dataset):
    '''
    CustomDataset is a subclass of PyTorch's Dataset.
    
    This dataset class wraps the MNIST dataset, allowing for easy loading and transformation.
    
    Attributes:
        data (MNIST): The MNIST dataset object.
    '''
    def __init__(self, root, train=True, transform=None):
        '''
        Initializes the CustomDataset.
        
        Loads the MNIST dataset from the specified root directory.
        
        Args:
            root (str): Root directory where the MNIST dataset is stored.
            train (bool): If True, creates dataset from the training set, otherwise from the test set.
            transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version.
        
        Returns:
            None
        '''
        super(CustomDataset, self).__init__()
        self.data = MNIST(root=root, train=train, download=True, transform=transform)

    def __len__(self):
        '''
        Returns the total number of samples in the dataset.
        
        Args:
            None
        
        Returns:
            int: Number of samples.
        '''
        return len(self.data) 

    def __getitem__(self, idx):
        '''
        Retrieves the sample at the specified index.
        
        Args:
            idx (int): Index of the sample to retrieve.
        
        Returns:
            tuple: (image, label) where image is the transformed image and label is its corresponding class.
        '''
        return self.data[idx]

In [None]:
# Define transformations for data preprocessing
transform = transforms.Compose([
    # Convert images to tensors
    transforms.ToTensor(),
    # Normalize with mean=0.1307 and std=0.3081
    transforms.Normalize((0.1307,), (0.3081,))
])

In [None]:
train_dataset = CustomDataset(root='./data', train=True, transform=transform)
test_dataset = CustomDataset(root='./data', train=False, transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=1000, shuffle=False)

# Define Model

In [None]:
from torch import nn

class CustomModule(nn.Module):
    def __init__(self):
        super.__init__()

    def forward(self, x):
        return x

# Train

# Test

# Inference

# Visualization