In [99]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import math
import random
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from federated_learning import *
from gradient_suppression import *

In [None]:
# Configuration
model = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
rng_seed = 0

# Clients
num_clients = 10
private_dataset_size = 1

# Dataset
num_classes = 10
input_shape = (1, 28, 28)  # MNIST grayscale images

# Attack params
input_model = 'trained_model_MNIST.pth'
target = 'all'
learning_rate = 0.01
loss_fn = torch.nn.CrossEntropyLoss()
batch_size = 1
epochs = 1

print(f"Using device: {device}")

Using device: cpu


In [101]:
# Make reproducible
random.seed(rng_seed)
torch.manual_seed(rng_seed)
np.random.seed(rng_seed)

In [102]:
# Load pretrained model if specified
if input_model != None:
    trained_model_state_dict = torch.load(input_model)
else:
    trained_model_state_dict = None

In [103]:
# Download and transform MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)) # normalization params for MNIST
])

mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

clients_dataloaders = []
for i in range(num_clients):
    client_indices = random.sample(range(len(mnist_train)), private_dataset_size)
    client_subset = Subset(mnist_train, client_indices)
    dataloader = DataLoader(client_subset, batch_size=batch_size, shuffle=True)
    clients_dataloaders.append(dataloader)

print(f"Prepared {num_clients} client dataloaders")

# inspect private datasets
for idx in range(len(clients_dataloaders)):
    data = []
    
    for dp, y in clients_dataloaders[idx]:
        data.append(int(y[0]))
    print('Private dataset ', idx, ': ', data, '. Missing digits: ', [x for x in range(10) if x not in data], sep='')

test_loader = DataLoader(dataset=mnist_test, batch_size=batch_size, shuffle=False)

Prepared 10 client dataloaders
Private dataset 0: [4]. Missing digits: [0, 1, 2, 3, 5, 6, 7, 8, 9]
Private dataset 1: [9]. Missing digits: [0, 1, 2, 3, 4, 5, 6, 7, 8]
Private dataset 2: [3]. Missing digits: [0, 1, 2, 4, 5, 6, 7, 8, 9]
Private dataset 3: [3]. Missing digits: [0, 1, 2, 4, 5, 6, 7, 8, 9]
Private dataset 4: [8]. Missing digits: [0, 1, 2, 3, 4, 5, 6, 7, 9]
Private dataset 5: [7]. Missing digits: [0, 1, 2, 3, 4, 5, 6, 8, 9]
Private dataset 6: [9]. Missing digits: [0, 1, 2, 3, 4, 5, 6, 7, 8]
Private dataset 7: [4]. Missing digits: [0, 1, 2, 3, 5, 6, 7, 8, 9]
Private dataset 8: [8]. Missing digits: [0, 1, 2, 3, 4, 5, 6, 7, 9]
Private dataset 9: [1]. Missing digits: [0, 2, 3, 4, 5, 6, 7, 8, 9]


In [104]:
# Attack round
global_model, local_updates = gradient_suppression(
    clients_dataloaders=clients_dataloaders,
    input_shape=input_shape,
    num_classes=num_classes,
    trained_model_state_dict = trained_model_state_dict,
    target=target,
    criterion=loss_fn,
    lr=learning_rate,
    epochs=epochs,
    device=device,
    model=model
)

print("Gradient Suppression Attack complete!")

if target == 'all':
    for i in range(num_clients):
        global_model, local_updates = gradient_suppression(
            clients_dataloaders=clients_dataloaders,
            input_shape=input_shape,
            num_classes=num_classes,
            trained_model_state_dict = trained_model_state_dict,
            target=i,
            criterion=loss_fn,
            lr=learning_rate,
            epochs=epochs,
            device=device,
            model=model
        )
        filename = 'target_' + str(i) + '.pth'
        torch.save(global_model.state_dict(), filename)

target is client 0
Comparing target update with global update:
Average MSE: 0.0081
Average cosine similarity: 0.9999
Gradient Suppression Attack complete!


In [105]:
# Debug
# for i in range(num_clients):
#     row = []
#     for j in range(num_clients):
#         avg_cos_sim = state_dicts_cosine_similarity(local_updates[i], local_updates[j])
#         # avg_cos_sim = i == j
#         row.append(avg_cos_sim)
#     print(row)

In [106]:
# save global update to file
filename = 'target_' + str(target) + '.pth'
torch.save(global_model.state_dict(), filename)