<a href="https://colab.research.google.com/github/AlexKitipov/VFX-0251-R.ipynb/blob/main/VFX_0251_R.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
# Сравнение на AI Платформи

#Ето преглед на някои от водещите облачни платформи за изкуствен интелект:

#🧠 **1. Google Cloud AI Platform**
#- Предлага AutoML, TensorFlow, PyTorch и JAX среди.
#- Можете да качите свои данни и да обучите модели директно в облака.
#- Има визуален интерфейс за настройка на хиперпараметри и мониторинг.
#- 🔗 [Официален сайт на Google Cloud AI](https://cloud.google.com/ai-platform)

#☁️ **2. Microsoft Azure AI**
#- Поддържа обучение на модели с ML Studio, AutoML и OpenAI API.
#- Има drag-and-drop интерфейс за начинаещи.
#- Може да се интегрира със Streamlit, Power BI и други.
#- 🔗 [Преглед на Azure AI и други платформи](https://azure.microsoft.com/en-us/overview/ai-platform/)

#🧪 **3. Amazon SageMaker**
#- Позволява обучение, тестване и деплой на модели в една среда.
#- Има готови Jupyter notebook шаблони.
#- Подходящо за NLP, CV и таблични данни.

#🧩 **4. IBM Watson Studio**
#- Силен фокус върху корпоративни приложения.
#- Може да обучава модели с AutoAI и визуални инструменти.
#- Поддържа Python, R и Scala.

In [7]:
# 📘 ValkyrieFX Core Engine — Стартов бележник
# Автор: Александар
# Версия: 1.0
# Дата: 2025-09-04
# Описание:
# Този бележник съдържа модулното ядро на ValkyrieFX — платформа за форекс търговия с помощта на RL агенти.
# Включва: обработка на данни, технически индикатори, симулационна среда, обучение на агент, Streamlit интерфейс.
# Подготвен за разширение с GUI, API и стратегии.

# ✅ Инсталиране на необходими библиотеки
!pip install pandas numpy yfinance ta gymnasium stable-baselines3 streamlit plotly --quiet


%%writefile agent_utils.py
import os
from stable_baselines3 import PPO, A2C, DQN
from stable_baselines3.common.vec_env import DummyVecEnv

# Function to create the RL agent
def create_agent(agent_name, env, agent_params=None):
    """
    Creates an instance of a Stable-Baselines3 RL agent.

    Args:
        agent_name (str): Name of the RL agent algorithm ('PPO', 'DQN', 'A2C').
        env (stable_baselines3.common.vec_env.VecEnv): The vectorized training environment.
        agent_params (dict, optional): Dictionary of agent-specific parameters. Defaults to None.

    Returns:
        stable_baselines3.common.base.BaseAlgorithm or None: The created agent instance, or None if agent name is unknown or an error occurs.
    """
    model = None
    # Get agent-specific parameters, defaulting to an empty dict if not provided
    current_agent_params = agent_params if agent_params else {}

    try:
        if agent_name == "PPO":
            # Define default PPO parameters and override with provided ones
            ppo_defaults = {
                "learning_rate": 1e-4,
                "n_steps": 2048,
                "batch_size": 64,
                "n_epochs": 10,
                "gamma": 0.99,
                "gae_lambda": 0.95,
                "clip_range": 0.2,
            }
            # Combine defaults and provided parameters, provided parameters take precedence
            final_ppo_params = {**ppo_defaults, **current_agent_params}
            model = PPO("MlpPolicy", env, verbose=0, **final_ppo_params) # Create PPO model
        elif agent_name == "DQN":
            # Define default DQN parameters and override with provided ones
            dqn_defaults = {
                "learning_rate": 1e-4,
                "buffer_size": 10000,
                "learning_starts": 100,
                "batch_size": 32,
                "gamma": 0.99,
                "train_freq": 1,
                "gradient_steps": 1,
            }
            # Combine defaults and provided parameters
            final_dqn_params = {**dqn_defaults, **current_agent_params}
            model = DQN("MlpPolicy", env, verbose=0, **final_dqn_params) # Create DQN model
            print(f"Using DQN with MlpPolicy for {agent_name}. Ensure observation space is compatible with DQN's MlpPolicy or consider a different policy.")
        elif agent_name == "A2C":
            # Define default A2C parameters and override with provided ones
            a2c_defaults = {
                "learning_rate": 7e-4,
                "n_steps": 5,
                "gamma": 0.99,
                "gae_lambda": 0.95,
                "vf_coef": 0.25,
                "ent_coef": 0.01,
            }
            # Combine defaults and provided parameters
            final_a2c_params = {**a2c_defaults, **current_agent_params}
            model = A2C("MlpPolicy", env, verbose=0, **final_a2c_params) # Create A2C model

        # Add other agents here (e.g., DDPG) if needed

        else:
            print(f"❌ Непознат агент: {agent_name}")
            return None

    except Exception as e:
         print(f"🚫 Грешка при създаване на агента {agent_name}: {e}")
         print(e)
         return None

    return model

# Function to train the RL agent
def train_agent(agent, total_timesteps, progress_callback=None):
    """
    Trains the provided RL agent.

    Args:
        agent (stable_baselines3.common.base.BaseAlgorithm): The agent instance to train.
        total_timesteps (int): The total number of timesteps for training.
        progress_callback (function, optional): A callback function to update progress.
                                                Takes current_step and total_steps as arguments. Defaults to None.
    Returns:
        stable_baselines3.common.base.BaseAlgorithm or None: The trained agent, or None if training fails.
    """
    if agent is None:
        print("🚫 Грешка: Агентът не е наличен за обучение.")
        return None

    try:
        print(f"\n🧠 Стартиране на обучението за {type(agent).__name__} агент за {total_timesteps} стъпки...")
        # Start the training process.
        agent.learn(total_timesteps=total_timesteps, callback=progress_callback)
        print("✅ Обучението приключи.")
        return agent
    except Exception as e:
        print(f"🚫 Възникна грешка по време на обучението: {e}")
        print(e)
        return None

# Function to save the trained agent
def save_agent(agent, path):
    """
    Saves the trained agent model.

    Args:
        agent (stable_baselines3.common.base.BaseAlgorithm): The agent instance to save.
        path (str): The directory or file path to save the model.
    """
    if agent is None:
        print("🚫 Няма обучен агент за запазване.")
        return
    try:
        agent.save(path)
        print(f"✅ Агентът е запазен успешно в: {path}.zip") # Stable-Baselines3 adds .zip
    except Exception as e:
        print(f"🚫 Възникна грешка при запазването на агента: {e}")
        print(e)

# Function to load a saved agent
def load_agent(path, env):
    """
    Loads a trained agent model.

    Args:
        path (str): The directory or file path to load the model from (without .zip extension).
        env (stable_baselines3.common.vec_env.VecEnv): The environment the agent was trained on or is compatible with.

    Returns:
        stable_baselines3.common.base.BaseAlgorithm or None: The loaded agent instance, or None if loading fails.
    """
    full_path = path + ".zip" if not path.lower().endswith('.zip') else path
    if not os.path.exists(full_path):
        print(f"⚠️ Запазен модел не е намерен на: {full_path}")
        return None

    try:
        # Automatically detect agent type from the saved file (if available)
        # Stable-Baselines3 save method doesn't inherently store agent type easily.
        # Need to know the agent type beforehand or store it separately.
        # For simplicity, assume PPO for now or pass agent_name to load.
        # Let's modify to accept agent_name for loading the correct class.
        print(f"⚙️ Зареждане на агент от: {full_path}")
        # PPO.load requires the environment
        loaded_agent = PPO.load(full_path, env=env) # Assuming PPO for now

        print("✅ Агентът е зареден успешно.")
        return loaded_agent
    except Exception as e:
        print(f"🚫 Грешка при зареждането на агента от {full_path}: {e}")
        print(e)
        return None

# Note: For loading different agent types (A2C, DQN), you would need
# to know the type before calling load and use the correct class (A2C.load, DQN.load).
# A better load function would infer the type or require it as an argument.
# Example: def load_agent(path, env, agent_name): ... then use if/elif to call correct load method.

In [8]:
!wget https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64 -O cloudflared
!chmod +x cloudflared


--2025-10-07 03:55:56--  https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64
Resolving github.com (github.com)... 20.27.177.113
Connecting to github.com (github.com)|20.27.177.113|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://github.com/cloudflare/cloudflared/releases/download/2025.9.1/cloudflared-linux-amd64 [following]
--2025-10-07 03:55:56--  https://github.com/cloudflare/cloudflared/releases/download/2025.9.1/cloudflared-linux-amd64
Reusing existing connection to github.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://release-assets.githubusercontent.com/github-production-release-asset/106867604/e30ab3bb-4e6a-464a-8db5-d5cabe6a2f8d?sp=r&sv=2018-11-09&sr=b&spr=https&se=2025-10-07T04%3A46%3A10Z&rscd=attachment%3B+filename%3Dcloudflared-linux-amd64&rsct=application%2Foctet-stream&skoid=96c2d410-5711-43a1-aedd-ab1947aa7ab0&sktid=398a6654-997b-47e9-b12b-9515b896b4de&skt=2025-10-07

In [9]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Task
Implement a checkpointing mechanism for the agent training process to save models at regular intervals and automatically resume training from the latest checkpoint upon restart. The checkpoints should be saved in the "/content/drive/MyDrive/" directory.

In [10]:
# Create the 'env' directory if it doesn't exist
import os
if not os.path.exists('env'):
    os.makedirs('env')
    print("Created directory: env")
else:
    print("Directory 'env' already exists.")

Created directory: env


In [11]:
%%writefile env/forex_env.py
import gymnasium as gym
from gymnasium import spaces
import pandas as pd
import numpy as np # Added import for numpy
import ta # Import ta for technical indicators
from collections import deque # For tracking open positions
import time # For timestamp in trades log

class ForexTradingEnv(gym.Env):
    """
    Custom Gymnasium Environment for Forex Trading.
    The environment interacts with a pandas DataFrame containing historical price data
    and technical indicators.

    Observation Space:
    - Concatenation of:
        - Lookback window of 'close' prices (normalized).
        - Lookback window of selected technical indicators (normalized).
        - Current account balance (normalized).
        - Current open position size (normalized).
        - Current average entry price (normalized).

    Action Space: Discrete(3)
    - 0: Hold (do nothing)
    - 1: Buy (enter a long position or increase existing long position)
    - 2: Sell (exit a long position or enter a short position - currently only long positions are supported)

    Reward:
    - Change in portfolio value (cash + unrealized PnL), potentially with bonuses/penalties.

    Termination:
    - When max drawdown limit is reached.
    - When the end of the data is reached.
    """
    metadata = {'render_modes': ['human']} # Define render modes if needed


    def __init__(self, df, initial_amount=100000, lookback_window=20,
                 buy_cost_pct=0.001, sell_cost_pct=0.001, max_drawdown_limit_pct=0.10,
                 position_size_pct=0.1, stop_loss_pct=0.02, take_profit_pct=0.04,
                 trailing_sl_pct=0.005, lot_model='percent_of_capital', # Added lot_model
                 tp_reward_bonus=0.01, sl_penalty=0.01, render_mode=None): # Added reward shaping params

        super().__init__()

        # --- Environment Parameters ---
        self.df = df.copy() # Use a copy to avoid modifying the original DataFrame
        self.initial_amount = float(initial_amount)
        self.lookback_window = int(lookback_window)
        self.buy_cost_pct = float(buy_cost_pct)
        self.sell_cost_pct = float(sell_cost_pct)
        self.max_drawdown_limit_pct = float(max_drawdown_limit_pct)
        self.position_size_pct = float(position_size_pct) # For percent_of_capital model
        self.stop_loss_pct = float(stop_loss_pct)
        self.take_profit_pct = float(take_profit_pct)
        self.trailing_sl_pct = float(trailing_sl_pct) # For Trailing SL
        self.lot_model = lot_model # 'percent_of_capital' or 'volatility'
        self.tp_reward_bonus = float(tp_reward_bonus) # Reward bonus for hitting TP
        self.sl_penalty = float(sl_penalty) # Penalty for hitting SL


        # --- Internal State ---
        self.current_step = self.lookback_window # Start after the lookback window
        self.account_balance = self.initial_amount
        self.net_worth = self.initial_amount # Cash + unrealized PnL
        self.max_net_worth = self.initial_amount # Track for drawdown
        self.open_position_units = 0 # Units of base currency (e.g., EUR in EUR/USD)
        self.average_entry_price = 0 # Average price of the current open position
        self.trailing_stop_loss_price = 0 # For Trailing SL


        # --- Observation Space Definition ---
        # Determine which technical indicators are present in the DataFrame
        # Exclude standard OHLCV columns and internal columns
        exclude_cols = ['open', 'high', 'low', 'close', 'volume', 'date', 'original_index', 'sequential_index']
        self.indicator_cols = [col for col in self.df.columns if col not in exclude_cols]

        # Define the observation dimension
        # lookback window of close prices + lookback window of each indicator + account_balance + open_position_units + average_entry_price
        # The number of features per step is 1 (close) + number of indicator columns
        num_features_per_step = 1 + len(self.indicator_cols)
        self.observation_dim = self.lookback_window * num_features_per_step + 3 # Added 3 for balance, units, entry price

        # Define the observation space (using Box for continuous values)
        # Use -np.inf and np.inf as bounds since values can vary widely, normalization is handled internally
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(self.observation_dim,), dtype=np.float32)

        # --- Action Space Definition ---
        # 0: Hold, 1: Buy, 2: Sell (Exit long position)
        self.action_space = spaces.Discrete(3)

        # --- Backtesting/Logging ---
        self.trades = [] # List to store trade details
        self.portfolio_history = [] # List to store portfolio value at each step


        # --- Volatility Model Specifics ---
        # atr_window is used in data_utils for calculation, assuming it's already added to df.
        # We need the ATR column if lot_model is 'volatility'
        if self.lot_model == 'volatility' and 'atr' not in self.df.columns:
            print("⚠️ Warning: 'atr' column not found in DataFrame but lot_model is 'volatility'. Falling back to 'percent_of_capital'.")
            self.lot_model = 'percent_of_capital'


    def _get_observation(self):
        """
        Generates the current observation for the agent.
        Includes lookback window of prices and indicators, and current state variables.
        Normalizes observations.
        """
        # Ensure we have enough data for the lookback window
        if self.current_step < self.lookback_window:
            # Not enough data for a full observation, return zeros or handle appropriately
            # Returning zeros might confuse the agent, but allows simulation to continue.
            # A better approach for real applications might be to pad or return None.
            # Given current_step starts at lookback_window, this check is mostly a safeguard.
            print(f"⚠️ Warning: Not enough data for lookback window at step {self.current_step}. Returning zero observation.")
            return np.zeros(self.observation_dim, dtype=np.float32)


        # Get the relevant historical data slice
        history_slice = self.df.iloc[self.current_step - self.lookback_window + 1 : self.current_step + 1]

        # Extract close prices and indicators for the lookback window
        close_prices = history_slice['close'].values
        indicators_data = history_slice[self.indicator_cols].values # Shape: (lookback_window, num_indicators)

        # Flatten and concatenate price and indicator data
        # Ensure all values are finite before flattening
        close_prices = np.nan_to_num(close_prices, nan=0.0, posinf=1e10, neginf=-1e10) # Handle potential non-finite values
        indicators_data = np.nan_to_num(indicators_data, nan=0.0, posinf=1e10, neginf=-1e10) # Handle potential non-finite values

        # Concatenate close prices and indicator data for each step in the lookback window
        # The structure should be [price_t-LW+1, ind1_t-LW+1, ind2_t-LW+1, ..., price_t, ind1_t, ind2_t, ...]
        # Need to stack price and indicators vertically for each timestep first
        price_and_indicators_stacked = np.hstack((close_prices.reshape(-1, 1), indicators_data)) # Shape: (lookback_window, 1 + num_indicators)
        price_indicator_features = price_and_indicators_stacked.flatten() # Flatten to 1D array


        # Get current state variables
        current_balance = self.account_balance
        current_position_units = self.open_position_units
        current_entry_price = self.average_entry_price

        # Combine all features
        raw_observation = np.concatenate([price_indicator_features,
                                          [current_balance, current_position_units, current_entry_price]])


        # --- Normalization (Simple example: Min-Max or Z-score could be better) ---
        # A simple normalization might involve scaling by initial amount for cash/position,
        # and by recent price ranges for prices/indicators.
        # For simplicity here, we'll use a basic scaling or leave as is, assuming the agent's network can handle it.
        # Proper normalization is crucial for RL performance.
        # Let's implement a basic normalization: scale prices/indicators by the last close price,
        # and scale balance/position by initial amount.

        last_close = close_prices[-1] if len(close_prices) > 0 and close_prices[-1] != 0 else 1.0 # Avoid division by zero
        if last_close == 0:
             # If last close is 0, scaling by it is impossible/meaningless.
             # Replace with a small value or handle as a special case.
             # For now, if last_close is 0, use 1.0 for scaling to avoid errors, but this might indicate data issues.
             last_close = 1.0


        # Scale prices and indicators by the last close price
        # Ensure price_indicator_features is not empty and has correct shape for reshaping
        if price_indicator_features.size > 0:
             # The price_indicator_features are already flattened in the desired order.
             # We need to iterate through the flattened array and scale price/indicator values
             # corresponding to each timestep in the lookback window by the *last* close price.
             # This is a simplification; ideally, you might scale each value by the price at its own timestep,
             # or use a more robust normalization like Z-score or scaling by recent volatility.
             # For now, scaling by the last close price as originally intended in the comment:
             normalized_price_indicator_features = price_indicator_features / last_close # Scale all features by the last close price


        else:
             normalized_price_indicator_features = price_indicator_features # If no features, keep empty


        # Scale balance and position by initial amount
        # Avoid division by zero if initial_amount is 0 (should not happen based on min_value in UI)
        initial_amount_safe = self.initial_amount if self.initial_amount > 0 else 1.0
        normalized_balance = current_balance / initial_amount_safe
        normalized_position_units = current_position_units / initial_amount_safe # Scaling units by initial amount might not be intuitive, maybe by max possible units? Let's keep it simple for now.
        normalized_entry_price = current_entry_price / last_close if last_close > 0 else 0.0 # Scale entry price by last close


        # Combine all normalized features
        normalized_observation = np.concatenate([normalized_price_indicator_features,
                                                 [normalized_balance, normalized_position_units, normalized_entry_price]])


        # Ensure the final observation has the correct shape
        if normalized_observation.shape[0] != self.observation_dim:
             # This is the error condition reported by the user.
             # Let's add more detailed logging here to help diagnose the mismatch.
             print(f"🚫 Error: Mismatch in observation dimension at step {self.current_step}.")
             print(f"  Expected dimension: {self.observation_dim}")
             print(f"  Actual dimension: {normalized_observation.shape[0]}")
             print(f"  Lookback Window: {self.lookback_window}")
             print(f"  Number of Indicator Columns: {len(self.indicator_cols)}")
             num_features_per_step_check = 1 + len(self.indicator_cols) # Re-calculate for logging
             print(f"  Number of features per step (1 + num_indicators): {num_features_per_step_check}")
             print(f"  Calculated expected dimension (LW * features_per_step + 3): {self.lookback_window} * {num_features_per_step_check} + 3 = {self.lookback_window * num_features_per_step_check + 3}")
             print(f"  Shape of price_indicator_features: {price_indicator_features.shape}")
             print(f"  Shape of state variables ([balance, units, entry_price]): {(3,)}")
             print(f"  Shape of concatenated raw_observation: {raw_observation.shape}")
             print(f"  Shape of normalized_observation: {normalized_observation.shape}")


             # Returning zero observation on error might hide the problem.
             # Let's raise an error or return None to make the issue clearer during debugging/training.
             # For now, returning zero observation as before, but with improved logging.
             return np.zeros(self.observation_dim, dtype=np.float32)


        # Ensure all values in the final observation are finite
        if not np.isfinite(normalized_observation).all():
             print(f"⚠️ Warning: Non-finite values detected in observation at step {self.current_step}. Replacing with zeros.")
             normalized_observation = np.nan_to_num(normalized_observation, nan=0.0, posinf=1e10, neginf=-1e10)


        return normalized_observation.astype(np.float32) # Ensure dtype is float32


    def reset(self, seed=None, options=None):
        """
        Resets the environment to an initial state.
        """
        super().reset(seed=seed) # For Gymnasium compatibility

        # Reset internal state
        self.current_step = self.lookback_window # Start simulation after the lookback window
        self.account_balance = self.initial_amount
        self.net_worth = self.initial_amount
        self.max_net_worth = self.initial_amount
        self.open_position_units = 0
        self.average_entry_price = 0
        self.trailing_stop_loss_price = 0

        # Clear logs
        self.trades = []
        self.portfolio_history = []
        # Log initial state (should be before the first step)
        self.portfolio_history.append({'date': self._get_current_date(),
                                        'portfolio_value': self.net_worth,
                                        'account_balance': self.account_balance, # Added for info
                                        'open_position_units': self.open_position_units, # Added for info
                                        'average_entry_price': self.average_entry_price, # Added for info
                                        'current_price': self._get_current_price(), # Added for info
                                        'step': self.current_step, # Step at reset (lookback_window)
                                        'action': 'Reset', # Indicate reset
                                        'reward': 0, # Reward at reset is 0
                                        'trade_executed': False,
                                        'trade_type': None,
                                        'trade_pnl': 0
                                       })


        # Get the initial observation
        observation = self._get_observation()

        # Return observation and info (Gymnasium standard)
        # The info dictionary returned by reset corresponds to the state *before* the first step()
        # Let's return a simplified info here and ensure the first entry in backtesting_results in dashboard
        # captures the full initial state.
        info = {'account_balance': self.account_balance,
                'net_worth': self.net_worth,
                'open_position_units': self.open_position_units,
                'average_entry_price': self.average_entry_price,
                'current_price': self._get_current_price(),
                'date': self._get_current_date(),
                'step': self.current_step} # Step at reset


        # Ensure observation is valid before returning
        if observation is None or not isinstance(observation, np.ndarray) or observation.shape[0] != self.observation_dim or not np.isfinite(observation).all():
             print(f"🚫 Error during reset: Invalid initial observation. Returning zero observation.")
             observation = np.zeros(self.observation_dim, dtype=np.float32) # Return zero observation on error
             # Ensure info is also consistent if observation is invalid
             info = {'account_balance': self.initial_amount,
                     'net_worth': self.initial_amount,
                     'open_position_units': 0,
                     'average_entry_price': 0,
                     'current_price': np.nan,
                     'date': self._get_current_date(),
                     'step': self.current_step,
                     'message': 'Invalid initial observation'}


        return observation, info


    def step(self, action):
        """
        Performs one step in the environment based on the chosen action.

        Args:
            action (int): The action chosen by the agent (0: Hold, 1: Buy, 2: Sell).

        Returns:
            tuple: (observation, reward, done, truncated, info)
            - observation (np.ndarray): The new observation.
            - reward (float): The reward received in this step.
            - done (bool): Whether the episode has ended (e.g., max drawdown reached, end of data).
            - truncated (bool): Whether the episode was truncated (e.g., time limit reached - not used here).
            - info (dict): Additional information.
        """
        # Ensure action is a valid integer
        if not isinstance(action, (int, np.integer)) or action not in [0, 1, 2]:
            print(f"⚠️ Warning: Invalid action received: {action}. Treating as Hold (0).")
            action = 0 # Default to Hold for invalid actions

        # Store previous net worth for reward calculation
        previous_net_worth = self.net_worth
        previous_account_balance = self.account_balance # Also track balance change if needed for reward


        # Get current market data (data point BEFORE the action is applied)
        # The action taken at step 't' uses observation from step 't', and affects the state at step 't+1'
        # The environment's current_step is the index of the data point being *processed*.
        # The observation generated by _get_observation() at step 't' uses data up to index 't'.
        # The step() function updates the state based on data at index 'self.current_step'.
        # So, current_price, current_date, current_atr should be from self.df.iloc[self.current_step]

        current_price_at_step = self._get_current_price() # Price at the data point BEFORE incrementing current_step
        current_date_at_step = self._get_current_date() # Date at the data point BEFORE incrementing current_step
        current_atr_at_step = self._get_current_atr() # ATR at the data point BEFORE incrementing current_step


        # Ensure current_price is valid
        if pd.isna(current_price_at_step) or current_price_at_step <= 0:
             print(f"⚠️ Warning: Invalid price ({current_price_at_step}) at step {self.current_step}. Cannot process action {action}. Ending episode.")
             # If price is invalid, we cannot trade. End the episode.
             # Increment current_step first before returning
             self.current_step += 1
             # Get observation for the *next* state (which might be invalid)
             observation = self._get_observation()
             reward = -self.initial_amount * 0.05 # Large penalty for invalid state
             done = True
             truncated = False
             info = {'account_balance': self.account_balance,
                     'net_worth': self.net_worth, # Log net_worth *before* invalid step? Or after? Let's log current state.
                     'open_position_units': self.open_position_units,
                     'average_entry_price': self.average_entry_price,
                     'current_price': current_price_at_step, # Log the invalid price
                     'date': current_date_at_step, # Log the date
                     'step': self.current_step - 1, # Log the step number where invalidity was detected
                     'action': self._action_to_text(action),
                     'reward': reward,
                     'message': f"Episode ended due to invalid price ({current_price_at_step})",
                     'portfolio_value': self.net_worth # Ensure portfolio_value is present
                    }
             return observation, reward, done, truncated, info


        # --- Process Action ---
        trade_executed = False # Flag to check if a trade happened in this step
        trade_type = None # 'buy', 'sell', 'sl', 'tp', 'trailing_sl'
        trade_price = None
        trade_units = 0
        trade_cost = 0
        trade_pnl = 0 # Profit/Loss for closed positions

        # Check for Stop Loss, Take Profit, Trailing Stop Loss first (applies if position is open)
        # These are evaluated based on the current_price_at_step
        if self.open_position_units > 0:
            # Calculate unrealized PnL *at this step's price*
            unrealized_pnl = (current_price_at_step - self.average_entry_price) * self.open_position_units

            # Check Take Profit
            if self.take_profit_pct > 0 and current_price_at_step >= self.average_entry_price * (1 + self.take_profit_pct):
                trade_type = 'tp'
                trade_price = self.average_entry_price * (1 + self.take_profit_pct) # Execute at TP price
                trade_units = self.open_position_units
                trade_cost = trade_units * trade_price * self.sell_cost_pct
                trade_pnl = unrealized_pnl - trade_cost # PnL includes cost
                self.account_balance += (trade_units * trade_price) - trade_cost # Add cash from selling
                self.open_position_units = 0
                self.average_entry_price = 0
                self.trailing_stop_loss_price = 0 # Reset Trailing SL
                trade_executed = True
                # print(f"TP Hit at step {self.current_step} ({current_date_at_step}): Price={trade_price:.5f}, PnL={trade_pnl:.2f}") # Debug print


            # Check Trailing Stop Loss (only if TP not hit)
            elif self.trailing_sl_pct > 0 and self.trailing_stop_loss_price > 0 and current_price_at_step <= self.trailing_stop_loss_price:
                 trade_type = 'trailing_sl'
                 trade_price = self.trailing_stop_loss_price # Execute at Trailing SL price
                 # Ensure execution price is not higher than current price if market gaps down
                 trade_price = min(trade_price, current_price_at_step) # Ensure realistic execution price on gap down

                 trade_units = self.open_position_units
                 trade_cost = trade_units * trade_price * self.sell_cost_pct
                 unrealized_pnl_at_sl = (trade_price - self.average_entry_price) * trade_units # PnL based on SL price
                 trade_pnl = unrealized_pnl_at_sl - trade_cost # PnL includes cost
                 self.account_balance += (trade_units * trade_price) - trade_cost # Add cash from selling
                 self.open_position_units = 0
                 self.average_entry_price = 0
                 self.trailing_stop_loss_price = 0 # Reset Trailing SL
                 trade_executed = True
                 # print(f"Trailing SL Hit at step {self.current_step} ({current_date_at_step}): Price={trade_price:.5f}, PnL={trade_pnl:.2f}") # Debug print

            # Check Stop Loss (only if TP and Trailing SL not hit)
            elif self.stop_loss_pct > 0 and current_price_at_step <= self.average_entry_price * (1 - self.stop_loss_pct):
                trade_type = 'sl'
                trade_price = self.average_entry_price * (1 - self.stop_loss_pct) # Execute at SL price
                 # Ensure execution price is not higher than current price if market gaps down
                trade_price = min(trade_price, current_price_at_step) # Ensure realistic execution price on gap down

                trade_units = self.open_position_units
                trade_cost = trade_units * trade_price * self.sell_cost_pct
                unrealized_pnl_at_sl = (trade_price - self.average_entry_price) * trade_units # PnL based on SL price
                trade_pnl = unrealized_pnl_at_sl - trade_cost # PnL includes cost
                self.account_balance += (trade_units * trade_price) - trade_cost # Add cash from selling
                self.open_position_units = 0
                self.average_entry_price = 0
                self.trailing_stop_loss_price = 0 # Reset Trailing SL
                trade_executed = True
                # print(f"SL Hit at step {self.current_step} ({current_date_at_step}): Price={trade_price:.5f}, PnL={trade_pnl:.2f}") # Debug print


            # If position is open and no automatic exit, update Trailing SL
            if self.open_position_units > 0 and not trade_executed and self.trailing_sl_pct > 0:
                 # Update Trailing SL: moves up if price makes a new high
                 current_trailing_sl = current_price_at_step * (1 - self.trailing_sl_pct)
                 if current_trailing_sl > self.trailing_stop_loss_price:
                      self.trailing_stop_loss_price = current_trailing_sl


        # If no automatic exit occurred, process the agent's action
        if not trade_executed:
            if action == 1: # Buy
                # Determine position size based on lot_model
                if self.lot_model == 'percent_of_capital':
                     # Buy using a percentage of current account balance
                     amount_to_spend = self.account_balance * self.position_size_pct
                     # Calculate how many units can be bought with the available amount
                     # Ensure current_price is not zero to avoid division by zero
                     if current_price_at_step > 0:
                          units_to_buy = amount_to_spend / current_price_at_step
                     else:
                          units_to_buy = 0
                          print(f"⚠️ Warning: Cannot calculate units to buy due to zero price at step {self.current_step}.")


                elif self.lot_model == 'volatility':
                     # Buy based on Volatility (e.g., Risk per trade / ATR)
                     # Assuming 'atr' column is available and valid
                     # Let's use a simplified ATR-based sizing: Units = (Account Balance * Risk %) / (ATR * Multiplier * Price per unit)
                     # Assuming Position Size Pct is interpreted as Risk % of Account Balance
                     atr_multiplier = 2 # Example multiplier for stop distance
                     if current_atr_at_step is not None and current_atr_at_step > 0 and self.position_size_pct > 0:
                           # Stop Loss Distance in price based on ATR
                           stop_distance_price = current_atr_at_step * atr_multiplier
                           if stop_distance_price > 0 and current_price_at_step > 0: # Also ensure price is valid
                                # Units = Risk Amount / Stop Loss Distance in Price
                                risk_amount = self.account_balance * self.position_size_pct
                                units_to_buy = risk_amount / stop_distance_price
                           else:
                                units_to_buy = 0 # Avoid division by zero or invalid calculation
                                print(f"⚠️ Warning: Cannot calculate units to buy (Volatility Model) due to invalid stop distance ({stop_distance_price}) or price ({current_price_at_step}) at step {self.current_step}.")
                     else:
                          units_to_buy = 0
                          print(f"⚠️ Warning: Cannot calculate units to buy (Volatility Model) due to invalid ATR ({current_atr_at_step}) or position_size_pct ({self.position_size_pct}) at step {self.current_step}.")


                else: # Fallback to percent_of_capital if lot_model is unknown or invalid
                    amount_to_spend = self.account_balance * self.position_size_pct
                    if current_price_at_step > 0:
                         units_to_buy = amount_to_spend / current_price_at_step
                    else:
                         units_to_buy = 0
                         print(f"⚠️ Warning: Cannot calculate units to buy (Default Model) due to zero price at step {self.current_step}.")


                # Ensure units_to_buy is non-negative and finite
                units_to_buy = max(0, np.nan_to_num(units_to_buy, nan=0.0, posinf=0.0, neginf=0.0))


                # Execute Buy if units_to_buy > 0 and enough balance
                cost = units_to_buy * current_price_at_step * (1 + self.buy_cost_pct) # Total cost includes fee
                if units_to_buy > 0 and self.account_balance >= cost:
                    # Calculate new average entry price
                    total_cost_of_existing_position = self.open_position_units * self.average_entry_price
                    new_total_units = self.open_position_units + units_to_buy
                    if new_total_units > 0:
                         # Weighted average entry price
                         self.average_entry_price = (total_cost_of_existing_position + (units_to_buy * current_price_at_step)) / new_total_units
                    else:
                         self.average_entry_price = 0 # Should not happen if units_to_buy > 0

                    self.account_balance -= cost
                    self.open_position_units += units_to_buy
                    trade_executed = True
                    trade_type = 'buy'
                    trade_price = current_price_at_step
                    trade_units = units_to_buy
                    trade_cost = cost # Store total cost including fee
                    trade_pnl = 0 # PnL is 0 at entry

                    # Set initial Trailing SL if enabled and this is a new position (or adding to existing)
                    if self.trailing_sl_pct > 0:
                         # When buying more, the trailing stop should be based on the *new* entry price
                         # or potentially adjusted based on the new total position.
                         # A simple approach: recalculate from the new average entry price and current price.
                         # Or, if adding to a winning position, maintain or raise the TSL.
                         # Let's use a simple method: update TSL based on current price and trailing_sl_pct IF it results in a higher TSL.
                         current_tsl_candidate = current_price_at_step * (1 - self.trailing_sl_pct)
                         if current_tsl_candidate > self.trailing_stop_loss_price:
                              self.trailing_stop_loss_price = current_tsl_candidate
                         # If it's a brand new position, set the TSL
                         if self.open_position_units == units_to_buy:
                              self.trailing_stop_loss_price = current_price_at_step * (1 - self.trailing_sl_pct)


                    # print(f"Buy executed at step {self.current_step} ({current_date_at_step}): Price={trade_price:.5f}, Units={trade_units:.2f}, Cost={cost:.2f}") # Debug print

                # else:
                    # print(f"Buy signal at step {self.current_step} ({current_date_at_step}), but no trade executed (units_to_buy={units_to_buy:.2f}, cost={cost:.2f}, balance={self.account_balance:.2f}).") # Debug print


            elif action == 2: # Sell (Exit long position)
                if self.open_position_units > 0:
                    # Sell the entire open position
                    amount_from_selling = self.open_position_units * current_price_at_step
                    cost = amount_from_selling * self.sell_cost_pct
                    amount_received = amount_from_selling - cost
                    unrealized_pnl = (current_price_at_step - self.average_entry_price) * self.open_position_units
                    trade_pnl = unrealized_pnl - cost # PnL includes cost

                    self.account_balance += amount_received
                    trade_units = self.open_position_units # Units being sold
                    self.open_position_units = 0
                    self.average_entry_price = 0
                    self.trailing_stop_loss_price = 0 # Reset Trailing SL
                    trade_executed = True
                    trade_type = 'sell' # Agent initiated sell
                    trade_price = current_price_at_step
                    trade_cost = cost # Store selling cost
                    # trade_pnl already calculated

                    # print(f"Sell executed at step {self.current_step} ({current_date_at_step}): Price={trade_price:.5f}, Units={trade_units:.2f}, PnL={trade_pnl:.2f}") # Debug print

                # else:
                    # print(f"Sell signal at step {self.current_step} ({current_date_at_step}), but no position to sell.") # Debug print


            # Action 0: Hold (do nothing, SL/TP/TSL check already handled)
            elif action == 0:
                 # print(f"Hold at step {self.current_step} ({current_date_at_step}).") # Debug print
                 pass # No action taken, SL/TP/TSL already checked before this block


        # --- Update Net Worth ---
        # Net worth is cash + value of open position (unrealized PnL) based on the price *at this step*
        unrealized_pnl_current = (current_price_at_step - self.average_entry_price) * self.open_position_units
        self.net_worth = self.account_balance + unrealized_pnl_current

        # Update max net worth for drawdown calculation
        self.max_net_worth = max(self.max_net_worth, self.net_worth)

        # --- Log Trade (if executed) ---
        # Log the trade using the details gathered during action processing for this step
        if trade_executed:
            self.trades.append({
                'date': current_date_at_step, # Date of the step when trade occurred
                'step': self.current_step, # Step number when trade occurred
                'type': 'entry' if trade_type == 'buy' else 'exit',
                'action': trade_type, # 'buy', 'sell', 'sl', 'tp', 'trailing_sl'
                'price': trade_price, # Price at which trade was executed
                'units': trade_units, # Units bought/sold
                'cost': trade_cost, # Transaction cost
                'pnl': trade_pnl if trade_type != 'buy' else 0, # PnL is for exit trades
                'account_balance': self.account_balance, # Balance AFTER trade
                'net_worth': self.net_worth # Net worth AFTER trade
            })


        # --- Calculate Reward ---
        # Reward is the change in net worth from the *previous* step to the *current* step
        # The net_worth calculated just above is the net worth *at the end of* the current step.
        # Previous net worth was stored at the beginning of step().
        reward = (self.net_worth - previous_net_worth) / self.initial_amount # Scale by initial capital

        # Add reward shaping bonuses/penalties based on the trade type that *executed* in this step
        if trade_executed:
             if trade_type == 'tp':
                  reward += self.tp_reward_bonus # Add bonus for hitting TP
                  # print(f"TP Bonus added: +{self.tp_reward_bonus * self.initial_amount:.2f}") # Debug print
             elif trade_type == 'sl' or trade_type == 'trailing_sl':
                  reward -= self.sl_penalty # Apply penalty for hitting SL/TSL
                  # print(f"SL/TSL Penalty applied: -{self.sl_penalty * self.initial_amount:.2f}") # Debug print


        # --- Check for Termination ---
        done = False
        truncated = False # Not using truncation for time limits in this env

        # Check max drawdown
        if self.max_net_worth > 0: # Avoid division by zero
             current_drawdown = (self.max_net_worth - self.net_worth) / self.max_net_worth
             if current_drawdown >= self.max_drawdown_limit_pct:
                 print(f"🚫 Max Drawdown Limit ({self.max_drawdown_limit_pct:.2%}) reached at step {self.current_step} ({current_date_at_step}). Net Worth: {self.net_worth:.2f}, Max Net Worth: {self.max_net_worth:.2f}")
                 # Optional: Add a penalty for hitting max drawdown - already included in reward calculation if net_worth drops significantly
                 # reward -= self.initial_amount * 0.10 # Decide if this penalty is separate or part of net_worth change
                 done = True


        # Increment current_step *after* processing the current step's data and calculating rewards/termination
        self.current_step += 1

        # Check if end of data is reached for the *next* step
        if self.current_step >= len(self.df):
            done = True
            # print(f"✅ End of data reached. Next step ({self.current_step}) is beyond data length ({len(self.df)}).")


        # --- Log Portfolio History ---
        # Log portfolio value *after* processing the step and updating net_worth,
        # but conceptually corresponds to the state *at the end of* the step just processed.
        # Use the date of the *processed* step.
        self.portfolio_history.append({'date': current_date_at_step, # Use date of the step just processed
                                        'portfolio_value': self.net_worth,
                                        'account_balance': self.account_balance, # Added for info
                                        'open_position_units': self.open_position_units, # Added for info
                                        'average_entry_price': self.average_entry_price, # Added for info
                                        'current_price': current_price_at_step, # Added for info
                                        'step': self.current_step -1 # Log the step number just processed
                                       })


        # --- Prepare next observation and info ---
        # Get observation for the *next* step (self.current_step is now the index for the next observation)
        observation = self._get_observation() # _get_observation uses self.current_step


        # Ensure observation is valid before returning
        if observation is None or not isinstance(observation, np.ndarray) or observation.shape[0] != self.observation_dim or not np.isfinite(observation).all():
             print(f"🚫 Error getting observation for next step ({self.current_step}): Invalid observation. Setting done=True.")
             observation = np.zeros(self.observation_dim, dtype=np.float32) # Return zero observation on error
             done = True # End episode if next observation is invalid


        # The info dictionary returned by step should reflect the state *after* the action and updates for the current step.
        # It should also contain information needed by the dashboard for logging/metrics.
        # Let's build the info dictionary here based on the state *after* the current step is processed.
        info = {'account_balance': self.account_balance,
                'net_worth': self.net_worth,
                'open_position_units': self.open_position_units,
                'average_entry_price': self.average_entry_price,
                'current_price': current_price_at_step, # Price of the step just processed
                'date': current_date_at_step, # Date of the step just processed
                'step': self.current_step - 1, # Step number just processed
                'action': self._action_to_text(action), # Log the action taken in this step
                'reward': reward, # Log the reward for this step
                'trade_executed': trade_executed, # Indicate if a trade happened
                'trade_type': trade_type, # Type of trade if executed
                'trade_pnl': trade_pnl, # PnL if position was closed
                'portfolio_value': self.net_worth # Explicitly include portfolio_value for the dashboard
               }


        return observation, reward, done, truncated, info


    def render(self, mode='human'):
        """
        Renders the environment state (optional).
        """
        # In a dashboard context, rendering is handled by the UI (e.g., Plotly chart)
        # This method might not be needed or could print state info.
        if mode == 'human':
            pass
            # print(f"Step: {self.current_step}, Date: {self._get_current_date()}, Price: {self._get_current_price():.5f}, Net Worth: {self.net_worth:.2f}, Position: {self.open_position_units:.2f}")


    def close(self):
        """
        Cleans up resources (if any).
        """
        pass # No external resources to close


    # --- Helper methods ---
    def _get_current_price(self):
        """Gets the 'close' price for the current step."""
        # Ensure current_step is within bounds and 'close' column exists
        if self.current_step < len(self.df) and 'close' in self.df.columns:
             return float(self.df.iloc[self.current_step]['close']) # Ensure float type
        return np.nan # Return NaN if out of bounds or column missing

    def _get_current_date(self):
        """Gets the 'date' for the current step."""
        # Ensure current_step is within bounds and 'date' column exists or index is datetime
        if self.current_step < len(self.df):
             if 'date' in self.df.columns:
                  return self.df.iloc[self.current_step]['date']
             elif isinstance(self.df.index, pd.DatetimeIndex):
                  return self.df.index[self.current_step]
        return None # Return None if out of bounds or date not found

    def _get_current_atr(self):
        """Gets the 'atr' value for the current step, if available."""
        if self.current_step < len(self.df) and 'atr' in self.df.columns:
             # Ensure ATR is a valid number before returning
             atr_value = self.df.iloc[self.current_step]['atr']
             if pd.notna(atr_value) and atr_value is not None and np.isfinite(atr_value):
                 return float(atr_value) # Ensure float type
             return None # Return None if ATR is NaN, None, or Inf
        return None # Return None if out of bounds or 'atr' column missing


    def _action_to_text(self, action):
        """Converts action integer to descriptive text."""
        if action == 0: return 'Hold'
        if action == 1: return 'Buy'
        if action == 2: return 'Sell'
        return 'Unknown'


# Function to calculate performance metrics
def calculate_metrics(portfolio_values, trades_log, initial_amount):
    """
    Calculates various trading performance metrics.

    Args:
        portfolio_values (pd.Series): Series of portfolio total values over time (indexed by date).
        trades_log (list of dict): List of executed trades.
        initial_amount (float): The initial capital.

    Returns:
        dict: A dictionary of calculated metrics.
    """
    metrics = {}

    # Check if portfolio_values is a pandas Series and has enough data points
    # Corrected check to only accept pandas Series with at least 2 data points
    if not isinstance(portfolio_values, pd.Series) or len(portfolio_values) < 2:
        print("⚠️ Warning: Invalid or insufficient portfolio_values data for metrics calculation. Expected pandas Series with >= 2 points.")
        # Return default metrics
        metrics['Total Return (%)'] = 0
        metrics['CAGR (%)'] = 0
        metrics['Sharpe Ratio'] = 0
        metrics['Sortino Ratio'] = 0
        metrics['Max Drawdown (%)'] = 0
        metrics['Total Trades'] = len(trades_log) if trades_log else 0 # Ensure trades_log is not None
        metrics['Winning Trades'] = 0
        metrics['Losing Trades'] = 0
        metrics['Win Rate (%)'] = 0
        metrics['Average PnL per Trade'] = 0
        metrics['Average Winning Trade'] = 0
        metrics['Average Losing Trade'] = 0
        metrics['Profit Factor'] = 0
        return metrics


    # Ensure initial amount is positive to avoid division by zero or log(0) issues
    initial_amount_safe = initial_amount if initial_amount > 0 else 1.0

    # Ensure portfolio values are all finite numbers
    if not np.isfinite(portfolio_values).all():
         print("⚠️ Warning: Non-finite values found in portfolio_values. Metrics might be inaccurate.")
         # Optionally replace non-finite values or handle as error
         portfolio_series = np.nan_to_num(portfolio_values, nan=initial_amount_safe, posinf=portfolio_values.max(), neginf=portfolio_values.min()) # Simple handling
         portfolio_series = pd.Series(portfolio_series, index=portfolio_values.index) # Recreate series with original index
    else:
         portfolio_series = portfolio_values.copy() # Work on a copy


    total_return = (portfolio_series.iloc[-1] - initial_amount_safe) / initial_amount_safe * 100
    metrics['Total Return (%)'] = round(total_return, 2)

    # CAGR (Compound Annual Growth Rate)
    # Requires DatetimeIndex
    if isinstance(portfolio_series.index, pd.DatetimeIndex) and len(portfolio_series.index) > 1:
        # Check if there's a meaningful time duration
        duration_years = (portfolio_series.index[-1] - portfolio_series.index[0]).days / 365.25
        # Ensure initial amount is positive for ratio calculation
        if duration_years > 0 and initial_amount_safe > 0:
             # Avoid log(0) if initial_amount is 0 or becomes 0
             if portfolio_series.iloc[-1] > 0:
                  cagr = ((portfolio_series.iloc[-1] / initial_amount_safe) ** (1 / duration_years) - 1) * 100
                  metrics['CAGR (%)'] = round(cagr, 2)
             else:
                  metrics['CAGR (%)'] = -100 # Represents total loss
        else:
             metrics['CAGR (%)'] = 0 # Cannot calculate if duration is zero or initial amount is zero (or data is only one point)
    else:
         # Fallback or warning if date/frequency info is not available or only one data point
         metrics['CAGR (%)'] = 0 # Cannot calculate accurately without sufficient dates


    # Daily returns for volatility/risk metrics
    # Ensure portfolio_series has DatetimeIndex for resampling if needed for daily returns
    # Or calculate simple period returns
    if isinstance(portfolio_series.index, pd.DatetimeIndex):
        # Attempt daily resampling if data is higher frequency than daily
        try:
             # Resample to daily, using the last value of the day
             daily_portfolio_values = portfolio_series.resample('D').ffill().dropna()
             if len(daily_portfolio_values) > 1:
                  returns = daily_portfolio_values.pct_change().dropna()
             else:
                  # If resampling results in too few points, use period returns
                  returns = portfolio_series.pct_change().dropna() # Use original frequency returns
                  print("⚠️ Warning: Daily resampling failed or resulted in insufficient data. Using original frequency returns for Sharpe/Sortino.")

        except Exception as e:
             print(f"⚠️ Warning: Error during daily resampling: {e}. Using original frequency returns for Sharpe/Sortino.")
             returns = portfolio_series.pct_change().dropna() # Fallback to original frequency

    else:
         print("⚠️ Warning: Portfolio values do not have DatetimeIndex. Using raw period returns for Sharpe/Sortino.")
         returns = portfolio_series.pct_change().dropna() # Use original frequency returns


    # Annualization factors (assuming daily for Sharpe/Sortino if daily returns were calculated)
    # If using original frequency returns, need to know the frequency
    annualization_factor = 1 # Default if frequency is unknown or single period
    if isinstance(returns.index, pd.DatetimeIndex):
        # Try to infer frequency or assume daily if resampled
        inferred_freq = pd.infer_freq(returns.index)
        if inferred_freq in ['D', 'B']: # Daily or Business Daily
             annualization_factor = np.sqrt(252) # Trading days
        elif inferred_freq in ['H']: # Hourly
             annualization_factor = np.sqrt(252 * 24) # Trading hours (approx)
        # Add other frequencies as needed

    # Sharpe Ratio
    # Assuming risk-free rate is 0 for simplicity
    if not returns.empty and returns.std() != 0:
         sharpe_ratio = returns.mean() / returns.std() * annualization_factor
         metrics['Sharpe Ratio'] = round(sharpe_ratio, 2)
    else:
         metrics['Sharpe Ratio'] = 0 # Or np.nan


    # Sortino Ratio (uses downside deviation)
    downside_returns = returns[returns < 0]
    if not downside_returns.empty:
         downside_std = downside_returns.std()
         if downside_std != 0:
              sortino_ratio = (returns.mean() / downside_std) * annualization_factor
              metrics['Sortino Ratio'] = round(sortino_ratio, 2)
         else:
              metrics['Sortino Ratio'] = 0 # Downside std is 0 means no losing periods
    else:
         metrics['Sortino Ratio'] = 0 # No downside deviation if no negative returns


    # Max Drawdown (already calculated correctly above)
    # Recalculate Max Drawdown just in case the passed series is different
    peak = portfolio_series.expanding(min_periods=1).max()
    drawdown = (portfolio_series - peak) / peak
    max_drawdown = drawdown.min() * 100 if not drawdown.empty else 0
    metrics['Max Drawdown (%)'] = round(max_drawdown, 2)


    # Trade Analysis
    exit_trades = [t for t in trades_log if t.get('type') == 'exit'] # Use .get to avoid KeyError
    metrics['Total Trades'] = len(exit_trades)

    if exit_trades:
        # Ensure 'pnl' key exists and is a number
        trade_pnls = [t.get('pnl', 0) for t in exit_trades if isinstance(t.get('pnl', 0), (int, float))] # Filter for valid pnl

        winning_trades = [pnl for pnl in trade_pnls if pnl > 0]
        losing_trades = [pnl for pnl in trade_pnls if pnl < 0]

        metrics['Winning Trades'] = len(winning_trades)
        metrics['Losing Trades'] = len(losing_trades)
        # Ensure total trades is not zero for division
        metrics['Win Rate (%)'] = round((metrics['Winning Trades'] / metrics['Total Trades']) * 100, 2) if metrics['Total Trades'] > 0 else 0

        metrics['Average PnL per Trade'] = round(sum(trade_pnls) / metrics['Total Trades'], 2) if metrics['Total Trades'] > 0 else 0
        metrics['Average Winning Trade'] = round(sum(winning_trades) / metrics['Winning Trades'], 2) if metrics['Winning Trades'] > 0 else 0
        metrics['Average Losing Trade'] = round(sum(losing_trades) / metrics['Losing Trades'], 2) if metrics['Losing Trades'] > 0 else 0 # Keep as negative average

        total_winnings = sum(winning_trades)
        total_losses = abs(sum(losing_trades)) # Use absolute value for profit factor denominator
        metrics['Profit Factor'] = round(total_winnings / total_losses, 2) if total_losses > 0 else (100 if total_winnings > 0 else 0) # Handle division by zero


    else: # No exit trades
        metrics['Winning Trades'] = 0
        metrics['Losing Trades'] = 0
        metrics['Win Rate (%)'] = 0
        metrics['Average PnL per Trade'] = 0
        metrics['Average Winning Trade'] = 0
        metrics['Average Losing Trade'] = 0
        metrics['Profit Factor'] = 0 # Or 1.0 if no losses, but 0 is safer


    # Ensure all metric values are finite before returning
    for key, value in metrics.items():
         if not np.isfinite(value):
              metrics[key] = 0 # Replace non-finite values with 0 or another sensible default


    return metrics

# Function to run a single experiment (for main.py or testing)
# This function is less relevant for the Streamlit dashboard's main loop,
# but kept here for completeness or potential use elsewhere.
# Removed this function as it's not used by the Streamlit dashboard and might cause confusion.
# The Streamlit dashboard implements its own backtesting loop.

Writing env/forex_env.py


In [12]:
!pip install -q ta

In [13]:
%%writefile data_utils.py
import pandas as pd
import yfinance as yf
import ta
import numpy as np
import re # Import regex for cleaning column names

def download_forex_data(ticker="EURUSD=X", start_date="2015-01-01", end_date="2023-12-31", timeframe="1d"):
    """
    Downloads historical Forex data using yfinance and performs initial cleaning.

    Args:
        ticker (str): Ticker symbol for the Forex pair (default: "EURUSD=X").
        start_date (str): Start date for data download (YYYY-MM-DD).
        end_date (str): End date for data download (YYYY-MM-DD).
        timeframe (str): Data interval (e.g., "1d", "1h", "15m", "5m", "1m"). Default is "1d".

    Returns:
        pd.DataFrame or None: Downloaded and initially cleaned DataFrame, or None if download fails.
    """
    print(f"📊 Изтегляне на данни за {ticker} в таймфрейм {timeframe}...")
    try:
        # Download historical data for the specified ticker and timeframe
        # Using auto_adjust=True to get adjusted close prices directly
        data = yf.download(ticker, start=start_date, end=end_date, interval=timeframe, auto_adjust=True) # Added interval parameter
        # Drop rows with any missing values immediately after download
        data = data.dropna()
        print("✅ Данните са изтеглени и почистени от липсващи стойности.")
        return data
    except Exception as e:
        print(f"🚫 Грешка при изтегляне на данни за {ticker} ({timeframe}): {e}")
        return None

# === Helper Functions for Indicators/Patterns ===

# === Classic Pivot Points ===
def calculate_classic_pivots(df):
    """
    Calculates Classic Pivot Points (P, R1, S1, R2, S2) based on daily OHLC.
    Merges the daily pivot levels back into the original DataFrame.
    """
    # Ensure the input DataFrame has a DatetimeIndex for resampling
    if not isinstance(df.index, pd.DatetimeIndex):
         print("⚠️ Warning: DataFrame index is not DatetimeIndex. Cannot calculate daily pivots.")
         # Return original DataFrame with NaN columns for pivots
         df['pivot_P'] = np.nan
         df['pivot_R1'] = np.nan
         df['pivot_S1'] = np.nan
         df['pivot_R2'] = np.nan
         df['pivot_S2'] = np.nan
         return df

    try:
        # Aggregate to daily data to get High, Low, Close for pivot calculation
        # Use .reset_index() to handle potential MultiIndex from yfinance
        daily_df = df.copy().reset_index()
        # Ensure date column is datetime before setting as index for resampling
        if 'Date' in daily_df.columns: # Handle potential default yfinance 'Date' column
             daily_df.rename(columns={'Date': 'date'}, inplace=True)
        if 'date' in daily_df.columns:
             daily_df['date'] = pd.to_datetime(daily_df['date'], errors='coerce')
             daily_df.dropna(subset=['date'], inplace=True)
             daily_df = daily_df.set_index('date').sort_index()
             # Check if required OHLC columns are present after potential MultiIndex flattening
             required_ohlc_daily = ['open', 'high', 'low', 'close']
             if not all(col in daily_df.columns for col in required_ohlc_daily):
                  print(f"🚫 Error aggregating daily data for pivots: Missing required OHLC columns. Found: {list(daily_df.columns)}")
                  # Return original DataFrame with NaN columns for pivots
                  df['pivot_P'] = np.nan
                  df['pivot_R1'] = np.nan
                  df['pivot_S1'] = np.nan
                  df['pivot_R2'] = np.nan
                  df['pivot_S2'] = np.nan
                  return df

             daily = daily_df.resample('1D').agg({'high': 'max', 'low': 'min', 'close': 'last'}).dropna()

             if daily.empty:
                  print("⚠️ Warning: Daily aggregated data is empty after dropna. Cannot calculate pivots.")
                  # Return original DataFrame with NaN columns for pivots
                  df['pivot_P'] = np.nan
                  df['pivot_R1'] = np.nan
                  df['pivot_S1'] = np.nan
                  df['pivot_R2'] = np.nan
                  df['pivot_S2'] = np.nan
                  return df


             # Calculate Pivot Points
             pivots_daily = pd.DataFrame(index=daily.index)
             pivots_daily['P'] = (daily['high'] + daily['low'] + daily['close']) / 3
             pivots_daily['R1'] = 2 * pivots_daily['P'] - daily['low']
             pivots_daily['S1'] = 2 * pivots_daily['P'] - daily['high']
             pivots_daily['R2'] = pivots_daily['P'] + (daily['high'] - daily['low'])
             pivots_daily['S2'] = pivots_daily['P'] - (daily['high'] - daily['low'])

             # Shift pivots by one day so that today's pivots are based on yesterday's data
             pivots_daily = pivots_daily.shift(1)

             # Rename columns to avoid conflicts and clarify they are pivot levels
             pivots_daily.rename(columns={
                 'P': 'pivot_P',
                 'R1': 'pivot_R1',
                 'S1': 'pivot_S1',
                 'R2': 'pivot_R2',
                 'S2': 'pivot_S2'
             }, inplace=True)

             # Reindex the daily pivots DataFrame to match the original DataFrame's index (which is higher frequency)
             # This will forward-fill the daily pivot values to all rows within that day
             # Ensure both DataFrames have DatetimeIndex and are sorted for correct merging/reindexing
             df_sorted = df.sort_index()
             pivots_daily_sorted = pivots_daily.sort_index()

             # Use reindex and ffill to merge daily pivots into the high-frequency data
             # Create a combined index to ensure all timestamps are covered
             combined_index = df_sorted.index.union(pivots_daily_sorted.index)
             pivots_reindexed = pivots_daily_sorted.reindex(combined_index).ffill()

             # Now merge the reindexed pivots back into the original DataFrame
             # We only need the pivot columns
             pivot_cols_to_merge = ['pivot_P', 'pivot_R1', 'pivot_S1', 'pivot_R2', 'pivot_S2']
             # Select only the rows from pivots_reindexed whose index is present in df_sorted
             merged_df = df_sorted.join(pivots_reindexed[pivot_cols_to_merge], how='left')

             print("✅ Класически пивот пойнти изчислени.")
             return merged_df # Return the DataFrame with pivot columns added

        else:
             print("⚠️ Warning: Could not find or convert date column for daily pivot calculation.")
             # Return original DataFrame with NaN columns for pivots
             df['pivot_P'] = np.nan
             df['pivot_R1'] = np.nan
             df['pivot_S1'] = np.nan
             df['pivot_R2'] = np.nan
             df['pivot_S2'] = np.nan
             return df


    except Exception as e:
        print(f"🚫 Грешка при изчисляване на Класически пивот пойнти: {e}")
        # Return original DataFrame with NaN columns for pivots on error
        df['pivot_P'] = np.nan
        df['pivot_R1'] = np.nan
        df['pivot_S1'] = np.nan
        df['pivot_R2'] = np.nan
        df['pivot_S2'] = np.nan
        return df


# === Candlestick Patterns ===
# Using the 'ta' library for pattern detection as it's more efficient
# This function will add binary columns (0 or 1) for detected patterns.
def add_candlestick_patterns(df):
    """
    Adds binary columns for various candlestick patterns using the 'ta' library.
    """
    processed_data = df.copy()

    try:
        # Ensure OHLC columns are present and in lowercase
        required_ohlc = ['open', 'high', 'low', 'close']
        if not all(col in processed_data.columns for col in required_ohlc):
            print(f"🚫 Error adding candlestick patterns: Missing required OHLC columns. Found: {list(processed_data.columns)}")
            # Return original DataFrame without pattern columns
            return processed_data


        print("ℹ️ Добавяне на японски свещни патерни...")

        # Example patterns using ta.earliest_signal() which returns 1 if pattern is true, 0 otherwise
        # You can add more patterns from ta.clean.candlestick.TACCandlestick
        # Note: ta.earliest_signal requires specific column names (Open, High, Low, Close) - ensure case matches or use .lower() consistently

        # Ensure OHLC columns are correctly cased if ta expects specific case
        # Let's assume ta expects lowercase based on typical usage, but it's good to double-check documentation if issues arise.
        # No need to do complex casing checks here, assuming previous steps ensured lowercase.

        # Bullish Patterns
        try:
            # Existing patterns
            processed_data['pattern_bullish_engulfing'] = ta.cdl_engulfing(processed_data['open'], processed_data['high'], processed_data['low'], processed_data['close']) > 0
            processed_data['pattern_hammer'] = ta.cdl_hammer(processed_data['open'], processed_data['high'], processed_data['low'], processed_data['close']) > 0
            processed_data['pattern_morning_star'] = ta.cdl_morningstar(processed_data['open'], processed_data['high'], processed_data['low'], processed_data['close']) > 0
            processed_data['pattern_three_white_soldiers'] = ta.cdl_3whitesoldiers(processed_data['open'], processed_data['high'], processed_data['low'], processed_data['close']) > 0

            # New patterns from user request
            processed_data['pattern_doji'] = ta.cdl_doji(processed_data['open'], processed_data['high'], processed_data['low'], processed_data['close']) != 0 # Doji is non-zero for Doji
            # Note: ta.cdl_morningstar and ta.cdl_eveningstar are already included above
            processed_data['pattern_piercing_line'] = ta.cdl_piercing(processed_data['open'], processed_data['high'], processed_data['low'], processed_data['close']) > 0


            # Bearish Patterns
            # Existing patterns
            processed_data['pattern_bearish_engulfing'] = ta.cdl_engulfing(processed_data['open'], processed_data['high'], processed_data['low'], processed_data['close']) < 0
            processed_data['pattern_hanging_man'] = ta.cdl_hangingman(processed_data['open'], processed_data['high'], processed_data['low'], processed_data['close']) > 0
            processed_data['pattern_evening_star'] = ta.cdl_eveningstar(processed_data['open'], processed_data['high'], processed_data['low'], processed_data['close']) > 0
            processed_data['pattern_three_black_crows'] = ta.cdl_3blackcrows(processed_data['open'], processed_data['high'], processed_data['low'], processed_data['close']) > 0

            # New patterns from user request
            processed_data['pattern_dark_cloud_cover'] = ta.cdl_darkcloudcover(processed_data['open'], processed_data['high'], processed_data['low'], processed_data['close']) < 0


            # Convert boolean results to integer (1 for true, 0 for false)
            pattern_cols = [col for col in processed_data.columns if col.startswith('pattern_')]
            for col in pattern_cols:
                 # ta returns integer values, not boolean. Convert to binary (1 or 0).
                 # Any non-zero value from ta pattern function indicates the pattern.
                 processed_data[col] = (processed_data[col] != 0).astype(int)


            print("✅ Японски свещни патерни добавени.")

        except Exception as e:
            print(f"🚫 Грешка при добавяне на японски свещни патерни с 'ta' библиотеката: {e}")
            # Ensure pattern columns are not left with partial results or errors
            pattern_cols_created = [col for col in processed_data.columns if col.startswith('pattern_')]
            processed_data.drop(columns=pattern_cols_created, errors='ignore', inplace=True) # Drop any partially created columns


    except Exception as e:
        print(f"🚫 Неочаквана грешка при функцията за добавяне на патерни: {e}")


    return processed_data


def add_technical_indicators(df, atr_window=14):
    """
    Adds technical indicators, Classic Pivot Points, and Candlestick Patterns
    to the DataFrame.
    Assumes input df has 'open', 'high', 'low', 'close', 'volume' columns (case-insensitive).

    Args:
        df (pd.DataFrame): Input DataFrame with raw price data.
        atr_window (int): Window size for ATR calculation. Default is 14.

    Returns:
        pd.DataFrame or None: DataFrame with added technical indicators, pivots, and patterns,
                              and rows with initial NaNs removed, or None if required columns
                              are missing or all rows are dropped.
    """
    processed_data = df.copy()

    # --- Add logging for raw columns immediately ---
    print(f"ℹ️ Raw columns upon entering add_technical_indicators: {list(processed_data.columns)}")
    print(f"ℹ️ Raw column dtypes upon entering add_technical_indicators: {processed_data.dtypes}")
    print(f"ℹ️ Number of rows upon entering add_technical_indicators: {len(processed_data)}")


    # --- Corrected Column Handling ---
    # Check if columns are a MultiIndex (often happens with yfinance)
    if isinstance(processed_data.columns, pd.MultiIndex):
        print("ℹ️ Detected MultiIndex columns. Attempting to flatten.")
        # Flatten MultiIndex by taking the first level name (e.g., 'Open', 'High', 'Close')
        # If a column name is empty after flattening (e.g., only ticker was present), use the full original name
        # Also handle potential 'Adj Close' from yfinance by renaming it to 'close'
        new_columns = []
        for col in processed_data.columns:
             if isinstance(col, tuple):
                  col_name = col[0] if col[0] else '_'.join(str(c) for c in col)
                  if col_name == 'Adj Close': # Handle Adj Close specifically
                       col_name = 'close'
                  new_columns.append(col_name)
             else:
                  col_name = str(col)
                  if col_name == 'Adj Close': # Handle Adj Close specifically
                       col_name = 'close'
                  new_columns.append(col_name)
        processed_data.columns = new_columns
        print(f"✅ Columns flattened: {list(processed_data.columns)}")


    # Ensure columns are strings and strip whitespace
    processed_data.columns = [str(col).strip() for col in processed_data.columns]

    # Remove emojis or special characters from column names using regex
    processed_data.columns = [re.sub(r'[^\w\s]', '', col) for col in processed_data.columns]
    print(f"✅ Special characters/emojis removed from column names: {list(processed_data.columns)}")


    # Convert column names to lowercase
    processed_data.columns = processed_data.columns.str.lower()
    print(f"✅ Column names converted to lowercase: {list(processed_data.columns)}")


    # Add translation map for Bulgarian column names
    translation_map = {
        'отвори': 'open',
        'връх': 'high',
        'дъно': 'low',
        'затвори': 'close',
        'обем': 'volume',
        'време': 'date'
    }

    # Rename columns using the translation map if they exist
    # Use a dictionary comprehension to filter for columns that actually exist in the DataFrame
    columns_to_rename = {col: translation_map[col] for col in translation_map if col in processed_data.columns}
    if columns_to_rename:
        processed_data.rename(columns=columns_to_rename, inplace=True)
        print(f"✅ Bulgarian column names translated: {columns_to_rename}")
        print(f"✅ Columns after translation: {list(processed_data.columns)}")
    else:
        print("ℹ️ No Bulgarian columns found to translate.")


    # Check for required columns after all renaming/cleaning
    required_cols = ['open', 'high', 'low', 'close', 'volume']
    missing_cols = [col for col in required_cols if col not in processed_data.columns]
    if missing_cols:
         print(f"🚫 Error in preprocessing: Missing required price columns after cleaning and translation. Missing: {missing_cols}, Found: {list(processed_data.columns)}")
         return None

    # Check if the 'volume' column is all zeros or NaN
    if 'volume' in processed_data.columns:
        if processed_data['volume'].sum() == 0:
            print("Warning: 'volume' column contains only zeros. Technical indicators relying on volume might be affected.")
        elif processed_data['volume'].isnull().all():
             print("Warning: 'volume' column contains only NaN values.")


    # --- Date Handling: Preserve original index/date ---
    # Check if a 'date' column already exists or if the index is a DatetimeIndex
    original_date_col_exists = 'date' in processed_data.columns
    original_index_is_datetime = isinstance(processed_data.index, pd.DatetimeIndex)

    # If the original index is a DatetimeIndex, reset and rename it to 'date' if no 'date' column exists
    if original_index_is_datetime and not original_date_col_exists:
         print("ℹ️ Original index is DatetimeIndex and no 'date' column exists. Resetting index and renaming.")
         processed_data = processed_data.reset_index()
         # Check if the resulting column is named 'index' or something else and rename to 'date'
         # Use the original index name if available, otherwise default to 'index'
         index_col_name = df.index.name if df.index.name is not None else 'index'
         if index_col_name in processed_data.columns:
              processed_data.rename(columns={index_col_name: 'date'}, inplace=True)
              original_date_col_exists = True
              print(f"✅ Renamed index column '{index_col_name}' to 'date'.")
         else:
              print(f"⚠️ Warning: DatetimeIndex found but could not rename index column '{index_col_name}' to 'date' after reset.")


    # If after attempting to rename index, there's still no 'date' column and the index is not DatetimeIndex,
    # and no 'original_index' or 'level_0' (from default reset_index) exists,
    # create a simple sequential index column as a fallback.
    if not original_date_col_exists and not isinstance(processed_data.index, pd.DatetimeIndex) and 'original_index' not in processed_data.columns and 'level_0' not in processed_data.columns:
         print("ℹ️ No 'date' column and index is not DatetimeIndex. Creating 'original_index' column.")
         processed_data = processed_data.reset_index()
         if 'index' in processed_data.columns:
              processed_data.rename(columns={'index': 'original_index'}, inplace=True)
              print("ℹ️ Could not identify or convert original index to datetime. Proceeding with 'original_index' column.")
         elif 'level_0' in processed_data.columns: # Handle default name from reset_index
              processed_data.rename(columns={'level_0': 'original_index'}, inplace=True)
              print("ℹ️ Could not identify or convert original index to datetime. Proceeding with 'original_index' column.")


    # Ensure 'date' column is datetime if it exists
    if 'date' in processed_data.columns:
        try:
            processed_data['date'] = pd.to_datetime(processed_data['date'], errors='coerce')
            processed_data.dropna(subset=['date'], inplace=True) # Drop rows where date conversion failed
            if not processed_data.empty:
                 # Set date as index temporarily for resampling/joining if needed
                 processed_data = processed_data.set_index('date').sort_index()
                 original_index_is_datetime = True # Now the index is datetime
                 print("✅ 'date' column converted to datetime and set as index.")
            else:
                 print("🚫 All rows dropped after converting 'date' column to datetime and dropping NaT.")
                 return None

        except Exception as e:
            print(f"🚫 Error converting 'date' column to datetime: {e}. Keeping original index/columns.")
            # If date conversion fails, revert to the state before setting index
            # If the index was already datetime, keep it as is. Otherwise, reset.
            if 'date' in processed_data.columns and not isinstance(processed_data.index, pd.DatetimeIndex):
                # Only reset index if it wasn't already a DatetimeIndex
                # Check if 'date' column was created from reset_index before
                if processed_data.index.name == 'level_0' or processed_data.index.name == 'index': # Heuristic for default reset_index names
                     processed_data = processed_data.reset_index(drop=False) # Keep date column
                elif 'date' in processed_data.columns: # If 'date' column existed originally
                     pass # Keep the existing 'date' column, index is not datetime
                print("ℹ️ Date conversion failed, proceeding without DatetimeIndex.")

            original_index_is_datetime = False


    # Add logging before indicator calculation
    print(f"ℹ️ Данни преди изчисляване на индикатори. Първи 5 реда:")
    print(processed_data.head())
    print(f"ℹ️ Колони в данните: {list(processed_data.columns)}")
    print(f"ℹ️ Типове данни на колони: {processed_data.dtypes}")
    print(f"ℹ️ Брой редове преди изчисляване на индикатори: {len(processed_data)}")


    try:
        # --- Calculate Standard Technical Indicators ---
        # Add try-except for each indicator calculation

        # Simple Moving Average (SMA)
        window_length_sma = 20
        try:
            print(f"ℹ️ Изчисляване на SMA (window={window_length_sma})...")
            sma_indicator = ta.trend.SMAIndicator(close=processed_data['close'], window=window_length_sma)
            processed_data['sma'] = sma_indicator.sma_indicator()
            print("✅ SMA изчислен.")
        except Exception as e:
            print(f"🚫 Грешка при изчисляване на SMA: {e}")
            processed_data['sma'] = np.nan

        # Relative Strength Index (RSI)
        try:
            print("ℹ️ Изчисляване на RSI...")
            rsi_indicator = ta.momentum.RSIIndicator(close=processed_data['close'])
            processed_data['rsi'] = rsi_indicator.rsi()
            print("✅ RSI изчислен.")
        except Exception as e:
            print(f"🚫 Грешка при изчисляване на RSI: {e}")
            processed_data['rsi'] = np.nan

        # Moving Average Convergence Divergence (MACD)
        try:
            print("ℹ️ Изчисляване на MACD...")
            macd_indicator = ta.trend.MACD(close=processed_data['close'])
            processed_data['macd'] = macd_indicator.macd()
            processed_data['macd_signal'] = macd_indicator.macd_signal()
            processed_data['macd_diff'] = macd_indicator.macd_diff()
            print("✅ MACD изчислен.")
        except Exception as e:
            print(f"🚫 Грешка при изчисляване на MACD: {e}")
            processed_data['macd'] = np.nan
            processed_data['macd_signal'] = np.nan
            processed_data['macd_diff'] = np.nan

        # Bollinger Bands
        try:
            print("ℹ️ Изчисляване на Bollinger Bands...")
            bb = ta.volatility.BollingerBands(close=processed_data['close'])
            processed_data['bb_upper'] = bb.bollinger_hband()
            processed_data['bb_lower'] = bb.bollinger_lband()
            processed_data['bb_mavg'] = bb.bollinger_mavg()
            print("✅ Bollinger Bands изчислени.")
        except Exception as e:
            print(f"🚫 Грешка при изчисляване на Bollinger Bands: {e}")
            processed_data['bb_upper'] = np.nan
            processed_data['bb_lower'] = np.nan
            processed_data['bb_mavg'] = np.nan

        # Exponential Moving Average (EMA)
        try:
            print("ℹ️ Изчисляване на EMA...")
            ema_indicator_short = ta.trend.EMAIndicator(close=processed_data['close'], window=12)
            ema_indicator_long = ta.trend.EMAIndicator(close=processed_data['close'], window=26)
            processed_data['ema_12'] = ema_indicator_short.ema_indicator()
            processed_data['ema_26'] = ema_indicator_long.ema_indicator()
            print("✅ EMA изчислени.")
        except Exception as e:
            print(f"🚫 Грешка при изчисляване на EMA: {e}")
            processed_data['ema_12'] = np.nan
            processed_data['ema_26'] = np.nan

        # Commodity Channel Index (CCI)
        try:
            print("ℹ️ Изчисляване на CCI...")
            cci_indicator = ta.trend.CCIIndicator(high=processed_data['high'],
                                   low=processed_data['low'],
                                   close=processed_data['close'], window=20)
            processed_data['cci'] = cci_indicator.cci()
            print("✅ CCI изчислен.")
        except Exception as e:
            print(f"🚫 Грешка при изчисляване на CCI: {e}")
            processed_data['cci'] = np.nan

        # Average Directional Index (ADX)
        try:
            print("ℹ️ Изчисляване на ADX...")
            adx_indicator = ta.trend.ADXIndicator(high=processed_data['high'],
                                       low=processed_data['low'],
                                       close=processed_data['close'], window=14)
            processed_data['adx'] = adx_indicator.adx()
            processed_data['adx_pos'] = adx_indicator.adx_pos()
            processed_data['adx_neg'] = adx_indicator.adx_neg()
            print("✅ ADX изчислен.")
        except Exception as adx_e:
            print(f"🚫 Could not calculate ADX indicator: {adx_e}")
            processed_data['adx'] = np.nan
            processed_data['adx_pos'] = np.nan
            processed_data['adx_neg'] = np.nan

        # Stochastic Oscillator (STOCH)
        try:
            print("ℹ️ Изчисляване на Stochastic Oscillator...")
            stoch_indicator = ta.momentum.StochasticOscillator(high=processed_data['high'], low=processed_data['low'], close=processed_data['close'])
            processed_data['stoch_k'] = stoch_indicator.stoch()
            processed_data['stoch_d'] = stoch_indicator.stoch_signal()
            print("✅ Stochastic Oscillator изчислен.")
        except Exception as e:
            print(f"🚫 Грешка при изчисляване на Stochastic Oscillator: {e}")
            processed_data['stoch_k'] = np.nan
            processed_data['stoch_d'] = np.nan

        # Average True Range (ATR) - Added for Volatility Lot Model
        try:
            print(f"ℹ️ Изчисляване на ATR (window={atr_window})...")
            atr_indicator = ta.volatility.AverageTrueRange(high=processed_data['high'], low=processed_data['low'], close=processed_data['close'], window=atr_window)
            processed_data['atr'] = atr_indicator.average_true_range()
            print("✅ ATR изчислен.")
        except Exception as e:
            print(f"🚫 Грешка при изчисляване на ATR: {e}")
            processed_data['atr'] = np.nan


        # --- Calculate Classic Pivot Points ---
        # This function assumes the DataFrame has a DatetimeIndex.
        # If not, it will add NaN columns.
        if isinstance(processed_data.index, pd.DatetimeIndex): # Check again after potential date column handling
             print("ℹ️ Изчисляване на Класически пивот пойнти (изисква DatetimeIndex)...")
             processed_data = calculate_classic_pivots(processed_data)
        else:
             print("⚠️ Skipping Classic Pivot Point calculation as DataFrame index is not DatetimeIndex after processing.")
             # Add NaN pivot columns to maintain consistent structure even if not calculated
             processed_data['pivot_P'] = np.nan
             processed_data['pivot_R1'] = np.nan
             processed_data['pivot_S1'] = np.nan
             processed_data['pivot_R2'] = np.nan
             processed_data['pivot_S2'] = np.nan


        # --- Add Candlestick Patterns ---
        # Ensure the DataFrame has the required OHLC columns and is not empty before adding patterns
        if not processed_data.empty and all(col in processed_data.columns for col in ['open', 'high', 'low', 'close']):
             print("ℹ️ Добавяне на японски свещни патерни...")
             processed_data = add_candlestick_patterns(processed_data)
        else:
             print("⚠️ Skipping Candlestick Pattern calculation as DataFrame is empty or missing OHLC columns.")


        # --- Check for missing values after adding indicators and patterns ---
        missing_pct = processed_data.isnull().mean() * 100
        print(f"ℹ️ Процент на липсващи стойности по колона след добавяне на индикатори и патерни:\n{missing_pct[missing_pct > 0]}") # Only print columns with missing values

        # Drop rows with NaN values introduced by indicators and patterns.
        # Update the list of columns to check for NaNs
        # Include OHLC columns in the check in case they became NaN somehow
        cols_to_check_for_nan = ['open', 'high', 'low', 'close', 'volume',
                                   'sma', 'rsi', 'macd', 'macd_signal', 'macd_diff',
                                   'bb_upper', 'bb_lower', 'bb_mavg',
                                   'ema_12', 'ema_26', 'cci', 'adx', 'adx_pos', 'adx_neg',
                                   'stoch_k', 'stoch_d', 'atr',
                                   'pivot_P', 'pivot_R1', 'pivot_S1', 'pivot_R2', 'pivot_S2', # Pivot columns
                                   # Candlestick pattern columns are binary (0 or 1), should not have NaNs unless calculation failed completely
                                   # Add them here if you want to drop rows where pattern calculation failed for *any* pattern
                                   # Example: [col for col in processed_data.columns if col.startswith('pattern_')]
                                  ]

        # Filter for columns that actually exist in the DataFrame
        cols_to_check_for_nan_present = [col for col in cols_to_check_for_nan if col in processed_data.columns]


        initial_rows = len(processed_data)
        if cols_to_check_for_nan_present:
            print(f"ℹ️ Проверка за NaNs в колони: {cols_to_check_for_nan_present}")
            # Use .loc[:] or similar if dropping NaNs introduces fragmented index
            processed_data.dropna(subset=cols_to_check_for_nan_present, inplace=True)
            rows_after_dropna = len(processed_data)
            if initial_rows > rows_after_dropna:
                print(f"ℹ️ Dropped {initial_rows - rows_after_dropna} rows with NaN values after adding indicators and patterns.")
            else:
                print("ℹ️ No rows with NaNs found in checked columns after adding indicators and patterns.")

        else:
             print("ℹ️ No columns specified to check for NaNs after adding indicators and patterns.")


        # Check if DataFrame is empty after dropping NaNs
        if processed_data.empty:
             print("⚠️ Warning: All rows dropped after adding indicators/patterns and removing NaNs.")
             return None

        # Reset index after dropping rows if index is DatetimeIndex, to make it a regular column again for the environment
        # Only reset if the index is currently a DatetimeIndex
        if isinstance(processed_data.index, pd.DatetimeIndex):
            print("ℹ️ Resetting DatetimeIndex to a regular column after dropping NaNs.")
            processed_data = processed_data.reset_index(drop=False) # Keep the date column
            processed_data.rename(columns={'index': 'date'}, inplace=True) # Ensure it's named 'date'
            print(f"✅ Index reset. Columns now: {list(processed_data.columns)}")


        # Re-add sequential index if it was lost and no date column exists
        # Check for both 'date' and 'original_index' before adding sequential_index
        if 'date' not in processed_data.columns and 'original_index' not in processed_data.columns:
             print("ℹ️ No 'date' or 'original_index' column found. Adding 'sequential_index'.")
             processed_data['sequential_index'] = processed_data.index
             print("ℹ️ Re-added 'sequential_index' column as no 'date' or 'original_index' column is present.")


        # Final check on data types and column names
        processed_data.columns = processed_data.columns.astype(str) # Ensure column names are strings
        print("✅ Column names ensured to be strings.")


        # --- Reorder columns ---
        base_cols = ['date', 'open', 'high', 'low', 'close', 'volume']
        # Ensure base_cols are actually present in the DataFrame before trying to select them
        present_base_cols = [col for col in base_cols if col in processed_data.columns]
        indicator_cols = [col for col in processed_data.columns if col not in present_base_cols]

        # Ensure indicator_cols does not contain 'original_index' or 'sequential_index' if they exist
        indicator_cols = [col for col in indicator_cols if col not in ['original_index', 'sequential_index']]

        # Combine and reorder
        ordered_cols = present_base_cols + sorted(indicator_cols) # Sort indicator columns alphabetically
        # Add back 'original_index' or 'sequential_index' if they exist and are not already in ordered_cols
        if 'original_index' in processed_data.columns and 'original_index' not in ordered_cols:
             ordered_cols.append('original_index')
        if 'sequential_index' in processed_data.columns and 'sequential_index' not in ordered_cols:
             ordered_cols.append('sequential_index')

        # Select columns in the new order
        processed_data = processed_data[ordered_cols]
        print("✅ Columns reordered.")


        # Add logging after indicator calculation and dropna
        print(f"ℹ️ Данни след изчисляване на индикатори, пивоти и патерни и премахване на NaNs. Първи 5 реда:")
        print(processed_data.head())
        print(f"ℹ️ Колони в данните: {list(processed_data.columns)}")
        print(f"ℹ️ Типове данни на колони: {processed_data.dtypes}")
        print(f"ℹ️ Брой редове след обработка: {len(processed_data)}")

        if processed_data.empty:
             print("🚫 Обработените данни са празни след всички стъпки. Връщане на None.")
             return None

        return processed_data

    except Exception as e:
        print(f"🚫 An unexpected error occurred during technical indicator/pattern calculation: {e}")
        print(e)
        return None

def split_data(df, split_ratio=0.8):
    """
    Splits a DataFrame into training and testing sets based on a ratio.

    Args:
        df (pd.DataFrame): The DataFrame to split.
        split_ratio (float): The ratio of data to use for the training set (0.0 to 1.0).

    Returns:
        tuple: A tuple containing (train_df, test_df), or (None, None) if the input df is invalid.
    """
    if df is None or df.empty:
        print("🚫 Cannot split data: Input DataFrame is not available or is empty.")
        return None, None
    try:
        split_index = int(len(df) * split_ratio)
        train_data = df.iloc[:split_index].copy().reset_index(drop=True)
        test_data = df.iloc[split_index:].copy().reset_index(drop=True)
        print(f"\n➡️ Data split into Training ({len(train_data)} rows) and Testing ({len(test_data)} rows) sets using {split_ratio*100:.0f}/{ (1-split_ratio)*100:.0f} percentage split.")
        return train_data, test_data
    except Exception as e:
        print(f"🚫 An error occurred during data splitting: {e}")
        return None, None

Writing data_utils.py


In [14]:
import pandas as pd
import numpy as np

# === 1. Зареждане на данни ===
def load_data(file_path):
    df = pd.read_csv(file_path, parse_dates=['timestamp'])
    df.set_index('timestamp', inplace=True)
    df = df[['open', 'high', 'low', 'close', 'volume']].dropna()
    return df

# === 2. Classic Pivot Points ===
def calculate_pivots(df):
    daily = df.resample('1D').agg({'high': 'max', 'low': 'min', 'close': 'last'})
    pivots = pd.DataFrame(index=daily.index)
    pivots['P'] = (daily['high'] + daily['low'] + daily['close']) / 3
    pivots['R1'] = 2 * pivots['P'] - daily['low']
    pivots['S1'] = 2 * pivots['P'] - daily['high']
    pivots['R2'] = pivots['P'] + (daily['high'] - daily['low'])
    pivots['S2'] = pivots['P'] - (daily['high'] - daily['low'])
    return pivots

# === 3. Японски свещни патерни ===
def detect_patterns(df):
    signals = []

    for i in range(2, len(df)):
        o1, h1, l1, c1 = df.iloc[i-2][['open', 'high', 'low', 'close']]
        o2, h2, l2, c2 = df.iloc[i-1][['open', 'high', 'low', 'close']]
        o3, h3, l3, c3 = df.iloc[i][['open', 'high', 'low', 'close']]

        # Bullish Engulfing
        if c2 < o2 and c3 > o3 and c3 > o2 and o3 < c2:
            signals.append((df.index[i], 'Bullish Engulfing'))

        # Bearish Engulfing
        elif c2 > o2 and c3 < o3 and c3 < o2 and o3 > c2:
            signals.append((df.index[i], 'Bearish Engulfing'))

        # Hammer
        elif (h3 - l3) > 3 * (o3 - c3) and (c3 - l3) / (h3 - l3) > 0.6:
            signals.append((df.index[i], 'Hammer'))

        # Hanging Man
        elif (h3 - l3) > 3 * (c3 - o3) and (h3 - c3) / (h3 - l3) > 0.6:
            signals.append((df.index[i], 'Hanging Man'))

        # Doji
        elif abs(c3 - o3) < 0.05 * (h3 - l3):
            signals.append((df.index[i], 'Doji'))

        # Morning Star
        if c1 < o1 and abs(c2 - o2) < 0.1 * (h2 - l2) and c3 > o3 and c3 > (c1 + o1)/2:
            signals.append((df.index[i], 'Morning Star'))

        # Evening Star
        if c1 > o1 and abs(c2 - o2) < 0.1 * (h2 - l2) and c3 < o3 and c3 < (c1 + o1)/2:
            signals.append((df.index[i], 'Evening Star'))

        # Piercing Line
        if c2 < o2 and c3 > o3 and c3 > (o2 + c2)/2 and o3 < c2:
            signals.append((df.index[i], 'Piercing Line'))

        # Dark Cloud Cover
        if c2 > o2 and c3 < o3 and c3 < (o2 + c2)/2 and o3 > c2:
            signals.append((df.index[i], 'Dark Cloud Cover'))

        # Three White Soldiers
        if all(df.iloc[j]['close'] > df.iloc[j]['open'] for j in range(i-2, i+1)):
            signals.append((df.index[i], 'Three White Soldiers'))

    return signals

# === 4. Обединение и логване ===
def run_module(file_path):
    df = load_data(file_path)
    pivots = calculate_pivots(df)
    patterns = detect_patterns(df)

    print("📊 Pivot Points:")
    print(pivots.tail())

    print("\n🕯️ Detected Patterns:")
    for time, pattern in patterns[-10:]:
        print(f"{time} → {pattern}")

# === Пример за стартиране ===
# run_module("EURUSD5.csv")


In [15]:
%%writefile agent_utils.py
import streamlit as st
import pandas as pd
import numpy as np
import os
from stable_baselines3 import PPO, A2C, DQN
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
import gymnasium as gym
import glob

# Define a simple callback to update the progress bar in Streamlit
class ProgressCallback(BaseCallback):
    """
    A custom callback that updates a Streamlit progress bar and status text
    during the training process.
    """
    def __init__(self, progress_bar=None, status_text=None, verbose=0):
        super(ProgressCallback, self).__init__(verbose)
        self.progress_bar = progress_bar
        self.status_text = status_text
        self.total_timesteps_in_episode = 0
        self.current_timesteps_in_episode = 0
        self._prev_timesteps = 0

    def _on_training_start(self) -> None:
        if hasattr(self.training_env, 'get_attr') and isinstance(self.training_env.get_attr('df', indices=0)[0], pd.DataFrame) and hasattr(self.training_env.get_attr('lookback_window', indices=0)[0], '__int__'):
             df_len = len(self.training_env.get_attr('df', indices=0)[0])
             lookback_window = self.training_env.get_attr('lookback_window', indices=0)[0]
             self.total_timesteps_in_episode = max(0, df_len - lookback_window)


    def _on_step(self) -> bool:
        if self.progress_bar is not None:
            if self.locals.get('total_timesteps') is not None and self.locals['total_timesteps'] > 0:
                 progress_value = min(1.0, self.num_timesteps / self.locals['total_timesteps'])
                 self.progress_bar.progress(progress_value)


        if self.status_text is not None:
             if self.locals.get('total_timesteps') is not None:
                  self.status_text.text(f"Обучение в прогрес: {self.num_timesteps}/{self.locals['total_timesteps']} стъпки")
             else:
                  self.status_text.text(f"Обучение в прогрес: {self.num_timesteps} стъпки")


        return True


# Function to create an agent instance
# @st.cache_resource
def create_agent(agent_type, env, agent_params=None):
    """
    Creates an instance of a Stable-Baselines3 RL agent.

    Args:
        agent_type (str): The type of agent to create ('PPO', 'A2C', 'DQN').
        env (gym.Env): The environment to train the agent on (should be a VecEnv).
        agent_params (dict, optional): Dictionary of agent-specific parameters. Defaults to None.
                                       Expected format: {param1: value1, ...}.

    Returns:
        stable_baselines3.common.base.BaseAlgorithm: The created agent instance, or None if creation fails.
    """
    st.write(f"⚙️ Опит за създаване на агент тип: {agent_type}")
    model = None
    try:
        if agent_type == "PPO":
            ppo_defaults = {
                "learning_rate": 1e-4,
                "n_steps": 2048,
                "batch_size": 64,
                "n_epochs": 10,
                "gamma": 0.99,
                "gae_lambda": 0.95,
                "clip_range": 0.2,
                "verbose": 0
            }
            final_ppo_params = {**ppo_defaults, **(agent_params if agent_params is not None else {})}
            model = PPO("MlpPolicy", env, **final_ppo_params)
            st.write("✅ PPO агент създаден.")
        elif agent_type == "DQN":
            dqn_defaults = {
                "learning_rate": 1e-4,
                "buffer_size": 10000,
                "learning_starts": 100,
                "batch_size": 32,
                "gamma": 0.99,
                "train_freq": 1,
                "gradient_steps": 1,
                "verbose": 0
            }
            final_dqn_params = {**dqn_defaults, **(agent_params if agent_params is not None else {})}
            if not isinstance(env.action_space, gym.spaces.Discrete):
                 st.error(f"🚫 DQN requires a Discrete action space, but the environment has {type(env.action_space)}. Cannot create DQN agent.")
                 return None

            model = DQN("MlpPolicy", env, **final_dqn_params)
            st.write("✅ DQN агент създаден.")
        elif agent_type == "A2C":
            a2c_defaults = {
                "learning_rate": 7e-4,
                "n_steps": 5,
                "gamma": 0.99,
                "gae_lambda": 0.95,
                "vf_coef": 0.25,
                "ent_coef": 0.01,
                "verbose": 0
            }
            final_a2c_params = {**a2c_defaults, **(agent_params if agent_params is not None else {})}
            model = A2C("MlpPolicy", env, **final_a2c_params)
            st.write("✅ A2C агент създаден.")

        else:
            st.error(f"❌ Непознат агент: {agent_type}")
            return None

    except Exception as e:
        st.error(f"🚫 Грешка при създаване на агента {agent_type}: {e}")
        st.exception(e)
        model = None

    return model

# Function to train the agent
# @st.cache_resource
def train_agent(agent, total_timesteps=10000, progress_bar=None, status_text=None, save_dir="/content/drive/MyDrive/Colab_Models", checkpoint_dir="/content/drive/MyDrive/Colab_Checkpoints", save_freq=5000):
    """
    Trains the provided Stable-Baselines3 agent with checkpointing.

    Args:
        agent (stable_baselines3.common.base.BaseAlgorithm): The agent instance to train.
        total_timesteps (int): The total number of timesteps to train for.
        progress_bar (streamlit.delta_generator.DeltaGenerator, optional): Streamlit progress bar object. Defaults to None.
        status_text (streamlit.delta_generator.DeltaGenerator, optional): Streamlit text object for status updates. Defaults to None.
        save_dir (str): Directory to save the final model. Defaults to "/content/drive/MyDrive/Colab_Models".
        checkpoint_dir (str): Directory to save training checkpoints. Defaults to "/content/drive/MyDrive/Colab_Checkpoints".
        save_freq (int): Frequency (in timesteps) of saving checkpoints. Defaults to 5000.

    Returns:
        stable_baselines3.common.base.BaseAlgorithm: The trained agent instance, or None if training fails.
    """
    agent_type = type(agent).__name__
    model_name = f"valkyrie_{agent_type}_model"

    st.write(f"🧠 Стартиране на обучение за {agent_type} агент за {total_timesteps} стъпки...")
    trained_model = None

    if agent is None:
        st.error("🚫 train_agent: Получен е невалиден (None) агент за обучение.")
        return None
    if not hasattr(agent, 'env') or agent.env is None:
         st.error("🚫 train_agent: Агентът не е свързан с валидна среда.")
         return None


    try:
        # Create the ProgressCallback instance
        progress_callback_instance = ProgressCallback(progress_bar=progress_bar, status_text=status_text)

        # Create the CheckpointCallback instance
        # Ensure checkpoint directory exists
        if not os.path.exists(checkpoint_dir):
            st.warning(f"Директорията за чекпойнти не съществува: {checkpoint_dir}. Опит за създаване...")
            os.makedirs(checkpoint_dir, exist_ok=True)
            st.write(f"Създадена директория за чекпойнти: {checkpoint_dir}")

        checkpoint_callback_instance = CheckpointCallback(
            save_freq=save_freq,
            save_path=checkpoint_dir,
            name_prefix=f"{agent_type}_checkpoint",
            save_replay_buffer=True, # Save replay buffer for DQN
            save_vecnormalize=True # Save VecNormalize for VecEnvs
        )

        # Combine callbacks
        callbacks = [progress_callback_instance, checkpoint_callback_instance]

        st.write("🚀 Започва обучение...")
        agent.learn(total_timesteps=total_timesteps, callback=callbacks)
        st.write("✅ Обучението на агента приключи.")

        trained_model = agent

        # Save the final trained model with the dynamically generated name
        final_save_path = os.path.join(save_dir, model_name)
        save_agent(trained_model, final_save_path)

        st.success(f"✅ Агентът е обучен успешно и запазен като: {final_save_path}.zip")
        return trained_model

    except Exception as e:
        st.error(f"🚫 Грешка при обучението: {e}")
        st.exception(e)
        return None


# Function to save the agent
def save_agent(agent, path):
    """
    Saves the trained Stable-Baselines3 agent to a file.

    Args:
        agent (stable_baselines3.common.base.BaseAlgorithm): The agent instance to save.
        path (str): The path (including filename, without .zip) to save the agent in Google Drive.
    """
    if agent is not None:
        try:
            save_dir = os.path.dirname(path)
            if not os.path.exists(save_dir):
                 st.warning(f"Директорията за запазване не съществува: {save_dir}. Опит за създаване...")
                 os.makedirs(save_dir, exist_ok=True)
                 st.write(f"Създадена директория за запазване: {save_dir}")

            agent.save(path)
            st.write(f"✅ Агентът е запазен успешно в: {path}.zip")
        except Exception as e:
            st.error(f"🚫 Грешка при запазване на агента в {path}.zip: {e}")
            st.exception(e)
    else:
        st.warning("⚠️ Няма агент за запазване.")


# Function to load the agent by type, looking for checkpoints first
# @st.cache_resource
def load_agent_by_type(path, env, agent_type, checkpoint_dir="/content/drive/MyDrive/Colab_Checkpoints"):
    """
    Loads a Stable-Baselines3 agent from a file based on its type.
    Prioritizes loading from the latest checkpoint in checkpoint_dir if available,
    otherwise loads from the initial model file at the specified path.

    Args:
        path (str): The path to the initial saved agent file (.zip) or a directory containing checkpoints.
                    If checkpoint_dir is provided, this path is used as a fallback if no checkpoints are found.
        env (gym.Env): The environment compatible with the agent (should be a VecEnv).
        agent_type (str): The type of agent to load ('PPO', 'A2C', 'DQN').
        checkpoint_dir (str): Directory where training checkpoints are saved. Defaults to "/content/drive/MyDrive/Colab_Checkpoints".

    Returns:
        stable_baselines3.common.base.BaseAlgorithm: The loaded agent instance, or None if loading fails.
    """
    st.write(f"⚙️ Опит за зареждане на агент тип {agent_type}...")
    model = None

    # Try loading from the latest checkpoint first
    if checkpoint_dir and os.path.exists(checkpoint_dir):
        checkpoint_files = glob.glob(os.path.join(checkpoint_dir, f"{agent_type}_checkpoint_*.zip"))
        if checkpoint_files:
            # Find the latest checkpoint based on the timestep in the filename
            latest_checkpoint = max(checkpoint_files, key=os.path.getctime)
            st.info(f"✅ Намерен последен чекпойнт: {latest_checkpoint}. Опит за зареждане оттук.")
            try:
                if agent_type == "PPO":
                    model = PPO.load(latest_checkpoint, env=env)
                elif agent_type == "DQN":
                    if not isinstance(env.action_space, gym.spaces.Discrete):
                         st.error(f"🚫 DQN requires a Discrete action space, but the environment has {type(env.action_space)}. Cannot load DQN agent from checkpoint.")
                         return None
                    model = DQN.load(latest_checkpoint, env=env)
                elif agent_type == "A2C":
                    model = A2C.load(latest_checkpoint, env=env)
                else:
                    st.error(f"❌ Непознат агент тип за зареждане от чекпойнт: {agent_type}")
                    return None
                st.success(f"✅ Агент тип {agent_type} успешно зареден от чекпойнт.")
                return model

            except Exception as e:
                st.error(f"🚫 Грешка при зареждане на агента от чекпойнт {latest_checkpoint}: {e}")
                st.exception(e)
                # Fallback to loading the initial model if checkpoint loading fails


    # If no checkpoints found or checkpoint loading failed, try loading the initial model
    st.info(f"⚠️ Няма намерени чекпойнти в {checkpoint_dir} или зареждането се провали. Опит за зареждане от първоначалния път: {path}")
    try:
        # Construct the full path, ensuring we don't add .zip if it's already there
        full_path = path
        if not full_path.lower().endswith('.zip'):
             full_path = f"{path}.zip"

        if os.path.exists(full_path): # Check if the initial model file exists
            if agent_type == "PPO":
                model = PPO.load(full_path, env=env)
                st.write("✅ PPO агент зареден от първоначалния път.")
            elif agent_type == "DQN":
                if not isinstance(env.action_space, gym.spaces.Discrete):
                     st.error(f"🚫 DQN requires a Discrete action space, but the environment has {type(env.action_space)}. Cannot load DQN agent from initial path.")
                     return None
                model = DQN.load(full_path, env=env)
                st.write("✅ DQN агент зареден от първоначалния път.")
            elif agent_type == "A2C":
                model = A2C.load(full_path, env=env)
                st.write("✅ A2C агент зареден от първоначалния път.")
            else:
                st.error(f"❌ Непознат агент тип за зареждане от първоначалния път: {agent_type}")
                return None
            st.success(f"✅ Агент тип {agent_type} успешно зареден от първоначалния път.")
            return model

        else:
            st.warning(f"⚠️ Не е намерен файл на агент за зареждане на първоначалния път: {full_path}")
            return None

    except Exception as e:
        st.error(f"🚫 Грешка при зареждане на агента от първоначалния път {full_path}: {e}")
        st.exception(e)
        return None

Writing agent_utils.py


In [16]:
%%writefile forex_dashboard.py
import streamlit as st
import pandas as pd
import numpy as np
import os
import plotly.graph_objects as go # Import Plotly for interactive charts
from plotly.subplots import make_subplots # For multi-panel charts
import json # Import json for saving/loading config

# Импортиране на модулите на ядрото
from data_utils import download_forex_data, add_technical_indicators, split_data
# Corrected import path and class name: Import ForexTradingEnv from forex_env_utils
# Corrected import path
from env.forex_env import ForexTradingEnv, calculate_metrics # Assuming calculate_metrics is in forex_env.py
from agent_utils import create_agent, train_agent, save_agent, load_agent_by_type
from stable_baselines3.common.vec_env import DummyVecEnv
import glob # Import glob to find latest checkpoint
import yfinance as yf  # 📥 Импортиране на Yahoo Finance API


# --- Конфигурация на Streamlit страницата ---
st.set_page_config(layout="wide", page_title="ValkyrieFX Dashboard", page_icon="📊")

st.title("ValkyrieFX Trading Dashboard")

# --- Инициализация на състоянието на сесията (Session State) ---
# Ensure all necessary keys are initialized
if 'processed_data' not in st.session_state:
    st.session_state['processed_data'] = None
if 'raw_uploaded_data' not in st.session_state:
    st.session_state['raw_uploaded_data'] = None
if 'train_data' not in st.session_state:
    st.session_state['train_data'] = None
if 'test_data' not in st.session_state:
    st.session_state['test_data'] = None
if 'trained_agent' not in st.session_state:
    st.session_state['trained_agent'] = None
if 'trained_agent_name' not in st.session_state:
     st.session_state['trained_agent_name'] = None
if 'loaded_agent' not in st.session_state:
    st.session_state['loaded_agent'] = None
if 'loaded_agent_name' not in st.session_state:
     st.session_state['loaded_agent_name'] = None
if 'backtesting_results' not in st.session_state:
    st.session_state['backtesting_results'] = None
if 'performance_metrics' not in st.session_state:
    st.session_state['performance_metrics'] = None
if 'trades_log' not in st.session_state:
    st.session_state['trades_log'] = None
if 'env_config' not in st.session_state:
     st.session_state['env_config'] = {}
if 'agent_config' not in st.session_state:
     st.session_state['agent_config'] = {}
if 'last_checkpoint_path' not in st.session_state: # To store the path of the last loaded checkpoint
     st.session_state['last_checkpoint_path'] = None
if 'data_timeframe' not in st.session_state: # Store the selected timeframe
     st.session_state['data_timeframe'] = "1d"


# --- Sidebar Configuration ---
st.sidebar.header("⚙️ Глобални Настройки")

# Data Loading Section (Moved to sidebar as it affects all tabs)
st.sidebar.subheader("📈 Настройки на данни")

# Add a radio button to choose data source
data_source = st.sidebar.radio(
    "Избери източник на данни",
    ("Изтегли от Yahoo Finance", "Качи от CSV"),
    key='data_source_radio_sidebar'
)

uploaded_file = st.sidebar.file_uploader("📂 Качи данни от CSV файл", type=["csv"], key='upload_csv_file_sidebar') # Moved file uploader

raw_data_to_process = None # Initialize variable to hold raw data

# Global timeframe selector
# This will be visible regardless of the data source, but its value might be overridden by Yahoo Finance download
if data_source in ["Изтегли от Yahoo Finance", "Качи от CSV"]:
     st.session_state['data_timeframe'] = st.sidebar.selectbox("⏱️ Таймфрейм", ["1d", "1h", "30m", "15m", "5m", "1m"], key="global_timeframe_selector")


# Display Yahoo Finance controls only if selected
if data_source == "Изтегли от Yahoo Finance":
    data_ticker = st.sidebar.text_input("Тикер (напр. EURUSD=X)", value="EURUSD=X", key='data_ticker_input_sidebar') # Changed key name
    data_start_date = st.sidebar.date_input("Начална дата", value=pd.to_datetime("2015-01-01"), key='data_start_date_input_sidebar') # Changed key name
    data_end_date = st.sidebar.date_input("Крайна дата", value=pd.to_datetime("2023-12-31"), key='data_end_date_input_sidebar') # Changed key name
    # Removed the timeframe selectbox from here, using the global one

    if st.sidebar.button("⬇️ Изтегли и Обработи Данни", key='download_and_process_button_sidebar'): # Changed key name
         with st.spinner(f"Изтегляне и обработка на данни за {data_ticker} ({st.session_state['data_timeframe']})..."):
             # Ensure dates are in YYYY-MM-DD format string for yfinance
             start_date_str = data_start_date.strftime("%Y-%m-%d")
             end_date_str = data_end_date.strftime("%Y-%m-%d")

             # Use the value from the global timeframe selector
             raw_data = download_forex_data(data_ticker, start_date_str, end_date_str, st.session_state['data_timeframe']) # Use raw_data
             if raw_data is not None and not raw_data.empty:
                 # Move processing steps inside this if block
                 with st.spinner("Добавяне на технически индикатори..."):
                     # Pass atr_window from sidebar to add_technical_indicators
                     processed_data = add_technical_indicators(raw_data, atr_window=st.session_state.get('env_atr_window_sidebar', 14)) # Use raw_data

                 if processed_data is not None and not processed_data.empty:
                      split_ratio = st.sidebar.slider("Съотношение тренировъчни/тестови данни (%)", min_value=50, max_value=90, value=80, step=5, key='data_split_ratio_slider_sidebar') / 100.0 # Changed key name
                      with st.spinner("Разделяне на данни..."):
                           train_data, test_data = split_data(processed_data, split_ratio)

                      st.session_state['processed_data'] = processed_data
                      st.session_state['train_data'] = train_data
                      st.session_state['test_data'] = test_data
                      st.sidebar.success("✅ Данните са изтеглени, обработени и разделени успешно.")
                 else:
                      st.session_state['processed_data'] = None
                      st.session_state['train_data'] = None
                      st.session_state['test_data'] = None
                      st.sidebar.error("🚫 Неуспешно добавяне на индикатори или данните са празни след обработка.")
             else:
                 st.session_state['processed_data'] = None
                 st.session_state['train_data'] = None
                 st.session_state['test_data'] = None
                 st.sidebar.error("🚫 Неуспешно изтегляне на данни. Моля, проверете тикера, датите или интернет връзката.")


elif data_source == "Качи от CSV":
    # Use the uploaded_file from the file_uploader above
    if uploaded_file is not None:
        if st.sidebar.button("🛠️ Обработи Качени Данни", key='process_uploaded_button_sidebar'): # Changed key name
            # Read the CSV file into a pandas DataFrame
            try:
                raw_data = pd.read_csv(uploaded_file) # Use raw_data
                st.session_state['raw_uploaded_data'] = raw_data.copy() # Store raw uploaded data in session state
                st.sidebar.success(f"✅ Файлът `{uploaded_file.name}` е зареден и готов за обработка.")

                # Continue processing if raw_data is not None
                if raw_data is not None and not raw_data.empty:
                    with st.spinner("Добавяне на технически индикатори..."):
                         # Pass atr_window from sidebar to add_technical_indicators
                        processed_data = add_technical_indicators(raw_data, atr_window=st.session_state.get('env_atr_window_sidebar', 14)) # Use raw_data

                    if processed_data is not None and not processed_data.empty:
                         split_ratio = st.sidebar.slider("Съотношение тренировъчни/тестови данни (%)", min_value=50, max_value=90, value=80, step=5, key='data_split_ratio_slider_sidebar') / 100.0 # Changed key name
                         with st.spinner("Разделяне на данни..."):
                              train_data, test_data = split_data(processed_data, split_ratio)

                         st.session_state['processed_data'] = processed_data
                         st.session_state['train_data'] = train_data
                         st.session_state['test_data'] = test_data
                         st.sidebar.success("✅ Данните са обработени и разделени успешно.")
                    else:
                         st.session_state['processed_data'] = None
                         st.session_state['train_data'] = None
                         st.session_state['test_data'] = None
                         st.sidebar.error("🚫 Неуспешно добавяне на индикатори или данните са празни след обработка.")
                else:
                     st.session_state['processed_data'] = None
                     st.session_state['train_data'] = None
                     st.session_state['test_data'] = None
                     st.sidebar.error("🚫 Качените данни са празни след четене.")


            except Exception as e:
                st.session_state['raw_uploaded_data'] = None
                st.sidebar.error(f"🚫 Грешка при четене на CSV файла: {e}")
                raw_data = None # Ensure raw_data is None on error


# Display loaded/processed data info
if st.session_state['processed_data'] is not None:
    st.sidebar.subheader("Статус на данните")
    st.sidebar.write(f"Обработени данни: {len(st.session_state['processed_data'])} реда")
    if st.session_state['train_data'] is not None:
         st.sidebar.write(f"Данни за обучение: {len(st.session_state['train_data'])} реда")
    if st.session_state['test_data'] is not None:
         st.sidebar.write(f"Тестови данни: {len(st.session_state['test_data'])} реда")
else:
    # This block will execute if no data was downloaded or uploaded/processed successfully
    pass


# Environment Parameters (Moved to sidebar as they are used in multiple tabs)
st.sidebar.header("⚙️ Настройки на средата")
initial_amount = st.sidebar.number_input(
    "Начален капитал ($)", min_value=1000.0, max_value=10000000.0, value=100000.0, step=1000.0, format="%.2f", key='env_initial_amount_sidebar'
)
lookback_window = st.sidebar.number_input(
    "Lookback Window", min_value=1, max_value=200, value=20, step=1, key='env_lookback_window_sidebar'
)

st.sidebar.markdown("##### Настройки за риск и позиция")
stop_loss_pct = st.sidebar.number_input(
    "Стоп-лос (%)", min_value=0.0, max_value=100.0, value=2.0, step=0.1, format="%.2f", key='env_stop_loss_pct_sidebar'
) / 100.0
take_profit_pct = st.sidebar.number_input(
    "Тейк-профит (%)", min_value=0.0, max_value=100.0, value=4.0, step=0.1, format="%.2f", key='env_take_profit_pct_sidebar'
) / 100.0
max_drawdown_limit_pct = st.sidebar.number_input(
    "Макс. допустим спад на портфейла (%)", min_value=0.01, max_value=100.0, value=10.0, step=0.1, format="%.2f", key='env_max_drawdown_limit_pct_sidebar'
) / 100.0

st.sidebar.markdown("##### Управление на размера на позицията")
lot_model = st.sidebar.selectbox("Модел за размер на позицията", ["percent_of_capital", "volatility"], key='env_lot_model_select_sidebar') # Changed key name
position_size_pct = st.sidebar.number_input(
    "Размер на позицията (% от капитала)", min_value=0.01, max_value=100.0, value=10.0, step=0.1, format="%.2f", key='env_position_size_pct_sidebar'
) / 100.0
# ATR window is used in data_utils, not passed to env constructor directly
# but its window size might be relevant for required data length.
# Let's define it here as a parameter for consistency with environment logic,
# even if it's primarily used in data_utils.
# Ensure atr_window has a default value even if volatility is not selected
atr_window = st.sidebar.number_input(
    "ATR Window (за Volatility Model)", min_value=1, max_value=50, value=14, step=1, key='env_atr_window_sidebar', disabled=(lot_model != "volatility") # Changed key name
)


st.sidebar.markdown("##### Настройки за такси")
buy_cost_pct = st.sidebar.number_input(
    "Такса при покупка (%)", min_value=0.0, max_value=1.0, value=0.1, step=0.001, format="%.3f", key='env_buy_cost_pct_sidebar'
) / 100.0
sell_cost_pct = st.sidebar.number_input(
    "Такса при продажба (%)", min_value=0.0, max_value=1.0, value=0.1, step=0.001, format="%.3f", key='env_sell_cost_pct_sidebar'
) / 100.0

st.sidebar.markdown("##### Наградна функция")
# Добавяне на контроли за Reward Shaping параметри
tp_reward_bonus_pct = st.sidebar.number_input(
    "TP Бонус (% от начален капитал)", min_value=0.0, max_value=100.0, value=1.0, step=0.1, format="%.2f", key='env_tp_bonus_sidebar' # Changed key name
) / 100.0
sl_penalty_pct = st.sidebar.number_input(
    "SL Наказание (% от начален капитал)", min_value=0.0, max_value=100.0, value=1.0, step=0.1, format="%.2f", key='env_sl_penalty_sidebar' # Changed key name
) / 100.0
# Добавяне на контрол за Trailing SL
trailing_sl_pct = st.sidebar.number_input(
    "Трейлинг Стоп Лос (%)", min_value=0.0, max_value=10.0, value=0.5, step=0.05, format="%.2f", key='env_trailing_sl_pct_sidebar' # Changed key name
) / 100.0

# --- Agent Parameters (Moved to sidebar) ---
st.sidebar.header("🧠 Настройки на агент")
agent_type = st.sidebar.selectbox("Избери RL алгоритъм", ["PPO", "A2C", "DQN"], key='agent_algo_select_sidebar') # Changed key name
total_timesteps = st.sidebar.number_input("Общо стъпки за обучение", min_value=1000, max_value=1000000, value=50000, step=1000, key='agent_total_timesteps_input_sidebar') # Changed key name

# Add PPO specific parameters in the sidebar, visible only if PPO is selected
if agent_type == "PPO":
    st.sidebar.subheader("PPO Параметри")
    ppo_lr = st.sidebar.number_input("Learning Rate", min_value=1e-7, max_value=1e-2, value=1e-4, format="%.7f", key='ppo_lr_sidebar')
    ppo_gamma = st.sidebar.number_input("Gamma", min_value=0.0, max_value=1.0, value=0.99, step=0.01, format="%.2f", key='ppo_gamma_sidebar')
    ppo_n_steps = st.sidebar.number_input("N Steps", min_value=1, max_value=8192, value=2048, step=1, key='ppo_n_steps_sidebar')
    ppo_batch_size = st.sidebar.number_input("Batch Size", min_value=1, max_value=512, value=64, step=1, key='ppo_batch_size_sidebar')
    ppo_n_epochs = st.sidebar.number_input("N Epochs", min_value=1, max_value=20, value=10, step=1, key='ppo_n_epochs_sidebar')
    ppo_clip_range = st.sidebar.number_input("Clip Range", min_value=0.0, max_value=1.0, value=0.2, step=0.01, format="%.2f", key='ppo_clip_range_sidebar')
    ppo_gae_lambda = st.sidebar.number_input("GAE Lambda", min_value=0.0, max_value=1.0, value=0.95, step=0.01, format="%.2f", key='ppo_gae_lambda_sidebar')

# TODO: Add similar sections for A2C and DQN parameters if needed


# Checkpointing Parameters (Moved to sidebar)
st.sidebar.subheader("💾 Настройки на чекпойнтинг")
checkpoint_dir = st.sidebar.text_input("Директория за Чекпойнтинг", value="/content/drive/MyDrive/Colab_Checkpoints", key='checkpoint_dir_input_sidebar') # Changed key name
checkpoint_freq = st.sidebar.number_input("Честота на Чекпойнтинг (стъпки)", min_value=1000, value=5000, step=1000, key='checkpoint_freq_input_sidebar') # Changed key name

# Agent Save/Load Section (Moved to sidebar)
st.sidebar.subheader("📦 Управление на Модели")
save_model_dir = st.sidebar.text_input("Директория за Запазени Агенти", value="/content/drive/MyDrive/Colab_Models", key='save_model_dir_input_sidebar') # Changed key name

# --- Check Google Drive Access (Global Check) ---
if not os.path.exists("/content/drive/MyDrive"):
    st.error("🚫 НЯМА ДОСТЪП ДО GOOGLE DRIVE! Моля, свържете Google Drive, за да използвате функциите за запазване/зареждане.")
    # Optionally disable relevant controls
    drive_access_available = False
else:
    drive_access_available = True

    # Ensure base save and checkpoint directories exist
    os.makedirs(save_model_dir, exist_ok=True)
    os.makedirs(checkpoint_dir, exist_ok=True)


# --- Main Content Area ---
tab1, tab2, tab3 = st.tabs(["Създай и Обучи Нов Агент", "Продължи Обучение от Чекпойнт", "Бектест и Анализ"])

with tab1:
    st.header("Създай и Обучи Нов Агент")
    st.markdown("""
        В този раздел можете да създадете изцяло нов агент с текущите настройки на средата и агента
        и да стартирате ново обучение. **Внимание:** Това ще игнорира всички съществуващи чекпойнти
        и предишно обучение за този тип агент.
    """)

    # Button to Create and Train a NEW Agent
    if st.button("🚀 Създай и Обучи Нов Агент", key='create_and_train_new_agent_button'):
        if not drive_access_available:
             st.error("🚫 Необходим е достъп до Google Drive за запазване на модели и чекпойнти.")
        elif st.session_state.get("train_data") is None or st.session_state["train_data"].empty:
            st.warning("⚠️ Моля, заредете и обработете данни за обучение първо.")
        else:
            st.info(f"⏳ Създаване и стартиране на ново обучение за {agent_type} агент...")

            # Събиране на всички настройки на средата от UI
            current_env_params = {
                'initial_amount': initial_amount,
                'lookback_window': lookback_window,
                'buy_cost_pct': buy_cost_pct,
                'sell_cost_pct': sell_cost_pct,
                'max_drawdown_limit_pct': max_drawdown_limit_pct,
                'position_size_pct': position_size_pct,
                'stop_loss_pct': stop_loss_pct,
                'take_profit_pct': take_profit_pct,
                'trailing_sl_pct': trailing_sl_pct,
                'lot_model': lot_model,
                # atr_window is used in data_utils, not passed to env constructor
                'tp_reward_bonus': tp_reward_bonus_pct,
                'sl_penalty': sl_penalty_pct
            }
            st.session_state['env_config'] = current_env_params # Save env config to state

            # Събиране на всички настройки на агента от UI за избрания тип
            current_agent_params = {}
            if agent_type == "PPO":
                 current_agent_params = {
                     "learning_rate": ppo_lr,
                     "gamma": ppo_gamma,
                     "n_steps": ppo_n_steps,
                     "batch_size": ppo_batch_size,
                     "n_epochs": ppo_n_epochs,
                     "clip_range": ppo_clip_range,
                     "gae_lambda": ppo_gae_lambda,
                     "verbose": 0 # Ensure verbose is set to 0 in params
                 }
            # TODO: Добавете събиране на параметри и за A2C, DQN ако има контроли за тях
            st.session_state['agent_config'] = current_agent_params # Save agent config to state


            try:
                # Създаване на тренировъчна среда
                train_data_for_env = st.session_state.get('train_data')
                if train_data_for_env is None or train_data_for_env.empty:
                     st.error("🚫 Вътрешна грешка: Липсват данни за обучение след първоначалната проверка.")
                     st.stop()

                # Corrected class name
                train_env_instance = ForexTradingEnv(df=train_data_for_env.copy(), **current_env_params)
                vec_train_env = DummyVecEnv([lambda: train_env_instance])

                agent = None # Initialize agent to None
                # Add try-except around agent creation
                try:
                    st.write("Опит за създаване на агента...")
                    agent = create_agent(agent_type, vec_train_env, agent_params=current_agent_params)

                    if agent is not None:
                         st.write("✅ Агентът е успешно създаден.")
                    else:
                         st.error("🚫 Неуспешно създаване на агента: create_agent върна None.")
                         st.stop() # Stop execution if agent creation fails

                except Exception as e:
                     st.error(f"🚫 Неуспешно създаване на агента: {e}")
                     st.exception(e)
                     agent = None
                     st.stop() # Stop execution on creation error


                if agent:
                    # Обучение на агента
                    progress_bar = st.progress(0)
                    status_text = st.empty()

                    # Pass progress_bar and status_text to train_agent
                    trained_agent = train_agent(
                        agent,
                        total_timesteps=total_timesteps,
                        progress_bar=progress_bar,
                        status_text=status_text,
                        save_dir=save_model_dir, # Use sidebar save dir
                        checkpoint_dir=checkpoint_dir, # Use sidebar checkpoint dir
                        save_freq=checkpoint_freq # Use sidebar checkpoint freq
                    )

                    progress_bar.empty()
                    status_text.empty()

                    if trained_agent:
                        st.session_state['trained_agent'] = trained_agent
                        st.session_state['trained_agent_name'] = agent_type # Запазваме името на обучен агент

                        # Save config files
                        agent_save_name = f"valkyrie_{agent_type.lower()}_model"
                        config_save_path_base = os.path.join(save_model_dir, agent_save_name)
                        with open(f"{config_save_path_base}_env_config.json", 'w') as f:
                            json.dump(current_env_params, f, indent=4)
                        with open(f"{config_save_path_base}_agent_config.json", 'w') as f:
                            json.dump(current_agent_params, f, indent=4)

                        st.success(f"🎉 Новото обучение на {agent_type} приключи успешно!")
                        st.info(f"Конфигурациите са запазени като `{agent_save_name}_env_config.json` и `{agent_save_name}_agent_config.json` в `{save_model_dir}`.")

                    else:
                        st.error("❌ Новото обучение на агента беше неуспешно.")
                        st.session_state['trained_agent'] = None


            except Exception as e:
                # This outer try-except catches errors during environment creation or initial data checks
                st.error(f"🚫 Възникна грешка преди или по време на създаване и обучение: {e}")
                st.exception(e)


with tab2:
    st.header("Продължи Обучение от Чекпойнт")
    st.markdown("""
        В този раздел можете да заредите последния наличен чекпойнт за избрания тип агент
        и да продължите обучението оттам.
    """)

    # Display available checkpoints for the selected agent type
    st.subheader(f"Налични чекпойнти за {agent_type}:")
    if drive_access_available:
         checkpoint_files = sorted(glob.glob(os.path.join(checkpoint_dir, f"{agent_type}_checkpoint_*.zip")))
         if checkpoint_files:
              st.write(f"Намерени {len(checkpoint_files)} чекпойнта в `{checkpoint_dir}`:")
              for cf in checkpoint_files:
                   st.write(f"- {os.path.basename(cf)}")
         else:
              st.info(f"⚠️ Няма намерени чекпойнти за {agent_type} в `{checkpoint_dir}`.")
    else:
         st.warning("🚫 Необходим е достъп до Google Drive за показване на чекпойнти.")


    # Button to Load Latest Checkpoint and Continue Training
    if st.button("🔄 Продължи Обучение от Последен Чекпойнт", key='continue_training_button'):
        if not drive_access_available:
             st.error("🚫 Необходим е достъп до Google Drive за зареждане на чекпойнти.")
        elif st.session_state.get("train_data") is None or st.session_state["train_data"].empty:
            st.warning("⚠️ Моля, заредете и обработете данни за обучение първо.")
        else:
             st.info(f"⏳ Опит за зареждане на последен чекпойнт за {agent_type} и продължаване на обучението...")

             # Създаване на тренировъчна среда (необходима за зареждане)
             # Използваме същите параметри на средата като за ново обучение
             current_env_params = {
                'initial_amount': initial_amount,
                'lookback_window': lookback_window,
                'buy_cost_pct': buy_cost_pct,
                'sell_cost_pct': sell_cost_pct,
                'max_drawdown_limit_pct': max_drawdown_limit_pct,
                'position_size_pct': position_size_pct,
                'stop_loss_pct': stop_loss_pct,
                'take_profit_pct': take_profit_pct,
                'trailing_sl_pct': trailing_sl_pct,
                'lot_model': lot_model,
                'tp_reward_bonus': tp_reward_bonus_pct,
                'sl_penalty': sl_penalty_pct
            }
             st.session_state['env_config'] = current_env_params # Update env config in state


             try:
                 train_data_for_env = st.session_state.get('train_data')
                 if train_data_for_env is None or train_data_for_env.empty:
                      st.error("🚫 Вътрешна грешка: Липсват данни за обучение след първоначалната проверка.")
                      st.stop()

                 # Corrected class name
                 train_env_instance = ForexTradingEnv(df=train_data_for_env.copy(), **current_env_params)
                 vec_train_env = DummyVecEnv([lambda: train_env_instance])

                 # --- Зареждане на последния чекпойнт ---
                 # load_agent_by_type already handles finding the latest checkpoint
                 loaded_agent = load_agent_by_type(
                     path=save_model_dir, # This path is now a fallback, not the primary source
                     env=vec_train_env,
                     agent_type=agent_type,
                     checkpoint_dir=checkpoint_dir # load_agent_by_type will look here first
                 )

                 if loaded_agent is not None:
                     st.session_state['trained_agent'] = loaded_agent # Store the loaded agent for training
                     st.session_state['trained_agent_name'] = agent_type
                     st.session_state['loaded_agent'] = None # Clear loaded_agent if continuing training
                     st.session_state['loaded_agent_name'] = None
                     st.success(f"✅ {agent_type} агентът е успешно зареден от последния чекпойнт. Обучението ще продължи.")

                     # --- Продължаване на обучението ---
                     progress_bar = st.progress(0)
                     status_text = st.empty()

                     trained_agent = train_agent(
                         st.session_state['trained_agent'], # Use the loaded agent
                         total_timesteps=total_timesteps,
                         progress_bar=progress_bar,
                         status_text=status_text,
                         save_dir=save_model_dir, # Use sidebar save dir
                         checkpoint_dir=checkpoint_dir, # Use sidebar checkpoint dir
                         save_freq=checkpoint_freq # Use sidebar checkpoint freq
                     )

                     progress_bar.empty()
                     status_text.empty()

                     if trained_agent:
                         st.session_state['trained_agent'] = trained_agent
                         st.success(f"🎉 Обучението на {agent_type} приключи успешно (продължено от чекпойнт)!")
                     else:
                         st.error("❌ Продължаването на обучението беше неуспешно.")
                         st.session_state['trained_agent'] = None

                 else:
                     st.warning(f"⚠️ Не бяха намерени чекпойнти за {agent_type} в `{checkpoint_dir}` или зареждането се провали. Моля, създайте нов агент в първия таб.")
                     st.session_state['trained_agent'] = None # Ensure state is clean

             except Exception as e:
                 st.error(f"🚫 Възникна грешка при зареждане от чекпойнт или продължаване на обучението: {e}")
                 st.exception(e)
                 st.session_state['trained_agent'] = None # Ensure state is clean


with tab3:
    st.header("Бектест и Анализ")
    st.markdown("""
        В този раздел можете да стартирате бектест с обучен или ръчно зареден агент
        и да анализирате резултатите.
    """)

    # --- Agent Selection for Backtesting ---
    st.subheader("Избор на Агент за Бектест")

    agent_for_backtesting = None
    agent_name_for_backtesting = "Няма"

    # Option to use the currently trained agent
    use_trained_agent = st.checkbox("Използвай текущо обучен агент", value=('trained_agent' in st.session_state and st.session_state['trained_agent'] is not None), key='use_trained_agent_checkbox', disabled=('trained_agent' not in st.session_state or st.session_state['trained_agent'] is None))

    if use_trained_agent and st.session_state.get('trained_agent') is not None:
        agent_for_backtesting = st.session_state['trained_agent']
        agent_name_for_backtesting = f"Текущо обучен ({st.session_state.get('trained_agent_name', 'Неизвестен')})"
        st.info(f"Избран агент за бектест: **{agent_name_for_backtesting}**") # Typo fixed here

    else:
         st.info("Няма избран текущо обучен агент. Можете да заредите агент от файл по-долу.")
         st.session_state['loaded_agent'] = None # Clear loaded agent if using trained one

         # Option to load an agent from a specific file
         st.subheader("Зареди Агент от Файл за Бектест")
         # List available saved models
         available_saved_models = []
         if drive_access_available and os.path.exists(save_model_dir):
              model_files = glob.glob(os.path.join(save_model_dir, "*.zip"))
              # Filter for agent model files, not checkpoints or config
              available_saved_models = [os.path.basename(f) for f in model_files if not os.path.basename(f).startswith('._') and '_checkpoint_' not in os.path.basename(f) and not os.path.basename(f).endswith('_env_config.json') and not os.path.basename(f).endswith('_agent_config.json')]

         selected_model_file = st.selectbox("Избери запазен агент (.zip файл)", ["-- Избери файл --"] + available_saved_models, key='select_saved_model_file')

         if selected_model_file != "-- Избери файл --":
              # Attempt to determine agent type from filename (basic guess)
              guessed_agent_type = "PPO" # Default guess
              if "dqn" in selected_model_file.lower():
                   guessed_agent_type = "DQN"
              elif "a2c" in selected_model_file.lower():
                   guessed_agent_type = "A2C"
              # Allow user to confirm/correct agent type
              loaded_agent_type_override = st.selectbox(f"Потвърди/Коригирай тип на агента за `{selected_model_file}`:", ["PPO", "A2C", "DQN"], index=["PPO", "A2C", "DQN"].index(guessed_agent_type) if guessed_agent_type in ["PPO", "A2C", "DQN"] else 0, key='loaded_agent_type_override') # Added safety check for index


              if st.button(f"Зареди `{selected_model_file}`", key='load_agent_for_backtest_button'):
                   if not drive_access_available:
                       st.error("🚫 Необходим е достъп до Google Drive за зареждане на агенти.")
                   elif st.session_state.get("test_data") is None or st.session_state["test_data"].empty:
                       st.warning("⚠️ Необходими са тестови данни за създаване на среда, съвместима с агента.")
                   else:
                        st.info(f"⚙️ Зареждане на агент тип {loaded_agent_type_override} от: `{selected_model_file}`...")
                        full_model_path = os.path.join(save_model_dir, selected_model_file)

                        # Create a temporary environment for loading (using test data)
                        # Use the current environment parameters from the sidebar
                        current_env_params_for_load = {
                             'initial_amount': initial_amount,
                             'lookback_window': lookback_window,
                             'buy_cost_pct': buy_cost_pct,
                             'sell_cost_pct': sell_cost_pct,
                             'max_drawdown_limit_pct': max_drawdown_limit_pct,
                             'position_size_pct': position_size_pct,
                             'stop_loss_pct': stop_loss_pct,
                             'take_profit_pct': take_profit_pct,
                             'trailing_sl_pct': trailing_sl_pct,
                             'lot_model': lot_model,
                             'tp_reward_bonus': tp_reward_bonus_pct,
                             'sl_penalty': sl_penalty_pct
                         }

                        try:
                             test_data_for_load_env = st.session_state.get('test_data')
                             if test_data_for_load_env is None or test_data_for_load_env.empty:
                                  st.error("🚫 Вътрешна грешка: Липсват тестови данни за създаване на среда за зареждане.")
                                  st.stop()

                             # Corrected class name
                             temp_load_env_instance = ForexTradingEnv(df=test_data_for_load_env.copy(), **current_env_params_for_load)
                             temp_vec_env = DummyVecEnv([lambda: temp_load_env_instance])

                             loaded_agent_obj = load_agent_by_type(
                                 path=full_model_path, # Load directly from the specified path
                                 env=temp_vec_env,
                                 agent_type=loaded_agent_type_override,
                                 checkpoint_dir=None # Do not look for checkpoints when loading a specific saved model file
                             )

                             if loaded_agent_obj:
                                 st.session_state['loaded_agent'] = loaded_agent_obj
                                 st.session_state['loaded_agent_name'] = loaded_agent_type_override
                                 st.session_state['trained_agent'] = None # Clear trained agent if loading a specific one
                                 st.session_state['trained_agent_name'] = None
                                 st.success(f"✅ Агент тип {loaded_agent_type_override} е зареден успешно за бектест.")
                                 # Update the agent_for_backtesting variable
                                 agent_for_backtesting = st.session_state['loaded_agent']
                                 agent_name_for_backtesting = f"Зареден ({st.session_state.get('loaded_agent_name', 'Неизвестен')})"
                                 st.info(f"Избран агент за бектест: **{agent_name_for_backtesting}**")

                             else:
                                 st.error("🚫 Неуспешно зареждане на агента от файл.")
                                 st.session_state['loaded_agent'] = None
                                 st.session_state['loaded_agent_name'] = None

                        except Exception as e:
                            st.error(f"🚫 Грешка при зареждане на агента от файл: {e}")
                            st.exception(e)
                            st.session_state['loaded_agent'] = None
                            st.session_state['loaded_agent_name'] = None

    # Ensure agent_for_backtesting is set correctly after load attempts
    if use_trained_agent and st.session_state.get('trained_agent') is not None:
         agent_for_backtesting = st.session_state['trained_agent']
         agent_name_for_backtesting = f"Текущо обучен ({st.session_state.get('trained_agent_name', 'Неизвестен')})"
    elif st.session_state.get('loaded_agent') is not None:
         agent_for_backtesting = st.session_state['loaded_agent']
         agent_name_for_backtesting = f"Зареден ({st.session_state.get('loaded_agent_name', 'Неизвестен')})"
    else:
        agent_for_backtesting = None
        agent_name_for_backtesting = "Няма"

    st.write(f"Агент за бектестинг: **{agent_name_for_backtesting}**") # Display final selected agent


    # --- Start Backtesting Button ---
    if st.button("🔬 Стартирай Бектестинг", key='start_backtesting_button_tab3'): # Changed key name
        if agent_for_backtesting is None:
            st.warning("⚠️ Няма наличен обучен или зареден агент за бектестинг.")
        elif st.session_state['test_data'] is None or st.session_state['test_data'].empty:
            st.warning("⚠️ Липсват тестови данни за бектестинг. Моля, заредете и обработете данни първо.")
        else:
            st.info("⏳ Стартиране на бектестинг...")

            # Създаване на среда за бектестинг (използваме тестови данни)
            # Събиране на всички настройки на средата от UI (същите като за обучение)
            current_env_params_for_backtest = {
                 'initial_amount': initial_amount,
                 'lookback_window': lookback_window,
                 'buy_cost_pct': buy_cost_pct,
                 'sell_cost_pct': sell_cost_pct,
                 'max_drawdown_limit_pct': max_drawdown_limit_pct,
                 'position_size_pct': position_size_pct,
                 'stop_loss_pct': stop_loss_pct,
                 'take_profit_pct': take_profit_pct,
                 'trailing_sl_pct': trailing_sl_pct,
                 'lot_model': lot_model,
                 'tp_reward_bonus': tp_reward_bonus_pct,
                 'sl_penalty': sl_penalty_pct
            }

            try:
                 # Corrected class name
                 test_env_instance = ForexTradingEnv(df=st.session_state['test_data'].copy(), **current_env_params_for_backtest)

                 # --- Изпълнение на бектестинга (ръчен цикъл) ---
                 st.write("🔄 Започва симулация на бектестинг...")
                 obs, info = test_env_instance.reset()
                 done = False
                 truncated = False
                 backtesting_results = []
                 step_count = 0

                 # Correct total steps is total rows minus the start offset (max of lookback and atr window)
                 # Ensure atr_window is correctly retrieved, fallback to default if needed
                 env_atr_window = current_env_params_for_backtest.get('atr_window', 14) # Get atr_window from env params
                 # Access lot_model from the correct environment instance
                 start_offset = max(test_env_instance.lookback_window, env_atr_window if test_env_instance.lot_model == 'volatility' else 0) # Use correct atr_window
                 total_steps_backtest = len(test_env_instance.df) - start_offset # Total simulation steps available

                 # Add progress bar and status text for backtesting
                 progress_bar_backtest = st.progress(0)
                 status_text_backtest = st.empty()

                 st.write(f"Общ брой възможни стъпки за симулация: {total_steps_backtest}")


                 while not done and not truncated and test_env_instance.current_step < len(test_env_instance.df): # Loop while not done, not truncated, and within data bounds
                     try:
                          st.write(f"Стъпка: {test_env_instance.current_step}")
                          # Predict the action using the trained agent (deterministic=True for evaluation)
                          # Add logging before prediction
                          st.write(f"Преди предсказване - obs shape: {obs.shape if isinstance(obs, np.ndarray) else 'N/A'}, obs type: {type(obs)}")
                          action, _states = agent_for_backtesting.predict(obs, deterministic=True)
                          action_scalar = action.item() if isinstance(action, np.ndarray) else action
                          st.write(f"Предсказано действие: {action_scalar}")


                          # Basic check for valid observation before stepping
                          if obs is None or not isinstance(obs, np.ndarray) or obs.shape != test_env_instance.observation_space.shape:
                               st.error(f"🚫 Error at step {step_count}: Invalid observation received from environment. Shape: {obs.shape if isinstance(obs, np.ndarray) else 'N/A'}, Expected: {test_env_instance.observation_space.shape}")
                               break # Exit loop on invalid observation

                          # Take a step in the environment
                          st.write(f"Преди стъпка в средата с действие: {action_scalar}")
                          # Add try-except around env.step
                          try:
                             obs, reward, done, truncated, info = test_env_instance.step(action_scalar)
                             st.write(f"След стъпка в средата - reward: {reward}, done: {done}, truncated: {truncated}, info keys: {info.keys()}")
                          except Exception as e:
                             st.error(f"🚫 Грешка по време на стъпка в средата (env.step) на стъпка {step_count}: {e}")
                             st.exception(e)
                             done = True # Stop simulation on environment step error
                             break # Exit loop


                          # Add action to info for logging/analysis
                          info['action'] = action_scalar
                          info['reward'] = reward # Store reward in info for analysis/plotting

                          # Append info dictionary
                          # Ensure 'date' is included in the info dictionary from the environment step
                          # Assuming your ForexTradingEnv step method adds 'date' to the info dict
                          # Add a check for required keys in info dictionary before appending
                          required_info_keys = ['portfolio_value', 'date'] # Add other keys if necessary
                          if all(key in info for key in required_info_keys):
                               backtesting_results.append(info)
                          else:
                               st.warning(f"⚠️ Стъпка {step_count}: Info dictionary липсват ключове ({[key for key in required_info_keys if key not in info]}). Пропускане на запис за тази стъпка.")


                          step_count += 1
                          # Update progress bar based on the actual steps taken relative to total possible steps
                          if total_steps_backtest > 0:
                               progress_value = min(1.0, step_count / total_steps_backtest) # Ensure progress doesn't exceed 1
                               progress_bar_backtest.progress(progress_value)
                               status_text_backtest.text(f"Бектестинг в прогрес: {step_count}/{total_steps_backtest} стъпки")
                          else:
                               status_text_backtest.text("Бектестинг: Няма стъпки за симулация.")
                               break # Exit if no steps possible


                     except Exception as e:
                          st.error(f"🚫 Възникна грешка по време на стъпка в цикъла на бектест (извън env.step) на стъпка {step_count}: {e}")
                          st.exception(e)
                          done = True # Прекратяваме симулацията при грешка


                 progress_bar_backtest.empty()
                 status_text_backtest.empty()
                 st.write("✅ Бектестинг симулацията приключи.")


                 # Store results in session state
                 if backtesting_results:
                     st.write(f"Събрани {len(backtesting_results)} резултата от бектестинга.")
                     st.session_state['backtesting_results'] = pd.DataFrame(backtesting_results)
                     st.session_state['trades_log'] = test_env_instance.trades # Вземаме дневника на сделките от средата

                     # --- Изчисляване и показване на метрики ---
                     if not st.session_state['backtesting_results'].empty and 'portfolio_value' in st.session_state['backtesting_results'].columns:
                         st.write("Изчисляване на метрики за представяне...")
                         performance_metrics = calculate_metrics(
                             st.session_state['backtesting_results']['portfolio_value'],
                             st.session_state['trades_log'],
                             current_env_params_for_backtest['initial_amount']
                         )
                         st.session_state['performance_metrics'] = performance_metrics
                         st.success("✅ Бектестингът приключи. Метриките са изчислени.")

                     elif not st.session_state['backtesting_results'].empty:
                          st.warning("⚠️ Резултати от бектестинга са събрани, но липсва колоната 'portfolio_value'. Не могат да се изчислят метрики.")
                          st.session_state['performance_metrics'] = None
                     else:
                          st.warning("⚠️ Не бяха събрани достатъчно данни с 'portfolio_value' за изчисляване на метрики.")
                          st.session_state['performance_metrics'] = None
                          st.session_state['trades_log'] = None

                 else:
                      st.warning("⚠️ Не бяха събрани резултати от бектестинга.")
                      st.session_state['backtesting_results'] = None
                      st.session_state['performance_metrics'] = None
                      st.session_state['trades_log'] = None


            except Exception as e:
                st.error(f"🚫 Грешка при създаване на среда или изпълнение на бектестинг: {e}")
                st.exception(e)
                st.session_state['backtesting_results'] = None
                st.session_state['performance_metrics'] = None
                st.session_state['trades_log'] = None


    # --- Показване на Резултати (within Tab 3) ---
    st.subheader("Резултати от Бектест")

    # Показване на метрики
    if st.session_state['performance_metrics'] is not None:
        st.markdown("##### Метрики за представяне")
        metrics_df = pd.DataFrame.from_dict(st.session_state['performance_metrics'], orient='index', columns=['Value'])
        st.dataframe(metrics_df)

    # Показване на дневник на сделките
    if st.session_state['trades_log'] is not None and st.session_state['trades_log']:
        st.markdown("##### Дневник на сделките")
        trades_df = pd.DataFrame(st.session_state['trades_log'])
        st.dataframe(trades_df)
        st.info(f"Общ брой регистрирани сделки: {len(st.session_state['trades_log'])}")
        # Option to save trades log
        csv_trades = trades_df.to_csv(index=False).encode('utf-8')
        st.download_button(
             label="Изтегли дневник на сделките като CSV",
             data=csv_trades,
             file_name=f'trades_log_{agent_name_for_backtesting.replace(" ", "_").lower()}.csv',
             mime='text/csv',
             key='download_trades_log'
        )

    elif st.session_state['trades_log'] is not None and not st.session_state['trades_log']:
         st.info("ℹ️ По време на бектестинга не са регистрирани сделки.")
    elif st.session_state['backtesting_results'] is not None:
        st.info("ℹ️ Бекteстингът приключи, но дневникът на сделките не е наличен.")


    # --- Визуализация на Данни и Сделки (Plotly Chart - within Tab 3) ---
    st.subheader("Визуализация на данни и сделки")

    if st.session_state['processed_data'] is not None and not st.session_state['processed_data'].empty:
        # Use processed_data for the main plot, as it contains indicators
        plot_df = st.session_state['processed_data'].copy()

        # Ensure the date column is datetime and is the index
        if 'date' in plot_df.columns:
             plot_df['date'] = pd.to_datetime(plot_df['date'], errors='coerce')
             plot_df.dropna(subset=['date'], inplace=True) # Remove rows with invalid dates
             if not plot_df.empty:
                  plot_df = plot_df.set_index('date')
             else:
                  st.warning("Невалидни дати в данните след обработка. Не може да се построи графика.")
                  plot_df = None # To avoid errors below
        elif not isinstance(plot_df.index, pd.DatetimeIndex):
             st.warning("Невалиден формат на индекса. Нужен е DatetimeIndex или 'date' колона за графиката.")
             plot_df = None


        if plot_df is not None and not plot_df.empty:

            # --- Controls for selecting indicators for visualization ---
            available_indicators = [col for col in plot_df.columns if col not in ['open', 'high', 'low', 'close', 'volume', 'original_index', 'sequential_index', 'date']] # Exclude non-indicator columns and 'date'
            selected_indicators = st.multiselect(
                "Избери индикатори за показване",
                available_indicators,
                default=[ind for ind in ['sma', 'rsi', 'macd', 'bb_upper', 'bb_lower', 'bb_mavg', 'atr'] if ind in available_indicators], # Example default selected
                key='selected_indicators_for_plot_tab3' # Changed key name
            )

            # --- Create Plotly figure ---
            # Determine how many subplots we need (main + one for each indicator not on the main chart)
            # Indicators usually shown on the same chart as price: SMA, EMA, Bollinger Bands (upper, lower, mavg)
            # Indicators usually shown in separate subplots: RSI, MACD (macd, signal, diff), CCI, ADX (adx, pos, neg), AO, Stochastic (k, d), ATR
            indicators_on_price_chart = ['sma', 'ema_12', 'ema_26', 'bb_upper', 'bb_lower', 'bb_mavg'] # Added specific EMAs
            indicators_in_subcharts = ['rsi', 'macd', 'macd_signal', 'macd_diff', 'cci', 'adx', 'adx_pos', 'adx_neg', 'ao', 'stoch_k', 'stoch_d', 'atr']

            # Filter selected indicators based on where they are shown
            selected_on_price = [ind for ind in selected_indicators if ind in indicators_on_price_chart and ind in plot_df.columns]
            selected_in_subcharts = [ind for ind in selected_indicators if ind in indicators_in_subcharts and ind in plot_df.columns]

            # Number of subplots = 1 (for price) + number of selected indicators in separate subplots
            # We should count MACD as one group if at least one of its components is selected
            has_macd_subchart = any(ind in selected_in_subcharts for ind in ['macd', 'macd_signal', 'macd_diff'])
            num_subcharts = 1 + len([ind for ind in selected_in_subcharts if ind not in ['macd', 'macd_signal', 'macd_diff']]) + (1 if has_macd_subchart else 0)

            # Determine subplot heights (main is larger)
            base_height = 0.6 # Height for the price chart
            subchart_height = (1.0 - base_height) / max(1, (num_subcharts - 1)) if num_subcharts > 1 else 0 # Remaining height divided among subcharts
            row_heights = [base_height] + [subchart_height] * (num_subcharts - 1)


            # Create subplots
            fig = make_subplots(
                rows=num_subcharts,
                cols=1,
                shared_xaxes=True, # Share x-axis between all charts
                vertical_spacing=0.05, # Space between charts
                row_heights=row_heights, # Row heights
                # Add titles for Y-axes
                subplot_titles=['Цена'] + [ind.upper() for ind in selected_in_subcharts if ind not in ['macd', 'macd_signal', 'macd_diff']] + (['MACD'] if has_macd_subchart else [])

            )

            # --- Add Price (Candlestick) to the main chart ---
            fig.add_trace(go.Candlestick(
                x=plot_df.index,
                open=plot_df['open'],
                high=plot_df['high'],
                low=plot_df['low'],
                close=plot_df['close'],
                name='Цена',
                increasing_line_color='green', # Green for increasing candles
                decreasing_line_color='red' # Red for decreasing candles
            ), row=1, col=1)

            # --- Add selected indicators to the main chart ---
            for indicator_name in selected_on_price:
                if indicator_name in plot_df.columns:
                     fig.add_trace(go.Scatter(
                         x=plot_df.index,
                         y=plot_df[indicator_name],
                         mode='lines',
                         name=indicator_name.upper(), # Display in uppercase
                         line=dict(width=1) # Thin line for indicators
                     ), row=1, col=1)

            # --- Add selected indicators to separate subplots ---
            row_index = 2 # Start from the second row for subplots
            added_macd_group = False # Flag whether the MACD group has been added

            for indicator_name in selected_in_subcharts:
                if indicator_name in plot_df.columns:
                     # We might want MACD components on one chart
                     if indicator_name in ['macd', 'macd_signal', 'macd_diff']:
                          if not added_macd_group: # Add the MACD group only once
                               # Add all MACD components at once in a new subplot
                               macd_components_present = [ind for ind in ['macd', 'macd_signal', 'macd_diff'] if ind in selected_in_subcharts and ind in plot_df.columns]
                               if macd_components_present:
                                    # MACD Line
                                    if 'macd' in macd_components_present:
                                        fig.add_trace(go.Scatter(x=plot_df.index, y=plot_df['macd'], mode='lines', name='MACD Line', line=dict(color='blue', width=1)), row=row_index, col=1)
                                    # MACD Signal Line
                                    if 'macd_signal' in macd_components_present:
                                         fig.add_trace(go.Scatter(x=plot_df.index, y=plot_df['macd_signal'], mode='lines', name='MACD Signal', line=dict(color='red', width=1)), row=row_index, col=1)
                                    # MACD Histogram (as Bar chart)
                                    if 'macd_diff' in macd_components_present:
                                         # Determine histogram color based on value
                                         colors = ['green' if val >= 0 else 'red' for val in plot_df['macd_diff']]
                                         fig.add_trace(go.Bar(x=plot_df.index, y=plot_df['macd_diff'], name='MACD Hist', marker_color=colors, opacity=0.7), row=row_index, col=1)

                                    # fig.update_yaxes(title_text="MACD", row=row_index, col=1) # Title set by subplot_titles
                                    row_index += 1 # Move to the next row after adding the MACD group
                                    added_macd_group = True # Mark that the MACD group has been added

                     # Indicators in separate subplots (excluding MACD, as we handled it as a group)
                     elif indicator_name not in ['macd', 'macd_signal', 'macd_diff']:
                         fig.add_trace(go.Scatter(
                             x=plot_df.index,
                             y=plot_df[indicator_name],
                             mode='lines',
                             name=indicator_name.upper(),
                             line=dict(width=1)
                         ), row=row_index, col=1)
                         # fig.update_yaxes(title_text=indicator_name.upper(), row=row_index, col=1) # Title set by subplot_titles
                         row_index += 1 # Move to the next row

            # --- Add markers for trades from backtest ---
            if st.session_state['trades_log'] is not None and st.session_state['trades_log']:
                 trades_df = pd.DataFrame(st.session_state['trades_log'])
                 # Ensure dates in the trades log match the index of plot_df
                 if 'date' in trades_df.columns:
                      trades_df['date'] = pd.to_datetime(trades_df['date'], errors='coerce')
                      trades_df.dropna(subset=['date'], inplace=True)
                      trades_df = trades_df.set_index('date')

                      # Filter trades that are within the range of plot_df
                      trades_df = trades_df[(trades_df.index >= plot_df.index.min()) & (trades_df.index <= plot_df.index.max())]


                      # Add markers for Buy
                      buy_trades = trades_df[trades_df['action'] == 'buy']
                      if not buy_trades.empty:
                           fig.add_trace(go.Scatter(
                               x=buy_trades.index,
                               y=buy_trades['price'], # Use execution price
                               mode='markers',
                               marker=dict(color='green', size=10, symbol='triangle-up'), # Green upward triangle
                               name='Buy Signal',
                               hoverinfo='text',
                               text=[f"Buy<br>Price: {p:.5f}<br>Units: {u:.2f}<br>Step: {s}" for p, u, s in zip(buy_trades['price'], buy_trades['units'], buy_trades['step'])]
                           ), row=1, col=1) # Add markers to the main chart

                      # Add markers for Sell (exits - sell, SL, TP)
                      sell_trades = trades_df[trades_df['type'] == 'exit'] # Exits are of type 'exit'
                      if not sell_trades.empty:
                           fig.add_trace(go.Scatter(
                               x=sell_trades.index,
                               y=sell_trades['price'], # Use execution price
                               mode='markers',
                               marker=dict(color='red', size=10, symbol='triangle-down'), # Red downward triangle
                               name='Sell/Exit Signal',
                               hoverinfo='text',
                               text=[f"{act}<br>Price: {p:.5f}<br>PnL: {pnl:.2f}<br>Step: {s}" for act, p, pnl, s in zip(sell_trades['action'], sell_trades['price'], sell_trades['pnl'], sell_trades['step'])]
                           ), row=1, col=1) # Add markers to the main chart


            # --- Configure chart layout ---
            fig.update_layout(
                # title='Цена, Индикатори и Сделки', # Title is now set implicitly by make_subplots titles
                xaxis_title='Дата',
                # yaxis_title='Цена', # Y-axis title set by subplot_titles
                xaxis_rangeslider_visible=False, # Hide the lower range slider
                hovermode='x unified', # Show info for all traces on one date
                height=800 # Total chart height
            )

            # Update layout for shared x-axis and other settings
            # Ensure only the bottom chart shows x-axis tick labels
            for i in range(1, num_subcharts + 1):
                 if i < num_subcharts:
                      fig.update_xaxes(showticklabels=False, row=i, col=1)
                 else:
                      fig.update_xaxes(showticklabels=True, row=i, col=1)

            # Show the chart in Streamlit
            st.plotly_chart(fig, use_container_width=True)

        else:
            st.info("Заредете и обработете данни, за да видите графиката.")
    else:
        st.info("Заредете и обработете данни, за да видите графиката.")


# --- Instructions for running the Streamlit dashboard locally (duplicated for convenience) ---
st.sidebar.markdown("---")
st.sidebar.subheader("Как да стартирате дашборда")
st.sidebar.markdown("""
1. Запазете кода като `forex_dashboard.py`.
2. Уверете се, че файловете `data_utils.py`, `forex_env_utils.py`, `agent_utils.py` са в същата директория.
3. Отворете терминал в тази директория.
4. Инсталирайте Streamlit, yfinance, ta, plotly:
   `pip install streamlit yfinance ta plotly`
5. Стартирайте дашборда:
   `streamlit run forex_dashboard.py`
6. Отворете показания URL адрес в браузъра.
""")

Writing forex_dashboard.py


In [17]:
# 🧹 1. Спиране на всички предишни процеси на Streamlit
!pkill -f streamlit || echo "Няма активни процеси на Streamlit"

# 🚀 2. Стартиране на Streamlit таблото във фонов режим
import os
import threading
import time
import subprocess
import re

def run_streamlit():
    # Use the correct port 8501 as used by cloudflared tunnel
    os.system("streamlit run forex_dashboard.py --server.port 8501")

# Start Streamlit in a separate thread
threading.Thread(target=run_streamlit).start()

# ⏳ 3. Изчакване Streamlit да се стартира
time.sleep(5)

# 🌐 4. Стартиране на Cloudflare Tunnel и извличане на публичен URL
print("🔄 Стартиране на Cloudflare Tunnel...")
# Ensure cloudflared executable is in the current directory
process = subprocess.Popen(["./cloudflared", "tunnel", "--url", "http://localhost:8501"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT)

# 📡 5. Извличане на публичния URL от изхода
print("Извличане на публичен URL...")
tunnel_url = None
while True:
    line = process.stdout.readline().decode()
    if not line:
        break # Exit loop if there's no output
    print(line.strip())
    match = re.search(r"https://.*\.trycloudflare\.com", line)
    if match:
        tunnel_url = match.group(0)
        print(f"\n🔗 Публичен URL към Streamlit таблото: {tunnel_url}")
        break
    # Add a small delay to avoid busy-waiting
    time.sleep(0.1)

if tunnel_url is None:
    print("🚫 Неуспешно извличане на публичен URL. Моля, проверете изхода за грешки.")

^C
🔄 Стартиране на Cloudflare Tunnel...
Извличане на публичен URL...
2025-10-07T03:56:16Z INF Thank you for trying Cloudflare Tunnel. Doing so, without a Cloudflare account, is a quick way to experiment and try it out. However, be aware that these account-less Tunnels have no uptime guarantee, are subject to the Cloudflare Online Services Terms of Use (https://www.cloudflare.com/website-terms/), and Cloudflare reserves the right to investigate your use of Tunnels for violations of such terms. If you intend to use Tunnels in production you should use a pre-created named tunnel by following: https://developers.cloudflare.com/cloudflare-one/connections/connect-apps
2025-10-07T03:56:16Z INF Requesting new quick Tunnel on trycloudflare.com...
2025-10-07T03:56:19Z INF +--------------------------------------------------------------------------------------------+
2025-10-07T03:56:19Z INF |  Your quick Tunnel has been created! Visit it at (it may take some time to be reachable):  |
2025-10-07T0

In [18]:
!pip install -q google-generativeai


In [19]:
import google.generativeai as genai

model = genai.GenerativeModel("gemini-pro")

In [20]:
import google.generativeai as genai

model = genai.GenerativeModel("gemini-pro")

def explain_agent_action(obs, action, current_price, account_balance, open_position_units):
    prompt = f"""
    Трейдинг агент направи действие: {action}.
    Наблюдение: {obs}
    Текуща цена: {current_price}
    Баланс по сметката: {account_balance}
    Размер на отворена позиция: {open_position_units}

    Обясни защо може да е взето това решение, като използваш трейдърска логика, базирана на предоставената информация.
    """
    response = model.generate_content(prompt)
    return response.text

In [21]:
%%writefile main.py
# Example entry point for running training or backtesting outside of Streamlit

import pandas as pd
import numpy as np
import os

# Import functions and classes from your local modules
from data_utils import download_forex_data, add_technical_indicators, split_data
from env.forex_env import ForexTradingEnv, calculate_metrics # Corrected import path
from agent_utils import create_agent, train_agent, load_agent_by_type
from stable_baselines3.common.env_util import make_vec_env

# --- Configuration ---
# You can define configuration parameters here or load them from a file
DATA_TICKER = "EURUSD=X"
DATA_START_DATE = "2022-01-01"
DATA_END_DATE = "2023-01-01"
DATA_TIMEFRAME = "1d"
LOOKBACK_WINDOW = 20
INITIAL_AMOUNT = 100000.0
TRAIN_SPLIT_RATIO = 0.8
AGENT_TYPE = "PPO" # Or "A2C", "DQN"
TOTAL_TIMESTEPS = 50000
SAVE_DIR = "/content/drive/MyDrive/Colab_Models"
CHECKPOINT_DIR = "/content/drive/MyDrive/Colab_Checkpoints"
CHECKPOINT_FREQ = 5000


def main():
    print("Starting main script...")

    # 1. Download and process data
    print("Downloading and processing data...")
    raw_data = download_forex_data(DATA_TICKER, DATA_START_DATE, DATA_END_DATE, DATA_TIMEFRAME)

    if raw_data is None or raw_data.empty:
        print("Failed to download or process raw data. Exiting.")
        return

    processed_data = add_technical_indicators(raw_data, atr_window=14) # Assuming default ATR window 14

    if processed_data is None or processed_data.empty:
        print("Failed to add technical indicators or processed data is empty. Exiting.")
        return

    train_data, test_data = split_data(processed_data, TRAIN_SPLIT_RATIO)

    if train_data is None or train_data.empty:
        print("Failed to split data or train data is empty. Exiting.")
        return

    print(f"Data loaded and split. Train data shape: {train_data.shape}, Test data shape: {test_data.shape}")


    # 2. Create and train agent
    print(f"Creating and training {AGENT_TYPE} agent...")

    # Define environment parameters (adjust as needed)
    env_params = {
        'initial_amount': INITIAL_AMOUNT,
        'lookback_window': LOOKBACK_WINDOW,
        'buy_cost_pct': 0.001,
        'sell_cost_pct': 0.001,
        'max_drawdown_limit_pct': 0.10,
        'position_size_pct': 0.1, # Example value
        'stop_loss_pct': 0.02,
        'take_profit_pct': 0.04,
        'trailing_sl_pct': 0.005,
        'lot_model': 'percent_of_capital', # Example lot model
        'tp_reward_bonus': 0.01,
        'sl_penalty': 0.01
    }
    # Create the environment for training
    train_env_instance = ForexTradingEnv(df=train_data.copy(), **env_params) # Corrected class name
    vec_train_env = make_vec_env(lambda: train_env_instance, n_envs=1) # Use make_vec_env for compatibility

    # Define agent parameters (adjust as needed or load from config)
    agent_params = {
        # Example PPO params, customize based on AGENT_TYPE
        "learning_rate": 1e-4,
        "n_steps": 2048,
        "batch_size": 64,
        "n_epochs": 10,
        "gamma": 0.99,
        "gae_lambda": 0.95,
        "clip_range": 0.2,
        "verbose": 1 # Set verbose to 1 to see training output in the console
    }

    # Try loading existing agent/checkpoint first
    print(f"Attempting to load existing agent or latest checkpoint for {AGENT_TYPE}...")
    loaded_agent = load_agent_by_type(
        path=os.path.join(SAVE_DIR, f"valkyrie_{AGENT_TYPE.lower()}_model"),
        env=vec_train_env,
        agent_type=AGENT_TYPE,
        checkpoint_dir=CHECKPOINT_DIR # load_agent_by_type checks checkpoint_dir first
    )

    if loaded_agent:
        print(f"Successfully loaded agent/checkpoint for {AGENT_TYPE}. Continuing training.")
        agent_to_train = loaded_agent
    else:
        print(f"No existing agent/checkpoint found for {AGENT_TYPE}. Creating a new agent.")
        agent_to_train = create_agent(AGENT_TYPE, vec_train_env, agent_params=agent_params)

    if agent_to_train is None:
        print(f"Failed to create or load agent for {AGENT_TYPE}. Exiting training phase.")
        return

    # Train the agent
    print("Starting agent training...")
    trained_agent = train_agent(
        agent_to_train,
        total_timesteps=TOTAL_TIMESTEPS,
        save_dir=SAVE_DIR,
        checkpoint_dir=CHECKPOINT_DIR,
        save_freq=CHECKPOINT_FREQ
    )

    if trained_agent:
        print("Agent training finished.")
        # Optionally, you can add backtesting logic here using the test_data and the trained_agent
        # from forex_dashboard import run_backtest # Assuming you move run_backtest here or implement similar logic
        # backtesting_results, performance_metrics, trades_log = run_backtest(trained_agent, test_data, env_params)
        # print("\nBacktesting Metrics:")
        # print(performance_metrics)
    else:
        print("Agent training failed.")


if __name__ == "__main__":
    main()

Writing main.py


In [22]:
%%writefile agent_utils.py
import streamlit as st
import pandas as pd
import numpy as np
import os
from stable_baselines3 import PPO, A2C, DQN
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
import gymnasium as gym
import glob

# Define a simple callback to update the progress bar in Streamlit
class ProgressCallback(BaseCallback):
    """
    A custom callback that updates a Streamlit progress bar and status text
    during the training process.
    """
    def __init__(self, progress_bar=None, status_text=None, verbose=0):
        super(ProgressCallback, self).__init__(verbose)
        self.progress_bar = progress_bar
        self.status_text = status_text
        self.total_timesteps_in_episode = 0
        self.current_timesteps_in_episode = 0
        self._prev_timesteps = 0

    def _on_training_start(self) -> None:
        if hasattr(self.training_env, 'get_attr') and isinstance(self.training_env.get_attr('df', indices=0)[0], pd.DataFrame) and hasattr(self.training_env.get_attr('lookback_window', indices=0)[0], '__int__'):
             df_len = len(self.training_env.get_attr('df', indices=0)[0])
             lookback_window = self.training_env.get_attr('lookback_window', indices=0)[0]
             self.total_timesteps_in_episode = max(0, df_len - lookback_window)


    def _on_step(self) -> bool:
        if self.progress_bar is not None:
            if self.locals.get('total_timesteps') is not None and self.locals['total_timesteps'] > 0:
                 progress_value = min(1.0, self.num_timesteps / self.locals['total_timesteps'])
                 self.progress_bar.progress(progress_value)


        if self.status_text is not None:
             if self.locals.get('total_timesteps') is not None:
                  self.status_text.text(f"Обучение в прогрес: {self.num_timesteps}/{self.locals['total_timesteps']} стъпки")
             else:
                  self.status_text.text(f"Обучение в прогрес: {self.num_timesteps} стъпки")


        return True


# Function to create an agent instance
# @st.cache_resource
def create_agent(agent_type, env, agent_params=None):
    """
    Creates an instance of a Stable-Baselines3 RL agent.

    Args:
        agent_type (str): The type of agent to create ('PPO', 'A2C', 'DQN').
        env (gym.Env): The environment to train the agent on (should be a VecEnv).
        agent_params (dict, optional): Dictionary of agent-specific parameters. Defaults to None.
                                       Expected format: {param1: value1, ...}.

    Returns:
        stable_baselines3.common.base.BaseAlgorithm: The created agent instance, or None if creation fails.
    """
    st.write(f"⚙️ Опит за създаване на агент тип: {agent_type}")
    model = None
    try:
        if agent_type == "PPO":
            ppo_defaults = {
                "learning_rate": 1e-4,
                "n_steps": 2048,
                "batch_size": 64,
                "n_epochs": 10,
                "gamma": 0.99,
                "gae_lambda": 0.95,
                "clip_range": 0.2,
                "verbose": 0
            }
            final_ppo_params = {**ppo_defaults, **(agent_params if agent_params is not None else {})}
            model = PPO("MlpPolicy", env, **final_ppo_params)
            st.write("✅ PPO агент създаден.")
        elif agent_type == "DQN":
            dqn_defaults = {
                "learning_rate": 1e-4,
                "buffer_size": 10000,
                "learning_starts": 100,
                "batch_size": 32,
                "gamma": 0.99,
                "train_freq": 1,
                "gradient_steps": 1,
                "verbose": 0
            }
            final_dqn_params = {**dqn_defaults, **(agent_params if agent_params is not None else {})}
            if not isinstance(env.action_space, gym.spaces.Discrete):
                 st.error(f"🚫 DQN requires a Discrete action space, but the environment has {type(env.action_space)}. Cannot create DQN agent.")
                 return None

            model = DQN("MlpPolicy", env, **final_dqn_params)
            st.write("✅ DQN агент създаден.")
        elif agent_type == "A2C":
            a2c_defaults = {
                "learning_rate": 7e-4,
                "n_steps": 5,
                "gamma": 0.99,
                "gae_lambda": 0.95,
                "vf_coef": 0.25,
                "ent_coef": 0.01,
                "verbose": 0
            }
            final_a2c_params = {**a2c_defaults, **(agent_params if agent_params is not None else {})}
            model = A2C("MlpPolicy", env, **final_a2c_params)
            st.write("✅ A2C агент създаден.")

        else:
            st.error(f"❌ Непознат агент: {agent_type}")
            return None

    except Exception as e:
        st.error(f"🚫 Грешка при създаване на агента {agent_type}: {e}")
        st.exception(e)
        model = None

    return model

# Function to train the agent
# @st.cache_resource
def train_agent(agent, total_timesteps=10000, progress_bar=None, status_text=None, save_dir="/content/drive/MyDrive/Colab_Models", checkpoint_dir="/content/drive/MyDrive/Colab_Checkpoints", save_freq=5000):
    """
    Trains the provided Stable-Baselines3 agent with checkpointing.

    Args:
        agent (stable_baselines3.common.base.BaseAlgorithm): The agent instance to train.
        total_timesteps (int): The total number of timesteps to train for.
        progress_bar (streamlit.delta_generator.DeltaGenerator, optional): Streamlit progress bar object. Defaults to None.
        status_text (streamlit.delta_generator.DeltaGenerator, optional): Streamlit text object for status updates. Defaults to None.
        save_dir (str): Directory to save the final model. Defaults to "/content/drive/MyDrive/Colab_Models".
        checkpoint_dir (str): Directory to save training checkpoints. Defaults to "/content/drive/MyDrive/Colab_Checkpoints".
        save_freq (int): Frequency (in timesteps) of saving checkpoints. Defaults to 5000.

    Returns:
        stable_baselines3.common.base.BaseAlgorithm: The trained agent instance, or None if training fails.
    """
    agent_type = type(agent).__name__
    model_name = f"valkyrie_{agent_type}_model"

    st.write(f"🧠 Стартиране на обучение за {agent_type} агент за {total_timesteps} стъпки...")
    trained_model = None

    if agent is None:
        st.error("🚫 train_agent: Получен е невалиден (None) агент за обучение.")
        return None
    if not hasattr(agent, 'env') or agent.env is None:
         st.error("🚫 train_agent: Агентът не е свързан с валидна среда.")
         return None


    try:
        # Create the ProgressCallback instance
        progress_callback_instance = ProgressCallback(progress_bar=progress_bar, status_text=status_text)

        # Create the CheckpointCallback instance
        # Ensure checkpoint directory exists
        if not os.path.exists(checkpoint_dir):
            st.warning(f"Директорията за чекпойнти не съществува: {checkpoint_dir}. Опит за създаване...")
            os.makedirs(checkpoint_dir, exist_ok=True)
            st.write(f"Създадена директория за чекпойнти: {checkpoint_dir}")

        checkpoint_callback_instance = CheckpointCallback(
            save_freq=save_freq,
            save_path=checkpoint_dir,
            name_prefix=f"{agent_type}_checkpoint",
            save_replay_buffer=True, # Save replay buffer for DQN
            save_vecnormalize=True # Save VecNormalize for VecEnvs
        )

        # Combine callbacks
        callbacks = [progress_callback_instance, checkpoint_callback_instance]

        st.write("🚀 Започва обучение...")
        agent.learn(total_timesteps=total_timesteps, callback=callbacks)
        st.write("✅ Обучението на агента приключи.")

        trained_model = agent

        # Save the final trained model with the dynamically generated name
        final_save_path = os.path.join(save_dir, model_name)
        save_agent(trained_model, final_save_path)

        st.success(f"✅ Агентът е обучен успешно и запазен като: {final_save_path}.zip")
        return trained_model

    except Exception as e:
        st.error(f"🚫 Грешка при обучението: {e}")
        st.exception(e)
        return None


# Function to save the agent
def save_agent(agent, path):
    """
    Saves the trained Stable-Baselines3 agent to a file.

    Args:
        agent (stable_baselines3.common.base.BaseAlgorithm): The agent instance to save.
        path (str): The path (including filename, without .zip) to save the agent in Google Drive.
    """
    if agent is not None:
        try:
            save_dir = os.path.dirname(path)
            if not os.path.exists(save_dir):
                 st.warning(f"Директорията за запазване не съществува: {save_dir}. Опит за създаване...")
                 os.makedirs(save_dir, exist_ok=True)
                 st.write(f"Създадена директория за запазване: {save_dir}")

            agent.save(path)
            st.write(f"✅ Агентът е запазен успешно в: {path}.zip")
        except Exception as e:
            st.error(f"🚫 Грешка при запазване на агента в {path}.zip: {e}")
            st.exception(e)
    else:
        st.warning("⚠️ Няма агент за запазване.")


# Function to load the agent by type, looking for checkpoints first
# @st.cache_resource
def load_agent_by_type(path, env, agent_type, checkpoint_dir="/content/drive/MyDrive/Colab_Checkpoints"):
    """
    Loads a Stable-Baselines3 agent from a file based on its type.
    Prioritizes loading from the latest checkpoint in checkpoint_dir if available,
    otherwise loads from the initial model file at the specified path.

    Args:
        path (str): The path to the initial saved agent file (.zip) or a directory containing checkpoints.
                    If checkpoint_dir is provided, this path is used as a fallback if no checkpoints are found.
        env (gym.Env): The environment compatible with the agent (should be a VecEnv).
        agent_type (str): The type of agent to load ('PPO', 'A2C', 'DQN').
        checkpoint_dir (str): Directory where training checkpoints are saved. Defaults to "/content/drive/MyDrive/Colab_Checkpoints".

    Returns:
        stable_baselines3.common.base.BaseAlgorithm: The loaded agent instance, or None if loading fails.
    """
    st.write(f"⚙️ Опит за зареждане на агент тип {agent_type}...")
    model = None

    # Try loading from the latest checkpoint first
    if checkpoint_dir and os.path.exists(checkpoint_dir):
        checkpoint_files = glob.glob(os.path.join(checkpoint_dir, f"{agent_type}_checkpoint_*.zip"))
        if checkpoint_files:
            # Find the latest checkpoint based on the timestep in the filename
            latest_checkpoint = max(checkpoint_files, key=os.path.getctime)
            st.info(f"✅ Намерен последен чекпойнт: {latest_checkpoint}. Опит за зареждане оттук.")
            try:
                if agent_type == "PPO":
                    model = PPO.load(latest_checkpoint, env=env)
                elif agent_type == "DQN":
                    if not isinstance(env.action_space, gym.spaces.Discrete):
                         st.error(f"🚫 DQN requires a Discrete action space, but the environment has {type(env.action_space)}. Cannot load DQN agent from checkpoint.")
                         return None
                    model = DQN.load(latest_checkpoint, env=env)
                elif agent_type == "A2C":
                    model = A2C.load(latest_checkpoint, env=env)
                else:
                    st.error(f"❌ Непознат агент тип за зареждане от чекпойнт: {agent_type}")
                    return None
                st.success(f"✅ Агент тип {agent_type} успешно зареден от чекпойнт.")
                return model

            except Exception as e:
                st.error(f"🚫 Грешка при зареждане на агента от чекпойнт {latest_checkpoint}: {e}")
                st.exception(e)
                # Fallback to loading the initial model if checkpoint loading fails


    # If no checkpoints found or checkpoint loading failed, try loading the initial model
    st.info(f"⚠️ Няма намерени чекпойнти в {checkpoint_dir} или зареждането се провали. Опит за зареждане от първоначалния път: {path}")
    try:
        # Construct the full path, ensuring we don't add .zip if it's already there
        full_path = path
        if not full_path.lower().endswith('.zip'):
             full_path = f"{path}.zip"

        if os.path.exists(full_path): # Check if the initial model file exists
            if agent_type == "PPO":
                model = PPO.load(full_path, env=env)
                st.write("✅ PPO агент зареден от първоначалния път.")
            elif agent_type == "DQN":
                if not isinstance(env.action_space, gym.spaces.Discrete):
                     st.error(f"🚫 DQN requires a Discrete action space, but the environment has {type(env.action_space)}. Cannot load DQN agent from initial path.")
                     return None
                model = DQN.load(full_path, env=env)
                st.write("✅ DQN агент зареден от първоначалния път.")
            elif agent_type == "A2C":
                model = A2C.load(full_path, env=env)
                st.write("✅ A2C агент зареден от първоначалния път.")
            else:
                st.error(f"❌ Непознат агент тип за зареждане от първоначалния път: {agent_type}")
                return None
            st.success(f"✅ Агент тип {agent_type} успешно зареден от първоначалния път.")
            return model

        else:
            st.warning(f"⚠️ Не е намерен файл на агент за зареждане на първоначалния път: {full_path}")
            return None

    except Exception as e:
        st.error(f"🚫 Грешка при зареждането на агента от първоначалния път {full_path}: {e}")
        st.exception(e)
        return None

Overwriting agent_utils.py
