Skip to content

Commit e5a2641

Browse files
authored
watermark optional (replicate#1)
* watermark optional
1 parent 8d10402 commit e5a2641

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

cog.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ build:
1111
- "libsm6"
1212
- "libxext6"
1313
python_packages:
14-
- "git+https://github.com/huggingface/diffusers.git@34abee090750d22fef357bfb1ffd564c961b9e1d"
14+
- "diffusers==0.19.3"
1515
- "torch==2.0.1"
1616
- "transformers==4.31.0"
1717
- "invisible-watermark==0.2.0"

predict.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,10 @@ def predict(
249249
description="for base_image_refiner, the number of steps to refine, defaults to num_inference_steps",
250250
default=None,
251251
),
252+
apply_watermark: bool = Input(
253+
description="Applies a watermark to enable determining if an image is generated in downstream applications. If you have other provisions for generating or deploying images safely, you can use this to disable watermarking.",
254+
default=True
255+
)
252256
) -> List[Path]:
253257
"""Run a single prediction on the model"""
254258
if seed is None:
@@ -284,6 +288,12 @@ def predict(
284288
elif refine == "base_image_refiner":
285289
sdxl_kwargs["output_type"] = "latent"
286290

291+
if not apply_watermark:
292+
# toggles watermark for this prediction
293+
watermark_cache = pipe.watermark
294+
pipe.watermark = None
295+
self.refiner.watermark = None
296+
287297
pipe.scheduler = SCHEDULERS[scheduler].from_config(pipe.scheduler.config)
288298
generator = torch.Generator("cuda").manual_seed(seed)
289299

@@ -307,7 +317,11 @@ def predict(
307317
common_args["num_inference_steps"] = refine_steps
308318

309319
output = self.refiner(**common_args, **refiner_kwargs)
310-
320+
321+
if not apply_watermark:
322+
pipe.watermark = watermark_cache
323+
self.refiner.watermark = watermark_cache
324+
311325
_, has_nsfw_content = self.run_safety_checker(output.images)
312326

313327
output_paths = []

0 commit comments

Comments
 (0)