55import subprocess
66import time
77from typing import Any , Callable , Dict , List , Optional , Tuple , Union
8-
8+ from compel import Compel , ReturnedEmbeddingsType
99import numpy as np
1010import torch
1111from cog import BasePredictor , Input , Path
@@ -226,7 +226,6 @@ def setup(self, weights: Optional[Path] = None):
226226 )
227227 self .refiner .to ("cuda" )
228228 print ("setup took: " , time .time () - start )
229- # self.txt2img_pipe.__class__.encode_prompt = new_encode_prompt
230229
231230 def load_image (self , path ):
232231 shutil .copyfile (path , "/tmp/image.png" )
@@ -366,36 +365,73 @@ def predict(
366365
367366 pipe .scheduler = SCHEDULERS [scheduler ].from_config (pipe .scheduler .config )
368367 generator = torch .Generator ("cuda" ).manual_seed (seed )
369-
368+ """""
370369 common_args = {
371370 "prompt": [prompt] * num_outputs,
372371 "negative_prompt": [negative_prompt] * num_outputs,
373372 "guidance_scale": guidance_scale,
374373 "generator": generator,
375374 "num_inference_steps": num_inference_steps,
376375 }
376+ """
377+
378+ # Compel for Base Pipeline
379+ compel_base = Compel (tokenizer = [pipe .tokenizer , pipe .tokenizer_2 ] , text_encoder = [pipe .text_encoder , pipe .text_encoder_2 ], returned_embeddings_type = ReturnedEmbeddingsType .PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED , requires_pooled = [False , True ])
380+ conditioning , pooled = compel_base (prompt )
381+ conditioning_neg , pooled_neg = compel_base (negative_prompt ) if negative_prompt is not None else (None , None )
382+
383+ common_args = {
384+ "prompt_embeds" : [conditioning ] * num_outputs ,
385+ "pooled_prompt_embeds" : [pooled ] * num_outputs ,
386+ "negative_prompt_embeds" : [conditioning_neg ] * num_outputs ,
387+ "negative_pooled_prompt_embeds" : [pooled_neg ] * num_outputs ,
388+ "guidance_scale" : guidance_scale ,
389+ "generator" : generator ,
390+ "num_inference_steps" : num_inference_steps ,
391+ "denoising_end" :high_noise_frac
392+ }
377393
378394 if self .is_lora :
379395 sdxl_kwargs ["cross_attention_kwargs" ] = {"scale" : lora_scale }
380396
397+ ## START BASE PIPELINE
381398 output = pipe (** common_args , ** sdxl_kwargs )
382-
399+
400+ # Refiner
383401 if refine in ["expert_ensemble_refiner" , "base_image_refiner" ]:
402+ # Compel for Refiner Pipeline
403+ compel_refiner = Compel (tokenizer = self .refiner .tokenizer_2 , text_encoder = self .refiner .text_encoder_2 , returned_embeddings_type = ReturnedEmbeddingsType .PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED , requires_pooled = True )
404+ conditioning , pooled = compel_refiner (prompt )
405+ conditioning_neg , pooled_neg = compel_refiner (negative_prompt ) if negative_prompt is not None else (None , None )
406+
384407 refiner_kwargs = {
385408 "image" : output .images ,
386409 }
387410
411+ refine_args = {
412+ "prompt_embeds" : [conditioning ] * num_outputs ,
413+ "pooled_prompt_embeds" : [pooled ] * num_outputs ,
414+ "negative_prompt_embeds" : [conditioning_neg ] * num_outputs ,
415+ "negative_pooled_prompt_embeds" : [pooled_neg ] * num_outputs ,
416+ "guidance_scale" : guidance_scale ,
417+ "generator" : generator ,
418+ "denoising_end" : high_noise_frac
419+ }
420+
388421 if refine == "expert_ensemble_refiner" :
389422 refiner_kwargs ["denoising_start" ] = high_noise_frac
390423 if refine == "base_image_refiner" and refine_steps :
391- common_args ["num_inference_steps" ] = refine_steps
424+ refine_args ["num_inference_steps" ] = refine_steps
425+
426+ output = self .refiner (** refine_args , ** refiner_kwargs )
392427
393- output = self .refiner (** common_args , ** refiner_kwargs )
394428
429+ # Check for Watermark
395430 if not apply_watermark :
396431 pipe .watermark = watermark_cache
397432 self .refiner .watermark = watermark_cache
398433
434+ # NSFW Check
399435 _ , has_nsfw_content = self .run_safety_checker (output .images )
400436
401437 output_paths = []
@@ -407,9 +443,10 @@ def predict(
407443 output .images [i ].save (output_path )
408444 output_paths .append (Path (output_path ))
409445
410- if len (output_paths ) == 0 :
411- raise Exception (
412- f"NSFW content detected. Try running it again, or try a different prompt."
413- )
446+ # Remove exception, get the content the same
447+ # if len(output_paths) == 0:
448+ # raise Exception(
449+ # f"NSFW content detected. Try running it again, or try a different prompt."
450+ # )
414451
415452 return output_paths
0 commit comments