In [3]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
from torchvision.datasets import CIFAR100
import numpy as np
from tqdm import tqdm

In [4]:
class CIFAR100DataModule(pl.LightningDataModule):
    def __init__(self, data_dir='./data', batch_size=128, num_workers=4):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        
        # Data transforms
        self.transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
        ])
        
        self.transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
        ])
    
    def prepare_data(self):
        # Download data if needed
        CIFAR100(self.data_dir, train=True, download=True)
        CIFAR100(self.data_dir, train=False, download=True)
    
    def setup(self, stage=None):
        # Load datasets
        cifar_full = CIFAR100(self.data_dir, train=True, transform=self.transform_train)
        
        # Split into train and validation
        train_size = int(0.9 * len(cifar_full))
        val_size = len(cifar_full) - train_size
        self.train_dataset, self.val_dataset = random_split(
            cifar_full, [train_size, val_size], 
            generator=torch.Generator().manual_seed(42)
        )
        
        # Create a separate validation dataset with test transforms
        cifar_val = CIFAR100(self.data_dir, train=True, transform=self.transform_test)
        _, val_indices = random_split(
            range(len(cifar_full)), [train_size, val_size],
            generator=torch.Generator().manual_seed(42)
        )
        self.val_dataset = torch.utils.data.Subset(cifar_val, val_indices.indices)
        
        self.test_dataset = CIFAR100(self.data_dir, train=False, transform=self.transform_test)
    
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset, batch_size=self.batch_size, shuffle=True,
            num_workers=self.num_workers, pin_memory=True
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.val_dataset, batch_size=self.batch_size, shuffle=False,
            num_workers=self.num_workers, pin_memory=True
        )
    
    def test_dataloader(self):
        return DataLoader(
            self.test_dataset, batch_size=self.batch_size, shuffle=False,
            num_workers=self.num_workers, pin_memory=True
        )


In [5]:
class AdversarialTrainingModule(pl.LightningModule):
    def __init__(self, 
                 pretrained_path=None,
                 epsilon=8/255, 
                 alpha=2/255, 
                 attack_iters=7, 
                 trades_beta=8.0,
                 learning_rate=0.05,
                 momentum=0.9,
                 weight_decay=5e-4,
                 lr_scheduler="cosine",
                 max_epochs=200):
        super().__init__()
        self.save_hyperparameters()
        
        # Model definition
        self.model = models.resnet18(weights=None)
        self.model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.model.fc = nn.Linear(self.model.fc.in_features, 100)
        
        # Load pretrained model if provided
        if pretrained_path:
            checkpoint = torch.load(pretrained_path)
            self.model.load_state_dict(checkpoint['state_dict'])
            print(f"Loaded pretrained model from {pretrained_path}")
        
        # Save clean and robust accuracy as attributes
        self.best_clean_acc = 0.0
        self.best_robust_acc = 0.0
        
        # Class weights for balanced loss
        self.class_weights = None
        self.weight_update_counter = 0
    
    def forward(self, x):
        return self.model(x)
    
    def trades_loss(self, x_natural, y, perturb_steps=10):
        """TRADES loss for adversarial robustness"""
        # Loss functions
        criterion_kl = nn.KLDivLoss(reduction='batchmean')
        
        # Generate adversarial examples
        x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape, device=self.device)
        
        for _ in range(perturb_steps):
            x_adv.requires_grad_()
            with torch.enable_grad():
                loss_kl = criterion_kl(
                    F.log_softmax(self.model(x_adv), dim=1),
                    F.softmax(self.model(x_natural), dim=1)
                )
            grad = torch.autograd.grad(loss_kl, [x_adv])[0]
            x_adv = x_adv.detach() + self.hparams.alpha * torch.sign(grad.detach())
            x_adv = torch.min(torch.max(x_adv, x_natural - self.hparams.epsilon), 
                             x_natural + self.hparams.epsilon)
            x_adv = torch.clamp(x_adv, 0.0, 1.0)
        
        # Calculate the TRADES loss
        logits_natural = self.model(x_natural)
        logits_adv = self.model(x_adv)
        
        if self.class_weights is not None and self.class_weights.device != self.device:
            self.class_weights = self.class_weights.to(self.device)
        
        # Natural loss (possibly with class weights)
        if self.class_weights is not None:
            loss_natural = F.cross_entropy(logits_natural, y, weight=self.class_weights)
        else:
            loss_natural = F.cross_entropy(logits_natural, y)
        
        # KL divergence
        loss_robust = criterion_kl(
            F.log_softmax(logits_adv, dim=1),
            F.softmax(logits_natural, dim=1)
        )
        
        # Combined loss
        loss = loss_natural + self.hparams.trades_beta * loss_robust
        
        return loss, logits_natural, logits_adv, x_adv
    
    def compute_class_weights(self, dataloader):
        """Compute class weights based on model confusion"""
        confusion = torch.zeros(100, 100, device=self.device)
        self.eval()
        with torch.no_grad():
            for inputs, labels in dataloader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                outputs = self.model(inputs)
                _, preds = outputs.max(1)
                for t, p in zip(labels.view(-1), preds.view(-1)):
                    confusion[t.long(), p.long()] += 1
        
        # Normalize and compute weights
        confusion = confusion / (confusion.sum(dim=1, keepdim=True) + 1e-8)
        diag_indices = torch.arange(100, device=self.device)
        confusion[diag_indices, diag_indices] = 0
        class_weights = confusion.sum(dim=1)
        
        # Normalize weights
        class_weights = class_weights / class_weights.mean()
        class_weights = 0.5 + class_weights / 2
        
        return class_weights
    
    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        
        # Update class weights occasionally
        if self.trainer.current_epoch > 0 and self.trainer.current_epoch % 15 == 0 and self.weight_update_counter != self.trainer.current_epoch:
            self.weight_update_counter = self.trainer.current_epoch
            self.class_weights = self.compute_class_weights(self.trainer.val_dataloaders[0])
            self.log('class_weights_min', self.class_weights.min())
            self.log('class_weights_max', self.class_weights.max())
        
        # Use TRADES loss
        loss, logits_natural, logits_adv, adv_images = self.trades_loss(
            inputs, labels, perturb_steps=self.hparams.attack_iters
        )
        
        # Calculate metrics
        _, nat_preds = logits_natural.max(1)
        nat_acc = nat_preds.eq(labels).float().mean()
        
        _, adv_preds = logits_adv.max(1)
        adv_acc = adv_preds.eq(labels).float().mean()
        
        # Log metrics
        self.log('train_loss', loss, prog_bar=True, sync_dist=True)
        self.log('train_nat_acc', nat_acc * 100.0, prog_bar=True, sync_dist=True)
        self.log('train_adv_acc', adv_acc * 100.0, prog_bar=True, sync_dist=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self.model(inputs)
        loss = F.cross_entropy(outputs, labels)
        
        _, preds = outputs.max(1)
        acc = preds.eq(labels).float().mean()
        
        # Log metrics
        self.log('val_loss', loss, prog_bar=True, sync_dist=True)
        self.log('val_acc', acc * 100.0, prog_bar=True, sync_dist=True)
        
        return {'val_loss': loss, 'val_acc': acc}
    
    def test_step(self, batch, batch_idx):
        inputs, labels = batch
        
        # Clean accuracy
        with torch.no_grad():
            outputs = self.model(inputs)
            _, preds = outputs.max(1)
            clean_acc = preds.eq(labels).float().mean()
        
        # Generate adversarial examples with PGD
        x_adv = inputs.clone() + 0.001 * torch.randn(inputs.shape, device=self.device)
        x_adv = torch.clamp(x_adv, 0.0, 1.0)
        
        for _ in range(20):  # More iterations for test-time evaluation
            x_adv.requires_grad_()
            with torch.enable_grad():
                outputs_adv = self.model(x_adv)
                loss = F.cross_entropy(outputs_adv, labels)
            
            grad = torch.autograd.grad(loss, [x_adv])[0]
            x_adv = x_adv.detach() + self.hparams.alpha * torch.sign(grad.detach())
            delta = torch.clamp(x_adv - inputs, -self.hparams.epsilon, self.hparams.epsilon)
            x_adv = torch.clamp(inputs + delta, 0.0, 1.0)
        
        # Evaluate on adversarial examples
        with torch.no_grad():
            outputs = self.model(x_adv)
            _, preds = outputs.max(1)
            robust_acc = preds.eq(labels).float().mean()
        
        # Log metrics
        self.log('test_clean_acc', clean_acc * 100.0, prog_bar=True, sync_dist=True)
        self.log('test_robust_acc', robust_acc * 100.0, prog_bar=True, sync_dist=True)
        
        return {'test_clean_acc': clean_acc, 'test_robust_acc': robust_acc}
    
    def on_test_epoch_end(self):
        # This is called automatically at the end of the test epoch
        pass
    
    def configure_optimizers(self):
        # Initialize optimizer with reduced learning rate for pretrained model
        optimizer = optim.SGD(
            self.parameters(),
            lr=self.hparams.learning_rate,
            momentum=self.hparams.momentum,
            weight_decay=self.hparams.weight_decay
        )
        
        # Configure learning rate scheduler
        if self.hparams.lr_scheduler == "cosine":
            scheduler = optim.lr_scheduler.CosineAnnealingLR(
                optimizer, 
                T_max=self.hparams.max_epochs
            )
        else:
            scheduler = optim.lr_scheduler.MultiStepLR(
                optimizer,
                milestones=[60, 120, 160],
                gamma=0.2
            )
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss"
            }
        }


In [6]:
def train_adversarial_model(
    pretrained_path='best_resnet18_cifar100_untargeted_adv.pth',
    batch_size=128,
    max_epochs=200,
    learning_rate=0.05,
    epsilon=8/255,
    alpha=2/255,
    attack_iters=7,
    trades_beta=8.0,
    num_workers=4
):
    # Set up data module
    data_module = CIFAR100DataModule(
        batch_size=batch_size,
        num_workers=num_workers
    )
    
    # Initialize model
    model = AdversarialTrainingModule(
        pretrained_path=pretrained_path,
        epsilon=epsilon,
        alpha=alpha,
        attack_iters=attack_iters,
        trades_beta=trades_beta,
        learning_rate=learning_rate,
        max_epochs=max_epochs
    )
    
    # Define callbacks
    checkpoint_callback_clean = ModelCheckpoint(
        monitor='val_acc',
        filename='best-clean-{epoch:02d}-{val_acc:.2f}',
        save_top_k=1,
        mode='max',
        save_last=True
    )
    
    checkpoint_callback_robust = ModelCheckpoint(
        monitor='test_robust_acc',
        filename='best-robust-{epoch:02d}-{test_robust_acc:.2f}',
        save_top_k=1,
        mode='max'
    )
    
    lr_monitor = LearningRateMonitor(logging_interval='epoch')
    
    # Set up logger
    logger = TensorBoardLogger("lightning_logs", name="adversarial_training")
    
    # Initialize trainer with GPU strategy
    trainer = pl.Trainer(
        accelerator='gpu',
        devices=2,  # Using 2 GPUs
        strategy='ddp_spawn',  # Distributed Data Parallel
        max_epochs=max_epochs,
        callbacks=[checkpoint_callback_clean, checkpoint_callback_robust, lr_monitor],
        logger=logger,
        precision=16,  # You can try 16 for mixed precision training
        log_every_n_steps=50,
        # Run testing every 3 epochs
        check_val_every_n_epoch=3
    )
    
    # Train the model
    trainer.fit(model, data_module)
    
    # Evaluate on test set
    trainer.test(model, data_module)
    
    return model, trainer

In [7]:
if __name__ == "__main__":
    # Set random seeds for reproducibility
    pl.seed_everything(42)
    
    # # Make sure CUDA is available
    # print(f"CUDA available: {torch.cuda.is_available()}")
    # print(f"Number of GPUs: {torch.cuda.device_count()}")
    
    # Start training
    model, trainer = train_adversarial_model(
        pretrained_path='best_resnet18_cifar100_untargeted_adv.pth',
        batch_size=128,
        max_epochs=200,
        learning_rate=0.05,  # Reduced from 0.1 since we're using a pretrained model
        epsilon=8/255,
        alpha=2/255,
        attack_iters=7,
        trades_beta=8.0
    )
    
    print("Training completed!")
    print(f"Best clean accuracy: {model.best_clean_acc:.2f}%")
    print(f"Best robust accuracy: {model.best_robust_acc:.2f}%")


Seed set to 42
/home/pnagaraj/miniconda3/lib/python3.12/site-packages/lightning_fabric/connector.py:571: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
Using 16bit Automatic Mixed Precision (AMP)


Loaded pretrained model from best_resnet18_cifar100_untargeted_adv.pth


MisconfigurationException: `Trainer(strategy='ddp_spawn')` is not compatible with an interactive environment. Run your code as a script, or choose a notebook-compatible strategy: `Trainer(strategy='ddp_notebook')`. In case you are spawning processes yourself, make sure to include the Trainer creation inside the worker function.