In [0]:
# %pip install -q -r requirements.txt

In [0]:
import cv2, torch, matplotlib.pyplot as plt
from src.data_loader import SingleImageDataset, VAL_TRANSFORM
from src.unet_model import load_unet, infer_masks
from src.geo_export import export_geopackage
from src.mask_utils import hsv_red_mask
from torch.utils.data import DataLoader
import easyocr
import numpy as np

img_path = "data/stockton_1.png"
img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
orig_h, orig_w = img.shape[:2]

dataset = SingleImageDataset(
    img_path,
    mask_path=None,
    transform=VAL_TRANSFORM
)
dataloader = DataLoader(dataset, batch_size=1)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = load_unet(
    "checkpoints/unet_last.pt",
    classes=3, in_channels=3, device=device
)

boundary_mask, text_mask = infer_masks(model, dataloader, device)

# Upscale model outputs back to original size
orig_h, orig_w = img.shape[:2]
scale = lambda m: cv2.resize(m, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)
boundary_mask = scale(boundary_mask)
text_mask = scale(text_mask)

# Optional: HSV filtering to remove false positives for boundaries
red_mask = hsv_red_mask(img)
boundary_mask = np.logical_and(boundary_mask, red_mask).astype(np.uint8)

# OCR on text regions
reader = easyocr.Reader(['en'])
contours, _ = cv2.findContours(text_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
img_with_ocr = img.copy()
text_labels = []
for cnt in contours:
    x, y, w, h = cv2.boundingRect(cnt)
    roi = img[y:y+h, x:x+w]
    result = reader.readtext(roi)
    text = result[0][1] if result else ""
    text_labels.append(text)
    if text:
        cv2.putText(img_with_ocr, text, (x, y+h//2), cv2.FONT_HERSHEY_SIMPLEX, 1.1, (0,0,255), 3)

# Export geopackage
export_geopackage(boundary_mask, text_mask, orig_w, orig_h, "output/segments.gpkg", text_labels=text_labels)

# Visualize all side-by-side
plt.figure(figsize=(18,6))
plt.subplot(1,3,1); plt.title("Original"); plt.imshow(img); plt.axis('off')
plt.subplot(1,3,2); plt.title("Predicted Boundary Mask (HSV-filtered)");
plt.imshow(boundary_mask, cmap="gray"); plt.axis('off')
plt.subplot(1,3,3); plt.title("Text with OCR");
plt.imshow(img_with_ocr); plt.axis('off')
plt.tight_layout()
plt.show()