# Substation Viewer

Interactive widget to select an **OSM id** and display:
- Satellite image + Transformer overlay
- Satellite image + Component detections
- Transformer capacity + substation score

Steps to run:
1. Activate .venv as the notebook kernel
2. Click on "Run All"
3. Using the widget select an **OSM id** or type it in manually

## Setup

In [1]:
# --- Imports ---
import os, re, glob, json
from pathlib import Path

import numpy as np
import pandas as pd

from PIL import Image, ImageDraw
import matplotlib.pyplot as plt

import ipywidgets as widgets
from IPython.display import display, clear_output


In [2]:
# --- Config (EDIT THESE PATHS) ---
ROOT = Path('.').resolve()

SNAPSHOTS_DIR = ROOT / 'data' / 'snapshots'

UNET_OUT_DIR  = ROOT / 'output' / 'unet_results'
UNET_OVERLAYS = UNET_OUT_DIR / 'overlays'
UNET_CSV      = UNET_OUT_DIR / 'transformer_detections.csv'

CAPACITY_OUT_DIR = ROOT / 'output' / 'capacity_results'
CAPACITY_COMPONENT_CSV = CAPACITY_OUT_DIR / 'transformers_with_capacity.csv'
CAPACITY_SUMMARY_CSV   = CAPACITY_OUT_DIR / 'substations_capacity_summary.csv'

YOLO_OUT_DIR     = ROOT / 'output' / 'yolo_results'
YOLO_CSV         = YOLO_OUT_DIR / 'yolo_detections.csv'
YOLO_RENDERED_DIR = YOLO_OUT_DIR / 'annotated'

SCORE_CSV    = ROOT / 'output' / 'score_results' / 'substations_scored.csv'
MANIFEST_CSV = ROOT / 'data' / 'manifests' / 'substations_manifest.csv'

print('ROOT:', ROOT)


ROOT: /Users/seif.daknou/Documents/full_pipeline


In [3]:
# --- Helpers ---
def parse_osm_id_digits(x):
    if x is None:
        return None
    s = str(x).strip()
    if not s:
        return None
    m = re.findall(r'(\d+)', s)
    if not m:
        return None
    try:
        return int(m[-1])
    except Exception:
        return None

def read_csv_maybe(path: Path, comment='#'):
    if path and Path(path).exists():
        try:
            return pd.read_csv(path, comment=comment)
        except Exception as e:
            print(f'⚠️ Failed reading {path}: {e}')
    return None

def find_snapshot_path(osm_id: int):
    p = SNAPSHOTS_DIR / f'{osm_id}.png'
    if p.exists():
        return p
    hits = list(SNAPSHOTS_DIR.glob(f'{osm_id}.*'))
    return hits[0] if hits else None

def find_unet_overlay_path(osm_id: int):
    p = UNET_OVERLAYS / f'{osm_id}_overlay.png'
    if p.exists():
        return p
    hits = list(UNET_OVERLAYS.glob(f'*{osm_id}*_overlay*.png'))
    return hits[0] if hits else None

def find_yolo_render_path(osm_id: int):
    if YOLO_RENDERED_DIR.exists():
        for pat in [f'{osm_id}.png', f'*{osm_id}*.png', f'*{osm_id}*.jpg']:
            hits = list(YOLO_RENDERED_DIR.glob(pat))
            if hits:
                return hits[0]
    return None

def _infer_bbox_columns(df: pd.DataFrame):
    cols = set(df.columns)
    for cset in [('x1','y1','x2','y2'), ('xmin','ymin','xmax','ymax'), ('left','top','right','bottom')]:
        if all(c in cols for c in cset):
            return cset, 'corners'
    for cset in [('x_center','y_center','width','height'), ('xcenter','ycenter','w','h'), ('xc','yc','w','h')]:
        if all(c in cols for c in cset):
            return cset, 'centerwh'
    return None, None

def draw_yolo_boxes_on_image(image_path: Path, yolo_df: pd.DataFrame, osm_id: int):
    img = Image.open(image_path).convert('RGB')
    W, H = img.size
    draw = ImageDraw.Draw(img)

    if yolo_df is None or yolo_df.empty:
        return img

    # filter rows for this osm_id using common cols
    sub = None
    for c in ['osm_id','join_key','image','image_name','filename']:
        if c in yolo_df.columns:
            sub = yolo_df[yolo_df[c].apply(parse_osm_id_digits) == osm_id]
            break
    if sub is None:
        sub = yolo_df.iloc[0:0]

    if sub.empty:
        return img

    bbox_cols, mode = _infer_bbox_columns(sub)
    if bbox_cols is None:
        return img

    label_col = next((c for c in ['class_name','label','name','cls','class'] if c in sub.columns), None)
    conf_col  = next((c for c in ['confidence','conf','score'] if c in sub.columns), None)

    # normalized heuristic for centerwh
    normalized = False
    if mode == 'centerwh':
        try:
            mx = sub[bbox_cols[0]].astype(float).abs().max()
            mw = sub[bbox_cols[2]].astype(float).abs().max()
            normalized = (mx <= 1.5 and mw <= 1.5)
        except Exception:
            normalized = False

    def f(v):
        try: return float(v)
        except Exception: return None

    for _, r in sub.iterrows():
        if mode == 'corners':
            x1,y1,x2,y2 = f(r[bbox_cols[0]]), f(r[bbox_cols[1]]), f(r[bbox_cols[2]]), f(r[bbox_cols[3]])
            if None in (x1,y1,x2,y2): 
                continue
        else:
            xc,yc,w,h = f(r[bbox_cols[0]]), f(r[bbox_cols[1]]), f(r[bbox_cols[2]]), f(r[bbox_cols[3]])
            if None in (xc,yc,w,h): 
                continue
            if normalized:
                xc,yc,w,h = xc*W, yc*H, w*W, h*H
            x1,y1 = xc - w/2, yc - h/2
            x2,y2 = xc + w/2, yc + h/2

        x1 = max(0, min(W-1, x1)); x2 = max(0, min(W-1, x2))
        y1 = max(0, min(H-1, y1)); y2 = max(0, min(H-1, y2))

        draw.rectangle([x1,y1,x2,y2], width=3)

        parts = []
        if label_col: parts.append(str(r[label_col]))
        if conf_col:
            try: parts.append(f'{float(r[conf_col]):.2f}')
            except Exception: pass
        if parts:
            draw.text((x1+4, y1+4), ' '.join(parts))

    return img

def show_side_by_side(img_left, img_right, title_left, title_right):
    plt.figure(figsize=(14,6))
    ax1 = plt.subplot(1,2,1)
    ax2 = plt.subplot(1,2,2)
    ax1.imshow(img_left); ax1.set_title(title_left); ax1.axis('off')
    ax2.imshow(img_right); ax2.set_title(title_right); ax2.axis('off')
    plt.tight_layout()
    plt.show()


In [4]:
# --- Load data tables (best-effort) ---
unet_df = read_csv_maybe(UNET_CSV)
cap_comp_df = read_csv_maybe(CAPACITY_COMPONENT_CSV)
cap_sum_df  = read_csv_maybe(CAPACITY_SUMMARY_CSV)
yolo_df = read_csv_maybe(YOLO_CSV)
score_df = read_csv_maybe(SCORE_CSV)
manifest_df = read_csv_maybe(MANIFEST_CSV)

def _add_osm_id_col(df, source_cols):
    if df is None:
        return None
    if 'osm_id' in df.columns and df['osm_id'].notna().any():
        df['osm_id_digits'] = df['osm_id'].apply(parse_osm_id_digits)
        return df
    for c in source_cols:
        if c in df.columns:
            df['osm_id_digits'] = df[c].apply(parse_osm_id_digits)
            return df
    df['osm_id_digits'] = None
    return df

unet_df = _add_osm_id_col(unet_df, ['join_key','image_name','image'])
cap_comp_df = _add_osm_id_col(cap_comp_df, ['join_key','image_name','image'])
cap_sum_df  = _add_osm_id_col(cap_sum_df,  ['osm_id','join_key'])
yolo_df = _add_osm_id_col(yolo_df, ['join_key','image_name','image','filename'])
score_df = _add_osm_id_col(score_df, ['osm_id','join_key','id'])
manifest_df = _add_osm_id_col(manifest_df, ['osm_id','Id','osm_ref'])

available = set()
if SNAPSHOTS_DIR.exists():
    for p in SNAPSHOTS_DIR.glob('*.png'):
        oid = parse_osm_id_digits(p.name)
        if oid is not None:
            available.add(oid)
for df in [unet_df, cap_sum_df, yolo_df, score_df, manifest_df]:
    if df is not None and 'osm_id_digits' in df.columns:
        available |= set(int(x) for x in df['osm_id_digits'].dropna().unique())
available = sorted(available)
print(f'✅ available osm ids: {len(available)}')


✅ available osm ids: 5020


In [5]:

# ---- Table display helpers (prevents truncation) ----
from IPython.display import HTML
from typing import List, Optional

def _pick_cols(df: pd.DataFrame, preferred: List[str], extras: Optional[List[str]] = None) -> List[str]:
    cols = []
    for c in (preferred + (extras or [])):
        if c in df.columns and c not in cols:
            cols.append(c)
    # Always keep something visible
    return cols or list(df.columns[:12])

def display_df_scrollable(df: pd.DataFrame, cols: Optional[List[str]] = None, max_rows: int = 30, title: Optional[str] = None):
    if df is None or df.empty:
        if title:
            display(HTML(f"<b>{title}</b><div style='color:#777'>No rows</div>"))
        return

    if cols is None:
        cols = list(df.columns)

    show = df[cols].head(max_rows).copy()

    # basic rounding for readability
    for c in ["area_m2", "pred_capacity_mva", "meters_per_px", "capacity_mva_sum", "capacity_mva_median"]:
        if c in show.columns:
            show[c] = pd.to_numeric(show[c], errors="coerce").round(2)

    html = show.to_html(index=False, escape=False)
    title_html = f"<b>{title}</b><br/>" if title else ""
    display(HTML(
        title_html +
        "<div style='max-width: 100%; overflow-x: auto; border: 1px solid #ddd; padding: 6px;'>"
        + html +
        "</div>"
    ))

# --- Viewer ---
def view_osm(osm_id: int):
    snap_path = find_snapshot_path(osm_id)
    unet_ov_path = find_unet_overlay_path(osm_id)
    yolo_render_path = find_yolo_render_path(osm_id)

    if manifest_df is not None:
        mrow = manifest_df[manifest_df['osm_id_digits'] == osm_id]
        if not mrow.empty:
            display_df_scrollable(mrow, max_rows=1, title='Manifest row')

    if snap_path is None:
        print(f'❌ No snapshot for osm_id={osm_id} in {SNAPSHOTS_DIR}')
        return

    snap = Image.open(snap_path).convert('RGB')

    if unet_ov_path is not None:
        unet_ov = Image.open(unet_ov_path).convert('RGB')
        show_side_by_side(snap, unet_ov, f'Snapshot ({snap_path.name})', f'UNet overlay ({unet_ov_path.name})')
    else:
        plt.figure(figsize=(7,7))
        plt.imshow(snap); plt.title(f'Snapshot ({snap_path.name}) — no UNet overlay'); plt.axis('off')
        plt.show()

    # YOLO
    if yolo_render_path is not None:
        yimg = Image.open(yolo_render_path).convert('RGB')
        show_side_by_side(snap, yimg, 'Snapshot', f'YOLO rendered ({yolo_render_path.name})')
    else:
        if yolo_df is not None and not yolo_df.empty:
            yimg = draw_yolo_boxes_on_image(snap_path, yolo_df, osm_id)
            show_side_by_side(snap, yimg, 'Snapshot', 'YOLO boxes (drawn from CSV)')
        else:
            print('ℹ️ No YOLO CSV and no rendered YOLO image.')

    # Capacity
    print('\n--- Capacity ---')
    shown = False

    if unet_df is not None and 'pred_capacity_mva' in unet_df.columns:
        sub = unet_df[unet_df['osm_id_digits'] == osm_id]
        if not sub.empty:
            cols = _pick_cols(sub, [
                'image_name','component_id','area_m2','pred_capacity_mva','capacity_class','capacity_reason','voltage_used'
            ], extras=['meters_per_px','scale_source'])
            display_df_scrollable(sub, cols=cols, max_rows=50, title='Capacity rows (subset)')
            shown = True

    if cap_comp_df is not None:
        sub = cap_comp_df[cap_comp_df['osm_id_digits'] == osm_id]
        if not sub.empty:
            display(sub.head(50))
            shown = True

    if cap_sum_df is not None:
        sub = cap_sum_df[cap_sum_df['osm_id_digits'] == osm_id]
        if not sub.empty:
            display(sub.head(5))
            shown = True

    if not shown:
        print('(no capacity rows found)')

    # Score
    print('\n--- Score ---')
    if score_df is not None:
        sub = score_df[score_df['osm_id_digits'] == osm_id]
        if not sub.empty:
            display_df_scrollable(sub, max_rows=5, title='Score row')
        else:
            print('(no score row found)')
    else:
        print('(score CSV not loaded)')

In [6]:
# --- Widget UI ---
dropdown = widgets.Dropdown(
    options=[(str(x), x) for x in available[:5000]],
    description='OSM id:',
    value=available[0] if available else None,
    layout=widgets.Layout(width='320px')
)
text = widgets.Text(
    value=str(available[0]) if available else '',
    description='Type:',
    placeholder='e.g. 18921995',
    layout=widgets.Layout(width='320px')
)
btn = widgets.Button(description='Show', button_style='primary')
out = widgets.Output()

def _sync(change):
    if change['new'] is not None:
        text.value = str(change['new'])
dropdown.observe(_sync, names='value')

def _go(_):
    with out:
        clear_output(wait=True)
        oid = parse_osm_id_digits(text.value)
        if oid is None:
            print('❌ Could not parse an osm id from input.')
            return
        view_osm(int(oid))

## Visualisation

In [7]:
btn.on_click(_go)
display(widgets.HBox([dropdown, text, btn]))
display(out)

if available:
    _go(None)


HBox(children=(Dropdown(description='OSM id:', layout=Layout(width='320px'), options=(('4826893', 4826893), ('…

Output()