In [None]:
import os
import numpy as np
import xarray as xr
from scipy import stats

# Directories
RAW_DATA_DIR = r"data\raw"
PATCHES_DIR = r"data\patches with padding"
PATCH_SIZE = 32
STRIDE = 16

# Ensure patch directory exists
os.makedirs(PATCHES_DIR, exist_ok=True)

def extract_patches(region_file):
    # Load region data
    ds = xr.open_dataset(os.path.join(RAW_DATA_DIR, region_file))
    region_name = region_file.split("_Combined_Consecutive.nc")[0]
    region_dir = os.path.join(PATCHES_DIR, region_name)
    os.makedirs(region_dir, exist_ok=True)

    # Extract variables
    variables = ["monthly_rain", "max_temp", "min_temp", "radiation", "spi_1", "consecutive_drought"]

    # Iterate over each time step
    for time in range(ds.dims['time']):
        stacked_data = [ds[var].isel(time=time).values for var in variables]
        data = np.stack(stacked_data, axis=-1)

        lat_size, lon_size = data.shape[0], data.shape[1]

        # Patch extraction (with padding if necessary)
        for lat in range(0, lat_size, STRIDE):
            for lon in range(0, lon_size, STRIDE):
                patch = data[lat:lat + PATCH_SIZE, lon:lon + PATCH_SIZE, :]

                # Apply zero-padding for smaller patches
                padded_patch = np.zeros((PATCH_SIZE, PATCH_SIZE, patch.shape[2]))
                padded_patch[:patch.shape[0], :patch.shape[1], :] = patch

                # Extract label using mode of consecutive_drought
                label_patch = padded_patch[:, :, -1]  # consecutive_drought
                patch_label = stats.mode(label_patch.flatten(), keepdims=False).mode

                # Save patch and label
                patch_filename = f"{region_name}_t{time}_{lat}_{lon}.npy"
                label_filename = f"{region_name}_t{time}_{lat}_{lon}_label.npy"

                np.save(os.path.join(region_dir, patch_filename), padded_patch)
                np.save(os.path.join(region_dir, label_filename), patch_label)

if __name__ == "__main__":
    # Process each region
    for file in os.listdir(RAW_DATA_DIR):
        if file.endswith("_Combined_Consecutive.nc"):
            print(f"Processing {file}...")
            extract_patches(file)
            print(f"Finished processing {file}.")


Processing BRB_Combined_Consecutive.nc...


  for time in range(ds.dims['time']):


Finished processing BRB_Combined_Consecutive.nc.
Processing CHC_Combined_Consecutive.nc...


  for time in range(ds.dims['time']):


Finished processing CHC_Combined_Consecutive.nc.
Processing CQC_Combined_Consecutive.nc...


  for time in range(ds.dims['time']):


Finished processing CQC_Combined_Consecutive.nc.
Processing CYP_Combined_Consecutive.nc...


  for time in range(ds.dims['time']):


Finished processing CYP_Combined_Consecutive.nc.
Processing DEU_Combined_Consecutive.nc...


  for time in range(ds.dims['time']):


Finished processing DEU_Combined_Consecutive.nc.
Processing EIU_Combined_Consecutive.nc...


  for time in range(ds.dims['time']):


Finished processing EIU_Combined_Consecutive.nc.
Processing GUP_Combined_Consecutive.nc...


  for time in range(ds.dims['time']):


Finished processing GUP_Combined_Consecutive.nc.
Processing MGD_Combined_Consecutive.nc...


  for time in range(ds.dims['time']):


Finished processing MGD_Combined_Consecutive.nc.
Processing MUL_Combined_Consecutive.nc...


  for time in range(ds.dims['time']):


Finished processing MUL_Combined_Consecutive.nc.
Processing NET_Combined_Consecutive.nc...


  for time in range(ds.dims['time']):


Finished processing NET_Combined_Consecutive.nc.
Processing NWH_Combined_Consecutive.nc...


  for time in range(ds.dims['time']):


Finished processing NWH_Combined_Consecutive.nc.
Processing SEQ_Combined_Consecutive.nc...


  for time in range(ds.dims['time']):


Finished processing SEQ_Combined_Consecutive.nc.
Processing WET_Combined_Consecutive.nc...


  for time in range(ds.dims['time']):


Finished processing WET_Combined_Consecutive.nc.
