# Pricing simulation with Q-learning agents in Julia



In [1]:
include("sim_types.jl")
include("utils_simulation.jl")
include("agent.jl")
include("market_env.jl")

calc_reward

In [31]:
function train_agent(seed::Int64, p::SimParams)
    # Create local random number generator
    local_RNG = MersenneTwister(seed)
    
    # Create variable from the parameters p
    all_prices, all_price_states, n_states, n_actions = create_vars(p)

    # Create mapping dicts from int states (column index) to prices and vice versa
    int_to_price_state, price_state_to_int = create_transition_dict(all_price_states, n_states)
    # Create Q-matrices
    all_q = [rand(local_RNG, Float64, n_states, n_actions) for _ in 1:p.n_firms]
    
    # Start with the training
    # Arbitraty start stae
    state = 1
    next_state = 1
    # Init convergence counter 
    convergence_count = 0
    
    # Init best action
    best_actions = []
    
    
    for t in 1:p.max_t
        epsilon = calc_epsilon(t, beta)
        best_actions_old = copy(best_actions)
        current_actions = tuple([get_action(q, state, epsilon, n_actions, local_RNG) for q in all_q]...)
        current_prices = collect(index_to_price(current_actions)) #Collect to get an array
        current_profits = [calc_reward(price, current_prices, p) for price in current_prices]
        next_state = price_state_to_int[tuple(current_prices...)] #Cast to tuple again due to hash

        # Get best actions before the update
        best_actions_before_update = [get_best_action(q, state) for q in all_q]

        # Update
        all_q = [update(q, state, action, reward, next_state, p) for (q, action, reward) in zip(all_q, current_actions, current_profits)]

        # Get best actions after the update, Note that only the best action for the given 
        # state could have changed
        best_action_after_update = [get_best_action(q, state) for q in all_q]

        # check convergence
        if best_actions_before_update == best_action_after_update
            convergence_count += 1
            if convergence_count > p.rounds_convergence
                break
            end
        else
            convergence_count = 0
        end
        # State is the new state from the last period
        state = next_state
    end    
    return state, all_q
end

train_agent (generic function with 1 method)

In [40]:
n_firms = 2
min_price = 0
max_price = 5
reservation_price = 4
m_consumer = 60

k_memory = 1
beta = 4e-07
discount_rate = 0.95
alpha = 0.04
max_t = 1000000000
rounds_convergence = 200000
optimality_threshold = 0.00000000001

current_params = SimParams(n_firms = n_firms,
                           min_price = min_price,
                           max_price = max_price,
                           reservation_price = reservation_price,
                           m_consumer = m_consumer,
                           k_memory = k_memory,
                           beta = beta,
                           discount_rate = discount_rate,
                           alpha = alpha,
                           max_t = max_t,
                           rounds_convergence = rounds_convergence,
                           optimality_threshold=optimality_threshold)

SimParams(2, 0, 5, 4, 60, 1, 4.0e-7, 0.95, 0.04, 1000000000, 200000, 1.0e-11)

In [41]:
@time convergence_state, all_q =  train_agent(6, current_params)


 35.766818 seconds (363.64 M allocations: 26.935 GiB, 9.27% gc time)


(29, [[1648.8671342423718 2084.4622462697043 … 1731.3639410891965 1732.9709092358225; 1919.2879404227992 1923.8881296448887 … 2039.802943719999 1903.442959982311; … ; 2005.3421751364417 2068.9280167693105 … 2035.1069465762134 1967.652637083865; 1719.95837213905 2139.3958028004345 … 1865.750459160043 1753.2403060634565], [1656.935474509337 1670.895865518859 … 1639.39630813921 1687.3168718941777; 1885.1706516508577 1927.5296435923344 … 1840.7675616280355 1904.7007348700406; … ; 2079.0483671885772 2236.778922045518 … 2126.014764540704 2080.1342618362905; 1705.3580438008694 1694.3219286948456 … 2155.2594196632112 1682.1541203531033]])