<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStardust/blob/main/_Advanced_One_Fine_Starstuff_V16.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import logging
import argparse
from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import GradScaler, autocast
import numpy as np
import sys

# Setup logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")

# Define an advanced CNN model
class AdvancedCNN(nn.Module):
    def __init__(self):
        super(AdvancedCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.fc1 = nn.Linear(128 * 4 * 4, 256)  # Adjusted for CIFAR-10 input size
        self.fc2 = nn.Linear(256, 10)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = nn.MaxPool2d(2)(x)
        x = torch.relu(self.bn2(self.conv2(x)))
        x = nn.MaxPool2d(2)(x)
        x = torch.relu(self.bn3(self.conv3(x)))
        x = nn.MaxPool2d(2)(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)  # Apply dropout
        x = self.fc2(x)
        return x

# Define Label Smoothing Loss
class LabelSmoothingLoss(nn.Module):
    def __init__(self, smoothing=0.1):
        super(LabelSmoothingLoss, self).__init__()
        self.smoothing = smoothing

    def forward(self, output, target):
        confidence = 1.0 - self.smoothing
        logprobs = nn.functional.log_softmax(output, dim=-1)
        nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)).squeeze(1)
        smooth_loss = -logprobs.mean(dim=-1)
        return (confidence * nll_loss + self.smoothing * smooth_loss).mean()

# Mixup data augmentation
def mixup_data(x, y, alpha=0.2):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

# Set up transformations, dataset, and data loader
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

def setup_data_loaders(data_dir: str = './data', batch_size: int = 64) -> (DataLoader, DataLoader):
    train_dataset = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform)
    test_dataset = torchvision.datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

# Model, loss function, optimizer, and scheduler
def setup_model_device(criterion: nn.Module, learning_rate: float = 0.001) -> (nn.Module, torch.device, optim.Optimizer, GradScaler, torch.cuda.amp.GradScaler):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = AdvancedCNN().to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)  # With weight decay
    scaler = GradScaler() if torch.cuda.is_available() else None
    return model, device, optimizer, scaler, scaler

def setup_scheduler(optimizer: optim.Optimizer, num_warmup_epochs: int = 5, total_epochs: int = 20) -> optim.lr_scheduler:
    warmup_scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, total_iters=num_warmup_epochs)
    cosine_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_epochs - num_warmup_epochs)
    scheduler = optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warmup_scheduler, cosine_scheduler], milestones=[num_warmup_epochs])
    return scheduler

# Training loop with early stopping and checkpoints
def train(model: nn.Module, train_loader: DataLoader, test_loader: DataLoader,
          criterion: nn.Module, optimizer: optim.Optimizer, scheduler: optim.lr_scheduler,
          device: torch.device, scaler, num_epochs: int = 20, patience: int = 5):
    best_val_acc = 0.0
    epochs_without_improvement = 0
    writer = SummaryWriter()

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()

            # Apply mixup
            images, labels_a, labels_b, lam = mixup_data(images, labels)
            if torch.cuda.is_available():
                # Mixed-precision training
                with autocast():
                    outputs = model(images)
                    loss = lam * criterion(outputs, labels_a) + (1 - lam) * criterion(outputs, labels_b)
            else:
                # Standard precision training
                outputs = model(images)
                loss = lam * criterion(outputs, labels_a) + (1 - lam) * criterion(outputs, labels_b)

            if scaler is not None:
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()
            running_loss += loss.item() * images.size(0)

        epoch_loss = running_loss / len(train_loader.dataset)
        val_loss, val_acc = evaluate(model, test_loader, criterion, device)
        logging.info(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

        # Log to TensorBoard
        writer.add_scalar('Training Loss', epoch_loss, epoch)
        writer.add_scalar('Validation Loss', val_loss, epoch)
        writer.add_scalar('Validation Accuracy', val_acc, epoch)

        scheduler.step()  # Adjust learning rate with scheduler

        # Early stopping and model checkpoint
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), "best_model.pth")
            logging.info("Best model saved.")
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1
            if epochs_without_improvement >= patience:
                logging.info("Early stopping triggered.")
                break

    writer.close()

# Evaluation function
def evaluate(model: nn.Module, data_loader: DataLoader, criterion: nn.Module, device: torch.device):
    model.eval()
    correct = 0
    total = 0
    val_loss = 0.0

    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_acc = 100 * correct / total
    return val_loss / len(data_loader.dataset), val_acc

# Argument parsing and execution
if __name__ == '__main__':
    if 'ipykernel' in sys.modules:
        # Running in a notebook environment
        args = argparse.Namespace(
            num_epochs=20,
            learning_rate=0.001,
            batch_size=64,
            patience=5
        )
    else:
        parser = argparse.ArgumentParser(description='Train and evaluate an advanced CNN on CIFAR-10.')
        parser.add_argument('--num_epochs', type=int, default=20, help='Number of epochs to train (default: 20)')
        parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate (default: 0.001)')
        parser.add_argument('--batch_size', type=int, default=64, help='Batch size (default: 64)')
        parser.add_argument('--patience', type=int, default=5, help='Early stopping patience (default: 5)')