In [1]:
import os
import cv2
import time
import numpy as np
import random
from tqdm import tqdm
import gym
import copy
import matplotlib.pyplot as plt
import argparse
from collections import namedtuple

import torch
import torch.tensor as Tensor
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import wandb

%matplotlib inline

In [2]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(env.observation_space.shape[0], 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, env.action_space.n)
            
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [3]:
class ReplayBuffer():
    def __init__(self, size):
        self.size = size
        self.buffer = []
        self.index = 0
        self.transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 'done'))
        
    def fill_buffer(self):
        obs = env.reset()
        done = False
        for trans in tqdm(range(0, self.size)):
            action = env.action_space.sample()
            new_obs, reward, done, _ = env.step(action)
            self.buffer.append(self.transition(obs, action, reward, new_obs, done))
            if done:
                obs = env.reset()
                done = False
            else:
                obs = new_obs
    
    def store_filled(self, trans):
        self.index = (self.index + 1) % self.size
        self.buffer[self.index] = self.transition(trans[0], trans[1], trans[2], trans[3], trans[4])
        
    def store(self, trans):
        if (self.index + 1) % self.size:
            self.buffer.append(self.transition(trans[0], trans[1], trans[2], trans[3], trans[4]))
            self.index += 1
        else:
            self.store_filled(trans)
        
    def sample(self, batch=32):
        return random.sample(self.buffer, k=batch)

In [4]:
env = gym.make('CartPole-v0')
buffer = ReplayBuffer(10000)
# buffer.fill_buffer()

In [5]:
learning_rate = 1e-4
value = MLP()
target = copy.deepcopy(value)
optimizer = optim.Adam(value.parameters(), lr=learning_rate)
loss_fn = torch.nn.MSELoss()

In [6]:
def get_action(obs):
    if np.random.rand() < epsilon:
        return env.action_space.sample()
    else:
        return torch.argmax(get_current_value())
    
def get_target_value(obs):
    return target.forward(torch.from_numpy(obs).float().unsqueeze(0)).detach()

def get_current_value(obs):
    return value.forward(torch.from_numpy(obs).float().unsqueeze(0))

In [7]:
EPISODES = 10000
epsilon = 1
gamma = 0.9
rewards = []

obs = env.reset()
done = False
for episode in tqdm(range(0, EPISODES)):
    action = get_action(obs)
    new_obs, reward, done, _ = env.step(action)
    buffer.store((obs, action, reward, new_obs, done))
    if done:
        done = False
        obs = env.reset()
        rewards.append(step)
        step = 0
    else:
        step += 1
        obs = new_obs
        
    if len(buffer.buffer) > 32:
        optimizer.zero_grad()
        minibatch = buffer.sample()
        next_qs = [i.reward if i.done else i.reward + gamma * torch.max(get_target_value(i.next_state)) for i in minibatch]
        current_qs = [get_current_value(i.state).squeeze(0)[0] for i in minibatch]
        loss = loss_fn(current_qs, next_qs)
        loss.backward()
        optimizer.step()

100%|██████████| 10000/10000 [01:12<00:00, 137.98it/s]


In [None]:
len(next_qs)

In [None]:
current_qs[1].squeeze(0)[0]