In [1]:
import sys
if not '..' in sys.path:
    sys.path.append('..')
from simplegrid.cow import Action
import numpy as np
import random

from unittest.mock import MagicMock
from shared.experiment_settings import ExperimentSettings
from simplegrid.deep_cow import DeepCow
from simplegrid.dqn_agent import DQNAgent
from simplegrid.world import World as World, MapFeature

In [2]:
settings = ExperimentSettings('')
settings.world_size = 5
settings.start_num_creatures = 0
settings.layers = [12]
DeepCow.agent = None
deepcow = DeepCow(x=2, y=2, energy=100, settings=settings)
world = World(settings, MagicMock())

In [3]:
def training_record(world, cow, grass_fraction=0.25, water_fraction=0.10):
    world.reset(MagicMock(), grass_fraction=grass_fraction, water_fraction=water_fraction)
    cow.x = 2
    cow.y = 2
    world.add_new_creature(cow)
    observation = world.get_observation(cow)
    action = cow.step(observation)
    new_creature, reward, done = world.process_action(cow, action)
    if done:
        next_state = None
    else:
        next_observation = world.get_observation(cow)
        next_state = cow.to_internal_state(next_observation)
    return cow.state, cow.action_idx, reward, next_state

training_record(world, deepcow)

(array([1., 0., 1., 0., 0., 0., 1., 1., 0., 1., 0., 1., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.]),
 3,
 0.8666666666666667,
 array([1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0.,
        0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
        0., 0.]))

In [12]:
def run_scenario(scenario, world, cow):
    environment = MapFeature.text_scene_to_environment(scenario)
    world.cells = environment
    cow.x = 2
    cow.y = 2
    world.add_new_creature(cow)
    observation = world.get_observation(cow)
    state = cow.to_internal_state(observation)
    act_values = cow.agent.predict(state)
    action_index = np.argmax(act_values[0])
    return [(Action(action_index + 1), act_values[0][action_index])
           for action_index in np.argsort(act_values[0])]

run_scenario('.....\n'
             '.....\n'
             '.....\n'
             '.....\n'
             '..~..\n',
             world, deepcow
            )

[(DOWN, -0.08910842),
 (LEFT, -0.088489294),
 (UP, -0.07259474),
 (RIGHT, -0.07130667)]

In [6]:
def score(world, deepcow):
    score = 0
    for i in range(1000):
        state, action, reward, next_state = training_record(world, deepcow)
        if next_state is None:
            score -= 10
        else:
            score += reward
    return score / 1000

score(world, deepcow)

-1.8710000000000429

In [22]:
records = [training_record(world, deepcow) for _ in range(200000)]
records[0]

(array([0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
 0,
 0.8666666666666667,
 array([0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))

In [23]:
agent = deepcow.agent
for rec in records:
    agent.remember(*rec)

In [24]:
for epoch in range(10):
    for __ in range(len(records) // agent.batch_size):
        loss = agent.replay()
    print(epoch, loss)

0 0.1081507563086537
1 0.04180942660896109
2 0.03738071905293813
3 0.05667696514477333
4 0.06424152853433043
5 0.07367683884222062
6 0.03925972802098843
7 0.07053790832093607
8 0.04638871860612806
9 0.055363988130314595


In [10]:
score(world, deepcow)

0.613666666666672

In [25]:
run_scenario('.....\n'
             '.....\n'
             '.....\n'
             '..~..\n'
             '.....\n',
             world, deepcow
            )

[(DOWN, -0.09832768),
 (RIGHT, -0.07484319),
 (UP, -0.06951808),
 (LEFT, 0.0024407804)]

In [26]:
run_scenario('.....\n'
             '.....\n'
             '.....\n'
             '.....\n'
             '..#..\n',
             world, deepcow
            )

[(UP, -0.08742268),
 (RIGHT, -0.03684604),
 (LEFT, -0.01104486),
 (DOWN, 0.11615394)]

In [27]:
run_scenario('#....\n'
             '.....\n'
             '.....\n'
             '.....\n'
             '.....\n',
             world, deepcow
            )

[(UP, -0.07630407),
 (RIGHT, -0.044064693),
 (DOWN, -0.0111687705),
 (LEFT, 0.015386969)]

In [29]:
run_scenario('.....\n'
             '.....\n'
             '#~...\n'
             '.....\n'
             '.....\n',
             world, deepcow
            )

[(LEFT, -0.16135077),
 (UP, -0.06531254),
 (RIGHT, -0.03798972),
 (DOWN, 0.05389011)]

In [30]:
run_scenario('.....\n'
             '#~...\n'
             '.....\n'
             '.....\n'
             '.....\n',
             world, deepcow
            )

[(RIGHT, -0.056711726),
 (UP, -0.05226101),
 (DOWN, -0.04222539),
 (LEFT, 0.04077822)]

In [31]:
run_scenario('.....\n'
             '.#~..\n'
             '.....\n'
             '.....\n'
             '.....\n',
             world, deepcow
            )

[(RIGHT, -0.055134237),
 (DOWN, -0.04744616),
 (UP, 0.08114942),
 (LEFT, 0.12565738)]