# Active Causal Discovery Example (RL)

This notebook demonstrates: 
- How to generate DAGs, linear Gaussian mechanisms, and train hierarchical (hybrid) policies with the Soft-Critic Actor (SAC) algorithm, where Q-values are enumerated over discrete actions (Q-enumeration). 

- Plotting code and experiments undertaken to investigate the effect of stochasticity on robustness to prior and misspecification shift.

### 1. Import Dependencies

In [None]:
# Main dependencies
import math
import numpy as np
import networkx as nx

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.distributions import Normal, TransformedDistribution, Categorical, Gamma
from torch.distributions.transforms import TanhTransform
from torch.nn.utils import clip_grad_norm_

import avici
from avici.metrics import shd, classification_metrics, threshold_metrics
avici_model = avici.load_pretrained(download="scm-v0")

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# For plotting:
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import colors
from matplotlib.colors import Normalize
from matplotlib.ticker import MaxNLocator

### 2. Generate DAGs and Mechanisms (linear Gaussian)

- `prior()`: Generates an Erdős–Rényi-style random DAG, with acyclicality constraints and a topological backbone to ensure weak connectivity.

- `LinearGaussianSCM`: Given an adjacency matrix A, samples edge weights normalized by number of parents and noise scale from InverseGamma. Observational values are sampled by following the topological order of the DAG (post-intervention), with optional leading batch dimensions (...).

In [None]:
def prior(d: int, p: float, device=torch.device('cpu')):
    '''Generates an Erdős–Rényi-style DAG, with acyclicality constraints 
    and a topological backbone to ensure weak connectivity.'''

    # Random topological order as a permutation of 0, ..., d-1
    order = torch.rand(d, device=device).argsort()          

    # Upper triangular Bernoulli mask on device
    mask = (torch.rand(d, d, device=device) < p)          
    mask = torch.triu(mask, diagonal=1)

    # Backbone to ensure weak connectivity in the topological order
    if d > 1:
        idx = torch.arange(d - 1, device=device)
        mask[idx, idx + 1] = True

    # inv[u] = i such that order[i] = u
    inv = torch.empty_like(order)
    inv.scatter_(0, order, torch.arange(d, device=device))

    # Permute rows then columns to map from topo-order space to node-label space
    A = torch.gather(mask, 0, inv.unsqueeze(1).expand_as(mask))
    A = torch.gather(A,    1, inv.unsqueeze(0).expand_as(mask))

    return A, order  

def adj_to_digraph(A: torch.Tensor) -> nx.DiGraph:
    '''Converts adjacency matrix to NetworkX DiGraph (useful for checking properties and plotting)'''

    A_cpu = A.to('cpu', dtype=torch.bool)
    i, j = A_cpu.nonzero(as_tuple=True)
    edges = list(zip(i.tolist(), j.tolist()))
    G = nx.DiGraph()
    G.add_edges_from(edges)
    return G

def plot_dag(A: Tensor) -> None:
    G = adj_to_digraph(A)
    plt.figure(figsize=(5,4))
    nx.draw(
        G, nx.spring_layout(G, seed=42),
        with_labels=True,
        node_color="lightblue",
        node_size=300,
        arrowsize=15,
        arrowstyle='->',
        font_size=10
    )
    plt.title("Erdős–Rényi-style DAG")
    plt.show()

# Check topology
d = 5  
m = 1.5   # Avg. Num. Parents/Node
p = (2 * (d * m - (d - 1))) / ((d - 1) * (d - 2))    # Accounts for the backbone

A, order = prior(d=d, p=p, device=device)
print("Adjacency Matrix (A):\n", A.cpu().numpy()*1)
print("\nOrder:", order.cpu().numpy())

G = adj_to_digraph(A)
print("\nis DAG?", nx.is_directed_acyclic_graph(G))
print("is weakly connected?", nx.is_weakly_connected(G))

plot_dag(A)

We can also check the effective number of incoming edges (parents) per node, to make sure that our expression for `p` in terms of `m` is correct:

In [None]:
d = 5
m = 1.5
p = (2 * (d * m - (d - 1))) / ((d - 1) * (d - 2))  

bin_A = []
for _ in range(1000):
    A, order = prior(d=d, p=p, device=device)
    A_sum = A.sum().cpu().item()
    bin_A.append(A_sum)

print(f"p: {p:.3f}")
print(f"m: {m}")
print(f"mean(Num. of Parents): {np.mean(bin_A)/d:.2f}")
print(f"std(Num. of Parents): {np.std(bin_A) / (d**2):.2f}")

In [None]:
class LinearGaussianSCM:
    '''Given an adjacency matrix A, samples edge weights normalized by number of parents and noise scale from InverseGamma.
        Values are sampled by following the topological order of the DAG (post-intervention), with optional leading batch dimensions.'''
    
    def __init__(self, 
                 A: Tensor,
                 W_base: Tensor | None = None,
                 b: Tensor | None = None,
                 sigma: Tensor | None = None,
                 device=torch.device('cpu')):
        self.device = device
        self.d = A.shape[-1]    
        A = A.to(device)
        
        if W_base is None:
            # Normalize weights by number of parents to keep noise variance stable
            in_degree = A.sum(dim=-2, keepdim=True).clamp_min(1.0)     
            coeff_std = (1.0 / torch.sqrt(in_degree)).expand_as(A)    
            coefficients = torch.normal(mean=0.0, std=coeff_std)
            self.W_base = coefficients
        else:
            self.W_base = W_base.to(device)
        
        if b is None:
            self.b = torch.zeros(self.d, device=device)  
        else:
            self.b = b.to(device)
        
        if sigma is None:
            gamma_samples = Gamma(10.0, 1.0).sample((self.d,)).to(device)
            self.sigma = 1.0 / gamma_samples
        else:
            self.sigma = sigma.to(device)

    def rsample(self,
                A: Tensor,
                order: Tensor,
                interv_node: Tensor | None = None,
                interv_val: Tensor | None = None,
                batch_shape: tuple = ()):
        A = A.to(self.device)
        order = order.to(self.device)
                
        W_eff = self.W_base * A.to(self.W_base.dtype)
        b     = self.b[(None,)*len(batch_shape)].expand(*batch_shape, self.d)
        sigma = self.sigma[(None,)*len(batch_shape)].expand(*batch_shape, self.d)

        if interv_node is not None:
            interv_node = interv_node.to(self.device).bool()
        if interv_val is not None:
            interv_val = interv_val.to(self.device).float()

        eps = torch.randn(*batch_shape, self.d, device=self.device) * sigma
        x = torch.zeros(*batch_shape, self.d, device=self.device)
        
        for k in order:
            parents_k = W_eff[:, k]                 # [d]
            lin_k = (x * parents_k).sum(dim=-1)     # [*B]

            if interv_node is not None and interv_val is not None:
                if batch_shape == ():
                    is_intervened = interv_node[k]
                    x[k] = torch.where(is_intervened, interv_val[k], b[k] + lin_k + eps[k])
                else:
                    is_intervened = interv_node[..., k]     # [*B]
                    x[..., k] = torch.where(is_intervened, interv_val[..., k], b[..., k] + lin_k + eps[..., k])
            else:
                x[..., k] = b[..., k] + lin_k + eps[..., k]
        
        return x

In [None]:
# Test 
d = 5
p = 0.583

A, order = prior(d=d, p=p, device=device)
print(f"A: \n{A.cpu().numpy()*1}")
print(f"\nOrder: {order.cpu().numpy()}")

scm = LinearGaussianSCM(A, device=device)

# Before intervention
x_original = scm.rsample(A, order, batch_shape=(1,))

print("\nAvg. values before intervention:")
for i in range(d):
    print(f"   Node {i+1}: {x_original.cpu().mean(dim=-2)[i]:.3f}")

# After intervention (shared across batch)
interv_node = torch.tensor([1, 0, 0, 0, 0])
interv_val = torch.tensor([1.0, 0, 0, 0, 0])       
x_intervened = scm.rsample(A, order, interv_node, interv_val, batch_shape=(1,))
print("\nAvg. values after intervention (node 1):")
for i in range(d):
    print(f"   Node {i+1}: {x_intervened.cpu().mean(dim=-2)[i]:.3f}")

### 3. Neural Networks (Policies + Critic)

- `HistoryEncoder`: uses a time mask because replayed transitions have different history lengths.
  Maps inputs to per-step latents, then applies mean pooling across valid time steps.
  Returns zeros if the history is empty.

- `TanhGaussianPolicy`: two heads sharing a common history encoding.
  The discrete head outputs logits over nodes, then Categorical plus softmax to pick an intervention node.
  The continuous head is conditioned on the one-hot node choice and has separate subheads for mean and log standard deviation.

- `RandomPolicy`: intervenes on a node with uniform probability, and uniformly samples an intervention value from (-1, 1). 

- `ObservationPolicy`: Simply returns all zeros for the intervention node and value masks. This is more of a convenience class to decrease the amount of logic in the training loop.

- `Critic`: has its own history encoder module, and concatenates the encoding with the one-hot node and value choice
  as input to fully connected layers to output a scalar Q-value.

In [None]:
class HistoryEncoder(nn.Module):
    def __init__(self, d: int, hidden_dim: int = 256, output_dim: int = 64, device=torch.device('cpu')):
        super().__init__()
        self.d = d
        self.output_dim = output_dim
        self.device = device
        self.enc_fc1 = nn.Linear(3*d, hidden_dim).to(device)
        self.enc_fc2 = nn.Linear(hidden_dim, output_dim).to(device)
        self.layer_norm = nn.LayerNorm(hidden_dim).to(device)

    def forward(self,
                hist_node: Tensor,               # [*B, t, d]
                hist_val: Tensor,                # [*B, t, d]
                hist_out: Tensor,                # [*B, t, d]
                time_mask: Tensor | None = None  # [*B, t]
                ):
        *batch_shape, t, _ = hist_node.shape

        if t == 0:
            return torch.zeros(*batch_shape, self.output_dim, device=self.device)

        x = torch.cat([hist_node, hist_val, hist_out], dim=-1)       # [*B, t, 3*d]
        h = self.enc_fc1(x)                                          # [*B, t, hidden_dim]
        h = self.layer_norm(h)
        h = F.relu(h)
        out = self.enc_fc2(h)                                        # [*B, t, output_dim]

        if time_mask is None:
            enc = out.mean(dim=-2)                              
        else:
            tm = time_mask.float().unsqueeze(-1)                     # [*B, t, 1]
            enc = (out * tm).sum(dim=-2) / tm.sum(dim=-2).clamp_min(1.0)

        return enc

class TanhGaussianPolicy(nn.Module):
    def __init__(
        self,
        d: int,
        enc_hidden_dim: int = 256,
        enc_output_dim: int = 64,
        hidden_dim: int = 256,
        min_std: float = 0.01,
        max_std: float = 3.0,
        init_mean: float = 0.0,
        init_std: float = 1.0,
        device: torch.device = torch.device("cpu"),
    ):
        super().__init__()
        self.d = d
        self.device = device
        self.history_encoder = HistoryEncoder(d, enc_hidden_dim, enc_output_dim, device)

        self.min_log_std = math.log(min_std)
        self.max_log_std = math.log(max_std)

        # Discrete head
        self.disc_fc1 = nn.Linear(enc_output_dim, hidden_dim)
        self.disc_fc_mid = nn.Linear(hidden_dim, hidden_dim)
        self.disc_fc2 = nn.Linear(hidden_dim, d)
        self.disc_ln1 = nn.LayerNorm(hidden_dim)

        # Continuous heads (conditioned on one-hot intervention node)
        self.mean_fc1 = nn.Linear(enc_output_dim + d, hidden_dim)
        self.mean_fc_mid = nn.Linear(hidden_dim, hidden_dim)
        self.mean_fc2 = nn.Linear(hidden_dim, 1)
        self.mean_ln1 = nn.LayerNorm(hidden_dim)

        self.log_std_fc1 = nn.Linear(enc_output_dim + d, hidden_dim)
        self.log_std_fc_mid = nn.Linear(hidden_dim, hidden_dim)
        self.log_std_fc2 = nn.Linear(hidden_dim, 1)
        self.log_std_ln1 = nn.LayerNorm(hidden_dim)

        # Initialization
        for layer in [self.disc_fc1, self.disc_fc_mid, self.mean_fc1, self.mean_fc_mid, self.log_std_fc1, self.log_std_fc_mid]:
            nn.init.xavier_uniform_(layer.weight, gain=nn.init.calculate_gain("relu"))
            nn.init.zeros_(layer.bias)

        for layer in [self.disc_fc2, self.mean_fc2, self.log_std_fc2]:
            nn.init.xavier_uniform_(layer.weight, gain=1.0)
            nn.init.zeros_(layer.bias)

        self.mean_fc2.bias.data.add_(init_mean)
        self.log_std_fc2.bias.data.add_(math.log(max(init_std, 1e-8)))

        # Transformations
        self._tanh = TanhTransform(cache_size=1)
        self.to(device)

    def _encode(self, hist_node: Tensor, hist_val: Tensor, hist_out: Tensor, time_mask: Tensor | None):
        return self.history_encoder(hist_node, hist_val, hist_out, time_mask)    # [*B, enc_output_dim]

    def _disc_logits(self, enc: Tensor):
        h = self.disc_ln1(self.disc_fc1(enc))
        h = F.relu(h)
        h = self.disc_fc_mid(h)
        h = F.relu(h)
        return self.disc_fc2(h)    # [*B, d]

    def _value_params(self, enc: Tensor, one_hot: Tensor):
        z = torch.cat([enc, one_hot], dim=-1)           # [*B, enc_output_dim+d]

        hm = self.mean_ln1(self.mean_fc1(z))
        hm = F.relu(hm)
        hm = self.mean_fc_mid(hm)
        hm = F.relu(hm)
        mean = self.mean_fc2(hm).squeeze(-1)            # [*B]

        hs = self.log_std_ln1(self.log_std_fc1(z))
        hs = F.relu(hs)
        hs = self.log_std_fc_mid(hs)
        hs = F.relu(hs) 
        log_std = self.log_std_fc2(hs).squeeze(-1)      # [*B]
        log_std = torch.clamp(log_std, self.min_log_std, self.max_log_std)
        std = torch.exp(log_std)                        # [*B]
        return mean, std

    def _sample_cont(self, enc: Tensor, node_idx: Tensor):
        interv_node = F.one_hot(node_idx.long(), num_classes=self.d).float()    # [*B, d]
        mean, std = self._value_params(enc, interv_node)
        base = Normal(mean, std)
        cont_dist = TransformedDistribution(base, [self._tanh])
        val = cont_dist.rsample()
        interv_val = interv_node * val.unsqueeze(-1)     # [*B, d]
        logp_val = cont_dist.log_prob(val)
        return interv_node, interv_val, logp_val

    def sample_action(self, hist_node: Tensor, hist_val: Tensor, hist_out: Tensor, time_mask: Tensor | None = None):
        enc = self._encode(hist_node, hist_val, hist_out, time_mask)
        logits = self._disc_logits(enc)
        disc_dist = Categorical(logits=logits)
        idx = disc_dist.sample()                         # [*B]
        logp_idx = disc_dist.log_prob(idx)               # [*B]
        interv_node, interv_val, logp_val = self._sample_cont(enc, idx)
        info = {"logits": logits, "enc": enc, "idx": idx}
        return interv_node, interv_val, logp_idx, logp_val, info
    
    @torch.no_grad()
    def mean_action(self, hist_node: Tensor, hist_val: Tensor, hist_out: Tensor, time_mask: Tensor | None = None):
        enc = self._encode(hist_node, hist_val, hist_out, time_mask)      # [*B, enc_output_dim]
        logits = self._disc_logits(enc)                                   # [*B, d]
        idx = torch.argmax(logits, dim=-1)                                # [*B]
        interv_node = F.one_hot(idx.long(), num_classes=self.d).float()   # [*B, d]

        mean, _ = self._value_params(enc, interv_node)                    # [*B]
        val = self._tanh(mean)                                            
        interv_val = interv_node * val.unsqueeze(-1)                      # [*B, d]

        info = {"logits": logits, "enc": enc, "idx": idx}
        return interv_node, interv_val, None, None, info

In [None]:
class RandomPolicy(nn.Module):
    def __init__(self, d: int, eps: float = 1e-3,  device: torch.device = torch.device('cpu')):
        super().__init__()
        self.d = d
        self.eps = float(eps)
        self.device = device

    def forward(self):
        idx = torch.randint(self.d, (1,), device=self.device).squeeze(0)
        val = torch.empty((), device=self.device).uniform_(-1.0 + self.eps, 1.0 - self.eps)
        interv_node = torch.zeros(self.d, device=self.device)
        interv_node[idx] = 1.0
        interv_val = interv_node * val

        return interv_node, interv_val

class ObservationPolicy(nn.Module):
    def __init__(self, d: int, device: torch.device = torch.device('cpu')):
        super().__init__()
        self.d = d
        self.device = device

    def forward(self):
        interv_node = torch.zeros(self.d, device=self.device)
        interv_val  = torch.zeros(self.d, device=self.device)
        return interv_node, interv_val

In [None]:
class Critic(nn.Module):
    def __init__(self, 
                 d: int, 
                 enc_hidden_dim: int = 256, 
                 enc_output_dim: int = 64, 
                 hidden_dim: int = 256,
                 device=torch.device('cpu')):
        super().__init__()
        self.d = d
        self.device = device
        self.history_encoder = HistoryEncoder(d, enc_hidden_dim, enc_output_dim, device)

        self.q_fc1 = nn.Linear(enc_output_dim + 2*d, hidden_dim).to(device)
        self.q_fc2 = nn.Linear(hidden_dim, hidden_dim).to(device)
        self.q_fc3 = nn.Linear(hidden_dim, 1).to(device)
        self.layer_norm1 = nn.LayerNorm(hidden_dim).to(device)

        # Initialization
        nn.init.xavier_uniform_(self.q_fc1.weight, gain=nn.init.calculate_gain("relu"))
        nn.init.zeros_(self.q_fc1.bias)
        nn.init.xavier_uniform_(self.q_fc2.weight, gain=nn.init.calculate_gain("relu"))
        nn.init.zeros_(self.q_fc2.bias)
        nn.init.xavier_uniform_(self.q_fc3.weight, gain=1.0)
        nn.init.zeros_(self.q_fc3.bias)

        self.to(device)

    def forward(self,
                hist_node: Tensor,        # [*B, t, d]
                hist_val: Tensor,         # [*B, t, d]
                hist_out: Tensor,         # [*B, t, d]
                interv_node: Tensor,      # [*B, d]
                interv_val: Tensor,       # [*B, d] 
                time_mask: Tensor | None = None):
        
        enc = self.history_encoder(hist_node, hist_val, hist_out, time_mask)
        x = torch.cat([enc, interv_node, interv_val], dim=-1)      # [*B, enc_output_dim + 2*d]

        h1 = self.q_fc1(x)
        h1 = self.layer_norm1(h1)
        h1 = F.relu(h1)

        h2 = self.q_fc2(h1)
        h2 = F.relu(h2)

        out = self.q_fc3(h2)         # [*B, 1]
        return out.squeeze(-1)       # [*B]

### 4. Soft Actor-Critic (ReplayBuffer + SAC + Reward Function)

- `ReplayBuffer`: stores transitions $(s_t, a_t, r_t, s_{t+1}, done_t) = (h_t, (i_t, v_t), r_t, h_{t+1}, done_t)$, and constructs the time mask used in the history encoder.

- `SAC`: Soft Actor-Critic with discrete node choice and continuous value (hybrid actions).
    - Actor (policy) and twin critics with target networks.
    - Entropy regularization with separate coefficients for discrete and continuous actions (optionally tuned online).
    - Critic update: minimizes error against target critics, averaging over discrete actions and sampling continuous values from the policy.
    - Policy update: maximizes expected Q-value minus entropy terms; optionally updates entropy coefficients; applies
    gradient clipping.

- `AVICIDeltaReward`: A convenience class for keeping track of the previous and current expected number of correct adjacency matrix (whose difference defines the reward). Requires an `avici_model` to be loaded in memory.

In [None]:
class ReplayBuffer:
    """
    Stores transitions:
      - hist_node, hist_val, hist_out: [T, d]                    (0-padded; first t rows valid)
      - time_mask: [T]                                           (0-padded; first t rows valid)
      - interv_node, interv_val: [d]
      - reward: scalar
      - next_hist_node, next_hist_val, next_hist_out: [T, d]     (0-padded; first t+1 rows valid)
      - next_time_mask: [T]
      - done: scalar bool (=True at last step of trajectory)
    """
    def __init__(self, d: int, T: int, num_obs: int, cap: int, device=torch.device('cpu')):
        self.d = d
        self.T = T
        self.max_len = num_obs + T
        self.cap = cap
        self.device = device

        # Pre-allocate storage on GPU
        self.hist_node      = torch.zeros(cap, num_obs + T, d, device=device)
        self.hist_val       = torch.zeros(cap, num_obs + T, d, device=device)
        self.hist_out       = torch.zeros(cap, num_obs + T, d, device=device)
        self.time_mask      = torch.zeros(cap, num_obs + T, device=device)
        self.interv_node    = torch.zeros(cap, d, device=device, dtype=torch.bool)
        self.interv_val     = torch.zeros(cap, d, device=device)
        self.reward         = torch.zeros(cap, device=device)
        self.done           = torch.zeros(cap, device=device)
        self.next_hist_node = torch.zeros(cap, num_obs + T, d, device=device)
        self.next_hist_val  = torch.zeros(cap, num_obs + T, d, device=device)
        self.next_hist_out  = torch.zeros(cap, num_obs + T, d, device=device)
        self.next_time_mask = torch.zeros(cap, num_obs + T, device=device)

        self.idx = 0
        self.size = 0
    def _advance(self):
        self.idx = (self.idx + 1) % self.cap
        self.size = min(self.size + 1, self.cap)

    @torch.no_grad()
    def add_transition(self,
                       hist_node: Tensor,          # [t, d]
                       hist_val: Tensor,           # [t, d]
                       hist_out: Tensor,           # [t, d]
                       interv_node: Tensor,        # [d]
                       interv_val: Tensor,         # [d]
                       reward: float | Tensor,   
                       next_hist_node: Tensor,     # [t+1, d]
                       next_hist_val: Tensor,      # [t+1, d]
                       next_hist_out: Tensor,      # [t+1, d]
                       done: bool | Tensor):
        t = hist_node.shape[-2]
        i = self.idx

        # Move to device
        hist_node = hist_node.to(self.device)
        hist_val = hist_val.to(self.device)
        hist_out = hist_out.to(self.device)
        interv_node = interv_node.to(self.device)
        interv_val = interv_val.to(self.device)
        next_hist_node = next_hist_node.to(self.device)
        next_hist_val = next_hist_val.to(self.device)
        next_hist_out = next_hist_out.to(self.device)

        # Current history
        self.hist_node[i].zero_()
        self.hist_val[i].zero_()
        self.hist_out[i].zero_()
        if t > 0:
            self.hist_node[i, :t] = hist_node
            self.hist_val[i, :t]  = hist_val
            self.hist_out[i, :t]   = hist_out
        self.time_mask[i].zero_()
        if t > 0:
            self.time_mask[i, :t] = 1.0

        # Action
        self.interv_node[i] = interv_node
        self.interv_val[i]  = interv_val

        # Reward / done
        self.reward[i] = torch.as_tensor(reward, device=self.device)
        self.done[i] = torch.as_tensor(done, device=self.device)

        # Next history
        self.next_hist_node[i].zero_()
        self.next_hist_val[i].zero_()
        self.next_hist_out[i].zero_()
        self.next_hist_node[i, :t+1] = next_hist_node
        self.next_hist_val[i, :t+1] = next_hist_val
        self.next_hist_out[i, :t+1] = next_hist_out
        self.next_time_mask[i].zero_()
        self.next_time_mask[i, :t+1] = 1.0

        self._advance()

    def sample(self, batch_shape: tuple = ()):
        """Sample uniform random minibatch of transitions."""
        idx = torch.randint(0, self.size, batch_shape, device=self.device)

        batch = {
            "hist_node":     self.hist_node[idx],
            "hist_val":      self.hist_val[idx],
            "hist_out":       self.hist_out[idx],
            "time_mask":      self.time_mask[idx],
            "interv_node":   self.interv_node[idx],
            "interv_val":    self.interv_val[idx],
            "reward":         self.reward[idx],
            "done":           self.done[idx],
            "next_hist_node": self.next_hist_node[idx],
            "next_hist_val":  self.next_hist_val[idx],
            "next_hist_out":   self.next_hist_out[idx],
            "next_time_mask":  self.next_time_mask[idx],
        }
        return batch

In [None]:
class SAC:
    def __init__(self, 
                 d: int, 
                 policy: nn.Module, 
                 critic1: nn.Module, 
                 critic2: nn.Module,
                 lr_actor: float = 1e-4, 
                 lr_critic: float = 3e-4,
                 lr_alpha: float = 3e-4,
                 gamma: float = 0.99, 
                 tau: float = 0.005,
                 alpha_d: float = 0.2,
                 alpha_c: float = 0.2,
                 tune_entropy: bool = False,
                 target_entropy_disc: float | None = None,
                 target_entropy_cont: float | None = None,
                 clip_norm: float | None = None,
                 device=torch.device('cpu')):
        
        self.d = d
        self.policy = policy
        self.critic1 = critic1
        self.critic2 = critic2
        self.gamma = gamma
        self.tau = tau
        self.clip_norm = clip_norm
        self.device = device

        self.target_critic1 = Critic(self.d, device=device)
        self.target_critic2 = Critic(self.d, device=device)
        self.target_critic1.load_state_dict(self.critic1.state_dict())
        self.target_critic2.load_state_dict(self.critic2.state_dict())

        self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr_actor)
        self.critic1_optimizer = torch.optim.Adam(self.critic1.parameters(), lr=lr_critic)
        self.critic2_optimizer = torch.optim.Adam(self.critic2.parameters(), lr=lr_critic)

        self.tune_entropy = tune_entropy
        if self.tune_entropy:
            if target_entropy_disc is None:
                target_entropy_disc = 0.25*math.log(self.d)
            if target_entropy_cont is None:
                target_entropy_cont = -1.0

            self.target_entropy_disc = float(target_entropy_disc)
            self.target_entropy_cont = float(target_entropy_cont)

            self.log_alpha_d = nn.Parameter(torch.log(torch.tensor(alpha_d, device=device).clamp_min(1e-8)), requires_grad=True)
            self.log_alpha_c = nn.Parameter(torch.log(torch.tensor(alpha_c, device=device).clamp_min(1e-8)), requires_grad=True)
            self.alpha_d_optimizer = torch.optim.Adam([self.log_alpha_d], lr=lr_alpha)
            self.alpha_c_optimizer = torch.optim.Adam([self.log_alpha_c], lr=lr_alpha)

            self.alpha_d = None
            self.alpha_c = None
        else:
            self.alpha_d = torch.tensor(alpha_d, device=device)
            self.alpha_c = torch.tensor(alpha_c, device=device)
            self.log_alpha_d = None
            self.log_alpha_c = None
            self.alpha_d_optimizer = None
            self.alpha_c_optimizer = None

    def _grad_norm(self, params):
        grads = [p.grad.detach().reshape(-1) for p in params if p.grad is not None]
        if not grads:
            return 0.0
        return torch.cat(grads).norm(2).item()

    def update_critics(self, batch: dict):
        hist_node, hist_val = batch["hist_node"], batch["hist_val"]
        hist_out, time_mask = batch["hist_out"], batch["time_mask"]
        interv_node, interv_val = batch["interv_node"], batch["interv_val"]
        reward, done = batch["reward"], batch["done"]
        next_hist_node, next_hist_val = batch["next_hist_node"], batch["next_hist_val"]
        next_hist_out, next_time_mask = batch["next_hist_out"], batch["next_time_mask"]

        with torch.no_grad():
            next_enc = self.policy._encode(next_hist_node, 
                                           next_hist_val, 
                                           next_hist_out, 
                                           next_time_mask)
            next_logits = self.policy._disc_logits(next_enc)
            next_probs_disc = F.softmax(next_logits, dim=-1)
            next_log_probs_disc = F.log_softmax(next_logits, dim=-1)

            B = hist_node.size(0)   
            expected_q_next = torch.zeros(B, device=self.device)
            exp_logp_cont_next = torch.zeros(B, device=self.device)

            for idx in range(self.d):
                idx_batch = torch.full((B,), idx, dtype=torch.long, device=self.device)

                # --- Inner expectation over continuous actions ---
                next_interv_node, next_interv_val, logp_val_i = self.policy._sample_cont(next_enc, idx_batch)

                tq1 = self.target_critic1(next_hist_node, next_hist_val, next_hist_out,
                                          next_interv_node, next_interv_val, next_time_mask)
                tq2 = self.target_critic2(next_hist_node, next_hist_val, next_hist_out,
                                          next_interv_node, next_interv_val, next_time_mask)
                tq_min = torch.min(tq1, tq2)
                
                expected_q_next += next_probs_disc[:, idx] * tq_min
                exp_logp_cont_next += next_probs_disc[:, idx] * logp_val_i

            exp_logp_disc_next = (next_probs_disc * next_log_probs_disc).sum(dim=-1)

            if self.tune_entropy:
                alpha_d_t = self.log_alpha_d.exp()
                alpha_c_t = self.log_alpha_c.exp()
            else:
                alpha_d_t = self.alpha_d
                alpha_c_t = self.alpha_c

            v_next = expected_q_next - alpha_d_t * exp_logp_disc_next - alpha_c_t * exp_logp_cont_next
            target_q = reward + (1.0 - done.float()) * self.gamma * v_next

        cq1 = self.critic1(hist_node, hist_val, hist_out, interv_node, interv_val, time_mask)
        cq2 = self.critic2(hist_node, hist_val, hist_out, interv_node, interv_val, time_mask)

        loss1 = F.smooth_l1_loss(cq1, target_q)
        loss2 = F.smooth_l1_loss(cq2, target_q)

        # Clip gradients
        self.critic1_optimizer.zero_grad()
        loss1.backward()
        if self.clip_norm is not None:
            clip_grad_norm_(self.critic1.parameters(), self.clip_norm)
        c1_grad_norm = self._grad_norm(self.critic1.parameters())
        self.critic1_optimizer.step()

        self.critic2_optimizer.zero_grad()
        loss2.backward()
        if self.clip_norm is not None:
            clip_grad_norm_(self.critic2.parameters(), self.clip_norm)
        c2_grad_norm = self._grad_norm(self.critic2.parameters())
        self.critic2_optimizer.step()

        return loss1.item(), loss2.item(), c1_grad_norm, c2_grad_norm

    def update_policy(self, batch: dict):
        hist_node, hist_val = batch["hist_node"], batch["hist_val"]
        hist_out, time_mask = batch["hist_out"], batch["time_mask"]

        B = hist_node.shape[0]
        
        enc = self.policy._encode(hist_node, hist_val, hist_out, time_mask)
        logits = self.policy._disc_logits(enc)
        disc_probs = F.softmax(logits, dim=-1)
        log_probs_disc = F.log_softmax(logits, dim=-1)
        avg_logp_disc = torch.zeros(B, device=self.device)
        avg_logp_cont = torch.zeros(B, device=self.device)
        outer_obj = torch.zeros(B, device=self.device)

        if self.tune_entropy:
            alpha_d_t = self.log_alpha_d.exp()
            alpha_c_t = self.log_alpha_c.exp()
        else:
            alpha_d_t = self.alpha_d
            alpha_c_t = self.alpha_c

        for idx in range(self.d):
            idx_batch = torch.full((B,), idx, dtype=torch.long, device=self.device)

            # --- Inner expectation over continuous actions ---
            interv_node_i, interv_val_i, logp_val_i = self.policy._sample_cont(enc, idx_batch)
            q1 = self.critic1(hist_node, hist_val, hist_out, interv_node_i, interv_val_i, time_mask)
            q2 = self.critic2(hist_node, hist_val, hist_out, interv_node_i, interv_val_i, time_mask)
            q_min = torch.min(q1, q2)

            inner_obj = q_min - (alpha_c_t.detach() * logp_val_i)

            # --- Outer expectation over discrete actions ---
            outer_obj += disc_probs[:, idx] * (inner_obj - alpha_d_t.detach() * log_probs_disc[:, idx])

            # --- average log probs. for entropy tuning ---
            avg_logp_disc += disc_probs[:, idx] * log_probs_disc[:, idx]
            avg_logp_cont += disc_probs[:, idx] * logp_val_i

        policy_loss = -outer_obj.mean()

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        if self.clip_norm is not None:
            clip_grad_norm_(self.policy.parameters(), self.clip_norm)
        policy_grad_norm = self._grad_norm(self.policy.parameters())
        self.policy_optimizer.step()
        
        if self.tune_entropy:
            alpha_d_loss = -(self.log_alpha_d * (avg_logp_disc.detach() + self.target_entropy_disc)).mean()
            alpha_c_loss = -(self.log_alpha_c * (avg_logp_cont.detach() + self.target_entropy_cont)).mean()

            self.alpha_d_optimizer.zero_grad()
            alpha_d_loss.backward()
            self.alpha_d_optimizer.step()

            self.alpha_c_optimizer.zero_grad()
            alpha_c_loss.backward()
            self.alpha_c_optimizer.step()

            alpha_d_value = float(self.log_alpha_d.exp().item())
            alpha_c_value = float(self.log_alpha_c.exp().item())
        else:
            alpha_d_value = float(self.alpha_d.item())
            alpha_c_value = float(self.alpha_c.item())
        
        avg_logp_disc = float(avg_logp_disc.mean().item())
        avg_logp_cont = float(avg_logp_cont.mean().item())
        return policy_loss.item(), avg_logp_disc, avg_logp_cont, alpha_d_value, alpha_c_value, policy_grad_norm

    def update_target_networks(self) -> None:
        with torch.no_grad():
            for tp, p in zip(self.target_critic1.parameters(), self.critic1.parameters()):
                tp.copy_(self.tau * p + (1 - self.tau) * tp)
            for tp, p in zip(self.target_critic2.parameters(), self.critic2.parameters()):
                tp.copy_(self.tau * p + (1 - self.tau) * tp)

    def train_step(self, batch: dict):
        c1, c2, c1_gn, c2_gn = self.update_critics(batch)
        pl, avg_logp_disc, avg_logp_cont, alpha_d_val, alpha_c_val, pol_gn = self.update_policy(batch)
        self.update_target_networks()
        
        return {
            "critic1_loss": c1,
            "critic2_loss": c2,
            "policy_loss": pl,
            "avg_log_prob_disc": avg_logp_disc,
            "avg_log_prob_cont": avg_logp_cont,
            "alpha_d": alpha_d_val,
            "alpha_c": alpha_c_val,
            "critic1_gn": c1_gn,
            "critic2_gn": c2_gn,
            "policy_gn": pol_gn,
        }

In [None]:
class AviciDeltaReward:
    def __init__(self, avici_model):
        self.avici_model = avici_model
        self.prev_g = None

    def reset(self):
        self.prev_g = None

    def __call__(self,
                 A_true: Tensor,      # [d, d]
                 hist_nodes: Tensor,  # [t, d] 
                 hist_out: Tensor     # [t, d] 
                 ):
        x_np = hist_out.detach().cpu().numpy()
        interv_np = hist_nodes.detach().cpu().numpy()
        A_true_np = A_true.detach().cpu().numpy()

        g_prob = self.avici_model(x=x_np, interv=interv_np)
        
        # Expected number of correct adjacency matrix entries (CAASL-style reward)**
        g_t_mat = A_true_np * g_prob + (1.0 - A_true_np) * (1.0 - g_prob)  
        g_t = g_t_mat.sum()
        
        # Telescoping
        if self.prev_g is None:
            self.prev_g = g_t
            return 0.0
        r_t = g_t - self.prev_g
        self.prev_g = g_t
        return r_t

### 5. SAC Training Loop

- `train()`:

    - For each episode: sample a random DAG and SCM, build an observational prefix of length num_obs, and initialize the AVICI reward with the observations.

    - For $t$ in $0,...,T-1$: sample action from `RandomPolicy` during replay buffer warmup; otherwise, sample from ``policy.sample_action``, generate outcomes from the SCM, append to history, compute reward, and store transition $(h_t, (i_t, v_t), r_t, h_{t+1}, done_t)$ in the replay buffer.
    
    - After each step: if replay buffer warmup is complete, run `updates_per_step` critic updates on minibatches sampled from the replay buffer.



In [None]:
def train(
    d: int = 5, 
    p: float | None = 0.583,
    T: int = 8,
    batch_size: int = 128,
    num_episodes: int = 5000,
    num_obs: int = 20,
    print_every: int = 50,
    device: torch.device = torch.device("cpu"),
    avici_model=None, 
    tune_entropy: bool = True,
    target_entropy_disc: float = 0.5*math.log(5), 
    target_entropy_cont: float = -1.0,
    alpha_d: float = 0.5,
    alpha_c: float = 0.5,
    updates_per_step: int = 1,
    warmup_episodes: int = 500
):
    assert avici_model is not None, "avici_model must be provided"

    # Init
    policy = TanhGaussianPolicy(d=d, device=device)
    rand_policy = RandomPolicy(d=d, device=device)
    critic1 = Critic(d=d, device=device)
    critic2 = Critic(d=d, device=device)
    replay_buffer = ReplayBuffer(d=d, T=T, num_obs=num_obs, cap=T * num_episodes, device=device)

    sac = SAC(
        d=d,
        policy=policy,
        critic1=critic1,
        critic2=critic2,
        lr_actor=1e-4,
        lr_critic=3e-4,
        lr_alpha=2e-4,      
        gamma=1.0,
        tau=0.005,
        alpha_d=alpha_d,
        alpha_c=alpha_c,
        tune_entropy=tune_entropy,
        target_entropy_disc=target_entropy_disc,
        target_entropy_cont=target_entropy_cont,              
        clip_norm=5.0,
        device=device,
    )
    training_stats = {
        "critic1_loss": [], "critic2_loss": [], "policy_loss": [],
        "eval_critic1_loss": [], "eval_critic2_loss": [], "eval_policy_loss": [],
        "critic1_gn": [], "critic2_gn": [], "policy_gn": [],
        "episode_reward": [], "avg_log_prob_disc": [], "avg_log_prob_cont": [],
        "final_alpha_d": [], "final_alpha_c": [],
    }

    print(
        "Episode   Train Reward   Policy Loss   Critic1 Loss   Critic2 Loss   "
        "LogP_disc   LogP_cont    Alpha_d   Alpha_c   Policy GN   Critic1 GN   Critic2 GN"
    )
    print("-" * 150)

    # Preallocate histories on GPU: [num_obs + T, d]
    hist_node = torch.zeros(num_obs + T, d, device=device)
    hist_val  = torch.zeros(num_obs + T, d, device=device)
    hist_out  = torch.zeros(num_obs + T, d, device=device)

    # AVICI reward (CPU)
    reward_fn = AviciDeltaReward(avici_model)

    for episode in range(1, num_episodes + 1):
        A, order = prior(d, p, device=device)
        scm = LinearGaussianSCM(A, device=device)

        # Observational prefix (no interventions)
        if num_obs > 0:
            obs_node = torch.zeros(num_obs, d, device=device)
            obs_val  = torch.zeros(num_obs, d, device=device)
            obs_out  = scm.rsample(A, order, interv_node=None, interv_val=None, batch_shape=(num_obs,))
        else:
            obs_node = torch.zeros(0, d, device=device)
            obs_val  = torch.zeros(0, d, device=device)
            obs_out  = torch.zeros(0, d, device=device)
        
        # Prime AVICI with obs
        reward_fn.reset()
        _ = reward_fn(A.cpu(), obs_node.cpu(), obs_out.cpu())

        # Reset histories and copy prefix
        hist_node.zero_(); hist_val.zero_(); hist_out.zero_()
        hist_node[:num_obs] = obs_node
        hist_val[:num_obs]  = obs_val
        hist_out[:num_obs]  = obs_out

        episode_losses = {"critic1": [], "critic2": [], "policy": []}
        episode_grad_norms = {"critic1": [], "critic2": [], "policy": []}
        episode_reward = 0.0
        episode_logp_disc = []
        episode_logp_cont = []
        episode_alpha_d = []
        episode_alpha_c = []

        for t in range(T):
            if episode < warmup_episodes:
                interv_node, interv_val = rand_policy() 
                logp_idx, logp_val, _ = None, None, None 
            else: 
                interv_node, interv_val, logp_idx, logp_val, _ = policy.sample_action(hist_node[: num_obs + t],
                                                                                  hist_val[: num_obs + t],
                                                                                  hist_out[: num_obs + t])
            # Environment step
            out = scm.rsample(A, order, interv_node=interv_node.bool(), interv_val=interv_val)

            # Append to histories
            hist_node[num_obs + t] = interv_node
            hist_val[num_obs + t]  = interv_val
            hist_out[num_obs + t]  = out

            # Calculate reward (CPU)
            r_t = reward_fn(A.cpu(), hist_node[: num_obs + t + 1].cpu(), hist_out[: num_obs + t + 1].cpu())
            episode_reward += r_t

            if logp_idx is not None: episode_logp_disc.append(float(logp_idx.item()))
            if logp_val is not None: episode_logp_cont.append(float(logp_val.item()))

            # Store transition
            replay_buffer.add_transition(
                hist_node[: num_obs + t],
                hist_val[: num_obs + t],
                hist_out[: num_obs + t],
                interv_node,
                interv_val,
                r_t,
                hist_node[: num_obs + t + 1],
                hist_val[: num_obs + t + 1],
                hist_out[: num_obs + t + 1],
                t == T - 1,
            )

            # Updates (skip until burn-in complete)
            if (episode >= warmup_episodes) and (replay_buffer.size >= batch_size):
                for _ in range(updates_per_step):
                    batch = replay_buffer.sample(batch_shape=(batch_size,))
                    stats = sac.train_step(batch)
                    episode_losses["critic1"].append(stats["critic1_loss"])
                    episode_losses["critic2"].append(stats["critic2_loss"])
                    episode_losses["policy"].append(stats["policy_loss"])
                    episode_grad_norms["critic1"].append(stats["critic1_gn"])
                    episode_grad_norms["critic2"].append(stats["critic2_gn"])
                    episode_grad_norms["policy"].append(stats["policy_gn"])
                    episode_alpha_d.append(stats["alpha_d"])
                    episode_alpha_c.append(stats["alpha_c"])

        # Aggregate stats
        training_stats["episode_reward"].append(episode_reward)
        training_stats["avg_log_prob_disc"].append(sum(episode_logp_disc) / max(1, len(episode_logp_disc)))
        training_stats["avg_log_prob_cont"].append(sum(episode_logp_cont) / max(1, len(episode_logp_cont)))
        training_stats["critic1_loss"].append(sum(episode_losses["critic1"]) / max(1, len(episode_losses["critic1"])))
        training_stats["critic2_loss"].append(sum(episode_losses["critic2"]) / max(1, len(episode_losses["critic2"])))
        training_stats["policy_loss"].append(sum(episode_losses["policy"]) / max(1, len(episode_losses["policy"])))
        training_stats["critic1_gn"].append(sum(episode_grad_norms["critic1"]) / max(1, len(episode_grad_norms["critic1"])))
        training_stats["critic2_gn"].append(sum(episode_grad_norms["critic2"]) / max(1, len(episode_grad_norms["critic2"])))
        training_stats["policy_gn"].append(sum(episode_grad_norms["policy"]) / max(1, len(episode_grad_norms["policy"])))

        if sac.tune_entropy and getattr(sac, "log_alpha_d", None) is not None:
            training_stats["final_alpha_d"].append(float(sac.log_alpha_d.exp().item()))
            training_stats["final_alpha_c"].append(float(sac.log_alpha_c.exp().item()))
        else:
            training_stats["final_alpha_d"].append(float(sac.alpha_d.item()))
            training_stats["final_alpha_c"].append(float(sac.alpha_c.item()))

        # Print row
        if episode % print_every == 0 or episode == 1:
            print(
                f"{episode:7d}  {training_stats['episode_reward'][-1]:12.4f}  "
                f"{training_stats['policy_loss'][-1]:11.4f}  {training_stats['critic1_loss'][-1]:12.4f}  "
                f"{training_stats['critic2_loss'][-1]:12.4f}      {training_stats['avg_log_prob_disc'][-1]:9.4f}   "
                f"{training_stats['avg_log_prob_cont'][-1]:9.4f}    "
                f"{training_stats['final_alpha_d'][-1]:7.4f}   {training_stats['final_alpha_c'][-1]:7.4f}  "
                f"{training_stats['policy_gn'][-1]:9.4f}  {training_stats['critic1_gn'][-1]:11.4f} {training_stats['critic2_gn'][-1]:11.4f}"
            )

    return training_stats, sac, replay_buffer, policy

### 6. Example Training Run

In [None]:
training_stats, sac, replay_buffer, trained_policy = train(
    d=5,
    T=8,
    p=0.583,
    batch_size=128,
    num_episodes=5000,    
    num_obs=20,
    print_every=25,
    device=torch.device("cuda"),
    avici_model=avici_model,
    tune_entropy=True,
    target_entropy_disc=0.25*math.log(5),  
    target_entropy_cont=-1.0,             
    alpha_d=0.5,
    alpha_c=0.5,  
    updates_per_step=1,
    warmup_episodes=500)

In [None]:
# --- Plots training statistics ---

fig, axes = plt.subplots(3, 3, figsize=(18, 16))
fig.suptitle('SAC Training Metrics', fontsize=16)
x = np.arange(1, len(training_stats['episode_reward']) + 1)

# Row 0
axes[0, 0].plot(x, training_stats['episode_reward'], 'b-', alpha=0.7)
axes[0, 0].set_title('Training Episode Reward')
axes[0, 0].set_xlabel('Episodes')
axes[0, 0].set_ylabel('Reward')
axes[0, 0].grid(True, alpha=0.3)

# Average log-probs: discrete and continuous
axes[0, 1].plot(x, training_stats['avg_log_prob_disc'], 'purple', alpha=0.7, label='disc')
axes[0, 1].plot(x, training_stats['avg_log_prob_cont'], 'teal',   alpha=0.7, label='cont')
axes[0, 1].set_title('Average Log Probability')
axes[0, 1].set_xlabel('Episodes')
axes[0, 1].set_ylabel('Log Prob')
axes[0, 1].grid(True, alpha=0.3)
axes[0, 1].legend()

# Entropy coefficients
axes[0, 2].plot(x, training_stats['final_alpha_d'], 'm-', alpha=0.7, label='alpha_d')
axes[0, 2].plot(x, training_stats['final_alpha_c'], 'c-', alpha=0.7, label='alpha_c')
axes[0, 2].set_title('Entropy Coefficients')
axes[0, 2].set_xlabel('Episodes')
axes[0, 2].set_ylabel('Value')
axes[0, 2].grid(True, alpha=0.3)
axes[0, 2].legend()

# Row 1: losses
axes[1, 0].plot(x, training_stats['policy_loss'], 'r-', alpha=0.7)
axes[1, 0].set_title('Policy Loss')
axes[1, 0].set_xlabel('Episodes')
axes[1, 0].set_ylabel('Loss')
axes[1, 0].grid(True, alpha=0.3)

axes[1, 1].plot(x, training_stats['critic1_loss'], 'r-', alpha=0.7)
axes[1, 1].set_title('Critic 1 Loss')
axes[1, 1].set_xlabel('Episodes')
axes[1, 1].set_ylabel('Loss')
axes[1, 1].grid(True, alpha=0.3)

axes[1, 2].plot(x, training_stats['critic2_loss'], 'r-', alpha=0.7)
axes[1, 2].set_title('Critic 2 Loss')
axes[1, 2].set_xlabel('Episodes')
axes[1, 2].set_ylabel('Loss')
axes[1, 2].grid(True, alpha=0.3)

# Row 3: grad norms
axes[2, 0].plot(x, training_stats['policy_gn'], 'k-', alpha=0.7)
axes[2, 0].set_title('Policy Grad Norm')
axes[2, 0].set_xlabel('Episodes')
axes[2, 0].set_ylabel('L2 Norm')
axes[2, 0].grid(True, alpha=0.3)

axes[2, 1].plot(x, training_stats['critic1_gn'], 'k-', alpha=0.7)
axes[2, 1].set_title('Critic 1 Grad Norm')
axes[2, 1].set_xlabel('Episodes')
axes[2, 1].set_ylabel('L2 Norm')
axes[2, 1].grid(True, alpha=0.3)

axes[2, 2].plot(x, training_stats['critic2_gn'], 'k-', alpha=0.7)
axes[2, 2].set_title('Critic 2 Grad Norm')
axes[2, 2].set_xlabel('Episodes')
axes[2, 2].set_ylabel('L2 Norm')
axes[2, 2].grid(True, alpha=0.3)

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()


### 7. Visualize Baseline Policy Performance

In [None]:
def plot_interventions(d: int,
                       T: int,
                       A: Tensor,
                       interv_nodes: Tensor,
                       interv_values: Tensor,
                       scm,
                       policy_title: str = "Trained") -> None:
    """Plots trajectories given a DAG, SCM, and chosen interventions"""
    G = adj_to_digraph(A)

    total_panels = T + 1
    max_cols = 4
    ncols = min(total_panels, max_cols)
    nrows = math.ceil(total_panels / ncols)

    fig_width  = 4 * ncols
    fig_height = 4 * nrows
    fig, axes = plt.subplots(
        nrows, ncols, figsize=(fig_width, fig_height),
        squeeze=False, constrained_layout=True, dpi=300
    )
    axes = axes.flatten()

    interv_nodes_cpu  = interv_nodes.detach().cpu()
    interv_values_cpu = interv_values.detach().cpu()

    assert interv_nodes_cpu.shape == (T, d), f"interv_nodes shape {tuple(interv_nodes_cpu.shape)} != (T, d)"
    assert interv_values_cpu.shape == (T, d), f"interv_values shape {tuple(interv_values_cpu.shape)} != (T, d)"

    # Map adjacency indices to graph node idx
    nodes_G = set(G.nodes())
    if nodes_G == set(range(d)):
        idx_to_node = list(range(d))
        node_labels = {i: str(i + 1) for i in range(d)}
    elif nodes_G == set(range(1, d + 1)):
        idx_to_node = list(range(1, d + 1))
        node_labels = {i: str(i) for i in range(1, d + 1)}
    else:
        raise ValueError("Graph nodes must be indexed as 0..d-1 or 1..d.")

    nodelist = [idx_to_node[i] for i in range(d)]
    neutral_gray = "#d9d9d9"

    # Colormap
    base_cmap = mpl.colormaps["managua"].reversed()
    def lighten(c, factor=0.3):
        r, g, b, a = c
        return (r + (1 - r) * factor, g + (1 - g) * factor, b + (1 - b) * factor, a)
    colors_list = [lighten(base_cmap(i), factor=0.35) for i in np.linspace(0, 1, 256)]
    cmap = colors.LinearSegmentedColormap.from_list("light_managua_rev", colors_list)
    norm = Normalize(vmin=-1.0, vmax=1.0)

    node_size = 600
    pos = nx.kamada_kawai_layout(G) 

    # Edge indices from A
    rows, cols = torch.nonzero(A.detach().cpu(), as_tuple=True)
    edge_indices = list(zip(rows.tolist(), cols.tolist()))

    # First panel: DAG structure with edge weights and noise scales
    ax = axes[0]
    ax.set_facecolor("white")
    ax.set_xticks([])
    ax.set_yticks([])
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(0.6)
        spine.set_edgecolor("#cccccc")
    ax.margins(0.08)

    # Draw nodes in neutral gray
    node_colors = [neutral_gray] * d
    nx.draw_networkx_nodes(
        G, pos, ax=ax, nodelist=nodelist,
        node_color=node_colors, node_size=node_size,
        edgecolors="black", linewidths=0.5)
    nx.draw_networkx_labels(G, pos, ax=ax, labels=node_labels, font_size=14)

    # Draw all edges as solid
    nx.draw_networkx_edges(
        G, pos, ax=ax, edgelist=list(G.edges()),
        arrows=True, arrowstyle="-|>", arrowsize=17,
        width=1, min_target_margin=12.5, connectionstyle="arc3,rad=0.05")

    # Draw edge weights
    edge_labels = {(idx_to_node[i], idx_to_node[j]): f"{float(scm.W_base[i, j]):.2f}"
                    for (i, j) in edge_indices}
    nx.draw_networkx_edge_labels(
        G, pos, ax=ax, edge_labels=edge_labels,
        font_size=11, font_color="black", label_pos=0.35, alpha=1.0)

    # Add noise scales in red next to each node
    for i, node_id in enumerate(nodelist):
        x, y = pos[node_id]
        sigma_val = float(scm.sigma[i])
        ax.text(x, y + 0.2, f"{sigma_val:.2f}",
                fontsize=11, color="red",
                ha='center', va='center')

    ax.set_title("DAG Structure", fontsize=18)

    nodeid_to_idx = {nodelist[i]: i for i in range(d)}
    base_colors = [neutral_gray] * d

    for t in range(T):
        ax = axes[t + 1]
        row = interv_nodes_cpu[t]
        has_intervention = bool(torch.any(row > 0))

        if not has_intervention:
            target_node = None
            colors_t = base_colors
        else:
            intervened_idx = int(torch.argmax(row).item())
            target_node = idx_to_node[intervened_idx]
            colors_t = base_colors[:]
            val = float(np.clip(float(interv_values_cpu[t, intervened_idx]), -1.0, 1.0))
            k = nodeid_to_idx[target_node]
            colors_t[k] = cmap(norm(val))

        # Draw nodes and labels
        nx.draw_networkx_nodes(
            G, pos, ax=ax, nodelist=nodelist,
            node_color=colors_t, 
            node_size=node_size, edgecolors="black", linewidths=0.5
        )
        nx.draw_networkx_labels(G, pos, ax=ax, labels=node_labels, font_size=14)

        # Edges
        if target_node is None:
            edges_dashed = []
            edges_solid  = list(G.edges())
        else:
            edges_dashed = [(u, v) for (u, v) in G.edges() if v == target_node]
            edges_solid  = [(u, v) for (u, v) in G.edges() if v != target_node]

        if edges_solid:
            nx.draw_networkx_edges(
                G, pos, ax=ax, edgelist=edges_solid,
                arrows=True, arrowstyle="-|>", arrowsize=17,
                width=1, min_target_margin=12.5, connectionstyle="arc3,rad=0.05",
            )
        if edges_dashed:
            nx.draw_networkx_edges(
                G, pos, ax=ax, edgelist=edges_dashed,
                arrows=True, arrowstyle="-|>", arrowsize=17,
                width=1.2, style=(0, (3, 3)), min_target_margin=12.5, connectionstyle="arc3,rad=0.05"
            )

        ax.set_title(rf"$t={t+1}$", fontsize=18)

    for ax in axes[total_panels:]:
        ax.set_visible(False)

    sm = mpl.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    cbar = fig.colorbar(
        sm,
        ax=axes[:total_panels],
        location="bottom",
        orientation="horizontal",
        anchor=(1.0, 0.0),
        shrink=0.25,
        pad=0.03,
        aspect=20,
    )
    cbar.ax.tick_params(labelsize=14)
    cbar.set_label("Intervention value", fontsize=18, labelpad=7)

    title_str = rf"{policy_title} Policy Trajectory ($T={T}$)"
    fig.suptitle(title_str, fontsize=22, y=1.05)

    plt.show()

In [None]:
def plot_adj_heatmap(pred_list, true, titles=None, cmap="viridis", dpi=300):
    """Plots the adjacency matrix heatmap of predicted edge probabilities (AVICI output).
    Loops over a list of predictions, plotting each side-by-side until the final subplot, 
    which plots the ground truth adjacency matrix"""

    # Convert inputs to numpy
    preds = [p.detach().cpu().numpy() if isinstance(p, torch.Tensor) else np.asarray(p) for p in pred_list]
    T = true.detach().cpu().numpy() if isinstance(true, torch.Tensor) else np.asarray(true)

    # Shared color scale across all panels
    vmax = float(max([p.max() for p in preds] + [T.max()]))
    vmin = 0.0

    if titles is None:
        titles = [f"Pred {i+1}" for i in range(len(preds))]
    titles = list(titles) + ["True"]

    num_panels = len(preds) + 1
    cell_size = 0.35
    fig_w = num_panels * d * cell_size + 1.2
    fig_h = d * cell_size + 0.8

    fig, axes = plt.subplots(
        1, num_panels,
        figsize=(fig_w, fig_h), dpi=dpi,
        constrained_layout=True)

    tick_pos = np.arange(d)
    tick_lbl = [str(i+1) for i in range(d)]

    for j, (ax, data, title) in enumerate(zip(axes[:-1], preds, titles[:-1])):
        ax.imshow(data, cmap=cmap, vmin=vmin, vmax=vmax,
                  interpolation="nearest", aspect="equal")
        ax.set_title(title, fontsize=9)

        ax.set_xlim(-0.5, d-0.5)
        ax.set_ylim(d-0.5, -0.5)  

        if j == 0:
            ax.set_yticks(tick_pos)
            ax.set_yticklabels(tick_lbl, fontsize=8)
        else:
            ax.set_yticks([])

        ax.set_xticks(tick_pos)
        ax.set_xticklabels(tick_lbl, fontsize=8)

        for s in ax.spines.values():
            s.set_linewidth(0.6)
            s.set_edgecolor("black")

    # Plot truth in the last panel
    ax_true = axes[-1]
    ax_true.imshow(T, cmap=cmap, vmin=vmin, vmax=vmax,
                   interpolation="nearest", aspect="equal")
    ax_true.set_title(titles[-1], fontsize=9)
    ax_true.set_xlim(-0.5, d-0.5)
    ax_true.set_ylim(d-0.5, -0.5)
    ax_true.set_xticks(tick_pos)
    ax_true.set_xticklabels(tick_lbl, fontsize=8)
    ax_true.set_yticks([])

    for s in ax_true.spines.values():
        s.set_linewidth(0.6)
        s.set_edgecolor("black")

    # Shared colorbar on the right
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=vmin, vmax=vmax))
    sm.set_array([])
    cbar = fig.colorbar(
        sm,
        ax=axes.ravel().tolist(),
        location="right",
        orientation="vertical",
        shrink=0.7,
        pad=0.02,
        aspect=15
    )
    cbar.set_label("Edge Probability", fontsize=9, labelpad=5)
    cbar.ax.tick_params(labelsize=8)

    fig.text(0.5, 0.04, "Child Node", ha="center", va="top", fontsize=9)
    fig.text(-0.005, 0.5, "Parent Node", ha="right", va="center", fontsize=9, rotation="vertical")

    fig.suptitle("Final Edge Predictions", fontsize=10, y=1.0)

    return fig, axes


In [None]:
def run_policy(policy,
                 scm,          
                 A: Tensor,          
                 order: Tensor,  
                 d: int,
                 T: int,
                 device: torch.device,
                 obs_nodes: Tensor | None = None,  
                 obs_vals:  Tensor | None = None,  
                 obs_out:   Tensor | None = None,
                 action_mode: str = "sample"):  # "sample" or "mean"
      """For a given DAG, SCM, policy, and observational prefix, roll out a trajectory 
      of interventions and outcomes under the policy"""

      interv_nodes = torch.zeros(T, d, device=device)
      interv_vals  = torch.zeros(T, d, device=device)
      interv_out   = torch.zeros(T, d, device=device)

      was_training = getattr(policy, "training", None)
      if hasattr(policy, "eval"): policy.eval()

      for t in range(T):
          if obs_out is not None:
              hist_nodes = torch.cat([obs_nodes, interv_nodes[:t]], dim=0)
              hist_vals  = torch.cat([obs_vals,  interv_vals[:t]],  dim=0)
              hist_out   = torch.cat([obs_out,   interv_out[:t]],   dim=0)
          else:
              hist_nodes = interv_nodes[:t]
              hist_vals  = interv_vals[:t]
              hist_out   = interv_out[:t]

          hist_nodes = hist_nodes.to(device)
          hist_vals = hist_vals.to(device)
          hist_out = hist_out.to(device)

          if action_mode == "mean" and hasattr(policy, "mean_action"):
              interv_node_t, interv_val_t, _, _, _ = policy.mean_action(hist_nodes, hist_vals, hist_out)
          else:
              # default to sample
              if hasattr(policy, "sample_action"):
                  interv_node_t, interv_val_t, _, _, _ = policy.sample_action(hist_nodes, hist_vals, hist_out)
              else:
                  interv_node_t, interv_val_t = policy()

          out_t = scm.rsample(A, order, interv_node=interv_node_t, interv_val=interv_val_t)

          interv_nodes[t] = interv_node_t
          interv_vals[t]  = interv_val_t
          interv_out[t]   = out_t

      # restore training flag
      if was_training is not None and hasattr(policy, "train"):
          policy.train(was_training)

      return interv_nodes, interv_vals, interv_out


def _avici_metrics(interv_nodes, interv_outs, A_np, avici_model):
    """Helper function to pass intervention node masks and outcomes
    to the AVICI model, using AVICI's native classification metrics to evaluate
    the trajectory"""

    to_cpu = lambda x: x.detach().cpu() if torch.is_tensor(x) else x
    nodes_np = to_cpu(interv_nodes).numpy()
    outs_np  = to_cpu(interv_outs).numpy()

    g_prob   = avici_model(x=outs_np, interv=nodes_np)
    pred_adj = (g_prob > 0.5).astype(int)

    shd_val   = shd(A_np, pred_adj)
    f1_val    = classification_metrics(A_np, pred_adj)['f1']
    auroc_val = threshold_metrics(A_np, g_prob)['auroc']
    return g_prob, shd_val, f1_val, auroc_val

def eval_policy(policy,
                avici_model,  
                d: int,
                T: int,
                device = torch.device("cpu"),
                p: float | None = 0.583,
                num_obs: int = 50,
                visualize_interv: bool = True):
    policy = policy.to(device)
    
    # SCM + graph
    A, order = prior(d, p, device=device)
    scm = LinearGaussianSCM(A, device=device)

    # --- Observational prefix (no interventions) ---
    if num_obs > 0:
        obs_nodes = torch.zeros(num_obs, d, device=device)
        obs_vals  = torch.zeros(num_obs, d, device=device)
        obs_out   = scm.rsample(A, order, interv_node=None, interv_val=None, batch_shape=(num_obs,))
    else:
        obs_nodes = torch.zeros(0, d, device=device)
        obs_vals  = torch.zeros(0, d, device=device)
        obs_out   = torch.zeros(0, d, device=device)

    # --- Baselines ---
    rand_policy = RandomPolicy(d=d, device=device).to(device)
    obs_policy  = ObservationPolicy(d=d, device=device).to(device)

    # --- Rollouts ---
    A_np = A.detach().cpu().numpy()
    trained_nodes, trained_vals, trained_outs = run_policy(policy, scm=scm, A=A, order=order, d=d, T=T, device=device,
                                                           obs_nodes=obs_nodes, obs_vals=obs_vals, obs_out=obs_out)
    random_nodes, random_vals, random_outs = run_policy(rand_policy, scm=scm, A=A, order=order, d=d, T=T, device=device,
                                                        obs_nodes=obs_nodes, obs_vals=obs_vals, obs_out=obs_out)
    obs_only_nodes, obs_only_vals, obs_only_outs = run_policy(obs_policy, scm=scm, A=A, order=order, d=d, T=T, device=device,
                                                              obs_nodes=obs_nodes, obs_vals=obs_vals, obs_out=obs_out)
    # --- AVICI metrics ---
    g_prob_tr, shd_tr, f1_tr, auroc_tr = _avici_metrics(trained_nodes, trained_outs, A_np, avici_model)
    g_prob_rand, shd_rand, f1_rand, auroc_rand = _avici_metrics(random_nodes, random_outs, A_np, avici_model)
    g_prob_obs, shd_obs, f1_obs, auroc_obs = _avici_metrics(obs_only_nodes, obs_only_outs, A_np, avici_model)

    print(f"  SHD: {shd_tr},  F1: {f1_tr:.4f},  AUROC: {auroc_tr:.4f},  Interventions: {int(trained_nodes.sum().item())}")
    print(f"  SHD: {shd_rand},  F1: {f1_rand:.4f},  AUROC: {auroc_rand:.4f},  Interventions: {int(random_nodes.sum().item())}")
    print(f"  SHD: {shd_obs},  F1: {f1_obs:.4f},  AUROC: {auroc_obs:.4f},  Interventions: {int(obs_only_nodes.sum().item())}")

    plot_adj_heatmap([g_prob_tr, g_prob_rand, g_prob_obs], true=A_np, titles=["Trained", "Random", "Observation"])
    
    # --- Visualize & summarize ---
    if visualize_interv:
        plot_interventions(d, T, A, trained_nodes, trained_vals, scm, policy_title="Trained")
        plot_interventions(d, T, A, random_nodes,  random_vals, scm, policy_title="Random")
        plot_interventions(d, T, A, obs_only_nodes, obs_only_vals, scm, policy_title="Observation")

    return {
        "A": A, "order": order,
        "obs":  (obs_nodes, obs_vals, obs_out),
        "trained": {"nodes": trained_nodes, "vals": trained_vals, "outs": trained_outs,
                    "g_prob": g_prob_tr, "shd": shd_tr, "f1": f1_tr, "auroc": auroc_tr},
        "random": {"nodes": random_nodes, "vals": random_vals, "outs": random_outs,
                   "g_prob": g_prob_rand, "shd": shd_rand, "f1": f1_rand, "auroc": auroc_rand},
        "obs_only": {"nodes": obs_only_nodes, "vals": obs_only_vals, "outs": obs_only_outs,
                     "g_prob": g_prob_obs, "shd": shd_obs, "f1": f1_obs, "auroc": auroc_obs}}

In [None]:
eval_dict = eval_policy(trained_policy, avici_model, d=5, T=7, num_obs=20, visualize_interv=True)

Next, we compare the sample efficiency between trained, random, and observational policies. 

In [None]:
# Helper functions for metrics
def compute_logp(A_np, g_prob, eps=1e-8):
    """log posterior probability"""
    p_clipped = np.clip(g_prob, eps, 1.0 - eps)
    logp_mat = A_np * np.log(p_clipped) + (1 - A_np) * np.log(1 - p_clipped)
    return logp_mat.sum()

def compute_caasl(A_np, g_prob):
    """Expected Number of Correct Entries"""
    g_t_mat = A_np * g_prob + (1 - A_np) * (1 - g_prob)
    return g_t_mat.sum()

In [None]:
# params
d = 5
p = 0.583
N = 30
n_mc = 200
eps = 1e-8
num_obs = 20

obs_policy = ObservationPolicy(d=d, device=device)
rand_policy = RandomPolicy(d=d, device=device)
trained_policy = trained_policy.to(device)

# Initialize storage
shd_obs_all = np.zeros((n_mc, N))
shd_rand_all = np.zeros((n_mc, N))
shd_trained_all = np.zeros((n_mc, N))
auc_obs_all = np.zeros((n_mc, N))
auc_rand_all = np.zeros((n_mc, N))
auc_trained_all = np.zeros((n_mc, N))
logp_obs_all = np.zeros((n_mc, N))
logp_rand_all = np.zeros((n_mc, N))
logp_trained_all = np.zeros((n_mc, N))
caasl_obs_all = np.zeros((n_mc, N))
caasl_rand_all = np.zeros((n_mc, N))
caasl_trained_all = np.zeros((n_mc, N))

for mc in range(n_mc):
    # Sample DAG/mechanism
    A, order = prior(d=d, p=p, device=device)
    A_np = A.detach().cpu().numpy().astype(int)
    scm = LinearGaussianSCM(A, device=device)

    # Observational prefix
    obs_data_nodes = torch.zeros(num_obs, d, device=device)
    obs_data_vals = torch.zeros(num_obs, d, device=device)
    obs_data_outs = scm.rsample(A, order, batch_shape=(num_obs,))

    # Run policies for full horizon N
    obs_nodes, obs_vals, obs_outs = run_policy(obs_policy, scm, A, order, d, N, device)
    rand_nodes, rand_vals, rand_outs = run_policy(rand_policy, scm, A, order, d, N, device)
    trained_nodes, trained_vals, trained_outs = run_policy(trained_policy, scm, A, order, d, N, device,
                                                           obs_nodes=obs_data_nodes, obs_vals=obs_data_vals, obs_out=obs_data_outs)
    # Evaluate over n = 1,...,N
    for n in range(1, N + 1):
        # Observational policy (no prefix)
        outs_np = obs_outs[:n].detach().cpu().numpy()
        nodes_np = obs_nodes[:n].detach().cpu().numpy()

        g_prob = avici_model(x=outs_np, interv=nodes_np)
        pred_adj = (g_prob > 0.5).astype(int)

        shd_obs_all[mc, n-1] = shd(A_np, pred_adj)
        auc_obs_all[mc, n-1] = threshold_metrics(A_np, g_prob)['auroc']
        logp_obs_all[mc, n-1] = compute_logp(A_np, g_prob, eps)
        caasl_obs_all[mc, n-1] = compute_caasl(A_np, g_prob)

        # Random policy (no prefix)
        outs_np_r = rand_outs[:n].detach().cpu().numpy()
        nodes_np_r = rand_nodes[:n].detach().cpu().numpy()

        g_prob_r = avici_model(x=outs_np_r, interv=nodes_np_r)
        pred_adj_r = (g_prob_r > 0.5).astype(int)

        shd_rand_all[mc, n-1] = shd(A_np, pred_adj_r)
        auc_rand_all[mc, n-1] = threshold_metrics(A_np, g_prob_r)['auroc']
        logp_rand_all[mc, n-1] = compute_logp(A_np, g_prob_r, eps)
        caasl_rand_all[mc, n-1] = compute_caasl(A_np, g_prob_r)

        # Trained policy (exclude observational prefix from evaluation)
        outs_np_t = trained_outs[:n].detach().cpu().numpy()
        nodes_np_t = trained_nodes[:n].detach().cpu().numpy()

        g_prob_t = avici_model(x=outs_np_t, interv=nodes_np_t)
        pred_adj_t = (g_prob_t > 0.5).astype(int)

        shd_trained_all[mc, n-1] = shd(A_np, pred_adj_t)
        auc_trained_all[mc, n-1] = threshold_metrics(A_np, g_prob_t)['auroc']
        logp_trained_all[mc, n-1] = compute_logp(A_np, g_prob_t, eps)
        caasl_trained_all[mc, n-1] = compute_caasl(A_np, g_prob_t)

# --- mean + SE across MC ---
def mean_se(arr):
    return arr.mean(axis=0), arr.std(axis=0) / np.sqrt(max(1, arr.shape[0]))

mean_shd_obs, se_shd_obs = mean_se(shd_obs_all)
mean_shd_rand, se_shd_rand = mean_se(shd_rand_all)
mean_shd_trained, se_shd_trained = mean_se(shd_trained_all)
mean_auc_obs, se_auc_obs = mean_se(auc_obs_all)
mean_auc_rand, se_auc_rand = mean_se(auc_rand_all)
mean_auc_trained, se_auc_trained = mean_se(auc_trained_all)
mean_logp_obs, se_logp_obs = mean_se(logp_obs_all)
mean_logp_rand, se_logp_rand = mean_se(logp_rand_all)
mean_logp_trained, se_logp_trained = mean_se(logp_trained_all)
mean_caasl_obs, se_caasl_obs = mean_se(caasl_obs_all)
mean_caasl_rand, se_caasl_rand = mean_se(caasl_rand_all)
mean_caasl_trained, se_caasl_trained = mean_se(caasl_trained_all)

In [None]:
# --- Plot results ---

ns = np.arange(1, N + 1)
metrics = [
    ("SHD", mean_shd_obs, se_shd_obs, mean_shd_rand, se_shd_rand, mean_shd_trained, se_shd_trained),
    ("AUROC", mean_auc_obs, se_auc_obs, mean_auc_rand, se_auc_rand, mean_auc_trained, se_auc_trained),
    ("Log Posterior Probability", mean_logp_obs, se_logp_obs, mean_logp_rand, se_logp_rand, mean_logp_trained, se_logp_trained),
    ("Expected Number of Correct Entries", mean_caasl_obs, se_caasl_obs, mean_caasl_rand, se_caasl_rand, mean_caasl_trained, se_caasl_trained)]

fig, axes = plt.subplots(2, 2, figsize=(12, 8))
axes = axes.flatten()

for i, (ax, (title, mean_o, se_o, mean_r, se_r, mean_t, se_t)) in enumerate(zip(axes, metrics)):
    ax.plot(ns, mean_o, label="Observation", color="black")
    ax.fill_between(ns, mean_o - se_o, mean_o + se_o, color="black", alpha=0.2)
    ax.plot(ns, mean_r, label="Random", color="blue")
    ax.fill_between(ns, mean_r - se_r, mean_r + se_r, color="blue", alpha=0.2)
    ax.plot(ns, mean_t, label="Trained", color="red")
    ax.fill_between(ns, mean_t - se_t, mean_t + se_t, color="red", alpha=0.2)
    ax.set_xlabel("Total Time Periods (T)", fontsize=12)
    ax.set_ylabel(None)
    ax.set_title(title)
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    
    if i == 0:
        ax.legend()

plt.suptitle(rf"Graph Identifiability Across Sample Sizes ($d={d}$)", fontsize=14)
plt.tight_layout()
plt.show()

### 8. Effects of Entropy Target

First, we define a helper function `eval_policy_params()` that will allow us to run a Monte Carlo evaluation of a list of trained policies under given environmental parameters `w`, `sigma`, and `p`. 

In [None]:
def eval_policy_params(
    d: int,
    p: float,
    w: float | None = None,
    sigma: float | None = None,
    n_mc: int = 50,
    T: int = 8,
    num_obs: int = 20,
    trained_policies: list = None,
    avici_model=None,
    device: torch.device = torch.device('cpu'),
    action_mode: str = "sample",  # "sample" or "mean"
    eps: float = 1e-8):

    # Initialize
    obs_policy = ObservationPolicy(d=d, device=device)
    rand_policy = RandomPolicy(d=d, device=device)
    trained_policies = [policy.to(device) for policy in trained_policies]

    metrics_names = ['shd', 'auroc', 'logp', 'caasl']
    baseline_policies = ['obs', 'rand']
    trained_policy_names = [f'trained_{i}' for i in range(len(trained_policies))]
    all_policy_names = baseline_policies + trained_policy_names

    results = {}
    for policy in all_policy_names:
        results[policy] = {}
        for metric in metrics_names:
            results[policy][metric] = np.zeros(n_mc)  

    for mc in range(n_mc):
        if mc % 20 == 0:
            print(f"mc: {mc}")
            
        # Sample DAG
        A, order = prior(d=d, p=p, device=device)
        A_np = A.detach().cpu().numpy().astype(int)

        # Create SCM with specified parameters
        W_base = None
        if w is not None:
            # Use custom edge weight scaling
            in_degree = A.sum(dim=-2, keepdim=True).clamp_min(1.0)
            coeff_std = (w / torch.sqrt(in_degree)).expand_as(A)
            W_base = torch.normal(mean=0.0, std=coeff_std)

        sigma_vec = None
        if sigma is not None:
            sigma_vec = torch.full((d,), sigma, device=device)
        
        scm = LinearGaussianSCM(A, W_base=W_base, sigma=sigma_vec, device=device)

        # Generate observational prefix
        obs_data_nodes = torch.zeros(num_obs, d, device=device)
        obs_data_vals = torch.zeros(num_obs, d, device=device)
        obs_data_outs = scm.rsample(A, order, batch_shape=(num_obs,))

        # Run baseline policies
        obs_nodes, obs_vals, obs_outs = run_policy(obs_policy, scm, A, order, d, T, device)
        rand_nodes, rand_vals, rand_outs = run_policy(rand_policy, scm, A, order, d, T, device)

        # Run all trained policies (using specified action_mode)
        trained_results = []
        for trained_policy in trained_policies:
            trained_nodes, trained_vals, trained_outs = run_policy(trained_policy, scm, A, order, d, T, device,
                                                                    obs_nodes=obs_data_nodes, obs_vals=obs_data_vals, obs_out=obs_data_outs,
                                                                    action_mode=action_mode)
            trained_results.append((trained_nodes, trained_vals, trained_outs))

        # Save the policy trajectories
        policy_data = {'obs': (obs_outs, obs_nodes),  'rand': (rand_outs, rand_nodes)}
        for i, (trained_nodes, trained_vals, trained_outs) in enumerate(trained_results):
            policy_data[f'trained_{i}'] = (trained_outs, trained_nodes)

        for policy_name, (outs, nodes) in policy_data.items():
            outs_np = outs.detach().cpu().numpy()
            nodes_np = nodes.detach().cpu().numpy()

            # Calculate performance metrics
            g_prob = avici_model(x=outs_np, interv=nodes_np)
            pred_adj = (g_prob > 0.5).astype(int)

            results[policy_name]['shd'][mc] = shd(A_np, pred_adj)
            results[policy_name]['auroc'][mc] = threshold_metrics(A_np, g_prob)['auroc']
            results[policy_name]['logp'][mc] = compute_logp(A_np, g_prob, eps)
            results[policy_name]['caasl'][mc] = compute_caasl(A_np, g_prob)

    def mean_se(arr):
        return arr.mean(), arr.std() / np.sqrt(max(1, len(arr)))

    # Store results
    stats = {}
    for policy_name in all_policy_names:
        stats[policy_name] = {}
        for metric in metrics_names:
            mean_val, se_val = mean_se(results[policy_name][metric])
            stats[policy_name][metric] = {'mean': mean_val,
                                          'se': se_val,
                                          'raw': results[policy_name][metric]}

    return {'stats': stats, 'policy_names': all_policy_names}

The target entropy experiment is broken down into 3 stages: training, evaluation, and plotting

In [None]:

# ---------- Training Stage ---------- #

entropy_pairs = [
    (0.01, -2.0),     
    (0.25, -1.0),      
    (0.75, 0.0),     
    (0.98, 0.68)]

trained_models = {}

for i, (target_entropy_disc, target_entropy_cont) in enumerate(entropy_pairs):
    print(f"\nTraining Policy {i+1}/{len(entropy_pairs)}")
    print(f"Target Entropy: Discrete {target_entropy_disc:.3f}, Continuous {target_entropy_cont:.3f}")

    # Re-train the policy for each entropy pair
    training_stats, sac, replay_buffer, trained_policy = train(
        d=5,
        T=8,
        p=0.583,
        batch_size=128,
        num_episodes=4000, 
        num_obs=20,
        print_every=100,  
        device=torch.device("cuda"),
        avici_model=avici_model,
        tune_entropy=True,
        target_entropy_disc=target_entropy_disc * math.log(5),
        target_entropy_cont=target_entropy_cont,
        alpha_d=0.5,
        alpha_c=0.5,
        updates_per_step=1,
        warmup_episodes=500)

    # Store results
    key = f"model_{i}_disc{target_entropy_disc:.2f}_cont{target_entropy_cont:.2f}"
    trained_models[key] = {
        'training_stats': training_stats,
        'sac': sac,
        'replay_buffer': replay_buffer,
        'trained_policy': trained_policy,
        'target_entropy_disc': target_entropy_disc,
        'target_entropy_cont': target_entropy_cont,
        'model_index': i}


In [None]:
# ---------- Evaluation Stage ---------- #

n_mc = 300

# Extract the trained policies
trained_policies_list = [trained_models[key]['trained_policy'] for key in trained_models.keys()]

# Run evaluation with all 4 trained policies using the SAMPLE action mode

results_sample = eval_policy_params(
    d=5,
    p=0.583,
    w=None,
    sigma=None,
    n_mc=n_mc,  # Increase for more stable results
    T=8,
    num_obs=20,
    trained_policies=trained_policies_list,
    avici_model=avici_model,
    device=device,
    action_mode="sample")

# Run evaluation with all 4 trained policies using the MEAN action mode

results_mean = eval_policy_params(
    d=5,
    p=0.583,
    w=None,
    sigma=None,
    n_mc=n_mc,  
    T=8,
    num_obs=20,
    trained_policies=trained_policies_list,
    avici_model=avici_model,
    device=device,
    action_mode="mean")

In [None]:
# ---------- Plotting Stage ---------- #

metrics = ['shd', 'auroc', 'logp', 'caasl']

# Prepare data for plotting
scaled_pairs = [(x * math.log(5), y) for x, y in entropy_pairs]
entropy_labels = [f"({pair[0]:.2f}, {pair[1]:.2f})" for pair in scaled_pairs]
x_positions = np.arange(len(entropy_pairs))

# Extract baseline policy results
obs_results = {}
rand_results = {}
trained_sample_results = {metric: {'means': [], 'ses': []} for metric in metrics}
trained_mean_results = {metric: {'means': [], 'ses': []} for metric in metrics}

for metric in metrics:
    # Baseline policies
    obs_mean = results_sample['stats']['obs'][metric]['mean']
    obs_se = results_sample['stats']['obs'][metric]['se']
    obs_results[metric] = {'mean': obs_mean, 'se': obs_se}

    rand_mean = results_sample['stats']['rand'][metric]['mean']
    rand_se = results_sample['stats']['rand'][metric]['se']
    rand_results[metric] = {'mean': rand_mean, 'se': rand_se}

    # Trained policies (SAMPLE mode)
    for i in range(len(entropy_pairs)):
        policy_name = f'trained_{i}'
        mean_val = results_sample['stats'][policy_name][metric]['mean']
        se_val = results_sample['stats'][policy_name][metric]['se']
        trained_sample_results[metric]['means'].append(mean_val)
        trained_sample_results[metric]['ses'].append(se_val)

    # Trained policies (MEAN mode)
    for i in range(len(entropy_pairs)):
        policy_name = f'trained_{i}'
        mean_val = results_mean['stats'][policy_name][metric]['mean']
        se_val = results_mean['stats'][policy_name][metric]['se']
        trained_mean_results[metric]['means'].append(mean_val)
        trained_mean_results[metric]['ses'].append(se_val)

    # Convert to numpy
    trained_sample_results[metric]['means'] = np.array(trained_sample_results[metric]['means'])
    trained_sample_results[metric]['ses'] = np.array(trained_sample_results[metric]['ses'])
    trained_mean_results[metric]['means'] = np.array(trained_mean_results[metric]['means'])
    trained_mean_results[metric]['ses'] = np.array(trained_mean_results[metric]['ses'])

# Plot
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten()

metric_titles = {
    'shd': 'SHD',
    'auroc': 'AUROC',
    'logp': 'Log Posterior Probability',
    'caasl': 'Expected Number of Correct Entries'}

for i, (ax, metric) in enumerate(zip(axes, metrics)):

    # Observational policy
    obs_mean = obs_results[metric]['mean']
    obs_se = obs_results[metric]['se']
    ax.plot(x_positions, [obs_mean] * len(x_positions),
            label='Observational', color="black", linewidth=2)
    ax.fill_between(x_positions,
                    [obs_mean - obs_se] * len(x_positions),
                    [obs_mean + obs_se] * len(x_positions),
                    color="black", alpha=0.2)

    # Random policy
    rand_mean = rand_results[metric]['mean']
    rand_se = rand_results[metric]['se']
    ax.plot(x_positions, [rand_mean] * len(x_positions),
            label='Random', color="blue", linewidth=2)
    ax.fill_between(x_positions,
                    [rand_mean - rand_se] * len(x_positions),
                    [rand_mean + rand_se] * len(x_positions),
                    color="blue", alpha=0.2)

    # Trained policies (SAMPLE mode)
    trained_sample_means = trained_sample_results[metric]['means']
    trained_sample_ses = trained_sample_results[metric]['ses']
    ax.plot(x_positions, trained_sample_means,
            label='Trained (Sample)', color="purple", linewidth=2, marker='o')
    ax.fill_between(x_positions,
                    trained_sample_means - trained_sample_ses,
                    trained_sample_means + trained_sample_ses,
                    color="purple", alpha=0.2)

    # Trained policies (MEAN mode)
    trained_mean_means = trained_mean_results[metric]['means']
    trained_mean_ses = trained_mean_results[metric]['ses']
    ax.plot(x_positions, trained_mean_means,
            label='Trained (Mean)', color="red", linewidth=2, marker='s')
    ax.fill_between(x_positions,
                    trained_mean_means - trained_mean_ses,
                    trained_mean_means + trained_mean_ses,
                    color="red", alpha=0.2)

    # Formatting
    ax.set_xlabel('Target Entropy (Discrete, Continuous)', fontsize=12, labelpad=10)
    ax.set_title(metric_titles[metric], fontsize=14)
    ax.set_xticks(x_positions)
    ax.set_xticklabels(entropy_labels, fontsize=12, ha='center')
    ax.tick_params(axis='y', labelsize=12)
    ax.grid(True, alpha=0.3)

    if i == 0:
        ax.legend(loc="upper center", bbox_to_anchor=(0.19, 0.8))

plt.suptitle(f'Policy Performance vs. Target Entropy', fontsize=16)
plt.tight_layout()
plt.show()

### 9. Simulator Shift

In [None]:
# Check the list of trained models
trained_models.keys()

In [None]:
p_values = np.linspace(0.1, 0.9, 5)

# Define baseline p value
d = 5
baseline_p = 0.583
n_mc = 300

# Assume you have two trained policies
trained_policy1 = trained_models["model_0_disc0.01_cont-2.00"]["trained_policy"]
trained_policy2 = trained_models["model_2_disc0.75_cont0.00"]["trained_policy"]

# We'll evaluate each policy separately with its desired action mode
actual_policies = ['obs', 'rand', 'trained_0_mean', 'trained_1_sample']

# Storage for results across all p values
all_results = {}
metrics = ['shd', 'auroc', 'logp', 'caasl']

# Initialize storage for all actual policies
for policy in actual_policies:
    all_results[policy] = {}
    for metric in metrics:
        all_results[policy][metric] = {'means': [], 'ses': []}

# Loop over p values
for i, p in enumerate(p_values):
    print(f"Evaluating p = {p:.3f} ({i+1}/{len(p_values)})")

    # Run evaluation for trained_policy1 with action_mode="mean"
    results_mean = eval_policy_params(d=d, p=p, w=None, sigma=None, n_mc=n_mc, T=8, num_obs=20,
                                      trained_policies=[trained_policy1],
                                      avici_model=avici_model, device=device, action_mode="mean")

    # Run evaluation for trained_policy2 with action_mode="sample"
    results_sample = eval_policy_params(d=d, p=p, w=None, sigma=None, n_mc=n_mc, T=8, num_obs=20,
                                        trained_policies=[trained_policy2], 
                                        avici_model=avici_model, device=device, action_mode="sample")

    # Extract results
    for metric in metrics:
        # Baseline policies (same in both results)
        all_results['obs'][metric]['means'].append(results_mean['stats']['obs'][metric]['mean'])
        all_results['obs'][metric]['ses'].append(results_mean['stats']['obs'][metric]['se'])

        all_results['rand'][metric]['means'].append(results_mean['stats']['rand'][metric]['mean'])
        all_results['rand'][metric]['ses'].append(results_mean['stats']['rand'][metric]['se'])

        # Trained policy 1 (mean mode)
        all_results['trained_0_mean'][metric]['means'].append(results_mean['stats']['trained_0'][metric]['mean'])
        all_results['trained_0_mean'][metric]['ses'].append(results_mean['stats']['trained_0'][metric]['se'])

        # Trained policy 2 (sample mode)
        all_results['trained_1_sample'][metric]['means'].append(results_sample['stats']['trained_0'][metric]['mean'])
        all_results['trained_1_sample'][metric]['ses'].append(results_sample['stats']['trained_0'][metric]['se'])

In [None]:
# Convert to numpy 
for policy in actual_policies:
    for metric in metrics:
        all_results[policy][metric]['means'] = np.array(all_results[policy][metric]['means'])
        all_results[policy][metric]['ses'] = np.array(all_results[policy][metric]['ses'])

selected_metrics = ['shd', 'auroc', 'caasl']

# Plotting
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

colors = {
    'obs': 'black',
    'rand': 'blue',
    'trained_0_mean': 'red',
    'trained_1_sample': 'purple'
}

labels = {
    'obs': 'Observational',
    'rand': 'Random',
    'trained_0_mean': 'Deterministic',
    'trained_1_sample': 'Stochastic'
}

metric_titles = {
    'shd': 'SHD',
    'auroc': 'AUROC',
    'logp': 'Log Posterior Probability',
    'caasl': 'Expected Number of Correct Entries'
}

for i, (ax, metric) in enumerate(zip(axes, selected_metrics)):
    ax.axvline(x=baseline_p, color='gray', linestyle='--', alpha=0.7, linewidth=2,
                label=f'Baseline')

    for policy in actual_policies:
        if policy == 'obs':
            continue

        means = all_results[policy][metric]['means']
        ses = all_results[policy][metric]['ses']

        ax.plot(p_values, means, label=labels[policy], color=colors[policy], linewidth=2)
        ax.fill_between(p_values, means - ses, means + ses, color=colors[policy], alpha=0.2)

    if i == 0:
        ax.legend()

    ax.set_xlabel('Edge Probability (p)', fontsize=12)
    ax.set_title(metric_titles[metric], fontsize=14)
    ax.grid(True, alpha=0.3)
    ax.tick_params(axis='both', labelsize=12)

plt.suptitle(f'Policy Performance vs. Edge Probability (d={d})', fontsize=16)
plt.tight_layout()
plt.show()

In [None]:
# --- Sigma Shift --- #

sigma_values = np.linspace(0.001, 1.0, 5)
d = 5
baseline_sigma = 0.11
n_mc = 200

trained_policy1 = trained_models["model_0_disc0.01_cont-2.00"]["trained_policy"]  # for mean
trained_policy2 = trained_models["model_2_disc0.75_cont0.00"]["trained_policy"]  # for sample

all_results = {}
metrics = ['shd', 'auroc', 'logp', 'caasl']

policy_names = ['obs', 'rand', 'trained_sample', 'trained_mean']

# Initialize storage
for policy in policy_names:
    all_results[policy] = {}
    for metric in metrics:
        all_results[policy][metric] = {'means': [], 'ses': []}

# Loop over sigma values
for i, sigma in enumerate(sigma_values):

    # Run evaluation with trained_policy1 in MEAN mode
    results_mean = eval_policy_params(
        d=d,
        p=0.583,
        w=None,
        sigma=sigma,
        n_mc=n_mc,
        T=8,
        num_obs=20,
        trained_policies=[trained_policy1],
        avici_model=avici_model,
        device=device,
        action_mode="mean")
    
    # Run evaluation with trained_policy2 in SAMPLE mode
    results_sample = eval_policy_params(
        d=d,
        p=0.583,
        w=None,
        sigma=sigma,
        n_mc=n_mc,
        T=8,
        num_obs=20,
        trained_policies=[trained_policy2],
        avici_model=avici_model,
        device=device,
        action_mode="sample")

    for metric in metrics:
        # Baseline policies
        obs_mean = results_sample['stats']['obs'][metric]['mean']
        obs_se = results_sample['stats']['obs'][metric]['se']
        all_results['obs'][metric]['means'].append(obs_mean)
        all_results['obs'][metric]['ses'].append(obs_se)

        rand_mean = results_sample['stats']['rand'][metric]['mean']
        rand_se = results_sample['stats']['rand'][metric]['se']
        all_results['rand'][metric]['means'].append(rand_mean)
        all_results['rand'][metric]['ses'].append(rand_se)

        # Trained policy SAMPLE mode 
        sample_mean = results_sample['stats']['trained_0'][metric]['mean']
        sample_se = results_sample['stats']['trained_0'][metric]['se']
        all_results['trained_sample'][metric]['means'].append(sample_mean)
        all_results['trained_sample'][metric]['ses'].append(sample_se)

        # Trained policy MEAN mode 
        mean_mean = results_mean['stats']['trained_0'][metric]['mean']
        mean_se = results_mean['stats']['trained_0'][metric]['se']
        all_results['trained_mean'][metric]['means'].append(mean_mean)
        all_results['trained_mean'][metric]['ses'].append(mean_se)

In [None]:
# Convert to numpy
for policy in policy_names:
    for metric in metrics:
        all_results[policy][metric]['means'] = np.array(all_results[policy][metric]['means'])
        all_results[policy][metric]['ses'] = np.array(all_results[policy][metric]['ses'])

selected_metrics = ['shd', 'auroc', 'caasl']

# Plot
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

colors = {
    'obs': 'black',
    'rand': 'blue',
    'trained_sample': 'purple',
    'trained_mean': 'red'
}

labels = {
    'obs': 'Observational',
    'rand': 'Random',
    'trained_sample': 'Trained (Sample)',
    'trained_mean': 'Trained (Mean)'
}

metric_titles = {
    'shd': 'SHD',
    'auroc': 'AUROC',
    'caasl': 'Expected Number of Correct Entries'
}

for i, (ax, metric) in enumerate(zip(axes, selected_metrics)):
    ax.axvline(x=baseline_sigma, color='gray', linestyle='--', alpha=0.7, linewidth=2,
                label=f'Baseline (Mean)')

    for policy in policy_names:
        if policy == 'obs':  
            continue

        means = all_results[policy][metric]['means']
        ses = all_results[policy][metric]['ses']

        ax.plot(sigma_values, means, label=labels[policy], color=colors[policy], linewidth=2)
        ax.fill_between(sigma_values, means - ses, means + ses, color=colors[policy], alpha=0.2)

    ax.set_xlabel('Noise Standard Deviation (σ)', fontsize=12)
    ax.set_title(metric_titles[metric], fontsize=14)
    ax.grid(True, alpha=0.3)
    ax.tick_params(axis='both', labelsize=12)

    if i == 0:
        ax.legend()

plt.suptitle(f'Policy Performance vs. Noise Standard Deviation', fontsize=16)
plt.tight_layout()
plt.show()