Importing Libraries

In [None]:
!pip install rasterio



Loading Drive

In [None]:
from google.colab import drive
import os
import shutil

# Check if the mountpoint exists and is not empty
mountpoint = '/content/drive'
if os.path.exists(mountpoint) and os.listdir(mountpoint):
  print(f"Mountpoint {mountpoint} is not empty. Clearing contents...")
  for item in os.listdir(mountpoint):
    item_path = os.path.join(mountpoint, item)
    try:
      if os.path.isfile(item_path) or os.path.islink(item_path):
        os.unlink(item_path)
      elif os.path.isdir(item_path):
        shutil.rmtree(item_path)
    except Exception as e:
      print(f"Error removing {item_path}: {e}")

drive.mount(mountpoint)

Mounted at /content/drive


Preprocessing Satellite Images and thier Respective Color Masks

In [None]:
import os
import numpy as np
from PIL import Image
import rasterio
from tqdm import tqdm

# === 1. Define Input and Output Directories ===
input_dir = "/content/drive/MyDrive/deep_globe/train"
output_dir = "/content/drive/MyDrive/deep_globe/processed_npy"
os.makedirs(output_dir, exist_ok=True)

# === 2. Define RGB Color → Class Index Mapping ===
color_map = {
    (0, 255, 255): 0,     # Urban land
    (255, 255, 0): 1,     # Agriculture land
    (255, 0, 255): 2,     # Rangeland
    (0, 255, 0): 3,       # Forest land
    (0, 0, 255): 4,       # Water
    (255, 255, 255): 5,   # Barren land
    (0, 0, 0): 6          # Unknown
}

# === 3. Convert RGB mask to class-indexed mask ===
def convert_mask_to_label(mask_rgb):
    label_mask = np.zeros((mask_rgb.shape[0], mask_rgb.shape[1]), dtype=np.uint8)
    for color, class_id in color_map.items():
        matches = np.all(mask_rgb == color, axis=-1)
        label_mask[matches] = class_id
    return label_mask

# === 4. Process All Files ===
for file in tqdm(sorted(os.listdir(input_dir))):
    if "_sat" in file and file.lower().endswith((".tif", ".jpg", ".png")):
        base = file.replace("_sat.tif", "").replace("_sat.jpg", "").replace("_sat.png", "")
        sat_path = os.path.join(input_dir, file)
        mask_path = os.path.join(input_dir, f"{base}_mask.png")

        if not os.path.exists(mask_path):
            print(f"Mask not found for: {file}")
            continue

        # Read satellite image using Rasterio
        try:
            with rasterio.open(sat_path) as src:
                img = src.read()  # shape: (bands, H, W)
                img = np.transpose(img, (1, 2, 0))  # (H, W, C)
        except Exception as e:
            print(f"Could not read image {sat_path}: {e}")
            continue

        # Normalize image
        img = img.astype(np.float32)
        if img.max() > 1.0:
            img /= 255.0

        # Read and convert mask
        try:
            mask_rgb = np.array(Image.open(mask_path).convert("RGB"))
            label_mask = convert_mask_to_label(mask_rgb)
        except Exception as e:
            print(f"Could not read/convert mask {mask_path}: {e}")
            continue

        # Save .npy files
        np.save(os.path.join(output_dir, f"{base}_image.npy"), img)
        np.save(os.path.join(output_dir, f"{base}_mask.npy"), label_mask)

  dataset = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)
100%|██████████| 1606/1606 [49:23<00:00,  1.84s/it]
