# Imports

In [None]:
import os
import cv2
import numpy as np
from tqdm import tqdm
from sklearn.model_selection import train_test_split

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Configuration

In [None]:
locations = [
    "Arugambay",
    "Beruwala",
    "Negombo",
    "Nilavali",
    "Oluvil",
    "Panadura",
    "Unawatunaa",
    "Weligama"
]

# Base directory for your data (inside Google Drive)
base_dir = "/content/drive/MyDrive/DSGP/Datasets"

# Output directory for the final NumPy arrays (also inside Drive)
output_root = "/content/drive/MyDrive/DSGP/split_datasets"

# Splitting ratios
train_ratio = 0.70
val_ratio   = 0.15
test_ratio  = 0.15

# Expected file extensions
image_ext = ".jpg"  # Satellite images
mask_ext  = ".png"  # EPR images

# Target size after cropping to square
resize_shape = (512, 512)  # (width, height)

In [None]:
def crop_center_square(img):
    """Crops the center square from an image."""
    h, w = img.shape[:2]
    min_dim = min(h, w)
    top  = (h - min_dim) // 2
    left = (w - min_dim) // 2
    return img[top:top+min_dim, left:left+min_dim]

In [None]:
def load_and_preprocess(sat_path, epr_path):
    """Loads and preprocesses a satellite image and its corresponding mask."""
    sat_img = cv2.imread(sat_path)
    if sat_img is None:
        return None, None
    sat_img = cv2.cvtColor(sat_img, cv2.COLOR_BGR2RGB)
    sat_img = crop_center_square(sat_img)
    sat_img = cv2.resize(sat_img, resize_shape, interpolation=cv2.INTER_LINEAR)
    sat_img = sat_img.astype(np.float32) / 255.0  # Normalize to [0,1]

    epr_img = cv2.imread(epr_path, cv2.IMREAD_GRAYSCALE)
    if epr_img is None:
        return None, None
    epr_img = crop_center_square(epr_img)
    epr_img = cv2.resize(epr_img, resize_shape, interpolation=cv2.INTER_NEAREST)

    # Convert mask to binary format [0,1] (for FCN-8)
    epr_img = (epr_img > 0).astype(np.uint8)

    return sat_img, epr_img

In [None]:
# Lists to store dataset
train_images, train_masks = [], []
val_images, val_masks = [], []
test_images, test_masks = [], []

In [None]:
# Process each location
for location_name in locations:
    print("\n===============================================================")
    print(f"  Processing location: {location_name}")
    print("===============================================================\n")

    satellite_dir = os.path.join(base_dir, location_name, "Satellite")
    epr_dir = os.path.join(base_dir, location_name, "EPR")

    if not os.path.exists(satellite_dir) or not os.path.exists(epr_dir):
        print(f"[ERROR] Missing data for {location_name}. Skipping.")
        continue

    sat_filenames = sorted([f for f in os.listdir(satellite_dir) if f.endswith(image_ext)])
    valid_pairs = [(f, os.path.splitext(f)[0] + mask_ext) for f in sat_filenames
                   if os.path.exists(os.path.join(epr_dir, os.path.splitext(f)[0] + mask_ext))]

    if len(valid_pairs) == 0:
        print(f"  [WARNING] No valid pairs found for {location_name}, skipping.")
        continue

    sat_files, epr_files = zip(*valid_pairs)

    # Train/Val/Test Split
    train_sat, temp_sat, train_epr, temp_epr = train_test_split(
        sat_files, epr_files, test_size=(1 - train_ratio), shuffle=True, random_state=42
    )

    val_sat, test_sat, val_epr, test_epr = train_test_split(
        temp_sat, temp_epr, test_size=test_ratio / (val_ratio + test_ratio), shuffle=True, random_state=42
    )

    print(f"  -> {location_name} - Train: {len(train_sat)}, Val: {len(val_sat)}, Test: {len(test_sat)}")

    # Function to load and store data
    def process_data(img_list, mask_list, dataset_images, dataset_masks):
        for img_file, mask_file in tqdm(zip(img_list, mask_list), total=len(img_list)):
            img_path = os.path.join(satellite_dir, img_file)
            mask_path = os.path.join(epr_dir, mask_file)

            img, mask = load_and_preprocess(img_path, mask_path)
            if img is not None and mask is not None:
                dataset_images.append(img)
                dataset_masks.append(mask)

    # Process training, validation, and test sets
    print(f"  Loading Training Data for {location_name}...")
    process_data(train_sat, train_epr, train_images, train_masks)

    print(f"  Loading Validation Data for {location_name}...")
    process_data(val_sat, val_epr, val_images, val_masks)

    print(f"  Loading Testing Data for {location_name}...")
    process_data(test_sat, test_epr, test_images, test_masks)

# Convert to NumPy arrays
train_images = np.array(train_images, dtype=np.float32)
train_masks = np.array(train_masks, dtype=np.uint8).reshape(-1, resize_shape[0], resize_shape[1], 1)  # Reshape for FCN-8

val_images = np.array(val_images, dtype=np.float32)
val_masks = np.array(val_masks, dtype=np.uint8).reshape(-1, resize_shape[0], resize_shape[1], 1)

test_images = np.array(test_images, dtype=np.float32)
test_masks = np.array(test_masks, dtype=np.uint8).reshape(-1, resize_shape[0], resize_shape[1], 1)

# Ensure the output directory exists
os.makedirs(output_root, exist_ok=True)

# Save NumPy arrays
np.save(os.path.join(output_root, "train_images.npy"), train_images)
np.save(os.path.join(output_root, "train_masks.npy"), train_masks)
np.save(os.path.join(output_root, "val_images.npy"), val_images)
np.save(os.path.join(output_root, "val_masks.npy"), val_masks)
np.save(os.path.join(output_root, "test_images.npy"), test_images)
np.save(os.path.join(output_root, "test_masks.npy"), test_masks)

print("\n✅ Done! NumPy datasets saved in:")
print(f"   Train: images={train_images.shape}, masks={train_masks.shape}")
print(f"   Val: images={val_images.shape}, masks={val_masks.shape}")
print(f"   Test: images={test_images.shape}, masks={test_masks.shape}")


  Processing location: Arugambay

  -> Arugambay - Train: 64, Val: 14, Test: 14
  Loading Training Data for Arugambay...


100%|██████████| 64/64 [00:02<00:00, 23.44it/s]


  Loading Validation Data for Arugambay...


100%|██████████| 14/14 [00:00<00:00, 24.44it/s]


  Loading Testing Data for Arugambay...


100%|██████████| 14/14 [00:00<00:00, 25.47it/s]



  Processing location: Beruwala

  -> Beruwala - Train: 48, Val: 10, Test: 11
  Loading Training Data for Beruwala...


100%|██████████| 48/48 [00:01<00:00, 25.89it/s]


  Loading Validation Data for Beruwala...


100%|██████████| 10/10 [00:00<00:00, 26.95it/s]


  Loading Testing Data for Beruwala...


100%|██████████| 11/11 [00:00<00:00, 24.88it/s]



  Processing location: Negombo

  -> Negombo - Train: 51, Val: 11, Test: 12
  Loading Training Data for Negombo...


100%|██████████| 51/51 [00:02<00:00, 23.92it/s]


  Loading Validation Data for Negombo...


100%|██████████| 11/11 [00:00<00:00, 23.00it/s]


  Loading Testing Data for Negombo...


100%|██████████| 12/12 [00:00<00:00, 23.33it/s]



  Processing location: Nilavali

  -> Nilavali - Train: 53, Val: 11, Test: 12
  Loading Training Data for Nilavali...


100%|██████████| 53/53 [00:02<00:00, 18.55it/s]


  Loading Validation Data for Nilavali...


100%|██████████| 11/11 [00:00<00:00, 17.89it/s]


  Loading Testing Data for Nilavali...


100%|██████████| 12/12 [00:00<00:00, 18.30it/s]



  Processing location: Oluvil

  -> Oluvil - Train: 62, Val: 14, Test: 14
  Loading Training Data for Oluvil...


100%|██████████| 62/62 [00:02<00:00, 24.44it/s]


  Loading Validation Data for Oluvil...


100%|██████████| 14/14 [00:00<00:00, 23.77it/s]


  Loading Testing Data for Oluvil...


100%|██████████| 14/14 [00:00<00:00, 25.05it/s]



  Processing location: Panadura

  -> Panadura - Train: 38, Val: 8, Test: 9
  Loading Training Data for Panadura...


100%|██████████| 38/38 [00:01<00:00, 25.97it/s]


  Loading Validation Data for Panadura...


100%|██████████| 8/8 [00:00<00:00, 24.65it/s]


  Loading Testing Data for Panadura...


100%|██████████| 9/9 [00:00<00:00, 24.45it/s]



  Processing location: Unawatunaa

  -> Unawatunaa - Train: 68, Val: 15, Test: 15
  Loading Training Data for Unawatunaa...


100%|██████████| 68/68 [00:04<00:00, 16.65it/s]


  Loading Validation Data for Unawatunaa...


100%|██████████| 15/15 [00:01<00:00, 13.60it/s]


  Loading Testing Data for Unawatunaa...


100%|██████████| 15/15 [00:01<00:00, 13.43it/s]



  Processing location: Weligama

  -> Weligama - Train: 51, Val: 11, Test: 12
  Loading Training Data for Weligama...


100%|██████████| 51/51 [00:03<00:00, 14.27it/s]


  Loading Validation Data for Weligama...


100%|██████████| 11/11 [00:00<00:00, 14.86it/s]


  Loading Testing Data for Weligama...


100%|██████████| 12/12 [00:00<00:00, 15.78it/s]



✅ Done! NumPy datasets saved in:
   Train: images=(435, 512, 512, 3), masks=(435, 512, 512, 1)
   Val: images=(94, 512, 512, 3), masks=(94, 512, 512, 1)
   Test: images=(99, 512, 512, 3), masks=(99, 512, 512, 1)
