In [1]:
import torch
from torch import nn
from torchvision import datasets
import fastai 
from torchvision.transforms import ToTensor
# from fastai.data.core import DataLoader
from torch.utils.data import DataLoader
from fastai.data.core import DataLoaders
from fastai.callback.core import Callback
from fastai.vision.all import Learner, Metric
from fastai import optimizer
import torch.nn.functional as F
from torch.utils.data import Subset


In [2]:
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

In [3]:
batch_size = 256

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

Shape of X [N, C, H, W]: torch.Size([256, 1, 28, 28])
Shape of y: torch.Size([256]) torch.int64


In [4]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [5]:
num_clients = 5
train_size = len(training_data)
# indices = list(range(train_size))

RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
torch.random.manual_seed(RANDOM_SEED)
indices = torch.randperm(train_size).tolist()

subset_size = train_size // num_clients
client_subsets = [] 
for i in range(num_clients):
    start_idx = i * subset_size
    end_idx = start_idx + subset_size

    if i == num_clients - 1:
        end_idx = train_size

    subset_indices = indices[start_idx:end_idx]
    client_subsets.append(Subset(training_data, subset_indices))

client_loaders = [DataLoader(sub, batch_size=batch_size, shuffle=True) for sub in client_subsets]

In [6]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(), 
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits 

model = NeuralNetwork().to(device)
model

NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)

In [28]:
# We now have 5 different datasets, each with some sort of representation of the data that is unknown, ie, we have no 
# statistical information on the data that each of these clients would have
# We now need to implement variations of the 3 protocols, namely, the encoding protocol, the communication protocol and the decoding protocol


# Encoders
def variable_size_encoder(grad_vectors, mu, p=0.1):
    # Lets take p = 0.1
    new_grad_vectors = []
    with torch.no_grad():
        for i in range(len(grad_vectors)):
            mask = torch.rand_like(grad_vectors[i], device=grad_vectors[i].device) < p
            Y = torch.empty_like(grad_vectors[i], device=grad_vectors[i].device)
            Y[mask] = (grad_vectors[i][mask] - mu[i] * (1-p))/p
            Y[~mask] = mu[i]
            new_grad_vectors.append(Y)
    return new_grad_vectors

def fixed_size_encoder(grad_vectors, mu, k=10):
    # k can vary
    new_grad_vectors = []
    with torch.no_grad():
        for i in range(len(grad_vectors)):
            shape = grad_vectors[i].shape
            # Flattening the parameters to permutate over them
            flat_grad = grad_vectors[i].view(-1)
            C = shape[-1]
            # Get the length of the flat_grad array
            d = flat_grad.numel()
            # Shuffle the list [1, 2, ... d] and get the first k elements
            indices = torch.randperm(C, device=flat_grad.device)[:k]
             
            mask = torch.zeros(d, dtype=torch.bool, device=flat_grad.device)
            mask[indices] = True
            
            Y = torch.empty_like(flat_grad)
            # Encode the parameters
            chosen_vals = (d/k)*flat_grad[mask] - ((d-k)/k)*mu[i]
            Y[mask] = chosen_vals
            Y[~mask] = mu[i]
            Y = Y.view(shape)
            new_grad_vectors.append(Y)
    return new_grad_vectors
            
            
# Decoders 
def averaging_decoder(grad_vectors_list):
    if isinstance(grad_vectors_list, list):
        grad_vectors_list = torch.stack(grad_vectors_list, dim=0)
    return torch.mean(grad_vectors_list, dim=0)

# Communication protocols
def sparse_for_variable_size_encoder(encoded_vectors, mu):
    final_vectors = []
    with torch.no_grad():
        for i in range(len(encoded_vectors)):
            flat_vector = encoded_vectors[i].view(-1)
            mask = flat_vector != mu[i]
            # vals = encoded_vectors[i][mask]
            indices = torch.nonzero(mask, as_tuple=False).view(-1)
            values = flat_vector[mask]
            final_vectors.append(list(zip(indices, values)))

    
    return final_vectors, mu
    
def sparse_for_fixed_size_encoder(grad_vectors):
    pass

parameters = list(model.parameters())
mu_1 = []
with torch.no_grad():
    for p in parameters:
        mu_1.append(torch.mean(p))

encoded_vectors = fixed_size_encoder(parameters, mu_1)
sparse_for_variable_size_encoder(encoded_vectors, mu_1)

([[(tensor(23, device='cuda:0'), tensor(-627.2519, device='cuda:0')),
   (tensor(28, device='cuda:0'), tensor(864.1521, device='cuda:0')),
   (tensor(87, device='cuda:0'), tensor(977.1844, device='cuda:0')),
   (tensor(166, device='cuda:0'), tensor(-1088.3890, device='cuda:0')),
   (tensor(225, device='cuda:0'), tensor(1006.3676, device='cuda:0')),
   (tensor(367, device='cuda:0'), tensor(-491.5223, device='cuda:0')),
   (tensor(400, device='cuda:0'), tensor(302.8898, device='cuda:0')),
   (tensor(584, device='cuda:0'), tensor(1104.7645, device='cuda:0')),
   (tensor(600, device='cuda:0'), tensor(-789.7195, device='cuda:0')),
   (tensor(694, device='cuda:0'), tensor(-803.2892, device='cuda:0'))],
  [(tensor(53, device='cuda:0'), tensor(-1.5444, device='cuda:0')),
   (tensor(181, device='cuda:0'), tensor(-0.3436, device='cuda:0')),
   (tensor(208, device='cuda:0'), tensor(1.4901, device='cuda:0')),
   (tensor(229, device='cuda:0'), tensor(0.2433, device='cuda:0')),
   (tensor(249, devic

In [11]:
a = torch.FloatTensor([[1, 2, 3], [2, 3, 4], [4, 5, 6]])
averaging_decoder(a)

tensor([2.3333, 3.3333, 4.3333])

In [15]:
parameters = list(model.parameters())
variable_size_encoder(parameters)

torch.Size([512, 784])
torch.Size([512, 784])
torch.Size([512])
torch.Size([512])
torch.Size([512, 512])
torch.Size([512, 512])
torch.Size([512])
torch.Size([512])
torch.Size([10, 512])
torch.Size([10, 512])
torch.Size([10])
torch.Size([10])


In [42]:
models = []


6