In [5]:
from dojo import Dojo
from tf_agents.environments import suite_gym, tf_py_environment, py_environment
from tf_agents.networks.q_network import QNetwork
from tf_agents.specs import array_spec
from tf_agents.trajectories import time_step as ts
import numpy as np

In [7]:
class CardGameEnv(py_environment.PyEnvironment):
  def __init__(self):
    self._action_spec = array_spec.BoundedArraySpec(
        shape=(), dtype=np.int32, minimum=0, maximum=1, name='action')
    self._observation_spec = array_spec.BoundedArraySpec(
        shape=(1,), dtype=np.int32, minimum=0, name='observation')
    self._state = 0
    self._episode_ended = False

  def action_spec(self):
    return self._action_spec

  def observation_spec(self):
    return self._observation_spec

  def _reset(self):
    self._state = 0
    self._episode_ended = False
    return ts.restart(np.array([self._state], dtype=np.int32))

  def _step(self, action):

    if self._episode_ended:
      # The last action ended the episode. Ignore the current action and start
      # a new episode.
      return self.reset()

    # Make sure episodes don't go on forever.
    if action == 1:
      self._episode_ended = True
    elif action == 0:
      new_card = np.random.randint(1, 11)
      self._state += new_card
    else:
      raise ValueError('`action` should be 0 or 1.')

    if self._episode_ended or self._state >= 21:
      reward = self._state - 21 if self._state <= 21 else -21
      return ts.termination(np.array([self._state], dtype=np.int32), reward)
    else:
      return ts.transition(
          np.array([self._state], dtype=np.int32), reward=0.0, discount=1.0)

In [8]:
# env = suite_gym.load('CartPole-v1')
env = CardGameEnv()
env = tf_py_environment.TFPyEnvironment(env)
env.reset()

TimeStep(
{'discount': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>,
 'observation': <tf.Tensor: shape=(1, 1), dtype=int32, numpy=array([[0]], dtype=int32)>,
 'reward': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.], dtype=float32)>,
 'step_type': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([0], dtype=int32)>})

In [9]:
q_net = QNetwork(env.observation_spec(), env.action_spec())

In [10]:
dojo = Dojo(q_net, env)
dojo.train(10000)

step = 200: loss = 24.986818313598633 return = -12.5
step = 400: loss = 4.448462963104248 return = -11.100000381469727
step = 600: loss = 4.064504623413086 return = -6.300000190734863
step = 800: loss = 3.6110057830810547 return = -6.699999809265137
step = 1000: loss = 3.1352672576904297 return = -5.099999904632568
step = 1200: loss = 3.2211601734161377 return = -6.300000190734863
step = 1400: loss = 3.2564537525177 return = -5.0
step = 1600: loss = 2.854937791824341 return = -6.0
step = 1800: loss = 2.9084932804107666 return = -4.800000190734863
step = 2000: loss = 3.150632858276367 return = -4.699999809265137
step = 2200: loss = 3.1324493885040283 return = -6.400000095367432
step = 2400: loss = 2.9978902339935303 return = -5.099999904632568
step = 2600: loss = 3.132115125656128 return = -4.599999904632568
step = 2800: loss = 3.04097843170166 return = -4.0
step = 3000: loss = 3.404898166656494 return = -5.900000095367432
step = 3200: loss = 2.818737506866455 return = -6.09999990463256