In [None]:
import torch
import matplotlib.pyplot as plt
from torchvision.transforms.functional import to_pil_image
import os
from torchvision import datasets, transforms

num_data_points = 1

class SimpleCNN(torch.nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = torch.nn.MaxPool2d(2, 2)
        self.fc1 = torch.nn.Linear(64 * 14 * 14, 128)
        self.fc2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

benign_model_path = './model/benign_model.pth'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

benign_model = SimpleCNN().to(device)
benign_model.load_state_dict(torch.load(benign_model_path, map_location=device))

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_loader = torch.utils.data.DataLoader(mnist_dataset, batch_size=len(mnist_dataset), shuffle=False)

all_data = next(iter(mnist_loader))
mnist_X, mnist_y = all_data[0].to(device), all_data[1].to(device)

selected_indices = torch.randperm(len(mnist_X))[:num_data_points]
selected_X = mnist_X[selected_indices]
selected_y = mnist_y[selected_indices]

image_folder = './image'
os.makedirs(image_folder, exist_ok=True)

for i, img_tensor in enumerate(selected_X):
    img = to_pil_image((img_tensor * 0.5 + 0.5).cpu())
    img.save(os.path.join(image_folder, f"selected_image_{i}.png"))
print(f"Selected images saved in folder: {image_folder}")

target_model = SimpleCNN().to(device)
target_model.load_state_dict(benign_model.state_dict())
target_model.train()

target_optimizer = torch.optim.SGD(target_model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

# Use all selected data points in a single batch
target_optimizer.zero_grad()
outputs = target_model(selected_X)
loss = criterion(outputs, selected_y)
dy_dx = torch.autograd.grad(loss, target_model.parameters())
original_dy_dx = list((_.detach().clone() for _ in dy_dx))
# loss.backward()


# Record gradients
# gradients = {name: param.grad.clone() for name, param in target_model.named_parameters() if param.grad is not None}

# Save gradients for inspection
# torch.save(gradients, './model/gradients.pth')

torch.save(target_model.state_dict(), './model/target_model_model.pth')

# Deep Leakage from Gradients (DLG) Algorithm
# Initialize dummy data
x_prime = torch.randn(selected_X.size(), requires_grad=True, device=device)
y_prime = torch.randn(selected_y.size(), requires_grad=True, device=device)  # y_prime now requires gradients


# Use LBFGS optimizer
optimizer = torch.optim.LBFGS([x_prime, y_prime], lr=0.1)

def closure():
    optimizer.zero_grad()

    # Compute dummy gradients
    dummy_outputs = benign_model(x_prime)
    dummy_loss = criterion(dummy_outputs, y_prime.long())
    dummy_dy_dx = torch.autograd.grad(dummy_loss, benign_model.parameters(), create_graph=True)
    # dummy_loss.backward()

    dummy_gradients = {name: param.grad.clone() for name, param in benign_model.named_parameters() if param.grad is not None}
    
    grad_diff = 0
    for gx, gy in zip(dummy_dy_dx, original_dy_dx): 
        grad_diff += ((gx - gy) ** 2).sum()
    grad_diff.backward()
        
    return grad_diff
'''
    # Compute gradient matching loss
    gradient_matching_loss = torch.sum(
        torch.stack([
            torch.nn.functional.mse_loss(dummy_gradients[name], gradients[name], reduction='sum')
            for name in gradients.keys()
        ])
    )

    # Backward pass for gradient matching loss
    gradient_matching_loss.backward()
    return gradient_matching_loss'''

for step in range(11):  # Number of optimization steps
    optimizer.step(closure)

    # Evaluate loss after step
    with torch.no_grad():
        dummy_outputs = target_model(x_prime)
        dummy_loss = criterion(dummy_outputs, y_prime.long())
        print(f"Step {step}, Dummy Loss: {dummy_loss.item():.4f}")

# Save the reconstructed data
for i, img_tensor in enumerate(x_prime):
    img = to_pil_image((img_tensor.detach() * 0.5 + 0.5).cpu())
    img.save(os.path.join(image_folder, f"reconstructed_image_{i}.png"))
print(f"Reconstructed images saved in folder: {image_folder}")

  warn(
  benign_model.load_state_dict(torch.load(benign_model_path, map_location=device))


Selected images saved in folder: ./image


RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED