# COSMOS Weather Nowcasting: PySpark Data Prep & Keras Training

This notebook demonstrates a workflow combining PySpark for data preparation and Keras/TensorFlow for training a ConvLSTM model for weather nowcasting based on INSAT-3DR data.

**Workflow:**
1.  **Phase 1 (PySpark):** Read raw HDF5 files, process time sequences, apply normalization and cropping, and save the prepared data to Parquet format (using chunking to manage memory).
2.  **Phase 2 (Keras):** Load the processed data from Parquet, define the ConvLSTM model, train the model (with checkpointing/resuming), and evaluate its performance with detailed metrics and visualizations.

## Cell 1: Imports, Setup, and Configuration

In [22]:
# ==============================================================
# Cell 1: Imports, Setup, and Configuration
# ==============================================================
print("--- Cell 1: Initializing ---")
import os
import glob
import h5py
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.callbacks import ModelCheckpoint
from sklearn.metrics import (
    confusion_matrix, precision_score, recall_score,
    f1_score, accuracy_score, mean_squared_error, mean_absolute_error
)
from tqdm.notebook import tqdm # Use tqdm.notebook for better Jupyter integration
# import ctypes # No longer needed for DLL fix
import re # For checkpoint resuming
import gc # For garbage collection (optional memory management)
import sys # To get python executable path
from IPython.display import display # For better dataframe display
import math # For chunk calculation
import shutil # For removing old parquet dir

# --- PySpark Imports (Needed for Phase 1) ---
try:
    from pyspark.sql import SparkSession
    from pyspark.sql.types import (
        StructType, StructField, StringType, ArrayType,
        FloatType, DoubleType
    )
    pyspark_available = True
    print("PySpark libraries imported successfully.")
except ImportError:
    pyspark_available = False
    print("\033[91mWarning: PySpark libraries not found. Phase 1 will be skipped.\033[0m")
    print("\033[91mEnsure PySpark is installed in your environment (e.g., spark301_py39).\033[0m")

# --- Configuration ---
# Data Paths
# !!! IMPORTANT: Update these paths to match your system !!!
RAW_DATA_DIR           = r"C:\college\CV\COSMOS\6C_full" # <--- POINT TO YOUR *FULL* HDF5 DATASET
PROCESSED_PARQUET_DIR  = r"C:\college\CV\COSMOS\processed_cosmos_data_pyspark.parquet" # <--- Location for Spark output
MODEL_SAVE_PATH        = r"C:\college\CV\COSMOS\multitask_nowcast_pyspark_trained.h5" # <--- Save path for Keras model
CHECKPOINT_DIR         = "checkpoints_pyspark_trained" # <--- Checkpoint dir for Keras training

# Model/Data Parameters
SEQ_LEN       = 4
PATCH_SIZE    = 32
BATCH_SIZE    = 16        # Keras training batch size
EPOCHS        = 20        # Keras training epochs
THRESHOLD     = 265.0     # Cloud detection threshold
CV_THRESHOLD  = 260.0     # Convection detection threshold
FOG_THRESHOLD = 270.0     # Fog detection threshold

# --- Hardware & Library Setup ---
print("\n--- Hardware & Library Setup ---")
# Mixed Precision is disabled
print("Info: Mixed precision explicitly disabled.")

# DLL Fix Block Removed (assuming `conda install zlib` handled it)

try: # GPU Setup
    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        # Attempt to enable memory growth for GPUs to avoid allocating all memory at once
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f"Using {len(gpus)} GPU(s) with memory growth enabled.")
    else:
        print("Info: No GPU detected by TensorFlow. Using CPU.")
except Exception as e:
    print(f"Warning: Error during GPU setup: {e}")

# Print Key Versions
print(f"Using TensorFlow version: {tf.__version__}")
print(f"Using NumPy version: {np.__version__}")


print("\n--- Environment Variable Checks ---")
# Set PYSPARK_PYTHON (helps Spark find the right Python for workers)
os.environ['PYSPARK_PYTHON'] = sys.executable
os.environ['PYSPARK_DRIVER_PYTHON'] = sys.executable
print(f"PYSPARK_PYTHON set to: {sys.executable}")

# Check for HADOOP_HOME (Crucial for Windows file operations)
hadoop_home = os.environ.get('HADOOP_HOME')
if hadoop_home:
    print(f"HADOOP_HOME environment variable found: {hadoop_home}")
    winutils_path = os.path.join(hadoop_home, 'bin', 'winutils.exe')
    if os.path.exists(winutils_path):
        print(f"winutils.exe found at: {winutils_path}")
    else:
        print(f"\033[91mWarning: winutils.exe not found at expected location: {winutils_path}\033[0m") # Red Warning
        print("\033[91mSpark file write operations might fail.\033[0m")
else:
    print("\033[91mWarning: HADOOP_HOME environment variable is not set.\033[0m") # Red Warning
    print("\033[91mSpark file write operations will likely fail on Windows.\033[0m")


print("\n--- Cell 1: Setup Complete ---")

--- Cell 1: Initializing ---
PySpark libraries imported successfully.

--- Hardware & Library Setup ---
Info: Mixed precision explicitly disabled.
Using 1 GPU(s) with memory growth enabled.
Using TensorFlow version: 2.8.0
Using NumPy version: 1.21.6

--- Environment Variable Checks ---
PYSPARK_PYTHON set to: c:\Users\dhanu\.conda\envs\w\python.exe
HADOOP_HOME environment variable found: C:\hadoop
winutils.exe found at: C:\hadoop\bin\winutils.exe

--- Cell 1: Setup Complete ---


## Phase 1: PySpark Data Preparation

This phase reads the raw HDF5 files, processes them into sequences, performs normalization and cropping, and saves the result as Parquet files using chunking to manage memory.


### Cell 2: Define HDF5 Processing Function

In [23]:
# ==============================================================
# Cell 2: Define HDF5 Processing Function
# ==============================================================
print("\n--- Cell 2: Defining HDF5 Processing Function ---")

# Import necessary libraries for the function (already imported in Cell 1 but good practice)
import os
import h5py
import numpy as np

# Make sure config variables from Cell 1 are available
if 'SEQ_LEN' not in locals(): SEQ_LEN = 4 # Default if running independently
if 'PATCH_SIZE' not in locals(): PATCH_SIZE = 32
if 'THRESHOLD' not in locals(): THRESHOLD = 265.0
if 'CV_THRESHOLD' not in locals(): CV_THRESHOLD = 260.0
if 'FOG_THRESHOLD' not in locals(): FOG_THRESHOLD = 270.0

# --- Data Loading/Cropping Function Definition ---
# This function processes one sequence of files
def load_and_crop_sequence(fp_seq):
    """
    Loads, processes, and crops data for one sequence.
    Returns a dictionary or None on error/skip conditions.
    """
    if len(fp_seq) != SEQ_LEN + 1:
        # print(f"Debug: Incorrect sequence length {len(fp_seq)}. Skipping.")
        return None
    try:
        frames = []
        for fp in fp_seq[:SEQ_LEN]:
            # Use 'try-with-resources' style for file handling
            with h5py.File(fp, 'r') as f:
                # Extract counts
                cnt1, cnt2 = f['IMG_TIR1'][0][...], f['IMG_TIR2'][0][...]
                cnt_wv, cnt_mir = f['IMG_WV'][0][...], f['IMG_MIR'][0][...]
                cnt_vis = f['IMG_VIS'][0][...]
                # Extract LUTs
                lut1, lut2 = f['IMG_TIR1_TEMP'][:], f['IMG_TIR2_TEMP'][:] 
                lut_wv, lut_mir = f['IMG_WV_TEMP'][:], f['IMG_MIR_TEMP'][:] 
                lut_vis = f['IMG_VIS_ALBEDO'][:] 
            # Apply LUTs
            bt1 = lut1[cnt1]; bt2 = lut2[cnt2]
            wv = lut_wv[cnt_wv]; mir = lut_mir[cnt_mir]
            vis = lut_vis[cnt_vis]
            # Stack channels & normalize
            frames.append(np.stack([bt1, bt2, wv, mir, vis], axis=-1) / 300.0)
        X = np.stack(frames, axis=0).astype(np.float32)
        
        # Load target frame data
        with h5py.File(fp_seq[-1], 'r') as f:
            cnt1_t, cnt2_t = f['IMG_TIR1'][0][...], f['IMG_TIR2'][0][...]
            cnt_wv_t, cnt_mir_t = f['IMG_WV'][0][...], f['IMG_MIR'][0][...]
            lut1_t, lut2_t = f['IMG_TIR1_TEMP'][:], f['IMG_TIR2_TEMP'][:]
            lut_wv_t, lut_mir_t = f['IMG_WV_TEMP'][:], f['IMG_MIR_TEMP'][:]
        bt1_t = lut1_t[cnt1_t]; bt2_t = lut2_t[cnt2_t]
        wv_t = lut_wv_t[cnt_wv_t]; mir_t = lut_mir_t[cnt_mir_t]
        
        # Calculate targets
        last_mean = bt1_t.mean() / 300.0
        first_mean = X[0, ..., 0].mean()
        temp_trend = np.array([last_mean - first_mean], dtype=np.float32)
        
        y = {
            'cloud': (bt1_t < THRESHOLD).astype(np.float32)[..., None],
            'convective': (bt1_t < CV_THRESHOLD).astype(np.float32)[..., None],
            'fog': (mir_t < FOG_THRESHOLD).astype(np.float32)[..., None],
            'moisture': (wv_t / 300.0).astype(np.float32)[..., None],
            'thermo_contrast': ((bt2_t - bt1_t) / 100.0).astype(np.float32)[..., None],
            'temp_trend': temp_trend
        }
        
        # Apply Random Crop
        H, W = X.shape[1], X.shape[2]
        if H < PATCH_SIZE or W < PATCH_SIZE: 
            # print(f"Debug: Skipping sequence {os.path.basename(fp_seq[0])} due to small dimensions ({H}x{W})")
            return None # Skip if too small
        i = np.random.randint(0, H - PATCH_SIZE + 1)
        j = np.random.randint(0, W - PATCH_SIZE + 1)
        Xc = X[:, i:i + PATCH_SIZE, j:j + PATCH_SIZE, :]
        yc = {}
        for k, v in y.items():
            yc[k] = v[i:i + PATCH_SIZE, j:j + PATCH_SIZE] if v.ndim == 3 else v
        
        # Return processed data as dictionary with lists
        # Keys here will become column names in Parquet
        return {
            'sequence_id': os.path.basename(fp_seq[0]),
            'input_features': Xc.tolist(),
            'target_cloud': yc['cloud'].tolist(),
            'target_convective': yc['convective'].tolist(),
            'target_fog': yc['fog'].tolist(),
            'target_moisture': yc['moisture'].tolist(),
            'target_thermo_contrast': yc['thermo_contrast'].tolist(),
            'target_temp_trend': float(yc['temp_trend'][0])
        }
    # Catch specific errors if possible, otherwise generic Exception
    except FileNotFoundError:
        print(f"Warning: File not found in sequence starting with {fp_seq[0]}. Skipping.")
        return None
    except KeyError as e:
         print(f"Warning: Missing dataset key '{e}' in sequence starting with {fp_seq[0]}. Skipping.")
         return None
    except Exception as e:
        # Log other errors without stopping the whole process if possible
        print(f"Warning: Error processing sequence starting with {fp_seq[0]}: {e}")
        # import traceback # Uncomment for full traceback during debug
        # traceback.print_exc()
        return None

print("Function 'load_and_crop_sequence' defined.")
print("\n--- Cell 2: Complete ---")


--- Cell 2: Defining HDF5 Processing Function ---
Function 'load_and_crop_sequence' defined.

--- Cell 2: Complete ---


### Cell 3: Spark Session & Process/Write in Chunks

In [24]:
# ==============================================================
# Cell 3: Spark Session & Process/Write in Chunks
# ==============================================================
print("\n--- Cell 3: Initializing Spark & Processing/Writing Data in Chunks ---")

# Ensure necessary imports/variables are available
import os 
import gc
import math
import shutil
from tqdm.notebook import tqdm
import glob # Needed if sequences list needs recreation

# Check if PySpark is available and if the previous cell produced data/function
if 'pyspark_available' not in locals() or not pyspark_available:
    print("Skipping Spark operations as PySpark is not available.")
elif 'load_and_crop_sequence' not in locals():
    print("Error: 'load_and_crop_sequence' function not defined. Please run Cell 2 first.")
    raise NameError("Missing 'load_and_crop_sequence' function")
elif 'sequences' not in locals() or not sequences:
    # Attempt to recreate sequences if Cell 2 wasn't run but Cell 1 was
    if 'RAW_DATA_DIR' in locals() and 'SEQ_LEN' in locals():
         print("Recreating sequences list...")
         all_files = sorted(glob.glob(os.path.join(RAW_DATA_DIR, "*.h5")))
         if not all_files: raise FileNotFoundError(f"Error: No .h5 files found in {RAW_DATA_DIR}.")
         sequences = [all_files[i:i+SEQ_LEN+1] for i in range(len(all_files)-SEQ_LEN)]
         print(f"Generated {len(sequences)} sequences.")
    else:
         print("Error: 'sequences' list not found or is empty. Please run Cell 1 & 2 first.")
         raise NameError("Missing 'sequences' list.")
else:
    # --- Spark Session Initialization ---
    print("Initializing Spark Session...")
    spark = None # Initialize spark variable
    try:
        # Import Spark types needed for schema definition
        from pyspark.sql import SparkSession
        from pyspark.sql.types import (
             StructType, StructField, StringType, ArrayType, DoubleType
        )
        
        # Use reduced memory (e.g., 2g or 4g based on system)
        spark_memory = "4g" # Adjust this based on your system RAM
        print(f"Attempting to start Spark with driver memory: {spark_memory}")
        spark = SparkSession.builder \
            .appName("COSMOS_Chunk_Write_Notebook") \
            .config("spark.driver.memory", spark_memory) \
            .config("spark.sql.execution.arrow.pyspark.enabled", "true") \
            .config("spark.ui.showConsoleProgress", "false") \
            .config("spark.driver.host", "127.0.0.1") \
            .config("spark.driver.bindAddress", "127.0.0.1") \
            .getOrCreate()
        print(f"Spark Session Initialized (Version: {spark.version}).")
    except Exception as e:
        print(f"Error initializing Spark Session: {e}")
        print("Please ensure Java and PySpark are correctly installed and configured.")
        spark = None
        
    if spark: # Proceed only if Spark session was created successfully
        write_success = True # Assume success initially
        processed_count = 0
        
        # --- Define Spark Schema (same as before) ---
        print("Defining Spark DataFrame schema...")
        feature_array_type = ArrayType(ArrayType(ArrayType(ArrayType(DoubleType()))))
        target_image_array_type = ArrayType(ArrayType(ArrayType(DoubleType())))
        schema = StructType([
            StructField("sequence_id", StringType(), True),
            StructField("input_features", feature_array_type, True),
            StructField("target_cloud", target_image_array_type, True),
            StructField("target_convective", target_image_array_type, True),
            StructField("target_fog", target_image_array_type, True),
            StructField("target_moisture", target_image_array_type, True),
            StructField("target_thermo_contrast", target_image_array_type, True),
            StructField("target_temp_trend", DoubleType(), True)
        ])
        
        # --- Process and Write in Chunks ---
        chunk_size = 30 # Process 30 sequences at a time (adjust if needed based on RAM)
        num_sequences = len(sequences)
        num_chunks = math.ceil(num_sequences / chunk_size)
        
        # Ensure HADOOP_HOME check for Windows
        if os.name == 'nt' and not os.environ.get('HADOOP_HOME'):
             print("\n\033[91mError: HADOOP_HOME is not set. Cannot write Parquet on Windows.\033[0m")
             write_success = False # Prevent loop from running
        else:
             print(f"\nStarting chunk processing: {num_sequences} sequences in ~{num_chunks} chunks of size {chunk_size}.")
             # Delete existing Parquet directory if it exists before starting append loop
             if os.path.exists(PROCESSED_PARQUET_DIR):
                  print(f"Removing existing Parquet directory: {PROCESSED_PARQUET_DIR}")
                  try:
                       shutil.rmtree(PROCESSED_PARQUET_DIR)
                       print("Existing directory removed.")
                  except OSError as e:
                       print(f"\033[91mError removing directory {PROCESSED_PARQUET_DIR}: {e}\033[0m")
                       write_success = False # Abort if cleanup fails
                       
                       
        # Loop through sequences in chunks
        if write_success: # Proceed only if HADOOP_HOME is set and cleanup worked
            for i in tqdm(range(num_chunks), desc="Processing Chunks"):
                start_idx = i * chunk_size
                end_idx = min((i + 1) * chunk_size, num_sequences)
                current_chunk_sequences = sequences[start_idx:end_idx]
                
                # print(f"\nProcessing chunk {i+1}/{num_chunks} (sequences {start_idx} to {end_idx-1})...") # Less verbose
                
                # Process current chunk into a list
                chunk_data_list = []
                for seq in current_chunk_sequences: # Use simple loop for clarity
                    processed_data = load_and_crop_sequence(seq)
                    if processed_data:
                        chunk_data_list.append(processed_data)
                        
                if not chunk_data_list:
                    print(f"Warning: No data processed in chunk {i+1}. Skipping write.")
                    continue # Move to the next chunk
                    
                # Create Spark DataFrame for the chunk
                df_chunk = None # Initialize
                try:
                    # print(f"Creating Spark DataFrame for chunk {i+1} ({len(chunk_data_list)} rows)...") # Less verbose
                    df_chunk = spark.createDataFrame(chunk_data_list, schema=schema)
                    
                    # Write chunk DataFrame using "append" mode
                    # print(f"Appending chunk {i+1} to Parquet: {PROCESSED_PARQUET_DIR}") # Less verbose
                    df_chunk.write.mode("append").parquet(PROCESSED_PARQUET_DIR)
                    processed_count += len(chunk_data_list)
                    # print(f"Chunk {i+1} written successfully.") # Less verbose
                    
                except Exception as e:
                    print(f"\033[91mError processing or writing chunk {i+1}: {e}\033[0m")
                    write_success = False
                    break # Stop processing if a chunk fails
                finally:
                     # Clear memory for this chunk
                     del chunk_data_list
                     if df_chunk is not None: del df_chunk
                     gc.collect()
                     
        # --- Final Summary & Cleanup ---
        if write_success and processed_count > 0:
             print(f"\nSuccessfully processed and wrote {processed_count} sequences to {PROCESSED_PARQUET_DIR}.")
        elif not write_success:
             print("\n\033[91mError occurred during chunk processing/writing. Parquet data might be incomplete or corrupted.\033[0m")
        else:
             print("\nWarning: No data was successfully processed or written (check processing function and input files).")
             
        print("Stopping Spark Session...")
        if spark: spark.stop()
        print("Spark Session stopped.")
        
        # Raise error if writing failed overall
        if not write_success:
             raise RuntimeError("Failed to write Parquet data completely due to errors in chunk processing.")
             
    # End of 'if spark:' block
    else:
        print("Spark Session could not be initialized. Skipping DataFrame creation and Parquet writing.")

print("\n--- Cell 3: Spark Operations Complete (or Skipped) ---")


--- Cell 3: Initializing Spark & Processing/Writing Data in Chunks ---
Initializing Spark Session...
Attempting to start Spark with driver memory: 4g
Spark Session Initialized (Version: 3.0.1).
Defining Spark DataFrame schema...

Starting chunk processing: 896 sequences in ~30 chunks of size 30.


Processing Chunks:   0%|          | 0/30 [00:00<?, ?it/s]


Successfully processed and wrote 896 sequences to C:\college\CV\COSMOS\processed_cosmos_data_pyspark.parquet.
Stopping Spark Session...
Spark Session stopped.

--- Cell 3: Spark Operations Complete (or Skipped) ---


## Phase 2: Keras Model Training & Evaluation

This phase loads the processed data from the Parquet files saved in Phase 1, defines the Keras model, trains it, and evaluates the results.

### Cell 4: Load Processed Data & Prepare for Keras

In [25]:
# ==============================================================
# Cell 4: Load Processed Data & Prepare for Keras
# ==============================================================
print("\n--- Cell 4: Starting Keras Data Loading & Preparation ---")

# Ensure necessary imports are available
import os
import pandas as pd
import numpy as np
import gc
from IPython.display import display

# --- Load Processed Data from Parquet ---
# Uses PROCESSED_PARQUET_DIR defined in Cell 1
if 'PROCESSED_PARQUET_DIR' not in locals():
     PROCESSED_PARQUET_DIR = r"C:\college\CV\COSMOS\processed_cosmos_data_pyspark.parquet" # Fallback path
     print(f"Warning: PROCESSED_PARQUET_DIR not found, using default: {PROCESSED_PARQUET_DIR}")
# Ensure SEQ_LEN, PATCH_SIZE are defined
if 'SEQ_LEN' not in locals(): SEQ_LEN = 4
if 'PATCH_SIZE' not in locals(): PATCH_SIZE = 32


print(f"Loading pre-processed data from Parquet: {PROCESSED_PARQUET_DIR}")
if not os.path.exists(PROCESSED_PARQUET_DIR):
     raise FileNotFoundError(f"Parquet directory not found: {PROCESSED_PARQUET_DIR}. "
                           "Ensure Phase 1 (PySpark - Cells 2 & 3) completed successfully and check the path.")
try:
    # Use Pandas to read the Parquet directory (handles multiple files)
    processed_df = pd.read_parquet(PROCESSED_PARQUET_DIR)
    print(f"Loaded {len(processed_df)} sequences from Parquet into Pandas DataFrame.")
    if processed_df.empty:
         raise ValueError("Loaded DataFrame is empty! Check Parquet files generated in Cell 3.")
    print("Columns found:", processed_df.columns.tolist())
    print("DataFrame Info:")
    processed_df.info() 

    # --- Optional Debugging: Inspect Raw Loaded Data ---
    print("\n--- Debugging: Inspecting first 3 rows as loaded from Parquet ---")
    pd.set_option('display.max_colwidth', 150) 
    display(processed_df.head(3))
    print("--- End Debugging Inspection ---")

except Exception as e:
    print(f"Error reading Parquet data with Pandas: {e}")
    raise 


# --- Convert Pandas DataFrame Columns to NumPy Arrays ---
# Keras model.fit typically expects NumPy arrays
print("\nConverting loaded data to NumPy arrays...")
try:
    # --- Conversion Attempt ---
    print("Attempting conversion for 'input_features'...")
    # Convert each element (list) in the Series to a NumPy array, then stack them
    # Add fallback for safety in case some rows failed processing earlier
    X_data = np.stack(processed_df['input_features'].apply(lambda x: np.array(x, dtype=np.float32) if isinstance(x, list) and np.array(x).shape == (SEQ_LEN, PATCH_SIZE, PATCH_SIZE, 5) else np.zeros((SEQ_LEN, PATCH_SIZE, PATCH_SIZE, 5), dtype=np.float32)).values) 
    print(f"Successfully converted 'input_features'. Shape: {X_data.shape}") # Should be (N, 4, 32, 32, 5)

    # --- Corrected Key Mapping for y_data --- 
    # Convert target variables into a dictionary of NumPy arrays
    # Use keys that MATCH the model's output layer names
    y_data = {}
    target_mapping = {
        'cloud': 'target_cloud', 
        'convective': 'target_convective', 
        'fog': 'target_fog',
        'moisture': 'target_moisture', 
        'thermo_contrast': 'target_thermo_contrast',
        'temp_trend': 'target_temp_trend'
    }
    
    expected_target_shape = (PATCH_SIZE, PATCH_SIZE, 1)

    for model_key, df_col in target_mapping.items():
         if df_col in processed_df.columns:
              print(f"Attempting conversion for '{model_key}' (from column '{df_col}')...")
              if model_key == 'temp_trend': # Handle scalar target 
                   # Ensure it's float32 and correct shape
                   y_data[model_key] = np.array(processed_df[df_col].tolist(), dtype=np.float32).reshape(-1, 1)
              else: # Handle image-like targets
                   # Apply conversion and stacking, with shape check
                   y_data[model_key] = np.stack(processed_df[df_col].apply(lambda x: np.array(x, dtype=np.float32) if isinstance(x, list) and np.array(x).shape == expected_target_shape else np.zeros(expected_target_shape, dtype=np.float32)).values)
              print(f"Successfully converted '{model_key}'. Shape: {y_data[model_key].shape}") # Should be (N, 32, 32, 1) or (N, 1)
         else:
              print(f"Warning: Source column '{df_col}' for target '{model_key}' not found in DataFrame.")
              
    # Verify all expected keys are in y_data
    expected_keys = set(target_mapping.keys())
    actual_keys = set(y_data.keys())
    if expected_keys != actual_keys:
         print(f"Warning: Mismatch in created target keys. Expected {expected_keys}, got {actual_keys}")
    else:
         print("\nSuccessfully converted all required columns to NumPy arrays with correct keys.")


    # Optional: Free memory by deleting the Pandas DataFrame
    del processed_df
    gc.collect()
    print("Pandas DataFrame cleared from memory.")

except KeyError as e:
     print(f"Error: Column '{e}' not found in the loaded Parquet data.")
     raise
except ValueError as e:
     # This error might now occur during np.stack if shapes are inconsistent between rows
     print(f"ValueError during conversion (np.stack): {e}")
     print("This likely indicates inconsistent shapes in the nested lists between different rows.")
     raise
except Exception as e:
    print(f"An unexpected error occurred converting Parquet data to NumPy: {e}")
    import traceback; traceback.print_exc()
    raise

# --- Train/Validation Split ---
# Simulates the split logic from the original Cell 3
print("\nSplitting data into training and validation sets (90/10 split)...")
split_fraction = 0.9
# Ensure X_data was successfully created before splitting
if 'X_data' not in locals():
     raise NameError("X_data was not created successfully in the conversion step.")

split_index = int(split_fraction * len(X_data))

# Input features split
X_train, X_val = X_data[:split_index], X_data[split_index:]

# Target dictionary split
y_train, y_val = {}, {}
# Ensure y_data was successfully created
if 'y_data' not in locals():
    raise NameError("y_data dictionary was not created successfully in the conversion step.")

for key in y_data: # Iterate through keys in the created y_data dictionary
    y_train[key] = y_data[key][:split_index]
    y_val[key] = y_data[key][split_index:]

print(f"Training set size: {len(X_train)} samples")
print(f"Validation set size: {len(X_val)} samples")

# --- FINAL VERIFICATION PRINT ---
print("\n--- Final Shapes Before Exiting Cell ---")
print(f"X_train shape: {X_train.shape}, dtype: {X_train.dtype}")
print(f"y_train keys: {list(y_train.keys())}") # Keys should NOT have 'target_' prefix
if 'cloud' in y_train: print(f"y_train['cloud'] shape: {y_train['cloud'].shape}, dtype: {y_train['cloud'].dtype}")
print(f"X_val shape: {X_val.shape}, dtype: {X_val.dtype}")
print(f"y_val keys: {list(y_val.keys())}") # Keys should NOT have 'target_' prefix
if 'cloud' in y_val: print(f"y_val['cloud'] shape: {y_val['cloud'].shape}, dtype: {y_val['cloud'].dtype}")
print("-" * 30)


# Optional: Free memory from the original full arrays
# Check if variables exist before deleting
if 'X_data' in locals(): del X_data
if 'y_data' in locals(): del y_data
gc.collect()
print("Full dataset arrays cleared from memory.")

print("\n--- Cell 4: Data Loading and Preparation Complete ---")


--- Cell 4: Starting Keras Data Loading & Preparation ---
Loading pre-processed data from Parquet: C:\college\CV\COSMOS\processed_cosmos_data_pyspark.parquet
Loaded 896 sequences from Parquet into Pandas DataFrame.
Columns found: ['sequence_id', 'input_features', 'target_cloud', 'target_convective', 'target_fog', 'target_moisture', 'target_thermo_contrast', 'target_temp_trend']
DataFrame Info:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 896 entries, 0 to 895
Data columns (total 8 columns):
 #   Column                  Non-Null Count  Dtype  
---  ------                  --------------  -----  
 0   sequence_id             896 non-null    object 
 1   input_features          896 non-null    object 
 2   target_cloud            896 non-null    object 
 3   target_convective       896 non-null    object 
 4   target_fog              896 non-null    object 
 5   target_moisture         896 non-null    object 
 6   target_thermo_contrast  896 non-null    object 
 7   target_temp_tren

Unnamed: 0,sequence_id,input_features,target_cloud,target_convective,target_fog,target_moisture,target_thermo_contrast,target_temp_trend
0,3RIMG_17MAR2025_2145_L1C_SGP_V01R00.h5,"[[[[0.86269468 0.84661466 0.77116781 0.88693094 0.01808104], [0.86269468 0.85231578 0.76826608 0.89674836 0.01903268], [0.86914611 0.86101615 0.76...","[[[1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0...","[[[1.0], [1.0], [1.0], [1.0], [1.0], [0.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0...","[[[1.0], [1.0], [0.0], [0.0], [0.0], [0.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0...","[[[0.7734650373458862], [0.774837076663971], [0.774837076663971], [0.7720751762390137], [0.7677935361862183], [0.7618001699447632], [0.75866925716...","[[[0.0008056640508584678], [0.005075531080365181], [0.009221648797392845], [0.0016748047200962901], [-0.014902038499712944], [-0.040740966796875],...",0.000539
1,3RIMG_15MAR2025_0045_L1C_SGP_V01R00.h5,"[[[[0.98947299 0.98293984 0.84187162 0.99704498 0.02379085], [0.98947299 0.98293984 0.84187162 0.99752104 0.02616993], [0.98905891 0.98246956 0.84...","[[[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0...","[[[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0...","[[[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0...","[[[0.8426675200462341], [0.8426675200462341], [0.8426675200462341], [0.8426675200462341], [0.8426675200462341], [0.8426675200462341], [0.843406856...","[[[-0.015034484677016735], [-0.019137268885970116], [-0.01655273512005806], [-0.01655273512005806], [-0.017922667786478996], [-0.01913726888597011...",0.005021
2,3RIMG_29MAR2025_1345_L1C_SGP_V01R00.h5,"[[[[0.89207685 0.88802153 0.79686397 0.91115689 0.01522614], [0.89207685 0.88668936 0.79686397 0.90988117 0.01712941], [0.89207685 0.88534993 0.79...","[[[1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0...","[[[1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0...","[[[1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [0.0], [0.0], [0.0], [0.0], [1.0], [1.0], [1.0], [0.0], [0.0], [0.0], [0.0], [1.0...","[[[0.7765102386474609], [0.7765102386474609], [0.7778137922286987], [0.7791017889976501], [0.7803747057914734], [0.7828768491744995], [0.782876849...","[[[-0.03438857942819595], [-0.038124848157167435], [-0.02991577237844467], [-0.022195739671587944], [-0.02911636419594288], [-0.0331297293305397],...",-0.002559


--- End Debugging Inspection ---

Converting loaded data to NumPy arrays...
Attempting conversion for 'input_features'...
Successfully converted 'input_features'. Shape: (896, 4, 32, 32, 5)
Attempting conversion for 'cloud' (from column 'target_cloud')...
Successfully converted 'cloud'. Shape: (896, 32, 32, 1)
Attempting conversion for 'convective' (from column 'target_convective')...
Successfully converted 'convective'. Shape: (896, 32, 32, 1)
Attempting conversion for 'fog' (from column 'target_fog')...
Successfully converted 'fog'. Shape: (896, 32, 32, 1)
Attempting conversion for 'moisture' (from column 'target_moisture')...
Successfully converted 'moisture'. Shape: (896, 32, 32, 1)
Attempting conversion for 'thermo_contrast' (from column 'target_thermo_contrast')...
Successfully converted 'thermo_contrast'. Shape: (896, 32, 32, 1)
Attempting conversion for 'temp_trend' (from column 'target_temp_trend')...
Successfully converted 'temp_trend'. Shape: (896, 1)

Successfully converted

### Cell 5: Keras Model Definition & Compilation

In [26]:
# ==============================================================
# Cell 5: Keras Model Definition & Compilation
# ==============================================================
print("\n--- Cell 5: Defining and Compiling Keras Model ---")

# Ensure necessary imports are available
import tensorflow as tf
from tensorflow.keras import layers, Model
import numpy as np # Needed if redefining model after checkpoint load failure

# Ensure config variables from Cell 1 are available
if 'SEQ_LEN' not in locals(): SEQ_LEN = 4
if 'PATCH_SIZE' not in locals(): PATCH_SIZE = 32
# Define losses
losses = {
    'cloud': 'binary_crossentropy', 'convective': 'binary_crossentropy', 'fog': 'binary_crossentropy',
    'moisture': 'mse', 'thermo_contrast': 'mse', 'temp_trend': 'mse'
}
# Define loss weights
loss_weights = {'cloud': 1.0, 'convective': 1.0, 'fog': 1.0, 'moisture': 0.5, 'thermo_contrast': 0.5, 'temp_trend': 0.1}
# Use Original Metrics dictionary (only 3 keys) - this matches your original notebook
metrics = {'cloud':'accuracy','convective':'accuracy','fog':'accuracy'}


# --- Model Definition (Matches original) ---
print("Defining Keras Multi-Output ConvLSTM model architecture...")
# Define input layer shape based on prepared data
# Shape is (TimeSteps, Height, Width, Channels)
inp = layers.Input(shape=(SEQ_LEN, PATCH_SIZE, PATCH_SIZE, 5), name="input_features") # Correct shape (4, 32, 32, 5)

# ConvLSTM layers
x = layers.ConvLSTM2D(32, (3, 3), padding='same', return_sequences=True, activation='relu', name='convlstm_1')(inp)
x = layers.BatchNormalization(name='batchnorm_1')(x)
x = layers.ConvLSTM2D(16, (3, 3), padding='same', return_sequences=False, activation='relu', name='convlstm_2')(x)
x = layers.BatchNormalization(name='batchnorm_2')(x) # Output shape: (batch, H, W, 16)

# Output heads for different prediction tasks
# Ensure keys here match the keys in losses, loss_weights, and y_data dictionary
heads = {
    'cloud': layers.Conv2D(1, (1, 1), activation='sigmoid', padding='same', name='cloud')(x),
    'convective': layers.Conv2D(1, (1, 1), activation='sigmoid', padding='same', name='convective')(x),
    'fog': layers.Conv2D(1, (1, 1), activation='sigmoid', padding='same', name='fog')(x),
    'moisture': layers.Conv2D(1, (1, 1), activation='linear', padding='same', name='moisture')(x),
    'thermo_contrast': layers.Conv2D(1, (1, 1), activation='linear', padding='same', name='thermo_contrast')(x),
}
# Regression task for a single scalar value (temperature trend)
temp_avg = layers.GlobalAveragePooling2D(name='global_avg_pool')(x) # Average spatial features -> (batch, 16)
heads['temp_trend'] = layers.Dense(1, activation='linear', name='temp_trend')(temp_avg) # Dense layer for scalar output

# Create the multi-output Keras Model
model = Model(inputs=inp, outputs=heads, name='multitask_nowcast_notebook') # Use consistent name
print("Model defined.")

# --- Model Compilation ---
print("Compiling Keras model (with original metrics)...")

# Compile the model using the original metrics definition
model.compile(
    optimizer='adam',
    loss=losses,
    loss_weights=loss_weights,
    metrics=metrics # Use the metrics dict defined above (only 3 keys)
)
print("Model compiled.")

# --- Display Model Summary ---
print("\nModel Summary:")
model.summary() # Print layer information and parameter counts

# --- Print TensorFlow Version ---
print(f"\nUsing TensorFlow version: {tf.__version__}")


print("--- Cell 5: Model Definition and Compilation Complete ---")


--- Cell 5: Defining and Compiling Keras Model ---
Defining Keras Multi-Output ConvLSTM model architecture...
Model defined.
Compiling Keras model (with original metrics)...
Model compiled.

Model Summary:
Model: "multitask_nowcast_notebook"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_features (InputLayer)    [(None, 4, 32, 32,   0           []                               
                                5)]                                                               
                                                                                                  
 convlstm_1 (ConvLSTM2D)        (None, 4, 32, 32, 3  42752       ['input_features[0][0]']         
                                2)                                                                
                                                                

### Cell 6: Keras Model Training

In [27]:
# ==============================================================
# Cell 6: Keras Model Training
# ==============================================================
print("\n--- Cell 6: Starting Keras Model Training ---")

# Ensure necessary imports are available if kernel restarted
import os
import glob
import re
import numpy as np
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint
import matplotlib.pyplot as plt
import gc # Ensure gc is imported if used

# Ensure config variables from Cell 1 are available
# Define defaults if running cell independently after kernel restart (adjust paths!)
if 'CHECKPOINT_DIR' not in locals(): CHECKPOINT_DIR = "checkpoints_pyspark_trained"
if 'EPOCHS' not in locals(): EPOCHS = 20
if 'BATCH_SIZE' not in locals(): BATCH_SIZE = 16
if 'MODEL_SAVE_PATH' not in locals(): MODEL_SAVE_PATH = r"C:\college\CV\COSMOS\multitask_nowcast_pyspark_trained.h5"
# Ensure model from Cell 5 is available
if 'model' not in locals(): raise NameError("Keras model not defined. Please run Cell 5 first.")
# Ensure data from Cell 4 is available
if 'X_train' not in locals() or 'y_train' not in locals() or 'X_val' not in locals() or 'y_val' not in locals():
     raise NameError("Training/Validation data not found. Please run the data preparation cell (Cell 4) first.")
# Get loss/metrics definitions from Cell 5 if available (needed for re-compile on load)
if 'losses' not in locals():
     losses = {
         'cloud': 'binary_crossentropy', 'convective': 'binary_crossentropy', 'fog': 'binary_crossentropy',
         'moisture': 'mse', 'thermo_contrast': 'mse', 'temp_trend': 'mse'
     }
if 'loss_weights' not in locals():
     loss_weights = {'cloud': 1.0, 'convective': 1.0, 'fog': 1.0, 'moisture': 0.5, 'thermo_contrast': 0.5, 'temp_trend': 0.1}
if 'metrics' not in locals(): # Use the original metrics definition
     metrics = {'cloud':'accuracy','convective':'accuracy','fog':'accuracy'}


# --- Setup Checkpointing ---
print(f"Setting up checkpoints in directory: {CHECKPOINT_DIR}")
os.makedirs(CHECKPOINT_DIR, exist_ok=True) # Ensure checkpoint directory exists

# Callback to save the model after each epoch
checkpoint_callback = ModelCheckpoint(
    filepath=os.path.join(CHECKPOINT_DIR, "model_epoch_{epoch:02d}.h5"),
    save_weights_only=False, # Save the full model
    save_freq='epoch',       # Save at the end of each epoch
    verbose=1                # Print message when saving
)

# --- Resume Logic ---
initial_epoch = 0 # Default starting epoch
ckpt_files = glob.glob(os.path.join(CHECKPOINT_DIR, "*.h5"))

if ckpt_files:
    # Sort checkpoints by epoch number
    def get_epoch_num(fpath):
        match = re.search(r"epoch_(\d+)", os.path.basename(fpath))
        return int(match.group(1)) if match else -1
        
    ckpt_files.sort(key=get_epoch_num)
    latest_ckpt = ckpt_files[-1]
    latest_epoch_num = get_epoch_num(latest_ckpt)
    
    if latest_epoch_num != -1:
        print(f"\nAttempting to load model from latest checkpoint: {latest_ckpt}")
        try:
            # Load weights into the existing 'model' structure defined in Cell 5
            print("Loading weights into existing model structure...")
            model.load_weights(latest_ckpt) # Load only weights
            initial_epoch = latest_epoch_num # Keras initial_epoch starts AT this epoch number
            print(f"Successfully loaded weights. Resuming training from epoch {initial_epoch}")
        except Exception as e:
            print(f"Warning: Error loading checkpoint weights {latest_ckpt}: {e}.")
            print("Training from scratch.")
            initial_epoch = 0
    else:
        print(f"Warning: Could not determine epoch number from latest checkpoint file '{latest_ckpt}'. Training from scratch.")
else:
    print("No checkpoints found. Training model from scratch.")


# --- Start or Resume Training ---
if initial_epoch < EPOCHS:
    print(f"\n--- Verifying Data Before Fit (Epoch {initial_epoch} to {EPOCHS}) ---")
    # Basic checks on data passed to fit
    print(f"X_train shape: {X_train.shape}, dtype: {X_train.dtype}") 
    print(f"y_train keys: {list(y_train.keys())}") # Should NOT have 'target_' prefix
    if 'cloud' in y_train:
        print(f"y_train['cloud'] shape: {y_train['cloud'].shape}, dtype: {y_train['cloud'].dtype}")
    print(f"X_val shape: {X_val.shape}, dtype: {X_val.dtype}") 
    print(f"y_val keys: {list(y_val.keys())}") # Should NOT have 'target_' prefix
    if 'cloud' in y_val:
        print(f"y_val['cloud'] shape: {y_val['cloud'].shape}, dtype: {y_val['cloud'].dtype}")
    print("-" * 30)
    
    # --- Check for NaN/Inf values ---
    print("Checking for NaN/Inf values in target data...")
    abort_training = False
    for dataset_name, dataset_dict in [("y_train", y_train), ("y_val", y_val)]:
        for key, arr in dataset_dict.items():
            if arr is not None:
                if np.isnan(arr).any():
                    print(f"  Error: NaN values found in {dataset_name}['{key}']")
                    abort_training = True
                if np.isinf(arr).any():
                    print(f"  Error: Infinite values found in {dataset_name}['{key}']")
                    abort_training = True
            else:
                print(f"  Error: {dataset_name}['{key}'] is None.")
                abort_training = True
    if abort_training:
        raise ValueError("NaN or Infinite values found in target data. Cannot proceed.")
    else:
        print("No NaN or Infinite values found in target data.")
    print("-" * 30)
    
    # --- Check model output names match y_train keys ---
    print("Verifying model output names match target keys...")
    model_output_names = model.output_names
    y_train_keys = list(y_train.keys())
    print(f"  Model output names: {sorted(model_output_names)}")
    print(f"  y_train keys:       {sorted(y_train_keys)}")
    if sorted(model_output_names) != sorted(y_train_keys):
        # This check should now PASS
        raise ValueError("Mismatch between model output layer names and keys in the y_train dictionary!")
    else:
        print("Model output names match target keys.")
    print("-" * 30)
    
    
    print(f"\nStarting model.fit from epoch {initial_epoch}...")
    history = None # Initialize history
    try:
        # *** Pass NumPy arrays directly to model.fit ***
        history = model.fit(
            X_train, y_train,                 # Training data and labels (y_train is a dict)
            validation_data=(X_val, y_val),   # Validation data and labels (y_val is a dict)
            batch_size=BATCH_SIZE,            # Number of samples per gradient update
            epochs=EPOCHS,                    # Total number of epochs to train for
            callbacks=[checkpoint_callback],  # List of callbacks (e.g., for saving)
            initial_epoch=initial_epoch,      # Starting epoch (for resuming)
            verbose=1                         # Show progress bar (1) or epoch summary (2)
        )
        print("Model training finished.")
    except Exception as e:
        print(f"\nError occurred during model.fit: {e}")
        import traceback
        traceback.print_exc() # Print full traceback if fit fails
        raise # Re-raise the error to stop execution


    # --- Save the Final Model (only if training completed) ---
    if history is not None: # Check if training actually ran
        print(f"\nSaving final trained model to: {MODEL_SAVE_PATH}")
        try:
            model.save(MODEL_SAVE_PATH)
            print("Final model saved successfully.")
        except Exception as e:
            print(f"Error saving final model: {e}")
            
        # --- Plot Training History (Loss and Accuracy) ---
        print("\nPlotting training history...")
        history_dict = history.history
        
        # Determine available metrics from history
        loss_keys = [k for k in history_dict if k.startswith('loss') or k.startswith('val_loss')]
        # Accuracy keys should exist because we compiled with them in Cell 5
        acc_keys = [k for k in history_dict if k.startswith('accuracy') or k.startswith('val_') and 'accuracy' in k]
        
        # Calculate the actual range of epochs trained in this run
        num_epochs_trained = len(history_dict.get('loss', [])) # Use .get for safety
        if num_epochs_trained > 0:
             actual_epochs_range = range(initial_epoch, initial_epoch + num_epochs_trained)
             
             plt.style.use('seaborn-v0_8-darkgrid') # Use a visually appealing style
             
             # Plot Total Loss
             if 'loss' in history_dict and 'val_loss' in history_dict:
                  # Check if accuracy keys exist to decide subplot layout
                  fig_cols = 2 if acc_keys else 1
                  plt.figure(figsize=(7 * fig_cols, 6)) # Adjust figure size
                  plt.subplot(1, fig_cols, 1)
                  plt.plot(actual_epochs_range, history_dict['loss'], 'o-', label='Training Loss', linewidth=2)
                  plt.plot(actual_epochs_range, history_dict['val_loss'], 's--', label='Validation Loss', linewidth=2)
                  plt.title('Total Training and Validation Loss', fontsize=14)
                  plt.xlabel('Epoch', fontsize=12)
                  plt.ylabel('Loss', fontsize=12)
                  plt.legend(fontsize=10)
                  plt.grid(True, linestyle=':')
             
             # Plot Accuracy (only if keys exist)
             train_acc_key = next((k for k in history_dict if k.endswith('accuracy') and not k.startswith('val_')), None)
             val_acc_key = next((k for k in history_dict if k.startswith('val_') and k.endswith('accuracy')), None)
             
             if train_acc_key and val_acc_key:
                  # Add subplot for accuracy if loss was plotted
                  if fig_cols == 2:
                       plt.subplot(1, 2, 2)
                  else: # Otherwise create a new figure
                       plt.figure(figsize=(7, 6))
                       
                  plt.plot(actual_epochs_range, history_dict[train_acc_key], 'o-', label=f'Training Accuracy ({train_acc_key})', linewidth=2)
                  plt.plot(actual_epochs_range, history_dict[val_acc_key], 's--', label=f'Validation Accuracy ({val_acc_key})', linewidth=2)
                  plt.title('Training and Validation Accuracy', fontsize=14)
                  plt.xlabel('Epoch', fontsize=12)
                  plt.ylabel('Accuracy', fontsize=12)
                  plt.legend(fontsize=10)
                  plt.grid(True, linestyle=':')
             
             plt.tight_layout() # Adjust layout
             plt.show() # Display the plot(s)
        else:
             print("No training history found to plot (training might not have run).")

else:
    print(f"\nTraining already completed up to epoch {initial_epoch}. Skipping training.")
    # If training was skipped, you might want to load the history object if saved previously
    # history = ... # load history if needed for plotting

print("\n--- Cell 6: Model Training Complete ---")


--- Cell 6: Starting Keras Model Training ---
Setting up checkpoints in directory: checkpoints_pyspark_trained

Attempting to load model from latest checkpoint: checkpoints_pyspark_trained\model_epoch_20.h5
Loading weights into existing model structure...
Successfully loaded weights. Resuming training from epoch 20

Training already completed up to epoch 20. Skipping training.

--- Cell 6: Model Training Complete ---


### Cell 7: Keras Model Evaluation

In [28]:
# ==============================================================
# Cell 7: Keras Model Evaluation
# ==============================================================
print("\n--- Cell 7: Evaluating Final Model ---")

# Ensure necessary imports are available
import os
import re # Needed for checkpoint loading fallback
import glob # Needed for checkpoint loading fallback
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.metrics import (
    confusion_matrix, precision_score, recall_score,
    f1_score, accuracy_score, mean_squared_error, mean_absolute_error
)
from IPython.display import display
from tqdm.notebook import tqdm # Use notebook version
import gc

# Ensure config variables from Cell 1 are available
if 'MODEL_SAVE_PATH' not in locals(): MODEL_SAVE_PATH = r"C:\college\CV\COSMOS\multitask_nowcast_pyspark_trained.h5" # Match save path
if 'BATCH_SIZE' not in locals(): BATCH_SIZE = 16
# Ensure validation data from Cell 4 is available
if 'X_val' not in locals() or 'y_val' not in locals():
     raise NameError("Validation data (X_val, y_val) not found. Please run the data preparation cell (Cell 4) first.")
if 'CHECKPOINT_DIR' not in locals(): CHECKPOINT_DIR = "checkpoints_pyspark_trained" # Match checkpoint dir


# --- Load the Final Saved Model --- 
# Ensure evaluation is done on the definitive final model saved after training
print(f"Loading final model from: {MODEL_SAVE_PATH}")
if not os.path.exists(MODEL_SAVE_PATH):
    # Try loading from the last checkpoint if final model doesn't exist
    print(f"Warning: Final model not found at {MODEL_SAVE_PATH}. Trying latest checkpoint...")
    ckpt_files = sorted(glob.glob(os.path.join(CHECKPOINT_DIR, "*.h5")),
                        key=lambda f: int(re.search(r"epoch_(\d+)", os.path.basename(f)).group(1) if re.search(r"epoch_(\d+)", os.path.basename(f)) else -1))
    if not ckpt_files:
        raise FileNotFoundError(f"No final model or checkpoints found. Please ensure training completed.")
    latest_ckpt = ckpt_files[-1]
    print(f"Loading model from latest checkpoint: {latest_ckpt}")
    MODEL_TO_LOAD = latest_ckpt
else:
    MODEL_TO_LOAD = MODEL_SAVE_PATH

try:
    # Load the model. compile=True is needed for model.evaluate with compiled metrics.
    eval_model = tf.keras.models.load_model(MODEL_TO_LOAD, compile=True) 
    print("Model loaded successfully for evaluation.")
except Exception as e:
    print(f"Error loading model: {e}")
    raise

# --- Option 1: Use model.evaluate() --- 
# Provides a quick summary based on compiled metrics and losses
print("\nRunning model.evaluate on validation data...")
try:
    # Pass validation data directly as NumPy arrays
    results = eval_model.evaluate(X_val, y_val, batch_size=BATCH_SIZE, verbose=1, return_dict=True)
    print("\nModel Evaluation Results (from model.evaluate):")
    # Format the results nicely
    for name, value in results.items():
        print(f"  {name}: {value:.4f}")
except Exception as e:
    print(f"Error during model.evaluate: {e}")

# --- Option 2: Manual Detailed Evaluation --- 
# Calculates metrics individually, allowing more control and custom metrics
print("\nCalculating detailed metrics manually...")
try:
    # Get predictions on the validation set
    print(f"Generating predictions for {len(X_val)} validation samples...")
    # Use predict on NumPy array
    y_pred_val_dict = eval_model.predict(X_val, batch_size=BATCH_SIZE, verbose=1)
    # Ensure predictions are in a dictionary if model has multiple named outputs
    if not isinstance(y_pred_val_dict, dict):
         # If predict returns a list, try to map it to output names
         if isinstance(y_pred_val_dict, list) and len(y_pred_val_dict) == len(eval_model.output_names):
              y_pred_val_dict = dict(zip(eval_model.output_names, y_pred_val_dict))
         else:
              # If single output model, wrap it in a dict
              if len(eval_model.output_names) == 1:
                   y_pred_val_dict = {eval_model.output_names[0]: y_pred_val_dict}
              else:
                   raise TypeError(f"Model prediction output type is {type(y_pred_val_dict)}, expected dict or list matching output names.")
                   
    print("Predictions generated.")
    
    # --- Calculate Segmentation Metrics ---
    seg_keys = ["cloud", "convective", "fog"]
    conf_matrices = {}
    seg_metrics_list = []
    print("\n--- Segmentation Metrics ---")
    plt.style.use('seaborn-v0_8-whitegrid') # Style for plots
    
    for k in seg_keys:
        if k not in y_val or k not in y_pred_val_dict:
             print(f"  Skipping metric calculation for '{k}': Key not found in y_val or predictions.")
             continue
        print(f"  Calculating for: {k}")
        # Flatten true labels and predictions for comparison
        y_true_flat = y_val[k].flatten().astype(np.uint8)
        # Apply threshold (0.5) to sigmoid output for binary prediction
        y_pred_flat = (y_pred_val_dict[k].flatten() > 0.5).astype(np.uint8) 
        
        # Confusion Matrix
        cm = confusion_matrix(y_true_flat, y_pred_flat, labels=[0, 1]) # Explicitly use labels 0 and 1
        conf_matrices[k] = cm
        
        # Handle potential division by zero if a class is missing or never predicted
        try:
            TN, FP, FN, TP = cm.ravel()
        except ValueError: # Handle case where cm might not be 2x2
             print(f"    Warning: Confusion matrix for '{k}' is not 2x2 ({cm.shape}). Metrics might be affected.")
             TN, FP, FN, TP = 0, 0, 0, 0
             if np.sum(y_true_flat == 0) == len(y_true_flat): TN = np.sum((y_true_flat==0)&(y_pred_flat==0)); FP = np.sum((y_true_flat==0)&(y_pred_flat==1))
             elif np.sum(y_true_flat == 1) == len(y_true_flat): TP = np.sum((y_true_flat==1)&(y_pred_flat==1)); FN = np.sum((y_true_flat==1)&(y_pred_flat==0))
                 
        # Calculate standard metrics
        acc = accuracy_score(y_true_flat, y_pred_flat)
        # Use zero_division=0 to return 0 instead of nan/error if denominator is zero
        prec = precision_score(y_true_flat, y_pred_flat, zero_division=0)
        rec = recall_score(y_true_flat, y_pred_flat, zero_division=0)
        f1 = f1_score(y_true_flat, y_pred_flat, zero_division=0)
        
        seg_metrics_list.append(dict(Task=k, Acc=acc, Prec=prec, Rec=rec, F1=f1, TN=TN, FP=FP, FN=FN, TP=TP))
    
    # Display segmentation metrics in a formatted table
    if seg_metrics_list:
        df_seg_metrics = pd.DataFrame(seg_metrics_list).set_index("Task")
        # Format floats to 4 decimal places for readability
        pd.options.display.float_format = '{:.4f}'.format 
        print("\nSegmentation Metrics Summary:")
        display(df_seg_metrics[['Acc', 'Prec', 'Rec', 'F1']]) # Use display for better notebook formatting
        # print(df_seg_metrics) # Uncomment to see TN, FP, FN, TP counts
    else:
        print("No segmentation metrics calculated.")
        
    # --- Calculate Regression Metrics ---
    reg_keys = ["moisture", "thermo_contrast", "temp_trend"]
    reg_metrics_list = []
    print("\n--- Regression Metrics ---")
    for k in reg_keys:
        if k not in y_val or k not in y_pred_val_dict:
             print(f"  Skipping metric calculation for '{k}': Key not found in y_val or predictions.")
             continue
        print(f"  Calculating for: {k}")
        y_true_flat = y_val[k].flatten()
        y_pred_flat = y_pred_val_dict[k].flatten()
        
        mse = mean_squared_error(y_true_flat, y_pred_flat)
        mae = mean_absolute_error(y_true_flat, y_pred_flat)
        reg_metrics_list.append(dict(Task=k, MSE=mse, MAE=mae))
    
    # Display regression metrics in a formatted table
    if reg_metrics_list:
        df_reg_metrics = pd.DataFrame(reg_metrics_list).set_index("Task")
        print("\nRegression Metrics Summary:")
        display(df_reg_metrics)
    else:
         print("No regression metrics calculated.")
         
    # --- Plot Confusion Matrices (Visuals) ---
    print("\n--- Confusion Matrices (Visualizations) ---")
    # Set a consistent color map and style
    cmap = plt.cm.Blues 
    plt.style.use('seaborn-v0_8-darkgrid') # Use darkgrid for better contrast
    
    for k, cm in conf_matrices.items():
        fig, ax = plt.subplots(figsize=(6, 5)) # Slightly larger figure for clarity
        im = ax.imshow(cm, interpolation='nearest', cmap=cmap) 
        ax.figure.colorbar(im, ax=ax, shrink=0.8) # Add colorbar
        
        # Configure axes, labels, title
        ax.set(xticks=np.arange(cm.shape[1]),
               yticks=np.arange(cm.shape[0]),
               xticklabels=["Predicted 0", "Predicted 1"], 
               yticklabels=["True 0", "True 1"],
               title=f'{k.capitalize()} Confusion Matrix',
               ylabel='True label',
               xlabel='Predicted label')
        ax.title.set_fontsize(14)
        ax.xaxis.label.set_fontsize(12)
        ax.yaxis.label.set_fontsize(12)

        # Add text annotations within each cell
        fmt = 'd' # Format as integer
        thresh = cm.max() / 2. # Threshold for text color (white/black)
        # Add text annotations for each cell value
        for i in range(cm.shape[0]):
            for j in range(cm.shape[1]):
                ax.text(j, i, format(cm[i, j], fmt),
                        fontsize=11,
                        ha="center", va="center",
                        # Set text color based on background
                        color="white" if cm[i, j] > thresh else "black")
                        
        fig.tight_layout() # Adjust layout
        plt.show() # Display the plot for each matrix

except NameError as e:
    print(f"Error during manual evaluation: {e}")
    print("Please ensure Cell 4 (Data Prep) and Cell 5/6 (Training/Prediction) executed correctly.")
except Exception as e:
    print(f"An unexpected error occurred during manual evaluation: {e}")
    import traceback
    traceback.print_exc()

print("\n--- Cell 7: Model Evaluation Complete ---")
print("\n--- Full Notebook Workflow Finished ---")


--- Cell 7: Evaluating Final Model ---
Loading final model from: C:\college\CV\COSMOS\multitask_nowcast_pyspark_trained.h5
Model loaded successfully for evaluation.

Running model.evaluate on validation data...

Model Evaluation Results (from model.evaluate):
  loss: 0.0149
  cloud_loss: 0.0069
  convective_loss: 0.0032
  fog_loss: 0.0048
  moisture_loss: 0.0000
  temp_trend_loss: 0.0000
  thermo_contrast_loss: 0.0000
  cloud_accuracy: 1.0000
  convective_accuracy: 1.0000
  fog_accuracy: 1.0000

Calculating detailed metrics manually...
Generating predictions for 90 validation samples...
Predictions generated.

--- Segmentation Metrics ---
An unexpected error occurred during manual evaluation: 'seaborn-v0_8-whitegrid' not found in the style library and input is not a valid URL or path; see `style.available` for list of available styles

--- Cell 7: Model Evaluation Complete ---

--- Full Notebook Workflow Finished ---


Traceback (most recent call last):
  File "c:\Users\dhanu\.conda\envs\w\lib\site-packages\matplotlib\style\core.py", line 127, in use
    rc = rc_params_from_file(style, use_default_template=False)
  File "c:\Users\dhanu\.conda\envs\w\lib\site-packages\matplotlib\__init__.py", line 854, in rc_params_from_file
    config_from_file = _rc_params_in_file(fname, fail_on_error=fail_on_error)
  File "c:\Users\dhanu\.conda\envs\w\lib\site-packages\matplotlib\__init__.py", line 780, in _rc_params_in_file
    with _open_file_or_url(fname) as fd:
  File "c:\Users\dhanu\.conda\envs\w\lib\contextlib.py", line 119, in __enter__
    return next(self.gen)
  File "c:\Users\dhanu\.conda\envs\w\lib\site-packages\matplotlib\__init__.py", line 757, in _open_file_or_url
    with open(fname, encoding=encoding) as f:
FileNotFoundError: [Errno 2] No such file or directory: 'seaborn-v0_8-whitegrid'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "