In [None]:
from math import isclose

from frozendict import frozendict
from random import choice, choices


class State(frozendict):
    def __new__(cls, *args, **kwargs):
        return frozendict.__new__(cls, *args, **kwargs)

    def __str__(self):
        if 'title' in self:
            return self['title']

        return f"L{self['level']} ({'T' if self['tired'] else 'R'})"

    def __repr__(self):
        return self.__str__()


class MDP:
    def __init__(self, states, statesPlus, actions, transitions=None, rewards=None, gamma=0.9, eps=1e-6, T=100, costOfLiving=0,
                 startingState=None):
        self.states = states
        self.statesPlus = statesPlus
        self.actions = actions
        self.transitions = transitions
        self.rewards = rewards
        self.gamma = gamma
        self.eps = eps
        self.T = T
        self.costOfLiving = costOfLiving

        self.startingState = startingState if startingState else states[0]
        self.currentState = self.startingState
        self.stepsTaken = 0

        self.checkProbabilities()

    def checkProbabilities(self):
        for s, a in self.transitions.items():
            for action, outcomes in a.items():
                try:
                    assert isclose(sum(outcomes.values()), 1,
                                   abs_tol=1e-4)  # Making sure the sum of outcomes is effectively 1
                except AssertionError:
                    raise AssertionError(f"Probability not 1: {s} + {action} -> {outcomes.values()}")

                for o, chance in outcomes.items():
                    try:
                        assert 0 <= chance <= 1
                    except AssertionError:
                        raise AssertionError(f"Invalid probability 1: {s} + {action} -> {o} ({chance})")


    def reset(self):
        self.currentState = self.startingState
        self.stepsTaken = 0

        return self.currentState

    @property
    def stateSpace(self) -> int:
        return len(self.states)

    @property
    def allActions(self) -> set:
        return set([a for A in self.actions.values() for a in A])

    @property
    def actionSpace(self) -> int:
        return len(self.allActions)  # To set to remove duplicate actions

    def getActions(self) -> list:
        return self.actions[self.currentState]

    def isTerminal(self, state):
        return state not in self.transitions

    def getReward(self, state, action, newState):
        return self.rewards.get(state, {}).get(action, {}).get(newState, 0) + self.costOfLiving

    def step(self, action, message=False):
        if action not in self.actions[self.currentState]:
            print('\tInvalid action:', action)
            return (self.currentState,
                    0,
                    False,
                    False,
                    "Invalid")

        possibleOutcomes = self.transitions[self.currentState][action]  # {s: chance, s: chance, etc.}
        newState = choices(list(possibleOutcomes.keys()), list(possibleOutcomes.values()))[0]
        reward = self.getReward(self.currentState, action, newState)

        if message:
            print(f"\t{self.currentState} -> {newState} via {action} (reward of {reward:.2f})")

        self.stepsTaken += 1
        self.currentState = newState

        return (newState,
                reward,
                self.isTerminal(newState),
                # In our case we treat 'sinks' (states with no further transitions) as terminal.
                self.stepsTaken >= self.T,
                None)  # What should info be?


def Test(mdp: MDP, how: str, maxLength: int) -> None:
    state = mdp.reset()  # reset/re-initialize
    totalRewards = 0

    print(f"Testing MDP with action '{how}' for {maxLength} turns.")
    for _ in range(maxLength):
        action = choice(mdp.getActions()) if how == 'random' else how

        newState, reward, terminated, truncated, info = mdp.step(
            action)  # execute an action in the current state of MDP

        if info == 'Invalid':
            return

        totalRewards += reward
        print(f"\t{state} -> {newState} via {action} (reward of {reward:.2f}, total of {totalRewards:.2f})")

        if terminated:
            print('\tDone!')
            break

        state = newState


def ActionsFromTransitions(t: dict) -> dict:  # state: {action1: outcomes, action2: outcomes} -> state: [actions]
    return {s: list(A) for s, A in t.items()}


def StatesFromTransitions(t: dict) -> list:
    return list(t)

In [32]:
from numpy.random import choice


def ToTired(s: State) -> State:
    return State(level=s['level'],
                 tired=True)


def ToRested(s: State) -> State:
    return State(level=s['level'],
                 tired=False)


def ToTrained(s: State) -> State:
    return State(level=s['level'] + 1,
                 tired=True)


def WinChance(s: State, maxLevel: int, invert: bool = False) -> float:
    # winChance = 0.1 + (0.9 / (maxLevel - 1)) * (s['level'] - 1)
    winChance = 0.2 * s['level'] - 0.1

    if s['tired']:
        winChance /= 2

    return (1 - winChance) if invert else winChance


def GenerateMDP(maxLevel: int, attackedChance: float, costOfLiving: float = 0, startingState: State = None) -> MDP:
    states = [State(level=l, tired=t) for l in range(1, maxLevel + 1) for t in (True, False)]

    won, died = State(title="Won"), State(title="Died")
    statesPlus = states + [won, died]

    transitions = {}

    for s in states:
        t = {'attack': {won: WinChance(s, maxLevel),
                        died: WinChance(s, maxLevel, True)}}

        if s['tired']:
            t['rest'] = {ToRested(s): 1 - attackedChance,
                         died: attackedChance}
        else:
            t['defend'] = {s: 1 - attackedChance,
                           ToTired(s): attackedChance}

            if s['level'] < maxLevel:
                t['train'] = {ToTrained(s): 1 - attackedChance,
                              died: attackedChance}

        transitions[s] = t

    rewards = {s: {'attack': {won: 1, died: -1},
                   'defend': {died: -1},
                   'train': {died: -1},
                   'rest': {died: -1}}
               for s in states}


    return MDP(states,
               statesPlus,
               ActionsFromTransitions(transitions),
               transitions=transitions,
               rewards=rewards,
               costOfLiving=costOfLiving,
               startingState=states[0] if not startingState else startingState)

In [33]:
from copy import deepcopy
from pprint import pprint

GAMMA = 0.95


def Gt(state: State, action: str, nextState: State, mdp: MDP, vTable: dict) -> float:
    # print(f"GT ({state}, {action} -> {nextState}): {mdp.getReward(state, action, nextState)}, s': {vTable[nextState]}")
    return mdp.getReward(state, action, nextState) + (GAMMA * vTable[nextState])


# Bellman equation for V
def CalculateV(policy: dict, state: State, mdp: MDP, vTable: dict) -> float:
    if mdp.isTerminal(state):
        return 0
    
    return sum(p * CalculateQ(state, a, mdp, vTable)
               for a, p in policy[state].items())


# Bellman equation for Q
def CalculateQ(state: State, action: str, mdp: MDP, vTable: dict) -> float:
    return sum(p * Gt(state, action, sNext, mdp, vTable)
               for sNext, p in mdp.transitions[state][action].items())


def UpdateQ(state: State, policy: dict, mdp: MDP, vTable: dict) -> dict:
    return {a: CalculateQ(state, a, mdp, vTable) for a in policy[state]}


def ArgMax(d: dict):  # np.argmax was causing issues.
    for k in d:
        if d[k] == max(d.values()):
            return k


def UpdatePolicy(policy: dict, actionValues: dict) -> dict:  # Choosing the optimal action from our actionValues.
    updatedPolicy = deepcopy(policy)

    for state in policy:
        optimalAction = ArgMax(actionValues[state])

        updatedPolicy[state] = {a: int(a == optimalAction) for a in policy[state]}

    return updatedPolicy


def UpdateV(vTable: dict, policy: dict, mdp: MDP, sweeps: int) -> dict:
    for _ in range(sweeps):
        tableCopy = deepcopy(vTable)
        
        vTable = {s: CalculateV(policy, s, mdp, tableCopy) for s in vTable}

    return vTable


def GPI(policy: dict, mdp: MDP, maxIterations: int = 10, vSweeps: int = 100) -> dict:
    vTable = {s: 0 for s in mdp.statesPlus}

    for _ in range(maxIterations):
        # Evaluation
        vTable = UpdateV(vTable, policy, mdp, sweeps=vSweeps)
        actionValues = {s: UpdateQ(s, policy, mdp, vTable) for s in mdp.states}

        # Improvement
        newPolicy = UpdatePolicy(policy, actionValues)
        
        if newPolicy == policy:  # No more changes, hence the optimal policy has been found.
            print('CONVERGED')
            print('V:')
            pprint({s: round(vTable[s], 3) for s in mdp.states})
            print('\nQ:')
            pprint({s: {a: round(v, 3) for a, v in actionValues[s].items()} for s in mdp.states})
            
            return policy
        
        policy = newPolicy

    print('N RUNS')
    print('V:')
    pprint({s: round(vTable[s], 3) for s in mdp.states})
    print('\nQ:')
    pprint({s: {a: round(v, 3) for a, v in actionValues[s].items()} for s in mdp.states})
    return policy


def GenerateRandomPolicy(mdp: MDP) -> dict:
    return {s: {a: 1 / len(actions) for a in actions}
            for s, actions in ActionsFromTransitions(mdp.transitions).items()}


In [34]:
mdp = GenerateMDP(maxLevel=5,
                  attackedChance=0,
                  costOfLiving=0,
                  startingState=State(level=1, tired=False))

# First pass

In [35]:
GenerateRandomPolicy(mdp)

{L1 (T): {'attack': 0.5, 'rest': 0.5},
 L1 (R): {'attack': 0.3333333333333333,
  'defend': 0.3333333333333333,
  'train': 0.3333333333333333},
 L2 (T): {'attack': 0.5, 'rest': 0.5},
 L2 (R): {'attack': 0.3333333333333333,
  'defend': 0.3333333333333333,
  'train': 0.3333333333333333},
 L3 (T): {'attack': 0.5, 'rest': 0.5},
 L3 (R): {'attack': 0.3333333333333333,
  'defend': 0.3333333333333333,
  'train': 0.3333333333333333},
 L4 (T): {'attack': 0.5, 'rest': 0.5},
 L4 (R): {'attack': 0.3333333333333333,
  'defend': 0.3333333333333333,
  'train': 0.3333333333333333},
 L5 (T): {'attack': 0.5, 'rest': 0.5},
 L5 (R): {'attack': 0.5, 'defend': 0.5}}

In [36]:
policy = GPI(GenerateRandomPolicy(mdp), mdp, maxIterations=1, vSweeps=1)

N RUNS
V:
{L1 (R): -0.267,
 L1 (T): -0.45,
 L2 (R): -0.133,
 L2 (T): -0.35,
 L4 (R): 0.133,
 L5 (T): -0.05,
 L5 (R): 0.4,
 L4 (T): -0.15,
 L3 (R): 0.0,
 L3 (T): -0.25}

Q:
{L1 (R): {'attack': -0.8, 'defend': -0.253, 'train': -0.332},
 L1 (T): {'attack': -0.9, 'rest': -0.253},
 L2 (R): {'attack': -0.4, 'defend': -0.127, 'train': -0.237},
 L2 (T): {'attack': -0.7, 'rest': -0.127},
 L4 (R): {'attack': 0.4, 'defend': 0.127, 'train': -0.048},
 L5 (T): {'attack': -0.1, 'rest': 0.38},
 L5 (R): {'attack': 0.8, 'defend': 0.38},
 L4 (T): {'attack': -0.3, 'rest': 0.127},
 L3 (R): {'attack': 0.0, 'defend': 0.0, 'train': -0.142},
 L3 (T): {'attack': -0.5, 'rest': 0.0}}


### $v_{\pi}(s)=\sum_{a} \pi(a|s) \sum_{r, s'} p(r, s' | s, a) [r+ \gamma v_{\pi}(s')]$
Since all values are initially set to 0, the formula can for now be simplified to:

### $v_{\pi}(s)=\sum_{a} \pi(a|s) \sum_{r, s'} p(r, s' | s, a) [r]$

| State  | Initial Value | Options                                             | Calculation                                                       | New Value |
|--------|---------------|-----------------------------------------------------|-------------------------------------------------------------------|-----------|
| L1 (T) | 0             | Attack (Died, Won), Rest (L1 (R))                   | $\frac{1}{2}(0.95(-1) + 0.05(1)) + \frac{1}{2}(0)$                | -0.45     |
| L1 (R) | 0             | Attack (Died, Won), Train (L2 (T)), Defend (L1 (R)) | $\frac{1}{3}(0.9(-1) + 0.1(1)) + \frac{1}{3}(0) + \frac{1}{3}(0)$ | -0.267    |
| L2 (T) | 0             | Attack (Died, Won), Rest (L2 (R))                   | $\frac{1}{2}(0.85(-1) + 0.15(1)) + \frac{1}{2}(0)$                | -0.35     |
| L2 (R) | 0             | Attack (Died, Won), Train (L3 (T)), Defend (L2 (R)) | $\frac{1}{3}(0.7(-1) + 0.3(1)) + \frac{1}{3}(0) + \frac{1}{3}(0)$ | -0.133    |
| L3 (T) | 0             | Attack (Died, Won), Rest (L3 (R))                   | $\frac{1}{2}(0.75(-1) + 0.25(1)) + \frac{1}{2}(0)$                | -0.25     |
| L3 (R) | 0             | Attack (Died, Won), Train (L4 (T)), Defend (L3 (R)) | $\frac{1}{3}(0.5(-1) + 0.5(1)) + \frac{1}{3}(0) + \frac{1}{3}(0)$ | 0         |
| L4 (T) | 0             | Attack (Died, Won), Rest (L4 (R))                   | $\frac{1}{2}(0.65(-1) + 0.35(1)) + \frac{1}{2}(0)$                | -0.15     |
| L4 (R) | 0             | Attack (Died, Won), Train (L5 (T)), Defend (L4 (R)) | $\frac{1}{3}(0.3(-1) + 0.7(1)) + \frac{1}{3}(0) + \frac{1}{3}(0)$ | 0.133     |
| L5 (T) | 0             | Attack (Died, Won), Rest (L5 (R))                   | $\frac{1}{2}(0.55(-1) + 0.25(1)) + \frac{1}{2}(0)$                | -0.05     |
| L5 (R) | 0             | Attack (Died, Won), Defend (L5 (R))                 | $\frac{1}{2}(0.1(-1) + 0.9(1)) + \frac{1}{2}(0)$                  | 0.4       |

### $q_{\pi}(s, a) = \sum_{s', r} p(s', r | s, a) [r+\gamma \sum_{a'} \pi (a'|s') q_{\pi}(s', a')]$


| State  | Action | Outcomes  | Calculation        | New Value |
|--------|--------|-----------|--------------------|-----------|
| L1 (T) | Attack | Died, Won | $0.95(-1)+0.05(1)$ | -0.9      |
| L1 (T) | Rest   | L1 (R)    | $1(-0.267)$        | -0.267    |
| L1 (R) | Attack | Died, Won | $0.9(-1)+0.1(1)$   | -0.8      |
| L1 (R) | Train  | L2 (T)    | $1(-0.35)$         | -0.35     |
| L1 (R) | Defend | L1 (R)    | $1(-0.267)$        | -0.35     |
| L2 (T) | Attack | Died, Won | $0.85(-1)+0.15(1)$ | -0.7      |
| L2 (T) | Rest   | L2 (R)    | $1(-0.133)$        | -0.133    |
| L2 (R) | Attack | Died, Won | $0.7(-1)+0.3(1)$   | -0.4      |
| L2 (R) | Train  | L3 (T)    | $1(-0.25)$         | -0.25     |
| L2 (R) | Defend | L2 (R)    | $1(-0.133)$        | -0.133    |
| L3 (T) | Attack | Died, Won | $0.75(-1)+0.25(1)$ | -0.5      |
| L3 (T) | Rest   | L3 (R)    | $1(0)$             | 0         |
| L3 (R) | Attack | Died, Won | $0.5(-1)+0.5(1)$   | 0         |
| L3 (R) | Train  | L4 (T)    | $1(-0.15)$         | -0.15     |
| L3 (R) | Defend | L3 (R)    | $1(0)$             | 0         |
| L4 (T) | Attack | Died, Won | $0.65(-1)+0.35(1)$ | -0.3      |
| L4 (T) | Rest   | L4 (R)    | $1(0.133)$         | 0.133     |
| L4 (R) | Attack | Died, Won | $0.3(-1)+0.7(1)$   | 0.4       |
| L4 (R) | Train  | L5 (T)    | $1(-0.05)$         | -0.05     |
| L4 (R) | Defend | L4 (R)    | $1(0.133)$         | 0.133     |
| L5 (T) | Attack | Died, Won | $0.55(-1)+0.45(1)$ | -0.1      |
| L5 (T) | Rest   | L5 (R)    | $1(0.4)$           | 0.4       |
| L5 (R) | Attack | Died, Won | $0.1(-1)+0.9(1)$   | 0.8       |
| L5 (R) | Defend | L5 (R)    | $1(0.4)$           | 0.4       |


# Last pass

After the policy is converged, it is no longer expected to update. Following that logic, using the last policy, state and action values to run one more iteration of GPI should result in no changes to the policy.

In [37]:
GAMMA = 0.95

GPI(GenerateRandomPolicy(mdp), mdp, maxIterations=1000, vSweeps=100)

CONVERGED
V:
{L1 (R): 0.531,
 L1 (T): 0.504,
 L2 (R): 0.588,
 L2 (T): 0.559,
 L4 (R): 0.722,
 L5 (T): 0.76,
 L5 (R): 0.8,
 L4 (T): 0.686,
 L3 (R): 0.652,
 L3 (T): 0.619}

Q:
{L1 (R): {'attack': -0.8, 'defend': 0.504, 'train': 0.531},
 L1 (T): {'attack': -0.9, 'rest': 0.504},
 L2 (R): {'attack': -0.4, 'defend': 0.559, 'train': 0.588},
 L2 (T): {'attack': -0.7, 'rest': 0.559},
 L4 (R): {'attack': 0.4, 'defend': 0.686, 'train': 0.722},
 L5 (T): {'attack': -0.1, 'rest': 0.76},
 L5 (R): {'attack': 0.8, 'defend': 0.76},
 L4 (T): {'attack': -0.3, 'rest': 0.686},
 L3 (R): {'attack': 0.0, 'defend': 0.619, 'train': 0.652},
 L3 (T): {'attack': -0.5, 'rest': 0.619}}


{L1 (T): {'attack': 0, 'rest': 1},
 L1 (R): {'attack': 0, 'defend': 0, 'train': 1},
 L2 (T): {'attack': 0, 'rest': 1},
 L2 (R): {'attack': 0, 'defend': 0, 'train': 1},
 L3 (T): {'attack': 0, 'rest': 1},
 L3 (R): {'attack': 0, 'defend': 0, 'train': 1},
 L4 (T): {'attack': 0, 'rest': 1},
 L4 (R): {'attack': 0, 'defend': 0, 'train': 1},
 L5 (T): {'attack': 0, 'rest': 1},
 L5 (R): {'attack': 1, 'defend': 0}}

### $v_{\pi}(s)=\sum_{a} \pi(a|s) \sum_{r, s'} p(r, s' | s, a) [r+ \gamma v_{\pi}(s')]$
Seeing as the policy is greedy, we only need to check the value of the only action taken at state $s$.
### $v_{\pi}(s)=\sum_{r, s'} p(r, s' | s, a) [r+ \gamma v_{\pi}(s')] \text{ where } a=\pi(s)$

| State  | Old Value | $\pi(s)$           | Calculation          | New Value |
|--------|-----------|--------------------|----------------------|-----------|
| L1 (T) | 0.504     | Rest (L1 (R))      | $0+0.95(0.504)$      | 0.504     |
| L1 (R) | 0.531     | Train (L2 (T))     | $0+0.95(0.559)$      | 0.531     |
| L2 (T) | 0.559     | Rest (L2 (R))      | $0+0.95(0.588)$      | 0.559     |
| L2 (R) | 0.588     | Train (L3 (T))     | $0+0.95(0.619)$      | 0.588     |
| L3 (T) | 0.619     | Rest (L3 (R))      | $0+0.95(0.652)$      | 0.619     |
| L3 (R) | 0.652     | Train (L4 (T))     | $0+0.95(0.686)$      | 0.652     |
| L4 (T) | 0.686     | Rest (L4 (R))      | $0+0.95(0.722)$      | 0.686     |
| L4 (R) | 0.722     | Train (L5 (T))     | $0+0.95(0.76)$       | 0.722     |
| L5 (T) | 0.76      | Rest (L5 (R))      | $0+0.95(0.8)$        | 0.76      |
| L5 (R) | 0.8       | Attack (Died, Won) | $(0.1(-1) + 0.9(1))$ | 0.8       |

Because the state values did not meaningfully change at all (chances, where present, are smaller than 0.001 and the hierarchy of state value is the same), we can be sure that the action values won't change, ergo the policy will stay the same. All this means that the code converged the policy properly.