# Graph Representation

a = list(graph.nodes())

sub_graph = nx.ego_graph(graph, a[0], radius=2)
nx.draw_networkx(sub_graph)

In [2]:
import numpy as np
import torch
from dhg import Graph
from scipy.io import loadmat
from skimage.segmentation import slic
from skimage.measure import regionprops

class DATA:
    def __init__(self, path1, path2, TYPE=0):
        self.PATH_TO_IMAGE = path1
        self.PATH_TO_LABEL = path2
        self.graph = None
        self.image = None
        self.label = None
        self.x = None  
        self.edge_index = None  
        self.y = None  
        self.num_superpixels = 90
        self.compactness = 1
        self.loadMAT_file()
        if TYPE==0:
            self.image_to_graph()
        if TYPE==1:
            self.image_to_superpixel_graph()

    def loadMAT_file(self):
        # Load MAT files
        self.image = loadmat(self.PATH_TO_IMAGE)['paviaU']
        self.label = loadmat(self.PATH_TO_LABEL)['paviaU_gt']

        # Check if images were loaded correctly
        if self.image is None or self.label is None:
            raise ValueError("Error loading MAT files. Check file paths!")

        # Handle missing values
        self.missing_values(self.image, "image")
        self.missing_values(self.label, "label")

        # Convert types
        self.image = np.array(self.image, dtype=np.float32)
        self.label = np.array(self.label, dtype=np.int64)

    def missing_values(self, array, name):
        if np.isnan(array).any():
            print(f"Warning: {name} contains NaN values. Replacing with 0.")
            array[np.isnan(array)] = 0
        if np.isinf(array).any():
            print(f"Warning: {name} contains infinite values. Replacing with 0.")
            array[np.isinf(array)] = 0

    def image_to_graph(self):
        offsets = [
            (-1, -1), (-1, 0), (-1, 1),
            (0, -1),           (0, 1),
            (1, -1),  (1, 0), (1, 1)
        ]

        # Get dimensions
        height, width, channels = self.image.shape if self.image.ndim == 3 else (*self.image.shape, 1)
        self.x = self.image.reshape(-1, channels)  

        # Create an edge list
        edge_list = []
        node_indices = np.arange(height * width).reshape(height, width)

        for i in range(height):
            for j in range(width):
                current_idx = node_indices[i, j]
                for offset in offsets:
                    ni, nj = i + offset[0], j + offset[1]
                    if 0 <= ni < height and 0 <= nj < width:
                        neighbor_idx = node_indices[ni, nj]
                        edge_list.append((current_idx, neighbor_idx))

        if not edge_list:
            raise ValueError("Error: Edge list is empty. Check your image loading!")

        # Convert to PyTorch tensor
        self.edge_index = torch.tensor(edge_list, dtype=torch.long).T  

        # Convert labels to tensor
        self.y = torch.tensor(self.label.flatten(), dtype=torch.long)

        # Create train and test masks (80% train, 20% test)
        self.train_mask = torch.rand(len(self.y)) < 0.8  
        self.test_mask = ~self.train_mask

        # Debugging print
        print(f"Creating graph with {height * width} nodes and {len(edge_list)} edges.")

        # Convert edge_list to list of tuples (important for dhg.Graph)
        edge_list = [tuple(edge) for edge in edge_list]  

        # Create DHG Graph
        self.graph = Graph(num_v=height * width)
        self.graph.add_edges(edge_list)
        
    def image_to_superpixel_graph(self):
        height, width, channels = self.image.shape if self.image.ndim == 3 else (*self.image.shape, 1)
        
        # Apply SLIC Superpixel Segmentation
        superpixels = slic(self.image, n_segments=self.num_superpixels, compactness=self.compactness, start_label=0)
        
        # Get unique superpixel IDs
        unique_superpixels = np.unique(superpixels)
        num_nodes = len(unique_superpixels)
        
        # Compute node features (average spectral signature per superpixel)
        self.x = np.zeros((num_nodes, channels), dtype=np.float32)
        for sp in unique_superpixels:
            mask = superpixels == sp
            self.x[sp] = np.mean(self.image[mask], axis=0)
        
        # Create adjacency list (graph edges)
        edge_list = []
        for sp in unique_superpixels:
            mask = superpixels == sp
            for region in regionprops(mask.astype(int)):
                for coord in region.coords:
                    i, j = coord
                    for ni, nj in [(i-1, j), (i+1, j), (i, j-1), (i, j+1)]:
                        if 0 <= ni < height and 0 <= nj < width:
                            neighbor_sp = superpixels[ni, nj]
                            if neighbor_sp != sp:
                                edge_list.append((sp, neighbor_sp))
        
        # Convert to PyTorch tensor
        self.edge_index = torch.tensor(edge_list, dtype=torch.long).T  
        
        # Convert labels to superpixel-level labels
        self.y = np.zeros(num_nodes, dtype=np.int64)
        for sp in unique_superpixels:
            mask = superpixels == sp
            labels_in_sp = self.label[mask]
            self.y[sp] = np.argmax(np.bincount(labels_in_sp.flat))  # Assign most common label in superpixel
        
        # Convert to PyTorch tensor
        self.x = torch.tensor(self.x, dtype=torch.float32)
        self.y = torch.tensor(self.y, dtype=torch.long)
        
        # Create train/test masks
        self.train_mask = torch.rand(num_nodes) < 0.8
        self.test_mask = ~self.train_mask
        
        # Create DHG Graph
        self.graph = Graph(num_v=num_nodes)
        self.graph.add_edges(edge_list)
        
        print(f"Created superpixel graph with {num_nodes} nodes and {len(edge_list)} edges.")

    def get_data_for_gcn(self):
        return {
            'x': torch.tensor(self.x, dtype=torch.float32),
            'graph': self.graph,
            'y': self.y,
            'train_mask': self.train_mask,
            'test_mask': self.test_mask,
            'edge_index': None
        }
    