# A Conditional Normalising Flow Tutorial
---
This notebook extends the previous Higgs pT regression example by learning a **full probability distribution** for the Higgs $p_{T}$ using a **RealNVP normalising flow**. 

As opposed to predicting a single point estimate, we train via **maximum likelihood** so that we can sample or evaluate a density at arbitrary $p_{T}$ values, conditioned on the jet features.

In [None]:
# Need to reuse existing setup for data prep, etc. TODO: find way to nicely link back?

## RealNVP Building Blocks
We define a simple 1D `RealNVP` transform, made **conditional** on the jet features (treated as our `context`).  

For multiple coupling layers, we stack them to build a more flexible flow.

In [None]:
# RealNVP Coupling Layer for 1D target
# condition on D-dim context [batch, context_dim=flattened_jet_features]

# In 1D, the coupling transform is basically: 
#   y = x * exp(s(context)) + t(context)

# We need ensure to invertibility, so we flip input half-splits or sign flips, etc.
# where a flip is just a sign flip on the Jacobian determinant [J.det = -J.det].

class ConditionalRealNVPCoupling(nn.Module):
    def __init__(self, context_dim, hidden_dim=64) -> None:
        super().__init__()
        
        # Simple MLP that outputs scale and shift as functions of the `context`.
        
        # For the 1D target, scale & shift are just scalars per event.
        
        # the scale network
        self.net_s = nn.Sequential(
            nn.Linear(context_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            # outputs scale
            nn.Linear(hidden_dim, 1)
        )
        
        # the shift network
        self.net_t = nn.Sequential(
            nn.Linear(context_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            # outputs translation
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x, context, reverse=False):
        """
        Args:
            x: shape [batch, 1], the target pT dimension.
            context: shape [batch, context_dim], the jet features.
            reverse: if True, invert the flow.
        Returns:
            y: shape [batch, 1], transformed variable.
            log_abs_det: the log of absolute determinant of the Jacobian.
        """
        s = self.net_s(context)  # [batch, 1] -> scale
        t = self.net_t(context)  # [batch, 1] -> shift
        
        if not reverse:
            # forward transform: 
            # y = x * exp(s) + t
            y = x * torch.exp(s) + t
            
            # log det = sum over dims, but here dims=1 => just s
            log_abs_det = s.squeeze(dim=-1)  # shape [batch]
        else:
            # inverse transform: 
            # x = (y - t) * exp(-s)
            y_ = x
            x = (y_ - t) * torch.exp(-s)
            y = x
            
            # log det = -s for the inverse
            # log det = sum over dims, but here dims=1 => just s
            log_abs_det = -s.squeeze(dim=-1)
        
        # return the transformed variable and the log determinant of the Jacobian
        return y, log_abs_det


## Stacking Multiple Coupling Layers
We can combine multiple coupling layers (with sign flips or permutations) for a richer transformation. 

Below is a minimal container that holds **N** coupling layers and composes them. 

In 1D, we often just flip the sign or add an identity pass each time to ensure that each layer operates differently.

In [None]:
class RealNVPFlow(nn.Module):
    def __init__(self, context_dim, n_coupling_layers=4, hidden_dim=64) -> None:
        super().__init__()
        self.layers = nn.ModuleList([
            ConditionalRealNVPCoupling(context_dim, hidden_dim=hidden_dim)
            for _ in range(n_coupling_layers)
        ])

    def forward(self, x, context):
        """ Forward pass: transforms base -> data space.
            We'll treat x as base-samples (z).
        """
        logdet_sum = 0.0
        y = x
        for i, layer in enumerate(self.layers):
            y, logdet = layer(y, context, reverse=False)
            logdet_sum += logdet
        return y, logdet_sum

    def inverse(self, y, context):
        """ Inverse pass: transforms data -> base space.
            This is used for log-likelihood computation.
        """
        logdet_sum = 0.0
        x = y
        # inverse in reverse order:
        for i, layer in reversed(list(enumerate(self.layers))):
            x, logdet = layer(x, context, reverse=True)
            logdet_sum += logdet
        return x, logdet_sum

    def log_prob(self, y, context):
        """Compute log p(y|context).  
        
        We transform y -> x in base, and add log p_x(x) + log|det J|.
        Base distribution is standard Normal(0,1) in 1D.
        """
        # inverse transform y -> x
        x, logdet = self.inverse(y, context)  
        # base log prob ~ -0.5 * x^2 - 0.5*log(2pi)
        log_p_x = -0.5*(x**2) - 0.5*np.log(2*np.pi)
        log_p_x = log_p_x.squeeze(dim=-1)  # shape [batch]
        return log_p_x + logdet

    def sample(self, context, n_samples=1):
        """Draw samples from p(y|context).  In 1D, x ~ N(0,1), then forward.
        """
        # context shape: [batch, context_dim]. We'll sample for each item.
        batch_size = context.shape[0]
        z = torch.randn(batch_size*n_samples, 1)  # random normal samples for base space
        
        # repeat context to match that shape
        # or do it in a loop if you prefer.  We'll tile the context.
        repeated_context = context.repeat_interleave(n_samples, dim=0)
        
        # forward transform z -> y
        y, _ = self.forward(z, repeated_context)
        
        # return the samples
        return y


## Training the Flow
Instead of MSE we used in the previous regression tutorial, we use the **negative log-likelihood** loss:
$$\mathcal{L} = -\frac{1}{N} \sum_{i=1}^N \log p(\text{Higgs}\ p_T^{(i)} \mid \text{jet features}^{(i)})$$

Below is a typical loop: you pass 
`(y, context)` → compute `flow.log_prob(y, context)` → maximise that log-prob (or minimise the negative).

In [None]:
# flow config

# you may recognise this as the same input dim for the previous regression tutorial
context_dim = (N_JETS * N_FEATURES) 

# build the flow
flow = RealNVPFlow(context_dim=context_dim, n_coupling_layers=4, hidden_dim=64)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
flow.to(device)

optimiser = optim.AdamW(flow.parameters(), lr=1e-3)
n_epochs = 10  # or more

train_losses = []
for epoch in range(n_epochs):
    flow.train()
    total_loss = 0.0
    for batch in train_loader_reg:  # same loader as before, but now we'll interpret `batch['target']` as y, and `batch['jets']` as context.
        context = batch['jets'].to(device)  # shape [B, N_JETS, N_FEATURES]
        context = context.reshape(context.size(0), -1)  # flatten
        y = batch['target'].to(device).unsqueeze(-1)  # shape [B, 1]

        optimiser.zero_grad()
        
        # log_prob has shape [batch], so we take mean of negative for loss.
        log_p = flow.log_prob(y, context)
        loss = -log_p.mean()
        loss.backward()
        optimiser.step()
        total_loss += loss.item() * len(y)

    avg_loss = total_loss / len(train_loader_reg.dataset)
    train_losses.append(avg_loss)
    
    # TODO: add a progress bar and improve logging here!
    print(f"Epoch [{epoch+1}/{n_epochs}]  NLL: {avg_loss:.4f}")

## Inference & Sampling
With a trained flow, you can now:

1. **Compute** $p(\hat p_T \mid X)$ for any $\hat p_T$. 
    - This is the probability of the Higgs pT given the jet features.
  
2. **Sample** from the distribution to see the variety of plausible pT outcomes.
    - This is the distribution of the Higgs pT given the jet features.

3. **Evaluate** the log-likelihood of the test set.
    - This is the log-likelihood of the Higgs pT given the jet features.


In [None]:
# Evaluate log p on test set + produce samples
flow.eval()
all_logp = []
all_pt_truth = []
with torch.no_grad():
    for batch in test_loader_reg:
        context = batch['jets'].to(device)
        context = context.reshape(context.size(0), -1)
        y = batch['target'].to(device).unsqueeze(-1)
        log_p = flow.log_prob(y, context)
        all_logp.append(log_p.cpu().numpy())
        all_pt_truth.append(y.cpu().numpy())
all_logp = np.concatenate(all_logp, axis=0)
all_pt_truth = np.concatenate(all_pt_truth, axis=0)

print(f"Average log probability on test set: {all_logp.mean():.4f}")

# As an example, we can generate samples for 1 batch of test context.
sample_batch = next(iter(test_loader_reg))
context_sample = sample_batch['jets'][:10].to(device)
context_sample = context_sample.reshape(context_sample.size(0), -1)

# Now, we can draw samples from the flow!
samples = flow.sample(context_sample, n_samples=100)  # shape [10*100, 1]

print("Samples shape:", samples.shape)
print("Example pT samples:", samples[:10].detach().cpu().squeeze().numpy())

# Cool right!


In [None]:
# Suppose 'flow' is trained, and we have a test batch of size B, sample from the flow over a set number of test batches.

# for batch in test_loader_reg:
#     all_truth.append(batch["target"].cpu().numpy())
# truth_np = np.concatenate(all_truth, axis=0)

batches = [next(iter(test_loader_reg)) for _ in range(10)]  # Use 2 batches
context = torch.cat([batch["jets"].to(device) for batch in batches]).reshape(-1, context.size(1))  # [B*2, context_dim]
true_scaled = torch.cat([batch["target"].to(device).unsqueeze(-1) for batch in batches])  # [B*2, 1] scaled pT

# -----> SAMPLE from the flow
n_samples_per_event = 1000  # (flows per event)
samples_scaled = flow.sample(context, n_samples=n_samples_per_event)  
# shape = [B*n_samples_per_event, 1]

# -----> Inverse-scale both "samples_scaled" and "true_scaled" so they're back in [MeV or GeV]
with torch.no_grad():
    
    # useful if we are on a GPU
    samples_np = samples_scaled.cpu().numpy()
    true_np    = true_scaled.cpu().numpy()

    # Invert with 'target_scaler' (the one used for pT, not the feature scaler!)
    samples_unscaled = target_scaler.inverse_transform(samples_np)  # shape [B*n_samples, 1]
    truth_unscaled   = target_scaler.inverse_transform(true_np)     # shape [B, 1]

# conert to GeV
samples_GeV = samples_unscaled / 1000.0
truth_GeV   = truth_unscaled   / 1000.0

In [None]:
plt.figure()
plt.hist(samples_GeV.flatten(), bins=50, alpha=0.5, density=True, label="Flow Samples (GeV)")
plt.hist(truth_GeV.flatten(),   bins=50, alpha=0.5, density=True, label="Truth pT (GeV)")
plt.xlabel("Higgs pT [GeV]")
plt.ylabel("Normalized Counts")
plt.legend()
plt.show()

In [None]:
# Does the flow cpature spread of pT?

print(f"Truth Variance: {np.var(truth_GeV):.2f}, Sample Variance: {np.var(samples_GeV):.2f}")

# plot kde of the samples and the truth idea
sns.kdeplot(samples_GeV.flatten(), label="Flow Samples KDE", fill=True)
sns.kdeplot(truth_GeV.flatten(), label="Truth pT KDE", fill=True)
plt.legend()
plt.show()

# does the flow predict reasonable quantiles?
percentiles = [10, 25, 50, 75, 90]
truth_quantiles = np.percentile(truth_GeV, percentiles)
sample_quantiles = np.percentile(samples_GeV, percentiles)

for p, tq, sq in zip(percentiles, truth_quantiles, sample_quantiles):
    print(f"{p}th percentile - Truth: {tq:.2f}, Samples: {sq:.2f}")

Soe things to consider
Flow Samples May Be Overfitting to High-Density Regions, i.e. the flow learns by maximising log-likelihood, so if some bins of pT have far more training points than others, it might over-prioritsze learning those regions, leading to sharp peaks and gaps in less-represented regions. The flow is also parametric, meaning that it can only model distributions it has seen enough examples of.

Next tutorial idea ---> Neural Spline Flows; allow non-linear monotonic transformations! since the flow is monotonic, it can learn the inverse transform, and hence the quantiles. REAL NVPs only allow linear transformations.