## Introduction ##

- What is our 'product'
- What issues are we attempting to solve/assist? Who are we helping? Why are we doing this?
- Data description (What data do we need to be able to build our model?)
- Overview of our upcoming sections of code. What is our process?
  - API
  - Building the predictions (Time series -> Survival Analysis)





## Helper functions ##

In [None]:
import json
import geopandas as gpd
import pandas as pd
import numpy as np
from bokeh.models import GeoJSONDataSource, LinearColorMapper
from bokeh.palettes import Viridis256
import os
from pathlib import Path
from pyproj import Transformer
from shapely.ops import transform
import fiona

# WIll take the a model and geodata and apply the survival function on the data so it is in geojson format for the map

def get_data_dir():
    """Return path to the raw data directory."""
    current_file = Path(__file__)
    project_root = current_file.parent.parent.parent
    data_dir = project_root / "data" / "raw"
    return data_dir

def load_geojson(file_name):
    """Load a GeoJSON file from the data directory."""
    data_dir = get_data_dir()
    file_path = os.path.join(data_dir, file_name)
    print(f"Loading GeoJSON from {file_path}")
    try:
        # First try standard geopandas approach
        gdf = gpd.read_file(file_path)
        print(f"Loaded GeoJSON from {file_name} using standard GeoPandas")
        return gdf
    except AttributeError as e:
        if "module 'pyogrio' has no attribute" in str(e):
            print(f"ERROR loading '{file_name}' due to pyogrio error: {e}")
            try:
                # Try using fiona directly
                import fiona
                with fiona.open(file_path, 'r') as src:
                    crs = src.crs
                    features = list(src)
                
                # Convert to GeoDataFrame
                import shapely.geometry
                geoms = [shapely.geometry.shape(feature['geometry']) for feature in features]
                properties = [feature['properties'] for feature in features]
                
                # Create a GeoDataFrame
                gdf = gpd.GeoDataFrame(properties, geometry=geoms, crs=crs)
                print(f"Loaded GeoJSON from {file_name} using fiona engine")
                return gdf
            except Exception as fiona_error:
                print(f"ERROR Fiona method failed also: {fiona_error}")
                try:
                    # Last resort: manually parse JSON
                    import json
                    from shapely.geometry import shape
                    
                    with open(file_path, 'r') as f:
                        geojson_dict = json.load(f)
                    
                    features = geojson_dict.get('features', [])
                    geoms = [shape(feature['geometry']) for feature in features]
                    properties = [feature['properties'] for feature in features]
                    
                    gdf = gpd.GeoDataFrame(properties, geometry=geoms)
                    if 'crs' in geojson_dict:
                        gdf.crs = geojson_dict['crs']
                    
                    print(f"Loaded GeoJSON from {file_name} using manual JSON parsing")
                    return gdf
                except Exception as json_error:
                    print(f"ERROR Manual JSON parsing failed: {json_error}")
                    print(f"ERROR: Could not load {file_name}")
                    raise json_error
        else:
            print(f"ERROR Could not load GeoJSON {file_name}: {e}")
            raise e
    except Exception as general_error:
        print(f"ERROR Could not load GeoJSON {file_name}: {general_error}")
        try:
            # Try alternate method with fiona as a general fallback
            import fiona
            with fiona.open(file_path, 'r') as src:
                crs = src.crs
                features = list(src)
            
            # Convert to GeoDataFrame
            import shapely.geometry
            geoms = [shapely.geometry.shape(feature['geometry']) for feature in features]
            properties = [feature['properties'] for feature in features]
            
            # Create a GeoDataFrame
            gdf = gpd.GeoDataFrame(properties, geometry=geoms, crs=crs)
            print(f"Loaded GeoJSON from {file_name} using fallback method")
            return gdf
        except Exception as e2:
            print(f"ERROR Alternative loading also failed: {e2}")
            print(f"ERROR: Could not load {file_name}")
            raise general_error

def load_gpkg(file_name, layer=None):
    """Load a GeoPackage file from the data directory, with optional layer name."""
    data_dir = get_data_dir()
    file_path = os.path.join(data_dir, file_name)
    
    if layer:
        return gpd.read_file(file_path, layer=layer)
    else:
        # Try to get available layers first
        try:
            layers = fiona.listlayers(file_path)
            if len(layers) > 0:
                print(f"Available layers in {file_name}: {layers}")
                return gpd.read_file(file_path, layer=layers[0])
            else:
                return gpd.read_file(file_path)
        except Exception as e:
            print(f"Error getting layers: {e}")
            return gpd.read_file(file_path)

def load_terrain_data(file_name):
    """Load terrain data from GeoJSON or GPKG file."""
    if file_name.endswith('.geojson'):
        return load_geojson(file_name)
    elif file_name.endswith('.gpkg'):
        return load_gpkg(file_name)
    else:
        raise ValueError(f"Unsupported file format for {file_name}")

def wgs84_to_web_mercator(df):
    """Convert GeoDataFrame from WGS84 to Web Mercator projection."""
    transformer = Transformer.from_crs("EPSG:4326", "EPSG:3857", always_xy=True)
    
    # Create new geometry column with transformed coordinates
    df = df.copy()
    df['geometry'] = df['geometry'].apply(
        lambda geom: transform(lambda x, y: transformer.transform(x, y), geom)
    )
    return df

def convert_to_serializable(obj):
    """Convert non-serializable objects to JSON serializable types."""
    import numpy as np
    import pandas as pd
    from datetime import datetime, date
    
    if isinstance(obj, (np.integer, np.int64, np.int32)):
        return int(obj)
    elif isinstance(obj, (np.floating, np.float64, np.float32)):
        return float(obj)
    elif isinstance(obj, (np.ndarray,)):
        return obj.tolist()
    elif isinstance(obj, (pd.Timestamp, datetime, date)):
        return obj.isoformat()
    elif hasattr(obj, 'to_dict'):
        return obj.to_dict()
    else:
        return str(obj)

def gdf_to_geojson(gdf):
    """
    Convert a GeoDataFrame to GeoJSON format.
    
    Parameters:
    -----------
    gdf : geopandas.GeoDataFrame
        GeoDataFrame to convert
    
    Returns:
    --------
    str
        GeoJSON string
    """
    import json
    import pandas as pd
    import numpy as np
    from datetime import datetime
    
    # Preprocess the dataframe to handle problematic columns
    df_copy = gdf.copy()
    
    # Convert all timestamp columns to strings
    for col in df_copy.columns:
        if col != 'geometry':
            # Check if column has timestamp data
            if pd.api.types.is_datetime64_any_dtype(df_copy[col]):
                df_copy[col] = df_copy[col].astype(str)
            # Convert any numpy data type columns to native Python types
            elif pd.api.types.is_numeric_dtype(df_copy[col]):
                df_copy[col] = df_copy[col].apply(
                    lambda x: float(x) if pd.api.types.is_float_dtype(type(x)) else 
                    int(x) if pd.api.types.is_integer_dtype(type(x)) else x
                )
    
    # Use a custom serialization approach
    try:
        class CustomEncoder(json.JSONEncoder):
            def default(self, obj):
                if isinstance(obj, (np.integer, np.int64, np.int32)):
                    return int(obj)
                elif isinstance(obj, (np.floating, np.float64, np.float32)):
                    return float(obj)
                elif isinstance(obj, np.ndarray):
                    return obj.tolist()
                elif isinstance(obj, (pd.Timestamp, datetime)):
                    return obj.isoformat()
                elif hasattr(obj, 'to_dict'):
                    return obj.to_dict()
                return json.JSONEncoder.default(self, obj)
        
        # First convert to GeoJSON dict
        geo_dict = json.loads(df_copy.to_json())
        
        # Then serialize with custom encoder
        return json.dumps(geo_dict, cls=CustomEncoder)
        
    except Exception as e:
        print(f"Error in GeoJSON serialization: {e}")
        
        # Fallback approach - manually build GeoJSON
        features = []
        for idx, row in df_copy.iterrows():
            try:
                properties = {}
                for col in df_copy.columns:
                    if col != 'geometry':
                        val = row[col]
                        properties[col] = convert_to_serializable(val)
                
                geometry = row['geometry'].__geo_interface__
                features.append({
                    "type": "Feature",
                    "properties": properties,
                    "geometry": geometry
                })
            except Exception as feat_e:
                print(f"Error processing feature {idx}: {feat_e}")
        
        geojson = {
            "type": "FeatureCollection",
            "features": features
        }
        
        return json.dumps(geojson)


## Utilizing an API to retrieve percipitation data for all of Denmark ##

To retrieve the neccesary perticipation data, we utilized DMI's (Danmarks Meteorologiske Institut) Meteorological observation API (https://dmiapi.govcloud.dk). This API granted us access to 86 weather observation stations, located all accross Denmark. This was an extensive data extraction process, which required us to divide the extraction into mulitple json files, and finally collect the data into a single parquet file. The API also provided us with the longitude and lattitude of each station, which was utilized in the subsequent mapping of the stations. 

In [None]:

#=================================================================================================#
# This script retrieves weather data from the DMI API for all stations in Denmark.
# It saves the data to JSON files and creates a map visualization using Folium.
# It handles pagination, retries, and error handling for API requests.
# It also includes logging for better tracking of the process.
#=================================================================================================#

import requests, os, json, folium, logging, time, random, gc, time, logging
from tqdm import tqdm
import pandas as pd
from datetime import datetime

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("dmi_api.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger()

# Define your API key
api_key = 'd111ba1d-a1f5-43a5-98c6-347e9c2729b2'  # Replace with your actual DMI API key

# Memory and performance settings
MAX_MEMORY_USAGE_GB = 14     # Increased for maximum performance
MAX_THREADS = 8              # Increased for better parallelism
RATE_LIMIT_DELAY = 0.5       # Reduced for faster data collection
MAX_RETRIES = 100             # High number of retries for resilience
RETRY_DELAY = 20             # Reduced initial retry delay
EXPONENTIAL_BACKOFF = True   # Still using exponential backoff to handle rate limits

# Define the output directory
output_dir = './dmi_data_daily'  # Changed to relative path for portability
# Create the directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Define the bounding box for Denmark [min_lon, min_lat, max_lon, max_lat]
denmark_bbox = [7.253075, 54.303704, 13.321548, 57.809651]

# Define the parameters to retrieve
parameters = [
    "precip_past1h"  # Precipitation in the last hour
]

# Function to save current state of all station data - MOVED UP before it's used
def save_station_data(stations_dict, output_directory, current_parameter):
    """Save all station data to individual JSON files."""
    logger.info("Saving current data to files...")
    saved_count = 0
    
    # Create a parameters directory to store parameter-specific files
    params_dir = os.path.join(output_directory, 'parameters')
    os.makedirs(params_dir, exist_ok=True)
    
    # First, save a parameter-specific file (for recovery if needed)
    param_file = os.path.join(params_dir, f'parameter_{current_parameter}.json')
    stations_with_param = {}
    for station_id, station_data in stations_dict.items():
        if current_parameter in station_data.get('parameters', {}):
            # Create a copy with only this parameter's data
            param_station = {
                'stationId': station_data.get('stationId'),
                'name': station_data.get('name', ''),
                'location': station_data.get('location', {}),
                'parameters': {current_parameter: station_data['parameters'][current_parameter]}
            }
            stations_with_param[station_id] = param_station
    
    # Save the parameter-specific file
    with open(param_file, 'w', encoding='utf-8') as f:
        json.dump(stations_with_param, f)
    
    # Now save individual station files with all accumulated data
    stations_dir = os.path.join(output_directory, 'stations')
    os.makedirs(stations_dir, exist_ok=True)
    
    for station_id, station_data in tqdm(stations_dict.items(), desc="Saving station data"):
        if station_data.get('parameters'):
            # Save to JSON file
            filename = os.path.join(stations_dir, f'station_{station_id}.json')
            with open(filename, 'w', encoding='utf-8') as f:
                json.dump(station_data, f, indent=2)
            saved_count += 1
    
    logger.info(f"Saved data for {saved_count} stations to {output_directory}")
    
    # Also save a parameter progress file
    progress_file = os.path.join(output_directory, 'parameter_progress.json')
    with open(progress_file, 'w', encoding='utf-8') as f:
        json.dump({
            'last_processed_parameter': current_parameter, 
            'timestamp': datetime.now().isoformat(),
            'stations_saved': saved_count
        }, f, indent=2)
    
    return saved_count

# Calculate the time frame
end_time = pd.Timestamp.now(tz='UTC')
start_time = end_time - pd.DateOffset(years=30)  # Start 30 years ago
datetime_str = f"{start_time.isoformat()}/{end_time.isoformat()}"

# Function to retrieve all stations with retry logic
def get_all_stations(api_key):
    """Retrieve all DMI stations, handling pagination and retries."""
    url = 'https://dmigw.govcloud.dk/v2/metObs/collections/station/items'
    params = {'api-key': api_key, 'limit': '10000'}
    stations = []
    
    retry_count = 0
    while retry_count < MAX_RETRIES:
        try:
            logger.info(f"Retrieving stations (attempt {retry_count + 1}/{MAX_RETRIES})...")
            r = requests.get(url, params=params, timeout=60)
            r.raise_for_status()
            json_data = r.json()
            stations.extend(json_data['features'])
            
            next_link = next((link for link in json_data['links'] if link['rel'] == 'next'), None)
            if next_link:
                logger.info(f"Found next page link, continuing pagination...")
                url = next_link['href']
                params = {}  # Clear params for subsequent requests
            else:
                logger.info(f"Station retrieval complete. Found {len(stations)} stations.")
                break
                
            # Add delay to avoid rate limiting
            time.sleep(RATE_LIMIT_DELAY)
            
        except requests.RequestException as e:
            logger.error(f"Error retrieving stations: {e}")
            retry_count += 1
            if retry_count < MAX_RETRIES:
                logger.info(f"Retrying in {RETRY_DELAY} seconds...")
                time.sleep(RETRY_DELAY)
            else:
                logger.error("Maximum retry attempts reached. Proceeding with collected stations.")
                break
    
    return stations

# Function to get data for a specific parameter with improved error handling
def get_data_for_parameter(parameter_id, datetime_str, api_key, bbox=None, time_chunks=1):
    """Retrieve data for a specific parameter, with robust error handling and retry logic."""
    all_data = []
    
    # Parse start and end times
    times = datetime_str.split('/')
    start_time = pd.Timestamp(times[0])
    end_time = pd.Timestamp(times[1])
    
    # Calculate the time delta for each chunk
    total_days = (end_time - start_time).days
    days_per_chunk = max(1, total_days // time_chunks)
    
    logger.info(f"Splitting timeframe into {time_chunks} chunks of approximately {days_per_chunk} days each")
    
    # Process each time chunk
    for i in range(time_chunks):
        chunk_start = start_time + pd.Timedelta(days=i * days_per_chunk)
        chunk_end = start_time + pd.Timedelta(days=(i+1) * days_per_chunk) if i < time_chunks - 1 else end_time
        chunk_datetime_str = f"{chunk_start.isoformat()}/{chunk_end.isoformat()}"
        
        logger.info(f"Processing time chunk {i+1}/{time_chunks}: {chunk_start.date()} to {chunk_end.date()}")
        
        # Set up the request parameters
        url = 'https://dmigw.govcloud.dk/v2/metObs/collections/observation/items'
        params = {
            'api-key': api_key,
            'datetime': chunk_datetime_str,
            'parameterId': parameter_id,
            'limit': '10000'  # Reduced limit to minimize server errors
        }
        
        # Add bbox parameter if provided
        if bbox:
            params['bbox'] = f"{bbox[0]},{bbox[1]},{bbox[2]},{bbox[3]}"
        
        # Variables to track pagination
        offset = 0
        max_offset = 490000  # Stay below the 500,000 limit
        has_more = True
        chunk_data = []
        
        # Retry loop
        while has_more and offset < max_offset:
            retry_count = 0
            success = False
            
            while retry_count < MAX_RETRIES and not success:
                try:
                    # Add the offset parameter for pagination
                    if offset > 0:
                        params['offset'] = str(offset)
                    
                    logger.info(f"Making request to: {url} with offset {offset} (attempt {retry_count + 1}/{MAX_RETRIES})")
                    r = requests.get(url, params=params, timeout=120)  # Increased timeout
                    r.raise_for_status()
                    json_data = r.json()
                    
                    if 'features' in json_data:
                        batch_size = len(json_data['features'])
                        logger.info(f"Retrieved {batch_size} records in this batch")
                        chunk_data.extend(json_data['features'])
                        
                        # Save data periodically to avoid memory issues
                        if len(chunk_data) >= 100000:  # Increased batch size before saving
                            logger.info("Saving intermediate batch to avoid memory issues...")
                            save_parameter_batch(parameter_id, chunk_data, chunk_datetime_str)
                            all_data.extend(chunk_data)  # Add to total count
                            chunk_data = []  # Clear for next batch
                            gc.collect()  # Force garbage collection
                        
                        # Check if we need to continue pagination
                        if batch_size < int(params['limit']):
                            has_more = False
                        else:
                            offset += batch_size
                        
                        success = True
                    else:
                        logger.warning(f"No 'features' in response. Response: {json_data}")
                        has_more = False
                        success = True
                    
                except requests.RequestException as e:
                    error_msg = str(e)
                    logger.error(f"Error retrieving data: {error_msg}")
                    
                    if hasattr(e, 'response') and e.response is not None:
                        status_code = e.response.status_code
                        response_text = e.response.text if hasattr(e.response, 'text') else "No response text"
                        logger.error(f"Status code: {status_code}, Response: {response_text}")
                        
                        # Handle specific errors
                        if status_code == 429:  # Too Many Requests
                            logger.warning("Rate limit exceeded. Increasing wait time.")
                            time.sleep(RETRY_DELAY * 2)  # Double the retry delay
                        elif status_code in [502, 503, 504]:  # Server errors
                            logger.warning(f"Server error {status_code}. Will retry.")
                        elif status_code == 400 and 'Offset cannot be greater than 500000' in response_text:
                            logger.warning("Hit offset limit. Saving current batch and continuing with next time chunk.")
                            has_more = False
                            success = True  # Exit retry loop but not while loop
                    
                    retry_count += 1
                    if retry_count < MAX_RETRIES and not success:
                        if EXPONENTIAL_BACKOFF:
                            # Exponential backoff with jitter
                            wait_time = min(RETRY_DELAY * (2 ** (retry_count - 1)) + (random.randint(0, 1000) / 1000), 600)
                        else:
                            wait_time = RETRY_DELAY * retry_count  # Linear backoff
                        
                        logger.info(f"Retrying in {wait_time:.1f} seconds...")
                        time.sleep(wait_time)
                    else:
                        logger.error("Maximum retry attempts reached for this batch.")
                        has_more = False  # Stop trying this chunk
                
                # Add a delay between requests to avoid rate limiting
                time.sleep(RATE_LIMIT_DELAY)
        
        # Save any remaining data from this chunk
        if chunk_data:
            logger.info(f"Saving final batch of {len(chunk_data)} records for time chunk {i+1}")
            save_parameter_batch(parameter_id, chunk_data, chunk_datetime_str)
            all_data.extend(chunk_data)
            chunk_data = []
            gc.collect()
    
    total_records = len(all_data)
    logger.info(f"Total records retrieved for {parameter_id}: {total_records}")
    return all_data

# Function to save a batch of parameter data
def save_parameter_batch(parameter_id, batch_data, datetime_str):
    """Save a batch of parameter data to a file."""
    if not batch_data:
        return
    
    batch_size = len(batch_data)
    logger.info(f"Saving batch of {batch_size} records for {parameter_id}...")
    
    # Create a batch directory
    batch_dir = os.path.join(output_dir, 'parameter_batches', parameter_id)
    os.makedirs(batch_dir, exist_ok=True)
    
    # Generate a unique batch filename based on timestamp
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    batch_file = os.path.join(batch_dir, f'{parameter_id}_batch_{timestamp}.json')
    
    # Save the batch with metadata
    batch_metadata = {
        'parameter_id': parameter_id,
        'datetime_range': datetime_str,
        'record_count': batch_size,
        'created_at': datetime.now().isoformat(),
        'data': batch_data
    }
    
    with open(batch_file, 'w', encoding='utf-8') as f:
        json.dump(batch_metadata, f)
    
    logger.info(f"Saved batch to {batch_file}")
    
    # Update the parameter tracking file
    update_parameter_tracking(parameter_id, batch_size, batch_file, datetime_str)

# Function to update the parameter tracking file
def update_parameter_tracking(parameter_id, batch_size, batch_file, datetime_str):
    """Update the parameter tracking file with new batch information."""
    tracking_file = os.path.join(output_dir, 'parameter_tracking.json')
    tracking_data = {}
    
    # Load existing tracking data if it exists
    if os.path.exists(tracking_file):
        try:
            with open(tracking_file, 'r', encoding='utf-8') as f:
                tracking_data = json.load(f)
        except Exception as e:
            logger.error(f"Error reading tracking file: {e}")
    
    # Update the tracking data
    if parameter_id not in tracking_data:
        tracking_data[parameter_id] = {
            'total_records': 0,
            'batches': []
        }
    
    tracking_data[parameter_id]['total_records'] += batch_size
    tracking_data[parameter_id]['batches'].append({
        'file': os.path.basename(batch_file),
        'records': batch_size,
        'datetime_range': datetime_str,
        'created_at': datetime.now().isoformat()
    })
    
    # Save updated tracking data
    with open(tracking_file, 'w', encoding='utf-8') as f:
        json.dump(tracking_data, f, indent=2)

# Function to combine parameter batches
def combine_parameter_batches(parameter_id):
    """Combine all stored batches for a parameter into station data."""
    batch_dir = os.path.join(output_dir, 'parameter_batches', parameter_id)
    if not os.path.exists(batch_dir):
        logger.warning(f"No batch directory found for parameter {parameter_id}")
        return {}
    
    # Dictionary to hold station data
    stations_data = {}
    
    # List all batch files
    batch_files = [f for f in os.listdir(batch_dir) if f.endswith('.json')]
    if not batch_files:
        logger.warning(f"No batch files found for parameter {parameter_id}")
        return {}
    
    logger.info(f"Found {len(batch_files)} batch files for parameter {parameter_id}")
    
    # Process each batch file
    for batch_file in tqdm(batch_files, desc=f"Processing {parameter_id} batches"):
        try:
            with open(os.path.join(batch_dir, batch_file), 'r', encoding='utf-8') as f:
                batch_data = json.load(f)
            
            # Process each record in the batch
            for record in batch_data['data']:
                station_id = record['properties']['stationId']
                
                # Initialize station if not exists
                if station_id not in stations_data:
                    stations_data[station_id] = {
                        'stationId': station_id,
                        'parameters': {parameter_id: []}
                    }
                
                # Initialize parameter if not exists
                if parameter_id not in stations_data[station_id]['parameters']:
                    stations_data[station_id]['parameters'][parameter_id] = []
                
                # Add the record
                stations_data[station_id]['parameters'][parameter_id].append(record)
        except Exception as e:
            logger.error(f"Error processing batch file {batch_file}: {e}")
    
    return stations_data

# Main execution function
def main():
    # Retrieve all stations
    logger.info("Starting DMI data collection process")
    all_stations = get_all_stations(api_key)

    # Filter stations within the bounding box
    filtered_stations = {}
    for station in all_stations:
        coords = station['geometry']['coordinates']
        lon, lat = coords[0], coords[1]
        if (denmark_bbox[0] <= lon <= denmark_bbox[2] and 
            denmark_bbox[1] <= lat <= denmark_bbox[3]):
            station_id = station['properties']['stationId']
            filtered_stations[station_id] = {
                'stationId': station_id,
                'name': station['properties'].get('name', ''),
                'location': {
                    'longitude': lon,
                    'latitude': lat
                },
                'parameters': {}
            }

    logger.info(f"Found {len(filtered_stations)} stations within the bounding box.")

    # Check for existing progress file
    progress_file = os.path.join(output_dir, 'parameter_progress.json')
    last_processed_idx = -1  # Start from the beginning by default
    
    if os.path.exists(progress_file):
        try:
            with open(progress_file, 'r') as f:
                progress_data = json.load(f)
                last_param = progress_data.get('last_processed_parameter')
                if last_param in parameters:
                    last_processed_idx = parameters.index(last_param)
                    logger.info(f"Resuming from parameter: {last_param} (index {last_processed_idx})")
        except Exception as e:
            logger.error(f"Error reading progress file: {e}")
    
    # Create directory for parameter batches
    batch_dir = os.path.join(output_dir, 'parameter_batches')
    os.makedirs(batch_dir, exist_ok=True)
    
    # Process each parameter for all stations at once, starting after the last processed parameter
    for i, parameter in enumerate(parameters[last_processed_idx+1:], start=last_processed_idx+1):
        logger.info(f"\nRetrieving {parameter} data for all stations... ({i+1}/{len(parameters)})")
        
        # Use smaller time chunks for better handling
        time_chunks = 12  # Split into smaller chunks to reduce server load and handle errors better
        
        # First retrieve parameter data in batches
        try:
            logger.info(f"Retrieving data for parameter {parameter} in batches...")
            get_data_for_parameter(parameter, datetime_str, api_key, denmark_bbox, time_chunks=time_chunks)
            
            # Now combine all batches for this parameter
            logger.info(f"Combining batches for parameter {parameter}...")
            parameter_stations = combine_parameter_batches(parameter)
            
            # Merge parameter data with existing station data
            for station_id, station_data in parameter_stations.items():
                if station_id in filtered_stations:
                    # Add station metadata if it's a new station
                    if 'name' not in station_data and 'name' in filtered_stations[station_id]:
                        station_data['name'] = filtered_stations[station_id]['name']
                    if 'location' not in station_data and 'location' in filtered_stations[station_id]:
                        station_data['location'] = filtered_stations[station_id]['location']
                    
                    # Add parameter data to the station
                    if parameter in station_data['parameters']:
                        filtered_stations[station_id]['parameters'][parameter] = station_data['parameters'][parameter]
            
            # Save all station data after each parameter is processed
            logger.info(f"Completed processing parameter: {parameter}")
            save_station_data(filtered_stations, output_dir, parameter)
            
            # Free up memory
            parameter_stations = None
            gc.collect()
            
        except Exception as e:
            logger.error(f"Error processing parameter {parameter}: {e}")
            # Still save progress even if we encounter an error
            save_station_data(filtered_stations, output_dir, f"{parameter}_error")

    # Final save to ensure everything is written
    logger.info("\nPerforming final data save...")
    save_count = save_station_data(filtered_stations, output_dir, "final")
    
    logger.info(f"Data extraction complete. Saved data for {save_count} stations.")

if __name__ == "__main__":
    try:
        main()
        logger.info("Program completed successfully")
    except Exception as e:
        logger.critical(f"Unexpected error: {e}", exc_info=True)
        logger.info("Program terminated with errors")





#=======================================================================================================#
# This script extracts station information from the DMI API and saves it to a Parquet file. It also     #
# creates a map visualization of the stations using Folium. The script first attempts to access the     #
# stations endpoint directly. If that fails, it falls back to extracting unique station IDs from        #
# observation data. The script handles potential issues with duplicate columns and missing coordinates. #
# It also includes error handling for API requests and file saving. The map visualization is saved as   #
# an HTML file in the specified directory. The script is designed to be run in a Python environment##   #
# with the necessary libraries installed.                                                               #
#=======================================================================================================#



import requests
import pandas as pd
import folium
import os

# Replace this with your actual API key
api_key = 'd111ba1d-a1f5-43a5-98c6-347e9c2729b2'  # insert your own key here

# Method 1: Try to access stations endpoint directly
stations_url = 'https://dmigw.govcloud.dk/v2/metObs/collections/station/items'

def get_all_stations_method1():
    """
    Attempt to get all stations using a dedicated station endpoint
    """
    try:
        response = requests.get(stations_url, params={'api-key': api_key})
        if response.status_code == 200:
            stations_data = response.json()
            
            # Print the first feature to debug the structure
            if 'features' in stations_data and stations_data['features']:
                print("First station feature structure:")
                print(stations_data['features'][0])
            
            # Create a fresh DataFrame to avoid duplicates
            stations_list = []
            for feature in stations_data['features']:
                station = {}
                
                # Extract basic info
                station['id'] = feature.get('id')
                station['feature_type'] = feature.get('type')
                
                # Extract geometry
                if 'geometry' in feature:
                    if 'coordinates' in feature['geometry']:
                        coords = feature['geometry']['coordinates']
                        if coords and len(coords) >= 2:
                            station['longitude'] = coords[0]
                            station['latitude'] = coords[1]
                    if 'type' in feature['geometry']:
                        station['geometry_type'] = feature['geometry']['type']
                
                # Extract properties
                if 'properties' in feature:
                    props = feature['properties']
                    for key, value in props.items():
                        # Rename 'type' to avoid duplicates
                        if key == 'type':
                            station['station_type'] = value
                        else:
                            station[key] = value
                
                stations_list.append(station)
            
            # Create DataFrame from our clean list
            stations_df = pd.DataFrame(stations_list)
            
            return stations_df
        else:
            print(f"Failed to access station endpoint. Status code: {response.status_code}")
            return None
    except Exception as e:
        print(f"Error accessing station endpoint: {e}")
        return None

# Method 2: Extract unique stations from observation data
def get_all_stations_method2():
    """
    Get all stations by extracting unique station IDs from observation data
    """
    dmi_url = 'https://dmigw.govcloud.dk/v2/metObs/collections/observation/items'
    
    try:
        # Request with a high limit to get as many records as possible
        params = {
            'api-key': api_key,
            'limit': '300000'  # Maximum allowed limit
        }
        
        response = requests.get(dmi_url, params=params)
        if response.status_code != 200:
            print(f"Failed to retrieve data. Status code: {response.status_code}")
            return None
            
        json_data = response.json()
        
        # Create a fresh DataFrame to avoid duplicates
        stations_list = []
        seen_station_ids = set()
        
        # Process each feature
        for feature in json_data.get('features', []):
            # Only process if it has properties and stationId
            if 'properties' in feature and 'stationId' in feature['properties']:
                station_id = feature['properties']['stationId']
                
                # Skip if we've already seen this station
                if station_id in seen_station_ids:
                    continue
                
                seen_station_ids.add(station_id)
                
                station = {'stationId': station_id}
                
                # Extract coordinates if available
                if 'geometry' in feature and 'coordinates' in feature['geometry']:
                    coords = feature['geometry']['coordinates']
                    if coords and len(coords) >= 2:
                        station['longitude'] = coords[0]
                        station['latitude'] = coords[1]
                
                stations_list.append(station)
        
        # Create DataFrame from our clean list
        stations_df = pd.DataFrame(stations_list)
            
        return stations_df
    
    except Exception as e:
        print(f"Error extracting stations from observation data: {e}")
        return None

# Try both methods and use the one that works
print("Attempting to get stations using Method 1...")
stations_df = get_all_stations_method1()
if stations_df is None or stations_df.empty:
    print("Falling back to method 2...")
    stations_df = get_all_stations_method2()

if stations_df is not None and not stations_df.empty:
    print(f"Successfully retrieved {len(stations_df)} stations")
    print("\nColumns in the DataFrame:")
    print(stations_df.columns.tolist())
    print("\nFirst 10 stations:")
    print(stations_df.head(10))
    
    # Check for duplicate column names
    if len(stations_df.columns) != len(set(stations_df.columns)):
        duplicate_cols = [col for col in stations_df.columns if list(stations_df.columns).count(col) > 1]
        print(f"Warning: Found duplicate columns: {duplicate_cols}")
    
    # Save to Parquet in the specified directory
    output_dir = '/Users/maks/Documents/GitHub/aba_flooding/dmi_data_daily'
    
    # Ensure the directory exists
    os.makedirs(output_dir, exist_ok=True)
    
    output_path = os.path.join(output_dir, 'dmi_stations.parquet')
    
    try:
        # Save to parquet
        stations_df.to_parquet(output_path, index=False)
        print(f"\nSaved station information to '{output_path}'")
        
    except Exception as e:
        print(f"Error saving to Parquet: {e}")
        print("Detailed column information for debugging:")
        for i, col in enumerate(stations_df.columns):
            print(f"{i}: {col} - type: {stations_df[col].dtype}")
        
        print("\nPlease fix the column issues and try again.")
    
    # Create a map visualization if coordinates are available
    if 'longitude' in stations_df.columns and 'latitude' in stations_df.columns:
        # Filter out rows with missing coordinates
        map_df = stations_df.dropna(subset=['longitude', 'latitude'])
        
        if len(map_df) > 0:
            # Calculate the center of the map
            center_lat = map_df['latitude'].mean()
            center_lon = map_df['longitude'].mean()
            
            # Create a map
            m = folium.Map(location=[center_lat, center_lon], zoom_start=7)
            
            # Add markers for each station
            for _, row in map_df.iterrows():
                # Create popup with station info
                popup_html = f"""
                <b>Station ID:</b> {row.get('stationId', 'N/A')}<br>
                """
                
                # Add name if available
                if 'name' in row and pd.notna(row['name']):
                    popup_html += f"<b>Name:</b> {row['name']}<br>"
                
                # Add station type if available
                if 'station_type' in row and pd.notna(row['station_type']):
                    popup_html += f"<b>Type:</b> {row['station_type']}<br>"
                
                folium.Marker(
                    location=[row['latitude'], row['longitude']],
                    popup=folium.Popup(popup_html, max_width=300),
                    icon=folium.Icon(color='blue')
                ).add_to(m)
            
            # Save the map to the same directory
            map_path = os.path.join(output_dir, 'dmi_stations_map.html')
            m.save(map_path)
            print(f"Created map visualization in '{map_path}'")
else:
    print("Failed to retrieve station information using both methods.
          



#=======================================================================================================#
# This script extracts station information from the DMI API and saves it to a Parquet file. It also     #
# creates a map visualization of the stations using Folium. The script first attempts to access the     #
# stations endpoint directly. If that fails, it falls back to extracting unique station IDs from        #
# observation data. The script handles potential issues with duplicate columns and missing coordinates. #
# It also includes error handling for API requests and file saving. The map visualization is saved as   #
# an HTML file in the specified directory. The script is designed to be run in a Python environment##   #
# with the necessary libraries installed.                                                               #
#=======================================================================================================#



import requests
import pandas as pd
import folium
import os

# Replace this with your actual API key
api_key = 'd111ba1d-a1f5-43a5-98c6-347e9c2729b2'  # insert your own key here

# Method 1: Try to access stations endpoint directly
stations_url = 'https://dmigw.govcloud.dk/v2/metObs/collections/station/items'

def get_all_stations_method1():
    """
    Attempt to get all stations using a dedicated station endpoint
    """
    try:
        response = requests.get(stations_url, params={'api-key': api_key})
        if response.status_code == 200:
            stations_data = response.json()
            
            # Print the first feature to debug the structure
            if 'features' in stations_data and stations_data['features']:
                print("First station feature structure:")
                print(stations_data['features'][0])
            
            # Create a fresh DataFrame to avoid duplicates
            stations_list = []
            for feature in stations_data['features']:
                station = {}
                
                # Extract basic info
                station['id'] = feature.get('id')
                station['feature_type'] = feature.get('type')
                
                # Extract geometry
                if 'geometry' in feature:
                    if 'coordinates' in feature['geometry']:
                        coords = feature['geometry']['coordinates']
                        if coords and len(coords) >= 2:
                            station['longitude'] = coords[0]
                            station['latitude'] = coords[1]
                    if 'type' in feature['geometry']:
                        station['geometry_type'] = feature['geometry']['type']
                
                # Extract properties
                if 'properties' in feature:
                    props = feature['properties']
                    for key, value in props.items():
                        # Rename 'type' to avoid duplicates
                        if key == 'type':
                            station['station_type'] = value
                        else:
                            station[key] = value
                
                stations_list.append(station)
            
            # Create DataFrame from our clean list
            stations_df = pd.DataFrame(stations_list)
            
            return stations_df
        else:
            print(f"Failed to access station endpoint. Status code: {response.status_code}")
            return None
    except Exception as e:
        print(f"Error accessing station endpoint: {e}")
        return None

# Method 2: Extract unique stations from observation data
def get_all_stations_method2():
    """
    Get all stations by extracting unique station IDs from observation data
    """
    dmi_url = 'https://dmigw.govcloud.dk/v2/metObs/collections/observation/items'
    
    try:
        # Request with a high limit to get as many records as possible
        params = {
            'api-key': api_key,
            'limit': '300000'  # Maximum allowed limit
        }
        
        response = requests.get(dmi_url, params=params)
        if response.status_code != 200:
            print(f"Failed to retrieve data. Status code: {response.status_code}")
            return None
            
        json_data = response.json()
        
        # Create a fresh DataFrame to avoid duplicates
        stations_list = []
        seen_station_ids = set()
        
        # Process each feature
        for feature in json_data.get('features', []):
            # Only process if it has properties and stationId
            if 'properties' in feature and 'stationId' in feature['properties']:
                station_id = feature['properties']['stationId']
                
                # Skip if we've already seen this station
                if station_id in seen_station_ids:
                    continue
                
                seen_station_ids.add(station_id)
                
                station = {'stationId': station_id}
                
                # Extract coordinates if available
                if 'geometry' in feature and 'coordinates' in feature['geometry']:
                    coords = feature['geometry']['coordinates']
                    if coords and len(coords) >= 2:
                        station['longitude'] = coords[0]
                        station['latitude'] = coords[1]
                
                stations_list.append(station)
        
        # Create DataFrame from our clean list
        stations_df = pd.DataFrame(stations_list)
            
        return stations_df
    
    except Exception as e:
        print(f"Error extracting stations from observation data: {e}")
        return None

# Try both methods and use the one that works
print("Attempting to get stations using Method 1...")
stations_df = get_all_stations_method1()
if stations_df is None or stations_df.empty:
    print("Falling back to method 2...")
    stations_df = get_all_stations_method2()

if stations_df is not None and not stations_df.empty:
    print(f"Successfully retrieved {len(stations_df)} stations")
    print("\nColumns in the DataFrame:")
    print(stations_df.columns.tolist())
    print("\nFirst 10 stations:")
    print(stations_df.head(10))
    
    # Check for duplicate column names
    if len(stations_df.columns) != len(set(stations_df.columns)):
        duplicate_cols = [col for col in stations_df.columns if list(stations_df.columns).count(col) > 1]
        print(f"Warning: Found duplicate columns: {duplicate_cols}")
    
    # Save to Parquet in the specified directory
    output_dir = '/Users/maks/Documents/GitHub/aba_flooding/dmi_data_daily'
    
    # Ensure the directory exists
    os.makedirs(output_dir, exist_ok=True)
    
    output_path = os.path.join(output_dir, 'dmi_stations.parquet')
    
    try:
        # Save to parquet
        stations_df.to_parquet(output_path, index=False)
        print(f"\nSaved station information to '{output_path}'")
        
    except Exception as e:
        print(f"Error saving to Parquet: {e}")
        print("Detailed column information for debugging:")
        for i, col in enumerate(stations_df.columns):
            print(f"{i}: {col} - type: {stations_df[col].dtype}")
        
        print("\nPlease fix the column issues and try again.")
    
    # Create a map visualization if coordinates are available
    if 'longitude' in stations_df.columns and 'latitude' in stations_df.columns:
        # Filter out rows with missing coordinates
        map_df = stations_df.dropna(subset=['longitude', 'latitude'])
        
        if len(map_df) > 0:
            # Calculate the center of the map
            center_lat = map_df['latitude'].mean()
            center_lon = map_df['longitude'].mean()
            
            # Create a map
            m = folium.Map(location=[center_lat, center_lon], zoom_start=7)
            
            # Add markers for each station
            for _, row in map_df.iterrows():
                # Create popup with station info
                popup_html = f"""
                <b>Station ID:</b> {row.get('stationId', 'N/A')}<br>
                """
                
                # Add name if available
                if 'name' in row and pd.notna(row['name']):
                    popup_html += f"<b>Name:</b> {row['name']}<br>"
                
                # Add station type if available
                if 'station_type' in row and pd.notna(row['station_type']):
                    popup_html += f"<b>Type:</b> {row['station_type']}<br>"
                
                folium.Marker(
                    location=[row['latitude'], row['longitude']],
                    popup=folium.Popup(popup_html, max_width=300),
                    icon=folium.Icon(color='blue')
                ).add_to(m)
            
            # Save the map to the same directory
            map_path = os.path.join(output_dir, 'dmi_stations_map.html')
            m.save(map_path)
            print(f"Created map visualization in '{map_path}'")
else:
    print("Failed to retrieve station information using both methods.")



 
#=================================================================================================#
# This script processes JSON files containing precipitation data, extracts relevant information,  #
# and saves the data into a Parquet file. It also logs the processing steps and statistics.       #
#=================================================================================================#


import os
import json
import pandas as pd
from datetime import datetime
import glob
import logging
from tqdm import tqdm

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("precipitation_processing.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

def process_json_files(folder_path):
    # Get all json files in the folder
    json_files = glob.glob(os.path.join(folder_path, "*.json"))
    logger.info(f"Found {len(json_files)} JSON files to process")
    
    # Create a dictionary to store all data
    all_data = {}
    total_records = 0
    
    # Process each file with a progress bar
    for file_path in tqdm(json_files, desc="Processing JSON files"):
        logger.info(f"Processing file: {os.path.basename(file_path)}")
        
        try:
            # Load the JSON file
            with open(file_path, 'r') as f:
                data = json.load(f)
            
            file_records = 0
            # Extract the data points
            for item in data['data']:
                # Get timestamp, station ID, and value
                timestamp = item['properties']['observed']
                station_id = item['properties']['stationId']
                value = item['properties']['value']
                
                # Convert timestamp to datetime object
                timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%SZ")
                
                # Store in dictionary: all_data[timestamp][station_id] = value
                if timestamp not in all_data:
                    all_data[timestamp] = {}
                
                all_data[timestamp][station_id] = value
                file_records += 1
            
            total_records += file_records
            logger.info(f"Extracted {file_records} records from {os.path.basename(file_path)}")
            
        except Exception as e:
            logger.error(f"Error processing {os.path.basename(file_path)}: {str(e)}")
    
    logger.info(f"Total records processed: {total_records}")
    logger.info("Converting to DataFrame...")
    
    # Convert to DataFrame
    df = pd.DataFrame.from_dict(all_data, orient='index')
    
    # Sort index (timestamps) chronologically
    df.sort_index(inplace=True)
    
    logger.info(f"DataFrame created with shape: {df.shape}")
    
    return df

if __name__ == "__main__":
    folder_path = "/Users/maks/Documents/GitHub/aba_flooding/dmi_data_daily/parameter_batches/precip_past1h_test"
    
    logger.info("Starting precipitation data processing")
    
    # Process the files
    precipitation_df = process_json_files(folder_path)
    
    # Get some stats
    memory_usage_mb = precipitation_df.memory_usage(deep=True).sum() / 1048576
    num_stations = len(precipitation_df.columns)
    num_timestamps = len(precipitation_df)
    
    # Save to Parquet
    logger.info("Saving data to Parquet file...")
    precipitation_df.to_parquet("precipitation_data_test.parquet")
    
    logger.info(f"Processing complete. Data saved to precipitation_data_test.parquet")
    logger.info(f"DataFrame shape: {precipitation_df.shape}")
    logger.info(f"Number of timestamps: {num_timestamps}")
    logger.info(f"Number of stations: {num_stations}")
    logger.info(f"DataFrame memory usage: {memory_usage_mb:.2f} MB")
    logger.info(f"Date range: {precipitation_df.index.min()} to {precipitation_df.index.max()}")
    
    print("\nSample of the data:")
    print(precipitation_df.head())             

## Extracting sediment data ##

- Explain process of extracting sediment data via QGIS

## Preprocessing the data ##

In [None]:
import pandas as pd
import aba_flooding.perculation_mapping as pm
import aba_flooding.geo_utils as gu
# import perculation_mapping as pm
# import geo_utils as gu
from shapely.geometry import Point, Polygon
from scipy.spatial import Voronoi
import numpy as np
from shapely.geometry import Polygon
import geopandas as gpd
import os
import matplotlib.pyplot as plt
###########################################
# SECTION 1: GEOGRAPHIC DATA PROCESSING   #
###########################################

def voronoi_finite_polygons_2d(vor, radius=None):
    """Convert Voronoi diagram to finite polygons."""
    if vor.points.shape[1] != 2:
        raise ValueError("Requires 2D input")
    
    new_regions = []
    new_vertices = vor.vertices.tolist()
    
    center = vor.points.mean(axis=0)
    radius = np.ptp(vor.points, axis=0).max() * 2 if radius is None else radius
    
    # Construct a map of all ridges for a given point
    all_ridges = {}
    for (p1, p2), (v1, v2) in zip(vor.ridge_points, vor.ridge_vertices):
        all_ridges.setdefault(p1, []).append((p2, v1, v2))
        all_ridges.setdefault(p2, []).append((p1, v1, v2))
    
    # Reconstruct infinite regions
    for p1, region in enumerate(vor.point_region):
        # Skip points that don't have any ridges
        if p1 not in all_ridges:
            print(f"Skipping point {p1} which has no ridges")
            continue
            
        vertices = vor.regions[region]
        if all(v >= 0 for v in vertices):
            # Finite region
            new_regions.append(vertices)
            continue
        
        # Reconstruct a non-finite region
        ridges = all_ridges[p1]
        new_region = [v for v in vertices if v >= 0]
        
        for p2, v1, v2 in ridges:
            if v2 < 0:
                v1, v2 = v2, v1
            if v1 >= 0:
                # Finite ridge
                continue
            
            # Infinite ridge
            t = vor.points[p2] - vor.points[p1]  # tangent
            t /= np.linalg.norm(t)
            n = np.array([-t[1], t[0]])  # normal
            
            midpoint = vor.points[[p1, p2]].mean(axis=0)
            direction = np.sign(np.dot(midpoint - center, n)) * n
            far_point = vor.vertices.mean(axis=0) + direction * radius
            
            new_region.append(len(new_vertices))
            new_vertices.append(far_point.tolist())
        
        # Sort region counterclockwise
        vs = np.asarray([new_vertices[v] for v in new_region])
        c = vs.mean(axis=0)
        angles = np.arctan2(vs[:,1] - c[1], vs[:,0] - c[0])
        new_region = np.array(new_region)[np.argsort(angles)]
        
        new_regions.append(new_region.tolist())
    
    return new_regions, np.asarray(new_vertices)

def create_precipitation_coverage(denmark_gdf):
    """
    Create Voronoi polygons for precipitation stations that cover Denmark without overlap.
    
    Called by: create_full_coverage()
    Calls: voronoi_finite_polygons_2d()
    """
    try:
        # Load dmi station data - stations are rows with longitude and latitude columns
        print(f"Loading dmi station data from data/raw/dmi_stations.parquet...")
        station_data = pd.read_parquet('data/raw/dmi_stations.parquet')
        print(f"Loaded coordinate data of the station with {len(station_data)} stations")
        
        # Determine the station ID column
        if 'stationId' in station_data.columns:
            id_column = 'stationId' 
        else:
            print("WARNING: No 'stationId' column found, using first column as ID")   
            id_column = station_data.columns[0]  # Use the first column as ID 
        
        # Find longitude and latitude columns
        lon_col = next((col for col in station_data.columns if col.lower() in ['longitude', 'lon', 'long']), None)
        lat_col = next((col for col in station_data.columns if col.lower() in ['latitude', 'lat']), None)
        
        if lon_col is None or lat_col is None:
            raise ValueError(f"Could not identify longitude and latitude columns. Available columns: {station_data.columns.tolist()}")
        print(f"Using column '{lon_col}' for longitude, '{lat_col}' for latitude, and '{id_column}' for station IDs")
        
        # Create GeoDataFrame for stations
        print("Creating GeoDataFrame for stations containing their coordinates")
        stations_gdf = gpd.GeoDataFrame(
            station_data,
            geometry=gpd.points_from_xy(station_data[lon_col], station_data[lat_col]),
            crs="EPSG:4326"  # WGS84 ellipsoid is a coordinate system used in Google Earth and GSP systems
        )
        
        # Reproject to match map projection since the Voronoi diagram is in EPSG:3857
        print("Reprojecting to Web Mercator EPSG:3857 for Voronoi diagram calculations")
        stations_gdf = stations_gdf.to_crs("EPSG:3857")  # Web Mercator
        # This EPSG:3857 makes the the X/Y coordinates in meters, which is suitable for Voronoi diagram calculations
        # Basically it makes it square, so the Voronoi diagram is not distorted by the curvature of the earth
        
        # Print information about Denmark boundary
        # print(f"Denmark GDF info: {denmark_gdf.shape}")
        # print(f"Denmark GDF columns: {denmark_gdf.columns.tolist()}")
        # print(f"Denmark CRS: {denmark_gdf.crs}")
        
        # Create Voronoi diagram
        print("Creating Voronoi diagram...")
        coords = np.array([(p.x, p.y) for p in stations_gdf.geometry])
        print(f"Number of station coordinates: {len(coords)}")
        
        # Check for duplicate or very close points
        _, unique_indices = np.unique(np.round(coords, decimals=5), axis=0, return_index=True)
        if len(unique_indices) < len(coords):
            print(f"WARNING: Found {len(coords) - len(unique_indices)} potential duplicate stations. Using only unique locations.")
            coords = coords[np.sort(unique_indices)]
            # Adjust stations_gdf to match unique points
            stations_gdf = stations_gdf.iloc[np.sort(unique_indices)].copy()
        
        # Get Denmark boundary for clipping
        boundary = denmark_gdf.geometry.union_all().bounds
        print(f"Denmark bounds:")
        print(f"\tSW corner city: {boundary[0]}, {boundary[1]}")
        print(f"\tNE corner city: {boundary[2]}, {boundary[3]}")

        boundary_width = boundary[2] - boundary[0]
        boundary_height = boundary[3] - boundary[1]

        # Add corner points to ensure complete coverage
        corner_points = [
            [boundary[0] - boundary_width, boundary[1] - boundary_height],
            [boundary[2] + boundary_width, boundary[1] - boundary_height],
            [boundary[0] - boundary_width, boundary[3] + boundary_height],
            [boundary[2] + boundary_width, boundary[3] + boundary_height]
        ]
        
        all_points = np.vstack([coords, corner_points])
        print(f"Total points for Voronoi (including corners): {len(all_points)}")
        
        try:
            vor = Voronoi(all_points)
            print(f"Voronoi diagram created with {len(vor.points)} points and {len(vor.vertices)} vertices")
        except Exception as vor_error:
            print(f"ERROR creating Voronoi diagram: {vor_error}")
            # Add jitter to points to avoid collinearity issues
            jitter = np.random.normal(0, 0.00001, all_points.shape)
            all_points = all_points + jitter
            print("Added small jitter to points to avoid numerical issues, retrying...")
            vor = Voronoi(all_points)
        
        # Get Voronoi polygons
        print("Converting Voronoi diagram to polygons...")
        regions, vertices = voronoi_finite_polygons_2d(vor)
        print(f"Created {len(regions)} Voronoi regions")
        
        # Create clipped polygons for each station 
        # this is because the Voronoi polygons can be infinite. So we need to clip them to the Denmark boundary
        print("Creating clipped polygons so that they are within the Denmark boundary...")
        voronoi_polygons = []
        valid_station_ids = []
        
        for i, region in enumerate(regions):
            if i < len(coords):  # Skip corner points
                try:
                    polygon = Polygon([vertices[v] for v in region])
                    # Clip polygon to Denmark boundary
                    clipped_polygon = polygon.intersection(denmark_gdf.geometry.union_all())
                    if not clipped_polygon.is_empty:
                        voronoi_polygons.append(clipped_polygon)
                        valid_station_ids.append(stations_gdf.iloc[i][id_column])
                except Exception as poly_error:
                    print(f"ERROR creating polygon for region {i}: {poly_error}")
                    continue
        
        print(f"Created {len(voronoi_polygons)} valid polygons")
        
        # Create GeoDataFrame with coverage areas, meaning that the polygons are the coverage areas of the stations
        coverage_gdf = gpd.GeoDataFrame(
            {'station_id': valid_station_ids},
            geometry=voronoi_polygons,
            crs=stations_gdf.crs
        )
        
        # TODO: Add avg_precipitation data to coverage areas if available
        print("Adding avg preciptation to the station data of the polygons...")
        # load in the precipitation data
        precipitation_data = pd.read_parquet('data/raw/precipitation_imputed_data.parquet')
        precipitation_data = precipitation_data.clip(lower=0, upper=100) 
        # drop nans
        # precipitation_data.dropna(inplace=True) # inplace means that 

        # precipitation_data has columns indexed by station IDs with the mm values
        avg_prec = precipitation_data.mean(axis=0, skipna=True)
        print(f"Highest average precipitation: {avg_prec.max()}")
        print(f"Lowest average precipitation: {avg_prec.min()}")

        # check the avg_precipitation values as a histogram
        plt.figure(figsize=(10, 6))
        plt.title("Average Hourly Precipitation Histogram")
        plt.xlabel("Average Hourly Precipitation (mm)")
        plt.ylabel("Frequency")
        plt.hist(avg_prec, bins=30, color='blue', alpha=0.7)
        plt.grid(axis='y', alpha=0.75)
        plt.tight_layout()
        plt.savefig("outputs/plots/avg_precipitation_histogram.png")
        
        # Return the coverage GeoDataFrame and stations GeoDataFrame
        # File saving is handled in create_full_coverage() to avoid duplication
        return coverage_gdf, stations_gdf
    
    except Exception as e:
        print(f"ERROR creating precipitation coverage: {e}")
        import traceback
        traceback.print_exc()
        return None, None

def create_full_coverage():
    """
    Create coverage areas for precipitation stations across Denmark.
    
    Returns:
    --------
    tuple: (GeoJSONDataSource, GeoDataFrame, GeoDataFrame)
        Coverage as GeoJSON source - contains geo data in GeoJSON format
        coverage GeoDataFrame      - 
        stations GeoDataFrame
    
    Called by: main
    Calls: create_precipitation_coverage(), gu.gdf_to_geojson()
    """
    # Create directories for output files
    os.makedirs("data/processed", exist_ok=True)
    os.makedirs("data/raw", exist_ok=True)
    
    # Create a simplified Denmark boundary manually
    print("Creating simplified Denmark boundary...")
    # Approximate Denmark bounding box in EPSG:4326 (WGS84)
    # These coordinates represent a rough bounding box around Denmark
    denmark_coords = [
        (8.0, 54.5),   # Southwest 
        (8.0, 57.8),   # Northwest 
        (13.0, 57.8),  # Northeast 
        (13.0, 54.5),  # Southeast
        (8.0, 54.5)    # to close the polygon
    ]
    
    # Create a polygon and convert to GeoDataFrame
    denmark_polygon = Polygon(denmark_coords)
    denmark_polygon_gdf = gpd.GeoDataFrame(
        {'name': ['Denmark']}, 
        geometry=[denmark_polygon], 
        crs="EPSG:4326"
    ).to_crs(epsg=3857)
    print("Using simplified Denmark boundary")

    # Create station coverage areas
    print("\nCreating station coverage areas")
    coverage_geojson_gdf, stations_gdf = create_precipitation_coverage(denmark_polygon_gdf)
    
    if coverage_geojson_gdf is not None and not coverage_geojson_gdf.empty:
        print(f"Successfully created polygon coverage GeoDataFrame with {len(coverage_geojson_gdf)} polygons")
        
        # Save the GeoJSON file using various methods as fallbacks
        try:
            coverage_geojson_gdf.to_file("data/raw/precipitation_coverage.geojson", driver="GeoJSON")
            print("Saved coverage GeoDataFrame to data/raw/precipitation_coverage.geojson")
        except AttributeError as e:
            if "module 'pyogrio' has no attribute 'write_dataframe'" in str(e):
                print("ERROR saving 'precipitation_coverage.geojson' due to pyogrio error has no attribute 'write_dataframe'") 
                try:
                    # Try using fiona driver directly
                    import fiona
                    coverage_geojson_gdf.to_file(
                        "data/raw/precipitation_coverage.geojson", 
                        driver="GeoJSON",
                        engine="fiona"
                    )
                    print("Saved coverage GeoDataFrame using fiona engine")
                except Exception as fiona_error:
                    print(f"ERROR Fiona method failed also: {fiona_error}")
                    try:
                        # Last resort: manually create GeoJSON
                        import json
                        geojson_dict = json.loads(gu.gdf_to_geojson(coverage_geojson_gdf))
                        with open("data/raw/precipitation_coverage.geojson", "w") as f:
                            json.dump(geojson_dict, f)
                        print("Saved coverage GeoDataFrame using manual JSON conversion")
                    except Exception as json_error:
                        print(f"ERROR Manual JSON conversion failed: {json_error}")
                        print("ERROR: Could not save precipitation coverage file")
            else:
                print(f"ERROR Could not save precipitation coverage areas: {e}")
        except Exception as general_error:
            print(f"ERROR Could not save precipitation coverage areas: {general_error}")
        
        return coverage_geojson_gdf, stations_gdf
    
    else:
        print("No valid coverage areas created. Skipping GeoJSON creation.")
        return None, None, None

In [None]:
###########################################
# SECTION 2: SOIL AND SEDIMENT ANALYSIS   #
###########################################

def sediment_types_for_station(stationId, precipitationCoverageStations, sedimentCoverage):
    """
    Get all soil types contained within a station area.
    
    Parameters:
    -----------    
    stationId : str
        Station ID to filter soil types
    precipitationCoverageStations : geoDataFrame
        GeoDataFrame containing precipitation coverage data
    sedimentCoverage : geoDataFrame
        GeoDataFrame containing sediment coverage data
    
    Returns:
    --------
    list : List of soil types within the station area or empty list if station not found
    
    Called by: load_process_data()
    """
    # Debug information about the datasets
    print(f"Precipitation coverage CRS: {precipitationCoverageStations.crs}")
    print(f"Sediment coverage CRS: {sedimentCoverage.crs}")
    
    # Check if 'stationId' or 'station_id' column exists
    id_column = 'station_id'
    if 'stationId' in precipitationCoverageStations.columns:
        id_column = 'stationId'
    
    print(f"Using {id_column} to identify stations.")
    
    # Get the stations matching the stationId
    matching_stations = precipitationCoverageStations[precipitationCoverageStations[id_column] == stationId]
    
    # Check if we found any matching stations
    if matching_stations.empty:
        print(f"Warning: No station found with ID {stationId}")
        return []
    
    # Get the geometry of the station
    station_geometry = matching_stations.geometry.iloc[0]
    
    # Print geometry information for debugging
    print(f"Station geometry type: {station_geometry.geom_type}")
    print(f"Station geometry bounds: {station_geometry.bounds}")
    
    # Ensure both datasets use the same CRS
    if precipitationCoverageStations.crs != sedimentCoverage.crs:
        #print(f"CRS mismatch! Reprojecting sediment coverage to {precipitationCoverageStations.crs}")
        sedimentCoverage = sedimentCoverage.to_crs(precipitationCoverageStations.crs)
    
    # Check if the geometries are valid
    if not station_geometry.is_valid:
        print("Station geometry is invalid! Attempting to fix...")
        station_geometry = station_geometry.buffer(0)
    
    # Use buffer to account for possible precision issues
    # This creates a small buffer around the station geometry to increase chances of intersection
    buffered_geometry = station_geometry.buffer(1)  # 1 meter buffer

    # Try with the buffered geometry first
    sediment_types_buffered = sedimentCoverage[sedimentCoverage.intersects(buffered_geometry)]
    
    if not sediment_types_buffered.empty:
        print(f"Found {len(sediment_types_buffered)} sediment features using buffered geometry")
        sediment_types = sediment_types_buffered
    else:
        # If buffered approach fails, try with original geometry
        sediment_types = sedimentCoverage[sedimentCoverage.intersects(station_geometry)]
        if sediment_types.empty:
            # Check if any sediment polygons are nearby
            # This helps diagnose if the issue is with projection or data
            buffer_distance = 1000  # 1 km
            large_buffer = station_geometry.buffer(buffer_distance)
            nearby_sediments = sedimentCoverage[sedimentCoverage.intersects(large_buffer)]
            
            if not nearby_sediments.empty:
                print(f"Found {len(nearby_sediments)} sediment features within {buffer_distance}m")
                print("The issue might be with projection or precision")
            else:
                print(f"No sediment features found even within {buffer_distance}m")
                print("The station might be outside the sediment coverage area")
            
            # Print a sample of sediment geometries to compare
            if not sedimentCoverage.empty:
                sample_sediment = sedimentCoverage.iloc[0]
                print(f"Sample sediment bounds: {sample_sediment.geometry.bounds}")
            
            return []
    
    # Check if 'tsym' column exists
    if 'tsym' not in sediment_types.columns:
        # Try to find a suitable column for soil types
        print(f"Available columns in sediment data: {sediment_types.columns.tolist()}")
        soil_type_columns = [col for col in sediment_types.columns if 'type' in col.lower() or 'sym' in col.lower() or 'soil' in col.lower()]
        
        if soil_type_columns:
            soil_type_column = soil_type_columns[0]
            print(f"Using '{soil_type_column}' instead of 'tsym' for soil types")
            soil_types = sediment_types[soil_type_column].unique().tolist()
        else:
            print(f"No suitable soil type column found")
            return []
    else:
        # Extract soil types from the filtered data
        soil_types = sediment_types['tsym'].unique().tolist()
    
    print(f"Found soil types: {soil_types}")
    return soil_types

def gather_soil_types(purculation_mapping):
    """
    Create a dictionary of soil types with their average percolation rates.
    
    Parameters:
    -----------    
    purculation_mapping : dict
        Dictionary with soil types as keys and min/max percolation rates
        
    Returns:
    --------
    dict
        Dictionary with soil types as keys and average percolation rates as values
    
    Called by: load_process_data()
    """
    # Take perculation Keys and the min and max / 2 and add to a dict
    soil_types = {}
    for key, value in purculation_mapping.items():
        min = 0.0001 if value['min'] == 0 else value['min']
        max = 0.9999 if value['max'] == 1 else value['max']
            
        soil_types[key] = (min + max) / 2
    return soil_types


In [None]:
###########################################
# SECTION 3: WATER CALCULATIONS           #
###########################################

def calculate_water_on_ground(df, soil_types, absorbtions, station):
    """
    Calculate water on ground for specific soil types and station.

    Parameters:
    -----------    
    df : pandas.DataFrame
        Dataframe containing precipitation data with a 'Nedbor' column
    soil_types : list
        List of soil types to calculate water on ground for
    absorbtions : dict
        Dictionary with soil types and their absorption rates
    station : str
        Station ID to use in column naming
    
    Returns:
    --------
    pandas.DataFrame: Dataframe with water on ground values for the specified soil types
    
    Called by: load_process_data()
    """
    # Get precipitation values as numpy array for faster calculations
    precip_array = df['Nedbor'].values
    n = len(precip_array)
    
    # Process all soil types at once using numpy operations
    soil_type_data = {}
    new_columns = {}  # Dict to collect all columns before creating DataFrame
    valid_soil_types = [st for st in soil_types if st in absorbtions]
    
    if not valid_soil_types:
        print(f"No valid soil types with known absorption rates for station {station}")
        return df.copy()
        
    # Pre-allocate numpy arrays for all calculations to avoid memory allocations in loops
    for soil_type in valid_soil_types:
        rate = absorbtions[soil_type] 
        soil_type_data[soil_type] = {
            'rate': rate,
            'wog_array': np.zeros(n),
            'observed': np.zeros(n, dtype=int),
            'tte': np.full(n, n),  # Fill with max value initially
            'duration': np.zeros(n, dtype=int)
        }
    
    # Parallel WOG calculation for each soil type using vectorized operations where possible
    for soil_type, data in soil_type_data.items():
        rate = data['rate']
        wog = data['wog_array']
        
        # First time step
        wog[0] = max(0, precip_array[0])
        
        # Vectorized recurrence relation using a cumulative approach
        for i in range(1, n):
            wog[i] = max(0, wog[i-1] * (1 - rate) + precip_array[i])
        
        # Calculate observed state (> threshold)
        # wog_window = np.convolve(wog, np.ones(3)/3, mode='same')  # 3-hour window
        # data['observed'] = (wog_window > 5).astype(int)

        data['observed'] = (wog > 5).astype(int) # CHANGE HERE!
        
        # Find event indices
        event_indices = np.where(data['observed'] == 1)[0]
        # First pass: Calculate time until next event (survival analysis approach)
        tte = np.full(n, n)  # Default to maximum for censored observations
        durations = np.full(n, n)  # Default to maximum

        # Mark events with time-to-event = 0
        tte[event_indices] = 0

        # For each pair of events, calculate time between them
        for i in range(len(event_indices)-1):
            start_idx = event_indices[i]
            end_idx = event_indices[i+1]
            time_between = end_idx - start_idx
            
            # Fill in counting up from 1 at non-event to event time at event
            for j in range(start_idx+1, end_idx):
                tte[j] = end_idx - j
            
            # Store the duration (time until next event)
            durations[start_idx:end_idx] = np.arange(1, time_between+1)

        # For observations after the last event, they're all censored
        if len(event_indices) > 0:
            last_event = event_indices[-1]
            durations[last_event+1:] = np.arange(1, n-last_event)

        # Store calculated values
        data['tte'] = tte
        data['duration'] = durations
            
        # Add columns to the dictionary
        new_columns[f'{station}_WOG_{soil_type}'] = data['wog_array']
        new_columns[f"{station}_{soil_type}_observed"] = data['observed']
        new_columns[f'{station}_{soil_type}_TTE'] = data['tte']
        new_columns[f'{station}_{soil_type}_duration'] = data['duration']
    
    # Create new DataFrame with all columns at once
    new_df = pd.DataFrame(new_columns, index=df.index)
    
    # Combine with original data
    result_df = pd.concat([df, new_df], axis=1)
    
    return result_df

In [None]:
###########################################
# SECTION 4: DATA LOADING AND SAVING      #
###########################################

def load_process_data(coverage_data=None, sediment_data=None):
    """
    Load precipitation data and calculate water-on-ground values for different soil types.
    
    Parameters:
    -----------
    coverage_data : GeoDataFrame, optional
        Pre-loaded precipitation coverage data
    sediment_data : GeoDataFrame, optional
        Pre-loaded sediment coverage data
    
    Returns:
    --------
    pandas.DataFrame: Processed data with soil type observations and durations
    
    Called by: main
    Calls: gather_soil_types(), sediment_types_for_station(), calculate_water_on_ground(), save_preprocessed_data()
    """
    try:
        # Load the data
        print("\nLoading precipitation imputed data...")
        df = pd.read_parquet("data/raw/precipitation_imputed_data.parquet")
        # print(f"Loaded precipitation data with columns: {df.columns.tolist()[:]}...")
        print(f"This is a total number of columns: {len(df.columns)}, which is the number of stations")
        # each row is an hour of precipitation data for each station
        # save as csv for easier debugging
        df.to_csv("data/raw/precipitation_imputed_data.csv", index=False)
        print(f"Precipitation data shape: {df.shape}") #(262783, 86)

        # total number of nans across all columns (all stations)
        print(f"Total number of NaNs in the data: {df.isna().sum().sum()}")

        # look column-wise for the number of nans
        stations_with_most_nans = df.isna().sum().sort_values(ascending=False)
        print(f"Top 5 Stations with the most NaNs in procent of that station: {(stations_with_most_nans / len(df) * 100).head(5)}")
        # here we can divide by len(df) because all columns have the same length

        # before clipping to remove extreme values, we need to check the data
        # Check for extreme values in the data
        # print(f"Precipitation data summary:\n{df.describe()}")

        # def length before clipping - which just replaces values lower than 0 with 0 and values higher than 60 with 60
        # clip does not replace Nans, only limits existing values 
        df = df.clip(lower=0, upper=100) 
        # check min and max values
        # The highest are the 2011 and 2014 cloudbursts, with 2014 possibly peaking around 119 mm. Official records might show around 115 mm.

        #https://international.kk.dk/sites/default/files/2021-09/Cloudburst%20Management%20plan%202010.pdf?utm_source=chatgpt.com
        # precipitation measured close to 100 mm in one hour.

        #https://web.archive.org/web/20140913151609/http://vejret.tv2.dk/artikel/id-32909558:et-af-de-kraftigste-regnvejr-nogensinde.html
        # over 100mm in 24 hours and private measurements for 160mm in 124 hours

        #https://ui.adsabs.harvard.edu/abs/2021AGUFMGC45G0892C/abstract
        #Between 90 and 135 mm of precipitation in less than 2 hours was recorded

        #https://vejr.tv2.dk/2019-12-28-her-er-de-danske-vejrrekorder-fra-de-seneste-10-aar
        # Here the record is 63mm in 30mins

        #https://vejr.tv2.dk/2016-07-02-husker-du-vejret-den-2-juli-2011-historisk-skybrud-ramte-koebenhavn
        # Kraftig regn er, når der falder mere end 24 millimeter regn over en periode på maksimalt seks timer.
        # Skybrud er, når der falder mere end 15 millimeter regn over en periode på maksimalt 30 minutter.
        # def length after clipping

        # Use provided coverage data or load from file
        if coverage_data is not None:
            precipitationCoverageStations = coverage_data
            print(f"Using provided precipitation coverage with {len(precipitationCoverageStations)} stations")
        else:
            print("Loading precipitation coverage stations...")
            precipitationCoverageStations = gu.load_geojson("precipitation_coverage.geojson")
            print(f"Loaded precipitation coverage with {len(precipitationCoverageStations)} stations")
        
        # Use provided sediment data or load from file
        if sediment_data is not None:
            sedimentCoverage = sediment_data
            print(f"Using provided sediment coverage with {len(sedimentCoverage)} features")
        else:
            print("\nLoading sediment coverage...")
            sedimentCoverage = gu.load_geojson("Sediment_wgs84.geojson")
            print(f"Loaded sediment coverage with {len(sedimentCoverage)} features")

        # Get absorption rates for each soil type
        absorbtions = gather_soil_types(pm.percolation_rates_updated)
        print(f"Gathered absorption rates for {len(absorbtions)} soil types")
        
        print(f"\nPrecipitaion dataset columns: {df.columns.tolist()}")
        stations_to_process = df.columns

        # For each station in the data, calculate the water on ground for each soil type
        for station in stations_to_process:
            print(f"Processing station {station}...")
            df_station = df[[station]].copy() # this is a single column dataframe with the precipitation data for this station
            
            # Rename the station name column to 'Nedbor' (precipitation) for consistency
            df_station.rename(columns={station: 'Nedbor'}, inplace=True)

            # length before dropping NaN values
            pre_drop_nans = len(df_station)
            print(f"  • Precipitation data {station} length before dropping NaN: {len(df_station)}")
            df_station.dropna(inplace=True) 
            # we remove the rows with NaN values, because they are not useful for our calculations 
            # as we want to calculate the water on ground only for the rows with precipitation data
            print(f"  • Removed {pre_drop_nans - len(df_station)} NaN values (procent {(pre_drop_nans - len(df_station)) / pre_drop_nans * 100:.2f}%)")
            
            if df_station.empty:
                print(f"No data for station {station}, skipping...")
                continue
            
            # Get soil types for this station
            sediment_types = sediment_types_for_station(station, precipitationCoverageStations, sedimentCoverage)
            
            if not sediment_types:
                print(f"No sediment types found for station {station}, skipping...")
                continue
                
            print(f"Found {len(sediment_types)} sediment types for station {station}")
            
            # Calculate water on ground for each soil type
            try:
                df_processed = calculate_water_on_ground(df_station, sediment_types, absorbtions, station)
                # Add processed columns to results (excluding 'Nedbor')
                result_columns = df_processed.drop(columns=['Nedbor'], errors='ignore')
                if not result_columns.empty:
                    save_preprocessed_data(result_columns, f"data/processed/survival_data_{station}.csv")
                    print(f"Saved processed data for station {station}")
            except Exception as e:
                print(f"ERROR processing station {station}: {e}")
                continue
        
        # Combine all result DataFrames at once
        return None  # Return None to indicate completion
        
    except Exception as e:
        print(f"ERROR in load_process_data: {e}")
        import traceback
        traceback.print_exc()
        return pd.DataFrame()  # Return empty DataFrame on error

def save_preprocessed_data(survival_df, output_path="data/processed/survival_data.csv"):
    """
    Save the processed survival data to a CSV file.
    
    Parameters:
    -----------    
    survival_df : DataFrame
        DataFrame containing survival data for different soil types and stations
    output_path : str
        Path to save the combined CSV file
    
    Called by: load_process_data()
    """
    import os
    
    # Ensure the output directory exists
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    # Save to Parquet for better performance
    try:
        survival_df.to_parquet(output_path.replace('.csv', '.parquet'), index=False)
        print(f"Saved preprocessed survival data to {output_path.replace('.csv', '.parquet')}")
    except Exception as e:
        print(f"ERROR saving to Parquet: {e}")
        print("Falling back to CSV format")
        try:
            survival_df.to_csv(output_path, index=False)
            print(f"Saved preprocessed survival data to {output_path}")
        except Exception as e:
            print(f"ERROR saving to CSV: {e}")
            print("Failed to save preprocessed data")
            return None
    
    return survival_df

def load_saved_data(file_path="data/processed/survival_data.csv"):
    """
    Load previously saved preprocessed data.
    
    Parameters:
    -----------    
    file_path : str
        Path to the saved data file
    
    Returns:
    --------
    dict
        Dictionary with soil types as keys and survival dataframes as values
    
    Called by: main
    """

    import os

    # load parquet file if it exists
    try:
        if os.path.exists(file_path.replace('.csv', '.parquet')):
            survival_df = pd.read_parquet(file_path.replace('.csv', '.parquet'))
            print(f"Loaded preprocessed data from {file_path.replace('.csv', '.parquet')}")
            return survival_df
    except Exception as e:
        print(f"ERROR loading Parquet data: {e}")

    try:
        # Load the combined dataframe
        survival_df = pd.read_csv(file_path)
        
        return survival_df
        
    except Exception as e:
        print(f"ERROR loading data from {file_path}: {e}")
        return None


In [None]:

###########################################
# MAIN EXECUTION                          #
###########################################

if __name__ == "__main__":

    # TODO uncomment this
    # Step 1: Create coverage areas for precipitation stations
    coverage_geojson_gdf, stations_gdf = create_full_coverage()
    # this functions saves the coverage_geojson_gdf to a file called precipitation_coverage.geojson
    if coverage_geojson_gdf is None:
        print("ERROR No valid coverage data created. Attemption to load from file.")
        try:
            coverage_geojson_gdf = gu.load_geojson("precipitation_coverage.geojson")
            print("Loaded coverage data from file.")
        except Exception as e:
            print(f"ERROR loading coverage data from file: {e}")
            print("Exiting.")
            exit(1)
    
    # Step 2: Load sediment data - this was the layer that was exported from the QGIS project
    print("Loading Sediment_wgs84.geojson...")
    sedimentCoverage = gu.load_geojson("Sediment_wgs84.geojson")
    if sedimentCoverage is None:
        print("No valid sediment data loaded. Exiting.")
        exit(1)
    
    # Step 3: Process precipitation and soil data
    # - this gathers the soil types
    # For each station, it calculates
    # * sediment types for the given station
    # * water on ground for each soil type
    # * saves the processed data to a CSV file
    load_process_data(coverage_data=coverage_geojson_gdf, sediment_data=sedimentCoverage)

    # Step 4: TESTING - Load and display sample processed data for a specific station
    station_id = '06058'  # Example station ID
    df = load_saved_data(f'data/processed/survival_data_{station_id}.csv')
    if df is not None:
        print(f"\nSample data for station {station_id}:")
        print(df.head())
    else:
        print(f"No data found for station {station_id}")

## Modelling predictions of extreme events ##

### Time Series Forecasting ###

To generate predictions for future precipitation events, we need a model capable of forecasting precipitation based on historical data. Initially, we planned to use a traditional time series forecasting model. This approach can be extended to include deep learning methods by leveraging pre-built forecasting frameworks, which are easy to use and provide strong results without requiring extensive manual tuning of parameterized models.

After reviewing available tools, we chose Meta’s NeuralProphet. NeuralProphet is a hybrid forecasting model that combines deep learning techniques with classical time series components, such as ARIMA-like autoregression and exponential smoothing (ETS). The framework is built on PyTorch, is highly scalable, and is user-friendly. In our implementation, we apply NeuralProphet to our dataset, using monthly accumulated precipitation as the forecast target. This transformation helps illustrate both the model’s capabilities and its limitations when applied to our specific use case.

In [None]:
import warnings
warnings.filterwarnings('ignore')

if station not in df.columns:
    raise ValueError("Expected column station not found.")

df[station] = df[station].clip(lower=0)
df = df[[station]]
df = df.reset_index()
df['date'] = df['index'].dt.date
df = df.rename(columns={station: 'precipitation'})
df = df.drop(columns='index')
df = df.dropna(subset=['precipitation']).reset_index(drop=True)
df['date'] = pd.to_datetime(df['date'])
df['year_month'] = df['date'].dt.to_period('M')
df_monthly = df.groupby('year_month')['precipitation'].sum().reset_index()
df_monthly['year_month'] = df_monthly['year_month'].dt.to_timestamp()
df_prophet = df_monthly.rename(columns={'year_month': 'ds', 'precipitation': 'y'})

# --- Quick data exploration ---
# plt.figure(figsize=(10, 5))
# plt.plot(df_prophet['ds'], df_prophet['y'], marker='o')
# plt.title('Monthly Precipitation Over Time')
# plt.xlabel('Date')
# plt.ylabel('Precipitation (mm)')
# plt.grid(True)
# plt.show()

# --- Train-test split for evaluation ---
split_idx = int(len(df_prophet) * 0.8)
df_train = df_prophet.iloc[:split_idx]
df_test = df_prophet.iloc[split_idx:]

# --- Initialize NeuralProphet model ---
m = NeuralProphet(
    yearly_seasonality=True,
    weekly_seasonality=False,
    daily_seasonality=False,
    #seasonality_mode='multiplicative',
    n_changepoints=20,
    changepoints_range=0.9,
    trend_reg=1,
    quantiles=[0.1, 0.9]
)

m = m.add_seasonality(name='monthly', period=30.5, fourier_order=5)

# --- Fit model on training set ---
metrics = m.fit(df_train, freq='M')

# --- Forecast on train + test periods ---
future = m.make_future_dataframe(df_train, periods=len(df_test), n_historic_predictions=True)
forecast = m.predict(future)

# # --- Plot forecast including quantiles ---
# plt.figure(figsize=(12, 6))
# plt.plot(forecast['ds'], forecast['yhat1'], label='Prediction (Median)', color='blue')
# if 'yhat1_lower' in forecast.columns and 'yhat1_upper' in forecast.columns:
#     plt.fill_between(forecast['ds'], forecast['yhat1_lower'], forecast['yhat1_upper'], color='blue', alpha=0.3, label='80% Prediction Interval')
# plt.scatter(df_prophet['ds'], df_prophet['y'], color='black', s=10, label='Actual Data')
# plt.title('Monthly Precipitation Forecast with Uncertainty Interval')
# plt.xlabel('Date')
# plt.ylabel('Precipitation (mm)')
# plt.legend()
# plt.grid(True)
# plt.show()

# --- Plot model components (trend, seasonality, etc.) ---
fig_components = m.plot_components(forecast)

# --- Plot model parameters ---
fig_parameters = m.plot_parameters()

# --- Evaluation on test set ---
forecast_test = forecast.iloc[-len(df_test):]
mae = mean_absolute_error(df_test['y'].values, forecast_test['yhat1'].values)
r2 = r2_score(df_test['y'].values, forecast_test['yhat1'].values)
print(f"Test MAE: {mae:.2f}")
print(f"Test R²: {r2:.2f}")

# # --- Plot actual vs predicted for test period ---
# plt.figure(figsize=(10, 5))
# plt.plot(df_test['ds'], df_test['y'], label='Actual', marker='o')
# plt.plot(df_test['ds'], forecast_test['yhat1'], label='Predicted', marker='x')
# if 'yhat1_lower' in forecast_test.columns and 'yhat1_upper' in forecast_test.columns:
#     plt.fill_between(df_test['ds'], forecast_test['yhat1_lower'], forecast_test['yhat1_upper'], color='blue', alpha=0.2, label='80% Interval')
# plt.title('Actual vs Predicted Precipitation (Test Set with Uncertainty)')
# plt.xlabel('Date')
# plt.ylabel('Precipitation (mm)')
# plt.legend()
# plt.grid(True)
# plt.show()

# --- Print final few metrics ---
print(metrics.tail())
# --- Forecast next 48 months into the future ---
future_48 = m.make_future_dataframe(df_prophet, periods=48, n_historic_predictions=True)  # <--- notice: True
forecast_48 = m.predict(future_48)
forecast_48['yhat1'] = forecast_48['yhat1'].clip(lower=0)
if 'yhat1_lower' in forecast_48.columns:
    forecast_48['yhat1_lower'] = forecast_48['yhat1_lower'].clip(lower=0)
    forecast_48['yhat1_upper'] = forecast_48['yhat1_upper'].clip(lower=0)


# --- Plot historical + forecasted precipitation together ---
plt.figure(figsize=(14, 6))

# Plot historical observed data
plt.scatter(df_prophet['ds'], df_prophet['y'], color='black', s=10, label='Actual Data')

# Plot fitted values (up to today)
plt.plot(forecast_48['ds'], forecast_48['yhat1'], label='Forecast (Median)', color='blue')

# Plot uncertainty if available
if 'yhat1_lower' in forecast_48.columns and 'yhat1_upper' in forecast_48.columns:
    plt.fill_between(forecast_48['ds'], forecast_48['yhat1_lower'], forecast_48['yhat1_upper'], color='blue', alpha=0.3, label='80% Prediction Interval')

plt.title('Historical and 48-Month Future Forecast of Precipitation')
plt.xlabel('Date')
plt.ylabel('Precipitation (mm)')
plt.legend()
plt.grid(True)
plt.show()



The limitations of this approach become immediately clear. While NeuralProphet provides a solid, smoothed forecast based on historical data, our objective requires a model that can respond to extreme precipitation events, such as those that could lead to flooding. NeuralProphet, and time series modelling in general, are not well suited to predict these outliers. These events are both rare and seemingly random, and time series models tend to optimize for overall accuracy, not the extremes.

 To make a time series forcasting model useful, we would essentially have to extend the model with a term, that would require the model to predict rare events. This would however, not be based on realistic predictions, unless external variables were added to this term. An option could be a multivariate model forecasting model, incorporating variables such as temperature, humidity, wind etc. However, the data collection process would increase the complexity and workload of our project considerably, and was therefore deemed unrealistic. 

As a result of this process, we switched methods, and opted to incorporate 'Survival Analysis', to predict the probability of rare events occuring in a given timeframe. This method is specifically designed to predict these rare events, and are possible to build using the percipiation data available to us.

### Survival Analysis ###

In [None]:
class SurvivalModel:
    def __init__(self, soil_type='clay'):
        self.model = ExponentialFitter()
        self.soil_type = soil_type
        self.station = None  # Placeholder for station data
        self.is_fitted = False
        self.units = 'hours'  # Default unit for duration 
    
    def train(self, df, duration_column='duration', event_column='observed'):
        """Train the survival model on dry spell durations."""
        if df is None or len(df) == 0:
            raise ValueError("Training data is empty")
            
        self.model.fit(durations=df[duration_column], event_observed=df[event_column])
        self.is_fitted = True
        return self
    
    def predict_proba(self, durations):
        """
        Predict probability of rain occurring by the given duration.
        
        Parameters:
        -----------
        durations : int, array-like
            Number of time units to predict probability for
            
        Returns:
        --------
        array-like : Probability of rain occurring by the specified duration
        """
        if not self.is_fitted:
            raise RuntimeError("Model must be trained before making predictions")
            
        # Convert to array if single value
        if isinstance(durations, (int, float)):
            durations = [durations]
            
        # Get survival function (probability of remaining dry)
        survival_probs = self.model.predict(durations)
        
        # Return probability of rain (1 - survival probability)
        return 1 - survival_probs
    
    def predict(self, year):
        """
        Predict the probability of rain occurring before a given year (int)
        
        Parameters:
        -----------
        year : int
            Number of years into the future
            
        Returns:
        --------
        float : Probability of rain occurring by the specified year
        """
        if not self.is_fitted:
            raise RuntimeError("Model must be trained before making predictions")
        
        # Get the maximum observed duration from the model
        max_observed = self.model.timeline.max()
        
        # Calculate duration based on the year
        if self.units == 'hours':
            # Scale down for better distribution across years
            # This prevents all predictions from being 100%
            if year <= 0.1:  # For very short periods
                duration = year * 365 * 24 * 0.1  # Scale down for short durations
            elif year <= 1:  # For periods up to a year
                duration = year * 365 * 24 * 0.2  # Scale down a bit
            else:  # For longer periods
                # Use a logarithmic scale to prevent saturation at 100%
                duration = min(max_observed * (1 + np.log(year)), max_observed)
        elif self.units == 'days':
            if year <= 1:
                duration = year * 365 * 0.5
            else:
                duration = min(max_observed * (1 + np.log(year)), max_observed)
        else:  # years or other units
            duration = min(year, max_observed)
            
        #print(f"Soil type: {self.soil_type}, Year: {year}, Duration: {duration}, Max observed: {max_observed}")
        
        # Get probability and apply a dampening function to avoid 100% predictions
        # as years increase
        raw_prob = float(self.predict_proba(duration))
        
        # Apply a dampening function for multi-year predictions
        if year > 1:
            # Dampened probability that approaches but never quite reaches 100%
            prob = raw_prob * (1 - 0.1 / year)
            #print(f"Year {year}: Raw prob {raw_prob:.4f}, dampened to {prob:.4f}")
        else:
            prob = raw_prob
        
        return prob
    
    def plot(self):
        """Plot the survival function."""
        if not self.is_fitted:
            raise RuntimeError("Model must be trained before plotting")
        
        plt.figure(figsize=(10, 6))
        self.model.plot_cumulative_density()
        plt.title(f"Survival Function for {self.soil_type}")
        plt.xlabel("Duration (hours)")
        plt.ylabel("Survival Probability")
        plt.grid()
        plt.savefig(f"{self.soil_type}_survival_function.png")

    def save(self, path):
        """
        Save the fitted Kaplan-Meier model to disk.
        
        Parameters:
        -----------
        path : str
            File path where the model should be saved
        
        Returns:
        --------
        self : SurvivalModel
            Returns self for method chaining
        """
        if not self.is_fitted:
            raise RuntimeError("Model must be trained before saving")
        
        # Create directory if it doesn't exist
        os.makedirs(os.path.dirname(path), exist_ok=True)
        
        # Save model and metadata
        model_data = {
            'model': self.model,
            'soil_type': self.soil_type,
            'station': self.station,
            'units': self.units,
            'is_fitted': self.is_fitted
        }
        pd.to_pickle(model_data, path)
        return self
    
    def load(self, path):
        """
        Load a fitted Kaplan-Meier model from disk.
        
        Parameters:
        -----------
        path : str
            File path to the saved model
            
        Returns:
        --------
        self : SurvivalModel
            Returns self with loaded model
        """
        # Load model and metadata
        model_data = pd.read_pickle(path)
        
        # Restore model attributes
        self.model = model_data['model']
        self.soil_type = model_data['soil_type']
        self.station = model_data['station']
        self.is_fitted = model_data['is_fitted']
        self.units = model_data['units']
        
        return self

class FloodModel:
    def __init__(self):
        self.models = {}
        self.is_fitted = False
        self.units = 'hours'
        self.soil_types = ["DG - Meltwater gravel", "DS - Meltwater sand"]
        self.stations = []
        self.available_soil_types = []
    
    def add_station(self, station, survival_df, soiltypes):
        """
        Add station data to the flood model and train survival models for each soil type.
        
        Parameters:
        -----------
        station : str
            Station identifier
        survival_df : pandas.DataFrame
            DataFrame containing survival data for the station
        soiltypes : list
            List of soil types to train models for this station
            
        Returns:
        --------
        self : FloodModel
            Returns self for method chaining
        """
        if station not in self.stations:
            self.stations.append(station)
            
        # Create models for each soil type in this station
        for soil_type in soiltypes:
            # Create column names based on pattern in the dataframe
            duration_column = f"{station}_{soil_type}_duration"
            event_column = f"{station}_{soil_type}_observed"
            
            # Check if needed columns exist
            if duration_column in survival_df.columns and event_column in survival_df.columns:
                # Filter out any missing values
                valid_data = survival_df[[duration_column, event_column]].dropna()
                
                if len(valid_data) > 0:
                    # Create a model for this station-soil combination
                    model_key = f"{station}_{soil_type}"
                    
                    # Create and train the model
                    model = SurvivalModel(soil_type=soil_type)
                    model.station = station
                    model.train(
                        valid_data.rename(columns={
                            duration_column: 'duration',
                            event_column: 'observed'
                        }),
                        'duration', 
                        'observed'
                    )
                    
                    # Add to our models dictionary
                    self.models[model_key] = model
                    
                    # Add to available soil types if not already there
                    if soil_type not in self.available_soil_types:
                        self.available_soil_types.append(soil_type)
                    
                    print(f"Trained model for station {station}, soil type {soil_type} with {len(valid_data)} observations")
                else:
                    print(f"No valid data for station {station}, soil type {soil_type}")
            else:
                print(f"Missing columns for station {station}, soil type {soil_type}")
        
        # Mark as fitted if we have any models
        if self.models:
            self.is_fitted = True
            
        return self
    
    def save(self, path, split_by_station=True):
        """
        Save the FloodModel to disk.
        
        Parameters:
        -----------
        path : str
            File path where the model should be saved
        split_by_station : bool
            If True, save each station's models in a separate file
            
        Returns:
        --------
        self : FloodModel
            Returns self for method chaining
        """
        if not self.is_fitted:
            raise RuntimeError("Model must be trained before saving")
        
        # Create directory if it doesn't exist
        os.makedirs(os.path.dirname(path), exist_ok=True)
        
        # Split storage by station
        if split_by_station and len(self.models) > 0:
            # Extract base directory and filename without extension
            base_dir = os.path.dirname(path)
            base_name = os.path.splitext(os.path.basename(path))[0]
            
            # Create stations directory
            stations_dir = os.path.join(base_dir, f"{base_name}_stations")
            os.makedirs(stations_dir, exist_ok=True)
            
            print(f"Starting split save: {len(self.stations)} stations with {len(self.models)} models...")
            
            # Group models by station
            station_models = {}
            for model_key, model in self.models.items():
                station = model_key.split('_')[0]
                if station not in station_models:
                    station_models[station] = {}
                station_models[station][model_key] = model
            
            # Create a metadata model that references station files
            meta_model = {
                'is_fitted': True,
                'units': self.units,
                'soil_types': self.soil_types,
                'stations': self.stations,
                'available_soil_types': self.available_soil_types,
                'station_paths': {}  # Will store paths to station model files
            }
            
            # Save each station separately
            saved_files = 0
            for station, models in station_models.items():
                station_path = os.path.join(stations_dir, f"station_{station}.joblib")
                
                # Display progress every 10 stations
                if saved_files % 10 == 0:
                    print(f"Saving station {saved_files}/{len(station_models)}: {station} with {len(models)} models...")
                    
                # Save station models with joblib
                joblib.dump(models, station_path, compress=3)
                
                # Store the relative path in metadata
                meta_model['station_paths'][station] = os.path.relpath(station_path, base_dir)
                saved_files += 1
                
            # Save the metadata file
            print(f"Saving metadata to {path}...")
            joblib.dump(meta_model, path, compress=3)
            print(f"Successfully saved {saved_files} station files and metadata")
            
        else:
            # Traditional single-file save
            print(f"Starting to save {len(self.models)} models to {path}...")
            model_data = {
                'models': self.models,
                'is_fitted': self.is_fitted,
                'units': self.units,
                'soil_types': self.soil_types,
                'stations': self.stations,
                'available_soil_types': self.available_soil_types
            }
            joblib.dump(model_data, path, compress=3)
            print(f"Successfully saved model to {path}")
            
        return self

    def load(self, path, lazy_load=True):
        """
        Load a saved FloodModel from disk.
        
        Parameters:
        -----------
        path : str
            File path to the saved model
        lazy_load : bool
            If True and model was saved with split_by_station=True, 
            only load station models when requested
            
        Returns:
        --------
        self : FloodModel
            Returns self with loaded models
        """
        print(f"Loading model from {path}...")
        
        # Try to load model data
        model_data = joblib.load(path)
        
        # Check if this is a split model (metadata file)
        if isinstance(model_data, dict) and 'station_paths' in model_data:
            # This is a split model - load metadata
            self.is_fitted = model_data.get('is_fitted', False)
            self.units = model_data.get('units', 'hours')
            self.soil_types = model_data.get('soil_types', [])
            self.stations = model_data.get('stations', [])
            self.available_soil_types = model_data.get('available_soil_types', [])
            
            # Get base directory for relative paths
            base_dir = os.path.dirname(path)
            
            if lazy_load:
                # Create a proxy function for each station that will load data when needed
                self.models = {}
                print(f"Lazy-loading enabled: Referenced {len(model_data['station_paths'])} stations")
                
                # Store the station paths for later loading
                self._station_paths = {
                    station: os.path.join(base_dir, rel_path) 
                    for station, rel_path in model_data['station_paths'].items()
                }
            else:
                # Load all station models immediately
                self.models = {}
                total_stations = len(model_data['station_paths'])
                print(f"Loading all {total_stations} station models...")
                
                for i, (station, rel_path) in enumerate(model_data['station_paths'].items()):
                    station_path = os.path.join(base_dir, rel_path)
                    if i % 10 == 0:
                        print(f"Loading station {i+1}/{total_stations}: {station}...")
                    
                    try:
                        # Load the station models
                        station_models = joblib.load(station_path)
                        # Add to the main models dictionary
                        self.models.update(station_models)
                    except Exception as e:
                        print(f"Error loading station {station}: {e}")
        else:
            # Traditional single-file model
            self.models = model_data.get('models', {})
            self.is_fitted = model_data.get('is_fitted', False)
            self.units = model_data.get('units', 'hours')
            self.soil_types = model_data.get('soil_types', [])
            self.stations = model_data.get('stations', [])
            self.available_soil_types = model_data.get('available_soil_types', [])
        
        print(f"Model loaded with {len(self.stations)} stations")
        return self

    def get_station_models(self, station):
        """
        Get all models for a specific station.
        Will load from disk if using lazy loading.
        
        Parameters:
        -----------
        station : str
            Station identifier
            
        Returns:
        --------
        dict : Dictionary of models for the station
        """
        # Check if we're using lazy loading and need to load this station
        if hasattr(self, '_station_paths') and station in self._station_paths:
            # Station not loaded yet, load it now
            station_path = self._station_paths[station]
            print(f"Loading station {station} models from {station_path}...")
            
            try:
                # Load the station models
                station_models = joblib.load(station_path)
                # Add to the main models dictionary
                self.models.update(station_models)
                # Return the loaded models for this station
                return {k: v for k, v in station_models.items()}
            except Exception as e:
                print(f"Error loading station {station}: {e}")
                return {}
        
        # If not lazy loading or already loaded, filter existing models
        return {k: v for k, v in self.models.items() if k.startswith(f"{station}_")}

    def load_station(self, station, stations_dir):
        """
        Load models for a specific station from the stations directory.
        
        Parameters:
        -----------
        station : str
            Station identifier
        stations_dir : str
            Directory containing station model files
            
        Returns:
        --------
        dict : Dictionary of loaded models for this station
        """
        # Try different filename patterns
        file_patterns = [
            os.path.join(stations_dir, f"{station}.joblib"),
            os.path.join(stations_dir, f"station_{station}.joblib")
        ]
        
        loaded_models = {}
        for file_path in file_patterns:
            if os.path.exists(file_path):
                try:
                    print(f"Loading station models from {file_path}")
                    station_models = joblib.load(file_path)
                    
                    # Add models to the main models dictionary
                    if isinstance(station_models, dict):
                        for model_key, model in station_models.items():
                            self.models[model_key] = model
                            loaded_models[model_key] = model
                    
                    # Return the loaded models
                    return loaded_models
                except Exception as e:
                    print(f"Error loading station file {file_path}: {e}")
        
        print(f"No valid model file found for station {station} in {stations_dir}")
        return {}
    
    def predict_proba(self, geodata, station_coverage, year):
        """"""
        if not self.is_fitted:
            raise RuntimeError("Model must be trained before making predictions")
        
        result_geodata = geodata.copy()
        column_name = f'predictions_{year}'

        soil_type_col = 'sediment'
        if soil_type_col not in result_geodata.columns:
            # Try to find a suitable column for soil types
            possible_cols = [col for col in result_geodata.columns 
                        if 'soil' in col.lower() or 'type' in col.lower()]
            if possible_cols:
                soil_type_col = possible_cols[0]
            else:
                print("Could not find soil type column in geodata")
                return geodata
         # Initialize prediction column
        result_geodata[column_name] = None
        
        # If no station coverage provided, use all available stations with equal weight
        # Spatial join to find which station coverage area each geometry falls into
        # Ensure both GeoDataFrames have the same CRS
        if result_geodata.crs != station_coverage.crs:
            print(f"Converting station_coverage from {station_coverage.crs} to {result_geodata.crs}")
            station_coverage = station_coverage.to_crs(result_geodata.crs)
        
        # Get the station ID column
        station_id_col = 'station_id'
        if station_id_col not in station_coverage.columns:
            # Try to find a suitable station ID column
            possible_cols = [col for col in station_coverage.columns 
                           if 'station' in col.lower() and 'id' in col.lower()]
            if possible_cols:
                station_id_col = possible_cols[0]
            else:
                print("Could not find station ID column in station_coverage")
                return geodata
        
        # Process each row in geodata
        for idx, row in result_geodata.iterrows():
            geometry = row.geometry
            soil_type = row[soil_type_col]
            
            # Find which station coverage area this geometry intersects with
            intersecting_stations = station_coverage[station_coverage.intersects(geometry)]
            
            if not intersecting_stations.empty:
                # Extract first element if it's a compound soil type description
                if isinstance(soil_type, str) and ' ' in soil_type:
                    simple_type = soil_type.split(' ')[0]
                else:
                    simple_type = soil_type
                    
                # Get predictions from all intersecting stations and take the average
                predictions = []
                for _, station_row in intersecting_stations.iterrows():
                    station = station_row[station_id_col]
                    model_key = f"{station}_{simple_type}"
                    
                    
                    
                    if model_key in self.models:
                        try:
                            model = self.models[model_key]
                            predictions.append(model.predict(year))
                        except Exception as e:
                            print(f"Error predicting for {model_key}: {e}")
                
                # Calculate the average prediction if we found any models
                if predictions:
                    result_geodata.at[idx, column_name] = sum(predictions) / len(predictions)
                else:
                    # No models found for this soil type at these stations, assign a default value
                    default_value = min(0.2 + 0.05 * year, 0.0)
                    result_geodata.at[idx, column_name] = default_value
            else:
                # Geometry doesn't intersect with any station coverage area
                default_value = min(0.2 + 0.05 * year, 0.0)
                result_geodata.at[idx, column_name] = default_value
    
        # Store raw probability values before percentage conversion
        result_geodata[f'{column_name}_raw'] = result_geodata[column_name].copy()
        
        # Convert to percentage for visualization
        result_geodata[column_name] = result_geodata[column_name] * 100
        
        return result_geodata

### Execution of the model ###

In [None]:
class FloodModel:
    def __init__(self):
        self.models = {}
        self.is_fitted = False
        self.units = 'hours'
        self.soil_types = ["DG - Meltwater gravel", "DS - Meltwater sand"]
        self.stations = []
        self.available_soil_types = []
    
    def add_station(self, station, survival_df, soiltypes):
        """
        Add station data to the flood model and train survival models for each soil type.
        
        Parameters:
        -----------
        station : str
            Station identifier
        survival_df : pandas.DataFrame
            DataFrame containing survival data for the station
        soiltypes : list
            List of soil types to train models for this station
            
        Returns:
        --------
        self : FloodModel
            Returns self for method chaining
        """
        if station not in self.stations:
            self.stations.append(station)
            
        # Create models for each soil type in this station
        for soil_type in soiltypes:
            # Create column names based on pattern in the dataframe
            duration_column = f"{station}_{soil_type}_duration"
            event_column = f"{station}_{soil_type}_observed"
            
            # Check if needed columns exist
            if duration_column in survival_df.columns and event_column in survival_df.columns:
                # Filter out any missing values
                valid_data = survival_df[[duration_column, event_column]].dropna()
                
                if len(valid_data) > 0:
                    # Create a model for this station-soil combination
                    model_key = f"{station}_{soil_type}"
                    
                    # Create and train the model
                    model = SurvivalModel(soil_type=soil_type)
                    model.station = station
                    model.train(
                        valid_data.rename(columns={
                            duration_column: 'duration',
                            event_column: 'observed'
                        }),
                        'duration', 
                        'observed'
                    )
                    
                    # Add to our models dictionary
                    self.models[model_key] = model
                    
                    # Add to available soil types if not already there
                    if soil_type not in self.available_soil_types:
                        self.available_soil_types.append(soil_type)
                    
                    print(f"Trained model for station {station}, soil type {soil_type} with {len(valid_data)} observations")
                else:
                    print(f"No valid data for station {station}, soil type {soil_type}")
            else:
                print(f"Missing columns for station {station}, soil type {soil_type}")
        
        # Mark as fitted if we have any models
        if self.models:
            self.is_fitted = True
            
        return self
    
    def save(self, path, split_by_station=True):
        """
        Save the FloodModel to disk.
        
        Parameters:
        -----------
        path : str
            File path where the model should be saved
        split_by_station : bool
            If True, save each station's models in a separate file
            
        Returns:
        --------
        self : FloodModel
            Returns self for method chaining
        """
        if not self.is_fitted:
            raise RuntimeError("Model must be trained before saving")
        
        # Create directory if it doesn't exist
        os.makedirs(os.path.dirname(path), exist_ok=True)
        
        # Split storage by station
        if split_by_station and len(self.models) > 0:
            # Extract base directory and filename without extension
            base_dir = os.path.dirname(path)
            base_name = os.path.splitext(os.path.basename(path))[0]
            
            # Create stations directory
            stations_dir = os.path.join(base_dir, f"{base_name}_stations")
            os.makedirs(stations_dir, exist_ok=True)
            
            print(f"Starting split save: {len(self.stations)} stations with {len(self.models)} models...")
            
            # Group models by station
            station_models = {}
            for model_key, model in self.models.items():
                station = model_key.split('_')[0]
                if station not in station_models:
                    station_models[station] = {}
                station_models[station][model_key] = model
            
            # Create a metadata model that references station files
            meta_model = {
                'is_fitted': True,
                'units': self.units,
                'soil_types': self.soil_types,
                'stations': self.stations,
                'available_soil_types': self.available_soil_types,
                'station_paths': {}  # Will store paths to station model files
            }
            
            # Save each station separately
            saved_files = 0
            for station, models in station_models.items():
                station_path = os.path.join(stations_dir, f"station_{station}.joblib")
                
                # Display progress every 10 stations
                if saved_files % 10 == 0:
                    print(f"Saving station {saved_files}/{len(station_models)}: {station} with {len(models)} models...")
                    
                # Save station models with joblib
                joblib.dump(models, station_path, compress=3)
                
                # Store the relative path in metadata
                meta_model['station_paths'][station] = os.path.relpath(station_path, base_dir)
                saved_files += 1
                
            # Save the metadata file
            print(f"Saving metadata to {path}...")
            joblib.dump(meta_model, path, compress=3)
            print(f"Successfully saved {saved_files} station files and metadata")
            
        else:
            # Traditional single-file save
            print(f"Starting to save {len(self.models)} models to {path}...")
            model_data = {
                'models': self.models,
                'is_fitted': self.is_fitted,
                'units': self.units,
                'soil_types': self.soil_types,
                'stations': self.stations,
                'available_soil_types': self.available_soil_types
            }
            joblib.dump(model_data, path, compress=3)
            print(f"Successfully saved model to {path}")
            
        return self

    def load(self, path, lazy_load=True):
        """
        Load a saved FloodModel from disk.
        
        Parameters:
        -----------
        path : str
            File path to the saved model
        lazy_load : bool
            If True and model was saved with split_by_station=True, 
            only load station models when requested
            
        Returns:
        --------
        self : FloodModel
            Returns self with loaded models
        """
        print(f"Loading model from {path}...")
        
        # Try to load model data
        model_data = joblib.load(path)
        
        # Check if this is a split model (metadata file)
        if isinstance(model_data, dict) and 'station_paths' in model_data:
            # This is a split model - load metadata
            self.is_fitted = model_data.get('is_fitted', False)
            self.units = model_data.get('units', 'hours')
            self.soil_types = model_data.get('soil_types', [])
            self.stations = model_data.get('stations', [])
            self.available_soil_types = model_data.get('available_soil_types', [])
            
            # Get base directory for relative paths
            base_dir = os.path.dirname(path)
            
            if lazy_load:
                # Create a proxy function for each station that will load data when needed
                self.models = {}
                print(f"Lazy-loading enabled: Referenced {len(model_data['station_paths'])} stations")
                
                # Store the station paths for later loading
                self._station_paths = {
                    station: os.path.join(base_dir, rel_path) 
                    for station, rel_path in model_data['station_paths'].items()
                }
            else:
                # Load all station models immediately
                self.models = {}
                total_stations = len(model_data['station_paths'])
                print(f"Loading all {total_stations} station models...")
                
                for i, (station, rel_path) in enumerate(model_data['station_paths'].items()):
                    station_path = os.path.join(base_dir, rel_path)
                    if i % 10 == 0:
                        print(f"Loading station {i+1}/{total_stations}: {station}...")
                    
                    try:
                        # Load the station models
                        station_models = joblib.load(station_path)
                        # Add to the main models dictionary
                        self.models.update(station_models)
                    except Exception as e:
                        print(f"Error loading station {station}: {e}")
        else:
            # Traditional single-file model
            self.models = model_data.get('models', {})
            self.is_fitted = model_data.get('is_fitted', False)
            self.units = model_data.get('units', 'hours')
            self.soil_types = model_data.get('soil_types', [])
            self.stations = model_data.get('stations', [])
            self.available_soil_types = model_data.get('available_soil_types', [])
        
        print(f"Model loaded with {len(self.stations)} stations")
        return self

    def get_station_models(self, station):
        """
        Get all models for a specific station.
        Will load from disk if using lazy loading.
        
        Parameters:
        -----------
        station : str
            Station identifier
            
        Returns:
        --------
        dict : Dictionary of models for the station
        """
        # Check if we're using lazy loading and need to load this station
        if hasattr(self, '_station_paths') and station in self._station_paths:
            # Station not loaded yet, load it now
            station_path = self._station_paths[station]
            print(f"Loading station {station} models from {station_path}...")
            
            try:
                # Load the station models
                station_models = joblib.load(station_path)
                # Add to the main models dictionary
                self.models.update(station_models)
                # Return the loaded models for this station
                return {k: v for k, v in station_models.items()}
            except Exception as e:
                print(f"Error loading station {station}: {e}")
                return {}
        
        # If not lazy loading or already loaded, filter existing models
        return {k: v for k, v in self.models.items() if k.startswith(f"{station}_")}

    def load_station(self, station, stations_dir):
        """
        Load models for a specific station from the stations directory.
        
        Parameters:
        -----------
        station : str
            Station identifier
        stations_dir : str
            Directory containing station model files
            
        Returns:
        --------
        dict : Dictionary of loaded models for this station
        """
        # Try different filename patterns
        file_patterns = [
            os.path.join(stations_dir, f"{station}.joblib"),
            os.path.join(stations_dir, f"station_{station}.joblib")
        ]
        
        loaded_models = {}
        for file_path in file_patterns:
            if os.path.exists(file_path):
                try:
                    print(f"Loading station models from {file_path}")
                    station_models = joblib.load(file_path)
                    
                    # Add models to the main models dictionary
                    if isinstance(station_models, dict):
                        for model_key, model in station_models.items():
                            self.models[model_key] = model
                            loaded_models[model_key] = model
                    
                    # Return the loaded models
                    return loaded_models
                except Exception as e:
                    print(f"Error loading station file {file_path}: {e}")
        
        print(f"No valid model file found for station {station} in {stations_dir}")
        return {}
    
    def predict_proba(self, geodata, station_coverage, year):
        """"""
        if not self.is_fitted:
            raise RuntimeError("Model must be trained before making predictions")
        
        result_geodata = geodata.copy()
        column_name = f'predictions_{year}'

        soil_type_col = 'sediment'
        if soil_type_col not in result_geodata.columns:
            # Try to find a suitable column for soil types
            possible_cols = [col for col in result_geodata.columns 
                        if 'soil' in col.lower() or 'type' in col.lower()]
            if possible_cols:
                soil_type_col = possible_cols[0]
            else:
                print("Could not find soil type column in geodata")
                return geodata
         # Initialize prediction column
        result_geodata[column_name] = None
        
        # If no station coverage provided, use all available stations with equal weight
        # Spatial join to find which station coverage area each geometry falls into
        # Ensure both GeoDataFrames have the same CRS
        if result_geodata.crs != station_coverage.crs:
            print(f"Converting station_coverage from {station_coverage.crs} to {result_geodata.crs}")
            station_coverage = station_coverage.to_crs(result_geodata.crs)
        
        # Get the station ID column
        station_id_col = 'station_id'
        if station_id_col not in station_coverage.columns:
            # Try to find a suitable station ID column
            possible_cols = [col for col in station_coverage.columns 
                           if 'station' in col.lower() and 'id' in col.lower()]
            if possible_cols:
                station_id_col = possible_cols[0]
            else:
                print("Could not find station ID column in station_coverage")
                return geodata
        
        # Process each row in geodata
        for idx, row in result_geodata.iterrows():
            geometry = row.geometry
            soil_type = row[soil_type_col]
            
            # Find which station coverage area this geometry intersects with
            intersecting_stations = station_coverage[station_coverage.intersects(geometry)]
            
            if not intersecting_stations.empty:
                # Extract first element if it's a compound soil type description
                if isinstance(soil_type, str) and ' ' in soil_type:
                    simple_type = soil_type.split(' ')[0]
                else:
                    simple_type = soil_type
                    
                # Get predictions from all intersecting stations and take the average
                predictions = []
                for _, station_row in intersecting_stations.iterrows():
                    station = station_row[station_id_col]
                    model_key = f"{station}_{simple_type}"
                    
                    
                    
                    if model_key in self.models:
                        try:
                            model = self.models[model_key]
                            predictions.append(model.predict(year))
                        except Exception as e:
                            print(f"Error predicting for {model_key}: {e}")
                
                # Calculate the average prediction if we found any models
                if predictions:
                    result_geodata.at[idx, column_name] = sum(predictions) / len(predictions)
                else:
                    # No models found for this soil type at these stations, assign a default value
                    default_value = min(0.2 + 0.05 * year, 0.0)
                    result_geodata.at[idx, column_name] = default_value
            else:
                # Geometry doesn't intersect with any station coverage area
                default_value = min(0.2 + 0.05 * year, 0.0)
                result_geodata.at[idx, column_name] = default_value
    
        # Store raw probability values before percentage conversion
        result_geodata[f'{column_name}_raw'] = result_geodata[column_name].copy()
        
        # Convert to percentage for visualization
        result_geodata[column_name] = result_geodata[column_name] * 100
        
        return result_geodata

In [None]:
import aba_flooding.model as md
import pandas as pd
# import geopandas as gpd
# import geo_utils as gu
import matplotlib.pyplot as plt
import os
import time
import cProfile
import pstats
from io import StringIO
from aba_flooding.preprocess import load_saved_data
from lifelines import KaplanMeierFitter
import multiprocessing
from functools import partial

def process_station_file(file, processed_data_path, profile=False):
    """
    Process a single station file - can be run in parallel
    
    Returns:
    --------
    tuple: (station ID, survival models dict, timing info dict)
    """
    try:
        station_start_time = time.time()
        
        # Extract station ID from filename
        station = file.replace("survival_data_", "").replace(".parquet", "")
        
        # Load the data for the station
        file_path = os.path.join(processed_data_path, file)
        
        load_start_time = time.time()
        survival_df = load_saved_data(file_path)
        load_time = time.time() - load_start_time
        
        if survival_df is None or survival_df.empty:
            print(f"Skipping empty data file for station {station}")
            return station, None, {
                'station': station,
                'status': 'skipped',
                'reason': 'empty data'
            }
            
        # Identify soil types in this dataframe
        soil_types = set()
        for column in survival_df.columns:
            parts = column.split('_')
            # Check if this is a station-soil column
            if len(parts) >= 3 and parts[0] == station:
                if parts[1] != "WOG":  # Skip WOG columns
                    soil_types.add(parts[1])
        
        # Create station models
        station_models = {}
        
        # Run profiling if enabled
        if profile:
            print(f"\nProfiling station {station}...")
            profiler = cProfile.Profile()
            profiler.enable()
            
            # Process each soil type
            for soil_type in soil_types:
                # Create column names
                duration_column = f"{station}_{soil_type}_duration"
                event_column = f"{station}_{soil_type}_observed"
                
                # Check if columns exist
                if duration_column in survival_df.columns and event_column in survival_df.columns:
                    valid_data = survival_df[[duration_column, event_column]].dropna()
                    
                    if len(valid_data) > 0:
                        # Create and train the model
                        model = md.SurvivalModel(soil_type=soil_type)
                        model.station = station
                        model.train(
                            valid_data.rename(columns={
                                duration_column: 'duration',
                                event_column: 'observed'
                            }),
                            'duration', 
                            'observed'
                        )
                        
                        # Add to our local models dictionary
                        model_key = f"{station}_{soil_type}"
                        station_models[model_key] = model
            
            profiler.disable()
            
            # Print profile results
            s = StringIO()
            ps = pstats.Stats(profiler, stream=s).sort_stats('cumulative')
            ps.print_stats(20)  # Top 20 functions by time
            print(s.getvalue())
        else:
            # Add station to the model without profiling
            train_start_time = time.time()
            
            # Process each soil type
            for soil_type in soil_types:
                # Create column names
                duration_column = f"{station}_{soil_type}_duration"
                event_column = f"{station}_{soil_type}_observed"
                
                # Check if columns exist
                if duration_column in survival_df.columns and event_column in survival_df.columns:
                    valid_data = survival_df[[duration_column, event_column]].dropna()
                    
                    if len(valid_data) > 0:
                        # Create and train the model
                        model = md.SurvivalModel(soil_type=soil_type)
                        model.station = station
                        model.train(
                            valid_data.rename(columns={
                                duration_column: 'duration',
                                event_column: 'observed'
                            }),
                            'duration', 
                            'observed'
                        )
                        
                        # Add to our local models dictionary
                        model_key = f"{station}_{soil_type}"
                        station_models[model_key] = model
            
            train_time = time.time() - train_start_time
        
        station_time = time.time() - station_start_time
        
        # Prepare timing info
        timing = {
            'station': station,
            'total_time': station_time,
            'load_time': load_time,
            'train_time': station_time - load_time,
            'soil_types': len(soil_types),
            'models_trained': len(station_models)
        }
        
        print(f"Station {station}: {station_time:.2f}s (Load: {load_time:.2f}s, Train: {station_time - load_time:.2f}s)")
        
        return station, station_models, timing
        
    except Exception as e:
        print(f"Error processing station file {file}: {e}")
        import traceback
        traceback.print_exc()
        return file, None, {'status': 'error', 'error': str(e)}

def train_all_models(output_path="models/flood_model.pkl", profile=False, parallel=True, max_workers=None):
    """
    Train survival models for all stations and soil types from processed parquet files.
    
    Parameters:
    -----------
    output_path : str
        Path where the trained model will be saved
    profile : bool
        Whether to run detailed profiling for each station
    parallel : bool
        Whether to use parallel processing (default: True)
    max_workers : int or None
        Maximum number of parallel workers (default: CPU count - 1)
        
    Returns:
    --------
    FloodModel : Trained flood model
    dict : Timing information
    """
    start_time = time.time()
    timing_info = {'total': 0, 'stations': {}}
    
    # Find all files in data/processed/
    processed_data_path = os.path.join(os.getcwd(), "data/processed/")
    station_files = []

    # Check if directory exists
    if os.path.exists(processed_data_path):
        # List all files in the directory
        files = os.listdir(processed_data_path)
        
        # Filter for parquet files
        station_files = [f for f in files if f.startswith("survival_data_") and f.endswith(".parquet")]
        print(f"Found {len(station_files)} station data files")
    else:
        print(f"Directory {processed_data_path} does not exist")
        return None, timing_info

    # Create a new FloodModel
    flood_model = md.FloodModel()

    # Use parallel processing if enabled and more than one file
    if parallel and len(station_files) > 1 and not profile:  # Skip parallel if profiling
        # Determine number of workers - default is based on physical cores when possible
        if max_workers is None:
            try:
                import psutil
                # Get physical cores (10 in your case) not logical processors (16)
                physical_cores = psutil.cpu_count(logical=False)
                # Use physical cores minus 2 to leave room for system processes
                max_workers = max(1, physical_cores - 2)
            except (ImportError, AttributeError):
                # If psutil isn't available, use a conservative default (about 60% of logical processors)
                max_workers = max(1, int(multiprocessing.cpu_count() * 0.6))
                
            # Cap at 8 workers to prevent overloading
            max_workers = min(8, max_workers)
        
        print(f"Using parallel processing with {max_workers} workers")
        
        # Create a pool of workers
        with multiprocessing.Pool(processes=max_workers) as pool:
            # Process all stations in parallel
            process_func = partial(process_station_file, 
                                  processed_data_path=processed_data_path, 
                                  profile=profile)
            
            results = pool.map(process_func, station_files)
            
            # Collect results
            for station, models, timing in results:
                if models:
                    # Add all models to the main flood model
                    for model_key, model in models.items():
                        flood_model.models[model_key] = model
                    
                    # Update station list if not already there
                    if station not in flood_model.stations:
                        flood_model.stations.append(station)
                    
                    # Update soil types
                    for model_key in models:
                        soil_type = model_key.split('_')[1]
                        if soil_type not in flood_model.available_soil_types:
                            flood_model.available_soil_types.append(soil_type)
                
                # Store timing info
                if isinstance(timing, dict) and 'station' in timing:
                    timing_info['stations'][station] = timing
    else:
        # Process files sequentially (original method)
        for file in station_files:
            station, models, timing = process_station_file(file, processed_data_path, profile)
            
            if models:
                # Add all models to the main flood model
                for model_key, model in models.items():
                    flood_model.models[model_key] = model
                
                # Update station list if not already there
                if station not in flood_model.stations:
                    flood_model.stations.append(station)
                
                # Update soil types
                for model_key in models:
                    soil_type = model_key.split('_')[1]
                    if soil_type not in flood_model.available_soil_types:
                        flood_model.available_soil_types.append(soil_type)
            
            # Store timing info
            if isinstance(timing, dict) and 'station' in timing:
                timing_info['stations'][station] = timing

    # If we've trained models then 
    if flood_model.models:
        flood_model.is_fitted = True
        total_time = time.time() - start_time
        timing_info['total'] = total_time
    else:
        print("No models were successfully trained")
        total_time = time.time() - start_time
        timing_info['total'] = total_time

    return flood_model, timing_info


def print_timing_report(timing_info):
    """Print a formatted timing report from timing information."""
    if not timing_info:
        print("No timing information available")
        return
        
    print("\n" + "="*60)
    print("TRAINING PERFORMANCE REPORT")
    print("="*60)
    print(f"Total training time: {timing_info['total']:.2f} seconds")
    print(f"Total stations: {len(timing_info['stations'])}")
    
    if timing_info['stations']:
        # Overall stats
        station_times = [info['total_time'] for info in timing_info['stations'].values()]
        avg_time = sum(station_times) / len(station_times)
        max_time = max(station_times)
        min_time = min(station_times)
        
        print(f"\nAverage time per station: {avg_time:.2f}s")
        print(f"Fastest station: {min_time:.2f}s")
        print(f"Slowest station: {max_time:.2f}s")
        
        # Top 5 slowest stations
        print("\nTop 5 slowest stations:")
        sorted_stations = sorted(timing_info['stations'].items(), 
                                key=lambda x: x[1]['total_time'], 
                                reverse=True)
        
        for i, (station, info) in enumerate(sorted_stations[:5], 1):
            print(f"{i}. Station {station}: {info['total_time']:.2f}s - {info['soil_types']} soil types")
            
        # Print save time if available
        if 'save_time' in timing_info:
            print(f"\nModel save time: {timing_info['save_time']:.2f}s")
    
    print("="*60)


# Example usage script
if __name__ == "__main__":
    # Make sure the output directory exists
    os.makedirs("models", exist_ok=True)
    
    # Enable detailed profiling if needed
    enable_profiling = False
    
    # Enable parallel processing
    enable_parallel = False
    
    # Set number of workers (None = auto)
    num_workers = None
    
    # Train models for all stations
    print("Training models for all stations...")
    training_start = time.time()
    flood_model, timing_info = train_all_models(
        "models/flood_model.joblib",
        profile=enable_profiling,
        parallel=enable_parallel, 
        max_workers=None  # Default to CPU count - 2
    )

    # Save using split storage
    print("Saving model with split storage...")
    save_start = time.time()
    flood_model.save("models/flood_model.joblib", split_by_station=True)
    save_time = time.time() - save_start
    print(f"Model saved in {save_time:.2f} seconds")
    # add save time to timing info
    timing_info['save_time'] = save_time

    # Print the timing report
    print_timing_report(timing_info)

## Model Test ##

In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np

from aba_flooding.train import process_station_file
from aba_flooding.model import FloodModel


def inspect_model(train = False):
    """Inspect the trained flood model."""
    model_path = os.path.join("models", "flood_model.joblib")
    
    
    print(f"Loading model from {model_path}...")
    try:
        # Create a FloodModel instance first, then call load with the path parameter
        model = FloodModel()
        if train:
            print("Training the model...")
            if train:
                
                # Fix: Pass string filename first, not the model object
                station_file = "data/processed/survival_data_05005.parquet"
                station_id = "05005"
                station, station_models, timing = process_station_file(f"survival_data_{station_id}.parquet", 
                                                                    os.path.dirname(station_file), 
                                                                    False)
                
                # Only try to add models if they were successfully created
                if station_models:
                    for model_key, survival_model in station_models.items():
                        model.models[model_key] = survival_model
                    model.stations.append(station)
                    
                    # Update available soil types
                    for model_key in station_models:
                        soil_type = model_key.split('_')[1]
                        if soil_type not in model.available_soil_types:
                            model.available_soil_types.append(soil_type)
            else:
                print(f"No models were created for station {station_id}")
        else:
            if not os.path.exists(model_path):
                print(f"Model file not found: {model_path}")
                return
            model.load(path=model_path)
        
        # Inspect specific station
        inspect_station(model, "05005")
        
        print("\n== Model Summary ==")
        print(f"Number of stations: {len(model.stations)}")
        print(f"Total models: {len(model.models)}")
        print(f"Number of available soil types: {len(model.available_soil_types)}")
        print(f"Available soil types: {model.available_soil_types}")
            
    except Exception as e:
        print(f"Error loading model: {e}")
        import traceback
        traceback.print_exc()

def inspect_station(model, station_id):
    """
    Inspect a specific station's models in detail
    
    Parameters:
    -----------    
    model : FloodModel
        The loaded FloodModel instance
    station_id : str
        The station ID to inspect
    """
    print(f"\n==== INSPECTING STATION {station_id} ====")
    
    # Check if station exists in the model
    if station_id not in model.stations:
        print(f"Station {station_id} not found in model")
        return
    
    # For debugging, let's examine the model structure
    print(f"Available model attributes: {[attr for attr in dir(model) if not attr.startswith('_')]}")

    # Try using the get_station_models method if available
    if hasattr(model, 'get_station_models'):
        print(f"Using get_station_models to load models for station {station_id}...")
        try:
            station_models = model.get_station_models(station_id)
            if station_models:
                print(f"Successfully loaded {len(station_models)} models for station {station_id}")
                
                # Plot survival curves for this station
                plot_survival_curves(station_models, station_id)
                
                # Report on each model
                for key, survival_model in station_models.items():
                    # Extract just the soil type part (remove station prefix if present)
                    if key.startswith(f"{station_id}_"):
                        soil_type = key.replace(f"{station_id}_", "")
                    else:
                        soil_type = key
                        
                    print(f"\nSoil type: {soil_type}")
                    
                    # Print model attributes for debugging
                    model_attrs = [attr for attr in dir(survival_model) if not attr.startswith('_')]
                    print(f"  - Model attributes: {model_attrs}")
                    
                    # Check if model is fitted
                    is_fitted = survival_model.is_fitted if hasattr(survival_model, 'is_fitted') else False
                    print(f"  - Is fitted: {is_fitted}")
                    
                    # Check for model attribute that might contain the fitted estimator
                    if hasattr(survival_model, 'model') and survival_model.model is not None:
                        print(f"  - Has model object: Yes")
                        if hasattr(survival_model.model, 'median_survival_time_'):
                            print(f"  - Median survival time: {survival_model.model.median_survival_time_}")
                    else:
                        print(f"  - Has model object: No")
                    
                    # Test different time intervals - both in years and days
                    time_periods = [
                        {'value': 10, 'unit': 'days'},
                        {'value': 150, 'unit': 'days'},
                        {'value': 1, 'unit': 'years'},
                        {'value': 5, 'unit': 'years'},
                        {'value': 10, 'unit': 'years'}
                    ]
                    print(f"  - Prediction results:")
                    for period in time_periods:
                        try:
                            if hasattr(survival_model, 'predict_proba'):
                                # Convert days to years for prediction if needed
                                t_value = period['value']
                                if period['unit'] == 'days':
                                    t_years = t_value / 365.25  # Convert to years
                                else:
                                    t_years = t_value
                                

                                surv_prob = survival_model.predict_proba(t_years)
                                # Flood probability is 1-survival probability
                                flood_prob = (1-surv_prob)*100 if isinstance(surv_prob, (int, float)) else None
                                print(f"    - At {period['value']} {period['unit']}: {flood_prob:.2f}% flood probability" 
                                      if flood_prob is not None else f"    - At {period['value']} {period['unit']}: No valid prediction")
                            else:
                                print(f"    - At {period['value']} {period['unit']}: predict_proba method not available")
                        except Exception as e:
                            print(f"    - At {period['value']} {period['unit']}: Error - {str(e)}")
                
                return  # Exit once we've used get_station_models successfully
            else:
                print(f"get_station_models returned empty for station {station_id}")
                
        except Exception as e:
            print(f"Error using get_station_models: {e}")
            import traceback
            traceback.print_exc()
    
    # Check for station model files directly
    print("\nChecking for station model files...")
    model_dir = os.path.join("models", "stations")
    station_file = os.path.join(model_dir, f"{station_id}.joblib")
    
    if os.path.exists(station_file):
        print(f"Found station file at {station_file}")
        # You could add code here to load and inspect this file directly
    else:
        print(f"No station file found at {station_file}")
    
    # List all available station files
    if os.path.exists(model_dir):
        station_files = [f for f in os.listdir(model_dir) if f.endswith('.joblib')]
        if station_files:
            print(f"\nAvailable station files ({len(station_files)} total):")
            for i, file in enumerate(sorted(station_files)[:10]):
                print(f"  - {file}")
            if len(station_files) > 10:
                print(f"  - ... and {len(station_files) - 10} more")
        else:
            print("No station files found in models/stations directory")
    else:
        print(f"Directory {model_dir} does not exist")
    
    print("\n==== STATION INSPECTION COMPLETE ====")

def plot_survival_curves(station_models, station_id):
    """
    Plot survival curves for models of a specific station
    
    Parameters:
    -----------    
    station_models : dict
        Dictionary of soil type -> survival model
    station_id : str
        The station ID
    """
    print("\n== Creating Survival Curve Plots ==")
    
    if not station_models:
        print("No models available to plot")
        return
    
    # Create output directory if it doesn't exist
    plot_dir = os.path.join("outputs", "plots", "inspect_model")
    os.makedirs(plot_dir, exist_ok=True)
    
    # Different time scales to plot
    time_ranges = [
        {"max": 1, "label": "1 Year", "filename": "1year"},
        {"max": 5, "label": "5 Years", "filename": "5years"},
        {"max": 10, "label": "10 Years", "filename": "10years"}
    ]
    
    # For each time range, create a separate plot
    for time_range in time_ranges:
        plt.figure(figsize=(10, 6))
        
        # Generate time points (convert to days for x-axis display)
        t_max = time_range["max"]  # in years
        t = np.linspace(0.01, t_max, 100)  # 100 points from 0.01 to max years (avoid 0)
        t_days = t * 365.25  # convert to days for display
        
        soil_types_plotted = 0
        
        # Plot each soil type
        for key, survival_model in station_models.items():
            # Extract soil type for the legend
            if key.startswith(f"{station_id}_"):
                soil_type = key.replace(f"{station_id}_", "")
            else:
                soil_type = key
                
            try:
                if hasattr(survival_model, 'predict_proba') and hasattr(survival_model, 'is_fitted') and survival_model.is_fitted:
                    # Get survival probabilities at each time point
                    survival_probs = [survival_model.predict_proba(time_point) for time_point in t]
                    
                    # Convert to flood probabilities
                    flood_probs = [1 - prob for prob in survival_probs]
                    
                    # Plot the flood probability curve
                    plt.plot(t_days, flood_probs, label=f"Soil: {soil_type}", linewidth=2)
                    soil_types_plotted += 1
            except Exception as e:
                print(f"  Error plotting model for soil type {soil_type}: {e}")
        
        if soil_types_plotted > 0:
            plt.xlabel('Time (days)')
            plt.ylabel('Flood Probability')
            plt.title(f'Flood Probability Curves for Station {station_id} ({time_range["label"]})')
            plt.grid(True, linestyle='--', alpha=0.7)
            plt.legend()
            
            # Save the plot
            filename = f"station_{station_id}_flood_prob_{time_range['filename']}.png"
            filepath = os.path.join(plot_dir, filename)
            plt.savefig(filepath)
            print(f"  Saved plot: {filepath}")
            
            plt.close()
        else:
            plt.close()
            print(f"  No valid models to plot for {time_range['label']} time range")
    
    print("== Plotting complete ==")

import pandas as pd
from lifelines import KaplanMeierFitter, WeibullFitter, ExponentialFitter, LogNormalFitter
from sksurv.nonparametric import kaplan_meier_estimator

if __name__ == "__main__":

    #inspect_model(True)

    df = pd.read_parquet("data/processed/survival_data_05109.parquet")

    print(df.columns)

    km = KaplanMeierFitter()
    
    km.fit(durations=df['05109_HI_TTE'],event_observed=df['05109_HI_observed'])
    
    df['05109_HI_observed'] = df['05109_HI_observed'].astype(bool)

    time, survival_prob, conf_int = kaplan_meier_estimator(df['05109_HI_observed'], df['05109_HI_duration'], conf_type="log-log")

    # Ensure the directory exists before saving the plot
    output_dir = os.path.join('outputs', 'plots', 'inspect_model')
    os.makedirs(output_dir, exist_ok=True)

    plt.step(time, survival_prob, where="post")
    plt.fill_between(time, conf_int[0], conf_int[1], alpha=0.25, step="post")
    plt.ylim(0, 1)
    plt.ylabel(r"est. probability of survival $\hat{S}(t)$")
    plt.xlabel("time $t$")
    plt.savefig(os.path.join(output_dir, "km_plot.png"))

    print(df['05109_HI_observed'].value_counts())
    print(df['05109_HI_duration'].describe())
    event_rows = df[df['05109_HI_observed'] == 1]
    print(f"\nFound {len(event_rows)} events")
    if len(event_rows) > 0:
        print("\nFirst 5 events:")
        print(event_rows.head())
    else:
        print("No events found! All observations are censored.")


    # Try plotting the cumulative hazard (might show the pattern better)
    plt.figure(figsize=(10, 6))
    km.plot_cumulative_density()
    plt.grid(True)
    plt.title("Cumulative density")
    plt.savefig('outputs/plots/inspect_model/cumulative_density.png')

    # Check for issues in the duration data
    plt.figure(figsize=(10, 6))
    plt.hist(df['05109_HI_duration'], bins=50) 
    plt.title("Distribution of Duration Values") 
    plt.savefig('outputs/plots/inspect_model/duration_hist.png')

    plt.figure()
    plt.plot(df['05109_WOG_HI'])
    plt.savefig("outputs/plots/inspect_model/ss21.png")

    plt.figure()


    df2 = pd.read_parquet("data/raw/precipitation_imputed_data.parquet")
    df2 = df2.clip(lower=0, upper=60)
    print(df2['05109'].isnull().sum())
    print(len(df2['05109']))
    print(len(df))
    #inspect_model()



    # DIAGNOSTIC SECTION
    print("\n=== DIAGNOSTIC INFORMATION ===")
    event_rate = df['05109_HI_observed'].mean()
    print(f"Event rate: {event_rate:.4f} ({event_rate*100:.2f}%)")

    # SOLUTION 1: Try plotting with CONSISTENT variables
    plt.figure(figsize=(10, 6))
    km_tte = KaplanMeierFitter()
    km_tte.fit(durations=df['05109_HI_TTE'], event_observed=df['05109_HI_observed'])
    km_tte.plot_cumulative_density()
    plt.grid(True)
    plt.title("Cumulative Incidence (using TTE values)")
    plt.savefig('outputs/plots/inspect_model/cumulative_density_tte.png')

    # SOLUTION 2: Try duration with events correctly marked
    plt.figure(figsize=(10, 6))
    km_dur = KaplanMeierFitter()
    km_dur.fit(durations=df['05109_HI_duration'], event_observed=df['05109_HI_observed'])
    km_dur.plot_cumulative_density()
    plt.grid(True)
    plt.title("Cumulative Incidence (using duration values)")
    plt.savefig('outputs/plots/inspect_model/cumulative_density_duration.png')


    # SOLUTION 4: Check for time window issues
    evenHI_by_time = df['05109_HI_observed'].rolling(window=1000).mean()
    plt.figure(figsize=(10, 6))
    plt.plot(evenHI_by_time)
    plt.title("Event Rate Over Time (Moving Average)")
    plt.savefig('outputs/plots/inspect_model/event_rate_time.png')

    # Create sksurv-compatible structured array
    y = np.zeros(len(df), dtype=[('event', bool), ('time', float)])
    y['event'] = df['05109_HI_observed'].values
    y['time'] = df['05109_HI_duration'].values

    print("\nEvent time analysis:")
    event_durations = df[df['05109_HI_observed'] == 1]['05109_HI_duration'].describe()
    print(f"Event durations: {event_durations}")
    print(f"Max duration overall: {df['05109_HI_duration'].max()}")
    print(f"Events at max duration: {sum((df['05109_HI_observed'] == 1) & (df['05109_HI_duration'] == df['05109_HI_duration'].max()))}")
    

    test = WeibullFitter()
    test.fit(df['05109_HI_duration'], df['05109_HI_observed'])
    plt.figure(figsize=(10, 6))
    test.plot_cumulative_density()
    plt.grid(True)
    plt.title("Cumulative Incidence (Weibull)")
    plt.savefig(os.path.join('outputs', 'plots', 'inspect_model', 'cumulative_density_weibull.png'))
    print(f"Weibull parameters: {test.lambda_}, {test.rho_}")
    print(f"Weibull median survival time: {test.median_survival_time_}")
    print(f"Weibull AIC: {test.AIC_}")
    print(f"Weibull BIC: {test.BIC_}")

    test = ExponentialFitter()
    test.fit(df['05109_HI_duration'], df['05109_HI_observed'])
    plt.figure(figsize=(10, 6))
    test.plot_cumulative_density()
    plt.grid(True)
    plt.title("Cumulative Incidence (Exponential)")
    plt.savefig(os.path.join('outputs', 'plots', 'inspect_model', 'cumulative_density_exponential.png'))
    print(f"Exponential parameters: {test.lambda_}")
    print(f"Exponential median survival time: {test.median_survival_time_}")
    print(f"Exponential AIC: {test.AIC_}")
    print(f"Exponential BIC: {test.BIC_}")

    test = LogNormalFitter()
    test.fit(df['05109_HI_duration'], df['05109_HI_observed'])
    plt.figure(figsize=(10, 6))
    test.plot_cumulative_density()
    plt.grid(True)
    plt.title("Cumulative Incidence (LogNormal)")
    plt.savefig(os.path.join('outputs', 'plots', 'inspect_model', 'cumulative_density_lognormal.png'))
    print(f"LogNormal parameters: {test.mu_}, {test.sigma_}")
    print(f"LogNormal median survival time: {test.median_survival_time_}")
    print(f"LogNormal AIC: {test.AIC_}")
    print(f"LogNormal BIC: {test.BIC_}")


## Running the model ##

In [None]:
import bokeh
import os
import geopandas as gpd
import time
import json  # Add for debugging GeoJSON structure

from aba_flooding.geo_utils import load_terrain_data, gdf_to_geojson, wgs84_to_web_mercator, load_geojson, load_gpkg, load_terrain_data, gdf_to_geojson
from aba_flooding.model import FloodModel

from bokeh.models import ColorBar

from bokeh.palettes import Viridis256, Category10
from bokeh.plotting import figure
from bokeh.io import output_file, show
from bokeh.models import GeoJSONDataSource, HoverTool, CheckboxGroup, CustomJS
from bokeh.models import WMTSTileSource, Column, Slider
from bokeh.transform import linear_cmap
from bokeh.layouts import column


MODEL_PATH = "models/"
MODEL_NAME = "flood_model.joblib"

def load_models(model_path):
    """Load a trained FloodModel from file, including split station files if available."""
    try:
        print(f"Loading model from {model_path}...")
        # Create a FloodModel instance first, then call load as instance method
        model = FloodModel()
        model.load(path=model_path)
        
        # Check if we need to load split station models
        if len(model.models) == 0 and len(model.stations) > 0:
            print(f"Main model has stations but no models. Looking for split station files...")
            
            # Check for a stations directory in the same location as the model file
            model_dir = os.path.dirname(model_path)
            stations_dir = os.path.join(model_dir, "flood_model_stations")
            
            # Also check for the _stations directory format used in train.py
            if not os.path.exists(stations_dir):
                base_name = os.path.splitext(os.path.basename(model_path))[0]
                stations_dir = os.path.join(model_dir, f"{base_name}_stations")
            
            if os.path.exists(stations_dir):
                print(f"Found stations directory: {stations_dir}")
                loaded_count = 0
                
                # Load each station file
                for station in model.stations:
                    # Try both naming patterns
                    station_files = [
                        # os.path.join(stations_dir, f"{station}.joblib"),   # Original pattern
                        os.path.join(stations_dir, f"station_{station}.joblib")  # Pattern from train.py
                    ]
                    
                    station_file = None
                    for potential_file in station_files:
                        if os.path.exists(potential_file):
                            station_file = potential_file
                            break
                    
                    if station_file:
                        try:
                            print(f"Loading station model from: {station_file}")
                            station_models = model.load_station(station, os.path.dirname(station_file))
                            loaded_count += len(station_models) if station_models else 0
                        except Exception as e:
                            print(f"ERROR loading station {station}: {e}")
                    else:
                        print(f"No file found for station {station}, tried patterns: {station_files}")
                
                print(f"Loaded {loaded_count} models for {len(model.stations)} stations")
            else:
                print(f"No stations directory found at {stations_dir}")
        
        print(f"Model loaded successfully with {len(model.models)} models across {len(model.stations)} stations")
        return model
    except Exception as e:
        print(f"ERROR loading model: {e}")
        import traceback
        traceback.print_exc()
        exit(1)
        # return None

def repair_geometries(gdf):
    """Repair invalid geometries in a GeoDataFrame"""
    invalid_count = sum(~gdf.geometry.is_valid)
    if invalid_count > 0:
        print(f"Fixing {invalid_count} invalid geometries")
        # buffer(0) is a common trick to fix many geometry issues
        gdf.geometry = gdf.geometry.apply(lambda geom: geom.buffer(0) if not geom.is_valid else geom)
        still_invalid = sum(~gdf.geometry.is_valid)
        if still_invalid > 0:
            print(f"Warning: {still_invalid} geometries still invalid after repair")
    return gdf

def init_map():
    """Initialize a Bokeh map with terrain data, sediment layers, and flood risk predictions using FloodModel."""

    # Load sediment data
    print("Loading sediment data...")
    try:
        # Uncomment the line below to use full Denmark dataset
        sediment_data = load_terrain_data("Sediment_wgs84.geojson")
        # Comment out the line below when using full dataset
        # sediment_data = load_terrain_data("Sediment.geojson")

        print(f"Loaded sediment data with CRS: {sediment_data.crs}")
        print(f"Sediment data size before simplification: {len(sediment_data)} polygons")
        
        # Fix: Better CRS comparison and conversion
        target_crs = "EPSG:3857"
        if sediment_data.crs is None:
            print("Warning: Sediment data has no CRS, assuming WGS84")
            sediment_data.crs = "EPSG:4326"
            sediment_data = repair_geometries(sediment_data)
            # Simplify geometries before conversion to reduce payload size
            print("Simplifying geometries...")
            sediment_data = sediment_data.copy()
            sediment_data.geometry = sediment_data.geometry.simplify(tolerance=50)
            sediment_mercator = sediment_data.to_crs(target_crs)
        elif str(sediment_data.crs).upper() != target_crs:
            print(f"Converting sediment data from {sediment_data.crs} to {target_crs} Web Mercator")
            # Fix any invalid geometries during conversion
            sediment_data = repair_geometries(sediment_data)
            # Simplify geometries before conversion to reduce payload size
            print("Simplifying geometries...")
            sediment_data = sediment_data.copy()
            sediment_data.geometry = sediment_data.geometry.simplify(tolerance=50)
            sediment_mercator = sediment_data.to_crs(target_crs)
        else:
            print("Sediment data already in Web Mercator projection")
            # Still simplify geometries for better performance
            print("Simplifying geometries...")
            sediment_data = sediment_data.copy()
            sediment_data.geometry = sediment_data.geometry.simplify(tolerance=50)
            sediment_mercator = sediment_data
        
        print(f"Sediment data size after simplification: {len(sediment_mercator)} polygons")
        print(f"Sediment data bounds: {sediment_mercator.total_bounds}")
            
        has_sediment_data = True
    except Exception as e:
        print(f"Could not load sediment data: {e}")
        import traceback
        traceback.print_exc()
        has_sediment_data = False
        exit(1)
    
    # Prepare map figure
    denmark_bounds_x = (670000, 1500000)
    denmark_bounds_y = (7000000, 8170000)
    p = figure(title="Flood Risk Map", 
               x_axis_type="mercator", y_axis_type="mercator",
               x_range=denmark_bounds_x, y_range=denmark_bounds_y,
               tools="pan,wheel_zoom,box_zoom,reset,save",
               width=1200, height=900)
    
    # Add base map tiles
    cartodb_positron = WMTSTileSource(
        url='https://tiles.basemaps.cartocdn.com/light_all/{z}/{x}/{y}.png',
        attribution='© OpenStreetMap contributors, © CartoDB'
    )
    p.add_tile(cartodb_positron)
    
    # Create precipitation coverage layer
    precipitation_layer = None
    station_layer = None
    
    # Load coverage data
    try:
        print("Loading precipitation coverage data...")
        coverage_data = load_geojson("precipitation_coverage.geojson")
        
        if coverage_data is not None and not coverage_data.empty:
            # Ensure data is in Web Mercator projection
            if coverage_data.crs != "EPSG:3857":
                coverage_mercator = coverage_data.to_crs(epsg=3857)
            else:
                coverage_mercator = coverage_data
            
            # Convert to GeoJSON for Bokeh
            coverage_geojson = gdf_to_geojson(coverage_mercator)
            coverage_source = GeoJSONDataSource(geojson=coverage_geojson)
            
            # Add precipitation coverage polygons
            precipitation_layer = p.patches(
                'xs', 'ys',
                source=coverage_source,
                fill_color='blue',
                fill_alpha=0.2,
                line_color='blue',
                line_width=1,
                legend_label="Precipitation Coverage"
            )
            
            # Create a hover tool for precipitation areas
            precip_hover = HoverTool(
                tooltips=[
                    ("Station ID", "@station_id"),
                    ("Avg Precipitation", "@avg_precipitation{0.0} mm")
                ],
                renderers=[precipitation_layer]
            )
            p.add_tools(precip_hover)
            
            # Extract station points (centroids of coverage areas) for visualization
            stations_gdf = gpd.GeoDataFrame(
                coverage_mercator.copy(),
                geometry=coverage_mercator.geometry.centroid,
                crs=coverage_mercator.crs
            )
            
            # Convert station points to GeoJSON
            stations_geojson = gdf_to_geojson(stations_gdf)
            stations_source = GeoJSONDataSource(geojson=stations_geojson)
            
            # Add station points
            station_layer = p.circle(
                'x', 'y',
                source=stations_source,
                size=8,
                color='blue',
                fill_alpha=1.0,
                line_color='white',
                line_width=1
            )
            
            print(f"Successfully loaded precipitation coverage with {len(coverage_mercator)} areas")
        else:
            print("No precipitation coverage data found or it's empty")
    except Exception as e:
        print(f"Failed to load precipitation coverage: {e}")
        import traceback
        traceback.print_exc()
        
    # Add sediment layer if available
    sediment_layer = None
    if has_sediment_data:
        try:
            print("Converting sediment data to GeoJSON...")
            sediment_geojson = gdf_to_geojson(sediment_mercator)
            
            # Debug: Check the size of the GeoJSON payload
            geojson_size_mb = len(sediment_geojson) / (1024 * 1024)
            print(f"GeoJSON payload size: {geojson_size_mb:.2f} MB")
            
            if geojson_size_mb > 50:  # If payload is very large, consider URL option
                print("Warning: GeoJSON payload is very large. Consider using file-based GeoJSON.")
                # Option to save and load via URL instead of inlining
                # os.makedirs("data", exist_ok=True)
                # with open("data/sediment.geojson", "w") as f:
                #     f.write(sediment_geojson)
                # sediment_source = GeoJSONDataSource(url="data/sediment.geojson")
            
            # Debug: Check if GeoJSON has the expected fields
            sample_feature = json.loads(sediment_geojson)["features"][0] if "features" in json.loads(sediment_geojson) else None
            if sample_feature:
                print(f"GeoJSON feature properties: {list(sample_feature['properties'].keys())}")
            
            sediment_source = GeoJSONDataSource(geojson=sediment_geojson)
            sediment_layer = p.patches('xs', 'ys', source=sediment_source,
                                    fill_color='brown', fill_alpha=0.4,
                                    line_color='black', line_width=0.2,
                                    legend_label="Sediment")
        except Exception as e:
            print(f"ERROR creating sediment layer: {e}")
            import traceback
            traceback.print_exc()
    
    # Prepare FloodModel predictions
    flood_layer = None
    year_slider = Slider(start=0, end=10, value=0, step=1, title="Years into future")
    if has_sediment_data:
        try:            
            # Train FloodModel with all soil types from sediment data
            if MODEL_NAME in os.listdir(MODEL_PATH):
                flood_model = load_models(MODEL_PATH + MODEL_NAME)
            else:
                print("No model found, GO train some by running train.py")
                # print("Training new flood models...")
                # Train models for all soil types
                # flood_model = train_all_models(soil_types, stationId)
                # Plot models for available soil types
                #flood_model.plot_all(save=True)

            # Precompute predictions for all years
            print(f"Starting predictions on {len(sediment_mercator)} polygons...")
            sediment_with_predictions = sediment_mercator.copy()
            for year in range(0, 2):
                start_time = time.time()
                sediment_with_predictions = flood_model.predict_proba(sediment_with_predictions, station_coverage=stations_gdf, year=year)
                end_time = time.time()
                prediction_time = end_time - start_time
                print(f"Predicted Year {year} in {prediction_time:.2f} seconds")
            
            # Convert to GeoJSON data source
            print("Converting predictions to GeoJSON...")
            flood_geojson = gdf_to_geojson(sediment_with_predictions)
            
            # Debug: Check the size of the predictions GeoJSON payload
            flood_geojson_size_mb = len(flood_geojson) / (1024 * 1024)
            print(f"Predictions GeoJSON payload size: {flood_geojson_size_mb:.2f} MB")
            
            # Debug: Check if GeoJSON has the expected prediction fields
            flood_sample = json.loads(flood_geojson)["features"][0] if "features" in json.loads(flood_geojson) else None
            if flood_sample:
                print(f"Prediction fields: {[k for k in flood_sample['properties'].keys() if 'prediction' in k]}")
            
            flood_source = GeoJSONDataSource(geojson=flood_geojson)
            print("Create Flood Layer")
            # Create color mapper with initial field name for year 0
            color_mapper = linear_cmap(
                field_name='predictions_0',
                palette=Viridis256,
                low=0,
                high=100  # Predictions are percentages (0-100)
            )
            
            # Add flood risk layer
            flood_layer = p.patches(
                'xs', 'ys', 
                source=flood_source,
                fill_color=color_mapper,
                fill_alpha=0.7,
                line_color=None,
                legend_label="Flood Risk"
            )
            
            # Configure color bar
            color_bar = ColorBar(
                color_mapper=color_mapper['transform'],
                location=(0, 0),
                title="Flood Risk (%) - Year 0",
                ticker=bokeh.models.BasicTicker(desired_num_ticks=5),
                formatter=bokeh.models.PrintfTickFormatter(format="%d%%")
            )
            p.add_layout(color_bar, 'right')
            
            # Fixed slider callback with error handling to properly update the visualization
            year_callback = CustomJS(
                args=dict(
                    flood_layer=flood_layer,
                    flood_source=flood_source,
                    slider=year_slider,
                    color_bar=color_bar,
                    mapper=color_mapper
                ),
                code="""
                    try {
                        // Ensure we're in a browser context with a DOM
                        if (typeof document === 'undefined' || document.body === null) {
                            console.warn('DOM not ready yet, skipping callback');
                            return;
                        }
                        
                        // Get current year from slider
                        const year = Math.round(slider.value);
                        console.log('Changing visualization to year:', year);
                        
                        // Create the field name for this year's predictions
                        const field_name = 'predictions_' + year;
                        
                        // Need to update the mapper's field name
                        mapper.field = field_name;
                        
                        // Update the layer's glyph
                        flood_layer.glyph.fill_color = {
                            ...flood_layer.glyph.fill_color,
                            field: field_name
                        };
                        
                        // Update the color bar title
                        color_bar.title = 'Flood Risk (%) - Year ' + year;
                        
                        // Force a data source change to trigger redraw
                        flood_source.change.emit();
                    } catch(e) {
                        console.error('Error in year slider callback:', e);
                    }
                """
            )
            year_slider.js_on_change('value', year_callback)
            
            # Create a single hover tool that will be dynamically updated
            hover = HoverTool(
                tooltips=[
                    ("Soil Type", "@sediment"),
                    ("Elevation", "@elevation{0,0.0}"),
                    ("Current Prediction (Year 0)", "@predictions_0{0.0}%")  # This should be updated by the slider
                ],
                renderers=[flood_layer]  # Explicitly attach to flood layer
            )
            
            # Improved tooltip callback with error handling
            hover_callback = CustomJS(
                args=dict(hover=hover, slider=year_slider),
                code="""
                    try {
                        // Get current year from slider
                        const year = Math.round(slider.value);
                        const field = 'predictions_' + year;
                        
                        // Update the hover tooltip with the current year
                        hover.tooltips[2][0] = "Current Prediction (Year " + year + ")";
                        hover.tooltips[2][1] = "@" + field + "{0.0}%";
                        
                        // Force the hover tool to update
                        hover.change.emit();
                        console.log("Updated hover tooltip for year: " + year);
                    } catch(e) {
                        console.error('Error in hover callback:', e);
                    }
                """
            )
            
            # Add the hover callback to the slider's change event
            year_slider.js_on_change('value', hover_callback)
            
            # Add the hover tool to the plot
            p.add_tools(hover)
            
        except Exception as e:
            print(f"ERROR setting up flood predictions: {e}")
            import traceback
            traceback.print_exc()
    
    # Layer visibility controls
    layer_names = []
    active_layers = []
    
    if sediment_layer:
        layer_names.append("Sediment")
        active_layers.append(len(layer_names) - 1)
    
    if precipitation_layer:
        layer_names.append("Precipitation Coverage")
        active_layers.append(len(layer_names) - 1)
    
    if flood_layer:
        layer_names.append("Flood Risk")
        active_layers.append(len(layer_names) - 1)
    
    checkbox = CheckboxGroup(labels=layer_names, active=active_layers)
    print(f"Initial active layers: {active_layers}")
    
    # JavaScript callback for layer visibility with error handling
    js_args = {'checkbox': checkbox}
    
    if sediment_layer:
        js_args['sediment_layer'] = sediment_layer
    
    if precipitation_layer:
        js_args['precipitation_layer'] = precipitation_layer
    
    if station_layer:
        js_args['station_layer'] = station_layer
    
    if flood_layer:
        js_args['flood_layer'] = flood_layer
    
    checkbox_code = """
        try {
            let i = 0;
    """
    
    if sediment_layer:
        checkbox_code += """
            sediment_layer.visible = checkbox.active.includes(i);
            i++;
        """
    
    if precipitation_layer:
        checkbox_code += """
            precipitation_layer.visible = checkbox.active.includes(i);
            if (typeof station_layer !== 'undefined')
                station_layer.visible = checkbox.active.includes(i);
            i++;
        """
    
    if flood_layer:
        checkbox_code += """
            if (typeof flood_layer !== 'undefined')
                flood_layer.visible = checkbox.active.includes(i);
        """
    
    checkbox_code += """
        } catch(e) {
            console.error('Error in checkbox callback:', e);
        }
    """
    
    checkbox_callback = CustomJS(args=js_args, code=checkbox_code)
    checkbox.js_on_change('active', checkbox_callback)
    
    # Assemble layout
    controls = column(year_slider, checkbox)
        
    layout = column(p, controls)
    p.legend.location = "top_left"
    p.legend.click_policy = "hide"
    p.legend.title = "Layers"
    
    print("From the River to the Sea, Palestine will be Free!")
    return layout

if __name__=="__main__":
    p = init_map()
    # Save to an HTML file and display in browser
    output_file("terrain_map.html")
    show(p)