In [None]:
# default_exp funcs.utils

In [None]:
%load_ext autoreload
%autoreload 2

# Random assortment of helper functions

In [None]:
#export
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
from decode_fish.imports import *
from itertools import product as iter_product

import gc

In [None]:
#export
def free_mem():
    gc.collect()
    torch.cuda.empty_cache()    
    
def crop_vol(vol, fxyz_sl=np.s_[:,:,:,:], px_size=[1.,1.,1.]):
    
    vol_sl = tuple([fxyz_sl[i] for i in [0,3,2,1]])
    if vol.ndim == 3:
        vol = vol[vol_sl[-3:]]
    else:
        vol = vol[vol_sl]
    
    return vol

def center_crop(volume, zyx_ext):
    
    shape_3d = volume.shape[-3:]
    center = [s//2 for s in shape_3d]
    volume = volume[...,center[0]-math.floor(zyx_ext[0]/2):center[0]+math.ceil(zyx_ext[0]/2),
                        center[1]-math.floor(zyx_ext[1]/2):center[1]+math.ceil(zyx_ext[1]/2),
                        center[2]-math.floor(zyx_ext[2]/2):center[2]+math.ceil(zyx_ext[2]/2)]
    return volume

def smooth(x,window_len=11,window='flat'):

    if window_len<3:
        return x

    s=np.r_[x[window_len-1:0:-1],x,x[-2:-window_len-1:-1]]
    if window == 'flat': #moving average
        w=np.ones(window_len,'d')

    y=np.convolve(w/w.sum(),s,mode='valid')
    return y

def plot_tb_logs(exps, metric='Sim. Metrics/eff_3d', window_len=1):
    all_vals = []
    for exp in exps:
        event_acc = EventAccumulator(exp)
        event_acc.Reload()       
        w_times, step_nums, vals = zip(*event_acc.Scalars(metric))
        all_vals.append(vals)
    
    for v,e in zip(all_vals,exps):
        plt.plot(smooth(v, window_len), label=e.split('/')[-1])
#         print(np.array(v).max().round(2), np.array(v).min().round(2), len(v), e)
    plt.legend()

In [None]:
#export
def gpu(x):
    '''Transforms numpy array or torch tensor torch torch.cuda.FloatTensor'''
    return FloatTensor(x).cuda()

def cpu(x):
    '''Transforms torch tensor into numpy array'''
    if torch.is_tensor(x):
        return x.cpu().detach().numpy()
    else:
        return x

In [None]:
#export
def zip_longest_special(*iterables):
    def filter(items, defaults):
        return tuple(d if i is sentinel else i for i, d in zip(items, defaults))
    sentinel = object()
    iterables = itertools.zip_longest(*iterables, fillvalue=sentinel)
    first = next(iterables)
    yield filter(first, [None] * len(first))
    for item in iterables:
        yield filter(item, first)

class param_iter(object):

    def __init__(self):

        self.keys = []
        self.vals = []

    def add(self, name, *args):

        self.keys.append(name)
        self.vals.append(args)

    def param_product(self):

        all_params = []
        for values in iter_product(*self.vals):

            params = dict()
            for i,val in zip(self.keys,values):
                params.update({i : val })

            all_params.append(params)

        return all_params

    def param_zip(self):

        all_params = []
        for values in zip_longest_special(*self.vals):

            params = dict()
            for i,val in zip(self.keys,values):
                params.update({i : val })

            all_params.append(params)

        return all_params

In [None]:
variable_col = param_iter()
variable_col.add('lr', 1e-3, 5e-3)
variable_col.add('lasso_mat', True, False)

par_prod = variable_col.param_product()
print(par_prod)

par_zip = variable_col.param_zip()
print(par_zip)

[{'lr': 0.001, 'lasso_mat': True}, {'lr': 0.001, 'lasso_mat': False}, {'lr': 0.005, 'lasso_mat': True}, {'lr': 0.005, 'lasso_mat': False}]
[{'lr': 0.001, 'lasso_mat': True}, {'lr': 0.005, 'lasso_mat': False}]


In [None]:
!nbdev_build_lib

Converted 00_models.ipynb.
Converted 01_psf.ipynb.
Converted 02_microscope.ipynb.
Converted 03_noise.ipynb.
Converted 04_pointsource.ipynb.
Converted 05_gmm_loss.ipynb.
Converted 06_plotting.ipynb.
Converted 07_file_io.ipynb.
Converted 08_dataset.ipynb.
Converted 09_output_trafo.ipynb.
Converted 10_evaluation.ipynb.
Converted 11_emitter_io.ipynb.
Converted 12_utils.ipynb.
Converted 13_train.ipynb.
Converted 15_fit_psf.ipynb.
Converted 16_visualization.ipynb.
Converted index.ipynb.
