In [None]:
!pip install diffusers transformers accelerate torch --quiet

In [None]:
import os
import re
import json
import zipfile
import torch
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler

def parse_captions_file(filename):
    """
    Parses a text file with groups of captions.

    Expected format for each group (groups start with a number and a dot):

    Example group:

    1. A long caption about the subject (for metadata only)
    A small caption 1
    A small caption 2
    A small caption 3
    A small caption 4
    A small caption 5

    Returns a list of dictionaries with keys:
      - "long_caption": the long description (with the leading number and dot removed)
      - "small_captions": list of 5 small captions (used for image generation)
    """
    with open(filename, "r", encoding="utf-8") as f:
        text = f.read()


    groups_raw = re.split(r'(?=\d+\.\s)', text)
    groups_raw = [group.strip() for group in groups_raw if group.strip()]

    groups = []
    for group in groups_raw:

        lines = [line.strip() for line in group.split("\n") if line.strip()]
        if len(lines) < 6:
            print("Skipping a group (not enough lines):", lines)
            continue


        long_caption_line = lines[0]
        if "." in long_caption_line:
            parts = long_caption_line.split(".", 1)
            long_caption = parts[1].strip()
        else:
            long_caption = long_caption_line


        small_captions = lines[1:6]
        groups.append({
            "long_caption": long_caption,
            "small_captions": small_captions
        })

    return groups


In [None]:
from google.colab import files


uploaded = files.upload()

captions_file = "captions.txt"
if not os.path.exists(captions_file):
    print(f"Error: {captions_file} not found.")
else:
    caption_groups = parse_captions_file(captions_file)
    print(f"Found {len(caption_groups)} caption groups.")


In [None]:
num_inference_steps = 100
guidance_scale = 7.5
width, height = 512, 512
seed = 42

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to(device)
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
generator = torch.Generator(device).manual_seed(seed)
negative_prompt = "blurry, oversaturated, low resolution, deformed"


output_folder = "generated_images"
if not os.path.exists(output_folder):
    os.makedirs(output_folder)
    print(f"Created folder: {output_folder}")


metadata_lines = []
for idx, group in enumerate(caption_groups):
    small_captions = group["small_captions"]
    prompt = ", ".join(small_captions)
    print(f"Generating image {idx} with prompt: {prompt}")

    result = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        generator=generator,
        width=width,
        height=height
    )
    image = result.images[0]
    image_filename = os.path.join(output_folder, f"generated_image_{idx}.png")
    image.save(image_filename)
    print(f"Saved image {idx} as {image_filename}")


    for branch_idx in range(5):
        caption_for_branch = small_captions[branch_idx]
        line = f"{os.path.basename(image_filename)}.jpg#{branch_idx} {caption_for_branch}"
        metadata_lines.append(line)


In [None]:
metadata_filename = "output_metadata.txt"
with open(metadata_filename, "w", encoding="utf-8") as f:
    for line in metadata_lines:
        f.write(line + "\n")
print(f"Wrote metadata to {metadata_filename}")

zip_filename = "generated_images.zip"
with zipfile.ZipFile(zip_filename, 'w', zipfile.ZIP_DEFLATED) as zipf:
    for root, _, files in os.walk(output_folder):
        for file in files:
            filepath = os.path.join(root, file)
            arcname = os.path.relpath(filepath, output_folder)
            zipf.write(filepath, arcname)
print(f"Created zip file: {zip_filename}")


In [None]:
from google.colab import files
files.download(zip_filename)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>