# All Convolutional Net Regularization


In [1]:
import lightning as L
import torch
import torch.nn.functional as F
import torchmetrics

from torch import nn


class AllConvNet(L.LightningModule):

    def __init__(self, num_classes=10, dropout_rate=0.5):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 96, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(96, 96, kernel_size=3, padding=1)
        self.dropout1 = nn.Dropout(dropout_rate)  # First dropout layer
        self.conv3 = nn.Conv2d(
            96, 96, kernel_size=3, stride=2, padding=1
        )  # Stride 2 for downsampling
        self.conv4 = nn.Conv2d(96, 192, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(192, 192, kernel_size=3, padding=1)
        self.dropout2 = nn.Dropout(dropout_rate)  # Second dropout layer
        self.conv6 = nn.Conv2d(
            192, 192, kernel_size=3, stride=2, padding=1
        )  # Another downsampling step
        self.conv7 = nn.Conv2d(192, 192, kernel_size=3, padding=1)
        self.conv8 = nn.Conv2d(192, 192, kernel_size=1)  # 1x1 convolutions
        self.conv9 = nn.Conv2d(
            192, num_classes, kernel_size=1
        )  # Final 1x1 convolution for class scores

        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.accuracy = torchmetrics.Accuracy(
            num_classes=num_classes, average="macro", task="multiclass"
        )

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.dropout1(x)  # Apply dropout after second conv layer
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = self.dropout2(x)  # Apply dropout after fifth conv layer
        x = F.relu(self.conv6(x))
        x = F.relu(self.conv7(x))
        x = F.relu(self.conv8(x))
        x = self.conv9(x)
        x = self.global_avg_pool(x)
        x = x.view(x.size(0), -1)  # Flatten for the final output
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        val_loss = F.cross_entropy(y_hat, y)
        self.accuracy(y_hat, y)
        self.log("val_loss", val_loss, on_epoch=True, prog_bar=True)
        self.log("val_accuracy", self.accuracy, on_epoch=True, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)

        self.accuracy(y_hat, y)

        self.log("test_accuracy", self.accuracy)
        self.log("test_loss", loss)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

The Imagenette dataset is a smaller subset of 10 easily classified classes from Imagenet. It is available to download from `torchvision`, as shown in the cell below. There are 3 different sizes of the images available. Feel free to use whichever version you prefer. It might make a difference in the performance of your model.

**Note: After downloading the Imagenette dataset, you will need to set `download=False` in the cell below to avoid errors.**

In [2]:
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from torchvision import transforms
from torchvision.datasets import Imagenette


# Prepare the dataset with augmentations for training
train_transforms = transforms.Compose(
    [
        transforms.CenterCrop(160),
        transforms.Resize(64),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),  # Rotate images by up to 10 degrees
        transforms.RandomResizedCrop(256, scale=(0.8, 1.0)),  # Random crops
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
        transforms.Grayscale(),
    ]
)

test_transforms = transforms.Compose([
    transforms.CenterCrop(160),
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    transforms.Grayscale()
])

train_dataset = Imagenette("data/imagenette/train/", split="train", size="160px", download=False, transform=train_transforms)

# Use 10% of the training set for validation
train_set_size = int(len(train_dataset) * 0.9)
val_set_size = len(train_dataset) - train_set_size

seed = torch.Generator().manual_seed(42)
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_set_size, val_set_size], generator=seed)
val_dataset.dataset.transform = test_transforms

# Use DataLoader to load the dataset
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=128, num_workers=8, shuffle=True, persistent_workers=True
)
val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=128, num_workers=8, shuffle=False, persistent_workers=True
)

# Configure the test dataset
test_dataset = Imagenette("data/imagenette/test/", split="val", size="160px", download=False, transform=test_transforms)

model = AllConvNet()

# Add EarlyStopping
early_stop_callback = EarlyStopping(monitor="val_loss",
                                    mode="min",
                                    patience=5)

# Configure Checkpoints
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    mode="min"
)

In [None]:
from pathlib import Path


# Fit the model
torch.set_float32_matmul_precision("high")
trainer = L.Trainer(callbacks=[early_stop_callback, checkpoint_callback], max_epochs=-1)

trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)

# Save the model to be used on the Transfer Learning Model
dir_path = Path("models")
dir_path.mkdir(parents=True, exist_ok=True)
model_path = dir_path / "all_convnet.pth"
torch.save(model.state_dict(), model_path)

In [None]:
# Evaluate the model on the test set
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=256, num_workers=8, shuffle=False, persistent_workers=True
)
trainer.test(model=model, dataloaders=test_loader)