In [2]:
import json
import uuid
import os
from PIL import Image
from tqdm import tqdm

import mlx.core as mx

from stable_diffusion import StableDiffusion

args = {
    "prompt": "low resolution, 2d game tile, grass, 16x16, pixel art",
    "n_images": 1,
    "steps": 20,
    "cfg": 7.5,
    "negative_prompt": "high resolution, photo-realistic image",
    "n_rows": 1,
    "decoding_batch_size": 1,
}

def archive_image(args, image):
    id = uuid.uuid4().hex

    image_dir_path = "images"
    os.makedirs(image_dir_path, exist_ok=True)

    log_file = "experiments_log.json"
    try:
        # Load existing data
        with open(log_file, "r") as file:
            data = json.load(file)
    except (FileNotFoundError, json.JSONDecodeError):
        # If file does not exist or is empty
        data = []

    # Append new experiment data
    data.append({
        'uuid': id,
        'args': args,
    })

    # Overwrite file with updated data
    with open(log_file, "w") as file:
        json.dump(data, file)
        
    # Save image in the Images directory with uuid as name
    image.save(os.path.join(image_dir_path, f"{id}.png"))

def generate_image() -> Image:
    sd = StableDiffusion()

    # Generate the latent vectors using diffusion
    latents = sd.generate_latents(
        args["prompt"],
        n_images=args["n_images"],
        cfg_weight=args["cfg"],
        num_steps=args["steps"],
        negative_text=args["negative_prompt"],
    )

    for x_t in tqdm(latents, total=args["steps"]):
        mx.simplify(x_t)
        mx.simplify(x_t)
        mx.eval(x_t)

    # Decode them into images
    decoded = []
    for i in tqdm(range(0, args["n_images"], args["decoding_batch_size"])):
        decoded.append(sd.decode(x_t[i : i + args["decoding_batch_size"]]))
        mx.eval(decoded[-1])

    # Arrange them on a grid
    x = mx.concatenate(decoded, axis=0)
    x = mx.pad(x, [(0, 0), (8, 8), (8, 8), (0, 0)])
    B, H, W, C = x.shape
    x = x.reshape(args["n_rows"], B // args["n_rows"], H, W, C).transpose(0, 2, 1, 3, 4)
    x = x.reshape(args["n_rows"] * H, B // args["n_rows"] * W, C)
    x = (x * 255).astype(mx.uint8)

    # Save them to disc
    return Image.fromarray(x.__array__())

image = generate_image()
archive_image(args, image)


100%|██████████| 20/20 [01:46<00:00,  5.30s/it]
100%|██████████| 1/1 [00:02<00:00,  2.39s/it]
