# RCDCC Research Notebook

## Risk-Conditioned Dirichlet Concentration Control for PPO Portfolio RL

This notebook provides a comprehensive and granular write-up of the Risk-Conditioned Dirichlet Concentration Control (RCDCC) framework, designed to stabilize PPO-based portfolio optimization agents.

## 1) Motivation and Problem

### 1.1 The PPO-Dirichlet Challenge
In current PPO-Dirichlet setups, we frequently observe periodic **KL spikes** and early-stop events. These indicate that the policy is making over-updatesâ€”abrupt shifts in allocation that exceed the trust-region boundaries ($target\text{\_}kl$).

### 1.2 Coupling of Preference and Confidence
The fundamental issue is the coupling of:
1. **Allocation preference**: Which assets to favor (the direction of the vector).
2. **Allocation confidence**: How concentrated or uncertain to be (the intensity of the vector).

In a standard parameterization, trying to change preference often inadvertently increases concentration (confidence), leading to instability. This coupling makes it hard to manage update stability, turnover, and risk simultaneously in non-stationary regimes.

### 1.3 Static vs. Dynamic Control
While static controls (like `alpha_activation`, `alpha_cap`, and `epsilon`) provide a baseline level of stability, they do not adapt to market context. RCDCC addresses this by introducing closed-loop, adaptive concentration governance.

## 2) Core Idea & Factorization

RCDCC explicitly decouples **Direction** from **Confidence** using a two-part Dirichlet parameterization.

### 2.1 The Parameterization
Introduce two components:
- A **simplex mean vector** $p_t$, representing the portfolio preference.
- A **scalar concentration budget** $C_t$, representing the confidence level.

The final Dirichlet parameters $\alpha_{t,i}$ are constructed as:

$$\alpha_{t,i} = \alpha_{\text{floor}} + C_t \cdot p_{t,i}, \quad \sum_i p_{t,i} = 1, \quad p_{t,i} \ge 0$$

### 2.2 Mechanism of Action
By modulating $C_t$, we effectively scale the variance of the sampled portfolio weights without altering their expected value. Under high market uncertainty or training instability, shrinking $C_t$ broadens the Dirichlet distribution (increasing exploration/uncertainty), which dampens the impact of abrupt policy shifts.

## 3) Network Parameterization

The actor produces two distinct heads from its backbone (e.g., TCN, Attention, or Fusion layer):

### 3.1 Preference Head
$$p_t = \text{softmax}(z_t / \tau_t)$$
- $z_t$: logits from actor backbone
- $\tau_t$: temperature (can be scheduled or fixed)

### 3.2 Raw Concentration Head
$$\hat{C}_t = \text{softplus}(u_t) + \epsilon_C$$
- $u_t$: scalar output
- Ensures positivity of the raw budget.

Then apply controller scaling:
$$C_t = \text{clip}(\hat{C}_t \cdot g_t,\; C_{\min},\; C_{\max})$$
where $g_t$ is the adaptive gain driven by the closed-loop controller.

## 4) Closed-Loop Controller

We track Exponential Moving Averages (EMA) of update diagnostics:
- $\bar{k}_t$: Approximate KL divergence
- $\bar{to}_t$: Portfolio turnover
- $\bar{dd}_t$: Drawdown magnitude
- $\bar{s}_t$: KL early-stop incidence rate

### 4.1 Error Terms
$$e_k = \frac{\bar{k}_t}{k^*} - 1, \quad e_{to} = \frac{\bar{to}_t}{to^*} - 1, \quad e_{dd} = \frac{\bar{dd}_t}{dd^*} - 1$$

### 4.2 Gain Update (Log-Space)
$$\log g_{t+1} = \text{clip}(\log g_t - \eta_k e_k - \eta_{to} e_{to} - \eta_{dd} e_{dd} - \eta_s \bar{s}_t, \; \log g_{\min}, \log g_{\max})$$
$$g_{t+1} = \exp(\log g_{t+1})$$

If diagnostics are poor (high KL, turnover, or drawdown), $g_t$ falls, shrinking concentration and increasing policy entropy. If healthy, $g_t$ recovers to allow stronger conviction.

## 5) Regime-Aware Extension

We can further refine the gain using a regime multiplier $\rho_t$ based on market state:
$$g_t \leftarrow g_t \cdot \rho_t$$
- **Calm Regime**: $\rho_t \in [1.00, 1.10]$ (allow higher conviction)
- **Stress Regime**: $\rho_t \in [0.75, 0.95]$ (force conservatism / broader exploration)

This makes the model's confidence adaptive to both its own optimization state and the external environment.

## 6) PPO Coupling and Synergy

RCDCC is designed to work *with* PPO's existing trust-region mechanisms, not replace them:
1. **Upstream Mitigation**: RCDCC reduces the frequency and severity of KL overshoots by modulating concentration *before* the action space is sampled.
2. **Complementary Control**: PPO still maintains `target\_kl`, `policy\_clip`, and early-stop triggers.
3. **Resulting Synergy**: Smoother learning curves, fewer update truncations, and lower churn.

## 7) Practical Integration (Your Codebase)

### 7.1 `src/agents/actor_critic_tf.py`
- Modify `DirichletActor` to add the dual heads.
- Update `_compute_alpha` to use the factorization logic.

### 7.2 `src/agents/ppo_agent_tf.py`
- Implement the EMA trackers for diagnostics.
- Implement the $g_t$ controller update loop within the `update` method.

### 7.3 `src/config.py`
- Add `dirichlet_controller_params` to the agent configuration (see suggested schema below).

In [None]:
# 7.4 Suggested config schema (copy into config override cell)
rcdcc_params = {
    'enabled': True,
    'alpha_floor': 1e-3,
    'epsilon_c': 1e-6,
    'concentration_min': 2.0,
    'concentration_max': 60.0,
    'gain_init': 1.0,
    'gain_min': 0.4,
    'gain_max': 2.0,
    'ema_beta': 0.90,
    'targets': {
        'kl': 0.020,
        'turnover': 0.35,
        'drawdown': 0.18,
    },
    'controller_eta': {
        'kl': 0.40,
        'turnover': 0.20,
        'drawdown': 0.15,
        'early_stop': 0.10,
    },
    'regime_gate': {
        'enabled': True,
        'stress_gain_mult': 0.85,
        'calm_gain_mult': 1.00,
    },
    'temperature': {
        'base': 1.20,
        'min': 0.90,
        'max': 1.60,
    },
}

print('RCDCC template ready')

## 8) Expected Behavioral Effects

- **Lower KL Overshoot Rate**: Fewer abrupt policy shifts from overconfident alpha spikes.
- **Lower Turnover Volatility**: Allocations become smoother, especially in high-stress regimes.
- **Improved Train Stability**: Fewer frequent KL early-stop interruptions, leading to more efficient use of data.
- **Better Robustness**: Out-of-sample (OOS) gains from adaptive confidence rather than static heuristics.

## 9) Ablation Plan

1. **Baseline**: Static Dirichlet (`elu/cap/epsilon` only).
2. **Factorized-only**: Mean-concentration factorization only (no controller).
3. **Factorized + KL Controller**: Partial controller using only KL error.
4. **Factorized + KL+Turnover Controller**: Multi-objective control.
5. **Full RCDCC**: KL + Turnover + Drawdown + Regime-Aware multipliers.

In [None]:
# Experiment registry for ablations
from pathlib import Path

EXPERIMENTS = [
    {
        'label': 'baseline_static_dirichlet',
        'method': 'Static Dirichlet + static PPO-KL',
        'group': 'baseline',
        'logs_dir': Path('./results/logs'),  # Update for actual local results
        'run_tag': None,
    },
    {
        'label': 'rcdcc_full',
        'method': 'RCDCC full',
        'group': 'proposed',
        'logs_dir': Path('./results/logs'),
        'run_tag': None,
    },
]

print('Configured runs:', len(EXPERIMENTS))

## 10) Novelty and Contribution

- **Adaptive Concentration Governance**: Most portfolio RL uses fixed or scheduled Dirichlet controls. RCDCC introduces a **closed-loop feedback system**.
- **Stateful Confidence**: Explicitly treats confidence as a stateful, risk-aware process driven by optimization diagnostics.
- **Explicit Decoupling**: Provides a cleaner mathematical separation between policy intent (mean) and policy conviction (concentration).

## 11) Risks and Mitigations

- **Controller Oscillation**: Aggressive $\eta$ values could cause $g_t$ to oscillate. 
  - *Mitigation*: use log-space updates and EMA smoothing.
- **Policy Flattening**: The controller might over-flatten the policy to satisfy constraints.
  - *Mitigation*: set reasonable $g_{\min}$ and $C_{\min}$ floors.
- **Delayed Feedback**: Reward/drawdown signals are delayed relative to the gradient update.
  - *Mitigation*: use EMA to provide a stable, long-horizon view of risk diagnostics.

In [None]:
# Load helper
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

plt.style.use('seaborn-v0_8-whitegrid')


def pick_latest(logs_dir: Path, pattern: str, run_tag=None):
    files = sorted(logs_dir.glob(pattern), key=lambda p: p.stat().st_mtime)
    if run_tag:
        tagged = [p for p in files if run_tag in p.name]
        if tagged:
            return tagged[-1]
    return files[-1] if files else None


def load_artifacts(logs_dir: Path, run_tag=None):
    ep_p = pick_latest(logs_dir, '*episodes*.csv', run_tag)
    st_p = pick_latest(logs_dir, '*step_diagnostics*.csv', run_tag)
    sm_p = pick_latest(logs_dir, '*summary*.csv', run_tag)
    md_p = pick_latest(logs_dir, '*_metadata.json', run_tag)

    return {
        'episodes_path': ep_p,
        'steps_path': st_p,
        'summary_path': sm_p,
        'metadata_path': md_p,
        'episodes': pd.read_csv(ep_p) if ep_p else None,
        'steps': pd.read_csv(st_p) if st_p else None,
        'summary': pd.read_csv(sm_p) if sm_p else None,
        'metadata': json.loads(md_p.read_text(encoding='utf-8')) if md_p else None,
    }

In [None]:
def col(df, names):
    for n in names:
        if df is not None and n in df.columns:
            return n
    return None


def summarize_run(art):
    out = {}
    ep = art['episodes']
    st = art['steps']

    if st is not None and len(st) > 0:
        kl = col(st, ['approx_kl', 'kl'])
        es = col(st, ['early_stop_kl_triggered', 'early_stop'])
        cf = col(st, ['clip_fraction'])
        to = col(st, ['turnover'])

        if kl:
            out['kl_mean'] = float(st[kl].mean())
            out['kl_p95'] = float(st[kl].quantile(0.95))
        if es:
            out['early_stop_rate'] = float((st[es] > 0).mean())
        if cf:
            out['clip_fraction_mean'] = float(st[cf].mean())
        if to:
            out['turnover_mean'] = float(st[to].mean())

    if ep is not None and len(ep) > 0:
        shp = col(ep, ['Sharpe', 'sharpe', 'sharpe_ratio'])
        ret = col(ep, ['Return', 'total_return'])
        mdd = col(ep, ['Max_Drawdown', 'max_drawdown'])

        if shp:
            out['episode_sharpe_mean'] = float(ep[shp].mean())
            out['episode_sharpe_p75'] = float(ep[shp].quantile(0.75))
        if ret:
            out['episode_return_mean'] = float(ep[ret].mean())
        if mdd:
            out['episode_mdd_mean'] = float(ep[mdd].mean())

    return out