In [1]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import math
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from model import Model
from federated_learning import *
from gradient_suppression import *

In [2]:
# Configuration
input_model = 'trained_model.pth'
num_clients = 5
num_classes = 10
input_shape = (1, 28, 28)  # MNIST grayscale images
batch_size = 32
epochs = 3
target = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")

Using device: cpu


In [3]:
if input_model != None:
    trained_model_state_dict = torch.load(input_model)
else:
    trained_model_state_dict = None

In [4]:
# Download and transform MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)) # normalization params for MNIST
])

mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Split dataset indices for clients
dataset_size = len(mnist_train)
indices = list(range(dataset_size))
split_size = dataset_size // num_clients

clients_dataloaders = []
for i in range(num_clients):
    start = i * split_size
    end = start + split_size if i < num_clients - 1 else dataset_size
    client_indices = indices[start:end]
    client_subset = Subset(mnist_train, client_indices)
    dataloader = DataLoader(client_subset, batch_size=batch_size, shuffle=True)
    clients_dataloaders.append(dataloader)

print(f"Prepared {num_clients} client dataloaders")

test_loader = DataLoader(dataset=mnist_test, batch_size=64, shuffle=False)

Prepared 5 client dataloaders


In [None]:
global_model, local_updates = gradient_suppression(
    clients_dataloaders=clients_dataloaders,
    input_shape=input_shape,
    num_classes=num_classes,
    trained_model_state_dict = trained_model_state_dict,
    target=target,
    epochs=epochs,
    device=device
)

print("Gradient Suppression Attack complete!")

target is client 0
Comparing target update with global update:
Average MSE: 0.0071
Fraction of exactly equal params: 0.0116
Average cosine similarity: 0.9970
Gradient Suppression Attack complete!


In [None]:
# save global update to file
torch.save(global_model.state_dict(), 'global_update.pth')