In [None]:
# default_exp utils

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
# export
import os
from glob import glob
import cv2
import os.path as osp
from tqdm import tqdm
import mmcv
from fastcore.script import call_parse, Param
from avcv.process import multi_thread
from loguru import logger

def get_name(path):
    path = osp.basename(path).split('.')[:-1]
    return '.'.join(path)


def find_contours(thresh):
    """
        Get contour of a binary image
            Arguments:
                thresh: binary image
            Returns:
                Contours: a list of contour
                Hierarchy:

    """
    try:
        contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE,
                                               cv2.CHAIN_APPROX_SIMPLE)
        return contours, hierarchy[0]
    except:
        return None, None


@call_parse
def download_file_from_google_drive(id_or_link: Param("Link or file id"), destination: Param("Path to the save file")):
    if "https" in id_or_link:
        x = id_or_link
        id = x.split("/")[x.split("/").index("d")+1]
    else:
        id = id_or_link
    logger.info(f"Download from id: {id}")
    import requests

    def get_confirm_token(response):
        for key, value in response.cookies.items():
            if key.startswith('download_warning'):
                return value

        return None

    def save_response_content(response, destination):
        CHUNK_SIZE = 32768

        with open(destination, "wb") as f:
            for chunk in response.iter_content(CHUNK_SIZE):
                if chunk:  # filter out keep-alive new chunks
                    f.write(chunk)

    URL = "https://docs.google.com/uc?export=download"

    session = requests.Session()

    response = session.get(URL, params={'id': id}, stream=True)
    token = get_confirm_token(response)

    if token:
        params = {'id': id, 'confirm': token}
        response = session.get(URL, params=params, stream=True)

    save_response_content(response, destination)
    logger.info(f"Done -> {destination}")
    return osp.abspath(destination)


def mkdir(path):
    os.makedirs(path, exist_ok=True)


def put_text(image, pos, text, color=(255, 255, 255)):
    return cv2.putText(image, text, pos, cv2.FONT_HERSHEY_SIMPLEX, 1.0,
                       color, 2)


def images_to_video(
        images,
        out_path=None,
        fps: int = 30,
        no_sort=False,
        max_num_frame=10e12,
        resize_rate=1,
        with_text=False,
        text_is_date=False,
        verbose=True,
):

    if out_path is None:
        assert isinstance(
            images, str), "No out_path specify, you need to input a string to a directory"
        out_path = images+'.mp4'
    if isinstance(images, str) and os.path.isdir(images):
        from glob import glob
        images = glob(os.path.join(images, "*.jpg")) + \
            glob(os.path.join(images, "*.png")) + \
            glob(os.path.join(images, "*.jpeg"))

    def get_num(s):
        try:
            s = os.path.basename(s)
            num = int(''.join([c for c in s if c.isdigit()]))
        except:
            num = s
        return num
#     global f

    def f(img_or_path):
        if isinstance(img_or_path, str):
            name = os.path.basename(img_or_path)
            img = mmcv.imread(img_or_path)
            img = cv2.resize(img, output_size)
            assert img is not None, img_or_path
            if with_text:
                if text_is_date:
                    from datetime import datetime
                    name = name.split('.')[0].split('_')
                    f = float('{}.{}'.format(*name))
                    name = str(datetime.fromtimestamp(f))
                img = put_text(img, (20, 20), name)
        else:
            img = img_or_path
        return img

    if not no_sort and isinstance(images[0], str):
        images = list(sorted(images, key=get_num))

    max_num_frame = int(max_num_frame)
    max_num_frame = min(len(images), max_num_frame)

    h, w = mmcv.imread(images[0]).shape[:2]
    output_size = (int(w*resize_rate), int(h*resize_rate))
    if out_path.endswith('.mp4'):
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(out_path, fourcc, fps, output_size)
    elif out_path.endswith('.avi'):
        out = cv2.VideoWriter(
            out_path, cv2.VideoWriter_fourcc(*'DIVX'), fps, output_size)
    else:
        raise NotImplementedError
    images = images[:max_num_frame]
    images = multi_thread(f, images, verbose=verbose)
    if verbose:
        logger.info("fWrite video, output_size: {output_size}")
        pbar = mmcv.ProgressBar(len(images))
    logger.info(f"out_path: {out_path}")
    for img in images:
        img = cv2.resize(img, output_size)
        out.write(img)
        if verbose:
            pbar.update()
    out.release()

In [None]:
#export
@call_parse
def av_i2v(
            images: Param("Path to the images folder or list of images"),
            out_path: Param("Output output video path", str)=None,
            fps: Param("Frame per second", int) = 30,
            no_sort: Param("Sort images", bool) = False,
            max_num_frame: Param("Max num of frame", int) = 10e12,
            resize_rate: Param("Resize rate", float) = 1,
            with_text: Param("Add additional index to image when writing vidoe", bool) = False,
            text_is_date: Param("Add additional index to image when writing vidoe", bool) = False,
            verbose:Param("Print...", bool)=True,
        ):

    return images_to_video(images, out_path, fps,
                           no_sort, max_num_frame, resize_rate, with_text,
                           text_is_date,verbose,
            )


In [None]:
#export
def video_to_images(input_video, output_dir=None, skip=1):
    """
        Extract video to image:
            inputs:
                input_video: path to video
                output_dir: default is set to video name
    """

    if output_dir is None:
        vname = get_name(input_video).split('.')[0]
        output_dir = osp.join('.cache/video_to_images/', vname)
        logger.info(f'Set output_dir = {output_dir}')
    
    video = mmcv.video.VideoReader(input_video)
    pbar = mmcv.ProgressBar(video._frame_cnt)
    for i in range(0, len(video), skip):
        try:
            img = video[i]
            out_img_path = os.path.join(output_dir, f'{i:05d}' + '.jpg')
            mmcv.imwrite(img, out_img_path)
            pbar.update()
        except Exception as e:
            logger.warning(f"Cannot write image index {i}, exception: {e}")
            continue
    
@call_parse
def av_v2i(input_video:Param("", str), output_dir:Param("", str)=None, skip:Param("", int)=1):
    return video_to_images(input_video, output_dir, skip)

In [None]:
video_to_images('../asset/hcm_5s.mp4')

## CLI examples
+ Download file given a google link

        gdown --help
        gdown "https://drive.google.com/file/d/1xOb92Yx3hoOsMsAiI2mnkcyoQatRQNBf/view?usp=sharing" test.mp3
        
This should return a openable mp3 file
+ Compose a video given a folder of images        

        i2v --help
        i2v PATH_TO_DIR out.mp4

+ Extract images given a video

        v2i --help # helper
        v2i test.mp4 test-img/

# Time logger

In [None]:
#export
from mmcv import Timer
import pandas as pd
import numpy as np

class TimeLoger:
    def __init__(self):
        self.timer = Timer()
        self.time_dict = dict()

    def start(self):
        self.timer.start()

    def update(self, name):
        # assert not name in self.time_dict
        duration = self.timer.since_last_check()
        if name in self.time_dict:
            self.time_dict[name].append(duration)
        else:
            self.time_dict[name] = [duration]

    def __str__(self):
        total_time = np.sum([np.sum(v) for v in self.time_dict.values()])
        s = f"------------------Time Loger Summary : Total {total_time:0.2f} ---------------------:\n"
        for k, v in self.time_dict.items():
            average = np.mean(v)
            times = len(v)
            percent = np.sum(v)*100/total_time
            s += f'\t\t{k}:  \t\t{percent:0.2f}% ({average:0.4f}s) | Times: {times} \n'
        return s

# Memoize

In [None]:
#export
import xxhash
import pickle

def identify(x):
    '''Return an hex digest of the input'''
    return xxhash.xxh64(pickle.dumps(x), seed=0).hexdigest()


def memoize(func):
    import os
    import pickle
    from functools import wraps
    import xxhash
    '''Cache result of function call on disk
    Support multiple positional and keyword arguments'''
    @wraps(func)
    def memoized_func(*args, **kwargs):
        cache_dir = '.cache'
        try:
            import inspect
            func_id = identify((inspect.getsource(func), args, kwargs))
            cache_path = os.path.join(cache_dir, func.__name__+'_'+func_id)

            if (os.path.exists(cache_path) and
                    not func.__name__ in os.environ and
                    not 'BUST_CACHE' in os.environ):
                result = pickle.load(open(cache_path, 'rb'))
            else:
                result = func(*args, **kwargs)
                os.makedirs(cache_dir, exist_ok=True)
                pickle.dump(result, open(cache_path, 'wb'))
            return result
        except (KeyError, AttributeError, TypeError, Exception) as e:
            logger.warning(f'Exception: {e}, use default function call')
            return func(*args, **kwargs)
    return memoized_func

In [3]:
#export
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

"""Facilities for pickling Python code alongside other data.

The pickled code is automatically imported into a separate Python module
during unpickling. This way, any previously exported pickles will remain
usable even if the original code is no longer available, or if the current
version of the code is not consistent with what was originally pickled."""

import sys
import pickle
import io
import inspect
import copy
import uuid
import types
from typing import Any
class EasyDict(dict):
    """Convenience class that behaves like a dict but allows access with the attribute syntax."""

    def __getattr__(self, name: str) -> Any:
        try:
            return self[name]
        except KeyError:
            raise AttributeError(name)

    def __setattr__(self, name: str, value: Any) -> None:
        self[name] = value

    def __delattr__(self, name: str) -> None:
        del self[name]


#----------------------------------------------------------------------------

_version            = 6         # internal version number
_decorators         = set()     # {decorator_class, ...}
_import_hooks       = []        # [hook_function, ...]
_module_to_src_dict = dict()    # {module: src, ...}
_src_to_module_dict = dict()    # {src: module, ...}

#----------------------------------------------------------------------------

def persistent_class(orig_class):
    r"""Class decorator that extends a given class to save its source code
    when pickled.

    Example:

        from torch_utils import persistence

        @persistence.persistent_class
        class MyNetwork(torch.nn.Module):
            def __init__(self, num_inputs, num_outputs):
                super().__init__()
                self.fc = MyLayer(num_inputs, num_outputs)
                ...

        @persistence.persistent_class
        class MyLayer(torch.nn.Module):
            ...

    When pickled, any instance of `MyNetwork` and `MyLayer` will save its
    source code alongside other internal state (e.g., parameters, buffers,
    and submodules). This way, any previously exported pickle will remain
    usable even if the class definitions have been modified or are no
    longer available.

    The decorator saves the source code of the entire Python module
    containing the decorated class. It does *not* save the source code of
    any imported modules. Thus, the imported modules must be available
    during unpickling, also including `torch_utils.persistence` itself.

    It is ok to call functions defined in the same module from the
    decorated class. However, if the decorated class depends on other
    classes defined in the same module, they must be decorated as well.
    This is illustrated in the above example in the case of `MyLayer`.

    It is also possible to employ the decorator just-in-time before
    calling the constructor. For example:

        cls = MyLayer
        if want_to_make_it_persistent:
            cls = persistence.persistent_class(cls)
        layer = cls(num_inputs, num_outputs)

    As an additional feature, the decorator also keeps track of the
    arguments that were used to construct each instance of the decorated
    class. The arguments can be queried via `obj.init_args` and
    `obj.init_kwargs`, and they are automatically pickled alongside other
    object state. A typical use case is to first unpickle a previous
    instance of a persistent class, and then upgrade it to use the latest
    version of the source code:

        with open('old_pickle.pkl', 'rb') as f:
            old_net = pickle.load(f)
        new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
        misc.copy_params_and_buffers(old_net, new_net, require_all=True)
    """
    assert isinstance(orig_class, type)
    if is_persistent(orig_class):
        return orig_class

    assert orig_class.__module__ in sys.modules
    orig_module = sys.modules[orig_class.__module__]
    orig_module_src = _module_to_src(orig_module)

    class Decorator(orig_class):
        _orig_module_src = orig_module_src
        _orig_class_name = orig_class.__name__

        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self._init_args = copy.deepcopy(args)
            self._init_kwargs = copy.deepcopy(kwargs)
            assert orig_class.__name__ in orig_module.__dict__
            _check_pickleable(self.__reduce__())

        @property
        def init_args(self):
            return copy.deepcopy(self._init_args)

        @property
        def init_kwargs(self):
            return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))

        def __reduce__(self):
            fields = list(super().__reduce__())
            fields += [None] * max(3 - len(fields), 0)
            if fields[0] is not _reconstruct_persistent_obj:
                meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
                fields[0] = _reconstruct_persistent_obj # reconstruct func
                fields[1] = (meta,) # reconstruct args
                fields[2] = None # state dict
            return tuple(fields)

    Decorator.__name__ = orig_class.__name__
    _decorators.add(Decorator)
    return Decorator

#----------------------------------------------------------------------------

def is_persistent(obj):
    r"""Test whether the given object or class is persistent, i.e.,
    whether it will save its source code when pickled.
    """
    try:
        if obj in _decorators:
            return True
    except TypeError:
        pass
    return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck

#----------------------------------------------------------------------------

def import_hook(hook):
    r"""Register an import hook that is called whenever a persistent object
    is being unpickled. A typical use case is to patch the pickled source
    code to avoid errors and inconsistencies when the API of some imported
    module has changed.

    The hook should have the following signature:

        hook(meta) -> modified meta

    `meta` is an instance of `dnnlib.EasyDict` with the following fields:

        type:       Type of the persistent object, e.g. `'class'`.
        version:    Internal version number of `torch_utils.persistence`.
        module_src  Original source code of the Python module.
        class_name: Class name in the original Python module.
        state:      Internal state of the object.

    Example:

        @persistence.import_hook
        def wreck_my_network(meta):
            if meta.class_name == 'MyNetwork':
                print('MyNetwork is being imported. I will wreck it!')
                meta.module_src = meta.module_src.replace("True", "False")
            return meta
    """
    assert callable(hook)
    _import_hooks.append(hook)

#----------------------------------------------------------------------------

def _reconstruct_persistent_obj(meta):
    r"""Hook that is called internally by the `pickle` module to unpickle
    a persistent object.
    """
    meta = EasyDict(meta)
    meta.state = EasyDict(meta.state)
    for hook in _import_hooks:
        meta = hook(meta)
        assert meta is not None

    assert meta.version == _version
    module = _src_to_module(meta.module_src)

    assert meta.type == 'class'
    orig_class = module.__dict__[meta.class_name]
    decorator_class = persistent_class(orig_class)
    obj = decorator_class.__new__(decorator_class)

    setstate = getattr(obj, '__setstate__', None)
    if callable(setstate):
        setstate(meta.state) # pylint: disable=not-callable
    else:
        obj.__dict__.update(meta.state)
    return obj

#----------------------------------------------------------------------------

def _module_to_src(module):
    r"""Query the source code of a given Python module.
    """
    src = _module_to_src_dict.get(module, None)
    if src is None:
        src = inspect.getsource(module)
        _module_to_src_dict[module] = src
        _src_to_module_dict[src] = module
    return src

def _src_to_module(src):
    r"""Get or create a Python module for the given source code.
    """
    module = _src_to_module_dict.get(src, None)
    if module is None:
        module_name = "_imported_module_" + uuid.uuid4().hex
        module = types.ModuleType(module_name)
        sys.modules[module_name] = module
        _module_to_src_dict[module] = src
        _src_to_module_dict[src] = module
        exec(src, module.__dict__) # pylint: disable=exec-used
    return module

#----------------------------------------------------------------------------

def _check_pickleable(obj):
    r"""Check that the given object is pickleable, raising an exception if
    it is not. This function is expected to be considerably more efficient
    than actually pickling the object.
    """
    def recurse(obj):
        if isinstance(obj, (list, tuple, set)):
            return [recurse(x) for x in obj]
        if isinstance(obj, dict):
            return [[recurse(x), recurse(y)] for x, y in obj.items()]
        if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
            return None # Python primitive types are pickleable.
        if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']:
            return None # NumPy arrays and PyTorch tensors are pickleable.
        if is_persistent(obj):
            return None # Persistent objects are pickleable, by virtue of the constructor check.
        return obj
    with io.BytesIO() as f:
        pickle.dump(recurse(obj), f)

#----------------------------------------------------------------------------


In [4]:
# !pip install dnnlib