In [17]:
import torch
from torch import nn
from torchvision import datasets
import fastai 
from torchvision.transforms import ToTensor
# from fastai.data.core import DataLoader
from torch.utils.data import DataLoader
from fastai.data.core import DataLoaders
from fastai.callback.core import Callback
from fastai.vision.all import Learner, Metric
from fastai import optimizer
import torch.nn.functional as F
from torch.utils.data import Subset
import copy


In [2]:
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

In [3]:
batch_size = 256

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

Shape of X [N, C, H, W]: torch.Size([256, 1, 28, 28])
Shape of y: torch.Size([256]) torch.int64


In [4]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [5]:
num_clients = 5
train_size = len(training_data)
# indices = list(range(train_size))

RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
torch.random.manual_seed(RANDOM_SEED)
indices = torch.randperm(train_size).tolist()

subset_size = train_size // num_clients
client_subsets = [] 
for i in range(num_clients):
    start_idx = i * subset_size
    end_idx = start_idx + subset_size

    if i == num_clients - 1:
        end_idx = train_size

    subset_indices = indices[start_idx:end_idx]
    client_subsets.append(Subset(training_data, subset_indices))

client_loaders = [DataLoader(sub, batch_size=batch_size, shuffle=True) for sub in client_subsets]

In [7]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(), 
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits 

model = NeuralNetwork().to(device)
original_shapes = []
for p in model.parameters():
    original_shapes.append(p.shape)
model

NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)

In [8]:
# We now have 5 different datasets, each with some sort of representation of the data that is unknown, ie, we have no 
# statistical information on the data that each of these clients would have
# We now need to implement variations of the 3 protocols, namely, the encoding protocol, the communication protocol and the decoding protocol

# For communication protocol for fixed size encoder, we set the seed. So the seed is communicated with the values. 
SEED = 42
torch.manual_seed(SEED)

# Encoders
def variable_size_encoder(grad_vectors, mu, p=0.1):
    # Lets take p = 0.1
    new_grad_vectors = []
    with torch.no_grad():
        for i in range(len(grad_vectors)):
            mask = torch.rand_like(grad_vectors[i], device=grad_vectors[i].device) < p
            Y = torch.empty_like(grad_vectors[i], device=grad_vectors[i].device)
            Y[mask] = (grad_vectors[i][mask] - mu[i] * (1-p))/p
            Y[~mask] = mu[i]
            new_grad_vectors.append(Y)
    return new_grad_vectors

def fixed_size_encoder(grad_vectors, mu, k=10):
    # k can vary
    new_grad_vectors = []
    with torch.no_grad():
        for i in range(len(grad_vectors)):
            shape = grad_vectors[i].shape
            # Flattening the parameters to permutate over them
            flat_grad = grad_vectors[i].view(-1)
            C = shape[-1]
            # Get the length of the flat_grad array
            d = flat_grad.numel()
            # Shuffle the list [1, 2, ... d] and get the first k elements
            indices = torch.randperm(C, device=flat_grad.device)[:k]
             
            mask = torch.zeros(d, dtype=torch.bool, device=flat_grad.device)
            mask[indices] = True
            
            Y = torch.empty_like(flat_grad)
            # Encode the parameters
            chosen_vals = (d/k)*flat_grad[mask] - ((d-k)/k)*mu[i]
            Y[mask] = chosen_vals
            Y[~mask] = mu[i]
            Y = Y.view(shape)
            new_grad_vectors.append(Y)
    return new_grad_vectors
            
            
# Decoders : I wont be making use of this later on
def averaging_decoder(grad_vectors_list):
    if isinstance(grad_vectors_list, list):
        grad_vectors_list = torch.stack(grad_vectors_list, dim=0)
    return torch.mean(grad_vectors_list, dim=0)

# Communication protocols
def sparse_for_variable_size_encoder(encoded_vectors, mu):
    final_vectors = []
    with torch.no_grad():
        for i in range(len(encoded_vectors)):
            flat_vector = encoded_vectors[i].view(-1)
            mask = flat_vector != mu[i]
            # vals = encoded_vectors[i][mask]
            indices = torch.nonzero(mask, as_tuple=False).view(-1)
            values = flat_vector[mask]
            final_vectors.append(list(zip(indices, values)))

    
    return final_vectors, mu
    
def sparse_for_fixed_size_encoder(encoded_vectors, mu):
    final_vectors = []
    with torch.no_grad():
        for i in range(len(encoded_vectors)):
            flat_vector = encoded_vectors[i].view(-1)
            mask = torch.zeros(len(flat_vector), dtype=torch.bool, device=flat_vector.device)
            mask[flat_vector != mu[i]] = True
            values = flat_vector[mask]
            final_vectors.append(values)

    return final_vectors, mu, SEED

def rebuild_from_protocol_1(final_vectors, mu, original_shapes):
    rebuilt_vectors = []
    with torch.no_grad():
        for i, vec_data in enumerate(final_vectors):
            num_elements = 1
            
            for dim_size in original_shapes[i]:
                num_elements *= dim_size
    
            
            Y_flat = torch.full((num_elements,), mu[i], dtype=torch.float32, device=mu[i].device)
    
            indices = torch.tensor([pair[0] for pair in vec_data], dtype=torch.long, device=Y_flat.device)
            values = torch.tensor([pair[1] for pair in vec_data], dtype=Y_flat.dtype, device=Y_flat.device)
            Y_flat[indices] = values
            Y = Y_flat.view(original_shapes[i])
            rebuilt_vectors.append(Y)
    return rebuilt_vectors

def rebuild_from_protocol_2(final_vectors, mu, SEED, original_shapes):
    rebuilt_vectors = []
    with torch.no_grad():
        for i, values in enumerate(final_vectors):
            
            num_elements = 1
            for dim_size in original_shapes[i]:
                num_elements *= dim_size
    

            torch.manual_seed(SEED) 
            k = len(final_vectors[i])
            indices = torch.randperm(num_elements)[:k]
    
            Y_flat = torch.full((num_elements,), mu[i], dtype=torch.float32, device=mu[i].device)
    
            # Place the chosen values
            Y_flat[indices] = values
    
            # Reshape to original shape
            Y = Y_flat.view(original_shapes[i])
            rebuilt_vectors.append(Y)
    return rebuilt_vectors

parameters = list(model.parameters())
mu_1 = []
with torch.no_grad():
    for p in parameters:
        mu_1.append(torch.mean(p))

encoded_vectors = fixed_size_encoder(parameters, mu_1)
final_vectors, mu, SEED = sparse_for_fixed_size_encoder(encoded_vectors, mu_1)
# rebuild_from_protocol_2(final_vectors, mu_1, SEED, original_shapes)

In [9]:
a = torch.FloatTensor([[1, 2, 3], [2, 3, 4], [4, 5, 6]])
averaging_decoder(a)

tensor([2.3333, 3.3333, 4.3333])

In [10]:
parameters = list(model.parameters())
variable_size_encoder(parameters)

TypeError: variable_size_encoder() missing 1 required positional argument: 'mu'

In [126]:
class ProxSGDWithLinearSearch:
    def __init__(self, params, lr):
        self.params, self.lr = list(params), lr
        self.state = {p: {} for p in self.params}
        self.hypers = [{'lr': lr}]
        self.max_iter = 5
        self.eta = 1e-5
        
    def soft_threshold(self, x, eta):
        # Apply the soft-thresholding operator
        return F.softshrink(x, lambd=eta)
        
    def prox_operator(self, x):
        # Use the soft-thresholding operator as the proximal step
        return self.soft_threshold(x, self.eta)

    def Gt(self, x, step_size, x_grad):
        return (1/step_size) * (x - self.prox_operator(x - step_size * x_grad))
        
    def step(self, *args, **kwargs):
        model = kwargs.get("model")
        loss_fn = kwargs.get("loss_fn")
        X = kwargs.get("X")
        y = kwargs.get("y")
        
        orig_params = [p.data.clone() for p in self.params]
        step_size = self.lr
        with torch.no_grad():
            pred = model(X)
            old_loss = loss_fn(pred, y)
        flag = True
        for _ in range(self.max_iter):
            for p in self.params:
                if p.grad is not None: 
                    Gt_val = self.Gt(p.data, step_size, p.grad.data)
                    p.data = p.data - step_size * Gt_val
            with torch.no_grad():
                pred = model(X)
                new_loss = loss_fn(pred, y)
            if new_loss < old_loss:
                flag = False
                break
            else:
                for i, j in zip(self.params, orig_params):
                    i.data.copy_(j)
                step_size *= 0.5
        if flag: 
            for p in self.params:
                if p.grad is not None: 
                    Gt_val = self.Gt(p.data, step_size, p.grad.data)
                    p.data = p.data - step_size * Gt_val
        else:
            self.lr = step_size
        # print(self.lr)
    def zero_grad(self, *args, **kwargs):
        for p in self.params:
            p.grad = None

    def set_hypers(self, **kwargs):
        if 'lr' in kwargs:
            self.lr = kwargs['lr']
            self.hypers[0]['lr'] = kwargs['lr']


    

In [81]:
loss_fn = nn.CrossEntropyLoss()

In [82]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step(model=model, loss_fn=loss_fn, X=X, y=y)
        optimizer.zero_grad()

        if batch % 100 == 0:
            print(f"BATCH: {batch} of {size/batch_size} batches")
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [83]:
def test(dataloader, model, loss_fn, number="main"):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error for client {number}: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [150]:
class Client:
    def __init__(self, model, train_dataloader, loss_fn, mu, max_iter=2):
        self.model = model
        # self.optimizer = ProxSGDWithLinearSearch(model.parameters(), 20)
        self.train_dataloader = train_dataloader
        self.max_iter = max_iter
        self.loss_fn = loss_fn
        self.mu = mu
        
    def train(self):
        for i in range(0, self.max_iter):
            train(self.train_dataloader, self.model, self.loss_fn, ProxSGDWithLinearSearch(self.model.parameters(), 20))

    def test(self, test_dataloader, number):
        test(test_dataloader, self.model, self.loss_fn, number)

    def set_parameters(self, model):
        self.model = copy.deepcopy(model)
            
    def get_encoded_1(self, p):
        encoded = variable_size_encoder(list(self.model.parameters()), self.mu, p)
        final_vectors, mu = sparse_for_variable_size_encoder(encoded, self.mu)
        return final_vectors, mu

    def get_encoded_2(self):
        encoded = fixed_size_encoder(list(self.model.parameters()), self.mu)
        final_vectors, mu, seed = sparse_for_fixed_size_encoder(encoded, self.mu)
        return final_vectors, mu, seed

    # Built the member function to test if continuous update of the mean is helpful or not
    def update_mean(self):
        with torch.no_grad():
            params = list(self.model.parameters())
            self.mu = [torch.mean(p) for p in params]

In [151]:
class Master:
    def __init__(self, model, mu, loss_fn):
        self.model = model
        self.original_shapes = [p.shape for p in model.parameters()]
        self.mu = mu
        self.loss_fn = loss_fn

    def set_mean(self, mu):
        self.mu = copy.deepcopy(mu)
        
    def update_global_model_from_protocol_1(self, clients_data):
        # clients_data is a list of tuples (final_vectors, mu, p) from each client
        # Decode each client's parameters and then compute updates
        decoded_params_list = []
        for (final_vectors, mu) in clients_data:
            decoded = rebuild_from_protocol_1(final_vectors, mu, self.original_shapes)
            decoded_params_list.append(decoded)

        # Now decoded_params_list is a list of parameter lists from each client
        # Convert each client's param list into a tensor stack and average updates
        # First, get master_params for reference
        master_params = list(self.model.parameters())
        
        # Compute updates: (client_params - master_params) for each client, then average
        all_updates = []
        with torch.no_grad():
            for decoded_params in decoded_params_list:
                updates = [(dp - mp) for dp, mp in zip(decoded_params, master_params)]
                all_updates.append(updates)

            # Average updates across clients
            # Stack each parameter across clients and mean
            averaged_updates = []
            num_clients = len(all_updates)
            for param_idx in range(len(master_params)):
                # Gather this param_idx from all clients
                stack = torch.stack([all_updates[c][param_idx] for c in range(num_clients)], dim=0)
                avg = torch.mean(stack, dim=0)
                averaged_updates.append(avg)

            # Apply averaged updates to master model
            for mp, au in zip(master_params, averaged_updates):
                mp.data.add_(au)

    def update_global_model_from_protocol_2(self, clients_data):
        # clients_data is a list of tuples (final_vectors, mu, seed, k) from each client
        decoded_params_list = []
        for (final_vectors, mu, seed, k) in clients_data:
            decoded = rebuild_from_protocol_2(final_vectors, mu, seed, self.original_shapes)
            decoded_params_list.append(decoded)

        master_params = list(self.model.parameters())
        all_updates = []
        with torch.no_grad():
            for decoded_params in decoded_params_list:
                updates = [(dp - mp) for dp, mp in zip(decoded_params, master_params)]
                all_updates.append(updates)

            num_clients = len(all_updates)
            averaged_updates = []
            for param_idx in range(len(master_params)):
                stack = torch.stack([all_updates[c][param_idx] for c in range(num_clients)], dim=0)
                avg = torch.mean(stack, dim=0)
                averaged_updates.append(avg)

            # Apply averaged updates to master model
            for mp, au in zip(master_params, averaged_updates):
                mp.data.add_(au)

    def test(self, test_dataloader):
        test(test_dataloader, self.model, self.loss_fn, "master")


In [178]:
model = NeuralNetwork().to(device)

parameters = list(model.parameters())
mu_1 = []
with torch.no_grad():
    for p in parameters:
        mu_1.append(0.0)
# mu_1 = torch.zeros(len(parameters), device=device)        
master = Master(model, mu_1, loss_fn) 
clients = [Client(NeuralNetwork().to(device), client_loaders[i], loss_fn, mu_1, 2) for i in range(num_clients)]


In [187]:

# Master sends global parameters to the client
for client in clients:
    client.set_parameters(master.model)

# Clients are trained locally
for client in clients:
    client.train()

clients_data_protocol_1 = []
for client in clients:
    final_vectors, mu = client.get_encoded_1(p=0.1)
    clients_data_protocol_1.append((final_vectors, mu))

master.update_global_model_from_protocol_1(clients_data_protocol_1)

BATCH: 0 of 46.875 batches
loss: 65854772204273664.000000  [  256/12000]
BATCH: 0 of 46.875 batches
loss: 2.314150  [  256/12000]
BATCH: 0 of 46.875 batches
loss: 316963329404829696.000000  [  256/12000]
BATCH: 0 of 46.875 batches
loss: 5913.462402  [  256/12000]
BATCH: 0 of 46.875 batches
loss: 416057776930816000.000000  [  256/12000]
BATCH: 0 of 46.875 batches
loss: 2.322959  [  256/12000]


KeyboardInterrupt: 

In [None]:
master.test(test_dataloader)

In [164]:
for i in range(len(clients)):
    clients[i].test(test_dataloader, i)

Test Error for client 0: 
 Accuracy: 93.8%, Avg loss: 0.199948 

Test Error for client 1: 
 Accuracy: 94.3%, Avg loss: 0.184125 

Test Error for client 2: 
 Accuracy: 94.2%, Avg loss: 0.184488 

Test Error for client 3: 
 Accuracy: 94.6%, Avg loss: 0.170945 

Test Error for client 4: 
 Accuracy: 94.0%, Avg loss: 0.196233 

