# Train CNN

## Setup

### Imports

In [19]:
import os
from typing import Any
from itertools import pairwise, repeat

import torch
import torchvision
import pandas as pd
import plotly.express as px
from torch import nn, Tensor

In [2]:
LABEL2ID = {
    "Apple_Black_rot": 0,
    "Apple_rust": 1,
    "Apple_healthy": 2,
    "Apple_scab": 3,
    "Grape_Black_rot": 4,
    "Grape_Esca": 5,
    "Grape_healthy": 6,
    "Grape_spot": 7,
}

### dataset_dict

In [3]:
# load toutes les images dans un tensor
imgs_lst: list[Tensor] = []
labels_lst: list[int] = []
for img_class in os.listdir("dataset"):
    class_idx = LABEL2ID[img_class]
    for img in os.listdir(os.path.join("dataset", img_class)):
        img_pth = os.path.join("dataset", img_class, img)
        img = torchvision.io.decode_image(img_pth)
        imgs_lst.append(img)
        labels_lst.append(class_idx)

In [4]:
raw_imgs = torch.stack(imgs_lst, dim=0) # dataset_size, C, H, W
labels = torch.IntTensor(labels_lst)

In [5]:
display(raw_imgs.shape)
display(raw_imgs.dtype)

torch.Size([7221, 3, 256, 256])

torch.uint8

In [6]:
preprocessed_imgs = (
    raw_imgs #dataset_size, C, H, W uint8
    # .permute(0, 2, 3, 1) #dataset_size, H, W, C uint8
    .to(dtype=torch.bfloat16)  #dataset_size, H, W, C float 32
)
preprocessed_imgs = (preprocessed_imgs - preprocessed_imgs.mean(dim=3, keepdim=True)) / preprocessed_imgs.std(dim=3, keepdim=True)
display(preprocessed_imgs.shape)
display(preprocessed_imgs.dtype)

torch.Size([7221, 3, 256, 256])

torch.bfloat16

### preprocess dataset

In [7]:
for class_name, idx in LABEL2ID.items():
    print(class_name, (labels == idx).sum())

Apple_Black_rot tensor(620)
Apple_rust tensor(275)
Apple_healthy tensor(1640)
Apple_scab tensor(629)
Grape_Black_rot tensor(1178)
Grape_Esca tensor(1382)
Grape_healthy tensor(422)
Grape_spot tensor(1075)


convert list of tensors into single tensors

In [8]:
dataset = torch.utils.data.TensorDataset(preprocessed_imgs, labels)

In [9]:
img, label = dataset[:10]
print(img.shape, label)

torch.Size([10, 3, 256, 256]) tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.int32)


In [10]:
# px.imshow(img.to(dtype=torch.float), facet_col=0)

In [None]:

# for batch_idx, (raw_imgs, labels) in enumerate(data_loader):
#     print("batch_idx:", batch_idx)
#     px.imshow(raw_imgs.permute(0, 2, 3, 1), facet_col=0).show()
#     if batch_idx == 2:
#         break

In [12]:

class CNN(nn.Module):
    def __init__(
        self,
        kernels_per_layer: list[int],
        mlp_width: int,
        mlp_depth: int,
        n_classes: int,
    ):
        super().__init__()
        conv_layers = []
        channels = [3] + kernels_per_layer
        for in_channels, out_channels in pairwise(channels):
            conv_layer = nn.Conv2d(
                in_channels,
                out_channels,
                5,
                padding=2,
            )
            conv_layers.append(conv_layer)
        self.conv_layers = nn.ModuleList(conv_layers)
        
        self.linear_layers = []
        for width in repeat(mlp_width, mlp_depth - 1):
            self.linear_layers.append(nn.LazyLinear(width))
        self.linear_layers.append(nn.LazyLinear(n_classes))
        self.linear_layers = nn.ModuleList(self.linear_layers)
    
    def forward(self, x: Tensor) -> Tensor:
        for layer_idx, conv in enumerate(self.conv_layers):
            x = conv(x)
            x = nn.functional.relu(x)
            x = nn.functional.max_pool2d(x, 2)

        x = x.flatten(1)
        for layer_idx, linear_layer in enumerate(self.linear_layers):
            x = linear_layer(x)
            if layer_idx != len(self.linear_layers) - 2:
                x = nn.functional.relu(x)
        return x

# with torch.no_grad():
#     imgs, lalbes = dataset[:2]
#     print(imgs.shape)
#     model(imgs.to(dtype=torch.float32)).shape

In [21]:
from tqdm import tqdm

DEVICE = torch.device("cuda")

def training_step(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    criterion,
    x: Tensor,
    y: Tensor,
) -> dict[str, float]:
    x = x.to(dtype=torch.float32, device=DEVICE)
    y = y.to(dtype=torch.long, device=DEVICE)
    optimizer.zero_grad()
    model_output = model(x)
    loss = criterion(model_output, y)
    loss.backward()
    optimizer.step()
    with torch.no_grad():
        accuracy = (model_output.argmax(dim=-1) == y).float().mean()
    return {
        "loss": loss.item(),
        "accuracy": accuracy.item(),
    }

def train_model_for_single_epoch(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    data_loader: torch.utils.data.DataLoader,
    criterion,
) -> list[dict]:
    step_dicts = []
    for x, y in data_loader:
        step_dict = training_step(model, optimizer, criterion, x, y)
        step_dicts.append(step_dict)
    return step_dicts

def train_model(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    data_loader: torch.utils.data.DataLoader,
    criterion, #: callable[Tensor, [Tensor, Tensor]],
    n_epochs: int,
) -> pd.DataFrame:
    training_data = []
    for epoch in range(n_epochs):
        epoch_data = train_model_for_single_epoch(\
            model,
            optimizer,
            data_loader,
            criterion,
        )
        training_data.extend(epoch_data)
    return pd.DataFrame.from_records(training_data)

In [23]:
model = (
    CNN(
        kernels_per_layer=[32, 64, 128, 256],
        mlp_width=128,
        mlp_depth=3,
        n_classes=len(LABEL2ID)
    )
    .to(device=DEVICE)
)
BATCH_SIZE = 32
data_loader = torch.utils.data.DataLoader(
    dataset,
    BATCH_SIZE,
    shuffle=True,
)

optimizer = torch.optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()
step_dicts = train_model(model, optimizer, data_loader, criterion, 10)

In [24]:
px.line(step_dicts, y="accuracy")

In [27]:
px.line(step_dicts, y="loss", log_y=True)