In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import OneCycleLR, CosineAnnealingWarmRestarts
from sklearn.metrics import accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import time
import os
import numpy as np
import random

# ----------------------------
# Set Seed for Reproducibility
# ----------------------------
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

set_seed()

# ----------------------------
# Residual Block with Squeeze-and-Excitation (SE) and DropBlock
# ----------------------------
class ResidualBlock(nn.Module):
    """
    A residual block with Squeeze-and-Excitation and DropBlock.
    """
    def __init__(self, in_channels, out_channels, stride=1, drop_prob=0.0):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(out_channels, out_channels,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.se = SqueezeExcitation(out_channels)

        self.shortcut = nn.Sequential()
        if stride !=1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

        self.dropblock = DropBlock(drop_prob) if drop_prob > 0 else nn.Identity()

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out = self.se(out)
        out += self.shortcut(x)
        out = self.relu(out)
        out = self.dropblock(out)
        return out

# ----------------------------
# Squeeze-and-Excitation Block
# ----------------------------
class SqueezeExcitation(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SqueezeExcitation, self).__init__()
        self.fc1 = nn.Linear(channel, channel // reduction, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(channel // reduction, channel, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, _, _ = x.size()
        y = F.adaptive_avg_pool2d(x, 1).view(b, c)
        y = self.fc1(y)
        y = self.relu(y)
        y = self.fc2(y)
        y = self.sigmoid(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

# ----------------------------
# DropBlock Regularization
# ----------------------------
class DropBlock(nn.Module):
    def __init__(self, drop_prob=0.3, block_size=7):
        super(DropBlock, self).__init__()
        self.drop_prob = drop_prob
        self.block_size = block_size

    def forward(self, x):
        if not self.training or self.drop_prob == 0.0:
            return x
        else:
            # Get gamma
            gamma = self._compute_gamma(x)
            # Sample mask
            mask = (torch.rand(x.shape[:2], device=x.device) < gamma).float()
            # Compute block mask
            mask = F.max_pool2d(mask.unsqueeze(1), kernel_size=self.block_size,
                                stride=1, padding=self.block_size//2)
            mask = 1 - mask.squeeze(1)
            # Apply mask
            out = x * mask.unsqueeze(2).unsqueeze(3)
            out = out * (mask.numel() / mask.sum())
            return out

    def _compute_gamma(self, x):
        b, c, h, w = x.size()
        return self.drop_prob * (h * w) / (self.block_size **2) / ((h - self.block_size +1) * (w - self.block_size +1))

# ----------------------------
# Global Feature Extractor
# ----------------------------
class GlobalFeatureExtractor(nn.Module):
    """
    Extracts global features from the input image with residual connections.
    """
    def __init__(self, in_channels, out_channels, num_blocks=4, drop_prob=0.3):
        super(GlobalFeatureExtractor, self).__init__()
        layers = []
        for i in range(num_blocks):
            stride = 1 if i ==0 else 2
            layers.append(ResidualBlock(in_channels if i ==0 else out_channels, out_channels, stride, drop_prob))
            in_channels = out_channels
        self.layer = nn.Sequential(*layers)
        self.pool = nn.AdaptiveAvgPool2d((1,1))

    def forward(self, x):
        x = self.layer(x)
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        return x  # [B, out_channels]

# ----------------------------
# Local Feature Extractor
# ----------------------------
class LocalFeatureExtractor(nn.Module):
    """
    Extracts local features from the input image with residual connections.
    """
    def __init__(self, in_channels, out_channels, num_blocks=4, drop_prob=0.3):
        super(LocalFeatureExtractor, self).__init__()
        layers = []
        for i in range(num_blocks):
            stride = 1 if i ==0 else 2
            layers.append(ResidualBlock(in_channels if i ==0 else out_channels, out_channels, stride, drop_prob))
            in_channels = out_channels
        self.layer = nn.Sequential(*layers)
        self.pool = nn.AdaptiveAvgPool2d((2,2))  # Retains some spatial information

    def forward(self, x):
        x = self.layer(x)
        x = self.pool(x)
        return x  # [B, out_channels, 2, 2]

# ----------------------------
# Enhanced Adaptive Attention Module with Multi-Head Attention
# ----------------------------
class EnhancedAdaptiveAttentionModule(nn.Module):
    """
    Enhanced Adaptive Attention Module with Multi-Head Attention.
    """
    def __init__(self, global_dim, local_dim, num_heads=4):
        super(EnhancedAdaptiveAttentionModule, self).__init__()
        self.query = nn.Linear(global_dim + local_dim, global_dim)
        self.key = nn.Linear(global_dim + local_dim, global_dim)
        self.value = nn.Linear(global_dim + local_dim, global_dim)
        self.multihead_attn = nn.MultiheadAttention(embed_dim=global_dim, num_heads=num_heads, batch_first=True)
        self.fc = nn.Sequential(
            nn.Linear(global_dim, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 2),
            nn.Softmax(dim=1)
        )

    def forward(self, global_feat, local_feat):
        # global_feat: [B, global_dim]
        # local_feat: [B, local_dim, H, W]
        local_feat_mean = torch.mean(local_feat, dim=[2,3])  # [B, local_dim]
        combined = torch.cat((global_feat, local_feat_mean), dim=1)  # [B, global_dim + local_dim]

        query = self.query(combined).unsqueeze(1)  # [B,1,global_dim]
        key = self.key(combined).unsqueeze(1)      # [B,1,global_dim]
        value = self.value(combined).unsqueeze(1)  # [B,1,global_dim]

        attn_output, _ = self.multihead_attn(query, key, value)  # [B,1,global_dim]
        attn_output = attn_output.squeeze(1)  # [B, global_dim]

        weights = self.fc(attn_output)  # [B,2]

        global_weight = weights[:,0].unsqueeze(1)  # [B,1]
        local_weight = weights[:,1].unsqueeze(1)   # [B,1]

        return global_weight, local_weight  # Each [B,1]

# ----------------------------
# Dynamic Holistic Perception Network
# ----------------------------
class DynamicHolisticPerceptionNetwork(nn.Module):
    """
    Dynamic Holistic Perception Network combining global and local features with residual connections.
    """
    def __init__(self, num_classes=10):
        super(DynamicHolisticPerceptionNetwork, self).__init__()
        # Initial Convolutional Layers
        self.initial_conv = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),  # [B,128,32,32]
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # [B,128,16,16]
            nn.Conv2d(128, 128, kernel_size=3, padding=1),  # [B,128,16,16]
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)   # [B,128,8,8]
        )

        # Feature Extractors with Residual Blocks
        self.global_extractor = GlobalFeatureExtractor(in_channels=128, out_channels=256, num_blocks=4, drop_prob=0.3)  # Output: [B,256]
        self.local_extractor = LocalFeatureExtractor(in_channels=128, out_channels=256, num_blocks=4, drop_prob=0.3)    # Output: [B,256,2,2]

        # Enhanced Adaptive Attention Module
        self.adaptive_attention = EnhancedAdaptiveAttentionModule(global_dim=256, local_dim=256, num_heads=4)

        # Fusion and Classification Layers
        self.fusion_fc = nn.Sequential(
            nn.Linear(256 + 256, 512),  # [B,512]
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.initial_conv(x)  # [B,128,8,8]
        global_feat = self.global_extractor(x)  # [B,256]
        local_feat = self.local_extractor(x)    # [B,256,2,2]

        global_weight, local_weight = self.adaptive_attention(global_feat, local_feat)  # Each [B,1]

        # Weight the features
        global_feat_weighted = global_feat * global_weight  # [B,256]
        local_feat_pooled = F.adaptive_avg_pool2d(local_feat, (1,1)).view(local_feat.size(0), -1)  # [B,256]
        local_feat_weighted = local_feat_pooled * local_weight  # [B,256]

        # Concatenate weighted features
        fused_feat = torch.cat((global_feat_weighted, local_feat_weighted), dim=1)  # [B,512]

        # Classification
        out = self.fusion_fc(fused_feat)  # [B, num_classes]
        return out

# ----------------------------
# Mixup and CutMix Functions
# ----------------------------
def rand_bbox(size, lam):
    '''Generates a random bounding box for CutMix.'''
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

def cutmix_data(x, y, alpha=1.0):
    '''Applies CutMix to the batch.'''
    if alpha > 0.:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)

    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
    x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2]

    # Adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size(-1) * x.size(-2)))
    return x, y, y[index], lam

def mixup_data(x, y, alpha=1.0):
    '''Applies MixUp to the batch.'''
    if alpha > 0.:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

# ----------------------------
# Label Smoothing Cross Entropy Loss
# ----------------------------
class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, smoothing=0.1):
        super(LabelSmoothingCrossEntropy, self).__init__()
        self.smoothing = smoothing
        self.confidence = 1.0 - smoothing

    def forward(self, x, target):
        log_probs = F.log_softmax(x, dim=-1)
        true_dist = torch.zeros_like(log_probs)
        true_dist.fill_(self.smoothing / (x.size(1) -1))
        true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * log_probs, dim=-1))

# ----------------------------
# Evaluation Function
# ----------------------------
def evaluate_model(model, dataloader, device):
    """
    Evaluates the model on the given dataloader.
    Returns accuracy, confusion matrix, and average loss.
    """
    model.eval()
    all_preds = []
    all_labels = []
    all_losses = []
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)  # [B, num_classes]
            loss = criterion(outputs, labels)
            _, preds = torch.max(outputs, 1)  # [B]
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_losses.append(loss.item() * inputs.size(0))
    acc = accuracy_score(all_labels, all_preds) * 100
    cm = confusion_matrix(all_labels, all_preds)
    avg_loss = sum(all_losses) / len(dataloader.dataset)
    return acc, cm, avg_loss

# ----------------------------
# Visualization Functions
# ----------------------------
def plot_metrics(train_losses, val_accuracies, val_losses, train_accuracies):
    """
    Plots training and validation loss and accuracy over epochs.
    """
    epochs = range(1, len(train_losses)+1)

    plt.figure(figsize=(14,5))

    # Plot Loss
    plt.subplot(1,2,1)
    plt.plot(epochs, train_losses, 'b-', label='Training Loss')
    plt.plot(epochs, val_losses, 'r-', label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    # Plot Accuracy
    plt.subplot(1,2,2)
    plt.plot(epochs, train_accuracies, 'b-', label='Training Accuracy')
    plt.plot(epochs, val_accuracies, 'r-', label='Validation Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()

    plt.show()

def plot_confusion_matrix(cm, classes):
    """
    Plots the confusion matrix.
    """
    plt.figure(figsize=(10,8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=classes, yticklabels=classes)
    plt.ylabel('Actual')
    plt.xlabel('Predicted')
    plt.title('Confusion Matrix')
    plt.show()

# ----------------------------
# Main Training and Evaluation Pipeline
# ----------------------------
def main():
    # ----------------------------
    # Device Configuration
    # ----------------------------
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # ----------------------------
    # Hyperparameters
    # ----------------------------
    num_epochs = 200  # Increased epochs for better convergence
    batch_size = 128
    learning_rate = 0.001
    num_classes = 10  # For CIFAR-10
    patience = 30  # For early stopping
    alpha = 1.0  # Mixup and CutMix alpha
    smoothing = 0.1  # Label smoothing

    # ----------------------------
    # Data Transformations with Advanced Augmentation
    # ----------------------------
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
        transforms.RandomErasing(p=0.1, scale=(0.02, 0.33), ratio=(0.3, 3.3), value='random')
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    # ----------------------------
    # Load Datasets
    # ----------------------------
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                              shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100,
                                             shuffle=False, num_workers=2)

    # ----------------------------
    # Initialize Model, Loss, Optimizer, Scheduler
    # ----------------------------
    model = DynamicHolisticPerceptionNetwork(num_classes=num_classes).to(device)
    print("Model initialized.")
    print(model)

    criterion = LabelSmoothingCrossEntropy(smoothing=smoothing)
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = OneCycleLR(optimizer, max_lr=learning_rate, steps_per_epoch=len(trainloader),
                           epochs=num_epochs, anneal_strategy='cos')

    # ----------------------------
    # Initialize Tracking Variables
    # ----------------------------
    best_val_acc = 0.0
    trigger_times = 0
    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []

    # ----------------------------
    # Class Names for CIFAR-10
    # ----------------------------
    classes = ['Airplane','Automobile','Bird','Cat','Deer',
               'Dog','Frog','Horse','Ship','Truck']

    # ----------------------------
    # Delete existing best_model.pth to avoid loading incompatible weights
    # ----------------------------
    if os.path.exists('best_model.pth'):
        os.remove('best_model.pth')
        print("Existing 'best_model.pth' removed.")

    # ----------------------------
    # Training Loop with Mixup, CutMix, Label Smoothing, and Early Stopping
    # ----------------------------
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        start_time = time.time()

        for inputs, labels in trainloader:
            inputs, labels = inputs.to(device), labels.to(device)

            # Apply CutMix with probability 0.5
            r = np.random.rand()
            if r < 0.5:
                inputs, targets_a, targets_b, lam = cutmix_data(inputs, labels, alpha=alpha)
                outputs = model(inputs)
                loss = lam * criterion(outputs, targets_a) + (1 - lam) * criterion(outputs, targets_b)
            else:
                # Apply MixUp
                inputs, targets_a, targets_b, lam = mixup_data(inputs, labels, alpha=alpha)
                outputs = model(inputs)
                loss = lam * criterion(outputs, targets_a) + (1 - lam) * criterion(outputs, targets_b)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Backward pass and optimization
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)  # Gradient Clipping
            optimizer.step()
            scheduler.step()  # Updated: Removed epoch parameter

            # Statistics
            running_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            # For accuracy, take the prediction and compare to targets_a and targets_b
            correct += (lam * predicted.eq(targets_a).sum().item() + (1 - lam) * predicted.eq(targets_b).sum().item())

        epoch_loss = running_loss / total
        epoch_acc = 100. * correct / total
        train_losses.append(epoch_loss)
        train_accuracies.append(epoch_acc)
        elapsed_time = time.time() - start_time

        # Validation
        val_acc, val_cm, val_loss = evaluate_model(model, testloader, device)
        val_losses.append(val_loss)
        val_accuracies.append(val_acc)

        print(f"Epoch [{epoch+1}/{num_epochs}] - "
              f"Loss: {epoch_loss:.4f} - "
              f"Train Acc: {epoch_acc:.2f}% - "
              f"Val Acc: {val_acc:.2f}% - "
              f"Val Loss: {val_loss:.4f} - "
              f"Time: {elapsed_time:.2f}s")

        # Early Stopping and Saving Best Model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            trigger_times = 0
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"Best model saved with Val Acc: {best_val_acc:.2f}%")
        else:
            trigger_times += 1
            if trigger_times >= patience:
                print("Early stopping triggered!")
                break

    # ----------------------------
    # Final Evaluation
    # ----------------------------
    print("\nTraining Completed!")
    print(f"Best Validation Accuracy: {best_val_acc:.2f}%")

    # Load the best model
    if os.path.exists('best_model.pth'):
        try:
            state_dict = torch.load('best_model.pth', map_location=device)
            model.load_state_dict(state_dict)
            print("Best model loaded successfully.")
        except Exception as e:
            print(f"Error loading best_model.pth: {e}")
            print("Please ensure that the model architecture matches exactly when saving and loading.")
    else:
        print("No best model found to load.")

    # Final evaluation
    final_acc, final_cm, final_loss = evaluate_model(model, testloader, device)
    print(f"\nFinal Test Accuracy: {final_acc:.2f}%")
    plot_confusion_matrix(final_cm, classes)

    # ----------------------------
    # Plot Training and Validation Metrics
    # ----------------------------
    plot_metrics(train_losses, val_accuracies, val_losses, train_accuracies)

if __name__ == '__main__':
    main()


Using device: cuda
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:03<00:00, 43.2MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Model initialized.
DynamicHolisticPerceptionNetwork(
  (initial_conv): Sequential(
    (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (global_extractor): GlobalFeatureExtractor(
    (layer): Sequential(
      (0): ResidualBlock(
        (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, trac