In [None]:
import os
# Keep using keras-2 (tf-keras) rather than keras-3 (keras).
os.environ['TF_USE_LEGACY_KERAS'] = '1'

import numpy as np
import pandas as pd
import tensorflow as tf

gpus = tf.config.list_physical_devices('GPU')
print(f"gpus: {gpus}")

from tf_agents.bandits.agents import lin_ucb_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import py_environment
from tf_agents.environments import tf_py_environment
from tf_agents.specs import array_spec, tensor_spec
from tf_agents.trajectories import time_step as ts
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.bandits.metrics import tf_metrics as tf_bandit_metrics
from tf_agents.policies import policy_saver

from utils import load_and_prepare_data, SYMBOLS

# --- Environment Class ---
class CryptoTradingEnvironment(py_environment.PyEnvironment):
    def __init__(self, data, symbols, context_len=10):
        super().__init__()
        self._data = data
        self._context_len = context_len
        self._symbols = symbols
        self._num_cryptos = len(self._symbols)
        self._minimal_obs_size = self._num_cryptos * 2 
        observation_size = self._minimal_obs_size * self._context_len
        self._action_spec = array_spec.BoundedArraySpec(
            shape=(), dtype=np.int32, minimum=0, maximum=self._num_cryptos * 2 - 1, name='action'
        )
        self._observation_spec = array_spec.ArraySpec(
            shape=(observation_size,), dtype=np.float32, name='context'
        )
        data_spec = tensor_spec.TensorSpec([self._minimal_obs_size], dtype=tf.float32, name='minimal_observation')
        self._replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=data_spec, batch_size=1, max_length=self._context_len + 5
        )
        self._current_step_index = 0
        self._episode_ended = False

    def action_spec(self):
        return self._action_spec

    def observation_spec(self):
        return self._observation_spec

    def _get_minimal_observation(self, index):
        obs_slice = []
        for symbol in self._symbols:
            obs_slice.append(self._data.iloc[index][f'{symbol}_close_return'])
            obs_slice.append(self._data.iloc[index][f'{symbol}_volume_return'])
        return np.array(obs_slice, dtype=np.float32)

    def _observe(self):
        num_items_in_buffer = self._replay_buffer.num_frames()
        if num_items_in_buffer == 0:
            return np.zeros(self._observation_spec.shape, dtype=np.float32)

        dataset = self._replay_buffer.as_dataset(single_deterministic_pass=True)
        batched_items = next(iter(dataset.batch(num_items_in_buffer)))
        
        # Select only the data tensor (at index 0) from the (data, info) tuple
        all_items = batched_items[0]
        context = tf.reshape(all_items[-self._context_len:], [-1])
        return context.numpy()

    def _reset(self):
        self._replay_buffer.clear()
        self._episode_ended = False
        self._current_step_index = self._context_len 
        for i in range(self._current_step_index):
            self._replay_buffer.add_batch(tf.expand_dims(self._get_minimal_observation(i), 0))
        return ts.restart(self._observe())

    def _step(self, action):
        if self._episode_ended:
            return self.reset()
        crypto_index = action // 2
        is_buy_action = action % 2 == 0
        column_name = f'{self._symbols[crypto_index]}_close'
        current_price = self._data.iloc[self._current_step_index][column_name]
        next_price = self._data.iloc[self._current_step_index + 1][column_name]
        reward = ((next_price - current_price) / current_price) if is_buy_action else ((current_price - next_price) / current_price)
        self._replay_buffer.add_batch(tf.expand_dims(self._get_minimal_observation(self._current_step_index), 0))
        self._current_step_index += 1
        if self._current_step_index >= len(self._data) - 2:
            self._episode_ended = True
        observation = self._observe()
        return ts.termination(observation, reward) if self._episode_ended else ts.transition(observation, reward)

# --- Configuration & Main Script ---
CRYPTO_NAMES = SYMBOLS
DATA_FILEPATH = 'data/ohlcv.csv.gz'
CONTEXT_LENGTH = 10
NUM_TRAINING_STEPS = 1000
POLICY_SAVE_PATH = 'policy' # Directory to save the policy


data = load_and_prepare_data(DATA_FILEPATH, CRYPTO_NAMES)

if NUM_TRAINING_STEPS > len(data) - CONTEXT_LENGTH - 5:
    NUM_TRAINING_STEPS = len(data) - CONTEXT_LENGTH - 5
    print(f"\nWarning: Training steps reduced to {NUM_TRAINING_STEPS} to fit available data.")

tf_env = tf_py_environment.TFPyEnvironment(CryptoTradingEnvironment(data, symbols=CRYPTO_NAMES, context_len=CONTEXT_LENGTH))

agent = lin_ucb_agent.LinearUCBAgent(
    time_step_spec=tf_env.time_step_spec(),
    action_spec=tf_env.action_spec(),
    alpha=1.0,
    dtype=tf.float32
)

def train_step(trajectory):        
    if not trajectory.is_last():
        time_axised_trajectory = tf.nest.map_structure(lambda t: tf.expand_dims(t, 1), trajectory)
        agent.train(time_axised_trajectory)

def optimal_reward_oracle(observation):
    """
    Calculates the best possible reward for the current step by looking ahead.
    NOTE: The 'observation' is unused, but required by the metric's API.
    We rely on the environment's internal state.
    """
    # Get the python environment to access its internal state
    py_env = tf_env.pyenv.envs[0]
    current_step = py_env._current_step_index
    
    if current_step >= len(py_env._data) - 2:
        return 0.0 # No future data available

    all_possible_rewards = []
    num_actions = py_env.action_spec().maximum + 1
    
    for action in range(num_actions):
        crypto_index = action // 2
        is_buy_action = action % 2 == 0
        symbol = py_env._symbols[crypto_index]
        column_name = f'{symbol}_close'
        
        current_price = py_env._data.iloc[current_step][column_name]
        next_price = py_env._data.iloc[current_step + 1][column_name]
        
        reward = ((next_price - current_price) / current_price) if is_buy_action else ((current_price - next_price) / current_price)
        all_possible_rewards.append(reward)
        
    return np.max(all_possible_rewards).astype(np.float32)

# Use our oracle with the RegretMetric
regret_metric = tf_bandit_metrics.RegretMetric(optimal_reward_oracle)

def get_action_name(action):
    crypto_index = action // 2
    action_type = "BUY" if action % 2 == 0 else "SELL"
    return f"{action_type} {CRYPTO_NAMES[crypto_index]}"
    
driver = dynamic_step_driver.DynamicStepDriver(
    env=tf_env,
    policy=agent.policy,
    num_steps=NUM_TRAINING_STEPS,
    observers=[train_step, regret_metric]
)

print(f"\nStarting training for {NUM_TRAINING_STEPS} steps...")
driver.run()
print("Training finished.")

# --- SAVING THE POLICY AFTER TRAINING ---
print(f"\nSaving the trained policy to: {POLICY_SAVE_PATH}")

# 1. Create a PolicySaver instance for our agent's policy
saver = policy_saver.PolicySaver(agent.policy)

# 2. Save the policy to the specified directory
saver.save(POLICY_SAVE_PATH)

print("Policy saved successfully.")

cumulative_regret = regret_metric.result().numpy()
print(f"cumulative_regret: {cumulative_regret}")
print(f"\nCumulative Regret vs. Perfect Foresight Oracle: {cumulative_regret:.4f}")
print("This measures the total profit the agent missed compared to a perfect model.")

print("\n--- Evaluation Loop ---")
time_step = tf_env.reset()
cumulative_reward = 0
num_eval_steps = 100
rewards = []
for i in range(num_eval_steps):
    if time_step.is_last():
        print("Evaluation data ended. Resetting.")
        time_step = tf_env.reset()
        
    action_step = agent.policy.action(time_step)
    action = action_step.action.numpy()[0]
    time_step = tf_env.step(action)
    reward = time_step.reward.numpy()[0]
    cumulative_reward += reward
    rewards.append(reward)

print(f"\nFinal cumulative reward over last {num_eval_steps} steps: {cumulative_reward:.6f}")

In [None]:
import matplotlib.pyplot as plt
plt.axhline(y=0.0, color='r', linestyle='-')
plt.plot(rewards)
plt.ylabel('Rewards')
plt.xlabel('Number of Iterations')