In [4]:
import random
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical
from tqdm.notebook import tqdm
from tensorboardX import SummaryWriter

torch.autograd.set_detect_anomaly(True)

import os
os.chdir('/home/sast/fastmoe/pjj/LunarLander')
os.listdir()

['log', 'main.ipynb', '.ipynb_checkpoints', 'ckpts']

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'
print(device)

cuda


In [6]:
import gym
import random
import numpy as np

from gym import envs

env = gym.make('LunarLander-v2')

In [7]:
def random_play():
    env.reset()
    done = False
    while not done:
        action = env.action_space.sample()
        observation, reward, done, _ = env.step(action)
        print(reward, end=', ')

random_play()

-2.1996377530852826, -1.2190360744776, -0.5245707082581987, 0.7000712750587297, -2.6062535786834347, -1.0229027563111857, -1.4824049776621564, -2.939953131173836, -0.7239307312343317, -3.013622261728726, -3.087152221770394, -3.13140120860035, -1.233369192447724, -3.2972080960410026, -1.4601189048129573, -2.207291101309636, -2.2185334149798166, -0.9462362551734589, -0.9536187630502855, 0.12942814781554263, -0.7937724776190567, -0.7802081004454056, -0.23920378667073122, -1.7470737572213466, -1.4906333134184138, -1.4983004503601478, -2.4040297416203473, -1.6324714169287518, -0.891971870791781, -1.647245086072246, -1.6453519224468494, -1.6414145706783643, -0.4677050494495336, -0.3196938135328924, -2.517644158573914, -1.4228919537009688, -2.4122204974366626, -1.5822578843897759, -2.8181230974840603, -0.8311187236660931, -2.8439119716915955, -1.8090141523564398, -0.790916191043084, -0.8089407467623733, -2.4859624747395217, -0.7230769650081175, -0.7624650515002418, -0.8528194096769426, -1.425

In [14]:
class QNetwork(nn.Module):

    def __init__(self):
        super().__init__()

        self.linears = nn.Sequential(
            nn.Linear(8, 128),
            nn.SELU(),
            nn.Linear(128, 32),
            nn.SELU(),
            nn.Linear(32, 4),
        )

    def forward(self, state):
        out = self.linears(state)
        return out

In [9]:
class Agent():

    def __init__(self, e_network, t_network, optimizer, loss_fn, greedy_epsilon):
        self.e_network = e_network
        self.t_network = t_network
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        self.greedy_epsilon = greedy_epsilon

    def act(self, state, net = 'e'):
        state = torch.Tensor(state).to(device)
        if net == 'e':
            q_values = self.e_network(state)
        elif net == 't':
            q_values = self.t_network(state)
        if random.random() < self.greedy_epsilon:
            action = torch.argmax(q_values).item()
        else:
            # action_dist = Categorical(q_values)
            # action = action_dist.sample()
            action = random.randint(0, q_values.shape[-1] - 1)
        return action, q_values[action], q_values

    def update(self, sampled_items):
        # (state, action, reward, next_state, done)
        with torch.no_grad():
            y = [sample[2]
                 if sample[4] is True
                 else sample[2] + decay * agent.act(sample[3], 't')[1].item()
                 for sample in sampled_items]
            y = torch.Tensor(y).to(device)

        q_values = [agent.act(sample[0], 'e')[2][sample[1]] for sample in sampled_items]
        q_values = torch.stack(q_values).to(device)

        self.optimizer.zero_grad()
        loss = self.loss_fn(y, q_values)
        loss.backward()
        self.optimizer.step()

        self.t_network.load_state_dict(self.e_network.state_dict())

        return loss

    def save(self, path):
        agent_dict = {
            'e_network': self.e_network.state_dict(),
            't_network': self.t_network.state_dict(),
            'optimizer': self.optimizer.state_dict(),
        }
        torch.save(agent_dict, path)

    def load(self, path):
        ckpt = torch.load(path)
        self.e_network.load_state_dict(ckpt['e_network'])
        self.t_network.load_state_dict(ckpt['t_network'])
        self.optimizer.load_state_dict(ckpt['optimizer'])


In [15]:
writer = SummaryWriter(f'log')

estimate_network = QNetwork().to(device)
target_network = QNetwork().to(device)
target_network.load_state_dict(estimate_network.state_dict())
loss_fn = nn.MSELoss()

optimizer = optim.Adam(estimate_network.parameters())
agent = Agent(estimate_network, target_network, optimizer, loss_fn, 0.9)

agent.e_network.train()
agent.t_network.train()
tot_batch = 10000
max_step  = 1000
sample_interval = 1
episode_per_batch = 5
sample_batch = 64
decay = 0.99

buffer = [] # (state, action, reward, next_state, is_done)
buffer_size = 256

prg_bar = tqdm(range(tot_batch))

loss_index = 0

final_rewards = []

for batch in prg_bar:

    rewards = []

    for episode in range(episode_per_batch):

        state = env.reset()

        for step in range(max_step):
            action, _, _ = agent.act(state, 'e')
            next_state, reward, done, _ = env.step(action)

            rewards.append(reward)
            buffer.append((state, action, reward, next_state, done))
            buffer = buffer[-buffer_size:]

            if done is True:
                final_rewards.append(reward)
                writer.add_scalar('final_reward', reward, len(final_rewards))
                break

            state = next_state

            if step % sample_interval == 0 and len(buffer) > sample_batch:
                num_sample = min(len(buffer), sample_batch)
                sampled_items = random.sample(buffer, num_sample)
                loss = agent.update(sampled_items)
                loss_index += 1
                writer.add_scalar('loss', loss, loss_index)

        # end of for step
    # end of for episode

    avg_reward = sum(rewards) / len(rewards)
    avg_final_reward = sum(final_rewards) / len(final_rewards)
    prg_bar.set_description(f'Avg_reward: {avg_reward: .1f}, '
                            f'Avg_final_reward: {avg_final_reward: .1f}')

    writer.add_scalar('avg_reward', avg_reward, batch)

    if batch % 20 == 0:
        agent.save(f'./ckpts/{batch}.pkl')

  0%|          | 0/10000 [00:00<?, ?it/s]

KeyboardInterrupt: 

['log', 'main.ipynb', '.ipynb_checkpoints', 'ckpts']