In [1]:
pip install xarray netcdf4 dask

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [2]:
!pip install scipy>=1.12.0 --upgrade -v

In [3]:
pip install tqdm

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [4]:
import numpy as np
import scipy as sp
print(f"NumPy version: {np.__version__}")
print(f"SciPy version: {sp.__version__}")

NumPy version: 2.2.5
SciPy version: 1.15.3


In [5]:
!python -m pip install "dask[distributed]"

Defaulting to user installation because normal site-packages is not writeable


In [6]:
!pip install psutil

Defaulting to user installation because normal site-packages is not writeable


# Creating valid patches valid_patches_under30.npz

In [8]:
import os
import xarray as xr
import numpy as np
import time
from tqdm import tqdm



DATA_FILE = "../Switch2_Fuel_predictors/AWRA-L_e0_monthly_005_Aus.nc"
REFERENCE_MONTH = "2005_01_01"
PATCH_SIZE = 128
THRESHOLD = 30.0
OUTPUT_FILE = "valid_patches_under30.npz"

print("Opening dataset lazily...")
ds = xr.open_dataset(DATA_FILE, chunks={"lat": 1000, "lon": 1000})

start_time = time.time()

print(f"Selecting band '{REFERENCE_MONTH}'...")
data_slice = ds.sel(band=REFERENCE_MONTH)["e0"].load()  # Loads one month's data
ds.close()

lat_dim, lon_dim = data_slice.shape
num_lat_patches = lat_dim // PATCH_SIZE
num_lon_patches = lon_dim // PATCH_SIZE
total_patches = num_lat_patches * num_lon_patches
print(f"Data shape: {lat_dim} x {lon_dim}, total patches: {total_patches}")

patch_nan_percentages = np.zeros((num_lat_patches, num_lon_patches), dtype=np.float32)
data_values = data_slice.values 

print("Analyzing patches for NaN coverage...")
for i in tqdm(range(num_lat_patches), desc="Patch rows"):
    for j in range(num_lon_patches):
        lat_start, lon_start = i * PATCH_SIZE, j * PATCH_SIZE
        patch_data = data_values[lat_start:lat_start + PATCH_SIZE,
                                 lon_start:lon_start + PATCH_SIZE]
        nan_count = np.isnan(patch_data).sum()
        patch_nan_percentages[i, j] = (nan_count / (PATCH_SIZE**2)) * 100

valid_mask = patch_nan_percentages <= THRESHOLD
valid_indices = np.where(valid_mask)
num_valid = len(valid_indices[0])

print(f"\nValid patches (NaN% <= {THRESHOLD}): {num_valid} of {total_patches}")


lats, lons = data_slice.lat.values, data_slice.lon.values
valid_patches = []
for i, j in zip(*valid_indices):
    lat_start, lon_start = i * PATCH_SIZE, j * PATCH_SIZE
    lat_center_idx = lat_start + PATCH_SIZE // 2
    lon_center_idx = lon_start + PATCH_SIZE // 2
    lat_center_val = lats[lat_center_idx] if lat_center_idx < len(lats) else np.nan
    lon_center_val = lons[lon_center_idx] if lon_center_idx < len(lons) else np.nan
    valid_patches.append((
        i, j,
        lat_start, lat_start + PATCH_SIZE,
        lon_start, lon_start + PATCH_SIZE,
        lat_center_idx, lon_center_idx,
        lat_center_val, lon_center_val,
        patch_nan_percentages[i, j]
    ))

valid_patches = np.array(valid_patches, dtype=[
    ('patch_i', np.int32),
    ('patch_j', np.int32),
    ('lat_start', np.int32),
    ('lat_end', np.int32),
    ('lon_start', np.int32),
    ('lon_end', np.int32),
    ('lat_center_idx', np.int32),
    ('lon_center_idx', np.int32),
    ('lat_center_val', np.float32),
    ('lon_center_val', np.float32),
    ('nan_percentage', np.float32)
])

np.savez_compressed(
    OUTPUT_FILE,
    valid_patches=valid_patches,
    patch_size=PATCH_SIZE,
    lat_dim=lat_dim,
    lon_dim=lon_dim,
    threshold=THRESHOLD
)

elapsed = time.time() - start_time
print(f"Finished in {elapsed:.2f} seconds.")


Opening dataset lazily...
Selecting band '2005_01_01'...


  ds = xr.open_dataset(DATA_FILE, chunks={"lat": 1000, "lon": 1000})
  ds = xr.open_dataset(DATA_FILE, chunks={"lat": 1000, "lon": 1000})


Data shape: 6800 x 9000, total patches: 3710
Analyzing patches for NaN coverage...


Patch rows: 100%|██████████| 53/53 [00:00<00:00, 792.35it/s]


Valid patches (NaN% <= 30.0): 2051 of 3710
Finished in 33.23 seconds.





# Feature : Switch 1 Static

In [15]:
import os
import numpy as np
import xarray as xr
import gc
import time
from tqdm import tqdm

# List of static features
STATIC_FEATURES = [
    {
        "path": "/home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Aspect_nc_005_Aus.nc",
        "var": "aspect_degrees",
        "name": "Aspect_degrees"
    },
    {
        "path": "/home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Distance_to_major_roads_nc_Aus_005.nc",
        "var": "distance",
        "name": "Distance_to_major_roads"
    },
    {
        "path": "/home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Distance_to_railway_line_nc_Aus_005.nc",
        "var": "distance",
        "name": "Distance_to_railway_line"
    },
    {
        "path": "/home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Distance_to_transmission_line_nc_Aus_005.nc",
        "var": "distance",
        "name": "Distance_to_transmission_line"
    },
    {
        "path": "/home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Elevation_nc_005_Aus.nc",
        "var": "Elevation",
        "name": "Elevation"
    },
    {
        "path": "/home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Global_Landform_nc_Aus_005.nc",
        "var": "Landform",
        "name": "Global_Landform"
    },
    {
        "path": "/home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Iwashi_nc_Aus_005.nc",
        "var": "Landform",
        "name": "Iwashi_Landform"
    },
    {
        "path": "/home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Population_nc_Aus_005.nc",
        "var": "population",
        "name": "Population"
    },
    {
        "path": "/home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Slope_nc_Aus_005.nc",
        "var": "slope",
        "name": "Slope"
    },
    {
        "path": "/home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Vegetation_nc_Aus_005.nc",
        "var": "Vegetation_Type",
        "name": "Vegetation_Type"
    },
    {
        "path": "/home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/meybeck_nc_Aus_005.nc",
        "var": "Landform",
        "name": "meybeck_Landform"
    }
]

def main():
    start_time = time.time()
    
    # Load valid patches from your previously-created file 
    # (e.g. valid_patches_under30.npz or valid_patches_under70.npz)
    valid_patches_file = "valid_patches_under30.npz"
    if not os.path.exists(valid_patches_file):
        raise FileNotFoundError(f"Valid patches file '{valid_patches_file}' not found.")

    print(f"Loading valid patches metadata from '{valid_patches_file}'...")
    data = np.load(valid_patches_file, allow_pickle=True)

    # Convert structured array to list of dictionaries
    valid_patches_struct = data["valid_patches"]
    valid_patches_list = []
    for i in range(len(valid_patches_struct)):
        patch_dict = {}
        for field in valid_patches_struct.dtype.names:
            patch_dict[field] = valid_patches_struct[field][i]
        valid_patches_list.append(patch_dict)
    
    patch_size = int(data["patch_size"])
    lat_dim    = int(data["lat_dim"])
    lon_dim    = int(data["lon_dim"])
    threshold  = float(data["threshold"])
    
    num_valid_patches = len(valid_patches_list)
    print(f"Found {num_valid_patches} valid patches, patch_size={patch_size}, threshold={threshold}%")
    print(f"Domain: {lat_dim} x {lon_dim}\n")

    # We'll save all static features under a single directory
    out_root = "static_patches"
    os.makedirs(out_root, exist_ok=True)
    
    # Iterate over each static feature, slice out patches, and store as [patches, patch_size, patch_size].
    for idx, feature_info in enumerate(STATIC_FEATURES, start=1):
        feature_path = feature_info["path"]
        feature_var  = feature_info["var"]
        feature_name = feature_info["name"]
        
        print(f"[{idx}/{len(STATIC_FEATURES)}] Processing static feature: {feature_name}")
        
        if not os.path.exists(feature_path):
            print(f"  -> File not found: {feature_path}, skipping.")
            continue
        
        # Open the dataset
        ds = None
        try:
            ds = xr.open_dataset(feature_path)
            
            if feature_var not in ds.data_vars:
                print(f"  -> Variable '{feature_var}' not found in {feature_path}, skipping.")
                ds.close()
                continue
            
            # The shape should be [lat, lon] with no time dimension
            # Load the entire array into memory (assuming you have enough RAM)
            print(f"  -> Loading data for '{feature_var}' from {feature_path}...")
            data_2d = ds[feature_var].load().values  # shape: (lat_dim, lon_dim)
            ds.close()
            ds = None
            
            # Create an array [num_valid_patches, patch_size, patch_size]
            feature_patches = np.zeros((num_valid_patches, patch_size, patch_size), dtype=np.float32)
            feature_patches.fill(np.nan)
            
            print(f"  -> Extracting {num_valid_patches} patches of size {patch_size}x{patch_size}...")
            for p_idx, patch_info in enumerate(tqdm(valid_patches_list, desc="Patch Extraction", leave=False)):
                lat_start = patch_info["lat_start"]
                lat_end   = patch_info["lat_end"]
                lon_start = patch_info["lon_start"]
                lon_end   = patch_info["lon_end"]
                
                # Extract the patch, copy to avoid referencing original array
                patch_array = data_2d[lat_start:lat_end, lon_start:lon_end].copy()
                feature_patches[p_idx] = patch_array
            
            # Build metadata
            metadata = {
                "feature_name": feature_name,
                "feature_var": feature_var,
                "patch_size": patch_size,
                "num_patches": num_valid_patches,
                "threshold_nan": threshold
            }
            
            # Save as .npz
            out_file = os.path.join(out_root, f"{feature_name}.npz")
            np.savez_compressed(
                out_file,
                data=feature_patches,
                metadata=metadata
            )
            
            print(f"  -> Saved patches to '{out_file}'")
            
            # Cleanup
            del data_2d
            del feature_patches
            gc.collect()
            
        except Exception as e:
            print(f"  -> Error processing {feature_name}: {e}")
            if ds:
                ds.close()
    
    end_time = time.time()
    elapsed = end_time - start_time
    print(f"\nAll static features processed. Total time: {elapsed:.2f}s ({elapsed/60:.2f} min).")

if __name__ == "__main__":
    main()

Loading valid patches metadata from 'valid_patches_under30.npz'...
Found 2051 valid patches, patch_size=128, threshold=30.0%
Domain: 6800 x 9000

[1/11] Processing static feature: Aspect_degrees
  -> Loading data for 'aspect_degrees' from /home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Aspect_nc_005_Aus.nc...
  -> Extracting 2051 patches of size 128x128...


                                                          

  -> Saved patches to 'static_patches/Aspect_degrees.npz'
[2/11] Processing static feature: Distance_to_major_roads
  -> Loading data for 'distance' from /home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Distance_to_major_roads_nc_Aus_005.nc...
  -> Extracting 2051 patches of size 128x128...


                                                          

  -> Saved patches to 'static_patches/Distance_to_major_roads.npz'
[3/11] Processing static feature: Distance_to_railway_line
  -> Loading data for 'distance' from /home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Distance_to_railway_line_nc_Aus_005.nc...
  -> Extracting 2051 patches of size 128x128...


                                                          

  -> Saved patches to 'static_patches/Distance_to_railway_line.npz'
[4/11] Processing static feature: Distance_to_transmission_line
  -> Loading data for 'distance' from /home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Distance_to_transmission_line_nc_Aus_005.nc...
  -> Extracting 2051 patches of size 128x128...


                                                          

  -> Saved patches to 'static_patches/Distance_to_transmission_line.npz'
[5/11] Processing static feature: Elevation
  -> Loading data for 'Elevation' from /home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Elevation_nc_005_Aus.nc...
  -> Extracting 2051 patches of size 128x128...


                                                          

  -> Saved patches to 'static_patches/Elevation.npz'
[6/11] Processing static feature: Global_Landform
  -> Loading data for 'Landform' from /home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Global_Landform_nc_Aus_005.nc...
  -> Extracting 2051 patches of size 128x128...


                                                          

  -> Saved patches to 'static_patches/Global_Landform.npz'
[7/11] Processing static feature: Iwashi_Landform
  -> Loading data for 'Landform' from /home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Iwashi_nc_Aus_005.nc...
  -> Extracting 2051 patches of size 128x128...


                                                          

  -> Saved patches to 'static_patches/Iwashi_Landform.npz'
[8/11] Processing static feature: Population
  -> Loading data for 'population' from /home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Population_nc_Aus_005.nc...
  -> Extracting 2051 patches of size 128x128...


                                                          

  -> Saved patches to 'static_patches/Population.npz'
[9/11] Processing static feature: Slope
  -> Loading data for 'slope' from /home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Slope_nc_Aus_005.nc...
  -> Extracting 2051 patches of size 128x128...


                                                          

  -> Saved patches to 'static_patches/Slope.npz'
[10/11] Processing static feature: Vegetation_Type
  -> Loading data for 'Vegetation_Type' from /home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Vegetation_nc_Aus_005.nc...
  -> Extracting 2051 patches of size 128x128...


                                                          

  -> Saved patches to 'static_patches/Vegetation_Type.npz'
[11/11] Processing static feature: meybeck_Landform
  -> Loading data for 'Landform' from /home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/meybeck_nc_Aus_005.nc...
  -> Extracting 2051 patches of size 128x128...


                                                          

  -> Saved patches to 'static_patches/meybeck_Landform.npz'

All static features processed. Total time: 37.36s (0.62 min).


In [16]:
import os
import numpy as np
import time
from tqdm.auto import tqdm

# Define the parameters for static features
STATIC_FEATURES = [
    {"path": "/home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Aspect_nc_005_Aus.nc", "var": "aspect_degrees", "name": "Aspect_degrees"},
    {"path": "/home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Distance_to_major_roads_nc_Aus_005.nc", "var": "distance", "name": "Distance_to_major_roads"},
    {"path": "/home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Distance_to_railway_line_nc_Aus_005.nc", "var": "distance", "name": "Distance_to_railway_line"},
    {"path": "/home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Distance_to_transmission_line_nc_Aus_005.nc", "var": "distance", "name": "Distance_to_transmission_line"},
    {"path": "/home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Elevation_nc_005_Aus.nc", "var": "Elevation", "name": "Elevation"},
    {"path": "/home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Global_Landform_nc_Aus_005.nc", "var": "Landform", "name": "Global_Landform"},
    {"path": "/home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Iwashi_nc_Aus_005.nc", "var": "Landform", "name": "Iwashi_Landform"},
    {"path": "/home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Population_nc_Aus_005.nc", "var": "population", "name": "Population"},
    {"path": "/home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Slope_nc_Aus_005.nc", "var": "slope", "name": "Slope"},
    {"path": "/home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Vegetation_nc_Aus_005.nc", "var": "Vegetation_Type", "name": "Vegetation_Type"},
    {"path": "/home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/meybeck_nc_Aus_005.nc", "var": "Landform", "name": "meybeck_Landform"}
]

OUT_ROOT = "static_patches"

def validate_static_data_structure():
    """
    Validates that all expected static feature files exist and have the correct structure.
    """
    print("Starting validation of processed static feature data...")
    start_time = time.time()
    
    # Load the original valid patches metadata for verification
    valid_patches_data = np.load("valid_patches_under30.npz", allow_pickle=True)
    valid_patches = valid_patches_data["valid_patches"]
    expected_num_patches = len(valid_patches)
    expected_patch_size = int(valid_patches_data["patch_size"])
    threshold = float(valid_patches_data["threshold"])
    
    print(f"Original metadata: {expected_num_patches} valid patches, "
          f"patch size: {expected_patch_size}x{expected_patch_size}, threshold: {threshold}%")
    
    # Check if output directory exists
    if not os.path.exists(OUT_ROOT):
        print(f"WARNING: Output directory '{OUT_ROOT}' does not exist!")
        return
    else:
        print(f"✓ Output directory '{OUT_ROOT}' exists.")
    
    # Verify all static feature npz files exist
    feature_names = [feature["name"] for feature in STATIC_FEATURES]
    expected_files = [os.path.join(OUT_ROOT, f"{name}.npz") for name in feature_names]
    
    missing_files = [f for f in expected_files if not os.path.exists(f)]
    
    if missing_files:
        print(f"WARNING: {len(missing_files)}/{len(expected_files)} static feature files are missing!")
        for f in missing_files:
            print(f"  - {f}")
    else:
        print(f"✓ All {len(expected_files)} expected static feature files exist.")
    
    # Validate content of files
    existing_files = [f for f in expected_files if os.path.exists(f)]
    print(f"Validating content of {len(existing_files)} static feature files...")
    
    validation_results = []
    for file_path in tqdm(existing_files, desc="Validating Files"):
        result = validate_static_file_content(file_path, expected_num_patches, expected_patch_size, threshold)
        validation_results.append(result)
    
    # Summarize validation results
    success_count = sum(1 for r in validation_results if r["valid"])
    print(f"\nValidation Results: {success_count}/{len(validation_results)} files are valid.")
    
    if success_count < len(validation_results):
        print("\nIssues found:")
        for r in validation_results:
            if not r["valid"]:
                print(f"  - File: {r['file']}, Issue: {r['issue']}")
    
    # Perform a deeper validation on one file to check all patches
    if existing_files:
        print("\nPerforming detailed validation on a sample file...")
        sample_file = existing_files[0]
        detailed_result = validate_static_file_detailed(sample_file, valid_patches, threshold)
        
        print(f"Detailed validation of {sample_file}:")
        print(f"  - Total patches: {detailed_result['total_patches']}")
        print(f"  - Valid patches (NaN% <= {threshold}%): {detailed_result['valid_patches']}")
        print(f"  - Invalid patches: {detailed_result['invalid_patches']}")
        print(f"  - Highest NaN%: {detailed_result['max_nan_pct']:.2f}%")
        print(f"  - Average NaN%: {detailed_result['avg_nan_pct']:.2f}%")
        
        # Print NaN percentage distribution in text form
        if detailed_result['nan_percentages']:
            nan_percentages = detailed_result['nan_percentages']
            
            # Define ranges for text-based distribution
            ranges = [(0, 10), (10, 20), (20, 30), (30, 40), (40, 50), 
                      (50, 60), (60, 70), (70, 80), (80, 90), (90, 100)]
            
            print("\nNaN Percentage Distribution:")
            for low, high in ranges:
                count = sum(1 for p in nan_percentages if low <= p < high)
                percentage = (count / len(nan_percentages)) * 100
                bar_length = int(percentage / 2)  # Scale for display
                bar = '█' * bar_length
                print(f"  {low:2d}-{high:<3d}%: {count:4d} patches ({percentage:5.1f}%) {bar}")
    
    # Check for inconsistencies across static features
    if len(existing_files) >= 2:
        print("\nChecking for inconsistencies across static features...")
        inconsistency_results = check_cross_feature_consistency(existing_files, threshold)
        
        if inconsistency_results['inconsistent_files']:
            print("Inconsistencies detected between static features:")
            for issue in inconsistency_results['issues']:
                print(f"  - {issue}")
        else:
            print("✓ No inconsistencies detected across static features.")
    
    end_time = time.time()
    elapsed = end_time - start_time
    print(f"\nValidation completed in {elapsed:.2f} seconds ({elapsed/60:.2f} minutes).")

def validate_static_file_content(file_path, expected_num_patches, expected_patch_size, threshold):
    """
    Validates the content of a single static feature file.
    """
    result = {
        "file": file_path,
        "valid": True,
        "issue": None
    }
    
    try:
        # Load the file
        data = np.load(file_path, allow_pickle=True)
        
        # Check if required keys exist
        required_keys = ['data', 'metadata']
        missing_keys = [k for k in required_keys if k not in data]
        if missing_keys:
            result["valid"] = False
            result["issue"] = f"Missing keys: {', '.join(missing_keys)}"
            return result
        
        # Check data shape - static features should be 3D [patches, height, width]
        data_array = data['data']
        if len(data_array.shape) != 3:
            result["valid"] = False
            result["issue"] = f"Expected 3D array, got {len(data_array.shape)}D"
            return result
        
        num_patches, patch_height, patch_width = data_array.shape
        
        if num_patches != expected_num_patches:
            result["valid"] = False
            result["issue"] = f"Expected {expected_num_patches} patches, got {num_patches}"
            return result
        
        if patch_height != expected_patch_size or patch_width != expected_patch_size:
            result["valid"] = False
            result["issue"] = f"Expected {expected_patch_size}x{expected_patch_size} patch size, got {patch_height}x{patch_width}"
            return result
        
        # Check metadata structure
        metadata = data['metadata'].item() if isinstance(data['metadata'], np.ndarray) else data['metadata']
        if not isinstance(metadata, dict):
            result["valid"] = False
            result["issue"] = f"Metadata is not a dictionary"
            return result
        
        # Validate static feature metadata - check if required fields exist
        required_metadata = ["feature_name", "feature_var", "patch_size", "num_patches", "threshold_nan"]
        missing_metadata = [k for k in required_metadata if k not in metadata]
        if missing_metadata:
            result["valid"] = False
            result["issue"] = f"Missing metadata fields: {', '.join(missing_metadata)}"
            return result
        
        # Check for unrealistic values based on feature type
        feature_name = os.path.basename(file_path).replace('.npz', '')
        unrealistic_values = check_unrealistic_static_values(data_array, feature_name)
        if unrealistic_values:
            result["valid"] = False
            result["issue"] = f"Contains unrealistic values: {unrealistic_values}"
            return result
        
    except Exception as e:
        result["valid"] = False
        result["issue"] = f"Error loading/parsing file: {e}"
    
    return result

def check_unrealistic_static_values(data_array, feature_name):
    """
    Check for unrealistic values in static features based on the feature type.
    Returns description of problems found or None if values are reasonable.
    """
    # Flatten for easier analysis, ignoring NaNs
    valid_data = data_array[~np.isnan(data_array)]
    
    if len(valid_data) == 0:
        return "Data contains only NaN values"
    
    # Define reasonable ranges for different static features
    if "Aspect_degrees" in feature_name:
        # Aspect should be between 0 and 360 degrees
        if np.min(valid_data) < 0 or np.max(valid_data) > 360:
            return f"Aspect out of range 0-360: min={np.min(valid_data):.1f}, max={np.max(valid_data):.1f}"
    
    elif "Distance" in feature_name:
        # Distance should be >= 0
        if np.min(valid_data) < 0:
            return f"Negative distance values: min={np.min(valid_data):.1f}"
    
    elif "Elevation" in feature_name:
        # Elevation in Australia generally between -15m and 2228m (Mt Kosciuszko)
        # Allow a bit more range for potential errors/outliers
        if np.min(valid_data) < -20 or np.max(valid_data) > 2500:
            return f"Elevation out of realistic range: min={np.min(valid_data):.1f}, max={np.max(valid_data):.1f}"
    
    elif "Landform" in feature_name:
        # Landform is typically categorical data with integer values
        if not np.all(np.equal(np.mod(valid_data, 1), 0)):
            return "Landform contains non-integer values"
    
    elif "Population" in feature_name:
        # Population should be >= 0
        if np.min(valid_data) < 0:
            return f"Negative population values: min={np.min(valid_data):.1f}"
    
    elif "Slope" in feature_name:
        # Slope typically in degrees (0-90) or percent
        if np.min(valid_data) < 0:
            return f"Negative slope values: min={np.min(valid_data):.1f}"
        # If in degrees, shouldn't exceed 90
        if np.max(valid_data) <= 90:
            pass
        # If in percent, extremely steep slopes might go over 100% but rarely over 1000%
        elif np.max(valid_data) > 1000:
            return f"Extremely high slope values: max={np.max(valid_data):.1f}"
    
    elif "Vegetation_Type" in feature_name:
        # Vegetation type is typically categorical data with integer values
        if not np.all(np.equal(np.mod(valid_data, 1), 0)):
            return "Vegetation type contains non-integer values"
    
    # Check for other anomalies
    if np.all(valid_data == valid_data[0]):
        return "Data contains only a single repeated value"
    
    # Check for NaN percentage
    nan_percentage = np.isnan(data_array).sum() / data_array.size * 100
    if nan_percentage > 50:
        return f"High percentage of NaN values: {nan_percentage:.1f}%"
    
    return None

def validate_static_file_detailed(file_path, valid_patches, threshold):
    """
    Performs a detailed validation of one static feature file, checking all patches.
    """
    result = {
        'total_patches': 0,
        'valid_patches': 0,
        'invalid_patches': 0,
        'max_nan_pct': 0,
        'avg_nan_pct': 0,
        'nan_percentages': []
    }
    
    try:
        # Load the file
        data = np.load(file_path, allow_pickle=True)
        data_array = data['data']
        
        num_patches, patch_height, patch_width = data_array.shape
        result['total_patches'] = num_patches
        
        # Check NaN percentages for each patch
        nan_percentages = []
        for patch_idx in range(num_patches):
            patch = data_array[patch_idx]
            nan_count = np.isnan(patch).sum()
            total_cells = patch_height * patch_width
            nan_percentage = (nan_count / total_cells) * 100
            
            nan_percentages.append(nan_percentage)
            
            if nan_percentage <= threshold:
                result['valid_patches'] += 1
            else:
                result['invalid_patches'] += 1
        
        result['nan_percentages'] = nan_percentages
        result['max_nan_pct'] = max(nan_percentages) if nan_percentages else 0
        result['avg_nan_pct'] = sum(nan_percentages) / len(nan_percentages) if nan_percentages else 0
        
    except Exception as e:
        print(f"Error in detailed validation: {e}")
    
    return result

def check_cross_feature_consistency(file_paths, threshold):
    """
    Checks for inconsistencies across different static feature files.
    """
    result = {
        'inconsistent_files': False,
        'issues': []
    }
    
    try:
        # Load metadata from all files
        metadata_list = []
        nan_patterns = []
        
        for file_path in file_paths:
            data = np.load(file_path, allow_pickle=True)
            
            # Extract metadata
            metadata = data['metadata'].item() if isinstance(data['metadata'], np.ndarray) else data['metadata']
            metadata_list.append({
                'file': file_path,
                'name': metadata.get('feature_name', os.path.basename(file_path).replace('.npz', '')),
                'num_patches': metadata.get('num_patches', None),
                'patch_size': metadata.get('patch_size', None)
            })
            
            # Extract NaN pattern (where data is NaN)
            data_array = data['data']
            nan_pattern = np.isnan(data_array)
            
            # For memory efficiency, we'll just remember the total NaN count for each patch
            nan_counts = np.sum(nan_pattern, axis=(1, 2))  # Sum over height and width dimensions
            nan_patterns.append({
                'file': file_path,
                'nan_counts': nan_counts
            })
        
        # Check for metadata consistency
        num_patches_values = set(m['num_patches'] for m in metadata_list if m['num_patches'] is not None)
        if len(num_patches_values) > 1:
            result['inconsistent_files'] = True
            result['issues'].append(f"Inconsistent number of patches across files: {num_patches_values}")
        
        patch_size_values = set(m['patch_size'] for m in metadata_list if m['patch_size'] is not None)
        if len(patch_size_values) > 1:
            result['inconsistent_files'] = True
            result['issues'].append(f"Inconsistent patch sizes across files: {patch_size_values}")
        
        # Check for NaN pattern consistency
        # We expect all static features to have NaNs in the same locations (patches)
        if len(nan_patterns) >= 2:
            reference = nan_patterns[0]
            ref_file = os.path.basename(reference['file'])
            ref_counts = reference['nan_counts']
            
            for pattern in nan_patterns[1:]:
                curr_file = os.path.basename(pattern['file'])
                curr_counts = pattern['nan_counts']
                
                # Compare NaN counts for each patch
                if not np.array_equal(ref_counts, curr_counts):
                    # Find how many patches have different NaN counts
                    diff_patches = np.sum(ref_counts != curr_counts)
                    total_patches = len(ref_counts)
                    diff_percentage = (diff_patches / total_patches) * 100
                    
                    if diff_percentage > 1:  # Allow small differences (< 1%)
                        result['inconsistent_files'] = True
                        result['issues'].append(
                            f"Different NaN patterns between {ref_file} and {curr_file}: "
                            f"{diff_patches}/{total_patches} patches ({diff_percentage:.1f}%)"
                        )
        
    except Exception as e:
        result['inconsistent_files'] = True
        result['issues'].append(f"Error checking cross-feature consistency: {e}")
    
    return result

if __name__ == "__main__":
    # Set a random seed for reproducibility
    np.random.seed(42)
    validate_static_data_structure()

Starting validation of processed static feature data...
Original metadata: 2051 valid patches, patch size: 128x128, threshold: 30.0%
✓ Output directory 'static_patches' exists.
✓ All 11 expected static feature files exist.
Validating content of 11 static feature files...


Validating Files: 100%|██████████| 11/11 [00:05<00:00,  1.89it/s]



Validation Results: 10/11 files are valid.

Issues found:
  - File: static_patches/Elevation.npz, Issue: Contains unrealistic values: Elevation out of realistic range: min=-60.0, max=2172.0

Performing detailed validation on a sample file...
Detailed validation of static_patches/Aspect_degrees.npz:
  - Total patches: 2051
  - Valid patches (NaN% <= 30.0%): 1677
  - Invalid patches: 374
  - Highest NaN%: 100.00%
  - Average NaN%: 18.26%

NaN Percentage Distribution:
   0-10 %: 1665 patches ( 81.2%) ████████████████████████████████████████
  10-20 %:    0 patches (  0.0%) 
  20-30 %:   12 patches (  0.6%) 
  30-40 %:    0 patches (  0.0%) 
  40-50 %:    3 patches (  0.1%) 
  50-60 %:    0 patches (  0.0%) 
  60-70 %:    1 patches (  0.0%) 
  70-80 %:    0 patches (  0.0%) 
  80-90 %:    2 patches (  0.1%) 
  90-100%:    0 patches (  0.0%) 

Checking for inconsistencies across static features...
Inconsistencies detected between static features:
  - Different NaN patterns between Aspect_d

In [17]:
import os
import numpy as np
import xarray as xr
from tqdm.auto import tqdm

# Define the static features to analyze
STATIC_FEATURES = [
    {"path": "/home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Aspect_nc_005_Aus.nc", "var": "aspect_degrees", "name": "Aspect_degrees"},
    {"path": "/home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Distance_to_major_roads_nc_Aus_005.nc", "var": "distance", "name": "Distance_to_major_roads"},
    {"path": "/home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Distance_to_railway_line_nc_Aus_005.nc", "var": "distance", "name": "Distance_to_railway_line"},
    {"path": "/home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Distance_to_transmission_line_nc_Aus_005.nc", "var": "distance", "name": "Distance_to_transmission_line"},
    {"path": "/home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Elevation_nc_005_Aus.nc", "var": "Elevation", "name": "Elevation"},
    {"path": "/home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Global_Landform_nc_Aus_005.nc", "var": "Landform", "name": "Global_Landform"},
    {"path": "/home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Iwashi_nc_Aus_005.nc", "var": "Landform", "name": "Iwashi_Landform"},
    {"path": "/home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Population_nc_Aus_005.nc", "var": "population", "name": "Population"},
    {"path": "/home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Slope_nc_Aus_005.nc", "var": "slope", "name": "Slope"},
    {"path": "/home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/Vegetation_nc_Aus_005.nc", "var": "Vegetation_Type", "name": "Vegetation_Type"},
    {"path": "/home/ubuntu/Data-Seasonal-forecast/Switch1_Static_predictors/meybeck_nc_Aus_005.nc", "var": "Landform", "name": "meybeck_Landform"}
]

def analyze_nan_percentages():
    """
    Analyzes and prints the percentage of NaN values in each static feature file.
    """
    print("Analyzing NaN percentages in static feature files...\n")
    
    results = []
    
    for feature in tqdm(STATIC_FEATURES, desc="Processing Static Features"):
        feature_path = feature["path"]
        feature_var = feature["var"]
        feature_name = feature["name"]
        
        print(f"\nAnalyzing feature: {feature_name}")
        
        # Check if file exists
        if not os.path.exists(feature_path):
            print(f"  ERROR: File not found: {feature_path}")
            results.append({
                "name": feature_name, 
                "nan_pct": None, 
                "error": "File not found"
            })
            continue
        
        try:
            # Open dataset
            ds = xr.open_dataset(feature_path)
            
            # Check if variable exists
            if feature_var not in ds.data_vars:
                print(f"  ERROR: Variable '{feature_var}' not found in dataset")
                ds.close()
                results.append({
                    "name": feature_name, 
                    "nan_pct": None, 
                    "error": f"Variable '{feature_var}' not found"
                })
                continue
            
            # Load the data
            print(f"  Loading data for '{feature_var}'...")
            data = ds[feature_var].load().values
            
            # Calculate NaN percentage
            total_elements = data.size
            nan_count = np.sum(np.isnan(data))
            nan_percentage = (nan_count / total_elements) * 100
            
            print(f"  Data shape: {data.shape}")
            print(f"  Total elements: {total_elements:,}")
            print(f"  NaN count: {nan_count:,}")
            print(f"  NaN percentage: {nan_percentage:.2f}%")
            
            # Add to results
            results.append({
                "name": feature_name,
                "shape": data.shape,
                "total_elements": total_elements,
                "nan_count": nan_count,
                "nan_pct": nan_percentage,
                "error": None
            })
            
            # Close the dataset
            ds.close()
            
            # Get min/max values (excluding NaNs)
            valid_data = data[~np.isnan(data)]
            if len(valid_data) > 0:
                min_val = np.min(valid_data)
                max_val = np.max(valid_data)
                unique_vals = len(np.unique(valid_data))
                print(f"  Value range: [{min_val}, {max_val}]")
                print(f"  Number of unique values: {unique_vals:,}")
                
                # Store additional statistics
                results[-1]["min_val"] = min_val
                results[-1]["max_val"] = max_val
                results[-1]["unique_vals"] = unique_vals
            else:
                print("  WARNING: No valid data (all values are NaN)")
                
            # Clean up
            del data
            
        except Exception as e:
            print(f"  ERROR: Failed to process file: {e}")
            results.append({
                "name": feature_name, 
                "nan_pct": None, 
                "error": str(e)
            })
    
    # Print summary table
    print("\n" + "="*80)
    print(f"{'Feature Name':<30} {'Shape':<15} {'NaN %':<10} {'Value Range':<20} {'Unique Values':<15}")
    print("-"*80)
    
    for result in results:
        name = result["name"]
        
        if result["error"] is not None:
            print(f"{name:<30} ERROR: {result['error']}")
        else:
            shape_str = "x".join(map(str, result["shape"]))
            nan_pct = f"{result['nan_pct']:.2f}%" if result["nan_pct"] is not None else "N/A"
            value_range = f"[{result.get('min_val', 'N/A')}, {result.get('max_val', 'N/A')}]"
            unique_vals = f"{result.get('unique_vals', 'N/A'):,}" if "unique_vals" in result else "N/A"
            
            print(f"{name:<30} {shape_str:<15} {nan_pct:<10} {value_range:<20} {unique_vals:<15}")
    
    print("="*80)
    
    # Calculate overall statistics
    valid_results = [r for r in results if r["nan_pct"] is not None]
    if valid_results:
        avg_nan_pct = sum(r["nan_pct"] for r in valid_results) / len(valid_results)
        max_nan_pct = max(r["nan_pct"] for r in valid_results)
        min_nan_pct = min(r["nan_pct"] for r in valid_results)
        
        print(f"\nOverall Statistics:")
        print(f"  Average NaN percentage across all features: {avg_nan_pct:.2f}%")
        print(f"  Maximum NaN percentage: {max_nan_pct:.2f}%")
        print(f"  Minimum NaN percentage: {min_nan_pct:.2f}%")
        
        # Feature with highest NaN percentage
        highest_nan_feature = max(valid_results, key=lambda r: r["nan_pct"])
        print(f"  Feature with highest NaN percentage: {highest_nan_feature['name']} ({highest_nan_feature['nan_pct']:.2f}%)")
        
        # Feature with lowest NaN percentage
        lowest_nan_feature = min(valid_results, key=lambda r: r["nan_pct"])
        print(f"  Feature with lowest NaN percentage: {lowest_nan_feature['name']} ({lowest_nan_feature['nan_pct']:.2f}%)")
    
    # Check for suspicious patterns in NaN distributions
    print("\nChecking for suspicious NaN patterns...")
    
    # Categorize features by NaN percentage
    high_nan_features = [r["name"] for r in valid_results if r["nan_pct"] > 50]
    if high_nan_features:
        print(f"  WARNING: {len(high_nan_features)} features have more than 50% NaN values:")
        for name in high_nan_features:
            print(f"    - {name}")
    
    # Look for outliers in NaN percentages
    if len(valid_results) > 3:
        nan_percentages = [r["nan_pct"] for r in valid_results]
        avg = sum(nan_percentages) / len(nan_percentages)
        std_dev = np.std(nan_percentages)
        
        outliers = []
        for result in valid_results:
            if abs(result["nan_pct"] - avg) > 2 * std_dev:
                outliers.append(result["name"])
        
        if outliers:
            print(f"  NOTE: {len(outliers)} features have NaN percentages that are statistical outliers:")
            for name in outliers:
                feature_result = next(r for r in valid_results if r["name"] == name)
                print(f"    - {name}: {feature_result['nan_pct']:.2f}%")

if __name__ == "__main__":
    analyze_nan_percentages()

Analyzing NaN percentages in static feature files...



Processing Static Features:   0%|          | 0/11 [00:00<?, ?it/s]


Analyzing feature: Aspect_degrees
  Loading data for 'aspect_degrees'...
  Data shape: (6800, 9000)
  Total elements: 61,200,000
  NaN count: 28,668,148
  NaN percentage: 46.84%


Processing Static Features:   9%|▉         | 1/11 [00:00<00:08,  1.11it/s]

  Value range: [0.0, 359.892]
  Number of unique values: 142,201

Analyzing feature: Distance_to_major_roads
  Loading data for 'distance'...
  Data shape: (6800, 9000)
  Total elements: 61,200,000
  NaN count: 0
  NaN percentage: 0.00%


Processing Static Features:  18%|█▊        | 2/11 [00:03<00:14,  1.63s/it]

  Value range: [0.0, 1387.126801538223]
  Number of unique values: 58,924,284

Analyzing feature: Distance_to_railway_line
  Loading data for 'distance'...
  Data shape: (6800, 9000)
  Total elements: 61,200,000
  NaN count: 0
  NaN percentage: 0.00%


Processing Static Features:  27%|██▋       | 3/11 [00:05<00:14,  1.86s/it]

  Value range: [0.0, 1385.1285348620188]
  Number of unique values: 60,469,973

Analyzing feature: Distance_to_transmission_line
  Loading data for 'distance'...
  Data shape: (6800, 9000)
  Total elements: 61,200,000
  NaN count: 0
  NaN percentage: 0.00%


Processing Static Features:  36%|███▋      | 4/11 [00:07<00:13,  1.96s/it]

  Value range: [0.0, 1385.038332245596]
  Number of unique values: 60,422,301

Analyzing feature: Elevation
  Loading data for 'Elevation'...
  Data shape: (6800, 9000)
  Total elements: 61,200,000
  NaN count: 28,632,801
  NaN percentage: 46.79%


Processing Static Features:  45%|████▌     | 5/11 [00:08<00:09,  1.56s/it]

  Value range: [-60.0, 2172.0]
  Number of unique values: 2,108

Analyzing feature: Global_Landform
  Loading data for 'Landform'...


Processing Static Features:  55%|█████▍    | 6/11 [00:08<00:05,  1.14s/it]

  Data shape: (6800, 9000)
  Total elements: 61,200,000
  NaN count: 33,438,202
  NaN percentage: 54.64%
  Value range: [101.0, 242.0]
  Number of unique values: 9

Analyzing feature: Iwashi_Landform
  Loading data for 'Landform'...


Processing Static Features:  64%|██████▎   | 7/11 [00:08<00:03,  1.12it/s]

  Data shape: (6800, 9000)
  Total elements: 61,200,000
  NaN count: 33,527,006
  NaN percentage: 54.78%
  Value range: [1.0, 16.0]
  Number of unique values: 16

Analyzing feature: Population
  Loading data for 'population'...
  Data shape: (6800, 9000)
  Total elements: 61,200,000
  NaN count: 0
  NaN percentage: 0.00%


Processing Static Features:  73%|███████▎  | 8/11 [00:09<00:02,  1.20it/s]

  Value range: [0.0, 32560.697265625]
  Number of unique values: 80,225

Analyzing feature: Slope
  Loading data for 'slope'...
  Data shape: (6800, 9000)
  Total elements: 61,200,000
  NaN count: 28,668,148
  NaN percentage: 46.84%


Processing Static Features:  82%|████████▏ | 9/11 [00:10<00:01,  1.22it/s]

  Value range: [0.0, 89.83]
  Number of unique values: 618

Analyzing feature: Vegetation_Type
  Loading data for 'Vegetation_Type'...
  Data shape: (6800, 9000)
  Total elements: 61,200,000
  NaN count: 0
  NaN percentage: 0.00%


Processing Static Features:  91%|█████████ | 10/11 [00:10<00:00,  1.42it/s]

  Value range: [0, 99]
  Number of unique values: 33

Analyzing feature: meybeck_Landform
  Loading data for 'Landform'...


Processing Static Features: 100%|██████████| 11/11 [00:11<00:00,  1.01s/it]

  Data shape: (6800, 9000)
  Total elements: 61,200,000
  NaN count: 33,496,632
  NaN percentage: 54.73%
  Value range: [1.0, 14.0]
  Number of unique values: 13

Feature Name                   Shape           NaN %      Value Range          Unique Values  
--------------------------------------------------------------------------------
Aspect_degrees                 6800x9000       46.84%     [0.0, 359.892]       142,201        
Distance_to_major_roads        6800x9000       0.00%      [0.0, 1387.126801538223] 58,924,284     
Distance_to_railway_line       6800x9000       0.00%      [0.0, 1385.1285348620188] 60,469,973     
Distance_to_transmission_line  6800x9000       0.00%      [0.0, 1385.038332245596] 60,422,301     
Elevation                      6800x9000       46.79%     [-60.0, 2172.0]      2,108          
Global_Landform                6800x9000       54.64%     [101.0, 242.0]       9              
Iwashi_Landform                6800x9000       54.78%     [1.0, 16.0]         




# Feature : Switch 2 Fuel

In [10]:
import os
import numpy as np
import xarray as xr
import gc
import time
import multiprocessing as mp
from tqdm.auto import tqdm


FEATURES = [
    {"path": "../Switch2_Fuel_predictors/AWRA-L_e0_monthly_005_Aus.nc", "var": "e0",    "name": "E0"},
    {"path": "../Switch2_Fuel_predictors/NDVI_monthly_005_Aus.nc",      "var": "NDVI", "name": "NDVI"},
    {"path": "../Switch2_Fuel_predictors/AWRA-L_ss_monthly_005_Aus.nc", "var": "SS",   "name": "SS"},
    {"path": "../Switch2_Fuel_predictors/AWRA-L_sd_monthly_005_Aus.nc", "var": "Sd",   "name": "Sd"},
    {"path": "../Switch2_Fuel_predictors/AWRA-L_s0_monthly_005_Aus.nc", "var": "S0",   "name": "S0"},
    {"path": "../Switch2_Fuel_predictors/AWRA-L_qtot_monthly_005_Aus.nc","var": "qtot","name": "qtot"},
    {"path": "../Switch2_Fuel_predictors/AWRA-L_etot_monthly_005_Aus.nc","var": "etot","name": "etot"}
]

START_YEAR = 2005
END_YEAR   = 2022

def generate_required_bands(start_year, end_year):
    """
    Builds a list of required band strings for each month from start_year to end_year.
    """
    required = []
    for year in range(start_year, end_year + 1):
        for month in range(1, 13):
            band_str = f"{year:04d}_{month:02d}_01"
            required.append(band_str)
    return required

def process_feature(args):
    """
    Process a single feature file. All arguments are passed as a single tuple to
    support multiprocessing.
    """
    feature_info, valid_patches_list, patch_size, required_bands, out_root, proc_idx = args
    
    feature_path = feature_info["path"]
    feature_var  = feature_info["var"]
    feature_name = feature_info["name"]
    
    
    progress_file = os.path.join(out_root, f"progress_{feature_name}.txt")
    with open(progress_file, 'w') as f:
        f.write(f"0/{END_YEAR-START_YEAR+1}\n")
    
    if not os.path.exists(feature_path):
        with open(progress_file, 'w') as f:
            f.write(f"ERROR: File not found\n")
        return None
    
    num_valid = len(valid_patches_list)
    
    try:
      
        ds = xr.open_dataset(feature_path)
        
        if feature_var not in ds.data_vars:
            with open(progress_file, 'w') as f:
                f.write(f"ERROR: Variable not found\n")
            ds.close()
            return None
        
    
        actual_bands = set(str(b) for b in ds["band"].values)
        
       
        missing_bands = [b for b in required_bands if b not in actual_bands]
        if missing_bands:
            with open(progress_file, 'w') as f:
                f.write(f"ERROR: Missing bands\n")
            ds.close()
            return None
        
       ement
        for i, year in enumerate(range(START_YEAR, END_YEAR + 1)):
            # Update progress file
            with open(progress_file, 'w') as f:
                f.write(f"{i}/{END_YEAR-START_YEAR+1}\n")
            
   
            year_bands = [b for b in required_bands if b.startswith(f"{year}_")]
            
        
            months_in_year = len(year_bands)
            year_data = np.zeros((months_in_year, num_valid, patch_size, patch_size), dtype=np.float32)
            year_data.fill(np.nan)  
          
            for month_idx, band_str in enumerate(year_bands):
                try:
                   
                    band_data = ds[feature_var].sel(band=band_str).load().values
                    
                    # Extract all patches
                    for patch_idx, patch_info in enumerate(valid_patches_list):
                        lat_start = patch_info["lat_start"]
                        lat_end   = patch_info["lat_end"]
                        lon_start = patch_info["lon_start"]
                        lon_end   = patch_info["lon_end"]
                        
                        patch = band_data[lat_start:lat_end, lon_start:lon_end].copy()
                        year_data[month_idx, patch_idx] = patch
                    
                    # Clean up band data 
                    del band_data
                    
                except Exception as e:
                    with open(os.path.join(out_root, f"error_{feature_name}.txt"), 'a') as f:
                        f.write(f"Error processing band {band_str}: {e}\n")
                    continue
            
         
            year_dir = os.path.join(out_root, str(year))
            os.makedirs(year_dir, exist_ok=True)
            
     
            out_file = os.path.join(year_dir, f"{feature_name}.npz")
            
      
            metadata = {
                "feature_name": feature_name,
                "feature_var": feature_var,
                "year": year,
                "bands": year_bands,
                "patch_size": patch_size,
                "num_patches": num_valid
            }
            
        
            patch_indices = {}
            for i, patch_info in enumerate(valid_patches_list):
                patch_i = patch_info["patch_i"]
                patch_j = patch_info["patch_j"]
                patch_indices[f"{patch_i}_{patch_j}"] = i
            
            
            np.savez_compressed(
                out_file,
                data=year_data,
                metadata=metadata,
                patch_indices=patch_indices
            )
            
       
            del year_data
            gc.collect()
            
        ds.close()
        
       
        with open(progress_file, 'w') as f:
            f.write(f"COMPLETE\n")
            
        return feature_name
        
    except Exception as e:
        with open(os.path.join(out_root, f"error_{feature_name}.txt"), 'a') as f:
            f.write(f"Unexpected error: {e}\n")
        return None

def monitor_progress(out_root, features, total_years):
    """
    Monitor progress of all processes and display in a single progress bar.
    """
    features_done = set()
    
    with tqdm(total=len(features) * total_years, desc="Overall Progress") as pbar:
        previous_progress = 0
        
        while len(features_done) < len(features):
            current_progress = 0
            
            for feature in features:
                feature_name = feature["name"]
                if feature_name in features_done:
                    current_progress += total_years
                    continue
                    
                progress_file = os.path.join(out_root, f"progress_{feature_name}.txt")
                if os.path.exists(progress_file):
                    try:
                        with open(progress_file, 'r') as f:
                            content = f.read().strip()
                            if content == "COMPLETE":
                                features_done.add(feature_name)
                                current_progress += total_years
                            elif content.startswith("ERROR"):
                                features_done.add(feature_name)  
                            else:
                                try:
                                    progress, total = content.split('/')
                                    current_progress += int(progress)
                                except:
                                    pass
                    except:
                        pass
            
            # Update progress bar
            pbar.update(current_progress - previous_progress)
            previous_progress = current_progress
            
         
            time.sleep(0.5)
            
            if len(features_done) == len(features):
                # One final update
                pbar.update(len(features) * total_years - previous_progress)
                break

def main():
    start_time = time.time()
    print("Loading valid patches metadata (with <=30% NaNs)...")
    data = np.load("valid_patches_under30.npz", allow_pickle=True)
    
 
    valid_patches_struct = data["valid_patches"]
    valid_patches_list = []
    
    for i in range(len(valid_patches_struct)):
        patch = {}
        for field in valid_patches_struct.dtype.names:
            patch[field] = valid_patches_struct[field][i]
        valid_patches_list.append(patch)
    
    patch_size = int(data["patch_size"])
    lat_dim    = int(data["lat_dim"])
    lon_dim    = int(data["lon_dim"])
    threshold  = float(data["threshold"])
    
    num_valid = len(valid_patches_list)
    print(f"Found {num_valid} valid patches in 'valid_patches_under30.npz'")
    print(f"Patch size: {patch_size}, domain: {lat_dim}x{lon_dim}, threshold: {threshold}%")

  
    required_bands = generate_required_bands(START_YEAR, END_YEAR)
    num_bands = len(required_bands)
    print(f"Generated {num_bands} required band strings from {START_YEAR} to {END_YEAR}")

   
    out_root = "fire"
    os.makedirs(out_root, exist_ok=True)
    
 
    for year in range(START_YEAR, END_YEAR + 1):
        os.makedirs(os.path.join(out_root, str(year)), exist_ok=True)
    
  
    num_processes = min(4, len(FEATURES))
    print(f"Using {num_processes} parallel processes for {len(FEATURES)} features")
    
 
    process_args = []
    for i, feature in enumerate(FEATURES):
        args = (
            feature,               
            valid_patches_list,  
            patch_size,           
            required_bands,     
            out_root,            
            i                      
        )
        process_args.append(args)
    

    pool = mp.Pool(processes=num_processes)
    print("Starting parallel processing of features...")
    
 
    pool.map_async(process_feature, process_args)
    

    total_years = END_YEAR - START_YEAR + 1
    monitor_progress(out_root, FEATURES, total_years)
    

    pool.close()
    pool.join()
    
    
    completed_features = []
    for feature in FEATURES:
        feature_name = feature["name"]
        progress_file = os.path.join(out_root, f"progress_{feature_name}.txt")
        if os.path.exists(progress_file):
            with open(progress_file, 'r') as f:
                content = f.read().strip()
                if content == "COMPLETE":
                    completed_features.append(feature_name)
    
    print(f"\nSuccessfully processed {len(completed_features)} features: {', '.join(completed_features)}")
    
    end_time = time.time()
    elapsed = end_time - start_time
    print(f"\nAll features processed. Total time: {elapsed:.2f} seconds ({elapsed/60:.2f} minutes)")

if __name__ == "__main__":
  
    mp.freeze_support()
    main()

Loading valid patches metadata (with <=30% NaNs)...
Found 2051 valid patches in 'valid_patches_under30.npz'
Patch size: 128, domain: 6800x9000, threshold: 30.0%
Generated 216 required band strings from 2005 to 2022
Using 4 parallel processes for 7 features
Starting parallel processing of features...


Overall Progress: 100%|██████████| 126/126 [16:44<00:00,  7.97s/it]


Successfully processed 7 features: E0, NDVI, SS, Sd, S0, qtot, etot

All features processed. Total time: 1004.59 seconds (16.74 minutes)





## Validation of Patches : Switch 2 Fuel

In [12]:
import os
import numpy as np
import time
from tqdm.auto import tqdm


FEATURES = ["E0", "NDVI", "SS", "Sd", "S0", "qtot", "etot"]
START_YEAR = 2005
END_YEAR = 2022
PATCH_SIZE = 128
OUT_ROOT = "fire"

def validate_data_structure():
    """
    Validates that all expected files exist and have the correct structure.
    """
    print("Starting validation of processed data...")
    start_time = time.time()
    
  
    valid_patches_data = np.load("valid_patches_under30.npz", allow_pickle=True)
    valid_patches = valid_patches_data["valid_patches"]
    expected_num_patches = len(valid_patches)
    expected_patch_size = int(valid_patches_data["patch_size"])
    threshold = float(valid_patches_data["threshold"])
    
    print(f"Original metadata: {expected_num_patches} valid patches, "
          f"patch size: {expected_patch_size}x{expected_patch_size}, threshold: {threshold}%")
    
    # Check all required directories exist
    years = list(range(START_YEAR, END_YEAR + 1))
    missing_dirs = []
    
    for year in years:
        year_dir = os.path.join(OUT_ROOT, str(year))
        if not os.path.exists(year_dir):
            missing_dirs.append(year_dir)
    
    if missing_dirs:
        print(f"WARNING: {len(missing_dirs)} year directories are missing!")
        for d in missing_dirs[:5]:  
            print(f"  - {d}")
    else:
        print(f"✓ All {len(years)} year directories exist.")
    
   
    expected_files = []
    for year in years:
        for feature in FEATURES:
            expected_files.append(os.path.join(OUT_ROOT, str(year), f"{feature}.npz"))
    
    missing_files = [f for f in expected_files if not os.path.exists(f)]
    
    if missing_files:
        print(f"WARNING: {len(missing_files)} data files are missing!")
        for f in missing_files[:5]:  # Show only the first few
            print(f"  - {f}")
    else:
        print(f"✓ All {len(expected_files)} expected data files exist.")
    

    num_to_sample = min(20, len(expected_files))
    sample_files = np.random.choice(
        [f for f in expected_files if os.path.exists(f)], 
        size=num_to_sample, 
        replace=False
    )
    
    validation_results = []
    
    print(f"Validating content of {num_to_sample} randomly selected files...")
    for file_path in tqdm(sample_files, desc="Validating Files"):
        result = validate_file_content(file_path, expected_num_patches, expected_patch_size, threshold)
        validation_results.append(result)
    
   
    success_count = sum(1 for r in validation_results if r["valid"])
    print(f"\nValidation Results: {success_count}/{len(validation_results)} files are valid.")
    
    if success_count < len(validation_results):
        print("\nIssues found:")
        for r in validation_results:
            if not r["valid"]:
                print(f"  - File: {r['file']}, Issue: {r['issue']}")
    
   
    print("\nPerforming detailed validation on one file...")
    sample_file = expected_files[0]
    if os.path.exists(sample_file):
        detailed_result = validate_file_detailed(sample_file, valid_patches, threshold)
        print(f"Detailed validation of {sample_file}:")
        print(f"  - Total patches: {detailed_result['total_patches']}")
        print(f"  - Valid patches (NaN% <= {threshold}%): {detailed_result['valid_patches']}")
        print(f"  - Invalid patches: {detailed_result['invalid_patches']}")
        print(f"  - Highest NaN%: {detailed_result['max_nan_pct']:.2f}%")
        print(f"  - Average NaN%: {detailed_result['avg_nan_pct']:.2f}%")
        
     
        if detailed_result['nan_percentages']:
            nan_percentages = detailed_result['nan_percentages']
            
          
            ranges = [(0, 10), (10, 20), (20, 30), (30, 40), (40, 50), 
                      (50, 60), (60, 70), (70, 80), (80, 90), (90, 100)]
            
            print("\nNaN Percentage Distribution:")
            for low, high in ranges:
                count = sum(1 for p in nan_percentages if low <= p < high)
                percentage = (count / len(nan_percentages)) * 100
                bar_length = int(percentage / 2)  # Scale for display
                bar = '█' * bar_length
                print(f"  {low:2d}-{high:<3d}%: {count:4d} patches ({percentage:5.1f}%) {bar}")
    
 
    print("\nValidating consistency across months...")
    

    test_feature = FEATURES[0]
    test_year = START_YEAR
    test_file = os.path.join(OUT_ROOT, str(test_year), f"{test_feature}.npz")
    
    if os.path.exists(test_file):
        consistency_result = validate_month_consistency(test_file, threshold)
        print(f"Month-to-month consistency in {test_file}:")
        print(f"  - Number of months: {consistency_result['num_months']}")
        print(f"  - Consistent shape: {consistency_result['consistent_shape']}")
        print(f"  - Month-to-month NaN pattern consistency: {consistency_result['nan_consistency']:.2f}%")
        print(f"  - Are all patches valid (NaN% <= {threshold}%)? {consistency_result['all_valid']}")
    
    end_time = time.time()
    elapsed = end_time - start_time
    print(f"\nValidation completed in {elapsed:.2f} seconds ({elapsed/60:.2f} minutes).")

def validate_file_content(file_path, expected_num_patches, expected_patch_size, threshold):
    """
    Validates the content of a single data file.
    """
    result = {
        "file": file_path,
        "valid": True,
        "issue": None
    }
    
    try:
     
        data = np.load(file_path, allow_pickle=True)
        
    
        required_keys = ['data', 'metadata', 'patch_indices']
        missing_keys = [k for k in required_keys if k not in data]
        if missing_keys:
            result["valid"] = False
            result["issue"] = f"Missing keys: {', '.join(missing_keys)}"
            return result
        
        # Check data shape
        data_array = data['data']
        if len(data_array.shape) != 4:
            result["valid"] = False
            result["issue"] = f"Expected 4D array, got {len(data_array.shape)}D"
            return result
        
        _, num_patches, patch_height, patch_width = data_array.shape
        
        if num_patches != expected_num_patches:
            result["valid"] = False
            result["issue"] = f"Expected {expected_num_patches} patches, got {num_patches}"
            return result
        
        if patch_height != expected_patch_size or patch_width != expected_patch_size:
            result["valid"] = False
            result["issue"] = f"Expected {expected_patch_size}x{expected_patch_size} patch size, got {patch_height}x{patch_width}"
            return result
        
 
        metadata = data['metadata'].item() if isinstance(data['metadata'], np.ndarray) else data['metadata']
        if not isinstance(metadata, dict):
            result["valid"] = False
            result["issue"] = f"Metadata is not a dictionary"
            return result
        
 
        patch_indices = data['patch_indices'].item() if isinstance(data['patch_indices'], np.ndarray) else data['patch_indices']
        if not isinstance(patch_indices, dict):
            result["valid"] = False
            result["issue"] = f"Patch indices is not a dictionary"
            return result
        
    except Exception as e:
        result["valid"] = False
        result["issue"] = f"Error loading/parsing file: {e}"
    
    return result

def validate_file_detailed(file_path, valid_patches, threshold):
    """
    Performs a detailed validation of one file, checking all patches.
    """
    result = {
        'total_patches': 0,
        'valid_patches': 0,
        'invalid_patches': 0,
        'max_nan_pct': 0,
        'avg_nan_pct': 0,
        'nan_percentages': []
    }
    
    try:
      
        data = np.load(file_path, allow_pickle=True)
        data_array = data['data']
        
        num_months, num_patches, patch_height, patch_width = data_array.shape
        result['total_patches'] = num_patches
        
 
        month_idx = 0  # First month
        
        nan_percentages = []
        for patch_idx in range(num_patches):
            patch = data_array[month_idx, patch_idx]
            nan_count = np.isnan(patch).sum()
            total_cells = patch_height * patch_width
            nan_percentage = (nan_count / total_cells) * 100
            
            nan_percentages.append(nan_percentage)
            
            if nan_percentage <= threshold:
                result['valid_patches'] += 1
            else:
                result['invalid_patches'] += 1
        
        result['nan_percentages'] = nan_percentages
        result['max_nan_pct'] = max(nan_percentages) if nan_percentages else 0
        result['avg_nan_pct'] = sum(nan_percentages) / len(nan_percentages) if nan_percentages else 0
        
    except Exception as e:
        print(f"Error in detailed validation: {e}")
    
    return result

def validate_month_consistency(file_path, threshold):
    """
    Validates consistency across all months in a file.
    """
    result = {
        'num_months': 0,
        'consistent_shape': True,
        'nan_consistency': 100.0,  
        'all_valid': True
    }
    
    try:
        # Load the file
        data = np.load(file_path, allow_pickle=True)
        data_array = data['data']
        
        num_months, num_patches, patch_height, patch_width = data_array.shape
        result['num_months'] = num_months
        

        for month_idx in range(num_months):
            month_data = data_array[month_idx]
            if month_data.shape != (num_patches, patch_height, patch_width):
                result['consistent_shape'] = False
                break
        

        reference_month = np.isnan(data_array[0])
        total_cells = num_patches * patch_height * patch_width
        
        for month_idx in range(1, num_months):
            current_month = np.isnan(data_array[month_idx])
            differences = np.sum(reference_month != current_month)
            if differences > 0:
    
                consistency_pct = 100 - (differences / total_cells * 100)
                result['nan_consistency'] = min(result['nan_consistency'], consistency_pct)
        
  
        for month_idx in range(num_months):
            for patch_idx in range(num_patches):
                patch = data_array[month_idx, patch_idx]
                nan_count = np.isnan(patch).sum()
                total_cells_patch = patch_height * patch_width
                nan_percentage = (nan_count / total_cells_patch) * 100
                
                if nan_percentage > threshold:
                    result['all_valid'] = False
                    break
            
            if not result['all_valid']:
                break
        
    except Exception as e:
        print(f"Error in month consistency validation: {e}")
    
    return result

if __name__ == "__main__":
  
    np.random.seed(42)
    validate_data_structure()

Starting validation of processed data...
Original metadata: 2051 valid patches, patch size: 128x128, threshold: 30.0%
✓ All 18 year directories exist.
✓ All 126 expected data files exist.
Validating content of 20 randomly selected files...


Validating Files: 100%|██████████| 20/20 [00:57<00:00,  2.87s/it]



Validation Results: 20/20 files are valid.

Performing detailed validation on one file...
Detailed validation of fire/2005/E0.npz:
  - Total patches: 2051
  - Valid patches (NaN% <= 30.0%): 2051
  - Invalid patches: 0
  - Highest NaN%: 29.98%
  - Average NaN%: 1.11%

NaN Percentage Distribution:
   0-10 %: 1940 patches ( 94.6%) ███████████████████████████████████████████████
  10-20 %:   83 patches (  4.0%) ██
  20-30 %:   28 patches (  1.4%) 
  30-40 %:    0 patches (  0.0%) 
  40-50 %:    0 patches (  0.0%) 
  50-60 %:    0 patches (  0.0%) 
  60-70 %:    0 patches (  0.0%) 
  70-80 %:    0 patches (  0.0%) 
  80-90 %:    0 patches (  0.0%) 
  90-100%:    0 patches (  0.0%) 

Validating consistency across months...
Month-to-month consistency in fire/2005/E0.npz:
  - Number of months: 12
  - Consistent shape: True
  - Month-to-month NaN pattern consistency: 100.00%
  - Are all patches valid (NaN% <= 30.0%)? True

Validation completed in 63.08 seconds (1.05 minutes).


#  Feature : Switch 3 Climate

In [13]:
import os
import numpy as np
import xarray as xr
import gc
import time
import multiprocessing as mp
from tqdm.auto import tqdm


FEATURES = [
    {"path": "../Switch3_Climate_predictors/Llightning_nc_005_Aus.nc", "var": "lightning", "name": "Lightning"},
    {"path": "../Switch3_Climate_predictors/precip_monthly_005_Aus.nc", "var": "precip",   "name": "Precipitation"},
    {"path": "../Switch3_Climate_predictors/tmax_monthly_005_Aus.nc",   "var": "tmax",     "name": "Maximum_Temperature"},
    {"path": "../Switch3_Climate_predictors/tmin_monthly_005_Aus.nc",   "var": "tmin",     "name": "Minimum_Temperature"},
    {"path": "../Switch3_Climate_predictors/vapourpresh09_monthly_005_Aus.nc", "var": "vapourpres", "name": "Vapor_Pressure_09"},
    {"path": "../Switch3_Climate_predictors/vapourpresh15_monthly_005_Aus.nc", "var": "vapourpres", "name": "Vapor_Pressure_15"}
]

START_YEAR = 2005
END_YEAR   = 2022

def generate_required_bands(start_year, end_year):
    """
    Builds a list of required band strings for each month from start_year to end_year.
    """
    required = []
    for year in range(start_year, end_year + 1):
        for month in range(1, 13):
            band_str = f"{year:04d}_{month:02d}_01"
            required.append(band_str)
    return required

def process_feature(args):
    """
    Process a single feature file. All arguments are passed as a single tuple to
    support multiprocessing.
    """
    feature_info, valid_patches_list, patch_size, required_bands, out_root, proc_idx = args
    
    feature_path = feature_info["path"]
    feature_var  = feature_info["var"]
    feature_name = feature_info["name"]
    
 
    progress_file = os.path.join(out_root, f"progress_{feature_name}.txt")
    with open(progress_file, 'w') as f:
        f.write(f"0/{END_YEAR-START_YEAR+1}\n")
    
    if not os.path.exists(feature_path):
        with open(progress_file, 'w') as f:
            f.write(f"ERROR: File not found\n")
        return None
    
    num_valid = len(valid_patches_list)
    
    try:
       
        ds = xr.open_dataset(feature_path)
        
        if feature_var not in ds.data_vars:
            with open(progress_file, 'w') as f:
                f.write(f"ERROR: Variable not found\n")
            ds.close()
            return None
        

        actual_bands = set(str(b) for b in ds["band"].values)
        

        missing_bands = [b for b in required_bands if b not in actual_bands]
        if missing_bands:
            with open(progress_file, 'w') as f:
                f.write(f"ERROR: Missing bands\n")
            ds.close()
            return None
        

        for i, year in enumerate(range(START_YEAR, END_YEAR + 1)):
            # Update progress file
            with open(progress_file, 'w') as f:
                f.write(f"{i}/{END_YEAR-START_YEAR+1}\n")
            
        
            year_bands = [b for b in required_bands if b.startswith(f"{year}_")]
            
       
            months_in_year = len(year_bands)
            year_data = np.zeros((months_in_year, num_valid, patch_size, patch_size), dtype=np.float32)
            year_data.fill(np.nan)  # Initialize with NaNs
            
        
            for month_idx, band_str in enumerate(year_bands):
                try:
                
                    band_data = ds[feature_var].sel(band=band_str).load().values
                    
           
                    for patch_idx, patch_info in enumerate(valid_patches_list):
                        lat_start = patch_info["lat_start"]
                        lat_end   = patch_info["lat_end"]
                        lon_start = patch_info["lon_start"]
                        lon_end   = patch_info["lon_end"]
                        
                        patch = band_data[lat_start:lat_end, lon_start:lon_end].copy()
                        year_data[month_idx, patch_idx] = patch
                    
           
                    del band_data
                    
                except Exception as e:
                    with open(os.path.join(out_root, f"error_{feature_name}.txt"), 'a') as f:
                        f.write(f"Error processing band {band_str}: {e}\n")
                    continue
            
    
            year_dir = os.path.join(out_root, str(year))
            os.makedirs(year_dir, exist_ok=True)
            
 
            out_file = os.path.join(year_dir, f"{feature_name}.npz")
            
     
            metadata = {
                "feature_name": feature_name,
                "feature_var": feature_var,
                "year": year,
                "bands": year_bands,
                "patch_size": patch_size,
                "num_patches": num_valid
            }
            
     
            patch_indices = {}
            for i, patch_info in enumerate(valid_patches_list):
                patch_i = patch_info["patch_i"]
                patch_j = patch_info["patch_j"]
                patch_indices[f"{patch_i}_{patch_j}"] = i
            
    
            np.savez_compressed(
                out_file,
                data=year_data,
                metadata=metadata,
                patch_indices=patch_indices
            )
            
    
            del year_data
            gc.collect()
            
        ds.close()
        

        with open(progress_file, 'w') as f:
            f.write(f"COMPLETE\n")
            
        return feature_name
        
    except Exception as e:
        with open(os.path.join(out_root, f"error_{feature_name}.txt"), 'a') as f:
            f.write(f"Unexpected error: {e}\n")
        return None

def monitor_progress(out_root, features, total_years):
    """
    Monitor progress of all processes and display in a single progress bar.
    """
    features_done = set()
    
    with tqdm(total=len(features) * total_years, desc="Overall Progress") as pbar:
        previous_progress = 0
        
        while len(features_done) < len(features):
            current_progress = 0
            
            for feature in features:
                feature_name = feature["name"]
                if feature_name in features_done:
                    current_progress += total_years
                    continue
                    
                progress_file = os.path.join(out_root, f"progress_{feature_name}.txt")
                if os.path.exists(progress_file):
                    try:
                        with open(progress_file, 'r') as f:
                            content = f.read().strip()
                            if content == "COMPLETE":
                                features_done.add(feature_name)
                                current_progress += total_years
                            elif content.startswith("ERROR"):
                                features_done.add(feature_name)  # Count as done but errored
                            else:
                                try:
                                    progress, total = content.split('/')
                                    current_progress += int(progress)
                                except:
                                    pass
                    except:
                        pass
            
            # Update progress bar
            pbar.update(current_progress - previous_progress)
            previous_progress = current_progress
            
            # Don't update too frequently
            time.sleep(0.5)
            
            if len(features_done) == len(features):
                # One final update
                pbar.update(len(features) * total_years - previous_progress)
                break

def main():
    start_time = time.time()
    print("Loading valid patches metadata (with <=30% NaNs)...")
    data = np.load("valid_patches_under30.npz", allow_pickle=True)
    
 
    valid_patches_struct = data["valid_patches"]
    valid_patches_list = []
    
    for i in range(len(valid_patches_struct)):
        patch = {}
        for field in valid_patches_struct.dtype.names:
            patch[field] = valid_patches_struct[field][i]
        valid_patches_list.append(patch)
    
    patch_size = int(data["patch_size"])
    lat_dim    = int(data["lat_dim"])
    lon_dim    = int(data["lon_dim"])
    threshold  = float(data["threshold"])
    
    num_valid = len(valid_patches_list)
    print(f"Found {num_valid} valid patches in 'valid_patches_under30.npz'")
    print(f"Patch size: {patch_size}, domain: {lat_dim}x{lon_dim}, threshold: {threshold}%")


    required_bands = generate_required_bands(START_YEAR, END_YEAR)
    num_bands = len(required_bands)
    print(f"Generated {num_bands} required band strings from {START_YEAR} to {END_YEAR}")


    out_root = "climate"
    os.makedirs(out_root, exist_ok=True)
    
 
    for year in range(START_YEAR, END_YEAR + 1):
        os.makedirs(os.path.join(out_root, str(year)), exist_ok=True)
    

    num_processes = min(4, len(FEATURES))
    print(f"Using {num_processes} parallel processes for {len(FEATURES)} features")
    

    process_args = []
    for i, feature in enumerate(FEATURES):
        args = (
            feature,                # Feature info
            valid_patches_list,     # Valid patches
            patch_size,             # Patch size
            required_bands,         # Required bands
            out_root,               # Output root directory
            i                       # Process index
        )
        process_args.append(args)
    

    pool = mp.Pool(processes=num_processes)
    print("Starting parallel processing of features...")
    
  
    pool.map_async(process_feature, process_args)
    
  
    total_years = END_YEAR - START_YEAR + 1
    monitor_progress(out_root, FEATURES, total_years)
    

    pool.close()
    pool.join()
    

    completed_features = []
    for feature in FEATURES:
        feature_name = feature["name"]
        progress_file = os.path.join(out_root, f"progress_{feature_name}.txt")
        if os.path.exists(progress_file):
            with open(progress_file, 'r') as f:
                content = f.read().strip()
                if content == "COMPLETE":
                    completed_features.append(feature_name)
    
    print(f"\nSuccessfully processed {len(completed_features)} features: {', '.join(completed_features)}")
    
    end_time = time.time()
    elapsed = end_time - start_time
    print(f"\nAll features processed. Total time: {elapsed:.2f} seconds ({elapsed/60:.2f} minutes)")

if __name__ == "__main__":
  
    mp.freeze_support()
    main()

Loading valid patches metadata (with <=30% NaNs)...
Found 2051 valid patches in 'valid_patches_under30.npz'
Patch size: 128, domain: 6800x9000, threshold: 30.0%
Generated 216 required band strings from 2005 to 2022
Using 4 parallel processes for 6 features
Starting parallel processing of features...


Overall Progress: 100%|██████████| 108/108 [30:47<00:00, 17.11s/it]


Successfully processed 6 features: Lightning, Precipitation, Maximum_Temperature, Minimum_Temperature, Vapor_Pressure_09, Vapor_Pressure_15

All features processed. Total time: 1848.04 seconds (30.80 minutes)





In [14]:
import os
import numpy as np
import time
from tqdm.auto import tqdm

# Define the parameters that should match our processed climate data
FEATURES = ["Lightning", "Precipitation", "Maximum_Temperature", "Minimum_Temperature", 
            "Vapor_Pressure_09", "Vapor_Pressure_15"]
START_YEAR = 2005
END_YEAR = 2022
PATCH_SIZE = 128
OUT_ROOT = "climate"

def validate_data_structure():
    """
    Validates that all expected files exist and have the correct structure.
    """
    print("Starting validation of processed climate data...")
    start_time = time.time()
    
    # Load the original valid patches metadata for verification
    valid_patches_data = np.load("valid_patches_under30.npz", allow_pickle=True)
    valid_patches = valid_patches_data["valid_patches"]
    expected_num_patches = len(valid_patches)
    expected_patch_size = int(valid_patches_data["patch_size"])
    threshold = float(valid_patches_data["threshold"])
    
    print(f"Original metadata: {expected_num_patches} valid patches, "
          f"patch size: {expected_patch_size}x{expected_patch_size}, threshold: {threshold}%")
    
    # Check all required directories exist
    years = list(range(START_YEAR, END_YEAR + 1))
    missing_dirs = []
    
    for year in years:
        year_dir = os.path.join(OUT_ROOT, str(year))
        if not os.path.exists(year_dir):
            missing_dirs.append(year_dir)
    
    if missing_dirs:
        print(f"WARNING: {len(missing_dirs)} year directories are missing!")
        for d in missing_dirs[:5]:  # Show only the first few
            print(f"  - {d}")
    else:
        print(f"✓ All {len(years)} year directories exist.")
    
    # Verify all npz files exist
    expected_files = []
    for year in years:
        for feature in FEATURES:
            expected_files.append(os.path.join(OUT_ROOT, str(year), f"{feature}.npz"))
    
    missing_files = [f for f in expected_files if not os.path.exists(f)]
    
    if missing_files:
        print(f"WARNING: {len(missing_files)} data files are missing!")
        for f in missing_files[:5]:  # Show only the first few
            print(f"  - {f}")
    else:
        print(f"✓ All {len(expected_files)} expected data files exist.")
    
    # Validate content of files (sample a subset for efficiency)
    num_to_sample = min(20, len(expected_files))
    sample_files = np.random.choice(
        [f for f in expected_files if os.path.exists(f)], 
        size=num_to_sample, 
        replace=False
    )
    
    validation_results = []
    
    print(f"Validating content of {num_to_sample} randomly selected files...")
    for file_path in tqdm(sample_files, desc="Validating Files"):
        result = validate_file_content(file_path, expected_num_patches, expected_patch_size, threshold)
        validation_results.append(result)
    
    # Summarize validation results
    success_count = sum(1 for r in validation_results if r["valid"])
    print(f"\nValidation Results: {success_count}/{len(validation_results)} files are valid.")
    
    if success_count < len(validation_results):
        print("\nIssues found:")
        for r in validation_results:
            if not r["valid"]:
                print(f"  - File: {r['file']}, Issue: {r['issue']}")
    
    # Perform a deeper validation on one file to check all patches
    print("\nPerforming detailed validation on one file...")
    sample_file = expected_files[0]
    if os.path.exists(sample_file):
        detailed_result = validate_file_detailed(sample_file, valid_patches, threshold)
        print(f"Detailed validation of {sample_file}:")
        print(f"  - Total patches: {detailed_result['total_patches']}")
        print(f"  - Valid patches (NaN% <= {threshold}%): {detailed_result['valid_patches']}")
        print(f"  - Invalid patches: {detailed_result['invalid_patches']}")
        print(f"  - Highest NaN%: {detailed_result['max_nan_pct']:.2f}%")
        print(f"  - Average NaN%: {detailed_result['avg_nan_pct']:.2f}%")
        
        # Print NaN percentage distribution in text form instead of histogram
        if detailed_result['nan_percentages']:
            nan_percentages = detailed_result['nan_percentages']
            
            # Define ranges for text-based distribution
            ranges = [(0, 10), (10, 20), (20, 30), (30, 40), (40, 50), 
                      (50, 60), (60, 70), (70, 80), (80, 90), (90, 100)]
            
            print("\nNaN Percentage Distribution:")
            for low, high in ranges:
                count = sum(1 for p in nan_percentages if low <= p < high)
                percentage = (count / len(nan_percentages)) * 100
                bar_length = int(percentage / 2)  # Scale for display
                bar = '█' * bar_length
                print(f"  {low:2d}-{high:<3d}%: {count:4d} patches ({percentage:5.1f}%) {bar}")
    
    # Validate consistency across months
    print("\nValidating consistency across months...")
    
    # Select one feature and year for consistency check
    test_feature = FEATURES[0]
    test_year = START_YEAR
    test_file = os.path.join(OUT_ROOT, str(test_year), f"{test_feature}.npz")
    
    if os.path.exists(test_file):
        consistency_result = validate_month_consistency(test_file, threshold)
        print(f"Month-to-month consistency in {test_file}:")
        print(f"  - Number of months: {consistency_result['num_months']}")
        print(f"  - Consistent shape: {consistency_result['consistent_shape']}")
        print(f"  - Month-to-month NaN pattern consistency: {consistency_result['nan_consistency']:.2f}%")
        print(f"  - Are all patches valid (NaN% <= {threshold}%)? {consistency_result['all_valid']}")
    
    # Additional check for climate data: analyze temporal variation
    if os.path.exists(test_file):
        print("\nAnalyzing temporal variation in climate data...")
        temporal_result = analyze_temporal_variation(test_file)
        print(f"Temporal variation in {test_file}:")
        print(f"  - Mean month-to-month change: {temporal_result['mean_change']:.4f}")
        print(f"  - Maximum month-to-month change: {temporal_result['max_change']:.4f}")
        print(f"  - Standard deviation of values: {temporal_result['std_dev']:.4f}")
        
        # Check for seasonal patterns (simple check)
        if temporal_result['has_seasonal_pattern']:
            print("  - Detected likely seasonal pattern in the data")
        else:
            print("  - No clear seasonal pattern detected")
    
    end_time = time.time()
    elapsed = end_time - start_time
    print(f"\nValidation completed in {elapsed:.2f} seconds ({elapsed/60:.2f} minutes).")

def validate_file_content(file_path, expected_num_patches, expected_patch_size, threshold):
    """
    Validates the content of a single data file.
    """
    result = {
        "file": file_path,
        "valid": True,
        "issue": None
    }
    
    try:
        # Load the file
        data = np.load(file_path, allow_pickle=True)
        
        # Check if required keys exist
        required_keys = ['data', 'metadata', 'patch_indices']
        missing_keys = [k for k in required_keys if k not in data]
        if missing_keys:
            result["valid"] = False
            result["issue"] = f"Missing keys: {', '.join(missing_keys)}"
            return result
        
        # Check data shape
        data_array = data['data']
        if len(data_array.shape) != 4:
            result["valid"] = False
            result["issue"] = f"Expected 4D array, got {len(data_array.shape)}D"
            return result
        
        # Climate data typically has shape [months, patches, height, width]
        num_months, num_patches, patch_height, patch_width = data_array.shape
        
        if num_patches != expected_num_patches:
            result["valid"] = False
            result["issue"] = f"Expected {expected_num_patches} patches, got {num_patches}"
            return result
        
        if patch_height != expected_patch_size or patch_width != expected_patch_size:
            result["valid"] = False
            result["issue"] = f"Expected {expected_patch_size}x{expected_patch_size} patch size, got {patch_height}x{patch_width}"
            return result
        
        # Check metadata structure
        metadata = data['metadata'].item() if isinstance(data['metadata'], np.ndarray) else data['metadata']
        if not isinstance(metadata, dict):
            result["valid"] = False
            result["issue"] = f"Metadata is not a dictionary"
            return result
        
        # Validate climate metadata - check if required fields exist
        required_metadata = ["feature_name", "feature_var", "year", "bands", "patch_size", "num_patches"]
        missing_metadata = [k for k in required_metadata if k not in metadata]
        if missing_metadata:
            result["valid"] = False
            result["issue"] = f"Missing metadata fields: {', '.join(missing_metadata)}"
            return result
        
        # Check patch indices structure
        patch_indices = data['patch_indices'].item() if isinstance(data['patch_indices'], np.ndarray) else data['patch_indices']
        if not isinstance(patch_indices, dict):
            result["valid"] = False
            result["issue"] = f"Patch indices is not a dictionary"
            return result
        
        # Check if months match expected value (should be 12 for climate data, per year)
        if num_months != 12:
            result["valid"] = False
            result["issue"] = f"Expected 12 months per year, got {num_months}"
            return result
        
        # Climate-specific: check for unrealistic values
        feature_name = os.path.basename(file_path).replace('.npz', '')
        unrealistic_values = check_unrealistic_values(data_array, feature_name)
        if unrealistic_values:
            result["valid"] = False
            result["issue"] = f"Contains unrealistic values: {unrealistic_values}"
            return result
        
    except Exception as e:
        result["valid"] = False
        result["issue"] = f"Error loading/parsing file: {e}"
    
    return result

def check_unrealistic_values(data_array, feature_name):
    """
    Check for unrealistic values based on the feature type.
    Returns description of problems found or None if values are reasonable.
    """
    # Flatten for easier analysis, ignoring NaNs
    valid_data = data_array[~np.isnan(data_array)]
    
    if len(valid_data) == 0:
        return "Data contains only NaN values"
    
    # Define reasonable ranges for different climate features
    if "Temperature" in feature_name:
        # Temperature in degrees Celsius, realistic range from -80 to +60
        if np.min(valid_data) < -80 or np.max(valid_data) > 60:
            return f"Temperature out of realistic range: min={np.min(valid_data):.1f}, max={np.max(valid_data):.1f}"
    
    elif "Precipitation" in feature_name:
        # Precipitation should be >= 0 and generally < 2000mm per month
        if np.min(valid_data) < 0:
            return f"Negative precipitation values: min={np.min(valid_data):.1f}"
        if np.max(valid_data) > 2000:
            return f"Extremely high precipitation: max={np.max(valid_data):.1f}"
    
    elif "Lightning" in feature_name:
        # Lightning counts should be >= 0
        if np.min(valid_data) < 0:
            return f"Negative lightning counts: min={np.min(valid_data):.1f}"
    
    elif "Vapor_Pressure" in feature_name:
        # Vapor pressure typically between 0-101.3 kPa (1 atm)
        if np.min(valid_data) < 0:
            return f"Negative vapor pressure: min={np.min(valid_data):.1f}"
        if np.max(valid_data) > 101.3:
            return f"Vapor pressure exceeds 1 atm: max={np.max(valid_data):.1f}"
    
    # Check for other anomalies
    if np.all(valid_data == valid_data[0]):
        return "Data contains only a single repeated value"
    
    # Check for NaN percentage
    nan_percentage = np.isnan(data_array).sum() / data_array.size * 100
    if nan_percentage > 50:
        return f"High percentage of NaN values: {nan_percentage:.1f}%"
    
    return None

def validate_file_detailed(file_path, valid_patches, threshold):
    """
    Performs a detailed validation of one file, checking all patches.
    """
    result = {
        'total_patches': 0,
        'valid_patches': 0,
        'invalid_patches': 0,
        'max_nan_pct': 0,
        'avg_nan_pct': 0,
        'nan_percentages': []
    }
    
    try:
        # Load the file
        data = np.load(file_path, allow_pickle=True)
        data_array = data['data']
        
        num_months, num_patches, patch_height, patch_width = data_array.shape
        result['total_patches'] = num_patches
        
        # Sample a month to check NaN percentages
        month_idx = 0  # First month
        
        nan_percentages = []
        for patch_idx in range(num_patches):
            patch = data_array[month_idx, patch_idx]
            nan_count = np.isnan(patch).sum()
            total_cells = patch_height * patch_width
            nan_percentage = (nan_count / total_cells) * 100
            
            nan_percentages.append(nan_percentage)
            
            if nan_percentage <= threshold:
                result['valid_patches'] += 1
            else:
                result['invalid_patches'] += 1
        
        result['nan_percentages'] = nan_percentages
        result['max_nan_pct'] = max(nan_percentages) if nan_percentages else 0
        result['avg_nan_pct'] = sum(nan_percentages) / len(nan_percentages) if nan_percentages else 0
        
    except Exception as e:
        print(f"Error in detailed validation: {e}")
    
    return result

def validate_month_consistency(file_path, threshold):
    """
    Validates consistency across all months in a file.
    """
    result = {
        'num_months': 0,
        'consistent_shape': True,
        'nan_consistency': 100.0,  # Percentage of consistent NaN patterns
        'all_valid': True
    }
    
    try:
        # Load the file
        data = np.load(file_path, allow_pickle=True)
        data_array = data['data']
        
        num_months, num_patches, patch_height, patch_width = data_array.shape
        result['num_months'] = num_months
        
        # Check shape consistency across months
        for month_idx in range(num_months):
            month_data = data_array[month_idx]
            if month_data.shape != (num_patches, patch_height, patch_width):
                result['consistent_shape'] = False
                break
        
        # Check NaN pattern consistency across months
        # We'll compare each month to the first month
        reference_month = np.isnan(data_array[0])
        total_cells = num_patches * patch_height * patch_width
        
        for month_idx in range(1, num_months):
            current_month = np.isnan(data_array[month_idx])
            differences = np.sum(reference_month != current_month)
            if differences > 0:
                # There are some differences in NaN patterns
                consistency_pct = 100 - (differences / total_cells * 100)
                result['nan_consistency'] = min(result['nan_consistency'], consistency_pct)
        
        # Check if all patches in all months are valid
        for month_idx in range(num_months):
            for patch_idx in range(num_patches):
                patch = data_array[month_idx, patch_idx]
                nan_count = np.isnan(patch).sum()
                total_cells_patch = patch_height * patch_width
                nan_percentage = (nan_count / total_cells_patch) * 100
                
                if nan_percentage > threshold:
                    result['all_valid'] = False
                    break
            
            if not result['all_valid']:
                break
        
    except Exception as e:
        print(f"Error in month consistency validation: {e}")
    
    return result

def analyze_temporal_variation(file_path):
    """
    Analyzes temporal variation in the climate data.
    This is a climate-specific function to check for realistic seasonal patterns.
    """
    result = {
        'mean_change': 0.0,
        'max_change': 0.0,
        'std_dev': 0.0,
        'has_seasonal_pattern': False
    }
    
    try:
        # Load the file
        data = np.load(file_path, allow_pickle=True)
        data_array = data['data']
        feature_name = data['metadata'].item()['feature_name']
        
        num_months, num_patches, patch_height, patch_width = data_array.shape
        
        # For efficiency, sample a subset of patches
        num_to_sample = min(50, num_patches)
        patch_indices = np.random.choice(num_patches, num_to_sample, replace=False)
        
        # Calculate month-to-month changes
        changes = []
        for patch_idx in patch_indices:
            # Average values for each month in this patch
            monthly_means = []
            for month_idx in range(num_months):
                patch = data_array[month_idx, patch_idx]
                valid_values = patch[~np.isnan(patch)]
                if len(valid_values) > 0:
                    monthly_means.append(np.mean(valid_values))
                else:
                    monthly_means.append(np.nan)
            
            # Calculate changes between consecutive months
            monthly_means = np.array(monthly_means)
            valid_months = ~np.isnan(monthly_means)
            if np.sum(valid_months) > 1:
                valid_means = monthly_means[valid_months]
                month_to_month = np.abs(np.diff(valid_means))
                if len(month_to_month) > 0:
                    changes.extend(month_to_month)
        
        if changes:
            result['mean_change'] = np.mean(changes)
            result['max_change'] = np.max(changes)
            
            # Calculate overall standard deviation of the data
            all_valid_data = data_array[~np.isnan(data_array)]
            if len(all_valid_data) > 0:
                result['std_dev'] = np.std(all_valid_data)
            
            # Simple check for seasonal patterns based on feature type
            # For temperature: we expect higher values in summer months (seasonal)
            # For precipitation: may show seasonal patterns depending on location
            # For lightning: often follows seasonal patterns
            has_pattern = False
            
            # This is a simplified check - in reality would need more sophisticated analysis
            # and would depend on hemisphere (northern/southern) for seasonality timing
            if "Temperature" in feature_name or "Lightning" in feature_name:
                has_pattern = True
            elif "Precipitation" in feature_name or "Vapor_Pressure" in feature_name:
                # Check if there's reasonable variation throughout the year
                # Simple test: If std dev is > 10% of mean, likely has seasonal pattern
                if len(all_valid_data) > 0:
                    mean_value = np.mean(all_valid_data)
                    if mean_value != 0 and result['std_dev'] / mean_value > 0.1:
                        has_pattern = True
            
            result['has_seasonal_pattern'] = has_pattern
    
    except Exception as e:
        print(f"Error in temporal variation analysis: {e}")
    
    return result

if __name__ == "__main__":
    # Set a random seed for reproducibility
    np.random.seed(42)
    validate_data_structure()

Starting validation of processed climate data...
Original metadata: 2051 valid patches, patch size: 128x128, threshold: 30.0%
✓ All 18 year directories exist.
✓ All 108 expected data files exist.
Validating content of 20 randomly selected files...


Validating Files: 100%|██████████| 20/20 [01:49<00:00,  5.45s/it]



Validation Results: 20/20 files are valid.

Performing detailed validation on one file...
Detailed validation of climate/2005/Lightning.npz:
  - Total patches: 2051
  - Valid patches (NaN% <= 30.0%): 2051
  - Invalid patches: 0
  - Highest NaN%: 0.00%
  - Average NaN%: 0.00%

NaN Percentage Distribution:
   0-10 %: 2051 patches (100.0%) ██████████████████████████████████████████████████
  10-20 %:    0 patches (  0.0%) 
  20-30 %:    0 patches (  0.0%) 
  30-40 %:    0 patches (  0.0%) 
  40-50 %:    0 patches (  0.0%) 
  50-60 %:    0 patches (  0.0%) 
  60-70 %:    0 patches (  0.0%) 
  70-80 %:    0 patches (  0.0%) 
  80-90 %:    0 patches (  0.0%) 
  90-100%:    0 patches (  0.0%) 

Validating consistency across months...
Month-to-month consistency in climate/2005/Lightning.npz:
  - Number of months: 12
  - Consistent shape: True
  - Month-to-month NaN pattern consistency: 100.00%
  - Are all patches valid (NaN% <= 30.0%)? True

Analyzing temporal variation in climate data...
Tem

# TARGET VARIABLE : BURNED AREA 

In [None]:
import os
import numpy as np
import xarray as xr
import gc
import time
import multiprocessing as mp
from tqdm.auto import tqdm

# Path and variable name for your burned-area dataset
BURNED_DATASET = "/home/ubuntu/Data-Seasonal-forecast/Response-Burned_area/MODIS_BA_nc_005_Aus.nc"
BURNED_VARIABLE = "burned_area"
BURNED_NAME = "BurnedArea"

# Example year range
START_YEAR = 2005
END_YEAR = 2022

def generate_year_month_bands(start_year, end_year):
    """
    Builds a list of required band strings in the format 'YYYY-MM'.
    Example: '2001-01', '2001-02', ..., '2001-12', '2002-01', etc.
    """
    required = []
    for year in range(start_year, end_year + 1):
        for month in range(1, 13):
            band_str = f"{year:04d}-{month:02d}"
            required.append(band_str)
    return required

def process_year(args):
    """
    Process a single year of burned area data. All arguments are passed as a single tuple to
    support multiprocessing.
    """
    year, ds, required_bands, valid_patches_list, patch_size, out_root = args
    
    # Create a progress file for this process to track progress
    progress_file = os.path.join(out_root, f"progress_{year}.txt")
    with open(progress_file, 'w') as f:
        f.write(f"0/12\n")  # Assuming 12 months per year
    
    try:
        # Gather the band strings for that year
        year_str = f"{year:04d}-"
        year_bands = [b for b in required_bands if b.startswith(year_str)]
        months_in_year = len(year_bands)
        num_valid_patches = len(valid_patches_list)
        
        # Allocate [months, patches, patch_size, patch_size]
        year_data = np.zeros((months_in_year, num_valid_patches, patch_size, patch_size), dtype=np.float32)
        year_data.fill(np.nan)
        
        for month_idx, band_str in enumerate(year_bands):
            # Update progress
            with open(progress_file, 'w') as f:
                f.write(f"{month_idx + 1}/12\n")
                
            try:
                # Load the 2D array for that band
                band_2d = ds[BURNED_VARIABLE].sel(band=band_str).load().values  # shape=(lat_dim, lon_dim)
                
                # Extract each valid patch
                for p_idx, patch_info in enumerate(valid_patches_list):
                    lat_start = patch_info["lat_start"]
                    lat_end = patch_info["lat_end"]
                    lon_start = patch_info["lon_start"]
                    lon_end = patch_info["lon_end"]
                    
                    subarray = band_2d[lat_start:lat_end, lon_start:lon_end].copy()
                    year_data[month_idx, p_idx] = subarray
                
                del band_2d
            except Exception as e:
                with open(os.path.join(out_root, f"error_{year}.txt"), 'a') as f:
                    f.write(f"Could not load band '{band_str}': {e}\n")
        
        # Save one .npz file per year
        year_dir = os.path.join(out_root, f"{year}")
        os.makedirs(year_dir, exist_ok=True)
        
        out_file = os.path.join(year_dir, f"{BURNED_NAME}.npz")
        metadata = {
            "feature_name": BURNED_NAME,
            "feature_var": BURNED_VARIABLE,
            "year": year,
            "bands": year_bands,
            "patch_size": patch_size,
            "num_patches": num_valid_patches
        }
        
        # Build patch index map
        patch_indices = {}
        for i, patch_info in enumerate(valid_patches_list):
            pi = patch_info["patch_i"]
            pj = patch_info["patch_j"]
            patch_indices[f"{pi}_{pj}"] = i
        
        np.savez_compressed(
            out_file,
            data=year_data,
            metadata=metadata,
            patch_indices=patch_indices
        )
        
        del year_data
        gc.collect()
        
        # Mark as complete
        with open(progress_file, 'w') as f:
            f.write(f"COMPLETE\n")
            
        return year
        
    except Exception as e:
        with open(os.path.join(out_root, f"error_{year}.txt"), 'a') as f:
            f.write(f"Unexpected error: {e}\n")
        return None

def monitor_progress(out_root, years):
    """
    Monitor progress of all processes and display in a single progress bar.
    """
    years_done = set()
    total_years = len(years)
    
    with tqdm(total=total_years * 12, desc="Overall Progress") as pbar:
        previous_progress = 0
        
        while len(years_done) < total_years:
            current_progress = 0
            
            for year in years:
                if year in years_done:
                    current_progress += 12  # 12 months already counted
                    continue
                    
                progress_file = os.path.join(out_root, f"progress_{year}.txt")
                if os.path.exists(progress_file):
                    try:
                        with open(progress_file, 'r') as f:
                            content = f.read().strip()
                            if content == "COMPLETE":
                                years_done.add(year)
                                current_progress += 12
                            else:
                                try:
                                    progress, total = content.split('/')
                                    current_progress += int(progress)
                                except:
                                    pass
                    except:
                        pass
            
            # Update progress bar
            pbar.update(current_progress - previous_progress)
            previous_progress = current_progress
            
            # Don't update too frequently
            time.sleep(0.5)
            
            if len(years_done) == total_years:
                # One final update
                pbar.update(total_years * 12 - previous_progress)
                break

def main():
    start_time = time.time()

    # 1) Load valid patches from 'valid_patches_under30.npz'
    valid_patches_file = "valid_patches_under30.npz"
    if not os.path.exists(valid_patches_file):
        raise FileNotFoundError(f"Valid patches file '{valid_patches_file}' not found.")
    
    print(f"Loading valid patches from '{valid_patches_file}'...")
    data = np.load(valid_patches_file, allow_pickle=True)
    
    # Convert structured array to a list of dictionaries
    valid_struct = data["valid_patches"]
    valid_patches_list = []
    for i in range(len(valid_struct)):
        patch_dict = {}
        for field in valid_struct.dtype.names:
            patch_dict[field] = valid_struct[field][i]
        valid_patches_list.append(patch_dict)
    
    patch_size = int(data["patch_size"])
    threshold = float(data["threshold"])
    num_valid_patches = len(valid_patches_list)

    print(f"Found {num_valid_patches} valid patches (<= {threshold}% NaNs).")
    print(f"Patch size = {patch_size}x{patch_size}\n")

    # 2) Generate the list of band strings like 'YYYY-MM'
    required_bands = generate_year_month_bands(START_YEAR, END_YEAR)
    print(f"Generated {len(required_bands)} required bands from {START_YEAR} to {END_YEAR}.")

    # 3) Open the burned-area dataset
    if not os.path.exists(BURNED_DATASET):
        raise FileNotFoundError(f"Dataset '{BURNED_DATASET}' not found.")
    
    ds = xr.open_dataset(BURNED_DATASET)
    if BURNED_VARIABLE not in ds.data_vars:
        ds.close()
        raise KeyError(f"Variable '{BURNED_VARIABLE}' not found in the dataset.")
    
    # 4) Verify that all required bands exist
    actual_bands = set(str(b) for b in ds["band"].values)  # e.g. {"2000-11", "2000-12", "2001-01", ...}
    missing_bands = [b for b in required_bands if b not in actual_bands]
    if missing_bands:
        ds.close()
        print(f"Missing some required 'YYYY-MM' bands (showing up to 10): {missing_bands[:10]}")
        raise ValueError("Not all required year-month bands are present in the dataset.")

    # 5) Create an output folder
    out_root = "Target"
    os.makedirs(out_root, exist_ok=True)
    
    # Create year directories
    for year in range(START_YEAR, END_YEAR + 1):
        os.makedirs(os.path.join(out_root, str(year)), exist_ok=True)
    
    # 6) Prepare for parallel processing by year
    years = list(range(START_YEAR, END_YEAR + 1))
    
    # Determine number of processes to use (adjust according to memory requirements)
    # Since we're only dealing with one dataset at a time, we can use more processes
    num_processes = min(8, len(years), mp.cpu_count())
    print(f"Using {num_processes} parallel processes for {len(years)} years")
    
    # Prepare arguments for each process
    process_args = []
    for year in years:
        args = (
            year,              # Year to process
            ds,                # Dataset (shared)
            required_bands,    # Required bands
            valid_patches_list, # Valid patches
            patch_size,        # Patch size
            out_root          # Output root directory
        )
        process_args.append(args)
    
    # Start the multiprocessing pool
    pool = mp.Pool(processes=num_processes)
    print("Starting parallel processing of years...")
    
    # Start processes
    pool.map_async(process_year, process_args)
    
    # Start monitoring thread for progress display
    monitor_progress(out_root, years)
    
    # Wait for all processes to complete
    pool.close()
    pool.join()
    
    # Close the dataset
    ds.close()
    
    # Check which years were successfully processed
    completed_years = []
    for year in years:
        progress_file = os.path.join(out_root, f"progress_{year}.txt")
        if os.path.exists(progress_file):
            with open(progress_file, 'r') as f:
                content = f.read().strip()
                if content == "COMPLETE":
                    completed_years.append(year)
    
    print(f"\nSuccessfully processed {len(completed_years)} years out of {len(years)}")
    
    elapsed = time.time() - start_time
    print(f"\nAll done. Burned-area data processed for {START_YEAR}-{END_YEAR}, in 'YYYY-MM' format.")
    print(f"Total time: {elapsed:.2f}s ({elapsed/60:.2f} min).")

if __name__ == "__main__":
    # Required for Windows multiprocessing
    mp.freeze_support()
    main()

Loading valid patches from 'valid_patches_under30.npz'...
Found 2051 valid patches (<= 30.0% NaNs).
Patch size = 128x128

Generated 216 required bands from 2005 to 2022.
Using 8 parallel processes for 18 years
Starting parallel processing of years...


Overall Progress: 100%|██████████| 216/216 [02:03<00:00,  1.75it/s] 


In [1]:
import os
import numpy as np
from tqdm import tqdm

# Constants
TARGET_DIR = "Target"
START_YEAR = 2005
END_YEAR = 2022
FEATURE_NAME = "BurnedArea"

def analyze_burned_areas():
    """
    Analyze and print burned area data:
    1. For each band (e.g., "2005-01"), the percentage of patches with at least 1 burned pixel
    2. For patches with burned pixels, the percentage of pixels that are burned (have value 1)
    """
    print("\nBurned Area Analysis Results:")
    print("=" * 80)
    print("{:<10} | {:<35} | {:<35}".format(
        "Band", 
        "Patches with Fire (%)", 
        "Avg Burned Area in Affected Patches (%)"
    ))
    print("-" * 80)
    
    # Process each year
    for year in range(START_YEAR, END_YEAR + 1):
        year_dir = os.path.join(TARGET_DIR, str(year))
        npz_file = os.path.join(year_dir, f"{FEATURE_NAME}.npz")
        
        if not os.path.exists(npz_file):
            # print(f"Warning: File for year {year} not found: {npz_file}")
            continue
        
        # Load the data with allow_pickle=True
        data = np.load(npz_file, allow_pickle=True)
        year_data = data['data']  # Shape: [months, patches, patch_height, patch_width]
        
        metadata = data['metadata'].item()
        bands = metadata['bands']  # e.g., ['2005-01', '2005-02', ...]
        
        # Process each month/band
        for month_idx, band in enumerate(bands):
            # Get all patches for this month
            monthly_patches = year_data[month_idx]  # Shape: [patches, patch_height, patch_width]
            
            total_patches = monthly_patches.shape[0]
            patches_with_fire = 0
            fire_percentages = []
            
            # Analyze each patch
            for patch_idx in range(total_patches):
                patch = monthly_patches[patch_idx]  # Shape: [patch_height, patch_width]
                
                # Check if patch has at least one burned pixel
                if np.any(patch == 1):
                    patches_with_fire += 1
                    
                    # Calculate percentage of burned area within this patch
                    total_pixels = patch.size
                    burned_pixels = np.sum(patch == 1)
                    fire_percent = (burned_pixels / total_pixels) * 100
                    fire_percentages.append(fire_percent)
            
            # Calculate and print results for this band
            fire_presence_percent = (patches_with_fire / total_patches) * 100 if total_patches > 0 else 0
            avg_fire_percent = np.mean(fire_percentages) if fire_percentages else 0
            
            print("{:<10} | {:<8.2f}% ({:>5}/{:<5}) | {:<8.2f}%".format(
                band,
                fire_presence_percent,
                patches_with_fire, total_patches,
                avg_fire_percent
            ))

def main():
    # Check if the Target directory exists
    if not os.path.exists(TARGET_DIR):
        print(f"Error: Target directory '{TARGET_DIR}' not found.")
        return
    
    # Analyze and print the data
    analyze_burned_areas()

if __name__ == "__main__":
    main()


Burned Area Analysis Results:
Band       | Patches with Fire (%)               | Avg Burned Area in Affected Patches (%)
--------------------------------------------------------------------------------
2005-01    | 16.67   % (  342/2051 ) | 0.53    %
2005-02    | 16.97   % (  348/2051 ) | 0.54    %
2005-03    | 17.94   % (  368/2051 ) | 0.27    %
2005-04    | 14.97   % (  307/2051 ) | 1.21    %
2005-05    | 16.04   % (  329/2051 ) | 2.17    %
2005-06    | 7.95    % (  163/2051 ) | 1.68    %
2005-07    | 9.61    % (  197/2051 ) | 1.60    %
2005-08    | 9.80    % (  201/2051 ) | 1.98    %
2005-09    | 12.24   % (  251/2051 ) | 2.15    %
2005-10    | 18.87   % (  387/2051 ) | 2.45    %
2005-11    | 22.82   % (  468/2051 ) | 1.70    %
2005-12    | 21.36   % (  438/2051 ) | 0.86    %
2006-01    | 10.73   % (  220/2051 ) | 0.42    %
2006-02    | 7.85    % (  161/2051 ) | 0.15    %
2006-03    | 10.53   % (  216/2051 ) | 0.16    %
2006-04    | 11.85   % (  243/2051 ) | 0.56    %
2006-05    | 