In [1]:
from collections import deque
from dataclasses import dataclass
import numpy as np
import random

import gym

import torch.optim as optim
import torch
import torch.nn as nn

from numpy.random import default_rng
rng = default_rng()


@dataclass(frozen=True)
class Transition:
    pass

@dataclass(frozen=True)
class CartpoleTransition(Transition):
    s: np.ndarray
    s_prime: np.ndarray
    a: int
    r: float
    is_done: bool
        
class ReplayMemory():
    def __init__(self,buffer_size):
        self.buffer = deque([],maxlen=buffer_size)
        
    def append(self,transition):
        self.buffer.appendleft(transition)
    
    def sample(self,batch_size=1):
        return random.sample(self.buffer,batch_size)

In [2]:
class StateTransitionModel(nn.Module):
    def __init__(self,feature_cnt,action_cnt):
        super(StateTransitionModel, self).__init__()
        
        hidden_layer_cnt = 64
        hidden_layer2_cnt = 64
        
        self.fc1 = nn.Linear(feature_cnt,hidden_layer_cnt)
        #self.fc2 = nn.Linear(hidden_layer_cnt,hidden_layer2_cnt)
        self.fc3 = nn.Linear(hidden_layer2_cnt,action_cnt)
    
    def forward(self,x):
        x = nn.functional.relu(self.fc1(x))
        #x = torch.tanh(self.fc2(x))
        #x = nn.functional.relu(self.fc1(x))
        x = self.fc3(x)
        return x

In [3]:
class Agent:
    def __init__(self,**args):
        
        print(args)
        
        self.action_cnt = args['action_cnt'] #env.action_space.n
        self.feature_cnt = args['feature_cnt'] #env.observation_space.shape[0]
        
        self.criterion = nn.MSELoss()
        
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        
        self.smoothing = args['smoothing'] if 'smoothing' in args else 0.5
        
        self.target_transition_network = StateTransitionModel(self.feature_cnt,self.action_cnt).to(self.device)
        self.base_transition_network = StateTransitionModel(self.feature_cnt,self.action_cnt).to(self.device)
        self.lr = args['lr'] if 'lr' in args else 0.01
        self.optimizer = optim.Adam(self.target_transition_network.parameters(),lr=self.lr)
        
        self.alpha = args['alpha'] if 'alpha' in args else 1.0
        self.gamma = args['gamma'] if 'gamma' in args else 1.0
        self.batch_size = args['batch_size'] if 'batch_size' in args else 20
        self.memory_size = args['memory_size'] if 'memory_size' in args else 200
        
        self.memory = deque([],maxlen=self.memory_size)
        
        self.epsilon = args['epsilon'] if 'epsilon' in args else 1.0
        self.epsilon_decay = args['epsilon_decay'] if 'epsilon_decay' in args else 0.99
        self.epsilon_min = args['epsilon_min'] if 'epsilon_min' in args else 0.1
        
    
    def _update_base_network(self):
        """
        Double DQN uses a smoothing factor when updating base transition model. 
        Smoothing improves the model's training stability by 'smoothing' large variations in the target model weight update values  
        """
        
        base_params = self.base_transition_network.named_parameters()
        target_params = self.target_transition_network.named_parameters()
        
        blended_weights = {}
        for base_param,target_param in zip(base_params,target_params):
            blended_weights[base_param[0]] = self.smoothing*base_param[1].data+(1-self.smoothing)*target_param[1].data
        
        self.base_transition_network.load_state_dict(blended_weights)
    
    def _update_epsilon(self):
        if self.epsilon >self.epsilon_min:
            self.epsilon = max(self.epsilon_min,self.epsilon*self.epsilon_decay)       
        
        
    def get_action(self,s,use_epsilon_decay=True):
        if np.random.random() < self.epsilon and use_epsilon_decay:
            action = random.randint(0,self.action_cnt-1)
        else:
            action = torch.argmax(self.base_transition_network.forward(torch.from_numpy(s.astype(np.float32)).to(self.device))).item()
            
        return action
    
    def append_replay(self,transition):
        self.memory.appendleft(transition)
    
    
    def update_transition_model(self):

        if len(self.memory)<self.batch_size:
            return
        
        batch_cnt=10
        for _ in range(batch_cnt):
            self.optimizer.zero_grad()
            
            minibatch_ix = rng.choice(len(self.memory), size=self.batch_size, replace=False)

            state = torch.from_numpy(np.array([self.memory[v]['s'] for v in minibatch_ix]).astype(np.float32)).to(self.device)
            next_state = torch.from_numpy(np.array([self.memory[v]['s_prime'] for v in minibatch_ix]).astype(np.float32)).to(self.device)
            reward = np.array([self.memory[v]['r'] for v in minibatch_ix])
            action_ix = np.array([self.memory[v]['a'] for v in minibatch_ix])

            target = self.base_transition_network.forward(state)
            for i in range(len(action_ix)):
                target[i,action_ix[i]]=reward[i]

            future_transition = self.base_transition_network.forward(next_state)
            has_future_states = torch.reshape(torch.from_numpy(np.array([not self.memory[v]['is_done'] for v in minibatch_ix])),(-1,1)).to(self.device)
            discounted_reward = torch.max(future_transition*has_future_states,axis=1).values*self.gamma

            for i in range(len(action_ix)):
                target[i,action_ix[i]]+=discounted_reward[i]

            # Backprop Target Network
            prediction = self.target_transition_network.forward(state)            
            loss = self.criterion(prediction,target)
            loss.backward()
            self.optimizer.step()
            
        self._update_epsilon()
        self._update_base_network()


In [4]:
class Simulator():
    pass

class CartpoleSimulator(Simulator):
    def __init__(self,agent:Agent,**args):
        self.env = gym.make("CartPole-v1")
        self.agent = agent
    
    def terminate(self):
        self.env.close()
    
    def run_trial(self,use_epsilon=True):
        
        curr_state = self.env.reset()
        replay_records = []
        cum_reward=0
        done=False
        
        while cum_reward<2000 and not done: 
            action = self.agent.get_action(curr_state,use_epsilon)
            s_prime, r, done, _ = self.env.step(action)
            
            if done and cum_reward<499:
                r = -100
            
            transition = {'s':curr_state,'s_prime':s_prime,'a':action,'r':r,'is_done':done}
            self.agent.append_replay(transition)
            cum_reward += 1
            curr_state = s_prime
            
        self.rewards.append(cum_reward)
    
    
    def train(self):
        
        run_cnt=1000
        self.rewards = []
        trial_cnt=10
        
        while run_cnt>0:
            
            for trial in range(trial_cnt):
                self.run_trial()
            self.agent.update_transition_model()
            
            if run_cnt%10==0:
                print(f'Experiment {1000-run_cnt}: reward={np.mean(self.rewards)}, epsilon={self.agent.epsilon}, memory={len(self.agent.memory)}')
                self.rewards=[]
            run_cnt-=1
        
    def test(self):
        self.env.reset()
        while True:
            self.env.render()
            self.run_trial(use_epsilon=False)
        for _ in range(iteration):
            self.step(visualize=True)
        self.terminate()


In [5]:
params = {'action_cnt':2,'feature_cnt':4,'alpha':1.0,'gamma':0.99999,'memory_size':10000,'batch_size':64,'lr':0.001, 'epsilon':0.90,'epsilon_decay':0.99,'epsilon_min':0.20}
agent = Agent(**params)

{'action_cnt': 2, 'feature_cnt': 4, 'alpha': 1.0, 'gamma': 0.99999, 'memory_size': 10000, 'batch_size': 64, 'lr': 0.001, 'epsilon': 0.9, 'epsilon_decay': 0.99, 'epsilon_min': 0.2}


In [7]:
sim = CartpoleSimulator(agent,**params)
sim.train()#visualize=True)
#sim.test()

Experiment 0: reward=18.2, epsilon=0.891, memory=182
Experiment 10: reward=19.29, epsilon=0.8058044288328449, memory=2111
Experiment 20: reward=20.78, epsilon=0.7287550813991327, memory=4189
Experiment 30: reward=22.11, epsilon=0.6590730326889578, memory=6400
Experiment 40: reward=42.82, epsilon=0.5960538368855853, memory=10000
Experiment 50: reward=40.97, epsilon=0.5390604058195451, memory=10000
Experiment 60: reward=20.87, epsilon=0.48751656837016827, memory=10000
Experiment 70: reward=19.72, epsilon=0.4409012457037845, memory=10000
Experiment 80: reward=38.48, epsilon=0.39874318346355536, memory=10000
Experiment 90: reward=41.31, epsilon=0.3606161876563865, memory=10000
Experiment 100: reward=21.68, epsilon=0.3261348160744472, memory=10000
Experiment 110: reward=21.29, epsilon=0.2949504816940234, memory=10000
Experiment 120: reward=32.36, epsilon=0.26674792865928726, memory=10000
Experiment 130: reward=83.98, epsilon=0.24124204522518672, memory=10000
Experiment 140: reward=74.47, ep

In [16]:
#sim.test()

In [18]:
sim.env.reset()
sim.env.render()
curr_state = sim.env.reset()
done=False
        
while not done:
    sim.env.render()
    action = sim.agent.get_action(curr_state,False)
    s_prime, r, done, _ = sim.env.step(action)  
    transition = {'s':curr_state,'s_prime':s_prime,'a':action,'r':r,'is_done':done}
    curr_state = s_prime
sim.terminate()