diff --git a/.gitignore b/.gitignore index 6c2e3e6..471947b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .idea/ +test-lang.py # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/setup.py b/setup.py index a901b41..b03ae78 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ from setuptools import setup, find_packages -VERSION = '0.10.3' +VERSION = '0.11.0' DESCRIPTION = 'Machine Learning project startup utilities' LONG_DESCRIPTION = 'My commonly used utilities for machine learning projects' @@ -13,7 +13,7 @@ description=DESCRIPTION, long_description=LONG_DESCRIPTION, url='https://github.com/StefanHeng/stef-util', - download_url='https://github.com/StefanHeng/stef-util/archive/refs/tags/0.10.3.tar.gz', + download_url='https://github.com/StefanHeng/stef-util/archive/refs/tags/v0.11.0.tar.gz', packages=find_packages(), include_package_data=True, install_requires=[ diff --git a/stefutil/__init__.py b/stefutil/__init__.py index 43bc920..7fa090b 100644 --- a/stefutil/__init__.py +++ b/stefutil/__init__.py @@ -1,9 +1,8 @@ from .built_in import * from .os_n_file import * from .primitive import * -from .prettier import * from .container import * -from .check_args import * +from .prettier import * from .concurrency import * from .function import * from .plot import * diff --git a/stefutil/check_args.py b/stefutil/check_args.py deleted file mode 100644 index ff16747..0000000 --- a/stefutil/check_args.py +++ /dev/null @@ -1,46 +0,0 @@ -""" -An easy, readable interface for checking string arguments as effectively enums - -Intended for high-level arguments instead of actual data processing as not as efficient -""" - - -from typing import List - -from stefutil.prettier import logi - - -__all__ = ['CheckArg', 'ca'] - - -class CheckArg: - """ - Raise errors when common arguments don't match the expected values - """ - - @staticmethod - def check_mismatch(display_name: str, val: str, accepted_values: List[str]): - if val not in accepted_values: - raise ValueError(f'Unexpected {logi(display_name)}: ' - f'expect one of {logi(accepted_values)}, got {logi(val)}') - - def __init__(self): - self.d_name2func = dict() - - def __call__(self, **kwargs): - for k in kwargs: - self.d_name2func[k](kwargs[k]) - - def cache_mismatch(self, display_name: str, attr_name: str, accepted_values: List[str]): - self.d_name2func[attr_name] = lambda x: CheckArg.check_mismatch(display_name, x, accepted_values) - - -ca = CheckArg() -ca.cache_mismatch( # See `stefutil::plot.py` - 'Bar Plot Orientation', attr_name='bar_orient', accepted_values=['v', 'h', 'vertical', 'horizontal'] -) - - -if __name__ == '__main__': - ori = 'v' - ca(bar_orient=ori) diff --git a/stefutil/concurrency.py b/stefutil/concurrency.py index 54e4867..80f6cb1 100644 --- a/stefutil/concurrency.py +++ b/stefutil/concurrency.py @@ -13,7 +13,7 @@ from tqdm.auto import tqdm from tqdm.contrib import concurrent as tqdm_concurrent -from stefutil.check_args import ca +from stefutil.prettier import ca __all__ = ['conc_map', 'batched_conc_map'] diff --git a/stefutil/plot.py b/stefutil/plot.py index 5f6bbfd..d32d042 100644 --- a/stefutil/plot.py +++ b/stefutil/plot.py @@ -11,7 +11,7 @@ import matplotlib.pyplot as plt import seaborn as sns -from stefutil.check_args import ca +from stefutil.prettier import ca from stefutil.container import df_col2cat_col diff --git a/stefutil/prettier.py b/stefutil/prettier.py index b2cb3b5..062d01d 100644 --- a/stefutil/prettier.py +++ b/stefutil/prettier.py @@ -26,10 +26,12 @@ __all__ = [ - 'fmt_num', 'fmt_sizeof', 'fmt_delta', 'sec2mmss', 'round_up_1digit', 'nth_sig_digit', 'now', + 'fmt_num', 'fmt_sizeof', 'fmt_delta', 'sec2mmss', 'round_up_1digit', 'nth_sig_digit', 'MyIceCreamDebugger', 'mic', 'log', 'log_s', 'logi', 'log_list', 'log_dict', 'log_dict_nc', 'log_dict_id', 'log_dict_pg', 'log_dict_p', 'hex2rgb', 'MyTheme', 'MyFormatter', 'get_logger', + 'CheckArg', 'ca', + 'now', 'MlPrettier', 'MyProgressCallback' ] @@ -93,17 +95,6 @@ def nth_sig_digit(flt: float, n: int = 1) -> float: return float('{:.{p}g}'.format(flt, p=n)) -def now(as_str=True, for_path=False) -> Union[datetime.datetime, str]: - """ - # Considering file output path - :param as_str: If true, returns string; otherwise, returns datetime object - :param for_path: If true, the string returned is formatted as intended for file system path - """ - d = datetime.datetime.now() - fmt = '%Y-%m-%d_%H-%M-%S' if for_path else '%Y-%m-%d %H:%M:%S' - return d.strftime(fmt) if as_str else d - - class MyIceCreamDebugger(IceCreamDebugger): def __init__(self, output_width: int = 120, **kwargs): self._output_width = output_width @@ -378,6 +369,7 @@ def get_logger(name: str, typ: str = 'stdout', file_path: str = None) -> logging handler.setLevel(logging.DEBUG) handler.setFormatter(MyFormatter(with_color=typ == 'stdout')) logger.addHandler(handler) + logger.propagate = False return logger @@ -520,16 +512,88 @@ def on_train_end(self, args, state, control, **kwargs): pass +class CheckArg: + """ + An easy, readable interface for checking string arguments as effectively enums + + Intended for high-level arguments instead of actual data processing as not as efficient + + Raise errors when common arguments don't match the expected values + """ + + @staticmethod + def check_mismatch(display_name: str, val: str, accepted_values: List[str]): + if val not in accepted_values: + raise ValueError(f'Unexpected {logi(display_name)}: ' + f'expect one of {logi(accepted_values)}, got {logi(val)}') + + def __init__(self): + self.d_name2func = dict() + + def __call__(self, **kwargs): + for k in kwargs: + self.d_name2func[k](kwargs[k]) + + def cache_mismatch(self, display_name: str, attr_name: str, accepted_values: List[str]): + self.d_name2func[attr_name] = lambda x: CheckArg.check_mismatch(display_name, x, accepted_values) + + +ca = CheckArg() +ca.cache_mismatch( # See `stefutil::plot.py` + 'Bar Plot Orientation', attr_name='bar_orient', accepted_values=['v', 'h', 'vertical', 'horizontal'] +) + + +def now(as_str=True, for_path=False, fmt: str = 'full') -> Union[datetime.datetime, str]: + """ + # Considering file output path + :param as_str: If true, returns string; otherwise, returns datetime object + :param for_path: If true, the string returned is formatted as intended for file system path + relevant only when as_str is True + :param fmt: One of [`full`, `date`, `short-date`] + relevant only when as_str is True + """ + d = datetime.datetime.now() + if as_str: + ca.check_mismatch('Date Format', fmt, ['full', 'date', 'short-date']) + if fmt == 'full': + fmt_tm = '%Y-%m-%d_%H-%M-%S' if for_path else '%Y-%m-%d %H:%M:%S.%f' + else: + fmt_tm = '%Y-%m-%d' + ret = d.strftime(fmt_tm) + if fmt == 'short-date': # year in 2-digits + ret = ret[2:] + return ret + else: + return d + + if __name__ == '__main__': # lg = get_logger('test') # lg.info('test') - def check_log_lst(): - lst = ['sda', 'asd'] - print(log_list(lst)) - # check_log_lst() - - def check_logi(): - d = dict(a=1, b=2) - print(logi(d)) - check_logi() + # def check_log_lst(): + # lst = ['sda', 'asd'] + # print(log_list(lst)) + # # check_log_lst() + # + # def check_logi(): + # d = dict(a=1, b=2) + # print(logi(d)) + # check_logi() + + def check_logger(): + logger = get_logger('blah') + logger.info('should appear once') + # check_logger() + + def check_now(): + mic(now(fmt='full')) + mic(now(fmt='date')) + mic(now(fmt='short-date')) + check_now() + + def check_ca(): + ori = 'v' + ca(bar_orient=ori) + # check_ca()