In [None]:
class BatchedPrecisionAttentionBlock(nn.Module):
  def __init__(self, args):
    """
    Initializes the batched precision-weighted attention block.

    Parameters:
        W_q, W_k, W_v, W_o (torch.Tensor): Learnable weight matrices (query, key, value, and output).
        nu (float): Scaling parameter.
        args: Additional model/system parameters.
    """
    
    super().__init__()

    assert args.d_e % 2 == 0 # d_e must be divisible by 2, since eigenvalues come in complex conjugate pairs

    #####################

    sqrt_dv = torch.sqrt(torch.tensor(args.d_v))
  
    self.W_q = nn.Parameter(init_complex_matrix(args.d_k, args.d_e)) # Query weight matrix
    self.W_k = nn.Parameter(init_complex_matrix(args.d_k, args.d_e)) # Key weight matrix
    self.W_v = nn.Parameter(init_complex_matrix(args.d_v, args.d_e)) # Value weight matrix
    self.W_r = nn.Parameter(init_complex_matrix(args.d_v, args.d_e)) # Residual weight matrix
    self.W_p = nn.Parameter(init_complex_matrix(args.d_e, args.d_v)) # Prediction output weight matrix
    
    #####################

    self.args = args
    self.nu = 1 # Measurement weighting in attention; Just set to 1 for now, since this can be absorbed into weight matrices

    self.causal_mask = torch.tril(torch.ones(args.Npts, args.Npts)).view(1, args.Npts, args.Npts, 1, 1).to(args.device) # Causal attention mask

  def forward(self, X, t_measure_all):
    """
    Forward pass through the precision-weighted attention block.

    Parameters:
        X (torch.Tensor): Input data.
        lambder_h (torch.Tensor): Diagonal of state transition matrix.
        lambda_Omega (torch.Tensor): Process noise covariance.
        lambda_Omega0 (torch.Tensor): Initial process noise covariance.
        lambda_C (torch.Tensor): Measurement output matrix.
        lambda_Gamma (torch.Tensor): Measurement noise covariance.
        t_measure_all (torch.Tensor): Time differences vector, for each trajectory in batch.

    Returns:
        out (torch.Tensor): Output tensor.
        Q_ij (torch.Tensor): Normalized attention weights.
        X_ij_hat_all (torch.Tensor): Estimated values.
    """
    
    X_q = torch.matmul(W_q, X)
    X_k = torch.matmul(W_k, X)
    X_v = torch.matmul(W_v, X)

    # Compute unnormalized attention matrix
    mahalanobis_distance = R_qk_ij**2
    denom = (1 + nu*torch.sum(mahalanobis_distance, axis=3, keepdims = True))
    A_ij = 1 / denom
 
    A_ij = A_ij * self.causal_mask # Apply causal mask to attention matrix
    X_ij_hat_all = X_ij_hat_all * self.causal_mask # Mask out estimates backward in time (not strictly necessary but useful larter for visualization)
    
    # Normalize attention
    S_ij = torch.sum(A_ij, axis=2, keepdims = True)
    Q_ij = A_ij / S_ij
    
    # Compute Hadamard product and sum to get estimate in diagonalized space
#     est_v = torch.sum(Q_ij * X_ij_hat_all,axis=3)
    est_v = torch.sum(Q_ij.unsqueeze(1) * X_v)

    # Add residual connection
    est_latent = est_v # No residual connection
    
    # Multiply by output matrix to get estimate
    out = batched_complex_matmul(W_p,est_latent)
     
    return out, Q_ij