In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import numpy as np
from tqdm import tqdm

In [9]:
# --- 1. Configuration and Hyperparameters ---
# This section contains all the tunable parameters for the model and training process.
class Config:
    # Data
    DATASET_PATH = "./data"

    # Training
    BATCH_SIZE = 64
    EPOCHS = 20  # Max number of epochs to train for
    LEARNING_RATE = 0.001

    # Model Architecture
    # The number of output channels for the residual blocks
    RES_BLOCK_CHANNELS = [32, 64]
    NUM_CLASSES = 10 # CIFAR-10 has 10 classes

    # Scheduler & Early Stopping
    SCHEDULER_PATIENCE = 5 # How many epochs to wait for improvement before reducing LR
    SCHEDULER_FACTOR = 0.1   # Factor by which to reduce learning rate
    EARLY_STOPPING_PATIENCE = 7 # How many epochs to wait for improvement before stopping

    # System
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

config = Config()
print(f"Using device: {config.DEVICE}")

Using device: cuda


In [10]:
# --- 2. Data Loading and Preprocessing (Corrected) ---
# We define transformations for our data. For the training set, we apply augmentations
# to make our model more robust. The test set only gets normalized.
def get_data_loaders():
    """
    Prepares and returns the CIFAR-10 data loaders for training, validation, and testing.
    This version corrects the data splitting logic to prevent TypeErrors.
    """
    # Transformations for the training data to introduce variability
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    # For validation and test data, we only normalize
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    # Custom Dataset to apply transformations to a subset of data
    class TransformedSubset(torch.utils.data.Dataset):
        def __init__(self, subset, transform=None):
            self.subset = subset
            self.transform = transform

        def __getitem__(self, index):
            # The subset gets the raw (PIL Image, label) tuple
            x, y = self.subset[index]
            if self.transform:
                # Apply the specified transform
                x = self.transform(x)
            return x, y

        def __len__(self):
            return len(self.subset)


    # Download the full training dataset WITHOUT applying transforms yet
    raw_train_dataset = torchvision.datasets.CIFAR10(
        root=config.DATASET_PATH, train=True, download=True, transform=None
    )

    # Split the raw dataset into training and validation sets
    train_size = int(0.85 * len(raw_train_dataset))
    val_size = len(raw_train_dataset) - train_size
    # Use a generator for reproducible splits
    train_subset, val_subset = random_split(raw_train_dataset, [train_size, val_size],
                                            generator=torch.Generator().manual_seed(42))

    # Now, apply the correct transforms to the subsets using our wrapper
    train_dataset = TransformedSubset(train_subset, transform=train_transform)
    val_dataset = TransformedSubset(val_subset, transform=test_transform)

    # Download and load the test dataset with the test transform
    test_dataset = torchvision.datasets.CIFAR10(
        root=config.DATASET_PATH, train=False, download=True, transform=test_transform
    )

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=config.BATCH_SIZE, shuffle=False, num_workers=2)

    return train_loader, val_loader, test_loader

In [11]:
# --- 3. Model Architecture (Mini-ResNet) ---

class ResidualBlock(nn.Module):
    """
    A single residual block for the ResNet.
    It consists of two convolutional layers with batch normalization.
    The input to the block is added to the output of the second conv layer (a "shortcut" connection).
    """
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()

        # Main path
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # Shortcut connection to match dimensions if in_channels != out_channels
        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        # Main path
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))

        # Add shortcut connection
        out += self.shortcut(x)

        # Final ReLU activation
        out = self.relu(out)
        return out

class MiniResNet(nn.Module):
    """
    The main model architecture, inspired by ResNet.
    """
    def __init__(self, num_classes=10):
        super(MiniResNet, self).__init__()

        # Initial convolutional layer
        self.in_conv = nn.Sequential(
            nn.Conv2d(3, config.RES_BLOCK_CHANNELS[0], kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(config.RES_BLOCK_CHANNELS[0]),
            nn.ReLU(inplace=True)
        )

        # Residual blocks and pooling
        self.res_block1 = ResidualBlock(config.RES_BLOCK_CHANNELS[0], config.RES_BLOCK_CHANNELS[0])
        self.pool1 = nn.MaxPool2d(2)
        # Note: The input to the second block is the output of the first
        self.res_block2 = ResidualBlock(config.RES_BLOCK_CHANNELS[0], config.RES_BLOCK_CHANNELS[1])
        self.pool2 = nn.MaxPool2d(2)

        # Final classification layers
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(config.RES_BLOCK_CHANNELS[1], num_classes)

    def forward(self, x):
        x = self.in_conv(x)
        x = self.res_block1(x)
        x = self.pool1(x)
        x = self.res_block2(x)
        x = self.pool2(x)
        x = self.avg_pool(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

In [12]:
# --- 4. Training and Evaluation Loops ---

def train_one_epoch(model, loader, optimizer, criterion, device):
    """Trains the model for one epoch."""
    model.train()
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    # Using tqdm for a progress bar
    progress_bar = tqdm(loader, desc="Training", leave=False)
    for inputs, labels in progress_bar:
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Statistics
        running_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs.data, 1)
        total_samples += labels.size(0)
        correct_predictions += (predicted == labels).sum().item()

        progress_bar.set_postfix(loss=loss.item())

    epoch_loss = running_loss / total_samples
    epoch_acc = correct_predictions / total_samples
    return epoch_loss, epoch_acc


def validate_one_epoch(model, loader, criterion, device):
    """Validates the model for one epoch."""
    model.eval()
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    with torch.no_grad():
        progress_bar = tqdm(loader, desc="Validating", leave=False)
        for inputs, labels in progress_bar:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total_samples += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()

    epoch_loss = running_loss / total_samples
    epoch_acc = correct_predictions / total_samples
    return epoch_loss, epoch_acc

In [13]:
# --- 5. Main Execution ---

if __name__ == "__main__":
    # Get data loaders
    train_loader, val_loader, test_loader = get_data_loaders()

    # Initialize model, criterion, and optimizer
    model = MiniResNet(num_classes=config.NUM_CLASSES).to(config.DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE)

    # Learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        'min', # reduce on validation loss
        patience=config.SCHEDULER_PATIENCE,
        factor=config.SCHEDULER_FACTOR,
        verbose=True
    )

    # Early stopping variables
    best_val_loss = float('inf')
    epochs_no_improve = 0
    best_model_weights = None

    print("\n--- Starting Training ---")
    for epoch in range(config.EPOCHS):
        print(f"\nEpoch {epoch+1}/{config.EPOCHS}")

        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, config.DEVICE)
        val_loss, val_acc = validate_one_epoch(model, val_loader, criterion, config.DEVICE)

        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%")
        print(f"Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc*100:.2f}%")

        # Update scheduler
        scheduler.step(val_loss)

        # Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
            # Save the best model weights
            best_model_weights = model.state_dict()
            print(f"Validation loss improved. Saving model.")
        else:
            epochs_no_improve += 1
            print(f"Validation loss did not improve. Counter: {epochs_no_improve}/{config.EARLY_STOPPING_PATIENCE}")

        if epochs_no_improve >= config.EARLY_STOPPING_PATIENCE:
            print("\nEarly stopping triggered!")
            break

    print("\n--- Training Finished ---")

    # Load the best model weights for final evaluation
    if best_model_weights:
        print("\nLoading best model weights for final testing.")
        model.load_state_dict(best_model_weights)
    else:
        print("\nNo best model weights found, using the last model state.")

    # Final evaluation on the test set
    test_loss, test_acc = validate_one_epoch(model, test_loader, criterion, config.DEVICE)
    print("\n--- Final Test Results ---")
    print(f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc*100:.2f}%")




--- Starting Training ---

Epoch 1/20




Train Loss: 1.5045 | Train Acc: 45.51%
Val Loss:   1.5176 | Val Acc:   45.77%
Validation loss improved. Saving model.

Epoch 2/20




Train Loss: 1.1601 | Train Acc: 58.74%
Val Loss:   1.2316 | Val Acc:   57.91%
Validation loss improved. Saving model.

Epoch 3/20




Train Loss: 1.0165 | Train Acc: 64.22%
Val Loss:   1.2967 | Val Acc:   55.41%
Validation loss did not improve. Counter: 1/7

Epoch 4/20




Train Loss: 0.9313 | Train Acc: 67.27%
Val Loss:   0.9974 | Val Acc:   64.24%
Validation loss improved. Saving model.

Epoch 5/20




Train Loss: 0.8702 | Train Acc: 69.37%
Val Loss:   0.9082 | Val Acc:   68.44%
Validation loss improved. Saving model.

Epoch 6/20




Train Loss: 0.8234 | Train Acc: 71.43%
Val Loss:   1.0214 | Val Acc:   66.19%
Validation loss did not improve. Counter: 1/7

Epoch 7/20




Train Loss: 0.7820 | Train Acc: 72.85%
Val Loss:   0.7859 | Val Acc:   71.99%
Validation loss improved. Saving model.

Epoch 8/20




Train Loss: 0.7519 | Train Acc: 73.88%
Val Loss:   0.8638 | Val Acc:   69.84%
Validation loss did not improve. Counter: 1/7

Epoch 9/20




Train Loss: 0.7268 | Train Acc: 74.96%
Val Loss:   0.9173 | Val Acc:   69.03%
Validation loss did not improve. Counter: 2/7

Epoch 10/20




Train Loss: 0.7013 | Train Acc: 75.67%
Val Loss:   0.7990 | Val Acc:   72.19%
Validation loss did not improve. Counter: 3/7

Epoch 11/20




Train Loss: 0.6851 | Train Acc: 76.16%
Val Loss:   0.8136 | Val Acc:   72.48%
Validation loss did not improve. Counter: 4/7

Epoch 12/20




Train Loss: 0.6676 | Train Acc: 76.87%
Val Loss:   0.7429 | Val Acc:   74.24%
Validation loss improved. Saving model.

Epoch 13/20




Train Loss: 0.6497 | Train Acc: 77.51%
Val Loss:   1.1367 | Val Acc:   63.49%
Validation loss did not improve. Counter: 1/7

Epoch 14/20




Train Loss: 0.6357 | Train Acc: 78.01%
Val Loss:   0.9856 | Val Acc:   67.88%
Validation loss did not improve. Counter: 2/7

Epoch 15/20




Train Loss: 0.6201 | Train Acc: 78.61%
Val Loss:   0.7749 | Val Acc:   74.36%
Validation loss did not improve. Counter: 3/7

Epoch 16/20




Train Loss: 0.6124 | Train Acc: 78.72%
Val Loss:   0.6991 | Val Acc:   76.64%
Validation loss improved. Saving model.

Epoch 17/20




Train Loss: 0.5987 | Train Acc: 79.43%
Val Loss:   0.8401 | Val Acc:   72.23%
Validation loss did not improve. Counter: 1/7

Epoch 18/20




Train Loss: 0.5880 | Train Acc: 79.79%
Val Loss:   0.8798 | Val Acc:   71.96%
Validation loss did not improve. Counter: 2/7

Epoch 19/20




Train Loss: 0.5794 | Train Acc: 80.25%
Val Loss:   0.6539 | Val Acc:   77.24%
Validation loss improved. Saving model.

Epoch 20/20




Train Loss: 0.5690 | Train Acc: 80.42%
Val Loss:   0.6432 | Val Acc:   77.96%
Validation loss improved. Saving model.

--- Training Finished ---

Loading best model weights for final testing.


                                                             


--- Final Test Results ---
Test Loss: 0.6553 | Test Acc: 78.05%


