In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Network

In [53]:
class LeNet(nn.Module):
    def __init__(self, in_channels: int = 1, num_classes: int =10):
        super(LeNet, self).__init__()


        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=6, kernel_size=5) 
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5) 
        
        self.fc1 = nn.Linear(16*4*4, 120) # 16*5*5 for CIFAR10 and 16*4*4 for MNIST
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)
        
        
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        
        return x
        

In [10]:
class LeNet(nn.Module):

        '''
        Conv layer / MaxPool output shape formula: [(Width − KernelSize + 2Padding) / Stride] + 1

        Default Stride for nn.Conv2D == 1
        //       //     // nn.MaxPool2D == Kernel Size



        MNIST for Example:
        '''

    def __init__(self, num_classes: int =10, in_channels: int = 1):
        super(LeNet, self).__init__()
    
        self.features = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=6, kernel_size=(5,5)), # (28-5 + 2*0 / 1) + 1 == 24 --> 24*24*6
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(2,2)), # (24-2 + 2*0 / 2) + 1 == 12 --> 12*12*6
            
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=(5,5)), # (12-5 + 2*0 / 1) + 1 == 9 --> 9*9*16
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(2,2)), # (9 - 2 + 2*0 / 2 + 1 == 4 --> 4*4*16)
        )
        
        
        self.fc = nn.Sequential(
        nn.Linear(16*4*4, 120),
        nn.ReLU(inplace=True),
        nn.Linear(120, 84),
        nn.ReLU(inplace=True),
        nn.Linear(84, num_classes),
        )
        
        
    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1) # can't use nn.Flatten(x, 1)
        x = self.fc(x)
        return x

# Data

### MNIST

In [11]:
import torchvision

# Load the MNIST dataset
mnist_train = torchvision.datasets.MNIST(
    root='../data',
    train=True,
    download=True,
    transform=torchvision.transforms.ToTensor()
)

# Load the MNIST test dataset
mnist_test = torchvision.datasets.MNIST(
    root='../data',
    train=False,
    download=True,
    transform=torchvision.transforms.ToTensor()
)

train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=16, shuffle=True)

### CIFAR10

In [38]:
import torchvision

# Load the cifar10 dataset
cifar10_train = torchvision.datasets.CIFAR10(
    root='../data',
    train=True,
    download=True,
    transform=torchvision.transforms.ToTensor()
)

# Load the cifar10 test dataset
cifar10_test = torchvision.datasets.CIFAR10(
    root='../data',
    train=False,
    download=True,
    transform=torchvision.transforms.ToTensor()
)

train_loader = torch.utils.data.DataLoader(cifar10_train, batch_size=16, shuffle=True)

Files already downloaded and verified
Files already downloaded and verified


# Train

In [54]:
import torch.optim as optim

net = LeNet(in_channels=3)

optimizer = optim.SGD(net.parameters(), lr=0.01)
# criterion = nn.MSELoss()
criterion = nn.CrossEntropyLoss()

epochs = 10
for epoch in range(epochs):
    epoch_loss = 0
    
    for batch in train_loader:
        X, y = batch[0], batch[1]

        optimizer.zero_grad()

        output = net(X)

        loss = criterion(output, y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss

    print(f'Epoch: {epoch}, Loss: {epoch_loss}')


Epoch: 0, Loss: 7105.69140625
Epoch: 1, Loss: 6044.060546875
Epoch: 2, Loss: 5220.2109375
Epoch: 3, Loss: 4693.55029296875
Epoch: 4, Loss: 4418.4443359375
Epoch: 5, Loss: 4216.78662109375
Epoch: 6, Loss: 4047.238037109375
Epoch: 7, Loss: 3898.958251953125
Epoch: 8, Loss: 3766.83984375
Epoch: 9, Loss: 3640.83642578125
