# Set Up

In [1]:
from torch_geometric.data import Data
from typing import List
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utils
import torch_geometric.transforms as T

import sklearn.metrics as metrics
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import glob
from PIL import Image

In [2]:
device = None

# check if MPS (Apple Silicon GPU) is available
if torch.backends.mps.is_available():
    device = torch.device("mps")
    x = torch.ones(1, device=device)
# check if CUDA (NVIDIA GPU) is available
elif torch.cuda.is_available():
    device = torch.device("cuda")
    x = torch.ones(1, device=device)
else:
    device = torch.device("cpu")
    print ("MPS and CUDA device not found.")

# Load Data

In [10]:
IMAGE_DIR = "../data/images/"
SEGM_DIR = "../data/segm/"

In [22]:

def get_corresponding_segm_path(image_path):
    base = os.path.basename(image_path)
    name, ext = os.path.splitext(base)
    segm_name = f'{name}_segm.png'
    return os.path.join(SEGM_DIR, segm_name)

def load_image(image_path):
    return np.array(Image.open(image_path).convert('RGB'))

def load_segm(segm_path):
    return np.array(Image.open(segm_path))

skipped = 0
labels_to_exclude = {0, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 22, 23}  # background and unwanted labels
dataset = []
image_paths = glob.glob(os.path.join(IMAGE_DIR, '*'))

for image_path in image_paths:
    
    segm_path = get_corresponding_segm_path(image_path)
    if not os.path.exists(segm_path):
        # print(f'Segmentation file not found for {image_path}, skipping.')
        skipped += 1
        continue
    image = load_image(image_path)
    segm = load_segm(segm_path)
    
    for label in np.unique(segm):
        if label in labels_to_exclude:  # exclude the background and unwanted labels
            continue
        mask = np.where(segm == label, 1, 0).astype(np.uint8)
        dataset.append((image, mask, label))
        
print(f'Total samples in dataset: {len(dataset)}')
print(f'Total skipped images: {skipped}')

Total samples in dataset: 25212
Total skipped images: 31395


In [31]:
print(dataset[0][0].shape, dataset[0][1].shape, dataset[0][2])
print(dataset[20][0].shape, dataset[20][1].shape, dataset[20][2])


(1101, 750, 3) (1101, 750) 1
(1101, 750, 3) (1101, 750) 5


# Model

Hyperparameters

Image Classification Model Class

# Train

# Test