# TRADING ENV 

**Goal:** Use `data/prices_returns.csv` to simulate trading with portfolio weights.

## What to build (minimum)
1. **Load data**
   - Read `data/prices_returns.csv`.
   - Pivot to a matrix `R` of shape `[T, A]` (rows=dates, cols=assets).
   - Keep `dates` and `assets` lists.

2. **Environment class**
   - `__init__(R, dates, assets, window=20, cost_bps=10.0, include_cash=True)`
   - `reset(seed=0)` → set `t=window`, start with equal weights, return last `window` rows of `R` as the observation `(window, A)`.
   - `step(action)` →
     - Softmax the action → weights that sum to 1 (long-only).
     - Compute reward = `weights · returns_today − (cost_bps/1e4)*0.5*sum(|w_t − w_{t-1}|)`.
     - Advance one day, return `(next_obs, reward, done, info)`.

3. **Testing helpers**
   - `equal_weight_policy(n_assets)` and `random_policy(n_assets, seed)`.
   - `rollout(env, policy_fn, out_csv="data/returns_EQW.csv")` to write daily portfolio returns.

## Checks
- Observation shape is always `(window, A)`.
- After softmax: weights sum to 1 and are non-negative.
- No NaNs in observations or rewards.
- Equal-weight rollout produces a sensible return series (dates contiguous).

## Done when
- `TradingEnv` runs from start to finish with the equal-weight policy.
- `data/returns_EQW.csv` is created with columns `date, ret`.


In [1]:
# ============================
# FULL TEST OF TRADING ENV
# ============================

from trading_env.load_data import load_returns
from trading_env.trading_env import (
    TradingEnv,
    equal_weight_policy,
    random_policy,
    rollout
)
import numpy as np
import pandas as pd

print("=== Loading Test Data ===")
R, dates, assets = load_returns("test_data/sample_prices_returns.csv")
print("R shape:", R.shape)
print("Assets:", assets)
print("Dates:", dates)
print("\nR Matrix:\n", R)

# ----------------------------
# Create environment (window=2)
# ----------------------------
print("\n=== Creating Environment ===")
env = TradingEnv(R, dates, assets, window=2, cost_bps=10)

# ----------------------------
# Test equal-weight policy
# ----------------------------
print("\n=== Testing equal_weight_policy ===")
eq_action = equal_weight_policy(env.A)
print("Raw equal-weight action:", eq_action)
print("Softmax result:", np.exp(eq_action) / np.sum(np.exp(eq_action)))

# ----------------------------
# Test random policy
# ----------------------------
print("\n=== Testing random_policy ===")
rand_action = random_policy(env.A, seed=0)
print("Raw random action:", rand_action)
print("Softmax result:", np.exp(rand_action) / np.sum(np.exp(rand_action)))

# ----------------------------
# Test rollout with equal weight
# ----------------------------
print("\n=== Running rollout (equal weight) ===")
rollout(env, lambda n: equal_weight_policy(n), out_csv="test_data/returns_EQW.csv")

df_eqw = pd.read_csv("test_data/returns_EQW.csv")
print("\nEqual-weight returns CSV:")
display(df_eqw)

# ----------------------------
# Test rollout with random policy
# ----------------------------
print("\n=== Running rollout (random) ===")
rollout(env, lambda n: random_policy(n, seed=42), out_csv="test_data/returns_RANDOM.csv")

df_rand = pd.read_csv("test_data/returns_RANDOM.csv")
print("\nRandom policy returns CSV:")
display(df_rand)

print("\n=== ALL TESTS COMPLETE ===")

=== Loading Test Data ===
R shape: (5, 4)
Assets: ['GLD', 'QQQ', 'SPY', 'TLT']
Dates: [Timestamp('2020-01-01 00:00:00'), Timestamp('2020-01-02 00:00:00'), Timestamp('2020-01-03 00:00:00'), Timestamp('2020-01-04 00:00:00'), Timestamp('2020-01-05 00:00:00')]

R Matrix:
 [[    nan     nan     nan     nan]
 [ 0.0014  0.0076  0.0036  0.0029]
 [-0.0007 -0.0033 -0.002  -0.0036]
 [ 0.0027  0.0062  0.0049  0.0014]
 [-0.0014 -0.0009  0.0028  0.005 ]]

=== Creating Environment ===

=== Testing equal_weight_policy ===
Raw equal-weight action: [1. 1. 1. 1.]
Softmax result: [0.25 0.25 0.25 0.25]

=== Testing random_policy ===
Raw random action: [1.76405235 0.40015721 0.97873798 2.2408932 ]
Softmax result: [0.30096764 0.07694629 0.13723412 0.48485195]

=== Running rollout (equal weight) ===
Rollout complete. Saved to test_data/returns_EQW.csv

Equal-weight returns CSV:


Unnamed: 0,date,ret
0,2020-01-03 00:00:00,-0.0024
1,2020-01-04 00:00:00,0.0038
2,2020-01-05 00:00:00,0.001375



=== Running rollout (random) ===
Rollout complete. Saved to test_data/returns_RANDOM.csv

Random policy returns CSV:


Unnamed: 0,date,ret
0,2020-01-03 00:00:00,-0.002962
1,2020-01-04 00:00:00,0.002843
2,2020-01-05 00:00:00,0.002796



=== ALL TESTS COMPLETE ===
