In [None]:
#| default_exp utils

In [None]:
#| export 
import pydoc
import PIL 
import SimpleITK as sitk
import numpy as np
from PIL import Image
from functools import partial
from collections import OrderedDict
from loguru import logger

In [None]:
#| export 
def image_grid(imgs, rows, cols):
    w,h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

In [None]:
#| export 
def thumbnail(img, size=256):
    if not isinstance(img, PIL.Image.Image): img = Image.fromarray(img)
    w, h = img.size
    ar = h/w 
    return img.resize((size, int(size*ar)))

In [None]:
#| export 
def windowing(ww, wl):
    low = wl - ww/2
    high = wl + ww/2
    
    def _window(img):
        out = img.copy()
        out = (out-low)/(high-low)
        out = np.clip(out, 0, 1)
        return out 
    return _window

In [None]:
#| export
hu_to_lung_window = windowing(1600, -600)

In [None]:
#| export 
def vis(img, size, window=True, seed=42):
    if window: 
        img = hu_to_lung_window(img)
        img = np.uint8(img*255)
    np.random.seed(seed)
    x = np.random.randint(img.shape[0], size=25)
    x.sort()
    img = [thumbnail(img[i], size) for i in x]
    return image_grid(img, 5, 5)

In [None]:
#| export 
def load_deeplake_nodule_data(ds, idx):
    f = {}
    keys = ["images", "boxes", "labels", "series_id", "mask", "spacing"]
    for key in keys:
        if key in ds.tensors.keys():
            f[key] = ds[key][idx].numpy()
        if key == "series_id": f[key] = f[key][0]        
    return f

In [None]:
#| export 
def load_sitk_img(series_path, series_id=None):
    img = series_path if isinstance(series_path, sitk.Image) else sitk.ReadImage(series_path)
    oimg = {}
    oimg["images"] = sitk.GetArrayFromImage(img)
    oimg["spacing"] = img.GetSpacing()[::-1]
    oimg["series_id"] = series_id if series_id is not None else ""
    return oimg 

In [None]:
ds_path = "../resources/1.3.6.1.4.1.14519.5.2.1.6279.6001.309564220265302089123180126785.nii.gz"
f = load_sitk_img(ds_path)
vis(f["images"], size=64, window=True)

In [None]:
#| export
def import_module(d, parent=None, **default_kwargs):
    # copied from
    kwargs = d.copy()
    object_type = kwargs.pop("type")
    for name, value in default_kwargs.items():
        kwargs.setdefault(name, value)

    try:
        if parent is not None:
            module = getattr(parent, object_type)(**kwargs)  # skipcq PTC-W0034
        else:
            module = pydoc.locate(object_type)(**kwargs)
    except Exception as e:
        logger.error(f"Cannot load {name}. Error: {str(e)}")
    return module

In [None]:
#| export
def locate_cls(transforms: dict, return_partial=False):
    name = transforms["__class_fullname__"]
    targs = {k: v for k, v in transforms.items() if k != "__class_fullname__"}
    try:
        if return_partial:
            transforms = partial(pydoc.locate(name), **targs)
        else:
            transforms = pydoc.locate(name)(**targs)
    except Exception as e:
        logger.error(f"Cannot load {name}. Error: {str(e)}")
    return transforms

In [None]:
#| export 
def clean_state_dict(state_dict):
    # 'clean' checkpoint by removing module. prefix from state dict if it exists from parallel training
    cleaned_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:] if k.startswith("module.") else k
        cleaned_state_dict[name.replace("model.", "")] = v
    return cleaned_state_dict

In [None]:
#| export 
def mmcv_config_to_omegaconf(cfg):
    from mmengine.config import ConfigDict
    from omegaconf import OmegaConf
    new_cfg = {}
    for k, v in cfg.items():
        if isinstance(v, ConfigDict):
            v = v.to_dict()
        elif isinstance(v, list):
            v = [i.to_dict() if isinstance(i, ConfigDict) else i for i in v]
        new_cfg[k] = v
    cfg2 = OmegaConf.create(new_cfg)
    return cfg2

In [1]:
#| hide
import nbdev; nbdev.nbdev_export()