In [None]:
import os
# Keep using keras-2 (tf-keras) rather than keras-3 (keras).
os.environ['TF_USE_LEGACY_KERAS'] = '1'

import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from IPython.display import display

from tf_agents.bandits.agents import lin_ucb_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import tf_py_environment
from tf_agents.bandits.metrics import tf_metrics as tf_bandit_metrics
from tf_agents.policies import policy_saver
from utils import preprocess_data, create_wide_format_data, SYMBOLS
from environment import CryptoTradingEnvironment


# Data and Model Paths
DATA_FILEPATH = 'data/ohlcv.csv.gz'
POLICY_SAVE_PATH = 'policy'

# Model Hyperparameters
CONTEXT_LENGTH = 10
NUM_TRAINING_STEPS = 1000  # Increased for more meaningful training
ALPHA = 1.0 # LinUCB exploration parameter

# --- Main Training Script ---
print("--- Starting Bandit Training Script ---")

# 1. Load Data
df = pd.read_csv(DATA_FILEPATH, compression='gzip', parse_dates=['timestamp']) #.set_index('timestamp')
all_data = preprocess_data(df)
observation_df, prices_df = create_wide_format_data(
    all_data, 
    symbols=SYMBOLS, 
    features=['rsi']
)
training_data = all_data.iloc[:NUM_TRAINING_STEPS]

# 2. Setup Environment
tf_env = tf_py_environment.TFPyEnvironment(
    CryptoTradingEnvironment(
    observation_df=observation_df,
    prices_df=prices_df, symbols=SYMBOLS)
)

# 3. Setup Agent
agent = lin_ucb_agent.LinearUCBAgent(
    time_step_spec=tf_env.time_step_spec(),
    action_spec=tf_env.action_spec(),
    alpha=ALPHA,
    dtype=tf.float32
)

# 4. Setup Metrics and Oracle
def optimal_reward_oracle(observation: np.ndarray) -> np.float32:
    """
    Calculates the best possible reward for the current step by looking ahead.
    This "perfect foresight" oracle is used for calculating regret.

    It accesses the environment's internal state to get the current time step
    and the pre-calculated prices DataFrame.

    NOTE: The 'observation' argument is unused but required by the metric's API.
    """
    # 1. Get a handle to the underlying Python environment.
    py_env = tf_env.pyenv.envs[0]
    
    # 2. Get the current step index from the environment's state.
    current_step = py_env.current_step
    
    # 3. Handle the edge case where we are at the end of the data.
    # We can't look one step into the future.
    if current_step >= len(py_env._price_data) - 1:
        return 0.0

    # 4. Calculate the reward for every possible action to find the maximum.
    all_possible_rewards = []
    num_actions = py_env.action_spec().maximum + 1
    
    for action in range(num_actions):
        # Decode the action into a symbol and a trade type
        crypto_index = action // 3
        trade_type_idx = action % 3  # 0: BUY, 1: HOLD, 2: SELL
        
        # The reward for a HOLD action is always 0.
        if trade_type_idx == 1:
            all_possible_rewards.append(0.0)
            continue
            
        symbol_to_trade = py_env.symbols[crypto_index]
        
        # PERFORMANCE WIN: Get prices from the fast, wide-format prices DataFrame.
        current_price = py_env._price_data.iloc[current_step][symbol_to_trade]
        next_price = py_env._price_data.iloc[current_step + 1][symbol_to_trade]
        
        # Calculate the reward for this specific BUY or SELL action
        if trade_type_idx == 0:  # BUY
            reward = (next_price - current_price) / current_price
        else:  # SELL (trade_type_idx == 2)
            reward = (current_price - next_price) / current_price
        
        all_possible_rewards.append(reward)
        
    # 5. Return the maximum possible reward from all actions.
    return np.max(all_possible_rewards).astype(np.float32)

regret_metric = tf_bandit_metrics.RegretMetric(optimal_reward_oracle)

class ShowProgress:
    def __init__(self, total, interval=50):
        self.counter = 0
        self.total = total
        self.interval = interval
    def __call__(self, trajectory):
        if not trajectory.is_boundary():
            self.counter += 1
        if self.counter % self.interval == 0:
            print("\r{}/{} Reward: {}".format(self.counter, self.total, trajectory.reward), end="")

# 5. Setup Driver
def train_step(trajectory):
    if not trajectory.is_last():
        time_axised_trajectory = tf.nest.map_structure(lambda t: tf.expand_dims(t, 1), trajectory)
        agent.train(time_axised_trajectory)
        
num_steps = len(training_data) - CONTEXT_LENGTH - 5

driver = dynamic_step_driver.DynamicStepDriver(
    env=tf_env,
    policy=agent.policy,
    num_steps=num_steps,
    observers=[train_step, regret_metric, ShowProgress(num_steps)]
)

# 6. Run Training
print(f"\nStarting training for {driver._num_steps} steps...")
driver.run()
print("Training finished.")

# 7. Save Policy
print(f"\nSaving the trained policy to: {POLICY_SAVE_PATH}")
saver = policy_saver.PolicySaver(agent.policy)
saver.save(POLICY_SAVE_PATH)
print("Policy saved successfully.")

# 8. Report Results
cumulative_regret = regret_metric.result().numpy()
print(f"\nCumulative Regret vs. Perfect Foresight Oracle: {cumulative_regret:.4f}")

In [None]:
import matplotlib.pyplot as plt
# plt.axhline(y=0.0, color='r', linestyle='-')
# plt.plot(rewards)
# plt.ylabel('Rewards')
# plt.xlabel('Number of Iterations')