In [None]:
import os
import re
import pandas as pd

import matplotlib.pyplot as plt
from PIL import Image

In [None]:
def gather_generated_images_with_hyperparams(base_folder: str):
    """
    Iterates over the benchmark_images_generations folder, gathers all available
    generated PNG images, parses their names to extract hyperparameters, and
    saves everything (together with the path to the image and its baseline image) into a DataFrame.
    """
    image_data = []
    # Regex to parse the filename and extract hyperparameters
    # Example: alphanoise0.05_timesteps50_QTrue_KTrue_VTrue_taua0.4_taub0.8_guidance3.0_all-layers.png
    pattern = re.compile(
        r"alphanoise(?P<alpha_noise>[\d.]+)_"
        r"timesteps(?P<timesteps>\d+)_"
        r"Q(?P<inject_q>True|False)_"
        r"K(?P<inject_k>True|False)_"
        r"V(?P<inject_v>True|False)_"
        r"taua(?P<tau_alpha>[\d.]+)_"
        r"taub(?P<tau_beta>[\d.]+)_"
        r"guidance(?P<guidance_scale>[\d.]+)_(?P<layers_for_injection>all|vital)-layers\.png"
    )

    for root, dirs, files in os.walk(base_folder):
        for file in files:
            if file.endswith(".png") and pattern.match(file): # Ensure we only process generated images, not baselines with this check
                match = pattern.match(file)
                if match:
                    hyperparams = match.groupdict()
                    # Convert types
                    hyperparams['alpha_noise'] = float(hyperparams['alpha_noise'])
                    hyperparams['timesteps'] = int(hyperparams['timesteps'])
                    hyperparams['inject_q'] = hyperparams['inject_q'] == 'True'
                    hyperparams['inject_k'] = hyperparams['inject_k'] == 'True'
                    hyperparams['inject_v'] = hyperparams['inject_v'] == 'True'
                    hyperparams['tau_alpha'] = float(hyperparams['tau_alpha'])
                    hyperparams['tau_beta'] = float(hyperparams['tau_beta'])
                    hyperparams['guidance_scale'] = float(hyperparams['guidance_scale'])
                    # layers_for_injection is already a string ('all' or 'vital')
                    
                    # Add image path and other relevant info from the path
                    full_path = os.path.join(root, file)
                    path_parts = full_path.split(os.sep)
                    # Example path: ../benchmark_images_generations/Real-Cartoon/0000 a cartoon animation of a sheep in the forest/alphanoise....png
                    if len(path_parts) >= 4: # Check based on expected depth from base_folder
                        hyperparams['category'] = path_parts[-3]
                        hyperparams['sample_description'] = path_parts[-2]
                    else:
                        hyperparams['category'] = None
                        hyperparams['sample_description'] = None

                    hyperparams['image_path'] = full_path
                    
                    # Find baseline image in the sample's folder (root which is the sample folder)
                    baseline_path_jpg = os.path.join(root, 'cp_bg_fg.jpg')
                    baseline_path_png = os.path.join(root, 'cp_bg_fg.png')
                    if os.path.exists(baseline_path_jpg):
                        hyperparams['baseline_image_path'] = baseline_path_jpg
                    elif os.path.exists(baseline_path_png):
                        hyperparams['baseline_image_path'] = baseline_path_png
                    else:
                        hyperparams['baseline_image_path'] = None
                        
                    image_data.append(hyperparams)
                
    return pd.DataFrame(image_data)


# Default hyperparameters from run_with_params.py
DEFAULT_HYPERPARAMS = {
    'alpha_noise': 0.05,
    'timesteps': 50,
    'inject_q': False,
    'inject_k': False,
    'inject_v': False,
    'tau_alpha': 0.4,
    'tau_beta': 0.8,
    'guidance_scale': 3.0,
    'layers_for_injection': 'vital' # New default hyperparameter
}

# All hyperparameters that are part of the filename and can be parsed
ALL_FILENAME_PARAMS = list(DEFAULT_HYPERPARAMS.keys())

def visualize_hyperparameter_variation(
    df: pd.DataFrame,
    fixed_hyperparams_input: dict,
    varying_hyperparam_name: str,
    hide_incomplete_rows: bool = False,
    w_pad: int = 0
) -> None:
    """
    Visualizes a grid of images, with a baseline image per sample, 
    and then varying one hyperparameter while others are fixed.

    Args:
        df (pd.DataFrame): DataFrame containing image paths, parsed hyperparameters, and baseline_image_path.
        fixed_hyperparams_input (dict): Dictionary of hyperparameter names and their fixed values.
                                        Parameters not in this dict (and not varying) will use defaults.
        varying_hyperparam_name (str): Name of the hyperparameter to vary.
        hide_incomplete_rows (bool, optional): If True, rows (samples) with missing images 
                                               for any variation will not be shown. Defaults to False.
        w_pad (int, optional): Horizontal padding for tight_layout.
    """
    if df.empty:
        print("Input DataFrame is empty. Cannot generate visualization.")
        return

    # Prepare the set of conditions for fixed hyperparameters
    current_fixed_conditions = {}
    for param_name in ALL_FILENAME_PARAMS:
        if param_name == varying_hyperparam_name:
            continue  # This parameter will be varied, not fixed

        if param_name in fixed_hyperparams_input:
            current_fixed_conditions[param_name] = fixed_hyperparams_input[param_name]
        else:
            # Use default if not specified in fixed_hyperparams_input
            current_fixed_conditions[param_name] = DEFAULT_HYPERPARAMS[param_name]

    # Filter the DataFrame based on fixed conditions
    df_filtered = df.copy()
    for param, value in current_fixed_conditions.items():
        if param in df_filtered.columns:
            df_filtered = df_filtered[df_filtered[param] == value]
        else:
            print(f"Warning: Fixed parameter '{param}' not found in DataFrame columns. Skipping this filter.")

    if df_filtered.empty:
        print(f"No images found matching the fixed hyperparameters: {current_fixed_conditions}")
        return

    # Identify unique samples (rows) and varying values (columns)
    # Sort samples by category then description for consistent row ordering
    df_filtered_sorted_samples = df_filtered.sort_values(by=['category', 'sample_description'])
    # Ensure baseline_image_path is kept with unique samples
    unique_samples_all = df_filtered_sorted_samples[['category', 'sample_description', 'baseline_image_path']].drop_duplicates(subset=['category', 'sample_description']).reset_index(drop=True)
    
    varying_values = sorted(df_filtered[varying_hyperparam_name].unique())
    num_cols_variations = len(varying_values)

    if hide_incomplete_rows and num_cols_variations > 0:
        indices_of_complete_rows = []
        for index, sample_data in unique_samples_all.iterrows(): # iterrows gives (index, Series)
            category = sample_data['category']
            description = sample_data['sample_description']
            
            count_images_for_current_sample = 0
            for var_value in varying_values:
                if not df_filtered[
                    (df_filtered['category'] == category) &
                    (df_filtered['sample_description'] == description) &
                    (df_filtered[varying_hyperparam_name] == var_value)
                ].empty:
                    count_images_for_current_sample += 1
            
            if count_images_for_current_sample == num_cols_variations:
                indices_of_complete_rows.append(index) 
        
        unique_samples = unique_samples_all.loc[indices_of_complete_rows].reset_index(drop=True)
    else:
        unique_samples = unique_samples_all

    num_rows = len(unique_samples)

    if num_rows == 0:
        print("No samples to display after filtering.")
        print(f"  Fixed conditions: {current_fixed_conditions}")
        print(f"  Varying parameter: {varying_hyperparam_name}")
        if hide_incomplete_rows and num_cols_variations > 0:
             print("  Incomplete rows might have been hidden.")
        return
        
    if num_cols_variations == 0 and num_rows > 0: # If no variations, but we have samples, we can still show baselines
        print(f"No variations found for hyperparameter '{varying_hyperparam_name}' with the given fixed conditions. Showing baselines only.")
        num_cols_plot = 1 # Only baseline column
    elif num_cols_variations == 0 and num_rows == 0:
        print(f"No variations found for hyperparameter '{varying_hyperparam_name}' and no samples to display.")
        print(f"  Fixed conditions: {current_fixed_conditions}")
        return
    else:
        num_cols_plot = num_cols_variations + 1 # +1 for the baseline image column

    fig, axes = plt.subplots(num_rows, num_cols_plot, figsize=(num_cols_plot * 4, num_rows * 3.5), squeeze=False)

    for r in range(num_rows):
        current_sample_info = unique_samples.iloc[r]
        current_category = current_sample_info['category']
        current_sample_desc = current_sample_info['sample_description']
        baseline_image_path_for_sample = current_sample_info['baseline_image_path']
        
        # Create a wrapped version of the sample description
        wrapped_current_sample_desc = '\n'.join(
            ' '.join(current_sample_desc[4:].strip().split()[i:i+6]) for i in range(0, len(current_sample_desc.split()), 6)
        )
        ylabel_text = f"{current_category.upper()}\n{wrapped_current_sample_desc}"
        axes[r, 0].set_ylabel(ylabel_text, rotation=0, fontsize=18, ha='right', va='center', labelpad=60) # Increased labelpad for baseline title

        # Plot baseline image in the first column
        ax_baseline = axes[r, 0]
        if baseline_image_path_for_sample and os.path.exists(baseline_image_path_for_sample):
            try:
                img_baseline = Image.open(baseline_image_path_for_sample)
                ax_baseline.imshow(img_baseline)
            except Exception as e:
                ax_baseline.text(0.5, 0.5, f'Baseline Error:\n{e}', ha='center', va='center', fontsize=8, color='red')
        else:
            ax_baseline.text(0.5, 0.5, 'Baseline N/A', ha='center', va='center', fontsize=10)
        
        ax_baseline.set_xticks([])
        ax_baseline.set_yticks([])
        for spine in ax_baseline.spines.values():
            spine.set_color("blue")
            spine.set_linewidth(2)
        if r == 0:
            ax_baseline.set_title("Baseline", fontsize=18, pad=20, color="blue") # Added padding for title

        # Plot varying hyperparameter images if there are variations
        if num_cols_variations > 0:
            for c in range(num_cols_variations):
                current_varying_value = varying_values[c]
                ax_variant = axes[r, c + 1] # Plot in subsequent columns

                image_data_row = df_filtered[
                    (df_filtered['category'] == current_category) &
                    (df_filtered['sample_description'] == current_sample_desc) &
                    (df_filtered[varying_hyperparam_name] == current_varying_value)
                ]

                if not image_data_row.empty:
                    image_path = image_data_row['image_path'].iloc[0]
                    try:
                        img = Image.open(image_path)
                        ax_variant.imshow(img)
                    except FileNotFoundError:
                        ax_variant.text(0.5, 0.5, 'Image not found', ha='center', va='center', fontsize=8, color='red')
                    except Exception as e:
                        ax_variant.text(0.5, 0.5, f'Error loading:\n{e}', ha='center', va='center', fontsize=8, color='red')
                else:
                    ax_variant.text(0.5, 0.5, 'N/A', ha='center', va='center', fontsize=10)
                
                ax_variant.set_xticks([])
                ax_variant.set_yticks([])
                for spine in ax_variant.spines.values():
                    spine.set_visible(False)

                if r == 0: # Set column titles only for the first row
                    ax_variant.set_title(f"{varying_hyperparam_name}={current_varying_value}", fontsize=18, pad=20) # Added padding



    # Create a title stating the varying hyperparameter and fixed parameters (4 per line)
    title_parts = [f"Varying '{varying_hyperparam_name}'"]
    fixed_param_details = []
    for i, (param, value) in enumerate(current_fixed_conditions.items()):
        fixed_param_details.append(f"{param}={value}")
    
    for i in range(0, len(fixed_param_details), 4):
        title_parts.append(", ".join(fixed_param_details[i:i+4]))
    wrapped_title = "\n".join(title_parts)
    
    fig.suptitle(wrapped_title, fontsize=20, y=1.05) # Adjusted y and fontsize for suptitle
    
    # Add a bit more horizontal space between baseline and first generated image if w_pad is small
    custom_w_pad = w_pad
    if num_cols_plot > 1 and w_pad < 1: # if there's more than just baseline and w_pad is small
        custom_w_pad = 1.0 # default to a bit of space

    plt.tight_layout(rect=[0, 0.02, 1, 0.96], w_pad=custom_w_pad, h_pad=2.0) # Adjusted rect for suptitle and y-labels, added h_pad
    os.makedirs("../assets/", exist_ok=True)
    plt.savefig("../assets/ablation_on_{}___params_{}.png".format(
        varying_hyperparam_name,
        '_'.join([f"{k}{v}" for k,v in current_fixed_conditions.items()])
    ), bbox_inches='tight', dpi=300)
    plt.show()

In [None]:
benchmark_folder = '../benchmark_images_generations/' 
df_generated_images = gather_generated_images_with_hyperparams(benchmark_folder)
print(f"Found {len(df_generated_images)} images.")

### Ablation on $\tau_{\alpha}$

In [None]:
visualize_hyperparameter_variation(df_generated_images, {'layers_for_injection': 'all', 'inject_q': True, 'inject_k': True,'inject_v': True,}, 'tau_alpha', hide_incomplete_rows=True, w_pad=0)


In [None]:
visualize_hyperparameter_variation(df_generated_images, {'layers_for_injection': 'vital', 'inject_q': True, 'inject_k': True,'inject_v': True,}, 'tau_alpha', hide_incomplete_rows=True, w_pad=0)


### Ablation on $\tau_{\beta}$

In [None]:
visualize_hyperparameter_variation(
    df_generated_images,
    {'layers_for_injection': 'vital', 'inject_q': True, 'inject_k': True,'inject_v': True,},
    'tau_beta',
    hide_incomplete_rows=True,
    # w_pad=3
)