In [1]:
import os
import joblib
import numpy as np
import networkx as nx
from glob import glob
from natsort import natsorted
from scipy.spatial import Delaunay
from skimage.measure import label, regionprops
import torch
from torch_geometric.data import Data
from sklearn.preprocessing import StandardScaler

# --- Compute Morphological Features ---
def compute_morph(contour, box):
    x1, y1, x2, y2 = map(int, box)
    pts = contour - np.array([x1, y1])
    h, w = y2 - y1 + 1, x2 - x1 + 1
    mask = np.zeros((h, w), dtype=np.uint8)
    rr = pts[:, 1].astype(int)
    cc = pts[:, 0].astype(int)
    mask[rr, cc] = 1
    lbl = label(mask)
    props = regionprops(lbl)
    if not props:
        return {'area': 0, 'perimeter': 0, 'eccentricity': 0,
                'solidity': 0, 'circularity': 0}
    r = props[0]
    return {
        'area': r.area,
        'perimeter': r.perimeter,
        'eccentricity': r.eccentricity,
        'solidity': r.solidity,
        'circularity': 4 * np.pi * r.area / (r.perimeter**2 + 1e-6)
    }

# --- Generate NetworkX Graph from .dat ---
def generate_graph_from_dat(dat_path):
    data = joblib.load(dat_path)
    nodes = []
    for idx, nucleus in enumerate(data):  # Iterate over list indices
        contour = np.array(nucleus['contour'])
        centroid = tuple(nucleus['centroid'])
        ntype = nucleus['type']
        x_min, y_min = contour.min(axis=0)
        x_max, y_max = contour.max(axis=0)
        box = (x_min, y_min, x_max, y_max)
        morph = compute_morph(contour, box)
        nodes.append({'id': idx,  # Use index as node ID
                      'centroid': centroid,
                      'type': ntype,
                      **morph})

    G = nx.Graph()
    for node in nodes:
        G.add_node(node['id'],
                   x=node['centroid'][0],
                   y=node['centroid'][1],
                   type=node['type'],
                   area=node['area'],
                   perimeter=node['perimeter'],
                   eccentricity=node['eccentricity'],
                   solidity=node['solidity'],
                   circularity=node['circularity'])
    points = np.array([n['centroid'] for n in nodes])
    if len(points) >= 3:
        tri = Delaunay(points)
        edges = set()
        for simplex in tri.simplices:
            for i in range(3):
                u, v = sorted([simplex[i], simplex[(i+1) % 3]])
                edges.add((u, v))
        for u_idx, v_idx in edges:
            p1, p2 = points[u_idx], points[v_idx]
            dist = np.linalg.norm(p1 - p2)
            inv_dist = 1.0 / (dist + 1e-6)
            factor = 1.0 if nodes[u_idx]['type'] == nodes[v_idx]['type'] else 0.5
            weight = inv_dist * factor
            G.add_edge(nodes[u_idx]['id'], nodes[v_idx]['id'], weight=weight)
    return G

# --- Convert NetworkX Graph to PyG Data ---
def convert_nx_to_pyg(G, label):
    # Convert node labels to integers
    G = nx.convert_node_labels_to_integers(G)
    
    # Extract node features
    node_features = []
    feature_names = ['x', 'y', 'type', 'area', 'perimeter', 'eccentricity', 'solidity', 'circularity']
    
    for node, data in G.nodes(data=True):
        features = []
        for f in feature_names:
            value = data.get(f, 0)
            try:
                features.append(float(value))
            except (ValueError, TypeError):
                print(f"Warning: Invalid value for {f} in node {node}: {value}. Using 0.")
                features.append(0.0)
        node_features.append(features)
    
    # Convert to numpy array
    node_features_np = np.array(node_features)
    
    # Debug: Check feature ranges
    for i, name in enumerate(feature_names):
        unique_values = np.unique(node_features_np[:, i])
        if len(unique_values) == 1:
            print(f"Warning: Feature '{name}' has identical values: {unique_values[0]}")
        else:
            print(f"Feature '{name}' range: {unique_values.min()} to {unique_values.max()}")
    
    # Normalize only area, perimeter, eccentricity, solidity, circularity (indices 3-7)
    scaler = StandardScaler()
    features_to_normalize = node_features_np[:, 3:]  # Select columns 3-7
    normalized_features = scaler.fit_transform(features_to_normalize)
    
    # Combine unnormalized (x, y, type) and normalized features
    node_features_combined = np.hstack([
        node_features_np[:, :3],  # x, y, type (unnormalized)
        normalized_features       # area, perimeter, eccentricity, solidity, circularity (normalized)
    ])
    
    # Convert to 2D tensor
    x = torch.tensor(node_features_combined, dtype=torch.float)
    if x.dim() != 2 or x.size(1) != len(feature_names):
        raise ValueError(f"Expected x to be 2D with shape [num_nodes, {len(feature_names)}], got shape {x.shape}")
    
    # Extract edge indices and attributes
    edge_index = []
    edge_attr = []
    for u, v, data in G.edges(data=True):
        edge_index.append([u, v])
        weight = float(data.get('weight', 1.0))
        edge_attr.append([weight])
    
    # Debug: Check edge weights
    edge_weights = np.array(edge_attr)
    unique_weights = np.unique(edge_weights)
    if len(unique_weights) == 1:
        print(f"Warning: All edge weights are identical: {unique_weights[0]}")
    else:
        print(f"Edge weight range: {unique_weights.min()} to {unique_weights.max()}")
    
    edge_index = torch.tensor(edge_index, dtype=torch.long).T.contiguous()
    edge_attr = torch.tensor(edge_attr, dtype=torch.float)
    
    # Create PyG Data object
    data = Data(
        x=x,
        edge_index=edge_index,
        edge_attr=edge_attr,
        y=torch.tensor([label], dtype=torch.long),
        original_node_indices=torch.arange(G.number_of_nodes(), dtype=torch.long)
    )
    
    # Add individual node features as separate attributes
    for i, feature_name in enumerate(feature_names):
        data[feature_name] = x[:, i]
    
    # Debug: Verify x tensor
    print(f"x tensor shape: {data.x.shape}")
    return data

# --- Generate and Save Graphs as .pt ---
def generate_and_save_graphs(dat_folder, graph_folder, label=0):
    os.makedirs(graph_folder, exist_ok=True)
    dats = natsorted(glob(os.path.join(dat_folder, "*.dat")))
    
    for dat_path in dats:
        # Generate NetworkX graph
        G = generate_graph_from_dat(dat_path)
        
        # Convert to PyG Data object
        graph = convert_nx_to_pyg(G, label)
        
        # Save full graph as .pt
        base = os.path.splitext(os.path.basename(dat_path))[0]
        graph_path = os.path.join(graph_folder, f"{base}.pt")
        torch.save(graph, graph_path)
        print(f"Saved full graph to {graph_path}")
    
    print(f"Generated {len(dats)} graphs in {graph_folder}")



In [None]:


# if __name__ == "__main__":
#     # Test visualization on a sample of images
#     subtype = 'Invasive'
#     print(f"Testing visualization for subtype: {subtype}")
    # generate_and_visualize_graphs(
    #     image_folder=f"./dataset/data/Photos/{subtype}/",
    #     dat_folder=f"./n_detected_pannuke/{subtype}/",
    #     graph_folder=f"./graphs_new_pannuke_visual/{subtype}/",
#         num_images=3
#     )
#     subtype = 'Normal'
#     print(f"Testing visualization for subtype: {subtype}")
#     generate_and_visualize_graphs(
#         image_folder=f"./dataset/data/Photos/{subtype}/",
#         dat_folder=f"./n_detected_pannuke/{subtype}/",
#         graph_folder=f"./graph_new/_pannuke_visual/{subtype}/",
#         num_images=3
#     )
#     subtype = 'InSitu'
#     print(f"Testing visualization for subtype: {subtype}")
#     generate_and_visualize_graphs(
#         image_folder=f"./dataset/data/Photos/{subtype}/",
#         dat_folder=f"./n_detected_pannuke/{subtype}/",
#         graph_folder=f"./graphs_new_pannuke_visual/{subtype}/",
#         num_images=3
#     )
#     subtype = 'Benign'
#     print(f"Testing visualization for subtype: {subtype}")
#     generate_and_visualize_graphs(
#         image_folder=f"./dataset/data/Photos/{subtype}/",
#         dat_folder=f"./n_detected_pannuke/{subtype}/",
#         graph_folder=f"./graphs_new_pannuke_visual/{subtype}/",
#         num_images=3
#     )


In [6]:
subtypes = {
    'Invasive': 3,
    'Benign': 1,
    'InSitu': 2,
    'Normal': 0
}

for st in subtypes.items():
        print(f"Saving all graphs for subtype: {st}")
        print(st[1])
        # generate_and_save_graphs(
        #     dat_folder=f"./n_detected_pannuke/{st}/",
        #     graph_folder=f"./graphs_new_pannuke/{st}/",
        #     label=i
        # )
        # i+=1


Saving all graphs for subtype: ('Invasive', 3)
3
Saving all graphs for subtype: ('Benign', 1)
1
Saving all graphs for subtype: ('InSitu', 2)
2
Saving all graphs for subtype: ('Normal', 0)
0
