In [None]:
import numpy as np
import torch
import networkx as nx
import matplotlib.pyplot as plt
from ts2vg import NaturalVG
from tqdm import tqdm
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import os
from PIL import Image
from pathlib import Path

# ─── SETTINGS ───────────────────────────────────────────────────────
INPUT_DIR      = "/home/mhs/research/thesis/balanced_img"  # Change this to your input directory
OUTPUT_DIR     = "/home/mhs/research/thesis/visibility_graph"
SAMPLE_SIZE    = 500       # max nodes per VG
FIGSIZE        = (10, 10)
DPI            = 25.6
NODE_COLOR     = (0.695, 0.746, 0.0273, 1)
EDGE_COLOR     = (0.695, 0.746, 0.0273, 0.25)
graph_opts     = {
    "with_labels": False,
    "node_size": 2,
    "node_color": [NODE_COLOR],
    "edge_color": [EDGE_COLOR],
}

def gen_vg(img_np: np.ndarray, sample_size=SAMPLE_SIZE) -> nx.Graph:
    flat = img_np.flatten()
    # downsample for speed
    if flat.size > sample_size:
        idx = np.linspace(0, flat.size - 1, sample_size, dtype=int)
        ts  = flat[idx]
    else:
        ts = flat
    
    nvg = NaturalVG()
    nvg.build(ts)
    return nvg.as_networkx()

def process_image(img_path: Path, output_path: Path):
    # Create output directory if it doesn't exist
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    # Load and convert image to grayscale numpy array
    img = Image.open(img_path).convert('L')
    img_np = np.array(img)
    
    # Generate visibility graph
    G = gen_vg(img_np)
    
    # Create visualization
    fig, ax = plt.subplots(figsize=FIGSIZE)
    ax.set_facecolor("black")
    for sp in ax.spines.values():
        sp.set_visible(False)
    
    pos = nx.kamada_kawai_layout(G)
    nx.draw_networkx(G, pos=pos, **graph_opts)
    
    # Save the graph image
    plt.savefig(
        output_path,
        dpi=DPI,
        bbox_inches="tight",
        pad_inches=0,
        facecolor="black"
    )
    plt.close(fig)

def main():
    input_root = Path(INPUT_DIR)
    output_root = Path(OUTPUT_DIR)
    
    # Clear existing output directory
    if output_root.exists():
        import shutil
        shutil.rmtree(output_root)
    output_root.mkdir(parents=True)
    
    # Process all images in subdirectories
    image_files = []
    for ext in ('*.jpg', '*.jpeg', '*.png'):
        image_files.extend(input_root.rglob(ext))
    
    for img_path in tqdm(image_files, desc="Processing images"):
        # Maintain same directory structure in output
        rel_path = img_path.relative_to(input_root)
        output_path = output_root / rel_path.with_suffix('.png')
        process_image(img_path, output_path)
        
    print(f"Processed {len(image_files)} images")

if __name__ == "__main__":
    main()

Processing images:   1%|          | 50/9096 [02:06<6:27:40,  2.57s/it]