In [1]:
import torch 
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold

In [2]:
def get_transforms(augment=False):
    if augment:
        return transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, padding=4),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
    else:
        return transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

In [3]:
class CIFARCNN(nn.Module):
    def __init__(self, use_dropout=False):
        super().__init__()

        # Dropout regularization
        def dropout_layer(p=0.3):
            return nn.Dropout2d(p=p) if use_dropout else nn.Identity()

        # Conv layers
        # image size (N, RGB(3), 32, 32)
        self.conv1 = nn.Conv2d(3, 6, 5)  # (input channel, output channel, kernel size)
        self.dropout1 = dropout_layer()
        self.pool = nn.MaxPool2d(2, 2)  # kernel size=2, stride = 2
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.dropout2 = dropout_layer()

        # Fully connected layers
        self.fc1 = nn.Linear(16*5*5, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)  # 10 output for the CIFAR10 classes

    def forward(self, x):
        x = self.droput1(F.relu(self.conv1(x)))
        x = self.pool(x)
        x = self.dropout2(F.relu(self.conv2(x)))

        # Moving to fully connected layer
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


In [4]:
def train_model(model,
                train_loader,
                val_loader,
                optimizer,
                criterion,
                device,
                num_epochs,
                early_stop=False):
    
    best_val_loss = float("inf")
    patience, patience_counter = 3, 0

    for epoch in range(num_epochs):
        model.train()
        train_loss, correct = 0 , 0
        # train_acc = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item * images.size(0)
            correct += (outputs.argmax(1) == labels).sum().item()

        train_loss /= len(train_loader)
        # train_loss /= len(train_loader.dataset)
        train_acc = correct / len(train_loader)

        model.eval()
        val_loss, correct = 0, 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * images.size(0)
                correct += (outputs.argmax(1) == labels).sum().item()
        val_loss /= len(train_loader)
        val_acc = correct / len(train_loader)

        print(f'Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

        if early_stop:
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print("Early stopping triggered!")
                    break

In [6]:
def visualize_test_data(images, labels, model, classes, fig_name="visualize_test"):
    model.eval()
    with torch.no_grad():
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
    
    fig, axes = plt.subplots(2, 5, figsize=(10, 5))
    axes = axes.flatten()
    for i in range(10):
        img = images[i].permute(1, 2, 0).cpu().numpy()
        axes[i].imshow(img, cmap='gray')
        axes[i].set_title(f'Pred: {classes[preds[i].item()]}, True: {classes[labels[i].item()]}')
        axes[i].axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
def run_cases(case, device):
    cases = [
        {'dropout': False, 'decay': 0, 'early_stop': False, 'augment': False},
        {'dropout': True, 'decay': 0, 'early_stop': False, 'augment': False},
        {'dropout': False, 'decay': 0.0001, 'early_stop': False, 'augment': False},
        {'dropout': False, 'decay': 0, 'early_stop': True, 'augment': False},
        {'dropout': False, 'decay': 0, 'early_stop': True, 'augment': True},
    ]
    config = cases[case - 1]
    print(f"Running Case {case}: {config}")
    
    transform = get_transforms(config['augment'])
    dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    
    kfold = KFold(n_splits=5, shuffle=True, random_state=42)
    for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset)):
        print(f"Fold {fold+1}")
        train_subset = torch.utils.data.Subset(dataset, train_idx)
        val_subset = torch.utils.data.Subset(dataset, val_idx)
        train_loader = DataLoader(train_subset, batch_size=64, shuffle=True)
        val_loader = DataLoader(val_subset, batch_size=64, shuffle=False)
        
        model = CIFARCNN(use_dropout=config['dropout']).to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=config['decay'])
        
        train_model(model, train_loader, val_loader, optimizer, criterion, device, 10, config['early_stop'])
        
        test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
        model.eval()
        correct = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                correct += (outputs.argmax(1) == labels).sum().item()
        test_acc = correct / len(test_loader.dataset)
        print(f'Test Accuracy for Fold {fold+1}: {test_acc:.4f}')

In [None]:
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    for case in range(1, 6):
        run_cases(case, device)
