In [None]:
import torch
from torch_geometric.data import Data
from tqdm import tqdm
import os
from PIL import Image

from utils import *
import torch
import torchvision.models as models
from PIL import Image
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt


# Build the PyTorch Geometric dataset using grid-based approach
def build_dataset(dataset_path, output_path, nb_per_class=200,apply_transform=True):
    dataset = []
    class_folders = [d for d in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, d))]

    for label, class_folder in enumerate(class_folders):
        pbar = tqdm(len(class_folders))
        pbar.set_description(f"Contructing graph data for Class #{label}: {class_folder} ... ")
        class_path = os.path.join(dataset_path, class_folder)
        if nb_per_class==0:
          image_files = shuffle_dataset([f for f in os.listdir(class_path) if f.lower().endswith(('.png', '.jpg', '.jpeg','.tiff'))])
        else:
          image_files = shuffle_dataset([f for f in os.listdir(class_path) if f.lower().endswith(('.png', '.jpg', '.jpeg','.tiff'))])[:nb_per_class]
        a = 1
        for img_file in image_files:
            img_path = os.path.join(class_path, img_file)
            data = image_to_graph(img_path, label,apply_transform)
            dataset.append(data)
            if a < 3:
                print(f"Graph {a} for Class #{label} ({class_folder}): {data} \n")
                plot_image_with_nodes(img_path, data, f"{config['param']['result_folder']}/ImageAndGraph/{label}/{a}")
                a += 1
        pbar.set_description(f"Contructed {len(image_files)} graphs  for Class #{label}: {class_folder} ")
        pbar.update(1)
    torch.save(dataset, output_path)


def image_to_graph(img_path, label,apply_transforms=True):
    img = Image.open(img_path).convert('RGB')

    if apply_transforms:
        transform_pipeline= transform(type_data="train")
        img = transform_pipeline(img)
    else:
        transform_pipeline = transform(type_data="test")
        img = transform_pipeline(img)
        # img = torch.from_numpy(np.transpose(img, (2, 0, 1))).to(dtype=torch.float)
    print(f"Image shape: {img.shape}")
    x, edge_index = get_node_features_and_edge_list(img)
    y = torch.tensor([label], dtype=torch.long)
    return Data(x=x, edge_index=edge_index, y=y, image_features=img.view(-1, 3))


def get_node_features_and_edge_list(image):
    """
    Convert an image into a graph representation using an edge list.
    Parameters:
        image (torch.Tensor): An input image of shape (height, width, channels),
                              where 'channels' is typically 3 for RGB images.
    Returns:
        edges (list of tuple): List of edges, where each edge is a tuple (node1, node2).
        node_features (torch.Tensor): Node features of the graph of size (num_pixels, channels).
    """

    image = validate_image(image)
    channels, height, width = image.shape

    # Flatten the image into a list of nodes
    node_features = image.view(-1, channels)  # Shape: (num_pixels, channels)

    # Create an edge list for the graph
    edges = compute_edges(height, width)

    return  node_features, edges


def validate_image(image):

    if image.dim() != 3:
        raise ValueError("Input image must have 3 dimensions: (height, width, channels)")
    return image


def compute_edges(height, width):
    """Generate an edge list for a graph based on 8-connectivity."""
    edges = []
    for i in range(height):
        for j in range(width):
            current_index = pixel_to_index(i, j, width)
            neighbors = [
                (i - 1, j),  # Top
                (i + 1, j),  # Bottom
                (i, j - 1),  # Left
                (i, j + 1),  # Right
                (i - 1, j - 1),  # Top-left
                (i - 1, j + 1),  # Top-right
                (i + 1, j - 1),  # Bottom-left
                (i + 1, j + 1),  # Bottom-right
            ]
            # Collect valid edges
            for ni, nj in neighbors:
                if 0 <= ni < height and 0 <= nj < width:  # Ensure within bounds
                    neighbor_index = pixel_to_index(ni, nj, width)
                    edges.append((current_index, neighbor_index))
    return torch.tensor(edges,dtype=torch.long)



def pixel_to_index(x, y, width):
    """Convert 2D pixel coordinates to a flattened index."""
    return x * width + y



In [None]:
from PIL import Image


# Function to predict a single image
def predict_single_image(model, image_path, transform, device, class_names):
    """
    Predict the class of a single image using the trained model.

    Args:
        model (torch.nn.Module): Trained model.
        image_path (str): Path to the test image.
        transform (torchvision.transforms.Compose): Transformations to apply.
        device (torch.device): Device to run the prediction on.
        class_names (list): List of class names corresponding to the trained labels.

    Returns:
        str: Predicted class name.
    """
    # Load the image
    image = Image.open(image_path).convert("RGB")  # Ensure image is RGB
    image_tensor = transform(image).unsqueeze(0)  # Add batch dimension

    # Move input to device
    image_tensor = image_tensor.to(device)

    # Set model to evaluation mode and perform inference
    model.eval()
    with torch.no_grad():
        outputs = model(image_tensor)
        _, predicted = torch.max(outputs, 1)  # Get the index of the max log-probability

    # Get the class name
    predicted_class = class_names[predicted.item()]
    return predicted_class


In [None]:
from concurrent.futures import ProcessPoolExecutor
from threading import Lock
import os
from tqdm import tqdm
import torch


lock = Lock()  # Lock for thread-safe I/O operations


def build_dataset(dataset_path, output_path, nb_per_class=200, apply_transform=True):
    IMAGE_EXTENSIONS = ('.png', '.jpg', '.jpeg', '.tiff')  # File extension constant
    dataset = []
    class_folders = [d for d in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, d))]

    def process_image(img_path, label, class_folder, graph_counter):
        # Process an individual image
        graph_data = image_to_graph(img_path, label, apply_transform)

        # Lock for printing and plotting to avoid concurrency issues
        with lock:
            if graph_counter <= 2:
                print(f"Graph {graph_counter} for Class #{label} ({class_folder}): {graph_data} \n")
                plot_image_with_nodes(
                    img_path,
                    graph_data,
                    f"{config['param']['result_folder']}/ImageAndGraph/{label}/{graph_counter}"
                )
        return graph_data

    def process_class_data(class_folder, label):
        class_path = os.path.join(dataset_path, class_folder)
        image_files = shuffle_dataset([
            f for f in os.listdir(class_path) if f.lower().endswith(IMAGE_EXTENSIONS)
        ])
        if nb_per_class > 0:
            image_files = image_files[:nb_per_class]

        # Parallel processing of images in the class
        with ProcessPoolExecutor() as executor:
            results = list(tqdm(
                executor.map(lambda img_file: process_image(
                    os.path.join(class_path, img_file), label, class_folder, image_files.index(img_file) + 1),
                             image_files
                             ),
                total=len(image_files),
                desc=f"Processing Class {label}"
            ))
        return results

    # Iterate over classes and process them
    for label, class_folder in enumerate(tqdm(class_folders, desc="Processing classes")):
        class_results = process_class_data(class_folder, label)
        dataset.extend(class_results)  # Aggregate results from parallel processing

    # Save the dataset after processing
    torch.save(dataset, output_path)




In [1]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data, DataLoader

# Define the GNN Model
class YourGNNModel(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(YourGNNModel, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.fc = torch.nn.Linear(hidden_channels * 2, out_channels)  # Adjust input due to concatenation

    def forward(self, x, edge_index):
        # First convolutional layer
        x1 = self.conv1(x, edge_index)
        x1 = F.relu(x1)

        # Second convolutional layer
        x2 = self.conv2(x1, edge_index)
        x2 = F.relu(x2)

        # Concatenate tensors from both layers
        x_concat = torch.cat([x1, x2], dim=1)

        # Fully connected layer
        out = self.fc(x_concat)
        return F.log_softmax(out, dim=1)  # Log-softmax for classification


# Create a dummy dataset
def build_dummy_data():
    """
    Create a simple example graph for testing the GNN model.
    4 nodes with 2 classes and a basic edge structure.
    Returns: Torch Geometric Data object
    """
    # Node features (4 nodes, 3 input features)
    x = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 1, 1]], dtype=torch.float)

    # Edge list (connectivity of the graph: 0->1, 1->2, 2->3, 3->0)
    edge_index = torch.tensor([
        [0, 1, 2, 3, 0, 2],  # Source nodes
        [1, 2, 3, 0, 2, 0]   # Target nodes
    ], dtype=torch.long)

    # Node labels (for 2 classes)
    y = torch.tensor([0, 1, 0, 1], dtype=torch.long)

    # Create the graph data
    data = Data(x=x, edge_index=edge_index, y=y)
    return data


# Training the GNN model on dummy data
def train_gnn():
    # Hyperparameters
    in_channels = 3  # Features per node
    hidden_channels = 4  # Hidden layer size
    out_channels = 2  # Number of output classes
    epochs = 50
    learning_rate = 0.01

    # Model and optimizer
    model = YourGNNModel(in_channels, hidden_channels, out_channels)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Create a dummy dataset
    data = build_dummy_data()

    # Training loop
    model.train()  # Set the model to training mode
    for epoch in range(epochs):
        optimizer.zero_grad()  # Reset gradients

        # Forward pass
        out = model(data.x, data.edge_index)

        # Compute the loss (negative log likelihood loss for classification)
        loss = F.nll_loss(out, data.y)
        loss.backward()  # Backpropagate the gradients
        optimizer.step()  # Update parameters with optimizer

        # Print loss every 10 epochs
        if epoch % 10 == 0 or epoch == epochs - 1:
            print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item()}")

    print("Training complete.")


train_gnn()

Epoch 1/50, Loss: 0.7582110166549683
Epoch 11/50, Loss: 0.6493818759918213
Epoch 21/50, Loss: 0.622076153755188
Epoch 31/50, Loss: 0.5875752568244934
Epoch 41/50, Loss: 0.5498992204666138
Epoch 50/50, Loss: 0.5072420239448547
Training complete.


In [46]:
from torchvision import transforms

def transform(type_data="train"):
    # Common preprocessing steps
    preprocessing = [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]

    if type_data == "train":
        # Additional augmentations for training
        augmentations = [
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(45),
            transforms.RandomResizedCrop(
                size=(224, 224),
                scale=(0.8, 1.0),
                ratio=(0.9, 1.1),
                interpolation=transforms.InterpolationMode.BILINEAR
            )
        ]
        return transforms.Compose(augmentations + preprocessing)


In [47]:
import cv2
import numpy as np
import torch
from torch_geometric.data import Data
import torchvision.models as models
from PIL import Image
import os



def build_graph_from_image_harris(image_path, label,label_name,apply_transforms=True, output_path="dataset/test_graph_data.pt",k=0.04, threshold=0.005, edge_type='4-connectivity'):
    """
    Build a PyTorch graph from an image based on Harris corner detection.

    Args:
        image_path (str): Path to the image.
        k (float): Harris detector free parameter for detecting corners.
        threshold (float): Threshold for detecting strong corners.
        edge_type (str): Type of graph edges. Options: '4-connectivity', '8-connectivity'.

    Returns:
        torch_geometric.data.Data: PyTorch geometric data object.
    """
    # Step 1: Load image
    image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)

    if image is None:
        raise ValueError(f"Unable to load image from {image_path}")
    image = cv2.resize(image, (128, 128))

    height, width = image.shape

    # Step 2: Apply Harris corner detection
    harris_response = cv2.cornerHarris(image, blockSize=3, ksize=3, k=k)

    # Step 3: Apply threshold to extract strong corners
    corners = np.zeros_like(harris_response, dtype=np.uint8)
    corners[harris_response > threshold * harris_response.max()] = 1

    # Step 4: Extract (x, y) positions of corner points
    corner_positions = np.argwhere(corners == 1)  # Get row, col indices
    corner_indices = {tuple(pos): idx for idx, pos in enumerate(corner_positions)}
    node_features = corner_positions  # Node features are the (x, y) positions

    # Step 5: Create edges based on neighbor connectivity
    edges = []
    for pos in corner_positions:
        i, j = pos

        # Neighbor coordinate offsets for chosen connectivity
        neighbors = []
        if edge_type == '8-connectivity':
            neighbors = [
                (i - 1, j), (i + 1, j), (i, j - 1), (i, j + 1),
                (i - 1, j - 1), (i - 1, j + 1), (i + 1, j - 1), (i + 1, j + 1)
            ]
        elif edge_type == '4-connectivity':
            neighbors = [
                (i - 1, j), (i + 1, j), (i, j - 1), (i, j + 1)
            ]
        else:
            raise ValueError(f"Invalid edge_type {edge_type}")

        # Process each neighbor
        for ni, nj in neighbors:
            if 0 <= ni < height and 0 <= nj < width and corners[ni, nj] == 1:  # Within bounds and valid corner
                edges.append((corner_indices[tuple(pos)], corner_indices[(ni, nj)]))

    # Convert edges to PyTorch tensor
    edge_index = torch.tensor(edges, dtype=torch.long).t() if edges else torch.empty((2, 0), dtype=torch.long)

    # Convert corner positions to PyTorch tensor as node features
    x = torch.tensor(node_features, dtype=torch.float)

    foundation_model = models.densenet121(weights='DenseNet121_Weights.IMAGENET1K_V1')
    feature_extractor = torch.nn.Sequential(*list(foundation_model.features.children()))
    feature_extractor.eval()
    image = Image.open(image_path).convert('RGB')

    if apply_transforms:
        transform_pipeline = transform(type_data="train")
        img = transform_pipeline(image).unsqueeze(dim=0)
    else:
        transform_pipeline = transform(type_data="test")
        img = transform_pipeline(image).unsqueeze(dim=0)


    with torch.no_grad():
        features = feature_extractor(img)

    img_features = torch.flatten(features, start_dim=1)



    # Return PyTorch geometric Data object
    data = Data(x=x, edge_index=edge_index, y=label, image_features=img_features,label_name=label_name)
    torch.save(data, output_path)

    return data



In [49]:
def plot_image_with_nodes(img_path, data, output_folder="GraphImage"):
    # Create the output folder if it doesn't exist
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    # Load the image
    img = io.imread(img_path)

    # Initialize the graph
    G = nx.Graph()

    for i, feat in enumerate(data.x.cpu().numpy()):
        x=feat[0]
        y=feat[1]
        G.add_node(i, pos=(x, y))

    for edge in data.edge_index.t().cpu().numpy():
        G.add_edge(edge[0], edge[1])

    # Create a plot
    plt.figure(figsize=(10, 5))
    plt.subplot(121)
    plt.imshow(img)
    plt.axis('off')

    plt.subplot(122)
    pos = nx.get_node_attributes(G, 'pos')
    nx.draw(G, pos, with_labels=False, node_size=10, node_color='r')
    plt.axis('off')

    # Extract the image name without extension and create the output path
    img_name = os.path.splitext(os.path.basename(img_path))[0]
    output_path = os.path.join(output_folder, "graph.png")

    # Save the plot
    plt.savefig(output_path, bbox_inches='tight')
    plt.close()


In [50]:
# Example usage
im=r"C:\Users\au783153\Documents\OBM\CODES\HeathlandSpeciesClassifier\dataset\images\test\amm\im63_64_4.1.0.jpg"
graph_data = build_graph_from_image_harris(im,label=2, label_name="tree", k=0.04, threshold=0.1)
print(graph_data)
plot_image_with_nodes(im, graph_data)

Data(x=[536, 2], edge_index=[2, 1038], y=2, image_features=[1, 50176], label_name='tree')


In [None]:
import torch
f=torch.load(r"C:\Users\au783153\Documents\OBM\CODES\HeathlandSpeciesClassifier\dataset\graphs\grid\train\0_68.pt")
f