In [None]:
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import (
    Compose, ToTensor, Normalize,
    RandomHorizontalFlip, RandomCrop
)
from tqdm import tqdm
import numpy as np
from sklearn.metrics import (
    accuracy_score, precision_score,
    recall_score, f1_score
)

# ================== Device & AMP Setup ==================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
scaler = GradScaler()

# ================== Model Architecture ==================
class ResidualBlock(nn.Module):
    def __init__(self, in_planes, planes, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, 3, stride, 1, bias=False)
        self.bn1   = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, 3, 1, 1, bias=False)
        self.bn2   = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, 1, stride, bias=False),
                nn.BatchNorm2d(planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return F.relu(out)

class WideResNet(nn.Module):
    def __init__(self, depth=28, widen_factor=10, num_classes=10):
        super().__init__()
        self.in_planes = 16
        n = (depth - 4) // 6

        self.conv1 = nn.Conv2d(3, 16, 3, 1, 1, bias=False)
        self.bn1   = nn.BatchNorm2d(16)

        self.layer1 = self._make_layer(n, 16 * widen_factor, stride=1)
        self.layer2 = self._make_layer(n, 32 * widen_factor, stride=2)
        self.layer3 = self._make_layer(n, 64 * widen_factor, stride=2)

        self.linear = nn.Linear(64 * widen_factor, num_classes)

    def _make_layer(self, num_blocks, planes, stride):
        layers = [ResidualBlock(self.in_planes, planes, stride)]
        self.in_planes = planes
        for _ in range(1, num_blocks):
            layers.append(ResidualBlock(self.in_planes, planes))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = F.avg_pool2d(x, 8)
        x = x.view(x.size(0), -1)
        return self.linear(x)

# ================== Data Loading ==================
def get_loaders(batch_size=128, num_workers=8, prefetch_factor=2):
    transform_train = Compose([
        RandomCrop(32, padding=4),
        RandomHorizontalFlip(),
        ToTensor(),
        Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))
    ])
    transform_test = Compose([
        ToTensor(),
        Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))
    ])

    train_set = CIFAR10(
        root='./data', train=True, download=True,
        transform=transform_train
    )
    test_set = CIFAR10(
        root='./data', train=False, download=True,
        transform=transform_test
    )

    train_loader = DataLoader(
        train_set, batch_size=batch_size,
        shuffle=True, num_workers=num_workers,
        pin_memory=True, prefetch_factor=prefetch_factor
    )
    test_loader = DataLoader(
        test_set, batch_size=256,
        shuffle=False, num_workers=num_workers,
        pin_memory=True, prefetch_factor=prefetch_factor
    )

    return train_loader, test_loader

# ================== Attack Implementations ==================
def fgsm_attack(model, images, labels, epsilon):
    images.requires_grad_()
    with autocast():
        outputs = model(images)
        loss = F.cross_entropy(outputs, labels)
    scaler.scale(loss).backward()
    adv = images + epsilon * images.grad.sign()
    return torch.clamp(adv, 0, 1).detach()

def pgd_attack(model, images, labels, epsilon, alpha, steps):
    orig = images.detach()
    # random start
    images = orig + torch.empty_like(orig).uniform_(-epsilon, epsilon)
    images = torch.clamp(images, 0, 1).detach()

    for _ in range(steps):
        images.requires_grad_()
        with autocast():
            outputs = model(images)
            loss = F.cross_entropy(outputs, labels)
        grad = torch.autograd.grad(loss, images)[0]
        images = images + alpha * grad.sign()
        images = torch.max(torch.min(images, orig + epsilon), orig - epsilon)
        images = torch.clamp(images, 0, 1).detach()
    return images

# ================== Training & Evaluation ==================
def train_model(model, loader, optimizer, scheduler, scaler,
                use_defense=False, epsilon=8/255, alpha=2/255, steps=7):
    model.train()
    total_loss = 0.0

    for images, labels in tqdm(loader, desc="Training"):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()

        with autocast():
            if use_defense:
                adv_images = pgd_attack(
                    model, images, labels,
                    epsilon=epsilon, alpha=alpha, steps=steps
                )
                outputs = model(adv_images)
            else:
                outputs = model(images)
            loss = F.cross_entropy(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        total_loss += loss.item()

    return total_loss / len(loader)

def evaluate(model, loader):
    model.eval()
    results = {}

    # clean
    all_preds, all_labels = [], []
    with torch.no_grad():
        for imgs, lbls in loader:
            imgs, lbls = imgs.to(device), lbls.to(device)
            outputs = model(imgs)
            all_preds.extend(outputs.argmax(1).cpu().numpy())
            all_labels.extend(lbls.cpu().numpy())
    results['clean'] = {
        'accuracy': accuracy_score(all_labels, all_preds),
        'precision': precision_score(all_labels, all_preds, average='macro'),
        'recall': recall_score(all_labels, all_preds, average='macro'),
        'f1': f1_score(all_labels, all_preds, average='macro'),
    }

    # adversarial
    attacks = {
        'fgsm_4':   (fgsm_attack, {'epsilon': 4/255}),
        'fgsm_8':   (fgsm_attack, {'epsilon': 8/255}),
        'pgd_4_5':  (pgd_attack, {'epsilon':4/255,'alpha':1/255,'steps':5}),
        'pgd_8_10': (pgd_attack, {'epsilon':8/255,'alpha':2/255,'steps':10}),
    }
    for name, (fn, params) in attacks.items():
        preds, labels = [], []
        for imgs, lbls in loader:
            imgs, lbls = imgs.to(device), lbls.to(device)
            adv = fn(model, imgs, lbls, **params)
            with torch.no_grad():
                out = model(adv)
            preds.extend(out.argmax(1).cpu().numpy())
            labels.extend(lbls.cpu().numpy())
        results[name] = {
            'accuracy': accuracy_score(labels, preds),
            'precision': precision_score(labels, preds, average='macro'),
            'recall': recall_score(labels, preds, average='macro'),
            'f1': f1_score(labels, preds, average='macro'),
        }

    return results

# ================== Main Script ==================
if __name__ == "__main__":
    train_loader, test_loader = get_loaders(batch_size=128)
    results = {}
    NUM_EPOCHS = 25

    try:
        # Phase 1: Clean
        print("\n=== Phase 1: Training Clean Model ===")
        model = WideResNet().to(device)
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)

        optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=5e-4)
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=NUM_EPOCHS, T_mult=1)

        for epoch in range(NUM_EPOCHS):
            loss = train_model(
                model, train_loader,
                optimizer, scheduler,
                scaler, use_defense=False
            )
            torch.cuda.empty_cache()
            print(f"[Clean] Epoch {epoch+1:02d}/{NUM_EPOCHS} — Loss: {loss:.4f}")

        print("\nEvaluating clean model…")
        results['clean_model'] = evaluate(model, test_loader)
        torch.save(model.state_dict(), 'clean_model.pth')
        del model, optimizer, scheduler
        torch.cuda.empty_cache()

        # Phase 2: Defense
        print("\n=== Phase 2: Training Defense Model ===")
        model = WideResNet().to(device)
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)

        optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=5e-4)
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=NUM_EPOCHS, T_mult=1)

        for epoch in range(NUM_EPOCHS):
            loss = train_model(
                model, train_loader,
                optimizer, scheduler,
                scaler, use_defense=True,
                epsilon=8/255, alpha=2/255, steps=7
            )
            torch.cuda.empty_cache()
            print(f"[Defense] Epoch {epoch+1:02d}/{NUM_EPOCHS} — Loss: {loss:.4f}")

        print("\nEvaluating defense model…")
        results['defense_model'] = evaluate(model, test_loader)
        torch.save(model.state_dict(), 'defense_model.pth')

    except Exception as e:
        print(f"\nError: {e}\nSaving partial results…")
        with open('results_run2.json', 'w') as f:
            json.dump(results, f, indent=4)
        exit(1)

    # Saving final results
    with open('results_run2.json', 'w') as f:
        json.dump(results, f, indent=4)
    print("\n=== All done! Results saved to results.json ===")


Using device: cuda


  scaler = GradScaler()


Files already downloaded and verified
Files already downloaded and verified

=== Phase 1: Training Clean Model ===


  with autocast():
Training: 100%|██████████| 391/391 [02:16<00:00,  2.86it/s]


[Clean] Epoch 01/25 — Loss: 1.6660


Training: 100%|██████████| 391/391 [02:06<00:00,  3.08it/s]


[Clean] Epoch 02/25 — Loss: 1.1803


Training: 100%|██████████| 391/391 [02:13<00:00,  2.92it/s]


[Clean] Epoch 03/25 — Loss: 0.9849


Training: 100%|██████████| 391/391 [02:08<00:00,  3.05it/s]


[Clean] Epoch 04/25 — Loss: 0.8490


Training: 100%|██████████| 391/391 [02:15<00:00,  2.89it/s]


[Clean] Epoch 05/25 — Loss: 0.7382


Training: 100%|██████████| 391/391 [02:26<00:00,  2.68it/s]


[Clean] Epoch 06/25 — Loss: 0.6498


Training: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]


[Clean] Epoch 07/25 — Loss: 0.5703


Training: 100%|██████████| 391/391 [02:33<00:00,  2.55it/s]


[Clean] Epoch 08/25 — Loss: 0.5257


Training: 100%|██████████| 391/391 [02:33<00:00,  2.55it/s]


[Clean] Epoch 09/25 — Loss: 0.4819


Training: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]


[Clean] Epoch 10/25 — Loss: 0.4347


Training: 100%|██████████| 391/391 [02:33<00:00,  2.55it/s]


[Clean] Epoch 11/25 — Loss: 0.4038


Training: 100%|██████████| 391/391 [02:30<00:00,  2.60it/s]


[Clean] Epoch 12/25 — Loss: 0.3739


Training: 100%|██████████| 391/391 [02:26<00:00,  2.66it/s]


[Clean] Epoch 13/25 — Loss: 0.3415


Training: 100%|██████████| 391/391 [02:26<00:00,  2.67it/s]


[Clean] Epoch 14/25 — Loss: 0.3198


Training: 100%|██████████| 391/391 [02:26<00:00,  2.66it/s]


[Clean] Epoch 15/25 — Loss: 0.2965


Training: 100%|██████████| 391/391 [02:26<00:00,  2.66it/s]


[Clean] Epoch 16/25 — Loss: 0.2731


Training: 100%|██████████| 391/391 [02:26<00:00,  2.66it/s]


[Clean] Epoch 17/25 — Loss: 0.2527


Training: 100%|██████████| 391/391 [02:26<00:00,  2.67it/s]


[Clean] Epoch 18/25 — Loss: 0.2351


Training: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]


[Clean] Epoch 19/25 — Loss: 0.2175


Training: 100%|██████████| 391/391 [02:32<00:00,  2.57it/s]


[Clean] Epoch 20/25 — Loss: 0.2010


Training: 100%|██████████| 391/391 [02:24<00:00,  2.71it/s]


[Clean] Epoch 21/25 — Loss: 0.1942


Training: 100%|██████████| 391/391 [02:07<00:00,  3.08it/s]


[Clean] Epoch 22/25 — Loss: 0.1756


Training: 100%|██████████| 391/391 [02:06<00:00,  3.08it/s]


[Clean] Epoch 23/25 — Loss: 0.1682


Training: 100%|██████████| 391/391 [02:07<00:00,  3.07it/s]


[Clean] Epoch 24/25 — Loss: 0.1522


Training: 100%|██████████| 391/391 [02:07<00:00,  3.07it/s]


[Clean] Epoch 25/25 — Loss: 0.1499

Evaluating clean model…


  with autocast():
  with autocast():
  with autocast():
  with autocast():



=== Phase 2: Training Defense Model ===


  with autocast():
  with autocast():
Training: 100%|██████████| 391/391 [09:52<00:00,  1.52s/it]


[Defense] Epoch 01/25 — Loss: 1.9927


Training: 100%|██████████| 391/391 [09:48<00:00,  1.50s/it]


[Defense] Epoch 02/25 — Loss: 1.7299


Training: 100%|██████████| 391/391 [09:40<00:00,  1.48s/it]


[Defense] Epoch 03/25 — Loss: 1.5906


Training: 100%|██████████| 391/391 [09:39<00:00,  1.48s/it]


[Defense] Epoch 04/25 — Loss: 1.4842


Training: 100%|██████████| 391/391 [09:38<00:00,  1.48s/it]


[Defense] Epoch 05/25 — Loss: 1.4104


Training: 100%|██████████| 391/391 [09:38<00:00,  1.48s/it]


[Defense] Epoch 06/25 — Loss: 1.3321


Training: 100%|██████████| 391/391 [09:38<00:00,  1.48s/it]


[Defense] Epoch 07/25 — Loss: 1.2630


Training: 100%|██████████| 391/391 [09:39<00:00,  1.48s/it]


[Defense] Epoch 08/25 — Loss: 1.1973


Training: 100%|██████████| 391/391 [09:25<00:00,  1.45s/it]


[Defense] Epoch 09/25 — Loss: 1.1563


Training: 100%|██████████| 391/391 [09:22<00:00,  1.44s/it]


[Defense] Epoch 10/25 — Loss: 1.0952


Training: 100%|██████████| 391/391 [09:25<00:00,  1.45s/it]


[Defense] Epoch 11/25 — Loss: 1.0512


Training: 100%|██████████| 391/391 [09:22<00:00,  1.44s/it]


[Defense] Epoch 12/25 — Loss: 1.0216


Training: 100%|██████████| 391/391 [09:21<00:00,  1.44s/it]


[Defense] Epoch 13/25 — Loss: 0.9789


Training: 100%|██████████| 391/391 [09:21<00:00,  1.44s/it]


[Defense] Epoch 14/25 — Loss: 0.9428


Training: 100%|██████████| 391/391 [09:22<00:00,  1.44s/it]


[Defense] Epoch 15/25 — Loss: 0.9049


Training: 100%|██████████| 391/391 [09:22<00:00,  1.44s/it]


[Defense] Epoch 16/25 — Loss: 0.8781


Training: 100%|██████████| 391/391 [09:22<00:00,  1.44s/it]


[Defense] Epoch 17/25 — Loss: 0.8475


Training: 100%|██████████| 391/391 [09:21<00:00,  1.44s/it]


[Defense] Epoch 18/25 — Loss: 0.8182


Training: 100%|██████████| 391/391 [09:23<00:00,  1.44s/it]


[Defense] Epoch 19/25 — Loss: 0.7938


Training: 100%|██████████| 391/391 [09:25<00:00,  1.45s/it]


[Defense] Epoch 20/25 — Loss: 0.7663


Training: 100%|██████████| 391/391 [09:25<00:00,  1.45s/it]


[Defense] Epoch 21/25 — Loss: 0.7397


Training: 100%|██████████| 391/391 [09:22<00:00,  1.44s/it]


[Defense] Epoch 22/25 — Loss: 0.7151


Training: 100%|██████████| 391/391 [09:23<00:00,  1.44s/it]


[Defense] Epoch 23/25 — Loss: 0.6921


Training: 100%|██████████| 391/391 [09:23<00:00,  1.44s/it]


[Defense] Epoch 24/25 — Loss: 0.6638


Training: 100%|██████████| 391/391 [10:37<00:00,  1.63s/it]


[Defense] Epoch 25/25 — Loss: 0.6341

Evaluating defense model…


  with autocast():
  with autocast():
  with autocast():
  with autocast():



=== All done! Results saved to results.json ===
