In [None]:
import sys
from pathlib import Path

try:
	base_dir = Path(__file__).resolve().parent
except NameError:
	base_dir = Path.cwd()

# If the notebook sits in a 'notebooks' folder, assume the repo root is its parent
repo_root = base_dir.parent if base_dir.name == "notebooks" else base_dir

sys.path.insert(0, str(repo_root))

In [None]:
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from skimage import color
from CNN_Transformer_Model import CNNTransformerColourizer

print("CUDA Available:", torch.cuda.is_available())
print("Device:", torch.device("cuda" if torch.cuda.is_available() else "cpu"))

In [None]:
# Path to trained model checkpoint
MODEL_PATH = r"C:\Users\ethan\Github Repositories\Image Colourisation\models\final.pth"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNNTransformerColourizer().to(device)

checkpoint = torch.load(MODEL_PATH, map_location=device)
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
	model.load_state_dict(checkpoint["model_state_dict"])
else:
	model.load_state_dict(checkpoint)

model.eval()

print("Model loaded successfully.")

In [None]:
# Path to input image (grayscale or RGB)
IMAGE_PATH = r"C:\Users\ethan\Github Repositories\Image Colourisation\tests\sample_9.jpg"

# Load image
img = Image.open(IMAGE_PATH)

# Convert to grayscale if RGB
if img.mode == 'RGB':
	img = img.convert('L')

# Convert to numpy and normalize to [-1, 1]
L = np.array(img).astype("float32")
L = (L / 128) - 1.0

# Add batch dimension: [H, W] -> [1, H, W]
L = torch.tensor(L).unsqueeze(0)

print("Input image shape:", L.shape)

In [None]:
# Get prediction
# ensure input is on the same device as the model
L_tensor = L.unsqueeze(0).to(device)

with torch.no_grad():
    pred_AB = model(L_tensor)

# Apply saturation boost (in normalized tanh [-1,1] space) and clamp
SATURATION_BOOST = 1.2 
pred_AB_boosted = torch.clamp(pred_AB * SATURATION_BOOST, -1.0, 1.0)

# Convert to numpy and reshape: (2, H, W) -> (H, W, 2)
L_np = L_tensor.squeeze().cpu().numpy()
pred_AB_np = np.transpose(pred_AB.squeeze().cpu().numpy(), (1, 2, 0))
pred_AB_boosted_np = np.transpose(pred_AB_boosted.squeeze().cpu().numpy(), (1, 2, 0))

# Convert back to CIELAB ranges
# L in [-1,1] -> normalize to [0,100]
L_lab = ((L_np + 1.0) / 2.0) * 100.0

# AB predicted in [-1,1] -> approximate a/b scale (use same scale for boosted)
AB_SCALE = 128.0
pred_AB_lab = pred_AB_np * AB_SCALE
pred_AB_boosted_lab = pred_AB_boosted_np * AB_SCALE

# Clip AB to reasonable CIELAB-like range to avoid extreme values
pred_AB_lab = np.clip(pred_AB_lab, -127.0, 127.0)
pred_AB_boosted_lab = np.clip(pred_AB_boosted_lab, -127.0, 127.0)

# Reconstruct LAB images (use float dtype)
lab_out = np.empty((L_lab.shape[0], L_lab.shape[1], 3), dtype=np.float64)
lab_out[:, :, 0] = L_lab
lab_out[:, :, 1:] = pred_AB_lab

lab_boosted = np.empty_like(lab_out)
lab_boosted[:, :, 0] = L_lab
lab_boosted[:, :, 1:] = pred_AB_boosted_lab

# Convert LAB â†’ RGB, clamp to [0,1] then convert to uint8 safely
rgb_out_f = color.lab2rgb(lab_out)
rgb_out_f = np.clip(rgb_out_f, 0.0, 1.0)
rgb_out = (rgb_out_f * 255.0).round().astype(np.uint8)

rgb_boosted_f = color.lab2rgb(lab_boosted)
rgb_boosted_f = np.clip(rgb_boosted_f, 0.0, 1.0)
rgb_boosted = (rgb_boosted_f * 255.0).round().astype(np.uint8)

print("Colorized output shape:", rgb_out.shape)

In [None]:
# Post-processing: boost saturation and save boosted result
from PIL import ImageEnhance, Image

# Factor >1 increases saturation (1.0 = original)
SATURATION_FACTOR = 1

# Convert colorized numpy array to PIL Image and boost saturation
img_colorized = Image.fromarray(rgb_out)
enhancer = ImageEnhance.Color(img_colorized)
img_boosted = enhancer.enhance(SATURATION_FACTOR)

# Back to numpy for further processing/display
rgb_boosted = np.array(img_boosted)

# Save boosted image to repo-level results folder
output_path = repo_root / "results" / "colorized_boosted.jpg"
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
img_boosted.save(output_path)
print(f"Boosted saturation by {SATURATION_FACTOR} and saved to: {output_path}")

In [None]:
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.title("Grayscale Input")
plt.imshow(img_np)
plt.axis("off")

plt.subplot(1, 3, 2)
plt.title("Colourised Output")
plt.imshow(rgb_out)
plt.axis("off")

plt.subplot(1, 3, 3)
plt.title("Boosted Saturation")
plt.imshow(rgb_boosted)
plt.axis("off")

plt.show()

In [1]:
# Demo
import subprocess
import webbrowser
import time
import sys
from pathlib import Path
try:
	base_dir = Path(__file__).resolve().parent
except NameError:
	base_dir = Path.cwd()
repo_root = base_dir.parent if base_dir.name == "notebooks" else base_dir
sys.path.insert(0, str(repo_root))

# Run the Gradio app
print("Starting Gradio app...")
process = subprocess.Popen([
    sys.executable, 
    str(repo_root / "src" / "gradio_app.py")
])

# wait...
time.sleep(10)

# Open in browser
webbrowser.open("http://localhost:7860")

Starting Gradio app...


True