In [1]:
import gym
import collections
from tensorboardX import SummaryWriter
import torch
import os
import numpy as np
import torch.nn as nn
import torch.optim as optim

In [17]:
ENV_NAME = "FrozenLake-v1"
n_neurons = 64
GAMMA = 0.9
ALPHA = 0.1

In [3]:
class ObservationWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.obs_length = env.observation_space.n
        
    def observation(self, obs):
        res = torch.zeros(self.obs_length)
        res[obs] = 1
        return res

In [4]:
class net(nn.Module):
    def __init__(self, in_dim, n_neurons, out_dim):
        super(net, self).__init__()
        
        self.fc1 = nn.Linear(in_dim, n_neurons)
        self.fc2 = nn.Linear(n_neurons, out_dim)
        self.act = nn.ReLU()
        self.soft = nn.Softmax(dim = 0)
        
    def forward(self, x):
        return self.soft(self.fc2(self.act(self.fc1(x))))

In [None]:
class full_net(nn.Module):
    def __init__(self, player, q_value):
        self.player = player
        self.q_value = q_value
        
    def forward(self, x):
        

In [15]:
class Agent_NN():
    def __init__(self):
        
        self.env = ObservationWrapper(gym.make(ENV_NAME))
        self.net = net(self.env.observation_space.n, n_neurons, self.env.action_space.n)
        self.net_optimizer = optim.Adam(self.net.parameters(), lr=ALPHA)
        self.state = self.env.reset()
        self.values = net(self.env.observation_space.n + 1, n_neurons, 1)
        self.values_optimizer = optim.Adam(self.values.parameters(), lr=ALPHA)
        self.n_actions = self.env.action_space.n
        
    def play_episode(self):
        
        old_state = self.state
        action = torch.argmax(self.net(old_state))
        next_state, reward, is_done, _ = self.env.step(action.item())
        next_action = torch.argmax(self.net(next_state))
        old_state = self.env.reset() if is_done else next_state
        
        return old_state, action, reward, next_state, next_action
    
    def calculate_loss(self, preds, targets):
        return torch.square(preds - targets)
    
    def update_params(self, s, a, r, n_s, n_a):
        q_value = self.values(torch.cat((s, a.unsqueeze(0).detach_() / self.n_actions)))
        #max_q_value= self.values(torch.cat((n_s, n_a.unsqueeze(0) / self.n_actions)))
        target_q_value = r #+ GAMMA*max_q_value
        
        loss = self.calculate_loss(q_value, target_q_value)
        
        self.optimizer.zero_grad()
        
        loss.backward()
        
        self.optimizer.step()
        
        return loss
    
    def play_test(self, env):
        total_reward = 0.0
        state = env.reset()
        while True:
            action = torch.argmax(self.net(state))
            new_state, reward, is_done, _ = env.step(action.item())
            total_reward += reward
            if is_done:
                break
            state = new_state
        return total_reward

In [16]:
test_env = ObservationWrapper(gym.make(ENV_NAME))
agent = Agent_NN()
writer = SummaryWriter(comment="my-q-learning")
TEST_EPISODES = 20
iter_no = 0
best_reward = 0.0

while True:
    iter_no += 1
    loss = 0.0
    s, a, r, next_s, next_a = agent.play_episode()
    loss += agent.update_params(s, a, r, next_s, next_a)
    reward = 0.0
    for _ in range(TEST_EPISODES):
        reward += agent.play_test(test_env)
    reward /= TEST_EPISODES
    writer.add_scalar("reward", reward, iter_no)
    writer.add_scalar("loss", loss, iter_no)
    if reward > best_reward:
        print("Best reward updated %.3f -> %.3f" % (best_reward, reward))
        best_reward = reward
    if reward > 0.80:
        print("Solved in %d iterations!" % iter_no)
        break
writer.close()

KeyboardInterrupt: 