In [30]:
import random ,argparse, sys
parser = argparse.ArgumentParser()
import numpy as np

In [31]:
class MDPValueIteration:
    def __init__(self, file_path, epsilon = 1e-9):
        with open(file_path, 'r') as file:
            lines = file.readlines()
        lines = [line.strip().split() for line in lines]
        self.extract_mdp(lines)
        self.epsilon = epsilon
        self.values_old = np.zeros(self.num_of_states)
        self.values_new = np.ones(self.num_of_states)
        self.actions_old = np.zeros(self.num_of_states)
        self.actions_new = np.zeros(self.num_of_states)
    
    def extract_mdp(self, lines):
        self.num_of_states = int(lines[0][1])
        self.num_of_actions = int(lines[1][1])
        self.end_states = []
        self.type = "None"
        self.gamma = 1.0
        self.transitions = np.zeros((self.num_of_states, self.num_of_actions, self.num_of_states))
        self.reward = np.zeros((self.num_of_states, self.num_of_actions, self.num_of_states)) 
        for i in range(len(lines)):
            if(lines[i][0]=="end"):
                for j in range(1,len(lines[i])):
                    self.end_states.append(int(lines[i][j]))

            elif(lines[i][0]=="transition"):
                self.transitions[int(lines[i][1])][int(lines[i][2])][int(lines[i][3])] = float(lines[i][5])
                self.reward[int(lines[i][1])][int(lines[i][2])][int(lines[i][3])] = float(lines[i][4])

            elif(lines[i][0]=="mdptype"):
                self.type = lines[i][1]
                
            elif(lines[i][0]=="discount"):
                self.gamma = float(lines[i][1])
    
    def value_iterate(self):
        i = 0
        while(True):
            Q_value = np.sum(self.transitions * (self.reward + self.gamma * self.values_old), axis=2)
            self.actions_new = np.argmax(Q_value, axis=1)
            self.values_new = np.max(Q_value, axis=1)

            for end_state in self.end_states:
                if(end_state!=-1):
                    self.values_new[end_state]= 0

            if(np.all(np.abs(self.values_new - self.values_old) < self.epsilon)):
                break

            self.values_old = self.values_new.copy()
            self.actions_old = self.actions_new.copy()
            i+=1
        print(i)
            

        value = self.values_old
        action = self.actions_old

        return value, action


mdp = MDPValueIteration("./data/episodic-mdp-10-5.txt")
value, action = mdp.value_iterate()

for i in range(len(value)):
    print(str(value[i]) + " "+ str(action[i]))

42355
0.0 0
530.2198198940359 3
530.513671332707 4
504.79864465483024 2
472.94806296920024 1
0.0 0
526.9529899925243 2
518.4643115613788 2
354.45767841147875 4
529.2921428030874 0
