In [1]:
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import time
import pickle
import cv2
from matplotlib import style

In [2]:
style.use("ggplot")

In [3]:
SIZE = 10
EPISODES = 25000
MOVE_PENALTY = 1
ENEMY_PENALTY = 300     #hitting enemy
FOOD_REWARD = 25
DECAY_RATE = 0.998
LEARNING_RATE = 0.1
DISCOUNT = 0.95

epsilon = 0.9

q_table = None

PLAYER_N = 1
FOOD_N = 2
ENEMY_N = 3

d = {
    1: (255, 175, 0),
    2: (0, 255, 0),
    3: (0, 0, 255)
}                  # color of player_n, food_n, enemy_n 

In [4]:
class Blob:
    def __init__(self):
        # blob's location
        self.x = np.random.randint(0, SIZE)
        self.y = np.random.randint(0, SIZE)
    
    def __str__(self):
        return f"{self.x}, {self.y}"
    
    def __sub__(self, other):
        return (self.x - other.x, self.y - other.y)
    
    def __eq__(self, other):
        return (self.x == other.x and self.y == other.y)
    
    def action(self, choice):
        if choice == 0:
            self.move(x=1, y=1)
        elif choice == 1:
            self.move(x=-1, y=-1)
        elif choice == 2:
            self.move(x=-1, y=1)
        elif choice == 3:
            self.move(x=1, y=-1)

    def move(self, x=False, y=False):
        if not x:
            self.x = np.random.randint(-1, 2)
        else:
            self.x += x

        if not y:
            self.y = np.random.randint(-1, 2)
        else:
            self.y += y

        self.x = max(0, min(self.x, SIZE-1))
        self.y = max(0, min(self.y, SIZE-1))

In [5]:
# state space => ((x1, y1),(x2, y2)) 1st coords -> delta of player and food. 2nd coords delta of player and enemy

start_q_table = '<filename>'

if q_table is None:
    q_table = {}
    for x1 in range(-SIZE+1, SIZE):
        for y1 in range(-SIZE+1, SIZE):
            for x2 in range(-SIZE+1, SIZE):
                for y2 in range(-SIZE+1, SIZE):
                    q_table[((x1,y1),(x2, y2))] = [np.random.uniform(-3,0) for _ in range(4)]

else:
    with open(start_q_table, 'rb') as f:
        q_table = pickle.load(f)

In [6]:
q_table[(-9,-9), (-9,-9)]   

[-0.8812321119037074,
 -2.558076760866161,
 -2.2477099604133346,
 -1.4304378213234172]

In [7]:
episode_rewards = []

for episode in range(EPISODES):
    player = Blob()
    food = Blob()
    enemy = Blob()

    if episode % 3000 == 0:
        print(f'on {episode}, {epsilon}')
        print(f"Episode mean {np.mean(episode_rewards[-3000:])}")
        show = True
    else:
        show = False

    rewards_per_episode = 0
    for i in range(200):
        rewards = 0
        state = (player-food, player-enemy)

        if np.random.random() > epsilon:
            action = np.argmax(q_table[state])
        else:
            action = np.random.randint(0,4)

        player.action(action)

        # TODO Train enemy to catch player

        if player == enemy:
            rewards = -ENEMY_PENALTY
        elif player == food:
            rewards = FOOD_REWARD
        else:
            rewards = -MOVE_PENALTY

        new_state = (player-food, player-enemy)

        q_table[state][action] = (1-LEARNING_RATE)*q_table[state][action] + LEARNING_RATE*(rewards + DISCOUNT*np.max(q_table[new_state]))

        if show:
            env = np.zeros((SIZE,SIZE, 3), dtype=np.uint8)
            env[food.y][food.x] = d[FOOD_N]
            env[player.y][player.x] = d[PLAYER_N]
            env[enemy.y][enemy.x] = d[ENEMY_N]

            img = Image.fromarray(env, "RGB")
            img = img.resize((300,300))
            cv2.imshow("",np.array(img))
            if rewards == FOOD_REWARD or rewards == -ENEMY_PENALTY:
                if cv2.waitKey(500) & 0xFF == ord('q'):
                    break
            else:
                if cv2.waitKey(100) & 0xFF == ord('q'):
                    break
        
        rewards_per_episode += rewards
        if rewards == FOOD_REWARD or rewards == -ENEMY_PENALTY:
            break
        
    episode_rewards.append(rewards_per_episode)
    epsilon *= DECAY_RATE

cv2.destroyAllWindows()

on 0, 0.9
Episode mean nan


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


on 3000, 0.0022175140060036905
Episode mean -156.94
on 6000, 5.463742629802828e-06
Episode mean -93.77566666666667
on 9000, 1.3462139785319096e-08
Episode mean -69.18766666666667
on 12000, 3.316942613858276e-11
Episode mean -54.458666666666666
on 15000, 8.172629670379144e-14
Episode mean -41.25033333333333
on 18000, 2.0136578622163367e-16
Episode mean -36.461
on 21000, 4.96146056973797e-19
Episode mean -31.386333333333333
on 24000, 1.222456478180995e-21
Episode mean -24.607333333333333


In [9]:
with open('q-table.txt', 'wb') as f:
    pickle.dump(q_table, f)