In [11]:
import sys
import os
import warnings
warnings.filterwarnings('ignore')

# Ensure we're in the project root
if os.path.basename(os.getcwd()) == 'notebooks':
    os.chdir('..')

sys.path.append(os.path.abspath("."))

import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"✅ Using device: {device}")
print(f"✅ Current directory: {os.getcwd()}")

# Test imports
try:
    from src.inference import ColorizeInference, find_latest_checkpoint
    print("✅ Inference modules imported successfully")
except ImportError as e:
    print(f"❌ Import error: {e}")
    print("Please make sure you have the corrected src/inference.py file")

✅ Using device: cpu
✅ Current directory: D:\colorization_task3\conditional-colorization-project
✅ Inference modules imported successfully


In [15]:
# Find latest checkpoint with better error handling
checkpoint_path = None

try:
    checkpoint_path = find_latest_checkpoint()
    print(f"🔍 Searching for checkpoints in: {os.path.abspath('checkpoints')}")
    
    if checkpoint_path:
        print(f"✅ Found checkpoint: {checkpoint_path}")
        
        try:
            # Initialize colorizer
            colorizer = ColorizeInference(checkpoint_path, device)
            print("✅ Model loaded successfully!")
            print(f"✅ Model image size: {colorizer.size}")
        except Exception as e:
            print(f"❌ Error loading model: {e}")
            import traceback
            traceback.print_exc()
            colorizer = None
    else:
        print("❌ No checkpoint found.")
        # List what's actually in the checkpoints directory
        if os.path.exists('checkpoints'):
            files = os.listdir('checkpoints')
            print(f"Files in checkpoints directory: {files}")
        else:
            print("Checkpoints directory doesn't exist. Please train the model first.")
        colorizer = None
        
except Exception as e:
    print(f"❌ Error finding checkpoint: {e}")
    import traceback
    traceback.print_exc()
    colorizer = None

ERROR:src.inference:Error finding latest checkpoint: int() argument must be a string, a bytes-like object or a real number, not 'list'


🔍 Searching for checkpoints in: D:\colorization_task3\conditional-colorization-project\checkpoints
❌ No checkpoint found.
Files in checkpoints directory: ['1014.jpg', 'best_model.pth', 'ckpt_epoch_1.pth', 'ckpt_epoch_2.pth', 'ckpt_epoch_3.pth', 'samples_epoch_1.png']


In [16]:
if colorizer is not None:
    # Check for test images
    test_dirs = ['data/val', 'data/train']
    test_image_path = None
    
    for test_dir in test_dirs:
        if os.path.exists(test_dir):
            images = [f for f in os.listdir(test_dir) 
                     if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
            if images:
                # FIX: Take first image, not the list
                test_image_path = os.path.join(test_dir, images[0])
                break
    
    if test_image_path:
        print(f"🖼️  Testing with image: {test_image_path}")
        
        try:
            # Run inference
            output_path = "notebooks/test_colorized.png"
            
            # Create notebooks directory if it doesn't exist
            os.makedirs("notebooks", exist_ok=True)
            
            result_path = colorizer.colorize(
                image_path=test_image_path,
                output_path=output_path,
                mask_path=None,  # No mask for automatic colorization
                color_hex=None
            )
            
            # Display results
            fig, axes = plt.subplots(1, 2, figsize=(12, 6))
            
            # Original image
            original = Image.open(test_image_path)
            axes.imshow(original)
            axes.set_title('Original Image')
            axes.axis('off')
            
            # Colorized image
            colorized = Image.open(result_path)
            axes.imshow(colorized)
            axes.set_title('Colorized Image')
            axes.axis('off')
            
            plt.tight_layout()
            plt.show()
            
            print(f"✅ Inference successful! Result saved to: {result_path}")
            
        except Exception as e:
            print(f"❌ Inference failed: {e}")
            import traceback
            traceback.print_exc()
    else:
        print("❌ No test images found in data directories")
        # List what's in the data directories
        for test_dir in test_dirs:
            if os.path.exists(test_dir):
                files = os.listdir(test_dir)
                print(f"Files in {test_dir}: {files}")
else:
    print("❌ Model not available for testing")

❌ Model not available for testing


In [4]:
if colorizer is not None and 'test_image_path' in locals() and test_image_path:
    print("🎨 Testing with color hints...")
    
    try:
        # Create a simple mask (center region)
        img = Image.open(test_image_path)
        mask_size = (colorizer.size, colorizer.size)
        mask = Image.new('L', mask_size, 0)  # Black background
        
        # Create white circle in center
        from PIL import ImageDraw
        draw = ImageDraw.Draw(mask)
        center_x, center_y = mask_size//2, mask_size//2  # FIX: Separate values
        radius = min(mask_size) // 4
        
        # FIX: Proper ellipse coordinates
        draw.ellipse([
            center_x - radius, center_y - radius,
            center_x + radius, center_y + radius
        ], fill=255)
        
        # Save mask
        mask_path = "notebooks/test_mask.png"
        mask.save(mask_path)
        print(f"✅ Mask created and saved to: {mask_path}")
        
        # Test different colors
        colors = ["#ff6b6b", "#4ecdc4", "#45b7d1", "#96ceb4", "#feca57"]
        
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        axes = axes.flatten()
        
        # Original
        original = Image.open(test_image_path)
        axes.imshow(original)
        axes.set_title('Original')
        axes.axis('off')
        
        success_count = 0
        for i, color in enumerate(colors):
            try:
                output_path = f"notebooks/test_colored_{i}.png"
                
                # Run inference with color hint
                colorizer.colorize(
                    image_path=test_image_path,
                    output_path=output_path,
                    mask_path=mask_path,
                    color_hex=color
                )
                
                # Display result
                colorized = Image.open(output_path)
                axes[i+1].imshow(colorized)
                axes[i+1].set_title(f'Color: {color}')
                axes[i+1].axis('off')
                success_count += 1
                
            except Exception as e:
                print(f"❌ Error with color {color}: {e}")
                axes[i+1].text(0.5, 0.5, f'Error\n{color}', 
                              ha='center', va='center', transform=axes[i+1].transAxes)
                axes[i+1].axis('off')
        
        plt.tight_layout()
        plt.show()
        
        print(f"✅ Color hint testing complete! {success_count}/{len(colors)} colors succeeded.")
        
    except Exception as e:
        print(f"❌ Error in color hint testing: {e}")
        import traceback
        traceback.print_exc()
else:
    print("❌ Model or test image not available for color hint testing")