In [6]:
import sys
import os
import pandas as pd

# Add project root to Python path to find the 'src' directory
notebook_dir = os.getcwd()
project_root = os.path.abspath(os.path.join(notebook_dir, '..'))
if project_root not in sys.path:
    sys.path.append(project_root)
    print(f"Added project root to sys.path: {project_root}")

In [7]:
from google.cloud import storage

In [8]:
gcs_client = storage.Client.create_anonymous_client()

In [9]:
import xarray as xr
ds = xr.open_zarr(
    "gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr",
    consolidated=True,storage_options={"token": "anon"},
)

In [10]:
for var_name in ds.data_vars:
    var = ds[var_name]
    print(f"{var_name}: shape={var.shape}, dims={var.dims}")
    

10m_u_component_of_wind: shape=(93544, 721, 1440), dims=('time', 'latitude', 'longitude')
10m_v_component_of_wind: shape=(93544, 721, 1440), dims=('time', 'latitude', 'longitude')
10m_wind_speed: shape=(93544, 721, 1440), dims=('time', 'latitude', 'longitude')
2m_dewpoint_temperature: shape=(93544, 721, 1440), dims=('time', 'latitude', 'longitude')
2m_temperature: shape=(93544, 721, 1440), dims=('time', 'latitude', 'longitude')
above_ground: shape=(93544, 13, 721, 1440), dims=('time', 'level', 'latitude', 'longitude')
ageostrophic_wind_speed: shape=(93544, 13, 721, 1440), dims=('time', 'level', 'latitude', 'longitude')
angle_of_sub_gridscale_orography: shape=(721, 1440), dims=('latitude', 'longitude')
anisotropy_of_sub_gridscale_orography: shape=(721, 1440), dims=('latitude', 'longitude')
boundary_layer_height: shape=(93544, 721, 1440), dims=('time', 'latitude', 'longitude')
divergence: shape=(93544, 13, 721, 1440), dims=('time', 'level', 'latitude', 'longitude')
eddy_kinetic_energy: s

In [11]:
import xarray as xr
from dask.diagnostics import ProgressBar
from pathlib import Path # For checking file size locally
variables = [
    # TARGET VARIABLE
    'total_precipitation_6hr',              # Our main target
    
    # CORE ATMOSPHERIC VARIABLES  
    '2m_temperature',                       # Surface temperature
    '2m_dewpoint_temperature',              # Surface moisture content
    'surface_pressure',                     # Surface pressure
    'mean_sea_level_pressure',              # Synoptic pressure patterns
    
    # WIND FIELDS
    '10m_u_component_of_wind',              # Surface wind U
    '10m_v_component_of_wind',              # Surface wind V
    '10m_wind_speed',                       # Surface wind magnitude
    'u_component_of_wind',                  # Upper-level winds (averaged over pressure levels)
    'v_component_of_wind',                  # Upper-level winds (averaged over pressure levels)
    
    # MOISTURE & THERMODYNAMICS
    'total_column_water_vapour',            # Atmospheric moisture content
    'integrated_vapor_transport',           # Moisture transport
    'boundary_layer_height',                # PBL structure
    'specific_humidity',                    # Atmospheric humidity (averaged over pressure levels)
    
    # CLOUDS & RADIATION
    'total_cloud_cover',                    # Cloud coverage
    'mean_surface_net_short_wave_radiation_flux',  # Solar heating
    'mean_surface_latent_heat_flux',        # Evaporation
    'mean_surface_sensible_heat_flux',      # Surface heating
    
    # SURFACE CONDITIONS
    'snow_depth',                           # Snow coverage
    'sea_surface_temperature',              # SST for coastal areas
    'volumetric_soil_water_layer_1',        # Surface soil moisture
    
    # ATMOSPHERIC DYNAMICS
    'mean_vertically_integrated_moisture_divergence',  # Moisture convergence
    'eddy_kinetic_energy',                  # Turbulence measure
    
]

# --- IMPORTANT NOTES FOR LOCAL VS CODE ---
# 1. No 'google.colab.auth' is needed; authentication is handled by 'gcloud auth application-default login' run in your terminal.
# 2. No '!pip install' commands are needed in the script itself; packages are installed once into your virtual environment.
# 3. Output will be saved directly to your local file system.
# ---

# 1. Load the public Zarr dataset anonymously
# 'storage_options={"token": "cloud"}' tells gcsfs to use the Application Default Credentials
# you set up with 'gcloud auth application-default login'.
print("🚀 Attempting to load dataset from Google Cloud Storage...")
ds = xr.open_zarr(
    "gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr",
)
print("✅ Dataset loaded successfully (metadata).")


# 2. Define variables to keep
vars_to_keep = variables

# 3. Select time, space, variables
# Using method='nearest' for robustness in case slice endpoints don't exactly match coordinates.
print("✂️ Subsetting data...")
ds_subset = ds[vars_to_keep].sel(
    time=slice("1959", "2023"),
    latitude=slice(75, 30), # Assuming latitude is ordered from North to South (e.g., 90 to -90)
    longitude=slice(-25 % 360, 50 % 360), method='nearest' # Ensure correct wraparound and selection
)

# --- Debugging Prints (Very helpful to confirm data before saving) ---
print("\n--- ds_subset Information ---")
print(ds_subset) # This shows the data structure, dimensions, and variables
print("\nds_subset dimensions:")
for dim, size in ds_subset.dims.items():
    print(f"  {dim}: {size}")

if all(size > 0 for size in ds_subset.dims.values()):
    print("\nAll dimensions have size > 0. Proceeding with save.")
else:
    print("\n❌ WARNING: One or more dimensions have a size of 0. The subset is likely empty and will result in a 0-byte file.")
# --- End Debugging Prints ---


🚀 Attempting to load dataset from Google Cloud Storage...
✅ Dataset loaded successfully (metadata).
✂️ Subsetting data...


NotImplementedError: cannot use ``method`` argument if any indexers are slice objects

In [None]:
#!/usr/bin/env python3
"""
Test script for MESA-Net data loading pipeline
Tests WeatherBench2Dataset and identifies issues

Run this first to see what breaks in our data pipeline.
"""

import sys
import traceback
import torch
import xarray as xr
import numpy as np
from typing import List, Dict, Tuple

# Add your mesa_net package to path if needed
# sys.path.append('/path/to/your/mesa_net/')

# Import your classes (adjust imports based on your file structure)
try:
    from src.mesanet.mesanet_dataset import WeatherBench2Dataset
    from src.mesanet.mesanet_datamanager import WeatherBench2DataManager
    print("✅ Imports successful")
except ImportError as e:
    print(f"❌ Import error: {e}")
    print("Please adjust the import paths based on your file structure")
    sys.exit(1)

class DataLoadingTester:
    """Test data loading step by step"""
    
    def __init__(self):
        self.zarr_path = "gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr"
        
        # Start with minimal variables to test
        self.test_variables = variables
        
        self.results = {}
    
    def test_1_basic_zarr_access(self):
        """Test 1: Can we access WeatherBench2 zarr directly?"""
        print("\n" + "="*50)
        print("TEST 1: Basic WeatherBench2 Access")
        print("="*50)
        
        try:
            # Try to open the zarr store
            ds = xr.open_zarr(
                self.zarr_path,
                consolidated=True,
                storage_options={"token": "anon"},
                chunks={'time': 10}
            )
            
            print("✅ Successfully opened WeatherBench2 zarr")
            print(f"   Dataset dimensions: {dict(ds.dims)}")
            print(f"   Available variables: {len(list(ds.data_vars.keys()))}")
            print(f"   Time range: {ds.time.values[0]} to {ds.time.values[-1]}")
            
            # Check which of our test variables exist
            available_vars = list(ds.data_vars.keys())
            print(f"\n   Variable availability check:")
            for var in self.test_variables:
                exists = var in available_vars
                status = "✅" if exists else "❌"
                print(f"   {status} {var}")
            
            self.results['zarr_access'] = True
            self.results['available_variables'] = available_vars
            return ds
            
        except Exception as e:
            print(f"❌ Failed to access WeatherBench2: {e}")
            print(f"   Error type: {type(e).__name__}")
            traceback.print_exc()
            self.results['zarr_access'] = False
            return None
    
    def test_2_variable_inspection(self, ds):
        """Test 2: Inspect variables in detail"""
        print("\n" + "="*50)
        print("TEST 2: Variable Inspection")
        print("="*50)
        
        if ds is None:
            print("❌ Skipping - no dataset available")
            return
        
        try:
            # Print first 20 variables with their dimensions
            variables = list(ds.data_vars.keys())[:20]
            print(f"First 20 variables and their dimensions:")
            
            for var in variables:
                dims = ds[var].dims
                shape = ds[var].shape
                print(f"   {var}: {dims} -> {shape}")
            
            # Test specific variables we need
            print(f"\nTesting our required variables:")
            for var in self.test_variables:
                if var in ds.data_vars:
                    var_data = ds[var]
                    print(f"   ✅ {var}: {var_data.dims} -> {var_data.shape}")
                    print(f"      Data type: {var_data.dtype}")
                    #print(f"      Min/Max: {float(var_data.min())} / {float(var_data.max())}")
                else:
                    print(f"   ❌ {var}: NOT FOUND")
            
            self.results['variable_inspection'] = True
            
        except Exception as e:
            print(f"❌ Variable inspection failed: {e}")
            traceback.print_exc()
            self.results['variable_inspection'] = False
    
    def test_3_geographic_subsetting(self, ds):
        """Test 3: Geographic subsetting for Europe"""
        print("\n" + "="*50)
        print("TEST 3: Geographic Subsetting")
        print("="*50)
        
        if ds is None:
            print("❌ Skipping - no dataset available")
            return None
        
        try:
            print(f"Original longitude range: {ds.longitude.min().values} to {ds.longitude.max().values}")
            print(f"Original latitude range: {ds.latitude.min().values} to {ds.latitude.max().values}")
            
            # Apply Europe bounds
            ds_europe = ds.where(
                (ds.longitude >= 335) | (ds.longitude <= 50),
                drop=True
            ).sel(latitude=slice(75, 30))
            
            print(f"Europe longitude range: {ds_europe.longitude.min().values} to {ds_europe.longitude.max().values}")
            print(f"Europe latitude range: {ds_europe.latitude.min().values} to {ds_europe.latitude.max().values}")
            print(f"Europe grid shape: lat={len(ds_europe.latitude)}, lon={len(ds_europe.longitude)}")
            
            self.results['geographic_subsetting'] = True
            return ds_europe
            
        except Exception as e:
            print(f"❌ Geographic subsetting failed: {e}")
            traceback.print_exc()
            self.results['geographic_subsetting'] = False
            return None
    
    def test_4_time_subsetting(self, ds):
        """Test 4: Time subsetting and indexing"""
        print("\n" + "="*50)
        print("TEST 4: Time Subsetting")
        print("="*50)
        
        if ds is None:
            print("❌ Skipping - no dataset available")
            return None
        
        try:
            # Test recent time period
            ds_recent = ds.sel(time=slice("2023", "2023"))
            print(f"✅ 2023 data: {len(ds_recent.time)} time steps")
            print(f"   Time range: {ds_recent.time.values[0]} to {ds_recent.time.values[-1]}")
            
            # Test sequence creation indices
            total_time_steps = len(ds_recent.time)
            sequence_length = 12
            forecast_horizon = 4
            
            valid_indices = np.arange(
                sequence_length,
                total_time_steps - forecast_horizon
            )
            
            print(f"   Valid sequence indices: {len(valid_indices)} out of {total_time_steps}")
            
            if len(valid_indices) > 0:
                print(f"   First valid index: {valid_indices[0]}")
                print(f"   Last valid index: {valid_indices[-1]}")
            else:
                print("   ❌ No valid sequences found!")
            
            self.results['time_subsetting'] = True
            return ds_recent, valid_indices
            
        except Exception as e:
            print(f"❌ Time subsetting failed: {e}")
            traceback.print_exc()
            self.results['time_subsetting'] = False
            return None, None
    
    def test_5_single_sample_extraction(self, ds, valid_indices):
        """Test 5: Extract a single training sample"""
        print("\n" + "="*50)
        print("TEST 5: Single Sample Extraction")
        print("="*50)
        
        if ds is None or len(valid_indices) == 0:
            print("❌ Skipping - no dataset or valid indices available")
            return
        
        try:
            sequence_length = 12
            forecast_horizon = 4
            time_idx = valid_indices[0]
            
            print(f"Testing with time index: {time_idx}")
            
            # Input sequence
            input_slice = ds.isel(
                time=slice(time_idx - sequence_length, time_idx)
            )
            print(f"✅ Input slice shape: {input_slice.sizes}")
            print(f"✅ Input dims: {dict(input_slice.dims)}")

            for dim, size in input_slice.sizes.items():
                print(f"    {dim}: {size}")
            
            # Target sequence  
            if 'total_precipitation_6hr' in ds.data_vars:
                target_slice = ds['total_precipitation_6hr'].isel(
                    time=slice(time_idx, time_idx + forecast_horizon)
                )
                print(f"✅ Target slice dims: {target_slice.dims}")
                print(f"✅ Target slice shape: {target_slice.shape}")
                for dim, size in target_slice.sizes.items():
                    print(f"    {dim}: {size}")
            else:
                print("❌ total_precipitation_6hr not found")
                return
            
            # Test loading the data
            print("Loading data into memory...")
            input_loaded = input_slice.load()
            target_loaded = target_slice.load()
            
            print(f"✅ Successfully loaded sample")
            print(f"   Input time steps: {len(input_loaded.time)}")
            print(f"   Target time steps: {len(target_loaded.time)}")
            
            self.results['sample_extraction'] = True
            return input_loaded, target_loaded
            
        except Exception as e:
            print(f"❌ Sample extraction failed: {e}")
            traceback.print_exc()
            self.results['sample_extraction'] = False
            return None, None
    
    def test_6_tensor_conversion(self, input_data, target_data):
        """Test 6: Convert xarray to PyTorch tensors"""
        print("\n" + "="*50)
        print("TEST 6: Tensor Conversion")
        print("="*50)
        
        if input_data is None or target_data is None:
            print("❌ Skipping - no data available")
            return
        
        try:
            # Test basic tensor conversion for target (single variable)
            print("Testing target tensor conversion...")
            target_array = target_data.values
            target_tensor = torch.tensor(target_array, dtype=torch.float32)
            print(f"✅ Target tensor shape: {target_tensor.shape}")
            print(f"   Data type: {target_tensor.dtype}")
            print(f"   Value range: {target_tensor.min().item():.4f} to {target_tensor.max().item():.4f}")
            
            # Test multi-variable input conversion (this is where we'll likely have issues)
            print("\nTesting input tensor conversion...")
            var_arrays = []
            
            for var in self.test_variables:
                if var in input_data.data_vars:
                    var_data = input_data[var].values
                    print(f"   Processing {var}: shape {var_data.shape}")
                    
                    # Handle different dimensionalities
                    if var_data.ndim == 2:  # (lat, lon) - single time step
                        var_data = var_data[None, ...]  # Add time dimension
                    elif var_data.ndim == 4:  # (time, level, lat, lon) - multi-level
                        # Average over pressure levels
                        var_data = np.mean(var_data, axis=1)
                        print(f"     Averaged over pressure levels: {var_data.shape}")
                    
                    var_arrays.append(var_data)
                    print(f"   ✅ {var}: final shape {var_data.shape}")
                else:
                    print(f"   ❌ {var}: not found in data")
            
            if var_arrays:
                # Stack variables
                stacked_array = np.stack(var_arrays, axis=1)
                input_tensor = torch.tensor(stacked_array, dtype=torch.float32)
                
                print(f"✅ Input tensor shape: {input_tensor.shape}")
                print(f"   Expected format: (time, vars, lat, lon)")
                print(f"   Data type: {input_tensor.dtype}")
                
                self.results['tensor_conversion'] = True
                return input_tensor, target_tensor
            else:
                print("❌ No variables could be processed")
                self.results['tensor_conversion'] = False
                return None, None
            
        except Exception as e:
            print(f"❌ Tensor conversion failed: {e}")
            traceback.print_exc()
            self.results['tensor_conversion'] = False
            return None, None
    
    def test_7_dataset_class(self):
        """Test 7: Our WeatherBench2Dataset class"""
        print("\n" + "="*50)
        print("TEST 7: WeatherBench2Dataset Class")
        print("="*50)
        
        try:
            # Create dataset with minimal configuration
            dataset = WeatherBench2Dataset(
                zarr_path=self.zarr_path,
                variables=self.test_variables,
                time_range=slice("2023", "2023"),  # Just 2023 for testing
                split="train",
                sequence_length=12,
                forecast_horizon=4,
                normalize=False  # Skip normalization for now
            )
            
            print(f"✅ Dataset created successfully")
            print(f"   Dataset length: {len(dataset)}")
            
            # Test getting one sample
            if len(dataset) > 0:
                print("Testing sample retrieval...")
                input_seq, target_seq, geo_features = dataset[0]
                
                print(f"✅ Sample retrieved successfully")
                print(f"   Input sequence shape: {input_seq.shape}")
                print(f"   Target sequence shape: {target_seq.shape}")
                print(f"   Geographic features shape: {geo_features.shape}")
                
                self.results['dataset_class'] = True
            else:
                print("❌ Dataset is empty")
                self.results['dataset_class'] = False
            
        except Exception as e:
            print(f"❌ Dataset class test failed: {e}")
            traceback.print_exc()
            self.results['dataset_class'] = False
    
    def run_all_tests(self):
        """Run all tests in sequence"""
        print("🚀 Starting MESA-Net Data Loading Tests")
        print("="*60)
        
        # Test 1: Basic access
        ds = self.test_1_basic_zarr_access()
        
        # Test 2: Variable inspection
        self.test_2_variable_inspection(ds)
        
        # Test 3: Geographic subsetting
        ds_europe = self.test_3_geographic_subsetting(ds)
        
        # Test 4: Time subsetting
        ds_recent, valid_indices = self.test_4_time_subsetting(ds_europe)
        
        # Test 5: Sample extraction
        input_data, target_data = self.test_5_single_sample_extraction(ds_recent, valid_indices)
        
        # Test 6: Tensor conversion
        input_tensor, target_tensor = self.test_6_tensor_conversion(input_data, target_data)
        
        # Test 7: Dataset class
        self.test_7_dataset_class()
        
        # Summary
        self.print_summary()
        return self.print_summary()  # ✅ Add this line
    
    def print_summary(self):
        """Print test results summary"""
        print("\n" + "="*60)
        print("🏁 TEST RESULTS SUMMARY")
        print("="*60)
        
        tests = [
            ('zarr_access', 'WeatherBench2 Access'),
            ('variable_inspection', 'Variable Inspection'),
            ('geographic_subsetting', 'Geographic Subsetting'),
            ('time_subsetting', 'Time Subsetting'),
            ('sample_extraction', 'Sample Extraction'),
            ('tensor_conversion', 'Tensor Conversion'),
            ('dataset_class', 'Dataset Class'),
        ]
        
        passed = 0
        total = len(tests)
        
        for test_key, test_name in tests:
            if test_key in self.results:
                status = "✅ PASS" if self.results[test_key] else "❌ FAIL"
                if self.results[test_key]:
                    passed += 1
            else:
                status = "⏭️ SKIP"
            
            print(f"{status} {test_name}")
        
        print(f"\nOverall: {passed}/{total} tests passed")
        
        if passed == total:
            print("🎉 All tests passed! Data pipeline is working.")
        else:
            print("🔧 Some tests failed. Check the errors above and fix issues.")
            
        return passed == total

if __name__ == "__main__":
    tester = DataLoadingTester()
    success = tester.run_all_tests()
    
    if success:
        print("\n✅ Ready to proceed to next step!")
    else:
        print("\n❌ Fix data loading issues before proceeding.")
        sys.exit(1)

✅ Imports successful
🚀 Starting MESA-Net Data Loading Tests

TEST 1: Basic WeatherBench2 Access


  from .autonotebook import tqdm as notebook_tqdm


✅ Successfully opened WeatherBench2 zarr
   Dataset dimensions: {'time': 93544, 'latitude': 721, 'longitude': 1440, 'level': 13}
   Available variables: 62
   Time range: 1959-01-01T00:00:00.000000000 to 2023-01-10T18:00:00.000000000

   Variable availability check:
   ✅ total_precipitation_6hr
   ✅ 2m_temperature
   ✅ surface_pressure
   ✅ mean_sea_level_pressure

TEST 2: Variable Inspection
First 20 variables and their dimensions:
   10m_u_component_of_wind: ('time', 'latitude', 'longitude') -> (93544, 721, 1440)
   10m_v_component_of_wind: ('time', 'latitude', 'longitude') -> (93544, 721, 1440)
   10m_wind_speed: ('time', 'latitude', 'longitude') -> (93544, 721, 1440)
   2m_dewpoint_temperature: ('time', 'latitude', 'longitude') -> (93544, 721, 1440)
   2m_temperature: ('time', 'latitude', 'longitude') -> (93544, 721, 1440)
   above_ground: ('time', 'level', 'latitude', 'longitude') -> (93544, 13, 721, 1440)
   ageostrophic_wind_speed: ('time', 'level', 'latitude', 'longitude') -> (

  print(f"   Dataset dimensions: {dict(ds.dims)}")


Original longitude range: 0.0 to 359.75
Original latitude range: -90.0 to 90.0
Europe longitude range: 0.0 to 359.75
Europe latitude range: 30.0 to 75.0
Europe grid shape: lat=181, lon=301

TEST 4: Time Subsetting
✅ 2023 data: 40 time steps
   Time range: 2023-01-01T00:00:00.000000000 to 2023-01-10T18:00:00.000000000
   Valid sequence indices: 24 out of 40
   First valid index: 12
   Last valid index: 35

TEST 5: Single Sample Extraction
Testing with time index: 12
✅ Input slice shape: Frozen({'time': 12, 'latitude': 181, 'longitude': 301, 'level': 13})
✅ Input dims: {'time': 12, 'latitude': 181, 'longitude': 301, 'level': 13}
    time: 12
    latitude: 181
    longitude: 301
    level: 13
✅ Target slice dims: ('time', 'latitude', 'longitude')
✅ Target slice shape: (4, 181, 301)
    time: 4
    latitude: 181
    longitude: 301
Loading data into memory...


  print(f"✅ Input dims: {dict(input_slice.dims)}")


: 

In [None]:
#!/usr/bin/env python3
"""
Comprehensive test for fixed MESA-Net data loading pipeline
Tests all edge cases and error handling

File: test_fixed_data_pipeline.py
"""

import sys
import traceback
import torch
import numpy as np
from torch.utils.data import DataLoader
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Import your fixed classes
try:
    # Update these imports based on your file structure
    from src.mesanet.mesanet_dataset import WeatherBench2Dataset, WeatherBench2DataManager
    print("✅ Fixed imports successful")
except ImportError as e:
    print(f"❌ Import error: {e}")
    print("Please adjust the import paths and ensure you've updated the dataset file")
    sys.exit(1)

class ComprehensiveDataTester:
    """Comprehensive test for fixed data pipeline"""
    
    def __init__(self):
        self.zarr_path = "gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr"
        
        # Test with realistic variable list
        self.test_variables = [
            'total_precipitation_6hr',
            '2m_temperature',
            'surface_pressure', 
            'mean_sea_level_pressure',
            '10m_u_component_of_wind',
            '10m_v_component_of_wind'
        ]
        
        # Also test with some invalid variables
        self.test_variables_with_invalid = self.test_variables + [
            'fake_variable_1',
            'non_existent_var',
            'invalid_precipitation'
        ]
        
        self.results = {}
    
    def test_1_data_manager_connection(self):
        """Test 1: Data manager connection and variable validation"""
        print("\n" + "="*60)
        print("TEST 1: Data Manager Connection & Variable Validation")
        print("="*60)
        
        try:
            # Test with valid variables
            print("Testing with valid variables...")
            data_manager = WeatherBench2DataManager(
                zarr_path=self.zarr_path,
                variables=self.test_variables,
                sequence_length=12,
                forecast_horizon=4
            )
            print("✅ Data manager created with valid variables")
            
            # Test with some invalid variables
            print("\nTesting with mixed valid/invalid variables...")
            data_manager_mixed = WeatherBench2DataManager(
                zarr_path=self.zarr_path,
                variables=self.test_variables_with_invalid,
                sequence_length=12,
                forecast_horizon=4
            )
            print("✅ Data manager handles invalid variables gracefully")
            
            self.data_manager = data_manager
            self.results['data_manager'] = True
            return True
            
        except Exception as e:
            print(f"❌ Data manager test failed: {e}")
            traceback.print_exc()
            self.results['data_manager'] = False
            return False
    
    def test_2_dataset_creation_validation(self):
        """Test 2: Dataset creation with variable validation"""
        print("\n" + "="*60)
        print("TEST 2: Dataset Creation & Variable Validation")
        print("="*60)
        
        if not self.results.get('data_manager', False):
            print("❌ Skipping - data manager test failed")
            return False
        
        try:
            # Create dataset with mixed variables
            dataset = WeatherBench2Dataset(
                zarr_path=self.zarr_path,
                variables=self.test_variables_with_invalid,
                time_range=slice("2023", "2023"),
                split="train",
                sequence_length=12,
                forecast_horizon=4,
                normalize=True
            )
            
            print(f"✅ Dataset created successfully")
            print(f"   Original variables requested: {len(self.test_variables_with_invalid)}")
            print(f"   Valid variables found: {len(dataset.available_variables)}")
            print(f"   Available variables: {dataset.available_variables}")
            print(f"   Dataset length: {len(dataset)}")
            
            # Get variable info
            info = dataset.get_variable_info()
            print(f"\n   Dataset info:")
            print(f"   - Shape: {info['dataset_shape']}")
            print(f"   - Geo features shape: {info['geographic_features_shape']}")
            print(f"   - Normalization stats computed: {len(info['normalization_stats'])}")
            
            self.dataset = dataset
            self.results['dataset_creation'] = True
            return True
            
        except Exception as e:
            print(f"❌ Dataset creation test failed: {e}")
            traceback.print_exc()
            self.results['dataset_creation'] = False
            return False
    
    def test_3_single_sample_retrieval(self):
        """Test 3: Single sample retrieval and tensor shapes"""
        print("\n" + "="*60)
        print("TEST 3: Single Sample Retrieval & Tensor Shapes")
        print("="*60)
        
        if not self.results.get('dataset_creation', False):
            print("❌ Skipping - dataset creation failed")
            return False
        
        try:
            if len(self.dataset) == 0:
                print("❌ Dataset is empty - cannot test sample retrieval")
                self.results['sample_retrieval'] = False
                return False
            
            # Test multiple samples
            for i in [0, len(self.dataset)//2, len(self.dataset)-1]:
                print(f"\nTesting sample {i}...")
                
                input_seq, target_seq, geo_features = self.dataset[i]
                
                print(f"✅ Sample {i} retrieved successfully")
                print(f"   Input sequence shape: {input_seq.shape}")
                print(f"   Target sequence shape: {target_seq.shape}")
                print(f"   Geographic features shape: {geo_features.shape}")
                
                # Validate tensor properties
                assert input_seq.dtype == torch.float32, f"Wrong input dtype: {input_seq.dtype}"
                assert target_seq.dtype == torch.float32, f"Wrong target dtype: {target_seq.dtype}"
                assert geo_features.dtype == torch.float32, f"Wrong geo dtype: {geo_features.dtype}"
                
                # Check for NaN values
                assert not torch.isnan(input_seq).any(), "Input contains NaN values"
                assert not torch.isnan(target_seq).any(), "Target contains NaN values"
                assert not torch.isnan(geo_features).any(), "Geo features contain NaN values"
                
                # Check value ranges are reasonable
                print(f"   Input value range: {input_seq.min():.4f} to {input_seq.max():.4f}")
                print(f"   Target value range: {target_seq.min():.4f} to {target_seq.max():.4f}")
                print(f"   Geo features range: {geo_features.min():.4f} to {geo_features.max():.4f}")
            
            print("✅ All sample retrievals successful")
            self.results['sample_retrieval'] = True
            return True
            
        except Exception as e:
            print(f"❌ Sample retrieval test failed: {e}")
            traceback.print_exc()
            self.results['sample_retrieval'] = False
            return False
    
    def test_4_dataloader_functionality(self):
        """Test 4: DataLoader functionality with error handling"""
        print("\n" + "="*60)
        print("TEST 4: DataLoader Functionality & Error Handling")
        print("="*60)
        
        if not self.results.get('dataset_creation', False):
            print("❌ Skipping - dataset creation failed")
            return False
        
        try:
            # Create datasets
            train_dataset, val_dataset, test_dataset = self.data_manager.create_datasets(
                time_range=slice("2023", "2023"),
                normalize=True
            )
            
            print(f"✅ Datasets created:")
            print(f"   Train: {len(train_dataset)} samples")
            print(f"   Val: {len(val_dataset)} samples")
            print(f"   Test: {len(test_dataset)} samples")
            
            # Create data loaders
            train_loader, val_loader, test_loader = self.data_manager.create_data_loaders(
                datasets=(train_dataset, val_dataset, test_dataset),
                batch_size=4,  # Small batch for testing
                num_workers=0  # No multiprocessing for testing
            )
            
            print(f"✅ DataLoaders created:")
            print(f"   Train batches: {len(train_loader)}")
            print(f"   Val batches: {len(val_loader)}")
            print(f"   Test batches: {len(test_loader)}")
            
            # Test loading batches
            print(f"\nTesting batch loading...")
            
            for loader_name, loader in [("Train", train_loader), ("Val", val_loader)]:
                if len(loader) > 0:
                    batch_input, batch_target, batch_geo = next(iter(loader))
                    
                    print(f"✅ {loader_name} batch loaded:")
                    print(f"   Batch input shape: {batch_input.shape}")
                    print(f"   Batch target shape: {batch_target.shape}")
                    print(f"   Batch geo shape: {batch_geo.shape}")
                    
                    # Validate batch properties
                    assert batch_input.dtype == torch.float32
                    assert batch_target.dtype == torch.float32
                    assert batch_geo.dtype == torch.float32
                    
                    # Check for NaN values
                    assert not torch.isnan(batch_input).any()
                    assert not torch.isnan(batch_target).any()
                    assert not torch.isnan(batch_geo).any()
                    
                    print(f"   ✅ Batch validation passed")
                else:
                    print(f"   ⚠️ {loader_name} loader is empty")
            
            self.train_loader = train_loader
            self.results['dataloader'] = True
            return True
            
        except Exception as e:
            print(f"❌ DataLoader test failed: {e}")
            traceback.print_exc()
            self.results['dataloader'] = False
            return False
    
    def test_5_normalization_effectiveness(self):
        """Test 5: Normalization effectiveness"""
        print("\n" + "="*60)
        print("TEST 5: Normalization Effectiveness")
        print("="*60)
        
        if not self.results.get('dataloader', False):
            print("❌ Skipping - dataloader test failed")
            return False
        
        try:
            # Test with normalization
            print("Testing WITH normalization...")
            dataset_norm = WeatherBench2Dataset(
                zarr_path=self.zarr_path,
                variables=self.test_variables[:3],  # Use first 3 variables
                time_range=slice("2023", "2023"),
                split="train",
                sequence_length=12,
                forecast_horizon=4,
                normalize=True
            )
            
            # Test without normalization
            print("Testing WITHOUT normalization...")
            dataset_no_norm = WeatherBench2Dataset(
                zarr_path=self.zarr_path,
                variables=self.test_variables[:3],  # Use first 3 variables
                time_range=slice("2023", "2023"),
                split="train",
                sequence_length=12,
                forecast_horizon=4,
                normalize=False
            )
            
            if len(dataset_norm) > 0 and len(dataset_no_norm) > 0:
                # Get sample from each
                input_norm, _, _ = dataset_norm[0]
                input_no_norm, _, _ = dataset_no_norm[0]
                
                print(f"\nNormalization comparison:")
                print(f"   Normalized - Mean: {input_norm.mean():.4f}, Std: {input_norm.std():.4f}")
                print(f"   Not normalized - Mean: {input_no_norm.mean():.4f}, Std: {input_no_norm.std():.4f}")
                print(f"   Normalized range: {input_norm.min():.4f} to {input_norm.max():.4f}")
                print(f"   Not normalized range: {input_no_norm.min():.4f} to {input_no_norm.max():.4f}")
                
                # Check that normalization actually changed the data
                if not torch.allclose(input_norm, input_no_norm, atol=1e-3):
                    print("✅ Normalization is working correctly")
                else:
                    print("⚠️ Normalization may not be applied correctly")
            
            self.results['normalization'] = True
            return True
            
        except Exception as e:
            print(f"❌ Normalization test failed: {e}")
            traceback.print_exc()
            self.results['normalization'] = False
            return False
    
    def test_6_memory_efficiency(self):
        """Test 6: Memory efficiency and batch processing"""
        print("\n" + "="*60)
        print("TEST 6: Memory Efficiency & Batch Processing")
        print("="*60)
        
        if not self.results.get('dataloader', False):
            print("❌ Skipping - dataloader test failed")
            return False
        
        try:
            # Test processing multiple batches
            print("Testing multiple batch processing...")
            
            batch_count = 0
            total_samples = 0
            
            for batch_idx, (input_batch, target_batch, geo_batch) in enumerate(self.train_loader):
                batch_count += 1
                total_samples += input_batch.shape[0]
                
                print(f"   Batch {batch_idx}: {input_batch.shape[0]} samples")
                
                # Test memory usage
                if torch.cuda.is_available():
                    print(f"   GPU memory: {torch.cuda.memory_allocated()/1e6:.1f}MB")
                
                # Limit test to first few batches
                if batch_count >= 3:
                    break
            
            print(f"✅ Processed {batch_count} batches, {total_samples} total samples")
            
            # Test iterator restart
            print("\nTesting DataLoader iterator restart...")
            batch_count_2 = sum(1 for _ in self.train_loader)
            print(f"✅ Second iteration: {batch_count_2} batches")
            
            self.results['memory_efficiency'] = True
            return True
            
        except Exception as e:
            print(f"❌ Memory efficiency test failed: {e}")
            traceback.print_exc()
            self.results['memory_efficiency'] = False
            return False
    
    def test_7_edge_cases(self):
        """Test 7: Edge cases and error handling"""
        print("\n" + "="*60)
        print("TEST 7: Edge Cases & Error Handling")
        print("="*60)
        
        try:
            print("Testing edge cases...")
            
            # Test with empty variable list
            print("\n1. Testing with empty variable list...")
            try:
                empty_dataset = WeatherBench2Dataset(
                    zarr_path=self.zarr_path,
                    variables=[],
                    time_range=slice("2023", "2023"),
                    split="train",
                    sequence_length=12,
                    forecast_horizon=4,
                    normalize=False
                )
                print("❌ Should have failed with empty variables")
            except Exception:
                print("✅ Correctly handled empty variable list")
            
            # Test with very small time range
            print("\n2. Testing with minimal time range...")
            try:
                minimal_dataset = WeatherBench2Dataset(
                    zarr_path=self.zarr_path,
                    variables=self.test_variables[:2],
                    time_range=slice("2023-01-01", "2023-01-02"),
                    split="train",
                    sequence_length=12,
                    forecast_horizon=4,
                    normalize=False
                )
                print(f"✅ Minimal dataset created: {len(minimal_dataset)} samples")
            except Exception as e:
                print(f"⚠️ Minimal time range failed: {e}")
            
            # Test with large sequence length
            print("\n3. Testing with large sequence length...")
            try:
                large_seq_dataset = WeatherBench2Dataset(
                    zarr_path=self.zarr_path,
                    variables=self.test_variables[:2],
                    time_range=slice("2023", "2023"),
                    split="train",
                    sequence_length=50,  # Very large
                    forecast_horizon=4,
                    normalize=False
                )
                print(f"✅ Large sequence dataset: {len(large_seq_dataset)} samples")
            except Exception as e:
                print(f"⚠️ Large sequence failed: {e}")
            
            # Test invalid time range
            print("\n4. Testing with invalid time range...")
            try:
                invalid_dataset = WeatherBench2Dataset(
                    zarr_path=self.zarr_path,
                    variables=self.test_variables[:2],
                    time_range=slice("2030", "2031"),  # Future dates
                    split="train",
                    sequence_length=12,
                    forecast_horizon=4,
                    normalize=False
                )
                print(f"⚠️ Invalid time range accepted: {len(invalid_dataset)} samples")
            except Exception:
                print("✅ Correctly handled invalid time range")
            
            self.results['edge_cases'] = True
            return True
            
        except Exception as e:
            print(f"❌ Edge cases test failed: {e}")
            traceback.print_exc()
            self.results['edge_cases'] = False
            return False
    
    def run_all_tests(self):
        """Run all comprehensive tests"""
        print("🚀 Starting Comprehensive MESA-Net Data Pipeline Tests")
        print("="*70)
        
        # Run all tests
        self.test_1_data_manager_connection()
        self.test_2_dataset_creation_validation()
        self.test_3_single_sample_retrieval()
        self.test_4_dataloader_functionality()
        self.test_5_normalization_effectiveness()
        self.test_6_memory_efficiency()
        self.test_7_edge_cases()
        
        # Print summary
        return self.print_summary()
    
    def print_summary(self):
        """Print comprehensive test results"""
        print("\n" + "="*70)
        print("🏁 COMPREHENSIVE DATA PIPELINE TEST RESULTS")
        print("="*70)
        
        tests = [
            ('data_manager', 'Data Manager Connection'),
            ('dataset_creation', 'Dataset Creation & Validation'),
            ('sample_retrieval', 'Single Sample Retrieval'),
            ('dataloader', 'DataLoader Functionality'),
            ('normalization', 'Normalization Effectiveness'),
            ('memory_efficiency', 'Memory Efficiency'),
            ('edge_cases', 'Edge Cases & Error Handling'),
        ]
        
        passed = 0
        total = len(tests)
        
        for test_key, test_name in tests:
            if test_key in self.results:
                status = "✅ PASS" if self.results[test_key] else "❌ FAIL"
                if self.results[test_key]:
                    passed += 1
            else:
                status = "⏭️ SKIP"
            
            print(f"{status} {test_name}")
        
        print(f"\nOverall: {passed}/{total} tests passed")
        
        if passed == total:
            print("🎉 ALL COMPREHENSIVE TESTS PASSED!")
            print("✅ Data pipeline is robust and ready for production")
            print("✅ Ready to proceed to model testing")
        elif passed >= total * 0.8:
            print("⚠️ Most tests passed - minor issues to fix")
            print("✅ Can proceed with caution")
        else:
            print("❌ Major issues found - need to fix before proceeding")
            
        return passed >= total * 0.8  # 80% pass rate required

if __name__ == "__main__":
    print("="*70)
    print("COMPREHENSIVE MESA-NET DATA PIPELINE TEST")
    print("="*70)
    print("This test validates the fixed data loading implementation")
    print("It checks error handling, edge cases, and production readiness")
    print("="*70)
    
    tester = ComprehensiveDataTester()
    success = tester.run_all_tests()
    
    if success:
        print("\n🎉 DATA PIPELINE IS PRODUCTION READY!")
        print("✅ All critical functionality working")
        print("✅ Error handling robust")
        print("✅ Ready for Step 2: Basic Model Testing")
    else:
        print("\n❌ DATA PIPELINE NEEDS FIXES")
        print("Please review the failed tests above")
        print("Fix issues before proceeding to model testing")
        sys.exit(1)