# utils

> Utility functions to make programming easier

In [None]:
#| default_exp utils

In [None]:
#|hide
%load_ext autoreload
%autoreload 2

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import fastcore.all as fc
import random
import torch
from typing import Mapping
import matplotlib.pyplot as plt
import math
import numpy as np
from itertools import zip_longest
from datetime import timedelta
import sys,gc,traceback
import types
import inspect
import functools

  from .autonotebook import tqdm as notebook_tqdm


## Misc 

In [None]:
#|export
def set_seed(seed, deterministic=False):
    torch.use_deterministic_algorithms(deterministic)
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

In [None]:
#|export
def inplace(f):
    '''Return the object passed to the function for in place mods'''
    def _f(b):
        f(b)
        return b
    return _f

In [None]:
#|export
def mask2idxs(mask): return [i for i, e in enumerate(mask) if e == True]

In [None]:
mask = fc.L(True, False, True, False, True)
fc.test_eq(mask2idxs(mask), [0, 2, 4])
fc.test_eq(mask2idxs(~mask), [1,3])

In [None]:
#| export
class PPDict(dict):
    def __str__(self):
        out = {}
        for k,v in self.items():
            if isinstance(v,float): out[k] = round(v,4)
            elif isinstance(v,timedelta): out[k] = str(timedelta(seconds=math.ceil(v.total_seconds())))
            else: out[k] = v
        return str(out)  

In [None]:
#| export
def retrieve_global_name(var):
    callers_local_vars = inspect.currentframe().f_back.f_back.f_locals.items()
    return [var_name for var_name, var_val in callers_local_vars if var_val is var]

## Memory

> Much of this code is taken from the minai library from fastai

In [None]:
#| export
def clean_ipython_hist():
    # Code in this function mainly copied from IPython source
    if not 'get_ipython' in globals(): return
    ip = get_ipython()
    user_ns = ip.user_ns
    ip.displayhook.flush()
    pc = ip.displayhook.prompt_count + 1
    for n in range(1, pc): user_ns.pop('_i'+repr(n),None)
    user_ns.update(dict(_i='',_ii='',_iii=''))
    hm = ip.history_manager
    hm.input_hist_parsed[:] = [''] * pc
    hm.input_hist_raw[:] = [''] * pc
    hm._i = hm._ii = hm._iii = hm._i00 =  ''

In [None]:
#| export
def clean_traceback():
    '''Objects in tracebacks are stored in memory, even cuda memory.
    This clears that traceback memory up'''
    # h/t Piotr Czapla
    if hasattr(sys, 'last_traceback'):
        traceback.clear_frames(sys.last_traceback)
        delattr(sys, 'last_traceback')
    if hasattr(sys, 'last_type'): delattr(sys, 'last_type')
    if hasattr(sys, 'last_value'): delattr(sys, 'last_value')

In [None]:
#| export
def clean_memory():
    '''Cleans all memory from hist and tracebacks'''
    clean_traceback()
    clean_ipython_hist()
    gc.collect()
    torch.cuda.empty_cache()

## Device Management

In [None]:
#| export
def_device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'

def to_device(x:torch.tensor, device=def_device):
    if isinstance(x, torch.Tensor): return x.to(device)
    if isinstance(x, Mapping): return {k:v.to(device) for k,v in x.items()}
    return type(x)(to_device(o, device) for o in x)

def to_cpu(x):
    if isinstance(x, Mapping): return {k:to_cpu(v) for k,v in x.items()}
    if isinstance(x, list): return [to_cpu(o) for o in x]
    if isinstance(x, tuple): return tuple(to_cpu(list(x)))
    return x.detach().cpu()

## MatplotLib Helpers

> Much of this code is taken from the minai library from fastai

In [None]:
#| export
@fc.delegates(plt.Axes.imshow)
def show_image(im, ax=None, figsize=None, title=None, noframe=True, **kwargs):
    '''Show a PIL or PyTorch image on `ax`
        + Moves to cpu & detach
        + converts to numpy
        + remove axis ticks
    '''
    if fc.hasattrs(im, ('cpu','permute','detach')):
        im = im.detach().cpu()
        if len(im.shape)==3 and im.shape[0]<5: im=im.permute(1,2,0)
    elif not isinstance(im,np.ndarray): im=np.array(im)
    if im.shape[-1]==1: im=im[...,0]
    if ax is None: _,ax = plt.subplots(figsize=figsize)
    ax.imshow(im, **kwargs)
    if title is not None: ax.set_title(title)
    ax.set_xticks([]) 
    ax.set_yticks([]) 
    if noframe: ax.axis('off')
    return ax

In [None]:
#| export
@fc.delegates(plt.subplots, keep=True)
def subplots(
    nrows:int=1, # Number of rows in returned axes grid
    ncols:int=1, # Number of columns in returned axes grid
    figsize:tuple=None, # Width, height in inches of the returned figure
    imsize:int=3, # Size (in inches) of images that will be displayed in the returned figure
    suptitle:str=None, # Title to be set to returned figure
    **kwargs
): # fig and axs
    "A figure and set of subplots to display images of `imsize` inches"
    if figsize is None: figsize=(ncols*imsize, nrows*imsize)
    fig,ax = plt.subplots(nrows, ncols, figsize=figsize, **kwargs)
    if suptitle is not None: fig.suptitle(suptitle)
    if nrows*ncols==1: ax = np.array([ax])
    return fig,ax

In [None]:
#| export
@fc.delegates(subplots)
def get_grid(
    n:int, # Number of axes
    nrows:int=None, # Number of rows, defaulting to `int(math.sqrt(n))`
    ncols:int=None, # Number of columns, defaulting to `ceil(n/rows)`
    title:str=None, # If passed, title set to the figure
    weight:str='bold', # Title font weight
    size:int=14, # Title font size
    **kwargs,
): # fig and axs
    "Return a grid of `n` axes, `rows` by `cols`"
    if nrows: ncols = ncols or int(np.floor(n/nrows))
    elif ncols: nrows = nrows or int(np.ceil(n/ncols))
    else:
        nrows = int(math.sqrt(n))
        ncols = int(np.floor(n/nrows))
    fig,axs = subplots(nrows, ncols, **kwargs)
    for i in range(n, nrows*ncols): axs.flat[i].set_axis_off()
    if title is not None: fig.suptitle(title, weight=weight, size=size)
    return fig,axs

In [None]:
#| export
@fc.delegates(subplots)
def show_images(ims:list, # Images to show
                nrows:int|None=None, # Number of rows in grid
                ncols:int|None=None, # Number of columns in grid (auto-calculated if None)
                titles:list|None=None, # Optional list of titles for each image
                **kwargs):
    "Show all images `ims` as subplots with `rows` using `titles`"
    axs = get_grid(len(ims), nrows, ncols, **kwargs)[1].flat
    for im,t,ax in zip_longest(ims, titles or [], axs): show_image(im, ax=ax, title=t)

# Callbacks

In [None]:
#| export
def run_callbacks(callbacks, method_name, trainer=None):
    for callback in sorted(callbacks, key=lambda x: getattr(x, 'order',0)):
        callback_method = getattr(callback, method_name,None)
        if callback_method is not None: callback_method(trainer)

In [None]:
#| export
def add_callback(trainer,callback,force=False):
    trainer.callbacks = getattr(trainer,'callbacks',fc.L())
    if callback is None: return None
    cb_name = callback.__class__.__name__
    
    if cb_name in trainer.callbacks: 
        if force: remove_callback(trainer,callback,True)
        else: assert cb_name not in trainer.callbacks
    
    trainer.callbacks.append(cb_name)
    setattr(trainer,cb_name,callback)
    
    _cb = getattr(trainer,cb_name)
    if hasattr(_cb,'callbacks'): trainer.add_callbacks(getattr(_cb,'callbacks'))    

    
def add_callbacks(trainer,callbacks,force=False):
    trainer.callbacks = getattr(trainer,'callbacks',fc.L())
    for callback in callbacks: add_callback(trainer,callback, force)

In [None]:
#|export
def remove_callback(trainer,callback,delete=False):
    cb_name = callback.__class__.__name__
    trainer.callbacks.remove(cb_name)
    if delete: delattr(trainer,cb_name)
    
def remove_callbacks(trainer,callbacks,delete=False):
    for callback in callbacks: remove_callback(trainer, callback, delete)

In [None]:
#| export
class with_cbs:
    def __init__(self, nm, exception): fc.store_attr()
    def __call__(self, f):
        def _f(o, *args, **kwargs):
            try:
                o.run_callbacks(f'before_{self.nm}')
                f(o, *args, **kwargs)
                o.run_callbacks(f'after_{self.nm}')
            except self.exception: pass
            finally: o.run_callbacks(f'cleanup_{self.nm}')
        return _f

# Hooks

In [None]:
#| export
class Hook:
    def __init__(self,module,func): self.hook = module.register_forward_hook(fc.bind(func,self))
    def __del__(self): self.hook.remove()
    def remove(self): self.hook.remove()

In [None]:
#| export
class Hooks(list):
    def __init__(self,modules, func): super().__init__([Hook(module,func) for module in modules])
    def __enter__(self): return self
    def __exit__(self): self.remove()
    def __del__(self): self.remove()
    def remove(self): 
        for hook in self: hook.remove()

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()