# Semantic Segmentation of Persian Garden Images
## Case Study: Fin Garden, Kashan

This notebook implements a semantic segmentation pipeline using a U-Net
architecture (PyTorch) for Persian garden images.

The dataset consists of RGB images and RGB-encoded segmentation masks
annotated in Roboflow.

Classes:
- Trees
- Buildings
- Sky
- Steps
- Cover Plants
- Water


In [3]:
import sys
sys.executable


'/Users/hessam/Documents/GitHub/Persian_Garden_Segmentation_Analysis/venv/bin/python'

In [None]:

# Environment & Import libraries


import os
import time
import warnings

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from PIL import Image
import cv2

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2

import segmentation_models_pytorch as smp

warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  data = fetch_version_info()


Using device: cpu


In [5]:
# Paths & file pairing
IMAGE_DIR = "/Users/hessam/Documents/GitHub/Persian_Garden_Segmentation_Analysis/Data/image"
MASK_DIR = "/Users/hessam/Documents/GitHub/Persian_Garden_Segmentation_Analysis/Data/mask"

image_files = sorted(os.listdir(IMAGE_DIR))
mask_files = sorted(os.listdir(MASK_DIR))

print(f"Images: {len(image_files)}")
print(f"Masks : {len(mask_files)}")

assert len(image_files) == len(mask_files), "Imageâ€“mask count mismatch"


Images: 0
Masks : 0


In [None]:
# Visual sanity check#Checking image-mask alignment and qualitzy by visualizing a sample pair. Adjust the index to view different pairs.

idx = 0

img = Image.open(os.path.join(IMAGE_DIR, image_files[idx])).convert("RGB")
mask = Image.open(os.path.join(MASK_DIR, mask_files[idx])).convert("RGB")

plt.figure(figsize=(12, 4))

plt.subplot(1, 3, 1)
plt.imshow(img)
plt.title("Image")
plt.axis("off")

plt.subplot(1, 3, 2)
plt.imshow(mask)
plt.title("RGB Mask")
plt.axis("off")

plt.subplot(1, 3, 3)
plt.imshow(img)
plt.imshow(mask, alpha=0.5)
plt.title("Overlay")
plt.axis("off")

plt.show()


IndexError: list index out of range

In [6]:
# Unique RGB colors
import numpy as np
mask_np = np.array(mask)
unique_colors = np.unique(mask_np.reshape(-1, 3), axis=0)

print("Unique RGB colors in this mask:")
unique_colors


NameError: name 'mask' is not defined

In [None]:
# This reads all file names inside the folder.
# We filter for common image formats and sort them to ensure consistent pairing with masks. The total count is printed to confirm the dataset size.
image_files = sorted([
    f for f in os.listdir(IMAGE_DIR)
    if f.lower().endswith((".png", ".jpg", ".jpeg"))
])

print("Total Images:", len(image_files))


NameError: name 'os' is not defined

In [None]:
CLASSES = ["tree", "building", "sky", "water", "path"]
n_classes = len(CLASSES)
# Map RGB to class index

In [None]:
n_classes = len(CLASSES)
# Define RGB to class index mapping

In [None]:
from sklearn.model_selection import train_test_split

# First split test
trainval_files, test_files = train_test_split(
    image_files,
    test_size=0.1,
    random_state=42
)

# Then split validation
train_files, val_files = train_test_split(
    trainval_files,
    test_size=0.15,
    random_state=42
)

print("Train Size:", len(train_files))
print("Val Size  :", len(val_files))
print("Test Size :", len(test_files))


In [None]:
# Visualize a sample from the training set to confirm splits and data integrity. Adjust the index to view different samples.
idx = 7  # or any number within range

img_name = image_files[idx]

img = Image.open(os.path.join(IMAGE_DIR, img_name)).convert("RGB")
mask = Image.open(os.path.join(MASK_DIR, img_name)).convert("RGB")

print("Image Size:", np.array(img).shape)
print("Mask Size :", np.array(mask).shape)

plt.figure(figsize=(6,6))
plt.imshow(img)
plt.imshow(mask, alpha=0.6)
plt.title("Image with Mask Overlay")
plt.axis("off")
plt.show()


In [None]:
class SegDataset(Dataset):
    
    def __init__(self, img_path, mask_path, file_list, color_map, mean, std, transform=None):
        self.img_path = img_path
        self.mask_path = mask_path
        self.files = file_list
        self.transform = transform
        self.color_map = color_map
        self.mean = mean
        self.std = std
        
    def __len__(self):
        return len(self.files)
    
    def rgb_to_class(self, mask):
        h, w, _ = mask.shape
        class_mask = np.zeros((h, w), dtype=np.uint8)
        
        for rgb, class_id in self.color_map.items():
            matches = np.all(mask == rgb, axis=-1)
            class_mask[matches] = class_id
            
        return class_mask
    
    def __getitem__(self, idx):
        img_name = self.files[idx]
        
        img = cv2.imread(self.img_path + img_name)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        mask = cv2.imread(self.mask_path + img_name)
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
        
        if self.transform is not None:
            aug = self.transform(image=img, mask=mask)
            img = aug['image']
            mask = aug['mask']
        
        mask = self.rgb_to_class(mask)
        
        img = Image.fromarray(img)
        t = T.Compose([
            T.ToTensor(),
            T.Normalize(self.mean, self.std)
        ])
        img = t(img)
        
        mask = torch.from_numpy(mask).long()
        
        return img, mask