In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from torch.autograd import grad
import matplotlib.pyplot as plt
from torchvision.transforms import ToPILImage

# Neural network architecture (LeNet with BatchNorm)
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.bn1 = nn.BatchNorm2d(6) 
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.bn2 = nn.BatchNorm2d(16) 
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2)
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# CIFAR-10 dataset and define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_data = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
test_data = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=False)

# noise level
NOISE_LEVEL = 0.99

stored_noise = []

# Add Gaussian noise to gradients and store the noise
def add_and_store_noise_to_gradients(gradients, noise_level, store_noise=False):
    noisy_gradients = []
    for g in gradients:
        noise = torch.randn_like(g) * noise_level
        if store_noise:
            stored_noise.append(noise)
        noisy_gradients.append(g + noise)
    return noisy_gradients

#Train the model
def train(model, train_loader, optimizer, criterion, device, scheduler, add_noise=False):
    model.train()
    running_loss = 0.0
    for epoch in range(10):
        scheduler.step()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            if add_noise:
                # Add Gaussian noise to gradients and store the noise
                noisy_gradients = add_and_store_noise_to_gradients([p.grad for p in model.parameters()], NOISE_LEVEL, store_noise=True)
                for p, noisy_grad in zip(model.parameters(), noisy_gradients):
                    p.grad = noisy_grad
            optimizer.step()
            running_loss += loss.item()
            if batch_idx % 200 == 199:
                print(f"Epoch [{epoch + 1}/10], Batch {batch_idx + 1}/{len(train_loader)}, Loss: {running_loss / 200:.4f}")
                running_loss = 0.0

# Test the model
def test(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    accuracy = 100.0 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

# DLG attack and image reconstruction
def perform_dlg_attack(model, criterion, test_loader, device, add_noise=False):
    model.eval()
    #batch selected for image reconstruction
    img_batch_index = 0
    original_images = []
    reconstructed_images = []

    for batch_idx, (data, target) in enumerate(test_loader):
        if batch_idx == img_batch_index:
            data, target = data.to(device), target.to(device)
            original_images.append(data.cpu().detach())

            # Forward pass to get the loss and gradients
            output = model(data)
            loss = criterion(output, target)
            model.zero_grad()
            loss.backward()

            gradients = [p.grad.clone() for p in model.parameters()]

            if add_noise:
                noisy_gradients = [g + n for g, n in zip(gradients, stored_noise)]
                # Reconstruction with noisy gradients
                reconstructed_data = reconstruct_from_gradients(model, noisy_gradients, data, target, device, criterion)
            else:
                # Reconstruction without noise
                reconstructed_data = reconstruct_from_gradients(model, gradients, data, target, device, criterion)
            
            reconstructed_images.append(reconstructed_data.cpu().detach())
            break 

    # Plot original and reconstructed images
    plot_images(original_images[0], reconstructed_images[0])

def reconstruct_from_gradients(model, gradients, data, target, device, criterion):
    reconstructed_data = data.clone().to(device).requires_grad_(True)
    optimizer = torch.optim.LBFGS([reconstructed_data])
    for _ in range(10):
        def closure():
            optimizer.zero_grad()
            output = model(reconstructed_data)
            loss = criterion(output, target)
            reconstructed_gradients = grad(loss, model.parameters(), create_graph=True)
            # Compute loss based on the difference between noisy gradients and current gradients
            gradient_diff_loss = 0
            for rg, ng in zip(reconstructed_gradients, gradients):
                gradient_diff_loss += ((rg - ng)**2).sum()
            gradient_diff_loss.backward()
            return gradient_diff_loss
        optimizer.step(closure)
    return reconstructed_data

def plot_images(original_images, reconstructed_images):
    to_pil = ToPILImage()

    def reverse_normalize(tensor):
        return torch.clamp(tensor * 0.5 + 0.5, 0, 1)

    batch_size = len(original_images)
    fig, axes = plt.subplots(nrows=batch_size, ncols=2, figsize=(10, 5 * batch_size))
    for i in range(batch_size):
        # Reverse normalization and convert tensors to PIL images for better readability
        original_image = to_pil(reverse_normalize(original_images[i]))
        reconstructed_image = to_pil(reverse_normalize(reconstructed_images[i]))
        axes[i, 0].imshow(original_image)
        axes[i, 0].set_title(f'Original Image {i + 1}')
        axes[i, 0].axis('off')
        axes[i, 1].imshow(reconstructed_image)
        axes[i, 1].set_title(f'Reconstructed Image {i + 1}')
        axes[i, 1].axis('off')
    plt.show()

# Initialize the model, optimizer, and criterion
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LeNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
criterion = nn.CrossEntropyLoss()
# Set up learning rate scheduler
scheduler = StepLR(optimizer, step_size=5, gamma=0.1)  # Reduce LR by half every 5 epochs

# Train the model without noise
print("Training the model without noise...")
train(model, train_loader, optimizer, criterion, device, scheduler, add_noise=False)

# Test the model without noise
print("\nTesting the model without noise...")
test_accuracy = test(model, test_loader, device)

# Perform DLG attack and image reconstruction without noise
print("\nPerforming DLG attack and reconstruction without noise...")
perform_dlg_attack(model, criterion, test_loader, device, add_noise=False)

# Train the model with noise
print("\nTraining the model with noise...")
train(model, train_loader, optimizer, criterion, device, scheduler, add_noise=True)

# Test the model with noise
print("\nTesting the model with noise...")
test_accuracy = test(model, test_loader, device)

# Perform DLG attack and  image reconstruction with noise
print("\nPerforming DLG attack and reconstruction with noise...")
perform_dlg_attack(model, criterion, test_loader, device, add_noise=True)