In [15]:
import torch
import torchvision.models as models
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import random_split
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
from torchvision.datasets import CIFAR100
import torch.nn as nn
from tqdm import tqdm

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [14]:
# Define transformations for CIFAR-100
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),  # Stronger augmentation
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
])

# Load the full training dataset with training transformations
full_train_dataset = CIFAR100(root='./data', train=True, download=True, transform=transform_train)

# Create a validation split from the training data
val_size = int(0.1 * len(full_train_dataset))  # 10% for validation
train_size = len(full_train_dataset) - val_size

# Use random_split to create the training and validation datasets
train_dataset, val_dataset = random_split(
    full_train_dataset, 
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)  # For reproducibility
)

# For the validation set, we need to replace the transform
# Create a copy of the validation dataset with test transforms
val_dataset = CIFAR100(root='./data', train=True, download=False, transform=transform_test)

# Only use the validation indices
val_indices = val_dataset.indices if hasattr(val_dataset, 'indices') else range(len(full_train_dataset))[train_size:]
val_dataset = torch.utils.data.Subset(val_dataset, val_indices)

test_dataset = CIFAR100(root='./data', train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4, pin_memory=True)

In [7]:
# Initialize ResNet18 model for CIFAR100 (100 classes)
model = models.resnet18(weights=None)  # Start from scratch for a true baseline

model.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)

# Adjust final fully connected layer for 100 classes
model.fc = torch.nn.Linear(model.fc.in_features, 100)
model = model.to(device)



In [8]:
# Updated improved_pgd_attack function
def improved_pgd_attack(model, images, labels, target_labels=None, epsilon=4/255, alpha=1/255, 
                       iters=3, random_start=True, untargeted=False):
    """
    Enhanced PGD attack with better performance and stability.
    Works with any input tensor shape.
    """
    device = next(model.parameters()).device
    images = images.clone().detach().to(device)
    labels = labels.to(device)
    
    if target_labels is not None:
        target_labels = target_labels.to(device)
    
    # Loss function with per-sample losses
    criterion = nn.CrossEntropyLoss(reduction='none')
    
    # Initialize adversarial images
    adv_images = images.clone().detach()
    
    # Start with random noise within epsilon ball if requested
    if random_start:
        noise = torch.FloatTensor(images.shape).uniform_(-epsilon, epsilon).to(device)
        adv_images = adv_images + noise
        adv_images = torch.clamp(adv_images, 0, 1)
    
    best_adv_images = adv_images.clone()
    best_loss = None  # Will be a tensor of shape (batch_size,)
    
    for i in range(iters):
        adv_images.requires_grad = True
        
        outputs = model(adv_images)
        
        # Compute per-sample loss
        if untargeted:
            # For untargeted, maximize loss on true labels
            loss_per_sample = criterion(outputs, labels)  # Shape: (batch_size,)
            loss = -loss_per_sample.mean()  # Scalar for gradient computation
        else:
            if target_labels is None:
                raise ValueError("Target labels must be provided for targeted attack")
            # For targeted, minimize loss on target labels
            loss_per_sample = criterion(outputs, target_labels)
            loss = loss_per_sample.mean()
        
        # Check for NaN or inf in loss
        if torch.isnan(loss) or torch.isinf(loss):
            print(f"Warning: Invalid loss detected at iteration {i}. Returning last valid images.")
            return best_adv_images.detach()
        
        # Clear gradients to prevent accumulation
        model.zero_grad()
        loss.backward()
        
        # Gradient step
        with torch.no_grad():
            grad = adv_images.grad.sign()
            if untargeted:
                adv_images = adv_images + alpha * grad  # Increase loss
            else:
                adv_images = adv_images - alpha * grad  # Decrease loss
            delta = torch.clamp(adv_images - images, min=-epsilon, max=epsilon)
            adv_images = torch.clamp(images + delta, min=0, max=1)
            
            # Check for NaN or inf in adv_images
            if torch.isnan(adv_images).any() or torch.isinf(adv_images).any():
                print(f"Warning: Invalid adv_images detected at iteration {i}. Returning last valid images.")
                return best_adv_images.detach()
        
        # Track best adversarial images
        with torch.no_grad():
            _, predicted = outputs.max(1)
            if untargeted:
                success_mask = (predicted != labels)
                # Maximize loss_per_sample
                if best_loss is None:
                    best_loss = loss_per_sample.clone()
                    best_adv_images = adv_images.clone()
                else:
                    improved = loss_per_sample > best_loss  # Want larger loss
                    update_indices = torch.nonzero(improved & success_mask).view(-1)
                    if len(update_indices) > 0:
                        best_adv_images[update_indices] = adv_images[update_indices]
                        best_loss[update_indices] = loss_per_sample[update_indices]
            else:
                success_mask = (predicted == target_labels)
                # Minimize loss_per_sample
                if best_loss is None:
                    best_loss = loss_per_sample.clone()
                    best_adv_images = adv_images.clone()
                else:
                    improved = loss_per_sample < best_loss  # Want smaller loss
                    update_indices = torch.nonzero(improved & success_mask).view(-1)
                    if len(update_indices) > 0:
                        best_adv_images[update_indices] = adv_images[update_indices]
                        best_loss[update_indices] = loss_per_sample[update_indices]
    
    return best_adv_images.detach()

def improved_evaluate_attack(model, testloader, epsilon=20/255, alpha=8/255, iters=100, subset_size=None):
    """
    Evaluates the model on clean and adversarial examples with improved attack
    """
    model.eval()
    clean_correct = 0
    adv_correct = 0
    total = 0
    
    # Handle subset processing
    if subset_size:
        try:
            subset_loader = torch.utils.data.DataLoader(
                testset, batch_size=128, shuffle=True, num_workers=4
            )
        except NameError:
            subset_loader = testloader
        max_batches = subset_size // 128 + 1
    else:
        subset_loader = testloader
        max_batches = len(testloader)
    
    for i, (images, labels) in enumerate(subset_loader):
        if i >= max_batches:
            break
            
        images, labels = images.to(device), labels.to(device)
        total += labels.size(0)
        
        # Clean accuracy
        with torch.no_grad():
            outputs = model(images)
            _, predicted = outputs.max(1)
            clean_correct += predicted.eq(labels).sum().item()
        
        # Generate adversarial examples using untargeted attack (more effective)
        adv_images = improved_pgd_attack(
            model, 
            images, 
            labels,
            target_labels=None, 
            epsilon=epsilon, 
            alpha=alpha, 
            iters=iters,
            random_start=True,
            untargeted=True  # Untargeted attacks are typically more successful
        )
        
        # Evaluate adversarial accuracy
        with torch.no_grad():
            adv_outputs = model(adv_images)
            _, adv_predicted = adv_outputs.max(1)
            adv_success = (adv_predicted != labels).sum().item()  # Count misclassifications
            adv_correct += adv_success
    
    clean_acc = 100. * clean_correct / total
    attack_success = 100. * adv_correct / total
    print(f"Clean Accuracy: {clean_acc:.2f}%")
    print(f"Attack Success Rate: {attack_success:.2f}%")
    return clean_acc, attack_success

In [9]:
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()

# SGD with momentum typically works better for adversarial training
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)

# Better learning rate scheduler
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)


### Pretrain on Clean Data

In [10]:
# Training loop with tqdm
num_epochs = 20
best_acc = 0

for epoch in range(num_epochs):
    # Training phase
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0
    
    train_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
    for inputs, labels in train_bar:
        inputs, labels = inputs.cuda(), labels.cuda()
        
        # Zero the parameter gradients
        optimizer.zero_grad(set_to_none=True)
        
        # Forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        # Statistics
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        # Update progress bar
        train_bar.set_postfix({
            'loss': f'{train_loss/len(train_bar):.4f}',
            'acc': f'{100.*correct/total:.2f}%',
            'lr': f'{scheduler.get_last_lr()[0]:.6f}'
        })
    
    train_loss = train_loss / len(train_loader)
    train_acc = 100. * correct / total
    
    # Validation phase
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    
    test_bar = tqdm(test_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Test]')
    with torch.no_grad():
        for inputs, labels in test_bar:
            inputs, labels = inputs.cuda(), labels.cuda()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            # Update progress bar
            test_bar.set_postfix({
                'loss': f'{test_loss/len(test_bar):.4f}',
                'acc': f'{100.*correct/total:.2f}%'
            })
    
    test_loss = test_loss / len(test_loader)
    test_acc = 100. * correct / total
    
    # Update learning rate
    scheduler.step()
    
    # Save checkpoint if it's the best model so far
    if test_acc > best_acc:
        best_acc = test_acc
        checkpoint = {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'best_acc': best_acc,
        }
        torch.save(checkpoint, 'best_resnet18_cifar100_untargeted_adv.pth')
        print(f'Checkpoint saved! New best accuracy: {best_acc:.2f}%')
    
    # Print epoch summary
    print(f'Epoch: {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%')

print(f"Training completed! Best accuracy: {best_acc:.2f}%")

Epoch 1/20 [Train]: 100%|██████████| 391/391 [00:17<00:00, 22.89it/s, loss=4.2746, acc=5.40%, lr=0.100000]
Epoch 1/20 [Test]: 100%|██████████| 79/79 [00:01<00:00, 52.07it/s, loss=3.8234, acc=11.09%]


Checkpoint saved! New best accuracy: 11.09%
Epoch: 1/20 | Train Loss: 4.2746 | Train Acc: 5.40% | Test Loss: 3.8234 | Test Acc: 11.09%


Epoch 2/20 [Train]: 100%|██████████| 391/391 [00:16<00:00, 23.86it/s, loss=3.7627, acc=12.01%, lr=0.099994]
Epoch 2/20 [Test]: 100%|██████████| 79/79 [00:01<00:00, 52.39it/s, loss=3.2825, acc=19.82%]


Checkpoint saved! New best accuracy: 19.82%
Epoch: 2/20 | Train Loss: 3.7627 | Train Acc: 12.01% | Test Loss: 3.2825 | Test Acc: 19.82%


Epoch 3/20 [Train]: 100%|██████████| 391/391 [00:16<00:00, 23.90it/s, loss=3.4113, acc=18.00%, lr=0.099975]
Epoch 3/20 [Test]: 100%|██████████| 79/79 [00:01<00:00, 53.36it/s, loss=2.8535, acc=27.55%]


Checkpoint saved! New best accuracy: 27.55%
Epoch: 3/20 | Train Loss: 3.4113 | Train Acc: 18.00% | Test Loss: 2.8535 | Test Acc: 27.55%


Epoch 4/20 [Train]: 100%|██████████| 391/391 [00:16<00:00, 23.92it/s, loss=3.1465, acc=22.94%, lr=0.099944]
Epoch 4/20 [Test]: 100%|██████████| 79/79 [00:01<00:00, 52.83it/s, loss=2.7944, acc=29.00%]


Checkpoint saved! New best accuracy: 29.00%
Epoch: 4/20 | Train Loss: 3.1465 | Train Acc: 22.94% | Test Loss: 2.7944 | Test Acc: 29.00%


Epoch 5/20 [Train]: 100%|██████████| 391/391 [00:16<00:00, 23.74it/s, loss=2.9714, acc=25.98%, lr=0.099901]
Epoch 5/20 [Test]: 100%|██████████| 79/79 [00:01<00:00, 53.38it/s, loss=2.7871, acc=30.39%]


Checkpoint saved! New best accuracy: 30.39%
Epoch: 5/20 | Train Loss: 2.9714 | Train Acc: 25.98% | Test Loss: 2.7871 | Test Acc: 30.39%


Epoch 6/20 [Train]: 100%|██████████| 391/391 [00:16<00:00, 23.83it/s, loss=2.8145, acc=29.07%, lr=0.099846]
Epoch 6/20 [Test]: 100%|██████████| 79/79 [00:01<00:00, 52.76it/s, loss=2.5405, acc=34.18%]


Checkpoint saved! New best accuracy: 34.18%
Epoch: 6/20 | Train Loss: 2.8145 | Train Acc: 29.07% | Test Loss: 2.5405 | Test Acc: 34.18%


Epoch 7/20 [Train]: 100%|██████████| 391/391 [00:16<00:00, 23.97it/s, loss=2.7182, acc=31.32%, lr=0.099778]
Epoch 7/20 [Test]: 100%|██████████| 79/79 [00:01<00:00, 53.13it/s, loss=2.6800, acc=32.72%]


Epoch: 7/20 | Train Loss: 2.7182 | Train Acc: 31.32% | Test Loss: 2.6800 | Test Acc: 32.72%


Epoch 8/20 [Train]: 100%|██████████| 391/391 [00:16<00:00, 23.84it/s, loss=2.6112, acc=33.43%, lr=0.099698]
Epoch 8/20 [Test]: 100%|██████████| 79/79 [00:01<00:00, 53.28it/s, loss=2.4166, acc=37.68%]


Checkpoint saved! New best accuracy: 37.68%
Epoch: 8/20 | Train Loss: 2.6112 | Train Acc: 33.43% | Test Loss: 2.4166 | Test Acc: 37.68%


Epoch 9/20 [Train]: 100%|██████████| 391/391 [00:16<00:00, 23.86it/s, loss=2.5389, acc=35.10%, lr=0.099606]
Epoch 9/20 [Test]: 100%|██████████| 79/79 [00:01<00:00, 52.62it/s, loss=2.2593, acc=40.78%]


Checkpoint saved! New best accuracy: 40.78%
Epoch: 9/20 | Train Loss: 2.5389 | Train Acc: 35.10% | Test Loss: 2.2593 | Test Acc: 40.78%


Epoch 10/20 [Train]: 100%|██████████| 391/391 [00:16<00:00, 23.91it/s, loss=2.4755, acc=36.50%, lr=0.099501]
Epoch 10/20 [Test]: 100%|██████████| 79/79 [00:01<00:00, 52.84it/s, loss=2.4373, acc=37.53%]


Epoch: 10/20 | Train Loss: 2.4755 | Train Acc: 36.50% | Test Loss: 2.4373 | Test Acc: 37.53%


Epoch 11/20 [Train]: 100%|██████████| 391/391 [00:16<00:00, 23.93it/s, loss=2.4263, acc=37.15%, lr=0.099384]
Epoch 11/20 [Test]: 100%|██████████| 79/79 [00:01<00:00, 53.15it/s, loss=2.1466, acc=42.94%]


Checkpoint saved! New best accuracy: 42.94%
Epoch: 11/20 | Train Loss: 2.4263 | Train Acc: 37.15% | Test Loss: 2.1466 | Test Acc: 42.94%


Epoch 12/20 [Train]: 100%|██████████| 391/391 [00:16<00:00, 23.87it/s, loss=2.3680, acc=38.48%, lr=0.099255]
Epoch 12/20 [Test]: 100%|██████████| 79/79 [00:01<00:00, 53.46it/s, loss=2.0424, acc=45.10%]


Checkpoint saved! New best accuracy: 45.10%
Epoch: 12/20 | Train Loss: 2.3680 | Train Acc: 38.48% | Test Loss: 2.0424 | Test Acc: 45.10%


Epoch 13/20 [Train]: 100%|██████████| 391/391 [00:16<00:00, 23.98it/s, loss=2.3270, acc=39.81%, lr=0.099114]
Epoch 13/20 [Test]: 100%|██████████| 79/79 [00:01<00:00, 53.23it/s, loss=2.4313, acc=38.83%]


Epoch: 13/20 | Train Loss: 2.3270 | Train Acc: 39.81% | Test Loss: 2.4313 | Test Acc: 38.83%


Epoch 14/20 [Train]: 100%|██████████| 391/391 [00:16<00:00, 23.96it/s, loss=2.2965, acc=40.43%, lr=0.098961]
Epoch 14/20 [Test]: 100%|██████████| 79/79 [00:01<00:00, 51.95it/s, loss=2.2776, acc=41.22%]


Epoch: 14/20 | Train Loss: 2.2965 | Train Acc: 40.43% | Test Loss: 2.2776 | Test Acc: 41.22%


Epoch 15/20 [Train]: 100%|██████████| 391/391 [00:16<00:00, 23.76it/s, loss=2.2624, acc=40.92%, lr=0.098796]
Epoch 15/20 [Test]: 100%|██████████| 79/79 [00:01<00:00, 52.43it/s, loss=2.0078, acc=46.44%]


Checkpoint saved! New best accuracy: 46.44%
Epoch: 15/20 | Train Loss: 2.2624 | Train Acc: 40.92% | Test Loss: 2.0078 | Test Acc: 46.44%


Epoch 16/20 [Train]: 100%|██████████| 391/391 [00:16<00:00, 23.78it/s, loss=2.2243, acc=42.02%, lr=0.098618]
Epoch 16/20 [Test]: 100%|██████████| 79/79 [00:01<00:00, 53.34it/s, loss=2.3144, acc=40.54%]


Epoch: 16/20 | Train Loss: 2.2243 | Train Acc: 42.02% | Test Loss: 2.3144 | Test Acc: 40.54%


Epoch 17/20 [Train]: 100%|██████████| 391/391 [00:16<00:00, 23.88it/s, loss=2.2191, acc=41.83%, lr=0.098429]
Epoch 17/20 [Test]: 100%|██████████| 79/79 [00:01<00:00, 52.55it/s, loss=1.9776, acc=47.22%]


Checkpoint saved! New best accuracy: 47.22%
Epoch: 17/20 | Train Loss: 2.2191 | Train Acc: 41.83% | Test Loss: 1.9776 | Test Acc: 47.22%


Epoch 18/20 [Train]: 100%|██████████| 391/391 [00:16<00:00, 24.16it/s, loss=2.1772, acc=42.95%, lr=0.098228]
Epoch 18/20 [Test]: 100%|██████████| 79/79 [00:01<00:00, 52.31it/s, loss=2.1846, acc=43.80%]


Epoch: 18/20 | Train Loss: 2.1772 | Train Acc: 42.95% | Test Loss: 2.1846 | Test Acc: 43.80%


Epoch 19/20 [Train]: 100%|██████████| 391/391 [00:16<00:00, 23.86it/s, loss=2.1677, acc=43.24%, lr=0.098015]
Epoch 19/20 [Test]: 100%|██████████| 79/79 [00:01<00:00, 51.83it/s, loss=1.8595, acc=49.26%]


Checkpoint saved! New best accuracy: 49.26%
Epoch: 19/20 | Train Loss: 2.1677 | Train Acc: 43.24% | Test Loss: 1.8595 | Test Acc: 49.26%


Epoch 20/20 [Train]: 100%|██████████| 391/391 [00:16<00:00, 23.94it/s, loss=2.1471, acc=43.51%, lr=0.097790]
Epoch 20/20 [Test]: 100%|██████████| 79/79 [00:01<00:00, 53.22it/s, loss=1.8369, acc=50.77%]


Checkpoint saved! New best accuracy: 50.77%
Epoch: 20/20 | Train Loss: 2.1471 | Train Acc: 43.51% | Test Loss: 1.8369 | Test Acc: 50.77%
Training completed! Best accuracy: 50.77%


### Train on Adversarial Data

In [18]:
# TRADES loss function
def trades_loss(model, x_natural, y, optimizer, step_size=0.003, epsilon=0.031, perturb_steps=10, beta=6.0):
    criterion_kl = nn.KLDivLoss(reduction='batchmean')
    model.eval()  # Use evaluation mode for generating adversaries
    batch_size = len(x_natural)
    
    # Generate adversarial examples
    x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).to(device).detach()
    
    for _ in range(perturb_steps):
        x_adv.requires_grad_()
        with torch.enable_grad():
            loss_kl = criterion_kl(
                F.log_softmax(model(x_adv), dim=1),
                F.softmax(model(x_natural), dim=1)
            )
        grad = torch.autograd.grad(loss_kl, [x_adv])[0]
        x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
        x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon)
        x_adv = torch.clamp(x_adv, 0.0, 1.0)
    
    model.train()  # Switch back to training mode
    
    # Calculate the TRADES loss
    logits_natural = model(x_natural)
    logits_adv = model(x_adv)
    loss_natural = F.cross_entropy(logits_natural, y)
    loss_robust = criterion_kl(
        F.log_softmax(logits_adv, dim=1),
        F.softmax(logits_natural, dim=1)
    )
    
    # Total loss
    loss = loss_natural + beta * loss_robust
    return loss, logits_natural, logits_adv, x_adv

# Function to compute class weights based on model confusion
def compute_class_weights(model, dataloader, num_classes=100):
    confusion = torch.zeros(num_classes, num_classes).to(device)
    model.eval()
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Computing class weights"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = outputs.max(1)
            for t, p in zip(labels.view(-1), preds.view(-1)):
                confusion[t.long(), p.long()] += 1
    
    # Normalize by row sums
    confusion = confusion / (confusion.sum(dim=1, keepdim=True) + 1e-8)
    
    # Class weights: higher weight for more confused classes
    # We use the sum of off-diagonal elements as confusion score
    diag_indices = torch.arange(num_classes)
    confusion[diag_indices, diag_indices] = 0
    class_weights = confusion.sum(dim=1)
    
    # Normalize weights
    class_weights = class_weights / class_weights.mean()
    
    # Scale weights to be in a reasonable range
    class_weights = 0.5 + class_weights / 2
    
    return class_weights

# Evaluate adversarial robustness using PGD attack
def evaluate_robustness(model, dataloader, epsilon=8/255, alpha=2/255, iters=20):
    model.eval()
    robust_correct = 0
    clean_correct = 0
    total = 0
    
    for inputs, targets in tqdm(dataloader, desc="Evaluating robustness"):
        inputs, targets = inputs.to(device), targets.to(device)
        
        # Clean accuracy
        with torch.no_grad():
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            clean_correct += predicted.eq(targets).sum().item()
        
        # Create adversarial examples with PGD
        x_adv = inputs.clone().detach() + 0.001 * torch.randn(inputs.shape).to(device).detach()
        x_adv = torch.clamp(x_adv, 0.0, 1.0)
        
        for _ in range(iters):
            x_adv.requires_grad_()
            with torch.enable_grad():
                outputs_adv = model(x_adv)
                loss = F.cross_entropy(outputs_adv, targets)
            
            grad = torch.autograd.grad(loss, [x_adv])[0]
            x_adv = x_adv.detach() + alpha * torch.sign(grad.detach())
            delta = torch.clamp(x_adv - inputs, -epsilon, epsilon)
            x_adv = torch.clamp(inputs + delta, 0.0, 1.0)
        
        # Evaluate on adversarial examples
        with torch.no_grad():
            outputs = model(x_adv)
            _, predicted = outputs.max(1)
            robust_correct += predicted.eq(targets).sum().item()
        
        total += targets.size(0)
    
    clean_acc = 100.0 * clean_correct / total
    robust_acc = 100.0 * robust_correct / total
    
    return clean_acc, robust_acc

In [20]:
# Pretrained model
model = models.resnet18(weights=None)
model.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.fc = nn.Linear(model.fc.in_features, 100)  # 100 classes for CIFAR-100

pretrained_checkpoint = torch.load('best_resnet18_cifar100_untargeted_adv.pth')
model.load_state_dict(pretrained_checkpoint['state_dict'])
model = model.to(device)
print("Loaded pretrained model with clean accuracy around 50%")

# Since the model is already trained, compute class weights immediately
print("Computing initial class weights...")
class_weights = compute_class_weights(model, val_loader)
print(f"Class weights range: Min={class_weights.min().item():.4f}, Max={class_weights.max().item():.4f}")
weighted_criterion = nn.CrossEntropyLoss(weight=class_weights)

# Adjust learning rate since we're continuing from a pretrained model
for param_group in optimizer.param_groups:
    param_group['lr'] = param_group['lr'] * 0.5  # Reduce initial learning rate

# Main training loop
num_epochs = 200
best_acc = 0
best_robust_acc = 0
epsilon = 8/255  # PGD attack strength
alpha = 2/255    # PGD step size
iters = 7        # Number of PGD iterations

for epoch in range(num_epochs):
    # Update class weights every 10 epochs (after initial stabilization)
    if epoch % 15 == 0 and epoch > 0:
        print("Updating class weights...")
        class_weights = compute_class_weights(model, val_loader)
        print(f"Class weights range: Min={class_weights.min().item():.4f}, Max={class_weights.max().item():.4f}")
        weighted_criterion = nn.CrossEntropyLoss(weight=class_weights)
    
    model.train()
    train_loss = 0.0
    nat_correct = 0
    total = 0
    
    train_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
    for batch_idx, (inputs, labels) in enumerate(train_bar):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad(set_to_none=True)
        
        # Use TRADES loss
        loss, logits_natural, logits_adv, adv_images = trades_loss(
            model, inputs, labels, optimizer, 
            step_size=alpha, epsilon=epsilon, 
            perturb_steps=iters, beta=8.0
        )
        
        # Check for NaN or inf in loss
        if torch.isnan(loss) or torch.isinf(loss):
            print(f"Warning: Invalid training loss at batch {batch_idx}. Skipping update.")
            continue
            
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = logits_natural.max(1)
        total += labels.size(0)
        nat_correct += predicted.eq(labels).sum().item()
        
        # Calculate adversarial accuracy for the batch
        _, adv_predicted = logits_adv.max(1)
        adv_correct = adv_predicted.eq(labels).sum().item()
        adv_acc = 100.0 * adv_correct / labels.size(0)
        
        # Update progress bar with detailed stats
        train_bar.set_postfix({
            'loss': f'{train_loss / (batch_idx + 1):.4f}',
            'nat_acc': f'{100.*nat_correct/total:.2f}%',
            'adv_acc': f'{adv_acc:.2f}%',
            'lr': f'{scheduler.get_last_lr()[0]:.6f}'
        })
    
    train_loss = train_loss / len(train_loader)
    train_acc = 100. * nat_correct / total
    
    # Validation phase
    model.eval()
    val_loss = 0
    val_correct = 0
    val_total = 0
    
    val_bar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Val]')
    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(val_bar):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = F.cross_entropy(outputs, labels)
            
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            val_total += labels.size(0)
            val_correct += predicted.eq(labels).sum().item()
            
            val_bar.set_postfix({
                'loss': f'{val_loss / (batch_idx + 1):.4f}',
                'acc': f'{100.*val_correct/val_total:.2f}%'
            })
    
    val_loss = val_loss / len(val_loader)
    val_acc = 100. * val_correct / val_total
    
    # Evaluate robust accuracy on test set every 5 epochs
    if epoch % 3 == 0 or epoch == num_epochs - 1:
        clean_acc, robust_acc = evaluate_robustness(model, test_loader, epsilon, alpha, iters=20)
        print(f'Clean Test Acc: {clean_acc:.2f}% | Robust Test Acc: {robust_acc:.2f}%')
        
        # Save model if robust accuracy improves
        if robust_acc > best_robust_acc:
            best_robust_acc = robust_acc
            torch.save({
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'best_clean_acc': clean_acc,
                'best_robust_acc': robust_acc,
            }, 'best_robust_model.pth')
            print(f'Robust checkpoint saved! New best robust accuracy: {best_robust_acc:.2f}%')
    
    # Save model if validation accuracy improves
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save({
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'best_acc': best_acc,
        }, 'best_clean_model.pth')
        print(f'Clean checkpoint saved! New best accuracy: {best_acc:.2f}%')
    
    # Step the scheduler
    scheduler.step()
    
    print(f'Epoch: {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%')

print(f"Training completed! Best accuracy: {best_acc:.2f}%, Best robust accuracy: {best_robust_acc:.2f}%")


Loaded pretrained model with clean accuracy around 50%
Computing initial class weights...


Computing class weights: 100%|██████████| 40/40 [00:12<00:00,  3.33it/s]


Class weights range: Min=0.5553, Max=1.5848


Epoch 1/200 [Train]: 100%|██████████| 352/352 [02:00<00:00,  2.93it/s, loss=16.9556, nat_acc=45.35%, adv_acc=13.89%, lr=0.097553]
Epoch 1/200 [Val]: 100%|██████████| 40/40 [00:00<00:00, 47.49it/s, loss=1.9793, acc=47.18%]
Evaluating robustness: 100%|██████████| 79/79 [00:42<00:00,  1.87it/s]


Clean Test Acc: 44.56% | Robust Test Acc: 10.72%
Robust checkpoint saved! New best robust accuracy: 10.72%
Clean checkpoint saved! New best accuracy: 47.18%
Epoch: 1/200 | Train Loss: 16.9556 | Train Acc: 45.35% | Val Loss: 1.9793 | Val Acc: 47.18%


Epoch 2/200 [Train]: 100%|██████████| 352/352 [01:59<00:00,  2.96it/s, loss=16.9971, nat_acc=45.12%, adv_acc=22.22%, lr=0.048652]
Epoch 2/200 [Val]: 100%|██████████| 40/40 [00:00<00:00, 49.38it/s, loss=1.9632, acc=47.58%]


Clean checkpoint saved! New best accuracy: 47.58%
Epoch: 2/200 | Train Loss: 16.9971 | Train Acc: 45.12% | Val Loss: 1.9632 | Val Acc: 47.58%


Epoch 3/200 [Train]: 100%|██████████| 352/352 [01:59<00:00,  2.95it/s, loss=17.0234, nat_acc=44.96%, adv_acc=12.50%, lr=0.048522]
Epoch 3/200 [Val]: 100%|██████████| 40/40 [00:00<00:00, 46.78it/s, loss=1.9777, acc=46.90%]


Epoch: 3/200 | Train Loss: 17.0234 | Train Acc: 44.96% | Val Loss: 1.9777 | Val Acc: 46.90%


Epoch 4/200 [Train]: 100%|██████████| 352/352 [01:59<00:00,  2.96it/s, loss=17.0861, nat_acc=45.06%, adv_acc=16.67%, lr=0.048386]
Epoch 4/200 [Val]: 100%|██████████| 40/40 [00:00<00:00, 47.44it/s, loss=1.9615, acc=47.62%]
Evaluating robustness: 100%|██████████| 79/79 [00:42<00:00,  1.87it/s]


Clean Test Acc: 44.68% | Robust Test Acc: 10.64%
Clean checkpoint saved! New best accuracy: 47.62%
Epoch: 4/200 | Train Loss: 17.0861 | Train Acc: 45.06% | Val Loss: 1.9615 | Val Acc: 47.62%


Epoch 5/200 [Train]: 100%|██████████| 352/352 [01:59<00:00,  2.95it/s, loss=17.0022, nat_acc=45.42%, adv_acc=16.67%, lr=0.048244]
Epoch 5/200 [Val]: 100%|██████████| 40/40 [00:00<00:00, 47.24it/s, loss=1.9822, acc=46.94%]


Epoch: 5/200 | Train Loss: 17.0022 | Train Acc: 45.42% | Val Loss: 1.9822 | Val Acc: 46.94%


Epoch 6/200 [Train]: 100%|██████████| 352/352 [01:58<00:00,  2.96it/s, loss=16.9560, nat_acc=45.47%, adv_acc=23.61%, lr=0.048097]
Epoch 6/200 [Val]: 100%|██████████| 40/40 [00:00<00:00, 45.06it/s, loss=1.9770, acc=47.08%]


Epoch: 6/200 | Train Loss: 16.9560 | Train Acc: 45.47% | Val Loss: 1.9770 | Val Acc: 47.08%


Epoch 7/200 [Train]: 100%|██████████| 352/352 [01:59<00:00,  2.95it/s, loss=17.0599, nat_acc=45.51%, adv_acc=19.44%, lr=0.047944]
Epoch 7/200 [Val]: 100%|██████████| 40/40 [00:00<00:00, 46.74it/s, loss=1.9568, acc=47.62%]
Evaluating robustness: 100%|██████████| 79/79 [00:42<00:00,  1.86it/s]


Clean Test Acc: 44.58% | Robust Test Acc: 10.59%
Epoch: 7/200 | Train Loss: 17.0599 | Train Acc: 45.51% | Val Loss: 1.9568 | Val Acc: 47.62%


Epoch 8/200 [Train]: 100%|██████████| 352/352 [01:59<00:00,  2.96it/s, loss=17.1161, nat_acc=45.33%, adv_acc=18.06%, lr=0.047785]
Epoch 8/200 [Val]: 100%|██████████| 40/40 [00:00<00:00, 47.60it/s, loss=1.9660, acc=47.36%]


Epoch: 8/200 | Train Loss: 17.1161 | Train Acc: 45.33% | Val Loss: 1.9660 | Val Acc: 47.36%


Epoch 9/200 [Train]:  87%|████████▋ | 307/352 [01:43<00:15,  2.96it/s, loss=16.9784, nat_acc=45.42%, adv_acc=17.19%, lr=0.047621]


KeyboardInterrupt: 

### Targeted

In [None]:
# Adversarial training loop (targeted)
num_epochs = 200
best_acc = 0

for epoch in range(num_epochs):
    # Training phase
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0
    
    train_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
    for inputs, labels in train_bar:
        inputs, labels = inputs.to(device), labels.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad(set_to_none=True)
        
        # Clean data forward pass
        clean_outputs = model(inputs)
        clean_loss = criterion(clean_outputs, labels)
        
        # Generate targeted adversarial examples
        target_labels = (labels + 1) % 100  # Simple target: next class
        adv_images = improved_pgd_attack(
            model, 
            inputs, 
            labels,
            target_labels=target_labels, 
            epsilon=8/255,  # Standard for CIFAR-100 training
            alpha=2/255,   # Smaller for stability
            iters=7,       # Fewer for training efficiency
            random_start=True,
            untargeted=False
        )
        
        # Adversarial data forward pass
        adv_outputs = model(adv_images)
        adv_loss = criterion(adv_outputs, labels)  # Train to predict true labels
        
        # Combined loss (50% clean, 50% adversarial)
        loss = 0.5 * clean_loss + 0.5 * adv_loss
        loss.backward()
        optimizer.step()
        
        # Statistics
        train_loss += loss.item()
        _, predicted = clean_outputs.max(1)  # Use clean outputs for accuracy
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        # Update progress bar
        train_bar.set_postfix({
            'loss': f'{train_loss/len(train_bar):.4f}',
            'acc': f'{100.*correct/total:.2f}%',
            'lr': f'{scheduler.get_last_lr()[0]:.6f}'
        })
    
    train_loss = train_loss / len(train_loader)
    train_acc = 100. * correct / total
    
    # Validation phase
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    
    test_bar = tqdm(test_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Test]')
    with torch.no_grad():
        for inputs, labels in test_bar:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            # Update progress bar
            test_bar.set_postfix({
                'loss': f'{test_loss/len(test_bar):.4f}',
                'acc': f'{100.*correct/total:.2f}%'
            })
    
    test_loss = test_loss / len(test_loader)
    test_acc = 100. * correct / total
    
    # Update learning rate
    scheduler.step()
    
    # Save checkpoint if it's the best model so far
    if test_acc > best_acc:
        best_acc = test_acc
        checkpoint = {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'best_acc': best_acc,
        }
        torch.save(checkpoint, 'best_resnet18_cifar100_targeted_adv.pth')
        print(f'Checkpoint saved! New best accuracy: {best_acc:.2f}%')
    
    # Print epoch summary
    print(f'Epoch: {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%')

print(f"Training completed! Best accuracy: {best_acc:.2f}%")