Skip to content

Commit b6d9827

Browse files
committed
added - to be tested locally
1 parent ca9883b commit b6d9827

File tree

3 files changed

+49
-11
lines changed

3 files changed

+49
-11
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
1+
STATUS.md
22
refiner-cache
33
sdxl-cache
44
safety-cache

cog.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ build:
2323
- "fire==0.5.0"
2424
- "opencv-python>=4.1.0.25"
2525
- "mediapipe==0.10.2"
26+
- "compel==2.0.2"
2627

2728
run:
2829
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.0.1/pget" && chmod +x /usr/local/bin/pget

predict.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import subprocess
66
import time
77
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
8-
8+
from compel import Compel, ReturnedEmbeddingsType
99
import numpy as np
1010
import torch
1111
from cog import BasePredictor, Input, Path
@@ -226,7 +226,6 @@ def setup(self, weights: Optional[Path] = None):
226226
)
227227
self.refiner.to("cuda")
228228
print("setup took: ", time.time() - start)
229-
# self.txt2img_pipe.__class__.encode_prompt = new_encode_prompt
230229

231230
def load_image(self, path):
232231
shutil.copyfile(path, "/tmp/image.png")
@@ -366,36 +365,73 @@ def predict(
366365

367366
pipe.scheduler = SCHEDULERS[scheduler].from_config(pipe.scheduler.config)
368367
generator = torch.Generator("cuda").manual_seed(seed)
369-
368+
"""""
370369
common_args = {
371370
"prompt": [prompt] * num_outputs,
372371
"negative_prompt": [negative_prompt] * num_outputs,
373372
"guidance_scale": guidance_scale,
374373
"generator": generator,
375374
"num_inference_steps": num_inference_steps,
376375
}
376+
"""
377+
378+
# Compel for Base Pipeline
379+
compel_base = Compel(tokenizer=[pipe.tokenizer, pipe.tokenizer_2] , text_encoder=[pipe.text_encoder, pipe.text_encoder_2], returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=[False, True])
380+
conditioning, pooled = compel_base(prompt)
381+
conditioning_neg, pooled_neg = compel_base(negative_prompt) if negative_prompt is not None else (None, None)
382+
383+
common_args = {
384+
"prompt_embeds": [conditioning] * num_outputs,
385+
"pooled_prompt_embeds": [pooled] * num_outputs,
386+
"negative_prompt_embeds": [conditioning_neg] * num_outputs,
387+
"negative_pooled_prompt_embeds" : [pooled_neg] * num_outputs,
388+
"guidance_scale": guidance_scale,
389+
"generator": generator,
390+
"num_inference_steps": num_inference_steps,
391+
"denoising_end":high_noise_frac
392+
}
377393

378394
if self.is_lora:
379395
sdxl_kwargs["cross_attention_kwargs"] = {"scale": lora_scale}
380396

397+
## START BASE PIPELINE
381398
output = pipe(**common_args, **sdxl_kwargs)
382-
399+
400+
# Refiner
383401
if refine in ["expert_ensemble_refiner", "base_image_refiner"]:
402+
# Compel for Refiner Pipeline
403+
compel_refiner = Compel(tokenizer=self.refiner.tokenizer_2 , text_encoder=self.refiner.text_encoder_2, returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=True)
404+
conditioning, pooled = compel_refiner(prompt)
405+
conditioning_neg, pooled_neg = compel_refiner(negative_prompt) if negative_prompt is not None else (None, None)
406+
384407
refiner_kwargs = {
385408
"image": output.images,
386409
}
387410

411+
refine_args = {
412+
"prompt_embeds": [conditioning] * num_outputs,
413+
"pooled_prompt_embeds": [pooled] * num_outputs,
414+
"negative_prompt_embeds": [conditioning_neg] * num_outputs,
415+
"negative_pooled_prompt_embeds" : [pooled_neg] * num_outputs,
416+
"guidance_scale": guidance_scale,
417+
"generator": generator,
418+
"denoising_end": high_noise_frac
419+
}
420+
388421
if refine == "expert_ensemble_refiner":
389422
refiner_kwargs["denoising_start"] = high_noise_frac
390423
if refine == "base_image_refiner" and refine_steps:
391-
common_args["num_inference_steps"] = refine_steps
424+
refine_args["num_inference_steps"] = refine_steps
425+
426+
output = self.refiner(**refine_args, **refiner_kwargs)
392427

393-
output = self.refiner(**common_args, **refiner_kwargs)
394428

429+
# Check for Watermark
395430
if not apply_watermark:
396431
pipe.watermark = watermark_cache
397432
self.refiner.watermark = watermark_cache
398433

434+
# NSFW Check
399435
_, has_nsfw_content = self.run_safety_checker(output.images)
400436

401437
output_paths = []
@@ -407,9 +443,10 @@ def predict(
407443
output.images[i].save(output_path)
408444
output_paths.append(Path(output_path))
409445

410-
if len(output_paths) == 0:
411-
raise Exception(
412-
f"NSFW content detected. Try running it again, or try a different prompt."
413-
)
446+
# Remove exception, get the content the same
447+
# if len(output_paths) == 0:
448+
# raise Exception(
449+
# f"NSFW content detected. Try running it again, or try a different prompt."
450+
# )
414451

415452
return output_paths

0 commit comments

Comments
 (0)