## Imports and Functions 

In [352]:
"""
A complete implementation and training of a CIFAR10 classifier.
The prompt is to create another LearningRateScheduler.
"""
import time
from typing import Tuple, Callable
from torchvision.transforms import Compose, ToTensor


import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from tqdm import tqdm

from model import Model
from config import CONFIG

from matplotlib import pyplot as plt
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [353]:
def get_cifar10_data() -> Tuple[DataLoader, DataLoader]:
    """
    Get the CIFAR10 data from torchvision.
    Arguments:
        None
    Returns:
        train_loader (DataLoader): The training data loader.
        test_loader (DataLoader): The test data loader.
    """
    # Get the training data:
    train_data = CIFAR10(
        root="data/cifar10", train=True, download=True, transform=CONFIG.transforms
    )
    # Create a data loader for the training data:
    train_loader = DataLoader(train_data, batch_size=CONFIG.batch_size, shuffle=True)
    # Get the test data:
    test_data = CIFAR10(
        root="data/cifar10", train=False, download=True, transform=CONFIG.transforms
    )
    # Create a data loader for the test data:
    test_loader = DataLoader(test_data, batch_size=CONFIG.batch_size, shuffle=True)
    # Return the data loaders:
    return train_loader, test_loader


def train(
    model: torch.nn.Module,
    train_loader: DataLoader,
    test_loader: DataLoader,
    num_epochs: int,
    optimizer: torch.optim.Optimizer,
    criterion: torch.nn.Module,
    device: torch.device = device,
) -> None:
    """
    Train a model on the data.
    Arguments:
        model (torch.nn.Module): The model to train.
        train_loader (DataLoader): The training data loader.
        test_loader (DataLoader): The test data loader.
        num_epochs (int): The number of epochs to train for.
        optimizer (torch.optim.Optimizer): The optimizer to use.
        criterion (torch.nn.Module): The loss function to use.
        learning_rate_scheduler (torch.optim.lr_scheduler._LRScheduler): The
            learning rate scheduler to use.
        device (torch.device): The device to use for training.
    Returns:
        None
    """
    # Move the model to the device:
    model.to(device)
    # Loop over the epochs:
    for epoch in range(num_epochs):
        # Set the model to training mode:
        model.train()
        # Loop over the training data:
        for x, y in tqdm(train_loader):
            # Move the data to the device:
            x, y = x.to(device), y.to(device)
            # Zero the gradients:
            optimizer.zero_grad()
            # Forward pass:
            y_hat = model(x)
            # Compute the loss:
            loss = criterion(y_hat, y)
            # Backward pass:
            loss.backward()
            # Update the parameters:
            optimizer.step()
        # Set the model to evaluation mode:
        model.eval()
        # Compute the accuracy on the test data:
        accuracy = compute_accuracy(model, test_loader, device)
        if accuracy > ACCURACY_THRESHOLD:
            break
        # Print the results:
        print(f"Epoch {epoch + 1} | Test Accuracy: {accuracy:.2f}")


def compute_accuracy(
    model: torch.nn.Module, data_loader: DataLoader, device: torch.device = device
) -> float:
    """
    Compute the accuracy of a model on some data.
    Arguments:
        model (torch.nn.Module): The model to compute the accuracy of.
        data_loader (DataLoader): The data loader to use.
        device (torch.device): The device to use for training.
    Returns:
        accuracy (float): The accuracy of the model on the data.
    """
    # Set the model to evaluation mode:
    model.eval()
    # Initialize the number of correct predictions:
    num_correct = 0
    # Loop over the data:
    for x, y in data_loader:
        # Move the data to the device:
        x, y = x.to(device), y.to(device)
        # Forward pass:
        y_hat = model(x)
        # Compute the predictions:
        predictions = torch.argmax(y_hat, dim=1)
        # Update the number of correct predictions:
        num_correct += torch.sum(predictions == y).item()
    # Compute the accuracy:
    accuracy = num_correct / len(data_loader.dataset)
    # Return the accuracy
    return accuracy

## Code 

In [356]:
ACCURACY_THRESHOLD = 0.55

In [357]:
train_loader, test_loader = get_cifar10_data()

Files already downloaded and verified
Files already downloaded and verified


In [358]:
class CONFIG:
    batch_size = 64*4
    num_epochs = 2

    optimizer_factory: Callable[
        [nn.Module], torch.optim.Optimizer
    ] = lambda model: torch.optim.Adam(model.parameters(), lr=3e-3)

    transforms = Compose([ToTensor()])

In [410]:
class Model(torch.nn.Module):
    def __init__(self, num_channels: int, num_classes: int) -> None:
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=num_channels, out_channels=16, kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3)
        # self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3)
        self.fc1 = nn.Linear(576, 128)
        self.fc2 = nn.Linear(128,num_classes)
        self.pool = nn.MaxPool2d(2,2)
        self.dropout = nn.Dropout(0.5)
        self.bn1 = nn.BatchNorm2d(16)
        self.bn2 = nn.BatchNorm2d(16)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        computes the output of the model
        """
        x = torch.relu(self.bn1(self.conv1(x)))
        x = self.pool(x)
        x = torch.relu(self.bn2(self.conv2(x)))
        # x = torch.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1) # Flatten
        x = self.fc1(x)
        x = self.fc2(torch.relu(x))
        return x

In [411]:
# Create the model:
model = Model(num_channels=3, num_classes=10)
# Create the optimizer:
optimizer = CONFIG.optimizer_factory(model)
# Create the loss function:
criterion = torch.nn.CrossEntropyLoss()

In [412]:
tic = time.time()
train(
    model,
    train_loader,
    test_loader,
    num_epochs=CONFIG.num_epochs,
    optimizer=optimizer,
    criterion=criterion,
)
toc = time.time()

  0%|                                  | 0/782 [00:00<?, ?it/s]


RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x576 and 288x128)

In [409]:
print(
    f"Training time: {toc - tic:.2f} seconds, final accuracy: {compute_accuracy(model, test_loader):.2f}"
)

Training time: 37.74 seconds, final accuracy: 0.56


In [None]:
# B64, C16-bn-pool, c32-bn-pool, 
 


In [137]:
# B32, C1-16-k3, C2-32-k3, 37s
# B32, C1-32-k3, C2-32-k3, 95s
# B64, C1-32-k3, C2-32-k3, 45s
# LR-1 -> B64, C1-32-k3, C2-32-k3 -> 47
# LR-15 -> B64, C1-32-k3, C2-32-k3 -> 46



# B32, C1-16-k5, C2-32-k3, 67.41 - 2batches
# B32, C1-32-k5, C2-32-k3, 87.68


