# Exploring MesoGEOS Wildfire Dataset

This notebook explores the mesogeos_cube.zarr dataset to understand its structure, dimensions, and variables.

In [None]:
# Import required libraries
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Set plotting style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')
%matplotlib inline

## 1. Load the Zarr Dataset

Zarr is a format for storing chunked, compressed N-dimensional arrays. We'll use xarray to load and explore the data.

In [None]:
# Path to the zarr dataset
data_path = '../data/raw/mesogeos/mesogeos_cube.zarr'

# Load the dataset using xarray
ds = xr.open_zarr(data_path)

print("Dataset loaded successfully!")
print(f"Dataset type: {type(ds)}")

## 2. Dataset Overview

In [None]:
# Display the complete dataset structure
print("=" * 80)
print("DATASET STRUCTURE")
print("=" * 80)
display(ds)

In [None]:
# Get dimensions
print("\nDimensions:")
print("-" * 40)
for dim, size in ds.dims.items():
    print(f"{dim}: {size}")

In [None]:
# Get coordinates
print("\nCoordinates:")
print("-" * 40)
for coord in ds.coords:
    print(f"{coord}: {ds.coords[coord].shape} - {ds.coords[coord].dtype}")

In [None]:
# List all data variables
print("\nData Variables:")
print("-" * 40)
for var in ds.data_vars:
    shape = ds[var].shape
    dtype = ds[var].dtype
    dims = ds[var].dims
    print(f"{var}:")
    print(f"  Shape: {shape}")
    print(f"  Dimensions: {dims}")
    print(f"  Data type: {dtype}")
    print()

## 3. Explore Individual Variables

In [None]:
# Examine attributes of each variable
print("Variable Attributes:")
print("=" * 80)

for var in ds.data_vars:
    print(f"\n{var}:")
    print("-" * 40)
    if ds[var].attrs:
        for attr_name, attr_value in ds[var].attrs.items():
            print(f"  {attr_name}: {attr_value}")
    else:
        print("  No attributes found")

In [None]:
# Get statistics for each variable
print("\nVariable Statistics:")
print("=" * 80)

for var in ds.data_vars:
    print(f"\n{var}:")
    print("-" * 40)
    try:
        data = ds[var]
        print(f"  Min: {float(data.min().values):.4f}")
        print(f"  Max: {float(data.max().values):.4f}")
        print(f"  Mean: {float(data.mean().values):.4f}")
        print(f"  Std: {float(data.std().values):.4f}")
        
        # Check for NaN values
        nan_count = np.isnan(data.values).sum()
        total_count = data.size
        print(f"  NaN values: {nan_count} / {total_count} ({100*nan_count/total_count:.2f}%)")
    except Exception as e:
        print(f"  Error computing statistics: {e}")

## 4. Temporal Analysis

Explore the temporal dimension if it exists.

In [None]:
# Check for time dimension
time_dims = ['time', 'Time', 'date', 'timestamp']
time_coord = None

for dim in time_dims:
    if dim in ds.dims:
        time_coord = dim
        break

if time_coord:
    print(f"Time coordinate found: {time_coord}")
    print(f"Number of timesteps: {len(ds[time_coord])}")
    print(f"\nFirst timestep: {ds[time_coord].values[0]}")
    print(f"Last timestep: {ds[time_coord].values[-1]}")
    
    # Show sample of time values
    print(f"\nSample time values (first 10):")
    print(ds[time_coord].values[:10])
else:
    print("No standard time dimension found in dataset")
    print(f"Available dimensions: {list(ds.dims.keys())}")

## 5. Spatial Analysis

Explore spatial dimensions (latitude, longitude, x, y).

In [None]:
# Check for spatial coordinates
spatial_coords = ['lat', 'latitude', 'lon', 'longitude', 'x', 'y']

print("Spatial Coordinates:")
print("=" * 80)

for coord in spatial_coords:
    if coord in ds.coords or coord in ds.dims:
        print(f"\n{coord}:")
        if coord in ds.coords:
            values = ds.coords[coord].values
        else:
            values = ds[coord].values if coord in ds else None
        
        if values is not None:
            print(f"  Range: [{values.min():.4f}, {values.max():.4f}]")
            print(f"  Shape: {values.shape}")
            print(f"  Resolution: {np.diff(values).mean():.6f}" if len(values) > 1 else "  Single value")

## 6. Visualization

Create visualizations of the data.

In [None]:
# Plot a sample of the first variable
if len(ds.data_vars) > 0:
    first_var = list(ds.data_vars)[0]
    print(f"Visualizing: {first_var}")
    
    # Create figure
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot 1: First timestep/slice
    if len(ds[first_var].shape) >= 2:
        # Get a 2D slice
        if len(ds[first_var].shape) == 2:
            data_slice = ds[first_var]
        elif len(ds[first_var].shape) == 3:
            data_slice = ds[first_var].isel({ds[first_var].dims[0]: 0})
        else:
            # More than 3 dimensions, take first index of all but last 2 dims
            isel_dict = {dim: 0 for dim in ds[first_var].dims[:-2]}
            data_slice = ds[first_var].isel(isel_dict)
        
        im = data_slice.plot(ax=axes[0], cmap='RdYlBu_r', add_colorbar=True)
        axes[0].set_title(f'{first_var} - Spatial Distribution')
    
    # Plot 2: Histogram
    axes[1].hist(ds[first_var].values.flatten(), bins=50, edgecolor='black', alpha=0.7)
    axes[1].set_xlabel('Value')
    axes[1].set_ylabel('Frequency')
    axes[1].set_title(f'{first_var} - Distribution')
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
else:
    print("No data variables available for visualization")

## 7. Dataset Summary

Create a comprehensive summary of the dataset for documentation.

In [None]:
# Create a summary dictionary
summary = {
    'dataset_path': data_path,
    'dimensions': dict(ds.dims),
    'coordinates': list(ds.coords),
    'data_variables': list(ds.data_vars),
    'total_size_bytes': sum(ds[var].nbytes for var in ds.data_vars),
}

print("\nDATASET SUMMARY")
print("=" * 80)
print(f"Path: {summary['dataset_path']}")
print(f"\nDimensions: {summary['dimensions']}")
print(f"\nCoordinates: {summary['coordinates']}")
print(f"\nData Variables ({len(summary['data_variables'])}): {summary['data_variables']}")
print(f"\nTotal Size: {summary['total_size_bytes'] / (1024**3):.2f} GB")

# Global attributes
if ds.attrs:
    print("\nGlobal Attributes:")
    print("-" * 40)
    for attr_name, attr_value in ds.attrs.items():
        print(f"{attr_name}: {attr_value}")

## 8. Export Summary to CSV

Export key information for future reference.

In [None]:
# Create a dataframe with variable information
var_info = []

for var in ds.data_vars:
    var_info.append({
        'variable': var,
        'dimensions': str(ds[var].dims),
        'shape': str(ds[var].shape),
        'dtype': str(ds[var].dtype),
        'min': float(ds[var].min().values),
        'max': float(ds[var].max().values),
        'mean': float(ds[var].mean().values),
        'std': float(ds[var].std().values),
    })

df_summary = pd.DataFrame(var_info)
print("\nVariable Summary Table:")
display(df_summary)

# Save to CSV
output_path = '../data/processed/mesogeos_summary.csv'
df_summary.to_csv(output_path, index=False)
print(f"\nSummary saved to: {output_path}")

## 9. Next Steps

Based on the exploration above:
1. Identify relevant features for wildfire prediction (temperature, wind, humidity, vegetation, etc.)
2. Check data quality and completeness
3. Plan preprocessing steps (normalization, handling missing values, etc.)
4. Design the model architecture based on spatial-temporal structure
5. Create train/validation/test splits