TO NOTE: TO PIPELINE IS VERY LENGTHY (>20 HOURS RUNTIME). TO FULLY REPRODUCE PLEASE ADAPT THE PATHS OF THE FOLDERS ACCORDING TO YOUR LOCAL STRUCTURE. THIS NOTEBOOK SHOWS THE MAIN STEPS TAKEN.

# Data Standardization

## Dataset 1 & 3

In [None]:
import os
from PIL import Image
from typing import Tuple

def crop_center(image: Image.Image, crop_size: Tuple[int, int] = (640, 640)) -> Image.Image:
    width, height = image.size
    crop_margin = 80
    image = image.crop((crop_margin, 0, width - crop_margin*4, height))

    crop_width, crop_height = crop_size
    img_width, img_height = image.size

    left = (img_width - crop_width) // 2
    top = (img_height - crop_height) // 2
    right = left + crop_width
    bottom = top + crop_height

    center_crop = image.crop((left, top, right, bottom))
    return center_crop

def process_images_in_folder(root_folder: str, crop_size: Tuple[int, int] = (640, 640)):
    supported_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp')

    for subdir, _, files in os.walk(root_folder):
        for file in files:
            if file.lower().endswith(supported_extensions):
                file_path = os.path.join(subdir, file)
                try:
                    with Image.open(file_path) as img:
                        cropped_img = crop_center(img, crop_size)
                        new_filename = f"cropped_{file}"
                        save_path = os.path.join(subdir, new_filename)
                        cropped_img.save(save_path)
                        print(f"Saved: {save_path}")
                except Exception as e:
                    print(f"Failed to process {file_path}: {e}")

# Example usage
process_images_in_folder()


## Dataset 2

In [None]:
import os
from PIL import Image
from typing import Tuple

def crop_center(image: Image.Image, crop_size: Tuple[int, int] = (640, 640)) -> Image.Image:
    """
    Crops the center of an image and resizes it to the specified dimensions.
    """
    # First crop the margins
    width, height = image.size
    crop_margin = 80
    image = image.crop((crop_margin, 0, width - crop_margin*4, height))
    
    # Then resize to target size
    return image.resize(crop_size, Image.Resampling.LANCZOS)

def split_by_three(image: Image.Image) -> Tuple[Image.Image, Image.Image, Image.Image]:
    """
    Splits an image into three equal vertical chunks.
    """
    width, height = image.size
    chunk_width = width // 3

    chunk1 = image.crop((0, 0, chunk_width, height))
    chunk2 = image.crop((chunk_width, 0, 2 * chunk_width, height))
    chunk3 = image.crop((2 * chunk_width, 0, width, height))

    return chunk1, chunk2, chunk3

def process_images_in_folder(root_folder: str, crop_size: Tuple[int, int] = (640, 640)):
    supported_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp')

    for subdir, _, files in os.walk(root_folder):
        for file in files:
            if file.lower().endswith(supported_extensions):
                file_path = os.path.join(subdir, file)
                try:
                    with Image.open(file_path) as img:
                        img1, img2, img3 = split_by_three(img)
                        for i, int_img in enumerate([img1, img2, img3]):
                            cropped_img = crop_center(int_img, crop_size)
                            new_filename = f"cropped_{i}_{file}"
                            save_path = os.path.join(subdir, new_filename)
                            cropped_img.save(save_path)
                            print(f"Saved: {save_path}")
                except Exception as e:
                    print(f"Failed to process {file_path}: {e}")

# Example usage
process_images_in_folder()


# Segmentation

## Segment using SegFormer

In [1]:
import os
from PIL import Image
import torch
import numpy as np
from tqdm import tqdm
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation

  from .autonotebook import tqdm as notebook_tqdm
2025-05-23 20:03:29.202551: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1748023409.221709  772778 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1748023409.227485  772778 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1748023409.243207  772778 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1748023409.243224  772778 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1748023409.243226  772778

In [None]:
# === CONFIG ===
IMAGE_DIR = "input_images"
OUTPUT_DIR = "segmented_output"
MODEL_NAME = "nvidia/segformer-b0-finetuned-cityscapes-768-768"

# Cityscapes class mapping
CITYSCAPES_ID2LABEL = {
    0: 'road', 1: 'sidewalk', 2: 'building', 3: 'wall', 4: 'fence',
    5: 'pole', 6: 'traffic_light', 7: 'traffic_sign', 8: 'vegetation', 9: 'terrain',
    10: 'sky', 11: 'person', 12: 'rider', 13: 'car', 14: 'truck',
    15: 'bus', 16: 'train', 17: 'motorcycle', 18: 'bicycle',
}

# === SETUP ===
os.makedirs(OUTPUT_DIR, exist_ok=True)
for class_name in CITYSCAPES_ID2LABEL.values():
    os.makedirs(os.path.join(OUTPUT_DIR, class_name), exist_ok=True)

In [None]:
# Load model
feature_extractor = SegformerFeatureExtractor.from_pretrained(MODEL_NAME)
model = SegformerForSemanticSegmentation.from_pretrained(MODEL_NAME).eval()

In [None]:
def segment_and_save(image_path):
    image = Image.open(image_path).convert("RGB")
    image_np = np.array(image)
    inputs = feature_extractor(images=image, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits  # (1, num_classes, h/4, w/4)
        upsampled_logits = torch.nn.functional.interpolate(
            logits,
            size=image.size[::-1],  # (H, W)
            mode="bilinear",
            align_corners=False
        )
        predicted = upsampled_logits.argmax(dim=1)[0].cpu().numpy()  # (H, W)

    base_name = os.path.splitext(os.path.basename(image_path))[0]

    for class_idx, class_name in CITYSCAPES_ID2LABEL.items():
        mask = (predicted == class_idx).astype(np.uint8)

        if np.any(mask):
            # Apply mask to original image
            masked_img = image_np.copy()
            masked_img[mask == 0] = 0  # Zero out everything except target class

            masked_pil = Image.fromarray(masked_img)
            save_path = os.path.join(OUTPUT_DIR, class_name, f"{base_name}_{class_name}.png")
            masked_pil.save(save_path)


In [None]:
def process_directory(directory):
    """Process all image files in a directory and its subdirectories."""
    for root, _, files in os.walk(directory):
        image_files = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        for filename in tqdm(image_files, desc=f"Processing {os.path.basename(root)}"):
            segment_and_save(os.path.join(root, filename))

# Process all images in IMAGE_DIR and its subdirectories
process_directory(IMAGE_DIR)
print("Done. Masks saved in:", OUTPUT_DIR)


## Remove mostly black pixels

In [None]:
import cv2
import numpy as np
from pathlib import Path
import os

def check_black_pixels(image_path, threshold=0.95):
    """Check if an image has more than threshold% black pixels."""
    img = cv2.imread(str(image_path))
    if img is None:
        print(f"Could not read image: {image_path}")
        return False
    
    # Convert to grayscale
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    
    # Count black pixels (value < 10)
    black_pixels = np.sum(gray < 10)
    total_pixels = gray.size
    
    return (black_pixels / total_pixels) > threshold

segments = CITYSCAPES_ID2LABEL.values()
base_path = ""
# Process all segments and splits
for segment in segments:
    print(f"\nProcessing {segment} segment...")
    dataset_path = Path(base_path) / segment / "final_datasets"
    
    for split in ["train", "test", "val"]:
        split_path = dataset_path / split
        if not split_path.exists():
            continue
            
        countries = [d for d in os.listdir(split_path) if os.path.isdir(split_path / d)]
        
        for country in countries:
            country_path = split_path / country
            print(f"\nChecking {split}/{country}...")
            black_images = []
            
            for img_path in country_path.glob("*.png"):
                if check_black_pixels(img_path):
                    black_images.append(img_path)
            
            if black_images:
                print(f"Found {len(black_images)} images with >95% black pixels in {country}")
                for img_path in black_images:
                    os.remove(img_path)
                    print(f"  - Deleted: {img_path.name}")
            else:
                print(f"No images with >95% black pixels found in {country}")

print("\n✅ Black pixel cleanup complete!")

# Rank by segments

In [None]:
# Initialize dictionary to store image scores
image_scores = {}

# Get all images from main dataset
main_dataset_path = Path(base_path)
for country in os.listdir(main_dataset_path):
    country_path = main_dataset_path / country
    if not country_path.is_dir():
        continue
        
    # Initialize scores for all images in this country
    for img_path in country_path.glob("*.png"):
        image_scores[img_path] = 0

# Check each segment for matching images
for segment in segments:
    segment_path = Path(base_path) / segment 
    if not segment_path.exists():
        continue
        
    for country in os.listdir(segment_path):
        country_segment_path = segment_path / country
        if not country_segment_path.is_dir():
            continue
            
        # For each image in the segment
        for seg_img_path in country_segment_path.glob("*.png"):
            # Extract original image name (remove segment suffix)
            original_name = seg_img_path.stem.split(f"_{segment}")[0] + ".png"
            original_path = main_dataset_path / country / original_name
            
            # If original image exists, increment its score
            if original_path in image_scores:
                image_scores[original_path] += 1


# Sort image scores in ascending order
sorted_scores = sorted(image_scores.items(), key=lambda x: x[1], reverse=True)

# Create final datasets

In [None]:
# Create output directory for final dataset
import shutil


final_dataset_path = Path(base_path) / "final_dataset"
final_dataset_path.mkdir(exist_ok=True)

# Get top 450 images
top_images = sorted_scores[:450]

# Copy top images to final dataset
for img_path, score in tqdm(top_images, desc="Copying top images"):
    # Create country directory if it doesn't exist
    country_dir = final_dataset_path / img_path.parent.name
    country_dir.mkdir(exist_ok=True)
    
    # Copy image to final dataset
    shutil.copy2(img_path, country_dir / img_path.name)

print(f"Created final dataset with {len(top_images)} images")


In [None]:
# Create train, test, val directories
train_path = final_dataset_path / "train"
test_path = final_dataset_path / "test" 
val_path = final_dataset_path / "val"

for path in [train_path, test_path, val_path]:
    path.mkdir(exist_ok=True)

# Get list of all images
all_images = list(final_dataset_path.glob("**/*.png"))
all_images = [img for img in all_images if img.parent.name != "train" and img.parent.name != "test" and img.parent.name != "val"]

# Shuffle images
import random
random.seed(42)
random.shuffle(all_images)

# Calculate split sizes
total_images = len(all_images)
train_size = int(0.8 * total_images)
test_size = int(0.1 * total_images)
val_size = total_images - train_size - test_size

# Split images
train_images = all_images[:train_size]
test_images = all_images[train_size:train_size + test_size]
val_images = all_images[train_size + test_size:]

# Function to copy images to split directories
def copy_to_split(images, split_path):
    for img_path in tqdm(images, desc=f"Copying to {split_path.name}"):
        # Create country directory in split
        country_dir = split_path / img_path.parent.name
        country_dir.mkdir(exist_ok=True)
        
        # Copy image
        shutil.copy2(img_path, country_dir / img_path.name)

# Copy images to respective splits
copy_to_split(train_images, train_path)
copy_to_split(test_images, test_path)
copy_to_split(val_images, val_path)

print(f"Created splits:")
print(f"Train: {len(train_images)} images")
print(f"Test: {len(test_images)} images")
print(f"Val: {len(val_images)} images")


In [None]:
from pathlib import Path
import shutil
from tqdm import tqdm

# Define paths
segmented_base = Path("")  # base with segment_class/country/image.png
original_split_base = Path("")        # where original split is (train/test/val/country/*.png)
segmented_final = Path("")  # destination

# Get list of segment classes
segment_classes = ['road', 'vegetation', 'terrain']

# For each segment class
for segment_class in segment_classes:
    print(f"\nProcessing segment class: {segment_class}")

    for split in ['train', 'test', 'val']:
        split_images = list((original_split_base / split).glob("*/*.png"))  # country/image.png

        for img_path in tqdm(split_images, desc=f"{segment_class} → {split}"):
            country = img_path.parent.name
            img_name = img_path.name
            segmented_img_path = segmented_base / segment_class / country / img_name

            if segmented_img_path.exists():
                dest_dir = segmented_final / segment_class / split / country
                dest_dir.mkdir(parents=True, exist_ok=True)
                shutil.copy2(segmented_img_path, dest_dir / img_name)
            else:
                print(f"⚠️ Missing segmented image: {segmented_img_path}")

    print(f"✅ Done with segment class: {segment_class}")


# Balance classes

In [None]:
def balance_dataset(base_path, dataset_type="countries"):
    """
    Balance datasets by reducing all classes to match the minimum class size
    while maintaining the 80:10:10 train/val/test split ratio
    """
    print(f"\nBalancing {dataset_type} dataset...")
    
    if dataset_type == "countries":
        # Get all segments
        segments = ['road', 'vegetation', 'terrain']
    else:
        segments = os.listdir(base_path)

    for segment in segments:
        print(f"\nProcessing {segment}...")
        
        # Determine paths based on dataset type
        data_path = os.path.join(base_path, segment, "final_datasets")
        
        # Find minimum total class size across all splits
        min_total = float('inf')
        min_class = None
        
        # Get all classes from train directory
        train_path = os.path.join(data_path, "train")
        if not os.path.exists(train_path):
            continue
            
        classes = os.listdir(train_path)
        for class_name in classes:
            total_images = 0
            for split in ['train', 'val', 'test']:
                split_path = os.path.join(data_path, split, class_name)
                if os.path.exists(split_path):
                    total_images += len(os.listdir(split_path))
            
            if total_images < min_total:
                min_total = total_images
                min_class = class_name
        
        if min_total == float('inf'):
            print(f"No valid classes found for {segment}")
            continue
            
        print(f"Minimum total class size: {min_total} ({min_class})")
        
        # Calculate target sizes for each split
        train_target = int(min_total * 0.8)
        val_target = int(min_total * 0.1)
        test_target = int(min_total * 0.1)
        
        # Balance each class
        for class_name in classes:
            for split, target_size in [('train', train_target), ('val', val_target), ('test', test_target)]:
                split_path = os.path.join(data_path, split, class_name)
                if not os.path.exists(split_path):
                    continue
                    
                images = os.listdir(split_path)
                if len(images) > target_size:
                    # Randomly select images to keep
                    images_to_keep = random.sample(images, target_size)
                    
                    # Remove excess images
                    for img in images:
                        if img not in images_to_keep:
                            os.remove(os.path.join(split_path, img))
                    
                    print(f"Reduced {class_name} {split} from {len(images)} to {target_size} images")

# Balance both datasets
balance_dataset(base_path, "countries")

print("\n✅ Dataset balancing complete!")
