In [10]:
import torch
from torch import nn
import numpy as np
import random
from collections import deque
from itertools import count
import torch.nn.functional as F
from mdp import TradeExecutionEnv, DiscreteTradeSizeWrapper, RelativeTradeSizeWrapper
from tensorboardX import SummaryWriter
from torch.distributions import Categorical


SEED = 42
HORIZON = 5 * 12 * 2
UNITS_TO_SELL = 64
BATCH_SIZE = 32
EPOCHS = 300
TASK_BATCH_SIZE = 12
EPSILON = 0.1
GAMMA = 0.99

env = TradeExecutionEnv()

trade_sizes = {
  i: i*2 for i in range(33)
}
env = DiscreteTradeSizeWrapper(env, trade_sizes)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class CategoricalPolicy(nn.Module):
    def __init__(self, num_states, num_actions, hidden_dim=32) -> None:
        super().__init__()
        #self.rnn = nn.LSTM(num_states, hidden_dim, 1, batch_first=True)
        self.fc1 = nn.Linear(num_states, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_actions)
        
    def forward(self, x):
        #x, _ = self.rnn(x)
        x = F.relu(self.fc1(x[:, -1, :]))
        x = F.softmax(self.fc2(x), dim=1)
        return x

    def select_action(self, state):
        probs = self.forward(state)
        m = Categorical(probs)
        action = m.sample()
        return action.item(), m.log_prob(action)

In [11]:
def parse_state(state):
    data = torch.FloatTensor(np.stack([
        state["low"].to_numpy(),
        state["high"].to_numpy(),
        state["close"].to_numpy(),
        state["open"].to_numpy(),
        state["volume"].to_numpy(),
    ])).T
    return torch.concat([
        data,
        torch.repeat_interleave(torch.FloatTensor([[state["units_sold"]]]), 6, 0),
        torch.repeat_interleave(torch.FloatTensor([[state["cost_basis"]]]), 6, 0),
        torch.repeat_interleave(torch.FloatTensor([[state["steps_left"]]]), 6, 0),
    ], dim=1).unsqueeze(0)

In [12]:
def sample_rollouts(policy, env, num_rollouts, seed, test=False):
    trajs = []
    for _ in range(num_rollouts):
        tau = []
        state = env.reset(UNITS_TO_SELL, HORIZON, seed=seed, test=test)
        state = parse_state(state)
        done = False
        while not done:
            with torch.no_grad():
                action, _ = policy.select_action(state)
            next_state, reward, done, _, _ = env.step(action)
            next_state = parse_state(next_state)
            tau.append((state, action, reward))
            state = next_state
        states, actions, rewards = zip(*tau)
        states = torch.cat(states)
        actions = torch.tensor([actions])
        rewards = torch.tensor([rewards])
        trajs.append((states, actions, rewards))
    return trajs

In [13]:
def R_tau(r, gamma):
    gammas = torch.tensor([gamma**i for i in range(len(r))])
    return torch.sum(gammas * r)

def avg_batch_rewards(trajs, gamma):
    return torch.mean(torch.stack([R_tau(r, gamma) for _, _, r in trajs]))

def grad_log_pi(policy, states, actions):
    logits = policy(states)
    log_probs = torch.log(logits)
    log_probs = log_probs.gather(1,actions).T.flatten()
    g = [torch.autograd.grad(log_p, policy.parameters(), retain_graph=True) for log_p in log_probs]
    return g

In [14]:
def grad_log_tau(policy, tau):
  g = grad_log_pi(policy, tau[0], tau[1])
  stacked_g = [torch.stack(g_).view(len(tau[0]), -1) for g_ in zip(*g)]
  return torch.cat(stacked_g, axis=1).sum(axis=0)

def grad_U_tau(policy, tau, gamma):
  return grad_log_tau(policy, tau) * R_tau(tau[2], gamma)

def flatten_params(policy):
    return torch.cat([p.view(-1) for p in policy.parameters()])

In [15]:
policy = CategoricalPolicy(len(env.observation_space), len(trade_sizes)).to(device)

train_rewards = []
for e in range(EPOCHS):
  meta_g = torch.zeros_like(flatten_params(policy))
  task_rewards = []
  for i in range(TASK_BATCH_SIZE):
    trajs = sample_rollouts(policy, env, BATCH_SIZE, seed=SEED+i)
    u = torch.stack([grad_U_tau(policy, tau, GAMMA) for tau in trajs]).mean(axis=0)
    g = u * torch.sqrt(2 * EPSILON / (u @ u))
    n = 0
    new_policy = CategoricalPolicy(len(env.observation_space), len(trade_sizes)).to(device)
    new_policy.load_state_dict(policy.state_dict().copy())
    for i, p in enumerate(new_policy.parameters()):
      num_elements = p.numel()
      p.data += g[n:n+num_elements].view(p.shape)
      n += num_elements
    new_task_trajs = sample_rollouts(new_policy, env, BATCH_SIZE, seed=SEED+i)
    u = torch.stack([grad_U_tau(new_policy, tau, GAMMA) for tau in new_task_trajs]).mean(axis=0)
    meta_g += u * torch.sqrt(2 * EPSILON / (u @ u))
    policy.zero_grad()
    task_rewards.append(avg_batch_rewards(trajs, GAMMA))

  n = 0
  meta_g /= TASK_BATCH_SIZE
  for i, p in enumerate(policy.parameters()):
      num_elements = p.numel()
      p.data += g[n:n+num_elements].view(p.shape)
      n += num_elements
  train_rewards.append(np.mean(task_rewards))
  print(f"iter {e}: {np.mean(task_rewards)}")

iter 0: -8.880998514508157
iter 1: -9.226531727473407
iter 2: -9.13046766528586
iter 3: -9.261844773919554
iter 4: -9.188043071530375
iter 5: -9.254966896620909
iter 6: -9.331928828489351
iter 7: -9.050284972335803
iter 8: -9.18257610640031
iter 9: -9.06741719637562
iter 10: -9.405331208400106
iter 11: -9.193426984781903
iter 12: -9.267656134341593
iter 13: -9.237012758763013
iter 14: -8.959855212670048
iter 15: -9.432641179237512
iter 16: -9.382913497791344
iter 17: -9.182923686169902
iter 18: -9.387646518691456
iter 19: -9.335296275257376
iter 20: -9.368544426075651
iter 21: -8.935014013022231
iter 22: -8.755760646021107
iter 23: -8.922913066498591
iter 24: -8.913571418431859
iter 25: -9.147146695871465
iter 26: -8.482045236630514
iter 27: -9.408519642861437
iter 28: -7.7577874050396005
iter 29: -8.882150724565264
iter 30: -7.800485136861173
iter 31: -9.256521008525278
iter 32: -8.092935946438725
iter 33: -9.022752977957106
iter 34: -7.8680123955167005
iter 35: -9.205310295471458
ite

KeyboardInterrupt: 

In [23]:
EVAL_EPOCHS = 50
K = 2

same_mdp_eval_rewards = []
for e in range(EVAL_EPOCHS):
    policy.load_state_dict(original_params)
    k_shot(policy, K, seed=SEED)
    state = env.reset(UNITS_TO_SELL, HORIZON, seed=SEED)
    state = parse_state(state)
    done = False
    episode_rewards = 0
    while not done:
        with torch.no_grad():
            action = torch.argmax(policy(state)).item()
        next_state, reward, done, _, _ = env.step(action)
        next_state = parse_state(next_state)
        state = next_state
        episode_rewards += reward
    same_mdp_eval_rewards.append(episode_rewards)
print(f"Avg same mdp eval reward: {np.mean(same_mdp_eval_rewards)}")

TypeError: sample_rollouts() got an unexpected keyword argument 'test'

In [17]:
train_mdp_eval_rewards = []
for e in range(EVAL_EPOCHS):
    policy.load_state_dict(original_params)
    k_shot(policy, K, seed=(BATCH_SIZE * EPOCHS)+e)
    state = env.reset(UNITS_TO_SELL, HORIZON, seed=(BATCH_SIZE * EPOCHS)+e)
    state = parse_state(state)
    done = False
    episode_rewards = 0
    while not done:
        with torch.no_grad():
            action = torch.argmax(policy(state)).item()
        next_state, reward, done, _, _ = env.step(action)
        next_state = parse_state(next_state)
        state = next_state
        episode_rewards += reward
    train_mdp_eval_rewards.append(episode_rewards)
print(f"Avg train mdp eval reward: {np.mean(train_mdp_eval_rewards)}")

Avg train mdp eval reward: -10.0


In [18]:
test_mdp_eval_rewards = []
for e in range(EVAL_EPOCHS):
    policy.load_state_dict(original_params)
    k_shot(env, policy, K, seed=SEED+e, test=True)
    state = env.reset(UNITS_TO_SELL, HORIZON, seed=SEED+e, test=True)
    state = parse_state(state)
    done = False
    episode_rewards = 0
    while not done:
        with torch.no_grad():
            action = torch.argmax(policy(state)).item()
        next_state, reward, done, _, _ = env.step(action)
        next_state = parse_state(next_state)
        state = next_state
        episode_rewards += reward
    test_mdp_eval_rewards.append(episode_rewards)
print(f"Avg test mdp eval reward: {np.mean(test_mdp_eval_rewards)}")

Avg test mdp eval reward: -10.0


In [19]:
import json

EXP_NAME = "maml_restricted_policy_update_fine_tune_mdp_train.json"
with open(f"./results/{EXP_NAME}", "w+") as f:
    json.dump({
        "train_rewards": train_rewards,
        "same_mdp_eval_rewards": same_mdp_eval_rewards,
        "train_mdp_eval_rewards": train_mdp_eval_rewards,
        "test_mdp_eval_rewards": test_mdp_eval_rewards
    }, f)

In [20]:
original_params = policy.state_dict().copy()

In [7]:
tau = []
state = env.reset(UNITS_TO_SELL, HORIZON, seed=SEED-4)
state = parse_state(state)
done = False
while not done:
    with torch.no_grad():
        probs = policy(state)
        action = torch.argmax(probs).item()
    next_state, reward, done, _, _ = env.step(action)
    next_state = parse_state(next_state)
    tau.append((state, action, reward))
    state = next_state
states, actions, rewards = zip(*tau)
states = torch.cat(states)
actions = torch.tensor([actions])
rewards = torch.tensor([rewards])
print(rewards)

tensor([[ 0.0000, -0.1004]], dtype=torch.float64)


In [21]:
def k_shot(policy, n_steps, seed, test=False):
  for e in range(n_steps):
    trajs = sample_rollouts(policy, env, 2, seed=seed, test=test)
    u = torch.stack([grad_U_tau(policy, tau, GAMMA) for tau in trajs]).mean(axis=0)
    g = u * torch.sqrt(2 * EPSILON / (u @ u))
    n = 0
    for i, p in enumerate(policy.parameters()):
      num_elements = p.numel()
      p.data += g[n:n+num_elements].view(p.shape)
      n += num_elements
    policy.zero_grad()

In [24]:
tau = []
state = env.reset(UNITS_TO_SELL, HORIZON, seed=SEED-4)
state = parse_state(state)
done = False
while not done:
    with torch.no_grad():
        probs = policy(state)
        action = torch.argmax(probs).item()
    next_state, reward, done, _, _ = env.step(action)
    next_state = parse_state(next_state)
    tau.append((state, action, reward))
    state = next_state
states, actions, rewards = zip(*tau)
states = torch.cat(states)
actions = torch.tensor([actions])
rewards = torch.tensor([rewards])
print(rewards)

tensor([[ 0.0000, -0.1004]], dtype=torch.float64)
