In [1]:
%load_ext autoreload
%autoreload 2

In [5]:
import gym
import torch
import torch.nn as nn
import numpy as np
from collections import deque
import random
from itertools import count
import torch.nn.functional as F
from tensorboardX import SummaryWriter
import mario_env

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class QNetwork(nn.Module):
    def __init__(self):
        super(QNetwork, self).__init__()

        self.fc1 = nn.Linear(187, 64*4)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.fc_value = nn.Linear(64*4, 256*4)
        self.fc_adv = nn.Linear(64*4, 256*4)

        self.value = nn.Linear(256*4, 1)
        self.adv = nn.Linear(1024, 17)

    def forward(self, state):
        state = torch.flatten(state)
        y = self.relu(self.fc1(state))
        
        value = self.relu(self.fc_value(y))
        adv = self.relu(self.fc_adv(y))

        value = self.relu(self.value(value))
        adv = self.relu(self.adv(adv))

        Q = self.sigmoid(value + adv)
        return Q

    def select_action(self, state):
        state = torch.flatten(state)
        with torch.no_grad():
            Q = self.forward(state)
            action_index = Q
        return action_index


class Memory(object):
    def __init__(self, memory_size: int) -> None:
        self.memory_size = memory_size
        self.buffer = deque(maxlen=self.memory_size)

    def add(self, experience) -> None:
        self.buffer.append(experience)

    def size(self):
        return len(self.buffer)

    def sample(self, batch_size: int, continuous: bool = True):
        if batch_size > len(self.buffer):
            batch_size = len(self.buffer)
        if continuous:
            rand = random.randint(0, len(self.buffer) - batch_size)
            return [self.buffer[i] for i in range(rand, rand + batch_size)]
        else:
            indexes = np.random.choice(np.arange(len(self.buffer)), size=batch_size, replace=False)
            return [self.buffer[i] for i in indexes]

    def clear(self):
        self.buffer.clear()


env = gym.make('MarioEnv-v0')
n_state = env.observation_space.shape[0]
n_action = 17

onlineQNetwork = QNetwork().to(device)
targetQNetwork = QNetwork().to(device)
targetQNetwork.load_state_dict(onlineQNetwork.state_dict())

optimizer = torch.optim.Adam(onlineQNetwork.parameters(), lr=1e-2)

GAMMA = 0.99
EXPLORE = 20000
INITIAL_EPSILON = 0.1
FINAL_EPSILON = 0.0001
REPLAY_MEMORY = 50000
BATCH = 16

UPDATE_STEPS = 4

memory_replay = Memory(REPLAY_MEMORY)

epsilon = INITIAL_EPSILON
learn_steps = 0
writer = SummaryWriter('logs/ddqn')
begin_learn = False

episode_reward = 0

# onlineQNetwork.load_state_dict(torch.load('ddqn-policy.para'))
for epoch in count():

    state = env.reset()
    episode_reward = 0
    for time_steps in range(200):
        p = random.random()
        if p < epsilon:
            action = [random.randint(0, 1) for _ in range(17)] 
        else:
            tensor_state = torch.FloatTensor(state).unsqueeze(0).to(device)
            action = onlineQNetwork.select_action(tensor_state)
        next_state, reward, done, _ = env.step(action)
        episode_reward += reward
        memory_replay.add((state, next_state, action, reward, done))
        if memory_replay.size() > 50:
            if begin_learn is False:
                print('learn begin!')
                begin_learn = True
            learn_steps += 1
            if learn_steps % UPDATE_STEPS == 0:
                targetQNetwork.load_state_dict(onlineQNetwork.state_dict())
            batch = memory_replay.sample(50, False)
            
            batch_state, batch_next_state, batch_action, batch_reward, batch_done = zip(*batch)
            batch_state = torch.FloatTensor(batch_state).to(device)
            batch_next_state = torch.FloatTensor(batch_next_state).to(device)
            batch_action = torch.FloatTensor(batch_action).to(device)
            batch_reward = torch.FloatTensor(batch_reward).to(device)
            batch_done = torch.FloatTensor(batch_done).to(device)

            with torch.no_grad():
                onlineQ_next = onlineQNetwork(batch_next_state)
                targetQ_next = targetQNetwork(batch_next_state)
                online_max_action = torch.argmax(onlineQ_next, dim=1, keepdim=True)
                y = batch_reward + (1 - batch_done) * GAMMA * targetQ_next.gather(1, online_max_action.long())

            loss = F.mse_loss(onlineQNetwork(batch_state).gather(1, batch_action.long()), y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            writer.add_scalar('loss', loss.item(), global_step=learn_steps)

            if epsilon > FINAL_EPSILON:
                epsilon -= (INITIAL_EPSILON - FINAL_EPSILON) / EXPLORE

                
        if done:
            break
        state = next_state

    writer.add_scalar('episode reward', episode_reward, global_step=epoch)
    if epoch % 10 == 0:
        torch.save(onlineQNetwork.state_dict(), 'ddqn-policy.para')
        print('Ep {}\tMoving average score: {:.2f}\t'.format(epoch, episode_reward))




called reset
[Errno 17] File exists
/Users/jackboynton/Library/Application Support/Dolphin/Pipes/p3
buffering p3 input fifo...
[Errno 17] File exists
/Users/jackboynton/Library/Application Support/Dolphin/Pipes/p4
buffering p3 input fifo...
sent reload state
48956.31777907968
25711.72894864715
7140.736884751532
Ep 0	Moving average score: 81808.78	
called reset
sent reload state
509182.62880525267
992687.0591096869
1290364.1046565024
1741770.0070329497
2106250.790523807
2639335.8070632736
3164795.1749279206
2443734.753748022
called reset
sent reload state
-3213319.8987795184
-904023.7390911662
-3697939.8809545822
-1567760.0745115918
-2053360.7990113455
-2618540.5417952957
-4368544.63153979
-1845012.8023758696
called reset
sent reload state
-1212505.7796918638
-820582.9906583775
-273181.22757419985
67604.61391768038
252693.88993090484
173639.50570949673
53578.358562049
-321535.8186914494
called reset
sent reload state
-3001711.534293363
-3536902.4685508753
-3881757.0224492145
called rese

ValueError: only one element tensors can be converted to Python scalars