In [0]:
# %pip install -q -r requirements.txt

In [0]:
import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import easyocr

from src.data_loader import SingleImageDataset, VAL_TRANSFORM
from src.unet_model import load_unet, infer_masks
from src.geo_export import export_geopackage
from src.mask_utils import hsv_red_mask, postprocess_boundary_mask, extract_red_text_regions

# Load original image (RGB, not BGR)
img_path = "data/stockton_1.png"
img = cv2.imread(img_path)
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
orig_h, orig_w = img.shape[:2]

# Prepare dataset for UNet inference (single image)
dataset = SingleImageDataset(
    img_path,
    mask_path=None,
    transform=VAL_TRANSFORM,
)
dataloader = DataLoader(dataset, batch_size=1)

# Load model (should match training settings)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = load_unet(
    "checkpoints/unet_last.pt",
    classes=3, in_channels=3, device=device
)

# UNet prediction
boundary_512, text_512 = infer_masks(model, dataloader, device)
# Resize predictions to original size
scale = lambda m: cv2.resize(m, (orig_w, orig_h), cv2.INTER_NEAREST)
boundary_mask = scale(boundary_512)
text_mask = scale(text_512)

# POST-PROCESS BOUNDARY: Keep only areas that are red (HSV), remove noise
red_hsv_mask = hsv_red_mask(img_rgb)
boundary_mask = np.logical_and(boundary_mask, red_hsv_mask).astype(np.uint8)
boundary_mask = postprocess_boundary_mask(boundary_mask, min_size=100)

# OCR ON STRICTLY RED TEXT REGIONS ONLY
text_mask, red_text_bboxes = extract_red_text_regions(img_rgb, min_area=30)

reader = easyocr.Reader(['en'], gpu=(device=='cuda'))
img_with_ocr = img_rgb.copy()
ocr_results = []
for (x, y, w, h) in red_text_bboxes:
    roi = img_rgb[y:y+h, x:x+w]
    result = reader.readtext(roi)
    if result:
        text = result[0][1]
        # Draw bounding box and text for visualization
        cv2.rectangle(img_with_ocr, (x, y), (x+w, y+h), (0,0,255), 2)
        cv2.putText(img_with_ocr, text, (x, y-5), cv2.FONT_HERSHEY_SIMPLEX, 1.1, (0,0,255), 2)
        ocr_results.append((x, y, w, h, text))
    else:
        ocr_results.append((x, y, w, h, ""))

# EXPORT GEO DATA
export_geopackage(boundary_mask, text_mask, orig_w, orig_h, "output/segments.gpkg", text_labels=[t[-1] for t in ocr_results])

In [0]:
# --- VISUALISATION: original | boundary (mask) | text with OCR ---
plt.figure(figsize=(18,6))
plt.subplot(1,3,1)
plt.title("Original")
plt.imshow(img_rgb)
plt.axis('off')

plt.subplot(1,3,2)
plt.title("Predicted Boundary Mask (HSV-filtered)")
plt.imshow(boundary_mask, cmap='gray')
plt.axis('off')

plt.subplot(1,3,3)
plt.title("Text with OCR (Red Only)")
plt.imshow(img_with_ocr)
plt.axis('off')

plt.tight_layout()
plt.show()