Skip to content
This repository has been archived by the owner on Jul 3, 2023. It is now read-only.

Skip DeepSpeed Optimization if not available or not supported #6

Merged
merged 10 commits into from
Jan 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,5 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
outputs/
profiles/
profiles/
v1-5-pruned-emaonly.ckpt
45 changes: 33 additions & 12 deletions ldm/deepspeed_replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,30 @@
from functools import partial
from dataclasses import dataclass
import time
import deepspeed.ops.transformer as transformer_inference
from deepspeed.ops.transformer.inference.diffusers_attention import DeepSpeedDiffusersAttention
from deepspeed.ops.transformer.inference.diffusers_transformer_block import DeepSpeedDiffusersTransformerBlock
from deepspeed.ops.transformer.inference.diffusers_2d_transformer import Diffusers2DTransformerConfig
from lightning_utilities.core.imports import package_available

from ldm.modules.attention import CrossAttention, BasicTransformerBlock
from deepspeed.module_inject.replace_policy import UNetPolicy, DSPolicy
from ldm.models.diffusion.ddpm import DiffusionWrapper
from ldm.models.autoencoder import AutoencoderKL
from ldm.modules.encoders.modules import FrozenCLIPEmbedder
from deepspeed.inference.engine import InferenceEngine
from ldm.detect_target import _detect_cuda
import logging

if package_available("deepspeed"):
import deepspeed.ops.transformer as transformer_inference
from deepspeed.ops.transformer.inference.diffusers_attention import DeepSpeedDiffusersAttention
from deepspeed.ops.transformer.inference.diffusers_transformer_block import DeepSpeedDiffusersTransformerBlock
from deepspeed.ops.transformer.inference.diffusers_2d_transformer import Diffusers2DTransformerConfig
from deepspeed.inference.engine import InferenceEngine
from deepspeed.module_inject.replace_policy import UNetPolicy, DSPolicy
else:
class InferenceEngine:
pass

class DSPolicy:
pass

logger = logging.getLogger(__name__)

class InferenceEngine(InferenceEngine):

Expand Down Expand Up @@ -266,7 +279,11 @@ def _module_match(module):
return None

# Inspired from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/replace_module.py#L201
def deepspeed_injection(module, fp16=False, enable_cuda_graph=True):
def deepspeed_injection(module, fp16=True, enable_cuda_graph=True):

if not torch.cuda.is_available():
logger.warn("You provided use_deepspeed=True but Deepspeed isn't supported on your architecture. Skipping...")
return

def replace_attn(child, policy):
policy_attn = policy.attention(child)
Expand Down Expand Up @@ -327,13 +344,17 @@ def _replace_module(module, policy):
for name, child in module.named_children():
_replace_module(child, policy)
if child.__class__ in new_policies:
replaced_module = new_policies[child.__class__](child,
policy)
replaced_module = new_policies[child.__class__](child, policy)
setattr(module, name, replaced_module)

_replace_module(sub_module, policy)
new_module = policy.apply(sub_module,
enable_cuda_graph=enable_cuda_graph)
if not package_available("deepspeed"):
logger.warn("You provided use_deepspeed=True but Deepspeed isn't installed. Skipping...")
if _detect_cuda() not in ["80"]:
logger.warn("You provided use_deepspeed=True but Deepspeed isn't supported on your architecture. Skipping...")
else:
_replace_module(sub_module, policy)

new_module = policy.apply(sub_module, enable_cuda_graph=enable_cuda_graph)
print(f"**** found and replaced {name} w. {type(new_module)}")
setattr(module, name, new_module)

23 changes: 23 additions & 0 deletions ldm/detect_target.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import os
from subprocess import PIPE, Popen

# Credits to AITemplate Team
# https://github.com/facebookincubator/AITemplate/blob/main/python/aitemplate/testing/detect_target.py
def _detect_cuda():
try:
proc = Popen(
["nvidia-smi", "--query-gpu=gpu_name", "--format=csv"],
stdout=PIPE,
stderr=PIPE,
)
stdout, _ = proc.communicate()
stdout = stdout.decode("utf-8")
if "A100" in stdout or "RTX 30" in stdout or "A30" in stdout:
return "80"
if "V100" in stdout:
return "70"
if "T4" in stdout:
return "75"
return None
except Exception:
return None
23 changes: 16 additions & 7 deletions ldm/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from contextlib import nullcontext
from torch import autocast
from ldm.deepspeed_replace import deepspeed_injection, ReplayCudaGraphUnet
import logging

logger = logging.getLogger(__name__)


class PromptDataset(Dataset):
Expand Down Expand Up @@ -47,28 +50,30 @@ def __init__(
self,
config_path: str,
checkpoint_path: str,
device: torch.device,
device: str,
size: int = 512,
fp16: bool = True,
sampler: str = "ddim",
steps: Optional[int] = None,
use_deepspeed: bool = True,
use_deepspeed: bool = False,
enable_cuda_graph: bool = False,
use_inference_context: bool = False,
):
super().__init__()

if device in ("mps", "cpu") and fp16:
logger.warn(f"You provided fp16=True but it isn't supported on `{device}`. Skipping...")
fp16 = False

config = OmegaConf.load(f"{config_path}")
config.model.params.unet_config["params"]["use_fp16"] = False
config.model.params.cond_stage_config["params"] = {"device": device}

checkpoint = torch.load(checkpoint_path, map_location="cpu")
state_dict = checkpoint["state_dict"]
self.model = instantiate_from_config(config.model)
self.model.load_state_dict(state_dict, strict=False)

self.to(dtype=torch.float16)

if use_deepspeed:
if use_deepspeed or enable_cuda_graph:
deepspeed_injection(self.model, fp16=fp16, enable_cuda_graph=enable_cuda_graph)

# Replace with
Expand All @@ -79,12 +84,16 @@ def __init__(

self.to(device, dtype=torch.float16 if fp16 else torch.float32)
self.fp16 = fp16
self.use_inference_context = use_inference_context

def predict_step(self, prompts: List[str], batch_idx: int):
def predict_step(self, prompts: Union[List[str], str], batch_idx: int = 0):
if isinstance(prompts, str):
prompts = [prompts]
batch_size = len(prompts)

precision_scope = autocast if self.fp16 else nullcontext
inference = torch.inference_mode if torch.cuda.is_available() else torch.no_grad
inference = inference if self.use_inference_context else nullcontext
with inference():
with precision_scope("cuda"):
with self.model.ema_scope():
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ einops==0.3.0
transformers>4.19.2
open-clip-torch==2.7.0
lightning
triton>=2.0.0.dev20221005
deepspeed>=0.7.5
triton>=2.0.0.dev20221005; platform_system == "Linux"
deepspeed>=0.7.5; platform_system == "Linux"
52 changes: 35 additions & 17 deletions scripts/txt2img_lightning.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import argparse
import os
import time
import torch
from pytorch_lightning import seed_everything
from ldm.lightning import LightningStableDiffusion

def benchmark_fn(iters: int, warm_up_iters: int, function, *args, **kwargs) -> float:
def benchmark_fn(device, iters: int, warm_up_iters: int, function, *args, **kwargs) -> float:
"""
Function for benchmarking a pytorch function.

Expand All @@ -29,21 +31,29 @@ def benchmark_fn(iters: int, warm_up_iters: int, function, *args, **kwargs) -> f
function(*args, **kwargs)

# Start benchmark.
torch.cuda.synchronize()
torch.cuda.empty_cache()
torch.cuda.reset_accumulated_memory_stats()
torch.cuda.reset_peak_memory_stats()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
torch.cuda.reset_peak_memory_stats()
if device == "cuda":
torch.cuda.synchronize()
torch.cuda.empty_cache()
torch.cuda.reset_accumulated_memory_stats()
torch.cuda.reset_peak_memory_stats()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
torch.cuda.reset_peak_memory_stats()
else:
t0 = time.time()

for _ in range(iters):
results.extend(function(*args, **kwargs))
max_memory = torch.cuda.max_memory_allocated(0)/2**20
end_event.record()
torch.cuda.synchronize()
# in ms
return (start_event.elapsed_time(end_event)) / iters, max_memory, results

if device == "cuda":
max_memory = torch.cuda.max_memory_allocated(0)/2**20
end_event.record()
torch.cuda.synchronize()
# in ms
return (start_event.elapsed_time(end_event)) / iters, max_memory, results
else:
return (time.time() - t0) / iters, None, results

def parse_args():
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -151,16 +161,24 @@ def main(opt):
os.makedirs(opt.outdir, exist_ok=True)
seed_everything(opt.seed)

device = "cuda" if torch.cuda.is_available() else "mps"

model = LightningStableDiffusion(
config_path=opt.config,
checkpoint_path=opt.ckpt,
device="cuda",
sampler=opt.sampler,
device=device,
fp16=True, # Supported on GPU and CPU only, skipped otherwise.
use_deepspeed=True, # Supported on Ampere and RTX, skipped otherwise.
enable_cuda_graph=True, # Currently enabled only for batch size 1.
use_inference_context=False,
steps=30,
)

for batch_size in [1, 2, 4]:
t, max_memory, images = benchmark_fn(10, 5, model.predict_step, prompts=[opt.prompt] * batch_size, batch_idx=0)
if batch_size == 1:
t, max_memory, images = benchmark_fn(device, 10, 5, model.predict_step, prompts=opt.prompt, batch_idx=0)
else:
t, max_memory, images = benchmark_fn(device, 10, 5, model.predict_step, prompts=[opt.prompt] * batch_size, batch_idx=0)
print(f"Average time {t} secs on batch size {batch_size}.")
print(f"Max GPU Memory cost is {max_memory} MB.")

Expand Down