In [8]:
import torch
from torch import nn
import numpy as np
import random
from collections import deque
from itertools import count
import torch.nn.functional as F
from mdp import TradeExecutionEnv, DiscreteTradeSizeWrapper, RelativeTradeSizeWrapper
from tensorboardX import SummaryWriter
from torch.distributions import Categorical

env = TradeExecutionEnv()

SEED = 42
HORIZON = 5 * 12 * 2
UNITS_TO_SELL = 64
BATCH_SIZE = 16
EPOCHS = 300
EPSILON = 0.01
GAMMA = 0.99
ALPHA = 0.9

env = TradeExecutionEnv()

trade_sizes = {
  i: i*2 for i in range(33)
}
env = DiscreteTradeSizeWrapper(env, trade_sizes)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class CategoricalPolicy(nn.Module):
    def __init__(self, num_states, num_actions, hidden_dim=32) -> None:
        super().__init__()
        #self.rnn = nn.LSTM(num_states, hidden_dim, 1, batch_first=True)
        self.fc1 = nn.Linear(num_states, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_actions)
        
    def forward(self, x):
        #x, _ = self.rnn(x)
        x = F.relu(self.fc1(x[:, -1, :]))
        x = F.softmax(self.fc2(x), dim=1)
        return x

    def select_action(self, state):
        probs = self.forward(state)
        m = Categorical(probs)
        action = m.sample()
        return action.item(), m.log_prob(action)

In [9]:
def parse_state(state):
    data = torch.FloatTensor(np.stack([
        state["low"].to_numpy(),
        state["high"].to_numpy(),
        state["close"].to_numpy(),
        state["open"].to_numpy(),
        state["volume"].to_numpy(),
    ])).T
    return torch.concat([
        data,
        torch.repeat_interleave(torch.FloatTensor([[state["units_sold"]]]), 6, 0),
        torch.repeat_interleave(torch.FloatTensor([[state["cost_basis"]]]), 6, 0),
        torch.repeat_interleave(torch.FloatTensor([[state["steps_left"]]]), 6, 0),
    ], dim=1).unsqueeze(0)

In [23]:
def sample_rollouts(policy, env, num_rollouts,seed=SEED):
    trajs = []
    for _ in range(num_rollouts):
        tau = []
        state = env.reset(UNITS_TO_SELL, HORIZON, seed=(seed * num_rollouts) + e)
        state = parse_state(state)
        done = False
        while not done:
            with torch.no_grad():
                action, _ = policy.select_action(state)
            next_state, reward, done, _, _ = env.step(action)
            next_state = parse_state(next_state)
            tau.append((state, action, reward))
            state = next_state
        states, actions, rewards = zip(*tau)
        states = torch.cat(states)
        actions = torch.tensor([actions])
        rewards = torch.tensor([rewards])
        trajs.append((states, actions, rewards))
    return trajs

In [24]:
def R_tau(r):
    gammas = torch.tensor([GAMMA**i for i in range(len(r))])
    return torch.sum(gammas * r)

def avg_batch_rewards(trajs):
    return torch.mean(torch.stack([R_tau(r) for _, _, r in trajs]))

def grad_log_pi(policy, states, actions):
    logits = policy(states)
    log_probs = torch.log(logits)
    log_probs = log_probs.gather(1,actions).T.flatten()
    g = [torch.autograd.grad(log_p, policy.parameters(), retain_graph=True) for log_p in log_probs]
    return g

In [25]:
def grad_log_tau(policy, tau):
  g = grad_log_pi(policy, tau[0], tau[1])
  stacked_g = [torch.stack(g_).view(len(tau[0]), -1) for g_ in zip(*g)]
  return torch.cat(stacked_g, axis=1).sum(axis=0)

def grad_U_tau(policy, tau):
  return grad_log_tau(policy, tau) * R_tau(tau[2])


In [26]:
def surrogate_objective(policy, new_policy, trajs):
    means = []
    for tau in trajs:
      states, actions, rewards = tau
      prob_ratios = new_policy(states).gather(1, actions).T.flatten() / policy(states).gather(1, actions).T.flatten()
      discounted_rewards = torch.FloatTensor([torch.sum(torch.stack([rewards.T[j] * GAMMA ** (j-i) for j in range(i,len(rewards.T))])) for i in range(len(rewards.T))])
      means.append(torch.mean(prob_ratios * discounted_rewards))
    return torch.mean(torch.stack(means))

def surrogate_constraint(policy, new_policy, trajs):
    means = []
    for tau in trajs:
        states, _, rewards = tau
        discounted_rewards = GAMMA ** torch.arange(len(rewards.T))
        policy_probs = policy(states)
        #print(f"Policy probs shape: {policy_probs.shape}")
        new_policy_probs = new_policy(states)
        kl_divs = torch.sum(policy_probs * torch.log(policy_probs / new_policy_probs), axis=1)
        #kl_divs = torch.sum(torch.kl_div(torch.log(policy_probs), new_policy_probs), axis=1)
        means.append(torch.mean(kl_divs * discounted_rewards))
    return torch.mean(torch.stack(means))

def flatten_params(policy):
    return torch.cat([p.view(-1) for p in policy.parameters()])

def linesearch(policy, new_policy, trajs):
    f_theta = surrogate_objective(policy, policy, trajs)
    while surrogate_constraint(policy, new_policy, trajs) > EPSILON or surrogate_objective(policy, new_policy, trajs) <= f_theta:
        theta = flatten_params(policy)
        theta_new = flatten_params(new_policy)
        theta_new = theta + ALPHA * (theta_new - theta)
        n = 0
        for _, p in enumerate(new_policy.parameters()):
            num_elements = p.numel()
            p.data = theta_new[n:n+num_elements].view(p.shape)
            n += num_elements
    return new_policy

In [27]:
policy = CategoricalPolicy(len(env.observation_space), len(trade_sizes)).to(device)

train_rewards = []
for e in range(EPOCHS):
  trajs = sample_rollouts(policy, env, BATCH_SIZE,seed=e)
  g_u_tau = [grad_U_tau(policy, tau) for tau in trajs]
  Fish = torch.stack([gut.unsqueeze(1) @ gut.unsqueeze(0) for gut in g_u_tau]).mean(axis=0)
  g_u = torch.stack(g_u_tau).mean(axis=0)
  u = torch.linalg.pinv(Fish) @ g_u.unsqueeze(0).T
  g = u * torch.sqrt(2 * EPSILON / (g_u @ u))
  #print(f"G: {g}")
  n = 0
  new_policy = CategoricalPolicy(len(env.observation_space), len(trade_sizes)).to(device)
  new_policy.load_state_dict(policy.state_dict().copy())
  for i, p in enumerate(new_policy.parameters()):
    num_elements = p.numel()
    p.data += g[n:n+num_elements].view(p.shape)
    n += num_elements
  policy = linesearch(policy, new_policy, trajs)
  policy.zero_grad()
  train_rewards.append(avg_batch_rewards(trajs))
  if e % 10 == 0:
    print(f"iter {e}: {avg_batch_rewards(trajs)}")

iter 0: -10.0
iter 10: -9.393252309956434
iter 20: -8.86914494946309
iter 30: -9.437237867071891
iter 40: -9.399950029675107
iter 50: -10.0
iter 60: -8.767210509025471
iter 70: -10.0
iter 80: -8.291639638446076
iter 90: -9.414105959690371
iter 100: -9.379112390318575
iter 110: -10.0
iter 120: -8.815065507185732
iter 130: -9.429459304404334
iter 140: -9.387194880073512
iter 150: -9.434178384303358
iter 160: -9.434908836327125
iter 170: -9.398376568917149
iter 180: -10.0
iter 190: -9.42385530719504
iter 200: -8.871562426470588
iter 210: -9.383873889711387
iter 220: -10.0
iter 230: -8.301330695312501
iter 240: -10.0
iter 250: -8.775769866890482
iter 260: -9.381359355316867
iter 270: -8.806608376338605
iter 280: -9.41252913751836
iter 290: -8.873009903055152


In [28]:
EVAL_EPOCHS = 50

same_mdp_eval_rewards = []
for e in range(EVAL_EPOCHS):
    state = env.reset(UNITS_TO_SELL, HORIZON, seed=SEED)
    state = parse_state(state)
    done = False
    episode_rewards = 0
    while not done:
        with torch.no_grad():
            action = torch.argmax(policy(state)).item()
        next_state, reward, done, _, _ = env.step(action)
        next_state = parse_state(next_state)
        state = next_state
        episode_rewards += reward
    same_mdp_eval_rewards.append(episode_rewards)
print(f"Avg same mdp eval reward: {np.mean(same_mdp_eval_rewards)}")

Avg same mdp eval reward: -10.0


In [29]:
train_mdp_eval_rewards = []
for e in range(EVAL_EPOCHS):
    state = env.reset(UNITS_TO_SELL, HORIZON, seed=(BATCH_SIZE * EPOCHS)+e)
    state = parse_state(state)
    done = False
    episode_rewards = 0
    while not done:
        with torch.no_grad():
            action = torch.argmax(policy(state)).item()
        next_state, reward, done, _, _ = env.step(action)
        next_state = parse_state(next_state)
        state = next_state
        episode_rewards += reward
    train_mdp_eval_rewards.append(episode_rewards)
print(f"Avg train mdp eval reward: {np.mean(train_mdp_eval_rewards)}")

Avg train mdp eval reward: -10.0


In [30]:
test_mdp_eval_rewards = []
for e in range(EVAL_EPOCHS):
    state = env.reset(UNITS_TO_SELL, HORIZON, seed=SEED+e, test=True)
    state = parse_state(state)
    done = False
    episode_rewards = 0
    while not done:
        with torch.no_grad():
            action = torch.argmax(policy(state)).item()
        next_state, reward, done, _, _ = env.step(action)
        next_state = parse_state(next_state)
        state = next_state
        episode_rewards += reward
    test_mdp_eval_rewards.append(episode_rewards)
print(f"Avg test mdp eval reward: {np.mean(test_mdp_eval_rewards)}")

Avg test mdp eval reward: -10.0


In [33]:
import json

EXP_NAME = "trpo_multi_mdp_train.json"
with open(f"./results/{EXP_NAME}", "w+") as f:
    json.dump({
        "train_rewards": train_rewards,
        "same_mdp_eval_rewards": same_mdp_eval_rewards,
        "train_mdp_eval_rewards": train_mdp_eval_rewards,
        "test_mdp_eval_rewards": test_mdp_eval_rewards
    }, f)

In [32]:
train_rewards = [n.item() for n in train_rewards]