In [2]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import math
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from federated_learning import *
from gradient_suppression import *

In [3]:
# Configuration
input_model = 'trained_model_MNIST.pth'
num_clients = 5
num_classes = 10
input_shape = (1, 28, 28)  # MNIST grayscale images
batch_size = 32
epochs = 3
target = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = 2

print(f"Using device: {device}")

Using device: cpu


In [4]:
if input_model != None:
    trained_model_state_dict = torch.load(input_model)
else:
    trained_model_state_dict = None

In [5]:
# Download and transform MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)) # normalization params for MNIST
])

mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Split dataset indices for clients
dataset_size = len(mnist_train)
indices = list(range(dataset_size))
split_size = dataset_size // num_clients

clients_dataloaders = []
for i in range(num_clients):
    start = i * split_size
    end = start + split_size if i < num_clients - 1 else dataset_size
    client_indices = indices[start:end]
    client_subset = Subset(mnist_train, client_indices)
    dataloader = DataLoader(client_subset, batch_size=batch_size, shuffle=True)
    clients_dataloaders.append(dataloader)

print(f"Prepared {num_clients} client dataloaders")

test_loader = DataLoader(dataset=mnist_test, batch_size=64, shuffle=False)

Prepared 5 client dataloaders


In [6]:
global_model, local_updates = gradient_suppression(
    clients_dataloaders=clients_dataloaders,
    input_shape=input_shape,
    num_classes=num_classes,
    trained_model_state_dict = trained_model_state_dict,
    target=target,
    epochs=epochs,
    device=device,
    model=2
)

print("Gradient Suppression Attack complete!")

target is client 0
Comparing target update with global update:
Average MSE: 0.0069
Fraction of exactly equal params: 0.0000
Average cosine similarity: 0.9828
Gradient Suppression Attack complete!


In [7]:
# save global update to file
torch.save(global_model.state_dict(), 'global_update.pth')

In [8]:
print(global_model.state_dict())

OrderedDict({'conv1.weight': tensor([[[[ 0.0001,  0.0163,  0.0125],
          [ 0.0461, -0.0072,  0.0118],
          [ 0.0422,  0.0687, -0.0317]]],


        [[[-0.0156,  0.0716, -0.0038],
          [ 0.0684, -0.0531,  0.0597],
          [-0.0050,  0.0194, -0.0124]]],


        [[[-0.0535, -0.0304,  0.0236],
          [-0.0634, -0.0221,  0.0305],
          [ 0.0431,  0.0693, -0.0243]]],


        [[[ 0.0495,  0.0736, -0.0440],
          [ 0.0654,  0.0822,  0.0807],
          [-0.0213,  0.0181, -0.0283]]],


        [[[-0.0060, -0.0403,  0.0228],
          [ 0.0034,  0.0327, -0.0064],
          [-0.0108, -0.0537, -0.0279]]],


        [[[ 0.0172, -0.0233,  0.0111],
          [ 0.0674,  0.0508, -0.0612],
          [-0.0645, -0.0426, -0.0457]]],


        [[[ 0.0917,  0.0383,  0.0932],
          [-0.0394,  0.0991,  0.0833],
          [ 0.0531,  0.0589,  0.0589]]],


        [[[-0.0053,  0.0153, -0.0425],
          [ 0.0116,  0.0006,  0.0484],
          [ 0.0517, -0.0519,  0.0497]]],


   