# Config loader

Code using `omegaconf` to handle IO.

In [None]:
#| default_exp config_loader

In [None]:
#| export
from genQC.imports import *
from omegaconf import OmegaConf

## IO

In [None]:
#| export
def class_to_str(cls):
    return str(cls)[8:-2]

In [None]:
#| export
def load_config(file_path):
    return OmegaConf.load(f"{file_path}")

In [None]:
#| export
def config_to_dict(config):
    return OmegaConf.to_container(config)

In [None]:
#| export
def save_dataclass_yaml(data_obj, file_path):
    conf = OmegaConf.structured(data_obj)
    with open(file_path, 'w') as f:
        OmegaConf.save(config=conf, f=f)

In [None]:
#| export
def save_dict_yaml(dict_obj, file_path):
    conf = OmegaConf.create(dict_obj)
    with open(file_path, 'w') as f:
        OmegaConf.save(config=conf, f=f)

Test

In [None]:
@dataclass
class MyConfig:    
    target:str = class_to_str(OmegaConf)
    clr_dim: int = 80
    features: list[int]=None
    
c = MyConfig()
c.features = [1,2,3]

OmegaConf.structured(c)

{'target': 'omegaconf.omegaconf.OmegaConf', 'clr_dim': 80, 'features': [1, 2, 3]}

## Object config load

Mostly taken from: https://github.com/Stability-AI/stablediffusion

In [None]:
#| export
def get_obj_from_str(string, reload=False):
    module, cls = string.rsplit(".", 1)
    if reload:
        module_imp = importlib.import_module(module)
        importlib.reload(module_imp)
    return getattr(importlib.import_module(module, package=None), cls)

In [None]:
#| export
def instantiate_from_config(config):
    if not "target" in config: raise KeyError("Expected key `target` to instantiate.")
    if not "params" in config: print(f"[WARNING] Expected key `params` to instantiate.")
    return get_obj_from_str(config["target"])(**config.get("params", dict()))

In [None]:
#| export
def load_model_from_config(config, ckpt, device):
    
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location=torch.device(device).type)
          
    model = instantiate_from_config(config.model)
    
    sd = pl_sd["state_dict"]
    m, u = model.load_state_dict(sd, strict=True)
    
    return model.to(device)

# Export -

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()