Skip to content

Commit d4d35d4

Browse files
committed
Update predict.py
1 parent 6ef7d21 commit d4d35d4

File tree

1 file changed

+72
-11
lines changed

1 file changed

+72
-11
lines changed

predict.py

Lines changed: 72 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
import os
3-
import re
3+
import requests
44
import shutil
55
import subprocess
66
import time
@@ -31,6 +31,8 @@
3131

3232
from dataset_and_utils import TokenEmbeddingsHandler
3333

34+
from lora_diffusion import LoRAManager, monkeypatch_remove_lora
35+
3436
SDXL_MODEL_CACHE = "./sdxl-cache"
3537
REFINER_MODEL_CACHE = "./refiner-cache"
3638
SAFETY_CACHE = "./safety-cache"
@@ -41,6 +43,7 @@
4143
)
4244
SAFETY_URL = "https://weights.replicate.delivery/default/sdxl/safety-1.0.tar"
4345

46+
device = "cuda" if torch.cuda.is_available() else "cpu"
4447

4548
class KarrasDPM:
4649
def from_config(config):
@@ -65,8 +68,32 @@ def download_weights(url, dest):
6568
subprocess.check_call(["pget", "-x", url, dest])
6669
print("downloading took: ", time.time() - start)
6770

71+
def url_local_fn(url):
72+
return sha512(url.encode()).hexdigest() + ".safetensors"
6873

6974
class Predictor(BasePredictor):
75+
76+
77+
def download_lora(url):
78+
# TODO: allow-list of domains
79+
80+
fn = url_local_fn(url)
81+
82+
if not os.path.exists(fn):
83+
print("Downloading LoRA model... from", url)
84+
# stream chunks of the file to disk
85+
with requests.get(url, stream=True) as r:
86+
r.raise_for_status()
87+
with open(fn, "wb") as f:
88+
for chunk in r.iter_content(chunk_size=8192):
89+
f.write(chunk)
90+
91+
else:
92+
print("Using disk cache...")
93+
94+
return fn
95+
96+
7097
def load_trained_weights(self, weights, pipe):
7198
local_weights_cache = "./trained-model"
7299
if not os.path.exists(local_weights_cache):
@@ -135,7 +162,7 @@ def load_trained_weights(self, weights, pipe):
135162
cross_attention_dim=cross_attention_dim,
136163
rank=name_rank_map[name],
137164
)
138-
unet_lora_attn_procs[name] = module.to("cuda")
165+
unet_lora_attn_procs[name] = module.to(device)
139166

140167
unet.set_attn_processor(unet_lora_attn_procs)
141168
unet.load_state_dict(tensors, strict=False)
@@ -157,13 +184,14 @@ def setup(self, weights: Optional[Path] = None):
157184
"""Load the model into memory to make running multiple predictions efficient"""
158185
start = time.time()
159186
self.tuned_model = False
187+
self.lora_manager = None
160188

161189
print("Loading safety checker...")
162190
if not os.path.exists(SAFETY_CACHE):
163191
download_weights(SAFETY_URL, SAFETY_CACHE)
164192
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
165193
SAFETY_CACHE, torch_dtype=torch.float16
166-
).to("cuda")
194+
).to(device)
167195
self.feature_extractor = CLIPImageProcessor.from_pretrained(FEATURE_EXTRACTOR)
168196

169197
if not os.path.exists(SDXL_MODEL_CACHE):
@@ -180,7 +208,7 @@ def setup(self, weights: Optional[Path] = None):
180208
if weights or os.path.exists("./trained-model"):
181209
self.load_trained_weights(weights, self.txt2img_pipe)
182210

183-
self.txt2img_pipe.to("cuda")
211+
self.txt2img_pipe.to(device)
184212

185213
print("Loading SDXL img2img pipeline...")
186214
self.img2img_pipe = StableDiffusionXLImg2ImgPipeline(
@@ -192,7 +220,7 @@ def setup(self, weights: Optional[Path] = None):
192220
unet=self.txt2img_pipe.unet,
193221
scheduler=self.txt2img_pipe.scheduler,
194222
)
195-
self.img2img_pipe.to("cuda")
223+
self.img2img_pipe.to(device)
196224

197225
print("Loading SDXL inpaint pipeline...")
198226
self.inpaint_pipe = StableDiffusionXLInpaintPipeline(
@@ -204,7 +232,7 @@ def setup(self, weights: Optional[Path] = None):
204232
unet=self.txt2img_pipe.unet,
205233
scheduler=self.txt2img_pipe.scheduler,
206234
)
207-
self.inpaint_pipe.to("cuda")
235+
self.inpaint_pipe.to(device)
208236

209237
print("Loading SDXL refiner pipeline...")
210238
# FIXME(ja): should the vae/text_encoder_2 be loaded from SDXL always?
@@ -224,7 +252,7 @@ def setup(self, weights: Optional[Path] = None):
224252
use_safetensors=True,
225253
variant="fp16",
226254
)
227-
self.refiner.to("cuda")
255+
self.refiner.to(device)
228256
print("setup took: ", time.time() - start)
229257
# self.txt2img_pipe.__class__.encode_prompt = new_encode_prompt
230258

@@ -233,16 +261,34 @@ def load_image(self, path):
233261
return load_image("/tmp/image.png").convert("RGB")
234262

235263
def run_safety_checker(self, image):
236-
safety_checker_input = self.feature_extractor(image, return_tensors="pt").to(
237-
"cuda"
238-
)
264+
safety_checker_input = self.feature_extractor(image, return_tensors="pt").to(device)
239265
np_image = [np.array(val) for val in image]
240266
image, has_nsfw_concept = self.safety_checker(
241267
images=np_image,
242268
clip_input=safety_checker_input.pixel_values.to(torch.float16),
243269
)
244270
return image, has_nsfw_concept
245271

272+
def set_lora(self, urllists: List[str], scales: List[float]):
273+
assert len(urllists) == len(scales), "Number of LoRAs and scales must match."
274+
275+
merged_fn = url_local_fn(f"{'-'.join(urllists)}")
276+
277+
if self.loaded == merged_fn:
278+
print("The requested LoRAs are loaded.")
279+
assert self.lora_manager is not None
280+
else:
281+
282+
st = time.time()
283+
self.lora_manager = LoRAManager(
284+
[download_lora(url) for url in urllists], self.pipe
285+
)
286+
self.loaded = merged_fn
287+
print(f"merging time: {time.time() - st}")
288+
289+
self.lora_manager.tune(scales)
290+
291+
246292
@torch.inference_mode()
247293
def predict(
248294
self,
@@ -315,6 +361,10 @@ def predict(
315361
description="Applies a watermark to enable determining if an image is generated in downstream applications. If you have other provisions for generating or deploying images safely, you can use this to disable watermarking.",
316362
default=True,
317363
),
364+
lora_urls: str = Input(
365+
description="List of urls for safetensors of lora models, seperated with | .",
366+
default="",
367+
),
318368
lora_scale: float = Input(
319369
description="LoRA additive scale. Only applicable on trained models.",
320370
ge=0.0,
@@ -365,7 +415,18 @@ def predict(
365415
self.refiner.watermark = None
366416

367417
pipe.scheduler = SCHEDULERS[scheduler].from_config(pipe.scheduler.config)
368-
generator = torch.Generator("cuda").manual_seed(seed)
418+
generator = torch.Generator(device).manual_seed(seed)
419+
420+
# check if LoRA Urls is supported
421+
if len(lora_urls) > 0:
422+
lora_urls = [u.strip() for u in lora_urls.split("|")]
423+
lora_scales = [float(s.strip()) for s in lora_scales.split("|")]
424+
self.set_lora(lora_urls, lora_scales)
425+
prompt = self.lora_manager.prompt(prompt)
426+
else:
427+
print("No LoRA models provided, using default model...")
428+
monkeypatch_remove_lora(self.txt2img_pipe.unet)
429+
monkeypatch_remove_lora(self.txt2img_pipe.text_encoder)
369430

370431
common_args = {
371432
"prompt": [prompt] * num_outputs,

0 commit comments

Comments
 (0)