<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/_Advanced_One_Fine_Starstuff_V16.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
import argparse
import logging
from tqdm import tqdm
import numpy as np
from torchvision.datasets import CIFAR10
from torch.amp import GradScaler, autocast
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transformations with Mixup and CutMix as options
class MixupCutMixDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, alpha=1.0, use_mixup=True):
        self.dataset = dataset
        self.alpha = alpha
        self.use_mixup = use_mixup

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        if self.alpha > 0 and np.random.rand() < 0.5:
            # Mixup or CutMix
            idx2 = np.random.randint(0, len(self.dataset))
            img2, label2 = self.dataset[idx2]
            if self.use_mixup:
                lam = np.random.beta(self.alpha, self.alpha)
                img = lam * img + (1 - lam) * img2
                label = lam * label + (1 - lam) * label2
            else:
                lam = np.random.beta(self.alpha, self.alpha)
                bbx1, bby1, bbx2, bby2 = self.rand_bbox(img.size(), lam)
                img[:, bbx1:bbx2, bby1:bby2] = img2[:, bbx1:bbx2, bby1:bby2]
        return img, label

    def rand_bbox(self, size, lam):
        W = size[1]
        H = size[2]
        cut_rat = np.sqrt(1. - lam)
        cut_w = int(W * cut_rat)
        cut_h = int(H * cut_rat)
        cx = np.random.randint(W)
        cy = np.random.randint(H)
        bbx1 = np.clip(cx - cut_w // 2, 0, W)
        bby1 = np.clip(cy - cut_h // 2, 0, H)
        bbx2 = np.clip(cx + cut_w // 2, 0, W)
        bby2 = np.clip(cy + cut_h // 2, 0, H)
        return bbx1, bby1, bbx2, bby2

# Define the CNN model with Adaptive Gradient Clipping (AGC)
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128 * 8 * 8, 256)
        self.fc2 = nn.Linear(256, 10)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 128 * 8 * 8)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

def train(model, train_loader, test_loader, criterion, optimizer, scheduler, num_epochs=20, use_amp=True):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    scaler = GradScaler("cuda") if use_amp else None
    best_acc = 0.0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for images, labels in tqdm(train_loader):
            images, labels = images.to(device), labels.to(device, dtype=torch.long)  # Ensure labels are of Long type
            optimizer.zero_grad()

            with autocast("cuda", enabled=use_amp):
                outputs = model(images)
                loss = criterion(outputs, labels)

            if use_amp:
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()

            running_loss += loss.item()

        scheduler.step()
        acc = evaluate(model, test_loader, device)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, Accuracy: {acc:.2f}%")

        if acc > best_acc:
            best_acc = acc
            torch.save(model.state_dict(), 'best_model.pth')

def evaluate(model, test_loader, device):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device, dtype=torch.long)  # Ensure labels are of Long type
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

# Argument parsing for training configurations
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Advanced CIFAR-10 Training')
    parser.add_argument('--num_epochs', type=int, default=30)
    parser.add_argument('--learning_rate', type=float, default=0.001)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--use_amp', action='store_true', default=True)
    parser.add_argument('--alpha', type=float, default=1.0, help='Mixup/CutMix alpha')
    parser.add_argument('--optimizer', type=str, default='AdamW', help='Optimizer (AdamW or SGD)')
    args, unknown = parser.parse_known_args()

    # Model, criterion, optimizer, and scheduler setup
    train_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
    ])

    train_dataset = MixupCutMixDataset(CIFAR10(root='./data', train=True, download=True, transform=train_transform), alpha=args.alpha, use_mixup=True)
    test_dataset = CIFAR10(root='./data', train=False, download=True, transform=test_transform)

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)

    model = SimpleCNN()
    criterion = nn.CrossEntropyLoss()

    if args.optimizer == 'AdamW':
        optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=5e-4)
    elif args.optimizer == 'SGD':
        optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=5e-4)

    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)

    logging.info("Starting advanced training with Mixup/CutMix, AMP, and CosineAnnealingWarmRestarts...")
    train(model, train_loader, test_loader, criterion, optimizer, scheduler, num_epochs=args.num_epochs, use_amp=args.use_amp)
    logging.info("Training completed. Best model saved as 'best_model.pth'.")