In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from environments import DS_DA, CS_DA_deterministic
from collections import deque, namedtuple
import random
from torch import optim
import matplotlib.pyplot as plt
import numpy as np
from torch.distributions import Categorical


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
class Policy(nn.Module):

    def __init__(self, n_obs, n_actions):
        super(Policy, self).__init__()
        self.layer1 = nn.Linear(n_obs, 16)
        self.layer2 = nn.Linear(16, 64)
        self.layer3 = nn.Linear(64, n_actions)

    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return F.softmax(self.layer3(x)/100)

    def select_action(self, state):
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        probs = self.forward(state).cpu()
        m = Categorical(probs)
        
        action = m.sample()
        # action = np.argmax(m)
        print(m.probs)
        return action.item(), m.log_prob(action)

In [4]:
N_EPISODES = 100
ALPHA = 0.1
GAMMA = 0.95
K = 10

Transition = namedtuple('Transition',('s', 'a', 's_prime', 'r', 't', "log_prob_a"))

In [5]:
class ReplayBUffer(object):

    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)


def discount_rewards(rewards, gamma=0.95):
    r = np.array([gamma**i * rewards[i] 
        for i in range(len(rewards))])
    # Reverse the array direction for cumsum and then
    # revert back to the original order
    r = r[::-1].cumsum()[::-1]
    return r
# rb = ReplayBUffer(10_000)

In [None]:
pi = Policy(4,3)
optimizer = optim.Adam(pi.parameters(), lr=1e-2)

for episode in range(1000):
    env = CS_DA_deterministic(lim=5)
    done = False
    s = env.state

    trajectory = deque([], maxlen=10)
    log_probs = []
    # sample one episode
    t=0
    while not done:
        a, log_prob_a = pi.select_action(s)
        s_prime, r, done = env.step(a)
        trajectory.append(Transition(s=s,a=a,s_prime=s_prime,r=r,t=t, log_prob_a=log_prob_a))
        s = s_prime
        t+= 1
        log_probs.append(log_prob_a)

    rewards = Transition(*zip(*trajectory)).r
    discounted_rewards = discount_rewards(rewards)
    eps = np.finfo(np.float32).eps.item()
    returns = torch.tensor(discounted_rewards.copy())
    returns = (returns - returns.mean()) / (returns.std() + eps)
    print(rewards)
    policy_loss = []
    for log_prob, disc_return in zip(log_probs, returns):
        policy_loss.append(-log_prob * disc_return)
    policy_loss = torch.cat(policy_loss).sum()

    optimizer.zero_grad()
    policy_loss.backward()
    optimizer.step()
    print(Transition(*zip(*trajectory)).a)