Skip to content

Commit

Permalink
🔊 Add timing debug logs (#1650)
Browse files Browse the repository at this point in the history
* 🔊 Add timing debug logs

* 🐛 Use custom metaclass in DEBUG only
  • Loading branch information
huchenlei committed Jul 3, 2023
1 parent 68e8037 commit eceeec7
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 3 deletions.
11 changes: 9 additions & 2 deletions scripts/controlnet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import gc
import os
import logging
from collections import OrderedDict
from copy import copy
from typing import Dict, Optional, Tuple
Expand All @@ -8,6 +9,7 @@
from modules import shared, devices, script_callbacks, processing, masking, images
import gradio as gr


from einops import rearrange
from scripts import global_state, hook, external_code, processor, batch_hijack, controlnet_version, utils
from scripts.controlnet_ui import controlnet_ui_group
Expand Down Expand Up @@ -206,7 +208,9 @@ def set_numpy_seed(p: processing.StableDiffusionProcessing) -> Optional[int]:
return None


class Script(scripts.Script):
class Script(scripts.Script, metaclass=(
utils.TimeMeta if logger.level == logging.DEBUG else type)):

model_cache = OrderedDict()

def __init__(self) -> None:
Expand Down Expand Up @@ -757,7 +761,7 @@ def process(self, p, *args):

if 'reference' not in unit.module and issubclass(type(p), StableDiffusionProcessingImg2Img) \
and p.inpaint_full_res and a1111_mask_image is not None:

logger.debug("A1111 inpaint mask START")
input_image = [input_image[:, :, i] for i in range(input_image.shape[2])]
input_image = [Image.fromarray(x) for x in input_image]

Expand All @@ -779,13 +783,16 @@ def process(self, p, *args):

input_image = [np.asarray(x)[:, :, 0] for x in input_image]
input_image = np.stack(input_image, axis=2)
logger.debug("A1111 inpaint mask END")

if 'inpaint_only' == unit.module and issubclass(type(p), StableDiffusionProcessingImg2Img) and p.image_mask is not None:
logger.warning('A1111 inpaint and ControlNet inpaint duplicated. ControlNet support enabled.')
unit.module = 'inpaint'

# safe numpy
logger.debug("Safe numpy convertion START")
input_image = np.ascontiguousarray(input_image.copy()).copy()
logger.debug("Safe numpy convertion END")

logger.info(f"Loading preprocessor: {unit.module}")
preprocessor = self.preprocessor[unit.module]
Expand Down
35 changes: 34 additions & 1 deletion scripts/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import torch
import os
import functools
import time
import base64
import numpy as np
import gradio as gr
import logging

from typing import Any, Callable, Dict

from scripts.logging import logger


def load_state_dict(ckpt_path, location="cpu"):
_, extension = os.path.splitext(ckpt_path)
if extension.lower() == ".safetensors":
Expand Down Expand Up @@ -80,6 +83,35 @@ def convert_item(item: Any):
return decorator


def timer_decorator(func):
"""Time the decorated function and output the result to debug logger."""
if logger.level != logging.DEBUG:
return func

@functools.wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
duration = end_time - start_time
# Only report function that are significant enough.
if duration > 1e-3:
logger.debug(f"{func.__name__} ran in: {duration} sec")
return result

return wrapper


class TimeMeta(type):
""" Metaclass to record execution time on all methods of the
child class. """
def __new__(cls, name, bases, attrs):
for attr_name, attr_value in attrs.items():
if callable(attr_value):
attrs[attr_name] = timer_decorator(attr_value)
return super().__new__(cls, name, bases, attrs)


# svgsupports
svgsupport = False
try:
Expand Down Expand Up @@ -108,11 +140,12 @@ def svg_preprocess(inputs: Dict, preprocess: Callable):
inputs["image"] = base64_str
return preprocess(inputs)


def get_unique_axis0(data):
arr = np.asanyarray(data)
idxs = np.lexsort(arr.T)
arr = arr[idxs]
unique_idxs = np.empty(len(arr), dtype=np.bool_)
unique_idxs[:1] = True
unique_idxs[1:] = np.any(arr[:-1, :] != arr[1:, :], axis=-1)
return arr[unique_idxs]
return arr[unique_idxs]

0 comments on commit eceeec7

Please sign in to comment.