In [None]:
# Re-import necessary libraries after reset
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Re-define the file path
new_file = "D:\quent\Téléchargements\Training curves.csv"

# Load the CSV file
df_new = pd.read_csv(new_file)

# Ensure there are enough columns
if df_new.shape[1] < 9:
    raise ValueError("The file does not contain at least 9 columns required.")

# Extract relevant columns (1-based indexing: 1=x, 2,5,8=y)
x = pd.to_numeric(df_new.iloc[:,0], errors="coerce")
y_cols = {
    "hairs dataset": pd.to_numeric(df_new.iloc[:,7], errors="coerce"),
    "medium dataset": pd.to_numeric(df_new.iloc[:,1], errors="coerce"),
    "large dataset": pd.to_numeric(df_new.iloc[:,4], errors="coerce"),
}

plt.figure(figsize=(8,5))

for name, y in y_cols.items():
    xy = pd.DataFrame({"x": x, "y": y}).dropna()
    xy = xy[xy["y"] > 0]  # Keep only positive y for log
    if len(xy) == 0:
        continue
    xy_sorted = xy.sort_values("x")
    y_log = np.log(xy_sorted["y"].values)
    
    # EMA smoothing
    y_smooth = pd.Series(y_log).ewm(span=200, adjust=False).mean()
    
    plt.plot(xy_sorted["x"].values, y_smooth, label=name)

plt.xlabel("Steps", fontsize=18)
plt.ylabel("ln(loss) smoothed by EMA", fontsize=18)
plt.title("")
plt.legend(fontsize=16)
plt.tight_layout()

out_path_training = "rapport/figures/training curves.png"
plt.savefig(out_path_training, dpi=150)
out_path_training


In [None]:
# We'll read the two uploaded CSV files, interpret the first column as x and the natural log of the second column as y.
# Then we'll plot one curve per file on a single chart, and save the figure for download.

import os
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

files = [
    "D:\quent\Téléchargements\RGB training.csv",
    "D:\quent\Téléchargements\RGBA training.csv",
]

def load_first_two_numeric_cols(path):
    # Try reading with header inference first
    try:
        df = pd.read_csv(path)
    except Exception:
        df = pd.read_csv(path, header=None)
    
    # If the DataFrame has fewer than 2 columns, raise
    if df.shape[1] < 2:
        raise ValueError(f"{path} n'a pas au moins 2 colonnes")
    
    # Try to coerce the first two columns to numeric; if headers are strings and numeric conversion fails for all,
    # we will attempt header=None.
    def coerce_two_cols(df_):
        c0, c1 = df_.columns[:2]
        x = pd.to_numeric(df_[c0], errors="coerce")
        y = pd.to_numeric(df_[c1], errors="coerce")
        # Drop rows where either is NaN
        xy = pd.DataFrame({"x": x, "y": y}).dropna()
        return xy
    
    xy = coerce_two_cols(df)
    
    # If too many rows were dropped (e.g., header row only), retry with header=None
    if len(xy) == 0:
        df2 = pd.read_csv(path, header=None)
        xy = coerce_two_cols(df2)
    
    if len(xy) == 0:
        raise ValueError(f"Impossible de convertir les deux premières colonnes en numériques pour {path}")
    
    # Keep only positive y for log
    xy = xy[xy["y"] > 0]
    if len(xy) == 0:
        raise ValueError(f"Toutes les valeurs de la 2e colonne sont ≤ 0 dans {path}, log impossible.")
    return xy

curves = {}
for f in files:
    if os.path.exists(f):
        curves[Path(f).stem] = load_first_two_numeric_cols(f)

# Plot
plt.figure(figsize=(8,5))
for name, xy in curves.items():
    # Sort by x to draw clean lines
    xy_sorted = xy.sort_values("x")
    plt.plot(xy_sorted["x"].values, np.log(xy_sorted["y"].values), label=name)

plt.xlabel("Première colonne (abscisse)")
plt.ylabel("logarithme naturel de la deuxième colonne (ordonnée)")
plt.title("Courbes : x = col1, y = ln(col2)")
plt.legend()
plt.tight_layout()

out_path = "rapport/figures/courbes_log.png"
plt.savefig(out_path, dpi=150)
out_path


In [None]:
plt.figure(figsize=(8,5))
 
for name, xy in curves.items():
    xy_sorted = xy.sort_values("x")
    y_log = np.log(xy_sorted["y"].values)
    y_smooth = pd.Series(y_log).ewm(span=200, adjust=False).mean()
    plt.plot(xy_sorted["x"].values, y_smooth, label=name)
    plt.xlim(
        max(curve["x"].min() for curve in curves.values()),
        min(curve["x"].max() for curve in curves.values())
    )
plt.xlabel("Steps", fontsize=18)
plt.ylabel("ln(loss) smoothed by EMA", fontsize=18)
plt.title("")
plt.legend(fontsize=16)
plt.tight_layout()

out_path_smooth = "rapport/figures/RGB vs RGBA.png"
plt.savefig(out_path_smooth, dpi=150)
out_path_smooth


In [None]:
from layer_diffuse.models import DDIMNextTokenV1_Refactored

pipeline = DDIMNextTokenV1_Refactored.DDIMNextTokenV1PipelineRefactored()

pipeline.load_model_from_hub(run='run_2025-07-18_16-12-39',
                             epoch=80)

In [None]:
from layer_diffuse.data_loaders import ModularCharatersDataLoader
from layer_diffuse.models import DDIMNextTokenV1_Refactored
import json
# pipeline = DDIMNextTokenV1_Refactored.DDIMNextTokenV1PipelineRefactored()
dataloader = ModularCharatersDataLoader.get_modular_char_dataloader(dataset_name='QLeca/modular_characters_large',
                                                                            split='train',
                                                                            image_size=128,
                                                                            batch_size=8,
                                                                            shuffle=True,
                                                                            streaming=True,
                                                                            conversionRGBA=True)
pipeline.set_num_class_embeds(len(dataloader.vocab))


In [None]:
%pip install PIL

In [None]:
dataloader.vocab

In [None]:
import torch
import matplotlib.pyplot as plt
character_pattern = ['Tint6 Left Arm',
                     'Tint6 Right Arm',
                     'Tint6 Neck',
                     'Tint6 Head',
                     'Tint6 Left Hand',
                     'Tint6 Right Hand'
                     ]
# Create a white 128x128 RGBA image (all values = 1)
def show_image(image, prompt, step):
    plt.imshow(image.squeeze().permute(1, 2, 0).cpu().numpy())
    plt.title(f"Prompt: {prompt}, step {step}")
    plt.axis('off')
    plt.show()
    
current_image = torch.ones((1, 3, 128, 128))
current_prompt = ''
current_label = torch.tensor(dataloader.vocab.get(current_prompt, -1), dtype=torch.long).unsqueeze(0)
for i, prompt in enumerate(character_pattern):
    image_to_show = (current_image * 0.5 + 0.5).clamp(0, 1).cpu()
    show_image(image_to_show, current_prompt, i)
    current_prompt = prompt
    current_label = torch.tensor(dataloader.vocab.get(current_prompt, -1), dtype=torch.long).unsqueeze(0)
    current_image = pipeline(current_image, current_label, 200)

show_image(current_image, prompt, i)



In [None]:
output = pipeline(input_image, label,50)
output_images = (output * 0.5 + 0.5).clamp(0, 1).cpu()
show_image(output_images, prompt)

In [None]:

output = pipeline(output, label,50)

In [None]:
pipeline.list_versions()

In [None]:
from layer_diffuse.models.BaseNextTokenPipeline import BaseNextTokenPipeline
BaseNextTokenPipeline.get_model_versions('QLeca/DDIMNextTokenV1')

In [None]:
import torch
import torchvision
from torchvision.utils import make_grid

def show_image_grid(input_images, output_images, target_images):
    output_images = (output_images * 0.5 + 0.5).clamp(0, 1).cpu()
    input_images = (input_images * 0.5 + 0.5).clamp(0, 1).cpu()
    target_images = (target_images * 0.5  + 0.5).clamp(0, 1).cpu()
    concat = torch.concat([input_images, output_images, target_images])
    grid = make_grid(concat, nrow=input_images.shape[0])
    img = torchvision.transforms.ToPILImage()(grid)
    display(img)


In [None]:
for batch in dataloader:
    input_images = batch['input']
    target_images = batch['target']
    labels = batch['label']
    outputs = pipeline(input_images=input_images, 
                       class_labels=labels,
                       num_inference_steps=50)
    show_image_grid(input_images, outputs, target_images)
    break    

# Inference test Widgets

In [None]:
# Interactive Model Selection Widget
import ipywidgets as widgets
from IPython.display import display, clear_output
import wandb
import json
import os
from layer_diffuse.models import DDIMNextTokenV1_Refactored, DDPMNextTokenV1, DDPMNextTokenV2, DDPMNextTokenV3_Refactored, BaseNextTokenPipeline
from layer_diffuse.data_loaders import ModularCharatersDataLoader

# Global variables to store current configuration
current_pipeline = None
current_dataloader = None
current_config = {}

# Available model types
MODEL_TYPES = {
    "DDIM Next Token V1 (Refactored)": DDIMNextTokenV1_Refactored.DDIMNextTokenV1PipelineRefactored,
    "DDPM Next Token V1": DDPMNextTokenV1.DDPMNextTokenV1Pipeline,
    "DDPM Next Token V2": DDPMNextTokenV2.DDPMNextTokenV2Pipeline,
    "DDPM Next Token V3 (Refactored)": DDPMNextTokenV3_Refactored.DDPMNextTokenV3Pipeline,
}

HF_REPOSITORIES = {
    "DDIM Next Token V1 (Refactored)": "QLeca/DDIMNextTokenV1",
    "DDPM Next Token V1": "QLeca/DDPMNextTokenV1",
    "DDPM Next Token V2": "QLeca/DDPMNextTokenV2",
    "DDPM Next Token V3 (Refactored)": "QLeca/DDPMNextTokenV3",
}

WANDB_PROJECTS = {"DDIM Next Token V1 (Refactored)": "ddim-next-token-v1",
                  "DDPM Next Token V1": "ddpm-next-token-v1",
                  "DDPM Next Token V2": "ddpm-next-token-v2",
                  "DDPM Next Token V3 (Refactored)": "ddpm-next-token-v3"
                  }

# Store available versions for each model type
available_versions = {}

# Available datasets (will be auto-detected from wandb runs)
DATASETS = {}

# Create widgets
model_dropdown = widgets.Dropdown(
    options=list(MODEL_TYPES.keys()),
    value=list(MODEL_TYPES.keys())[0],
    description='Model Type:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='400px')
)

dataset_dropdown = widgets.Dropdown(
    options=['Select a run first'],
    value='Select a run first',
    description='Dataset:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='400px'),
    disabled=True
)

run_name_dropdown = widgets.Dropdown(
    options=['Select a model type first'],
    value='Select a model type first',
    description='Run Name:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='400px'),
    disabled=True
)

epoch_dropdown = widgets.Dropdown(
    options=['Select a run first'],
    value='Select a run first',
    description='Epoch:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='400px'),
    disabled=True
)

# Additional dataset parameters
batch_size_slider = widgets.IntSlider(
    value=8,
    min=1,
    max=32,
    step=1,
    description='Batch Size:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='400px')
)

image_size_dropdown = widgets.Dropdown(
    options=[64, 128, 256, 512],
    value=128,
    description='Image Size:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='400px')
)

streaming_checkbox = widgets.Checkbox(
    value=True,
    description='Use Streaming',
    style={'description_width': 'initial'},
)

# Output widget for status messages
output = widgets.Output()

# Function to get dataset info from wandb run
def get_dataset_from_run(model_type, run_name):
    """Get dataset information from a specific wandb run"""
    try:
        import wandb
        api = wandb.Api()
        
        # Get the repository for this model type
        repo = WANDB_PROJECTS[model_type]
        project_name = repo.split('/')[-1]  # Extract project name from repo
        
        # Find the specific run
        runs = api.runs(project_name)
        target_run = None
        
        for run in runs:
            if run.name == run_name:
                target_run = run
                break
        
        if target_run is None:
            return None
            
        # Get dataset information from run config
        config = target_run.config
        dataset_name = config.get('dataset', {}).get('name', None)
        
        if dataset_name:
            return dataset_name
        else:
            return None
            
    except Exception as e:
        print(f"Error getting dataset from run {run_name}: {e}")
        return None
# Function to get available runs and epochs from pipeline.list_versions()
def get_available_versions(model_type):
    """Get available runs and epochs for a specific model type"""
    try:
        versions = BaseNextTokenPipeline.BaseNextTokenPipeline.get_model_versions(HF_REPOSITORIES[model_type])
        return versions
    except Exception as e:
        print(f"Error getting versions for {model_type}: {e}")
        return []

# Function to update available runs when model type changes
def update_runs_for_model(change=None):
    """Update available runs dropdown based on selected model type"""
    model_type = model_dropdown.value
    
    with output:
        clear_output()
        print(f"🔄 Loading available runs for {model_type}...")
    
    try:
        # Get versions for the selected model type
        versions = get_available_versions(model_type)
        available_versions[model_type] = versions
        
        if versions:
            run_names = [v['name'] for v in versions]
            run_name_dropdown.options = run_names
            run_name_dropdown.value = run_names[0]
            run_name_dropdown.disabled = False
            
            # Update epochs for the first run
            update_epochs_for_run({'new': run_names[0]})
            
            # Update dataset for the first run
            update_dataset_for_run({'new': run_names[0]})
            
            with output:
                clear_output()
                print(f"✅ Found {len(versions)} runs for {model_type}")
                for i, version in enumerate(versions[:3]):
                    epochs_str = f"[{', '.join(map(str, version['epochs'][:5]))}{'...' if len(version['epochs']) > 5 else ''}]"
                    print(f"  {i+1}. {version['name']} - epochs: {epochs_str}")
                if len(versions) > 3:
                    print(f"  ... and {len(versions) - 3} more runs")
        else:
            run_name_dropdown.options = ['No runs found']
            run_name_dropdown.value = 'No runs found'
            run_name_dropdown.disabled = True
            epoch_dropdown.options = ['No epochs available']
            epoch_dropdown.value = 'No epochs available'
            epoch_dropdown.disabled = True
            dataset_dropdown.options = ['No dataset available']
            dataset_dropdown.value = 'No dataset available'
            dataset_dropdown.disabled = True
            
            with output:
                clear_output()
                print(f"❌ No runs found for {model_type}")
                
    except Exception as e:
        with output:
            clear_output()
            print(f"❌ Error loading runs for {model_type}: {e}")
        
        run_name_dropdown.options = ['Error loading runs']
        run_name_dropdown.value = 'Error loading runs'
        run_name_dropdown.disabled = True
        epoch_dropdown.options = ['Error loading epochs']
        epoch_dropdown.value = 'Error loading epochs'
        epoch_dropdown.disabled = True
        dataset_dropdown.options = ['Error loading dataset']
        dataset_dropdown.value = 'Error loading dataset'
        dataset_dropdown.disabled = True

# Function to update dataset when run name changes
def update_dataset_for_run(change):
    """Update dataset dropdown based on selected run"""
    model_type = model_dropdown.value
    run_name = change['new'] if isinstance(change, dict) and 'new' in change else change
    
    try:
        # Get dataset from wandb run
        dataset_name = get_dataset_from_run(model_type, run_name)
        
        if dataset_name:
            # Extract a readable name from the dataset path
            dataset_display_name = dataset_name.split('/')[-1].replace('_', ' ').title()
            
            dataset_dropdown.options = [dataset_display_name]
            dataset_dropdown.value = dataset_display_name
            dataset_dropdown.disabled = True  # Keep it disabled since it's auto-detected
            
            # Store the actual dataset name for use in loading
            DATASETS[dataset_display_name] = dataset_name
            
            with output:
                current_text = output.outputs[-1]['text'] if output.outputs else ""
                clear_output()
                print(current_text)
                print(f"📊 Auto-detected dataset: {dataset_name}")
        else:
            dataset_dropdown.options = ['Dataset not found in run']
            dataset_dropdown.value = 'Dataset not found in run'
            dataset_dropdown.disabled = True
            
            with output:
                current_text = output.outputs[-1]['text'] if output.outputs else ""
                clear_output()
                print(current_text)
                print(f"⚠️  Could not detect dataset for run {run_name}")
                
    except Exception as e:
        dataset_dropdown.options = ['Error loading dataset']
        dataset_dropdown.value = 'Error loading dataset'
        dataset_dropdown.disabled = True
        
        with output:
            current_text = output.outputs[-1]['text'] if output.outputs else ""
            clear_output()
            print(current_text)
            print(f"❌ Error loading dataset for {run_name}: {e}")

# Function to update available epochs when run name changes
def update_epochs_for_run(change):
    """Update available epochs dropdown based on selected run"""
    model_type = model_dropdown.value
    run_name = change['new'] if isinstance(change, dict) and 'new' in change else change
    
    if model_type not in available_versions:
        return
    
    try:
        # Find the selected run in the versions
        versions = available_versions[model_type]
        selected_version = None
        
        for version in versions:
            if version['name'] == run_name:
                selected_version = version
                break
        
        if selected_version and selected_version['epochs']:
            epochs = [str(epoch) for epoch in sorted(selected_version['epochs'], reverse=True)]
            epoch_dropdown.options = epochs
            epoch_dropdown.value = epochs[0]  # Select the highest epoch by default
            epoch_dropdown.disabled = False
            
            with output:
                clear_output()
                print(f"✅ Available epochs for {run_name}: {', '.join(epochs)}")
        else:
            epoch_dropdown.options = ['No epochs available']
            epoch_dropdown.value = 'No epochs available'
            epoch_dropdown.disabled = True
            
            with output:
                clear_output()
                print(f"❌ No epochs found for run {run_name}")
                
    except Exception as e:
        with output:
            clear_output()
            print(f"❌ Error loading epochs for {run_name}: {e}")
        
        epoch_dropdown.options = ['Error loading epochs']
        epoch_dropdown.value = 'Error loading epochs'
        epoch_dropdown.disabled = True

# Button to load configuration
load_button = widgets.Button(
    description='🚀 Load Configuration',
    button_style='primary',
    layout=widgets.Layout(width='200px')
)

# Button to refresh available runs for current model
refresh_runs_button = widgets.Button(
    description='🔄 Refresh Runs',
    button_style='info',
    layout=widgets.Layout(width='150px')
)

# Function to refresh runs for current model
def refresh_current_model_runs(b=None):
    """Refresh runs for the currently selected model type"""
    update_runs_for_model()

# Function to load the selected configuration
def load_configuration(b):
    global current_pipeline, current_dataloader, current_config
    
    with output:
        clear_output()
        print("🔄 Loading configuration...")
        
        try:
            # Get selected values
            model_type = model_dropdown.value
            dataset_name = DATASETS[dataset_dropdown.value]
            run_name = run_name_dropdown.value
            epoch = int(epoch_dropdown.value) if epoch_dropdown.value.isdigit() else 0
            batch_size = batch_size_slider.value
            image_size = image_size_dropdown.value
            streaming = streaming_checkbox.value
            
            print(f"📋 Configuration:")
            print(f"  - Model: {model_type}")
            print(f"  - Dataset: {dataset_dropdown.value} ({dataset_name})")
            print(f"  - Run: {run_name}")
            print(f"  - Epoch: {epoch}")
            print(f"  - Batch size: {batch_size}")
            print(f"  - Image size: {image_size}")
            print(f"  - Streaming: {streaming}")
            
            # Validate that we have a valid dataset
            if dataset_name == 'Select a run first' or dataset_name == 'Dataset not found in run':
                print("❌ No valid dataset detected. Please select a run with dataset information.")
                return
            
            # 1. Initialize pipeline
            print(f"\n🤖 Initializing {model_type} pipeline...")
            pipeline_class = MODEL_TYPES[model_type]
            current_pipeline = pipeline_class()
            
            # 2. Load vocabulary
            print(f"📚 Loading vocabulary...")
            vocab_file = "layer_diffuse/vocab.json"
            with open(vocab_file, 'r') as f:
                vocab = json.load(f)
            print(f"✅ Loaded vocabulary with {len(vocab)} classes")
            
            # 3. Load model from hub
            if run_name:
                print(f"📥 Loading model from hub...")
                current_pipeline.load_model_from_hub(run=run_name, epoch=epoch)
                current_pipeline.set_num_class_embeds(len(vocab))
                print(f"✅ Model loaded successfully")
            else:
                print("⚠️  No run name specified - using default model weights")
            
            # 4. Create dataloader
            print(f"🗂️  Creating dataloader...")
            current_dataloader = ModularCharatersDataLoader.get_modular_char_dataloader(
                dataset_name=dataset_name,
                split='train',
                image_size=image_size,
                batch_size=batch_size,
                shuffle=True,
                streaming=streaming,
                conversionRGBA=True,
                vocab=vocab
            )
            print(f"✅ Dataloader created successfully")
            
            # 5. Store configuration
            current_config = {
                'model_type': model_type,
                'dataset_name': dataset_name,
                'run_name': run_name,
                'epoch': epoch,
                'batch_size': batch_size,
                'image_size': image_size,
                'streaming': streaming,
                'vocab_size': len(vocab)
            }
            
            print(f"\n🎉 Configuration loaded successfully!")
            print(f"✅ Pipeline: {type(current_pipeline).__name__}")
            print(f"✅ Dataloader: Ready with {len(vocab)} classes")
            print(f"\n💡 You can now use 'current_pipeline' and 'current_dataloader' for inference!")
            
        except Exception as e:
            print(f"❌ Error loading configuration: {e}")
            import traceback
            traceback.print_exc()

# Attach event handlers
load_button.on_click(load_configuration)
refresh_runs_button.on_click(refresh_current_model_runs)

# Attach observers for dropdown changes
model_dropdown.observe(update_runs_for_model, names='value')
run_name_dropdown.observe(update_epochs_for_run, names='value')
run_name_dropdown.observe(update_dataset_for_run, names='value')

# Create the interface
print("🎛️ Model Configuration Widget")
print("=" * 50)

# Display all widgets
display(widgets.VBox([
    widgets.HTML("<h3>🎛️ Model Configuration</h3>"),
    model_dropdown,
    dataset_dropdown,
    run_name_dropdown,
    epoch_dropdown,
    widgets.HTML("<h4>📊 Dataset Parameters</h4>"),
    batch_size_slider,
    image_size_dropdown,
    streaming_checkbox,
    widgets.HTML("<h4>🔧 Actions</h4>"),
    widgets.HBox([load_button, refresh_runs_button]),
    output
]))

# Auto-load available runs for the default model on startup
update_runs_for_model()

In [None]:
# Test Inference with Selected Configuration
import torch
import torchvision
from torchvision.utils import make_grid
from IPython.display import display
import matplotlib.pyplot as plt

def test_inference(num_samples=1, num_inference_steps=50):
    """Test inference with the currently loaded configuration"""
    
    if current_pipeline is None:
        print("❌ No pipeline loaded! Please load a configuration first.")
        return
    
    if current_dataloader is None:
        print("❌ No dataloader available! Please load a configuration first.")
        return
    
    print(f"🧪 Testing inference with current configuration...")
    print(f"📋 Config: {current_config['model_type']}")
    print(f"🗂️  Dataset: {current_config['dataset_name']}")
    print(f"🎯 Run: {current_config['run_name']} (epoch {current_config['epoch']})")
    print(f"🔢 Inference steps: {num_inference_steps}")
    
    try:
        # Get a batch from the dataloader
        sample_count = 0
        for batch in current_dataloader:
            if sample_count >= num_samples:
                break
                
            input_images = batch['input'][:num_samples]
            target_images = batch['target'][:num_samples]
            labels = batch['label'][:num_samples]
            
            print(f"\n🎨 Generating {input_images.shape[0]} image(s)...")
            print(f"📏 Input shape: {input_images.shape}")
            print(f"🏷️  Label shape: {labels.shape}")
            
            # Run inference
            with torch.no_grad():
                outputs = current_pipeline(
                    input_images=input_images, 
                    class_labels=labels,
                    num_inference_steps=num_inference_steps
                )
            
            print(f"✅ Generated output shape: {outputs.shape}")
            
            # Display results
            show_inference_results(input_images, outputs, target_images, batch.get('prompt', ['Unknown'] * len(input_images)))
            
            sample_count += input_images.shape[0]
            
        print(f"\n🎉 Inference completed successfully!")
        
    except Exception as e:
        print(f"❌ Error during inference: {e}")
        import traceback
        traceback.print_exc()

def show_inference_results(input_images, output_images, target_images, prompts=None):
    """Display inference results in a grid"""
    
    # Denormalize images (from [-1, 1] to [0, 1])
    def denormalize(tensor):
        return (tensor * 0.5 + 0.5).clamp(0, 1).cpu()
    
    input_norm = denormalize(input_images)
    output_norm = denormalize(output_images)
    target_norm = denormalize(target_images)
    
    # Create comparison grid
    batch_size = input_images.shape[0]
    
    # Concatenate all images: [input1, output1, target1, input2, output2, target2, ...]
    all_images = []
    for i in range(batch_size):
        all_images.extend([input_norm[i], output_norm[i], target_norm[i]])
    
    # Stack into tensor
    grid_tensor = torch.stack(all_images)
    
    # Create grid with 3 columns (input, output, target)
    grid = make_grid(grid_tensor, nrow=3, padding=2, pad_value=1.0)
    
    # Convert to PIL and display
    img = torchvision.transforms.ToPILImage()(grid)
    
    # Create figure with labels
    fig, ax = plt.subplots(1, 1, figsize=(12, 4 * batch_size))
    ax.imshow(img)
    ax.axis('off')
    
    # Add column headers
    ax.text(img.width * 0.17, -20, 'Input', ha='center', va='bottom', fontsize=12, fontweight='bold')
    ax.text(img.width * 0.50, -20, 'Generated', ha='center', va='bottom', fontsize=12, fontweight='bold')
    ax.text(img.width * 0.83, -20, 'Target', ha='center', va='bottom', fontsize=12, fontweight='bold')
    
    # Add prompts if available
    if prompts:
        for i, prompt in enumerate(prompts):
            y_pos = (i + 0.5) * (img.height / batch_size)
            ax.text(-50, y_pos, f"'{prompt}'", ha='right', va='center', fontsize=10, 
                   rotation=90, fontweight='bold')
    
    plt.title(f"Inference Results - {current_config.get('model_type', 'Unknown Model')}", 
              fontsize=14, fontweight='bold', pad=30)
    plt.tight_layout()
    plt.show()

# Interactive controls for inference testing
inference_steps_slider = widgets.IntSlider(
    value=50,
    min=1,
    max=100,
    step=1,
    description='Inference Steps:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='300px')
)

num_samples_slider = widgets.IntSlider(
    value=1,
    min=1,
    max=8,
    step=1,
    description='Num Samples:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='300px')
)

test_button = widgets.Button(
    description='🎨 Test Inference',
    button_style='success',
    layout=widgets.Layout(width='150px')
)

def on_test_click(b):
    test_inference(
        num_samples=num_samples_slider.value,
        num_inference_steps=inference_steps_slider.value
    )

test_button.on_click(on_test_click)

# Display inference controls
print("\n" + "="*50)
print("🎨 Inference Testing")
display(widgets.VBox([
    widgets.HTML("<h3>🎨 Test Inference</h3>"),
    inference_steps_slider,
    num_samples_slider,
    test_button
]))

# Quick test function
def quick_test():
    """Quick inference test with default parameters"""
    test_inference(num_samples=2, num_inference_steps=20)

print("\n💡 Usage:")
print("1. Configure your model using the widget above")
print("2. Click '🚀 Load Configuration' to initialize")
print("3. Use the inference controls or call quick_test() for a fast test")