In [None]:
!pip install adversarial-robustness-toolbox

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import numpy as np
from art.attacks.evasion import ProjectedGradientDescent
from art.estimators.classification import PyTorchClassifier

# =========================================================
# Config
# =========================================================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

BATCH_SIZE = 32
EPOCHS = 10
LR = 1e-4

IMG_SIZE = 224
NUM_CLASSES = 100   # ImageNet-100

# Loss weights
LAMBDA_RECON   = 1.0
LAMBDA_FEAT    = 0.5
LAMBDA_CLS     = 0.2
LAMBDA_RESID   = 0.05   # residual magnitude penalty

CLEAN_PROB = 0.7        # identity preservation

print("Using device:", DEVICE)

# =========================================================
# ImageNet normalization
# =========================================================
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1).to(DEVICE)
IMAGENET_STD  = torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1).to(DEVICE)

def normalize_imagenet(x):
    return (x - IMAGENET_MEAN) / IMAGENET_STD

# =========================================================
# Dataset
# =========================================================
transform_train = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor()
])

train_dataset = torchvision.datasets.ImageFolder(
    root="/kaggle/input/imagenet100",
    transform=transform_train
)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

# =========================================================
# Frozen Classifier
# =========================================================
classifier = torchvision.models.resnet18(weights="IMAGENET1K_V1")
classifier.fc = nn.Linear(512, NUM_CLASSES)
classifier = classifier.to(DEVICE)
classifier.eval()

for p in classifier.parameters():
    p.requires_grad = False

# =========================================================
# ART PGD (attack only)
# =========================================================
criterion = nn.CrossEntropyLoss()

art_classifier = PyTorchClassifier(
    model=classifier,
    loss=criterion,
    optimizer=None,
    input_shape=(3, IMG_SIZE, IMG_SIZE),
    nb_classes=NUM_CLASSES,
    device_type=DEVICE.type
)

pgd_attack = ProjectedGradientDescent(
    estimator=art_classifier,
    eps=8/255,
    eps_step=2/255,
    max_iter=10,
    batch_size=BATCH_SIZE
)

# =========================================================
# UNet Purifier (Artifact-free, Residual)
# =========================================================
class PurifierUNet(nn.Module):
    def __init__(self):
        super().__init__()

        def block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, 3, padding=1),
                nn.GroupNorm(8, out_c),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_c, out_c, 3, padding=1),
                nn.GroupNorm(8, out_c),
                nn.ReLU(inplace=True)
            )

        def up_block(in_c, out_c):
            return nn.Sequential(
                nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
                nn.Conv2d(in_c, out_c, 3, padding=1),
                nn.GroupNorm(8, out_c),
                nn.ReLU(inplace=True)
            )

        self.enc1 = block(3, 64)
        self.enc2 = block(64, 128)
        self.enc3 = block(128, 256)

        self.pool = nn.MaxPool2d(2)

        self.bottleneck = block(256, 512)

        self.up3 = up_block(512, 256)
        self.dec3 = block(512, 256)

        self.up2 = up_block(256, 128)
        self.dec2 = block(256, 128)

        self.up1 = up_block(128, 64)
        self.dec1 = nn.Sequential(
            nn.Conv2d(128, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, 3, padding=1)  # residual output
        )

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))

        b = self.bottleneck(self.pool(e3))

        d3 = self.dec3(torch.cat([self.up3(b), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))

        return d1  # residual

purifier = PurifierUNet().to(DEVICE)

# =========================================================
# Optimizer & Losses
# =========================================================
optimizer = optim.Adam(purifier.parameters(), lr=LR)
recon_loss = nn.L1Loss()
cls_loss = nn.CrossEntropyLoss()

# =========================================================
# Multi-layer perceptual features (ResNet)
# =========================================================
feat_l1 = nn.Sequential(
    classifier.conv1,
    classifier.bn1,
    classifier.relu,
    classifier.maxpool,
    classifier.layer1
).eval()

feat_l3 = nn.Sequential(
    classifier.conv1,
    classifier.bn1,
    classifier.relu,
    classifier.maxpool,
    classifier.layer1,
    classifier.layer2,
    classifier.layer3
).eval()

# =========================================================
# Training Loop
# =========================================================
for epoch in range(EPOCHS):
    purifier.train()
    total_loss = 0.0

    for x_clean, y in train_loader:
        x_clean = x_clean.to(DEVICE)
        y = y.to(DEVICE)

        # --------------------
        # Generate adversarial
        # --------------------
        x_adv = pgd_attack.generate(x_clean.cpu().numpy())
        x_adv = torch.tensor(x_adv, device=DEVICE)

        # --------------------
        # Mix clean & adv
        # --------------------
        mask = (torch.rand(x_clean.size(0), 1, 1, 1, device=DEVICE) < CLEAN_PROB)
        x_in = torch.where(mask, x_clean, x_adv)

        # --------------------
        # Residual purification
        # --------------------
        residual = purifier(x_in)
        x_pur = torch.clamp(x_in + residual, 0, 1)

        # --------------------
        # Losses
        # --------------------
        loss_recon = recon_loss(x_pur, x_clean)

        # Perceptual
        with torch.no_grad():
            f1_clean = feat_l1(normalize_imagenet(x_clean))
            f3_clean = feat_l3(normalize_imagenet(x_clean))

        f1_pur = feat_l1(normalize_imagenet(x_pur))
        f3_pur = feat_l3(normalize_imagenet(x_pur))

        loss_feat = (
            recon_loss(f1_pur, f1_clean) +
            0.5 * recon_loss(f3_pur, f3_clean)
        )

        # Classification (frozen)
        with torch.no_grad():
            logits = classifier(normalize_imagenet(x_pur))
        loss_cls = cls_loss(logits, y)

        # Residual magnitude regularization
        loss_resid = torch.mean(torch.abs(residual))

        # Total
        loss = (
            LAMBDA_RECON * loss_recon +
            LAMBDA_FEAT  * loss_feat +
            LAMBDA_CLS   * loss_cls +
            LAMBDA_RESID * loss_resid
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"[Epoch {epoch+1}/{EPOCHS}] Loss: {total_loss / len(train_loader):.4f}")

# =========================================================
# Save
# =========================================================
torch.save(purifier.state_dict(), "purifier_unet_imagenet_corrected.pth")
print("Corrected ImageNet purifier saved.")