In [53]:
from itertools import product
import numpy as np
from tqdm import tqdm
from matplotlib import pyplot as plt
from matplotlib import colors as mcolors
%matplotlib inline

In [54]:
# game (2 players) settings,
lo, hi = 0, 1

def r1(a1, a2):
    return min(a1, a2) + 2 * np.sign(a2 - a1)

np.random.seed(618)
k_max = 1

for a1 in range(lo, hi + 1):
    print([(r1(a1, a2), r1(a2, a1)) for a2 in range(lo, hi + 1)])

[(0, 0), (2, -2)]
[(-2, 2), (1, 1)]


In [55]:
# strategy settings
actions = range(lo, hi + 1)
states = []
for k in range(k_max + 1):
    states += list(map(list, product(actions, repeat=2 * k)))
print(states)

[[], [0, 0], [0, 1], [1, 0], [1, 1]]


In [56]:
num_actions = len(actions)
num_states = len(states)
R = np.zeros(shape=(num_states, num_actions, num_actions))
T = np.zeros(shape=(num_states, num_states, num_actions, num_actions))
print(R.shape, T.shape)

for idx1, a1 in enumerate(actions):
    for idx2, a2 in enumerate(actions):
        R[:, idx1, idx2] = r1(a1, a2)

for idx1, a1 in enumerate(actions):
    for idx2, a2 in enumerate(actions):
        R[:, idx1, idx2] = r1(a1, a2)
for i, Si in enumerate(states):
    for idx1, a1 in enumerate(actions):
        for idx2, a2 in enumerate(actions):
            Sj = Si[2:] + [a1, a2]
            l = len(Sj)
            base = [num_actions ** i for i in range(l - 1, -1, -1)]
            j = np.dot(np.array(Sj), base) + sum([num_actions ** m for m in range(l - 2, -1 , -2)])
            # print(Sj, states.index(Sj), j)
            T[i, j, idx1, idx2] = 1

(5, 2, 2) (5, 5, 2, 2)


In [67]:
# initialization
p1 = np.random.rand(num_states, num_actions)
p1 = p1 / p1.sum(axis=1, keepdims=True)
p2 = np.random.rand(num_states, num_actions)
p2 = p2 / p2.sum(axis=1, keepdims=True)

V1 = np.random.rand(num_states)
V2 = np.random.rand(num_states)

print(p1.shape, V1.shape)

# refinement mapping iteration
gamma = 0.99
it = 100
for i in tqdm(range(it)):
    R1SA = np.sum(R * p2[:, np.newaxis, :], axis=2)
    T1SAS = np.sum(T * p2[:, np.newaxis, np.newaxis, :], axis=3).transpose(0, 2, 1)
    
    R2SA = np.sum(R * p1[:, np.newaxis, :], axis=2)
    T2SAS = np.sum(T * p1[:, np.newaxis, np.newaxis, :], axis=3).transpose(0, 2, 1)

    # solve Bellman expectation equation
    I = np.eye(num_states)
    
    R1S = np.sum(R1SA * p1, axis=1)
    T1SS = np.sum(T1SAS * p1[:, :, np.newaxis], axis=1)
    V1 = np.linalg.inv(I - gamma * T1SS) @ R1S

    R2S = np.sum(R2SA * p2, axis=1)
    T2SS = np.sum(T2SAS * p2[:, :, np.newaxis], axis=1)
    V2 = np.linalg.inv(I - gamma * T2SS) @ R2S

    # extract Q and refine
    Q1 = R1SA + gamma * T1SAS @ V1
    # V1 = np.sum(Q1 * p1, axis=1))
    ref1 = np.maximum(Q1 - V1[:, np.newaxis], 0)
    p1 = np.exp(p1 + ref1) - 1
    # p1 = np.sqrt(p1 + ref1)
    p1 = p1 / p1.sum(axis=1, keepdims=True)
    
    Q2 = R2SA + gamma * T2SAS @ V2
    # V2 = np.sum(Q2 * p2, axis=1)
    ref2 = np.maximum(Q2 - V2[:, np.newaxis], 0)
    p2 = np.exp(p2 + ref2) - 1
    # p2 = np.sqrt(p2 + ref2)
    p2 = p2 / p2.sum(axis=1, keepdims=True)

print(p1.round(3), V1.round(3))
print(p2.round(3), V2.round(3))

(5, 2) (5,)


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3023.25it/s]

[[1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]] [0. 0. 0. 0. 0.]
[[1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]] [0. 0. 0. 0. 0.]





In [58]:
# V1_hist = dict()
# k_list = [1, 2, 3]
# for k in k_list:
#     V1_hist[k] = []

In [59]:
# for k in k_list:
#     for _ in range(5):
#         V1_hist[k].append([])
    
        
        
#             V1_hist[k][-1].append(V1.mean())

In [60]:
# fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(4, 6), dpi=250)
# for k in k_list:
#     mean = np.mean(V1_hist[k], axis=0)
#     std = np.std(V1_hist[k], axis=0)
#     ax.plot(range(it), mean, label=f'memory={k}', color=list(mcolors.TABLEAU_COLORS)[k], lw=1)
#     # ax.fill_between(range(it), mean - std, mean + std, color=list(mcolors.TABLEAU_COLORS)[k], alpha=0.2, ec=None)
#     ax.hlines(40, 0, it - 1, color='b', linestyle='dotted')
#     # ax.set_xticks(range(it))
# ax.legend()