In [1]:
import torch 
import torch.nn as nn
from tqdm import tqdm
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import Subset, DataLoader
import copy

### 1. Creation du Dataset

In [2]:
# Original transform
original_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# Download and load the original dataset
original_trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=original_transform)

# Function to create rotated dataset
def create_rotated_dataset(dataset, rotation_angle):
    rotated_transform = transforms.Compose([transforms.ToTensor(),
                                            transforms.Normalize((0.5,), (0.5,)),
                                            transforms.RandomRotation(rotation_angle)])
    rotated_dataset = copy.deepcopy(dataset)
    rotated_dataset.transform = rotated_transform
    return rotated_dataset

# Function to create private label dataset
def create_private_label_dataset(dataset, cluster_index):
    private_label_transform = transforms.Lambda(lambda y: (y + cluster_index) % 10)
    private_label_dataset = copy.deepcopy(dataset)
    private_label_dataset.targets = private_label_transform(private_label_dataset.targets)
    return private_label_dataset

# Create rotated datasets for each cluster
rotated_datasets = [create_rotated_dataset(original_trainset, k * 90) for k in range(4)]

# Create private label datasets for each cluster
datasets = [create_private_label_dataset(rotated_datasets[k], k) for k in range(4)]

# Split each dataset into clusters
cluster_size = len(original_trainset) // 5

clustered_datasets = [Subset(datasets[k], range(i * cluster_size, (i + 1) * cluster_size)) for k in range(4) for i in range(5)]

# Create dataloaders for each cluster
# rotated_dataloaders = [DataLoader(dataset, batch_size=64, shuffle=True) for dataset in clustered_rotated_datasets]
dataloaders = [DataLoader(dataset, batch_size=64, shuffle=True) for dataset in clustered_datasets]
print(len(dataloaders))

20


### 2. ALGORITHME DU PAPIER

In [6]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 4, (5,5), padding=2)
        self.pool1 = nn.MaxPool2d((2,2))
        self.conv2 = nn.Conv2d(4, 8, (5,5), padding=2)
        self.pool2 = nn.MaxPool2d((2,2))
        self.conv3 = nn.Conv2d(8, 16, (5,5), padding=2)
        self.fc = nn.Linear(16*7*7, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(-1,1,28,28)
        x = self.relu(self.conv1(x))
        x = self.pool1(x)
        x = self.relu(self.conv2(x))
        x = self.pool2(x)
        x = self.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

models = {}

for i in range(20):
    models[f'client{i}'] = Model()

In [7]:
# THRESHOLD CLUSTERING

def threshold_clustering(Z, K, initial_centers):
    """
    Perform Threshold-Clustering on a set of points using PyTorch.

    Parameters:
    Z (List): list of moments to be clustered, shape (N, d)
    K (int): Number of clusters
    initial_centers (Tensor): Initial cluster centers, shape (K, d)
    tau (Tensor): Radius for each cluster in each round, shape (K, M)

    Returns:
    Tensor: Final cluster centers, shape (K, d)
    """

    centers = [center.clone() for center in initial_centers]
    tau = np.zeros(K)

    for k in range(K):
        distances = torch.stack([(z - centers[k]).norm(dim=0) for z in Z])
        # tau[k] is the distance to th
        # e Kth closest point
        tau[k] = torch.topk(distances, len(Z) // K)[0][-1]
        within_radius = distances <= tau[k]
        centers[k] = torch.mean(torch.stack([z * within_radius[i] + centers[k] * ~within_radius[i] for i, z in enumerate(Z)]), dim=0)

    # Find the closest center for each point
    closest_centers = torch.stack([torch.stack([(z - center).norm(dim=0) for z in Z]) for center in centers])
    closest_center_indices = torch.argmin(closest_centers, dim=0)
    return centers, closest_center_indices

In [13]:
def trainer(dataloaders, models, epoch=1, batch_size=16, rate=1e-3, momentum=0.9):

    optimizers = [torch.optim.SGD(models[f'client{i}'].parameters(), lr=rate, momentum = 0.9) for i in range(20)]

    losses_fn = [nn.CrossEntropyLoss()  for i in range(20)]

    moments = []

    losses = {}

    for items in zip(*dataloaders):
            
            for i, item in enumerate(items):
                optimizers[i].zero_grad()
                sample, target = item
                # compute loss and grad
                losses[f'client{i}'] = losses_fn[i](models[f'client{i}'](sample), target)
                losses[f'client{i}'].backward()
                #compute grad

                grad = torch.cat([param.grad.view(-1) if param.grad is not None else torch.zeros_like(param).view(-1) for param in models[f'client{i}'].parameters()])

                moments.append(grad)

            break
    
    center_clusters = [moments[i] for i in np.random.choice(20, 4, replace=False)]

    for t in range(epoch):
        compt = 0
        for items in tqdm(zip(*dataloaders)):
            for i, item in enumerate(items):
                optimizers[i].zero_grad()
                sample, target = item

                losses[f'client{i}'] = losses_fn[i](models[f'client{i}'](sample), target)

                grad = torch.cat([param.grad.view(-1) if param.grad is not None else torch.zeros_like(param).view(-1) for param in models[f'client{i}'].parameters()])

                moments[i] = momentum * grad + (1 - momentum) * moments[i]

            if (compt%10 == 0) : 
                center_clusters, clustering = threshold_clustering(moments, 4, center_clusters)

            for i in range(20):

                losses[f'client{i}'].backward()
                k = 0
                # set the grad to be the same as in center_clusters[clustering[i]]
                for param in models[f'client{i}'].parameters():
                    param.grad = center_clusters[clustering[i]][k : k + param.grad.numel()].view(param.grad.shape)
                    k += param.grad.numel()
                
                optimizers[i].step()

            compt += 1

            if compt == 100:
                break

In [14]:
trainer(dataloaders, models, epoch=1, batch_size=16, rate=1e-3, momentum=0.9)

99it [01:07,  1.46it/s]
