In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter
import os

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=1000, shuffle=False)

  from .autonotebook import tqdm as notebook_tqdm


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:17<00:00, 566948.55it/s]


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 123883.05it/s]


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:01<00:00, 831922.38it/s] 


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<?, ?it/s]


Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw



In [2]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28*28, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, 64)
        self.fc5 = nn.Linear(64, 10)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        x = self.flatten(x)
        x = nn.ReLU()(self.fc1(x))
        x = self.dropout(x)
        x = nn.ReLU()(self.fc2(x))
        x = self.dropout(x)
        x = nn.ReLU()(self.fc3(x))
        x = self.dropout(x)
        x = nn.ReLU()(self.fc4(x))
        x = self.dropout(x)
        x = self.fc5(x)
        return x


In [3]:
model = MLP()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

writer = SummaryWriter('runs/mnist_experiment')

checkpoint_dir = './checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

patience = 3
best_loss = float('inf')
early_stop_counter = 0


In [4]:
def train(model, train_loader, criterion, optimizer, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        if batch_idx % 100 == 0:
            writer.add_scalar('training loss',
                              running_loss / 100,
                              epoch * len(train_loader) + batch_idx)
            running_loss = 0.0
    
    train_accuracy = 100. * correct / total
    writer.add_scalar('training accuracy', train_accuracy, epoch)
    print(f'Epoch {epoch}, Training Accuracy: {train_accuracy}%')

def validate(model, test_loader, criterion, epoch):
    global best_loss, early_stop_counter
    model.eval()
    validation_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, targets in test_loader:
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            validation_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    validation_loss /= len(test_loader)
    validation_accuracy = 100. * correct / total
    writer.add_scalar('validation loss', validation_loss, epoch)
    writer.add_scalar('validation accuracy', validation_accuracy, epoch)
    
    # Checkpointing
    if validation_loss < best_loss:
        print(f'Validation loss improved from {best_loss} to {validation_loss}. Saving model...')
        best_loss = validation_loss
        torch.save(model.state_dict(), os.path.join(checkpoint_dir, 'best_model.pth'))
        early_stop_counter = 0
    else:
        early_stop_counter += 1
        if early_stop_counter >= patience:
            print('Early stopping triggered')
            return False

    print(f'Epoch {epoch}, Validation Accuracy: {validation_accuracy}%')
    return True

num_epochs = 50
for epoch in range(num_epochs):
    train(model, train_loader, criterion, optimizer, epoch)
    if not validate(model, test_loader, criterion, epoch):
        break

writer.close()


Epoch 0, Training Accuracy: 90.05666666666667%
Validation loss improved from inf to 0.1258832547813654. Saving model...
Epoch 0, Validation Accuracy: 96.28%
Epoch 1, Training Accuracy: 95.78%
Validation loss improved from 0.1258832547813654 to 0.11237807124853134. Saving model...
Epoch 1, Validation Accuracy: 97.0%
Epoch 2, Training Accuracy: 96.90666666666667%
Validation loss improved from 0.11237807124853134 to 0.09089073650538922. Saving model...
Epoch 2, Validation Accuracy: 97.6%
Epoch 3, Training Accuracy: 97.38%
Validation loss improved from 0.09089073650538922 to 0.07846759352833033. Saving model...
Epoch 3, Validation Accuracy: 97.73%
Epoch 4, Training Accuracy: 97.72833333333334%
Epoch 4, Validation Accuracy: 97.72%
Epoch 5, Training Accuracy: 97.89833333333333%
Epoch 5, Validation Accuracy: 97.61%
Epoch 6, Training Accuracy: 97.99333333333334%
Validation loss improved from 0.07846759352833033 to 0.06957596009597182. Saving model...
Epoch 6, Validation Accuracy: 98.17%
Epoch 

In [5]:
model.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'best_model.pth')))

model.eval()
correct = 0
total = 0

with torch.no_grad():
    for inputs, targets in test_loader:
        outputs = model(inputs)
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

test_accuracy = 100. * correct / total
print(f'Test Accuracy: {test_accuracy}%')


Test Accuracy: 98.17%


In [None]:
# tensorboard --logdir=runs -- To view the TensorBoard logs, run the following command
