In [6]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from env.bertrand_env import BertrandPricingEnv

In [7]:
class QLearningAgent:
    """
    Q‐learning agent with three update modes:
      - 'async': asynchronous (only update played action)
      - 'sync_perfect': update all counterfactual profits exactly
      - 'sync_downward': update using downward‐demand assumption
    """
    def __init__(self,
                 n_actions: int,
                 alpha: float = 0.1,
                 init_low: float = 10.0,
                 init_high: float = 20.0,
                 update_type: str = 'async',
                 cost: float = 2.0,
                 prices: np.ndarray = None):
        self.n_actions = n_actions
        self.alpha = alpha
        # initialize W(p) ~ U[init_low, init_high]
        self.Q = np.random.uniform(init_low, init_high, size=n_actions)
        self.update_type = update_type
        self.cost = cost
        self.prices = prices  # needed for synchronous updates

    def select_action(self) -> int:
        # greedy policy (no explicit exploration beyond init)
        return int(np.argmax(self.Q))

    def update(self,
               action: int,
               reward: float,
               competitor_action: int):
        p_j = self.prices[competitor_action]
        p_i = self.prices[action]
        # realized quantity at chosen price
        if p_i < p_j and p_i <= 10:
            q = 1.0
        elif p_i == p_j and p_i <= 10:
            q = 0.5
        else:
            q = 0.0
        base_profit = (p_i - self.cost) * q

        if self.update_type == 'async':
            # only update played price
            self.Q[action] = self.alpha * reward + (1 - self.alpha) * self.Q[action]

        elif self.update_type == 'sync_perfect':
            # update for all p exactly
            for i, p in enumerate(self.prices):
                # compute counterfactual demand
                if p < p_j and p <= 10:
                    d = 1.0
                elif p == p_j and p <= 10:
                    d = 0.5
                else:
                    d = 0.0
                profit = (p - self.cost) * d
                self.Q[i] = self.alpha * profit + (1 - self.alpha) * self.Q[i]

        elif self.update_type == 'sync_downward':
            # first update the chosen action as in async
            self.Q[action] = self.alpha * reward + (1 - self.alpha) * self.Q[action]
            # then adjust others under downward demand assumption
            for i, p in enumerate(self.prices):
                if i == action:
                    continue
                # case p > chosen: W(p) can only fall to base_profit
                if p > p_i and self.Q[i] > base_profit:
                    pi_e = base_profit
                    self.Q[i] = self.alpha * pi_e + (1 - self.alpha) * self.Q[i]
                # case p < chosen: if (p-c)*q exceeds current W(p), update upward
                elif p < p_i and (p - self.cost) * q > self.Q[i]:
                    pi_e = (p - self.cost) * q
                    self.Q[i] = self.alpha * pi_e + (1 - self.alpha) * self.Q[i]
        else:
            raise ValueError(f"Unknown update_type: {self.update_type}")

In [None]:
def simulate(env, agent1, agent2, periods: int) -> np.ndarray:
    """
    Run a single trajectory of length `periods`, return sequence of prices chosen by firm 1.
    """
    history = np.zeros(periods)
    for t in range(periods):
        a1 = agent1.select_action()
        a2 = agent2.select_action()
        _, (r1, r2), _, _ = env.step((a1, a2))
        agent1.update(a1, r1, a2)
        agent2.update(a2, r2, a1)
        history[t] = env.prices[a1]
    return history


def run_all(protocol: str,
            Sims: int = 100,
            periods: int = 5000,
            alpha: float = 0.1):
    # environment & common params
    grid_size = 100
    price_min, price_max = 0.1, 10.0
    cost = 2.0
    prices = np.linspace(price_min, price_max, grid_size)
    env = BertrandPricingEnv(price_min, price_max, grid_size, cost)

    all_hist = np.zeros((Sims, periods))
    for sim in range(Sims):
        # new agents each sim
        ag1 = QLearningAgent(grid_size, alpha, 10, 20, protocol, cost, prices)
        ag2 = QLearningAgent(grid_size, alpha, 10, 20, protocol, cost, prices)
        hist = simulate(env, ag1, ag2, periods)
        all_hist[sim] = hist

    # compute percentiles by period
    q_min, q25, q50, q75, q_max = np.percentile(all_hist, [0,25,50,75,100], axis=0)
    return q_min, q25, q50, q75, q_max


def plot_results():
    protocols = ['async', 'sync_perfect', 'sync_downward']
    periods = {'async':5000, 'sync_perfect':5000, 'sync_downward':5000}
    plt.figure(figsize=(10,6))
    for prot in protocols:
        print(f"Running protocol: {prot}")  
        qmin, q25, q50, q75, qmax = run_all(prot, periods=periods[prot])
        xs = np.arange(len(q50))
        plt.plot(xs, q50, label=f"Median ({prot})")
    plt.xlabel('Period')
    plt.ylabel('Price')
    plt.title('Price Convergence under Different Update Protocols')
    plt.legend()
    plt.tight_layout()
    plt.show()

In [None]:
# Run the simulation and plot results
plot_results()