In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# JAX and related libraries
import jax
import jax.numpy as jnp

# Sklearn libraries
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, brier_score_loss, mean_squared_error, mean_absolute_error, max_error

# --- 1. QDA Core Functions ---

def fit_qda(X_train, y_train):
    """
    Calculates the required parameters for the QDA model from the training data.
    """
    print("Fitting QDA parameters from training data...")
    X_train, y_train = jnp.array(X_train), jnp.array(y_train)
    classes = np.unique(np.array(y_train))
    
    priors, means, covariances = {}, {}, {}

    for k in classes:
        X_k = X_train[y_train == k]
        priors[k] = X_k.shape[0] / X_train.shape[0]
        means[k] = jnp.mean(X_k, axis=0)
        covariances[k] = jnp.cov(X_k, rowvar=False, bias=True)
        
    print("Fitting complete.")
    return priors, means, covariances

@jax.jit
def predict_qda(params, X):
    """
    Predicts class labels and probabilities for new data using QDA parameters.
    """
    priors, means, covariances = params
    scores = []
    # Ensure class order is 0, 1
    for k in sorted(priors.keys()):
        mu_k = means[k]
        cov_k = covariances[k]
        inv_cov_k = jnp.linalg.inv(cov_k)
        log_det_cov_k = jnp.log(jnp.linalg.det(cov_k))
        
        diff = jax.vmap(lambda x: x - mu_k)(X)
        quadratic_term = jnp.sum((diff @ inv_cov_k) * diff, axis=1)
        
        score_k = -0.5 * quadratic_term - 0.5 * log_det_cov_k + jnp.log(priors[k])
        scores.append(score_k)
        
    scores_stacked = jnp.stack(scores, axis=0)
    predictions = jnp.argmax(scores_stacked, axis=0)
    probabilities = jax.nn.softmax(scores_stacked, axis=0)
    
    return predictions, probabilities[1, :]

# --- 2. Evaluation and Plotting Function ---
def evaluate_and_plot_classification(params, y_true, y_pred, y_prob, X_coords, dataset_name, save_dir):
    """
    Evaluates the QDA model and saves three plots, including all required metrics.
    """
    print(f"--- Evaluating on {dataset_name} Set ---")
    
    # Classification metrics
    cm = confusion_matrix(y_true, y_pred)
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    brier = brier_score_loss(y_true, y_prob)

    # --- Added: Regression-style metrics ---
    mse = mean_squared_error(y_true, y_pred)
    rmse = np.sqrt(mse)
    mae = mean_absolute_error(y_true, y_pred)
    max_err = max_error(y_true, y_pred)
    
    print("Confusion Matrix:")
    print(cm)
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"Confidence MSE (Brier Score): {brier:.4f}")
    print(f"MSE: {mse:.4f}")
    print(f"RMSE: {rmse:.4f}")
    print(f"MAE: {mae:.4f}")

    # --- Plotting and Saving ---
    # Plot 1: Prediction Correctness Distribution
    plt.figure(figsize=(10, 8))
    title1 = f'QDA_Prediction_Correctness_{dataset_name}_Set'
    correct_predictions = (y_true == y_pred)
    plt.scatter(X_coords[correct_predictions]['longitude'], X_coords[correct_predictions]['latitude'], 
                c='green', label='Correct', alpha=0.6, s=10)
    plt.scatter(X_coords[~correct_predictions]['longitude'], X_coords[~correct_predictions]['latitude'], 
                c='red', label='Incorrect', alpha=0.6, s=10)
    plt.title(f'QDA Prediction Correctness ({dataset_name} Set)')
    plt.xlabel('Longitude'); plt.ylabel('Latitude'); plt.legend(); plt.grid(True)
    save_path1 = save_dir / f"{title1}.png"; plt.savefig(save_path1, dpi=300, bbox_inches='tight'); plt.close()
    print(f"Plot saved to: {save_path1}")

    # Plot 2: Prediction Confidence Map
    plt.figure(figsize=(10, 8))
    title2 = f'QDA_Prediction_Confidence_{dataset_name}_Set'
    scatter = plt.scatter(X_coords['longitude'], X_coords['latitude'], c=y_prob, 
                          cmap='coolwarm', vmin=0, vmax=1, s=10)
    plt.colorbar(scatter, label='Probability of being Valid (Label=1)')
    plt.title(f'QDA Prediction Confidence ({dataset_name} Set)')
    plt.xlabel('Longitude'); plt.ylabel('Latitude'); plt.grid(True)
    save_path2 = save_dir / f"{title2}.png"; plt.savefig(save_path2, dpi=300, bbox_inches='tight'); plt.close()
    print(f"Plot saved to: {save_path2}")

    # Plot 3: Decision Boundary
    plt.figure(figsize=(10, 8))
    title3 = f'QDA_Decision_Boundary_{dataset_name}_Set'
    h = .02 # Mesh step size
    x_min, x_max = X_coords['longitude'].min() - 0.1, X_coords['longitude'].max() + 0.1
    y_min, y_max = X_coords['latitude'].min() - 0.1, X_coords['latitude'].max() + 0.1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
    
    grid_points = jnp.c_[xx.ravel(), yy.ravel()]
    Z, _ = predict_qda(params, grid_points)
    Z = np.array(Z).reshape(xx.shape)
    
    plt.contourf(xx, yy, Z, cmap=plt.cm.coolwarm, alpha=0.8)
    plt.scatter(X_coords['longitude'], X_coords['latitude'], c=y_true, 
                cmap=plt.cm.coolwarm, s=10, edgecolors='k', alpha=0.7)
    
    plt.title(f'QDA Decision Boundary ({dataset_name} Set)')
    plt.xlabel('Longitude'); plt.ylabel('Latitude'); plt.grid(True)
    save_path3 = save_dir / f"{title3}.png"; plt.savefig(save_path3, dpi=300, bbox_inches='tight'); plt.close()
    print(f"Decision Boundary plot saved to: {save_path3}\n")


# --- 3. Main Execution Function ---
def main():
    script_dir = Path(__file__).parent.resolve()
    data_file_path = script_dir / 'classification_data.csv'
    data = pd.read_csv(data_file_path)

    X = data[['longitude', 'latitude']]
    y = data['label']

    X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)
    X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=(1/3), random_state=42, stratify=y_temp)
    
    # 1. Fit the QDA model
    params = fit_qda(X_train.values, y_train.values)
    
    # 2. Make predictions
    y_train_pred, y_train_prob = predict_qda(params, jnp.array(X_train.values))
    y_val_pred, y_val_prob = predict_qda(params, jnp.array(X_val.values))
    y_test_pred, y_test_prob = predict_qda(params, jnp.array(X_test.values))

    # 3. Evaluate and plot
    # Pass params to the evaluation function
    evaluate_and_plot_classification(params, np.array(y_train), np.array(y_train_pred), np.array(y_train_prob), X_train, 'Training', script_dir)
    evaluate_and_plot_classification(params, np.array(y_val), np.array(y_val_pred), np.array(y_val_prob), X_val, 'Validation', script_dir)
    evaluate_and_plot_classification(params, np.array(y_test), np.array(y_test_pred), np.array(y_test_prob), X_test, 'Test', script_dir)

if __name__ == '__main__':
    main()

# Question 1
## Comparison of 4 Methods on test set:
| Loss Function | Accuracy | Precision | Recall | Brier Score | Epochs to Stop |
| :--- | :--- | :--- | :--- | :--- | :--- |
| QDA | 0.8271 | 0.8207 | 0.7714 | 0.1301 | - |
| Euclidean Distance | 0.9764 | 0.9559 | 0.9914 | 0.0162 | 192 |
| Cosine Similarity | **0.9776** | **0.9611** | 0.9886 | 0.0163 | 205 |
| Cross-Entropy | 0.9764 | 0.9534 | **0.9943** | **0.0149** | **169** |
## Figures of QDA

<table>
  <tr>
    <td><img src="Figures/QDA_Prediction_Correctness_Training_Set.png" width="400"></td>
    <td><img src="Figures/QDA_Prediction_Correctness_Validation_Set.png" width="400"></td>
    <td><img src="Figures/QDA_Prediction_Correctness_Test_Set.png" width="400"></td>
  </tr>
  <tr>
    <td><img src="Figures/QDA_Prediction_Confidence_Training_Set.png" width="400"></td>
    <td><img src="Figures/QDA_Prediction_Confidence_Validation_Set.png" width="400"></td>
    <td><img src="Figures/QDA_Prediction_Confidence_Test_Set.png" width="400"></td>
  </tr>
  <tr>
    <td><img src="Figures/QDA_Decision_Boundary_Training_Set.png" width="400"></td>
    <td><img src="Figures/QDA_Decision_Boundary_Validation_Set.png" width="400"></td>
    <td><img src="Figures/QDA_Decision_Boundary_Test_Set.png" width="400"></td>
  </tr>
</table>


# Question 2

In [None]:
# ==============================================================================
#
#                       Combined Machine Learning Script
#
# This script combines three separate processes into a single workflow:
# 1. Data Labeling:      Processes the raw XML data into CSV files.
# 2. Classification:     Trains and evaluates a Neural Network for classification.
# 3. Regression:         Trains and evaluates a CNN for regression (inpainting).
#
# ==============================================================================

# --- 0. Combined Imports ---
import csv
import time
import xml.etree.ElementTree as ET
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.metrics import (
    accuracy_score,
    brier_score_loss,
    confusion_matrix,
    max_error,
    mean_absolute_error,
    mean_squared_error,
    precision_score,
    recall_score,
)
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.layers import (
    Activation,
    BatchNormalization,
    Conv2D,
    Dense,
    Dropout,
    Input,
)
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.utils import to_categorical

# ==============================================================================
#
#                       PART 1: DATA LABELING
#                       (from Data labeling.py)
#
# ==============================================================================

def run_data_labeling():
    """
    Reads the raw weather XML data, converts it into classification and
    regression datasets, and exports them into two separate CSV files.
    """
    print(f"\n{'='*80}\nPART 1: Running Data Labeling...\n{'='*80}")

    # --- 1. Read and parse the source XML data ---
    try:
        script_dir = Path(__file__).parent.resolve()
        xml_file = script_dir / 'O-A0038-003.xml'
        tree = ET.parse(xml_file)
        root = tree.getroot()
        namespace = {'cwa': 'urn:cwa:gov:tw:cwacommon:0.1'}
        content_str = root.find('.//cwa:Content', namespace).text
        lines = content_str.strip().split('\n')
        all_floats = []
        for line in lines:
            if not line.strip():
                continue
            floats_in_line = [float(val) for val in line.split(',') if val.strip()]
            all_floats.extend(floats_in_line)
        temp_grid = np.array(all_floats).reshape(120, 67)
    except FileNotFoundError:
        print(f"Error: File not found at '{xml_file}'. Please ensure it is in the same directory.")
        return False
    except Exception as e:
        print(f"An error occurred while reading or parsing the XML file: {e}")
        return False

    # --- 2. Create the latitude and longitude coordinate grid ---
    start_lon, start_lat = 120.00, 21.88
    lon_res, lat_res = 0.03, 0.03
    lon_points, lat_points = 67, 120
    longitudes = start_lon + np.arange(lon_points) * lon_res
    latitudes = start_lat + np.arange(lat_points) * lat_res

    # --- 3. Generate the classification and regression datasets ---
    classification_data = []
    regression_data = []
    for i in range(lat_points):
        for j in range(lon_points):
            lon = longitudes[j]
            lat = latitudes[i]
            temp_value = temp_grid[i, j]
            label = 1 if temp_value != -999.0 else 0
            classification_data.append({'longitude': lon, 'latitude': lat, 'label': label})
            if temp_value != -999.0:
                regression_data.append({'longitude': lon, 'latitude': lat, 'value': temp_value})
    print("Data conversion complete.")
    print(f"Total entries in classification dataset: {len(classification_data)}.")
    print(f"Total entries in regression dataset: {len(regression_data)}.")

    # --- 4. Write the data into two separate CSV files ---
    classification_csv_file = script_dir / "classification_data.csv"
    try:
        with open(classification_csv_file, 'w', newline='', encoding='utf-8') as f:
            fieldnames = ['longitude', 'latitude', 'label']
            writer = csv.DictWriter(f, fieldnames=fieldnames)
            writer.writeheader()
            writer.writerows(classification_data)
        print(f"\nSuccessfully wrote classification data to: '{classification_csv_file}'")
    except Exception as e:
        print(f"An error occurred while writing the classification CSV: {e}")
        return False

    regression_csv_file = script_dir / "regression_data.csv"
    try:
        with open(regression_csv_file, 'w', newline='', encoding='utf-8') as f:
            fieldnames = ['longitude', 'latitude', 'value']
            writer = csv.DictWriter(f, fieldnames=fieldnames)
            writer.writeheader()
            writer.writerows(regression_data)
        print(f"Successfully wrote regression data to: '{regression_csv_file}'")
    except Exception as e:
        print(f"An error occurred while writing the regression CSV: {e}")
        return False
    
    return True

# ==============================================================================
#
#                       PART 2: CLASSIFICATION
#                     (from train_classifier(nn).py)
#
# ==============================================================================

def evaluate_and_plot_nn_cls(model, scaler, X_scaled, y_ohe, dataset_name, save_dir):
    print(f"--- Evaluating on {dataset_name} Set ---")
    
    # Predictions
    y_prob = model.predict(X_scaled)
    y_pred = np.argmax(y_prob, axis=1)
    y_true = np.argmax(y_ohe, axis=1)
    
    # Metrics
    cm = confusion_matrix(y_true, y_pred)
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    brier = brier_score_loss(y_true, y_prob[:, 1])

    print("Confusion Matrix:")
    print(cm)
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"Confidence MSE (Brier Score): {brier:.4f}\n")

    # For plotting
    X_unscaled = scaler.inverse_transform(X_scaled)
    
    # Accuracy Plot
    plt.figure(figsize=(10, 8))
    title1 = f'NN_Cls_Prediction_Correctness_{dataset_name}_Set'
    correct_predictions = (y_true == y_pred)
    plt.scatter(X_unscaled[correct_predictions, 0], X_unscaled[correct_predictions, 1], 
                c='green', label='Correct', alpha=0.6, s=10)
    plt.scatter(X_unscaled[~correct_predictions, 0], X_unscaled[~correct_predictions, 1], 
                c='red', label='Incorrect', alpha=0.6, s=10)
    plt.title(f'NN Classification Prediction Correctness ({dataset_name} Set)')
    plt.xlabel('Longitude'); plt.ylabel('Latitude'); plt.legend(); plt.grid(True)
    save_path1 = save_dir / f"{title1}.png"
    plt.savefig(save_path1, dpi=300, bbox_inches='tight'); plt.close()
    print(f"Plot saved to: {save_path1}")

    # Confidence Plot
    plt.figure(figsize=(10, 8))
    title2 = f'NN_Cls_Prediction_Confidence_{dataset_name}_Set'
    scatter = plt.scatter(X_unscaled[:, 0], X_unscaled[:, 1], c=y_prob[:, 1], 
                          cmap='coolwarm', vmin=0, vmax=1, s=10)
    plt.colorbar(scatter, label='Probability of being Valid (Label=1)')
    plt.title(f'NN Classification Prediction Confidence ({dataset_name} Set)')
    plt.xlabel('Longitude'); plt.ylabel('Latitude'); plt.grid(True)
    save_path2 = save_dir / f"{title2}.png"
    plt.savefig(save_path2, dpi=300, bbox_inches='tight'); plt.close()
    print(f"Plot saved to: {save_path2}\n")

def run_classification_nn():
    print(f"\n{'='*80}\nPART 2: Running Classification Neural Network...\n{'='*80}")
    
    # Set seed for reproducibility
    seed_value = 42
    import os
    os.environ['PYTHONHASHSEED'] = str(seed_value)
    import random
    random.seed(seed_value)
    np.random.seed(seed_value)
    tf.random.set_seed(seed_value)
    
    script_dir = Path(__file__).parent.resolve()
    data_file_path = script_dir / 'classification_data.csv'
    
    try:
        data = pd.read_csv(data_file_path)
    except FileNotFoundError:
        print(f"Error: '{data_file_path}' not found. Aborting Part 2.")
        return

    X = data[['longitude', 'latitude']]
    y = data['label']

    y_ohe = to_categorical(y, num_classes=2)

    X_train, X_temp, y_train_ohe, y_temp_ohe = train_test_split(X, y_ohe, test_size=0.3, random_state=42, stratify=y_ohe)
    X_val, X_test, y_val_ohe, y_test_ohe = train_test_split(X_temp, y_temp_ohe, test_size=(1/3), random_state=42, stratify=y_temp_ohe)

    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_val_scaled = scaler.transform(X_val)
    X_test_scaled = scaler.transform(X_test)
    
    model = Sequential([
        Dense(32, activation='relu', input_shape=(2,)),
        Dense(32, activation='relu'),
        Dense(2, activation='softmax')
    ])

    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    model.summary()
    
    early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)

    print("\nTraining the Classification Neural Network model...")
    history = model.fit(X_train_scaled, y_train_ohe,
                        epochs=1000,
                        batch_size=32,
                        validation_data=(X_val_scaled, y_val_ohe),
                        callbacks=[early_stopping],
                        verbose=2)
    print("Model training complete.\n")

    evaluate_and_plot_nn_cls(model, scaler, X_train_scaled, y_train_ohe, 'Training', script_dir)
    evaluate_and_plot_nn_cls(model, scaler, X_val_scaled, y_val_ohe, 'Validation', script_dir)
    evaluate_and_plot_nn_cls(model, scaler, X_test_scaled, y_test_ohe, 'Test', script_dir)

# ==============================================================================
#
#                       PART 3: REGRESSION (INPAINTING)
#                   (from train_regression(inpainting).py)
#
# ==============================================================================

def create_inpainting_dataset_v2(data):
    """Creates the dataset for the CNN inpainting task (leak-free version)."""
    print("Creating corrected inpainting dataset for regression...")
    start_lon, start_lat, lon_res, lat_res = 120.00, 21.88, 0.03, 0.03
    lon_points, lat_points = 67, 120
    
    longitudes = start_lon + np.arange(lon_points) * lon_res
    latitudes = start_lat + np.arange(lat_points) * lat_res
    lon_grid, lat_grid = np.meshgrid(longitudes, latitudes)
    
    scaler = StandardScaler()
    coords_flat = np.stack([lon_grid.flatten(), lat_grid.flatten()], axis=1)
    scaler.fit(coords_flat)
    scaled_coords_flat = scaler.transform(coords_flat)
    scaled_lon_grid = scaled_coords_flat[:, 0].reshape(lat_points, lon_points)
    scaled_lat_grid = scaled_coords_flat[:, 1].reshape(lat_points, lon_points)

    indices = data.index
    train_indices, temp_indices = train_test_split(indices, test_size=0.3, random_state=42)
    val_indices, test_indices = train_test_split(temp_indices, test_size=(1/3), random_state=42)
    train_df, val_df, test_df = data.loc[train_indices], data.loc[val_indices], data.loc[test_indices]
    
    y_true_grid_nan = np.full((lat_points, lon_points), np.nan)
    for _, row in data.iterrows():
        j = int(round((row['longitude'] - start_lon) / lon_res))
        i = int(round((row['latitude'] - start_lat) / lat_res))
        if 0 <= i < lat_points and 0 <= j < lon_points: y_true_grid_nan[i, j] = row['value']

    train_points_mask_channel = np.zeros((lat_points, lon_points))
    for _, row in train_df.iterrows():
        j = int(round((row['longitude'] - start_lon) / lon_res))
        i = int(round((row['latitude'] - start_lat) / lat_res))
        if 0 <= i < lat_points and 0 <= j < lon_points: train_points_mask_channel[i, j] = 1.0
            
    X_grid = np.stack([train_points_mask_channel, scaled_lon_grid, scaled_lat_grid], axis=-1)

    train_weight_mask = np.zeros((lat_points, lon_points))
    val_weight_mask = np.zeros((lat_points, lon_points))
    for index in train_indices:
        row = data.loc[index]; j = int(round((row['longitude']-start_lon)/lon_res)); i = int(round((row['latitude']-start_lat)/lat_res))
        if 0 <= i < lat_points and 0 <= j < lon_points: train_weight_mask[i, j] = 1.0
    for index in val_indices:
        row = data.loc[index]; j = int(round((row['longitude']-start_lon)/lon_res)); i = int(round((row['latitude']-start_lat)/lat_res))
        if 0 <= i < lat_points and 0 <= j < lon_points: val_weight_mask[i, j] = 1.0

    X_grid = X_grid[np.newaxis, ...]
    y_grid_for_loss = np.nan_to_num(y_true_grid_nan)[np.newaxis, ..., np.newaxis]
    train_weight_mask = train_weight_mask[np.newaxis, ...]
    val_weight_mask = val_weight_mask[np.newaxis, ...]

    print("Inpainting dataset created.")
    return X_grid, y_grid_for_loss, y_true_grid_nan, train_weight_mask, val_weight_mask, (train_df, val_df, test_df)

def evaluate_and_plot_inpainting(y_true_grid_nan, y_pred_grid, df, dataset_name, save_dir):
    """Evaluates the inpainting model and plots the results."""
    print(f"--- Evaluating on {dataset_name} Set ---")
    
    true_values, pred_values = [], []
    start_lon, start_lat, lon_res, lat_res = 120.00, 21.88, 0.03, 0.03
    
    for _, row in df.iterrows():
        j = int(round((row['longitude'] - start_lon) / lon_res))
        i = int(round((row['latitude'] - start_lat) / lat_res))
        if 0 <= i < y_true_grid_nan.shape[0] and 0 <= j < y_true_grid_nan.shape[1]:
            true_values.append(y_true_grid_nan[i, j])
            pred_values.append(y_pred_grid[i, j])

    mse = mean_squared_error(true_values, pred_values)
    rmse = np.sqrt(mse)
    mae = mean_absolute_error(true_values, pred_values)
    max_err = max_error(true_values, pred_values)
    
    print(f"Temperature MSE: {mse:.4f}")
    print(f"Temperature RMSE: {rmse:.4f} (°C)")
    print(f"Temperature MAE: {mae:.4f} (°C)")
    print(f"Max Temperature Error: {max_err:.4f} (°C)\n")
    
    vmin, vmax = np.nanmin(y_true_grid_nan), np.nanmax(y_true_grid_nan)
    extent = [start_lon, start_lon + 67*lon_res, start_lat, start_lat + 120*lat_res]

    plt.figure(figsize=(8, 10)); title1 = f'Inpainting_Actual_Temperature_{dataset_name}'
    plt.imshow(y_true_grid_nan, cmap='viridis', vmin=vmin, vmax=vmax, origin='lower', extent=extent, interpolation='nearest')
    plt.colorbar(label='Actual Temperature (°C)'); plt.title(f'Actual Temperature ({dataset_name} points)')
    plt.xlabel('Longitude'); plt.ylabel('Latitude'); plt.grid(True)
    save_path1 = save_dir / f"{title1}.png"; plt.savefig(save_path1, dpi=300); plt.close()
    print(f"Plot saved to: {save_path1}")

    plt.figure(figsize=(8, 10)); title2 = f'Inpainting_Predicted_Temperature_Full'
    plt.imshow(y_pred_grid, cmap='viridis', vmin=vmin, vmax=vmax, origin='lower', extent=extent, interpolation='nearest')
    plt.colorbar(label='Predicted Temperature (°C)'); plt.title(f'Full Predicted Temperature Map')
    plt.xlabel('Longitude'); plt.ylabel('Latitude'); plt.grid(True)
    save_path2 = save_dir / f"{title2}.png"; plt.savefig(save_path2, dpi=300); plt.close()
    print(f"Plot saved to: {save_path2}")
    
    plt.figure(figsize=(8, 10)); title3 = f'Inpainting_Temperature_Error_Full'
    errors = y_pred_grid - y_true_grid_nan
    error_max_abs = np.nanmax(np.abs(errors))
    plt.imshow(errors, cmap='coolwarm', vmin=-error_max_abs, vmax=error_max_abs, origin='lower', extent=extent, interpolation='nearest')
    plt.colorbar(label='Prediction Error (°C)'); plt.title(f'Full Prediction Error Map')
    plt.xlabel('Longitude'); plt.ylabel('Latitude'); plt.grid(True)
    save_path3 = save_dir / f"{title3}.png"; plt.savefig(save_path3, dpi=300); plt.close()
    print(f"Plot saved to: {save_path3}\n")

def run_regression_inpainting():
    print(f"\n{'='*80}\nPART 3: Running Regression CNN (Inpainting)...\n{'='*80}")
    
    # Set seed for reproducibility
    seed_value = 42
    import os
    os.environ['PYTHONHASHSEED'] = str(seed_value)
    import random
    random.seed(seed_value)
    np.random.seed(seed_value)
    tf.random.set_seed(seed_value)

    script_dir = Path(__file__).parent.resolve()
    data_file_path = script_dir / 'regression_data.csv'
    
    try:
        data = pd.read_csv(data_file_path)
    except FileNotFoundError:
        print(f"Error: '{data_file_path}' not found. Aborting Part 3.")
        return
    
    X_grid, y_grid_for_loss, y_true_grid_nan, train_mask, val_mask, (train_df, val_df, test_df) = create_inpainting_dataset_v2(data)

    input_shape = X_grid.shape[1:]
    inputs = Input(shape=input_shape)
    x = Conv2D(64, (5, 5), padding='same')(inputs)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(128, (3, 3), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Dropout(0.25)(x)
    x = Conv2D(64, (3, 3), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    outputs = Conv2D(1, (1, 1), padding='same', activation='linear')(x)
    model = Model(inputs, outputs)

    huber_loss = tf.keras.losses.Huber(delta=4.0)
    model.compile(optimizer='adam', loss=huber_loss)
    model.summary()
    
    early_stopping = EarlyStopping(monitor='val_loss', patience=200, restore_best_weights=True)

    print("\nTraining the Corrected Inpainting CNN model...")
    history = model.fit(X_grid, y_grid_for_loss,
                        epochs=2000,
                        batch_size=1,
                        sample_weight=train_mask,
                        validation_data=(X_grid, y_grid_for_loss, val_mask),
                        callbacks=[early_stopping],
                        verbose=2)
    print("Model training complete.\n")

    y_pred_grid_full = model.predict(X_grid)[0, :, :, 0]
    
    evaluate_and_plot_inpainting(y_true_grid_nan, y_pred_grid_full, train_df, 'Training', script_dir)
    evaluate_and_plot_inpainting(y_true_grid_nan, y_pred_grid_full, val_df, 'Validation', script_dir)
    evaluate_and_plot_inpainting(y_true_grid_nan, y_pred_grid_full, test_df, 'Test', script_dir)


# ==============================================================================
#
#                       MAIN EXECUTION BLOCK
#
# ==============================================================================

if __name__ == '__main__':
    # Step 1: Process raw data and create CSVs
    success = run_data_labeling()
    
    # Step 2 & 3: Run models only if data processing was successful
    if success:
        run_classification_nn()
        run_regression_inpainting()
    else:
        print("\nData labeling failed. Halting execution of model training.")

<hr style="border-style: dashed; border-color: white; border-width: 0.8px;">

## 1. Imports and Configuration

This section loads the necessary libraries for data processing, traditional machine learning, and the JAX/Flax framework. A global random seed is set to ensure reproducibility of the data split and model initialization.

In [None]:
import xml.etree.ElementTree as ET
from pathlib import Path
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# JAX/Flax Libraries for the core NN implementation
import jax
import jax.numpy as jnp
from jax import random
import flax.linen as nn
from flax.training import train_state
import optax

# Scikit-learn Utilities for data handling and metrics
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    confusion_matrix, accuracy_score, precision_score,
    recall_score, f1_score, brier_score_loss
)

# Benchmark Classifier Imports (Will be used in Section 3.2)
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.svm import SVC
import xgboost as xgb
import lightgbm as lgb
from catboost import CatBoostClassifier
from mpl_toolkits.axes_grid1.inset_locator import mark_inset, inset_axes

# Finding files in Google Colab
from google.colab import drive

# Set global random seed
seed = 42
np.random.seed(seed)
key = random.PRNGKey(seed)

## 2. Data Acquisition and Transformation

The raw grid data from `O-A0038-003.xml` is parsed, and the coordinates are generated. This section focuses on creating the **Classification Dataset** `(Longitude, Latitude, label)` by identifying valid (`label=1`) versus invalid (`label=0`) temperature readings.

### XML Parsing and Grid Creation

In [None]:
# --- Data Characteristics ---
drive.mount('/content/drive')
xml_file = '/content/drive/MyDrive/Meteorological Data/O-A0038-003.xml'
lat_points = 120
lon_points = 67
start_lon = 120.00
start_lat = 21.88
resolution = 0.03
INVALID_VALUE = -999.0

# Read and parse the XML data
tree = ET.parse(xml_file)
root = tree.getroot()
namespace = {'cwa': 'urn:cwa:gov:tw:cwacommon:0.1'}
content_str = root.find('.//cwa:Content', namespace).text

# Parse the temperature grid into a NumPy array (120 rows x 67 columns)
lines = content_str.strip().split('\n')
all_floats = []
for line in lines:
    if not line.strip():
        continue
    floats_in_line = [float(val) for val in line.split(',') if val.strip()]
    all_floats.extend(floats_in_line)

temp_grid = np.array(all_floats).reshape(lat_points, lon_points)

print(f"Temperature grid shape: {temp_grid.shape}")
valid_temps = temp_grid[temp_grid != INVALID_VALUE]
print(f"Valid temperature range: {valid_temps.min():.2f}°C to {valid_temps.max():.2f}°C")

### Classification Dataset Creation

The core classification task is to predict data validity. The dataset is structured as: `(Longitude,Latitude,label)`, where `label=0` for `Value=−999.0` and `label=1` otherwise.

In [None]:
# Generate coordinates
longitudes = start_lon + np.arange(lon_points) * resolution
latitudes = start_lat + np.arange(lat_points) * resolution

classification_data = []

for i in range(lat_points):
    for j in range(lon_points):
        lon = longitudes[j]
        lat = latitudes[i]
        temp_value = temp_grid[i, j]

        # Classification Dataset Rule: label=0 if invalid, label=1 if valid
        label = 1 if temp_value != INVALID_VALUE else 0
        classification_data.append({'longitude': lon, 'latitude': lat, 'label': label})

data_cls = pd.DataFrame(classification_data)

print("\n--- Classification Data Summary ---")
print(f"Total entries: {len(data_cls)}")
print(f"Valid data points (label=1): {(data_cls['label'] == 1).sum()}")
print(f"Invalid data points (label=0): {(data_cls['label'] == 0).sum()}")
print(data_cls.head())

```
--- Classification Data Summary ---
Total entries: 8040
Valid data points (label=1): 3495
Invalid data points (label=0): 4545
   longitude  latitude  label
0     120.00     21.88      0
1     120.03     21.88      0
2     120.06     21.88      0
3     120.09     21.88      0
4     120.12     21.88      0
```

### Data Splitting and Scaling

The data is split into Train (70%), Validation (20%), and Test (10%) sets with **stratified sampling** to maintain the balance of valid/invalid data points across the sets. Features (Longitude, Latitude) are then standardized for optimal model performance.

In [None]:
# Extract features (X) and labels (y)
X_cls = data_cls[['longitude', 'latitude']].values
y_cls = data_cls['label'].values

# Split data: 70% train, 20% validation, 10% test (stratified)
X_train_cls, X_temp_cls, y_train_cls, y_temp_cls = train_test_split(
    X_cls, y_cls, test_size=0.3, random_state=seed, stratify=y_cls
)
X_val_cls, X_test_cls, y_val_cls, y_test_cls = train_test_split(
    X_temp_cls, y_temp_cls, test_size=(1/3), random_state=seed, stratify=y_temp_cls
)

print(f"\nTraining set size: {len(X_train_cls)}")
print(f"Validation set size: {len(X_val_cls)}")
print(f"Test set size: {len(X_test_cls)}")

# Standardize features (fit only on training data)
scaler_cls = StandardScaler()
X_train_scaled_cls = scaler_cls.fit_transform(X_train_cls)
X_val_scaled_cls = scaler_cls.transform(X_val_cls)
X_test_scaled_cls = scaler_cls.transform(X_test_cls)

# Convert to JAX arrays for neural network training
X_train_jax = jnp.array(X_train_scaled_cls)
y_train_jax = jnp.array(y_train_cls)
X_val_jax = jnp.array(X_val_scaled_cls)
y_val_jax = jnp.array(y_val_cls)
X_test_jax = jnp.array(X_test_scaled_cls)
y_test_jax = jnp.array(y_test_cls)

```
Training set size: 5628
Validation set size: 1608
Test set size: 804
```

## 3. Classification Model Training


### Benchmark Model Training

Before implementing the JAX Neural Network, a variety of traditional and ensemble classifiers are trained to establish a performance benchmark.

In [None]:
# Utility function for training and evaluation
all_results_cls = {}
def fit_qda(X_train, y_train):
    X_train, y_train = jnp.array(X_train), jnp.array(y_train)
    classes = np.unique(np.array(y_train))
    priors, means, covariances = {}, {}, {}
    for k in classes:
        X_k = X_train[y_train == k]
        priors[k] = X_k.shape[0] / X_train.shape[0]
        means[k] = jnp.mean(X_k, axis=0)
        covariances[k] = jnp.cov(X_k, rowvar=False, bias=True)
    return priors, means, covariances

@jax.jit
def predict_qda(params, X):
    priors, means, covariances = params
    scores = []
    for k in sorted(priors.keys()):
        mu_k, cov_k = means[k], covariances[k]
        inv_cov_k = jnp.linalg.inv(cov_k)
        log_det_cov_k = jnp.log(jnp.linalg.det(cov_k))
        diff = jax.vmap(lambda x: x - mu_k)(X)
        quadratic_term = jnp.sum((diff @ inv_cov_k) * diff, axis=1)
        score_k = -0.5 * quadratic_term - 0.5 * log_det_cov_k + jnp.log(priors[k])
        scores.append(score_k)
    scores_stacked = jnp.stack(scores, axis=0)
    predictions = jnp.argmax(scores_stacked, axis=0)
    probabilities = jax.nn.softmax(scores_stacked, axis=0)
    return predictions, probabilities

# 2. Wrapping QDA as Scikit-learn 

class JAX_QDA_Wrapper:
    def __init__(self):
        self.params_ = None

    def fit(self, X, y):
        if hasattr(X, 'values'):
            X = X.values
        if hasattr(y, 'values'):
            y = y.values
        self.params_ = fit_qda(X, y)
        return self

    def predict(self, X):
        if hasattr(X, 'values'):
            X = X.values
        predictions, _ = predict_qda(self.params_, jnp.array(X))
        return np.array(predictions)

    def predict_proba(self, X):
        if hasattr(X, 'values'):
            X = X.values
        # Transpose
        _, probabilities = predict_qda(self.params_, jnp.array(X))
        return np.array(probabilities).T

def evaluate_classifier(model, model_name, X_train, y_train, X_test, y_test, results_dict):
    """Trains and evaluates a classification model using scikit-learn metrics."""
    print(f"\n{'='*60}")
    print(f"Training {model_name}")
    print(f"{'='*60}")

    start_time = time.time()
    model.fit(X_train, y_train)
    train_time = time.time() - start_time
    
    y_test_pred = model.predict(X_test)

    if hasattr(model, "predict_proba"):
        # Get probabilities for the positive class (class 1)
        y_test_prob = model.predict_proba(X_test)[:, 1]
        test_brier = brier_score_loss(y_test, y_test_prob)
    else:
        # If the model doesn't support probabilities, set Brier score to NaN
        test_brier = np.nan

    # Metrics
    test_acc = accuracy_score(y_test, y_test_pred)
    test_prec = precision_score(y_test, y_test_pred, zero_division=0)
    test_rec = recall_score(y_test, y_test_pred, zero_division=0)
    test_f1 = f1_score(y_test, y_test_pred, zero_division=0)

    print(f"Training Time: {train_time:.4f}s")
    print(f"Test Metrics: Acc={test_acc:.4f}, Prec={test_prec:.4f}, Rec={test_rec:.4f}, F1={test_f1:.4f}, Brier={test_brier:.4f}")

    results_dict[model_name] = {
        'train_time': train_time,
        'test_acc': test_acc,
        'test_prec': test_prec,
        'test_rec': test_rec,
        'test_f1': test_f1,
        'test_brier': test_brier
    }
    return model

print("Starting Classification Benchmark Models Training...")
print("="*60)

# Define models 
models = {
    "QDA (JAX)": JAX_QDA_Wrapper(),
    "K-Nearest Neighbors": KNeighborsClassifier(n_neighbors=5),
    "Decision Tree": DecisionTreeClassifier(random_state=seed, max_depth=10),
    "Random Forest": RandomForestClassifier(n_estimators=100, random_state=seed, max_depth=10, n_jobs=-1),
    "Naive Bayes": GaussianNB(),
    "Support Vector Machine": SVC(random_state=seed, probability=True), 
    "XGBoost": xgb.XGBClassifier(n_estimators=100, max_depth=6, learning_rate=0.1, random_state=seed, eval_metric='logloss'),
    "LightGBM": lgb.LGBMClassifier(n_estimators=100, max_depth=6, learning_rate=0.1, random_state=seed, verbose=-1),
    "CatBoost": CatBoostClassifier(iterations=100, depth=6, learning_rate=0.1, random_state=seed, verbose=False)
}

# Train and evaluate all benchmark models
for name, model in models.items():
    evaluate_classifier(
        model, name,
        X_train_scaled_cls, y_train_cls,
        X_test_scaled_cls, y_test_cls,
        all_results_cls
    )

print("\n" + "="*60)
print("Classification Benchmark Training Complete!")
print("="*60)

```
Starting Classification Benchmark Models Training...
============================================================

============================================================
Training QDA (JAX)
============================================================
Training Time: 0.9934s
Test Metrics: Acc=0.8271, Prec=0.8207, Rec=0.7714, F1=0.7953, Brier=0.1301

============================================================
Training K-Nearest Neighbors
============================================================
Training Time: 0.0031s
Test Metrics: Acc=0.9776, Prec=0.9798, Rec=0.9686, F1=0.9741, Brier=0.0147

============================================================
Training Decision Tree
============================================================
Training Time: 0.0053s
Test Metrics: Acc=0.9614, Prec=0.9518, Rec=0.9600, F1=0.9559, Brier=0.0339

============================================================
Training Random Forest
============================================================
Training Time: 0.3540s
Test Metrics: Acc=0.9776, Prec=0.9770, Rec=0.9714, F1=0.9742, Brier=0.0195

============================================================
Training Naive Bayes
============================================================
Training Time: 0.0019s
Test Metrics: Acc=0.8072, Prec=0.8854, Rec=0.6400, F1=0.7430, Brier=0.1701

============================================================
Training Support Vector Machine
============================================================
Training Time: 0.9255s
Test Metrics: Acc=0.9577, Prec=0.9675, Rec=0.9343, F1=0.9506, Brier=0.0289

============================================================
Training XGBoost
============================================================
Training Time: 0.0589s
Test Metrics: Acc=0.9764, Prec=0.9825, Rec=0.9629, F1=0.9726, Brier=0.0180

============================================================
Training LightGBM
============================================================
Training Time: 0.0636s
Test Metrics: Acc=0.9776, Prec=0.9826, Rec=0.9657, F1=0.9741, Brier=0.0177

============================================================
Training CatBoost
============================================================
Training Time: 0.1294s
Test Metrics: Acc=0.9677, Prec=0.9709, Rec=0.9543, F1=0.9625, Brier=0.0265

============================================================
Classification Benchmark Training Complete!
============================================================
```

### Neural Network (Flax/JAX) Training

The core neural network model uses the **JAX** framework for high-performance computing, with **Flax** providing a convenient API for defining the architecture and **Optax** for the optimization routine.

#### Model Architecture and Setup

In [None]:
# Define the simple Feedforward Neural Network (FNN)
class Classifier(nn.Module):
    """A simple feedforward neural network for classification."""
    hidden_dim: int = 32
    num_classes: int = 2

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(self.num_classes)(x)
        return x  # Returns logits

# Initialize the model and calculate parameters
model_cls = Classifier(hidden_dim=32, num_classes=2)
key, init_key = random.split(key)
params_cls = model_cls.init(init_key, jnp.ones((1, 2)))
param_count = sum(x.size for x in jax.tree_util.tree_leaves(params_cls))

print(f"Neural Network Architecture: Input(2) -> Dense({model_cls.hidden_dim}) -> ReLU -> Dense({model_cls.hidden_dim}) -> ReLU -> Dense({model_cls.num_classes})")
print(f"Total Parameters: {param_count}")

# Define the training state (parameters + optimizer)
learning_rate = 0.001
tx = optax.adam(learning_rate)
state_cls = train_state.TrainState.create(
    apply_fn=model_cls.apply,
    params=params_cls['params'],
    tx=tx
)

#### JAX/Flax Core Functions

JAX's functional programming paradigm allows defining pure, *JIT-compiled functions* for the training and evaluation loops, drastically improving performance.

In [None]:
def cross_entropy_loss(logits, labels):
    """Compute categorical cross-entropy loss."""
    y_one_hot = jax.nn.one_hot(labels, num_classes=2)
    log_probs = jax.nn.log_softmax(logits)
    return -jnp.mean(jnp.sum(y_one_hot * log_probs, axis=-1))

def compute_metrics(logits, labels):
    """Compute loss and accuracy."""
    loss = cross_entropy_loss(logits, labels)
    accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)
    return {'loss': loss, 'accuracy': accuracy}

@jax.jit
def train_step(state, batch_x, batch_y):
    """Single JAX training step, calculating gradients and updating state."""
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch_x)
        return cross_entropy_loss(logits, batch_y)

    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss

@jax.jit
def eval_step(params, batch_x, batch_y):
    """Single JAX evaluation step."""
    logits = model_cls.apply({'params': params}, batch_x)
    return compute_metrics(logits, batch_y)

@jax.jit
def predict_fn(params, X):
    """JAX prediction step (returns logits)."""
    return model_cls.apply({'params': params}, X)

#### Training Loop with Early Stopping

In [None]:
# Training parameters
num_epochs = 1000
batch_size = 32
patience = 10
patience_counter = 0
best_val_loss = float('inf')
best_params = state_cls.params

train_losses, val_losses = [], []
train_accs, val_accs = [], []

print("\nStarting Neural Network Training...\n")
nn_start_time = time.time()

for epoch in range(num_epochs):
    # Shuffle and create mini-batches
    key, subkey = random.split(key)
    perm = random.permutation(subkey, len(X_train_jax))
    X_train_shuffled = X_train_jax[perm]
    y_train_shuffled = y_train_jax[perm]

    for i in range(0, len(X_train_jax), batch_size):
        x_batch = X_train_shuffled[i:i+batch_size]
        y_batch = y_train_shuffled[i:i+batch_size]
        state_cls, _ = train_step(state_cls, x_batch, y_batch)

    # Compute and store metrics
    train_metrics = eval_step(state_cls.params, X_train_jax, y_train_jax)
    val_metrics = eval_step(state_cls.params, X_val_jax, y_val_jax)

    train_loss = float(train_metrics['loss'])
    val_loss = float(val_metrics['loss'])
    
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accs.append(float(train_metrics['accuracy']))
    val_accs.append(float(val_metrics['accuracy']))

    # Early stopping check
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_params = state_cls.params
        patience_counter = 0
    else:
        patience_counter += 1

    if (epoch + 1) % 50 == 0:
        print(f"Epoch {epoch+1}/{num_epochs}: Train Loss={train_loss:.4f}, Train Acc={train_accs[-1]:.4f}, Val Loss={val_loss:.4f}, Val Acc={val_accs[-1]:.4f}")

    if patience_counter >= patience:
        print(f"\nEarly stopping at epoch {epoch+1}")
        break

nn_train_time = time.time() - nn_start_time
params_nn_cls = best_params
print(f"\nNeural Network Training Complete! Time: {nn_train_time:.4f}s, Epochs: {len(train_losses)}")

```
Starting Neural Network Training...


Early stopping at epoch 40

Neural Network Training Complete! Time: 12.0536s, Epochs: 40
```

#### Neural Network Test Evaluation

The final, best-performing parameters are used to evaluate performance on the unseen test set, and results are added to the comparison dictionary.

In [None]:
# Evaluate the JAX/Flax classifier on the test set
logits_test = predict_fn(params_nn_cls, X_test_jax)
y_test_pred_nn = np.array(jnp.argmax(logits_test, axis=-1))
y_test_prob_nn = np.array(jax.nn.softmax(logits_test))

nn_test_acc = accuracy_score(y_test_cls, y_test_pred_nn)
nn_test_prec = precision_score(y_test_cls, y_test_pred_nn, zero_division=0)
nn_test_rec = recall_score(y_test_cls, y_test_pred_nn, zero_division=0)
nn_test_f1 = f1_score(y_test_cls, y_test_pred_nn, zero_division=0)
nn_test_brier = brier_score_loss(y_test_cls, y_test_prob_nn[:, 1])

# Add to results
all_results_cls['Neural Network (Flax)'] = {
    'train_time': nn_train_time,
    'test_acc': nn_test_acc,
    'test_prec': nn_test_prec,
    'test_rec': nn_test_rec,
    'test_f1': nn_test_f1,
    'test_brier': nn_test_brier
}

# Detailed output
cm = confusion_matrix(y_test_cls, y_test_pred_nn)

print(f"\n{'='*60}")
print("NEURAL NETWORK DETAILED EVALUATION (Classification)")
print(f"{'='*60}\n")
print(f"Confusion Matrix:\n{cm}\n")
print(f"Test Metrics:")
print(f"  Accuracy:  {nn_test_acc:.4f}")
print(f"  Precision: {nn_test_prec:.4f}")
print(f"  Recall:    {nn_test_rec:.4f}")
print(f"  F1 Score:  {nn_test_f1:.4f}")
print(f"  Brier Score: {nn_test_brier:.4f}")

# Plot training history
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
ax1.plot(train_losses, label='Train Loss')
ax1.plot(val_losses, label='Validation Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Neural Network: Training and Validation Loss')
ax1.legend()
ax1.grid(True)

ax2.plot(train_accs, label='Train Accuracy')
ax2.plot(val_accs, label='Validation Accuracy')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.set_title('Neural Network: Training and Validation Accuracy')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.savefig('classification_nn_training_history.png', dpi=300, bbox_inches='tight')
plt.show()
print("\nTraining history saved to 'classification_nn_training_history.png'")

```
============================================================
NEURAL NETWORK DETAILED EVALUATION (Classification)
============================================================

Confusion Matrix:
[[447   7]
 [ 11 339]]

Test Metrics:
  Accuracy:  0.9776
  Precision: 0.9798
  Recall:    0.9686
  F1 Score:  0.9741
  Brier Score: 0.0189
```
<table>
  <tr>
    <td><img src="Training and Validation Loss and Accuracy.png" width="1200"></td>
  </tr>
  
</table>

## 4. Classification Results Comparison and Visualization

The performance of all models is compiled and visualized to compare key metrics: **Accuracy**, **F1 Score**, and **Training Time**.

In [None]:
results_df_cls = pd.DataFrame(all_results_cls).T
results_df_cls = results_df_cls.round(4)

print("\n" + "="*90)
print("ALL CLASSIFICATION MODELS COMPARISON (Test Set)")
print("="*90)
print(results_df_cls.to_string())

# Sort by test accuracy and F1 score for highlights
results_sorted_acc = results_df_cls.sort_values('test_acc', ascending=False)
results_sorted_f1 = results_df_cls.sort_values('test_f1', ascending=False)
results_sorted_time = results_df_cls.sort_values('train_time')

print("\n" + "="*90)
print(f"Best Model by Accuracy: {results_sorted_acc.index[0]} (Accuracy: {results_sorted_acc['test_acc'].iloc[0]:.4f})")
print(f"Best Model by F1 Score: {results_sorted_f1.index[0]} (F1: {results_sorted_f1['test_f1'].iloc[0]:.4f})")
print(f"Fastest Training: {results_sorted_time.index[0]} (Time: {results_sorted_time['train_time'].iloc[0]:.4f}s)")
print("="*90)

# --- Visualization ---
fig, axes = plt.subplots(2, 2, figsize=(20, 14))

# 1. Test Accuracy Comparison
ax1 = axes[0, 0]
models = results_sorted_acc.index
test_acc = results_sorted_acc['test_acc']
colors = plt.cm.viridis(np.linspace(0, 1, len(models)))
bars1 = ax1.barh(models, test_acc, color=colors, height=0.6)
ax1.set_xlabel('Test Accuracy', fontsize=12, fontweight='bold')
ax1.set_title('Classification: Test Accuracy by Model', fontsize=14, fontweight='bold', pad=15)
ax1.set_xlim([0.75, 1.0])
for i, (bar, val) in enumerate(zip(bars1, test_acc)):
    ax1.text(val + 0.005, i, f'{val:.4f}', va='center', fontsize=10, fontweight='bold')
ax1.grid(axis='x', alpha=0.3, linestyle='--')
ax1.tick_params(axis='both', labelsize=10)

# 2. Test F1 Score Comparison
ax2 = axes[0, 1]
models_f1 = results_sorted_f1.index
test_f1 = results_sorted_f1['test_f1']
colors_f1 = plt.cm.plasma(np.linspace(0, 1, len(models_f1)))
bars2 = ax2.barh(models_f1, test_f1, color=colors_f1, height=0.6)
ax2.set_xlabel('Test F1 Score', fontsize=12, fontweight='bold')
ax2.set_title('Classification: Test F1 Score by Model', fontsize=14, fontweight='bold', pad=15)
ax2.set_xlim([0.7, 1.0])
for i, (bar, val) in enumerate(zip(bars2, test_f1)):
    ax2.text(val + 0.005, i, f'{val:.4f}', va='center', fontsize=10, fontweight='bold')
ax2.grid(axis='x', alpha=0.3, linestyle='--')
ax2.tick_params(axis='both', labelsize=10)

# 3. Training Time Comparison
ax3 = axes[1, 0]
models_time = results_sorted_time.index
train_time_vals = results_sorted_time['train_time']
colors_time = plt.cm.cool(np.linspace(0, 1, len(models_time)))
bars3 = ax3.barh(models_time, train_time_vals, color=colors_time, height=0.6)
ax3.set_xlabel('Training Time (seconds)', fontsize=12, fontweight='bold')
ax3.set_title('Classification: Training Time by Model', fontsize=14, fontweight='bold', pad=15)
for i, (bar, val) in enumerate(zip(bars3, train_time_vals)):
    ax3.text(val + 0.02, i, f'{val:.3f}s', va='center', fontsize=10, fontweight='bold')
ax3.grid(axis='x', alpha=0.3, linestyle='--')
ax3.tick_params(axis='both', labelsize=10)

# 4. Precision-Recall Trade-off
ax4 = axes[1, 1]
test_prec = results_df_cls['test_prec']
test_rec = results_df_cls['test_rec']
scatter = ax4.scatter(test_rec, test_prec, c=results_df_cls['test_f1'],
                      s=300, cmap='viridis', alpha=0.8, edgecolors='black', linewidth=2)
for i, model in enumerate(results_df_cls.index):
    offset_x = -0.002 if i % 2 == 0 else 0.002
    offset_y = 0.002 if i % 3 == 0 else -0.002
    ax4.annotate(model, (test_rec.iloc[i] + offset_x, test_prec.iloc[i] + offset_y),
                 fontsize=9, ha='center', va='bottom', fontweight='bold',
                 bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.7, edgecolor='gray'))
ax4.set_xlabel('Test Recall', fontsize=12, fontweight='bold')
ax4.set_ylabel('Test Precision', fontsize=12, fontweight='bold')
ax4.set_title('Classification: Precision-Recall Trade-off (color=F1 Score)', fontsize=14, fontweight='bold', pad=15)
ax4.grid(True, alpha=0.3, linestyle='--')
cbar = plt.colorbar(scatter, ax=ax4)
cbar.set_label('F1 Score', fontsize=11, fontweight='bold')
cbar.ax.tick_params(labelsize=10)
ax4.set_xlim([0.6, 1.01])
ax4.set_ylim([0.8, 1.01])
ax4.tick_params(axis='both', labelsize=10)

# Plotting on the subfigure
ax_inset = inset_axes(ax4, width="45%", height="45%", loc='lower left',
                      bbox_to_anchor=(0.2, 0.35, 1, 1),
                      bbox_transform=ax4.transAxes)

ax_inset.scatter(test_rec, test_prec, c=results_df_cls['test_f1'],
                 s=300, cmap='viridis', alpha=0.8, edgecolors='black', linewidth=2)
offsets = {
    'K-Nearest Neighbors': (0.00, -0.001),  
    'Decision Tree':       (0.0, 0.004),  
    'Random Forest':       (-0.0005, -0.0036), 
    'Support Vector Machine': (0.008, -0.004),
    'XGBoost':             (-0.0025, 0.003),   
    'LightGBM':            (0, 0.004),       
    'CatBoost':            (0.00, 0.004),   
    'Neural Network (Flax)': (0.0, 0.001)     
}

ax_inset.scatter(test_rec, test_prec, c=results_df_cls['test_f1'],
                 s=300, cmap='viridis', alpha=0.8, edgecolors='black', linewidth=2)

for i, model in enumerate(results_df_cls.index):
    if model in offsets: 
        x_coord = test_rec.iloc[i]
        y_coord = test_prec.iloc[i]
        
        offset_x, offset_y = offsets[model]
        
        ax_inset.annotate(model, (x_coord + offset_x, y_coord + offset_y),
                          fontsize=8, ha='center', va='center') 

x1, x2, y1, y2 = 0.93, 0.975, 0.94, 0.99 
ax_inset.set_xlim(x1, x2)
ax_inset.set_ylim(y1, y2)
ax_inset.grid(True, alpha=0.5, linestyle=':')
ax_inset.tick_params(axis='both', labelsize=8)

mark_inset(ax4, ax_inset, loc1=2, loc2=1, fc="none", ec="0.5")

plt.tight_layout(pad=3.0)
plt.savefig('classification_all_models_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print("\nComparison visualization saved to 'classification_all_models_comparison.png'")

```
==========================================================================================
ALL CLASSIFICATION MODELS COMPARISON (Test Set)
==========================================================================================
                        train_time  test_acc  test_prec  test_rec  test_f1  test_brier
QDA (JAX)                   0.9934    0.8271     0.8207    0.7714   0.7953      0.1301
K-Nearest Neighbors         0.0031    0.9776     0.9798    0.9686   0.9741      0.0147
Decision Tree               0.0053    0.9614     0.9518    0.9600   0.9559      0.0339
Random Forest               0.3540    0.9776     0.9770    0.9714   0.9742      0.0195
Naive Bayes                 0.0019    0.8072     0.8854    0.6400   0.7430      0.1701
Support Vector Machine      0.9255    0.9577     0.9675    0.9343   0.9506      0.0289
XGBoost                     0.0589    0.9764     0.9825    0.9629   0.9726      0.0180
LightGBM                    0.0636    0.9776     0.9826    0.9657   0.9741      0.0177
CatBoost                    0.1294    0.9677     0.9709    0.9543   0.9625      0.0265
Neural Network (Flax)      12.0536    0.9776     0.9798    0.9686   0.9741      0.0189

==========================================================================================
Best Model by Accuracy: K-Nearest Neighbors (Accuracy: 0.9776)
Best Model by F1 Score: Random Forest (F1: 0.9742)
Fastest Training: Naive Bayes (Time: 0.0019s)
==========================================================================================
```

<table>
  <tr>
    <td><img src="Comparison.png" width="1200"></td>
  </tr>
  
</table>