In [1]:
from typing import Tuple, Dict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch import Tensor
from torchvision.datasets import CIFAR10

In [2]:
DATA_ROOT = "./dataset"

In [24]:
# pylint: disable=unsubscriptable-object
class Net(nn.Module):
    """Simple CNN adapted from 'PyTorch: A 60 Minute Blitz'."""

    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.bn1 = nn.BatchNorm2d(6)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.bn2 = nn.BatchNorm2d(16)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.bn3 = nn.BatchNorm1d(120)
        self.fc2 = nn.Linear(120, 84)
        self.bn4 = nn.BatchNorm1d(84)
        self.fc3 = nn.Linear(84, 10)

    # pylint: disable=arguments-differ,invalid-name
    def forward(self, x: Tensor) -> Tensor:
        """Compute forward pass."""
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.bn3(self.fc1(x)))
        x = F.relu(self.bn4(self.fc2(x)))
        x = self.fc3(x)
        return x

In [25]:
def load_data() -> (
    Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, Dict]
):
    """Load CIFAR-10 (training and test set)."""
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )
    trainset = CIFAR10(DATA_ROOT, train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)
    testset = CIFAR10(DATA_ROOT, train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)
    num_examples = {"trainset": len(trainset), "testset": len(testset)}
    return trainloader, testloader, num_examples

In [26]:
UsedNet = Net

def train(
    net: UsedNet,
    trainloader: torch.utils.data.DataLoader,
    epochs: int,
    device: torch.device,  # pylint: disable=no-member
) -> None:
    """Train the network."""
    # Define loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)

    print(f"Training {epochs} epoch(s) w/ {len(trainloader)} batches each")

    # Train the network
    net.to(device)
    net.train()
    for epoch in range(epochs):  # loop over the dataset multiple times
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            images, labels = data[0].to(device), data[1].to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 100 == 99:  # print every 100 mini-batches
                print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / 2000))
                running_loss = 0.0


def test(
    net: UsedNet,
    testloader: torch.utils.data.DataLoader,
    device: torch.device,  # pylint: disable=no-member
) -> Tuple[float, float]:
    """Validate the network on the entire test set."""
    # Define loss and metrics
    criterion = nn.CrossEntropyLoss()
    correct, loss = 0, 0.0

    # Evaluate the network
    net.to(device)
    net.eval()
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)  # pylint: disable=no-member
            correct += (predicted == labels).sum().item()
    accuracy = correct / len(testloader.dataset)
    return loss, accuracy

In [27]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Centralized PyTorch training")
print("Load data")
trainloader, testloader, _ = load_data()
net = Net().to(DEVICE)
net.eval()

Centralized PyTorch training
Load data
Files already downloaded and verified
Files already downloaded and verified


Net(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (bn1): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (bn3): BatchNorm1d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (bn4): BatchNorm1d(84, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

In [28]:
print("Start training")
train(net=net, trainloader=trainloader, epochs=10, device=DEVICE)

Start training
Training 10 epoch(s) w/ 1563 batches each
[1,   100] loss: 0.115
[1,   200] loss: 0.112
[1,   300] loss: 0.107
[1,   400] loss: 0.106
[1,   500] loss: 0.103
[1,   600] loss: 0.101
[1,   700] loss: 0.100
[1,   800] loss: 0.099
[1,   900] loss: 0.097
[1,  1000] loss: 0.096
[1,  1100] loss: 0.096
[1,  1200] loss: 0.095
[1,  1300] loss: 0.094
[1,  1400] loss: 0.093
[1,  1500] loss: 0.092
[2,   100] loss: 0.092
[2,   200] loss: 0.091
[2,   300] loss: 0.090
[2,   400] loss: 0.089
[2,   500] loss: 0.088
[2,   600] loss: 0.088
[2,   700] loss: 0.087
[2,   800] loss: 0.087
[2,   900] loss: 0.086
[2,  1000] loss: 0.086
[2,  1100] loss: 0.086
[2,  1200] loss: 0.084
[2,  1300] loss: 0.085
[2,  1400] loss: 0.085
[2,  1500] loss: 0.085
[3,   100] loss: 0.083
[3,   200] loss: 0.082
[3,   300] loss: 0.082
[3,   400] loss: 0.083
[3,   500] loss: 0.082
[3,   600] loss: 0.081
[3,   700] loss: 0.082
[3,   800] loss: 0.081
[3,   900] loss: 0.081
[3,  1000] loss: 0.080
[3,  1100] loss: 0.080


In [None]:
# Save the trained model
torch.save(net.state_dict(), 'trained_model.pth')
print("Trained model saved as 'trained_model.pth'")

Trained model saved as 'trained_model.pth'


In [31]:
# Load the trained model
loaded_net = Net()
loaded_net.load_state_dict(torch.load('trained_model.pth'))
loaded_net.eval()
print("Loaded 'trained_model.pth' as model")

Loaded 'trained_model.pth' as model


In [32]:
print("Evaluate model")
loss, accuracy = test(net=loaded_net, testloader=testloader, device=DEVICE)
print("Loss: ", loss)
print("Accuracy: ", accuracy)


Evaluate model
Loss:  378.02625554800034
Accuracy:  0.6593
