# Core

> Some basic functions and classes.

In [None]:
#| default_exp core

In [None]:
#| hide

from nbdev.showdoc import *
from fastcore.test import *

%nbdev_skip_test
%matplotlib inline
%load_ext autoreload
%autoreload 2
%reload_ext autoreload

UsageError: Line magic function `%nbdev_skip_test` not found.


In [None]:
#| export

from dreamai.imports import *

In [None]:
#| export

def flatten_list(l):
    "Flatten a list of lists."
    l2 = []
    for x in l:
        if is_list(x):
            l2 += flatten_list(x)
        else:
            l2.append(x)
    return l2

def noop(x=None, **kwargs):
    "Do nothing."
    return x

def is_list(x):
    return isinstance(x, list)

def is_tuple(x):
    return isinstance(x, tuple)

def list_or_tuple(x):
    return (is_list(x) or is_tuple(x))

def is_iter(o):
    "Test whether `o` can be used in a `for` loop."
    #Rank 0 tensors in PyTorch are not really iterable
    return isinstance(o, (Iterable,Generator)) and getattr(o,'ndim',1)

def is_dict(x):
    return isinstance(x, dict)

def is_df(x):
    return isinstance(x, pd.core.frame.DataFrame)

def is_str(x):
    return isinstance(x, str)

def is_int(x):
    return isinstance(x, int)    

def is_float(x):
    return isinstance(x, float)

def is_array(x):
    return isinstance(x, np.ndarray)

def is_pilimage(x):
    return 'PIL' in str(type(x))

def is_tensor(x):
    return isinstance(x, torch.Tensor)

def is_set(x):
    return isinstance(x, set)

def is_path(x):
    return isinstance(x, Path)

def path_or_str(x):
    return is_str(x) or is_path(x)

def is_norm(x):
    return type(x).__name__ == 'Normalize'

def params(m):
    "Return all parameters of `m`."
    return [p for p in m.parameters()]

def is_frozen(model):
    return np.array([not p.requires_grad for p in (params(model))]).all()

def is_unfrozen(model):
    return np.array([p.requires_grad for p in (params(model))]).all()

def is_subscriptable(x):
    return hasattr(x, '__getitem__')

def is_sequential(x):
    return isinstance(x, nn.Sequential)

def is_clip(x):
    return type(x).__name__ == 'ProntoClip' or 'moviepy' in str(type(x))

def path_name(x):
    return Path(x).name

def path_stem(x):
    return Path(x).stem

def extend_path_name(p, s='_2'):
    "Add `s` to the name of a path `p`. Before the extension."
    p = Path(p)
    return p.parent/(p.stem+s+p.suffix)

def end_of_path(p, n=2):
    "Get the last `n` parts of a path `p`."
    parts = p.parts
    p = Path(parts[-n])
    for i in range(-(n-1), 0):
        p/=parts[i]
    return p

def add_ext_to_path(p, ext='pkl'):
    "Add an extension to a path `p` if it doesn't have one."
    if ext[0] != '.':
        ext = '.'+ext
    if len(Path(p).suffix) == 0:
        p = str(p)
        if p[-1] != '.':
            p+=ext
        else:
            p+=ext[1:]
    return p

def last_modified(x):
    "Get the last modified time of a file."
    return x.stat().st_ctime

def load_yaml(file):
    with open(file) as f:
        env = load(f, Loader=Loader)
    return env

def save_obj(path, obj):
    with open(path, 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def load_obj(path):
    with open(path, 'rb') as f:
        return pickle.load(f)

def yml_to_pip(yml):
    "Get pip packages from a conda environment `yml` file."
    env = load_yaml(yml)
    env_pip = env['dependencies'][-1]['pip']
    return " ".join(env_pip).replace('==', '>=')

def merge_dicts(d1,d2):
    d = {}
    for k in d1:
        d[k] = d1[k]
    for k in d2:
        d[k] = d2[k]
    return d

def dict_values(d):
    "Get the values of a dictionary sorted by key."
    return [v for _,v in sorted(d.items(), key=lambda x:x[0])]

def dict_keys(d):
    "Get the sorted keys of a dictionary."
    return [k for k,_ in sorted(d.items(), key=lambda x:x[0])]

def sort_dict(d, by_value=False):
    "Sort a dictionary by key by default or by value if `by_value` is True."
    idx = int(by_value)
    return {k: v for k, v in sorted(d.items(), key=lambda item:item[idx])}

def locals_to_params(l, omit=[], expand=['kwargs']):
    "Convert all the local variables to a dictionary of parameters."
    if 'kwargs' not in expand:
        expand.append('kwargs')
    l = copy.deepcopy(l)
    if 'self' in l.keys():
        del l['self']
    if '__class__' in l.keys():
        del l['__class__']
    keys = dict_keys(l)
    for k in keys:
        if k in expand:
            for k2 in l[k]:
                if k2 not in l.keys():
                    l[k2] = l[k][k2]
            del l[k]
        if k in omit:
            del l[k]
    return l

def list_map(l, m):
    "Apply `m` to each element of `l`."
    return list(pd.Series(l).apply(m))

def next_batch(dl):
    "Get the next batch from a dataloader `dl`."
    return next(iter(dl))

def model_children(model):
    "Get the children of a model."
    return list(model.children())

def replace_dict_key(d:dict, x='', y='',
                     strict=False): # If True, replace if `x` == key. If False, replace if `x` in key.
    "Replace key `x` with `y` in dictionary `d`."
    if (x == '' and y == '') or x == y:
        return d
    if strict:
        fn = lambda k: k.replace(x, y) if x == k else k
    else:
        fn = lambda k: k.replace(x, y) if x in k else k
    return {fn(k): v for k, v in d.items()}

def proc_fn(fn):
    "Process function `fn`. It can be a string to match or a function that takes a key and returns True/False."
    if is_str(fn):
        t = copy.deepcopy(fn)
        fn = lambda x: x==t
    return fn

def filter_dict(d,
                fn): # Can be a string to match or a function that takes a key and returns True/False.
    "Filter dict `d` based on function `fn`."
    d2 = {}
    fn = proc_fn(fn)
    keys = dict_keys(d)
    for k in keys:
        if fn(k):
            d2[k] = d[k]
    return d2

def setify(o): return o if isinstance(o,set) else set(list(o))

def _get_files(p, fs, extensions=None):
    p = Path(p)
    res = [p/f for f in fs if not f.startswith('.')
           and ((not extensions) or f'.{f.split(".")[-1].lower()}' in extensions)]
    return res

def get_files(path, extensions=None, recurse=True, folders=None, followlinks=True, make_str=False):
    "Get all the files in `path` with optional `extensions`, optionally with `recurse`, only in `folders`, if specified."
    if folders is None:
        folders = list([])
    path = Path(path)
    if extensions is not None:
        extensions = setify(extensions)
        extensions = {e.lower() for e in extensions}
    if recurse:
        res = []
        for i,(p,d,f) in enumerate(os.walk(path, followlinks=followlinks)): # returns (dirpath, dirnames, filenames)
            if len(folders) !=0 and i==0: d[:] = [o for o in d if o in folders]
            else:                         d[:] = [o for o in d if not o.startswith('.')]
            if len(folders) !=0 and i==0 and '.' not in folders: continue
            res += _get_files(p, f, extensions)
    else:
        f = [o.name for o in os.scandir(path) if o.is_file()]
        res = _get_files(path, f, extensions)
    if make_str: res = [str(o) for o in res]
    return list(res)

## Some usage examples:

Get all the files in a directory.

In [None]:
path = '../'
files = get_files(path)
files[-11:-8]

[Path('../dreamai/vision.py'),
 Path('../dreamai/core.py'),
 Path('../dreamai/imports.py')]

Only get the files from the `nbs` folder.

In [None]:
nb_files = get_files(path, folders=['nbs'])
nb_files

[Path('../nbs/index.ipynb'),
 Path('../nbs/nbdev.yml'),
 Path('../nbs/sidebar.yml'),
 Path('../nbs/01_vision.ipynb'),
 Path('../nbs/_quarto.yml'),
 Path('../nbs/styles.css'),
 Path('../nbs/00_core.ipynb')]

Checking a file path. All tests should pass.

In [None]:
file = str(nb_files[-1])
test_eq(is_str(file), True)
test_eq(is_path(file), False)
test_eq(path_or_str(file), True)
file

'../nbs/00_core.ipynb'

In [None]:
file = nb_files[-1]
test_eq(is_str(file), False)
test_eq(is_path(file), True)
test_eq(path_or_str(file), True)
file

Path('../nbs/00_core.ipynb')

Checking a numpy array.

In [None]:
x = np.array([1,2,3])
test_eq(is_array(x), True)
test_eq(is_list(x), False)
test_eq(is_iter(x), True)

Add an extension to a file path if it doesn't already have one.

In [None]:
l = ['f1', 'f2.jpeg', 'f3']
ext = '.png'
l2 = list_map(l, partial(add_ext_to_path, ext=ext))
print(l)
print(l2)

['f1', 'f2.jpeg', 'f3']
['f1.png', 'f2.jpeg', 'f3.png']


Some dictionary examples.

In [None]:
d = {'apple':1, 'apple_pie':2, 'cake':3}

In [None]:
# Sort dict.

print(sort_dict(d)) # by key
print(sort_dict(d, by_value=True)) # by value

{'apple': 1, 'apple_pie': 2, 'cake': 3}
{'apple': 1, 'apple_pie': 2, 'cake': 3}


In [None]:
# Replace 'apple' with 'pumpkin'
print(replace_dict_key(d, x='apple', y='pumpkin', strict=True))

# Replace all instances of 'apple' with 'pumpkin'
print(replace_dict_key(d, x='apple', y='pumpkin', strict=False))

{'pumpkin': 1, 'apple_pie': 2, 'cake': 3}
{'pumpkin': 1, 'pumpkin_pie': 2, 'cake': 3}


In [None]:
# Remove keys that don't have 'apple' in them.
d2 = filter_dict(d, fn=lambda x: 'apple' in x)
d2

{'apple': 1, 'apple_pie': 2}

In [None]:
#| hide

test_eq(dict_keys(d2), ['apple', 'apple_pie'])

Flatten a list.

In [None]:
l = [[1, 2, 3], [4, [5, 6]], [7], [8, 9], 10]
l2 = flatten_list(l)
l2

[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

In [None]:
test_eq(l2, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

Get the local variables/parameters in a function.

In [None]:
def fn2(z=30, **kwargs):
    
    # lp will have the value of `z` and the values in `kwargs`.
    lp = locals_to_params(locals())
    print(f'fn2 local params: {lp}')
    return lp['z']

def fn(x=10, y=20, **kwargs):
    
    # lp will have the value of `x`, `y` and the values in `kwargs`.
    lp = locals_to_params(locals())
    print(f'fn local params: {lp}') 
    return x + y + fn2(**lp)

In [None]:
res = fn()  # fn2 will receive `x=10`, `y=30` and `z=30` by default.
print(f'Result: {res}')

fn local params: {'x': 10, 'y': 20}
fn2 local params: {'z': 30, 'x': 10, 'y': 20}
Result: 60


In [None]:
params = {'x': 50, 'z': 300} # fn2 will receive `x=50`, `y=20`, `z=300`.
res = fn(**params)
print(f'Result: {res}')

fn local params: {'x': 50, 'y': 20, 'z': 300}
fn2 local params: {'z': 300, 'x': 50, 'y': 20}
Result: 370


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()