In [None]:
import torch
from torch_geometric.data import Data
from tqdm import tqdm
import os
from PIL import Image

from utils import *
import torch
import torchvision.models as models
from PIL import Image
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt


# Build the PyTorch Geometric dataset using grid-based approach
def build_dataset(dataset_path, output_path, nb_per_class=200,apply_transform=True):
    dataset = []
    class_folders = [d for d in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, d))]

    for label, class_folder in enumerate(class_folders):
        pbar = tqdm(len(class_folders))
        pbar.set_description(f"Contructing graph data for Class #{label}: {class_folder} ... ")
        class_path = os.path.join(dataset_path, class_folder)
        if nb_per_class==0:
          image_files = shuffle_dataset([f for f in os.listdir(class_path) if f.lower().endswith(('.png', '.jpg', '.jpeg','.tiff'))])
        else:
          image_files = shuffle_dataset([f for f in os.listdir(class_path) if f.lower().endswith(('.png', '.jpg', '.jpeg','.tiff'))])[:nb_per_class]
        a = 1
        for img_file in image_files:
            img_path = os.path.join(class_path, img_file)
            data = image_to_graph(img_path, label,apply_transform)
            dataset.append(data)
            if a < 3:
                print(f"Graph {a} for Class #{label} ({class_folder}): {data} \n")
                plot_image_with_nodes(img_path, data, f"{config['param']['result_folder']}/ImageAndGraph/{label}/{a}")
                a += 1
        pbar.set_description(f"Contructed {len(image_files)} graphs  for Class #{label}: {class_folder} ")
        pbar.update(1)
    torch.save(dataset, output_path)


def image_to_graph(img_path, label,apply_transforms=True):
    img = Image.open(img_path).convert('RGB')

    if apply_transforms:
        transform_pipeline= transform(type_data="train")
        img = transform_pipeline(img)
    else:
        transform_pipeline = transform(type_data="test")
        img = transform_pipeline(img)
        # img = torch.from_numpy(np.transpose(img, (2, 0, 1))).to(dtype=torch.float)
    print(f"Image shape: {img.shape}")
    x, edge_index = get_node_features_and_edge_list(img)
    y = torch.tensor([label], dtype=torch.long)
    return Data(x=x, edge_index=edge_index, y=y, image_features=img.view(-1, 3))


def get_node_features_and_edge_list(image):
    """
    Convert an image into a graph representation using an edge list.
    Parameters:
        image (torch.Tensor): An input image of shape (height, width, channels),
                              where 'channels' is typically 3 for RGB images.
    Returns:
        edges (list of tuple): List of edges, where each edge is a tuple (node1, node2).
        node_features (torch.Tensor): Node features of the graph of size (num_pixels, channels).
    """

    image = validate_image(image)
    channels, height, width = image.shape

    # Flatten the image into a list of nodes
    node_features = image.view(-1, channels)  # Shape: (num_pixels, channels)

    # Create an edge list for the graph
    edges = compute_edges(height, width)

    return  node_features, edges


def validate_image(image):

    if image.dim() != 3:
        raise ValueError("Input image must have 3 dimensions: (height, width, channels)")
    return image


def compute_edges(height, width):
    """Generate an edge list for a graph based on 8-connectivity."""
    edges = []
    for i in range(height):
        for j in range(width):
            current_index = pixel_to_index(i, j, width)
            neighbors = [
                (i - 1, j),  # Top
                (i + 1, j),  # Bottom
                (i, j - 1),  # Left
                (i, j + 1),  # Right
                (i - 1, j - 1),  # Top-left
                (i - 1, j + 1),  # Top-right
                (i + 1, j - 1),  # Bottom-left
                (i + 1, j + 1),  # Bottom-right
            ]
            # Collect valid edges
            for ni, nj in neighbors:
                if 0 <= ni < height and 0 <= nj < width:  # Ensure within bounds
                    neighbor_index = pixel_to_index(ni, nj, width)
                    edges.append((current_index, neighbor_index))
    return torch.tensor(edges,dtype=torch.long)



def pixel_to_index(x, y, width):
    """Convert 2D pixel coordinates to a flattened index."""
    return x * width + y



In [None]:
from PIL import Image


# Function to predict a single image
def predict_single_image(model, image_path, transform, device, class_names):
    """
    Predict the class of a single image using the trained model.

    Args:
        model (torch.nn.Module): Trained model.
        image_path (str): Path to the test image.
        transform (torchvision.transforms.Compose): Transformations to apply.
        device (torch.device): Device to run the prediction on.
        class_names (list): List of class names corresponding to the trained labels.

    Returns:
        str: Predicted class name.
    """
    # Load the image
    image = Image.open(image_path).convert("RGB")  # Ensure image is RGB
    image_tensor = transform(image).unsqueeze(0)  # Add batch dimension

    # Move input to device
    image_tensor = image_tensor.to(device)

    # Set model to evaluation mode and perform inference
    model.eval()
    with torch.no_grad():
        outputs = model(image_tensor)
        _, predicted = torch.max(outputs, 1)  # Get the index of the max log-probability

    # Get the class name
    predicted_class = class_names[predicted.item()]
    return predicted_class


In [None]:
from concurrent.futures import ProcessPoolExecutor
from threading import Lock
import os
from tqdm import tqdm
import torch


lock = Lock()  # Lock for thread-safe I/O operations


def build_dataset(dataset_path, output_path, nb_per_class=200, apply_transform=True):
    IMAGE_EXTENSIONS = ('.png', '.jpg', '.jpeg', '.tiff')  # File extension constant
    dataset = []
    class_folders = [d for d in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, d))]

    def process_image(img_path, label, class_folder, graph_counter):
        # Process an individual image
        graph_data = image_to_graph(img_path, label, apply_transform)

        # Lock for printing and plotting to avoid concurrency issues
        with lock:
            if graph_counter <= 2:
                print(f"Graph {graph_counter} for Class #{label} ({class_folder}): {graph_data} \n")
                plot_image_with_nodes(
                    img_path,
                    graph_data,
                    f"{config['param']['result_folder']}/ImageAndGraph/{label}/{graph_counter}"
                )
        return graph_data

    def process_class_data(class_folder, label):
        class_path = os.path.join(dataset_path, class_folder)
        image_files = shuffle_dataset([
            f for f in os.listdir(class_path) if f.lower().endswith(IMAGE_EXTENSIONS)
        ])
        if nb_per_class > 0:
            image_files = image_files[:nb_per_class]

        # Parallel processing of images in the class
        with ProcessPoolExecutor() as executor:
            results = list(tqdm(
                executor.map(lambda img_file: process_image(
                    os.path.join(class_path, img_file), label, class_folder, image_files.index(img_file) + 1),
                             image_files
                             ),
                total=len(image_files),
                desc=f"Processing Class {label}"
            ))
        return results

    # Iterate over classes and process them
    for label, class_folder in enumerate(tqdm(class_folders, desc="Processing classes")):
        class_results = process_class_data(class_folder, label)
        dataset.extend(class_results)  # Aggregate results from parallel processing

    # Save the dataset after processing
    torch.save(dataset, output_path)


