In [None]:
import os
# Keep using keras-2 (tf-keras) rather than keras-3 (keras).
os.environ['TF_USE_LEGACY_KERAS'] = '1'

import numpy as np
import pandas as pd
import tensorflow as tf
import time

# We only need a few imports for deployment
from tf_agents.trajectories import time_step as ts
from tf_agents.specs import array_spec, tensor_spec
from tf_agents.replay_buffers import tf_uniform_replay_buffer


from utils import load_and_prepare_data, SYMBOLS

POLICY_SAVE_PATH = 'policy'
CRYPTO_NAMES = SYMBOLS
CONTEXT_LENGTH = 10
NUM_TRAINING_STEPS = 1000


# --- Helper functions to simulate the live environment ---
def get_action_name(action):
    crypto_index = action // 2
    action_type = "BUY" if action % 2 == 0 else "SELL"
    return f"{action_type} {CRYPTO_NAMES[crypto_index]}"

def get_live_minimal_observation(row):
    """Creates a minimal observation from a single row of live data."""
    obs_slice = []
    for symbol in CRYPTO_NAMES:
        obs_slice.append(row[f'{symbol}_close_return'])
        obs_slice.append(row[f'{symbol}_volume_return'])
    return np.array(obs_slice, dtype=np.float32)

def create_context_from_buffer(buffer):
    """Samples from the buffer to create the context vector for the policy."""
    num_items = buffer.num_frames()
    dataset = buffer.as_dataset(single_deterministic_pass=True)
    batched_items = next(iter(dataset.batch(num_items)))
    
    # Select only the data tensor (at index 0) from the (data, info) tuple
    all_items = batched_items[0]

    # The squeeze operation is removed from here as well.
    context = tf.reshape(all_items[-CONTEXT_LENGTH:], [-1])
    return tf.expand_dims(context, 0)

# --- Main Deployment Logic ---
print(f"Loading trained policy from {POLICY_SAVE_PATH}...")
# 1. Load the policy from the saved directory
loaded_policy = tf.saved_model.load(POLICY_SAVE_PATH)

# --- Simulate a live data feed ---
# In a real system, this data would come from a WebSocket or REST API call.
# We'll just load our data file again for simulation.
# Let's pretend the first part of the data was for training, and the next part is live.
all_data = load_and_prepare_data('data/ohlcv.csv.gz', CRYPTO_NAMES)
live_data_stream = all_data.iloc[NUM_TRAINING_STEPS:]

print(f"\n--- Starting Live Inference Simulation ({len(live_data_stream)} steps) ---")

# Setup a buffer to maintain the state, just like in the environment
num_features = len(CRYPTO_NAMES) * 2
data_spec = tensor_spec.TensorSpec([num_features], dtype=tf.float32, name='minimal_observation')
live_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=data_spec, batch_size=1, max_length=CONTEXT_LENGTH + 5
)

# Pre-fill the buffer with initial data
for i in range(CONTEXT_LENGTH):
    minimal_obs = get_live_minimal_observation(live_data_stream.iloc[i])
    live_buffer.add_batch(tf.expand_dims(minimal_obs, 0))

# The main loop of the trading bot
for i in range(CONTEXT_LENGTH, len(live_data_stream)):
    # 1. Get current context from our state buffer
    current_context = create_context_from_buffer(live_buffer)
    
    # 2. Construct a TimeStep object for the policy
    # For a continuous stream, we use a transition. The reward can be a dummy value.
    time_step = ts.transition(observation=current_context, reward=tf.constant([0.0]))
    
    # 3. GET THE ACTION FROM THE POLICY
    action_step = loaded_policy.action(time_step)
    action = action_step.action.numpy()[0]
    
    print(f"Step {i}: Context ready. Policy chose action: {get_action_name(action)}")
    # In a real bot: Execute trade via API based on `action`
    
    # 4. Update the buffer with the latest data point to prepare for the next step
    latest_minimal_obs = get_live_minimal_observation(live_data_stream.iloc[i])
    live_buffer.add_batch(tf.expand_dims(latest_minimal_obs, 0))
    
    time.sleep(0.1) # Simulate waiting for the next data candle

print("\nLive simulation finished.")