In [None]:
# colab.ipynb
import sys
import os

# Add the root folder to the Python path
root_path = os.path.abspath("")
if root_path not in sys.path:
    sys.path.append(root_path)

import random
from huggingface_hub import InferenceClient
from PIL import Image
from google.colab import userdata
from IPython.display import display, clear_output
import ipywidgets as widgets
from datetime import datetime
from config_colab import api_token
from models import models
from img_gen_logic_colab import generate_image, on_generate_button_clicked


# Initialize the InferenceClient with the default model
client = InferenceClient(models[0]["name"], token=api_token)


# Input for left castle HP
left_hp_input = widgets.IntSlider(
    value=100,
    min=0,
    max=100,
    step=1,
    description="Left Castle HP:",
    style={"description_width": "initial"}
)

# Input for right castle HP
right_hp_input = widgets.IntSlider(
    value=100,
    min=0,
    max=100,
    step=1,
    description="Right Castle HP:",
    style={"description_width": "initial"}
)

# Input for height
height_input = widgets.IntText(
    value=512,
    description="Height:",
    style={"description_width": "initial"}
)

# Input for width
width_input = widgets.IntText(
    value=1024,
    description="Width:",
    style={"description_width": "initial"}
)

# Input for number of inference steps
num_inference_steps_input = widgets.IntSlider(
    value=20,
    min=10,
    max=100,
    step=1,
    description="Inference Steps:",
    style={"description_width": "initial"}
)

# Input for guidance scale (default set to 2)
guidance_scale_input = widgets.FloatSlider(
    value=2.0,  # Default set to 2
    min=1.0,
    max=20.0,
    step=0.5,
    description="Guidance Scale:",
    style={"description_width": "initial"}
)

# Input for seed
seed_input = widgets.IntText(
    value=random.randint(0, 1000000),
    description="Seed:",
    style={"description_width": "initial"}
)

# Checkbox to randomize seed
randomize_seed_checkbox = widgets.Checkbox(
    value=True,
    description="Randomize Seed",
    style={"description_width": "initial"}
)

# Button to generate image
generate_button = widgets.Button(
    description="Generate Image",
    button_style="success"
)


# Output area to display the image
output = widgets.Output()


# Function to generate images based on the HP values
def generate_image(left_hp, right_hp, height, width, num_inference_steps, guidance_scale, seed):
    # Generate the prompt
    prompt = generate_prompt(left_hp, right_hp)

    try:
        # Randomize the seed if the checkbox is checked
        if randomize_seed_checkbox.value:
            seed = random.randint(0, 1000000)
            seed_input.value = seed  # Update the seed input box

        print(f"Using seed: {seed}")

        # Debug: Indicate that the image is being generated
        print("Generating image... Please wait.")

        # Initialize the InferenceClient with the selected model
        client = InferenceClient(models[0]["name"], token=api_token)

        # Generate the image using the Inference API with parameters
        image = client.text_to_image(
            prompt,
            guidance_scale=guidance_scale,  # Guidance scale
            num_inference_steps=num_inference_steps,  # Number of inference steps
            width=width,  # Width
            height=height,  # Height
            seed=seed  # Random seed
        )
        return image
    except Exception as e:
        return f"An error occurred: {e}"

# Function to handle button click event
def on_generate_button_clicked(b):
    with output:
        clear_output(wait=True)  # Clear previous output
        left_hp = left_hp_input.value
        right_hp = right_hp_input.value
        height = height_input.value
        width = width_input.value
        num_inference_steps = num_inference_steps_input.value
        guidance_scale = guidance_scale_input.value
        seed = seed_input.value

        # Debug: Show selected parameters
        print(f"Left Castle HP: {left_hp}")
        print(f"Right Castle HP: {right_hp}")
        print(f"Height: {height}")
        print(f"Width: {width}")
        print(f"Inference Steps: {num_inference_steps}")
        print(f"Guidance Scale: {guidance_scale}")
        print(f"Seed: {seed}")

        # Generate the image
        image = generate_image(left_hp, right_hp, height, width, num_inference_steps, guidance_scale, seed)

        if isinstance(image, str):
            print(image)
        else:
            # Debug: Indicate that the image is being displayed and saved
            print("Image generated successfully!")
            print("Displaying image...")

            # Display the image in the notebook
            display(image)

            # Save the image with a timestamped filename
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            output_filename = f"{timestamp}_left_{left_hp}_right_{right_hp}.png"
            print(f"Saving image as {output_filename}...")
            image.save(output_filename)
            print(f"Image saved as {output_filename}")


# Attach the button click event handler
generate_button.on_click(on_generate_button_clicked)

# Display the widgets
#display(left_hp_input, right_hp_input, height_input, width_input, num_inference_steps_input, guidance_scale_input, seed_input, randomize_seed_checkbox, generate_button, output)

display(left_hp_input, right_hp_input, generate_button, output)