# Polygon Color UNet - Inference and Testing

This notebook demonstrates how to use the trained UNet model to generate colored polygons.

## Setup and Imports

In [None]:
import sys
import os

# Add src directory to path
sys.path.append('src')

import torch
import torch.nn.functional as F
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image
import json

# Import our modules
from model import PolygonColorUNet
from dataset import COLOR_NAMES, COLOR_TO_IDX, create_color_onehot
from utils import (
    visualize_results, 
    preprocess_image_for_inference, 
    postprocess_output,
    create_sample_polygon,
    create_color_legend
)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set matplotlib style
plt.style.use('default')
%matplotlib inline

## Model Loading

Load the trained model from checkpoint.

In [None]:
# Initialize model
model = PolygonColorUNet(n_colors=len(COLOR_NAMES)).to(device)

# Load trained weights
checkpoint_path = 'checkpoints/best_model.pth'  # Update this path as needed

if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded model from {checkpoint_path}")
    print(f"Best IoU: {checkpoint.get('best_iou', 'N/A'):.4f}")
else:
    print(f"Checkpoint not found at {checkpoint_path}")
    print("Please train the model first or update the checkpoint path")

# Set model to evaluation mode
model.eval()
print("Model loaded and set to evaluation mode")

## Available Colors

Display all available colors that the model can generate.

In [None]:
print("Available colors:")
for i, color in enumerate(COLOR_NAMES):
    print(f"{i}: {color}")

# Show color legend
create_color_legend()

## Helper Functions for Inference

In [None]:
def predict_colored_polygon(model, input_image, color_name, device):
    """
    Predict colored polygon from input image and color name
    
    Args:
        model: Trained UNet model
        input_image: Input polygon image (can be path, numpy array, or tensor)
        color_name: Name of the desired color
        device: Device to run inference on
    
    Returns:
        Predicted colored polygon as numpy array
    """
    # Preprocess input image
    if isinstance(input_image, str):
        # If path is provided
        input_tensor = preprocess_image_for_inference(input_image)
    elif isinstance(input_image, np.ndarray):
        # If numpy array is provided
        if input_image.ndim == 2:  # Grayscale
            input_image = input_image.astype(np.float32) / 255.0
            input_tensor = torch.from_numpy(input_image).unsqueeze(0).unsqueeze(0)
        else:
            raise ValueError("Input image should be grayscale (2D array)")
    elif torch.is_tensor(input_image):
        input_tensor = input_image
    else:
        raise ValueError("Input image should be path, numpy array, or tensor")
    
    # Create color one-hot encoding
    if color_name not in COLOR_TO_IDX:
        raise ValueError(f"Unknown color: {color_name}. Available colors: {list(COLOR_TO_IDX.keys())}")
    
    color_onehot = create_color_onehot(color_name)
    color_tensor = torch.from_numpy(color_onehot).unsqueeze(0)  # Add batch dimension
    
    # Move to device
    input_tensor = input_tensor.to(device)
    color_tensor = color_tensor.to(device)
    
    # Inference
    with torch.no_grad():
        output = model(input_tensor, color_tensor)
    
    # Postprocess output
    result = postprocess_output(output)
    
    return result, input_tensor.cpu(), output.cpu()

def batch_inference(model, input_images, colors, device):
    """
    Perform batch inference on multiple images
    """
    results = []
    
    for img, color in zip(input_images, colors):
        result, input_tensor, output_tensor = predict_colored_polygon(model, img, color, device)
        results.append((result, input_tensor, output_tensor, color))
    
    return results

## Example 1: Generate Sample Polygons and Test

Create some sample polygons and test the model with different colors.

In [None]:
# Create sample polygons
shapes = ['triangle', 'square', 'pentagon', 'hexagon', 'octagon']
test_colors = ['red', 'blue', 'yellow', 'green', 'purple']

# Generate and visualize results
fig, axes = plt.subplots(len(shapes), 3, figsize=(12, 15))

for i, (shape, color) in enumerate(zip(shapes, test_colors)):
    # Create sample polygon
    polygon = create_sample_polygon(shape, size=(256, 256))
    
    # Predict colored version
    try:
        colored_result, input_tensor, output_tensor = predict_colored_polygon(
            model, polygon, color, device
        )
        
        # Display results
        axes[i, 0].imshow(polygon, cmap='gray')
        axes[i, 0].set_title(f'Input: {shape.capitalize()}')
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(colored_result)
        axes[i, 1].set_title(f'Predicted: {color.capitalize()}')
        axes[i, 1].axis('off')
        
        # Show a different color for comparison
        alt_color = 'cyan' if color != 'cyan' else 'magenta'
        alt_result, _, _ = predict_colored_polygon(model, polygon, alt_color, device)
        axes[i, 2].imshow(alt_result)
        axes[i, 2].set_title(f'Alternative: {alt_color.capitalize()}')
        axes[i, 2].axis('off')
        
    except Exception as e:
        print(f"Error processing {shape} with {color}: {e}")
        for j in range(3):
            axes[i, j].text(0.5, 0.5, 'Error', ha='center', va='center', transform=axes[i, j].transAxes)
            axes[i, j].axis('off')

plt.tight_layout()
plt.show()

## Example 2: Test with Validation Dataset

If you have the validation dataset available, test on real data.

In [None]:
# Check if validation data exists
val_data_path = 'data/dataset/validation/data.json'

if os.path.exists(val_data_path):
    with open(val_data_path, 'r') as f:
        val_data = json.load(f)
    
    print(f"Found {len(val_data)} validation samples")
    
    # Test on first few samples
    num_samples = min(5, len(val_data))
    
    fig, axes = plt.subplots(num_samples, 4, figsize=(16, 4*num_samples))
    
    for i in range(num_samples):
        sample = val_data[i]
        
        # Load input and target images
        input_path = os.path.join('data/dataset/validation/inputs', sample['input'])
        target_path = os.path.join('data/dataset/validation/outputs', sample['output'])
        
        if os.path.exists(input_path) and os.path.exists(target_path):
            # Load images
            input_img = cv2.imread(input_path, cv2.IMREAD_GRAYSCALE)
            target_img = cv2.imread(target_path, cv2.IMREAD_COLOR)
            target_img = cv2.cvtColor(target_img, cv2.COLOR_BGR2RGB)
            
            # Predict
            try:
                pred_result, _, _ = predict_colored_polygon(
                    model, input_img, sample['color'], device
                )
                
                # Display
                axes[i, 0].imshow(input_img, cmap='gray')
                axes[i, 0].set_title('Input')
                axes[i, 0].axis('off')
                
                axes[i, 1].imshow(target_img)
                axes[i, 1].set_title(f'Target ({sample["color"]})') 
                axes[i, 1].axis('off')
                
                axes[i, 2].imshow(pred_result)
                axes[i, 2].set_title(f'Predicted ({sample["color"]})') 
                axes[i, 2].axis('off')
                
                # Difference
                target_norm = target_img.astype(np.float32) / 255.0
                diff = np.abs(target_norm - pred_result)
                axes[i, 3].imshow(diff)
                axes[i, 3].set_title('Difference')
                axes[i, 3].axis('off')
                
            except Exception as e:
                print(f"Error processing sample {i}: {e}")
        else:
            print(f"Sample {i} files not found")
    
    plt.tight_layout()
    plt.show()
    
else:
    print("Validation dataset not found. Please ensure the dataset is properly extracted.")

## Example 3: Interactive Testing

Create an interactive interface to test different color combinations.

In [None]:
# Test all colors on a single shape
test_shape = 'hexagon'
polygon = create_sample_polygon(test_shape, size=(256, 256))

# Create grid for all colors
n_colors = len(COLOR_NAMES)
cols = 5
rows = (n_colors + cols - 1) // cols

fig, axes = plt.subplots(rows, cols, figsize=(15, 3*rows))
axes = axes.flatten()

for i, color in enumerate(COLOR_NAMES):
    try:
        colored_result, _, _ = predict_colored_polygon(model, polygon, color, device)
        axes[i].imshow(colored_result)
        axes[i].set_title(f'{color.capitalize()}')
        axes[i].axis('off')
    except Exception as e:
        print(f"Error with color {color}: {e}")
        axes[i].text(0.5, 0.5, 'Error', ha='center', va='center', transform=axes[i].transAxes)
        axes[i].axis('off')

# Hide unused subplots
for i in range(n_colors, len(axes)):
    axes[i].axis('off')

plt.suptitle(f'All Colors Applied to {test_shape.capitalize()}')
plt.tight_layout()
plt.show()

## Example 4: Custom Image Upload

Test with your own polygon images.

In [None]:
# Function to test custom image
def test_custom_image(image_path, color_name):
    """Test with a custom uploaded image"""
    if not os.path.exists(image_path):
        print(f"Image not found: {image_path}")
        return
    
    try:
        # Load and preprocess
        input_img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        if input_img is None:
            print(f"Could not load image: {image_path}")
            return
        
        # Resize if needed
        if input_img.shape != (256, 256):
            input_img = cv2.resize(input_img, (256, 256))
        
        # Predict
        colored_result, _, _ = predict_colored_polygon(model, input_img, color_name, device)
        
        # Visualize
        fig, axes = plt.subplots(1, 2, figsize=(10, 5))
        
        axes[0].imshow(input_img, cmap='gray')
        axes[0].set_title('Input Image')
        axes[0].axis('off')
        
        axes[1].imshow(colored_result)
        axes[1].set_title(f'Colored Result ({color_name})')
        axes[1].axis('off')
        
        plt.tight_layout()
        plt.show()
        
    except Exception as e:
        print(f"Error processing image: {e}")

# Example usage (update the path to your image)
# test_custom_image('path/to/your/polygon.png', 'blue')

print("To test your own image, use:")
print("test_custom_image('path/to/your/image.png', 'desired_color')")
print(f"Available colors: {', '.join(COLOR_NAMES)}")

## Model Analysis and Performance

Analyze model performance and failure cases.

In [None]:
# Analyze model behavior with different shapes and colors
def analyze_model_performance():
    """Analyze model performance across different shapes and colors"""
    shapes = ['triangle', 'square', 'pentagon', 'hexagon', 'octagon']
    colors = COLOR_NAMES
    
    # Test systematic combinations
    results = {}
    
    print("Testing model on all shape-color combinations...")
    
    for shape in shapes:
        results[shape] = {}
        polygon = create_sample_polygon(shape, size=(256, 256))
        
        for color in colors:
            try:
                colored_result, _, output_tensor = predict_colored_polygon(
                    model, polygon, color, device
                )
                
                # Simple quality metric: check if output has the expected color
                # This is a simplified check - in practice, you'd want more sophisticated metrics
                mean_color = colored_result.mean(axis=(0,1))
                results[shape][color] = {
                    'success': True,
                    'mean_rgb': mean_color.tolist()
                }
                
            except Exception as e:
                results[shape][color] = {
                    'success': False,
                    'error': str(e)
                }
    
    # Summary
    total_tests = len(shapes) * len(colors)
    successful_tests = sum([
        sum([results[shape][color]['success'] for color in colors])
        for shape in shapes
    ])
    
    print(f"\nPerformance Summary:")
    print(f"Total tests: {total_tests}")
    print(f"Successful: {successful_tests}")
    print(f"Success rate: {successful_tests/total_tests*100:.1f}%")
    
    return results

# Run analysis (uncomment to run)
# performance_results = analyze_model_performance()

## Failure Cases and Edge Cases

Test the model on challenging cases to understand its limitations.

In [None]:
# Test edge cases
def test_edge_cases():
    """Test model on edge cases"""
    
    # Create challenging test cases
    edge_cases = []
    
    # 1. Very small polygon
    small_polygon = create_sample_polygon('triangle', size=(256, 256), thickness=1)
    # Make it smaller
    small_polygon = cv2.resize(small_polygon, (64, 64))
    small_polygon = cv2.resize(small_polygon, (256, 256))
    edge_cases.append(('Small Triangle', small_polygon))
    
    # 2. Thick polygon
    thick_polygon = create_sample_polygon('square', size=(256, 256), thickness=20)
    edge_cases.append(('Thick Square', thick_polygon))
    
    # 3. Noisy polygon
    noisy_polygon = create_sample_polygon('pentagon', size=(256, 256))
    noise = np.random.randint(0, 50, size=noisy_polygon.shape, dtype=np.uint8)
    noisy_polygon = np.clip(noisy_polygon.astype(int) + noise, 0, 255).astype(np.uint8)
    edge_cases.append(('Noisy Pentagon', noisy_polygon))
    
    # 4. Multiple shapes (this will likely fail as model expects single polygon)
    multi_shape = np.zeros((256, 256), dtype=np.uint8)
    triangle = create_sample_polygon('triangle', size=(128, 128))
    square = create_sample_polygon('square', size=(128, 128))
    multi_shape[50:178, 50:178] = triangle
    multi_shape[128:256, 128:256] = square
    edge_cases.append(('Multiple Shapes', multi_shape))
    
    # Test each edge case
    fig, axes = plt.subplots(len(edge_cases), 3, figsize=(12, 4*len(edge_cases)))
    
    for i, (case_name, test_polygon) in enumerate(edge_cases):
        # Test with red color
        try:
            colored_result, _, _ = predict_colored_polygon(model, test_polygon, 'red', device)
            
            axes[i, 0].imshow(test_polygon, cmap='gray')
            axes[i, 0].set_title(f'Input: {case_name}')
            axes[i, 0].axis('off')
            
            axes[i, 1].imshow(colored_result)
            axes[i, 1].set_title('Prediction (Red)')
            axes[i, 1].axis('off')
            
            # Test with another color
            colored_result2, _, _ = predict_colored_polygon(model, test_polygon, 'blue', device)
            axes[i, 2].imshow(colored_result2)
            axes[i, 2].set_title('Prediction (Blue)')
            axes[i, 2].axis('off')
            
        except Exception as e:
            print(f"Error with {case_name}: {e}")
            for j in range(3):
                axes[i, j].text(0.5, 0.5, f'Error\n{case_name}', 
                               ha='center', va='center', transform=axes[i, j].transAxes)
                axes[i, j].axis('off')
    
    plt.tight_layout()
    plt.show()

# Run edge case tests
test_edge_cases()

## Conclusion

This notebook demonstrated:

1. **Model Loading**: How to load the trained UNet model
2. **Basic Inference**: Single image prediction with different colors
3. **Batch Processing**: Testing multiple images and colors
4. **Validation Testing**: Comparison with ground truth (if available)
5. **Edge Cases**: Understanding model limitations

### Key Observations:

- The model performs well on clean, well-defined polygons
- Color conditioning works effectively across different polygon shapes
- Performance may degrade on noisy inputs or complex multi-shape images
- The FiLM conditioning mechanism successfully guides color generation

### Future Improvements:

1. **Data Augmentation**: More diverse training data with various noise levels
2. **Architecture**: Experiment with attention mechanisms or deeper networks
3. **Loss Functions**: Perceptual loss or adversarial training for better quality
4. **Multi-object Support**: Extend to handle multiple polygons
5. **Color Interpolation**: Support for custom colors or color mixing