In [1]:
import sys
import os
import pandas as pd

# Add project root to Python path to find the 'src' directory
notebook_dir = os.getcwd()
project_root = os.path.abspath(os.path.join(notebook_dir, '..'))
if project_root not in sys.path:
    sys.path.append(project_root)
    print(f"Added project root to sys.path: {project_root}")

Added project root to sys.path: c:\Users\peera\Desktop\DroughtLSTM_oneday\src


In [3]:
import subprocess
from tqdm import tqdm
import os

DEST = "./1959-2023_era5.zarr"
SRC = "gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr"

if not os.path.exists(DEST) or not os.listdir(DEST):
    print("📥 Copying ERA5 dataset with full verification...")

    os.makedirs(DEST, exist_ok=True)  # Ensure destination folder exists

    with tqdm(total=1, desc="gsutil rsync", bar_format="{l_bar}{bar} [elapsed: {elapsed}]") as pbar:
        result = subprocess.run(
            ["gsutil", "-m", "rsync", "-r", "-c", SRC, DEST],
            stdout=subprocess.DEVNULL,
            stderr=subprocess.DEVNULL
        )
        pbar.update(1)

    if result.returncode == 0:
        print("✅ Copy complete and verified.")
    else:
        print("❌ Copy failed. Please check your GCP permissions or bucket path.")
else:
    print("✅ Dataset already exists and appears non-empty. Skipping download.")


📥 Copying ERA5 dataset with full verification...


gsutil rsync:   0%|           [elapsed: 00:00]


FileNotFoundError: [WinError 2] The system cannot find the file specified

In [2]:
import torch
import torch.nn as nn
from src.mesanet.mesanet_datamanager import WeatherBench2DataLoader
from src.mesanet.mesanet_dataset import WeatherBench2Dataset
from src.mesanet import MESANet
from src.mesanet.mesanet_loss import MESANetLoss
from src.mesanet.mesanet_trainer import MESANetTrainer
from src.mesanet.mesanet_evaluator import MESANetEvaluator
from mesanet.state_machine import MemoryConfig

# MESA-Net Configuration and Training Script

def create_mesa_net_config():
    """Create default configuration for MESA-Net"""
    return {
        'model': {
            'input_channels': 18,  # Number of meteorological variables
            'num_layers': 3,
            'hidden_dim': 128,
            'memory_config': MemoryConfig(
                num_states=3,
                hidden_dim=128,
                learning_rates={
                    'alert': 0.1,
                    'normal': 0.01,
                    'suppressed': 0.001
                }
            )
        },
        'training': {
            'batch_size': 32,
            'learning_rate': 1e-4,
            'num_epochs': 100,
            'sequence_length': 12,  # 3 days of 6h data
            'forecast_horizon': 4,  # 1 day ahead
        },
        'data': {
            'zarr_path': "gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr",
            'variables': [
                'total_precipitation_6hr',
                '2m_temperature', '2m_dewpoint_temperature',
                'surface_pressure', 'mean_sea_level_pressure',
                '10m_u_component_of_wind', '10m_v_component_of_wind',
                'u_component_of_wind', 'v_component_of_wind',
                'specific_humidity', 'relative_humidity',
                'total_column_water_vapour', 'total_cloud_cover',
                'vertical_velocity', 'geopotential_at_surface'
            ],
            'europe_bounds': {
                'latitude': slice(75, 30),
                'longitude_mask': '(longitude >= 335) | (longitude <= 50)'
            }
        },
        'loss': {
            'alpha_prediction': 1.0,
            'alpha_state_entropy': 0.1,
            'alpha_transition_smooth': 0.01,
            'alpha_cross_memory': 0.05,
            'alpha_cross_layer': 0.05
        }
    }

def main():
    """Main function to demonstrate MESA-Net usage"""
    
    # Configuration
    config = create_mesa_net_config()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # 1. Initialize data loaders
    print("Initializing data loaders...")
    train_data_loader = WeatherBench2DataLoader(
        zarr_path=config['data']['zarr_path'],
        variables=config['data']['variables'],
        europe_bounds=config['data']['europe_bounds'],
        sequence_length=config['training']['sequence_length'],
        forecast_horizon=config['training']['forecast_horizon'],
        batch_size=config['training']['batch_size']
    )
    
    # Create validation data loader (same config, different time range)
    val_data_loader = WeatherBench2DataLoader(
        zarr_path=config['data']['zarr_path'],
        variables=config['data']['variables'],
        europe_bounds=config['data']['europe_bounds'],
        sequence_length=config['training']['sequence_length'],
        forecast_horizon=config['training']['forecast_horizon'],
        batch_size=config['training']['batch_size']
    )
    
    # 2. Initialize model
    print("Initializing MESA-Net model...")
    model = MESANet(
        input_channels=config['model']['input_channels'],
        num_layers=config['model']['num_layers'],
        hidden_dim=config['model']['hidden_dim'],
        memory_config=config['model']['memory_config']
    ).to(device)
    
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # 3. Initialize loss function and optimizer
    loss_fn = MESANetLoss(
        alpha_prediction=config['loss']['alpha_prediction'],
        alpha_state_entropy=config['loss']['alpha_state_entropy'],
        alpha_transition_smooth=config['loss']['alpha_transition_smooth'],
        alpha_cross_memory=config['loss']['alpha_cross_memory'],
        alpha_cross_layer=config['loss']['alpha_cross_layer']
    )
    
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=config['training']['learning_rate'],
        weight_decay=1e-5
    )
    
    # 4. Initialize trainer
    trainer = MESANetTrainer(
        model=model,
        data_loader=train_data_loader,
        loss_fn=loss_fn,
        optimizer=optimizer,
        device=device,
        save_dir="./mesa_net_checkpoints"
    )
    
    # 5. Training
    print("Starting training...")
    trainer.train(
        num_epochs=config['training']['num_epochs'],
        val_data_loader=val_data_loader
    )
    
    # 6. Evaluation
    print("Starting evaluation...")
    evaluator = MESANetEvaluator(model, device)
    
    # Load test data (placeholder)
    test_data_loader = val_data_loader  # In practice, use separate test set
    
    # Evaluate model
    # metrics = evaluator._evaluate_model(model, test_data_loader)
    # print("Final metrics:", metrics)
    
    print("MESA-Net training and evaluation completed!")

if __name__ == "__main__":
    main()

ModuleNotFoundError: No module named 'src.mesanet'