In [1]:
! pip install ta

Collecting ta
  Downloading ta-0.11.0.tar.gz (25 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: ta
  Building wheel for ta (setup.py) ... [?25l[?25hdone
  Created wheel for ta: filename=ta-0.11.0-py3-none-any.whl size=29412 sha256=2009e6356c5c0b3b474df57ef920fa1a07b013873a48eb08d7d671ed9f5cf409
  Stored in directory: /root/.cache/pip/wheels/a1/d7/29/7781cc5eb9a3659d032d7d15bdd0f49d07d2b24fec29f44bc4
Successfully built ta
Installing collected packages: ta
Successfully installed ta-0.11.0


In [None]:
import yfinance as yf
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
from collections import deque, namedtuple
import matplotlib.pyplot as plt
import ta # Technical Analysis library
from sklearn.preprocessing import StandardScaler # For state normalization
import copy # For deep copying target network
from datetime import timedelta

# --- Configuration ---
TICKER = 'BTC-USD'       # Stock ticker symbol (e.g., 'AAPL', 'BTC-USD')
#TICKER = 'BTC-USD' # Uncomment to test crypto

# Determine if the ticker is crypto for market hours adjustment
IS_CRYPTO = "-USD" in TICKER.upper() or "-USDT" in TICKER.upper()

DATA_PERIOD_TOTAL = '7d' # Total data to download (yfinance 1-min free limit: 7d for stocks, sometimes more for crypto but 7d is safe)
TRAIN_DAYS = 5           # Number of days from downloaded data for training (must be < total days in DATA_PERIOD_TOTAL for a test set)
INTERVAL = '1m'          # Data interval
N_SHORT_LAGS = 5         # Number of short-term (1-min) lagged features

# Adjust minutes in a day based on asset type
MINUTES_IN_TRADING_DAY_STOCK = 390 # Approx. 6.5 hours * 60 minutes (e.g., 9:30 AM - 4:00 PM)
MINUTES_IN_TRADING_DAY_CRYPTO = 24 * 60 # Crypto trades 24/7
MINUTES_IN_DAY_EFFECTIVE = MINUTES_IN_TRADING_DAY_CRYPTO if IS_CRYPTO else MINUTES_IN_TRADING_DAY_STOCK


INITIAL_BALANCE = 10000
TRANSACTION_COST_PCT = 0.001 # 0.1%

# --- DQN Hyperparameters ---
BUFFER_SIZE = int(1e5)
BATCH_SIZE = 64
GAMMA = 0.99
LR = 5e-4
TAU = 1e-3
UPDATE_EVERY = 4
TARGET_UPDATE_EVERY = 100 # Soft update target network periodically

# --- Exploration ---
EPS_START = 1.0
EPS_END = 0.01
EPS_DECAY = 0.995 # Adjust based on number of episodes / steps per episode

# --- Training ---
NUM_EPISODES = 500      # Number of training episodes
# MAX_T will be determined by the length of the training data for each episode

# --- Device ---
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"Ticker: {TICKER}, Crypto: {IS_CRYPTO}, Effective Minutes in Day for Lag: {MINUTES_IN_DAY_EFFECTIVE}")

# --- Data Loading and Feature Engineering ---
def load_and_prepare_data(ticker, period, interval, n_short_lags, train_days_count, minutes_in_day_for_lag):
    """Loads 1-min data, calculates features, adds time features and lags, splits, and scales."""
    print(f"Loading data for {ticker} ({period}, {interval})...")
    stock = yf.Ticker(ticker)
    data = stock.history(period=period, interval=interval, auto_adjust=True)

    if data.empty:
        data_source_note = "Yfinance might have limitations for 1-min data for this crypto/period combination (max 7 days usually for free tier)." if IS_CRYPTO else ""
        raise ValueError(f"No data loaded for {ticker}. Check ticker, period '{period}', interval '{interval}'. {data_source_note}")
    if not isinstance(data.index, pd.DatetimeIndex):
        raise ValueError("Data index is not a DatetimeIndex. Required for time features.")

    print(f"Original data shape: {data.shape}")
    data.dropna(subset=['Close', 'Open', 'High', 'Low', 'Volume'], how='any', inplace=True) # Drop rows if essential price/volume data is missing
    print(f"Data shape after initial essential column dropna: {data.shape}")

    if data.empty:
        raise ValueError("No data left after initial essential column NaN drop.")

    # --- Time Features ---
    print("Adding time features...")
    data['day_of_week'] = data.index.dayofweek / 6.0
    data['hour'] = data.index.hour / 23.0
    data['minute'] = data.index.minute / 59.0

    # --- Technical Indicators ---
    print("Calculating technical indicators...")
    data['RSI'] = ta.momentum.RSIIndicator(close=data['Close'], window=14).rsi()
    macd = ta.trend.MACD(close=data['Close'], window_slow=26, window_fast=12, window_sign=9)
    data['MACD'] = macd.macd()
    data['MACD_signal'] = macd.macd_signal()
    data['MACD_diff'] = macd.macd_diff()
    bb = ta.volatility.BollingerBands(close=data['Close'], window=20, window_dev=2)
    data['BB_high_ind'] = bb.bollinger_hband_indicator().astype(float)
    data['BB_low_ind'] = bb.bollinger_lband_indicator().astype(float)
    data['ATR'] = ta.volatility.AverageTrueRange(high=data['High'], low=data['Low'], close=data['Close'], window=14).average_true_range()

    base_feature_cols = ['Close', 'Volume', 'RSI', 'MACD', 'MACD_signal', 'MACD_diff',
                           'BB_high_ind', 'BB_low_ind', 'ATR',
                           'day_of_week', 'hour', 'minute']

    for col in base_feature_cols:
        if col not in data.columns:
            raise ValueError(f"Column '{col}' not found in data after indicator calculation.")
        if not pd.api.types.is_numeric_dtype(data[col]):
            print(f"Warning: Column '{col}' is not numeric. Attempting conversion.")
            data[col] = pd.to_numeric(data[col], errors='coerce')

    data.dropna(inplace=True) # Drop NaNs from TIs
    print(f"Data shape after TIs and NaN drop: {data.shape}")
    if data.empty:
        raise ValueError("No data left after calculating TIs and dropping NaNs. Check TI windows or data quality.")

    base_features_df = data[base_feature_cols].copy()

    print(f"Creating lagged features (short, 15m, 1h, 1d-approx based on {minutes_in_day_for_lag} min/day)...")
    all_features_list = [base_features_df]

    for lag in range(1, n_short_lags + 1):
        shifted = base_features_df.shift(lag)
        shifted.columns = [f'{col}_lag{lag}m' for col in base_features_df.columns]
        all_features_list.append(shifted)

    lags_to_add = {
        '15m': 15,
        '1h': 60,
        '1d': minutes_in_day_for_lag
    }
    for name, lag_val in lags_to_add.items():
        if lag_val <= 0: continue # Skip non-positive lags
        if lag_val < len(base_features_df):
            shifted = base_features_df.shift(lag_val)
            shifted.columns = [f'{col}_lag{name}' for col in base_features_df.columns]
            all_features_list.append(shifted)
        else:
             print(f"Warning: Lag value {lag_val} for '{name}' is too large for current data length {len(base_features_df)}. Skipping this lag.")

    full_features = pd.concat(all_features_list, axis=1)
    original_len_before_lag_dropna = len(full_features)
    full_features.dropna(inplace=True)
    data = data[data.index.isin(full_features.index)]
    print(f"Data shape after feature engineering and final NaN drop: {full_features.shape}")
    print(f"Dropped {original_len_before_lag_dropna - len(full_features)} rows due to NaNs from TIs/lags.")

    if full_features.empty or data.empty:
        raise ValueError("No data left after feature engineering and NaN drop. Adjust lag settings or increase data period.")

    print(f"Splitting data: {train_days_count} days for training, rest for testing...")
    unique_dates = sorted(data.index.normalize().unique())

    if len(unique_dates) <= train_days_count:
        raise ValueError(f"Not enough unique days in data ({len(unique_dates)}) for {train_days_count} train days and a separate test set. Need at least {train_days_count + 1} unique days.")

    split_date_marker = unique_dates[train_days_count - 1]
    train_mask = data.index.normalize() <= split_date_marker

    first_test_date = unique_dates[train_days_count]
    test_mask = data.index.normalize() >= first_test_date

    train_stock_data = data[train_mask]
    train_features_unscaled = full_features[train_mask]
    test_stock_data = data[test_mask]
    test_features_unscaled = full_features[test_mask]

    if train_stock_data.empty or train_features_unscaled.empty:
        raise ValueError("Training data is empty after split. Check date splitting logic or data range.")
    if test_stock_data.empty or test_features_unscaled.empty:
        print("Warning: Test data is empty after split. This might happen if TRAIN_DAYS covers almost all available data or due to data gaps on test days.")

    print(f"Training data shape: {train_stock_data.shape}, Training features shape: {train_features_unscaled.shape}")
    print(f"Test data shape: {test_stock_data.shape}, Test features shape: {test_features_unscaled.shape}")

    print("Scaling features (fitting on training data only)...")
    scaler = StandardScaler()
    scaled_train_features_np = scaler.fit_transform(train_features_unscaled)
    scaled_train_features = pd.DataFrame(scaled_train_features_np, index=train_features_unscaled.index, columns=train_features_unscaled.columns)

    scaled_test_features = pd.DataFrame()
    if not test_features_unscaled.empty:
        scaled_test_features_np = scaler.transform(test_features_unscaled)
        scaled_test_features = pd.DataFrame(scaled_test_features_np, index=test_features_unscaled.index, columns=test_features_unscaled.columns)
    else:
        print("No test data to scale.")

    return (train_stock_data, scaled_train_features,
            test_stock_data, scaled_test_features, scaler)

# --- Trading Environment ---
class TradingEnv:
    def __init__(self, stock_data_df, feature_data_df, initial_balance=10000, transaction_cost_pct=0.001):
        self.stock_data = stock_data_df.copy().reset_index(drop=True)
        self.feature_data = feature_data_df.copy().reset_index(drop=True)

        if not self.stock_data.index.equals(self.feature_data.index):
             # This can happen if one df is empty and the other is not before reset_index
            if self.stock_data.empty and self.feature_data.empty:
                print("Warning: Both stock_data and feature_data are empty in TradingEnv init.")
            elif self.stock_data.empty:
                 raise ValueError("Stock data is empty while feature data is not.")
            elif self.feature_data.empty:
                 raise ValueError("Feature data is empty while stock data is not.")
            else:
                # If both are non-empty but indices don't match after reset, it's an issue
                raise ValueError(f"Stock data and feature data indices do not match after reset. Stock len: {len(self.stock_data)}, Feature len: {len(self.feature_data)}")


        self.initial_balance = initial_balance
        self.transaction_cost_pct = transaction_cost_pct

        if self.feature_data.empty:
            print("Warning: Feature data is empty, environment will have 0 steps.")
            self.n_steps = 0
            self.state_dim = 0 # Or a predefined dimension if known, but safer to error out or handle
        else:
            self.n_steps = len(self.feature_data)
            self.state_dim = self.feature_data.shape[1]

        self.action_space_n = 3 # 0: Hold, 1: Buy, 2: Sell
        self._reset()

    def _reset(self):
        self.current_step = 0
        self.balance = self.initial_balance
        self.shares_held = 0
        self.net_worth = self.initial_balance
        self.trade_history = []
        self.total_reward = 0
        # self.daily_net_worths = [self.initial_balance] # Removed, use history in backtest for plotting
        return self._get_state()

    def _get_state(self):
        if self.n_steps == 0: # Handle case where environment has no data
             return np.zeros(self.state_dim if hasattr(self, 'state_dim') and self.state_dim > 0 else 1) # Return dummy state
        if self.current_step < self.n_steps:
            return self.feature_data.iloc[self.current_step].values
        else:
            return np.zeros(self.state_dim)

    def _get_current_price(self):
        if self.n_steps == 0 or self.current_step >= len(self.stock_data):
            return 0 # Or raise error, handle gracefully
        return self.stock_data['Close'].iloc[self.current_step]

    def step(self, action):
        if self.n_steps == 0 or self.current_step >= self.n_steps -1 : # Cannot act on the last step as no next state price for reward
             # If state_dim is 0 due to no features, get_state will handle it
             return self._get_state(), 0, True, {}

        prev_net_worth = self.net_worth
        current_price = self._get_current_price()
        cost = 0
        trade_executed = False
        action_type = 'HOLD'

        if action == 1: # Buy
            shares_can_buy = 0
            if current_price > 0: # Avoid division by zero
                shares_can_buy = self.balance / (current_price * (1 + self.transaction_cost_pct))

            # Ensure we buy a positive amount of shares and have balance
            if shares_can_buy > 1e-8 and self.balance > 0: # 1e-8 is a small threshold to avoid dust trades
                shares_to_buy = shares_can_buy # Buy all possible (fractional)

                actual_cost_of_shares = shares_to_buy * current_price
                transaction_fee = actual_cost_of_shares * self.transaction_cost_pct

                if self.balance >= actual_cost_of_shares + transaction_fee: # Final check
                    self.shares_held += shares_to_buy
                    self.balance -= (actual_cost_of_shares + transaction_fee)
                    cost = transaction_fee
                    self.trade_history.append({'step': self.current_step, 'type': 'BUY',
                                               'price': current_price, 'shares': shares_to_buy, 'cost': cost})
                    trade_executed = True
                    action_type = 'BUY'
                else: # Should not happen if shares_can_buy logic is correct, but as a safeguard
                    action_type = 'HOLD (Buy failed - insufficient funds for tx cost)'
            else:
                action_type = 'HOLD (Buy cond. not met)'


        elif action == 2: # Sell
            if self.shares_held > 1e-8: # Sell if holding a meaningful amount
                sell_value = self.shares_held * current_price
                transaction_fee = sell_value * self.transaction_cost_pct

                self.balance += (sell_value - transaction_fee)
                sold_shares = self.shares_held
                self.shares_held = 0 # Sell all
                cost = transaction_fee
                self.trade_history.append({'step': self.current_step, 'type': 'SELL',
                                           'price': current_price, 'shares': sold_shares, 'cost': cost})
                trade_executed = True
                action_type = 'SELL'
            else:
                action_type = 'HOLD (Sell cond. not met)'


        self.current_step += 1

        next_price_for_eval = current_price # Default if at the end
        if self.current_step < self.n_steps: # If there's a next step
            next_price_for_eval = self.stock_data['Close'].iloc[self.current_step]

        self.net_worth = self.balance + self.shares_held * next_price_for_eval
        reward = self.net_worth - prev_net_worth
        done = self.current_step >= self.n_steps -1 or self.net_worth <= 0

        next_state = self._get_state()
        self.total_reward += reward
        # self.daily_net_worths.append(self.net_worth) # Removed

        info = {
            'step': self.current_step, # This is already the *next* step index
            'balance': self.balance,
            'shares_held': self.shares_held,
            'net_worth': self.net_worth,
            'trade_executed': trade_executed,
            'action_taken_code': action, # Original action code
            'action_type_str': action_type, # String representation of what happened
            'cost': cost,
            'current_price_of_trade': current_price if trade_executed else None
        }
        return next_state, reward, done, info

    def render(self, mode='human', **kwargs):
        # Step in info is already advanced, so use current_step-1 for "current action" context
        print(f"Step: {self.current_step-1}/{self.n_steps}, Net Worth: {self.net_worth:.2f}, Shares: {self.shares_held:.4f}, Balance: {self.balance:.2f}, Total Episode Reward: {self.total_reward:.2f}")

# --- Q-Network (Unchanged) ---
class QNetwork(nn.Module):
    def __init__(self, state_size, action_size, seed, fc1_units=128, fc2_units=128, fc3_units=64):
        super(QNetwork, self).__init__()
        self.seed = torch.manual_seed(seed)
        self.fc1 = nn.Linear(state_size, fc1_units)
        self.fc2 = nn.Linear(fc1_units, fc2_units)
        self.fc3 = nn.Linear(fc2_units, fc3_units)
        self.fc4 = nn.Linear(fc3_units, action_size)

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        return self.fc4(x)

# --- Replay Buffer (Unchanged) ---
Experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
class ReplayBuffer:
    def __init__(self, action_size, buffer_size, batch_size, seed):
        self.action_size = action_size
        self.memory = deque(maxlen=buffer_size)
        self.batch_size = batch_size
        self.seed = random.seed(seed)

    def add(self, state, action, reward, next_state, done):
        e = Experience(state, action, reward, next_state, done)
        self.memory.append(e)

    def sample(self):
        experiences = random.sample(self.memory, k=self.batch_size)
        states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device)
        actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device)
        rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device)
        next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(device)
        dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(device)
        return (states, actions, rewards, next_states, dones)

    def __len__(self):
        return len(self.memory)

# --- DQN Agent (Modified to return loss and Q-values) ---
class DQNAgent():
    def __init__(self, state_size, action_size, seed):
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(seed)

        if state_size == 0: # Handle case of no state from env
            print("Warning: DQNAgent initialized with state_size 0. Network will not be functional.")
            # Create dummy networks to avoid crashing, but they won't work
            self.qnetwork_local = nn.Linear(1,action_size).to(device) # Dummy
            self.qnetwork_target = nn.Linear(1,action_size).to(device) # Dummy
            self.optimizer = None
        else:
            self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device)
            self.qnetwork_target = QNetwork(state_size, action_size, seed).to(device)
            self.qnetwork_target.load_state_dict(self.qnetwork_local.state_dict())
            self.qnetwork_target.eval()
            self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)

        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
        self.t_step = 0
        self.target_update_step = 0


    def step(self, state, action, reward, next_state, done):
        if self.state_size == 0: return None # Agent not functional

        self.memory.add(state, action, reward, next_state, done)
        self.t_step = (self.t_step + 1) % UPDATE_EVERY
        current_loss = None
        if self.t_step == 0:
            if len(self.memory) > BATCH_SIZE:
                experiences = self.memory.sample()
                current_loss = self.learn(experiences, GAMMA)

        self.target_update_step = (self.target_update_step + 1) % TARGET_UPDATE_EVERY
        if self.target_update_step == 0:
             self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU)
        return current_loss

    def act(self, state, eps=0.):
        if self.state_size == 0: return random.choice(np.arange(self.action_size)), 0 # Random action if not functional

        state_t = torch.from_numpy(state).float().unsqueeze(0).to(device)
        self.qnetwork_local.eval()
        with torch.no_grad():
            action_values = self.qnetwork_local(state_t)
        self.qnetwork_local.train()

        if random.random() > eps:
            action = np.argmax(action_values.cpu().data.numpy())
        else:
            action = random.choice(np.arange(self.action_size))
        return action, action_values.cpu().data.numpy().squeeze().mean() # Squeeze for single action value array

    def learn(self, experiences, gamma):
        if self.optimizer is None: return 0.0 # Cannot learn

        states, actions, rewards, next_states, dones = experiences

        Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(1)
        Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))
        Q_expected = self.qnetwork_local(states).gather(1, actions)

        loss = F.mse_loss(Q_expected, Q_targets)
        self.optimizer.zero_grad()
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(self.qnetwork_local.parameters(), 1) # Optional: Gradient Clipping
        self.optimizer.step()
        return loss.item()

    def soft_update(self, local_model, target_model, tau):
        if self.state_size == 0: return # Agent not functional
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)

# --- Training Function (Modified for more metrics) ---
def train_dqn(agent, env, n_episodes=2000, eps_start=1.0, eps_end=0.01, eps_decay=0.995):
    if env.n_steps == 0:
        print("Environment has no steps. Skipping training.")
        return { 'rewards': [], 'net_worths': [], 'losses': [], 'avg_q_values': [], 'epsilons': [], 'num_trades': []}

    episode_rewards = []
    avg_rewards_deque = deque(maxlen=100) # For printing average over last 100 episodes
    episode_net_worths = []
    episode_losses = []
    episode_avg_q_values = []
    episode_epsilons = []
    episode_num_trades = []

    eps = eps_start
    max_t_per_episode = env.n_steps - 1 # Max steps per episode is length of data for that episode minus 1

    print(f"\nStarting Training for {n_episodes} episodes... Max steps per episode: {max_t_per_episode}")
    if max_t_per_episode <=0:
        print("Max steps per episode is <=0. Training cannot proceed.")
        return { 'rewards': [], 'net_worths': [], 'losses': [], 'avg_q_values': [], 'epsilons': [], 'num_trades': []}

    for i_episode in range(1, n_episodes + 1):
        state = env._reset()
        current_episode_reward = 0
        episode_loss_sum = 0
        episode_q_sum = 0
        learn_steps_count = 0
        q_value_steps_count = 0
        num_trades_this_episode = 0

        for t in range(max_t_per_episode): # Iterate up to n_steps-2 (0 to n_steps-2)
            action, q_value = agent.act(state, eps)
            next_state, reward, done, info = env.step(action)
            loss = agent.step(state, action, reward, next_state, done)

            state = next_state
            current_episode_reward += reward

            if info.get('trade_executed', False):
                num_trades_this_episode +=1

            if loss is not None:
                episode_loss_sum += loss
                learn_steps_count += 1
            if q_value is not None: # q_value is now scalar (mean)
                episode_q_sum += q_value
                q_value_steps_count +=1

            # Optional: Detailed print for debugging first few steps of first crypto episode
            # if IS_CRYPTO and i_episode == 1 and t < 5:
            #    print(f"  Crypto Debug Ep1, t={t}: action={action}, Q_mean={q_value:.2f}, reward={reward:.2f}, bal={env.balance:.2f}, shares={env.shares_held:.4f}, net_w={env.net_worth:.2f}, action_str='{info['action_type_str']}'")

            if done:
                break

        avg_rewards_deque.append(current_episode_reward)
        episode_rewards.append(current_episode_reward)
        episode_net_worths.append(env.net_worth)
        episode_epsilons.append(eps)
        episode_num_trades.append(num_trades_this_episode)


        if learn_steps_count > 0:
            episode_losses.append(episode_loss_sum / learn_steps_count)
        else:
            episode_losses.append(None)

        if q_value_steps_count > 0:
            episode_avg_q_values.append(episode_q_sum / q_value_steps_count)
        else:
            episode_avg_q_values.append(None)

        eps = max(eps_end, eps_decay * eps)

        print(f'\rEpisode {i_episode}/{n_episodes}\tAvg Reward (100): {np.mean(avg_rewards_deque):.2f}\tReward: {current_episode_reward:.2f}\tNet Worth: {env.net_worth:.2f}\tEps: {eps:.4f}\tTrades: {num_trades_this_episode}', end="")
        if i_episode % 20 == 0 or i_episode == n_episodes :
            print(f'\rEpisode {i_episode}/{n_episodes}\tAvg Reward (100): {np.mean(avg_rewards_deque):.2f}\tReward: {current_episode_reward:.2f}\tNet Worth: {env.net_worth:.2f}\tEps: {eps:.4f}\tTrades: {num_trades_this_episode}')

    print("\nTraining finished.")
    return {
        'rewards': episode_rewards, 'net_worths': episode_net_worths,
        'losses': episode_losses, 'avg_q_values': episode_avg_q_values,
        'epsilons': episode_epsilons, 'num_trades': episode_num_trades
    }

# --- Backtesting Function (Modified for more plots) ---
def backtest(agent, env, stock_data_original_index):
    if env.n_steps == 0:
        print("Environment has no steps. Skipping backtesting.")
        return env.initial_balance, 0, []

    print("\nStarting Backtesting...")
    state = env._reset()
    done = False

    # Store history for plotting and analysis
    portfolio_values = [env.initial_balance] # Start with initial balance
    all_shares_held = [0]
    all_rewards = []
    all_actions_str = [] # Store string representation of actions

    max_backtest_steps = env.n_steps -1
    if max_backtest_steps <= 0:
        print("Not enough data for backtesting (max_backtest_steps <=0).")
        return env.initial_balance, 0, []

    for t_step in range(max_backtest_steps): # Iterate up to n_steps-2
        action, _ = agent.act(state, eps=0.0) # No exploration in backtest
        next_state, reward, done, info = env.step(action)
        state = next_state

        portfolio_values.append(info['net_worth'])
        all_shares_held.append(info['shares_held'])
        all_rewards.append(reward)
        all_actions_str.append(info['action_type_str'])

        if info.get('trade_executed', False):
             print(f"Step: {info['step']-1}, Action: {info['action_type_str']}, Price: {info['current_price_of_trade']:.2f}, Shares: {info['shares_held']:.4f} (after trade), Net Worth: {info['net_worth']:.2f}, Cost: {info['cost']:.2f}")
        elif (info['step']-1) % (MINUTES_IN_DAY_EFFECTIVE // 4) == 0: # Print Hold status occasionally based on effective day length
             print(f"Step: {info['step']-1}, Action: {info['action_type_str']}, Net Worth: {info['net_worth']:.2f}")

        if done:
            break

    trade_details = env.trade_history # Get trade log from environment
    print("Backtesting Finished.")

    final_net_worth = portfolio_values[-1]
    total_return_pct = (final_net_worth - env.initial_balance) / env.initial_balance * 100 if env.initial_balance else 0

    buy_hold_final_worth = env.initial_balance
    buy_hold_return_pct = 0.0
    if not env.stock_data.empty and len(env.stock_data) > 1:
        buy_hold_start_price = env.stock_data['Close'].iloc[0]
        buy_hold_end_price = env.stock_data['Close'].iloc[env.n_steps-1 if env.n_steps > 0 else 0] # Use last available price
        if buy_hold_start_price > 0:
            buy_hold_shares = env.initial_balance / buy_hold_start_price
            buy_hold_final_worth = buy_hold_shares * buy_hold_end_price
            buy_hold_return_pct = (buy_hold_final_worth - env.initial_balance) / env.initial_balance * 100 if env.initial_balance else 0

    print(f"\n--- Backtest Results ({TICKER}) ---")
    print(f"Initial Balance: ${env.initial_balance:.2f}")
    print(f"Final Net Worth (Agent): ${final_net_worth:.2f}")
    print(f"Total Return (Agent): {total_return_pct:.2f}%")
    print(f"Number of Trades (Agent): {len(trade_details)}")
    print(f"\n--- Buy and Hold Benchmark ---")
    print(f"Final Net Worth (Buy & Hold): ${buy_hold_final_worth:.2f}")
    print(f"Total Return (Buy & Hold): {buy_hold_return_pct:.2f}%")

    # Plotting
    # The length of portfolio_values is num_steps + 1 (includes initial balance)
    # The length of stock_data_original_index is num_env_data_points
    # env.n_steps is len(feature_data) which is also len(stock_data) in env
    # We have max_backtest_steps = env.n_steps - 1. Loop runs this many times.
    # portfolio_values will have 1 (initial) + max_backtest_steps = env.n_steps items.
    # So, use stock_data_original_index[:env.n_steps]

    plot_index = stock_data_original_index[:len(portfolio_values)-1] # Index for items that correspond to a step/action
    # This plot_index should match the length of rewards, shares_held (excluding initial), etc.

    plt.figure(figsize=(18, 16))

    plt.subplot(4, 1, 1)
    # portfolio_values[0] is initial balance. portfolio_values[1:] are after each step.
    plt.plot(plot_index, portfolio_values[1:], label='Agent Portfolio Value', color='blue')
    if not env.stock_data.empty and len(env.stock_data) > 1 and buy_hold_start_price > 0:
        # Align Buy & Hold with the actual period agent traded on
        buy_hold_plot_prices = env.stock_data['Close'].iloc[:len(plot_index)].values
        buy_hold_values = (env.initial_balance / buy_hold_start_price) * buy_hold_plot_prices
        plt.plot(plot_index, buy_hold_values, label=f'Buy & Hold ({TICKER})', color='grey', linestyle='--')
    plt.title(f'{TICKER} Agent Performance vs Buy & Hold (Backtest)')
    plt.ylabel('Portfolio Value ($)')
    plt.legend()
    plt.grid(True)

    plt.subplot(4, 1, 2)
    prices_to_plot = []
    if not env.stock_data.empty:
        prices_to_plot = env.stock_data['Close'].iloc[:len(plot_index)].values
        plt.plot(plot_index, prices_to_plot, label='Stock Price', color='black', alpha=0.7)

        buy_trade_details = [td for td in trade_details if td['type'] == 'BUY']
        sell_trade_details = [td for td in trade_details if td['type'] == 'SELL']

        if buy_trade_details:
            buy_steps_indices = [td['step'] for td in buy_trade_details if td['step'] < len(plot_index)]
            buy_prices_at_steps = [td['price'] for td in buy_trade_details if td['step'] < len(plot_index)]
            if buy_steps_indices: # Ensure valid indices exist
                 plt.scatter(plot_index[buy_steps_indices], buy_prices_at_steps, marker='^', color='green', label='Buy Signal', s=100, alpha=0.9, zorder=5)

        if sell_trade_details:
            sell_steps_indices = [td['step'] for td in sell_trade_details if td['step'] < len(plot_index)]
            sell_prices_at_steps = [td['price'] for td in sell_trade_details if td['step'] < len(plot_index)]
            if sell_steps_indices: # Ensure valid indices exist
                plt.scatter(plot_index[sell_steps_indices], sell_prices_at_steps, marker='v', color='red', label='Sell Signal', s=100, alpha=0.9, zorder=5)

    plt.ylabel('Stock Price ($)')
    plt.legend()
    plt.grid(True)

    plt.subplot(4, 1, 3)
    # all_shares_held[0] is initial. all_shares_held[1:] are after each step.
    plt.plot(plot_index, all_shares_held[1:], label='Shares Held', color='purple')
    plt.ylabel('Number of Shares')
    plt.legend()
    plt.grid(True)

    ax1 = plt.subplot(4, 1, 4)
    # all_rewards has length max_backtest_steps.
    if all_rewards:
        cumulative_rewards = np.cumsum(all_rewards)
        color = 'tab:green'
        ax1.set_xlabel('Time (Datetime)')
        ax1.set_ylabel('Cumulative Reward', color=color)
        ax1.plot(plot_index[:len(cumulative_rewards)], cumulative_rewards, color=color, label='Cumulative Reward')
        ax1.tick_params(axis='y', labelcolor=color)
        ax1.legend(loc='upper left')
        ax1.grid(True, axis='y')

        ax2 = ax1.twinx()
        portfolio_values_for_drawdown = np.array(portfolio_values[1:]) # Net worth after each step
        if len(portfolio_values_for_drawdown) > 0:
            peak = np.maximum.accumulate(portfolio_values_for_drawdown)
            drawdown = (portfolio_values_for_drawdown - peak) / peak * 100 # Percentage
            color = 'tab:red'
            ax2.set_ylabel('Drawdown (%)', color=color)
            ax2.plot(plot_index[:len(drawdown)], drawdown, color=color, alpha=0.6, label='Drawdown')
            ax2.tick_params(axis='y', labelcolor=color)
            ax2.fill_between(plot_index[:len(drawdown)], drawdown, 0, alpha=0.2, color=color)
            ax2.legend(loc='upper right')

    plt.title('Agent Financial Metrics')
    plt.tight_layout()
    plt.show()

    return final_net_worth, total_return_pct, trade_details


# --- Plot Training Metrics ---
def plot_training_metrics(metrics, ticker_symbol):
    # Filter out metrics that are empty or None before counting
    valid_metric_keys = [k for k, v in metrics.items() if v and (isinstance(v, list) and len(v) > 0)]
    num_metrics_to_plot = len(valid_metric_keys)

    if num_metrics_to_plot == 0:
        print("No valid training metrics to plot.")
        return

    fig, axs = plt.subplots(num_metrics_to_plot, 1, figsize=(12, num_metrics_to_plot * 3.5), sharex=True)
    if num_metrics_to_plot == 1: axs = [axs] # Ensure axs is always a list/array

    plot_idx = 0

    def plot_metric(data, label, color, ylabel, rolling_window=100, rolling_label_suffix="Avg"):
        nonlocal plot_idx
        if data and any(d is not None for d in data): # Check for non-None items
            valid_data_points = [(i,val) for i,val in enumerate(data) if val is not None]
            if not valid_data_points: return # Skip if no valid points

            indices, values = zip(*valid_data_points)
            axs[plot_idx].plot(indices, values, label=label, color=color, alpha=0.7)

            # Calculate rolling mean only on valid (non-None) numeric data
            series = pd.Series(values, index=indices)
            if len(series) >= rolling_window:
                 axs[plot_idx].plot(series.rolling(window=rolling_window).mean(), label=f'Rolling {rolling_label_suffix} ({rolling_window} eps)', color=plt.cm.Oranges(0.6))

            axs[plot_idx].set_ylabel(ylabel)
            axs[plot_idx].legend()
            axs[plot_idx].grid(True)
            if plot_idx == 0: axs[plot_idx].set_title(f'Training Progress for {ticker_symbol}')
            plot_idx +=1

    plot_metric(metrics.get('rewards'), 'Episode Reward', 'blue', 'Total Reward', rolling_label_suffix="Reward")
    plot_metric(metrics.get('net_worths'), 'Episode End Net Worth', 'green', 'Net Worth ($)', rolling_label_suffix="Net Worth")
    plot_metric(metrics.get('losses'), 'Avg Episode Loss', 'red', 'MSE Loss', rolling_label_suffix="Loss")
    plot_metric(metrics.get('avg_q_values'), 'Avg Q-Value', 'purple', 'Avg Q-Value', rolling_label_suffix="Q-Value")
    plot_metric(metrics.get('epsilons'), 'Epsilon', 'cyan', 'Epsilon Value', rolling_window=1, rolling_label_suffix="Epsilon") # No rolling for epsilon typically
    plot_metric(metrics.get('num_trades'), 'Number of Trades', 'brown', 'Trades per Episode', rolling_label_suffix="Trades")

    if plot_idx > 0 : # If any plot was made
        axs[plot_idx-1].set_xlabel('Episode #')
    plt.tight_layout()
    plt.show()


# --- Main Execution ---
if __name__ == '__main__':
    try:
        # 1. Load and Prepare Data
        train_stock_df, train_features_scaled_df, \
        test_stock_df, test_features_scaled_df, \
        feature_scaler = load_and_prepare_data(
            TICKER, DATA_PERIOD_TOTAL, INTERVAL, N_SHORT_LAGS, TRAIN_DAYS, MINUTES_IN_DAY_EFFECTIVE
        )

        if train_features_scaled_df.empty:
            raise ValueError("Training features are empty after loading and preparation. Cannot proceed.")

        # 2. Create Training Environment
        train_env = TradingEnv(train_stock_df, train_features_scaled_df,
                                INITIAL_BALANCE, TRANSACTION_COST_PCT)

        # 3. Create Agent
        agent = DQNAgent(state_size=train_env.state_dim, action_size=train_env.action_space_n, seed=0)

        if train_env.state_dim == 0: # If env could not be properly initialized
            raise ValueError("Environment state dimension is 0. Agent cannot be trained. Check data loading and feature engineering.")

        # 4. Train the Agent
        training_metrics = train_dqn(
            agent, train_env,
            n_episodes=NUM_EPISODES,
            eps_start=EPS_START,
            eps_end=EPS_END,
            eps_decay=EPS_DECAY
        )

        # Plot training progress
        plot_training_metrics(training_metrics, TICKER)

        # Optional: Save the trained model
        model_save_path = f'{TICKER.replace("-","_")}_dqn_model_intraday.pth' # Sanitize filename
        torch.save(agent.qnetwork_local.state_dict(), model_save_path)
        print(f"Trained model saved to {model_save_path}")

        # 5. Backtest the Trained Agent
        if not test_features_scaled_df.empty and not test_stock_df.empty:
            backtest_env = TradingEnv(test_stock_df, test_features_scaled_df,
                                      INITIAL_BALANCE, TRANSACTION_COST_PCT)
            if backtest_env.state_dim > 0:
                final_worth, total_return, trade_history = backtest(agent, backtest_env, test_stock_df.index)
            else:
                print("\nSkipping backtesting as test environment state dimension is 0.")
        else:
            print("\nSkipping backtesting as test data is empty.")
            print("This can happen if TRAIN_DAYS covers all available data or if specified test period had no valid data after processing.")

    except ValueError as ve:
        print(f"ValueError: {ve}")
        import traceback
        traceback.print_exc()
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        import traceback
        traceback.print_exc()

Using device: cpu
Ticker: BTC-USD, Crypto: True, Effective Minutes in Day for Lag: 1440
Loading data for BTC-USD (7d, 1m)...
Original data shape: (7630, 7)
Data shape after initial essential column dropna: (7630, 7)
Adding time features...
Calculating technical indicators...
Data shape after TIs and NaN drop: (7597, 17)
Creating lagged features (short, 15m, 1h, 1d-approx based on 1440 min/day)...
Data shape after feature engineering and final NaN drop: (6157, 108)
Dropped 1440 rows due to NaNs from TIs/lags.
Splitting data: 5 days for training, rest for testing...
Training data shape: (5955, 17), Training features shape: (5955, 108)
Test data shape: (202, 17), Test features shape: (202, 108)
Scaling features (fitting on training data only)...

Starting Training for 500 episodes... Max steps per episode: 5954
Episode 20/500	Avg Reward (100): -4196.28	Reward: -3790.63	Net Worth: 6209.37	Eps: 0.9046	Trades: 466
Episode 40/500	Avg Reward (100): -4121.36	Reward: -3945.09	Net Worth: 6054.91	