# Clients

> The core abstraction for different FL Clients.

In [None]:
#| default_exp clients

In [None]:
#| hide
from nbdev.showdoc import *
from fastcore.test import *

In [None]:
#| export
from fastcore.utils import *
import os
import torch
from collections import OrderedDict
from copy import deepcopy
from peft import *
from fedai.trainers import *
from fedai.utils import get_class
from fedai.data import LLMDataCollator
from transformers import AutoTokenizer
from omegaconf.dictconfig import DictConfig


## Base Client

In [None]:
#| exportsend, aggregate
import torch

class BaseClient:
    '''A base FL client.\n
        data_dict: A dictionary that contains the train and test data sets. keys: (train, test)
    '''
    def __init__(self,
                 data_dict: dict,
                 model: torch.nn.Module,
                 criterion,
                 optimizer: torch.optim.Optimizer,
                 idx: int) -> None : 
        
        self.train_ds = data_dict['train']
        self.test_ds = data_dict['test']
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.idx = idx
        
        for key, value in data_dict.items():
            setattr(self, key, value)# client now has a data set object for train, test


We will adjust the string reprsntation of the client abstraction to make it more meaningful.

In [None]:
@patch #allows us to add a method to an existing class
def __str__(self: BaseClient) -> str:
    return f'''Client: {self.__class__.__name__}
    Index : {self.idx}
    Model: {self.model.__class__.__name__}
    Criterion: {self.criterion.__class__.__name__}
    Optimizer: {self.optimizer.__class__.__name__}'''


For every client abstraction, whether it a base or any other type of federated client, it will initalize the training locally with a set of steps. This might include things like extracting the eft model out of the base model (in the case of LLMs clients). Also, it will terminate the local training with some steps, like saving the model state dictionary and so on.

In [None]:
@patch
def init_local_train(self: BaseClient):
    raise NotImplementedError

@patch
def terminate_local_train(self: BaseClient):
    raise NotImplementedError

@patch
def clear_model(self: BaseClient):
    self.model = None

In [None]:
show_doc(BaseClient)

---

### BaseClient

>      BaseClient (DataDict:dict, model:torch.nn.modules.module.Module,
>                  criterion:int, optimizer:torch.optim.optimizer.Optimizer,
>                  idx:int)

*A base FL client.

DataDict: A dictionary that contains the train and test data sets. keys: (train, test)*

### Testing the BaseClient Functionalities

In [None]:
#| hide
import torch
from torch.utils.data import Dataset, random_split

class RandomTwoCaseDataset(Dataset):
    def __init__(self, num_samples=1000, input_size=10, case_prob=0.5, transform=None):
        self.num_samples = num_samples
        self.input_size = input_size
        self.case_prob = case_prob
        self.transform = transform
        self.data, self.labels = self._generate_data()

    def _generate_data(self):
        """Generates random data for two cases."""
        data = torch.randn(self.num_samples, self.input_size)  # Random data
        labels = torch.zeros(self.num_samples, dtype=torch.long)  # Labels (0 or 1)

        # Assign case 1 based on case_prob
        case_1_indices = torch.rand(self.num_samples) < self.case_prob
        labels[case_1_indices] = 1  # Assign case 1 (label=1) to some samples

        # Modify data to differ based on the case label
        data[labels == 0] *= 1.5  # Modify case 0 samples
        data[labels == 1] += 2.0  # Modify case 1 samples

        return data, labels

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        sample, label = self.data[idx], self.labels[idx]
        if self.transform:
            sample = self.transform(sample)
        return sample, label


In [None]:
#| hide
# Create full dataset
full_dataset = RandomTwoCaseDataset(num_samples=2000, input_size=5, case_prob=0.5)

# Split into train and test datasets
train_size = int(0.8 * len(full_dataset))  # 80% for training
test_size = len(full_dataset) - train_size  # 20% for testing
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])

item_shape = train_dataset[0][0].shape[0]
item_shape

5

In [None]:
#| hide
from einops import repeat
test_item = train_dataset[0][0]
test_item =  repeat(test_item, 'l -> b l', b=1) # add a batch dimnsion

In [None]:
#| hide
from torch import nn
class SimpleModel(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(5, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 1)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits



In [None]:
#| hide
model = SimpleModel()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params= model.parameters(), lr= 0.01)
DataDict = {
    'train': train_dataset,
    'test': test_dataset
}

In [None]:
#| hide
dummy_client= BaseClient(DataDict, model,criterion, optimizer, 0)
print(dummy_client)

Client: BaseClient
    Index : 0
    Model: SimpleModel
    Criterion: CrossEntropyLoss
    Optimizer: Adam


## MIRA Client

Mira clients have more parameters. Since it's a client for LLM in principle, we need to feed the generation dataset (the dataset of text ids at the end layer not the logits). Also, a tokenizer and a collate function that will be used for the generation and the data loader construction processes.

In [None]:
#| export
class Client_mira(BaseClient):
    def __init__(self,
                 data_dict: dict,
                 model: torch.nn.Module,
                 criterion,
                 optimizer: torch.optim.Optimizer,
                 idx: int,
                 gen_data_dict: dict,
                 tokenizer: AutoTokenizer,
                 collat_fn: LLMDataCollator,
                 cfg: DictConfig) -> None:
            
        super().__init__(data_dict, model, criterion, optimizer, idx)
        
        self.train_ds_genr = gen_data_dict['train']
        self.test_ds_genr = gen_data_dict['test']
        self.tokenizer = tokenizer
        self.collat_fn = collat_fn
        self.cfg = cfg 

In order for us to save space, we will replace the original model with only the trainable peft model parameters. 

In [None]:
#| export
@patch 
def init_local_train(self: Client_mira, out_dir):

    self.output_dir = out_dir
    self.params_dict_old = deepcopy(
        OrderedDict((name, param.detach()) for name, param in self.model.named_parameters() if
                    "default" in name))
    
    self.params_dict_new = OrderedDict((name, param.detach()) for name, param in self.model.named_parameters() if
                                        "default" in name)
    
    self.model.state_dict = (
        lambda instance, *_, **__: get_peft_model_state_dict(
            instance, self.params_dict_new, "default"
        )
    ).__get__(self.model, type(self.model))

    self.optimizer = get_class('torch.optim', self.cfg.optimizer)(self.model.parameters(), lr= self.cfg.lr)

In [None]:
#| export
@patch
def train(self: Client_mira, n_epochs):
    self.model.train()
    trainer = Trainer(self)
    history = trainer.fit(n_epochs)
    return history

In [None]:
#| export
@patch
def clear_model(self: Client_mira):
    self.model = None

In [None]:
#| export
@patch
def terminate_local_train(self: Client_mira, epoch, local_dataset_len_dict, previously_selected_clients_set):

    local_dataset_len_dict[self.idx] = len(self.train_ds)
    new_adapter_weight = self.model.state_dict()
    single_output_dir = os.path.join(self.output_dir, str(epoch), "local_output_{}".format(self.idx))
    os.makedirs(single_output_dir, exist_ok=True)
    torch.save(new_adapter_weight, single_output_dir + "/pytorch_model.bin")

    older_adapter_weight = get_peft_model_state_dict(self.model, self.params_dict_old, "default")
    set_peft_model_state_dict(self.model, older_adapter_weight, "default")
    previously_selected_clients_set = previously_selected_clients_set | set({self.idx})
    last_client_id = self.idx

    return self.model, local_dataset_len_dict, previously_selected_clients_set, last_client_id

### Testing Mira Client

We will do the following:
- Define a Mira client.
- inspect the `init_local_train` and `terminate_local_train` methods and their effect on the model's parameters.

In [None]:
# #| hide
# from transformers import AutoModelForCausalLM
# gpt2 = AutoModelForCausalLM.from_pretrained("gpt2")
# base_model = deepcopy(gpt2)

In [None]:
# #| hide
# config = LoraConfig(
#     r=8,# arbitrary numbr but usually 8, 16, 32, 64, 128
#     target_modules=['c_attn'],
#     lora_alpha=8,
#     lora_dropout=0.05,
#     bias="none",
#     task_type="CAUSAL_LM",
#     )

# peft_model = get_peft_model(gpt2, config)
# mira  = Client_mira(DataDict, peft_model, criterion, optimizer, 0, train_dataset, test_dataset, None, None, None, None)



Let us inpect the model architecture:

In [None]:
# #| hide
# base_model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

Now try to observe the difference of architecture that we get from peft_model vs base_model.

In [None]:
# #| hide
# mira.init_local_train('')

In [None]:
# #| hide
# mira.model

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): GPT2LMHeadModel(
      (transformer): GPT2Model(
        (wte): Embedding(50257, 768)
        (wpe): Embedding(1024, 768)
        (drop): Dropout(p=0.1, inplace=False)
        (h): ModuleList(
          (0-11): 12 x GPT2Block(
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (attn): GPT2SdpaAttention(
              (c_attn): lora.Linear(
                (base_layer): Conv1D(nf=2304, nx=768)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=768, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=2304, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
      

If you print the lengths of the keys of the state dictionaries of the two models, you find out the Lora model has fewer keys. In fact, those are the only trainable parameters that e have here.

In [None]:
# #| hide
# len(mira.model.state_dict()), len(base_model.state_dict())

(24, 149)

the keys of the PeftModel are as follows:

In [None]:
# #| hide
# from IPython.display import display, Markdown

# keys_list = "\n".join(f"- {key}" for key in mira.model.state_dict().keys())
# display(Markdown(keys_list))

- base_model.model.transformer.h.0.attn.c_attn.lora_A.weight
- base_model.model.transformer.h.0.attn.c_attn.lora_B.weight
- base_model.model.transformer.h.1.attn.c_attn.lora_A.weight
- base_model.model.transformer.h.1.attn.c_attn.lora_B.weight
- base_model.model.transformer.h.2.attn.c_attn.lora_A.weight
- base_model.model.transformer.h.2.attn.c_attn.lora_B.weight
- base_model.model.transformer.h.3.attn.c_attn.lora_A.weight
- base_model.model.transformer.h.3.attn.c_attn.lora_B.weight
- base_model.model.transformer.h.4.attn.c_attn.lora_A.weight
- base_model.model.transformer.h.4.attn.c_attn.lora_B.weight
- base_model.model.transformer.h.5.attn.c_attn.lora_A.weight
- base_model.model.transformer.h.5.attn.c_attn.lora_B.weight
- base_model.model.transformer.h.6.attn.c_attn.lora_A.weight
- base_model.model.transformer.h.6.attn.c_attn.lora_B.weight
- base_model.model.transformer.h.7.attn.c_attn.lora_A.weight
- base_model.model.transformer.h.7.attn.c_attn.lora_B.weight
- base_model.model.transformer.h.8.attn.c_attn.lora_A.weight
- base_model.model.transformer.h.8.attn.c_attn.lora_B.weight
- base_model.model.transformer.h.9.attn.c_attn.lora_A.weight
- base_model.model.transformer.h.9.attn.c_attn.lora_B.weight
- base_model.model.transformer.h.10.attn.c_attn.lora_A.weight
- base_model.model.transformer.h.10.attn.c_attn.lora_B.weight
- base_model.model.transformer.h.11.attn.c_attn.lora_A.weight
- base_model.model.transformer.h.11.attn.c_attn.lora_B.weight

In [None]:
# #| hide
# from IPython.display import display, Markdown

# keys_list = "\n".join(f"- {key}" for key in base_model.state_dict().keys())
# display(Markdown(keys_list))

- transformer.wte.weight
- transformer.wpe.weight
- transformer.h.0.ln_1.weight
- transformer.h.0.ln_1.bias
- transformer.h.0.attn.c_attn.weight
- transformer.h.0.attn.c_attn.bias
- transformer.h.0.attn.c_proj.weight
- transformer.h.0.attn.c_proj.bias
- transformer.h.0.ln_2.weight
- transformer.h.0.ln_2.bias
- transformer.h.0.mlp.c_fc.weight
- transformer.h.0.mlp.c_fc.bias
- transformer.h.0.mlp.c_proj.weight
- transformer.h.0.mlp.c_proj.bias
- transformer.h.1.ln_1.weight
- transformer.h.1.ln_1.bias
- transformer.h.1.attn.c_attn.weight
- transformer.h.1.attn.c_attn.bias
- transformer.h.1.attn.c_proj.weight
- transformer.h.1.attn.c_proj.bias
- transformer.h.1.ln_2.weight
- transformer.h.1.ln_2.bias
- transformer.h.1.mlp.c_fc.weight
- transformer.h.1.mlp.c_fc.bias
- transformer.h.1.mlp.c_proj.weight
- transformer.h.1.mlp.c_proj.bias
- transformer.h.2.ln_1.weight
- transformer.h.2.ln_1.bias
- transformer.h.2.attn.c_attn.weight
- transformer.h.2.attn.c_attn.bias
- transformer.h.2.attn.c_proj.weight
- transformer.h.2.attn.c_proj.bias
- transformer.h.2.ln_2.weight
- transformer.h.2.ln_2.bias
- transformer.h.2.mlp.c_fc.weight
- transformer.h.2.mlp.c_fc.bias
- transformer.h.2.mlp.c_proj.weight
- transformer.h.2.mlp.c_proj.bias
- transformer.h.3.ln_1.weight
- transformer.h.3.ln_1.bias
- transformer.h.3.attn.c_attn.weight
- transformer.h.3.attn.c_attn.bias
- transformer.h.3.attn.c_proj.weight
- transformer.h.3.attn.c_proj.bias
- transformer.h.3.ln_2.weight
- transformer.h.3.ln_2.bias
- transformer.h.3.mlp.c_fc.weight
- transformer.h.3.mlp.c_fc.bias
- transformer.h.3.mlp.c_proj.weight
- transformer.h.3.mlp.c_proj.bias
- transformer.h.4.ln_1.weight
- transformer.h.4.ln_1.bias
- transformer.h.4.attn.c_attn.weight
- transformer.h.4.attn.c_attn.bias
- transformer.h.4.attn.c_proj.weight
- transformer.h.4.attn.c_proj.bias
- transformer.h.4.ln_2.weight
- transformer.h.4.ln_2.bias
- transformer.h.4.mlp.c_fc.weight
- transformer.h.4.mlp.c_fc.bias
- transformer.h.4.mlp.c_proj.weight
- transformer.h.4.mlp.c_proj.bias
- transformer.h.5.ln_1.weight
- transformer.h.5.ln_1.bias
- transformer.h.5.attn.c_attn.weight
- transformer.h.5.attn.c_attn.bias
- transformer.h.5.attn.c_proj.weight
- transformer.h.5.attn.c_proj.bias
- transformer.h.5.ln_2.weight
- transformer.h.5.ln_2.bias
- transformer.h.5.mlp.c_fc.weight
- transformer.h.5.mlp.c_fc.bias
- transformer.h.5.mlp.c_proj.weight
- transformer.h.5.mlp.c_proj.bias
- transformer.h.6.ln_1.weight
- transformer.h.6.ln_1.bias
- transformer.h.6.attn.c_attn.weight
- transformer.h.6.attn.c_attn.bias
- transformer.h.6.attn.c_proj.weight
- transformer.h.6.attn.c_proj.bias
- transformer.h.6.ln_2.weight
- transformer.h.6.ln_2.bias
- transformer.h.6.mlp.c_fc.weight
- transformer.h.6.mlp.c_fc.bias
- transformer.h.6.mlp.c_proj.weight
- transformer.h.6.mlp.c_proj.bias
- transformer.h.7.ln_1.weight
- transformer.h.7.ln_1.bias
- transformer.h.7.attn.c_attn.weight
- transformer.h.7.attn.c_attn.bias
- transformer.h.7.attn.c_proj.weight
- transformer.h.7.attn.c_proj.bias
- transformer.h.7.ln_2.weight
- transformer.h.7.ln_2.bias
- transformer.h.7.mlp.c_fc.weight
- transformer.h.7.mlp.c_fc.bias
- transformer.h.7.mlp.c_proj.weight
- transformer.h.7.mlp.c_proj.bias
- transformer.h.8.ln_1.weight
- transformer.h.8.ln_1.bias
- transformer.h.8.attn.c_attn.weight
- transformer.h.8.attn.c_attn.bias
- transformer.h.8.attn.c_proj.weight
- transformer.h.8.attn.c_proj.bias
- transformer.h.8.ln_2.weight
- transformer.h.8.ln_2.bias
- transformer.h.8.mlp.c_fc.weight
- transformer.h.8.mlp.c_fc.bias
- transformer.h.8.mlp.c_proj.weight
- transformer.h.8.mlp.c_proj.bias
- transformer.h.9.ln_1.weight
- transformer.h.9.ln_1.bias
- transformer.h.9.attn.c_attn.weight
- transformer.h.9.attn.c_attn.bias
- transformer.h.9.attn.c_proj.weight
- transformer.h.9.attn.c_proj.bias
- transformer.h.9.ln_2.weight
- transformer.h.9.ln_2.bias
- transformer.h.9.mlp.c_fc.weight
- transformer.h.9.mlp.c_fc.bias
- transformer.h.9.mlp.c_proj.weight
- transformer.h.9.mlp.c_proj.bias
- transformer.h.10.ln_1.weight
- transformer.h.10.ln_1.bias
- transformer.h.10.attn.c_attn.weight
- transformer.h.10.attn.c_attn.bias
- transformer.h.10.attn.c_proj.weight
- transformer.h.10.attn.c_proj.bias
- transformer.h.10.ln_2.weight
- transformer.h.10.ln_2.bias
- transformer.h.10.mlp.c_fc.weight
- transformer.h.10.mlp.c_fc.bias
- transformer.h.10.mlp.c_proj.weight
- transformer.h.10.mlp.c_proj.bias
- transformer.h.11.ln_1.weight
- transformer.h.11.ln_1.bias
- transformer.h.11.attn.c_attn.weight
- transformer.h.11.attn.c_attn.bias
- transformer.h.11.attn.c_proj.weight
- transformer.h.11.attn.c_proj.bias
- transformer.h.11.ln_2.weight
- transformer.h.11.ln_2.bias
- transformer.h.11.mlp.c_fc.weight
- transformer.h.11.mlp.c_fc.bias
- transformer.h.11.mlp.c_proj.weight
- transformer.h.11.mlp.c_proj.bias
- transformer.ln_f.weight
- transformer.ln_f.bias
- lm_head.weight

## Cleint FedIT

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()