In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Install required libraries
!pip install -q xarray netCDF4 matplotlib seaborn scipy

# GCM Downscaling Data Inspection Notebook

**Environment**: Google Colab  
**Purpose**: Inspect NetCDF climate data files (CRU, ERA5, GCMs) to verify variable names, coordinates, temporal coverage, and spatial grids.

## Setup Instructions:
1. Run the installation cell to install required packages
2. Mount Google Drive and verify file paths
3. Execute inspection cells sequentially

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Set up paths for Google Drive
base_path = Path('/content/drive/MyDrive/Downscaling ML CEP/AI_GCMs')
cru_path = base_path / 'CRU'
era5_path = base_path / 'ERA5'
gcm_path = base_path / 'GCMs'

# Verify paths exist
print("Verifying data paths...")
for path_name, path in [('Base', base_path), ('CRU', cru_path), ('ERA5', era5_path), ('GCM', gcm_path)]:
    if path.exists():
        print(f"✓ {path_name} path exists: {path}")
    else:
        print(f"✗ {path_name} path NOT FOUND: {path}")
        print(f"  Please create this directory in Google Drive")

# Create outputs directory
output_path = Path('/content/drive/MyDrive/Downscaling ML CEP/outputs/figures')
output_path.mkdir(parents=True, exist_ok=True)
print(f"\n✓ Output directory ready: {output_path}")

## 1. Inspect CRU Reference Data

In [None]:
def inspect_netcdf(filepath, show_sample=True):
    """Comprehensive NetCDF file inspection"""
    print(f"\n{'='*80}")
    print(f"FILE: {filepath.name}")
    print(f"{'='*80}")
    
    try:
        ds = xr.open_dataset(filepath)
    except FileNotFoundError:
        print(f"ERROR: File not found at {filepath}")
        return None
    except Exception as e:
        print(f"ERROR: Could not open file - {e}")
        return None
    
    # Dimensions
    print(f"\nDIMENSIONS:")
    for dim, size in ds.sizes.items():
        print(f"  {dim}: {size}")
    
    # Data variables
    print(f"\nDATA VARIABLES:")
    for var in ds.data_vars:
        var_obj = ds[var]
        print(f"  {var}:")
        print(f"    Shape: {var_obj.shape}")
        print(f"    Dtype: {var_obj.dtype}")
        print(f"    Dims: {var_obj.dims}")
        
        # Attributes
        for attr in ['units', 'long_name', 'standard_name']:
            if hasattr(var_obj, attr):
                print(f"    {attr}: {getattr(var_obj, attr)}")
        
        # Statistics (only for numeric data)
        if np.issubdtype(var_obj.dtype, np.number):
            try:
                print(f"    Min: {float(var_obj.min().values):.4f}")
                print(f"    Max: {float(var_obj.max().values):.4f}")
                print(f"    Mean: {float(var_obj.mean().values):.4f}")
                print(f"    NaN count: {np.isnan(var_obj.values).sum()}")
            except Exception as e:
                print(f"    Statistics: Unable to compute ({e})")
        else:
            print(f"    Data type: Non-numeric ({var_obj.dtype})")
            print(f"    Sample values: {var_obj.values.flat[:3]}")
    
    # Coordinates
    print(f"\nCOORDINATES:")
    for coord in ds.coords:
        coord_obj = ds[coord]
        print(f"  {coord}:")
        print(f"    Shape: {coord_obj.shape}")
        print(f"    Dtype: {coord_obj.dtype}")
        
        if coord_obj.size > 0:
            if 'time' in coord.lower():
                try:
                    print(f"    Range: {pd.to_datetime(coord_obj.values[0])} to {pd.to_datetime(coord_obj.values[-1])}")
                    print(f"    First 3: {[pd.to_datetime(t) for t in coord_obj.values[:3]]}")
                except Exception as e:
                    print(f"    Values: {coord_obj.values[:3]} (time parsing failed)")
            elif coord_obj.size <= 10:
                print(f"    Values: {coord_obj.values}")
            else:
                if np.issubdtype(coord_obj.dtype, np.number):
                    print(f"    Range: {coord_obj.values.min():.4f} to {coord_obj.values.max():.4f}")
                    print(f"    First 3: {coord_obj.values[:3]}")
                    print(f"    Last 3: {coord_obj.values[-3:]}")
                else:
                    print(f"    First 3: {coord_obj.values[:3]}")
                    print(f"    Last 3: {coord_obj.values[-3:]}")
        
        # Units
        if hasattr(coord_obj, 'units'):
            print(f"    Units: {coord_obj.units}")
    
    # Global attributes
    print(f"\nGLOBAL ATTRIBUTES:")
    for attr, value in ds.attrs.items():
        print(f"  {attr}: {str(value)[:100]}")
    
    if show_sample:
        print(f"\nSAMPLE DATA (first time slice):")
        print(ds.isel(time=0) if 'time' in ds.dims else ds.isel({list(ds.dims.keys())[0]: 0}))
    
    return ds

In [None]:
# Inspect CRU temperature
cru_tmp = inspect_netcdf(cru_path / 'cru_tmp.1901.2024.0.25deg.pakistan.nc', show_sample=False)

In [None]:
# Inspect CRU precipitation
cru_pre = inspect_netcdf(cru_path / 'cru_pre.1901.2024.0.25deg.pakistan.nc', show_sample=False)

## 2. Inspect ERA5 Target Data

In [None]:
# Inspect ERA5 file 1 (avgua - likely upward/temperature)
era5_ua = inspect_netcdf(era5_path / 'data_stream-moda_stepType-avgua.nc', show_sample=False)

In [None]:
# Inspect ERA5 file 2 (avgad - likely downward/precipitation)
era5_ad = inspect_netcdf(era5_path / 'data_stream-moda_stepType-avgad.nc', show_sample=False)

## 3. Inspect GCM Historical and Future Data

In [None]:
# Inspect sample GCM files (BCC-CSM2-MR)
print("\n" + "#"*80)
print("# GCM HISTORICAL DATA")
print("#"*80)

gcm_hist_tas = inspect_netcdf(gcm_path / 'BCC-CSM2-MR_hist_tas.nc', show_sample=False)

In [None]:
gcm_hist_pr = inspect_netcdf(gcm_path / 'BCC-CSM2-MR_hist_pr.nc', show_sample=False)

In [None]:
print("\n" + "#"*80)
print("# GCM FUTURE SCENARIOS")
print("#"*80)

gcm_ssp126_tas = inspect_netcdf(gcm_path / 'BCC-CSM2-MR_ssp126_tas.nc', show_sample=False)

In [None]:
gcm_ssp585_pr = inspect_netcdf(gcm_path / 'BCC-CSM2-MR_ssp585_pr.nc', show_sample=False)

## 4. Check another GCM for consistency

In [None]:
# Check CanESM5 for consistency
print("\n" + "#"*80)
print("# VERIFICATION: CanESM5")
print("#"*80)

canesm_hist = inspect_netcdf(gcm_path / 'CanESM5_hist_tas.nc', show_sample=False)

## 5. Temporal Overlap Analysis

In [None]:
# Analyze temporal overlap for 1980-2014 training period
def get_time_range(filepath):
    """Extract time range from NetCDF file"""
    try:
        ds = xr.open_dataset(filepath)
        time_coord = None
        
        # Find time coordinate
        for coord in ds.coords:
            if 'time' in coord.lower():
                time_coord = coord
                break
        
        if time_coord:
            time_vals = pd.to_datetime(ds[time_coord].values)
            ds.close()
            return time_vals[0], time_vals[-1], len(time_vals)
        else:
            ds.close()
            return None, None, 0
    except Exception as e:
        print(f"Error reading {filepath.name}: {e}")
        return None, None, 0

# Check temporal ranges
datasets = {
    'CRU Temperature': cru_path / 'cru_tmp.1901.2024.0.25deg.pakistan.nc',
    'CRU Precipitation': cru_path / 'cru_pre.1901.2024.0.25deg.pakistan.nc',
    'ERA5 avgua': era5_path / 'data_stream-moda_stepType-avgua.nc',
    'ERA5 avgad': era5_path / 'data_stream-moda_stepType-avgad.nc',
    'BCC-CSM2-MR hist tas': gcm_path / 'BCC-CSM2-MR_hist_tas.nc',
    'BCC-CSM2-MR hist pr': gcm_path / 'BCC-CSM2-MR_hist_pr.nc',
    'BCC-CSM2-MR ssp126 tas': gcm_path / 'BCC-CSM2-MR_ssp126_tas.nc',
    'BCC-CSM2-MR ssp585 pr': gcm_path / 'BCC-CSM2-MR_ssp585_pr.nc',
}

print("\n" + "="*80)
print("TEMPORAL COVERAGE SUMMARY")
print("="*80)
print(f"{'Dataset':<30} {'Start':<12} {'End':<12} {'N timesteps':>12}")
print("-"*80)

for name, path in datasets.items():
    start, end, n = get_time_range(path)
    if start:
        print(f"{name:<30} {str(start)[:10]:<12} {str(end)[:10]:<12} {n:>12}")

print("\n" + "="*80)
print("TARGET TRAINING PERIOD: 1980-01-01 to 2014-12-31")
print("="*80)

## 6. Spatial Grid Comparison

In [None]:
# Compare spatial grids
def get_grid_info(filepath):
    """Extract spatial grid information"""
    try:
        ds = xr.open_dataset(filepath)
        
        lat_coord = None
        lon_coord = None
        
        # Find lat/lon coordinates
        for coord in ds.coords:
            if 'lat' in coord.lower():
                lat_coord = coord
            if 'lon' in coord.lower():
                lon_coord = coord
        
        if lat_coord and lon_coord:
            lat = ds[lat_coord].values
            lon = ds[lon_coord].values
            
            lat_res = np.diff(lat).mean() if len(lat) > 1 else 0
            lon_res = np.diff(lon).mean() if len(lon) > 1 else 0
            
            ds.close()
            return {
                'lat_min': lat.min(),
                'lat_max': lat.max(),
                'lon_min': lon.min(),
                'lon_max': lon.max(),
                'n_lat': len(lat),
                'n_lon': len(lon),
                'lat_res': lat_res,
                'lon_res': lon_res,
            }
        else:
            ds.close()
            return None
    except Exception as e:
        print(f"Error reading {filepath.name}: {e}")
        return None

print("\n" + "="*80)
print("SPATIAL GRID COMPARISON")
print("="*80)
print(f"{'Dataset':<30} {'Lat Range':<20} {'Lon Range':<20} {'Grid':>15}")
print("-"*80)

for name, path in datasets.items():
    if 'ssp' not in name.lower():  # Skip future scenarios for now
        grid = get_grid_info(path)
        if grid:
            lat_range = f"{grid['lat_min']:.2f} to {grid['lat_max']:.2f}"
            lon_range = f"{grid['lon_min']:.2f} to {grid['lon_max']:.2f}"
            grid_str = f"{grid['n_lat']} x {grid['n_lon']} ({abs(grid['lat_res']):.3f}°)"
            print(f"{name:<30} {lat_range:<20} {lon_range:<20} {grid_str:>15}")

print("\n" + "="*80)
print("TARGET GRID: CRU 0.25° Pakistan (lat 23-38°N, lon 60-78°E)")
print("="*80)

## 7. Visualize Sample Data

In [None]:
# Quick visualization of spatial patterns
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle('Sample Spatial Patterns (First Time Step)', fontsize=14, fontweight='bold')

try:
    # CRU Temperature
    ds_cru_tmp = xr.open_dataset(cru_path / 'cru_tmp.1901.2024.0.25deg.pakistan.nc')
    var_name_tmp = list(ds_cru_tmp.data_vars)[0]
    ds_cru_tmp[var_name_tmp].isel(time=0).plot(ax=axes[0, 0], cmap='RdYlBu_r')
    axes[0, 0].set_title('CRU Temperature (first timestep)')

    # CRU Precipitation
    ds_cru_pre = xr.open_dataset(cru_path / 'cru_pre.1901.2024.0.25deg.pakistan.nc')
    var_name_pre = list(ds_cru_pre.data_vars)[0]
    ds_cru_pre[var_name_pre].isel(time=0).plot(ax=axes[0, 1], cmap='YlGnBu')
    axes[0, 1].set_title('CRU Precipitation (first timestep)')

    # GCM Temperature
    ds_gcm_tas = xr.open_dataset(gcm_path / 'BCC-CSM2-MR_hist_tas.nc')
    var_name_tas = list(ds_gcm_tas.data_vars)[0]
    ds_gcm_tas[var_name_tas].isel(time=0).plot(ax=axes[1, 0], cmap='RdYlBu_r')
    axes[1, 0].set_title('GCM Temperature (BCC-CSM2-MR, first timestep)')

    # GCM Precipitation
    ds_gcm_pr = xr.open_dataset(gcm_path / 'BCC-CSM2-MR_hist_pr.nc')
    var_name_pr = list(ds_gcm_pr.data_vars)[0]
    ds_gcm_pr[var_name_pr].isel(time=0).plot(ax=axes[1, 1], cmap='YlGnBu')
    axes[1, 1].set_title('GCM Precipitation (BCC-CSM2-MR, first timestep)')

    plt.tight_layout()
    
    # Save to Google Drive
    save_path = output_path / '00_initial_spatial_patterns.png'
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    print(f"\n✓ Figure saved to: {save_path}")
    plt.show()

    # Clean up
    ds_cru_tmp.close()
    ds_cru_pre.close()
    ds_gcm_tas.close()
    ds_gcm_pr.close()
    
except Exception as e:
    print(f"Error creating visualization: {e}")
    print("Please verify all NetCDF files exist in the specified paths")

## 8. Summary and Next Steps

In [None]:
print("\n" + "="*80)
print("INSPECTION SUMMARY")
print("="*80)
print("""
Key Findings from Data Inspection:

1. CRU Variables:
   - Temperature: 'tmp' (degrees Celsius)
   - Precipitation: 'pre' (mm/month)
   - Coordinates: time, lat, lon
   - Resolution: 0.25° (60 x 72 grid)
   - Coverage: 1901-2024 (1488 timesteps)

2. ERA5 Variables:
   - Check avgua file for temperature variable
   - Check avgad file for precipitation variable
   - May need unit conversion (K→°C, m→mm)

3. GCM Variables:
   - Temperature: 'tas' (likely Kelvin)
   - Precipitation: 'pr' (likely kg m⁻²s⁻¹)
   - Resolution: Coarser than CRU (needs regridding)

4. Coordinate Conventions:
   - Use ds.sizes instead of ds.dims.items() (deprecated)
   - Standard naming: time, lat, lon

5. Required Preprocessing:
   - Subset to 1980-2014 training period
   - Regrid GCMs to CRU 0.25° grid
   - Convert units (K→°C, kg m⁻²s⁻¹→mm/month)
   - Align temporal indices across datasets

Next Steps:
1. Upload this notebook's output to verify ERA5 variable names
2. Build preprocessing pipeline (src/data/preprocessors.py)
3. Implement regridding using xESMF or xarray.interp
4. Create feature engineering pipeline
5. Train ML models
""")
print("="*80)

print("\n" + "="*80)
print("GOOGLE COLAB PATHS")
print("="*80)
print(f"Base path: {base_path}")
print(f"Output path: {output_path}")
print("="*80)