In [1]:
import numpy as np

In [2]:
loaded_params = {}
reading:str = ''
read_lines = 0

with open('./Example Models/tiger.95.POMDP') as file:
    for line in file:
        if line.startswith(('#', '\n')):
            continue

        # Split line
        line_items = line.replace('\n','').strip().split()

        # Discount factor
        if line.startswith('discount'):
            loaded_params['gamma'] = float(line_items[-1])
        
        # Value (either reward or cost)
        elif line.startswith('values'):
            loaded_params['values'] = line_items[-1] # To investigate

        # States
        elif line.startswith('states'):
            if line_items[-1].isnumeric():
                loaded_params['state_count'] = int(line_items[-1])
                loaded_params['states'] = [f's{i}' for i in range(loaded_params['state_count'])]
            else:
                loaded_params['states'] = line_items[1:]
                loaded_params['state_count'] = len(loaded_params['states'])

        # Actions
        elif line.startswith('actions'):
            if line_items[-1].isnumeric():
                loaded_params['action_count'] = int(line_items[-1])
                loaded_params['actions'] = [f'a{i}' for i in range(loaded_params['action_count'])]
            else:
                loaded_params['actions'] = line_items[1:]
                loaded_params['action_count'] = len(loaded_params['actions'])

        # Observations
        elif line.startswith('observations'):
            if line_items[-1].isnumeric():
                loaded_params['observation_count'] = int(line_items[-1])
                loaded_params['observations'] = [f'o{i}' for i in range(loaded_params['observation_count'])]
            else:
                loaded_params['observations'] = line_items[1:]
                loaded_params['observation_count'] = len(loaded_params['observations'])
        
        # Start
        elif line.startswith('start'):
            if len(line_items) == 1:
                reading = 'start'
            else:
                assert len(line_items[1:]) == loaded_params['state_count'], 'Not enough states in initial belief'
                loaded_params['init_belief'] = np.array([float(item) for item in line_items[1:]])
        elif reading == 'start':
            assert len(line_items) == loaded_params['state_count'], 'Not enough states in initial belief'
            loaded_params['init_belief'] = np.array([float(item) for item in line_items])
            reading = 'None'

        # ----------------------------------------------------------------------------------------------
        # Transition table
        # ----------------------------------------------------------------------------------------------
        if ('states' in loaded_params) and ('actions' in loaded_params) and ('observations' in loaded_params) and not ('transition_table' in loaded_params):
            loaded_params['transition_table'] = np.full((loaded_params['state_count'], loaded_params['action_count'], loaded_params['state_count']), np.nan)
        
        if line.startswith('T'):
            transition_params = line.replace(':',' ').split()[1:]
            transition_params = transition_params[:-1] if (line.count(':') == 3) else transition_params

            ids = []
            for i, param in enumerate(transition_params):
                if param.isnumeric():
                    ids.append([int(param)])
                elif i == 0:
                    ids.append(np.arange(loaded_params['action_count']) if param == '*' else [loaded_params['actions'].index(param)])
                elif i in [1,2]:
                    ids.append(np.arange(loaded_params['state_count']) if param == '*' else [loaded_params['states'].index(param)])
                else:
                    raise Exception('Cant load more than 3 parameters for transitions')

            # single item
            if len(transition_params) == 3:
                for s in ids[1]:
                    for a in ids[0]:
                        for s_p in ids[2]:
                            loaded_params['transition_table'][s, a, s_p] = float(line_items[-1])
            
            # More items
            else:
                reading = f'T{len(transition_params)} ' + ' '.join(transition_params)
                
        # Reading action-state line
        elif reading.startswith('T2'):
            transition_params = reading.split()[1:]

            ids = []
            for i, param in enumerate(transition_params):
                if param.isnumeric():
                    ids.append([int(param)])
                elif i == 0:
                    ids.append(np.arange(loaded_params['action_count']) if param == '*' else [loaded_params['actions'].index(param)])
                else:
                    ids.append(np.arange(loaded_params['state_count']) if param == '*' else [loaded_params['states'].index(param)])

            for a in ids[0]:
                for s in ids[1]:
                    # Uniform
                    if 'uniform' in line_items:
                        loaded_params['transition_table'][s, a, :] = np.ones(loaded_params['state_count']) / loaded_params['state_count']
                        continue

                    for s_p, item in enumerate(line_items):
                        loaded_params['transition_table'][s, a, s_p] = float(item)

            reading = ''

        # Reading action matrix
        elif reading.startswith('T1'):
            s = read_lines

            transition_params = reading.split()[1:]

            ids = []
            for i, param in enumerate(transition_params):
                if param.isnumeric():
                    ids.append([int(param)])
                elif i == 0:
                    ids.append(np.arange(loaded_params['action_count']) if param == '*' else [loaded_params['actions'].index(param)])
            
            for a in ids[0]:
                # Uniform
                if 'uniform' in line_items:
                    loaded_params['transition_table'][:, a, :] = np.ones((loaded_params['state_count'], loaded_params['state_count'])) / loaded_params['state_count']
                    reading = ''
                    continue
                # Identity
                if 'identity' in line_items:
                    loaded_params['transition_table'][:, a, :] = np.eye(loaded_params['state_count'])
                    reading = ''
                    continue

                for s_p, item in enumerate(line_items):
                    loaded_params['transition_table'][s, a, s_p] = float(item)

            if ('uniform' not in line_items) and ('identity' not in line_items):
                read_lines += 1
            
            if read_lines == loaded_params['state_count']:
                reading = ''
                read_lines = 0


        # ----------------------------------------------------------------------------------------------
        # Observation table
        # ----------------------------------------------------------------------------------------------
        if ('states' in loaded_params) and ('actions' in loaded_params) and ('observations' in loaded_params) and not ('observation_table' in loaded_params):
            loaded_params['observation_table'] = np.full((loaded_params['state_count'], loaded_params['action_count'], loaded_params['observation_count']), np.nan)

        if line.startswith('O'):
            observation_params = line.replace(':',' ').split()[1:]
            observation_params = observation_params[:-1] if (line.count(':') == 3) else observation_params

            ids = []
            for i, param in enumerate(observation_params):
                if param.isnumeric():
                    ids.append([int(param)])
                elif i == 0:
                    ids.append(np.arange(loaded_params['action_count']) if param == '*' else [loaded_params['actions'].index(param)])
                elif i == 1:
                    ids.append(np.arange(loaded_params['state_count']) if param == '*' else [loaded_params['states'].index(param)])
                elif i == 2:
                    ids.append(np.arange(loaded_params['observation_count']) if param == '*' else [loaded_params['observations'].index(param)])
                else:
                    raise Exception('Cant load more than 3 parameters for observations')

            # single item
            if len(observation_params) == 3:
                for a in ids[0]:
                    for s_p in ids[1]:
                        for o in ids[2]:
                            loaded_params['observation_table'][s_p, a, o] = float(line_items[-1])
            
            # More items
            else:
                reading = f'O{len(observation_params)} ' + ' '.join(observation_params)
                
        # Reading action-state line
        elif reading.startswith('O2'):
            observation_params = reading.split()[1:]

            ids = []
            for i, param in enumerate(observation_params):
                if param.isnumeric():
                    ids.append([int(param)])
                elif i == 0:
                    ids.append(np.arange(loaded_params['action_count']) if param == '*' else [loaded_params['actions'].index(param)])
                else:
                    ids.append(np.arange(loaded_params['state_count']) if param == '*' else [loaded_params['states'].index(param)])

            for a in ids[0]:
                for s_p in ids[1]:
                    # Uniform
                    if 'uniform' in line_items:
                        loaded_params['observation_table'][s_p, a, :] = np.ones(loaded_params['observation_count']) / loaded_params['observation_count']
                        continue

                    for o, item in enumerate(line_items):
                        loaded_params['observation_table'][s_p, a, o] = float(item)

            reading = ''

        # Reading action matrix
        elif reading.startswith('O1'):
            s_p = read_lines

            observation_params = reading.split()[1:]
            ids = []
            for i, param in enumerate(observation_params):
                if param.isnumeric():
                    ids.append([int(param)])
                else:
                    ids.append(np.arange(loaded_params['action_count']) if param == '*' else [loaded_params['actions'].index(param)])

            for a in ids[0]:
                # Uniform
                if 'uniform' in line_items:
                    loaded_params['observation_table'][:, a, :] = np.ones((loaded_params['state_count'], loaded_params['observation_count'])) / loaded_params['observation_count']
                    reading = ''
                    continue

                for o, item in enumerate(line_items):
                    loaded_params['observation_table'][s_p, a, o] = float(item)

            if 'uniform' not in line_items:
                read_lines += 1
            
            if read_lines == loaded_params['state_count']:
                reading = ''
                read_lines = 0


        # ----------------------------------------------------------------------------------------------
        # Rewards table
        # ----------------------------------------------------------------------------------------------
        if ('states' in loaded_params) and ('actions' in loaded_params) and ('observations' in loaded_params) and not ('reward_table' in loaded_params):
            loaded_params['reward_table'] = np.full((loaded_params['state_count'], loaded_params['action_count'], loaded_params['state_count'], loaded_params['observation_count']), np.nan)

        if line.startswith('R'):
            reward_params = line.replace(':',' ').split()[1:]
            reward_params = reward_params[:-1] if (line.count(':') == 4) else reward_params

            ids = []
            for i, param in enumerate(reward_params):
                if param.isnumeric():
                    ids.append([int(param)])
                elif i == 0:
                    ids.append(np.arange(loaded_params['action_count']) if param == '*' else [loaded_params['actions'].index(param)])
                elif i in [1,2]:
                    ids.append(np.arange(loaded_params['state_count']) if param == '*' else [loaded_params['states'].index(param)])
                elif i == 3:
                    ids.append(np.arange(loaded_params['observation_count']) if param == '*' else [loaded_params['observations'].index(param)])
                else:
                    raise Exception('Cant load more than 4 parameters for rewards')

            # single item
            if len(reward_params) == 4:
                for a in ids[0]:
                    for s in ids[1]:
                        for s_p in ids[2]:
                            for o in ids[3]:
                                loaded_params['reward_table'][s, a, s_p, o] = float(line_items[-1])
            
            elif len(reward_params) == 1:
                raise Exception('Need more than 1 parameter for rewards')

            # More items
            else:
                reading = f'R{len(reward_params)} ' + ' '.join(reward_params)
                
        # Reading action-state line
        elif reading.startswith('R3'):
            reward_params = reading.split()[1:]
            
            ids = []
            for i, param in enumerate(reward_params):
                if param.isnumeric():
                    ids.append([int(param)])
                elif i == 0:
                    ids.append(np.arange(loaded_params['action_count']) if param == '*' else [loaded_params['actions'].index(param)])
                else:
                    ids.append(np.arange(loaded_params['state_count']) if param == '*' else [loaded_params['states'].index(param)])

            for a in ids[0]:
                for s in ids[1]:
                    for s_p in ids[2]:
                        for o, item in enumerate(line_items):
                            loaded_params['reward_table'][s, a, s_p, o] = float(item)

            reading = ''

        # Reading action matrix
        elif reading.startswith('R2'):
            s_p = read_lines

            reward_params = reading.split()[1:]
            ids = []
            for i, param in enumerate(reward_params):
                if param.isnumeric():
                    ids.append([int(param)])
                elif i == 0:
                    ids.append(np.arange(loaded_params['action_count']) if param == '*' else [loaded_params['actions'].index(param)])
                else:
                    ids.append(np.arange(loaded_params['state_count']) if param == '*' else [loaded_params['states'].index(param)])

            for a in ids[0]:
                for s in ids[1]:
                    for o, item in enumerate(line_items):
                        loaded_params['reward_table'][s, a, s_p, o] = float(item)

            read_lines += 1
            if read_lines == loaded_params['state_count']:
                reading = ''
                read_lines = 0

In [3]:
loaded_params

{'gamma': 0.95,
 'values': 'reward',
 'states': ['tiger-left', 'tiger-right'],
 'state_count': 2,
 'actions': ['listen', 'open-left', 'open-right'],
 'action_count': 3,
 'observations': ['tiger-left', 'tiger-right'],
 'observation_count': 2,
 'transition_table': array([[[1. , 0. ],
         [0.5, 0.5],
         [0.5, 0.5]],
 
        [[0. , 1. ],
         [0.5, 0.5],
         [0.5, 0.5]]]),
 'observation_table': array([[[0.85, 0.15],
         [0.5 , 0.5 ],
         [0.5 , 0.5 ]],
 
        [[0.15, 0.85],
         [0.5 , 0.5 ],
         [0.5 , 0.5 ]]]),
 'reward_table': array([[[[  -1.,   -1.],
          [  -1.,   -1.]],
 
         [[-100., -100.],
          [-100., -100.]],
 
         [[  10.,   10.],
          [  10.,   10.]]],
 
 
        [[[  -1.,   -1.],
          [  -1.,   -1.]],
 
         [[  10.,   10.],
          [  10.,   10.]],
 
         [[-100., -100.],
          [-100., -100.]]]]),
 'init_belief': array([0.5, 0.5])}

In [5]:
import sys
sys.path.append('..')
from src.pomdp import Model, PBVI_Solver

In [6]:
loaded_model = Model(loaded_params['states'],
                     loaded_params['actions'],
                     loaded_params['observations'],
                     loaded_params['transition_table'],
                     loaded_params['reward_table'],
                     loaded_params['observation_table'])

loaded_solver = PBVI_Solver(loaded_params['gamma'])

In [8]:
loaded_model.actions

[0, 1, 2]