In [None]:
import sys
from pathlib import Path
import numpy as np
import torch
from PIL import Image
import matplotlib.pyplot as plt
from ipywidgets import interact, FloatSlider, Dropdown
import warnings
warnings.filterwarnings('ignore')

# Add parent directory to path
sys.path.append(str(Path.cwd().parent))

from src.config import Config
from src.utils.inference import ColorizationInference
from src.utils.video import colorize_video
from src.training.metrics import PerceptualMetrics


In [None]:
# Load configuration
cfg = Config("../configs/default.yaml")

# Path to your trained checkpoint
checkpoint_path = Path("../artifacts/food101_color/checkpoints/best_val_loss.pt")

# Initialize inference manager
print(f"Loading model from {checkpoint_path}")
print(f"Using device: {cfg.device}")

inference_mgr = ColorizationInference(
    model_path=checkpoint_path,
    centers_path=cfg.centers_path,
    config=cfg.config,
    device=cfg.device
)

print("✓ Inference manager ready!")

In [None]:
# Choose a test image
test_images_dir = Path("../data/Food-101/images")
test_img = test_images_dir / "apple_pie" / "1005649.jpg"  # Change this path as needed

# Load original image
original = Image.open(test_img).convert('RGB')

@interact(temperature=FloatSlider(min=0.1, max=1.5, step=0.05, value=0.42, description='Temperature'))
def colorize_interactive(temperature):
    """Interactive colorization with temperature control."""
    result = inference_mgr.colorize_image(test_img, temperature=temperature, return_entropy=True)

    fig, axes = plt.subplots(1, 4, figsize=(20, 5))

    axes[0].imshow(result['L'], cmap='gray')
    axes[0].set_title('Input (Grayscale)', fontsize=14)
    axes[0].axis('off')

    axes[1].imshow(result['rgb'])
    axes[1].set_title(f'Colorized (T={temperature:.2f})', fontsize=14)
    axes[1].axis('off')

    axes[2].imshow(original)
    axes[2].set_title('Ground Truth', fontsize=14)
    axes[2].axis('off')

    axes[3].imshow(result['entropy'], cmap='hot')
    axes[3].set_title('Prediction Uncertainty', fontsize=14)
    axes[3].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
# Select multiple test images from different categories
test_samples = [
    test_images_dir / "apple_pie" / "1005649.jpg",
    test_images_dir / "pizza" / "1008104.jpg",
    test_images_dir / "sushi" / "1021019.jpg",
    test_images_dir / "hamburger" / "1015969.jpg",
    test_images_dir / "ice_cream" / "1020775.jpg",
    test_images_dir / "chocolate_cake" / "1007307.jpg",
]
# Filter existing files
test_samples = [p for p in test_samples if p.exists()][:6]
# Colorize all images
results = []
for img_path in test_samples:
    result = inference_mgr.colorize_image(img_path, temperature=0.42)
    results.append({
        'path': img_path,
        'L': result['L'],
        'rgb': result['rgb'],
        'gt': Image.open(img_path).convert('RGB')
    })
# Create comparison grid
n_images = len(results)
fig, axes = plt.subplots(n_images, 3, figsize=(12, 4*n_images))
if n_images == 1:
    axes = axes[np.newaxis, :]
for i, result in enumerate(results):
    axes[i, 0].imshow(result['L'], cmap='gray')
    axes[i, 0].set_title('Input' if i == 0 else '', fontsize=12)
    axes[i, 0].axis('off')
    axes[i, 1].imshow(result['rgb'])
    axes[i, 1].set_title('Predicted' if i == 0 else '', fontsize=12)
    axes[i, 1].axis('off')
    axes[i, 2].imshow(result['gt'])
    axes[i, 2].set_title('Ground Truth' if i == 0 else '', fontsize=12)
    axes[i, 2].axis('off')
plt.tight_layout()
plt.show()

In [None]:

# Compute metrics for test images
metrics_computer = PerceptualMetrics(
    device=cfg.device,
    use_lpips=True,
    use_ssim=True,
    use_psnr=True
)
print("Computing metrics for test images...\n")
all_metrics = []
for result in results:
    # Convert to tensors
    pred_rgb = torch.from_numpy(result['rgb']).permute(2, 0, 1).unsqueeze(0).to(cfg.device)
    gt_rgb = torch.from_numpy(np.array(result['gt']) / 255.0).permute(2, 0, 1).unsqueeze(0).to(cfg.device)
    # Compute metrics
    metrics = metrics_computer.compute_all(pred_rgb, gt_rgb)
    # Add filename
    metrics['image'] = result['path'].parent.name + '/' + result['path'].name
    all_metrics.append(metrics)
# Display as table
import pandas as pd
df = pd.DataFrame(all_metrics)
df = df[['image', 'lpips', 'ssim', 'psnr']]  # Reorder columns
print(df.to_string(index=False))
print(f"\nAverage Metrics:")
print(f"  LPIPS: {df['lpips'].mean():.4f}")
print(f"  SSIM:  {df['ssim'].mean():.4f}")
print(f"  PSNR:  {df['psnr'].mean():.2f} dB")


In [None]:
# Compare different temperatures side-by-side
temperatures = [0.1, 0.3, 0.42, 0.7, 1.0, 1.5]
sample_img = test_samples[0]
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()
for i, temp in enumerate(temperatures):
    result = inference_mgr.colorize_image(sample_img, temperature=temp)
    axes[i].imshow(result['rgb'])
    axes[i].set_title(f'T = {temp:.2f}', fontsize=14)
    axes[i].axis('off')
plt.suptitle('Effect of Temperature on Colorization\n(Lower = more saturated, Higher = more conservative)',
             fontsize=16, y=0.98)
plt.tight_layout()
plt.show()

In [None]:
# Video colorization example
# Place a test video in ../data/test_video.mp4

video_input = Path("../data/test_video.mp4")
video_output = Path("../outputs/colorized_video.mp4")
video_output.parent.mkdir(exist_ok=True)

if video_input.exists():
    print(f"Colorizing video: {video_input}")
    print("This may take a few minutes depending on video length...")

    colorize_video(
        video_path=video_input,
        output_path=video_output,
        inference_manager=inference_mgr,
        temperature=0.42,
        max_frames=300,  # Limit frames for demo (remove for full video)
        create_gif=True,  # Also create animated GIF
        keep_frames=False
    )

    print(f"✓ Video saved to: {video_output}")
    print(f"✓ GIF saved to: {video_output.with_suffix('.gif')}")
else:
    print(f"Video not found at {video_input}")
    print("To test video colorization, place a video file at the path above.")