11import json
22import os
3- import re
3+ import requests
44import shutil
55import subprocess
66import time
3131
3232from dataset_and_utils import TokenEmbeddingsHandler
3333
34+ from lora_diffusion import LoRAManager , monkeypatch_remove_lora
35+
3436SDXL_MODEL_CACHE = "./sdxl-cache"
3537REFINER_MODEL_CACHE = "./refiner-cache"
3638SAFETY_CACHE = "./safety-cache"
4143)
4244SAFETY_URL = "https://weights.replicate.delivery/default/sdxl/safety-1.0.tar"
4345
46+ device = "cuda" if torch .cuda .is_available () else "cpu"
4447
4548class KarrasDPM :
4649 def from_config (config ):
@@ -65,8 +68,32 @@ def download_weights(url, dest):
6568 subprocess .check_call (["pget" , "-x" , url , dest ])
6669 print ("downloading took: " , time .time () - start )
6770
71+ def url_local_fn (url ):
72+ return sha512 (url .encode ()).hexdigest () + ".safetensors"
6873
6974class Predictor (BasePredictor ):
75+
76+
77+ def download_lora (url ):
78+ # TODO: allow-list of domains
79+
80+ fn = url_local_fn (url )
81+
82+ if not os .path .exists (fn ):
83+ print ("Downloading LoRA model... from" , url )
84+ # stream chunks of the file to disk
85+ with requests .get (url , stream = True ) as r :
86+ r .raise_for_status ()
87+ with open (fn , "wb" ) as f :
88+ for chunk in r .iter_content (chunk_size = 8192 ):
89+ f .write (chunk )
90+
91+ else :
92+ print ("Using disk cache..." )
93+
94+ return fn
95+
96+
7097 def load_trained_weights (self , weights , pipe ):
7198 local_weights_cache = "./trained-model"
7299 if not os .path .exists (local_weights_cache ):
@@ -135,7 +162,7 @@ def load_trained_weights(self, weights, pipe):
135162 cross_attention_dim = cross_attention_dim ,
136163 rank = name_rank_map [name ],
137164 )
138- unet_lora_attn_procs [name ] = module .to ("cuda" )
165+ unet_lora_attn_procs [name ] = module .to (device )
139166
140167 unet .set_attn_processor (unet_lora_attn_procs )
141168 unet .load_state_dict (tensors , strict = False )
@@ -157,13 +184,14 @@ def setup(self, weights: Optional[Path] = None):
157184 """Load the model into memory to make running multiple predictions efficient"""
158185 start = time .time ()
159186 self .tuned_model = False
187+ self .lora_manager = None
160188
161189 print ("Loading safety checker..." )
162190 if not os .path .exists (SAFETY_CACHE ):
163191 download_weights (SAFETY_URL , SAFETY_CACHE )
164192 self .safety_checker = StableDiffusionSafetyChecker .from_pretrained (
165193 SAFETY_CACHE , torch_dtype = torch .float16
166- ).to ("cuda" )
194+ ).to (device )
167195 self .feature_extractor = CLIPImageProcessor .from_pretrained (FEATURE_EXTRACTOR )
168196
169197 if not os .path .exists (SDXL_MODEL_CACHE ):
@@ -180,7 +208,7 @@ def setup(self, weights: Optional[Path] = None):
180208 if weights or os .path .exists ("./trained-model" ):
181209 self .load_trained_weights (weights , self .txt2img_pipe )
182210
183- self .txt2img_pipe .to ("cuda" )
211+ self .txt2img_pipe .to (device )
184212
185213 print ("Loading SDXL img2img pipeline..." )
186214 self .img2img_pipe = StableDiffusionXLImg2ImgPipeline (
@@ -192,7 +220,7 @@ def setup(self, weights: Optional[Path] = None):
192220 unet = self .txt2img_pipe .unet ,
193221 scheduler = self .txt2img_pipe .scheduler ,
194222 )
195- self .img2img_pipe .to ("cuda" )
223+ self .img2img_pipe .to (device )
196224
197225 print ("Loading SDXL inpaint pipeline..." )
198226 self .inpaint_pipe = StableDiffusionXLInpaintPipeline (
@@ -204,7 +232,7 @@ def setup(self, weights: Optional[Path] = None):
204232 unet = self .txt2img_pipe .unet ,
205233 scheduler = self .txt2img_pipe .scheduler ,
206234 )
207- self .inpaint_pipe .to ("cuda" )
235+ self .inpaint_pipe .to (device )
208236
209237 print ("Loading SDXL refiner pipeline..." )
210238 # FIXME(ja): should the vae/text_encoder_2 be loaded from SDXL always?
@@ -224,7 +252,7 @@ def setup(self, weights: Optional[Path] = None):
224252 use_safetensors = True ,
225253 variant = "fp16" ,
226254 )
227- self .refiner .to ("cuda" )
255+ self .refiner .to (device )
228256 print ("setup took: " , time .time () - start )
229257 # self.txt2img_pipe.__class__.encode_prompt = new_encode_prompt
230258
@@ -233,16 +261,34 @@ def load_image(self, path):
233261 return load_image ("/tmp/image.png" ).convert ("RGB" )
234262
235263 def run_safety_checker (self , image ):
236- safety_checker_input = self .feature_extractor (image , return_tensors = "pt" ).to (
237- "cuda"
238- )
264+ safety_checker_input = self .feature_extractor (image , return_tensors = "pt" ).to (device )
239265 np_image = [np .array (val ) for val in image ]
240266 image , has_nsfw_concept = self .safety_checker (
241267 images = np_image ,
242268 clip_input = safety_checker_input .pixel_values .to (torch .float16 ),
243269 )
244270 return image , has_nsfw_concept
245271
272+ def set_lora (self , urllists : List [str ], scales : List [float ]):
273+ assert len (urllists ) == len (scales ), "Number of LoRAs and scales must match."
274+
275+ merged_fn = url_local_fn (f"{ '-' .join (urllists )} " )
276+
277+ if self .loaded == merged_fn :
278+ print ("The requested LoRAs are loaded." )
279+ assert self .lora_manager is not None
280+ else :
281+
282+ st = time .time ()
283+ self .lora_manager = LoRAManager (
284+ [download_lora (url ) for url in urllists ], self .pipe
285+ )
286+ self .loaded = merged_fn
287+ print (f"merging time: { time .time () - st } " )
288+
289+ self .lora_manager .tune (scales )
290+
291+
246292 @torch .inference_mode ()
247293 def predict (
248294 self ,
@@ -315,6 +361,10 @@ def predict(
315361 description = "Applies a watermark to enable determining if an image is generated in downstream applications. If you have other provisions for generating or deploying images safely, you can use this to disable watermarking." ,
316362 default = True ,
317363 ),
364+ lora_urls : str = Input (
365+ description = "List of urls for safetensors of lora models, seperated with | ." ,
366+ default = "" ,
367+ ),
318368 lora_scale : float = Input (
319369 description = "LoRA additive scale. Only applicable on trained models." ,
320370 ge = 0.0 ,
@@ -365,7 +415,18 @@ def predict(
365415 self .refiner .watermark = None
366416
367417 pipe .scheduler = SCHEDULERS [scheduler ].from_config (pipe .scheduler .config )
368- generator = torch .Generator ("cuda" ).manual_seed (seed )
418+ generator = torch .Generator (device ).manual_seed (seed )
419+
420+ # check if LoRA Urls is supported
421+ if len (lora_urls ) > 0 :
422+ lora_urls = [u .strip () for u in lora_urls .split ("|" )]
423+ lora_scales = [float (s .strip ()) for s in lora_scales .split ("|" )]
424+ self .set_lora (lora_urls , lora_scales )
425+ prompt = self .lora_manager .prompt (prompt )
426+ else :
427+ print ("No LoRA models provided, using default model..." )
428+ monkeypatch_remove_lora (self .txt2img_pipe .unet )
429+ monkeypatch_remove_lora (self .txt2img_pipe .text_encoder )
369430
370431 common_args = {
371432 "prompt" : [prompt ] * num_outputs ,
0 commit comments