In [1]:
import numpy as np
import pandas as pd
from pandas import DataFrame
from tqdm import tqdm
from collections import defaultdict
import pickle

In [2]:
traj_tr = pickle.load(open('trajDr_tr.pkl', 'rb'))
nS, nA = 750, 25

In [3]:
traj_tr[4]

[{'s': 482, 'a': 0, 'r': 0, 's_': 513, 'a_': 0, 'done': False},
 {'s': 513, 'a': 0, 'r': 0, 's_': 513, 'a_': 0, 'done': False},
 {'s': 513, 'a': 0, 'r': 0, 's_': 513, 'a_': 5, 'done': False},
 {'s': 513, 'a': 5, 'r': 0, 's_': 513, 'a_': 5, 'done': False},
 {'s': 513, 'a': 5, 'r': 0, 's_': 513, 'a_': 5, 'done': False},
 {'s': 513, 'a': 5, 'r': 0, 's_': 513, 'a_': 5, 'done': False},
 {'s': 513, 'a': 5, 'r': 0, 's_': 513, 'a_': 5, 'done': False},
 {'s': 513, 'a': 5, 'r': 0, 's_': 582, 'a_': 0, 'done': False},
 {'s': 582, 'a': 0, 'r': 0, 's_': 475, 'a_': 0, 'done': False},
 {'s': 475, 'a': 0, 'r': 0, 's_': 513, 'a_': 0, 'done': False},
 {'s': 513, 'a': 0, 'r': 0, 's_': 582, 'a_': 0, 'done': False},
 {'s': 582, 'a': 0, 'r': 0, 's_': 451, 'a_': 0, 'done': False},
 {'s': 451, 'a': 0, 'r': -100, 's_': None, 'a_': None, 'done': True}]

In [5]:
# Empirically estimate the transition matrix
trans_counts = defaultdict(lambda: defaultdict(int))
for trajectory in tqdm(traj_tr):
    for t, transition in enumerate(trajectory):
        s = transition['s']
        a = transition['a']
        r = transition['r']
        s_ = transition['s_']
        if s_ is None:
            if r == -100:
                s_ = 750   # death
            elif r == 100:
                s_ = 751   # survival
            else:
                raise NotImplementedError
        
        trans_counts[s,a][s_] += 1

100%|██████████| 14657/14657 [00:00<00:00, 30239.62it/s]


In [6]:
# Normalise the transition counts
# Build probabilistic MDP model
P = defaultdict(lambda: defaultdict(list))
for key, emp_dist in tqdm(trans_counts.items()):
    s,a = key
    norm_const = sum(emp_dist.values())
    for s_, count in emp_dist.items():
        prob = float(count) / float(norm_const)
        P[s][a].append((prob, s_, int(s_ in [750, 751]) * (100 if int(s_==751) else -100), (s_ in [750, 751])))
        # terminal rewards {-1, +1}
        # intermediate rewards are all 0

100%|██████████| 10521/10521 [00:00<00:00, 26335.44it/s]


In [7]:
# pickle the trans counts to file for the value iteration stage later
import pickle
with open('MDP_P.p', 'wb') as f:
    pickle.dump(dict(P), f)

In [8]:
P[0]

defaultdict(list,
            {0: [(0.03875968992248062, 678, 0, False),
              (0.24031007751937986, 0, 0, False),
              (0.09302325581395349, 751, 100, True),
              (0.007751937984496124, 606, 0, False),
              (0.007751937984496124, 590, 0, False),
              (0.007751937984496124, 302, 0, False),
              (0.007751937984496124, 692, 0, False),
              (0.007751937984496124, 548, 0, False),
              (0.015503875968992248, 557, 0, False),
              (0.015503875968992248, 435, 0, False),
              (0.015503875968992248, 26, 0, False),
              (0.007751937984496124, 312, 0, False),
              (0.023255813953488372, 708, 0, False),
              (0.015503875968992248, 677, 0, False),
              (0.007751937984496124, 468, 0, False),
              (0.007751937984496124, 349, 0, False),
              (0.023255813953488372, 246, 0, False),
              (0.007751937984496124, 676, 0, False),
              (0.0077519379844