In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributions as dist
import torch.functional as F
from models import *
from scipy.optimize import minimize
from torch.distributions import Exponential

In [None]:
def sample_gpd_tail(psi_params, size=MTail):
    """Sample from GPD for tail distribution."""
    scale, shape = psi_params
    return dist.GeneralizedPareto(scale=scale, concentration=shape).sample((size,))

In [None]:
def qr_loss(quantiles, targets, N, MBody, MTail):
    """Quantile Regression Loss for EX-D4PG (Formula from step 12)."""
    loss_body = torch.mean(torch.abs(quantiles[:MBody] - targets[:MBody]))
    loss_tail = torch.mean(torch.abs(quantiles[-MTail:] - targets[-MTail:]))
    return loss_body + loss_tail



In [None]:
def psi_mle(quantiles):
    """
    Estimate GPD parameters (shape and scale) using Maximum Likelihood Estimation (MLE).
    
    Args:
    quantiles (torch.Tensor): Selected quantiles (below u(s_double_prime, a_double_prime)).

    Returns:
    psi (tuple): Estimated GPD parameters (shape, scale).
    """
    # Convert torch tensor to numpy for use with scipy.optimize
    quantiles_np = quantiles.detach().cpu().numpy()

    # Initial guess for shape and scale
    initial_params = [0.1, 1.0]  # Starting points for shape and scale

    # Log-Likelihood function for GPD
    def gpd_neg_log_likelihood(params):
        shape, scale = params
        if scale <= 0:
            return float('inf')  # Scale must be positive

        # Negative log-likelihood for GPD
        n = len(quantiles_np)
        term1 = n * torch.log(torch.tensor(scale))
        term2 = (1 + 1 / shape) * torch.sum(torch.log(1 + shape * quantiles_np / scale))
        log_likelihood = -(term1 + term2)

        return log_likelihood.item()

    # Minimize the negative log-likelihood
    result = minimize(gpd_neg_log_likelihood, initial_params, method='L-BFGS-B')

    # The result contains the optimal shape and scale
    shape_mle, scale_mle = result.x

    return (shape_mle, scale_mle)

In [None]:
N = 32  # number of quantiles (formula from step 1)
MTail = 10  # number of tail samples (step 3)
beta = 0.2  # body proportion
alpha = 1e-3  # learning rate (step 3)
state_dim = 8  # example state dimension
action_dim = 4  # example action dimension

critic = Critic(state_dim, action_dim, N)
target_critic = Critic(state_dim, action_dim, N)
actor = Actor(state_dim, action_dim)
optimizer_critic = optim.Adam(critic.parameters(), lr=alpha)
optimizer_actor = optim.Adam(actor.parameters(), lr=alpha)

In [None]:
# Update actor (line 14 in algorithm)
def update_actor(s, a):
    """Update actor network to maximize VaR or CVaR (line 14)."""
    # Maximize risk-adjusted reward
    policy_loss = -critic(s, actor(s)).mean()  # Placeholder for risk metric VaR or CVaR

    optimizer_actor.zero_grad()
    policy_loss.backward()
    optimizer_actor.step()



In [None]:
# Update target distribution parameters (lines 15-19 in algorithm)
def update_target(s, a, s_double_prime, a_double_prime, N):
    """Update target distribution (line 15-19)."""
    # Sample new tail quantiles (line 16)
    quantiles = critic(s_double_prime, a_double_prime)
    tau = torch.arange(0, N) / N  # Formula for quantiles (line 17)
    selected_quantiles = quantiles[quantiles < u(s_double_prime, a_double_prime)]

    # Update GPD parameters using MLE (line 18)
    psi_new = psi_mle(selected_quantiles)

    return psi_new

In [None]:
def quantile_huber_loss(input, target, tau):
    # Line 5: Compute the difference between the target and input.
    diff = target - input
    # Line 6: Use the quantile loss for both positive and negative errors.
    loss = torch.where(diff > 0, tau * diff, (tau - 1) * diff)
    return loss


In [None]:
def update_critic(critic_network, body_samples, tail_samples, 
                  target_body, target_tail, taus, 
                  M_body, M_tail, optimizer):
    
    # Line 11: Number of quantiles (N)
    N = len(taus)
    
    # Line 13: Initialize QR loss to 0
    qr_loss = 0.0

    # Line 15: Loop over each quantile \tau_n
    for tau_idx in range(N):
        tau = taus[tau_idx]  # \tau_n for the current quantile

        # (3) Line 19: Predicted quantile for body samples, \theta_w^{\tau_n}(s,a)
        critic_prediction_body = critic_network(body_samples['state'], body_samples['action'])[:, tau_idx]
        # (4) Line 21: Predicted quantile for tail samples, \theta_w^{\tau_n}(s,a)
        critic_prediction_tail = critic_network(tail_samples['state'], tail_samples['action'])[:, tau_idx]

        # (5) Line 23: Calculate quantile loss for body samples \rho_{\tau_n}(z_l^{\text{Body}} - \theta_w^{\tau_n}(s,a))
        loss_body = quantile_huber_loss(critic_prediction_body, target_body, tau)
        loss_body = loss_body.mean()  # Mean over body samples

        # (6) Line 27: Calculate quantile loss for tail samples \rho_{\tau_n}(z_k^{\text{Tail}} - \theta_w^{\tau_n}(s,a))
        loss_tail = quantile_huber_loss(critic_prediction_tail, target_tail, tau)
        loss_tail = loss_tail.mean()  # Mean over tail samples

        # (7) Line 31: Combine the two losses as per Equation 20
        # \frac{1}{M_{\text{Body}}} and \frac{1}{M_{\text{Tail}}} are the sample weights
        qr_loss += (1 / M_body) * loss_body + (1 / M_tail) * loss_tail

    # (8) Line 35: Backpropagate the loss to update critic's parameters
    optimizer.zero_grad()  # Clear previous gradients
    qr_loss.backward()     # Backpropagate the QR loss
    optimizer.step()       # Update the parameters

    return qr_loss.item()  # Return the QR loss value for logging

In [None]:
# Example loop for a single episode, following the time loop (line 3 in the algorithm)
for t in range(1000):  # Time loop (line 3)
    s = torch.randn(state_dim)  # Example state
    a = actor(s)  # Action from policy
    r = torch.randn(1)  # Example reward
    s_prime = torch.randn(state_dim)  # Example next state

    # Threshold selection (line 4)
    u = lambda s, a: torch.quantile(critic(s, a), 1 - beta)  # Placeholder quantile function

    # Update critic (line 12)
    update_critic(s, a, r, s_prime)

    # Update actor (line 14)
    update_actor(s, a)

    # Update target distribution (lines 15-19)
    psi_params = update_target(s, a, s_prime, actor(s_prime), N)