In [0]:
# %pip install -q -r requirements.txt

In [0]:
import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import easyocr

from src.data_loader import SingleImageDataset, VAL_TRANSFORM
from src.unet_model import load_unet, infer_masks
from src.geo_export import export_geopackage
from src.mask_utils import hsv_red_mask, postprocess_boundary_mask

# --- Load image and model ---
img_path = "data/stockton_1.png"
img = cv2.imread(img_path)
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
orig_h, orig_w = img.shape[:2]

dataset = SingleImageDataset(img_path, mask_path=None, transform=VAL_TRANSFORM)
dataloader = DataLoader(dataset, batch_size=1)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = load_unet("checkpoints/unet_last.pt", classes=3, in_channels=3, device=device)

# --- Model predictions (raw, not post-processed) ---
boundary_512, text_512 = infer_masks(model, dataloader, device)
# Resize back to original size
scale = lambda m: cv2.resize(m, (orig_w, orig_h), cv2.INTER_NEAREST)
boundary_pred = scale(boundary_512)    # Model's boundary prediction (0/1 mask)
text_pred     = scale(text_512)        # Model's text prediction (0/1 mask)

# --- HSV masks (for red) ---
red_hsv_mask = hsv_red_mask(img_rgb)    # Binary mask where HSV says "red"

# --- Filtered masks: model AND HSV ---
boundary_final = np.logical_and(boundary_pred, red_hsv_mask).astype(np.uint8)
boundary_final = postprocess_boundary_mask(boundary_final, min_size=100)   # Remove small blobs
text_final = np.logical_and(text_pred, red_hsv_mask).astype(np.uint8)

# --- OCR only where model+HSV say it's text ---
reader = easyocr.Reader(['en'], gpu=(device=='cuda'))
img_with_ocr = img_rgb.copy()
contours, _ = cv2.findContours(text_final, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
ocr_results = []
for cnt in contours:
    x, y, w, h = cv2.boundingRect(cnt)
    roi = img_rgb[y:y+h, x:x+w]
    result = reader.readtext(roi)
    text = result[0][1] if result else ""
    # Visualise
    cv2.rectangle(img_with_ocr, (x, y), (x+w, y+h), (0,0,255), 2)
    cv2.putText(img_with_ocr, text, (x, y-5), cv2.FONT_HERSHEY_SIMPLEX, 1.1, (0,0,255), 2)
    ocr_results.append((x, y, w, h, text))

In [0]:
# --- EXPORT GEO DATA ---
# This saves the final boundary and text masks, with OCR labels, as polygons in output/segments.gpkg
export_geopackage(
    boundary_final,         # binary mask for boundary (filtered)
    text_final,             # binary mask for text (filtered)
    orig_w, orig_h,
    "output/segments.gpkg",
    text_labels=[t[-1] for t in ocr_results]   # list of OCR results for each text region
)
print("GeoPackage saved: output/segments.gpkg")

In [0]:
# --- Visualise all intermediate results ---
plt.figure(figsize=(18,12))

plt.subplot(2,3,1); plt.title("Original Image"); plt.imshow(img_rgb); plt.axis('off')

plt.subplot(2,3,2); plt.title("Boundary: Model Prediction"); plt.imshow(boundary_pred, cmap='gray'); plt.axis('off')

plt.subplot(2,3,3); plt.title("HSV Red Mask"); plt.imshow(red_hsv_mask, cmap='gray'); plt.axis('off')

plt.subplot(2,3,4); plt.title("Boundary: Model AND HSV"); plt.imshow(boundary_final, cmap='gray'); plt.axis('off')

plt.subplot(2,3,5); plt.title("Text: Model Prediction"); plt.imshow(text_pred, cmap='gray'); plt.axis('off')

plt.subplot(2,3,6); plt.title("Text: Model AND HSV + OCR"); plt.imshow(img_with_ocr); plt.axis('off')

plt.tight_layout()
plt.show()