# Clients

> The core abstraction for different FL Clients.

In [None]:
#| default_exp federated.agents

In [None]:
#| hide
from nbdev.showdoc import *
from fastcore.test import *

In [None]:
#| export
from fastcore.utils import *
import os
import json
from collections import defaultdict,OrderedDict
from copy import deepcopy
from enum import Enum
import torch
from peft import *
from fedai.trainers import *
from fedai.utils import *
from fedai.data.core import LLMDataCollator
from transformers import AutoTokenizer
from omegaconf.dictconfig import DictConfig


In [None]:
#| export
class AgentRole(Enum):
    SERVER = 1
    CLIENT = 2
    MARL = 3

## Base Agent

An agent is an entity that has a state and exist in an environment. In the case of Federated learning (FL), the agent's state is defined as its own model, data, criterion, optimizer. FL Focuses on distributed model training across multiple clients (agents), each with its local data. Clients **collaborate** to improve a global or shared model while keeping their data private. Communication is often periodic (e.g., every few training rounds). On the other hand, Multi-agent RL systems (MARL) Involves multiple agents interacting with an environment to learn policies for specific tasks (e.g., navigation, resource allocation). Each agent has a state also, but the state represntation might differ slightly from that of an FL agent. The data is often not preloaded as in FL rather, it's collected from the environemnt.

In [None]:
#| export
class Agent:
    def __init__(self,
                 id,
                 cfg,
                 state= None,
                 role= AgentRole.CLIENT):
        
        self.cfg = cfg # contains all the configurations needed for the agent/trainer.
        self.state = state # A dictionary containing the state of the agent
        self.id = id # each agent has a unique id
        self.role = role # either a client or a server

In [None]:
#| export
@patch
def init_agent(self: Agent):
    # Initialize the state of the agent. In FL Agent, this means making any adjustments to the model/optimizer/state_dict/...etc
    raise NotImplementedError

In [None]:
#| export
@patch
def communicate(self: Agent, msg):
    raise NotImplementedError

In [None]:
#| export
@patch
def update_state(self: Agent):
    raise NotImplementedError

In [None]:
#| export
@patch
def save_state(self: Agent):
    # save the state of the agent to a file on disk (id, model, optimizer, loss_fn).
    raise NotImplementedError

In [None]:
#| export
@patch
def clear_model(self: Agent):
    self.model = None

## MARL Agent

In [None]:
class MARLAgent(Agent):
    def _sense(self, state):
        # sense the environment
        self.state = state

    def _decide(self):
    # Compute the next action(s) based on the current state and observations.
        pass

    def _act(self):
        pass
    

In [None]:
show_doc(MARLAgent, title_level=3)

---

### MARLAgent

>      MARLAgent (cfg, id, state=None, role='client')

*Initialize self.  See help(type(self)) for accurate signature.*

## FL Agent

In [None]:
#| export
class FLAgent(Agent):
    # A Federated Learning Agent that can be used to train a model in a federated learning setting
    def __init__(self,
                 id, # the id of the agent
                 cfg, # the configuration of the agent.
                 state= None, # the state of the agent (model, optimizer, loss_fn), etc.
                 role= AgentRole.CLIENT, # the role of the agent (client or server)
                 block= None): # The data block (local data of the FL Agent).
                 
        super().__init__(id, cfg, state, role)
        if block:
            self.train_ds, self.test_ds = block[0], block[1]
        
        if self.state :
            for key, value in self.state.items():
                setattr(self, key, value)
            self.init_agent()

Since data blocks are already on the disk, and since RL agents don't have a preloaded data blocks, we don't include the data in the FL agent's state. Another ratioanle behind this decision is that, state should contain dynamic objects that change over the interaction of the agents and data blocks are static in the case of FL agents, unless you are doing FL-RL Agents. 

For every client abstraction, whether it a base or any other type of federated client, it will initalize the training locally with a set of steps. This might include things like extracting the peft model out of the base model (in the case of LLMs clients). Also, it will terminate the local training with some steps, like saving the model state dictionary and so on.

We will adjust the string reprsntation of the client abstraction to make it more meaningful.

In [None]:
#| export
@patch
def __str__(self: FLAgent) -> str:
    return f'''FLAgent: {self.__class__.__name__}
    Index : {self.id}
    Model: {self.model.__class__.__name__}
    Criterion: {self.criterion.__class__.__name__}
    Optimizer: {self.optimizer.__class__.__name__}'''


In [None]:
#| export
@patch
def init_agent(self: FLAgent):  # noqa: F811
    self.optimizer = get_class('torch.optim', self.cfg.optimizer.name)(self.model.parameters(),  # noqa: F405
                                                                                lr= self.cfg.lr)

In [None]:
#| export
@patch
def clear_model(self: FLAgent):
    self.model = None if hasattr(self, 'model') else None

In [None]:
#| export
@patch
def save_state(self: FLAgent, state_dict, comm_round):  # noqa: F811
    # save the model to self.cfg.save_dir/comm_round/f"local_output_{id}"/pytorch_model.bin
    
    model_path = os.path.join(self.cfg.save_dir, 
                              str(comm_round),
                              f"local_output_{self.id}")
    
    os.makedirs(model_path, exist_ok=True)
    torch.save(state_dict, 
               os.path.join(model_path, 
                            "pytorch_model.pth"))
    save_space(self)  # noqa: F405


> To do: implement the communication process in **Protobuf**.

### Communication

Communication refers to the process of downloading and uploading models from the server and to the client. Since we are safeguarding against memory issues, we use sequential client processing and disk checkpointing as our tools.

In [None]:
#| export
@patch
def communicate(self: Agent, another_agent: Agent, comm_round):  # noqa: F811
    if self.role == AgentRole.CLIENT:
        self.save_state(self.model.state_dict(), comm_round)

In [None]:
#| export
@patch
def aggregate(self: FLAgent, lst_active_ids, comm_round, len_clients_ds):
    # load the models of the agents in lst_active_ids and `FedAvg` them. At the end, save the aggregated model to the disk.
        
    for i, id in enumerate(lst_active_ids):
        model_path = os.path.join(self.cfg.save_dir, 
                                   str(comm_round),
                                   f"local_output_{id}",
                                   "pytorch_model.pth")
        client_state_dict = torch.load(model_path, map_location='cpu')

        if i == 0:
            client_avg = {
                key: torch.zeros_like(value) 
                for key, value in client_state_dict.items()
            }
        
        weight = len_clients_ds[i] / sum(len_clients_ds)

        for key in client_state_dict.keys():
            client_avg[key].data += weight * client_state_dict[key].data

    for key in client_avg.keys():
        client_avg[key].data /= len(lst_active_ids)

    for id in lst_active_ids:
        model_path = os.path.join(self.cfg.save_dir, 
                                  str(comm_round),
                                  f"local_output_{id}",
                                  "pytorch_model.pth")
        self.save_state(client_avg, comm_round)
    

In [None]:
from collections import defaultdict
import torch
client_avg = defaultdict(lambda: torch.tensor(0.0).to('cpu'))

In [None]:
client_avg['1'] = torch.tensor(1.0)

In [None]:
client_avg

defaultdict(<function __main__.<lambda>()>, {'1': tensor(1.)})

## PEFT Agent

In [None]:
#| export
class PeftAgent(FLAgent):
    def __init__(self,
                 cfg,
                 block,
                 id,
                 state= None,
                 role= "client",
                 **adapter_params):
        super().__init__(cfg, block, id, state, role)


In [None]:
#| export
@patch
def peftify(self: PeftAgent):
    # extract only the adapter's parameters from the model and store them in a dictionary
    self.params_dict_old = deepcopy(
        OrderedDict((name, param.detach()) for name, param in self.model.named_parameters() if
                    "default" in name))
    
    self.params_dict_new = deepcopy(self.params_dict_old)
    
    self.model.state_dict = (
        lambda instance, *_, **__: get_peft_model_state_dict(  # noqa: F405
            instance, self.params_dict_new, "default"
        )
    ).__get__(self.model, type(self.model))

In [None]:
#| export
@patch 
def init_agent(self: PeftAgent):  # noqa: F811
    self.peftify()
    self.state['optimizer'] = get_class('torch.optim', self.cfg.optimizer.name)(self.model.parameters(),
                                                                                lr= self.cfg.lr)

In [None]:
#| export
@patch
def save_state_(self: PeftAgent, epoch, local_dataset_len_dict, previously_selected_clients_set):  # noqa: F811
    # save the new adapter weights to disk
    self.save_state(epoch)

    local_dataset_len_dict[self.id] = len(self.block)
    older_adapter_weight = get_peft_model_state_dict(self.model, self.params_dict_old, "default")  # noqa: F405
    set_peft_model_state_dict(self.model, older_adapter_weight, "default")  # noqa: F405
    previously_selected_clients_set = previously_selected_clients_set | set({self.id})
    last_client_id = self.id

    return self.model, local_dataset_len_dict, previously_selected_clients_set, last_client_id

In [None]:
#| export
@patch
def strategy(self: PeftAgent):
    # implement the strategy for the agent if it's a server. This is the aggregation strategy.
    pass

## MIRA Agent

Mira clients have more parameters. Since it's a client for LLM in principle, we need to feed the generation dataset (the dataset of text ids at the end layer not the logits). Also, a tokenizer and a collate function that will be used for the generation and the data loader construction processes.

In [None]:
#| export
class AgentMira(FLAgent):
    def __init__(self,
                 data_dict: dict,
                 model: torch.nn.Module,
                 criterion,
                 optimizer: torch.optim.Optimizer,
                 id: int,
                 gen_data_dict: dict,
                 tokenizer: AutoTokenizer,
                 collat_fn: LLMDataCollator,
                 cfg: DictConfig) -> None:
            
        super().__init__(data_dict, model, criterion, optimizer, id)
        
        self.train_ds_genr = gen_data_dict['train']
        self.test_ds_genr = gen_data_dict['test']
        self.tokenizer = tokenizer
        self.collat_fn = collat_fn
        self.cfg = cfg 

In order for us to save space, we will replace the original model with only the trainable peft model parameters. 

### Testing Mira Client

We will do the following:
- Define a Mira client.
- inspect the `init_local_train` and `terminate_local_train` methods and their effect on the model's parameters.

In [None]:
# #| hide
# from transformers import AutoModelForCausalLM
# gpt2 = AutoModelForCausalLM.from_pretrained("gpt2")
# base_model = deepcopy(gpt2)

In [None]:
# #| hide
# config = LoraConfig(
#     r=8,# arbitrary numbr but usually 8, 16, 32, 64, 128
#     target_modules=['c_attn'],
#     lora_alpha=8,
#     lora_dropout=0.05,
#     bias="none",
#     task_type="CAUSAL_LM",
#     )

# peft_model = get_peft_model(gpt2, config)
# mira  = AgentMira(DataDict, peft_model, criterion, optimizer, 0, train_dataset, test_dataset, None, None, None, None)



Let us inpect the model architecture:

In [None]:
# #| hide
# base_model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

Now try to observe the difference of architecture that we get from peft_model vs base_model.

In [None]:
# #| hide
# mira.init_local_train('')

In [None]:
# #| hide
# mira.model

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): GPT2LMHeadModel(
      (transformer): GPT2Model(
        (wte): Embedding(50257, 768)
        (wpe): Embedding(1024, 768)
        (drop): Dropout(p=0.1, inplace=False)
        (h): ModuleList(
          (0-11): 12 x GPT2Block(
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (attn): GPT2SdpaAttention(
              (c_attn): lora.Linear(
                (base_layer): Conv1D(nf=2304, nx=768)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=768, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=2304, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
      

If you print the lengths of the keys of the state dictionaries of the two models, you find out the Lora model has fewer keys. In fact, those are the only trainable parameters that e have here.

In [None]:
# #| hide
# len(mira.model.state_dict()), len(base_model.state_dict())

(24, 149)

the keys of the PeftModel are as follows:

In [None]:
# #| hide
# from IPython.display import display, Markdown

# keys_list = "\n".join(f"- {key}" for key in mira.model.state_dict().keys())
# display(Markdown(keys_list))

- base_model.model.transformer.h.0.attn.c_attn.lora_A.weight
- base_model.model.transformer.h.0.attn.c_attn.lora_B.weight
- base_model.model.transformer.h.1.attn.c_attn.lora_A.weight
- base_model.model.transformer.h.1.attn.c_attn.lora_B.weight
- base_model.model.transformer.h.2.attn.c_attn.lora_A.weight
- base_model.model.transformer.h.2.attn.c_attn.lora_B.weight
- base_model.model.transformer.h.3.attn.c_attn.lora_A.weight
- base_model.model.transformer.h.3.attn.c_attn.lora_B.weight
- base_model.model.transformer.h.4.attn.c_attn.lora_A.weight
- base_model.model.transformer.h.4.attn.c_attn.lora_B.weight
- base_model.model.transformer.h.5.attn.c_attn.lora_A.weight
- base_model.model.transformer.h.5.attn.c_attn.lora_B.weight
- base_model.model.transformer.h.6.attn.c_attn.lora_A.weight
- base_model.model.transformer.h.6.attn.c_attn.lora_B.weight
- base_model.model.transformer.h.7.attn.c_attn.lora_A.weight
- base_model.model.transformer.h.7.attn.c_attn.lora_B.weight
- base_model.model.transformer.h.8.attn.c_attn.lora_A.weight
- base_model.model.transformer.h.8.attn.c_attn.lora_B.weight
- base_model.model.transformer.h.9.attn.c_attn.lora_A.weight
- base_model.model.transformer.h.9.attn.c_attn.lora_B.weight
- base_model.model.transformer.h.10.attn.c_attn.lora_A.weight
- base_model.model.transformer.h.10.attn.c_attn.lora_B.weight
- base_model.model.transformer.h.11.attn.c_attn.lora_A.weight
- base_model.model.transformer.h.11.attn.c_attn.lora_B.weight

In [None]:
# #| hide
# from IPython.display import display, Markdown

# keys_list = "\n".join(f"- {key}" for key in base_model.state_dict().keys())
# display(Markdown(keys_list))

- transformer.wte.weight
- transformer.wpe.weight
- transformer.h.0.ln_1.weight
- transformer.h.0.ln_1.bias
- transformer.h.0.attn.c_attn.weight
- transformer.h.0.attn.c_attn.bias
- transformer.h.0.attn.c_proj.weight
- transformer.h.0.attn.c_proj.bias
- transformer.h.0.ln_2.weight
- transformer.h.0.ln_2.bias
- transformer.h.0.mlp.c_fc.weight
- transformer.h.0.mlp.c_fc.bias
- transformer.h.0.mlp.c_proj.weight
- transformer.h.0.mlp.c_proj.bias
- transformer.h.1.ln_1.weight
- transformer.h.1.ln_1.bias
- transformer.h.1.attn.c_attn.weight
- transformer.h.1.attn.c_attn.bias
- transformer.h.1.attn.c_proj.weight
- transformer.h.1.attn.c_proj.bias
- transformer.h.1.ln_2.weight
- transformer.h.1.ln_2.bias
- transformer.h.1.mlp.c_fc.weight
- transformer.h.1.mlp.c_fc.bias
- transformer.h.1.mlp.c_proj.weight
- transformer.h.1.mlp.c_proj.bias
- transformer.h.2.ln_1.weight
- transformer.h.2.ln_1.bias
- transformer.h.2.attn.c_attn.weight
- transformer.h.2.attn.c_attn.bias
- transformer.h.2.attn.c_proj.weight
- transformer.h.2.attn.c_proj.bias
- transformer.h.2.ln_2.weight
- transformer.h.2.ln_2.bias
- transformer.h.2.mlp.c_fc.weight
- transformer.h.2.mlp.c_fc.bias
- transformer.h.2.mlp.c_proj.weight
- transformer.h.2.mlp.c_proj.bias
- transformer.h.3.ln_1.weight
- transformer.h.3.ln_1.bias
- transformer.h.3.attn.c_attn.weight
- transformer.h.3.attn.c_attn.bias
- transformer.h.3.attn.c_proj.weight
- transformer.h.3.attn.c_proj.bias
- transformer.h.3.ln_2.weight
- transformer.h.3.ln_2.bias
- transformer.h.3.mlp.c_fc.weight
- transformer.h.3.mlp.c_fc.bias
- transformer.h.3.mlp.c_proj.weight
- transformer.h.3.mlp.c_proj.bias
- transformer.h.4.ln_1.weight
- transformer.h.4.ln_1.bias
- transformer.h.4.attn.c_attn.weight
- transformer.h.4.attn.c_attn.bias
- transformer.h.4.attn.c_proj.weight
- transformer.h.4.attn.c_proj.bias
- transformer.h.4.ln_2.weight
- transformer.h.4.ln_2.bias
- transformer.h.4.mlp.c_fc.weight
- transformer.h.4.mlp.c_fc.bias
- transformer.h.4.mlp.c_proj.weight
- transformer.h.4.mlp.c_proj.bias
- transformer.h.5.ln_1.weight
- transformer.h.5.ln_1.bias
- transformer.h.5.attn.c_attn.weight
- transformer.h.5.attn.c_attn.bias
- transformer.h.5.attn.c_proj.weight
- transformer.h.5.attn.c_proj.bias
- transformer.h.5.ln_2.weight
- transformer.h.5.ln_2.bias
- transformer.h.5.mlp.c_fc.weight
- transformer.h.5.mlp.c_fc.bias
- transformer.h.5.mlp.c_proj.weight
- transformer.h.5.mlp.c_proj.bias
- transformer.h.6.ln_1.weight
- transformer.h.6.ln_1.bias
- transformer.h.6.attn.c_attn.weight
- transformer.h.6.attn.c_attn.bias
- transformer.h.6.attn.c_proj.weight
- transformer.h.6.attn.c_proj.bias
- transformer.h.6.ln_2.weight
- transformer.h.6.ln_2.bias
- transformer.h.6.mlp.c_fc.weight
- transformer.h.6.mlp.c_fc.bias
- transformer.h.6.mlp.c_proj.weight
- transformer.h.6.mlp.c_proj.bias
- transformer.h.7.ln_1.weight
- transformer.h.7.ln_1.bias
- transformer.h.7.attn.c_attn.weight
- transformer.h.7.attn.c_attn.bias
- transformer.h.7.attn.c_proj.weight
- transformer.h.7.attn.c_proj.bias
- transformer.h.7.ln_2.weight
- transformer.h.7.ln_2.bias
- transformer.h.7.mlp.c_fc.weight
- transformer.h.7.mlp.c_fc.bias
- transformer.h.7.mlp.c_proj.weight
- transformer.h.7.mlp.c_proj.bias
- transformer.h.8.ln_1.weight
- transformer.h.8.ln_1.bias
- transformer.h.8.attn.c_attn.weight
- transformer.h.8.attn.c_attn.bias
- transformer.h.8.attn.c_proj.weight
- transformer.h.8.attn.c_proj.bias
- transformer.h.8.ln_2.weight
- transformer.h.8.ln_2.bias
- transformer.h.8.mlp.c_fc.weight
- transformer.h.8.mlp.c_fc.bias
- transformer.h.8.mlp.c_proj.weight
- transformer.h.8.mlp.c_proj.bias
- transformer.h.9.ln_1.weight
- transformer.h.9.ln_1.bias
- transformer.h.9.attn.c_attn.weight
- transformer.h.9.attn.c_attn.bias
- transformer.h.9.attn.c_proj.weight
- transformer.h.9.attn.c_proj.bias
- transformer.h.9.ln_2.weight
- transformer.h.9.ln_2.bias
- transformer.h.9.mlp.c_fc.weight
- transformer.h.9.mlp.c_fc.bias
- transformer.h.9.mlp.c_proj.weight
- transformer.h.9.mlp.c_proj.bias
- transformer.h.10.ln_1.weight
- transformer.h.10.ln_1.bias
- transformer.h.10.attn.c_attn.weight
- transformer.h.10.attn.c_attn.bias
- transformer.h.10.attn.c_proj.weight
- transformer.h.10.attn.c_proj.bias
- transformer.h.10.ln_2.weight
- transformer.h.10.ln_2.bias
- transformer.h.10.mlp.c_fc.weight
- transformer.h.10.mlp.c_fc.bias
- transformer.h.10.mlp.c_proj.weight
- transformer.h.10.mlp.c_proj.bias
- transformer.h.11.ln_1.weight
- transformer.h.11.ln_1.bias
- transformer.h.11.attn.c_attn.weight
- transformer.h.11.attn.c_attn.bias
- transformer.h.11.attn.c_proj.weight
- transformer.h.11.attn.c_proj.bias
- transformer.h.11.ln_2.weight
- transformer.h.11.ln_2.bias
- transformer.h.11.mlp.c_fc.weight
- transformer.h.11.mlp.c_fc.bias
- transformer.h.11.mlp.c_proj.weight
- transformer.h.11.mlp.c_proj.bias
- transformer.ln_f.weight
- transformer.ln_f.bias
- lm_head.weight

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()