In [None]:
!pip install adversarial-robustness-toolbox

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import numpy as np
from art.attacks.evasion import ProjectedGradientDescent
from art.estimators.classification import PyTorchClassifier

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 128
EPOCHS = 10
LR = 1e-3
LAMBDA_CLS = 0.5

print("Using device:", DEVICE)

CIFAR_MEAN = torch.tensor([0.4914, 0.4822, 0.4465]).view(1,3,1,1).to(DEVICE)
CIFAR_STD  = torch.tensor([0.2023, 0.1994, 0.2010]).view(1,3,1,1).to(DEVICE)

def normalize_cifar10(x):
    return (x - CIFAR_MEAN) / CIFAR_STD

transform = transforms.Compose([
    transforms.ToTensor()
])

train_dataset = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform
)

train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2
)

def make_resnet18_cifar10():
    model = torchvision.models.resnet18(weights=None)
    model.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
    model.maxpool = nn.Identity()
    model.fc = nn.Linear(model.fc.in_features, 10)
    return model

classifier = make_resnet18_cifar10().to(DEVICE)
classifier.load_state_dict(
    torch.load("/kaggle/input/baseline-model/pytorch/default/1/temp_baseline_resnet_cifar.pth", map_location=DEVICE)
)
classifier.eval()

criterion = nn.CrossEntropyLoss()

art_classifier = PyTorchClassifier(
    model=classifier,
    loss=criterion,
    optimizer=None,
    input_shape=(3,32,32),
    nb_classes=10,
    device_type=DEVICE.type
)

pgd_attack = ProjectedGradientDescent(
    estimator=art_classifier,
    eps=8/255,
    eps_step=2/255,
    max_iter=10,
    batch_size=BATCH_SIZE
)

class PurifierUNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.enc1 = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU()
        )
        self.pool1 = nn.MaxPool2d(2)

        self.enc2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU()
        )
        self.pool2 = nn.MaxPool2d(2)

        self.bottleneck = nn.Sequential(
            nn.Conv2d(128, 256, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.ReLU()
        )

        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = nn.Sequential(
            nn.Conv2d(256, 128, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU()
        )

        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = nn.Sequential(
            nn.Conv2d(128, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 3, 3, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        e1 = self.enc1(x)
        p1 = self.pool1(e1)

        e2 = self.enc2(p1)
        p2 = self.pool2(e2)

        b = self.bottleneck(p2)

        d2 = self.up2(b)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))

        d1 = self.up1(d2)
        out = self.dec1(torch.cat([d1, e1], dim=1))
        return out

purifier = PurifierUNet().to(DEVICE)

optimizer = optim.Adam(purifier.parameters(), lr=LR)
recon_loss = nn.MSELoss()
cls_loss = nn.CrossEntropyLoss()

for epoch in range(EPOCHS):
    purifier.train()
    running_loss = 0.0

    for x_clean, y in train_loader:
        x_clean = x_clean.to(DEVICE)
        y = y.to(DEVICE)

        x_adv = pgd_attack.generate(x_clean.detach().cpu().numpy())
        x_adv = torch.tensor(x_adv).to(DEVICE)

        x_pur = purifier(x_adv)

        loss_recon = recon_loss(x_pur, x_clean)
        logits = classifier(normalize_cifar10(x_pur))
        loss_cls = cls_loss(logits, y)

        loss = loss_recon + LAMBDA_CLS * loss_cls

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"[Epoch {epoch+1}/{EPOCHS}] Loss: {running_loss/len(train_loader):.4f}")

torch.save(purifier.state_dict(), "purifier_unet_cifar10.pth")
print("Purifier model saved.")