# Servers

> Fill in a module description here

In [None]:
#| default_exp servers

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#|export
from copy import deepcopy
import os
import numpy as np
from collections import defaultdict
import torch
from fastcore.utils import *
from peft import *
from fedai.models import *
from fedai.utils import *
from fedai.clients import Client_mira, BaseClient

## BaseServer

The `BaseServer` class can be seen as an abstract base class (ABC), that enables us to define the core structure of the server and operations carried out by it. The server has the following attributes:
- `cfg`: An object that contains the configurations tied to this server (things like learning rate of the clients, optimizer, log directory and so on).
- `model`: The base model used for all other child classes. Gets transmitted to clients at the beginning of every round in traditional FL. In MTL-based or personalization settings, a deep copy of the model is distributed only at the first round, and clients iterate over their respective model accordingly after that.
- `holdout_ds`: A dataset to evaluate the performance at the server. This is optional and can be passed as **None**.
- `lst_data_dict`: A list of dictionaries. Every dictionary belongs to one client and contains two keys `train` and `test`.
- `client_list`: An object of type `LazyList`. Instantiate all clients. typically done in lazy way, which means that clients are only instantiated when accessd (via their index). This is the prefered when working with large models. It can be accessed as a regular list. Once accessed, the client will be loaded into memory.



In [None]:
#| export
class BaseServer:

    def __init__(self, cfg, lst_data_dict, model, holdout_ds, client_class):
        self.cfg = cfg
        self.lst_data_dict = lst_data_dict
        self.model = model
        self.holdout_ds = holdout_ds
        self.client_list = LazyList(self, client_class)  # type: ignore # noqa: F405
        self.latest_model_iter = dict()
        self.__str__ = self.__repr__
       
    def __str__(self) -> str:
        return f'''Server: {self.__class__.__name__}'''

The following two methods are common among all servers:
- `send`: Send a model from the server to the client.
- `aggregate`: The aggregation function, what happens when the server recieves the updates from the clients.

In [None]:
#| export
@patch
def send(self: BaseServer, client: BaseClient):  # noqa: F811
    
    if client.idx in self.latest_model_iter:
        comm_round = self.latest_model_iter[client.idx]
        model_path = os.path.join(self.cfg.output_dir, str(comm_round), 
                                  "local_output_{}".format(client.idx),
                                  "pytorch_model.pth")
    else:
        model_path = ''

    with torch.no_grad():
        client.model = deepcopy(self.model)
    
    if os.path.exists(model_path):
        if isinstance(client.model, torch.nn.Module):
            client.model.load_state_dict(torch.load(model_path, map_location='cpu'))
        elif isinstance(client.model, PeftModel): # noqa: F405
            set_peft_model_state_dict(client.model,  # noqa: F405
                                  torch.load(model_path, map_location='cpu'),
                                  "default")
    return client.model

- `client_selection` : Client selection is done at this function. The bare minimum is a random uniform selection. Returns a list of lists of all the selected indices. Every inner list reprensts the indices of the selected clients at a specefic round.


In [None]:
#| export
@patch
def client_selection(self: BaseServer):
    client_indices_rounds = []
    for _ in range(self.cfg.rounds):
        client_indices_rounds.append(np.random.choice(a= np.arange(self.cfg.num_clients), 
                                                      size=int(self.cfg.num_clients * self.cfg.m), 
                                                      replace=False))
        
    return client_indices_rounds

- `get_selected_client`: Access the selected client. takes a list of selected clients at current round and returns a generator which contains the respected clients. Since we are intializing in a lazy manner, we just need to write a function that returns a `generator`.

In [None]:
#| export
@patch
def get_selected_client(self: BaseServer,
                        client_indices: list) : # a list of current round's selected clients
    
    for idx in client_indices:
        yield self.client_list[idx]  # Lazily access the client and return a generator


## Server MIRA

Mira extends the baseserver's capabilities by initializing the model using the `get_model`, which instantiate an LLM from `HuggingFace` library.

When working with Large models in the case of MTL (where every client has a unique model), most of the time, you cannot hold more than one or two models in memory an at the same time, you must start client's local training from the latest locally trained model of this particular client. One way to implement such constraint is to **offload** the model between the *memory* and the *disk*. Although this might be slow a little bit, it might be the only possibility in certain cases, especially with very large models. 



In [None]:
#| export
class Server_mira(BaseServer):
    def __init__(self, cfg, lst_data_dict, model, holdout_ds,client_class, **kwargs):
        super().__init__(cfg, lst_data_dict, model, holdout_ds, client_class)
        
        self.model = get_model(self.cfg)
        
        for key, value in kwargs.items():
            setattr(self, key, value)

        self.model.resize_token_embeddings(len(self.tokenizer))

There are several parameters that constitutes to `MIRA`. Since it operates on a graph, there is the laplacian matrix, the regularization parameter (how much weight we give to collaobration versus non-collaboration). The following function `init_sim_matrix` is responsible for intializing the values of those parameters.

In [None]:
#| export
@patch
def init_sim_matrix(self: Server_mira):
    N = self.cfg.num_clients
    b = np.random.uniform(0,1,size=(N,N))
    b_symm = (b + b.T)/2
    b_symm[b_symm < 0.25] = 0
    self.alk_connection = b_symm


`send`: Sends a model from the server to the given client. It takes the given client as an input and returns the client's model. Uses `set_peft_model_state_dict` to change the `sate_dict` of the client's model to the `state_dict` from either the server (typical FL) or a loaded one from the disk (personalized/multi-task learning). Below is a detailed explaination of how the loading happens.

### Offloading models

To implement this offloading, you need to keep track of client's model's paths, the directory in which the latest model of the client resides. `latest_model_iter` is a dictionary that does this exactly. it contains keys of clients that participated in the training process along with the lates model's directrory of those clients. The path of the models is chosen as `self.cfg.output_dir/comm_round/local_output_{client.idx}/pytorch_model.pth"` and the `comm_round` is the value in the dictionary `latest_model_iter`. To give an example of this, the following is a dctionary after the second communication round (assuming we sample two clients per round):

```python 
latest_model_iter = {
                        5: 1
                        3: 2,
                        4: 1,
                        6: 2
                    }

```

This can be interpreted as: in the first round, clients (5, 4) were selected for training and after they trained, the latest model of them resides in the `output_dir/1/local_output_5/pytorch_model.pth` and `output_dir/1/local_output_4/pytorch_model.pth` respectively. On the other hand, in the second round, clients (3, 6) we selected for training, and after finishing training, the latest model of them resides in `output_dir/2/local_output_3/pytorch_model.pth` and `output_dir/2/local_output_6/pytorch_model.pth`. Note that there is no keys for client number `1` or `2`, which indicates that they have not appeared in the training process up until the current round.

> NOTE: In the next version releases, we will refactor this to use the approach [here](https://pytorch.org/tutorials/intermediate/autograd_saved_tensors_hooks_tutorial.html). This might be less costy, but we need to test it.

### Aggregation

The aggregation at the server is done using the following formula

$$\bm{W}_k^{(t+1)} = \bm{W}_{k, R}^{(t)}-\eta \lambda \sum_{\ell \in \mathcal{N}_k} a_{k \ell}\left( \bm{W}_{k, R}^{(t)}- \bm{W}_{\ell, R}^{(t)}\right).$$

In [None]:
#| export
@patch
def aggregate(self: Server_mira, selected_clients_indices, comm_round):
    global_lr = float(self.cfg.lr) * float(self.cfg.local_step)

    for i, client_id in enumerate(selected_clients_indices):
        client_path = os.path.join(self.cfg.output_dir, str(comm_round), f"local_output_{client_id}", "pytorch_model.pth")
        client_state_dict = torch.load(client_path, map_location=self.device)

        client_diff = defaultdict(lambda: torch.tensor(0.0).to(self.device))

        for key in client_state_dict.keys():
            client_diff[key] = torch.zeros_like(client_state_dict[key]).to(self.device)

        for j, other_client_id in enumerate(selected_clients_indices):
            if i != j:
                other_client_path = os.path.join(self.cfg.output_dir, str(comm_round), f"local_output_{other_client_id}", "pytorch_model.pth")
                other_client_state_dict = torch.load(other_client_path, map_location=self.device)

                weight = self.alk_connection[int(client_id)][int(other_client_id)]
                for key in client_state_dict.keys():
                    client_diff[key].data += weight * (client_state_dict[key].data.clone() - other_client_state_dict[key].data.clone())

        for key in client_state_dict:
            client_state_dict[key].data -=  global_lr * self.cfg.lambda_ * client_diff[key].data

        self.update(client_state_dict, comm_round, client_id)


In [None]:
#| export
@patch
def update(self: Server_mira, client_state_dict: dict, comm_round: int, client_id: int) -> None:
    save_dir = os.path.join(self.cfg.output_dir, str(comm_round + 1), f"local_output_{client_id}")
    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, "pytorch_model.pth")
    torch.save(client_state_dict, save_path)
    set_peft_model_state_dict(self.model, client_state_dict, "default")  # noqa: F405
    self.model.save_pretrained(save_dir)


### Testing the trained model(s)

`test` is responsible for the final evaluation for different **text-based** metrics like `rouge` and `BELU`. After all the federated rounds are done, we end up with a model per client (or the global intitial model for clients that has not particpated in the fL training). This function will loop over te clients to wvaluate each client ad then report a dictionary of metrics values per every client.

In [None]:
#| hide
import nbdev
nbdev.nbdev_export()