In [None]:
import os
import math
import time
import random
from collections import deque

import pandas as pd
import numpy as np
import gymnasium as gym
from gymnasium import spaces
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#########################################
# ------------- Environment -------------
#########################################
class EarningsEventEnv(gym.Env):
    '''Trading environment for earnings event-driven strategy.'''
    
    def __init__(self, event_data, transaction_cost=0.0005, initial_cash=10000):
        super(EarningsEventEnv, self).__init__()
        
        self.event_data = event_data.sort_values('timestamp').reset_index(drop=True).copy()
        self.transaction_cost = transaction_cost
        self.initial_cash = initial_cash
        
        # Safety: ensure columns exist
        required = ['ticker_event','earnings_date','event_id','momentum','volatility','pre_close','close','timestamp']
        for c in required:
            if c not in self.event_data.columns:
                raise ValueError(f"Missing column in event_data: {c}")
        
        # Extract metadata (first row)
        self.ticker = event_data['ticker_event'].iloc[0]
        self.earnings_date = event_data['earnings_date'].iloc[0]
        self.event_id = event_data['event_id'].iloc[0]
        self.momentum = float(event_data['momentum'].iloc[0])
        self.volatility = float(event_data['volatility'].iloc[0])
        self.pre_close = float(event_data['pre_close'].iloc[0])
        
        # Define spaces
        self.action_space = spaces.Discrete(3)  # 0=HOLD, 1=BUY (enter long), 2=SELL (exit long)
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(7,), dtype=np.float32)
        
        # Initialize state
        self.current_step = 0
        self.position = 0  # 0 or 1 (long)
        self.cash = initial_cash
        self.trades = []
    
    def reset(self, seed=None, return_info=False):
        self.current_step = 0
        self.position = 0
        self.cash = self.initial_cash
        self.trades = []
        obs = self._get_observation()
        if return_info:
            return obs, {}
        return obs
    
    def _get_observation(self):
        if self.current_step >= len(self.event_data):
            current_price = float(self.event_data['close'].iloc[-1])
        else:
            current_price = float(self.event_data['close'].iloc[self.current_step])
        
        portfolio_value = self.cash + (self.position * current_price)
        window_pnl = (portfolio_value - self.initial_cash) / self.initial_cash
        
        state = np.array([
            self.momentum,
            self.volatility,
            self.pre_close,
            current_price,
            float(self.position),
            float(self.cash),
            float(window_pnl)
        ], dtype=np.float32)
        return state
    
    def step(self, action):
        # action: 0=hold,1=buy,2=sell
        if self.current_step >= len(self.event_data):
            # Already done; return terminal observation
            observation = self._get_observation()
            return observation, 0.0, True, {}
        
        current_price = float(self.event_data['close'].iloc[self.current_step])
        reward = 0.0
        
        if action == 1 and self.position == 0:
            cost = current_price * (1.0 + self.transaction_cost)
            if self.cash >= cost:
                self.position = 1
                self.cash -= cost
                self.trades.append({'action': 'buy', 'price': current_price, 'step': self.current_step})
        elif action == 2 and self.position == 1:
            proceeds = current_price * (1.0 - self.transaction_cost)
            self.cash += proceeds
            self.position = 0
            self.trades.append({'action': 'sell', 'price': current_price, 'step': self.current_step})
        # else: hold or illegal action (ignored)
        
        self.current_step += 1
        done = (self.current_step >= len(self.event_data))
        
        # If done and still long -> auto close at final price with costs
        if done and self.position == 1:
            final_price = float(self.event_data['close'].iloc[-1])
            self.cash += final_price * (1.0 - self.transaction_cost)
            self.position = 0
            self.trades.append({'action': 'auto_sell', 'price': final_price, 'step': self.current_step-1})
        
        if done:
            final_value = self.cash
            reward = (final_value - self.initial_cash) / self.initial_cash
        
        observation = self._get_observation()
        return observation, float(reward), bool(done), {}
    
    def get_window_pnl(self):
        portfolio = self.cash + (self.position * float(self.event_data['close'].iloc[-1]))
        return (portfolio - self.initial_cash) / self.initial_cash


In [None]:
####################################################################
# -------------------- Policy (Actor-Critic) -----------------------
####################################################################
class ActorCriticNet(nn.Module):
    def __init__(self, obs_dim=7, n_actions=3, hidden=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
        )
        self.actor = nn.Linear(hidden, n_actions)  # logits
        self.critic = nn.Linear(hidden, 1)         # state value
    
    def forward(self, x):
        z = self.net(x)
        logits = self.actor(z)
        value = self.critic(z).squeeze(-1)
        return logits, value

####################################################################
# ---------------------- Agent utilities ---------------------------
####################################################################
def select_action(model, state):
    """Return action (int), log_prob (tensor), value (tensor), entropy (tensor)"""
    state_t = torch.tensor(state, dtype=torch.float32, device=DEVICE).unsqueeze(0)  # [1,obs]
    logits, value = model(state_t)   # logits: [1, n_actions], value: [1]
    probs = F.softmax(logits, dim=-1)
    dist = torch.distributions.Categorical(probs=probs)
    action = dist.sample()
    logp = dist.log_prob(action)
    entropy = dist.entropy()
    return int(action.item()), logp.squeeze(0), value.squeeze(0), entropy.squeeze(0)

In [None]:
####################################################################
# -------------------- Training Loop -------------------------------
####################################################################
def train(train_envs, test_envs=None, epochs=50, events_per_epoch=64, lr=3e-4, gamma=1.0,
          entropy_coef=0.01, value_coef=0.5, max_grad_norm=0.5, save_path='pg_agent.pth'):
    """
    train_envs: dict[event_id] -> EarningsEventEnv
    test_envs: dict (optional)
    """
    obs_dim = 7
    n_actions = 3
    model = ActorCriticNet(obs_dim=obs_dim, n_actions=n_actions).to(DEVICE)
    opt = optim.Adam(model.parameters(), lr=lr)
    
    event_ids = list(train_envs.keys())
    if len(event_ids) == 0:
        raise ValueError("No training events provided")
    
    best_eval = -float('inf')
    history = {'train_return': [], 'eval_return': []}
    
    for epoch in range(1, epochs+1):
        model.train()
        epoch_losses = []
        epoch_returns = []
        start_time = time.time()
        
        # sample events for this epoch (with replacement if not enough)
        chosen = [random.choice(event_ids) for _ in range(events_per_epoch)]
        
        for eid in chosen:
            env = train_envs[eid]
            state = env.reset()
            
            log_probs = []
            values = []
            entropies = []
            rewards = []
            done = False
            
            # roll out entire event (episode)
            while not done:
                action, logp, value, entropy = select_action(model, state)
                next_state, reward, done, _ = env.step(action)
                
                log_probs.append(logp)
                values.append(value)
                entropies.append(entropy)
                rewards.append(reward)  # mostly 0 until final step where reward=window pnl
                
                state = next_state
            
            # for REINFORCE with baseline: compute return for each step (rewards are zeros except final)
            # but handle generic discounted case in case you modify reward shaping
            returns = []
            R = 0.0
            # Since rewards are given at end, iterating backwards works fine.
            for r in reversed(rewards):
                R = r + gamma * R
                returns.insert(0, R)
            returns = torch.tensor(returns, dtype=torch.float32, device=DEVICE)
            values = torch.stack(values)
            log_probs = torch.stack(log_probs)
            entropies = torch.stack(entropies)
            
            advantages = returns - values.detach()
            
            # actor loss (policy gradient with advantage)
            actor_loss = -(log_probs * advantages).mean()
            # critic loss (value MSE)
            critic_loss = F.mse_loss(values, returns)
            # entropy bonus
            entropy_loss = -entropies.mean()
            
            loss = actor_loss + value_coef * critic_loss + entropy_coef * entropy_loss
            
            opt.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            opt.step()
            
            epoch_losses.append(loss.item())
            epoch_returns.append(returns[-1].item())  # final event return (window pnl)
        
        avg_loss = np.mean(epoch_losses)
        avg_return = np.mean(epoch_returns)
        history['train_return'].append(avg_return)
        
        # optional evaluation on test set
        eval_return = None
        if test_envs:
            model.eval()
            eval_returns = []
            with torch.no_grad():
                # evaluate on up to 128 test events (or all if fewer)
                test_ids = list(test_envs.keys())
                if len(test_ids) > 128:
                    eval_sample = random.sample(test_ids, 128)
                else:
                    eval_sample = test_ids
                for eid in eval_sample:
                    env = test_envs[eid]
                    s = env.reset()
                    done = False
                    while not done:
                        # greedy/policy sample: we sample to keep stochasticity
                        a, _, _, _ = select_action(model, s)
                        s, r, done, _ = env.step(a)
                    eval_returns.append(env.get_window_pnl())
            eval_return = float(np.mean(eval_returns)) if len(eval_returns) > 0 else 0.0
            history['eval_return'].append(eval_return)
            # save best
            if eval_return > best_eval:
                best_eval = eval_return
                torch.save(model.state_dict(), save_path)
        else:
            # save periodically
            if epoch % 10 == 0:
                torch.save(model.state_dict(), save_path)
        
        elapsed = time.time() - start_time
        print(f"Epoch {epoch:3d} | loss {avg_loss:.4f} | train_return {avg_return:.4f} | eval_return {eval_return if eval_return is not None else 'N/A'} | time {elapsed:.1f}s")
    
    # final save
    torch.save(model.state_dict(), save_path)
    print("Training complete. Model saved to", save_path)
    return model, history


In [None]:
####################################################################
# ----------------------- Utility functions ------------------------
####################################################################
def build_envs_from_csv(csv_path, transaction_cost=0.0005, initial_cash=10000, split_col='split'):
    """
    Expects CSV with columns per Appendix 1 and a 'split' column with values 'train'/'test'.
    Returns train_envs, test_envs as dicts of EarningsEventEnv instances keyed by event_id.
    """
    df = pd.read_csv(csv_path)
    df['timestamp'] = pd.to_datetime(df['timestamp'])
    df['earnings_date'] = pd.to_datetime(df['earnings_date'])
    
    # must have event_id column
    if 'event_id' not in df.columns:
        raise ValueError("CSV must include 'event_id' column to group by events.")
    if split_col not in df.columns:
        # If no split provided, create a random split
        unique_events = df['event_id'].unique()
        rng = np.random.RandomState(0)
        mask = {eid: ('train' if rng.rand() < 0.8 else 'test') for eid in unique_events}
        df['split'] = df['event_id'].map(mask)
    events = df.groupby('event_id').first().reset_index()
    train_events = events[events['split'] == 'train']
    test_events = events[events['split'] == 'test']
    
    train_envs = {}
    for event_id in train_events['event_id']:
        event_data = df[df['event_id'] == event_id].copy().reset_index(drop=True)
        train_envs[event_id] = EarningsEventEnv(event_data, transaction_cost=transaction_cost, initial_cash=initial_cash)
    
    test_envs = {}
    for event_id in test_events['event_id']:
        event_data = df[df['event_id'] == event_id].copy().reset_index(drop=True)
        test_envs[event_id] = EarningsEventEnv(event_data, transaction_cost=transaction_cost, initial_cash=initial_cash)
    
    return df, train_envs, test_envs

In [None]:
####################################################################
# --------------------------- main() -------------------------------
####################################################################
def main():
    # ====== USER CONFIG ======
    CSV_PATH = os.environ.get('CSV_PATH', 'earnings_events_data.csv')
    TRANSACTION_COST = 0.0005
    INITIAL_CASH = 10000
    EPOCHS = 80
    EVENTS_PER_EPOCH = 128
    LR = 3e-4
    SAVE_PATH = 'earnings_pg_agent.pth'
    # =========================
    
    print("Loading CSV:", CSV_PATH)
    df, train_envs, test_envs = build_envs_from_csv(CSV_PATH, TRANSACTION_COST, INITIAL_CASH)
    print(f"Loaded {len(df):,} rows. Train events: {len(train_envs)}, Test events: {len(test_envs)}")
    
    model, history = train(train_envs, test_envs=test_envs, epochs=EPOCHS, events_per_epoch=EVENTS_PER_EPOCH,
                           lr=LR, save_path=SAVE_PATH)
    
    # quick eval summary on test set
    if len(test_envs) > 0:
        print("\nSample evaluation on 10 test events (stochastic policy):")
        model.eval()
        sample_ids = list(test_envs.keys())[:10]
        with torch.no_grad():
            for eid in sample_ids:
                env = test_envs[eid]
                s = env.reset()
                done = False
                while not done:
                    a, _, _, _ = select_action(model, s)
                    s, r, done, _ = env.step(a)
                print(f"Event {eid} | PnL {env.get_window_pnl():.4f} | trades {env.trades}")
    print("Done.")

if __name__ == "__main__":
    main()

Loading CSV: earnings_events_data.csv
Loaded 14,815 rows. Train events: 110, Test events: 28
Epoch   1 | loss 2549.6211 | train_return 0.0005 | eval_return -1.9240948214311335e-06 | time 3.1s
Epoch   2 | loss 445.5890 | train_return 0.0006 | eval_return -1.9240948214311335e-06 | time 3.1s
Epoch   3 | loss 224.0462 | train_return 0.0005 | eval_return -1.9240948214311335e-06 | time 3.0s
Epoch   4 | loss 449.8268 | train_return -0.0009 | eval_return -1.9240948214311335e-06 | time 3.0s
Epoch   5 | loss 200.6727 | train_return 0.0015 | eval_return -1.9240948214311335e-06 | time 3.0s
Epoch   6 | loss 170.8739 | train_return 0.0005 | eval_return -1.9240948214311335e-06 | time 3.0s
Epoch   7 | loss 140.4495 | train_return 0.0008 | eval_return -1.9240948214311335e-06 | time 3.1s
Epoch   8 | loss 328.2689 | train_return 0.0014 | eval_return -1.9240948214311335e-06 | time 3.0s
Epoch   9 | loss 278.8358 | train_return 0.0007 | eval_return -1.9240948214311335e-06 | time 3.1s
Epoch  10 | loss 179.97