In [1]:
# --- Imports ---
import os
from pathlib import Path
import pandas as pd
import torch
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecMonitor
from env import TradingEnv
from extractor import CNNPolicy 

# --- Load and Filter Meta ---
meta_path = Path("../dataset/meta.csv").resolve()
meta_root = meta_path.parent.resolve()

# Load the metadata
meta_df = pd.read_csv(meta_path, parse_dates=["timestamp"])

# ✅ Keep only the 5-minute timeframe image and label
meta_df = meta_df[["id", "timestamp", "close", "5m_img", "5m_lbl"]]

# ✅ Rename for easier downstream use
meta_df = meta_df.rename(columns={
    "5m_img": "img",
    "5m_lbl": "lbl"
})

# --- Train/Test split ---
split_ratio = 0.8
split_idx = int(len(meta_df) * split_ratio)
train_df = meta_df.iloc[:split_idx].reset_index(drop=True)
test_df = meta_df.iloc[split_idx:].reset_index(drop=True)

print(f"[INFO] Training on {len(train_df)} samples, testing on {len(test_df)}")

# --- Env Builders ---
def make_train_env():
    def _init():
        return TradingEnv(
            meta_df=train_df,
            root_dir=meta_root,
            image_size=(128, 128),
            starting_balance=100_000,
            leverage=1,
            risk_per_trade=1,
            log_path="trading_log.csv",  # shared training log
        )
    return _init

def make_test_env():
    def _init():
        return TradingEnv(
            meta_df=test_df,
            root_dir=meta_root,
            image_size=(128, 128),
            starting_balance=100_000,
            leverage=1,
            risk_per_trade=1,
            log_path="trading_log.csv",  # base path, overridden per episode
        )
    return _init

# --- Create Vec Envs ---
train_env = DummyVecEnv([make_train_env()])
train_env = VecMonitor(train_env)

test_env = DummyVecEnv([make_test_env()])

# --- Model path ---
base_model_path = "ppo-cnnpolicy-5m"

# --- Load or Initialize PPO ---
if os.path.exists(base_model_path + ".zip"):
    print(f"[📂 LOADING EXISTING MODEL from {base_model_path}]")
    model = PPO.load(base_model_path, env=train_env, device="cuda")
else:
    print("[🆕 INITIALIZING NEW MODEL]")
    model = PPO(
        policy=CNNPolicy,
        env=train_env,
        verbose=1,
        n_steps=256,
        batch_size=64,
        learning_rate=1e-4,
        tensorboard_log="./logs",
        device="cuda"
    )

# --- Training Loop ---
total_target_steps = 5_000_000
chunk_size = 100_000
chunks = total_target_steps // chunk_size

for i in range(chunks):
    print(f"\n[🔁 TRAINING CHUNK {i+1}/{chunks}]")
    
    model.learn(total_timesteps=chunk_size, reset_num_timesteps=False)

    steps_so_far = (i+1) * chunk_size
    chunk_model_path = f"{base_model_path}-{steps_so_far//1000}k"
    model.save(chunk_model_path)
    print(f"[✅ Saved checkpoint: {chunk_model_path}.zip]")

# --- Evaluation Loop ---
print("\n[🚀 TESTING STARTED]")
obs = test_env.reset(seed=42)

total_rewards = []
episode_reward = 0
episode_count = 0

for i in range(len(test_df)):
    action, _ = model.predict(obs, deterministic=True)
    obs, reward, done, info = test_env.step(action)
    episode_reward += reward[0]

    if done[0]:
        total_rewards.append(episode_reward)
        print(f"Episode {episode_count} reward: {episode_reward:.2f}")
        
        # ✅ Set a new log file for the next episode
        test_env.envs[0].set_log_suffix(f"test_ep{episode_count}")
        
        episode_count += 1
        episode_reward = 0
        obs = test_env.reset()

print("\n[📊 TESTING COMPLETE]")
if total_rewards:
    print(f"Average Reward: {sum(total_rewards)/len(total_rewards):.2f}")
    print(f"Max Reward: {max(total_rewards):.2f}")
    print(f"Min Reward: {min(total_rewards):.2f}")
else:
    print("No episodes completed during testing.")


INFO:env:[ENV INIT] 171 rows loaded.
INFO:env:[ENV INIT] 43 rows loaded.


[INFO] Training on 171 samples, testing on 43
[🆕 INITIALIZING NEW MODEL]
Using cpu device

[🔁 TRAINING CHUNK 1/50]
Logging to ./logs/PPO_0
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 171      |
|    ep_rew_mean     | -3e+03   |
| time/              |          |
|    fps             | 67       |
|    iterations      | 1        |
|    time_elapsed    | 3        |
|    total_timesteps | 256      |
---------------------------------


KeyboardInterrupt: 