In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Network

In [49]:
class LeNet(nn.Module):
    def __init__(self, in_channels: int = 1, num_classes: int =10):
        super(LeNet, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels, 6, kernel_size=5)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        
        self.fc1 = nn.Linear(16*5*5, 120) # 16*5*5 for CIFAR10 and 16*4*4 for MNIST
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)
        
        
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        
        return x
        

In [10]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(1, 6, 5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(2,2)),
            
            nn.Conv2d(6, 16, 5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(2,2)),
        )
        
        
        self.fc = nn.Sequential(
        nn.Linear(16*5*5, 120),
        nn.ReLU(inplace=True),
        nn.Linear(120, 84),
        nn.ReLU(inplace=True),
        nn.Linear(84, 10),
        )
        
        
    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1) # can't use nn.Flatten(x, 1)
        x = self.fc(x)
        return x

# Data

### MNIST

In [11]:
import torchvision

# Load the MNIST dataset
mnist_train = torchvision.datasets.MNIST(
    root='../data',
    train=True,
    download=True,
    transform=torchvision.transforms.ToTensor()
)

# Load the MNIST test dataset
mnist_test = torchvision.datasets.MNIST(
    root='../data',
    train=False,
    download=True,
    transform=torchvision.transforms.ToTensor()
)

train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=16, shuffle=True)

### CIFAR10

In [38]:
import torchvision

# Load the cifar10 dataset
cifar10_train = torchvision.datasets.CIFAR10(
    root='../data',
    train=True,
    download=True,
    transform=torchvision.transforms.ToTensor()
)

# Load the cifar10 test dataset
cifar10_test = torchvision.datasets.CIFAR10(
    root='../data',
    train=False,
    download=True,
    transform=torchvision.transforms.ToTensor()
)

train_loader = torch.utils.data.DataLoader(cifar10_train, batch_size=16, shuffle=True)

Files already downloaded and verified
Files already downloaded and verified


# Train

In [52]:
import torch.optim as optim

net = LeNet(in_channels=3)

optimizer = optim.SGD(net.parameters(), lr=0.01)
# criterion = nn.MSELoss()
criterion = nn.CrossEntropyLoss()

epochs = 10
for epoch in range(epochs):
    epoch_loss = 0
    
    for batch in train_loader:
        X, y = batch[0], batch[1]

        optimizer.zero_grad()

        output = net(X)

        loss = criterion(output, y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss

    print(f'Epoch: {epoch}, Loss: {epoch_loss}')


Epoch: 0, Loss: 7184.0224609375
Epoch: 1, Loss: 6282.46337890625
Epoch: 2, Loss: 5324.2880859375
Epoch: 3, Loss: 4848.81396484375
Epoch: 4, Loss: 4504.10302734375
Epoch: 5, Loss: 4227.78369140625
Epoch: 6, Loss: 4002.423828125
Epoch: 7, Loss: 3818.681884765625
Epoch: 8, Loss: 3655.15673828125
Epoch: 9, Loss: 3518.828125
