<a href="https://colab.research.google.com/github/aWolander/Cryptography/blob/main/FederatedLearning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from __future__ import annotations
import copy
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Subset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import random
import numpy as np
# vits16 = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')

# print(vits16)


In [None]:
def main():
    K= 2
    N = 10
    CIFAR100 = datasets.CIFAR100(
      root="data",
      download=True,
      transform=ToTensor()
    )
    train_dataset, validate_dataset, test_dataset = torch.utils.data.random_split(CIFAR100, [0.7,0.15,0.15])
    client_datasets = split_data_non_iid(train_dataset, K, N)
    test_split(client_datasets)

def create_label_indexing(dataset):
    label_index = {i: [] for i in range(100)}
    for idx, (_, label) in enumerate(dataset):
        label_index[label].append(idx)
    return label_index




'''
def split_data_non_iid(dataset, K, N_c=20, seed=42):
    # as far I can tell N_c does exactly nothing
    random.seed(seed)
    label_index = create_label_indexing(dataset)
    client_indices = [[] for _ in range(K)]

    class_to_clients = {label: set() for label in range(100)} # why sets?
    for label in range(100):
        selected_clients = random.sample(range(K), k=max(1, K // 5)) # why define variable k? why K // 5?
        class_to_clients[label].update(selected_clients)

    # why are these two for loops?
    client_to_classes = {client: set() for client in range(K)} # why do you need both client_to_classes and class_to_clients?
    for label in range(100):
        for client in class_to_clients[label]:
            if len(client_to_classes[client]) < N_c:
                client_to_classes[client].add(label)

    # what
    all_labels = list(range(100)) # ???
    for client in range(K):
        while len(client_to_classes[client]) < N_c:
            label = random.choice(all_labels)
            client_to_classes[client].add(label) # is this dict even used?
            class_to_clients[label].add(client)

    for label, indices in label_index.items():
        random.shuffle(indices) # why? this is done at the end
        clients = list(class_to_clients[label])
        num_clients = len(clients)

        # im pretty sure this is just the split(l,n) function
        split_size = len(indices) // num_clients
        for i, client in enumerate(clients):
            start = i * split_size
            end = (i + 1) * split_size if i < num_clients - 1 else len(indices)
            client_indices[client].extend(indices[start:end])

    for indices in client_indices:
        random.shuffle(indices)

    return [Subset(dataset, indices) for indices in client_indices]
'''

def split(l: list, n: int) -> list[list]:
    '''splits a list into n parts'''
    split_list = []
    for i in range(0, n):
        split_list.append(l[i::n])
    return split_list

def split_data_iid(dataset, K):
    index_split = [[] for x in range(K)]
    label_index = create_label_indexing(dataset)

    for label in range(0,100):
        split_indices = split(label_index[label],K)
        for client in range(0, K):
            index_split[client] += split_indices[client]

    for client_indices in index_split:
        random.shuffle(client_indices)
    return [Subset(dataset, indices) for indices in index_split]

def test_split(client_datasets):
    bottom = np.zeros(100)
    for client_id, client_dataset in enumerate(client_datasets):
        occurrences = np.zeros(100)
        for datapoint in client_dataset:
            label = datapoint[1]
            occurrences[label] += 1
        #print(occurrences)
        print(f"Client {client_id}: No. of nonzero elements: {np.count_nonzero(occurrences)} | Avg., stdev. of nonzero elements: {occurrences[occurrences>0].mean()}, {occurrences[occurrences>0].std()}") # should be N_c non-zero elements!
        plt.bar(range(100), occurrences, bottom=bottom, label=client_id)
        plt.xlabel("Class label")
        plt.ylabel("Number of samples")
        bottom += occurrences

    plt.show()

main()

# def test_split_index(dataset):
#     K= 2
#     N = 5
#     client_data_split = split_data_non_iid(dataset, K, N)
#     for client_data in client_data_split:
#         occurences = {}
#         for index in client_data:
#             occurences[dataset[index][1]] = occurences.setdefault(dataset[index][1], 0) + 1
#         print(occurences)
#         plt.hist(occurences, stacked=True, bins=100)
#     plt.show()


In [None]:
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
class DinoFullModel(nn.Module):
    def __init__(self):
      super().__init__()
      # Load the full DINO model (backbone + head)
      self.model = torch.hub.load('facebookresearch/dino:main', 'dino_vits16').to(device)

      self.learning_rate = 1e-2
      self.epochs = 10

      self.loss_fn = nn.CrossEntropyLoss()
      self.optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate)

      self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        self.optimizer, T_max=self.epochs
      )

    def forward(self, x):
      return self.model(x)

    def train_model(self, dataloader):
      size = len(dataloader.dataset)
      # Set the model to training mode - important for batch normalization and dropout layers
      # Unnecessary in this situation but added for best practices
      self.train()
      for epoch in range(self.epochs):
        current = 0
        print(f"-------------------------------\nEpoch {epoch+1}\n-------------------------------")
        for (X, y) in dataloader:
          # Compute prediction and loss
          X = X.to(device)
          y = y.to(device)
          pred = self.model(X).to(device)
          loss = self.loss_fn(pred, y)

          # Backpropagation
          loss.backward()
          self.optimizer.step()
          self.optimizer.zero_grad()
          current += len(X)
          if current % 5000 == 0:
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
            loss = loss.item()

        self.scheduler.step()



    def test_model(self, dataloader):
      self.model.eval()
      total_loss = 0.0
      total_correct = 0
      total_samples = 0

      with torch.no_grad():
        for inputs, targets in dataloader:
          inputs, targets = inputs.to(device), targets.to(device)

          outputs = self.model(inputs)
          loss = self.loss_fn(outputs, targets)
          total_loss += loss.item() * inputs.size(0)  # total loss, not average

          # Get predicted class
          preds = outputs.argmax(dim=1)
          total_correct += (preds == targets).sum().item()
          total_samples += targets.size(0)

      avg_loss = total_loss / total_samples
      accuracy = total_correct / total_samples

      #print(f"Validation Loss: {avg_loss:.4f}, Correct: {total_correct} out of {len(dataloader)}, Accuracy: {accuracy:.2%}")

      return avg_loss, total_correct, accuracy

def main():

  train_dataset, validate_dataset, test_dataset = torch.utils.data.random_split(CIFAR100, [0.7,0.15,0.15])

  train_dataloader = DataLoader(train_dataset, batch_size=100)
  test_dataloader = DataLoader(test_dataset, batch_size=100)
  validate_dataloader = DataLoader(validate_dataset, batch_size=100)

  model = DinoFullModel().to(device)
  #for param in model.parameters():
  #    print(param.shape)

  model.train_model(train_dataloader)
  model.test_model(test_dataloader)
  print("Done!")

if __name__=="__main__":
  main()

In [None]:
class FL_client(DinoFullModel):
  '''
  Basically just the single model, but with addition and scaling functionality.
  To facilitate FedAvg more easily.
  '''
  def __init__(self):
    super().__init__()

  def __add__(self, other: FL_client) -> FL_client:
    assert isinstance(other, FL_client)
    temp_weights = []
    temp_biases = []
    for (name, self_param), (_,other_param) in zip(self.named_parameters(), other.named_parameters()):
      self_param.to(device)
      other_param.to(device)
      if 'weight' in name:
        temp_weights.append(self_param + other_param)
      elif 'bias' in name:
        temp_biases.append(self_param + other_param)
    temp_client = copy.deepcopy(self) # pointers??
    with torch.no_grad():
      i = 0
      for name, param in temp_client.model.named_parameters():
        if 'weight' in name:
          param.copy_(temp_weights[i])
          i += 1

      j = 0
      for name, param in temp_client.model.named_parameters():
        if 'bias' in name:
          param.copy_(temp_biases[j])
          j += 1
    # order?
    return temp_client

  def __mul__(self, multiplier: float|int) -> FL_client:
    assert isinstance(multiplier, float|int)
    temp_weights = []
    temp_biases = []
    for (name, self_param) in self.named_parameters():
      self_param.to(device)
      if 'weight' in name:
        temp_weights.append(self_param * multiplier)
      elif 'bias' in name:
        temp_biases.append(self_param * multiplier)

    temp_client = copy.deepcopy(self)
    #order?
    with torch.no_grad():
      i = 0
      for name, param in temp_client.model.named_parameters():
        if 'weight' in name:
          param.copy_(temp_weights[i])
          i += 1

      j = 0
      for name, param in temp_client.model.named_parameters():
        if 'bias' in name:
          param.copy_(temp_biases[j])
          j += 1
    return temp_client

  def __rmul__(self, multiplier: float|int) -> FL_client:
    return self.__mul__(multiplier)

  def __sub__(self, other: FL_client) -> FL_client:
    assert isinstance(other, FL_client)
    return self + (-other)



class FL_server():
  def __init__(self, K):
    self.model = torch.hub.load('facebookresearch/dino:main', 'dino_vits16').to(device)

    self.K = K
    self.training_rounds = 10

    self.loss_fn = nn.CrossEntropyLoss()
    self.clients = []
    for i in range(self.K):
      self.clients.append(FL_client()) # pointers?


  def forward(self, x):
    return self.model(x)

  def FedAvg(self, dataloaders_list, validate_dataloader, C):
    for round in range(self.training_rounds):
      m = max(C*self.K, 1)
      rand_set_clients = random.sample(range(self.K), m)
      client_model_sum = None
      m_t = 0
      for client_id in rand_set_clients:
        temp_client = self.clients[client_id]
        temp_loader = dataloaders_list[client_id]

        print(f"----------- TRAIN LOOP FOR CLIENT {client_id}. -----------")
        temp_client.train_model(temp_loader)
        n_k = len(temp_loader.dataset)
        m_t += n_k
        # Easier than instantiating an empty model
        if client_model_sum is None:
          client_model_sum = n_k*temp_client
        else:
          client_model_sum = client_model_sum + n_k*temp_client

      client_model_sum = (1/m_t) * client_model_sum
      self.model = client_model_sum.model
      print(f"----------- VALIDATION FOR SERVER. -----------")
      self.test_model(validate_dataloader)


  def test_model(self, dataloader):
    self.model.eval()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    with torch.no_grad():
      for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)

        outputs = self.model(inputs)
        loss = self.loss_fn(outputs, targets)
        total_loss += loss.item() * inputs.size(0)  # total loss, not average

        # Get predicted class
        preds = outputs.argmax(dim=1)
        total_correct += (preds == targets).sum().item()
        total_samples += targets.size(0)

    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples

    print(f"Validation Loss: {avg_loss:.4f}, Correct: {total_correct} out of {len(dataloader)}, Accuracy: {accuracy:.2%}")

    return avg_loss, total_correct, accuracy



In [None]:
'''
TODO:
Batch normalization, probably. I think I read somewhere that the facebook model needs this
Version control and checkpointing
Testing and hyperparameter tuning
gradient mask TaLoS thing. Also expanding this to FL.
Plot accuracy and loss method in DinoFullModel
'''

K=5
FL = FL_server(K)
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
CIFAR100 = datasets.CIFAR100(
  root="data",
  download=True,
  transform=ToTensor()
)
train_dataset, validate_dataset, test_dataset = torch.utils.data.random_split(CIFAR100, [0.7,0.15,0.15])


split_dataset = split_data_iid(train_dataset, K)

train_dataloader = [DataLoader(subset, batch_size=100) for subset in split_dataset]

test_dataloader = DataLoader(test_dataset, batch_size=100)
validate_dataloader = DataLoader(validate_dataset, batch_size=100)

FL.FedAvg(train_dataloader, validate_dataloader, 1)
FL.test_model(test_dataloader)

