In [1]:
import torch

# Check if PyTorch can access the GPU
if torch.cuda.is_available():
    print(f"Success! PyTorch is using GPU: {torch.cuda.get_device_name(0)}")
else:
    print("PyTorch is installed, but it is using the CPU.")

Success! PyTorch is using GPU: NVIDIA GeForce RTX 3050 6GB Laptop GPU


In [None]:
import os
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch_geometric.data import Data
from skimage.segmentation import slic
from skimage.measure import regionprops
from skimage.graph import rag_mean_color
import numpy as np
from tqdm import tqdm

In [3]:
ROOT_DATA_DIR = './'
PROCESSED_DATA_DIR = os.path.join(ROOT_DATA_DIR, 'processed_mnist_graphs')
N_SEGMENTS = 75 # The number of superpixels to segment each image into.

In [9]:
def image_to_graph(image_np, label, n_segments):
    """
    Converts a single image tensor to a PyTorch Geometric graph Data object.
    
    Args:
        image_tensor (torch.Tensor): A 2D tensor representing a grayscale image.
        label (int): The class label of the image.
        n_segments (int): The target number of superpixels.
        
    Returns:
        torch_geometric.data.Data: A graph object with node features, edges, and label.
    """

    # 1. Generate superpixels using SLIC
    segments = slic(image_np, n_segments=n_segments, compactness=10, sigma=1, start_label=1,channel_axis=None)
    
    # 2. Get properties for each superpixel (node)
    regions = regionprops(segments, intensity_image=image_np)
    
    # Initialize lists for node features and centroids
    node_features = []
    node_centroids = []
    
    for props in regions:
        # Node features: mean intensity of the superpixel
        mean_intensity = props.mean_intensity
        
        # Node position: centroid of the superpixel
        centroid = props.centroid # (row, col)
        
        # We normalize centroids to be between 0 and 1
        normalized_centroid = (centroid[0] / image_np.shape[0], centroid[1] / image_np.shape[1])
        
        # Combine features: [intensity, pos_x, pos_y]
        node_features.append([mean_intensity] + list(normalized_centroid))
        node_centroids.append(centroid)

    # Convert node features to a tensor
    x = torch.tensor(node_features, dtype=torch.float)

    # 3. Create edges for adjacent superpixels
    rag = rag_mean_color(image_np, segments)
    edge_index_np = np.array(list(rag.edges))
    
    # Edges need to be adjusted because scikit-image labels start from 1, not 0
    edge_index = torch.tensor(edge_index_np.T, dtype=torch.long) - 1

    # 4. Create the PyTorch Geometric Data object
    data = Data(x=x, edge_index=edge_index, y=torch.tensor(label, dtype=torch.long))
    
    return data


In [10]:
if __name__ == '__main__':
    # Create the directory for processed data if it doesn't exist
    os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)
    
    # Load the MNIST training dataset
    # We don't apply ToTensor immediately because scikit-image needs numpy arrays.
    dataset = MNIST(root=ROOT_DATA_DIR, train=True, download=True)
    
    print(f"Starting conversion of {len(dataset)} MNIST images to graphs...")
    
    # Loop through each image, convert it, and save it
    for i in tqdm(range(len(dataset))):
        image, label = dataset[i]
        
        # The image is a PIL Image, convert to a numpy array
        image_np = np.array(image)
        
        # Convert the numpy image to a graph Data object
        graph_data = image_to_graph(image_np, label, n_segments=N_SEGMENTS)
        
        # Save the processed graph object to a file
        torch.save(graph_data, os.path.join(PROCESSED_DATA_DIR, f'data_{i}.pt'))

    print("-" * 30)
    print(f"Processing complete.")
    print(f"Saved {len(dataset)} graph files in '{PROCESSED_DATA_DIR}'")

Starting conversion of 60000 MNIST images to graphs...


100%|██████████| 60000/60000 [20:37<00:00, 48.50it/s]   

------------------------------
Processing complete.
Saved 60000 graph files in './processed_mnist_graphs'



