In [None]:
# Risk-Sensitive Portfolio Management with Reinforcement Learning
# ====================================================
#
# This notebook demonstrates:
# 1. Policy Gradient (REINFORCE) implementation for portfolio optimization
# 2. Actor-Critic (A2C) implementation for improved sample efficiency
# 3. Risk-sensitive reward functions incorporating CVaR (Conditional Value-at-Risk)
# 4. Performance comparison and stress testing during crisis periods
#
# Author: Your Name
# Date: October 21, 2025

# ## 1. Import Libraries and Setup Environment

import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Dirichlet
import gymnasium as gym
from gymnasium import spaces
from tqdm.notebook import tqdm
import yfinance as yf
from datetime import datetime, timedelta
import warnings

# Suppress warnings
warnings.filterwarnings('ignore')

# Check if CUDA is available and set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seeds for reproducibility
def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    
set_seed(42)

# Create directories for saving models and data
os.makedirs('data', exist_ok=True)
os.makedirs('saved_models', exist_ok=True)

# ## 2. Data Processing and Financial Metrics

# ### 2.1 Data Utilities

def download_stock_data(tickers, start_date, end_date, save_path=None):
    """
    Download historical price data for a list of stocks.
    
    Args:
        tickers (list): List of stock ticker symbols.
        start_date (str): Start date in 'YYYY-MM-DD' format.
        end_date (str): End date in 'YYYY-MM-DD' format.
        save_path (str, optional): Path to save the data. If None, data won't be saved.
        
    Returns:
        pd.DataFrame: DataFrame with the adjusted close prices of the stocks.
    """
    data = yf.download(tickers, start=start_date, end=end_date)['Adj Close']
    
    # If only one ticker, yfinance doesn't return a DataFrame with ticker columns
    if isinstance(data, pd.Series):
        data = pd.DataFrame(data, columns=[tickers[0]])
    
    # Fill missing values using forward fill then backward fill
    data = data.fillna(method='ffill').fillna(method='bfill')
    
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        data.to_csv(save_path)
    
    return data

def calculate_returns(prices, log_returns=False):
    """
    Calculate daily returns from price data.
    
    Args:
        prices (pd.DataFrame): DataFrame with price data.
        log_returns (bool): If True, calculate log returns, otherwise simple returns.
        
    Returns:
        pd.DataFrame: DataFrame with returns.
    """
    if log_returns:
        returns = np.log(prices / prices.shift(1))
    else:
        returns = prices / prices.shift(1) - 1
    
    # Drop the first row which will have NaN values
    return returns.dropna()

def split_data(data, train_ratio=0.6, val_ratio=0.2, test_ratio=0.2):
    """
    Split data into train, validation, and test sets.
    
    Args:
        data (pd.DataFrame): DataFrame to split.
        train_ratio (float): Ratio of data to use for training.
        val_ratio (float): Ratio of data to use for validation.
        test_ratio (float): Ratio of data to use for testing.
        
    Returns:
        tuple: (train_data, val_data, test_data)
    """
    assert train_ratio + val_ratio + test_ratio == 1.0, "Ratios must sum to 1"
    
    n = len(data)
    train_end = int(n * train_ratio)
    val_end = int(n * (train_ratio + val_ratio))
    
    train_data = data.iloc[:train_end]
    val_data = data.iloc[train_end:val_end]
    test_data = data.iloc[val_end:]
    
    return train_data, val_data, test_data

# ### 2.2 Financial Risk Metrics

def calculate_cvar(returns, alpha=0.05):
    """
    Calculate the Conditional Value-at-Risk (CVaR) for a series of returns.
    
    Args:
        returns (np.array): Array of returns.
        alpha (float): The significance level (e.g., 0.05 for 95% CVaR).
        
    Returns:
        float: The CVaR value.
    """
    # Sort returns in ascending order
    sorted_returns = np.sort(returns)
    
    # Determine the VaR threshold
    var_threshold_idx = int(np.ceil(alpha * len(sorted_returns))) - 1
    if var_threshold_idx < 0:
        var_threshold_idx = 0
    
    # Calculate CVaR as the mean of returns beyond VaR
    cvar = np.mean(sorted_returns[:var_threshold_idx+1])
    
    return cvar

def calculate_sharpe_ratio(returns, risk_free_rate=0.0, periods_per_year=252):
    """
    Calculate the Sharpe ratio for a series of returns.
    
    Args:
        returns (np.array): Array of returns.
        risk_free_rate (float): The risk-free rate.
        periods_per_year (int): Number of periods per year (e.g., 252 for daily returns).
        
    Returns:
        float: The Sharpe ratio.
    """
    mean_return = np.mean(returns)
    std_return = np.std(returns)
    
    if std_return == 0:
        return 0
    
    sharpe = (mean_return - risk_free_rate) / std_return
    annualized_sharpe = sharpe * np.sqrt(periods_per_year)
    
    return annualized_sharpe

def calculate_max_drawdown(portfolio_values):
    """
    Calculate the maximum drawdown of a portfolio.
    
    Args:
        portfolio_values (np.array): Array of portfolio values.
        
    Returns:
        float: Maximum drawdown.
    """
    # Calculate the maximum drawdown
    peak = portfolio_values[0]
    max_drawdown = 0
    
    for value in portfolio_values:
        if value > peak:
            peak = value
        drawdown = (peak - value) / peak
        max_drawdown = max(max_drawdown, drawdown)
    
    return max_drawdown

def calculate_turnover(old_weights, new_weights):
    """
    Calculate portfolio turnover.
    
    Args:
        old_weights (np.array): Previous portfolio weights.
        new_weights (np.array): New portfolio weights.
        
    Returns:
        float: Portfolio turnover.
    """
    return np.sum(np.abs(new_weights - old_weights)) / 2.0

# ### 2.3 Data Loading and Exploration

# Define the assets we want to include in our portfolio
tickers = ['SPY', 'QQQ', 'GLD', 'TLT', 'VNQ', 'BND', 'VWO']
start_date = '2010-01-01'
end_date = '2023-01-01'

# Data file path
data_file = os.path.join('data', 'stock_data.csv')

# Load or download data
if os.path.exists(data_file):
    print(f"Loading data from {data_file}")
    price_data = pd.read_csv(data_file, index_col=0, parse_dates=True)
else:
    print(f"Downloading data for {tickers}")
    price_data = download_stock_data(tickers, start_date, end_date, save_path=data_file)

# Display the first few rows of price data
print("Price data:")
price_data.head()

# Calculate returns
returns_data = calculate_returns(price_data, log_returns=False)

# Display the first few rows of returns data
print("Returns data:")
returns_data.head()

# Plot the price data
plt.figure(figsize=(14, 7))
for ticker in tickers:
    plt.plot(price_data.index, price_data[ticker] / price_data[ticker].iloc[0], label=ticker)
plt.title('Price Evolution (Normalized)')
plt.xlabel('Date')
plt.ylabel('Normalized Price')
plt.legend()
plt.grid(True)
plt.show()

# Calculate and plot correlation matrix
plt.figure(figsize=(10, 8))
sns.heatmap(returns_data.corr(), annot=True, cmap='coolwarm', fmt='.2f')
plt.title('Correlation Matrix of Asset Returns')
plt.tight_layout()
plt.show()

# Calculate and display risk metrics for individual assets
risk_metrics = pd.DataFrame(index=tickers)
risk_metrics['Annual Return'] = [np.mean(returns_data[ticker]) * 252 for ticker in tickers]
risk_metrics['Annual Volatility'] = [np.std(returns_data[ticker]) * np.sqrt(252) for ticker in tickers]
risk_metrics['Sharpe Ratio'] = [calculate_sharpe_ratio(returns_data[ticker].values) for ticker in tickers]
risk_metrics['CVaR (5%)'] = [calculate_cvar(returns_data[ticker].values) for ticker in tickers]

print("Risk metrics for individual assets:")
risk_metrics

# Split data into train, validation, and test sets
train_data, val_data, test_data = split_data(returns_data)

print(f"Data shapes - Train: {train_data.shape}, Val: {val_data.shape}, Test: {test_data.shape}")

# ## 3. Portfolio Environment Construction

class PortfolioEnv(gym.Env):
    """
    A reinforcement learning environment for portfolio management.
    """
    
    def __init__(self, returns, features=None, window_size=10, transaction_cost=0.001, 
                 risk_aversion=1.0, initial_amount=1.0, reward_mode='risk_adjusted'):
        """
        Initialize the environment.
        
        Args:
            returns (pd.DataFrame): DataFrame with asset returns.
            features (pd.DataFrame, optional): DataFrame with additional features.
            window_size (int): Size of the observation window.
            transaction_cost (float): Transaction cost as a fraction of traded amount.
            risk_aversion (float): Risk aversion parameter for CVaR penalty.
            initial_amount (float): Initial portfolio value.
            reward_mode (str): Reward calculation mode ('return', 'sharpe', 'risk_adjusted').
        """
        super(PortfolioEnv, self).__init__()
        
        self.returns = returns.values
        self.dates = returns.index
        self.asset_names = returns.columns
        self.num_assets = returns.shape[1]
        self.window_size = window_size
        self.transaction_cost = transaction_cost
        self.risk_aversion = risk_aversion
        self.initial_amount = initial_amount
        self.reward_mode = reward_mode
        
        # Set additional features if provided
        if features is not None:
            self.features = features.values
            assert features.shape[0] == returns.shape[0], "Features and returns must have same number of time steps"
            self.feature_dim = features.shape[1]
        else:
            self.features = None
            self.feature_dim = 0
        
        # Define action and observation spaces
        # Action space: portfolio weights (continuous) that sum to 1
        self.action_space = spaces.Box(low=0, high=1, shape=(self.num_assets,), dtype=np.float32)
        
        # Observation space: window of past returns and features
        obs_dim = self.window_size * self.num_assets
        if self.features is not None:
            obs_dim += self.feature_dim
        
        self.observation_space = spaces.Box(
            low=-np.inf, high=np.inf, shape=(obs_dim,), dtype=np.float32
        )
        
        # Initialize state
        self.reset()
    
    def _get_observation(self):
        """
        Construct the observation from the current state.
        
        Returns:
            np.array: The observation.
        """
        # Get the returns window
        returns_window = self.returns[self.current_step - self.window_size:self.current_step]
        obs = returns_window.flatten()
        
        # Add features if available
        if self.features is not None:
            current_features = self.features[self.current_step]
            obs = np.concatenate([obs, current_features])
        
        return obs
    
    def _calculate_reward(self, portfolio_return, previous_weights, new_weights):
        """
        Calculate the reward based on portfolio return and risk.
        
        Args:
            portfolio_return (float): Portfolio return.
            previous_weights (np.array): Previous portfolio weights.
            new_weights (np.array): New portfolio weights.
            
        Returns:
            float: The reward.
        """
        # Calculate transaction cost
        turnover = np.sum(np.abs(new_weights - previous_weights))
        cost = self.transaction_cost * turnover
        
        # Apply transaction cost to return
        net_return = portfolio_return - cost
        
        if self.reward_mode == 'return':
            return net_return
        
        # For risk-adjusted rewards, we need a window of recent portfolio returns
        if self.current_step >= self.window_size + 20:  # Need enough history for meaningful risk calculation
            # Calculate portfolio returns over a window
            hist_portfolio_returns = np.zeros(20)
            for i in range(20):
                step = self.current_step - 20 + i
                hist_portfolio_returns[i] = np.dot(self.returns[step], new_weights)
            
            # Calculate CVaR for the window
            sorted_returns = np.sort(hist_portfolio_returns)
            cvar_threshold_idx = int(0.05 * len(sorted_returns))
            cvar = np.mean(sorted_returns[:cvar_threshold_idx+1]) if cvar_threshold_idx >= 0 else sorted_returns[0]
            
            if self.reward_mode == 'risk_adjusted':
                # Risk-adjusted return (return - λ * CVaR)
                return net_return - self.risk_aversion * abs(cvar)
            elif self.reward_mode == 'sharpe':
                # Approximate Sharpe ratio over the window
                mean_return = np.mean(hist_portfolio_returns)
                std_return = np.std(hist_portfolio_returns)
                return mean_return / (std_return + 1e-6)  # Avoid division by zero
        
        # Default to net return if not enough history
        return net_return
    
    def step(self, action):
        """
        Take an action in the environment.
        
        Args:
            action (np.array): Portfolio weights.
            
        Returns:
            tuple: (observation, reward, terminated, truncated, info)
        """
        # Ensure action is valid (weights sum to 1)
        action = np.clip(action, 0, 1)
        action_sum = np.sum(action)
        if action_sum > 0:
            action = action / action_sum
        
        # Store previous weights and portfolio value
        prev_weights = self.weights.copy()
        prev_portfolio_value = self.portfolio_value
        
        # Move to the next time step
        self.current_step += 1
        
        # Check if episode is done
        terminated = self.current_step >= len(self.returns) - 1
        truncated = False
        
        # Update weights
        self.weights = action
        
        # Calculate portfolio return
        if not terminated:
            step_returns = self.returns[self.current_step]
            portfolio_return = np.dot(step_returns, self.weights)
            self.portfolio_value = prev_portfolio_value * (1 + portfolio_return)
            self.portfolio_returns.append(portfolio_return)
            
            # Calculate reward
            reward = self._calculate_reward(portfolio_return, prev_weights, self.weights)
            
            # Get new observation
            observation = self._get_observation()
            
            # Store info for logging
            info = {
                'portfolio_value': self.portfolio_value,
                'portfolio_return': portfolio_return,
                'weights': self.weights,
                'date': self.dates[self.current_step] if isinstance(self.dates, np.ndarray) else self.dates.iloc[self.current_step]
            }
        else:
            # If done, return the final observation and zero reward
            observation = self._get_observation()
            reward = 0
            info = {
                'portfolio_value': self.portfolio_value,
                'portfolio_return': 0,
                'weights': self.weights,
                'date': self.dates[self.current_step] if isinstance(self.dates, np.ndarray) else self.dates.iloc[self.current_step]
            }
        
        return observation, reward, terminated, truncated, info
    
    def reset(self, seed=None):
        """
        Reset the environment to the initial state.
        
        Returns:
            np.array: Initial observation.
        """
        super().reset(seed=seed)
        
        # Initialize state variables
        self.current_step = self.window_size
        self.portfolio_value = self.initial_amount
        self.portfolio_returns = []
        
        # Initialize with equal weights
        self.weights = np.ones(self.num_assets) / self.num_assets
        
        # Get initial observation
        observation = self._get_observation()
        
        info = {}
        return observation, info
    
    def get_portfolio_history(self):
        """
        Get the portfolio value history.
        
        Returns:
            tuple: (dates, portfolio_values)
        """
        portfolio_values = [self.initial_amount]
        for r in self.portfolio_returns:
            portfolio_values.append(portfolio_values[-1] * (1 + r))
        
        dates = self.dates[self.window_size:self.window_size + len(portfolio_values)]
        
        return dates, portfolio_values

# Create environments
window_size = 20
transaction_cost = 0.001
risk_aversion = 1.0
reward_mode = 'risk_adjusted'

# Initialize environments for train, validation, and test sets
train_env = PortfolioEnv(
    returns=train_data,
    window_size=window_size,
    transaction_cost=transaction_cost,
    risk_aversion=risk_aversion,
    reward_mode=reward_mode
)

val_env = PortfolioEnv(
    returns=val_data,
    window_size=window_size,
    transaction_cost=transaction_cost,
    risk_aversion=risk_aversion,
    reward_mode=reward_mode
)

test_env = PortfolioEnv(
    returns=test_data,
    window_size=window_size,
    transaction_cost=transaction_cost,
    risk_aversion=risk_aversion,
    reward_mode=reward_mode
)

# Get dimensions for neural networks
input_dim = train_env.observation_space.shape[0]
output_dim = train_env.action_space.shape[0]

print(f"State dimension: {input_dim}, Action dimension: {output_dim}")

# ## 4. REINFORCE Implementation

class PolicyNetwork(nn.Module):
    """
    Policy network for the REINFORCE algorithm.
    Outputs parameters for a Dirichlet distribution over portfolio weights.
    """
    
    def __init__(self, input_dim, output_dim, hidden_dim=64):
        super(PolicyNetwork, self).__init__()
        
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
            nn.Softplus()  # Ensure positive concentration parameters for Dirichlet
        )
    
    def forward(self, x):
        # Add a small constant to avoid zero concentration parameters
        return self.network(x) + 0.1

class ValueNetwork(nn.Module):
    """
    Value network for baseline in REINFORCE algorithm.
    """
    
    def __init__(self, input_dim, hidden_dim=64):
        super(ValueNetwork, self).__init__()
        
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, x):
        return self.network(x)


class REINFORCEAgent:
    """
    REINFORCE agent with baseline for portfolio optimization.
    """
    
    def __init__(self, input_dim, output_dim, lr_policy=0.001, lr_value=0.001, gamma=0.99, device='cpu'):
        """
        Initialize the REINFORCE agent.
        
        Args:
            input_dim (int): Dimension of input features.
            output_dim (int): Dimension of output (number of assets).
            lr_policy (float): Learning rate for policy network.
            lr_value (float): Learning rate for value network.
            gamma (float): Discount factor.
            device (str): Device to use for tensor operations.
        """
        self.device = device
        self.gamma = gamma
        
        # Initialize networks
        self.policy_net = PolicyNetwork(input_dim, output_dim).to(device)
        self.value_net = ValueNetwork(input_dim).to(device)
        
        # Initialize optimizers
        self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=lr_policy)
        self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=lr_value)
        
        # For storing episode history
        self.reset_episode()
    
    def reset_episode(self):
        """
        Reset the episode history.
        """
        self.log_probs = []
        self.values = []
        self.rewards = []
        self.entropies = []
        self.actions = []
        
    def select_action(self, state, evaluate=False):
        """
        Select an action using the policy network.
        
        Args:
            state (np.array): Current state.
            evaluate (bool): If True, use the mode of the distribution.
            
        Returns:
            np.array: Selected action (portfolio weights).
        """
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        
        # Get Dirichlet concentration parameters
        alpha = self.policy_net(state).squeeze()
        
        if evaluate:
            # During evaluation, use mean of Dirichlet distribution
            # E[Dir(alpha)] = alpha / sum(alpha)
            action = alpha / alpha.sum()
            return action.detach().cpu().numpy()
        
        # Sample from Dirichlet distribution
        m = Dirichlet(alpha)
        action = m.sample()
        
        # Store log probability and entropy for training
        log_prob = m.log_prob(action)
        entropy = m.entropy()
        
        # Store value estimate
        value = self.value_net(state)
        
        # Store for training
        self.log_probs.append(log_prob)
        self.entropies.append(entropy)
        self.values.append(value)
        self.actions.append(action)
        
        return action.detach().cpu().numpy()
    
    def update(self, last_value=0):
        """
        Update the policy and value networks.
        
        Args:
            last_value (float): Value estimate for the final state.
            
        Returns:
            float: Policy loss.
            float: Value loss.
        """
        rewards = self.rewards
        values = self.values + [torch.tensor([[last_value]], device=self.device)]
        
        # Calculate returns and advantages
        returns = []
        advantages = []
        R = last_value
        
        for i in reversed(range(len(rewards))):
            R = rewards[i] + self.gamma * R
            advantage = R - values[i].item()
            
            returns.append(R)
            advantages.append(advantage)
        
        # Reverse the lists to match the original order
        returns = returns[::-1]
        advantages = advantages[::-1]
        
        # Convert to tensors
        returns = torch.tensor(returns, device=self.device).unsqueeze(1)
        advantages = torch.tensor(advantages, device=self.device)
        
        # Update policy network
        policy_loss = 0
        for log_prob, advantage, entropy in zip(self.log_probs, advantages, self.entropies):
            policy_loss += -log_prob * advantage - 0.01 * entropy  # Add entropy regularization
        
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 0.5)  # Gradient clipping
        self.policy_optimizer.step()
        
        # Update value network
        value_loss = F.mse_loss(torch.cat(self.values), returns)
        
        self.value_optimizer.zero_grad()
        value_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.value_net.parameters(), 0.5)  # Gradient clipping
        self.value_optimizer.step()
        
        # Reset episode history
        self.reset_episode()
        
        return policy_loss.item(), value_loss.item()
    
    def store_reward(self, reward):
        """
        Store a reward from the environment.
        
        Args:
            reward (float): Reward from the environment.
        """
        self.rewards.append(reward)
    
    def save(self, path):
        """
        Save the agent's models.
        
        Args:
            path (str): Path to save the models.
        """
        torch.save({
            'policy_net': self.policy_net.state_dict(),
            'value_net': self.value_net.state_dict(),
            'policy_optimizer': self.policy_optimizer.state_dict(),
            'value_optimizer': self.value_optimizer.state_dict(),
        }, path)
    
    def load(self, path):
        """
        Load the agent's models.
        
        Args:
            path (str): Path to load the models from.
        """
        checkpoint = torch.load(path)
        self.policy_net.load_state_dict(checkpoint['policy_net'])
        self.value_net.load_state_dict(checkpoint['value_net'])
        self.policy_optimizer.load_state_dict(checkpoint['policy_optimizer'])
        self.value_optimizer.load_state_dict(checkpoint['value_optimizer'])

# ## 5. Actor-Critic (A2C) Implementation

class ActorCritic(nn.Module):
    """
    Actor-Critic network with shared features.
    Actor outputs parameters for Dirichlet distribution over portfolio weights.
    """
    
    def __init__(self, input_dim, output_dim, hidden_dim=64):
        super(ActorCritic, self).__init__()
        
        # Shared feature extractor
        self.feature_extractor = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # Actor head (policy)
        self.actor = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
            nn.Softplus()  # Ensure positive concentration parameters
        )
        
        # Critic head (value)
        self.critic = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, x):
        features = self.feature_extractor(x)
        
        # Actor output (add small constant to avoid zero concentration)
        alpha = self.actor(features) + 0.1
        
        # Critic output (value)
        value = self.critic(features)
        
        return alpha, value


class A2CAgent:
    """
    Advantage Actor-Critic (A2C) agent for portfolio optimization.
    """
    
    def __init__(self, input_dim, output_dim, lr=0.001, gamma=0.99, entropy_coef=0.01, value_coef=0.5, device='cpu'):
        """
        Initialize the A2C agent.
        
        Args:
            input_dim (int): Dimension of input features.
            output_dim (int): Dimension of output (number of assets).
            lr (float): Learning rate.
            gamma (float): Discount factor.
            entropy_coef (float): Entropy regularization coefficient.
            value_coef (float): Value loss coefficient.
            device (str): Device to use for tensor operations.
        """
        self.device = device
        self.gamma = gamma
        self.entropy_coef = entropy_coef
        self.value_coef = value_coef
        
        # Initialize network and optimizer
        self.network = ActorCritic(input_dim, output_dim).to(device)
        self.optimizer = optim.Adam(self.network.parameters(), lr=lr)
        
        # For storing episode history
        self.reset_episode()
    
    def reset_episode(self):
        """
        Reset the episode history.
        """
        self.values = []
        self.log_probs = []
        self.rewards = []
        self.entropies = []
        self.actions = []
    
    def select_action(self, state, evaluate=False):
        """
        Select an action using the actor-critic network.
        
        Args:
            state (np.array): Current state.
            evaluate (bool): If True, use the mode of the distribution.
            
        Returns:
            np.array: Selected action (portfolio weights).
        """
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        
        # Get policy parameters and value
        alpha, value = self.network(state)
        alpha = alpha.squeeze()
        
        if evaluate:
            # During evaluation, use mean of Dirichlet distribution
            action = alpha / alpha.sum()
            return action.detach().cpu().numpy()
        
        # Sample from Dirichlet distribution
        m = Dirichlet(alpha)
        action = m.sample()
        
        # Store log probability and entropy for training
        log_prob = m.log_prob(action)
        entropy = m.entropy()
        
        # Store for training
        self.log_probs.append(log_prob)
        self.values.append(value)
        self.entropies.append(entropy)
        self.actions.append(action)
        
        return action.detach().cpu().numpy()
    
    def update(self, next_value=0):
        """
        Update the actor-critic network.
        
        Args:
            next_value (float): Value estimate for the final state.
            
        Returns:
            float: Total loss.
            float: Actor loss.
            float: Critic loss.
            float: Entropy loss.
        """
        rewards = self.rewards
        values = self.values + [torch.tensor([[next_value]], device=self.device)]
        
        # Calculate returns and advantages using Generalized Advantage Estimation (GAE)
        returns = []
        advantages = []
        gae = 0
        
        for i in reversed(range(len(rewards))):
            delta = rewards[i] + self.gamma * values[i+1].item() - values[i].item()
            gae = delta + self.gamma * 0.95 * gae  # 0.95 is the GAE lambda
            advantage = gae
            
            returns.append(gae + values[i].item())
            advantages.append(advantage)
        
        # Reverse the lists to match the original order
        returns = returns[::-1]
        advantages = advantages[::-1]
        
        # Convert to tensors
        returns = torch.tensor(returns, device=self.device).unsqueeze(1)
        advantages = torch.tensor(advantages, device=self.device)
        
        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        # Calculate actor (policy) loss
        actor_loss = 0
        for log_prob, advantage in zip(self.log_probs, advantages):
            actor_loss += -log_prob * advantage
        actor_loss = actor_loss / len(self.log_probs)
        
        # Calculate critic (value) loss
        critic_loss = F.mse_loss(torch.cat(self.values), returns)
        
        # Calculate entropy loss (for exploration)
        entropy_loss = -torch.stack(self.entropies).mean()
        
        # Total loss
        total_loss = actor_loss + self.value_coef * critic_loss + self.entropy_coef * entropy_loss
        
        # Update network
        self.optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.network.parameters(), 0.5)  # Gradient clipping
        self.optimizer.step()
        
        # Reset episode history
        self.reset_episode()
        
        return (
            total_loss.item(),
            actor_loss.item(),
            critic_loss.item(),
            entropy_loss.item()
        )
    
    def store_reward(self, reward):
        """
        Store a reward from the environment.
        
        Args:
            reward (float): Reward from the environment.
        """
        self.rewards.append(reward)
    
    def save(self, path):
        """
        Save the agent's model.
        
        Args:
            path (str): Path to save the model.
        """
        torch.save({
            'network': self.network.state_dict(),
            'optimizer': self.optimizer.state_dict(),
        }, path)
    
    def load(self, path):
        """
        Load the agent's model.
        
        Args:
            path (str): Path to load the model from.
        """
        checkpoint = torch.load(path)
        self.network.load_state_dict(checkpoint['network'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])

# ## 6. Training and Evaluation Framework

def evaluate_agent(env, agent, num_episodes=5):
    """
    Evaluate an agent on multiple episodes.
    
    Args:
        env: The environment.
        agent: The agent.
        num_episodes (int): Number of episodes to evaluate on.
        
    Returns:
        float: Average reward over episodes.
    """
    total_reward = 0
    
    for _ in range(num_episodes):
        state, _ = env.reset()
        done = False
        truncated = False
        episode_reward = 0
        
        while not done and not truncated:
            action = agent.select_action(state, evaluate=True)
            state, reward, done, truncated, _ = env.step(action)
            episode_reward += reward
        
        total_reward += episode_reward
    
    return total_reward / num_episodes

def train_reinforce(env, agent, num_episodes, eval_interval=20, eval_episodes=5, save_path=None):
    """
    Train a REINFORCE agent.
    
    Args:
        env: The environment.
        agent: The REINFORCE agent.
        num_episodes (int): Number of episodes to train for.
        eval_interval (int): Interval for evaluation during training.
        eval_episodes (int): Number of episodes for evaluation.
        save_path (str): Path to save the trained agent.
        
    Returns:
        tuple: (agent, training_rewards, evaluation_rewards)
    """
    training_rewards = []
    evaluation_rewards = []
    
    for episode in tqdm(range(num_episodes)):
        # Reset the environment and agent episode history
        state, _ = env.reset()
        agent.reset_episode()
        done = False
        truncated = False
        episode_reward = 0
        
        # Run an episode
        while not done and not truncated:
            # Select action
            action = agent.select_action(state)
            
            # Take action in environment
            next_state, reward, done, truncated, _ = env.step(action)
            
            # Store reward
            agent.store_reward(reward)
            episode_reward += reward
            
            # Update state
            state = next_state
        
        # Update agent after episode
        agent.update()
        
        # Record training reward
        training_rewards.append(episode_reward)
        
        # Evaluate the agent periodically
        if (episode + 1) % eval_interval == 0:
            eval_reward = evaluate_agent(env, agent, eval_episodes)
            evaluation_rewards.append(eval_reward)
            
            print(f"\nEpisode {episode+1}/{num_episodes}")
            print(f"Training reward: {episode_reward:.4f}")
            print(f"Evaluation reward: {eval_reward:.4f}")
            
            # Save the best agent
            if save_path and (len(evaluation_rewards) == 1 or eval_reward > max(evaluation_rewards[:-1])):
                os.makedirs(os.path.dirname(save_path), exist_ok=True)
                agent.save(save_path)
                print(f"Saved model to {save_path}")
    
    # Plot training progress
    plt.figure(figsize=(12, 5))
    plt.plot(training_rewards, label='Training Rewards')
    plt.plot(np.arange(eval_interval-1, num_episodes, eval_interval), evaluation_rewards, label='Evaluation Rewards')
    plt.xlabel('Episode')
    plt.ylabel('Reward')
    plt.title('REINFORCE Training Progress')
    plt.legend()
    plt.grid(True)
    plt.show()
    
    return agent, training_rewards, evaluation_rewards

def train_a2c(env, agent, num_episodes, eval_interval=20, eval_episodes=5, save_path=None):
    """
    Train an A2C agent.
    
    Args:
        env: The environment.
        agent: The A2C agent.
        num_episodes (int): Number of episodes to train for.
        eval_interval (int): Interval for evaluation during training.
        eval_episodes (int): Number of episodes for evaluation.
        save_path (str): Path to save the trained agent.
        
    Returns:
        tuple: (agent, training_rewards, evaluation_rewards)
    """
    training_rewards = []
    evaluation_rewards = []
    
    for episode in tqdm(range(num_episodes)):
        # Reset the environment and agent episode history
        state, _ = env.reset()
        agent.reset_episode()
        done = False
        truncated = False
        episode_reward = 0
        
        # Run an episode
        while not done and not truncated:
            # Select action
            action = agent.select_action(state)
            
            # Take action in environment
            next_state, reward, done, truncated, _ = env.step(action)
            
            # Store reward
            agent.store_reward(reward)
            episode_reward += reward
            
            # Update state
            state = next_state
        
        # Update agent after episode
        total_loss, actor_loss, critic_loss, entropy_loss = agent.update()
        
        # Record training reward
        training_rewards.append(episode_reward)
        
        # Evaluate the agent periodically
        if (episode + 1) % eval_interval == 0:
            eval_reward = evaluate_agent(env, agent, eval_episodes)
            evaluation_rewards.append(eval_reward)
            
            print(f"\nEpisode {episode+1}/{num_episodes}")
            print(f"Training reward: {episode_reward:.4f}")
            print(f"Evaluation reward: {eval_reward:.4f}")
            print(f"Actor loss: {actor_loss:.4f}, Critic loss: {critic_loss:.4f}, Entropy loss: {entropy_loss:.4f}")
            
            # Save the best agent
            if save_path and (len(evaluation_rewards) == 1 or eval_reward > max(evaluation_rewards[:-1])):
                os.makedirs(os.path.dirname(save_path), exist_ok=True)
                agent.save(save_path)
                print(f"Saved model to {save_path}")
    
    # Plot training progress
    plt.figure(figsize=(12, 5))
    plt.plot(training_rewards, label='Training Rewards')
    plt.plot(np.arange(eval_interval-1, num_episodes, eval_interval), evaluation_rewards, label='Evaluation Rewards')
    plt.xlabel('Episode')
    plt.ylabel('Reward')
    plt.title('A2C Training Progress')
    plt.legend()
    plt.grid(True)
    plt.show()
    
    return agent, training_rewards, evaluation_rewards

# ## 7. Train REINFORCE Agent

# Initialize REINFORCE agent
reinforce_agent = REINFORCEAgent(
    input_dim=input_dim,
    output_dim=output_dim,
    lr_policy=0.001,
    lr_value=0.001,
    gamma=0.99,
    device=device
)

# Define save path
reinforce_path = os.path.join('saved_models', 'reinforce_agent.pth')

# Train the agent (uncomment to run training)
# reinforce_agent, reinforce_train_rewards, reinforce_eval_rewards = train_reinforce(
#     env=train_env,
#     agent=reinforce_agent,
#     num_episodes=200,
#     eval_interval=20,
#     save_path=reinforce_path
# )

# ## 8. Train A2C Agent

# Initialize A2C agent
a2c_agent = A2CAgent(
    input_dim=input_dim,
    output_dim=output_dim,
    lr=0.001,
    gamma=0.99,
    entropy_coef=0.01,
    value_coef=0.5,
    device=device
)

# Define save path
a2c_path = os.path.join('saved_models', 'a2c_agent.pth')

# Train the agent (uncomment to run training)
# a2c_agent, a2c_train_rewards, a2c_eval_rewards = train_a2c(
#     env=train_env,
#     agent=a2c_agent,
#     num_episodes=200,
#     eval_interval=20,
#     save_path=a2c_path
# )

# ## 9. Performance Evaluation and Analysis

def evaluate_portfolio(env, agent, returns_df, plot=True, title=None):
    """
    Evaluate a trained agent on a portfolio environment.
    
    Args:
        env: The portfolio environment.
        agent: The trained agent.
        returns_df (pd.DataFrame): DataFrame with returns data.
        plot (bool): Whether to plot the performance.
        title (str): Title for the plot.
        
    Returns:
        dict: Performance metrics.
    """
    obs, _ = env.reset()
    done = False
    truncated = False
    
    # Track portfolio values and weights
    portfolio_values = [env.portfolio_value]
    weights_history = []
    dates = []
    
    while not done and not truncated:
        action = agent.select_action(obs, evaluate=True)
        obs, _, done, truncated, info = env.step(action)
        
        portfolio_values.append(info['portfolio_value'])
        weights_history.append(info['weights'])
        dates.append(info['date'])
    
    # Convert to arrays and DataFrames
    portfolio_values = np.array(portfolio_values)
    weights_history = np.array(weights_history)
    
    # Calculate performance metrics
    portfolio_returns = np.diff(portfolio_values) / portfolio_values[:-1]
    
    total_return = (portfolio_values[-1] / portfolio_values[0]) - 1
    annualized_return = (1 + total_return) ** (252 / len(portfolio_returns)) - 1
    sharpe = calculate_sharpe_ratio(portfolio_returns)
    cvar = calculate_cvar(portfolio_returns)
    volatility = np.std(portfolio_returns) * np.sqrt(252)
    max_drawdown = calculate_max_drawdown(portfolio_values)
    
    # Calculate turnover
    turnover = 0
    for i in range(1, len(weights_history)):
        turnover += np.sum(np.abs(weights_history[i] - weights_history[i-1])) / 2.0
    avg_turnover = turnover / (len(weights_history) - 1) if len(weights_history) > 1 else 0
    
    # Create metrics dictionary
    metrics = {
        'total_return': total_return,
        'annualized_return': annualized_return,
        'sharpe_ratio': sharpe,
        'cvar': cvar,
        'volatility': volatility,
        'max_drawdown': max_drawdown,
        'avg_turnover': avg_turnover
    }
    
    if plot:
        # Plot portfolio performance
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), gridspec_kw={'height_ratios': [2, 1]})
        
        # Plot portfolio value
        ax1.plot(dates, portfolio_values, label='Portfolio Value')
        
        if title:
            ax1.set_title(title)
        else:
            ax1.set_title('Portfolio Performance')
            
        ax1.set_ylabel('Portfolio Value')
        ax1.legend()
        ax1.grid(True)
        
        # Plot asset weights over time
        if len(weights_history) > 0:
            df_weights = pd.DataFrame(weights_history, columns=env.asset_names, index=dates)
            df_weights.plot(kind='area', stacked=True, ax=ax2)
            ax2.set_title('Asset Allocation Over Time')
            ax2.set_ylabel('Weight')
            ax2.set_xlabel('Date')
            ax2.grid(True)
        
        plt.tight_layout()
        plt.show()
        
        # Print metrics
        print("\nPerformance Metrics:")
        print(f"Total Return: {total_return:.4f} ({annualized_return:.4f} annualized)")
        print(f"Sharpe Ratio: {sharpe:.4f}")
        print(f"CVaR (5%): {cvar:.4f}")
        print(f"Volatility (annualized): {volatility:.4f}")
        print(f"Maximum Drawdown: {max_drawdown:.4f}")
        print(f"Average Turnover: {avg_turnover:.4f}")
    
    return metrics

def compare_agents(envs, agents, agent_names, returns_df, plot=True):
    """
    Compare multiple agents on the same environment.
    
    Args:
        envs (list): List of portfolio environments.
        agents (list): List of trained agents.
        agent_names (list): List of agent names.
        returns_df (pd.DataFrame): DataFrame with returns data.
        plot (bool): Whether to plot the comparison.
        
    Returns:
        pd.DataFrame: Performance metrics for each agent.
    """
    all_metrics = []
    portfolio_values_dict = {}
    
    for i, (env, agent, name) in enumerate(zip(envs, agents, agent_names)):
        metrics = evaluate_portfolio(env, agent, returns_df, plot=False)
        metrics['agent'] = name
        all_metrics.append(metrics)
        
        # Track portfolio values
        obs, _ = env.reset()
        done = False
        truncated = False
        portfolio_values = [env.portfolio_value]
        dates = []
        
        while not done and not truncated:
            action = agent.select_action(obs, evaluate=True)
            obs, _, done, truncated, info = env.step(action)
            portfolio_values.append(info['portfolio_value'])
            if i == 0:  # Only need to save dates once
                dates.append(info['date'])
                
        portfolio_values_dict[name] = portfolio_values
    
    # Create DataFrame with metrics
    metrics_df = pd.DataFrame(all_metrics)
    
    if plot:
        # Plot portfolio values
        plt.figure(figsize=(12, 6))
        
        for name, values in portfolio_values_dict.items():
            plt.plot(dates[:len(values)], values, label=name)
            
        plt.title('Portfolio Performance Comparison')
        plt.xlabel('Date')
        plt.ylabel('Portfolio Value')
        plt.legend()
        plt.grid(True)
        plt.show()
        
        # Plot metrics comparison
        metrics_for_plot = ['annualized_return', 'sharpe_ratio', 'cvar', 'volatility', 'max_drawdown', 'avg_turnover']
        metrics_df_for_plot = metrics_df.set_index('agent')[metrics_for_plot]
        
        plt.figure(figsize=(14, 7))
        sns.heatmap(metrics_df_for_plot, annot=True, cmap='coolwarm', fmt='.4f', cbar=True)
        plt.title('Performance Metrics Comparison')
        plt.tight_layout()
        plt.show()
    
    return metrics_df

# Load agents if they were trained and saved previously
try:
    reinforce_agent = REINFORCEAgent(input_dim, output_dim, device=device)
    reinforce_agent.load(reinforce_path)
    print("Loaded REINFORCE agent from saved model")
except:
    print("No saved REINFORCE model found. Please train the model first.")

try:
    a2c_agent = A2CAgent(input_dim, output_dim, device=device)
    a2c_agent.load(a2c_path)
    print("Loaded A2C agent from saved model")
except:
    print("No saved A2C model found. Please train the model first.")

# ## 10. Stress Testing on Crisis Periods

def get_crisis_periods():
    """
    Return predefined crisis periods for testing.
    
    Returns:
        dict: Dictionary with crisis periods.
    """
    periods = {
        'financial_crisis_2008': {
            'start': '2008-01-01',
            'end': '2009-06-30'
        },
        'covid_crash_2020': {
            'start': '2020-02-01',
            'end': '2020-05-31'
        }
    }
    
    return periods

def stress_test_performance(env, agent, crisis_periods, returns_df):
    """
    Evaluate agent performance during crisis periods.
    
    Args:
        env: The portfolio environment.
        agent: The trained agent.
        crisis_periods (dict): Dictionary with crisis period date ranges.
        returns_df (pd.DataFrame): DataFrame with returns data.
        
    Returns:
        dict: Performance metrics during crisis periods.
    """
    crisis_metrics = {}
    
    for period_name, period_dates in crisis_periods.items():
        start_date = pd.Timestamp(period_dates['start'])
        end_date = pd.Timestamp(period_dates['end'])
        
        # Filter returns for the crisis period
        crisis_returns = returns_df.loc[start_date:end_date] if start_date in returns_df.index and end_date in returns_df.index else pd.DataFrame()
        
        if len(crisis_returns) == 0:
            print(f"No data available for period: {period_name}")
            continue
        
        # Create a new environment with crisis period data
        crisis_env = PortfolioEnv(
            returns=crisis_returns,
            window_size=env.window_size,
            transaction_cost=env.transaction_cost,
            risk_aversion=env.risk_aversion,
            initial_amount=env.initial_amount,
            reward_mode=env.reward_mode
        )
        
        # Evaluate agent during the crisis period
        metrics = evaluate_portfolio(
            crisis_env, agent, crisis_returns, 
            plot=True, 
            title=f'Portfolio Performance During {period_name}'
        )
        
        crisis_metrics[period_name] = metrics
    
    return crisis_metrics

# Perform stress testing if agents are loaded
# Note: This requires downloading additional data for the crisis periods
# Uncomment to run stress testing

# Download data for crisis periods
# crisis_periods = get_crisis_periods()
# tickers = ['SPY', 'QQQ', 'GLD', 'TLT', 'VNQ', 'BND', 'VWO']

# crisis_start = min([pd.Timestamp(period['start']) for period in crisis_periods.values()])
# crisis_end = max([pd.Timestamp(period['end']) for period in crisis_periods.values()])
# crisis_data = download_stock_data(tickers, crisis_start.strftime('%Y-%m-%d'), crisis_end.strftime('%Y-%m-%d'))
# crisis_returns = calculate_returns(crisis_data)

# if 'reinforce_agent' in locals():
#     print("\nStress testing REINFORCE agent")
#     reinforce_crisis_metrics = stress_test_performance(test_env, reinforce_agent, crisis_periods, crisis_returns)

# if 'a2c_agent' in locals():
#     print("\nStress testing A2C agent")
#     a2c_crisis_metrics = stress_test_performance(test_env, a2c_agent, crisis_periods, crisis_returns)

# ## 11. Sensitivity Analysis to Risk Aversion Parameter

def risk_aversion_sensitivity(env_template, agent, returns_df, risk_aversions=[0.0, 0.5, 1.0, 2.0, 5.0]):
    """
    Analyze sensitivity to the risk aversion parameter.
    
    Args:
        env_template: Template environment to clone with different risk aversion values.
        agent: The agent to evaluate.
        returns_df (pd.DataFrame): DataFrame with returns data.
        risk_aversions (list): List of risk aversion values to test.
        
    Returns:
        pd.DataFrame: Performance metrics for different risk aversion values.
    """
    metrics_list = []
    
    for risk_aversion in risk_aversions:
        # Create environment with the current risk aversion
        env = PortfolioEnv(
            returns=returns_df,
            window_size=env_template.window_size,
            transaction_cost=env_template.transaction_cost,
            risk_aversion=risk_aversion,
            initial_amount=env_template.initial_amount,
            reward_mode='risk_adjusted'  # Always use risk-adjusted reward for this analysis
        )
        
        # Evaluate the agent
        metrics = evaluate_portfolio(
            env, agent, returns_df, 
            plot=False
        )
        
        metrics['risk_aversion'] = risk_aversion
        metrics_list.append(metrics)
    
    # Create DataFrame with metrics
    metrics_df = pd.DataFrame(metrics_list)
    
    # Plot metrics vs. risk aversion
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    axes = axes.flatten()
    
    metrics_to_plot = ['annualized_return', 'sharpe_ratio', 'volatility', 'cvar']
    titles = ['Annualized Return', 'Sharpe Ratio', 'Volatility', 'CVaR']
    
    for i, (metric, title) in enumerate(zip(metrics_to_plot, titles)):
        axes[i].plot(metrics_df['risk_aversion'], metrics_df[metric], marker='o')
        axes[i].set_title(title)
        axes[i].set_xlabel('Risk Aversion Parameter')
        axes[i].grid(True)
    
    plt.tight_layout()
    plt.show()
    
    return metrics_df

# Perform risk aversion sensitivity analysis if an agent is loaded
# Uncomment to run sensitivity analysis

# if 'reinforce_agent' in locals():
#     print("\nRisk Aversion Sensitivity Analysis for REINFORCE Agent")
#     reinforce_sensitivity = risk_aversion_sensitivity(test_env, reinforce_agent, test_data)

# if 'a2c_agent' in locals():
#     print("\nRisk Aversion Sensitivity Analysis for A2C Agent")
#     a2c_sensitivity = risk_aversion_sensitivity(test_env, a2c_agent, test_data)

# ## 12. Conclusion

# Summary of the portfolio management project using reinforcement learning:
# 
# 1. We implemented REINFORCE and A2C algorithms for portfolio optimization
# 2. We incorporated risk-sensitivity through CVaR penalty in the reward function
# 3. We evaluated performance on historical data and during crisis periods
# 4. We analyzed the sensitivity to risk aversion parameter
# 
# Key findings:
# - Actor-Critic methods typically converge faster than REINFORCE
# - Higher risk aversion leads to more conservative portfolios with lower volatility
# - The agents can adapt to different market regimes, including crisis periods
# - Transaction costs significantly impact overall performance
# 
# Future improvements:
# - Implement more advanced policy gradient methods (PPO, SAC)
# - Add more sophisticated features (technical indicators, sentiment, etc.)
# - Incorporate transaction cost optimization directly into the learning objective
# - Test with different risk measures beyond CVaR