### Limit Scope to Only AGR Testing, No Sockets, No Implementation, Just Logits Aggregation

In [14]:
# Imports for model
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchvision import transforms
import torchvision.datasets as datasets
from torch.utils.data import Dataset, DataLoader, random_split
import torch
from torch.optim import Adam, SGD

In [15]:
NUM_MODELS = 9 # Try to keep 60,000 divisiable by NUM_MODELS + 1
EPOCHS_PER_ROUND = 10
model_list = []

In [16]:
# Data fetching
transform = transforms.ToTensor()
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# randomly split dataset into models
# dataset 0 is the servers public data
training_sizes = [len(mnist_trainset) // (NUM_MODELS + 1)] * (NUM_MODELS + 1)
training_datasets = random_split(dataset=mnist_trainset, lengths=training_sizes)
# Load testsset

batch_size = 16

training_loaders = [
    DataLoader(training_datasets, batch_size=batch_size, shuffle=True)
    for training_datasets in training_datasets
]

testing_loader = DataLoader(mnist_testset, batch_size=batch_size, shuffle=True)


In [None]:
# FOR THE PUBLIC DATA SET
class PublicDistillationDataset(torch.utils.data.Dataset):
    def __init__(self, base_dataset, soft_labels):
        self.base = base_dataset
        self.soft_labels = soft_labels

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx):
        x, _ = self.base[idx]   # ignore original label
        return x, self.soft_labels[idx]

In [None]:
# FOR CRONUS AGR
def distillation_loss(student_logits, teacher_logits, T=1.0):
    return F.kl_div(
        F.log_softmax(student_logits / T, dim=1),
        F.softmax(teacher_logits / T, dim=1),
        reduction="batchmean"
    ) * (T * T)


In [None]:
# BENIGN MODEL DEFINITION
class MnistModel(nn.Module):
  def __init__(self) -> None:
      super().__init__()
      self.lin1 = nn.Linear(784, 256)
      self.lin2 = nn.Linear(256, 64)
      self.lin3 = nn.Linear(64, 10)

  def forward(self, X):
      x1 = F.relu(self.lin1(X))
      x2 = F.relu(self.lin2(x1))
      x3 = F.relu(self.lin3(x2))
      return x3

  # Fit function
  def fit(self, dataloader, optimizer, loss_fn, epochs, public_loader = None):

    for epoch in range(epochs):
      for x, y in dataloader:
        x = x.view(x.size(0), -1)
        optimizer.zero_grad()

        preds = self(x)
        loss = loss_fn(preds, y)

        loss.backward()
        optimizer.step()

      if public_loader is not None: # IF A PUBLIC DATA LOADER IS PRESENT USE DISTILATION LOSS WITH PUBLIC DATA (BECAUSE IT ISNT JUST ONE INT)
        for x, y in public_loader:
          x = x.view(x.size(0), -1)
          optimizer.zero_grad()

          preds = self(x)
          loss = distillation_loss(preds, y)

          loss.backward()
          optimizer.step()

#TODO: Malicious Models

In [None]:
# INITIAL TRAINING
models = []
models.append(None)
for i in range(1, NUM_MODELS + 1):
    print(i)
    models.append(MnistModel())
# Adam for local training phase and first 50 epochs of collaborative phase
# SGD is for the last 50 epochs of the collaborative phase

# Loss function is same for all epochs
loss_fn = nn.CrossEntropyLoss()

#FIRST ROUND
for i in range(1, NUM_MODELS + 1):
    print(f"Training model {i}")

    models[i].fit(
        dataloader=training_loaders[i],
        optimizer=SGD(models[i].parameters(), lr=1e-3),
        loss_fn=loss_fn,
        epochs=EPOCHS_PER_ROUND
    )

1
2
3
4
5
6
7
8
9
Training model 1
Training model 2
Training model 3
Training model 4
Training model 5
Training model 6
Training model 7
Training model 8
Training model 9


In [None]:
# INITIAL Predictions
for i in range(1, NUM_MODELS + 1):
    models[i].eval()

all_model_preds = []

with torch.no_grad():
    for i in range(1, NUM_MODELS + 1):
        print(f'Getting predictions from model {i}\n')
        model = models[i]
        model_preds = []

        for x, _ in training_loaders[0]:   # labels optional
            x = x.view(x.size(0), -1)
            preds = model(x)      # logits
            model_preds.append(preds)

        # shape: (num_agr_samples, 10)
        model_preds = torch.cat(model_preds, dim=0)
        all_model_preds.append(model_preds)

all_model_preds[0].shape

Getting predictions from model 1

Getting predictions from model 2

Getting predictions from model 3

Getting predictions from model 4

Getting predictions from model 5

Getting predictions from model 6

Getting predictions from model 7

Getting predictions from model 8

Getting predictions from model 9



torch.Size([6000, 10])

In [None]:
# AGR FOR LOGITS AND CUSTOM DATA LOADER TO PAS TO FIT

def cronus_robust_mean(
    Y,              # Tensor [n_models, num_classes]
    eps=0.2,        # fraction of adversaries (<= 0.5)
    iters=2
):
    """
    Implements Cronus aggregation for ONE public sample.
    Y: logits from all models for one datapoint
    """
    S = Y.clone()

    for _ in range(iters):
        mu = S.mean(dim=0)
        centered = S - mu

        # covariance
        cov = centered.T @ centered / S.size(0)

        # top eigenvector
        eigvals, eigvecs = torch.linalg.eigh(cov)
        v = eigvecs[:, -1]          # largest eigenvalue direction
        lam = eigvals[-1]

        # stopping condition (paper uses threshold 9)
        if lam <= 9:
            break

        # project onto principal direction
        proj = torch.abs(centered @ v)

        # trim eps/2 fraction (deterministic)
        keep = int((1 - eps/2) * S.size(0))
        _, idx = torch.topk(proj, keep, largest=False)
        S = S[idx]

    return S.mean(dim=0)

def cronus_aggregate_all(
    all_model_preds,   # list of [num_public, num_classes]
    eps=0.2
):
    """
    Returns aggregated soft labels Y_bar:
    Tensor [num_public, num_classes]
    """
    n_models = len(all_model_preds)
    num_public, num_classes = all_model_preds[0].shape

    Y_bar = torch.zeros(num_public, num_classes)

    for k in range(num_public):
        # collect predictions for datapoint k
        Yk = torch.stack([all_model_preds[i][k] for i in range(n_models)])
        Y_bar[k] = cronus_robust_mean(Yk, eps=eps)

    return Y_bar


Y_bar = cronus_aggregate_all(all_model_preds, eps=0.2)
public_dataset = PublicDistillationDataset(
    training_datasets[0],  # server/public split
    Y_bar
)

public_loader = DataLoader(
    public_dataset,
    batch_size=16,
    shuffle=True
)


In [None]:
#ROUND 2 TRAINING... TODO: Inf
for i in range(1, NUM_MODELS + 1):
    print(f"Training model {i}")

    models[i].fit(
        dataloader=training_loaders[i],
        optimizer=SGD(models[i].parameters(), lr=1e-3),
        loss_fn=loss_fn,
        epochs=EPOCHS_PER_ROUND,
        public_loader = public_loader
    )

Training model 1
Training model 2
Training model 3
Training model 4
Training model 5
Training model 6
Training model 7
Training model 8
Training model 9
