In [3]:
import torch.nn as nn
from skimage import segmentation

In [4]:
class ComputeSegments(nn.Module):
    def __init__(self, segment_algorithm='quickshift',**kwargs):
        super(ComputeSegments, self).__init__()
        self.segment_algorithm = segment_algorithm
        self.kwargs = kwargs

    def forward(self, sample):
        image = sample['image']

        if self.segment_algorithm == 'quickshift':
            segments = segmentation.quickshift(image,**self.kwargs)
        elif self.segment_algorithm == 'slic':
            segments = segmentation.slic(image,**self.kwargs)
        elif self.segment_algorithm == 'felzenszwalb':
            segments = segmentation.felzenszwalb(image,**self.kwargs)
        else:
            raise ValueError("Unsupported segment algorithm. Choose from 'quickshift', 'slic', or 'felzenszwalb'.")

        sample['segments'] = segments
        return sample

In [1]:
import torch
import numpy as np

def image_to_graph(image):
    """
    Convert an image into a graph representation.

    Parameters:
        image (torch.Tensor): An input image of shape (H, W, C), where H is the height,
                              W is the width, and C is the number of channels (e.g., 3 for RGB).

    Returns:
        adjacency_matrix (torch.Tensor): Adjacency matrix of size (N, N) where N = H*W.
        node_features (torch.Tensor): Node features of the graph of size (N, C).
    """
    if isinstance(image, np.ndarray):
        # Convert numpy array to torch tensor if needed
        image = torch.from_numpy(image)

    if image.dim() != 3:
        raise ValueError("Input image must have 3 dimensions: (H, W, C)")

    H, W, C = image.shape
    N = H * W  # Total number of nodes/pixels

    # Flatten the image into a list of nodes
    node_features = image.view(-1, C)  # Shape: (N, C)

    # Initialize adjacency matrix
    adjacency_matrix = torch.zeros((N, N), dtype=torch.float32)

    # Compute edges for all 8-connectivity neighbors
    pixel_to_index = lambda x, y: x * W + y
    for i in range(H):
        for j in range(W):
            current_index = pixel_to_index(i, j)

            # Check all 8 neighbors
            neighbors = [
                (i - 1, j),    # Top
                (i + 1, j),    # Bottom
                (i, j - 1),    # Left
                (i, j + 1),    # Right
                (i - 1, j - 1),  # Top-left
                (i - 1, j + 1),  # Top-right
                (i + 1, j - 1),  # Bottom-left
                (i + 1, j + 1),  # Bottom-right
            ]

            for ni, nj in neighbors:
                if 0 <= ni < H and 0 <= nj < W:  # Ensure within bounds
                    neighbor_index = pixel_to_index(ni, nj)
                    # Connect current pixel to its neighbor
                    adjacency_matrix[current_index, neighbor_index] = 1

    return adjacency_matrix, node_features

# Example usage
if __name__ == "__main__":
    # Example image (5x5x3 RGB image with random values)
    example_image = torch.randint(0, 256, (5, 5, 3), dtype=torch.uint8)

    # Convert to graph
    adjacency_matrix, node_features = image_to_graph(example_image)

    print("Adjacency Matrix:")
    print(adjacency_matrix)
    print("\nNode Features:")
    print(node_features)

Adjacency Matrix:
tensor([[0., 1., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 1., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 1., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 1., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 1., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 1., 0., 1., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 1., 1., 0., 0., 1., 0., 1., 0., 0., 1., 1., 1., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 1., 1., 0

In [8]:
from skimage import io

image= io.imread(r"C:\Users\au783153\Documents\OBM\CODES\HeathlandSpeciesClassifier\dataset\images\train\calluna\im107_104_4.1.0.jpg")

In [12]:
import torch
import numpy as np
from skimage import io


def validate_image(image):
    """Validate the input image."""
    if isinstance(image, np.ndarray):
        # Convert numpy array to torch tensor if needed
        image = torch.from_numpy(image)
    if image.dim() != 3:
        raise ValueError("Input image must have 3 dimensions: (height, width, channels)")
    return image


def pixel_to_index(x, y, width):
    """Convert 2D pixel coordinates to a flattened index."""
    return x * width + y


def compute_edges(height, width):
    """Generate an edge list for a graph based on 8-connectivity."""
    edges = []
    for i in range(height):
        for j in range(width):
            current_index = pixel_to_index(i, j, width)
            neighbors = [
                (i - 1, j),  # Top
                (i + 1, j),  # Bottom
                (i, j - 1),  # Left
                (i, j + 1),  # Right
                (i - 1, j - 1),  # Top-left
                (i - 1, j + 1),  # Top-right
                (i + 1, j - 1),  # Bottom-left
                (i + 1, j + 1),  # Bottom-right
            ]
            # Collect valid edges
            for ni, nj in neighbors:
                if 0 <= ni < height and 0 <= nj < width:  # Ensure within bounds
                    neighbor_index = pixel_to_index(ni, nj, width)
                    edges.append((current_index, neighbor_index))
    return edges


def image_to_graph(image):
    """
    Convert an image into a graph representation using an edge list.
    Parameters:
        image (torch.Tensor): An input image of shape (height, width, channels),
                              where 'channels' is typically 3 for RGB images.
    Returns:
        edges (list of tuple): List of edges, where each edge is a tuple (node1, node2).
        node_features (torch.Tensor): Node features of the graph of size (num_pixels, channels).
    """
    image = validate_image(image)
    height, width, channels = image.shape

    # Flatten the image into a list of nodes
    node_features = image.view(-1, channels)  # Shape: (num_pixels, channels)

    # Create an edge list for the graph
    edges = compute_edges(height, width)

    return edges, node_features


# Example usage
if __name__ == "__main__":
    # Example image (5x5x3 RGB image with random values)
    example_image = image

    # Convert to graph
    edges, node_features = image_to_graph(example_image)

    print("Edge List:")
    print(len(edges))
    print("\nNode Features:")
    print(node_features)


Edge List:
178204

Node Features:
tensor([[ 73,  71,  48],
        [ 72,  70,  47],
        [ 75,  73,  50],
        ...,
        [113, 105,  82],
        [103,  92,  70],
        [120, 108,  86]], dtype=torch.uint8)


In [11]:
import torch
import numpy as np
from skimage import io

# Define constant for neighbor offsets
NEIGHBOR_OFFSETS = [
    (-1, 0),  # Top
    (1, 0),  # Bottom
    (0, -1),  # Left
    (0, 1),  # Right
    (-1, -1),  # Top-left
    (-1, 1),  # Top-right
    (1, -1),  # Bottom-left
    (1, 1),  # Bottom-right
]


def validate_image(input_image):
    """Validate the input image."""
    if isinstance(input_image, np.ndarray):
        # Convert numpy array to torch tensor if needed
        input_image = torch.from_numpy(input_image)
    if input_image.dim() != 3:
        raise ValueError("Input image must have 3 dimensions: (height, width, channels)")
    return input_image


def pixel_to_index(x, y, width):
    """Convert 2D pixel coordinates to a flattened index."""
    return x * width + y


def generate_neighbors(row, col, height, width):
    """Generate valid neighboring indices for a pixel."""
    neighbors = []
    for offset_row, offset_col in NEIGHBOR_OFFSETS:
        neighbor_row, neighbor_col = row + offset_row, col + offset_col
        if 0 <= neighbor_row < height and 0 <= neighbor_col < width:
            neighbors.append((neighbor_row, neighbor_col))
    return neighbors


def compute_edges(height, width):
    """Generate an edge list for a graph based on 8-connectivity."""
    edges = []
    for row in range(height):
        for col in range(width):
            current_index = pixel_to_index(row, col, width)
            for neighbor_row, neighbor_col in generate_neighbors(row, col, height, width):
                neighbor_index = pixel_to_index(neighbor_row, neighbor_col, width)
                edges.append((current_index, neighbor_index))
    return edges


def image_to_graph(image):
    """
    Convert an image into a graph representation using an edge list.
    Parameters:
        image (torch.Tensor): An input image of shape (height, width, channels),
                              where 'channels' is typically 3 for RGB images.
    Returns:
        edges (list of tuple): List of edges, where each edge is a tuple (node1, node2).
        flattened_pixels (torch.Tensor): Node features of the graph of size (num_pixels, channels).
    """
    image = validate_image(image)
    height, width, channels = image.shape
    # Flatten the image into a list of nodes
    flattened_pixels = image.view(-1, channels)  # Shape: (num_pixels, channels)
    # Create an edge list for the graph
    edges = compute_edges(height, width)
    return edges, flattened_pixels


# Example usage
if __name__ == "__main__":
    # Placeholder for example image
    example_image = image
    try:
        # Convert to graph
        edges, flattened_pixels = image_to_graph(example_image)
        print("Edge List:")
        print(len(edges))
        print("\nNode Features:")
        print(flattened_pixels)
    except Exception as e:
        print(f"Error: {e}")


Edge List:
178204

Node Features:
tensor([[ 73,  71,  48],
        [ 72,  70,  47],
        [ 75,  73,  50],
        ...,
        [113, 105,  82],
        [103,  92,  70],
        [120, 108,  86]], dtype=torch.uint8)


In [4]:
e.shape

(150, 150, 3)