In [None]:
from diffusers import StableDiffusionPipeline
import torch
from torchvision.transforms import v2

import os

from images import load_original_images, load_generated_images, generate_and_save_images
from unet import load_unet, save_unet, train_unet, save_train_loss

In [None]:
# SD-Turbo text-to-image pipeline
pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/sd-turbo", torch_dtype=torch.float16, variant="fp16")
pipeline.to("cuda")

In [None]:
def transform(examples: dict) -> dict:
    """
    Formatting transform for a text-to-image dataset.

    Args:
        examples (`dict`):
            Batch of examples.

    Returns:
        `dict`: Batch of transformed examples.
    """

    # Function to preprocess images before VAE encoding
    preprocess = v2.Compose([
        v2.Resize((512, 512)),
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize([0.5], [0.5]),
    ])

    # Preprocess images and tokenize text
    pixel_values = [preprocess(image) for image in examples["image"]]
    input_ids = pipeline.tokenizer(examples["text"], return_tensors="pt")["input_ids"]
    return {"pixel_values": pixel_values, "input_ids": input_ids}

In [None]:
# Prompts to use as inputs
prompts = {
    "main": "a photo capturing student life on campus at the University of Toronto",
    "similar": "a photo capturing student life on campus at the University of Waterloo",
    "different": "a wide shot of Santa Monica Beach",
}

# Number of images to generated for each prompt per generation
num_images = {"main": 75, "similar": 25, "different": 25}

In [None]:
# Generation loop
for generation in range(20):
    print(f"Generation: {generation}")

    # Load generated dataset and UNet for subsequent generations
    if generation > 0:

        # Load generated dataset from previous generation
        print(" - Loading Dataset")
        dataset = load_generated_images(prompts, "main", generation - 1)
        dataset.set_transform(transform)

        # Load UNet from previous generation
        print(" - Loading UNet")
        load_unet(pipeline, generation - 1)

    # Load original dataset for first generation
    else:

        # Load original dataset
        print(" - Loading Dataset")
        dataset = load_original_images(prompts, "main")
        dataset.set_transform(transform)

    # Training loop
    print(" - Training UNet")
    train_loss = train_unet(pipeline, dataset)

    # Save training loss
    print(" - Saving Train Loss")
    save_train_loss(train_loss, generation)

    # Save UNet state
    print(" - Saving UNet")
    save_unet(pipeline, generation)

    # Generate and save images
    print(" - Generating and Saving Images")
    for prompt_name in prompts:
        generate_and_save_images(pipeline, prompts, prompt_name, generation, num_images[prompt_name])

    # Compress all data from this generation
    print(" - Compressing")
    os.system(f"tar -C data -czf data/gen_{generation}.tar.gz gen_{generation}")