In [1]:
from matplotlib import pyplot as plt
import numpy as np
import gymnasium as gym
from edugym.agents import QLearningAgent, SarsaAgent
from gymnasium.envs.registration import register
import plotly.graph_objects as go

register(
     id="edugym/Tamagotchi-v0",
     entry_point="edugym.envs:TamagotchiEnv",
)

  logger.warn(f"Overriding environment {new_spec.id} already in registry.")


In [2]:
learning_rate = 0.1
gamma = 1.0
epsilon = 0.1

n_timesteps = 3001
n_repetitions = 10

temperatures = [0.1, 0.5, 1.0, 2.0, 4.0]
message_lengths = [1, 2, 3, 4]

## Qlearning Agent

In [3]:
learning_curves_qlearning = []

for temperature in temperatures:
    results = []
    for rep in range(n_repetitions):
        env = gym.make("edugym/Tamagotchi-v0", tau=temperature)
        eval_env = env
        Agent = QLearningAgent(
            env.n_states, env.action_space.n, gamma=gamma, learning_rate=learning_rate)
        time_steps, returns = Agent.train(env, eval_env, epsilon, n_timesteps)
        results.append(returns)
    average_learning_curve = np.mean(np.array(results), axis=0)
    learning_curves_qlearning.append(average_learning_curve)
    print("Completed temperature: {}".format(temperature))

  logger.warn(
  logger.deprecation(
  logger.warn(f"{pre} was expecting a numpy array, actual type: {type(obs)}")
  logger.warn(f"{pre} is not within the observation space.")


Completed temperature: 0.1
Completed temperature: 0.5
Completed temperature: 1.0
Completed temperature: 2.0
Completed temperature: 4.0


In [4]:
fig = go.Figure()
for i, temperature in enumerate(temperatures):
    name = (
        "Tau: {}".format(temperature)
    )
    fig.add_trace(go.Scatter(x=time_steps, y=learning_curves_qlearning[i], name=name,))

# Customize layout
fig.update_layout(
    title="Q-Learning",
    title_x=0.5,
    xaxis_title="Timesteps",
    yaxis_title="Average Return",
    legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
    font=dict(family="serif", size=12),
    width=900,
    height=500,
)
fig.show()

## Sarsa Agent

In [5]:
learning_curves_sarsa = []

for temperature in temperatures:
    results = []
    for rep in range(n_repetitions):
        env = gym.make("edugym/Tamagotchi-v0", tau=temperature)
        eval_env = env
        Agent = SarsaAgent(
            env.n_states, env.action_space.n, gamma=gamma, learning_rate=learning_rate)
        time_steps, returns = Agent.train(env, eval_env, epsilon, n_timesteps)
        results.append(returns)
    average_learning_curve = np.mean(np.array(results), axis=0)
    learning_curves_sarsa.append(average_learning_curve)
    print("Completed temperature: {}".format(temperature))


[33mWARN: The result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `<class 'numpy.int64'>`[0m


[33mWARN: Core environment is written in old step API which returns one bool instead of two. It is recommended to rewrite the environment with new step API. [0m


[33mWARN: The obs returned by the `step()` method was expecting a numpy array, actual type: <class 'numpy.int64'>[0m


[33mWARN: The obs returned by the `step()` method is not within the observation space.[0m



Completed temperature: 0.1
Completed temperature: 0.5
Completed temperature: 1.0
Completed temperature: 2.0
Completed temperature: 4.0


In [6]:
fig = go.Figure()
for i, temperature in enumerate(temperatures):
    name = (
        "Tau: {}".format(temperature)
    )
    fig.add_trace(go.Scatter(x=time_steps, y=learning_curves_sarsa[i], name=name,))

# Customize layout
fig.update_layout(
    title="SARSA",
    title_x=0.5,
    xaxis_title="Timesteps",
    yaxis_title="Average Return",
    legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
    font=dict(family="serif", size=12),
    width=900,
    height=500,
)
fig.show()

In [8]:
fig = go.Figure()
results = [learning_curves_qlearning, learning_curves_sarsa]
for i, type in enumerate(["Qlearning", "Sarsa"]):

    for j, temperature in enumerate([temperatures[0], temperatures[-1]]):
        index = temperatures.index(temperature)    
        name = (
            "{} (tau: {})".format(type, temperature)
        )
        fig.add_trace(go.Scatter(x=time_steps, y=results[i][index], name=name,))

# Customize layout
fig.update_layout(
    title="Qlearning vs. Sarsa",
    title_x=0.5,
    xaxis_title="Timesteps",
    yaxis_title="Average Return",
    legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
    font=dict(family="serif", size=12),
    width=900,
    height=500,
)
fig.show()

## Communication message length

In [9]:
learning_curves_messages = []

for msg_length in message_lengths:
    results = []
    for rep in range(n_repetitions):
        env = gym.make("edugym/Tamagotchi-v0", tau=.1, max_msg_length=msg_length)
        eval_env = env
        Agent = QLearningAgent(
            env.n_states, env.action_space.n, gamma=gamma, learning_rate=learning_rate)
        time_steps, returns = Agent.train(env, eval_env, epsilon, n_timesteps)
        results.append(returns)
    average_learning_curve = np.mean(np.array(results), axis=0)
    learning_curves_messages.append(average_learning_curve)
    print("Completed message length: {}".format(msg_length))

Completed message length: 1
Completed message length: 2
Completed message length: 3
Completed message length: 4


In [10]:
fig = go.Figure()
for i, msg_length in enumerate(message_lengths):
    name = (
        "length: {}".format(msg_length)
    )
    fig.add_trace(go.Scatter(x=time_steps, y=learning_curves_messages[i], name=name,))

# Customize layout
fig.update_layout(
    title="Message Length",
    title_x=0.5,
    xaxis_title="Timesteps",
    yaxis_title="Average Return",
    legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
    font=dict(family="serif", size=12),
    width=900,
    height=500,
)
fig.show()

## Tests (Remove once done)

In [None]:
env = gym.make("edugym/Tamagotchi-v0", tau=0.01, max_msg_length=2)

observation = env.reset(seed=42)
rewards = []
for i in range(100):
    print(f"Step: {i}")
    action = env.action_space.sample()

    
    # action = env.required_action
    print(env.internal_vars, env.weights, env.happiness, env.required_action, action)


    observation, reward, terminated, info = env.step(action)
    print(f"Observation: {observation}, Reward: {reward}")
    # env.render()
    print('-'*100)
    
    if terminated:
        print('Tamagotchi died.. ')
        observation = env.reset()
    rewards.append(reward)
env.close()

# plt.plot(range(100), rewards)