<a href="https://colab.research.google.com/github/KangaOnGit/CNN-Demo/blob/main/MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [33]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

In [34]:
ROOT = './data'
BATCH_SIZE = 128
VALID_RATIO = 0.1

# Tải dữ liệu
train_dataset = datasets.MNIST(root=ROOT, train=True, download=True)
test_dataset = datasets.MNIST(root=ROOT, train=False, download=True)

In [35]:
n_train = int(len(train_dataset) * (1 - VALID_RATIO))
n_valid = len(train_dataset) - n_train
train_dataset, valid_dataset = random_split(train_dataset, [n_train, n_valid])

In [36]:
mean = train_dataset.dataset.data.float().mean() / 255
std = train_dataset.dataset.data.float().std() / 255

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[mean], std=[std])
])

train_dataset.dataset.transform = transform
valid_dataset.dataset.transform = transform
test_dataset.transform = transform

In [37]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

In [38]:
class LeNet(nn.Module):
    def __init__(self, num_classes=10):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2)
        self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = x.view(-1, 16 * 5 * 5)  # Flatten
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [39]:
def train(model, loader, optimizer, criterion, device):
    model.train()
    total_loss, total_correct = 0, 0
    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * inputs.size(0)
        total_correct += (outputs.argmax(1) == labels).sum().item()
    return total_loss / len(loader.dataset), total_correct / len(loader.dataset)

In [40]:
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss, total_correct = 0, 0
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * inputs.size(0)
            total_correct += (outputs.argmax(1) == labels).sum().item()
    return total_loss / len(loader.dataset), total_correct / len(loader.dataset)

In [41]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = LeNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 10
best_valid_loss = float('inf')

for epoch in range(1, num_epochs + 1):
    train_loss, train_acc = train(model, train_loader, optimizer, criterion, device)
    valid_loss, valid_acc = evaluate(model, valid_loader, criterion, device)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'lenet_mnist.pt')

    print(f'Epoch {epoch:02}: Train Loss={train_loss:.4f}, Train Acc={train_acc:.4f}, '
          f'Valid Loss={valid_loss:.4f}, Valid Acc={valid_acc:.4f}')

Epoch 01: Train Loss=0.4009, Train Acc=0.8789, Valid Loss=0.1602, Valid Acc=0.9517
Epoch 02: Train Loss=0.1115, Train Acc=0.9658, Valid Loss=0.1045, Valid Acc=0.9680
Epoch 03: Train Loss=0.0744, Train Acc=0.9770, Valid Loss=0.0715, Valid Acc=0.9782
Epoch 04: Train Loss=0.0568, Train Acc=0.9823, Valid Loss=0.0594, Valid Acc=0.9808
Epoch 05: Train Loss=0.0458, Train Acc=0.9855, Valid Loss=0.0517, Valid Acc=0.9828
Epoch 06: Train Loss=0.0401, Train Acc=0.9872, Valid Loss=0.0484, Valid Acc=0.9840
Epoch 07: Train Loss=0.0335, Train Acc=0.9896, Valid Loss=0.0547, Valid Acc=0.9823
Epoch 08: Train Loss=0.0303, Train Acc=0.9905, Valid Loss=0.0420, Valid Acc=0.9863
Epoch 09: Train Loss=0.0263, Train Acc=0.9916, Valid Loss=0.0408, Valid Acc=0.9865
Epoch 10: Train Loss=0.0231, Train Acc=0.9923, Valid Loss=0.0381, Valid Acc=0.9878


In [42]:
model.load_state_dict(torch.load('lenet_mnist.pt'))
test_loss, test_acc = evaluate(model, test_loader, criterion, device)
print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}')

  model.load_state_dict(torch.load('lenet_mnist.pt'))


Test Loss: 0.0316, Test Accuracy: 0.9897
