# Utils

> Fill in a module description here


In [None]:
#| default_exp utils

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import os
from fastcore.utils import *  # noqa: F403
from torch.utils.data import DataLoader
import numpy as np
import yaml
import torch
from fedai.data import *
import numpy as np
from fedai.vision.models import *


In [None]:
#| export
import importlib
def get_class(module_name, class_name):
    module = importlib.import_module(module_name)
    return getattr(module, class_name)

In [None]:
#| export
def get_server(cfg, lst_data_dict, model, holdout_ds, **kwargs):
    Server = get_class('fedai.servers', f'Server_{cfg.name}')
    client_class = get_class('fedai.clients', f'Client_{cfg.name}')
    return Server(cfg, lst_data_dict, model, holdout_ds, client_class, **kwargs)

In [None]:
#| export
def load_config(file_path):
    with open(file_path, 'r') as file:
        return yaml.safe_load(file)

In [None]:
#| export
def save_space(client) -> None:
    client.clear_model()
    del client.optimizer
    del client
    import gc
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

Because of the difference that it takes to prepare the dataset for sinle device vs multi-device training, we make a method that handles this separately. `prepare_dl` prepares the dataloader needed for the trainer's type.

In [None]:
#| export
def prepare_dl(cfg, ds, shuffle=True, collate_fn=None):
    return DataLoader(
        ds,
        batch_size= cfg.data.batch_size,
        shuffle= shuffle,
        collate_fn= collate_fn     
    )

## Learner's utils

In [None]:
#| export
def get_model(cfg):
    model_name = cfg.model.name

    # Check if the model name contains "hf://"
    if model_name.startswith("hf://"):
        return get_hf_model(cfg)  # type: ignore # Call your HF model loader function  # noqa: F405

    # Define the rest of the model mapping
    mapping = {
        "MNISTCNN": MNISTCNN(num_classes=10),  # noqa: F405 # type: ignore
        "CIFAR10CNN": CIFAR10CNN(num_classes=10),  # noqa: F405 # type: ignore
        
        "MLP": MLP(
            dim_in=cfg.model.dim_in, 
            dim_hidden=cfg.model.dim_hidden, 
            dim_out=cfg.model.dim_out
            ),

        "CharacterLSTM": CharacterLSTM(  # noqa: F405 # type: ignore
            vocab_size=cfg.model.vocab_size,
            embed_size=cfg.model.embed_size,
            hidden_size=cfg.model.hidden_size,
            num_layers=cfg.model.num_layers
        )
    }

    # Look up the model in the mapping
    if model_name in mapping:
        return mapping[model_name]
    
    raise ValueError(f"Model '{model_name}' is not recognized, the available models are: {list(mapping.keys())}")


In [None]:
#| export
def get_criterion(cfg, customm_fn):
    if customm_fn:
        return customm_fn
    return get_class('fedai.loss', cfg.criterion)()

In [None]:
#| export
def load_state_from_disk(cfg, model, id, comm_round):
    
    model_path = os.path.join(cfg.save_dir,
                              str(comm_round),
                              f"local_output_{id}",
                              "pytorch_model.bin")

    if os.path.exists(model_path):
        if isinstance(model, torch.nn.Module):
            model.load_state_dict(torch.load(model_path, map_location= model))
        else:
            set_peft_model_state_dict(model,
                                  torch.load(model_path, map_location= model.device), 
                                  "default")
    return model

## A Lazy initializer 

To save the memory, we don't need to instantiate all the client's objects at once. We can use `generators` as our tool to **lazily** instanitate them, meaning that they will only be instantiated and created in memory when we access them. This can be achieved by creating a class ad overriding the `__getitem__` method of our defined class. Inside the class, we use not a generator directly, but a cache object (dictionary) to retireve the clients.

In [None]:
#| export
class LazyList:
    def __init__(self, server, client_cls):
        self.server = server
        self.client_cls = client_cls
        self.client_cache = {}  # Cache to store initialized clients

    def clear_cache(self):
        # Clear the cache to free memory if needed
        self.client_cache = {}


In [None]:
#| export
@patch
def __getitem__(self: LazyList, idx):
    # Check if the client is already instantiated
    if idx not in self.client_cache:
        # Instantiate the client and store it in the cache
        self.client_cache[idx] = self.client_cls(
            data_dict= self.server.lst_data_dict[idx],
            model= None, #deepcopy(self.server.model),
            criterion= self.server.criterion,
            optimizer= None, #get_class('torch.optim', self.server.cfg.optimizer)(self.server.model.parameters(), lr= self.server.cfg.lr),
            idx= idx,
            gen_data_dict= self.server.lst_gen_data_dict[idx],
            tokenizer= self.server.tokenizer,
            collat_fn= self.server.collat_fn,
            cfg= self.server.cfg
        )
    return self.client_cache[idx]

Let us see how can this be used on a real example.

In [None]:
#| hide
import nbdev
nbdev.nbdev_export()