<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStardust/blob/main/_Advanced_One_Fine_Starstuff_V17.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import logging
import argparse
from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import GradScaler, autocast

# Setup logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")

# Define an advanced CNN model
class AdvancedCNN(nn.Module):
    def __init__(self):
        super(AdvancedCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.fc1 = nn.Linear(128 * 4 * 4, 256)
        self.fc2 = nn.Linear(256, 10)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = nn.MaxPool2d(2)(x)
        x = torch.relu(self.bn2(self.conv2(x)))
        x = nn.MaxPool2d(2)(x)
        x = torch.relu(self.bn3(self.conv3(x)))
        x = nn.MaxPool2d(2)(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)  # Apply dropout
        x = self.fc2(x)
        return x

# Data transformations with advanced augmentations
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Datasets and loaders
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Model, criterion, optimizer, and scheduler
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AdvancedCNN().to(device)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)  # Label smoothing
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)  # With weight decay
scheduler = optim.lr_scheduler.SequentialLR(optimizer,
    schedulers=[
        optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, total_iters=5),  # Warm-up for 5 epochs
        optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=15)  # Cosine annealing for the remaining epochs
    ],
    milestones=[5]
)

# Mixed-precision scaler for AMP training
scaler = GradScaler()

# TensorBoard for tracking
writer = SummaryWriter()

# Training loop with early stopping and checkpoints
def train(model, train_loader, test_loader, criterion, optimizer, scheduler, num_epochs=20, patience=5):
    best_val_acc = 0.0
    epochs_without_improvement = 0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()

            # Mixed-precision training
            with autocast():
                outputs = model(images)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item() * images.size(0)

        epoch_loss = running_loss / len(train_loader.dataset)
        val_loss, val_acc = evaluate(model, test_loader, criterion)
        logging.info(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

        # Log to TensorBoard
        writer.add_scalar('Training Loss', epoch_loss, epoch)
        writer.add_scalar('Validation Loss', val_loss, epoch)
        writer.add_scalar('Validation Accuracy', val_acc, epoch)

        scheduler.step()  # Adjust learning rate with scheduler

        # Early stopping and model checkpoint
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), "best_model.pth")
            logging.info("Best model saved.")
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1
            if epochs_without_improvement >= patience:
                logging.info("Early stopping triggered.")
                break

def evaluate(model, data_loader, criterion):
    model.eval()
    correct = 0
    total = 0
    val_loss = 0.0

    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_acc = 100 * correct / total
    return val_loss / len(data_loader.dataset), val_acc

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train and evaluate an advanced CNN on CIFAR-10.')
    parser.add_argument('--num_epochs', type=int, default=20, help='Number of epochs to train (default: 20)')
    parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate (default: 0.001)')
    parser.add_argument('--batch_size', type=int, default=64, help='Batch size (default: 64)')
    parser.add_argument('--patience', type=int, default=5, help='Early stopping patience (default: 5)')
    args, _ = parser.parse_known_args()

    logging.info("Starting training...")
    train(model, train_loader, test_loader, criterion, optimizer, scheduler, args.num_epochs, args.patience)
    logging.info("Training completed. Best model saved as 'best_model.pth'.")
    writer.close()