In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.distributions import Categorical

import gym

In [2]:
from util import DotDict
from memory import Memory

In [176]:
# an Env should record the episode rewards etc.

class Env:
    def __init__(self, name, stat):
        self._name = name
        self._env = gym.make(name)
        self._stat = stat
        self._state = torch.FloatTensor(self._env.reset())
    
    @property
    def name(self):
        return self._name
    
    @property
    def state(self):
        return self._state
    
    def step(self, action):
        next_state, reward, done, info = self._env.step(action.item())
        self._stat.rewards[-1] += reward
        if done:
            print(self._stat.frame, self._stat.rewards[-1])
            self._stat.rewards.append(0)
        self._state = torch.FloatTensor(self._env.reset()) if done else torch.FloatTensor(next_state)
        return torch.FloatTensor(next_state), torch.tensor(reward), torch.tensor(done), info
    

In [169]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.linear1 = nn.Linear(4, 20)
        self.linear2 = nn.Linear(20, 20)
        self.linear3 = nn.Linear(20, 20)
        self.linear4 = nn.Linear(20, 2)
    
    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = F.relu(self.linear3(x))
        return F.softmax(self.linear4(x), dim=-1)

In [168]:
class Base:
    def __init__(self, config):
        self._config = DotDict(**config)
        self._stat = DotDict(**{
            'frame': 0,
            'episode': 0,
            'rewards': [0]
        })
        env_name = self._config.env
        self._config.env = Env(env_name, self._stat)
    
    @property
    def env(self):
        return self._config.env
    
    @property
    def nn(self):
        return self._config.nn
    
    @property
    def optim(self):
        return self._config.optim
    
    @property
    def scheduler(self):
        return self._config.scheduler
    
    @property
    def stat(self):
        return self._stat
    
    @property
    def frame(self):
        return self._stat.frame
    
    @property
    def episode(self):
        return self._stat.episode
    
    def run(self, n_frames, from_frame=0):
        self._stat.frame = from_frame
        
        while self.frame <= n_frames:
            for _ in range(self._config.step_length):
                self.step()
                self._stat.frame += 1
            self.learn()
            
            if len(self._stat.rewards) >= 100 and sum(self._stat.rewards) / len(self._stat.rewards) >= 195:
                print('Solved')
                break

## REINFORCE

In [210]:
class Reinforce(Base):
    def __init__(self, config):
        super(Reinforce, self).__init__(config)
        self._memory = Memory(
            fields=('log_prob', 'reward', 'done'),
            cap=config['step_length'])
    
    def step(self):
        state = self.env.state
        policy = self.nn(state)
        probs = Categorical(policy)
        action = probs.sample()
        log_prob = probs.log_prob(action)
        next_state, reward, done, info = self.env.step(action)
        self._memory.append([log_prob, reward, done.float()])
    
    def learn(self):
        log_probs, rewards, dones = self._memory.flush()
        gamma = self._config.gamma
        eps = torch.finfo()
        
        gain = 0
        exp_rewards = []
        for i in reversed(range(rewards.size(0))):
            reward, done = rewards[i], dones[i]
            gain = reward + gamma * gain * (1 - done)
            exp_rewards.append(gain)
        
        exp_rewards.reverse()
        exp_rewards = torch.stack(exp_rewards)
        eps = torch.finfo(exp_rewards.dtype).eps
        exp_rewards = (exp_rewards - exp_rewards.mean()) / (exp_rewards.std() + eps)

        loss = -(log_probs * exp_rewards).sum()
        self.optim.zero_grad()
        loss.backward()
        self.optim.step()

In [171]:
config = {
    'env': 'CartPole-v0',
    'nn': net,
    'optim': optimizer,
    'scheduler': None,
    'step_length': 20,
    'gamma': 0.9
}

net = Net()
optimizer = optim.Adam(net.parameters())

In [211]:
agent = Reinforce(config)

In [213]:
agent.run(20000)

16 63.0
142 126.0
243 101.0
331 88.0
407 76.0
476 69.0
629 153.0
741 112.0
875 134.0
960 85.0
1029 69.0
1191 162.0
1268 77.0
1334 66.0
1381 47.0
1466 85.0
1537 71.0
1640 103.0
1722 82.0
1893 171.0
1934 41.0
1981 47.0
2021 40.0
2079 58.0
2119 40.0
2156 37.0
2232 76.0
2261 29.0
2323 62.0
2348 25.0
2390 42.0
2403 13.0
2429 26.0
2467 38.0
2547 80.0
2579 32.0
2643 64.0
2714 71.0
2854 140.0
2896 42.0
3002 106.0
3126 124.0
3177 51.0
3239 62.0
3292 53.0
3350 58.0
3441 91.0
3575 134.0
3609 34.0
3654 45.0
3711 57.0
3763 52.0
3809 46.0
3841 32.0
3887 46.0
3945 58.0
3989 44.0
4019 30.0
4076 57.0
4131 55.0
4187 56.0
4246 59.0
4329 83.0
4365 36.0
4390 25.0
4442 52.0
4501 59.0
4565 64.0
4645 80.0
4687 42.0
4730 43.0
4791 61.0
4872 81.0
4940 68.0
4991 51.0
5031 40.0
5068 37.0
5103 35.0
5180 77.0
5206 26.0
5288 82.0
5386 98.0
5425 39.0
5444 19.0
5505 61.0
5559 54.0
5663 104.0
5705 42.0
5776 71.0
5804 28.0
5858 54.0
5891 33.0
5950 59.0
5973 23.0
6012 39.0
6049 37.0
6079 30.0
6124 45.0
6140 16.0
6199 59.

## Actor Critic

In [None]:
class ActorCriticTailNet(nn.Module):
    