In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import models, transforms, datasets
from torch.utils.data import DataLoader
from tqdm import tqdm
import io
from PIL import Image

# -------------------------
# SETTINGS
# -------------------------
ADV_EPSILON = 0.1
ADV_EPOCHS = 5
SAVE_PATH = "cifake_resnet18_adv_trained.pth"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -------------------------
# PATHS
# -------------------------
TRAIN_DIR = "/content/cifake_data/train"
TEST_DIR = "/content/cifake_data/test"

# -------------------------
# DATA TRANSFORMATIONS
# -------------------------
stats = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

def simulate_jpeg_compression(image_pil, quality=10):
    buffer = io.BytesIO()
    image_pil.save(buffer, "JPEG", quality=quality)
    buffer.seek(0)
    return Image.open(buffer)

train_transform = transforms.Compose([
    transforms.Lambda(lambda x: simulate_jpeg_compression(x, quality=30)),
    transforms.Resize((32, 32)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(*stats)
])

test_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(*stats)
])

# -------------------------
# DATA LOADERS
# -------------------------
train_set = datasets.ImageFolder(root=TRAIN_DIR, transform=train_transform)
test_set = datasets.ImageFolder(root=TEST_DIR, transform=test_transform)

train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2, pin_memory=True)

# -------------------------
# LOAD BASE MODEL (Phase 1 weights)
# -------------------------
model = models.resnet18()
model.fc = nn.Linear(model.fc.in_features, 2)

model.load_state_dict(torch.load("cifake_resnet18_latest.pth", map_location=device))
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)

print("Loaded base model for adversarial training")

# -------------------------
# FGSM FUNCTION
# -------------------------
def fgsm_attack(image, epsilon, grad):
    sign_grad = grad.sign()
    adv_image = image + epsilon * sign_grad
    return torch.clamp(adv_image, image.min(), image.max())

# -------------------------
# ADVERSARIAL TRAIN LOOP
# -------------------------
def adversarial_train_epoch(model, loader, epsilon):
    model.train()

    running_loss = 0
    correct = 0
    total = 0

    for imgs, labels in tqdm(loader, desc="Adv Training"):
        imgs, labels = imgs.to(device), labels.to(device)

        # ---------- STEP 1: create adversarial images ----------
        imgs.requires_grad = True
        outputs = model(imgs)
        loss = criterion(outputs, labels)

        model.zero_grad()
        loss.backward()

        adv_imgs = fgsm_attack(imgs, epsilon, imgs.grad.data)

        # ---------- STEP 2: combine clean + adv ----------
        combined_imgs = torch.cat([imgs.detach(), adv_imgs.detach()])
        combined_labels = torch.cat([labels, labels])

        # ---------- STEP 3: train on combined batch ----------
        optimizer.zero_grad()

        preds = model(combined_imgs)
        adv_loss = criterion(preds, combined_labels)

        adv_loss.backward()
        optimizer.step()

        running_loss += adv_loss.item()

        _, predicted = preds.max(1)
        total += combined_labels.size(0)
        correct += predicted.eq(combined_labels).sum().item()

    return running_loss/len(loader), 100*correct/total


# -------------------------
# EVALUATION FUNCTION
# -------------------------
def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0

    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)

            outputs = model(imgs)
            _, preds = outputs.max(1)

            total += labels.size(0)
            correct += preds.eq(labels).sum().item()

    return 100*correct/total


# -------------------------
# MAIN TRAINING LOOP
# -------------------------
print("\nStarting Adversarial Training...\n")

for epoch in range(ADV_EPOCHS):
    loss, train_acc = adversarial_train_epoch(model, train_loader, ADV_EPSILON)
    val_acc = evaluate(model, test_loader)

    print(f"Epoch [{epoch+1}/{ADV_EPOCHS}] | Loss: {loss:.4f} | Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}%")

# -------------------------
# SAVE ROBUST MODEL
# -------------------------
torch.save(model.state_dict(), SAVE_PATH)
print(f"\n✅ Adversarially trained model saved → {SAVE_PATH}")

In [None]:
# =========================================
# PHASE 3 TESTING — ROBUSTNESS CHECK
# =========================================

import torch
import torch.nn.functional as F
from torchvision import models
import torch.nn as nn
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

MODEL_PATH = "cifake_resnet18_adv_trained.pth"   # your new robust model
EPSILON = 0.1


# -------------------------
# LOAD ROBUST MODEL
# -------------------------
model = models.resnet18()
model.fc = nn.Linear(model.fc.in_features, 2)

model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.to(device)
model.eval()

print("Robust model loaded ✅")


# -------------------------
# FGSM ATTACK
# -------------------------
def fgsm_attack(image, epsilon, grad):
    sign_grad = grad.sign()
    adv_image = image + epsilon * sign_grad
    return torch.clamp(adv_image, image.min(), image.max())


# -------------------------
# CLEAN ACCURACY
# -------------------------
def clean_accuracy(model, loader):
    correct, total = 0, 0

    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)

            outputs = model(imgs)
            preds = outputs.argmax(1)

            total += labels.size(0)
            correct += (preds == labels).sum().item()

    return 100 * correct / total


# -------------------------
# ADVERSARIAL ACCURACY
# -------------------------
def adversarial_accuracy(model, loader, epsilon):
    correct, total = 0, 0
    flipped = 0
    fake_count = 0

    for imgs, labels in tqdm(loader, desc="Testing Attack"):
        imgs, labels = imgs.to(device), labels.to(device)

        imgs.requires_grad = True

        outputs = model(imgs)
        loss = F.cross_entropy(outputs, labels)

        model.zero_grad()
        loss.backward()

        adv_imgs = fgsm_attack(imgs, epsilon, imgs.grad.data)

        with torch.no_grad():
            adv_out = model(adv_imgs)
            preds = adv_out.argmax(1)

            total += labels.size(0)
            correct += (preds == labels).sum().item()

            fakes = (labels == 0)
            fake_count += fakes.sum().item()
            flipped += ((preds == 1) & fakes).sum().item()

    adv_acc = 100 * correct / total
    evasion_rate = 100 * flipped / fake_count

    return adv_acc, evasion_rate


# =========================================
# RUN TESTS
# =========================================

clean_acc = clean_accuracy(model, test_loader)
adv_acc, evasion = adversarial_accuracy(model, test_loader, EPSILON)

print("\n========= PHASE 3 RESULTS =========")
print(f"Clean Accuracy: {clean_acc:.2f}%")
print(f"Adversarial Accuracy (FGSM ε={EPSILON}): {adv_acc:.2f}%")
print(f"Evasion Rate (Fake → Real): {evasion:.2f}%")