In [None]:
import os
import json
import time
import hashlib
import requests
from concurrent.futures import ThreadPoolExecutor
from PIL import Image
from io import BytesIO
import base64
import threading

# === Config ===
CONTROLLER = "http://127.0.0.1:9000/generate"
CLIENT_ID = "my-client-1"
SAVE_DIR = "outputs"
PROMPT_FILE = "prompts.txt"  # optional: use instead of PROMPTS
NUM_WORKERS = 4  # Number of parallel requests
MAX_IN_FLIGHT = NUM_WORKERS * 2  # Allow some buffer

# === Read prompts ===
if os.path.exists(PROMPT_FILE):
    with open(PROMPT_FILE, 'r', encoding='utf-8') as f:
        PROMPTS = [line.strip() for line in f if line.strip()]
else:
    PROMPTS = ["a cat", "a dog", "a castle", "a dragon"] * 4

# === Ensure save dir ===
os.makedirs(SAVE_DIR, exist_ok=True)

# === Save function ===
def save_image_and_metadata(img_b64, metadata, index):
    prefix = f"{index:0{len(str(len(PROMPTS)))}}"
    
    # Compute hash to avoid naming collisions
    meta_str = json.dumps(metadata, sort_keys=True)
    digest = hashlib.sha1(meta_str.encode()).hexdigest()[:8]
    
    img_filename = f"{prefix}_{digest}.png"
    json_filename = f"{prefix}_{digest}.json"

    # Decode and save image
    image = Image.open(BytesIO(base64.b64decode(img_b64.split(',')[-1])))
    image.save(os.path.join(SAVE_DIR, img_filename))

    # Save metadata
    metadata["client_timestamp"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
    with open(os.path.join(SAVE_DIR, json_filename), "w") as f:
        json.dump(metadata, f, indent=2)

    print(f"Saved: {img_filename}, {json_filename}")

# === Request function with semaphore ===
semaphore = threading.Semaphore(MAX_IN_FLIGHT)

def send(index, prompt, width, height):
    with semaphore:
        try:
            resp = requests.post(
                CONTROLLER,
                json={"prompt": prompt, "width": width, "height": height},
                headers={"X-Client-ID": CLIENT_ID},
                timeout=300
            )
            resp.raise_for_status()
            data = resp.json()
            if "image" in data and "metadata" in data:
                save_image_and_metadata(data["image"], data["metadata"], index)
            else:
                print(f"Failed: Missing image or metadata for prompt: {prompt}")
        except Exception as e:
            print(f"Error at line {index+1} ('{prompt}'): {e}")

# === Main Execution ===
if __name__ == "__main__":
    start = time.time()
    with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
        futures = [
            executor.submit(send, idx, p, 1024, 1024)
            for idx, p in enumerate(PROMPTS)
        ]
        [f.result() for f in futures]
    print("Done in", round(time.time() - start, 2), "seconds.")

Saved: 0000_70885e8c.png, 0000_70885e8c.json
Saved: 0001_a5dc8d0a.png, 0001_a5dc8d0a.json
Saved: 0003_cfabfebb.png, 0003_cfabfebb.json
Saved: 0002_659558cd.png, 0002_659558cd.json
Saved: 0004_68a602fa.png, 0004_68a602fa.json
Saved: 0005_13bf0c23.png, 0005_13bf0c23.json
Saved: 0007_d1dc5070.png, 0007_d1dc5070.json
Saved: 0006_b45ff736.png, 0006_b45ff736.json
Saved: 0008_ef1309da.png, 0008_ef1309da.json
Saved: 0009_fe68c53b.png, 0009_fe68c53b.json
Saved: 0010_34e1afe6.png, 0010_34e1afe6.json
Saved: 0011_74d35a53.png, 0011_74d35a53.json
Saved: 0012_8675a082.png, 0012_8675a082.json
Saved: 0013_6ea7c167.png, 0013_6ea7c167.json
Saved: 0014_6b2e07a1.png, 0014_6b2e07a1.json
Saved: 0015_a32dc61e.png, 0015_a32dc61e.json
Done in 95.19 seconds.
