<a href="https://colab.research.google.com/github/Knightler/PyTorch-projects/blob/main/pytorch_practice01.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **MNIST Handwritten Digit Classifier**

In [6]:
import torch
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
train_transform = transforms.Compose([
    transforms.RandomRotation(10),  # Rotate images randomly by ±10 degrees
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),  # Random shift
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.1307,), std=(0.3081,))
])
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.1307,), std=(0.3081,))
])

In [4]:
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=train_transform)
test_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=test_transform)

In [5]:
train_dataloader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True
)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False
)

In [7]:
class CNN(nn.Module):
    def __init__(self,
                 input_channels,
                 hidden_layers,
                 output_layers):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=input_channels, out_channels=hidden_layers, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels=hidden_layers, out_channels=hidden_layers, kernel_size=3, stride=2, padding=1)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(hidden_layers * 7 * 7, output_layers)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.maxpool(x)
        x = self.relu(self.conv2(x))
        x = self.maxpool(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x