diff --git a/.gitignore b/.gitignore index e7f397d4..45587922 100644 --- a/.gitignore +++ b/.gitignore @@ -159,4 +159,5 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ outputs/ -profiles/ \ No newline at end of file +profiles/ +v1-5-pruned-emaonly.ckpt \ No newline at end of file diff --git a/ldm/deepspeed_replace.py b/ldm/deepspeed_replace.py index 5ab6d1de..94cb17c1 100644 --- a/ldm/deepspeed_replace.py +++ b/ldm/deepspeed_replace.py @@ -6,17 +6,30 @@ from functools import partial from dataclasses import dataclass import time -import deepspeed.ops.transformer as transformer_inference -from deepspeed.ops.transformer.inference.diffusers_attention import DeepSpeedDiffusersAttention -from deepspeed.ops.transformer.inference.diffusers_transformer_block import DeepSpeedDiffusersTransformerBlock -from deepspeed.ops.transformer.inference.diffusers_2d_transformer import Diffusers2DTransformerConfig +from lightning_utilities.core.imports import package_available + from ldm.modules.attention import CrossAttention, BasicTransformerBlock -from deepspeed.module_inject.replace_policy import UNetPolicy, DSPolicy from ldm.models.diffusion.ddpm import DiffusionWrapper from ldm.models.autoencoder import AutoencoderKL from ldm.modules.encoders.modules import FrozenCLIPEmbedder -from deepspeed.inference.engine import InferenceEngine +from ldm.detect_target import _detect_cuda +import logging + +if package_available("deepspeed"): + import deepspeed.ops.transformer as transformer_inference + from deepspeed.ops.transformer.inference.diffusers_attention import DeepSpeedDiffusersAttention + from deepspeed.ops.transformer.inference.diffusers_transformer_block import DeepSpeedDiffusersTransformerBlock + from deepspeed.ops.transformer.inference.diffusers_2d_transformer import Diffusers2DTransformerConfig + from deepspeed.inference.engine import InferenceEngine + from deepspeed.module_inject.replace_policy import UNetPolicy, DSPolicy +else: + class InferenceEngine: + pass + + class DSPolicy: + pass +logger = logging.getLogger(__name__) class InferenceEngine(InferenceEngine): @@ -266,7 +279,11 @@ def _module_match(module): return None # Inspired from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/replace_module.py#L201 -def deepspeed_injection(module, fp16=False, enable_cuda_graph=True): +def deepspeed_injection(module, fp16=True, enable_cuda_graph=True): + + if not torch.cuda.is_available(): + logger.warn("You provided use_deepspeed=True but Deepspeed isn't supported on your architecture. Skipping...") + return def replace_attn(child, policy): policy_attn = policy.attention(child) @@ -327,13 +344,17 @@ def _replace_module(module, policy): for name, child in module.named_children(): _replace_module(child, policy) if child.__class__ in new_policies: - replaced_module = new_policies[child.__class__](child, - policy) + replaced_module = new_policies[child.__class__](child, policy) setattr(module, name, replaced_module) - _replace_module(sub_module, policy) - new_module = policy.apply(sub_module, - enable_cuda_graph=enable_cuda_graph) + if not package_available("deepspeed"): + logger.warn("You provided use_deepspeed=True but Deepspeed isn't installed. Skipping...") + if _detect_cuda() not in ["80"]: + logger.warn("You provided use_deepspeed=True but Deepspeed isn't supported on your architecture. Skipping...") + else: + _replace_module(sub_module, policy) + + new_module = policy.apply(sub_module, enable_cuda_graph=enable_cuda_graph) print(f"**** found and replaced {name} w. {type(new_module)}") setattr(module, name, new_module) diff --git a/ldm/detect_target.py b/ldm/detect_target.py new file mode 100644 index 00000000..f9daf09b --- /dev/null +++ b/ldm/detect_target.py @@ -0,0 +1,23 @@ +import os +from subprocess import PIPE, Popen + +# Credits to AITemplate Team +# https://github.com/facebookincubator/AITemplate/blob/main/python/aitemplate/testing/detect_target.py +def _detect_cuda(): + try: + proc = Popen( + ["nvidia-smi", "--query-gpu=gpu_name", "--format=csv"], + stdout=PIPE, + stderr=PIPE, + ) + stdout, _ = proc.communicate() + stdout = stdout.decode("utf-8") + if "A100" in stdout or "RTX 30" in stdout or "A30" in stdout: + return "80" + if "V100" in stdout: + return "70" + if "T4" in stdout: + return "75" + return None + except Exception: + return None \ No newline at end of file diff --git a/ldm/lightning.py b/ldm/lightning.py index c7eb80eb..71251dbc 100644 --- a/ldm/lightning.py +++ b/ldm/lightning.py @@ -16,6 +16,9 @@ from contextlib import nullcontext from torch import autocast from ldm.deepspeed_replace import deepspeed_injection, ReplayCudaGraphUnet +import logging + +logger = logging.getLogger(__name__) class PromptDataset(Dataset): @@ -47,28 +50,30 @@ def __init__( self, config_path: str, checkpoint_path: str, - device: torch.device, + device: str, size: int = 512, fp16: bool = True, sampler: str = "ddim", steps: Optional[int] = None, - use_deepspeed: bool = True, + use_deepspeed: bool = False, enable_cuda_graph: bool = False, + use_inference_context: bool = False, ): super().__init__() + if device in ("mps", "cpu") and fp16: + logger.warn(f"You provided fp16=True but it isn't supported on `{device}`. Skipping...") + fp16 = False + config = OmegaConf.load(f"{config_path}") config.model.params.unet_config["params"]["use_fp16"] = False config.model.params.cond_stage_config["params"] = {"device": device} - checkpoint = torch.load(checkpoint_path, map_location="cpu") state_dict = checkpoint["state_dict"] self.model = instantiate_from_config(config.model) self.model.load_state_dict(state_dict, strict=False) - self.to(dtype=torch.float16) - - if use_deepspeed: + if use_deepspeed or enable_cuda_graph: deepspeed_injection(self.model, fp16=fp16, enable_cuda_graph=enable_cuda_graph) # Replace with @@ -79,12 +84,16 @@ def __init__( self.to(device, dtype=torch.float16 if fp16 else torch.float32) self.fp16 = fp16 + self.use_inference_context = use_inference_context - def predict_step(self, prompts: List[str], batch_idx: int): + def predict_step(self, prompts: Union[List[str], str], batch_idx: int = 0): + if isinstance(prompts, str): + prompts = [prompts] batch_size = len(prompts) precision_scope = autocast if self.fp16 else nullcontext inference = torch.inference_mode if torch.cuda.is_available() else torch.no_grad + inference = inference if self.use_inference_context else nullcontext with inference(): with precision_scope("cuda"): with self.model.ema_scope(): diff --git a/requirements.txt b/requirements.txt index 0da0d75a..fe2a899d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,5 +9,5 @@ einops==0.3.0 transformers>4.19.2 open-clip-torch==2.7.0 lightning -triton>=2.0.0.dev20221005 -deepspeed>=0.7.5 \ No newline at end of file +triton>=2.0.0.dev20221005; platform_system == "Linux" +deepspeed>=0.7.5; platform_system == "Linux" \ No newline at end of file diff --git a/scripts/txt2img_lightning.py b/scripts/txt2img_lightning.py index 9bcb27f8..124a5805 100644 --- a/scripts/txt2img_lightning.py +++ b/scripts/txt2img_lightning.py @@ -1,9 +1,11 @@ import argparse import os +import time +import torch from pytorch_lightning import seed_everything from ldm.lightning import LightningStableDiffusion -def benchmark_fn(iters: int, warm_up_iters: int, function, *args, **kwargs) -> float: +def benchmark_fn(device, iters: int, warm_up_iters: int, function, *args, **kwargs) -> float: """ Function for benchmarking a pytorch function. @@ -29,21 +31,29 @@ def benchmark_fn(iters: int, warm_up_iters: int, function, *args, **kwargs) -> f function(*args, **kwargs) # Start benchmark. - torch.cuda.synchronize() - torch.cuda.empty_cache() - torch.cuda.reset_accumulated_memory_stats() - torch.cuda.reset_peak_memory_stats() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - torch.cuda.reset_peak_memory_stats() + if device == "cuda": + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.reset_accumulated_memory_stats() + torch.cuda.reset_peak_memory_stats() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + torch.cuda.reset_peak_memory_stats() + else: + t0 = time.time() + for _ in range(iters): results.extend(function(*args, **kwargs)) - max_memory = torch.cuda.max_memory_allocated(0)/2**20 - end_event.record() - torch.cuda.synchronize() - # in ms - return (start_event.elapsed_time(end_event)) / iters, max_memory, results + + if device == "cuda": + max_memory = torch.cuda.max_memory_allocated(0)/2**20 + end_event.record() + torch.cuda.synchronize() + # in ms + return (start_event.elapsed_time(end_event)) / iters, max_memory, results + else: + return (time.time() - t0) / iters, None, results def parse_args(): parser = argparse.ArgumentParser() @@ -151,16 +161,24 @@ def main(opt): os.makedirs(opt.outdir, exist_ok=True) seed_everything(opt.seed) + device = "cuda" if torch.cuda.is_available() else "mps" + model = LightningStableDiffusion( config_path=opt.config, checkpoint_path=opt.ckpt, - device="cuda", - sampler=opt.sampler, + device=device, + fp16=True, # Supported on GPU and CPU only, skipped otherwise. + use_deepspeed=True, # Supported on Ampere and RTX, skipped otherwise. + enable_cuda_graph=True, # Currently enabled only for batch size 1. + use_inference_context=False, steps=30, ) for batch_size in [1, 2, 4]: - t, max_memory, images = benchmark_fn(10, 5, model.predict_step, prompts=[opt.prompt] * batch_size, batch_idx=0) + if batch_size == 1: + t, max_memory, images = benchmark_fn(device, 10, 5, model.predict_step, prompts=opt.prompt, batch_idx=0) + else: + t, max_memory, images = benchmark_fn(device, 10, 5, model.predict_step, prompts=[opt.prompt] * batch_size, batch_idx=0) print(f"Average time {t} secs on batch size {batch_size}.") print(f"Max GPU Memory cost is {max_memory} MB.")