Skip to content
This repository has been archived by the owner on Jul 3, 2023. It is now read-only.

Commit

Permalink
Merge pull request #6 from Lightning-AI/make_deepspeed_opt_in
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton committed Jan 13, 2023
2 parents 03ae671 + 0506caa commit 8d78cc3
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 39 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,5 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
outputs/
profiles/
profiles/
v1-5-pruned-emaonly.ckpt
45 changes: 33 additions & 12 deletions ldm/deepspeed_replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,30 @@
from functools import partial
from dataclasses import dataclass
import time
import deepspeed.ops.transformer as transformer_inference
from deepspeed.ops.transformer.inference.diffusers_attention import DeepSpeedDiffusersAttention
from deepspeed.ops.transformer.inference.diffusers_transformer_block import DeepSpeedDiffusersTransformerBlock
from deepspeed.ops.transformer.inference.diffusers_2d_transformer import Diffusers2DTransformerConfig
from lightning_utilities.core.imports import package_available

from ldm.modules.attention import CrossAttention, BasicTransformerBlock
from deepspeed.module_inject.replace_policy import UNetPolicy, DSPolicy
from ldm.models.diffusion.ddpm import DiffusionWrapper
from ldm.models.autoencoder import AutoencoderKL
from ldm.modules.encoders.modules import FrozenCLIPEmbedder
from deepspeed.inference.engine import InferenceEngine
from ldm.detect_target import _detect_cuda
import logging

if package_available("deepspeed"):
import deepspeed.ops.transformer as transformer_inference
from deepspeed.ops.transformer.inference.diffusers_attention import DeepSpeedDiffusersAttention
from deepspeed.ops.transformer.inference.diffusers_transformer_block import DeepSpeedDiffusersTransformerBlock
from deepspeed.ops.transformer.inference.diffusers_2d_transformer import Diffusers2DTransformerConfig
from deepspeed.inference.engine import InferenceEngine
from deepspeed.module_inject.replace_policy import UNetPolicy, DSPolicy
else:
class InferenceEngine:
pass

class DSPolicy:
pass

logger = logging.getLogger(__name__)

class InferenceEngine(InferenceEngine):

Expand Down Expand Up @@ -266,7 +279,11 @@ def _module_match(module):
return None

# Inspired from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/replace_module.py#L201
def deepspeed_injection(module, fp16=False, enable_cuda_graph=True):
def deepspeed_injection(module, fp16=True, enable_cuda_graph=True):

if not torch.cuda.is_available():
logger.warn("You provided use_deepspeed=True but Deepspeed isn't supported on your architecture. Skipping...")
return

def replace_attn(child, policy):
policy_attn = policy.attention(child)
Expand Down Expand Up @@ -327,13 +344,17 @@ def _replace_module(module, policy):
for name, child in module.named_children():
_replace_module(child, policy)
if child.__class__ in new_policies:
replaced_module = new_policies[child.__class__](child,
policy)
replaced_module = new_policies[child.__class__](child, policy)
setattr(module, name, replaced_module)

_replace_module(sub_module, policy)
new_module = policy.apply(sub_module,
enable_cuda_graph=enable_cuda_graph)
if not package_available("deepspeed"):
logger.warn("You provided use_deepspeed=True but Deepspeed isn't installed. Skipping...")
if _detect_cuda() not in ["80"]:
logger.warn("You provided use_deepspeed=True but Deepspeed isn't supported on your architecture. Skipping...")
else:
_replace_module(sub_module, policy)

new_module = policy.apply(sub_module, enable_cuda_graph=enable_cuda_graph)
print(f"**** found and replaced {name} w. {type(new_module)}")
setattr(module, name, new_module)

23 changes: 23 additions & 0 deletions ldm/detect_target.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import os
from subprocess import PIPE, Popen

# Credits to AITemplate Team
# https://github.com/facebookincubator/AITemplate/blob/main/python/aitemplate/testing/detect_target.py
def _detect_cuda():
try:
proc = Popen(
["nvidia-smi", "--query-gpu=gpu_name", "--format=csv"],
stdout=PIPE,
stderr=PIPE,
)
stdout, _ = proc.communicate()
stdout = stdout.decode("utf-8")
if "A100" in stdout or "RTX 30" in stdout or "A30" in stdout:
return "80"
if "V100" in stdout:
return "70"
if "T4" in stdout:
return "75"
return None
except Exception:
return None
23 changes: 16 additions & 7 deletions ldm/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from contextlib import nullcontext
from torch import autocast
from ldm.deepspeed_replace import deepspeed_injection, ReplayCudaGraphUnet
import logging

logger = logging.getLogger(__name__)


class PromptDataset(Dataset):
Expand Down Expand Up @@ -47,28 +50,30 @@ def __init__(
self,
config_path: str,
checkpoint_path: str,
device: torch.device,
device: str,
size: int = 512,
fp16: bool = True,
sampler: str = "ddim",
steps: Optional[int] = None,
use_deepspeed: bool = True,
use_deepspeed: bool = False,
enable_cuda_graph: bool = False,
use_inference_context: bool = False,
):
super().__init__()

if device in ("mps", "cpu") and fp16:
logger.warn(f"You provided fp16=True but it isn't supported on `{device}`. Skipping...")
fp16 = False

config = OmegaConf.load(f"{config_path}")
config.model.params.unet_config["params"]["use_fp16"] = False
config.model.params.cond_stage_config["params"] = {"device": device}

checkpoint = torch.load(checkpoint_path, map_location="cpu")
state_dict = checkpoint["state_dict"]
self.model = instantiate_from_config(config.model)
self.model.load_state_dict(state_dict, strict=False)

self.to(dtype=torch.float16)

if use_deepspeed:
if use_deepspeed or enable_cuda_graph:
deepspeed_injection(self.model, fp16=fp16, enable_cuda_graph=enable_cuda_graph)

# Replace with
Expand All @@ -79,12 +84,16 @@ def __init__(

self.to(device, dtype=torch.float16 if fp16 else torch.float32)
self.fp16 = fp16
self.use_inference_context = use_inference_context

def predict_step(self, prompts: List[str], batch_idx: int):
def predict_step(self, prompts: Union[List[str], str], batch_idx: int = 0):
if isinstance(prompts, str):
prompts = [prompts]
batch_size = len(prompts)

precision_scope = autocast if self.fp16 else nullcontext
inference = torch.inference_mode if torch.cuda.is_available() else torch.no_grad
inference = inference if self.use_inference_context else nullcontext
with inference():
with precision_scope("cuda"):
with self.model.ema_scope():
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ einops==0.3.0
transformers>4.19.2
open-clip-torch==2.7.0
lightning
triton>=2.0.0.dev20221005
deepspeed>=0.7.5
triton>=2.0.0.dev20221005; platform_system == "Linux"
deepspeed>=0.7.5; platform_system == "Linux"
52 changes: 35 additions & 17 deletions scripts/txt2img_lightning.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import argparse
import os
import time
import torch
from pytorch_lightning import seed_everything
from ldm.lightning import LightningStableDiffusion

def benchmark_fn(iters: int, warm_up_iters: int, function, *args, **kwargs) -> float:
def benchmark_fn(device, iters: int, warm_up_iters: int, function, *args, **kwargs) -> float:
"""
Function for benchmarking a pytorch function.
Expand All @@ -29,21 +31,29 @@ def benchmark_fn(iters: int, warm_up_iters: int, function, *args, **kwargs) -> f
function(*args, **kwargs)

# Start benchmark.
torch.cuda.synchronize()
torch.cuda.empty_cache()
torch.cuda.reset_accumulated_memory_stats()
torch.cuda.reset_peak_memory_stats()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
torch.cuda.reset_peak_memory_stats()
if device == "cuda":
torch.cuda.synchronize()
torch.cuda.empty_cache()
torch.cuda.reset_accumulated_memory_stats()
torch.cuda.reset_peak_memory_stats()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
torch.cuda.reset_peak_memory_stats()
else:
t0 = time.time()

for _ in range(iters):
results.extend(function(*args, **kwargs))
max_memory = torch.cuda.max_memory_allocated(0)/2**20
end_event.record()
torch.cuda.synchronize()
# in ms
return (start_event.elapsed_time(end_event)) / iters, max_memory, results

if device == "cuda":
max_memory = torch.cuda.max_memory_allocated(0)/2**20
end_event.record()
torch.cuda.synchronize()
# in ms
return (start_event.elapsed_time(end_event)) / iters, max_memory, results
else:
return (time.time() - t0) / iters, None, results

def parse_args():
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -151,16 +161,24 @@ def main(opt):
os.makedirs(opt.outdir, exist_ok=True)
seed_everything(opt.seed)

device = "cuda" if torch.cuda.is_available() else "mps"

model = LightningStableDiffusion(
config_path=opt.config,
checkpoint_path=opt.ckpt,
device="cuda",
sampler=opt.sampler,
device=device,
fp16=True, # Supported on GPU and CPU only, skipped otherwise.
use_deepspeed=True, # Supported on Ampere and RTX, skipped otherwise.
enable_cuda_graph=True, # Currently enabled only for batch size 1.
use_inference_context=False,
steps=30,
)

for batch_size in [1, 2, 4]:
t, max_memory, images = benchmark_fn(10, 5, model.predict_step, prompts=[opt.prompt] * batch_size, batch_idx=0)
if batch_size == 1:
t, max_memory, images = benchmark_fn(device, 10, 5, model.predict_step, prompts=opt.prompt, batch_idx=0)
else:
t, max_memory, images = benchmark_fn(device, 10, 5, model.predict_step, prompts=[opt.prompt] * batch_size, batch_idx=0)
print(f"Average time {t} secs on batch size {batch_size}.")
print(f"Max GPU Memory cost is {max_memory} MB.")

Expand Down

0 comments on commit 8d78cc3

Please sign in to comment.