# Image Telephone

# Environment management
- import relevant modules
- make environment assertions
- set up variables for dataflow to use

You'll want to tune the variables for later use (the image, the s3 bucket/data directory...)

In [None]:
import os
import sys
import urllib
from io import BytesIO
from typing import Tuple

import boto3
from PIL import Image
from tenacity import retry, stop_after_delay
from run import determine_state

from hamilton import driver
from hamilton.io.materialization import to

assert "OPENAI_API_KEY" in os.environ, "Must have OpenAI key set for this to work!"

# Ensure your state is correct

You'll want your initial image in png format -- change the `INITIAL_IMAGE_PATH` to specify it!

In [None]:
STORAGE_ENGINE = "local" # s3 or local
S3_BUCKET = "dagworks-image-telephone" # TODO -- put your bucket

DATA_DIR = "./results" # For local mode, unset for now


INITIAL_IMAGE_PATH = "./seed_images/test_wikipedia_image_20231213.png"
UNIQUE_IMAGE_NAME = "test_wikipedia_image_20231213"
NUM_ITERATIONS = 3

DESCRIPTIVENESS = "obsessively"

In [None]:
if STORAGE_ENGINE == "s3":
    assert S3_BUCKET is not None, "Must provide S3_BUCKET to use S3"

if STORAGE_ENGINE == "local":
    assert DATA_DIR is not None, "Must provide data directory for results when using local mode"
    
BASE_SAVE_LOCATION = os.path.join(DATA_DIR, UNIQUE_IMAGE_NAME) if STORAGE_ENGINE == "local" else os.path.join(f"s3://{S3_BUCKET}/{UNIQUE_IMAGE_NAME}")

if STORAGE_ENGINE == "local":
    if not os.path.exists(BASE_SAVE_LOCATION):
        os.makedirs(BASE_SAVE_LOCATION)
    

# Pull dataflows from the Hub

These two dataflows have everything we need to play image telephone. We're going to download two dataflows:

1. `caption_images` -- this has the ability to provide a caption given an image
2. `generate_images` -- this has the ability to generate an image, given a caption

We use the hub API to download the modules, then do a quick visualization to ensure we're happy with what we've got. We've combined these into the same driver, although one could easily run two drivers. The DAG's are actually independent

In [None]:
from hamilton import dataflows
caption_images = dataflows.import_module("caption_images", "elijahbenizzy")
generate_images = dataflows.import_module("generate_images", "elijahbenizzy")
import caption_images
import generate_images
dr = driver.Driver({"include_embeddings" : True}, caption_images, generate_images)
dr.display_all_functions(orient="TB")

# Define our Capabilities (chains)

We define some pretty basic functions that allow us to run components of the DAG. We'll be running these in a loop, displaying the results in-between to track progress. We do two calls to `.materialize(...)` -- this allows us to run/execute the DAG.

1. Generate captions
2. Generate images

We then update the state, and run again!

In [None]:
# This allows execution to start where it left off
iteration, image_url, has_original = determine_state(
    INITIAL_IMAGE_PATH,
    STORAGE_ENGINE,
    UNIQUE_IMAGE_NAME,
    {"base_dir": DATA_DIR, "s3_bucket": S3_BUCKET}
)

# Loop until we're there
while iteration < NUM_ITERATIONS:
    print(f" Beginning iteration: {iteration} with image URL: {image_url}")
    metadata_save_path = os.path.join(BASE_SAVE_LOCATION, f"metadata_{iteration}.json")
    # Run the caption generation step
    _, results = dr.materialize(
        to.json(
                path=metadata_save_path,
                dependencies=["metadata"],
                id="save_metadata",
            ),
        *([] if has_original else [
            to.image(
                path=os.path.join(BASE_SAVE_LOCATION, f"{UNIQUE_IMAGE_NAME}/original.png"),
                dependencies=["image_url"],
                id=f"save_original_image",
                format="png",
            )
        ]),
        additional_vars=["generated_caption"],
        inputs={
            "image_url" : image_url,
            "descriptiveness" : DESCRIPTIVENESS,
            "additional_metadata" : {
                "descriptiveness" : DESCRIPTIVENESS,
                "iteration" : iteration,
            }
        }
    )

    generated_caption = results["generated_caption"]
    print(f"Captioned image: {image_url} with caption: {generated_caption}. \n\n Saved metadata (caption + embeddings) at: {metadata_save_path}")
    image_save_path = os.path.join(BASE_SAVE_LOCATION, f"image_{iteration}.png")

    # Run the image generation step
    _, results = dr.materialize(
        to.image(
            path=image_save_path,
            dependencies=["generated_image"],
            id=f"save_image",
            format="png",
        ),
        inputs={"image_generation_prompt" : generated_caption},
        additional_vars=["generated_image"]
    )
    generated_caption = results["generated_caption"]
    generated_image = results["generated_image"]
    print(f"Generated image, saved at: {image_save_path}")
    iteration += 1
    image_url = image_save_path
    has_original = True
    with open(generated_image) as url:
        img = Image.open(BytesIO(url.read()))
    display(img)