Skip to content

Commit

Permalink
Merge pull request #2 from daanelson/dan/cleanup
Browse files Browse the repository at this point in the history
now with resizing and proper download
  • Loading branch information
anotherjesse committed Jun 13, 2023
2 parents e4183e7 + f2fba46 commit 1194962
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
18 changes: 14 additions & 4 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,12 @@ def predict(
default=1,
),
image_resolution: int = Input(
description="Resolution of image (square)",
description="Resolution of image (smallest dimension)",
choices=[256, 512, 768],
default=512,
),
scheduler: str = Input(
default="DPMSolverMultistep",
default="DDIM",
choices=SCHEDULERS.keys(),
description="Choose a scheduler.",
),
Expand Down Expand Up @@ -211,13 +211,23 @@ def predict(
high_threshold=high_threshold,
)

scale = float(image_resolution) / (min(input_image.size))

def quick_rescale(dim, scale):
"""quick rescale to a multiple of 64, as per original controlnet"""
dim *= scale
return int(np.round(dim / 64.0)) * 64

width = quick_rescale(input_image.size[0], scale)
height = quick_rescale(input_image.size[1], scale)

generator = torch.Generator("cuda").manual_seed(seed)

outputs = pipe(
prompt,
input_image,
height=image_resolution,
width=image_resolution,
height=height,
width=width,
num_inference_steps=steps,
guidance_scale=scale,
eta=eta,
Expand Down
8 changes: 8 additions & 0 deletions script/download_weights
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import torch
from diffusers import ControlNetModel, StableDiffusionPipeline
from controlnet_aux import HEDdetector, OpenposeDetector, MLSDdetector, MidasDetector
from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
import torch

# append project directory to path so predict.py can be imported
sys.path.append(".")
Expand Down Expand Up @@ -53,3 +54,10 @@ MLSDdetector.from_pretrained("lllyasviel/ControlNet", cache_dir=PROCESSORS_CACHE
OpenposeDetector.from_pretrained("lllyasviel/Annotators", cache_dir=PROCESSORS_CACHE)

shutil.rmtree(TMP_CACHE)

if os.path.exists(SD15_WEIGHTS):
shutil.rmtree(SD15_WEIGHTS)
os.makedirs(SD15_WEIGHTS)

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipe.save_pretrained(SD15_WEIGHTS)

0 comments on commit 1194962

Please sign in to comment.