In [10]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Value Iteration

In [11]:
def get_deterministic_policy(states):
    pi = {}
    for state, value in states.items():
        pi[state] = list(value.keys())[0]
    return pi

In [12]:
from MDP_chess import policy_improve_shortest_path
import numpy as np

In [13]:
def value_iteration(states_actions, theta=1e-10, winning_reward=1e3):
    V = {}
    iters = 0
    for state in states_actions:
        V[state] = 0
    delta = theta + 1
    iterat = 0
    N = len(states_actions)
    while theta<delta: 
        suma = 0
        delta = 0
        for state, actions in states_actions.items():
            expected_rewards = []
            for action in actions:
                next_node = actions[action]['next_state']
                reward = actions[action]['status'] * winning_reward
                if next_node in V:
                    expected_rewards.append(-(reward + V[next_node]) - np.sign(-V[next_node]))
                else:
                    expected_rewards.append(-reward)
            V_updated = max(expected_rewards)
            suma = suma + np.abs(V_updated - V[state])
            delta = max(delta, np.abs(V_updated - V[state]))
            V[state] = V_updated
        iterat += 1
        print(iterat, delta, suma/N)
    return V, delta

# Rook endgame

In [35]:
import shelve
filename='./data/rook_states_tree/states'
states = shelve.open(filename, flag='r')
len(states)

398780

In [30]:
%time V, delta = value_iteration(states)

1 1000.0 4.17970810973469
2 999.0 7.591551732784994
3 998.0 12.473110486985306
4 997.0 8.636448668438739
5 996.0 7.598066603139576
6 995.0 13.587827875018807
7 994.0 24.948267215005767
8 993.0 33.33779026029389
9 992.0 44.633253924469635
10 991.0 50.75919554641657
11 990.0 58.45572496113145
12 988.0 67.39033301569788
13 986.0 76.31656803249912
14 985.0 89.0842394302623
15 984.0 99.40709413711821
16 982.0 100.57818847484829
17 981.0 93.3030292391795
18 980.0 71.0916244545865
19 978.0 40.815284116555496
20 976.0 16.711555243492654
21 974.0 4.1677942725312205
22 973.0 0.39699332965544915
23 971.0 0.010687597171372687
24 2.0 1.0030593309594262e-05
25 0 0.0
CPU times: user 6min 30s, sys: 32.3 s, total: 7min 2s
Wall time: 7min 2s


In [33]:
%time pi = policy_improve_shortest_path(V, states)

h6h8
['h6h8', 'h6h7', 'h6g6', 'h6f6', 'h6h5', 'h6h4', 'h6h3', 'h6h2', 'h6h1', 'e6f6', 'e6d6', 'e6f5', 'e6e5', 'e6d5']
[1000.0, 990.0, 990.0, 990.0, 996.0, 996.0, 996.0, 996.0, 996.0, 984.0, 986.0, 982.0, 982.0, 982.0]
CPU times: user 18 s, sys: 1.55 s, total: 19.6 s
Wall time: 19.6 s


In [34]:
np.save('PI_rook_endgame_value_iter', pi)
np.save('V_rook_endgame_value_iter', V)

# Pawn endgame

In [36]:
import shelve
filename='./data/pawn_states_tree/states'
states = shelve.open(filename, flag='r')
len(states)

807232

In [37]:
%time V, delta = value_iteration(states)

1 1000.0 5.232176127804646
2 999.0 11.528944590898279
3 998.0 17.415773903908665
4 997.0 27.15806856021565
5 996.0 40.91075552009831
6 995.0 61.28433213747721
7 994.0 79.10555701458813
8 993.0 90.92947256798541
9 992.0 88.46317911083803
10 991.0 76.24496551177357
11 990.0 59.524479951240785
12 989.0 54.16530563704115
13 988.0 48.9796613612939
14 987.0 50.10218747522398
15 986.0 51.804043942757474
16 985.0 49.15365718901134
17 984.0 40.06836200745263
18 982.0 27.24851219971458
19 980.0 17.232945423372712
20 978.0 7.941199060493142
21 977.0 2.4471415900261633
22 975.0 0.7893145960516927
23 974.0 0.5600917703956235
24 973.0 0.6646837588202648
25 972.0 0.7156430864980575
26 970.0 0.6043714818044874
27 968.0 0.6549926167446286
28 967.0 0.6395051236819155
29 966.0 0.47148651193213353
30 964.0 0.32417817925949416
31 962.0 0.3485813248235947
32 961.0 0.4560807797510505
33 960.0 0.35704605367478
34 959.0 0.1172587310711171
35 958.0 0.07100560929200032
36 0 0.0
CPU times: user 21min 8s, sys: 1mi

In [38]:
%time pi = policy_improve_shortest_path(V, states)

h6h8
['h6h8', 'h6h7', 'h6g6', 'h6f6', 'h6h5', 'h6h4', 'h6h3', 'h6h2', 'h6h1', 'e6f6', 'e6d6', 'e6f5', 'e6e5', 'e6d5']
[1000.0, 990.0, 990.0, 990.0, 996.0, 996.0, 996.0, 996.0, 996.0, 984.0, 986.0, 982.0, 982.0, 982.0]
CPU times: user 40.1 s, sys: 3.13 s, total: 43.2 s
Wall time: 43.3 s


In [39]:
np.save('PI_pawn_endgame_value_iter', pi)
np.save('V_pawn_endgame_value_iter', V)

# Two bishops endgame

In [40]:
import shelve
filename='./data/two_bishops_states_tree/states'
states = shelve.open(filename, flag='r')
len(states)

5923016

In [41]:
# CPU times: user 2h 8min 13s, sys: 9min 48s, total: 2h 18min 2s
# Wall time: 2h 22min 59s
%time V, delta = value_iteration(states)

1 1000.0 1.1035902992664548
2 999.0 0.27589643519450224
3 998.0 1.36986325885326
4 997.0 2.8006522352801344
5 996.0 3.657171785455248
6 995.0 5.337648252174231
7 994.0 7.457905229362878
8 993.0 10.74546058967256
9 992.0 15.409639616033454
10 990.0 21.735496746927577
11 989.0 31.077902541543025
12 988.0 41.917350721321704
13 986.0 53.3536792742076
14 985.0 66.9991401340128
15 984.0 80.96305817846854
16 983.0 92.40151841561799
17 982.0 101.91697760060077
18 981.0 107.28973212295898
19 980.0 96.9107824459701
20 977.0 68.18849248423439
21 976.0 36.75164156233919
22 975.0 16.420765704499193
23 973.0 5.154635746383262
24 972.0 0.7542333162699544
25 969.0 0.025971228171593662
26 967.0 0.0007204100073341014
27 2.0 1.9922282836987104e-05
28 2.0 3.376658107963916e-07
29 0 0.0
CPU times: user 2h 8min 13s, sys: 9min 48s, total: 2h 18min 2s
Wall time: 2h 22min 59s


In [42]:
# CPU times: user 5min 3s, sys: 23.2 s, total: 5min 26s
# Wall time: 6min 28s
%time pi = policy_improve_shortest_path(V, states)

CPU times: user 5min 3s, sys: 23.2 s, total: 5min 26s
Wall time: 6min 28s


In [43]:
np.save('PI_two_bishops_endgame_value_iter', pi)
np.save('V_two_bishops_endgame_value_iter', V)

# Knight and Bishop endgame

In [None]:
import shelve
filename='./data/bishop_knight_tree/states'
states = shelve.open(filename, flag='r')
len(states)

In [None]:
%time V, delta = value_iteration(states)

In [None]:
%time pi = policy_improve_shortest_path(V, states)

In [None]:
np.save('PI_knight_bishop_endgame_value_iter', pi)
np.save('V_knight_bishop_endgame_value_iter', V)