In [1]:
import numpy as np
from skimage.transform import resize
import gym
import gym.spaces
import gym_oculoenv

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.autograd import Variable
from torch.distributions import Categorical
import torch.optim as optim

from esncell import ESNCell
import utils

import matplotlib.pyplot as plt
%matplotlib inline

from itertools import count
from collections import namedtuple

In [2]:
# Dataset params
sample_length = 1000
n_samples = 40
batch_size = 1

n_obs = 2048
n_action = 9
oh_mat = np.eye(n_action)

# ESN properties
input_dim = n_obs + n_action
n_hidden = 200
w_sparsity=0.1

n_iterations = 50

SavedAction = namedtuple('SavedAction', ['log_prob', 'value'])

In [3]:
class ActorCritic(nn.Module):
    def __init__(self, n_action, input_dim, hidden_dim=2000):
        super(ActorCritic, self).__init__()
        self.affine_a = nn.Linear(input_dim, hidden_dim) # for actor
        self.affine_v = nn.Linear(input_dim, hidden_dim) # for critic
        self.action_head = nn.Linear(hidden_dim, n_action)
        self.value_head = nn.Linear(hidden_dim, 1)
        self.saved_actions = []
        self.rewards = []

    def forward(self, x):
        x1 = F.relu(self.affine_a(x))
        x2 = F.relu(self.affine_v(x))
        action_scores = self.action_head(x1)
        state_values = self.value_head(x2)
        return F.softmax(action_scores, dim=-1), state_values

In [4]:
resnet = models.resnet50(pretrained=True)
resnet = nn.Sequential(*list(resnet.children())[:-1])
resevior = ESNCell(input_dim, n_hidden, batch_size, spectral_radius=0.5, input_scaling=0.3, w_sparsity=w_sparsity)
ac_model = ActorCritic(n_action, n_hidden)

optimizer = optim.Adam(ac_model.parameters(), lr=1e-3)
eps = np.finfo(np.float32).eps.item()

In [5]:
def select_action(state):
    state = torch.from_numpy(state).float()
    state = state.requires_grad_(requires_grad=True)
    probs, state_value = ac_model(state)
#     print(probs)
    m = Categorical(probs)
    action = m.sample()
    ac_model.saved_actions.append(SavedAction(m.log_prob(action), state_value))
    return action.item()


def finish_episode(gamma=0.99):
    R = 0
    saved_actions = ac_model.saved_actions
    policy_losses = []
    value_losses = []
    rewards = []
    for r in ac_model.rewards[::-1]:
        R = r + gamma * R
        rewards.insert(0, R)
    rewards = torch.tensor(rewards)
    rewards = (rewards - rewards.mean()) / (rewards.std() + eps)
    for (log_prob, value), r in zip(saved_actions, rewards):
        reward = r - value.item()
        policy_losses.append(-log_prob * reward)
        value_losses.append(F.smooth_l1_loss(value.view(1), torch.tensor([r])))
    optimizer.zero_grad()
    loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum()
    loss.backward()
#     retain_graph=True
    optimizer.step()
    print(loss.item(), np.sum(ac_model.rewards))
    del ac_model.rewards[:]
    del ac_model.saved_actions[:]

In [6]:
env = gym.make("RedCursor-v0")
obs = env.reset()

NoneType 0.05817764173314431 9 False False


In [None]:
running_reward = 10
hidden_x = resevior.init_hidden(batch_size)
for i_episode in range(10000):
    obs = env.reset()
    action_arr = np.array([0.] * n_action).reshape(1, -1)
    for t in range(200):  # Don't infinite loop while learning
        obs = obs/255.
        obs = resize(obs, (224, 224))
        obs = torch.tensor(np.expand_dims(obs.transpose((2, 0, 1)), axis=0).astype("float32"))
        obs = resnet(obs).view(1, -1)
        obs = torch.cat([obs, torch.tensor(action_arr.astype("float32")).requires_grad_(requires_grad=False)], dim=-1)
        hidden_x = resevior(obs, hidden_x=hidden_x)
#         print(hidden_x)
        state = hidden_x.detach().numpy().reshape(1, -1)
        action = select_action(state)
        action_arr = np.expand_dims(oh_mat[action], axis=0)
        time_step = env.step(action)
        obs, reward, done, _ = env.step(action)
        ac_model.rewards.append(reward)
        if done or (t+1)%20 == 0:
            running_reward = running_reward * 0.99 + t * 0.01
            finish_episode()
            if done:
                print('Episode {}\tLast length: {:5d}\tAverage length: {:.2f}'.format(i_episode, t, running_reward))
                break
        if t+1 == 200:
            print('Episode {}\tLast length: {:5d}\tAverage length: {:.2f}'.format(i_episode, t, running_reward))

5.901492595672607 15.71751820531188
75.90277099609375 11.676625603246116
4.729850769042969 12.26033076200861
9.254830360412598 9.256434719774237
16.24832534790039 7.577581708653683
-0.8216276168823242 7.198015772708286
9.305319786071777 5.679752028926697
12.164582252502441 5.679752028926697
13.040145874023438 5.679752028926697
11.194570541381836 5.679752028926697
Episode 0	Last length:   199	Average length: 19.62
9.128890037536621 14.014790373931564
9.151163101196289 13.193009343192527
10.161458015441895 13.193009343192527
10.600664138793945 13.193009343192527
10.031092643737793