# Setup
1. In colab: Go to "Runtime" -> "Change runtime type" -> Select "T4 GPU"
2. Install TerraTorch

In [None]:
!pip install terratorch==1.0.1 gdown

In [None]:
import os
import sys
import torch
import gdown
import terratorch
import albumentations
import lightning.pytorch as pl
import matplotlib.pyplot as plt
from pathlib import Path
from terratorch.datamodules import GenericNonGeoSegmentationDataModule

3. Download the dataset from Google Drive

In [None]:
from google.colab import drive
import gdown
import zipfile
import os
import shutil

# Mount Google Drive
drive.mount('/content/drive')
base_dir = '/content/drive/MyDrive/terratorch_S3_CM'  # safer to write in 'MyDrive'
os.makedirs(base_dir, exist_ok = True)

# --- CONFIG ---
shared_url = 'https://drive.google.com/file/d/1JeY917uXpGrHTyuWLvA5A8n4VG2us8gs/view?usp=sharing'
zip_path = "patches_zip.zip"
extract_tmp = "/content/temp_extracted"
final_output_dir = os.path.join(base_dir, "downloaded_data")
# --------------

def download_and_unzip(url, zip_file_path, extract_to, move_to):
    if os.path.exists(zip_file_path):
        print(f"✅ Found existing file: {zip_file_path}. Skipping download.")
    else:
        print(f"⬇️ Downloading from {url} ...")
        gdown.download(url, zip_file_path, quiet=False, fuzzy=True)

    print(f"📦 Extracting to {extract_to} ...")
    os.makedirs(extract_to, exist_ok=True)
    with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
        zip_ref.extractall(extract_to)

    print(f"📁 Moving extracted files to: {move_to}")
    os.makedirs(move_to, exist_ok=True)
    for filename in os.listdir(extract_to):
        shutil.move(os.path.join(extract_to, filename), move_to)

    shutil.rmtree(extract_to)  # clean up
    print("✅ Done.")

download_and_unzip(shared_url, zip_path, extract_tmp, final_output_dir)


In [None]:
import os
import shutil
import random

def merge_and_split_by_date(download_root, output_root, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
    source_folders = [
        os.path.join(download_root, name) for name in os.listdir(download_root)
        if os.path.isdir(os.path.join(download_root, name))
    ]

    # Shuffle and split source folders (by date)
    random.seed(43)  # or any number you like
    random.shuffle(source_folders)

    total = len(source_folders)
    train_end = int(train_ratio * total)
    val_end = train_end + int(val_ratio * total)

    split_map = {
        "train": source_folders[:train_end],
        "val": source_folders[train_end:val_end],
        "test": source_folders[val_end:]
    }

    for split, folders in split_map.items():
        split_dir = os.path.join(output_root, f"{split}-data")
        os.makedirs(split_dir, exist_ok=True)

        reflectance_files = []
        current_index_offset = 0

        for folder in folders:
            files = sorted(os.listdir(folder))
            reflectance = [f for f in files if f.endswith("_reflectance.tif")]

            for filename in reflectance:
                parts = filename.split("_")
                original_index = int(parts[1])
                new_index = current_index_offset + original_index
                timestamp = parts[0]
                patch_size = parts[2]

                new_name = f"{timestamp}_{new_index:04d}_{patch_size}_reflectance.tif"
                shutil.copy(os.path.join(folder, filename), os.path.join(split_dir, new_name))
                reflectance_files.append(new_name)

            for filename in files:
                if filename.endswith("_binary.tif"):
                    parts = filename.split("_")
                    original_index = int(parts[1])
                    new_index = current_index_offset + original_index
                    timestamp = parts[0]
                    patch_size = parts[2]

                    new_name = f"{timestamp}_{new_index:04d}_{patch_size}_binary.tif"
                    shutil.copy(os.path.join(folder, filename), os.path.join(split_dir, new_name))

            current_index_offset = len(reflectance_files)

        print(f"✅ {split.upper()} set: Merged and renamed {len(reflectance_files)} reflectance files into {split_dir}")


In [None]:
import os

def create_split_files_from_existing_dirs(data_root, splits_dir):
    os.makedirs(splits_dir, exist_ok=True)

    for split in ["train", "val", "test"]:
        split_dir = os.path.join(data_root, f"{split}-data")
        if not os.path.isdir(split_dir):
            print(f"⚠️ Missing split directory: {split_dir}")
            continue

        base_names = []
        for filename in os.listdir(split_dir):
            if filename.endswith("_reflectance.tif"):
                base = filename.replace("_reflectance.tif", "")
                base_names.append(base)

        base_names = sorted(list(set(base_names)))

        with open(os.path.join(splits_dir, f"{split}.txt"), "w") as f:
            f.writelines(f"{name}\n" for name in base_names)

        print(f"✅ {split}.txt created with {len(base_names)} entries.")





In [None]:
import rasterio
import numpy as np
import os
import matplotlib.pyplot as plt
import random
from collections import defaultdict

def balance_dataset_by_capping(data_dir,split,threshold = 0.6):
    binary_files = [f for f in os.listdir(data_dir) if f.endswith("_binary.tif")]
    patch_stats = []
    ratios = []

    os.makedirs(f"{base_dir}/plots/data_hists", exist_ok=True)

    print(f"\n🔍 Scanning {data_dir} for class balance...")

    for fname in binary_files:
        path = os.path.join(data_dir, fname)
        with rasterio.open(path) as src:
            mask = src.read(1)

        count_0 = np.sum(mask == 0)
        count_1 = np.sum(mask == 1)
        total = count_0 + count_1

        if total == 0:
            continue

        ratio = count_1 / total
        ratios.append(ratio)

        patch_stats.append({
            "filename": fname,
            "count_0": count_0,
            "count_1": count_1,
            "total": total,
            "ratio": ratio
        })

    # Plot histogram before balancing
    bins = np.arange(0, 1.1, 0.1)
    plt.hist(ratios, bins=bins, color="skyblue", edgecolor="black", align='mid', rwidth=0.8)
    plt.title(f"Class 1 ratio per patch (before balancing) - {os.path.basename(data_dir)}")
    plt.xlabel("Proportion of Class 1 pixels")
    plt.ylabel("Number of patches")
    plt.xticks(bins)
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f"{base_dir}/plots/data_hists/Hist_before_balancing_{split.replace('-data','')}.png")
    plt.show()

    # Filter out high-ratio patches for mean calculation
    filtered_ratios = [r for r in ratios if r <= threshold] # important threshold !
    mean_ratio = np.mean(filtered_ratios)
    print(f"📊 Mean ratio (<={threshold}): {mean_ratio:.4f}")

    # Bin patches by rounded ratio
    ratio_bins = defaultdict(list)
    for patch in patch_stats:
        bin_key = min(np.floor(patch["ratio"] * 10) / 10, 0.9)
        ratio_bins[bin_key].append(patch)

    mean_bin = round(mean_ratio, 1)
    print('mean bin', mean_bin)
    max_bin_size = len(ratio_bins[mean_bin])
    print(f"📦 Max patches per ratio bin (based on mean bin): {max_bin_size}")

    kept_patches = []
    deleted = 0

    for bin_key, patches in ratio_bins.items():
        if len(patches) > max_bin_size:
            random.shuffle(patches)
            kept = patches[:max_bin_size]
            deleted_patches = patches[max_bin_size:]
            for patch in deleted_patches:
                fname = patch["filename"]
                base = fname.replace("_binary.tif", "")
                binary_path = os.path.join(data_dir, fname)
                refl_path = os.path.join(data_dir, base + "_reflectance.tif")

                os.remove(binary_path)
                os.remove(refl_path)
                deleted += 1
        else:
            kept = patches

        kept_patches.extend(kept)

    final_ratios = [p["ratio"] for p in kept_patches]

    plt.hist(final_ratios, bins=bins, color="lightcoral", edgecolor="black", align='mid', rwidth=0.8)
    plt.title(f"Class 1 ratio per patch (after capping) - {os.path.basename(data_dir)}")
    plt.xlabel("Proportion of Class 1 pixels")
    plt.ylabel("Number of patches")
    plt.xticks(bins)
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f"{base_dir}/plots/data_hists/Hist_after_balancing_{split.replace('-data','')}.png")
    plt.show()

    print(f"🗑 Deleted {deleted} patches from {os.path.basename(data_dir)}")
    print(f"✅ Final dataset size: {len(kept_patches)} patches")




In [None]:
download_root = base_dir+"/downloaded_data"
merged_base = base_dir + "/merged"
output_root = os.path.join(merged_base, "data")
os.makedirs(output_root, exist_ok = True)

# Delete all subdirectories in output_root
for entry in os.listdir(output_root):
    path = os.path.join(output_root, entry)
    if os.path.isdir(path):
        shutil.rmtree(path)
        print(f"🗑️ Deleted directory: {path}")

merge_and_split_by_date(download_root, output_root, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15)

# Apply to all splits
data_root = output_root
for split in ["train-data", "val-data", "test-data"]:
    split_dir = os.path.join(data_root, split)
    if os.path.isdir(split_dir):
        balance_dataset_by_capping(split_dir, split, threshold = 0.5)


# Example usage:
splits_dir = merged_base+"/data/splits"
create_split_files_from_existing_dirs(data_root, splits_dir)



In [None]:
import os
import glob
import numpy as np
import rasterio

# Computing the mean and stds only on the training data is best practise to avoid information leakage.

# Directory containing the reflectance images
data_dir = merged_base + "/data/train-data"
output_path = os.path.join(os.path.dirname(data_dir), "means_stds.txt")

# Find all *_reflectance.tif files
file_paths = glob.glob(os.path.join(data_dir, "*_reflectance.tif"))

# Initialize accumulators
band_sums = np.zeros(6, dtype=np.float64)
band_squared_sums = np.zeros(6, dtype=np.float64)
pixel_counts = np.zeros(6, dtype=np.int64)

# Loop through files and compute running sums
for path in file_paths:
    with rasterio.open(path) as src:
        img = src.read()  # Shape: (bands, height, width)

        if img.shape[0] != 6:
            print(f"⚠️ Skipping {path} — expected 6 bands, got {img.shape[0]}")
            continue

        # Flatten each band and accumulate stats
        for b in range(6):
            band_data = img[b].astype(np.float64)

            if np.isnan(band_data).any():
                print(f"⚠️ NaNs found in file: {path}, band: {b}")

            mask = ~np.isnan(band_data)  # Mask out NaNs
            band_sums[b] += band_data[mask].sum()
            band_squared_sums[b] += (band_data[mask] ** 2).sum()
            pixel_counts[b] += mask.sum()

# Calculate per-band mean and std
means = band_sums / pixel_counts
stds = np.sqrt(band_squared_sums / pixel_counts - means ** 2)

# Save to file
with open(output_path, "w") as f:
    f.write("Band\tMean\t\tStd\n")
    f.write("-" * 30 + "\n")
    for i in range(6):
        f.write(f"{i+1}\t{means[i]:.6f}\t{stds[i]:.6f}\n")

print(f"\n✅ Per-band statistics saved to: {output_path}")
