# Notebook 05 â€“ PPO Agent Training (Revised)
## Multimodal FinRL Trading System for EGX

This notebook trains a PPO agent using the multimodal dataset.
The implementation strictly follows the FinRL API to avoid runtime errors.


In [88]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [89]:
!pip install -q finrl stable-baselines3 gymnasium pandas numpy

In [90]:
import os
import pandas as pd
import numpy as np

!pip install alpaca-trade-api
!pip install exchange-calendars

from finrl.meta.env_stock_trading.env_stocktrading import StockTradingEnv
from stable_baselines3 import PPO



In [91]:
BASE_DIR = '/content/drive/MyDrive/finrl-egx-multimodal'
DATA_PATH = os.path.join(BASE_DIR, 'data', 'multimodal_finrl_data.csv')
RESULTS_DIR = os.path.join(BASE_DIR, 'results')
os.makedirs(RESULTS_DIR, exist_ok=True)

In [92]:
df = pd.read_csv(DATA_PATH)
df['Date'] = pd.to_datetime(df['Date'])
df.rename(columns={'Close': 'close', 'Date': 'date', 'Adj Close': 'adjcp'}, inplace=True)
df['close'] = pd.to_numeric(df['close'], errors='coerce') # Ensure 'close' column is numeric
df = df.sort_values(['date', 'tic']).reset_index(drop=True) # Ensure 'date' is a column and a default integer index
df.head()

  return datetime.utcnow().replace(tzinfo=utc)


Unnamed: 0,date,adjcp,close,High,Low,Open,Volume,daily_return,rsi,sma_20,sma_50,volatility,sentiment_news,sentiment_social,tic
0,2020-03-15,1.230391,1.9,2.059999942779541,1.850000023841858,2.180000066757202,2335434,-0.12844,13.470857,1.802846,2.059285,0.042764,0.0,-0.314831,AMOC
1,2020-03-15,28.729263,31.472473,34.14461898803711,31.447547912597656,34.937286376953125,17661326,-0.100841,15.340648,36.029904,37.541819,0.031654,0.0,-0.145865,COMI
2,2020-03-15,6.516446,7.81,8.380000114440918,7.769999980926514,8.630000114440918,2740164,-0.072446,23.45679,7.719194,8.512932,0.033478,0.0,-0.09956,SWDY
3,2020-03-16,1.133255,1.75,1.899999976158142,1.7100000381469729,1.899999976158142,2396915,-0.078947,12.021511,1.754926,2.034807,0.044315,0.0,-0.696901,AMOC
4,2020-03-16,26.80427,29.363672,31.25810432434082,28.34167861938477,31.47247314453125,5148437,-0.067005,12.860626,35.425101,37.32156,0.033519,0.078292,0.078292,COMI


  return datetime.utcnow().replace(tzinfo=utc)


In [93]:
TECH_INDICATORS = [
    'rsi', 'sma_20', 'sma_50', 'volatility',
    'sentiment_news', 'sentiment_social'
]

In [94]:
train_df = df[df['date'] < '2024-01-01'] # Use 'date' column for slicing
test_df = df[df['date'] >= '2024-01-01'] # Use 'date' column for slicing
print(train_df.shape, test_df.shape)

(2772, 15) (1452, 15)


In [95]:
import numpy as np
import pandas as pd # Explicitly import pandas

class CustomStockTradingEnv(StockTradingEnv):
    def __init__(self, df, stock_dim, hmax, initial_amount, num_stock_shares, buy_cost_pct, sell_cost_pct, reward_scaling, state_space, action_space, tech_indicator_list, turbulence_threshold=0, risk_indicator_col='vix', make_plots=False, print_verbosity=10, day=0, initial=True, previous_state=[], model_name='', mode='train', iteration=''):
        # Store unique sorted dates from the DataFrame BEFORE calling super().__init__
        self._trading_days = sorted(df['date'].unique())
        self.max_day = len(self._trading_days) - 1

        # Initialize balance and initial_cash BEFORE super().__init__
        self.balance = float(initial_amount) # Ensure balance is float
        self.initial_cash = float(initial_amount)

        # Initialize turbulence_make_day as it's used in the custom step method
        self.turbulence_make_day = 1 # Default to 1, meaning turbulence is checked every day

        # Call super().__init__. It will call self._initiate_state() which we have overridden.
        super().__init__(df, stock_dim, hmax, initial_amount, num_stock_shares, buy_cost_pct, sell_cost_pct, reward_scaling, state_space, action_space, tech_indicator_list, turbulence_threshold, risk_indicator_col, make_plots, print_verbosity, day, initial, previous_state, model_name, mode, iteration)

    def _get_data_for_current_day(self):
        # Helper method to get the correct DataFrame slice for the current day index
        if self.day > self.max_day: # This should be handled by terminal condition but as a safeguard
            return pd.DataFrame() # Return empty DataFrame if past valid days
        current_date = self._trading_days[self.day]
        data_for_day = self.df[self.df['date'] == current_date]
        if data_for_day.empty:
            raise ValueError(f"No data found for date: {current_date} at day index: {self.day}. Check your data or day indexing logic.")
        return data_for_day

    def _initiate_state(self):
        # Ensure self.data is always a DataFrame for the current day (self.day=0 initially)
        self.data = self._get_data_for_current_day()

        state = (
            [self.balance]
            + self.data.close.values.tolist() # current stock prices
            + self.num_stock_shares # current stock holdings
        )
        for indicator in self.tech_indicator_list:
            state += self.data[indicator].values.tolist()

        return state

    def _update_state(self):
        # Ensure self.data is always a DataFrame for the current day
        # This method is called after day increment in step, so self.data should reflect the new day.
        self.data = self._get_data_for_current_day()

        state = (
            [self.balance]
            + self.data.close.values.tolist()
            + self.num_stock_shares
        )
        for indicator in self.tech_indicator_list:
            state += self.data[indicator].values.tolist()
        return state

    def reset(self, seed=None, options=None):
        # Completely override reset to ensure proper initialization
        if seed is not None:
            # This is typically handled by the gymnasium wrapper
            pass

        self.day = 0
        self.data = self._get_data_for_current_day() # Ensure self.data is correct DataFrame for day 0

        self.initial = True
        self.num_stock_shares = [0] * self.stock_dim
        self.balance = self.initial_cash # Reset balance to initial
        self.cost = 0
        self.transactions = []
        self.rewards = []
        initial_asset_value = self.balance + sum(np.array(self.num_stock_shares) * np.array(self.data.close.values))
        self.asset_memory = [initial_asset_value]
        self.roi_memory = [0]
        self.state = self._initiate_state() # Calls our overridden _initiate_state

        self.account_information = {
            "cash": self.balance,
            "asset_value": sum(np.array(self.num_stock_shares) * np.array(self.data.close.values)),
            "number_of_stocks": sum(self.num_stock_shares),
            "total_assets": initial_asset_value
        }

        # Return state and info dictionary as per gymnasium API
        return self.state, self.account_information

    # Override the step method to control data fetching and turbulence calculation
    def step(self, actions):
        # print("CustomStockTradingEnv step called") # Debug statement
        self.terminal = self.day >= self.max_day

        if self.terminal:
            # Episode termination logic, return gymnasium-compatible format
            df_total_value = pd.DataFrame(self.asset_memory)
            tot_reward = self.asset_memory[-1] - self.asset_memory[0]
            df_total_value.columns = ["account_value"]
            df_total_value.set_index(df_total_value.columns[0], inplace=True)
            df_total_value.index.name = None
            if self.make_plots:
                self.plot(df_total_value)
            # Reward and state should be from the last valid step
            return self.state, self.rewards[-1] if self.rewards else 0.0, True, False, {"e": self.episode}

        else:
            # Get current day's data before processing actions
            self.data = self._get_data_for_current_day() # Ensure self.data is DataFrame for action processing

            begin_total_asset = self.balance + sum(self.num_stock_shares * self.data.close.values)

            # Process actions (actions here are expected to be from the agent, typically scaled by hmax)
            actions = np.array(actions) # Ensure actions are numpy array

            # Convert actions to desired share changes (FinRL's default assumes actions are normalized [-1, 1])
            # If actions are already shares, remove scaling by hmax
            # The problem context did not specify action space scaling, so I will assume actions are normalized and need hmax scaling.
            scaled_actions = actions * self.hmax # Scale actions by hmax
            scaled_actions = scaled_actions.astype(int) # Convert to integer shares

            argsort_actions = np.argsort(scaled_actions)
            sell_index = argsort_actions[:np.where(scaled_actions < 0)[0].shape[0]]
            buy_index = argsort_actions[::-1][:np.where(scaled_actions > 0)[0].shape[0]]

            # Execute sell orders
            for index in sell_index:
                if self.num_stock_shares[index] > 0:
                    sell_num_shares = min(abs(scaled_actions[index]), self.num_stock_shares[index])
                    self.balance += (
                        self.data.close.values[index] * sell_num_shares * (1 - self.sell_cost_pct)
                    )
                    self.num_stock_shares[index] -= sell_num_shares
                    self.cost += (
                        self.data.close.values[index] * sell_num_shares * self.sell_cost_pct
                    )
                    self.transactions.append(
                        (self.day, self.data.tic.values[index], "sell", self.data.close.values[index], sell_num_shares)
                    )

            # Execute buy orders
            for index in buy_index:
                # Ensure enough balance before buying
                price_with_cost = self.data.close.values[index] * (1 + self.buy_cost_pct)
                buy_num_shares = min(abs(scaled_actions[index]), int(self.balance // price_with_cost))
                if buy_num_shares > 0:
                    self.balance -= price_with_cost * buy_num_shares
                    self.num_stock_shares[index] += buy_num_shares
                    self.cost += (
                        self.data.close.values[index] * buy_num_shares * self.buy_cost_pct
                    )
                    self.transactions.append(
                        (self.day, self.data.tic.values[index], "buy", self.data.close.values[index], buy_num_shares)
                    )

            # Increment day to move to the next trading day
            self.day += 1

            # After incrementing day, get data for the new current day to update asset values and state
            if self.day <= self.max_day:
                self.data = self._get_data_for_current_day() # This updates self.data for the new day
            else:
                # If we just hit terminal state, self.data should still be the last day's data for calculations
                pass

            end_total_asset = self.balance + sum(np.array(self.num_stock_shares) * np.array(self.data.close.values))

            self.reward = end_total_asset - begin_total_asset
            self.rewards.append(self.reward)

            self.asset_memory.append(end_total_asset)
            self.roi_memory.append(
                ((end_total_asset - self.asset_memory[0]) / self.asset_memory[0]) * 100
            )

            # Recalculate state for the new day (using _update_state which ensures self.data is correct)
            self.state = self._update_state()

            # Turbulence calculation for the NEW day
            if self.turbulence_threshold is not None:
                if self.day % self.turbulence_make_day == 0 and self.day <= self.max_day:
                    # self.data is already a DataFrame for the current (new) day
                    # self.data[self.risk_indicator_col] will be a Series, .values[0] will be a float
                    # The turbulence value needs to be a float, not a DataFrame or Series
                    turbulence_value = self.data[self.risk_indicator_col].values
                    if len(turbulence_value) > 0:
                        self.turbulence = float(turbulence_value[0])
                    else:
                        self.turbulence = 0.0 # Default if no turbulence data for current day
                else:
                    self.turbulence = 0.0 # If not a turbulence day or past max day
            else:
                self.turbulence = 0.0

            self.terminal = self.day >= self.max_day # Update terminal status again

            self.info = {
                "day": self.day,
                "asset_memory": self.asset_memory[-1],
                "cash": self.balance,
                "num_stock_shares": self.num_stock_shares,
                "turbulence": self.turbulence,
                "current_date": self._trading_days[self.day] if self.day <= self.max_day else "Terminal"
            }

            return self.state, self.reward, self.terminal, False, self.info

In [96]:
stock_dim = len(train_df.tic.unique())
state_space = 1 + 2 * stock_dim + len(TECH_INDICATORS) * stock_dim
action_space = stock_dim

env_train = CustomStockTradingEnv(
    df=train_df,
    stock_dim=stock_dim,
    hmax=100,
    initial_amount=1_000_000,
    num_stock_shares=[0] * stock_dim,
    buy_cost_pct=0.001,
    sell_cost_pct=0.001,
    reward_scaling=1e-4,
    state_space=state_space,
    action_space=action_space,
    tech_indicator_list=TECH_INDICATORS,
    day=0,
    risk_indicator_col='volatility' # Ensure risk_indicator_col is passed here as well
)

In [97]:
model = PPO(
    'MlpPolicy',
    env_train,
    learning_rate=3e-4,
    n_steps=2048,
    batch_size=64,
    gamma=0.99,
    verbose=1
)
print("PPO model re-initialized with CustomStockTradingEnv.")

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
PPO model re-initialized with CustomStockTradingEnv.


In [98]:
model.learn(total_timesteps=100_000)
print("PPO agent training complete.")

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 923      |
|    ep_rew_mean     | 1.39e+05 |
| time/              |          |
|    fps             | 173      |
|    iterations      | 1        |
|    time_elapsed    | 11       |
|    total_timesteps | 2048     |
---------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 923          |
|    ep_rew_mean          | 9.62e+04     |
| time/                   |              |
|    fps                  | 166          |
|    iterations           | 2            |
|    time_elapsed         | 24           |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0042424607 |
|    clip_fraction        | 0.0271       |
|    clip_range           | 0.2          |
|    entropy_loss         | -4.25        |
|    explained_variance   | 0            |
|    learning_r

# Task
The PPO agent has been trained. I'll proceed to save the trained model.

## Save Trained Model

### Subtask:
Save the trained PPO model to the designated RESULTS_DIR for future use.


**Reasoning**:
The training of the PPO agent is complete, and now I need to save the trained model to the designated RESULTS_DIR using `model.save(MODEL_PATH)` as instructed. Afterward, I will print a confirmation message.



In [100]:
MODEL_PATH = os.path.join(RESULTS_DIR, 'ppo_multimodal_model')
model.save(MODEL_PATH)
print('Saved to', MODEL_PATH)

Saved to /content/drive/MyDrive/finrl-egx-multimodal/results/ppo_multimodal_model


## Final Task

### Subtask:
Confirm the successful training and saving of the PPO agent.


## Summary:

### Data Analysis Key Findings

*   The PPO model was successfully saved to the path `/content/drive/MyDrive/finrl-egx-multimodal/results/ppo_multimodal_model`.
*   A confirmation message indicated the successful completion of the save operation.

### Insights or Next Steps

*   The trained PPO agent is now persistently stored and ready for subsequent evaluation or deployment phases.
