In [None]:
import sys
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np

from experiments.train import train
from experiments.evaluate import evaluate
from agents.SARSA_agent import SARSAAgent
from agents.Q_agent import Qlearning_agent 
from experiments.config import ENV_CONFIG, TRAINING_CONFIG, EVAL_CONFIG
from experiments.run_baselines import run_baselines
from environnement.cournot_env import CournotEnv

project_root = Path("..").resolve()
sys.path.insert(0, str(project_root))

In [None]:
baseline_rewards = run_baselines()
baseline_avg_profit = baseline_rewards.mean()

In [None]:
episode_rewards_sarsa, agents_sarsa = train(
    agent_class=SARSAAgent,
    agent_kwargs={"n_actions": 1, "q_max": ENV_CONFIG["q_max"]}
)

eval_rewards_sarsa = evaluate(agents_sarsa)

In [None]:
episode_rewards_q, agents_q = train(
    agent_class=Qlearning_agent,
    agent_kwargs={"q_max": ENV_CONFIG["q_max"]}
)

eval_rewards_q = evaluate(agents_q)

In [None]:
avg_sarsa = np.mean(episode_rewards_sarsa, axis=1)
avg_q = np.mean(episode_rewards_q, axis=1)

plt.figure()
plt.plot(avg_sarsa, label="SARSA")
plt.plot(avg_q, label="Q-learning")
plt.xlabel("Episode")
plt.ylabel("Average profit per firm")
plt.title("Learning curves comparison")
plt.legend()
plt.grid(True)
plt.show()

In [None]:
def rollout_episode(env, agents, greedy=True):
    """
    Rollout one episode and record states, actions, rewards, info.
    If greedy=True, sets epsilon=0 temporarily (if exists).
    """
    # turn off exploration
    old_eps = []
    if greedy:
        for ag in agents:
            if hasattr(ag, "epsilon"):
                old_eps.append(ag.epsilon)
                ag.epsilon = 0.0
            else:
                old_eps.append(None)

    states = []
    actions_hist = []
    rewards_hist = []
    prices = []

    state = env.reset()
    done = False

    while not done:
        states.append(state.copy())

        actions = np.array([ag.select_action(state)[0] for ag in agents])
        actions_hist.append(actions.copy())

        next_state, rewards, done, info = env.step(actions)
        rewards_hist.append(rewards.copy())
        prices.append(info["price"])

        state = next_state

    # restore epsilons
    if greedy:
        for ag, eps in zip(agents, old_eps):
            if eps is not None:
                ag.epsilon = eps

    return (np.array(states), np.array(actions_hist), np.array(rewards_hist), np.array(prices))


In [None]:
from analysis.metrics import nash_quantities, nash_price, distance_to_nash
from analysis.plotting import (
    plot_quantities_vs_nash,
    plot_price_vs_nash,
    plot_distance_to_nash,
    plot_distance_to_nash_rolling
)

env = CournotEnv(**ENV_CONFIG)

# Choose which trained agents you want to analyze:
agents = agents_sarsa   # or agents_q

states, actions_hist, rewards_hist, prices = rollout_episode(env, agents, greedy=True)

q_nash = nash_quantities(ENV_CONFIG)
p_nash = nash_price(ENV_CONFIG)

# Distance to Nash based on actions (quantities actually chosen at time t)
dist = distance_to_nash(actions_hist, q_nash, metric="l2")

print("Nash quantities:", q_nash)
print("Nash price:", p_nash)
print("Mean L2 distance to Nash:", dist.mean())

plot_quantities_vs_nash(actions_hist, q_nash, title="Quantities (chosen) vs Nash")
plot_price_vs_nash(prices, p_nash, title="Price vs Nash")
plot_distance_to_nash(dist, title="Distance to Nash (per step)")
plot_distance_to_nash_rolling(dist, window=10, title="Distance to Nash")


In [None]:
def eval_distance_distribution(agents, n_eval_episodes=50):
    env = CournotEnv(**ENV_CONFIG)
    q_nash = nash_quantities(ENV_CONFIG)

    means = []
    for _ in range(n_eval_episodes):
        _, actions_hist, _, _ = rollout_episode(env, agents, greedy=True)
        dist = distance_to_nash(actions_hist, q_nash, metric="l2")
        means.append(dist.mean())

    return np.array(means)

dists = eval_distance_distribution(agents, n_eval_episodes=50)

plt.figure()
plt.hist(dists, bins=15)
plt.xlabel("Mean L2 distance to Nash (episode)")
plt.ylabel("Count")
plt.title("Distribution of Nash distances across evaluation episodes")
plt.grid(True)
plt.show()

print("Average mean distance:", dists.mean())
print("Std:", dists.std())
