In [None]:
# Classification model on a MCU NN Library

In [15]:
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from torch import nn
import torch

from nn_deployment_course.mcu.utils import transform_cifar10, sample_from_class, SimpleTrainer


## Dataset

First, we load the dataset and split it in training, validation, and test.

In [16]:
train_set = CIFAR10(
    root="./data/data_cifar10/",
    train=True,
    transform=transform_cifar10(),
    download=True
    )
val_set, tr_set = sample_from_class(train_set, 500)
test_set = CIFAR10(
    root="./data/data_cifar10/",
    train=False,
    transform=transform_cifar10(),
    download=True
    )

Files already downloaded and verified
Files already downloaded and verified


Then we can build the dataloaders with batch size 8

In [17]:
batch_size = 8
datasets = [tr_set, val_set, test_set]
dataloaders = {
    i: DataLoader(
        sett, batch_size=8, shuffle=True, num_workers=4
    )
    for i, sett in zip(["train", "val", "test"], datasets)
}

## Network

We can now build our network

In [18]:
class SampleCNN(nn.Module):
    def __init__(self, shape=(3, 32, 32), batch_size=4):
        super().__init__()
        self.input_shape = shape
        self.batch_size = batch_size

        self.conv1 = nn.Conv2d(in_channels=shape[0], out_channels=64, kernel_size=3)
        self.pool1 = nn.MaxPool2d(2)
        self.relu1 = nn.ReLU()

        self.conv2 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3)
        self.pool2 = nn.AvgPool2d(2)
        self.relu2 = nn.ReLU()

        self.conv3 = nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3)
        self.pool3 = nn.AvgPool2d(2)
        self.relu3 = nn.ReLU()

        self.flatten = nn.Flatten()
        self.interface_shape = self.get_shape()
        self.interface = nn.Linear(in_features=self.interface_shape.numel(), out_features=10)

    def get_shape(self):
        sample = torch.randn(size=(self.batch_size, *self.input_shape))
        out = self.conv1(sample)
        out = self.pool1(out)
        out = self.relu1(out)
        out = self.conv2(out)
        out = self.pool2(out)
        out = self.relu2(out)
        out = self.conv3(out)
        out = self.pool3(out)
        out = self.relu3(out)
        return out.shape[1:]

    def forward(self, x):
        out = self.conv1(x)
        out = self.pool1(out)
        out = self.relu1(out)
        out = self.conv2(out)
        out = self.pool2(out)
        out = self.relu2(out)
        out = self.conv3(out)
        out = self.pool3(out)
        out = self.relu3(out)
        out = self.flatten(out)
        return self.interface(out)


Let's declare the network

In [20]:
input_shape = (3, 32, 32)
cnn = SampleCNN(shape=input_shape, batch_size=8)

## Train

Let's train, but before let's declare the trainer and the training hyperparameters

In [19]:
trainer = SimpleTrainer(datasets=datasets, dataloaders=dataloaders)

In [None]:
hyperparam = {
    "learning_rate": 0.001,
    "learning_step": 5000,
    "learning_gamma": 0.99,
    "epochs": 20,
}

In [None]:
cnn = trainer.train(cnn, config, config.get("name"))