# Chinese Classical Poetry Visualization System
This system generates visual interpretations of Chinese classical poems using AI models, combining modified SDXL for image generation and GLM-4 for poem analysis.

## Setup and Dependencies
The following cell installs required packages and imports necessary libraries.

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# Required pip installations
!pip install transformers
!pip install diffusers
!pip install accelerate
!pip install zhipuai
!pip install moviepy
!pip install bayesian-optimization
!pip install xformers
!pip install safetensors
!pip install triton

# Imports
import os
import gc
import json
import time
import torch
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from PIL import Image, ImageDraw, ImageFont
from moviepy.editor import ImageClip, concatenate_videoclips
import traceback
import torch.nn.functional as F
from transformers import CLIPProcessor, CLIPModel
from diffusers import DiffusionPipeline, EulerDiscreteScheduler, StableDiffusionXLImg2ImgPipeline, StableVideoDiffusionPipeline
from diffusers.utils import load_image, export_to_video
from bayes_opt import BayesianOptimization
from zhipuai import ZhipuAI
from concurrent.futures import ThreadPoolExecutor
import re
import triton
import subprocess
import shutil

# Disable warnings
import warnings
warnings.filterwarnings('ignore')

# Enable cuda if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Collecting zhipuai
  Downloading zhipuai-2.1.5.20230904-py3-none-any.whl.metadata (10 kB)
Collecting pyjwt<2.9.0,>=2.8.0 (from zhipuai)
  Downloading PyJWT-2.8.0-py3-none-any.whl.metadata (4.2 kB)
Downloading zhipuai-2.1.5.20230904-py3-none-any.whl (104 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m105.0/105.0 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading PyJWT-2.8.0-py3-none-any.whl (22 kB)
Installing collected packages: pyjwt, zhipuai
  Attempting uninstall: pyjwt
    Found existing installation: PyJWT 2.10.0
    Uninstalling PyJWT-2.10.0:
      Successfully uninstalled PyJWT-2.10.0
Successfully installed pyjwt-2.8.0 zhipuai-2.1.5.20230904
Collecting bayesian-optimization
  Downloading bayesian_optimization-2.0.0-py3-none-any.whl.metadata (8.9 kB)
Collecting colorama<0.5.0,>=0.4.6 (from bayesian-optimization)
  Downloading colorama-0.4.6-py2.py3-none-any.whl.metadata (17 kB)
Downloading bayesian_optimization-2.0.0-py3-none-any.whl (30 kB)
Downloa

  if event.key is 'enter':



In [None]:
# Define models to compare
MODELS_TO_COMPARE = {
    "SDXL": "stabilityai/stable-diffusion-xl-base-1.0"
}

def find_available_font():
    pass

def clear_gpu_memory():
    """Clear GPU memory and cache."""
    gc.collect()
    torch.cuda.empty_cache()
    if torch.cuda.is_available():
        torch.cuda.synchronize()

def save_to_drive(video_path):
    """Save the output to Google Drive if mounted"""
    pass

def load_poem_from_json(json_file_path, poem_title):
    """Load a specific poem from the JSON file."""
    try:
        with open(json_file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)

        for poem in data['poems']:
            if poem['title'] == poem_title:
                return poem

        print(f"Poem '{poem_title}' not found in the database.")
        return None

    except Exception as e:
        print(f"Error loading poem data: {str(e)}")
        return None

class EnhancedPoemAnalyzer:
    def __init__(self, api_key="968cd0b672b9b5133d01741721558a95.xYFanKwaJ2ShpQuZ"):
        self.client = ZhipuAI(api_key=api_key)
        self.chunk_cache = {}

    def translate_to_english(self, chinese_text):
        prompt = f"""Please translate the following Chinese text to English, maintaining the visual and descriptive nature of the content:

        Chinese text:
        {chinese_text}

        Requirements:
        1. Translate to natural, fluent English
        2. Preserve all visual descriptions and imagery
        3. Keep any technical or specific terms
        4. Maintain the original structure where applicable
        """

        response = self.client.chat.completions.create(
            model="glm-4",
            messages=[{"role": "user", "content": prompt}]
        )

        return response.choices[0].message.content.strip()

    def interpret_cultural_terms(self, text, context):
        prompt = f"""请分析这句诗中的专有名词或文化意象，将其转换为具体的视觉描述：

        原诗上下文：
        {context}

        需要分析的句子：
        {text}

        请识别所有的专有名词、人物称谓或文化意象，并给出具体的视觉描述。
        格式要求：
        1. 每个词一行
        2. 用 "词：视觉描述" 的格式
        3. 描述必须是具体的、可视化的，避免抽象概念
        4. 描述要符合诗歌语境
        5. 请直接用英文描述

        示例：
        若"王孙"在此诗中表达隐居的文人，应描述为"a scholarly man in traditional robes meditating in nature"
        若"渔父"出现，应描述为"an old fisherman in simple clothes on a wooden boat"
        """

        response = self.client.chat.completions.create(
            model="glm-4",
            messages=[{"role": "user", "content": prompt}]
        )

        interpretations = {}
        for line in response.choices[0].message.content.strip().split('\n'):
            if '：' in line:
                term, desc = line.split('：', 1)
                interpretations[term.strip()] = desc.strip()

        interpreted_text = text
        for term, desc in interpretations.items():
            interpreted_text = interpreted_text.replace(term, desc)

        return interpreted_text

    def get_poem_understanding(self, full_poem):
        study_prompt = f"""请从视觉角度分析这首诗的整体意境、场景和情感，重点描述：
        1. 主要场景和环境特征
        2. 光线和时间的变化
        3. 人物的动作和状态
        4. 整体氛围和情感基调

        诗文：
        {full_poem}

        请用英文回答，使用具体的视觉语言描述，避免抽象概念。"""

        response = self.client.chat.completions.create(
            model="glm-4",
            messages=[{"role": "user", "content": study_prompt}]
        )

        return response.choices[0].message.content

    def analyze_chunk_detail(self, chunk, category, context_dict):
        cache_key = f"{chunk}_{category}"
        if cache_key in self.chunk_cache:
            return self.chunk_cache[cache_key]

        # Get results from other analysis functions
        english_translation = self.translate_to_english(chunk)
        interpreted_chunk = self.interpret_cultural_terms(chunk, context_dict['full_poem'])
        poem_understanding = self.get_poem_understanding(context_dict['full_poem'])
        previous_chunk = context_dict.get('previous_chunk', '')

        prompts = {
            "subject_action": f"""

            Analyze the subjects and their actions in this line/segment of poetry.

            The line to be focused on: "{chunk}",
            and its previous line: "{previous_chunk}"

            Return in this exact format. IT MUST ALIGN WITH THE MEANING OF THAT LINE!!!!!!!!!
            subjects: [concrete description of each person/animal/living being, their clothing/appearance if mentioned]
            actions: [specific descriptions of actions]

            Requirements:
            - Your analysis must align with:
              * The English translation provided above
              * The cultural interpretations provided above
              * The overall poem understanding provided above
            - For subjects, always include full description
            - If no explicit subject in current line, use the subject from previous line
            - If previous line has subject "孤鸿" (lonely goose), current line should use "the lonely goose" as subject
            - All actions must be visually concrete (e.g., for "不敢顾", show "lifting its wings away from")
            - Avoid literary descriptions and Chinese terms
            - If this is the first line ({context_dict['chunk_index'] == 0}), only include subjects explicitly mentioned
            - If subject is not further described in terms of appearance or clothing in this or previous line,
              do not mention something like "not further describe..." or "not explicitly stated..." or "imply/implies..." or "not further described in terms of..." in final output
            - Again if something is not explicitly describe, PLEASE do not let us know in the final output.
              For example, "featuring A character - not explicitly described in terms of appearance or clothing"
              Do not mention "not explicitly described in terms of appearance or clothing"
            - Do not put explanations or interpretations in brackets
            - List each subject and action exactly once
            - Use concise, visual descriptions
            - For temporal descriptions (like future events), describe the current state
            - For questions about return (like '归不归'), describe as 'contemplating return'
            - Always include emotional subjects even if implied
            """,

            "scene_setting": f"""

            Analyze the scene and environmental elements in this line: "{chunk}"

            Please return in this format. IT MUST ALIGN WITH THE MEANING OF THAT LINE!!!!!!!!!
            locations: [specific scene locations]
            objects: [specific physical objects, flora, fauna, celestial elements and architectural elements]

            Requirements:
            - Your analysis must align with:
              * The English translation provided above
              * The cultural interpretations provided above
              * The overall poem understanding provided above
            - Descriptions must be concrete and visual
            - Avoid display of Chinese in final output
            - All physical elements must be included! (e.g. 举杯邀明月 will have cup and moon; 巢在三珠树 will have nests on branches, three pearl trees)
            - Take note of the amounts too (e.g. three birds)
            - Be extremely specific about objects (e.g., "three pearl trees" rather than just "trees")
            - No explanations in the brackets
            - List each element exactly once
            - Use concise, visual terms
            - Include seasonal elements (e.g., spring grass should be listed as an object)
            - Include temporal indicators as physical manifestations
            """,

            "time_weather": f"""

            Analyze the time and weather elements in this line: "{chunk}"

            Please return in this format. IT MUST ALIGN WITH THE MEANING OF THAT LINE!!!!!!!!!
            time: [specific time, e.g., sunset, dawn]
            weather: [specific weather conditions]

            Requirements:
            - Your analysis must align with:
              * The English translation provided above
              * The cultural interpretations provided above
              * The overall poem understanding provided above
            - Use commonly recognized natural phenomena
            - Avoid display of Chinese in final output
            - If there is a moon mentioned, set time to night
            - No explanations in the brackets
            - List each element exactly once
            - Use concise, visual terms
            - Include seasonal timeframes
            - Include future time references as present atmospheric conditions
            """,

            "mood": f"""

            Analyze the visual atmosphere and emotional elements in this line: "{chunk}"

            Please return in this format. IT MUST ALIGN WITH THE MEANING OF THAT LINE!!!!!!!!!
            lighting: [specific lighting effects]
            atmosphere: [specific visual mood]
            color_tone: [main color tones]

            Requirements:
            - Your analysis must align with:
              * The English translation provided above
              * The cultural interpretations provided above
              * The overall poem understanding provided above
            - All descriptions should be directly usable for image generation
            - Avoid display of Chinese in final output
            - No explanations in the brackets
            - List each element exactly once
            - Use concise, visual terms
            - Capture emotional uncertainty in visual terms
            - Include seasonal color palettes
            - Transform abstract concepts into visual metaphors
            """
        }

        response = self.client.chat.completions.create(
            model="glm-4",
            messages=[{"role": "user", "content": prompts[category]}]
        )

        result = response.choices[0].message.content.strip()
        self.chunk_cache[cache_key] = result
        return result

    def analyze_chunk_parallel(self, chunk, context):
        with ThreadPoolExecutor(max_workers=4) as executor:
            futures = {
                "subject_action": executor.submit(self.analyze_chunk_detail, chunk, "subject_action", context),
                "scene_setting": executor.submit(self.analyze_chunk_detail, chunk, "scene_setting", context),
                "time_weather": executor.submit(self.analyze_chunk_detail, chunk, "time_weather", context),
                "mood": executor.submit(self.analyze_chunk_detail, chunk, "mood", context)
            }

            return {
                "text": chunk,
                "subject_action": futures["subject_action"].result(),
                "scene_setting": futures["scene_setting"].result(),
                "time_weather": futures["time_weather"].result(),
                "mood": futures["mood"].result()
            }

    def pack_chunk_to_prompt(self, chunk_analysis, overall_understanding):
        elements = {
            'primary_subjects': [],
            'secondary_subjects': [],
            'actions': [],
            'objects': [],
            'environment': [],
            'lighting': [],
            'atmosphere': [],
            'style': ['RTX realism with ray tracing', 'traditional China']
        }

        # Extract all elements from the analysis
        if 'scene_setting' in chunk_analysis:
            content = chunk_analysis['scene_setting']
            if 'locations: [' in content:
                elements['environment'].extend(
                    content.split('locations: [')[1].split(']')[0].split(', '))
            if 'objects: [' in content:
                elements['objects'].extend(
                    content.split('objects: [')[1].split(']')[0].split(', '))

        if 'mood' in chunk_analysis:
            content = chunk_analysis['mood']
            if 'lighting: [' in content:
                elements['lighting'].extend(
                    content.split('lighting: [')[1].split(']')[0].split(', '))
            if 'atmosphere: [' in content:
                elements['atmosphere'].extend(
                    content.split('atmosphere: [')[1].split(']')[0].split(', '))

        # Build a more comprehensive prompt
        prompt_parts = []

        # Add environmental elements first
        if elements['environment'] or elements['objects']:
            env_parts = []
            if elements['environment']:
                env_parts.extend(elements['environment'])
            if elements['objects']:
                env_parts.extend(elements['objects'])
            prompt_parts.append(f"featuring {', '.join(env_parts)}")

        # Add subjects and actions
        if elements['primary_subjects'] or elements['actions']:
            action_parts = []
            if elements['primary_subjects']:
                action_parts.extend(elements['primary_subjects'])
            if elements['actions']:
                action_parts.extend(elements['actions'])
            prompt_parts.append(f"with {', '.join(action_parts)}")

        # Add atmosphere and lighting
        if elements['atmosphere'] or elements['lighting']:
            mood_parts = []
            if elements['atmosphere']:
                mood_parts.extend(elements['atmosphere'])
            if elements['lighting']:
                mood_parts.extend(elements['lighting'])
            prompt_parts.append(f"in {', '.join(mood_parts)}")

        # Add style elements
        prompt_parts.append(', '.join(elements['style']))

        return ' | '.join(filter(None, prompt_parts))

    def get_contextual_analysis(self, full_poem, segment, context_dict):
        overall_understanding = self.get_poem_understanding(full_poem)
        detailed_analysis = self.analyze_chunk_parallel(segment, context_dict)
        compact_prompt = self.pack_chunk_to_prompt(detailed_analysis, overall_understanding)

        translated_poem = self.translate_to_english(full_poem)

        return {
            "original_poem": full_poem,
            "translated_poem": translated_poem,
            "overall_understanding": overall_understanding,
            "detailed_analysis": detailed_analysis,
            "compact_prompt": compact_prompt
        }

class BayesianStableDiffusion:
    def __init__(self, model_id="stabilityai/stable-diffusion-xl-base-1.0", num_inference_steps=50,
                 clip_model_name="openai/clip-vit-base-patch32"):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model_id = model_id
        self.refiner_id = "stabilityai/stable-diffusion-xl-refiner-1.0"

        print(f"Initializing models on device: {self.device}")
        if torch.cuda.is_available():
            print(f"Available CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
            clear_gpu_memory()

        try:
            # Load base model
            print(f"Loading base model {model_id}...")
            self.base = DiffusionPipeline.from_pretrained(
                model_id,
                torch_dtype=torch.float16,
                variant="fp16",
                use_safetensors=True
            ).to(self.device)

            # Load refiner model
            print(f"Loading refiner model...")
            self.refiner = DiffusionPipeline.from_pretrained(
                self.refiner_id,
                torch_dtype=torch.float16,
                variant="fp16",
                use_safetensors=True,
                text_encoder_2=self.base.text_encoder_2,
                vae=self.base.vae,
            ).to(self.device)

            # Configure schedulers
            self.base.scheduler = EulerDiscreteScheduler.from_config(
                self.base.scheduler.config,
                use_karras_sigmas=True
            )
            self.refiner.scheduler = EulerDiscreteScheduler.from_config(
                self.refiner.scheduler.config,
                use_karras_sigmas=True
            )

            # Enable optimizations for both models
            for pipe in [self.base, self.refiner]:
                try:
                    pipe.enable_attention_slicing(slice_size="auto")
                    pipe.enable_vae_slicing()
                    pipe.enable_xformers_memory_efficient_attention()
                except Exception as e:
                    print(f"Warning: Could not enable some optimizations: {e}")

            print("Loading CLIP model...")
            self.num_inference_steps = num_inference_steps
            self.clip_processor = CLIPProcessor.from_pretrained(clip_model_name)
            self.clip_model = CLIPModel.from_pretrained(clip_model_name).to(self.device)
            self.clip_model.eval()

            print("Model initialization completed")

        except Exception as e:
            print(f"Error initializing model: {str(e)}")
            traceback.print_exc()
            raise

    def generate_images(self, prompt, negative_prompt="", num_samples=5, guidance_scale=9, temperature=1.0):
        try:
            clear_gpu_memory()

            print(f"Generating {num_samples} images with prompt: {prompt}")

            # First pass with base model
            base_images = self.base(
                prompt=[prompt] * num_samples,
                negative_prompt=[negative_prompt] * num_samples,
                num_inference_steps=30,
                denoising_end=0.8,
                guidance_scale=guidance_scale,
                width=1024,
                height=1024,
            ).images

            # Second pass with refiner
            refined_images = []
            for base_image in base_images:
                refined = self.refiner(
                    prompt=prompt,
                    negative_prompt=negative_prompt,
                    image=base_image,
                    num_inference_steps=20,
                    denoising_start=0.8,
                    guidance_scale=guidance_scale,
                ).images[0]
                refined_images.append(refined)

            if not refined_images:
                raise ValueError("No images were generated")

            # Ensure all images are in RGB mode
            images = [img.convert('RGB') if isinstance(img, Image.Image) else Image.fromarray(img).convert('RGB')
                    for img in refined_images]

            # Compute CLIP scores
            likelihoods = self.compute_clip_likelihoods(images, prompt)

            clear_gpu_memory()
            return images, likelihoods

        except Exception as e:
            print(f"Error in generate_images: {str(e)}")
            traceback.print_exc()
            return [], np.array([])

    def compute_clip_likelihoods(self, images, prompt):
        try:
            inputs = self.clip_processor(
                text=[prompt] * len(images),
                images=images,
                return_tensors="pt",
                padding=True
            ).to(self.device)

            with torch.no_grad():
                outputs = self.clip_model(**inputs)
                image_embeds = F.normalize(outputs.image_embeds, p=2, dim=1)
                text_embeds = F.normalize(outputs.text_embeds, p=2, dim=1)
                cosine_similarity = F.cosine_similarity(image_embeds, text_embeds, dim=1)
                likelihoods = (cosine_similarity + 1) / 2
            return likelihoods.cpu().numpy()

        except Exception as e:
            print(f"Error in compute_clip_likelihoods: {str(e)}")
            traceback.print_exc()
            return np.array([0.0] * len(images))

    def compute_mean_and_variance(self, images):
        if isinstance(images[0], Image.Image):
            images = [np.array(img) for img in images]
        images_array = np.array(images) / 255.0
        mean_image = np.mean(images_array, axis=0)
        variance_image = np.var(images_array, axis=0)
        return mean_image, variance_image

class ModelComparisonExperiment:
    def __init__(self):
        self.models = {}
        self.results = {}
        self.best_images_sequence = {}
        self.poem_analyzer = EnhancedPoemAnalyzer()
        self.load_models()

    def load_models(self):
        for model_name, model_id in MODELS_TO_COMPARE.items():
            print(f"Loading {model_name}...")
            try:
                self.models[model_name] = BayesianStableDiffusion(
                    model_id=model_id,
                    num_inference_steps=50
                )
                print(f"Successfully loaded {model_name}")
            except Exception as e:
                print(f"Error loading {model_name}: {str(e)}")

    def run_comparison(self, poem):
        print("\nStarting poem analysis...")

        # Get overall poem understanding first
        overall_analysis = self.poem_analyzer.get_poem_understanding(poem)

        # Split poem into chunks and keep track of their order
        chunks = [chunk.strip() for chunk in re.split('[，。？！]', poem) if chunk.strip()]

        # Initialize results structure
        results = {
            model_name: {
                'images': [],
                'scores': [],
                'generation_times': [],
                'clip_scores': [],
                'optimization_results': [],
                'best_images': [],
                'prompts': [],  # Added to store prompts for video generation
                'chunk_texts': []  # Added to store original chunk texts
            } for model_name in self.models.keys()
        }

        # Create context dictionary for each chunk
        chunk_contexts = {}
        for i, chunk in enumerate(chunks):
            chunk_contexts[chunk] = {
                'full_poem': poem,
                'previous_chunk': chunks[i-1] if i > 0 else None,
                'chunk_index': i
            }

        # Phase 1: Generate images for all chunks
        print("\nPhase 1: Generating images for all half-stanzas...")
        for chunk in tqdm(chunks, desc="Processing half-stanzas"):
            print(f"\nProcessing half-stanza: {chunk}")

            # Get contextual analysis for the chunk
            chunk_analysis = self.poem_analyzer.get_contextual_analysis(
                poem,
                chunk,
                chunk_contexts[chunk]
            )

            for model_name, model in self.models.items():
                print(f"\nUsing model: {model_name}")

                try:
                    start_time = time.time()

                    # Generate prompt
                    main_prompt = chunk_analysis["compact_prompt"]
                    negative_prompt = "low quality, blurry, bad anatomy, bad composition, deformed, split image, collage, grid, multiple panels, comic panels, storyboard"

                    print(f"Generated prompt: {main_prompt}")

                    promptstxt = "/content/drive/MyDrive/Colab Notebooks/Capstone/Finalized Prototype/objective-1-temp-imgGen-metadata/prompt.txt"

                    try:
                        with open(promptstxt, 'a') as file:  # 'a' mode for appending
                            file.write(main_prompt + '\n')
                        print(f"Successfully appended to {promptstxt}")
                    except FileNotFoundError:
                        print(f"Error: File '{promptstxt}' not found.")
                    except Exception as e:
                        print(f"An error occurred: {e}")

                    # Number of image samples per chunk
                    num_samples = 5

                    # Optimize guidance scale
                    optimal_scale = optimize_guidance_scale(
                        model,
                        main_prompt,
                        negative_prompt,
                        num_samples
                    )

                    print(f"Generating {num_samples} images with prompt: {main_prompt}")

                    # Generate images
                    images, likelihoods = model.generate_images(
                        main_prompt,
                        negative_prompt=negative_prompt,
                        num_samples=num_samples,
                        guidance_scale=optimal_scale
                    )

                    if images and len(images) > 0 and len(likelihoods) > 0:
                        generation_time = time.time() - start_time
                        best_idx = np.argmax(likelihoods)
                        best_image = images[best_idx]

                        # Save results
                        results[model_name]['best_images'].append({
                            'image': best_image,
                            'text': chunk,
                            'prompt': main_prompt,
                            'likelihood': likelihoods[best_idx],
                            'analysis': chunk_analysis
                        })

                        # Save image to disk for video generation
                        image_path = f"/content/drive/MyDrive/Colab Notebooks/Capstone/Finalized Prototype/objective-1-temp-imgGen-metadata/halfStanza_{len(results[model_name]['images'])+1}.png"
                        best_image.save(image_path)

                        results[model_name]['images'].append(image_path)
                        results[model_name]['prompts'].append(main_prompt)
                        results[model_name]['chunk_texts'].append(chunk)
                        results[model_name]['scores'].append(likelihoods[best_idx])
                        results[model_name]['generation_times'].append(generation_time)
                        results[model_name]['clip_scores'].append(np.mean(likelihoods))
                        results[model_name]['optimization_results'].append(optimal_scale)

                        # Display results
                        self.display_model_comparison(
                            images,
                            likelihoods,
                            model_name,
                            main_prompt,
                            generation_time,
                            optimal_scale,
                            chunk_analysis
                        )
                    else:
                        print(f"No valid images generated for {model_name}")

                except Exception as e:
                    print(f"Error processing chunk with {model_name}: {str(e)}")
                    traceback.print_exc()
                    continue

        return results

    def generate_videos(self, results):
        """
        Generate videos from the generated images.
        This is called after all images have been generated.
        """
        pass

    def add_subtitles(self, video_path, text):
        pass

    def display_model_comparison(self, images, likelihoods, model_name, prompt, generation_time, guidance_scale, analysis):
        mean_image, variance_image = self.models[model_name].compute_mean_and_variance(images)

        n = len(images) + 2
        fig = plt.figure(figsize=(5*n, 12))
        gs = gridspec.GridSpec(4, n, height_ratios=[1, 1, 8, 1])

        prompt_ax = plt.subplot(gs[1, :])
        prompt_ax.axis('off')
        prompt_ax.text(0.5, 0.5, f"Model: {model_name}\nPrompt: {prompt}",
                      ha='center', va='center', wrap=True,
                      fontsize=12)

        axes = [plt.subplot(gs[2, i]) for i in range(n)]
        best_idx = np.argmax(likelihoods)

        for i, (ax, img) in enumerate(zip(axes[:len(images)], images)):
            ax.imshow(img)
            ax.axis('off')

            if i == best_idx:
                title = f"Selected Image\nLikelihood: {likelihoods[i]:.3f}"
                ax.set_title(title, color='green', fontweight='bold')
            else:
                title = f"Sample {i+1}\nLikelihood: {likelihoods[i]:.3f}"
                ax.set_title(title)

        axes[-2].imshow(mean_image)
        axes[-2].axis('off')
        axes[-2].set_title("Mean Image")

        axes[-1].imshow(variance_image, cmap='viridis')
        axes[-1].axis('off')
        axes[-1].set_title("Variance Image")

        metrics_ax = plt.subplot(gs[3, :])
        metrics_ax.axis('off')
        metrics_text = f"Generation Time: {generation_time:.2f}s | "
        metrics_text += f"Mean CLIP Score: {np.mean(likelihoods):.3f} | "
        metrics_text += f"Optimal Guidance Scale: {guidance_scale:.2f}"
        metrics_ax.text(0.5, 0.5, metrics_text,
                       ha='center', va='center',
                       fontsize=10)

        plt.tight_layout()
        plt.show()

def optimize_guidance_scale(model, prompt, negative_prompt="", num_samples=5):
    def objective(guidance_scale):
        try:
            images, likelihoods = model.generate_images(
                prompt,
                negative_prompt=negative_prompt,
                num_samples=num_samples,
                guidance_scale=guidance_scale
            )
            return np.mean(likelihoods) if len(likelihoods) > 0 else 0.0
        except Exception as e:
            print(f"Error in objective function: {str(e)}")
            return 0.0

    try:
        optimizer = BayesianOptimization(
            f=objective,
            pbounds={"guidance_scale": (7.0, 12.0)},
            random_state=42,
            verbose=0
        )

        optimizer.maximize(
            init_points=2,
            n_iter=5
        )
        return optimizer.max['params']['guidance_scale']
    except Exception as e:
        print(f"Error in optimization: {str(e)}")
        return 5

def main():
    try:
        json_file_path = '/content/drive/MyDrive/Colab Notebooks/Capstone/Poem Database/poem_database.json'  # Adjust path as needed
        try:
            with open(json_file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
                print("\nAvailable poems:")
                for poem in data['poems']:
                    print(f"- {poem['title']}")
        except Exception as e:
            print(f"Error loading poems file: {e}")
            return

        promptstxt = "/content/drive/MyDrive/Colab Notebooks/Capstone/Finalized Prototype/objective-1-temp-imgGen-metadata/prompt.txt"

        with open(promptstxt, 'w') as file:  # 'w' mode for writing (overwrites)
            file.write("")  # Write an empty string to clear the file

        poem_title = input("\nWhat poem would you like to visualize? ")
        poem_data = load_poem_from_json(json_file_path, poem_title)
        if not poem_data:
            print("Failed to load poem data.")
            return

        poem = poem_data['content']
        poem_txt = "/content/drive/MyDrive/Colab Notebooks/Capstone/Finalized Prototype/objective-1-temp-imgGen-metadata/poem.txt"

        try:
            with open(poem_txt, 'w') as f:
                f.write("")  # Write an empty string to clear the file
                f.write(poem)
            print(f"Successfully wrote to {poem_txt}")
        except Exception as e:
            print(f"Could note write poem content to file: {e}")

        print(f"\nLoaded poem: {poem_data['title']}")
        print(f"Author: {poem_data['author']}")
        print(f"Content: {poem}")

        experiment = ModelComparisonExperiment()
        results = experiment.run_comparison(poem)

        report = pd.DataFrame({
            model_name: {
                'Mean CLIP Score': np.mean(data['clip_scores']) if data['clip_scores'] else 0.0,
                'Mean Generation Time': np.mean(data['generation_times']) if data['generation_times'] else 0.0,
                'Mean Optimal Scale': np.mean(data['optimization_results']) if data['optimization_results'] else 0.0,
                'Best Score': max(data['scores']) if data['scores'] else 0.0,
                'Worst Score': min(data['scores']) if data['scores'] else 0.0
            }
            for model_name, data in results.items()
        }).T

        print("\nModel Comparison Report:")
        print(report)

        plt.figure(figsize=(15, 5))
        metrics = ['Mean CLIP Score', 'Mean Generation Time', 'Mean Optimal Scale']
        for i, metric in enumerate(metrics, 1):
            plt.subplot(1, 3, i)
            report[metric].plot(kind='bar')
            plt.title(metric)
            plt.xticks(rotation=45)

        plt.tight_layout()
        plt.show()

    except Exception as e:
        print(f"An error occurred in main: {str(e)}")
        traceback.print_exc()

if __name__ == "__main__":
    main()

Output hidden; open in https://colab.research.google.com to view.