In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import os
import numpy as np
import random

def set_seed(seed):
    """Set seed for reproducibility"""
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class StandardAlexNet(nn.Module):
    def __init__(self, num_classes=10):
        super(StandardAlexNet, self).__init__()
        self.features = nn.Sequential(
            # First convolutional block
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            # Second convolutional block
            nn.Conv2d(64, 192, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            # Third convolutional block
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            
            # Fourth convolutional block
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            
            # Fifth convolutional block
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256 * 4 * 4, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(4096, num_classes),
        )
        
        # Weight initialization
        self._initialize_weights()
        
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

class Trainer:
    def __init__(self, config):
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")
        
        # Data transformations with basic augmentations
        self.setup_data_transforms()
        
        # Load and split dataset
        self.load_dataset()
        
        # Create model, optimizer, and loss function
        self.model = StandardAlexNet().to(self.device)
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.SGD(
            self.model.parameters(), 
            lr=config.learning_rate, 
            momentum=config.momentum, 
            weight_decay=config.weight_decay
        )
        
        # Learning rate scheduler
        self.scheduler = optim.lr_scheduler.StepLR(
            self.optimizer, 
            step_size=30, 
            gamma=0.1
        )
        
    def setup_data_transforms(self):
        # Basic data augmentation for training
        self.transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
        ])
        
        # Test transforms
        self.transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
        ])
    
    def load_dataset(self):
        try:
            dataset_path = os.path.join(self.config.input_path, "cifar-10")
            if os.path.exists(dataset_path):
                print(f"Loading CIFAR-10 from dataset at {dataset_path}")
                full_trainset = torchvision.datasets.CIFAR10(
                    root=dataset_path, train=True, download=False, transform=self.transform_train
                )
                testset = torchvision.datasets.CIFAR10(
                    root=dataset_path, train=False, download=False, transform=self.transform_test
                )
            else:
                print("CIFAR-10 dataset not found, downloading to working directory")
                full_trainset = torchvision.datasets.CIFAR10(
                    root=self.config.output_path, train=True, download=True, transform=self.transform_train
                )
                testset = torchvision.datasets.CIFAR10(
                    root=self.config.output_path, train=False, download=True, transform=self.transform_test
                )
        except Exception as e:
            print(f"Error loading dataset: {e}")
            print("Falling back to downloading the dataset")
            full_trainset = torchvision.datasets.CIFAR10(
                root=self.config.output_path, train=True, download=True, transform=self.transform_train
            )
            testset = torchvision.datasets.CIFAR10(
                root=self.config.output_path, train=False, download=True, transform=self.transform_test
            )
        
        val_size = int(len(full_trainset) * self.config.val_split)
        train_size = len(full_trainset) - val_size
        
        trainset, valset = random_split(
            full_trainset, [train_size, val_size],
            generator=torch.Generator().manual_seed(self.config.seed)
        )
        
        self.trainloader = DataLoader(
            trainset, batch_size=self.config.batch_size, shuffle=True, 
            num_workers=self.config.num_workers, pin_memory=True
        )
        
        self.valloader = DataLoader(
            valset, batch_size=self.config.batch_size, shuffle=False, 
            num_workers=self.config.num_workers, pin_memory=True
        )
        
        self.testloader = DataLoader(
            testset, batch_size=self.config.batch_size, shuffle=False, 
            num_workers=self.config.num_workers, pin_memory=True
        )
        
        self.classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    
    def train_epoch(self, epoch):
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for batch_idx, (inputs, targets) in enumerate(self.trainloader):
            inputs, targets = inputs.to(self.device), targets.to(self.device)
            
            # Forward pass
            outputs = self.model(inputs)
            loss = self.criterion(outputs, targets)
            
            # Backward and optimize
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            # Update metrics
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            # Print progress
            if batch_idx % 100 == 0:
                print(f"Epoch: {epoch+1} | Batch: {batch_idx}/{len(self.trainloader)} | "
                      f"Loss: {running_loss/(batch_idx+1):.4f} | "
                      f"Acc: {100.*correct/total:.2f}%")
        
        train_loss = running_loss / len(self.trainloader)
        train_acc = 100. * correct / total
        return train_loss, train_acc
    
    def evaluate(self, dataloader):
        self.model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for inputs, targets in dataloader:
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                
                outputs = self.model(inputs)
                loss = self.criterion(outputs, targets)
                
                running_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        
        eval_loss = running_loss / len(dataloader)
        eval_acc = 100. * correct / total
        return eval_loss, eval_acc
    
    def save_checkpoint(self, epoch, val_loss, val_acc, filename):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'val_loss': val_loss,
            'val_acc': val_acc
        }
        torch.save(checkpoint, filename)
    
    def train(self):
        print(f"Starting training for {self.config.num_epochs} epochs")
        best_val_loss = float('inf')
        
        for epoch in range(self.config.num_epochs):
            train_loss, train_acc = self.train_epoch(epoch)
            
            val_loss, val_acc = self.evaluate(self.valloader)
            
            self.scheduler.step()
            
            print(f"Epoch {epoch+1}/{self.config.num_epochs}")
            print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
            print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
            
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                print(f"Saving best model with validation loss: {val_loss:.4f}")
                self.save_checkpoint(
                    epoch, val_loss, val_acc, 
                    os.path.join(self.config.output_path, "alexnet_cifar10_best.pt")
                )
        
        self.save_checkpoint(
            epoch, val_loss, val_acc, 
            os.path.join(self.config.output_path, "alexnet_cifar10_final.pt")
        )
        
        return best_val_loss, val_acc
    
    def evaluate_model(self):
        checkpoint_path = os.path.join(self.config.output_path, "alexnet_cifar10_best.pt")
        checkpoint = torch.load(checkpoint_path)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        print(f"Loaded best model from epoch {checkpoint['epoch']+1}")
        
        test_loss, test_acc = self.evaluate(self.testloader)
        print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.2f}%")
        
        return test_acc

class Config:
    def __init__(self):
        self.seed = 42
        self.input_path = "./data"
        self.output_path = "./output"
        self.num_workers = 4
        self.batch_size = 128
        self.val_split = 0.1  # 10% of training data for validation
        self.learning_rate = 0.01
        self.weight_decay = 5e-4
        self.momentum = 0.9
        self.num_epochs = 60

def main():
    config = Config()
    os.makedirs(config.output_path, exist_ok=True)
    set_seed(config.seed)
    trainer = Trainer(config)
    trainer.train()
    test_acc = trainer.evaluate_model()
    print(f"Final test accuracy: {test_acc:.2f}%")

if __name__ == '__main__':
    main()

Using device: cuda
CIFAR-10 dataset not found, downloading to working directory
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./output/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:04<00:00, 35.1MB/s] 


Extracting ./output/cifar-10-python.tar.gz to ./output
Files already downloaded and verified
Starting training for 60 epochs
Epoch: 1 | Batch: 0/352 | Loss: 2.3201 | Acc: 8.59%
Epoch: 1 | Batch: 100/352 | Loss: 2.1198 | Acc: 20.87%
Epoch: 1 | Batch: 200/352 | Loss: 1.9748 | Acc: 26.32%
Epoch: 1 | Batch: 300/352 | Loss: 1.8725 | Acc: 30.39%
Epoch 1/60
Train Loss: 1.8294, Train Acc: 31.92%
Val Loss: 1.5636, Val Acc: 42.00%
Saving best model with validation loss: 1.5636
Epoch: 2 | Batch: 0/352 | Loss: 1.5084 | Acc: 46.09%
Epoch: 2 | Batch: 100/352 | Loss: 1.5258 | Acc: 43.50%
Epoch: 2 | Batch: 200/352 | Loss: 1.4827 | Acc: 45.16%
Epoch: 2 | Batch: 300/352 | Loss: 1.4433 | Acc: 46.73%
Epoch 2/60
Train Loss: 1.4280, Train Acc: 47.50%
Val Loss: 1.3389, Val Acc: 51.22%
Saving best model with validation loss: 1.3389
Epoch: 3 | Batch: 0/352 | Loss: 1.3097 | Acc: 46.09%
Epoch: 3 | Batch: 100/352 | Loss: 1.2965 | Acc: 52.65%
Epoch: 3 | Batch: 200/352 | Loss: 1.2574 | Acc: 54.40%
Epoch: 3 | Batch:

  checkpoint = torch.load(checkpoint_path)


Test Loss: 0.3928, Test Accuracy: 88.12%
Final test accuracy: 88.12%
