# Models

> Fill in a module description here


In [None]:
#| default_exp learner_utils

In [None]:
#| hide
from nbdev.showdoc import *  # type: ignore # noqa: F403

In [None]:
#| export
import os
import torch.nn as nn
import torch
from fedai.vision.VisionBlock import VisionBlock
from fedai.vision.models import *
from fedai.text.models import *
from fedai.models import * # noqa: F403
from peft import *  # type: ignore # noqa: F403


In [None]:
#| export
import importlib
def get_cls(module_name, class_name):
    module = importlib.import_module(module_name)
    return getattr(module, class_name)

In [None]:
#| export
def get_block(cfg, id, train=True):
    block = VisionBlock if cfg.data.modality == ['Vision'] else None
    return block(cfg, id, train=train)

In [None]:
#| export
def get_model(cfg):
    model_name = cfg.model.name

    # Check if the model name contains "hf://"
    if model_name.startswith("hf://"):
        return get_hf_model(cfg)  # type: ignore # Call your HF model loader function  # noqa: F405

    # Define the rest of the model mapping
    mapping = {
    "LogisticRegression": LogisticRegression(  
        input_dim=getattr(cfg.model, "dim_in", 784),  
        output_dim=getattr(cfg.model, "dim_out", 10)
    ),
    "MNISTCNN": MNISTCNN(num_classes=10),  
    "CIFAR10CNN": CIFAR10CNN(num_classes=10),  
    
    "MLP": MLP(  
        dim_in=getattr(cfg.model, "dim_in", 784),  
        dim_hidden=getattr(cfg.model, "dim_hidden", 128),  
        dim_out=getattr(cfg.model, "dim_out", 10)
    ),

    "CharacterLSTM": CharacterLSTM(  
        vocab_size=getattr(cfg.model, "vocab_size", 50000),  
        embed_size=getattr(cfg.model, "embed_size", 512),  
        hidden_size=getattr(cfg.model, "hidden_size", 512),  
        num_layers=getattr(cfg.model, "num_layers", 8)
    )
}


    # Look up the model in the mapping
    if model_name in mapping:
        return mapping[model_name]
    
    raise ValueError(f"Model '{model_name}' is not recognized, the available models are: {list(mapping.keys())}")


In [None]:
#| export
def get_criterion(customm_fn):
    if customm_fn:
        return customm_fn
    else:
        return nn.CrossEntropyLoss()

in the case of one model aggregation, we send the aggregated model back to all clients. On the other hand, Personalized FL, FMTL, ...etc uses one model per client so we need to only work on per client model case.

In [None]:
#| export
def load_state_from_disk(cfg, state, latest_round, id, t):
    
    if cfg.agg == "one_model":
        global_model_path = os.path.join(cfg.save_dir,
                                        str(t-1),
                                        "global_model",
                                        "state.pth")
        gloabal_model_state = torch.load(global_model_path)
        
        if isinstance(state["model"], torch.nn.Module):
            state["model"].load_state_dict(gloabal_model_state["model"])
        else:
            set_peft_model_state_dict(state["model"],  # noqa: F405 # type: ignore
                                      gloabal_model_state["model"],
                                      "default")
        
    else:
        latest_comm_round = latest_round[id]
        old_state_path = os.path.join(cfg.save_dir,
                                       str(latest_comm_round),
                                       f"local_output_{id}",
                                       "state.pth")
        
        old_saved_state = torch.load(old_state_path)

        if isinstance(state["model"], torch.nn.Module):
            state["model"].load_state_dict(old_saved_state["model"])
        else:
            set_peft_model_state_dict(state["model"],  # noqa: F405 # type: ignore
                                      old_saved_state["model"],
                                      "default")    

    return state

In [None]:
# from torch import nn
# def client_fn(client_cls, cfg, id, latest_round, t, loss_fn = None, optimizer = None):
    
#     model = get_model(cfg)
#     criterion = get_criterion(loss_fn)

#     train_block = get_block(cfg, id)
#     test_block = get_block(cfg, id, train=False)    
    
#     state = {'model': model, 'optimizer': None, 'criterion': criterion, 't': t}
    
#     if t > 0:
#         state = load_state_from_disk(cfg, state, latest_round, id, t)  # noqa: F405
#         state['optimizer'] = SophiaG(model.parameters(),
#                              lr=2e-4,
#                              betas=(0.965, 0.99),
#                              rho=0.01,
#                              weight_decay=1e-1)
        
#     return client_cls(id, cfg, state, block= [train_block, test_block])


In [None]:
# def client_fn(client_cls, cfg, id, latest_round, loss_fn):
#     model = get_model(cfg)
#     criterion = get_criterion(loss_fn)
#     # get train and test ds
#     train_block, test_block = get_block(cfg, id), get_block(cfg, id, train=False)

#     state = {'model': model, 'optimizer': None, 'criterion': criterion}

#     if id in latest_round:
#         comm_round = latest_round[id]
#         state['model'] = load_state_from_disk(cfg, model, id, comm_round)  # noqa: F405
    
#     return client_cls(id, cfg, state, block= [train_block, test_block])


In [None]:
#| hide
import nbdev
nbdev.nbdev_export() # type: ignore  # noqa: E702
