In [8]:
import os
import torch 
import sys
import numpy as np
device = torch.device("cuda")
sys.path.append('src')
from omegaconf import OmegaConf 
from diffusers import StableDiffusionPipeline 
from utils_model import load_model_from_config 
import random
import json
seed = 42

np.random.seed(seed)
random.seed(seed)


ldm_config = "ckpt/v2-inference.yaml"
ldm_ckpt = "ckpt/v2-1_512-ema-pruned.ckpt"

In [2]:
from pycocotools.coco import COCO
import random

ann_path = 'coco/annotations/captions_train2014.json'
coco = COCO(ann_path)

img_ids = coco.getImgIds()
random.seed(42)
random.shuffle(img_ids)


captions = []
for img_id in img_ids[:16]:  
    ann_ids = coco.getAnnIds(imgIds=img_id)
    anns = coco.loadAnns(ann_ids)
    if anns:
        captions.append(anns[0]['caption'])  

print(f"Total images selected: {len(captions)}")
print("Example captions:")
for i, cap in enumerate(captions[:5]):
    print(f"{i+1}: {cap}")

loading annotations into memory...
Done (t=0.56s)
creating index...
index created!
Total images selected: 16
Example captions:
1: A large blue clock tower above an old brick building.

2: Large group of stuffed animals sitting on top of a red bed. 
3: People swimming in the ocean, one is surfing.
4: a woman is sitting up in her bed 
5: A large truck made for hauling loads is stopped at a stop sign.


In [3]:
print(f'>>> Building LDM model with config {ldm_config} and weights from {ldm_ckpt}...')
config = OmegaConf.load(f"{ldm_config}")
ldm_ae = load_model_from_config(config, ldm_ckpt)
ldm_aef = ldm_ae.first_stage_model
ldm_aef.eval()

# loading the fine-tuned decoder weights
state_dict = torch.load("ckpt/sd2_decoder.pth")
unexpected_keys = ldm_aef.load_state_dict(state_dict, strict=False)
print(unexpected_keys)
print("you should check that the decoder keys are correctly matched")

# loading the pipeline, and replacing the decode function of the pipe
model = "stabilityai/stable-diffusion-2"
pipe = StableDiffusionPipeline.from_pretrained(model).to(device)


>>> Building LDM model with config ckpt/v2-inference.yaml and weights from ckpt/v2-1_512-ema-pruned.ckpt...
Loading model from ckpt/v2-1_512-ema-pruned.ckpt


  pl_sd = torch.load(ckpt, map_location="cpu")


Global Step: 220000
No module 'xformers'. Proceeding without it.
LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 865.91 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels


  state_dict = torch.load("ckpt/sd2_decoder.pth")
Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: 
```
pip install accelerate
```
.


_IncompatibleKeys(missing_keys=['encoder.conv_in.weight', 'encoder.conv_in.bias', 'encoder.down.0.block.0.norm1.weight', 'encoder.down.0.block.0.norm1.bias', 'encoder.down.0.block.0.conv1.weight', 'encoder.down.0.block.0.conv1.bias', 'encoder.down.0.block.0.norm2.weight', 'encoder.down.0.block.0.norm2.bias', 'encoder.down.0.block.0.conv2.weight', 'encoder.down.0.block.0.conv2.bias', 'encoder.down.0.block.1.norm1.weight', 'encoder.down.0.block.1.norm1.bias', 'encoder.down.0.block.1.conv1.weight', 'encoder.down.0.block.1.conv1.bias', 'encoder.down.0.block.1.norm2.weight', 'encoder.down.0.block.1.norm2.bias', 'encoder.down.0.block.1.conv2.weight', 'encoder.down.0.block.1.conv2.bias', 'encoder.down.0.downsample.conv.weight', 'encoder.down.0.downsample.conv.bias', 'encoder.down.1.block.0.norm1.weight', 'encoder.down.1.block.0.norm1.bias', 'encoder.down.1.block.0.conv1.weight', 'encoder.down.1.block.0.conv1.bias', 'encoder.down.1.block.0.norm2.weight', 'encoder.down.1.block.0.norm2.bias', 'e

  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(
Loading pipeline components...: 100%|██████████| 6/6 [00:08<00:00,  1.47s/it]


In [4]:
# prompt = "the cat drinks water."

# generator = torch.manual_seed(seed)
# img_orig = pipe(prompt,generator = generator).images[0]
# img_orig.save("cat_original.png")

# pipe.vae.decode = (lambda x,  *args, **kwargs: ldm_aef.decode(x).unsqueeze(0))



# generator = torch.manual_seed(seed)
# img = pipe(prompt,generator = generator).images[0]
# img.save("cat_watermarked.png")

In [5]:
# generator = torch.manual_seed(seed)
# prompts = [ "a pig","a dragon"]
# results = pipe(prompts, generator=generator).images

# for i, img in enumerate(results):
#     img.save(f"output_{i}.png")

In [None]:

output_dir = "test"
os.makedirs(f"{output_dir}/original", exist_ok=True)
os.makedirs(f"{output_dir}/watermarked", exist_ok=True)

In [None]:
batch_size = 4

dataset_records = []
original_decode = pipe.vae.decode
for idx in range(0, len(captions), batch_size):
    batch_prompts = captions[idx:idx+batch_size]
    generator = torch.Generator(device="cuda").manual_seed(seed + idx)
    pipe.vae.decode = original_decode  
    image_origs = pipe(batch_prompts, generator=generator).images

    orig_paths = []
    for i, img in enumerate(image_origs):
        orig_path = f"{output_dir}/original/img_{idx + i:04d}.png"
        img.save(orig_path)
        orig_paths.append(orig_path)


    generator = torch.Generator(device="cuda").manual_seed(seed + idx)
    pipe.vae.decode = lambda x, *args, **kwargs: ldm_aef.decode(x).unsqueeze(0)
    image_watermarked = pipe(batch_prompts, generator=generator).images

    for i, img in enumerate(image_origs):
        wm_path = f"{output_dir}/watermarked/img_{idx + i:04d}.png"
        img.save(wm_path)
    
        dataset_records.append({
            "id": idx + i,
            "prompt": batch_prompts[i],
            "original": orig_paths[i],
            "watermarked": wm_path
        })
    



100%|██████████| 50/50 [01:20<00:00,  1.62s/it]
100%|██████████| 50/50 [01:22<00:00,  1.65s/it]
100%|██████████| 50/50 [01:22<00:00,  1.65s/it]
100%|██████████| 50/50 [01:22<00:00,  1.66s/it]
100%|██████████| 50/50 [01:22<00:00,  1.66s/it]
100%|██████████| 50/50 [01:22<00:00,  1.65s/it]
100%|██████████| 50/50 [01:22<00:00,  1.65s/it]
100%|██████████| 50/50 [01:22<00:00,  1.65s/it]


NameError: name 'json' is not defined

In [None]:
with open(f"{output_dir}/metadata.json", "w") as f:
    json.dump(dataset_records, f, indent=2)

print(f"✅ Done! Total samples: {len(dataset_records)}")

✅ Done! Total samples: 16
