# Utility Functions

> 通用工具函数和路径设置
>
> General utility functions and path settings

## 简介/Description:

utils 模块包含通用工具函数和项目中的关键路径设置，如 data_path。这些工具函数与项目的各个模块没有直接耦合，提供了项目中可复用的常用功能。

The utils module contains general utility functions and key path settings for the project, such as data_path. These utility functions are decoupled from the project’s main modules and provide commonly used reusable functionality across the project.

## 主要符号/Main symbols:

- data_path: 数据存储路径的设置，用于配置数据集的根目录。

  data_path: Defines the data storage path, used for setting the dataset root directory.


- other_util_function: 其他工具函数，未来可以扩展。

  other_util_function: Placeholder for other utility functions, expandable for future needs.

In [None]:
#| default_exp utils

In [None]:
#| hide
%load_ext autoreload
%autoreload 2
from nbdev.showdoc import *

## 本库有关一些信息

In [None]:
#| export
from pathlib import Path
import inspect
import namable_classify
lib_init_path = Path(inspect.getfile(namable_classify))
lib_directory_path = lib_init_path.parent
lib_repo_path = lib_directory_path.parent
runs_path = lib_repo_path/'runs'
runs_path.mkdir(exist_ok=True, parents=True)
runs_figs_path = runs_path/'figs'
runs_figs_path.mkdir(exist_ok=True, parents=True)
data_path = lib_repo_path/'data'
data_path.mkdir(exist_ok=True, parents=True)

In [None]:
#| export
with open(lib_repo_path/"README.md") as readme:
    namable_classify.__doc__ = readme.read()

In [None]:
# namable_classify?

## 日志模块 / Logging Module

我们结合loguru和rich的最佳实践，利用richuru库。

In [None]:
#| export 
# How to set logger level in loguru?
# https://github.com/Delgan/loguru/issues/138
# Make faster? picologging
# import 
# def set_logger_level(level):
#     os
# How to add file handler to loguru logger?
# try:
import richuru
from rich.console import Console
from rich.theme import Theme
import logging
from rich.markdown import Markdown
import rich

# 如果在python console里面调用，就可以看到好看的东西。
from rich import pretty
pretty.install()

rich_console = Console(
    theme=richuru.Theme(  # required, otherwise the color will be incorrect
        {
            'logging.level.success': 'green',
            'logging.level.trace': 'bright_black',
        }
    ), 
    markup=True
)
richuru.install(rich_console=rich_console, 
                time_format="%a %Y-%m-%d %H:%M:%S.%f", 
                level = logging.INFO
)
# except ImportError:
#     pass


In [None]:
Theme?
richuru.install?

[0;31mSignature:[0m
[0mrichuru[0m[0;34m.[0m[0minstall[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mrich_console[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mrich[0m[0;34m.[0m[0mconsole[0m[0;34m.[0m[0mConsole[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mexc_hook[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mCallable[0m[0;34m[[0m[0;34m[[0m[0mType[0m[0;34m[[0m[0mBaseException[0m[0;34m][0m[0;34m,[0m [0mBaseException[0m[0;34m,[0m [0mOptional[0m[0;34m[[0m[0mtraceback[0m[0;34m][0m[0;34m][0m[0;34m,[0m [0mAny[0m[0;34m][0m[0;34m][0m [0;34m=[0m [0;34m<[0m[0mfunction[0m [0m_loguru_exc_hook[0m [0mat[0m [0;36m0x7458661a60e0[0m[0;34m>[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mrich_traceback[0m[0;34m:[0m [0mbool[0m [0;34m=[0m [0;32mTrue[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtb_ctx_lines[0m[0;34m:[0m [0mint[0m [0;34m=[0m [0;36m3[0m[0;34m,[0m[0;34m[0m
[0;34m

In [None]:
# richuru.Theme?
# %load_ext rich
# ["Rich and pretty", True]

In [None]:
#| export
from loguru import logger
original_print = print
print = lambda *args, **kwargs: logger.info(*args, **kwargs)

In [None]:
%load_ext rich
print(["Hello World!", True])
original_print(["Hello World!", True])

The rich extension is already loaded. To reload it, use:
  %reload_ext rich


['Hello World!', True]


In [None]:
# from rich.logging import RichHandler
# import logging
# logger.info("setting up logger. ")
# logger.configure(handlers=[{"sink": RichHandler(markup=True, 
#                                                 log_time_format = "%a %Y-%m-%d %H:%M:%S.%f", 
#                                                 level=logging.INFO),
#                         #  "format": "[red]{function}[/red] {message}"}]
#                         "format":"<cyan>{name}</cyan>: <level>{message}</level>"
#                         }]
                #  )
# RichHandler?


In [None]:
logger.debug("This is a info statement")
logger.info("This is a info statement", style="bold blue")

In [None]:
logger.info("", rich=Markdown("---"))

## 检查PyTorch模型是否符合预期，是否为要训练的模型

In [None]:
#| export
from fastcore.basics import patch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from rich.table import Table

@patch
def inspect_model_parameters(model:nn.Module):
    trainable_params = 0
    all_param = 0
    trainable_bytes = 0
    all_bytes = 0
    for _, param in model.named_parameters():
        param_bytes = param.numel() * param.element_size()
        all_param += param.numel()
        all_bytes += param_bytes
        if param.requires_grad:
            trainable_params += param.numel()
            trainable_bytes += param_bytes
    return trainable_params, all_param, trainable_bytes, all_bytes

@patch
def num_of_total_parameters(model:nn.Module):
    return inspect_model_parameters(model)[1]

@patch
def num_of_trainable_parameters(model:nn.Module):
    return inspect_model_parameters(model)[0]

@patch
def print_trainable_parameters(model:nn.Module):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params, all_param, trainable_bytes, all_bytes = model.inspect_model_parameters()
    # print(
    table = Table(title=f"Model {model.__class__.__name__}'s Trainable Parameters Inspection")
    table.add_column("Number of Trainable Parameters", justify="right", style="cyan", no_wrap=True)
    table.add_column("Number of Total Parameters", style="magenta")
    table.add_column("Trainable Ratio (0-1)", justify="right", style="green")
    table.add_row(f"{trainable_params:.3e} ({trainable_bytes:.3e} bytes)", f"{all_param:.3e} ({all_bytes:.3e} bytes)", f"{trainable_params / all_param:.3e}")
    
    logger.info(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}", 
        rich=table
    )

In [None]:
from transformers import AutoModel, AutoConfig
test_model = AutoModel.from_config(AutoConfig.from_pretrained("google/vit-base-patch16-224-in21k"))



In [None]:
test_model.print_trainable_parameters()

In [None]:
# @patch
# def print_model_pretty(self:nn.Module):
#     with console.capture() as capture:
#         module_tree = Visualization(self).structure_graph(printTree=False)
#         console.print(module_tree)
#     # console.print(capture.get())
#     # return capture.get()
#     # return module_tree

In [None]:
#| export
from bigmodelvis import Visualization
@patch
def model_rich_tree(self:nn.Module):
    module_tree = Visualization(self).structure_graph(printTree=False)
    return module_tree

from rich.panel import Panel
@patch
def print_model_pretty(self:nn.Module):
    module_tree = self.model_rich_tree()
    panel = Panel(module_tree, title=f"Model Tree for {self.__class__.__name__}")
    logger.info(str(self), rich=panel)
    # return module_tree

In [None]:
test_model.print_model_pretty()

## 其他工具

In [None]:
#| export
import warnings
class MuteWarnings:
    def __enter__(self):
        # self.warnings_show = warnings.showwarning
        # warnings.showwarning = lambda *args, **kwargs: None
        self.mute()
        
    def __exit__(self, exc_type, exc_val, exc_tb):
        # warnings.showwarning = self.warnings_show
        self.close()        
        
    def mute(self):
        warnings.filterwarnings("ignore", append=True)
        
    def resume(self):
        warnings.filters.pop(0)
        

In [None]:
#| export
import torch
import numpy as np
def ensure_array(x: torch.TensorType | np.ndarray | list):
    if isinstance(x, torch.Tensor):
        return x.detach().cpu().numpy()
    elif isinstance(x, np.ndarray):
        return x
    else: # list
        return np.array(x)

In [None]:
ensure_array([1, 2, 3])

array([1, 2, 3])

In [None]:
#| export
# from scipy.special import softmax
from decorator import decorator
# def default_on_exception(default_value=None):
#     def decorator(func):
#         def wrapper(*args, **kwargs):
#             try:
#                 result = func(*args, **kwargs)
#                 return result
#             except Exception as e:
#                 logger.warning(f"An exception occurred: {e}")
#                 return default_value
#         return wrapper
#     return decorator

@decorator
def default_on_exception(func, default_value=None, verbose=False, *args, **kwargs):
    try:
        result = func(*args, **kwargs)
        return result
    except Exception as e:
        # logger.warning(f"An exception occurred: {e}")
        if verbose:
            logger.exception(e)
        return default_value

In [None]:
@default_on_exception(default_value=999, verbose=True)
def test_default_on_exception(a=1):
    1/0
    return a
test_default_on_exception()

999

In [None]:
#| export
def append_dict_list(dict, name, value):
    dict[name] = dict.get(name, []) + [value]

In [None]:
# TODO 暂时无法使用 decorator实现这个; 目前尽量不要使用这个API
def partial_with_self(method, *args, **kwargs):
    def wrapped(self, *additional_args, **additional_kwargs):
        # Combine provided args and kwargs with additional ones
        all_args = args + additional_args
        all_kwargs = kwargs | additional_kwargs
        return method(self, *all_args, **all_kwargs)
    return wrapped

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()