In [2]:
import os 
import sys
import json
import time
import random
import cv2
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from collections import defaultdict
from torch_geometric.data import Data

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")) + '/lib/')

from lib.blob_extraction import img_preprocess_mask, blob_detect, get_nodes_pos
from lib.graph_generate import Delaunay_graph_generate
from lib.voronoi_generate import cal_3d_Voronoi

Parameter sweep for robust blob detection

In [3]:
REF_IMAGE_PATH = os.path.abspath(os.path.join(os.getcwd(), "..")) + r"\data\trans\2_1_1_1\frames_bw\frame_0_0.png"
ZPARAM = dict(pool_neighbours=3, num_interp_points=50)

PARAM_SETS = [
    dict(ppm=dict(erosion=False, resize_x=300, resize_y=300, kernel_size=1,
                  binary_threshold=50, circle_x_bias=0, circle_y_bias=0, circle_radius_bias=0),
         mask_radius=None,   
         blob=dict(minArea=5, blobColor=0, minCircularity=0.1, minConvexity=0.01, minInertiaRatio=0.01,
                   thresholdStep=5, minDistBetweenBlobs=2.0, minRepeatability=2)),
    dict(ppm=dict(erosion=False, resize_x=360, resize_y=360, kernel_size=1,
                  binary_threshold=80, circle_x_bias=0, circle_y_bias=0, circle_radius_bias=0),
         mask_radius=None,
         blob=dict(minArea=7, blobColor=0, minCircularity=0.08, minConvexity=0.01, minInertiaRatio=0.01,
                   thresholdStep=5, minDistBetweenBlobs=2.5, minRepeatability=2)),
    dict(ppm=dict(erosion=False, resize_x=380, resize_y=380, kernel_size=1,
                  binary_threshold=65, circle_x_bias=0, circle_y_bias=0, circle_radius_bias=0),
         mask_radius=None,
         blob=dict(minArea=6, blobColor=0, minCircularity=0.08, minConvexity=0.01, minInertiaRatio=0.01,
                   thresholdStep=5, minDistBetweenBlobs=2.0, minRepeatability=2)),
]

Utility functions

In [4]:
#Extract material label from directory name
def material_key_from_dir(dir_name: str) -> str:
    parts = dir_name.split('_')
    if len(parts) >= 2:
        return f"{parts[0]}_{parts[1]}"
    return dir_name

#Load pose2 and pose6 values
def load_targets_csv(csv_path: str):
    if not os.path.exists(csv_path):
        return {}
    df = pd.read_csv(csv_path)
    cols = {c.lower(): c for c in df.columns}
    for k in ["image_name", "pose_2", "pose_6"]:
        if k not in cols:
            raise RuntimeError(f"{csv_path} 缺少列: {k}")
    names = df[cols["image_name"]].astype(str).apply(lambda s: Path(s).stem)
    pose2 = pd.to_numeric(df[cols["pose_2"]], errors="coerce")
    pose6 = pd.to_numeric(df[cols["pose_6"]], errors="coerce")
    mapping = {}
    for s, p2, p6 in zip(names, pose2, pose6):
        mapping[s] = (float(p2) if pd.notna(p2) else np.nan,
                      float(p6) if pd.notna(p6) else np.nan)
    return mapping

In [5]:
#Split graph list
def stratified_split_by_class(data_list, y_list, test_ratio=0.25, seed=42):
    rng = random.Random(seed)
    per_class = defaultdict(list)
    for idx, y in enumerate(y_list):
        per_class[int(y)].append(idx)
    train_idx, test_idx = [], []
    for _, idxs in per_class.items():
        rng.shuffle(idxs)
        n = len(idxs)
        n_test = max(1, int(round(n * test_ratio)))
        test_idx.extend(idxs[:n_test])
        train_idx.extend(idxs[n_test:])
    return [data_list[i] for i in train_idx], [data_list[i] for i in test_idx]

In [6]:
#Auto-estimate circular mask radius
def _auto_mask_radius(h, w, scale=0.48):
    return int(scale * min(h, w))


In [7]:
#Apply image preprocessing
def preprocess_and_detect_with_params(img_bgr, ppm, mask_radius, blob):
    processed, white_mask, erosion = img_preprocess_mask(
        img_bgr,
        erosion=ppm["erosion"],
        resize_x=ppm["resize_x"], resize_y=ppm["resize_y"],
        kernel_size=ppm["kernel_size"],
        binary_threshold=ppm["binary_threshold"],
        circle_x_bias=ppm["circle_x_bias"],
        circle_y_bias=ppm["circle_y_bias"],
        circle_radius_bias=ppm["circle_radius_bias"]
    )
    if mask_radius is None:
        mask_radius = _auto_mask_radius(processed.shape[0], processed.shape[1])

    mask = np.zeros(processed.shape[:2], dtype=np.uint8)
    center = (processed.shape[1] // 2, processed.shape[0] // 2)
    cv2.circle(mask, center, mask_radius, 255, -1)
    masked = cv2.bitwise_and(processed, processed, mask=mask)

    kps = blob_detect(
        masked,
        minArea=blob["minArea"], blobColor=blob["blobColor"],
        minCircularity=blob["minCircularity"], minConvexity=blob["minConvexity"],
        minInertiaRatio=blob["minInertiaRatio"], thresholdStep=blob["thresholdStep"],
        minDistBetweenBlobs=blob["minDistBetweenBlobs"], minRepeatability=blob["minRepeatability"]
    )
    nodes_pos = get_nodes_pos(kps)  # shape: [N, 2] (x,y)
    return nodes_pos, masked.shape

In [8]:
def detect_pins_with_sweep(img_bgr, min_pins=100, max_pins=400):
    last_shape = None
    for si, p in enumerate(PARAM_SETS):
        nodes_pos, shp = preprocess_and_detect_with_params(img_bgr, p["ppm"], p["mask_radius"], p["blob"])
        last_shape = shp
        n = nodes_pos.shape[0]
        if min_pins <= n <= max_pins:
            return nodes_pos, si, shp
    return None, None, last_shape

In [9]:
#Return layout label
def classify_layout(n):
    if 120 <= n <= 134:
        return '127'
    if 132 <= n <= 142:
        return '137'
    if 320 <= n <= 340:
        return '331'
    return 'generic'

Auto template-based Voronoi transform

In [10]:
#Automatically select matching Voronoi transform template based on pin count.
class TransformVoronoi_Auto:

    def __init__(self, borderScale=1.1):
        self.borderScale = float(borderScale)
        self.T127 = self.T331 = self.T137 = None
        try:
            from lib.voronoi_generate import TransformVoronoi_127 as T127
            self.T127 = T127
        except Exception:
            pass
        try:
            from lib.voronoi_generate import TransformVoronoi_331 as T331
            self.T331 = T331
        except Exception:
            pass
        try:
            from lib.voronoi_generate import TransformVoronoi_137 as T137
            self.T137 = T137
        except Exception:
            pass

    @staticmethod
    def _normalize_xy(xy):
        xy = np.asarray(xy, dtype=np.float32)
        c = xy.mean(axis=0, keepdims=True)
        centered = xy - c
        r = np.sqrt((centered**2).sum(axis=1)).mean()
        if r < 1e-6:
            r = 1.0
        return centered / r

    def _generic_voronoi(self, nodes_pos):
        pts = np.asarray(nodes_pos, dtype=np.float32)
        xmin, ymin = pts.min(axis=0)
        xmax, ymax = pts.max(axis=0)
        w = xmax - xmin
        h = ymax - ymin
        padx = (self.borderScale - 1.0) * w * 0.5
        pady = (self.borderScale - 1.0) * h * 0.5
        rect = (int(xmin - padx), int(ymin - pady), int(xmax + padx), int(ymax + pady))

        subdiv = cv2.Subdiv2D(rect)
        for (x, y) in pts:
        
            try:
                subdiv.insert((float(x), float(y)))
            except cv2.error:
                pass

        idx = list(range(len(pts)))
        facets, centers = subdiv.getVoronoiFacetList(idx)

        areas = np.zeros(len(pts), dtype=np.float32)
        Cx = np.zeros(len(pts), dtype=np.float32)
        Cy = np.zeros(len(pts), dtype=np.float32)

        for i, (facet, c) in enumerate(zip(facets, centers)):
            if facet is None or len(facet) == 0:
                areas[i] = 0.0
                Cx[i], Cy[i] = float(c[0]), float(c[1])
                continue
            poly = np.array(facet, dtype=np.float32)
            areas[i] = float(cv2.contourArea(poly))
            M = cv2.moments(poly)
            if M["m00"] != 0:
                Cx[i] = float(M["m10"] / M["m00"])
                Cy[i] = float(M["m01"] / M["m00"])
            else:
                Cx[i], Cy[i] = float(c[0]), float(c[1])

        XY = self._normalize_xy(pts)  # Normalization
        return areas, Cx, Cy, XY

    def transform(self, nodes_pos):
        n = nodes_pos.shape[0]
        layout = classify_layout(n)
        if layout == '127' and self.T127 is not None:
            return self.T127(borderScale=self.borderScale).transform(nodes_pos)
        if layout == '331' and self.T331 is not None:
            return self.T331(borderScale=self.borderScale).transform(nodes_pos)
        if layout == '137' and self.T137 is not None:
            return self.T137(borderScale=self.borderScale).transform(nodes_pos)
     
        return self._generic_voronoi(nodes_pos)

Convert one image to graph + depth

In [11]:

def frame_to_graph_and_depth(img_bgr, transformer, Z_ref):
    nodes_pos, set_idx, shp = detect_pins_with_sweep(img_bgr, min_pins=100, max_pins=400)
    if nodes_pos is None:
        return None, None, None, shp, set_idx, None

    Axx, Cxx, Cyy, XY = transformer.transform(nodes_pos)
    Xg, Yg, Z = cal_3d_Voronoi(Axx, Cxx, Cyy, **ZPARAM)
    depth_scalar = float(np.mean(Z - Z_ref))

    # Node
    pair = np.asarray([XY[:, 0], XY[:, 1], Axx[:]]).T
    x = torch.tensor(pair, dtype=torch.float)

    # Edge
    edge = torch.tensor(np.array(Delaunay_graph_generate(nodes_pos)).T, dtype=torch.long).contiguous()

    return x, edge, depth_scalar, shp, set_idx, nodes_pos.shape[0]


Main dataset build function

In [12]:
# Main
def build_dataset(root_dir: str, train_dir: str, test_dir: str, seed: int = 42):
    os.makedirs(train_dir, exist_ok=True)
    os.makedirs(test_dir, exist_ok=True)
    np.random.seed(seed); random.seed(seed)

    root = Path(root_dir).resolve()
    subdirs = [p for p in root.iterdir() if p.is_dir()]

    material_keys = sorted({material_key_from_dir(p.name) for p in subdirs})
    mat2idx = {mk: i for i, mk in enumerate(material_keys)}
    num_classes = len(mat2idx)

    transformer = TransformVoronoi_Auto(borderScale=1.1)

    # Reference frame
    if not os.path.exists(REF_IMAGE_PATH):
        raise FileNotFoundError(f"Reference frame not found: {REF_IMAGE_PATH}")
    ref_img = cv2.imread(REF_IMAGE_PATH)
    ref_nodes, ref_set, ref_shape = detect_pins_with_sweep(ref_img, min_pins=100, max_pins=400)
    if ref_nodes is None:
        raise RuntimeError("Failed to detect valid pins in the reference frame")
    A_ref, Cx_ref, Cy_ref, XY_ref = transformer.transform(ref_nodes)
    Xg_ref, Yg_ref, Z_ref = cal_3d_Voronoi(A_ref, Cx_ref, Cy_ref, **ZPARAM)
    tip_num_ref = int(ref_nodes.shape[0])

    # -Generation
    data_list, y_list = [], []
    pin_fail, used_set_hist = 0, [0]*len(PARAM_SETS)
    last_shape = ref_shape
    total_png = 0
    pin_hist = defaultdict(int)
    t0 = time.time()

    for mat_dir in subdirs:
        mk = material_key_from_dir(mat_dir.name)
        y_idx = mat2idx[mk]

        frames_dir = mat_dir / "frames_bw"
        csv_path = mat_dir / "targets.csv"
        '''if not frames_dir.exists():
            print(f"[WARN] 跳过 {mat_dir}: 无 frames_bw/")
            continue'''

        labels = load_targets_csv(str(csv_path))
        pngs = sorted(frames_dir.glob("*.png"))
        for img_path in pngs:
            total_png += 1
            stem = img_path.stem
            p2, p6 = labels.get(stem, (np.nan, np.nan))

            img = cv2.imread(str(img_path))
            x, edge, depth_scalar, shp, set_idx, n_pins = frame_to_graph_and_depth(img, transformer, Z_ref)
            last_shape = shp

            if x is None:
                pin_fail += 1
                continue

            used_set_hist[set_idx] += 1
            pin_hist[int(n_pins)] += 1

            t = torch.tensor([p2, p6, depth_scalar, float(y_idx)], dtype=torch.float)
            y = torch.tensor(y_idx, dtype=torch.long)
            data_list.append(Data(x=x, edge_index=edge, t=t, y=y))
            y_list.append(y_idx)

    # Save
    train_val, test = stratified_split_by_class(data_list, y_list, test_ratio=0.25, seed=seed)

    torch.save(train_val, str(Path(train_dir) / "Train_val_data_list.pt"))
    torch.save(test,     str(Path(test_dir)  / "Test_data_list.pt"))
    with open(Path(train_dir) / "material_id_to_idx.json", "w", encoding="utf-8") as f:
        json.dump({
            "material_id_to_idx": {k: int(v) for k, v in mat2idx.items()},
            "num_classes": int(num_classes),
            "tip_num_ref": int(tip_num_ref),
            "ref_image_path": REF_IMAGE_PATH,
            "z_params": ZPARAM,
            "param_sets": PARAM_SETS,
            "pin_hist": {str(k): int(v) for k, v in sorted(pin_hist.items())}
        }, f, ensure_ascii=False, indent=2)

    dt = time.time() - t0
    print("\nVoronoi Graph Dataset Loading Complete!!!")
    print("Example processed image shape:", last_shape)
    print(f"Materials (classes): {num_classes}")
    print(f"Images scanned: {total_png} | Graphs built: {len(data_list)}")
    print(f"Train/Val: {len(train_val)} | Test: {len(test)}")
    print(f"Pin extract fail: {pin_fail} | Used param sets: " +
          ", ".join([f"S{i}:{c}" for i,c in enumerate(used_set_hist)]))
    print("Pin count histogram:", dict(sorted(pin_hist.items())))
    print(f"Time cost: {dt:.2f}s | Speed: {(len(data_list)/dt) if dt>0 else 0:.2f} graphs/s")

    return train_val, test

In [13]:
if __name__ == "__main__":
    root_dir  = r"..\data\trans"   
    train_dir = r"..\result\train"
    test_dir  = r"..\result\test"
    build_dataset(root_dir, train_dir, test_dir, seed=42)



Voronoi Graph Dataset Loading Complete!!!
Example processed image shape: (300, 300)
Materials (classes): 50
Images scanned: 25251 | Graphs built: 24539
Train/Val: 18414 | Test: 6125
Pin extract fail: 712 | Used param sets: S0:16145, S1:8013, S2:381
Pin count histogram: {100: 280, 101: 296, 102: 324, 103: 340, 104: 365, 105: 398, 106: 409, 107: 422, 108: 447, 109: 472, 110: 467, 111: 471, 112: 472, 113: 491, 114: 462, 115: 489, 116: 466, 117: 451, 118: 418, 119: 418, 120: 452, 121: 410, 122: 404, 123: 438, 124: 425, 125: 406, 126: 423, 127: 466, 128: 456, 129: 465, 130: 487, 131: 484, 132: 528, 133: 512, 134: 477, 135: 505, 136: 469, 137: 454, 138: 419, 139: 382, 140: 387, 141: 359, 142: 343, 143: 326, 144: 321, 145: 296, 146: 293, 147: 253, 148: 262, 149: 237, 150: 244, 151: 204, 152: 234, 153: 193, 154: 226, 155: 188, 156: 194, 157: 178, 158: 186, 159: 171, 160: 151, 161: 166, 162: 155, 163: 151, 164: 119, 165: 111, 166: 102, 167: 116, 168: 102, 169: 87, 170: 88, 171: 86, 172: 65, 17