In [1]:
import os
import numpy as np
import pandas as pd
import re
import matplotlib.pyplot as plt
import cv2
import tensorflow as tf

class TerrainDataset:
    def __init__(self, dataset_dir):
        self.dataset_dir = dataset_dir

        # List all files and sort by suffix
        all_files = os.listdir(dataset_dir)
        self.height_images = sorted([file for file in all_files if file.endswith("_h.png")])
        self.separation_images = sorted([file for file in all_files if file.endswith("_i2.png")])
        self.terrain_images = sorted([file for file in all_files if file.endswith("_t.png")])

        # Debug: print counts
        print(f"Found {len(self.height_images)} height images")
        print(f"Found {len(self.separation_images)} separation images")
        print(f"Found {len(self.terrain_images)} terrain images")

        # Ensure all lists have the same length
        if not (len(self.height_images) == len(self.separation_images) == len(self.terrain_images)):
            raise ValueError("Mismatch in the number of height, separation, and terrain images.")

    def __len__(self):
        return len(self.height_images)

    def __getitem__(self, index):
        # Construct paths for images
        height_path = os.path.join(self.dataset_dir, self.height_images[index])
        separation_path = os.path.join(self.dataset_dir, self.separation_images[index])
        terrain_path = os.path.join(self.dataset_dir, self.terrain_images[index])

        # Load images
        height_image = cv2.imread(height_path, cv2.IMREAD_GRAYSCALE)  # Grayscale
        separation_image = cv2.imread(separation_path, cv2.IMREAD_GRAYSCALE)  # Grayscale
        terrain_image = cv2.imread(terrain_path, cv2.IMREAD_COLOR)  # Color

        # Resize images 
        shape = 128  # Example size
        height_image = cv2.resize(height_image, (shape, shape))
        separation_image = cv2.resize(separation_image, (shape, shape))
        terrain_image = cv2.resize(terrain_image, (shape, shape))

        # Return images
        return height_image, separation_image, terrain_image




In [3]:
# Example Usage
dataset_dir = "earth-terrain-height-and-segmentation-map-images"  # Replace with your dataset path
dataset = TerrainDataset(dataset_dir)



Found 5000 height images
Found 5000 separation images
Found 5000 terrain images


In [4]:

# Debug: Print first item
first_item = dataset[0]
print("First item shapes:", [img.shape for img in first_item])


First item shapes: [(128, 128), (128, 128), (128, 128, 3)]
