In [32]:
import torch
from torch import nn
import numpy as np
import random
from collections import deque
from itertools import count
import torch.nn.functional as F
from mdp import TradeExecutionEnv, DiscreteTradeSizeWrapper
from tensorboardX import SummaryWriter
from torch.distributions import Categorical


SEED = 42
#HORIZON = 5 * 12 * 8
HORIZON = 100
UNITS_TO_SELL = 24
BATCH_SIZE = 32
EPOCHS = 100
EPSILON = 0.001
GAMMA = 0.99

env = TradeExecutionEnv()

trade_sizes = {
  0: 0,
  1: 1,
  2: 2,
  3: 4,
  4: 8,
  #5: 16,
  #6: 32,
  #7: 64,
  #8: 128,
  #9: 250
}
env = DiscreteTradeSizeWrapper(env, trade_sizes)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class CategoricalPolicy(nn.Module):
    def __init__(self, num_states, num_actions, hidden_dim=32) -> None:
        super().__init__()
        self.fc1 = nn.Linear(num_states, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_actions)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.softmax(self.fc2(x), dim=1)
        return x

    def select_action(self, state):
        probs = self.forward(state)
        m = Categorical(probs)
        action = m.sample()
        return action.item(), m.log_prob(action)

In [33]:
def parse_state(state):
    data = torch.FloatTensor(np.stack([
        state["low"].to_numpy(),
        state["high"].to_numpy(),
        state["close"].to_numpy(),
        state["open"].to_numpy(),
        state["volume"].to_numpy(),
    ])).T
    return torch.concat([
        data,
        torch.repeat_interleave(torch.FloatTensor([[state["units_sold"]]]), 6, 0),
        torch.repeat_interleave(torch.FloatTensor([[state["cost_basis"]]]), 6, 0),
        torch.repeat_interleave(torch.FloatTensor([[state["steps_left"]]]), 6, 0),
    ], dim=1)[-1,:].unsqueeze(0)

In [34]:
def sample_rollouts(policy, env, num_rollouts):
    trajs = []
    for _ in range(num_rollouts):
        tau = []
        state = env.reset(UNITS_TO_SELL, HORIZON, seed=SEED)
        state = parse_state(state)
        done = False
        while not done:
            with torch.no_grad():
                action, _ = policy.select_action(state)
            next_state, reward, done, _, _ = env.step(action)
            next_state = parse_state(next_state)
            tau.append((state, action, reward))
            state = next_state
        states, actions, rewards = zip(*tau)
        states = torch.cat(states)
        actions = torch.tensor([actions])
        rewards = torch.tensor([rewards])
        trajs.append((states, actions, rewards))
    return trajs

In [35]:
def R_tau(r, gamma):
    gammas = torch.tensor([gamma**i for i in range(len(r))])
    return torch.sum(gammas * r)

def avg_batch_rewards(trajs, gamma):
    return torch.mean(torch.stack([R_tau(r, gamma) for _, _, r in trajs]))

def grad_log_pi(policy, states, actions):
    logits = policy(states)
    log_probs = torch.log(logits)
    log_probs = log_probs.gather(1,actions).T.flatten()
    g = [torch.autograd.grad(log_p, policy.parameters(), retain_graph=True) for log_p in log_probs]
    return g

In [36]:
def grad_log_tau(policy, tau):
  g = grad_log_pi(policy, tau[0], tau[1])
  stacked_g = [torch.stack(g_).view(len(tau[0]), -1) for g_ in zip(*g)]
  return torch.cat(stacked_g, axis=1).sum(axis=0)

def grad_U_tau(policy, tau, gamma):
  return grad_log_tau(policy, tau) * R_tau(tau[2], gamma)


In [37]:
policy = CategoricalPolicy(8, 5).to(device)

for e in range(EPOCHS):
  trajs = sample_rollouts(policy, env, BATCH_SIZE)
  u = torch.stack([grad_U_tau(policy, tau, GAMMA) for tau in trajs]).mean(axis=0)
  #print(f"U = {u}")
  #print(f"U @ U = {u @ u}")
  g = u * torch.sqrt(2 * EPSILON / (u @ u))
  #print(f"G: {g}")
  n = 0
  for i, p in enumerate(policy.parameters()):
    num_elements = p.numel()
    p.data += g[n:n+num_elements].view(p.shape)
    n += num_elements
  policy.zero_grad()
  print(f"iter {e}: {avg_batch_rewards(trajs, GAMMA)}")

iter 0: -6.044705696718247
iter 1: -7.855036541826589
iter 2: -7.266424895113815
iter 3: -7.859795072709465
iter 4: -7.260136173475123
iter 5: -6.3448365685156665
iter 6: -6.941219994004985
iter 7: -6.940241104750553
iter 8: -7.8512344232771465
iter 9: -5.726122043189353
iter 10: -6.9301121563146255
iter 11: -5.109279480143882
iter 12: -6.631501406429125
iter 13: -7.548902714738112
iter 14: -6.319173220024987
iter 15: -6.933041302710257
iter 16: -6.026817134331535
iter 17: -4.499993044411918
iter 18: -4.182078627727573
iter 19: -5.713329484414986
iter 20: -6.321930144243808
iter 21: -7.236390018237644
iter 22: -8.150292220119711
iter 23: -5.38583711227118
iter 24: -4.171136444879478
iter 25: -5.389825265924445
iter 26: -6.32344085072514
iter 27: -3.8435543330916007
iter 28: -4.769265510668705
iter 29: -5.082269309270839
iter 30: -4.770002300466574
iter 31: -4.461360121992577
iter 32: -2.0005168892285
iter 33: -4.769121414626076
iter 34: -4.774995418246892
iter 35: -4.1555383075521775
i