In [1]:
from random import randint
import numpy as np

# Build environment and agent

In [31]:
class State(object):
    def __init__(self, representation, actions, persistent=True) -> None:
        self.representation = representation
        self.actions = actions
        self.transitions = {'null': ([self], np.ones(1))} if persistent else {}

    def __repr__(self):
        return f'State {self.representation}'


class MonsterTaskEnv(object):
    def __init__(self, states, agent, init_state_ind=0) -> None:
        self.states = states
        self.state = states[init_state_ind]
        self.agent = agent
        self.agent.env = self

    def transition(self, action):
        new_states, probas = self.state.transitions[action]
        new_state = np.random.choice(new_states, size=1, p=probas)[0]
        self.state = new_state
        return new_state


class Agent(object):
    def __init__(self, actions, env=None) -> None:
        self.env = env
        self.actions = actions
        self.policy = None

    def generate_responses(self, state):
        # Responses
        return list(set(state.actions).intersection(self.actions))

    def generate_predictions(self, state, responses):
        # Predictions should be a probability distribution over next states
        predictions = {}
        for response in responses:
            predictions[response] = (state.transitions[response])
        return None

    def choose_response(self, responses, predictions):
        # For now, ignore predictions and choose randomly
        if len(responses) > 1:
            return responses[randint(0, len(responses) - 1)]
        else:
            return responses[0]

    def learn(self, *args):
        # Here, the agent should update its models or response tendencies.
        pass

    def step(self, auto):
        if not auto:
            usr_inp = input('Press enter to continue')
            if usr_inp == 'exit':
                return False
        state = self.env.state
        responses = self.generate_responses(state)
        predictions = self.generate_predictions(state, responses)
        choice = self.choose_response(responses, predictions)
        new_state = self.env.transition(choice)
        self.learn(state, responses, predictions, choice, new_state)
        return True

    def go(self, timesteps, auto=True):
        for t in range(timesteps):
            print(f'Step {t}')
            if not self.step(auto):
                break

    def give_control(self, max_timesteps):
        for i in range(max_timesteps):
            state = self.env.state
            responses = self.generate_responses(state)
            
            response = '...'
            if state.representation[0] == 1:
                print(f'Choose monster family {responses}')
            elif state.representation[0] == 2:
                print(f'What does this monster like to eat? {responses}\nFamily {state.representation[1]}, {state.representation[2:4]}')
            elif state.representation[0] == 3:
                if state.representation[-1] == 1:
                    print('Correct!')
                if state.representation[-1] == 2:
                    print('Incorrect!')
                response = 'null'
            while response not in responses:
                response = input()
                if response == 'exit':
                    break
            self.env.transition(response)

            

# Define states

## Explanation

There are three kinds of states in the environment, each characterized by a unique, yet semantically meaningful representation, and a set of actions available. Each state is represented by a 5-element vector where each element encodes some qualitative information about the state.

For example, the choice state, represented as a vector `[1, 0, 0, 0, 0]`. The `1` in the first position indicates that this is the choice stage and 0 values in other positions indicate that the features characterizing other states are absent in this state. There are 4 actions available in this state: `a1`, `a2`, `a3`, and `a4`. These actions correspond to family choices and each of them leads to a different set of states.

For example, action `a1` leads to one of the states that have `2` in the first position, and `1` in the second position, e.g., `[2, 1, 3, 5, 0]`. The leading `2` encodes the fact that the state is a guess state (not a choice state), and values in the 3rd and 4th positions encode the features of a particular stimulus (monster) presented to the agent. Each of the guess states have two actions whereby all guess states presenting monsters from the same family have the same pair of actions (the same food choices). 

The last example state `[2, 1, 3, 5, 0]` has actions `f11` and `f12` corresponding to two food choices. Each action, depending on the rule, leads to the third type of state, the feedback state, which can be either `[3, 0, 0, 0, 1]` for a positive feedback or `[3, 0, 0, 0, 1]` for the negative feedback. As you can see, the states differ semantically, but there is nothing inherently positive or negative about any particular state. The feedback states both have a null action `null` that inadvertently leads to the choice state. The `null` action is present in all other states, but in these cases, it results in an "autotransition" so that the state does not change.

To simplify things, we can imagine that instead of the feedback states being evaluated by the agent, the environment emits a reward for the agent to process (how it is usually done in RL).

## Code

In [32]:
def xnor(a, b):
    return (a and b) or (not a and not b)


states = []

states += [
    State(representation=np.array([1, 0, 0, 0, 0]), actions=['a1', 'a2', 'a3', 'a4', 'null']),
    State(representation=np.array([3, 0, 0, 0, 1]), actions=['null'], persistent=False),
    State(representation=np.array([3, 0, 0, 0, 2]), actions=['null'], persistent=False)
]

for i in range(1, 7):
    states.append(
        State(representation=np.array([2, 1, i, 3, 0]), actions=['f11', 'f12', 'null']),
    )

for i in range(2, 5):
    actions = [f'f{i}1', f'f{i}2', 'null']
    for j in range(1, 7):
        for k in range(1, 7):
            states.append(
                State(representation=np.array([2, i, j, k, 0]), actions=actions),
            )

# Transitions from choice state
for act in states[0].actions[:-1]:
    to_states = [state for state in states[3:] if state.representation[1] == int(act[-1])]
    states[0].transitions[act] = (to_states, np.ones(len(to_states))/len(to_states))

# Transitions from feedback state
states[1].transitions['null'] = ([states[0]], np.ones(1))
states[2].transitions['null'] = ([states[0]], np.ones(1))

# Transitions from guess states
for state in states[3:]:
    # 1d1 rule
    if state.representation[1] == 1:
        rule = {'f11': states[1], 'f12': states[2]} if state.representation[2] < 4 else {'f11': states[2], 'f12': states[1]}
        state.transitions['f11'] = ([rule['f11']], np.ones(1))
        state.transitions['f12'] = ([rule['f12']], np.ones(1))
    # 2d1 rule
    elif state.representation[1] == 2:
        rule = {'f21': states[1], 'f22': states[2]} if state.representation[2] < 4 else {'f21': states[2], 'f22': states[1]}
        state.transitions['f21'] = ([rule['f21']], np.ones(1))
        state.transitions['f22'] = ([rule['f22']], np.ones(1))
    # 2d2 rule
    elif state.representation[1] == 3:
        rule = {'f31': states[1], 'f32': states[2]} if xnor(state.representation[2] < 4, state.representation[3] < 4) else {'f31': states[2], 'f32': states[1]}
        state.transitions['f31'] = ([rule['f31']], np.ones(1))
        state.transitions['f32'] = ([rule['f32']], np.ones(1))
    elif state.representation[1] == 4:
        state.transitions['f41'] = ([states[1], states[2]], np.ones(2)/2)
        state.transitions['f42'] = ([states[1], states[2]], np.ones(2)/2)


# Simulation

In [35]:
agent = Agent(actions=['null', 'a1','a2','a3','a4','f11','f12','f21','f22','f31','f32','f41','f42'])
env = MonsterTaskEnv(states=states, agent=agent)
# agent.go(50)
agent.give_control(20)

Choose monster family ['a2', 'null', 'a1', 'a4', 'a3']
What does this monster like to eat? ['null', 'f12', 'f11']
Family 1, [6 3]
Incorrect!
Choose monster family ['a2', 'null', 'a1', 'a4', 'a3']
What does this monster like to eat? ['null', 'f12', 'f11']
Family 1, [3 3]
Correct!
Choose monster family ['a2', 'null', 'a1', 'a4', 'a3']
What does this monster like to eat? ['null', 'f12', 'f11']
Family 1, [2 3]
Correct!
Choose monster family ['a2', 'null', 'a1', 'a4', 'a3']


KeyError: 'exit'