Based on paper: https://d-nb.info/1248317343/34

## Data Scrapping

### 1. Output Gap Data extraction

First I got the quarterly GDP for the period [Office for National Statistics]:
https://www.ons.gov.uk/economy/grossdomesticproductgdp/timeseries/ybha/qna

I got the yearly output gap [Office for Budget Responsibility (OBR)]
https://obr.uk/public-finances-databank-2024-25/

Using the quarterly estimates developed [OBR: Output gap measurement: judgement and uncertainty] I replicated the shape of the quarterly output gaps in %.
https://obr.uk/docs/dlm_uploads/WorkingPaperNo5.pdf

In [8]:
import pandas as pd
import datetime as dt

# Use the raw URL from the GitHub repository
xlsx_url = "https://raw.githubusercontent.com/guri99uy/ST449_Project/52611de9d475e711c4c917c4d5ca137427404612/outputgap.xlsx"


# Load the Excel file
df_outputgap = pd.read_excel(xlsx_url, engine='openpyxl')  # Ensure you specify the 'openpyxl' engine for .xlsx files

# Define a function to parse QQYYYY
def parse_qqyyyy(qqyyyy):
    # Extract the quarter and year
    quarter = int(qqyyyy[1])
    year = int(qqyyyy[2:])
    
    # Map the quarter to the first month of that quarter
    quarter_start_month = {1: 1, 2: 4, 3: 7, 4: 10}
    month = quarter_start_month[quarter]
    
    # Create a datetime object for the first day of the quarter
    return dt.datetime(year, month, 1)

# Apply the function to the first column 'QQYYYY' to convert it to datetime
df_outputgap['QQYYYY'] = df_outputgap['QQYYYY'].apply(parse_qqyyyy)
# Rename a single column, e.g., 'OldName' to 'NewName'
df_outputgap.rename(columns={'QQYYYY': 'Date'}, inplace=True)

#Get Date in Quarters
df_outputgap['Date'] = pd.to_datetime(df_outputgap['Date'])
df_outputgap['Quarter'] = df_outputgap['Date'].dt.to_period('Q')
df_outputgap = df_outputgap.drop(columns=['Date'])

df_outputgap['GDP_Pot (m£)'] = df_outputgap['GDP_Pot (m£)'].round(0).astype(int)
df_outputgap['Output_gap (%)'] = df_outputgap['Output_gap (%)'].round(2)

# Display the first few rows of the transformed DataFrame
print(df_outputgap.head())
print(df_outputgap.tail())

   GDP_Real (m£)  GDP_Pot (m£)  Output_gap (%) Quarter
0         127119        130233            2.45  1987Q3
1         129815        133288            2.68  1987Q4
2         133283        137215            2.95  1988Q1
3         136630        141576            3.62  1988Q2
4         140801        145602            3.41  1988Q3
    GDP_Real (m£)  GDP_Pot (m£)  Output_gap (%) Quarter
77         372900        372629           -0.07  2006Q4
78         376958        378202            0.33  2007Q1
79         386144        387920            0.46  2007Q2
80         389291        392366            0.79  2007Q3
81         392244        396777            1.16  2007Q4


### 2. Interest Rate
Got .xlsx file from [Bank of Engalnd]
https://www.bankofengland.co.uk/boeapps/database/Bank-Rate.asp


In [9]:
import pandas as pd
import datetime as dt

# Raw URL of the Excel file
url = "https://raw.githubusercontent.com/guri99uy/ST449_Project/7715079b32be2ea0b9e2e77a3f7b81244f85720f/Bank_Rate.xlsx"
df_interest_rate = pd.read_excel(url, engine='openpyxl')


# Rename columns for easier access (optional)
df_interest_rate.columns = ['Date', 'Interest_rate']

# Convert the 'Date_Changed' column to datetime format
def parse_date(date_str):
    # Handle the format '07 Nov 24' as 'DD MMM YY'
    return dt.datetime.strptime(date_str, '%d %b %y')

df_interest_rate['Date'] = df_interest_rate['Date'].apply(parse_date)

# Check if 'Rate' column is string type, and process accordingly
if df_interest_rate['Interest_rate'].dtype == 'object':
    # Clean the 'Rate' column (replace commas with dots and convert to float)
    df_interest_rate['Interest_rate'] = df_interest_rate['Rate'].str.replace(',', '.').astype(float)
else:
    # Ensure the 'Rate' column is numeric
    df_interest_rate['Interest_rate'] = pd.to_numeric(df_interest_rate['Interest_rate'], errors='coerce')

# Display the processed DataFrame
print("\nEvery Interest rate by Bank of England:")
print(df_interest_rate.head())




Every Interest rate by Bank of England:
        Date  Interest_rate
0 2024-11-07           4.75
1 2024-08-01           5.00
2 2023-08-03           5.25
3 2023-06-22           5.00
4 2023-05-11           4.50


Lets process the data to: 
1. Get the quarter average
2. Assign missing quarters with the last value

In [10]:
import pandas as pd

# Assuming df_interest_rate is the DataFrame with 'Date_Changed' and 'Rate'
# Ensure 'Date_Changed' is a datetime column
df_interest_rate['Date'] = pd.to_datetime(df_interest_rate['Date'])

# Create a column for the quarter and year as strings for grouping
df_interest_rate['Quarter'] = df_interest_rate['Date'].dt.to_period('Q')

# Group by the 'Quarter' column and calculate the average interest rate
quarterly_avg_rate = (
    df_interest_rate.groupby('Quarter', as_index=False)['Interest_rate']
    .mean()
    .rename(columns={'Interest_rate': 'Avg_Interest_Rate'})
)

full_quarters = pd.period_range('1975Q1', '2007Q4', freq='Q')
quarterly_avg_rate['Quarter'] = pd.PeriodIndex(quarterly_avg_rate['Quarter'], freq='Q')
quarterly_avg_rate = quarterly_avg_rate.set_index('Quarter').reindex(full_quarters)

# Fill missing values with the value from the previous quarter
quarterly_avg_rate['Avg_Interest_Rate'] = quarterly_avg_rate['Avg_Interest_Rate'].ffill()
quarterly_avg_rate.reset_index(inplace=True)
quarterly_avg_rate.rename(columns={'index': 'Quarter'}, inplace=True)

# Filter 1997 - 2007
Quarterly_interest_rates = quarterly_avg_rate[
    (quarterly_avg_rate['Quarter'] >= '1987Q3') & (quarterly_avg_rate['Quarter'] <= '2007Q4')
]
Quarterly_interest_rates.reset_index(inplace=True)
Quarterly_interest_rates = Quarterly_interest_rates.drop(columns=['index'])

# Display
print(Quarterly_interest_rates.head())



  Quarter  Avg_Interest_Rate
0  1987Q3              9.880
1  1987Q4              8.880
2  1988Q1              8.630
3  1988Q2              8.080
4  1988Q3             10.755


### 3. Inflation
Source?
Relevant comments:


In [11]:
import pandas as pd

# GitHub raw URL for inflation
url = "https://raw.githubusercontent.com/guri99uy/ST449_Project/c87d1b581f0af98f2a813a9c6134160303e74883/inf_Data.csv"
inflation = pd.read_csv(url)

# Rename columns
inf_data = inflation.rename(columns={"Implied GDP deflator at market prices: SA Index": "GDP Deflator"})
inf_data.rename(columns={"Title": "Quarter"}, inplace=True)

# Change Quarter
inf_data["Quarter"] = inf_data["Quarter"].str.replace(r"(\d{4})\sQ(\d)", r"\1Q\2", regex=True)

print(inf_data.head())


  Quarter  GDP Deflator
0  1987Q3       35.8724
1  1987Q4       36.2206
2  1988Q1       36.5950
3  1988Q2       37.3205
4  1988Q3       37.9849


### 4. Merge relevant data
1. Output Gap
2. Interest rate
3. Inflation
   

In [12]:
# Convert 'Quarter' column in all datasets to period type
Quarterly_interest_rates['Quarter'] = pd.PeriodIndex(Quarterly_interest_rates['Quarter'], freq='Q')
df_outputgap['Quarter'] = pd.PeriodIndex(df_outputgap['Quarter'], freq='Q')
inf_data['Quarter'] = pd.PeriodIndex(inf_data['Quarter'], freq='Q')

# Merge the datasets
merged_df = pd.merge(Quarterly_interest_rates, df_outputgap, on='Quarter', how='inner')  # Inner join
merged_df = pd.merge(merged_df, inf_data, on='Quarter', how='inner')  # Inner join

# Display the merged DataFrame
print(merged_df.head())


print(merged_df.tail())

  Quarter  Avg_Interest_Rate  GDP_Real (m£)  GDP_Pot (m£)  Output_gap (%)  \
0  1987Q3              9.880         127119        130233            2.45   
1  1987Q4              8.880         129815        133288            2.68   
2  1988Q1              8.630         133283        137215            2.95   
3  1988Q2              8.080         136630        141576            3.62   
4  1988Q3             10.755         140801        145602            3.41   

   GDP Deflator  
0       35.8724  
1       36.2206  
2       36.5950  
3       37.3205  
4       37.9849  
   Quarter  Avg_Interest_Rate  GDP_Real (m£)  GDP_Pot (m£)  Output_gap (%)  \
75  2006Q2               4.50         367042        366712           -0.09   
76  2006Q3               4.75         370883        370824           -0.02   
77  2006Q4               5.00         372900        372629           -0.07   
78  2007Q1               5.25         376958        378202            0.33   
79  2007Q2               5.50         3

## Model

In [31]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque, namedtuple
import random
import matplotlib.pyplot as plt
from typing import Tuple, Dict, Any, List

### 1. Environment

In [28]:
class DataBasedEconomyEnv:
    def __init__(
        self, 
        df: pd.DataFrame,
        date_col: str = 'Quarter',
        interest_col: str = 'Avg_Interest_Rate',
        output_gap_col: str = 'Output gap (%)',
        inflation_col: str = 'GDP Deflator',
        lookback_periods: int = 2, 
        validation_split: float = 0.15
    ):
        """Initialize environment with economic dataframe."""
        # Store column names
        self.cols = {
            'date': date_col,
            'interest_rate': interest_col,
            'output_gap': output_gap_col,
            'inflation': inflation_col
        }
        
        # Validate dataframe
        required_cols = list(self.cols.values())
        missing_cols = [col for col in required_cols if col not in df.columns]
        if missing_cols:
            raise ValueError(f"Missing required columns: {missing_cols}")
        
        # Ensure date is datetime and sort
        self.data = df.copy()
        
        # Convert Period to timestamp if necessary
        if isinstance(self.data[self.cols['date']].dtype, pd.PeriodDtype):
            self.data[self.cols['date']] = self.data[self.cols['date']].dt.to_timestamp()
        else:
            self.data[self.cols['date']] = pd.to_datetime(self.data[self.cols['date']])

        self.data = self.data.sort_values(self.cols['date']).reset_index(drop=True)
        
        # Split data
        split_idx = int(len(self.data) * (1 - validation_split))
        self.train_data = self.data.iloc[:split_idx].reset_index(drop=True)
        self.val_data = self.data.iloc[split_idx:].reset_index(drop=True)
        
        self.lookback_periods = lookback_periods
        self.is_validation = False
        self.active_data = self.train_data
        
        # Initialize episode state
        self.current_idx = lookback_periods
        self.max_idx = len(self.active_data) - 1
        
        # Policy targets from paper
        self.inflation_target = 2.0
        self.output_gap_target = 0.0
        
        # Calculate normalization statistics from training data only
        self.compute_normalization_stats()

    def compute_normalization_stats(self) -> None:
        """Compute normalization statistics from training data."""
        self.data_stats = {
            'inflation_mean': self.train_data[self.cols['inflation']].mean(),
            'inflation_std': self.train_data[self.cols['inflation']].std(),
            'output_gap_mean': self.train_data[self.cols['output_gap']].mean(),
            'output_gap_std': self.train_data[self.cols['output_gap']].std(),
            'interest_rate_mean': self.train_data[self.cols['interest_rate']].mean(),
            'interest_rate_std': self.train_data[self.cols['interest_rate']].std()
        }
    
    def normalize(self, value: float, variable: str) -> float:
        """Normalize a value using stored statistics."""
        return (value - self.data_stats[f'{variable}_mean']) / self.data_stats[f'{variable}_std']
    
    def denormalize(self, value: float, variable: str) -> float:
        """Denormalize a value using stored statistics."""
        return value * self.data_stats[f'{variable}_std'] + self.data_stats[f'{variable}_mean']
    
    def switch_to_validation(self) -> None:
        """Switch to validation dataset."""
        self.is_validation = True
        self.active_data = self.val_data
        self.current_idx = self.lookback_periods
        self.max_idx = len(self.active_data) - 1
    
    def switch_to_training(self) -> None:
        """Switch to training dataset."""
        self.is_validation = False
        self.active_data = self.train_data
        self.current_idx = self.lookback_periods
        self.max_idx = len(self.active_data) - 1
    
    def get_state(self) -> np.ndarray:
        """Get current state including lookback periods."""
        start_idx = self.current_idx - self.lookback_periods
        end_idx = self.current_idx + 1
        
        state_data = {
            'inflation': self.active_data[self.cols['inflation']].iloc[start_idx:end_idx].values,
            'output_gap': self.active_data[self.cols['output_gap']].iloc[start_idx:end_idx].values,
            'interest_rate': self.active_data[self.cols['interest_rate']].iloc[start_idx:end_idx-1].values
        }
        
        # Create normalized state vector
        normalized_state = []
        
        # Add current and lagged inflation and output gap
        for var in ['inflation', 'output_gap']:
            normalized_state.extend([self.normalize(x, var) for x in state_data[var]])
            
        # Add lagged interest rates
        normalized_state.extend([self.normalize(x, 'interest_rate') for x in state_data['interest_rate']])
        
        return np.array(normalized_state)
    
    def compute_reward(self, inflation: float, output_gap: float) -> float:
        """
        Compute reward based on paper's specification:
        rt = -ωπ(πt+1 - π*)² - ωy(yt+1)²
        """
        omega_pi = omega_y = 0.5
        
        inflation_loss = -omega_pi * (inflation - self.inflation_target)**2
        output_gap_loss = -omega_y * output_gap**2
        reward = inflation_loss + output_gap_loss
        
        # Additional penalty for large deviations
        if abs(inflation - self.inflation_target) > 2:
            reward *= 0.1
        if abs(output_gap) > 2:
            reward *= 0.1
            
        return reward
    
    def reset(self) -> np.ndarray:
        """Reset environment to start of current dataset."""
        self.current_idx = self.lookback_periods
        return self.get_state()
    
    def step(self, action: float) -> Tuple[np.ndarray, float, bool, Dict[str, Any]]:
        """Execute one step in environment."""
        if self.current_idx >= self.max_idx:
            return self.get_state(), 0, True, {}
        
        actual_action = self.denormalize(action, 'interest_rate')
        self.current_idx += 1
        next_state = self.get_state()
        
        current_inflation = self.active_data[self.cols['inflation']].iloc[self.current_idx]
        current_output_gap = self.active_data[self.cols['output_gap']].iloc[self.current_idx]
        
        reward = self.compute_reward(current_inflation, current_output_gap)
        done = self.current_idx >= self.max_idx
        
        info = {
            'date': self.active_data[self.cols['date']].iloc[self.current_idx],
            'actual_inflation': current_inflation,
            'actual_output_gap': current_output_gap,
            'actual_interest_rate': actual_action,
            'inflation_target': self.inflation_target,
            'output_gap_target': self.output_gap_target
        }
        
        return next_state, reward, done, info


# Create environment
env = DataBasedEconomyEnv(
    df=merged_df,
    date_col='Quarter',
    interest_col='Avg_Interest_Rate',
    output_gap_col='Output_gap (%)',
    inflation_col='GDP Deflator'
)

In [29]:
env.get_state()

array([-2.24641309, -2.19601752, -2.14182999,  1.49115918,  1.64195272,
        1.81897123,  0.71984371,  0.41531225])

In [30]:
env.step(0.5)

(array([-2.19601752, -2.14182999, -2.03682717,  1.64195272,  1.81897123,
         2.25823938,  0.41531225,  0.33917938]),
 -6.303210601250001,
 False,
 {'date': Timestamp('1988-04-01 00:00:00'),
  'actual_inflation': 37.3205,
  'actual_output_gap': 3.62,
  'actual_interest_rate': 9.158091968033347,
  'inflation_target': 2.0,
  'output_gap_target': 0.0})

### 2. Agents

In [26]:
# Experience tuple structure
Experience = namedtuple('Experience', ['state', 'action', 'reward', 'next_state'])

class Actor(nn.Module):
    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 64):
        super(Actor, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
            nn.ReLU()  # ReLU for ZLB constraint (i >= 0)
        )
        
        # Initialize weights using paper's approach
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                module.bias.data.zero_()
                
    def forward(self, state: torch.Tensor) -> torch.Tensor:
        return self.network(state)

class Critic(nn.Module):
    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 64):
        super(Critic, self).__init__()
        
        # Observation path
        self.obs_path = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Action path
        self.action_path = nn.Sequential(
            nn.Linear(action_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Common path
        self.common_path = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )
        
        # Initialize weights
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                module.bias.data.zero_()
                
    def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        obs_features = self.obs_path(state)
        action_features = self.action_path(action)
        combined = torch.cat([obs_features, action_features], dim=1)
        return self.common_path(combined)

class OUNoise:
    """Ornstein-Uhlenbeck process noise generator"""
    def __init__(self, size: int, mu: float = 0., theta: float = 0.15, sigma: float = 1.):
        self.mu = mu * np.ones(size)
        self.theta = theta
        self.sigma = sigma
        self.state = None
        self.reset()
        
    def reset(self):
        self.state = np.copy(self.mu)
        
    def sample(self) -> np.ndarray:
        x = self.state
        dx = self.theta * (self.mu - x) + self.sigma * np.random.randn(len(x))
        self.state = x + dx
        return self.state

class DDPGAgent:
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        hidden_dim: int = 64,
        buffer_size: int = 10000,
        batch_size: int = 64,
        gamma: float = 0.99,
        tau: float = 0.001,
        actor_lr: float = 0.0001,
        critic_lr: float = 0.0001
    ):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Networks
        self.actor = Actor(state_dim, action_dim, hidden_dim).to(self.device)
        self.actor_target = Actor(state_dim, action_dim, hidden_dim).to(self.device)
        self.actor_target.load_state_dict(self.actor.state_dict())
        
        self.critic = Critic(state_dim, action_dim, hidden_dim).to(self.device)
        self.critic_target = Critic(state_dim, action_dim, hidden_dim).to(self.device)
        self.critic_target.load_state_dict(self.critic.state_dict())
        
        # Optimizers
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=critic_lr)
        
        # Experience replay
        self.buffer = deque(maxlen=buffer_size)
        self.batch_size = batch_size
        
        # Parameters
        self.gamma = gamma
        self.tau = tau
        
        # Training metrics
        self.critic_losses = []
        self.actor_losses = []
        
    def select_action(self, state: np.ndarray, noise: np.ndarray = None) -> np.ndarray:
        """Select action with optional exploration noise"""
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        with torch.no_grad():
            action = self.actor(state).cpu().numpy()[0]
            
        if noise is not None:
            action += noise
            
        return np.clip(action, 0, None)  # Apply ZLB constraint
    
    def store_experience(self, state: np.ndarray, action: np.ndarray, 
                        reward: float, next_state: np.ndarray) -> None:
        """Store experience in replay buffer"""
        self.buffer.append(Experience(state, action, reward, next_state))
    
    def train(self) -> Tuple[float, float]:
        """Train the agent using a minibatch from replay buffer"""
        if len(self.buffer) < self.batch_size:
            return 0.0, 0.0
        
        # Sample minibatch
        batch = random.sample(self.buffer, self.batch_size)
        state_batch = torch.FloatTensor([exp.state for exp in batch]).to(self.device)
        action_batch = torch.FloatTensor([exp.action for exp in batch]).to(self.device)
        reward_batch = torch.FloatTensor([exp.reward for exp in batch]).to(self.device)
        next_state_batch = torch.FloatTensor([exp.next_state for exp in batch]).to(self.device)
        
        # Update critic (Bellman equation)
        next_actions = self.actor_target(next_state_batch)
        target_q = reward_batch + self.gamma * self.critic_target(next_state_batch, next_actions)
        current_q = self.critic(state_batch, action_batch)
        critic_loss = nn.MSELoss()(current_q, target_q)
        
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()
        
        # Update actor using policy gradient
        actor_loss = -self.critic(state_batch, self.actor(state_batch)).mean()
        
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()
        
        # Update target networks
        self._update_target_network(self.actor_target, self.actor)
        self._update_target_network(self.critic_target, self.critic)
        
        # Store losses
        self.critic_losses.append(critic_loss.item())
        self.actor_losses.append(actor_loss.item())
        
        return critic_loss.item(), actor_loss.item()
    
    def _update_target_network(self, target: nn.Module, source: nn.Module) -> None:
        """Soft update target network parameters"""
        for target_param, param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_(
                target_param.data * (1.0 - self.tau) + param.data * self.tau
            )
    
    def get_training_metrics(self) -> Tuple[List[float], List[float]]:
        """Return training metrics"""
        return self.critic_losses, self.actor_losses

### 3. Training

In [27]:
import numpy as np
from collections import deque, namedtuple
import random
import torch
from typing import List, Tuple, Dict

# Experience tuple structure
Experience = namedtuple('Experience', ['state', 'action', 'reward', 'next_state'])

class ReplayBuffer:
    def __init__(self, capacity: int = 10000):
        self.buffer = deque(maxlen=capacity)
        
    def push(self, state: np.ndarray, action: np.ndarray, 
             reward: float, next_state: np.ndarray) -> None:
        """Add experience to buffer"""
        self.buffer.append(Experience(state, action, reward, next_state))
        
    def sample(self, batch_size: int) -> Tuple:
        """Sample a batch of experiences"""
        batch = random.sample(self.buffer, batch_size)
        state = torch.FloatTensor([exp.state for exp in batch])
        action = torch.FloatTensor([exp.action for exp in batch])
        reward = torch.FloatTensor([exp.reward for exp in batch])
        next_state = torch.FloatTensor([exp.next_state for exp in batch])
        return state, action, reward, next_state
    
    def __len__(self) -> int:
        return len(self.buffer)

class OUNoise:
    """Ornstein-Uhlenbeck process for exploration"""
    def __init__(self, size: int, mu: float = 0., theta: float = 0.15, sigma: float = 1.):
        self.mu = mu * np.ones(size)
        self.theta = theta
        self.sigma = sigma
        self.state = None
        self.reset()
        
    def reset(self):
        self.state = np.copy(self.mu)
        
    def sample(self) -> np.ndarray:
        x = self.state
        dx = self.theta * (self.mu - x) + self.sigma * np.random.randn(len(x))
        self.state = x + dx
        return self.state

class EpisodeManager:
    """Manages episode execution and data collection"""
    def __init__(self, env, agent):
        self.env = env
        self.agent = agent
        self.episode_rewards = []
        self.policy_states = {}
        
    def run_episode(self, training: bool = True, noise_sigma: float = 1.0) -> Dict:
        state = self.env.reset()
        episode_reward = 0
        transitions = []
        noise = OUNoise(1, sigma=noise_sigma) if training else None
        
        while True:
            action = self.agent.select_action(state)
            if noise:
                action += noise.sample()
                
            next_state, reward, done, info = self.env.step(action)
            
            if training:
                self.agent.buffer.push(state, action, reward, next_state)
                self.agent.train()
                
            transitions.append({
                'state': state,
                'action': action,
                'reward': reward,
                'next_state': next_state,
                'info': info
            })
            
            episode_reward += reward
            state = next_state
            
            if done:
                break
                
        self.episode_rewards.append(episode_reward)
        
        return {
            'episode_reward': episode_reward,
            'transitions': transitions,
            'final_info': info
        }
    
    def get_metrics(self) -> Dict:
        """Return training metrics"""
        return {
            'episode_rewards': self.episode_rewards,
            'avg_reward': np.mean(self.episode_rewards),
            'std_reward': np.std(self.episode_rewards)
        }

class ValidationManager:
    """Manages validation process"""
    def __init__(self, env, agent):
        self.env = env
        self.agent = agent
        self.best_reward = float('-inf')
        self.best_policy_state = None
        
    def validate(self, num_episodes: int = 5) -> Dict:
        self.env.switch_to_validation()
        episode_manager = EpisodeManager(self.env, self.agent)
        
        validation_rewards = []
        for _ in range(num_episodes):
            episode_info = episode_manager.run_episode(training=False)
            validation_rewards.append(episode_info['episode_reward'])
            
        avg_reward = np.mean(validation_rewards)
        
        # Save best policy
        if avg_reward > self.best_reward:
            self.best_reward = avg_reward
            self.best_policy_state = {
                'actor': self.agent.actor.state_dict(),
                'critic': self.agent.critic.state_dict(),
                'reward': avg_reward
            }
            
        self.env.switch_to_training()
        
        return {
            'avg_reward': avg_reward,
            'std_reward': np.std(validation_rewards),
            'best_reward': self.best_reward
        }
    
    def restore_best_policy(self) -> None:
        """Restore best performing policy"""
        if self.best_policy_state is not None:
            self.agent.actor.load_state_dict(self.best_policy_state['actor'])
            self.agent.critic.load_state_dict(self.best_policy_state['critic'])
            print(f"Restored best policy with validation reward: {self.best_policy_state['reward']:.2f}")