In [1]:
import numpy as np

In [2]:
bin = lambda x : ''.join(reversed( [str((x >> i) & 1) for i in range(jobs)] ) )
def calc_cost_per_state(problem_table, state):
    in_jobs = bin(state)
    in_jobs=in_jobs[::-1]
    cost = 0
    for idx, job in enumerate(in_jobs):
        if int(job):
            cost += problem_table[1, idx]
    return cost

In [3]:
def calc_p(problem_table, policy):
    P = np.zeros([states, states])
    for row in range(states):
        for col in range(states):
            curr_state = row
            next_state = col
            if curr_state > 0:
                possible_next_state = curr_state - 2 ** (int(policy[curr_state]) - 1)
            if curr_state == 0:
                if next_state == 0:
                    P[row,col] = 1
                else:
                    P[row,col] = 0
                continue
            if next_state == possible_next_state:
                P[row,col] = problem_table[0, int(policy[curr_state]) - 1]
            elif next_state == curr_state:
                P[row,col] = 1 - problem_table[0, int(policy[curr_state]) - 1]
            else:
                P[row,col] = 0
    return P
            

In [4]:
def get_v(p, r):
    return np.linalg.inv(np.identity(len(r))-0.9*p) * r

def get_v_iteration(p, r):
    v = np.ones([len(r)])
    v_new = np.zeros([len(r)])
    while not (v == v_new).all():
        v = v_new.copy()
        v_new = r + p @ v
    return v

In [5]:
def get_policy_c(table, state):
    in_jobs = bin(state)
    in_jobs=in_jobs[::-1]
    max_c = -1
    max_idx = -1
    for idx, job in enumerate(in_jobs):
        if int(job):
            cost = problem_table[1, idx]
            if cost > max_c:
                max_c = cost
                max_idx = idx
    return max_idx+1

In [6]:
def is_action_valid(state, action):
    in_jobs = bin(state)
    in_jobs=in_jobs[::-1]
    return bool(int(in_jobs[action-1]))

def get_temp_p(state, action):
    p = np.zeros([2 ** jobs])
    possible_next_state = state - 2 ** (action - 1)
    p[possible_next_state] = problem_table[0, action - 1]
    p[state] = 1 - problem_table[0, action - 1]
    return p

def get_new_policy(r, v):
    num_states = 32
    policy = np.zeros([num_states])
    for state in range(num_states):
        min_action_val = 10000000
        min_action = -1
        for a in range(1,jobs+1):
            if is_action_valid(state=state, action=a):
                val = r[state]+get_temp_p(state=state, action=a) @ v
                if val < min_action_val:
                    min_action_val = val
                    min_action = a
        policy[state] = min_action
    return policy
        

def policy_iteration(r, table):
    num_states = 2 ** len(table[0])
    new_policy = np.zeros([num_states])
    for state in range(num_states):
        new_policy[state] = get_policy_c(table=table, state=state)
    policy = np.zeros([num_states])
    while not (new_policy == policy).all():
        print("*"*20)
        print(new_policy)
        policy = new_policy.copy()
        p = calc_p(problem_table=table, policy=policy)
        v = get_v_iteration(p=p, r=r)
        print(v)
        new_policy = get_new_policy(r=r, v=v)
    return policy
        

In [7]:
jobs = 5
states = 2 ** jobs
cost = np.zeros([states])
problem_table = np.array([
    [0.6, 0.5, 0.3, 0.7, 0.1],
    [1,     4,   6,   2,   9]
])

for state in range(states):
    cost[state] = calc_cost_per_state(problem_table=problem_table, state=state)
    
policy = np.array([
    0, 1, 2, 1, 3, 1, 3, 1,  # 
    4, 1, 4, 1, 4, 1, 4, 1,  # 
    5, 1, 5, 1, 5, 1, 5, 1,  #
    5, 1, 5, 1, 5, 1, 5, 1   # 
])
p = calc_p(problem_table, policy)


In [8]:
policy_iteration(r=cost, table=problem_table)

********************
[0. 1. 2. 2. 3. 3. 3. 3. 4. 4. 2. 2. 3. 3. 3. 3. 5. 5. 5. 5. 5. 5. 5. 5.
 5. 5. 5. 5. 5. 5. 5. 5.]
[  0.           1.66666667   8.          11.66666667  20.
  25.          41.33333333  48.33333333   2.85714286   5.95238095
  14.85714286  19.95238095  29.52380952  35.95238095  54.85714286
  63.28571429  90.         101.66666667 138.         151.66666667
 170.         185.         231.33333333 248.33333333 112.85714286
 125.95238095 164.85714286 179.95238095 199.52380952 215.95238095
 264.85714286 283.28571429]
********************
[-1.  1.  2.  2.  3.  3.  2.  2.  4.  4.  2.  2.  3.  3.  2.  2.  5.  5.
  2.  2.  3.  3.  2.  2.  4.  4.  2.  2.  3.  3.  2.  2.]
[  0.           1.66666667   8.          11.66666667  20.
  25.          40.          47.           2.85714286   5.95238095
  14.85714286  19.95238095  29.52380952  35.95238095  53.52380952
  61.95238095  90.         101.66666667 116.         129.66666667
 140.         155.         178.         195.         105

array([-1.,  1.,  2.,  2.,  3.,  3.,  2.,  2.,  4.,  4.,  2.,  2.,  3.,
        3.,  2.,  2.,  5.,  5.,  2.,  2.,  3.,  3.,  2.,  2.,  4.,  4.,
        2.,  2.,  3.,  3.,  2.,  2.])