In [None]:
from Env.MariNav import *
import networkx as nx
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sb3_contrib import MaskablePPO
import h3
from utils import *

import os
import csv
import time
from collections.abc import Mapping
import json
from collections import OrderedDict

# ========= User Config =========
GRAPH_PATH = "../wind_and_graph_2024/GULF_VISITS_cargo_tanker_2024_merged.gexf"
WIND_MAP_PATH = "../wind_and_graph_2024/2024_august_wind_data.csv"

MODEL_PATHS = [
    "model_step_10000000_4.zip",
    "model_step_10000000_42.zip",
    "model_step_10000000_31.zip",
]

CSV_DIR = "eval_csvs"  # per-model CSVs will be stored here
MERGED_CSV_OUT = os.path.join(CSV_DIR, "evaluation_info_logs_all_models.csv")

PAIR_LIST = [
    ("861ab6847ffffff", "860e4daafffffff"),
]

EPISODES = 5
MAX_STEPS = 1000
H3_RES = 6
WIND_THRESHOLD = 22
RENDER_MODE = "human"

# ========= Helpers =========
def flatten_dict(d, parent_key="", sep="."):
    """Flatten nested dictionaries using dot notation for keys."""
    items = []
    for k, v in (d or {}).items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else str(k)
        if isinstance(v, dict):
            items.extend(flatten_dict(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)

def model_tag_from_path(path):
    """Return a short tag from the model filename (e.g., 'model_step_10000000_42')."""
    return os.path.splitext(os.path.basename(path))

# ========= Load graph and wind map =========
G_visits = nx.read_gexf(GRAPH_PATH).to_undirected()
try:
    full_wind_map = load_full_wind_map(WIND_MAP_PATH)  # Provided by project
except NameError:
    # Fallback: raw DataFrame if helper not in scope (works if MariNav accepts it)
    full_wind_map = pd.read_csv(WIND_MAP_PATH)

# ========= Evaluation routine for a single model =========
def evaluate_one_model(model_path):
    """
    Run EPISODES for one model checkpoint, write a per-model CSV, and return per-episode returns.
    """
    # Create fresh environment for each model
    try:
        eval_env = MariNav(
            pairs=PAIR_LIST,
            graph=G_visits,
            wind_map=full_wind_map,
            h3_resolution=H3_RES,
            wind_threshold=WIND_THRESHOLD,
            render_mode=RENDER_MODE,
        )
    except NameError as e:
        raise RuntimeError(
            "MariNav class is not available in current environment. Import it before running."
        ) from e

    # Load model with this env
    model = MaskablePPO.load(model_path, env=eval_env)

    rows = []
    episode_returns = []
    model_tag = model_tag_from_path(model_path)

    print(f"\n===== Evaluating {model_tag} =====")
    for episode in range(1, EPISODES + 1):
        obs, _ = eval_env.reset()
        done = False
        truncated = False
        agent_path = [getattr(eval_env, "current_h3", None)]
        start_h3 = getattr(eval_env, "start_h3", None)
        goal_h3 = getattr(eval_env, "goal_h3", None)
        print(f"\n🚀 Starting Episode {episode} from H3 {start_h3} to {goal_h3} | model={model_tag}")

        ep_ret = 0.0
        for step in range(MAX_STEPS):
            # Predict using action masks from env (required for MaskablePPO)
            action_masks = eval_env.action_masks()
            action, _ = model.predict(obs, deterministic=False, action_masks=action_masks)
            obs, reward, done, truncated, info = eval_env.step(action)
            ep_ret += float(reward)

            # Build base record
            rec = OrderedDict()
            rec["timestamp"] = time.time()
            rec["model_tag"] = model_tag
            rec["episode"] = episode
            rec["step"] = step
            rec["start_h3"] = start_h3
            rec["goal_h3"] = goal_h3
            rec["current_h3"] = getattr(eval_env, "current_h3", None)

            # Normalize action
            if np.isscalar(action) or np.array(action).size == 1:
                try:
                    rec["action"] = int(action)
                except Exception:
                    rec["action"] = float(np.array(action).item())
            else:
                rec["action"] = json.dumps(np.array(action).tolist())

            rec["reward"] = float(reward)
            rec["done"] = bool(done)
            rec["truncated"] = bool(truncated)

            # Lat/Lon for current H3 if available
            try:
                lat, lon = h3.cell_to_latlng(rec["current_h3"]) if rec["current_h3"] else (None, None)
            except Exception:
                lat, lon = (None, None)
            rec["lat"] = lat
            rec["lon"] = lon

            # Flatten info and merge
            info_flat = flatten_dict(info if isinstance(info, dict) else {})
            full_record = {**rec, **info_flat}
            rows.append(full_record)

            # Optional console print for checkpoints
            if step % 1000 == 0 or done:
                speed = info_flat.get("speed", -1.0)
                wind_dir = info_flat.get("wind_direction", -1.0)
                try:
                    speed_val = float(speed) if speed is not None else -1.0
                except Exception:
                    speed_val = -1.0
                try:
                    wind_val = float(wind_dir) if wind_dir is not None else -1.0
                except Exception:
                    wind_val = -1.0
                print(
                    f"🧭 Step {step} | H3: {rec['current_h3']} | "
                    f"Speed: {speed_val:.2f} knots | Wind Dir: {wind_val:.2f} | Reward: {reward:.2f}"
                )
                print(f"info: {info}")

            agent_path.append(rec["current_h3"])

            if truncated or done:
                break

        episode_returns.append(ep_ret)
        print(f"✅ Episode {episode} return: {ep_ret:.6f} | model={model_tag}")

    # ========= Write per-model CSV =========
    os.makedirs(CSV_DIR, exist_ok=True)
    # Consolidate all encountered columns (stable ordering)
    all_cols = []
    for r in rows:
        for k in r.keys():
            if k not in all_cols:
                all_cols.append(k)

    df = pd.DataFrame(rows, columns=all_cols)
    per_model_csv = os.path.join(CSV_DIR, f"evaluation_info_logs_{model_tag}.csv")
    df.to_csv(per_model_csv, index=False)
    print(f"Wrote {len(df)} rows to {per_model_csv}")

    return episode_returns, per_model_csv

# ========= Run all models, merge CSVs, and print averaged final return =========
all_episode_returns = {}
all_csv_paths = []

for mp in MODEL_PATHS:
    ep_returns, csv_path = evaluate_one_model(mp)
    all_episode_returns[model_tag_from_path(mp)] = ep_returns
    all_csv_paths.append(csv_path)

# Averaged final episode return across models
final_returns = [ret_list[-1] for ret_list in all_episode_returns.values() if len(ret_list) > 0]
avg_final_return = float(np.mean(final_returns)) if len(final_returns) > 0 else float("nan")
print("\n================ Summary ================")
for tag, rets in all_episode_returns.items():
    print(f"{tag}: episode returns = {np.array(rets).round(6).tolist()}")
print(f"\nAveraged final episode return across {len(final_returns)} models: {avg_final_return:.6f}")

# Optional: merge per-model CSVs into one
try:
    merged_frames = []
    for p in all_csv_paths:
        df_ = pd.read_csv(p)
        merged_frames.append(df_)
    if merged_frames:
        merged_df = pd.concat(merged_frames, ignore_index=True)
        merged_df.to_csv(MERGED_CSV_OUT, index=False)
        print(f"Wrote merged CSV with {len(merged_df)} rows to {MERGED_CSV_OUT}")
except Exception as e:
    print(f"Skipping merged CSV due to error: {e}")
