In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tqdm import tqdm
import numpy as np
import os
from scipy.linalg import sqrtm
from torchvision.models import inception_v3, Inception_V3_Weights, resnet18, ResNet18_Weights
from sklearn.metrics import accuracy_score

# Configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DATASET = "cifar10"  # "mnist" or "cifar10"
LATENT_DIM = 128
IMG_SIZE = 32  # Inception v3 requires 299x299
CHANNELS = 3  # Force 3 channels for all datasets
BATCH_SIZE = 1024
EPOCHS = 100
EPSILON = 0.1
LR = 2e-4
METRIC_FREQ = 5

# Create directories
os.makedirs("images", exist_ok=True)
os.makedirs("metrics", exist_ok=True)

# Data Loading with MNIST->RGB conversion
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*CHANNELS, [0.5]*CHANNELS)
])

if DATASET == "cifar10":
    dataset = datasets.CIFAR10(
        root='./data',
        train=True,
        download=True,
        transform=transform
    )
elif DATASET == "mnist":
    dataset = datasets.MNIST(
        root='./data',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.Grayscale(num_output_channels=3),  # Convert to 3 channels
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3)
        ])
    )

dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

# Model Architectures
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(LATENT_DIM, 512 * 4 * 4),
            nn.Unflatten(1, (512, 4, 4)),
            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, CHANNELS, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, z):
        return self.net(z)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(CHANNELS, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.AdaptiveAvgPool2d((4, 4)),
            nn.Flatten(),
            nn.Linear(512*4*4, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.net(img)

class Attacker(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(CHANNELS, 32, 3, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, CHANNELS, 3, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        return self.net(x) * EPSILON

# Initialize models
generator = Generator().to(DEVICE)
discriminator = Discriminator().to(DEVICE)
attacker = Attacker().to(DEVICE)

# Optimizers
opt_G = optim.Adam(generator.parameters(), lr=LR, betas=(0.5, 0.999))
opt_D = optim.Adam(discriminator.parameters(), lr=LR, betas=(0.5, 0.999))
opt_A = optim.Adam(attacker.parameters(), lr=LR, betas=(0.5, 0.999))

# AMP Scaler
scaler = torch.amp.GradScaler('cuda')

# Metrics setup
inception_model = resnet18(weights=ResNet18_Weights).to(DEVICE)
inception_model.aux_logits = False
inception_model.eval()

def calculate_fid(real_imgs, fake_imgs, batch_size=100):
    """Calculate Frechet Inception Distance"""
    # Normalize to [0,1]
    real_imgs = (real_imgs + 1) * 0.5
    fake_imgs = (fake_imgs + 1) * 0.5

    # Normalize with Inception v3 stats
    mean = torch.tensor([0.485, 0.456, 0.406], device=DEVICE).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225], device=DEVICE).view(1, 3, 1, 1)

    real_imgs = (real_imgs - mean) / std
    fake_imgs = (fake_imgs - mean) / std

    real_features, fake_features = [], []

    with torch.no_grad():
        # Process real images
        for i in range(0, len(real_imgs), batch_size):
            batch = real_imgs[i:i+batch_size]
            features = inception_model(batch)
            real_features.append(features.cpu().numpy())

        # Process fake images
        for i in range(0, len(fake_imgs), batch_size):
            batch = fake_imgs[i:i+batch_size]
            features = inception_model(batch)
            fake_features.append(features.cpu().numpy())

    # Calculate statistics
    real_features = np.concatenate(real_features, axis=0)
    fake_features = np.concatenate(fake_features, axis=0)

    mu_real, sigma_real = np.mean(real_features, axis=0), np.cov(real_features, rowvar=False)
    mu_fake, sigma_fake = np.mean(fake_features, axis=0), np.cov(fake_features, rowvar=False)

    # FID calculation
    ssdiff = np.sum((mu_real - mu_fake)**2)
    covmean = sqrtm(sigma_real.dot(sigma_fake))
    fid = ssdiff + np.trace(sigma_real + sigma_fake - 2*covmean.real)

    return fid

def pgd_attack(model, images, epsilon=0.1, alpha=0.01, iters=10):
    """PGD attack validation"""
    orig_images = images.detach().clone()
    delta = torch.zeros_like(images).uniform_(-epsilon, epsilon)

    for _ in range(iters):
        delta.requires_grad = True
        outputs = model(orig_images + delta)
        loss = -torch.mean(outputs)

        model.zero_grad()
        loss.backward()

        data_grad = delta.grad.detach()
        delta = delta.detach() + alpha * data_grad.sign()
        delta = torch.clamp(delta, -epsilon, epsilon)
        delta = torch.clamp(orig_images + delta, -1, 1) - orig_images

    return orig_images + delta

def save_samples(epoch, generator, n_samples=100):
    """Save generated images"""
    with torch.no_grad():
        z = torch.randn(n_samples, LATENT_DIM, device=DEVICE)
        gen_imgs = generator(z)
        gen_imgs = gen_imgs * 0.5 + 0.5  # Denormalize
        save_image(gen_imgs, f"images/epoch_{epoch+1}.png", nrow=10)

# Precompute real samples for FID
real_samples = torch.stack([dataset[i][0] for i in range(5000)]).to(DEVICE)
metrics = {'fid': [], 'attack_success': []}

# Training loop
for epoch in range(EPOCHS):
    generator.train()
    discriminator.train()
    attacker.train()

    loop = tqdm(dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}", leave=False)
    for real_imgs, _ in loop:
        real_imgs = real_imgs.to(DEVICE)
        batch_size = real_imgs.size(0)

        # Train Attacker
        with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
            z = torch.randn(batch_size, LATENT_DIM, device=DEVICE)
            fake_imgs = generator(z)
            pert_real = attacker(real_imgs)
            pert_fake = attacker(fake_imgs.detach())

            d_real = discriminator(real_imgs + pert_real)
            d_fake = discriminator(fake_imgs + pert_fake)
            loss_A = -(torch.log(d_real).mean() + torch.log(1 - d_fake).mean())

        opt_A.zero_grad()
        scaler.scale(loss_A).backward()
        scaler.step(opt_A)

        # Train Discriminator
        with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
            z = torch.randn(batch_size, LATENT_DIM, device=DEVICE)
            fake_imgs = generator(z)

            with torch.no_grad():
                pert_real = attacker(real_imgs)
                pert_fake = attacker(fake_imgs)

            d_real = discriminator(real_imgs + pert_real)
            d_fake = discriminator(fake_imgs + pert_fake)
            loss_D = -(torch.log(d_real).mean() + torch.log(1 - d_fake).mean())

        opt_D.zero_grad()
        scaler.scale(loss_D).backward()
        scaler.step(opt_D)

        # Train Generator
        with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
            z = torch.randn(batch_size, LATENT_DIM, device=DEVICE)
            fake_imgs = generator(z)

            with torch.no_grad():
                pert_fake = attacker(fake_imgs)

            d_fake = discriminator(fake_imgs + pert_fake)
            loss_G = -torch.log(d_fake).mean()

        opt_G.zero_grad()
        scaler.scale(loss_G).backward()
        scaler.step(opt_G)

        scaler.update()
        loop.set_postfix({
            'D': f"{loss_D.item():.4f}",
            'G': f"{loss_G.item():.4f}",
            'A': f"{loss_A.item():.4f}"
        })

    # Validation and metrics
    if epoch % METRIC_FREQ == 0:
        generator.eval()
        with torch.no_grad():
            z = torch.randn(5000, LATENT_DIM, device=DEVICE)
            fake_imgs = generator(z)

        # Calculate FID
        fid_score = calculate_fid(real_samples, fake_imgs)
        metrics['fid'].append(fid_score)

        # Attack validation
        test_loader = DataLoader(dataset, batch_size=100, shuffle=True)
        real_batch, _ = next(iter(test_loader))
        real_batch = real_batch.to(DEVICE)
        adv_real = pgd_attack(discriminator, real_batch)

        with torch.no_grad():
            pred_real = discriminator(adv_real).cpu().numpy()

        attack_success = 1 - accuracy_score(np.ones(100), (pred_real > 0.5).astype(int))
        metrics['attack_success'].append(attack_success)

        # Save progress
        np.save("metrics/metrics.npy", metrics)
        save_samples(epoch, generator)

        print(f"\nEpoch {epoch+1}:")
        print(f"FID: {fid_score:.2f} | Attack Success: {attack_success*100:.1f}%")

print("Training complete!")

100%|██████████| 170M/170M [00:07<00:00, 22.5MB/s]
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 191MB/s]



Epoch 1:
FID: 10632.96 | Attack Success: 0.0%





Epoch 6:
FID: 16036.78 | Attack Success: 0.0%





Epoch 11:
FID: 10956.68 | Attack Success: 0.0%





Epoch 16:
FID: 13331.34 | Attack Success: 0.0%





Epoch 21:
FID: 11709.27 | Attack Success: 0.0%





Epoch 26:
FID: 14220.30 | Attack Success: 0.0%





Epoch 31:
FID: 15596.25 | Attack Success: 0.0%





Epoch 36:
FID: 16567.50 | Attack Success: 1.0%





Epoch 41:
FID: 14926.65 | Attack Success: 0.0%





Epoch 46:
FID: 16451.21 | Attack Success: 1.0%





Epoch 51:
FID: 16274.94 | Attack Success: 2.0%





Epoch 56:
FID: 16063.77 | Attack Success: 7.0%





Epoch 61:
FID: 15328.15 | Attack Success: 15.0%





Epoch 66:
FID: 13961.23 | Attack Success: 17.0%





Epoch 71:
FID: 18245.22 | Attack Success: 25.0%





Epoch 76:
FID: 17375.49 | Attack Success: 27.0%





Epoch 81:
FID: 12786.17 | Attack Success: 29.0%





Epoch 86:
FID: 12577.43 | Attack Success: 21.0%





Epoch 91:
FID: 12154.69 | Attack Success: 23.0%





Epoch 96:
FID: 15210.74 | Attack Success: 27.0%


                                                                                             

Training complete!


