# Inference Notebook: UNet Polygon Coloring
This notebook demonstrates how to load a trained UNet model and perform inference on grayscale polygon images with a specified color.

In [None]:
import torch
from model import create_model
from dataset import preprocess_image
from PIL import Image
import matplotlib.pyplot as plt

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = create_model(n_channels=1, n_classes=3, num_colors=10).to(device)
checkpoint = torch.load('checkpoints/best_model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

In [None]:
OLOR_MAP = {
    'red': 0,
    'green': 1,
    'blue': 2,
    'yellow': 3,
    'purple': 4,
    'cyan': 5,
    'orange': 6,
    'black': 7,
    'white': 8,
    'pink': 9
}

def run_inference(image_path, color_name):
    if color_name not in COLOR_MAP:
        raise ValueError(f"Unknown color name '{color_name}'. Valid options: {list(COLOR_MAP.keys())}")

    input_image = preprocess_image(image_path)
    input_tensor = input_image.unsqueeze(0).to(device)
    color_index = torch.tensor([COLOR_MAP[color_name]], dtype=torch.long).to(device)

    with torch.no_grad():
        prediction = model(input_tensor, color_index)
        prediction = prediction.squeeze(0).cpu().numpy().transpose(1, 2, 0)
        prediction = (prediction * 255).clip(0, 255).astype('uint8')

    fig, axs = plt.subplots(1, 2, figsize=(8, 4))
    axs[0].imshow(Image.open(image_path).convert('L'), cmap='gray')
    axs[0].set_title('Input Polygon')
    axs[0].axis('off')
    axs[1].imshow(prediction)
    axs[1].set_title(f'Output: {color_name}')
    axs[1].axis('off')
    plt.show()

In [None]:
run_inference('path/to/grayscale_polygon.png', 'red')
