Skip to content

Commit

Permalink
first pass at openai's Consistency Decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
anotherjesse committed Nov 6, 2023
1 parent fab8ea5 commit d31793c
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 34 deletions.
1 change: 1 addition & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
.git
*.png
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
.cog/
__pycache__/
diffusers-cache/
diffusers-cache/
consistencydecoder-cache/
*.png
13 changes: 7 additions & 6 deletions cog.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ build:
cuda: "11.8"
python_version: "3.11.1"
python_packages:
- "diffusers==0.11.1"
- "torch==1.13.0"
- "diffusers==0.22.1"
- "torch==2.1.0"
- "ftfy==6.1.1"
- "scipy==1.9.3"
- "transformers==4.25.1"
- "accelerate==0.15.0"
- "huggingface-hub==0.13.2"
- "scipy==1.11.3"
- "transformers==4.35.0"
- "accelerate==0.24.1"
- "huggingface-hub"
- "git+https://github.com/openai/consistencydecoder.git"

predict: "predict.py:Predictor"
42 changes: 26 additions & 16 deletions predict.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import time
from typing import List

import torch
from cog import BasePredictor, Input, Path
from consistencydecoder import ConsistencyDecoder, save_image
from diffusers import (
StableDiffusionPipeline,
PNDMScheduler,
Expand All @@ -12,32 +14,26 @@
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
)
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)

# MODEL_ID refers to a diffusers-compatible model on HuggingFace
# e.g. prompthero/openjourney-v2, wavymulder/Analog-Diffusion, etc
MODEL_ID = "stabilityai/stable-diffusion-2-1"
MODEL_ID = "runwayml/stable-diffusion-v1-5"
MODEL_CACHE = "diffusers-cache"
SAFETY_MODEL_ID = "CompVis/stable-diffusion-safety-checker"


class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
print("Loading pipeline...")
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
SAFETY_MODEL_ID,
cache_dir=MODEL_CACHE,
local_files_only=True,
)
self.pipe = StableDiffusionPipeline.from_pretrained(
MODEL_ID,
safety_checker=safety_checker,
cache_dir=MODEL_CACHE,
MODEL_CACHE,
local_files_only=True,
torch_dtype=torch.float16,
).to("cuda")

print("Loading ConsistencyDecoder...")
self.consistency_decoder = ConsistencyDecoder(
device="cuda:0", download_root="/src/consistencydecoder-cache"
)

@torch.inference_mode()
def predict(
self,
Expand All @@ -49,6 +45,10 @@ def predict(
description="Specify things to not see in the output",
default=None,
),
consistency_decoder: bool = Input(
description="Enable consistency decoder",
default=True,
),
width: int = Input(
description="Width of output image. Maximum size is 1024x768 or 768x1024 because of memory limits",
choices=[128, 256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024],
Expand Down Expand Up @@ -100,6 +100,7 @@ def predict(
self.pipe.scheduler = make_scheduler(scheduler, self.pipe.scheduler.config)

generator = torch.Generator("cuda").manual_seed(seed)
start = time.time()
output = self.pipe(
prompt=[prompt] * num_outputs if prompt is not None else None,
negative_prompt=[negative_prompt] * num_outputs
Expand All @@ -110,15 +111,24 @@ def predict(
guidance_scale=guidance_scale,
generator=generator,
num_inference_steps=num_inference_steps,
output_type="latent" if consistency_decoder else None,
)
print("Inference took", time.time() - start, "seconds")

output_paths = []
for i, sample in enumerate(output.images):
if output.nsfw_content_detected and output.nsfw_content_detected[i]:
continue

output_path = f"/tmp/out-{i}.png"
sample.save(output_path)
if consistency_decoder:
print("Running consistency decoder...")
start = time.time()
sample = self.consistency_decoder(sample.unsqueeze(0))
print("Consistency decoder took", time.time() - start, "seconds")
save_image(sample, output_path)
else:
sample.save(output_path)
output_paths.append(Path(output_path))

if len(output_paths) == 0:
Expand Down
21 changes: 10 additions & 11 deletions script/download-weights
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,25 @@
import os
import shutil
import sys
import torch

from diffusers import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import \
StableDiffusionSafetyChecker
from consistencydecoder import ConsistencyDecoder


# append project directory to path so predict.py can be imported
sys.path.append('.')
sys.path.append(".")

from predict import MODEL_CACHE, MODEL_ID

from predict import MODEL_CACHE, MODEL_ID, SAFETY_MODEL_ID
ConsistencyDecoder(device="cuda:0", download_root="/src/consistencydecoder-cache")

if os.path.exists(MODEL_CACHE):
shutil.rmtree(MODEL_CACHE)
os.makedirs(MODEL_CACHE, exist_ok=True)

saftey_checker = StableDiffusionSafetyChecker.from_pretrained(
SAFETY_MODEL_ID,
cache_dir=MODEL_CACHE,
)

pipe = StableDiffusionPipeline.from_pretrained(
MODEL_ID,
cache_dir=MODEL_CACHE,
MODEL_ID, torch_dtype=torch.float16, device="cuda:0"
)

pipe.save_pretrained(MODEL_CACHE)

0 comments on commit d31793c

Please sign in to comment.