In [None]:
import sys
sys.path.append('../src')

import json
import matplotlib.pyplot as plt
from infer_sdxl import SDXLInferencer


In [None]:
# Load the story and attributes from previous steps
sample_story = """
In the heart of the ancient mountains, where emerald forests stretched endlessly toward azure skies, 
a tale of wonder began to unfold. The peaceful valley held secrets whispered by the wind through 
towering pines, and every sunrise painted the landscape in hues of gold and crimson.
"""

# Try to load actual results if available
try:
    with open('../outputs/extracted_attributes.json', 'r') as f:
        attributes = json.load(f)
    print("Loaded attributes from previous analysis")
except FileNotFoundError:
    attributes = {"caption": "a serene mountain landscape"}
    print("Using sample attributes")

print(f"Story to visualize: {sample_story[:100]}...")
print(f"Key attributes: {attributes.get('caption', 'N/A')}")


In [None]:
# Initialize SDXL inferencer
# Note: This requires significant GPU memory and may take time to load
sdxl = SDXLInferencer()


In [None]:
# Generate images based on the story
generated_images = sdxl.generate_story_images(
    sample_story, 
    attributes, 
    num_images=3,
    num_inference_steps=20  # Faster generation for demo
)

# Display the generated images
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
for i, img in enumerate(generated_images):
    axes[i].imshow(img)
    axes[i].axis('off')
    axes[i].set_title(f'Generated Image {i+1}')

plt.tight_layout()
plt.show()

# Save images
import os
os.makedirs('../outputs', exist_ok=True)
for i, img in enumerate(generated_images):
    img.save(f'../outputs/generated_image_{i+1}.png')
    
print(f"Generated and saved {len(generated_images)} images to ../outputs/")
