In [6]:
import os
import numpy as np
import xarray as xr
from tqdm import tqdm
# === Sensitivity test flag ===
SENSITIVITY_MODE = True
STRICT_NAN_THRESHOLD = 0.7

# Settings
RAW_DATA_DIR = r"data\sensitivity test\raw"
PATCH_SIZE = 32
STRIDE = 16
TIME_WINDOW = 10

# Region-specific NaN thresholds
region_thresholds = {
    "NET": 0.85,
    "CQC": 0.8,
    "WET": 0.75,
}

variables = ["monthly_rain", "max_temp", "min_temp", "radiation", "spi_1", "consecutive_drought"]

def count_sequence_patches(region_file):
    ds = xr.open_dataset(os.path.join(RAW_DATA_DIR, region_file))
    region_name = region_file.split("_Combined_Consecutive.nc")[0]

    # Select threshold
    if SENSITIVITY_MODE:
        nan_ratio_threshold = STRICT_NAN_THRESHOLD
    else:
        nan_ratio_threshold = region_thresholds.get(region_name, 0.4)

    # Pad small regions
    pad_lat = max(0, PATCH_SIZE - ds.sizes['lat'])
    pad_lon = max(0, PATCH_SIZE - ds.sizes['lon'])

    ds = ds.pad(
        lat=(0, pad_lat),
        lon=(0, pad_lon),
        constant_values=np.nan
    )

    total_attempted = 0
    retained = 0
    skipped = 0

    for lat in range(0, ds.sizes['lat'], STRIDE):
        for lon in range(0, ds.sizes['lon'], STRIDE):
            for t in range(0, ds.sizes['time'] - TIME_WINDOW + 1):

                patch_seq = []

                for var in variables[:-1]:  # exclude label
                    data_slice = ds[var].isel(
                        time=slice(t, t + TIME_WINDOW),
                        lat=slice(lat, lat + PATCH_SIZE),
                        lon=slice(lon, lon + PATCH_SIZE)
                    ).values
                    patch_seq.append(data_slice)

                patch_seq = np.stack(patch_seq, axis=-1)
                patch = np.full((TIME_WINDOW, PATCH_SIZE, PATCH_SIZE, len(variables) - 1), np.nan)
                patch[:, :patch_seq.shape[1], :patch_seq.shape[2], :] = patch_seq

                total_attempted += 1

                nan_ratio = np.isnan(patch).mean()
                if nan_ratio > nan_ratio_threshold:
                    skipped += 1
                else:
                    retained += 1

    retention_rate = (retained / total_attempted) * 100 if total_attempted > 0 else 0

    print(
        f"ðŸ“Š {region_name} | "
        f"Threshold={nan_ratio_threshold:.2f} | "
        f"Attempted={total_attempted} | "
        f"Retained={retained} | "
        f"Skipped={skipped} | "
        f"Retention={retention_rate:.2f}%"
    )

if __name__ == "__main__":
    for file in tqdm(os.listdir(RAW_DATA_DIR)):
        if file.endswith("_Combined_Consecutive.nc"):
            count_sequence_patches(file)


 33%|â–ˆâ–ˆâ–ˆâ–Ž      | 1/3 [00:02<00:04,  2.45s/it]

ðŸ“Š CQC | Threshold=0.70 | Attempted=4656 | Retained=291 | Skipped=4365 | Retention=6.25%


 67%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‹   | 2/3 [00:03<00:01,  1.36s/it]

ðŸ“Š NET | Threshold=0.70 | Attempted=1164 | Retained=0 | Skipped=1164 | Retention=0.00%


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 3/3 [00:05<00:00,  1.75s/it]

ðŸ“Š WET | Threshold=0.70 | Attempted=4365 | Retained=582 | Skipped=3783 | Retention=13.33%



