In [None]:
import os
import csv
import gymnasium as gym
from stable_baselines3 import DQN
import requests
import json
import numpy as np

#OPENROUTER_API_KEY = os.environ.get("OPENROUTER_API_KEY", "API_KEY")

def short_specific_prompt(state_before, state_after, action_name, env_description=None):
    return (
        f"Provide a short and precise analysis of why the agent failed in the Frozen Lake environment.\n"
        f"Map details: S = Start, F = Frozen, H = Hole, G = Goal, X = Agentâ€™s position\n\n"
        f"Second to last state:\n{state_before}\n\n"
        f"Last action: {action_name}\n\n"
        f"Last state:\n{state_after}\n\n"
        f"Please include specific factors that contributed to this failure."
    )

def long_detailed_prompt(state_before, state_after, action_name, env_description):
    return (
        f"Provide a detailed analysis of the agent's failure in the Frozen Lake environment.\n"
        f"Environment details:\n{env_description}\n\n"
        f"Second to last state:\n{state_before}\n\n"
        f"Last action: {action_name}\n\n"
        f"Last state:\n{state_after}\n\n"
        f"Please include specific factors that contributed to this failure."
    )

def generate_explanation(model_id, prompt):
    url = "https://openrouter.ai/api/v1/chat/completions"
    headers = {"Authorization": f"Bearer {OPENROUTER_API_KEY}", "Content-Type": "application/json"}
    data = {"model": model_id, "messages": [{"role": "user", "content": prompt}], "temperature": 0.7, "top_p": 0.9}
    try:
        response = requests.post(url, headers=headers, data=json.dumps(data), timeout=30)
        if response.status_code != 200:
            return f"Error: HTTP {response.status_code} - {response.text[:200]}"
        return response.json()['choices'][0]['message']['content'].strip()
    except requests.exceptions.RequestException as e:
        return f"Error: Request failed: {e}"

def get_state_description(env, s):
    grid = [[c.decode("utf-8") if isinstance(c, bytes) else c for c in row] for row in env.unwrapped.desc.tolist()]
    ncol = env.unwrapped.ncol
    r = int(s) // ncol
    c = int(s) % ncol
    grid[r][c] = "X"
    return "\n".join("".join(row) for row in grid)

ACTION_MAP = {0: "LEFT", 1: "DOWN", 2: "RIGHT", 3: "UP"}

def generate_reference(env, state, action):
    desc = env.unwrapped.desc
    row = int(state) // env.unwrapped.ncol
    col = int(state) % env.unwrapped.ncol
    tile = desc[row][col].decode("utf-8") if isinstance(desc[row][col], bytes) else desc[row][col]
    action_name = ACTION_MAP[int(action)]
    if tile == 'H':
        return f"The agent fell into a hole after moving {action_name} from the state shown."
    if tile == 'G':
        return f"The agent reached the goal after moving {action_name} from the state shown."
    return f"The agent moved {action_name} from the state shown, but did not reach the goal."

env = gym.make("FrozenLake-v1", is_slippery=True)
model = DQN("MlpPolicy", env, verbose=1, device="cpu")
model.learn(total_timesteps=100000)
env_description = env.unwrapped.__doc__

models = [
    ("openai/gpt-5-chat", "GPT-5"),
    ("meta-llama/llama-3.2-11b-vision-instruct", "Llama-3"),
    ("deepseek/deepseek-r1-0528", "DeepSeek")
]

results = []
episodes = 100

for ep in range(episodes):
    obs, _ = env.reset()
    current_state = int(obs)
    transitions = []
    done = False
    while not done:
        action_raw, _ = model.predict(current_state, deterministic=False)
        try:
            action = int(action_raw)
        except Exception:
            action = int(np.asarray(action_raw).item())
        new_obs, reward, terminated, truncated, _ = env.step(action)
        new_state = int(new_obs)
        transitions.append((current_state, action, new_state, float(reward), bool(terminated), bool(truncated)))
        current_state = new_state
        done = terminated or truncated

    if transitions and transitions[-1][3] == 0 and (transitions[-1][4] or transitions[-1][5]):
        if len(transitions) >= 2:
            before_state_idx = transitions[-2][2]
            last_state_idx = transitions[-1][2]
            last_action = transitions[-1][1]
            state_before_text = get_state_description(env, before_state_idx)
            state_after_text = get_state_description(env, last_state_idx)
            action_name = ACTION_MAP[last_action]
            reference_text = generate_reference(env, last_state_idx, last_action)

            for prompt_func, prompt_name in [
                (short_specific_prompt, "ShortSpecific"),
                (long_detailed_prompt, "LongDetailed")
            ]:
                if prompt_name == "ShortSpecific":
                    prompt = prompt_func(state_before_text, state_after_text, action_name)
                else:
                    prompt = prompt_func(state_before_text, state_after_text, action_name, env_description)

                for model_id, model_name in models:
                    explanation = generate_explanation(model_id, prompt)
                    results.append([
                        ep + 1, model_name, prompt_name,
                        prompt, reference_text,
                        state_before_text, state_after_text, action_name,
                        explanation
                    ])

with open("FL_LLM_Explanations.csv", "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow([
        "Episode", "LLM", "PromptType",
        "Prompt", "Reference",
        "SecondToLastStateMap", "LastStateMap", "LastAction",
        "Explanation"
    ])
    writer.writerows(results)