# skinfra.experiment: 实验原子能力基建

> 

In [24]:
#| default_exp experiment.nucleus

In [25]:
#| hide
%load_ext autoreload
%autoreload 2
from nbdev.showdoc import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
#| export
from skinfra.rv_args.nucleus import experiment_setting, RandomVariable

In [27]:
#| export
import warnings

def deprecated_alias_of(old_function):
    def wrapper(*args, **kwargs):
        warnings.warn(f"{old_function.__name__} 已弃用，请使用 load_config", DeprecationWarning, stacklevel=2)
        return old_function(*args, **kwargs)
    # 注意，不宜使用 decorator 库，反而不够直观，反而不会触发deprecated警报、
    return wrapper

In [28]:
# | export
import yaml
import json
import toml

from pathlib import Path

from typing import Optional
from configparser import ConfigParser
from typing import Union

def load_config(config_path: Union[str, Path], raise_error: bool = True) -> Optional[dict]:
    config_path = Path(config_path)
    if not config_path.exists():
        if raise_error:
            raise FileNotFoundError(f"配置文件不存在: {config_path}")
        return None
    with config_path.open("r", encoding="utf-8") as f:
        suffix = config_path.suffix.lower()
        if suffix == ".json":
            config = json.load(f)
        elif suffix in (".yaml", ".yml"):
            config = yaml.full_load(f)
        elif suffix == ".toml":
            config = toml.load(f)
        elif suffix == ".ini":
            parser = ConfigParser()
            parser.read_string(f.read())
            config = {section: dict(parser.items(section)) for section in parser.sections()}
        else:
            if raise_error:
                raise ValueError(f"不支持的配置文件格式: {config_path}")
            return None
    return config



get_config = deprecated_alias_of(load_config)


def save_config(config: dict, config_path: Union[str, Path], raise_error: bool = True, **kwargs) -> bool:
    """
    将配置字典保存到指定路径，根据扩展名自动选择格式。
    返回 True 表示保存成功，False 表示失败。
    """
    config_path = Path(config_path)
    if not config:
        if raise_error:
            raise ValueError("配置字典为空")
        return False

    try:
        with config_path.open("w", encoding="utf-8") as f:
            suffix = config_path.suffix.lower()
            if suffix == ".json":
                if "indent" not in kwargs:
                    kwargs["indent"] = 2
                json.dump(config, f, ensure_ascii=False, **kwargs)
            elif suffix in (".yaml", ".yml"):
                if "sort_keys" not in kwargs:
                    kwargs["sort_keys"] = False
                yaml.dump(config, f, allow_unicode=True, **kwargs)
            elif suffix == ".toml":
                toml.dump(config, f, **kwargs)
            elif suffix == ".ini":
                raise NotImplementedError("INI 格式的配置文件保存暂不支持嵌套字典结构")
                # TODO 
                # parser = ConfigParser()
                # # 检查是否为嵌套字典结构
                # is_nested = all(isinstance(v, dict) for v in config.values())
                
                # if is_nested:
                #     # 处理嵌套字典结构（多个section）
                #     for section, items in config.items():
                #         parser.add_section(str(section))
                #         for k, v in items.items():
                #             parser.set(str(section), str(k), str(v))
                # else:
                #     # 处理非嵌套结构（使用默认section）
                #     parser.add_section("DEFAULT")
                #     for k, v in config.items():
                #         parser.set("DEFAULT", str(k), str(v))
                
                # parser.write(f)
                # # 确保文件内容立即落盘
                # f.flush()
                # os.fsync(f.fileno())
            else:
                if raise_error:
                    raise ValueError(f"不支持的配置文件格式: {config_path}")
                return False
        return True
    except Exception as e:
        if raise_error:
            raise e
        return False

In [29]:
import tempfile, os
# 在临时目录中创建 config 文件
with tempfile.TemporaryDirectory() as tmpdir:
    config_data = {
        "learning_rate": 1e-4,
        "batch_size": 32,
        "epochs": 10
    }
    # for ext in [".json", ".toml", ".yaml", '.ini']:
    for ext in [".json", ".toml", ".yaml"]:
        config_path = os.path.join(tmpdir, f"config{ext}")
        save_config(config_data, config_path)
        loaded_config = load_config(config_path)
        loaded_config = get_config(config_path)
        print(f"已读取临时配置{ext}:", loaded_config)


已读取临时配置.json: {'learning_rate': 0.0001, 'batch_size': 32, 'epochs': 10}
已读取临时配置.toml: {'learning_rate': 0.0001, 'batch_size': 32, 'epochs': 10}
已读取临时配置.yaml: {'learning_rate': 0.0001, 'batch_size': 32, 'epochs': 10}


  loaded_config = get_config(config_path)


In [30]:
#| export
def iterate_path_hierarchy(path_str, ensure_self=False):
    """遍历路径层级，从根目录到目标路径"""
    path = Path(path_str)
    parts = path.parts

    paths = [] if not ensure_self else ["."]
    for i in range(len(parts)):
        current_path = Path(*parts[: i + 1])
        paths.append(str(current_path))

    return paths


def load_overlaying_config(
    path_str: str, config_filename: str, ensure_self=False, verbose=False
) -> Optional[dict]:
    """读取路径层级中的配置文件，优先级从低到高"""
    paths = iterate_path_hierarchy(path_str, ensure_self=ensure_self)
    config = dict()
    not_all_vacant = False
    for i, path in enumerate(paths):
        config_path = Path(path) / config_filename
        if config_path.exists():
            if verbose:
                print(f"Found config at: {config_path}, priority is {i} (the higher the priority, the later it is loaded).")
            not_all_vacant = True
            config.update(get_config(str(config_path)) or {})
    return config if not_all_vacant else None

read_overlaying_config = deprecated_alias_of(load_overlaying_config)
