In [6]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import torchmetrics
import wandb

In [7]:
class SignDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "dataset", batch_size: int = 16):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

        self.transform = transforms.Compose([
            transforms.Grayscale(),
            transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])


    def setup(self, stage=None):
        # Dataset in folder structure: dataset/class_name/*.jpg
        self.dataset = datasets.ImageFolder(self.data_dir, transform=self.transform)

        # 80/20 split
        train_size = int(0.8 * len(self.dataset))
        val_size = len(self.dataset) - train_size

        self.train_set, self.val_set = torch.utils.data.random_split(
            self.dataset, [train_size, val_size]
        )

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size)

In [8]:
from torchmetrics.classification import MulticlassConfusionMatrix
from sklearn.metrics import f1_score


class TrafficSignCNN(pl.LightningModule):
    def __init__(self, class_names, lr=1e-3):
        super().__init__()
        self.save_hyperparameters()

        self.class_names = class_names
        self.num_classes = len(class_names)

        self.model = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(32 * 7 * 7, 64), nn.ReLU(),
            nn.Linear(64, self.num_classes)
        )

        self.loss_fn = nn.CrossEntropyLoss()

        # --- METRICS ---
        # Initialize metrics for easy calculation
        self.f1 = torchmetrics.F1Score(task="multiclass", num_classes=self.num_classes, average='macro')
        self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=self.num_classes)
        self.confmat = MulticlassConfusionMatrix(num_classes=self.num_classes)


    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)

        self.log("train_loss", loss, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        preds = logits.argmax(dim=1)

        self.log("val_loss", loss, prog_bar=True)

        # Update running metrics
        self.f1.update(preds, y)
        self.confmat.update(preds, y)
        self.accuracy.update(preds, y)

    def on_validation_epoch_end(self):
        # 1. Compute and Log F1
        f1_score = self.f1.compute()
        self.log("val_f1", f1_score, prog_bar=True)

        # 2. Plot and Log Confusion Matrix
        # .plot() returns a Matplotlib Figure and Axis
        fig, ax = self.confmat.plot(labels=self.class_names)

        # Log the figure to WandB as an image
        if self.logger:
            self.logger.experiment.log({
                "confusion_matrix": wandb.Image(fig),
                "global_step": self.global_step
            })

        # Close the figure to free memory
        import matplotlib.pyplot as plt
        plt.close(fig)

        # 3. Reset metrics for the next epoch
        self.f1.reset()
        self.confmat.reset()

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

In [9]:
wandb_logger = WandbLogger(project="sign-cnn")

data = SignDataModule(data_dir="C:/Users/robbe/PycharmProjects/DAI_Autonomous_Vehicles/Data/traffic_signs", batch_size=16)
model = TrafficSignCNN(class_names=["90", "60", "30", "stop"])

trainer = pl.Trainer(
    max_epochs=20,
    accelerator="auto",
    logger=wandb_logger,
    deterministic=True,
)

trainer.fit(model, data)

ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
C:\Users\robbe\.conda\envs\DataScienceEnv\Lib\site-packages\pytorch_lightning\callbacks\model_checkpoint.py:658: Checkpoint directory .\sign-cnn\ww24mik2\checkpoints exists and is not empty.

  | Name     | Type                      | Params | Mode 
---------------------------------------------------------------
0 | model    | Sequential                | 105 K  | train
1 | loss_fn  | CrossEntropyLoss          | 0      | train
2 | f1       | MulticlassF1Score         | 0      | train
3 | accuracy | MulticlassAccuracy        | 0      | train
4 | confmat  | MulticlassConfusionMatrix | 0      | train
---------------------------------------------------------------
105 K     Tra

                                                                           

C:\Users\robbe\.conda\envs\DataScienceEnv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
C:\Users\robbe\.conda\envs\DataScienceEnv\Lib\site-packages\pytorch_lightning\loops\fit_loop.py:310: The number of training batches (6) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 6/6 [00:00<00:00, 25.62it/s, v_num=mik2, train_loss_step=1.630]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation:   0%|          | 0/2 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Validation DataLoader 0:  50%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆ     | 1/2 [00:00<00:00, 58.94it/s]
Validation DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2/2 [00:00<00:00, 60.51it/s]
Epoch 1: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 6/6 [00:00<00:00, 32.92it/s, v_num=mik2, train_loss_step=1.270, val_loss=1.190, val_f1=0.161, train_loss_epoch=1.290]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation:   0%|          | 0/2 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Validation DataLoader 0:  50%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆ     | 1/2 [00:00<00:00, 63.22it/s]
Validation DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2/2 [00:00<00:00, 62.47it/s]
Epoch 2: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 6/6 [00:00<00:00, 

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 6/6 [00:00<00:00, 11.57it/s, v_num=mik2, train_loss_step=0.00473, val_loss=0.022, val_f1=1.000, train_loss_epoch=0.0231]
