In [None]:
import os
import random
from datetime import datetime, timedelta

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.ticker import FuncFormatter
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecNormalize

os.chdir("..")
from antimpt import AntiMPT, TradingFloorEnv
from antimpt.utils.dataset import get_equity

In [None]:
%matplotlib inline

random.seed(42)
np.random.seed(42)

#### Simulate Trades

In [None]:
agent = AntiMPT("PPO", artefact=["PPO", "09", "30_000"])
equities = [get_equity(2021, 2024, "WSM")]

equity = equities[0][["close"]].reset_index(names="date")
equity["date"] = equity["date"].dt.strftime("%Y-%m-%d")

In [None]:
t0 = 0
tn = len(equities[0]) - 1
funds = 10_000
histories = [{
    "date": equity["date"].iat[0],
    "action_taken": "HOLD",
    "portfolio": funds,
}]

while t0 < tn:
    vec_env = make_vec_env(
        TradingFloorEnv,
        n_envs=1,
        env_kwargs={
            "equities": equities,
            "funds": funds,
            "t0": t0,
            "is_train": False,
            "render_mode": "human",
        },
        seed=42,
    )

    norm_vec_env = VecNormalize(vec_env, norm_obs=True, norm_reward=True)

    obs = norm_vec_env.reset()
    norm_vec_env.render()
    done = False

    while not done:
        action, _ = agent.predict(obs)
        obs, _, done, info = norm_vec_env.step(action)
        norm_vec_env.render() if not done else None

        histories.append(info[0])

    t0 += info[0]["timestep"]
    funds += info[0]["portfolio"] - funds

In [None]:
results = pd.DataFrame(histories)
results.iloc[1:].query("`TimeLimit.truncated`")

In [None]:
results = results[["date", "action_taken", "portfolio"]]

#### Benchmark Market Data

In [None]:
risk_free = pd.read_parquet("./data/raw/bills/^IRX.parquet")
risk_free.columns = risk_free.columns.str.lower()
risk_free.index = risk_free.index.strftime("%Y-%m-%d")

start, end = equity["date"].iloc[[0,-1]].values
dates, rates = [], []

while start < end:
    dates.append(start)
    rates.append(1 + risk_free.at[start, "close"] / 100 / 4)
    start = datetime.strptime(start, "%Y-%m-%d") + timedelta(days=13*7)

    while start.strftime("%Y-%m-%d") not in risk_free.index \
        and start.strftime("%Y-%m-%d") < end:
        start += timedelta(days=1)
    start = start.strftime("%Y-%m-%d")

risk_free = pd.DataFrame({
    "date": dates,
    "risk_free": np.cumprod([10_000]+list(rates))[:-1],
})

In [None]:
ref_idx = pd.read_parquet("./data/raw/indices/XHB.parquet")
ref_idx.columns = ref_idx.columns.str.lower()
ref_idx["date"] = ref_idx.index.strftime("%Y-%m-%d")
ref_idx.reset_index(drop=True, inplace=True)

start = equity["date"].iat[0]
position = 10_000 // ref_idx.query("date == @start")["close"].values
ref_idx["index"] = ref_idx["close"] * position

ref_idx = ref_idx[["date", "index"]]

#### Visualise Performance

In [None]:
df = equity.merge(results, on="date", how="left") \
    .merge(risk_free, on="date", how="left") \
    .merge(ref_idx, on="date", how="left")

df["date"] = pd.to_datetime(df["date"])
df["action_taken"] = df["action_taken"].fillna("HOLD")
df[["portfolio", "risk_free"]] = df[["portfolio", "risk_free"]].ffill()

In [None]:
date = df["date"]
action = df["action_taken"]
price = df["close"]
portfolio = df["portfolio"]
index = df["index"]
risk_free = df["risk_free"]

action_long = price.where(action == "LONG", None)
action_short = price.where(action.isin(["SHORT", "CLOSE"]), None)

In [None]:
%matplotlib inline

In [None]:
def thousands_formatter(x, pos):
    return f"{x/1_000:,.0f}"


fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))

# trading actions
ax1.plot(date, price, label="Price", color="#63c5da")
ax1.scatter(
    date, action_long, label="LONG", color="#2e8b57",
    marker="^", alpha=0.5,
)
ax1.scatter(
    date, action_short, label="SHORT", color="#ba110c",
    marker="v", alpha=0.5,
)

ax1.set_title("Trading Actions")
ax1.set_xlabel("Date")
ax1.set_ylabel("USD (in thousands)")

ax1.legend()

# overall returns
ax2.plot(
    date, portfolio, label="AntiMPT",
    color=["#ba110c", "#2e8b57"][bool(portfolio.iat[-1] > index.iat[-1])],
)
ax2.plot(date, index, label="S&P Homebuilders ETF", color="#b8b4b4")
ax2.plot(date, risk_free, label="13-Week Treasury Bill", color="#f0b58d")

ax2.set_title("Return over Time")
ax2.set_xlabel("Date")
ax2.set_ylabel("USD (in thousands)")

ax2.yaxis.set_major_formatter(FuncFormatter(thousands_formatter))

ax2.legend()

plt.tight_layout()
plt.show()