In [1]:
import sys
if not '..' in sys.path:
    sys.path.append('..')
from simplegrid.cow import Action
import numpy as np
import random

from unittest.mock import MagicMock
from shared.experiment_settings import ExperimentSettings
from simplegrid.deep_cow import DeepCow
from simplegrid.dqn_agent import DQNAgent
from simplegrid.world import World as World, MapFeature

In [2]:
settings = ExperimentSettings('')
settings.world_size = 5
settings.start_num_creatures = 0
settings.layers = [12]
DeepCow.agent = None
deepcow = DeepCow(x=2, y=2, energy=100, settings=settings)
world = World(settings, MagicMock())

In [3]:
def training_record(world, cow, grass_fraction=0.25, water_fraction=0.10):
    world.reset(MagicMock(), grass_fraction=grass_fraction, water_fraction=water_fraction)
    cow.x = 2
    cow.y = 2
    world.add_new_creature(cow)
    observation = world.get_observation(cow)
    action = cow.step(observation)
    new_creature, reward, done = world.process_action(cow, action)
    if done:
        next_state = None
    else:
        next_observation = world.get_observation(cow)
        next_state = cow.to_internal_state(next_observation)
    return cow.state, cow.action_idx, reward, next_state

training_record(world, deepcow)

(array([1., 0., 1., 0., 0., 0., 1., 1., 0., 1., 0., 1., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.]),
 3,
 0.8666666666666667,
 array([1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0.,
        0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
        0., 0.]))

In [12]:
def run_scenario(scenario, world, cow):
    environment = MapFeature.text_scene_to_environment(scenario)
    world.cells = environment
    cow.x = 2
    cow.y = 2
    world.add_new_creature(cow)
    observation = world.get_observation(cow)
    state = cow.to_internal_state(observation)
    act_values = cow.agent.predict(state)
    action_index = np.argmax(act_values[0])
    return [(Action(action_index + 1), act_values[0][action_index])
           for action_index in np.argsort(act_values[0])]

run_scenario('.....\n'
             '.....\n'
             '.....\n'
             '.....\n'
             '..~..\n',
             world, deepcow
            )

[(DOWN, -0.08910842),
 (LEFT, -0.088489294),
 (UP, -0.07259474),
 (RIGHT, -0.07130667)]

In [6]:
def score(world, deepcow):
    score = 0
    for i in range(1000):
        state, action, reward, next_state = training_record(world, deepcow)
        if next_state is None:
            score -= 10
        else:
            score += reward
    return score / 1000

score(world, deepcow)

-1.8710000000000429

In [7]:
records = [training_record(world, deepcow) for _ in range(10000)]
records[0]

(array([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0]),
 3,
 -0.13333333333333333,
 array([1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))

In [8]:
agent = deepcow.agent
for rec in records:
    agent.remember(*rec)

In [9]:
for epoch in range(10):
    for __ in range(len(records) // agent.batch_size):
        loss = agent.replay()
    print(epoch, loss)

0 0.2340843577325965
1 0.1237001986359246
2 0.08668299736067031
3 0.08420823316167418
4 0.07991760802008988
5 0.07443107463380633
6 0.08026132208760828
7 0.06388902496546506
8 0.07735907312016935
9 0.09507555242938301


In [10]:
score(world, deepcow)

0.613666666666672

In [13]:
run_scenario('.....\n'
             '.....\n'
             '.....\n'
             '..~..\n'
             '.....\n',
             world, deepcow
            )

[(DOWN, -0.22786957),
 (UP, -0.08288756),
 (RIGHT, -0.07165544),
 (LEFT, -0.068953216)]

In [14]:
run_scenario('.....\n'
             '.....\n'
             '.....\n'
             '.....\n'
             '..#..\n',
             world, deepcow
            )

[(LEFT, -0.09282464),
 (RIGHT, -0.07800578),
 (UP, -0.03955635),
 (DOWN, 0.030865952)]

In [18]:
run_scenario('#....\n'
             '.....\n'
             '.....\n'
             '.....\n'
             '.....\n',
             world, deepcow
            )

[(RIGHT, -0.09073467),
 (LEFT, -0.087608606),
 (DOWN, -0.06990122),
 (UP, -0.051943235)]