# Setup and Dependencies

In [3]:
!pip install torch numpy torchvision matplotlib

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [3]:
# Required Libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as np
import matplotlib.pyplot as plt
import os
import random
from scipy import linalg
from sklearn.neighbors import NearestNeighbors
from torchvision.models import resnet18


In [4]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Set random seed for reproducibility
seed = 42
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

# Data Preparation

In [5]:
# Data Transforms for CIFAR-10
transform = transforms.Compose([
    transforms.Resize(128),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Load CIFAR-10 Dataset
public_data = datasets.CIFAR10(root='data', train=True, transform=transform, download=True)

# DataLoader
batch_size = 64
public_loader = DataLoader(public_data, batch_size=batch_size, shuffle=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:09<00:00, 18.7MB/s]


Extracting data/cifar-10-python.tar.gz to data


# Model Architechture

### Generator

In [6]:
class Generator(nn.Module):
    def __init__(self, z_dim=100, img_channels=3, feature_g=64):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(z_dim, feature_g * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(feature_g * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(feature_g * 8, feature_g * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_g * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(feature_g * 4, feature_g * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_g * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(feature_g * 2, img_channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        return self.net(z)

### Discriminator

In [7]:
class Discriminator(nn.Module):
    def __init__(self, img_channels=3, feature_d=64, num_classes=11):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(img_channels, feature_d, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(feature_d, feature_d * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_d * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(feature_d * 2, feature_d * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_d * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(feature_d * 4, num_classes, 4, 1, 0, bias=False)
        )

    def forward(self, x):
        return self.net(x).view(x.size(0), -1)

# Training Functions

### Inversion-Specific GAN Training

In [8]:
def train_inversion_gan(generator, discriminator, target_model, public_loader, num_epochs=50, lr=0.0002):
    criterion = nn.CrossEntropyLoss()
    optimizer_g = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizer_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

    for epoch in range(num_epochs):
        for i, (images, _) in enumerate(public_loader):
            images = images.to(device)
            batch_size = images.size(0)
            noise = torch.randn(batch_size, 100, 1, 1, device=device)

            # Labels from target model (soft labels)
            soft_labels = target_model(images).detach()

            # Train Discriminator
            real_output = discriminator(images)
            fake_images = generator(noise)
            fake_output = discriminator(fake_images.detach())

            d_loss_real = criterion(real_output, torch.argmax(soft_labels, dim=1))
            d_loss_fake = criterion(fake_output, torch.full((batch_size,), 10, dtype=torch.long, device=device))
            d_loss = d_loss_real + d_loss_fake

            optimizer_d.zero_grad()
            d_loss.backward()
            optimizer_d.step()

            # Train Generator
            fake_output = discriminator(fake_images)
            g_loss = criterion(fake_output, torch.argmax(soft_labels, dim=1))

            optimizer_g.zero_grad()
            g_loss.backward()
            optimizer_g.step()

            if i % 100 == 0:
                print(f"Epoch [{epoch}/{num_epochs}], Step [{i}/{len(public_loader)}], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}")

        # Save generated samples for monitoring
        save_image(fake_images, f'samples/epoch_{epoch}.png', normalize=True)

# Distributional Recovery

In [9]:
class DistributionalRecovery:
    def __init__(self, generator, discriminator, target_model, num_classes, device):
        self.generator = generator
        self.discriminator = discriminator
        self.target_model = target_model
        self.num_classes = num_classes
        self.device = device
        self.mu = nn.Parameter(torch.zeros(1, 100, 1, 1, device=device), requires_grad=True)
        self.sigma = nn.Parameter(torch.ones(1, 100, 1, 1, device=device), requires_grad=True)
        self.optimizer = optim.Adam([self.mu, self.sigma], lr=0.01)

    def sample_latent(self, num_samples):
        epsilon = torch.randn(num_samples, 100, 1, 1, device=self.device)
        return self.mu + self.sigma * epsilon

    def distributional_loss(self, target_label, lambda_id=100):
        z = self.sample_latent(64)
        generated_images = self.generator(z)

        # Prior Loss: Real vs Fake
        realness = self.discriminator(generated_images)
        Lprior = -torch.mean(torch.log(torch.sigmoid(realness[:, -1])))

        # Identity Loss: Classification Confidence under Target Model
        target_confidence = self.target_model(generated_images)
        Lid = -torch.mean(torch.log(target_confidence[:, target_label]))

        return Lprior + lambda_id * Lid

    def update_distribution(self, target_label, num_steps=1500):
        for step in range(num_steps):
            loss = self.distributional_loss(target_label)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            if step % 100 == 0:
                print(f"Step [{step}/{num_steps}], Loss: {loss.item()}")

    def generate_samples(self, num_samples=64):
        z = self.sample_latent(num_samples)
        return self.generator(z)

# Evaluation Metrics

### Attack Accuracy

In [10]:
def calculate_attack_accuracy(generator, evaluation_classifier, target_label, num_samples=100):
    evaluation_classifier.eval()
    correct = 0

    with torch.no_grad():
        # Generate Samples for the Target Label
        z = torch.randn(num_samples, 100, 1, 1, device=device)
        generated_images = generator(z)

        # Get Predictions from the Evaluation Classifier
        outputs = evaluation_classifier(generated_images)
        _, predicted = torch.max(outputs, 1)

        # Calculate Accuracy for the Target Label
        correct += (predicted == target_label).sum().item()

    attack_accuracy = correct / num_samples
    return attack_accuracy

### K-Nearest Neighbor Distance (KNN Dist)

In [11]:
def extract_features(images, model):
    model.eval()
    with torch.no_grad():
        features = model(images)
    return features.cpu().numpy()

def calculate_knn_distance(generator, evaluation_classifier, real_loader, target_label, num_samples=100):
    evaluation_classifier.eval()

    # Step 1: Get Features for Real Images of the Target Label
    real_features = []
    for images, labels in real_loader:
        images = images.to(device)
        labels = labels.to(device)

        # Filter real images of the target label
        target_images = images[labels == target_label]
        if len(target_images) > 0:
            real_features.append(extract_features(target_images, evaluation_classifier))

    real_features = np.vstack(real_features)

    # Step 2: Get Features for Generated Images
    z = torch.randn(num_samples, 100, 1, 1, device=device)
    generated_images = generator(z)
    fake_features = extract_features(generated_images, evaluation_classifier)

    # Step 3: Calculate K-Nearest Neighbor Distance
    knn = NearestNeighbors(n_neighbors=1)
    knn.fit(real_features)
    distances, _ = knn.kneighbors(fake_features)

    knn_dist = np.mean(distances)
    return knn_dist

In [12]:
# Helper Function: Calculate Feature Means and Covariances
def calculate_statistics(features):
    mu = np.mean(features, axis=0)
    sigma = np.cov(features, rowvar=False)
    return mu, sigma

# FID Score Calculation
def calculate_fid(real_features, fake_features):
    mu_real, sigma_real = calculate_statistics(real_features)
    mu_fake, sigma_fake = calculate_statistics(fake_features)

    # Calculate FID
    diff = mu_real - mu_fake
    covmean = linalg.sqrtm(sigma_real.dot(sigma_fake))

    # Check for imaginary component from sqrtm
    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = np.sum(diff**2) + np.trace(sigma_real + sigma_fake - 2*covmean)
    return fid

In [13]:
# Ensure the samples directory exists
os.makedirs('samples', exist_ok=True)

# Main Training Loop and Evaluation Loop

In [14]:
def evaluate_attack(generator, evaluation_classifier, real_loader, target_label=0, num_samples=100):
    # Calculate Attack Accuracy
    attack_accuracy = calculate_attack_accuracy(generator, evaluation_classifier, target_label, num_samples)

    # Calculate KNN Distance
    knn_dist = calculate_knn_distance(generator, evaluation_classifier, real_loader, target_label, num_samples)

    # Generate Samples for FID Calculation
    z = torch.randn(num_samples, 100, 1, 1, device=device)
    generated_images = generator(z)
    fake_features = extract_features(generated_images, evaluation_classifier)

    # Get Real Features
    real_features = []
    for images, labels in real_loader:
        images = images.to(device)
        labels = labels.to(device)
        target_images = images[labels == target_label]
        if len(target_images) > 0:
            real_features.append(extract_features(target_images, evaluation_classifier))

    real_features = np.vstack(real_features)
    fid_score = calculate_fid(real_features, fake_features)

    # Print All Metrics
    print(f"\nEvaluation Results for Class {target_label}:")
    print(f"Attack Accuracy: {attack_accuracy * 100:.2f}%")
    print(f"KNN Distance: {knn_dist:.4f}")
    print(f"FID Score: {fid_score:.4f}")

# Initialize Models
generator = Generator().to(device)
discriminator = Discriminator().to(device)
# CIFAR-10 Specific Classifier
target_model = resnet18(pretrained=False, num_classes=10).to(device)

# Train Inversion-Specific GAN
train_inversion_gan(generator, discriminator, target_model, public_loader)

# Loop over all 10 classes in CIFAR-10 (0 to 9)
for target_label in range(10):
    print(f"\nEvaluating for Class {target_label}...")

    # Distributional Recovery for the current class
    recovery = DistributionalRecovery(generator, discriminator, target_model, num_classes=10, device=device)
    recovery.update_distribution(target_label=target_label)

    # Evaluate Attack Performance for the current class
    evaluate_attack(generator, target_model, public_loader, target_label=target_label)



Epoch [0/50], Step [0/782], D Loss: 9.899887084960938, G Loss: 2.918546199798584
Epoch [0/50], Step [100/782], D Loss: 2.5068628787994385, G Loss: 7.250673770904541
Epoch [0/50], Step [200/782], D Loss: 2.383089542388916, G Loss: 7.997137069702148
Epoch [0/50], Step [300/782], D Loss: 2.392315149307251, G Loss: 8.797613143920898
Epoch [0/50], Step [400/782], D Loss: 2.4219870567321777, G Loss: 9.412561416625977
Epoch [0/50], Step [500/782], D Loss: 2.424818515777588, G Loss: 9.975582122802734
Epoch [0/50], Step [600/782], D Loss: 2.426813840866089, G Loss: 9.196457862854004
Epoch [0/50], Step [700/782], D Loss: 2.450277805328369, G Loss: 9.454392433166504
Epoch [1/50], Step [0/782], D Loss: 2.4408950805664062, G Loss: 11.310224533081055
Epoch [1/50], Step [100/782], D Loss: 2.418503522872925, G Loss: 12.271249771118164
Epoch [1/50], Step [200/782], D Loss: 2.403623104095459, G Loss: 11.711448669433594
Epoch [1/50], Step [300/782], D Loss: 2.3976011276245117, G Loss: 11.001505851745605
