In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [17]:
def masked_row_softmax(logits: torch.Tensor,
                       mask: torch.Tensor,
                       temperature: float = 1.0,
                       fallback_self_loop: bool = True) -> torch.Tensor:
    """
    Perform row-wise softmax only over entries where mask==1; others get prob 0.
    If a row has no allowed entries and fallback_self_loop=True, we set the diagonal
    mask to 1 for that row before softmax (so it becomes a self-loop).
    Shapes: logits, mask -> (B,N,N) or (1,N,N) for mask (will broadcast to B).
    """
    if mask is None:
        # standard row-softmax over all columns
        return F.softmax(logits / temperature, dim=-1)

    # Broadcast mask to (B,N,N)
    m = mask.bool()
    if m.dim() == 3 and logits.dim() == 3 and m.size(0) == 1:
        m = m.expand(logits.size(0), -1, -1)

    B, N, _ = logits.shape
    # Optionally ensure at least one allowed entry per row (to avoid all -inf → NaN)
    if fallback_self_loop:
        row_has_any = m.any(dim=-1)  # (B,N)
        if (~row_has_any).any():
            eye = torch.eye(N, device=logits.device, dtype=m.dtype).bool().unsqueeze(0)
            m = torch.where(~row_has_any.unsqueeze(-1) & eye, True, m)

    # Set disallowed positions to -inf so softmax gives ~0 there
    neg_inf = torch.finfo(logits.dtype).min
    masked_logits = torch.where(m, logits / temperature, neg_inf)

    # Numerically stable softmax: subtract row-max over allowed entries
    row_max = masked_logits.max(dim=-1, keepdim=True).values  # (B,N,1)
    exps = torch.exp(masked_logits - row_max) * m  # zero where disallowed
    denom = exps.sum(dim=-1, keepdim=True).clamp_min(1e-12)
    probs = exps / denom
    return probs  # (B,N,N), rows sum to 1 over allowed entries only


class TemporalNodeEncoder(nn.Module):
    def __init__(self, in_dim: int, hidden: int = 64, out_dim: int = 64, bidirectional=False):
        super().__init__()
        self.gru = nn.GRU(input_size=in_dim, hidden_size=hidden,
                          batch_first=True, bidirectional=bidirectional)
        self.proj = nn.Linear(hidden*(2 if bidirectional else 1), out_dim)

    def forward(self, x_BTNC: torch.Tensor) -> torch.Tensor:
        B, T, N, C = x_BTNC.shape
        x = x_BTNC.permute(0, 2, 1, 3).reshape(B*N, T, C)   # (B*N,T,C)
        h, _ = self.gru(x)
        h = self.proj(h[:, -1])                             # (B*N,H)
        return h.view(B, N, -1)                             # (B,N,H)


class EdgeScorer(nn.Module):
    def __init__(self, node_dim: int, edge_feat_dim: int = 0, hidden: int = 128):
        super().__init__()
        in_d = 4*node_dim + edge_feat_dim      # [hi, hj, hi-hj, hi*hj, (+ e_ij)]
        self.mlp = nn.Sequential(
            nn.Linear(in_d, hidden),
            nn.ReLU(inplace=True),
            nn.Linear(hidden, 1)
        )

    def forward(self, H_BNH: torch.Tensor, Efeat_BNNE: torch.Tensor | None = None) -> torch.Tensor:
        B, N, H = H_BNH.shape
        hi = H_BNH[:, :, None, :].expand(B, N, N, H)
        hj = H_BNH[:, None, :, :].expand(B, N, N, H)
        feats = [hi, hj, hi - hj, hi * hj]
        if Efeat_BNNE is not None:
            feats.append(Efeat_BNNE)
        Z = torch.cat(feats, dim=-1)
        return self.mlp(Z).squeeze(-1)  # (B,N,N)


class PropagationGraphLearner(nn.Module):
    """
    Learns a soft, directed propagation graph G (B,N,N) with masked row-softmax.
    Self-loops are allowed if the mask diagonal is 1.
    """
    def __init__(self, in_dim: int, node_dim: int = 64, edge_hidden: int = 128,
                 scorer_edge_feat_dim: int = 0, temperature: float = 1.0):
        super().__init__()
        self.encoder = TemporalNodeEncoder(in_dim, hidden=node_dim, out_dim=node_dim)
        self.scorer  = EdgeScorer(node_dim, edge_feat_dim=scorer_edge_feat_dim, hidden=edge_hidden)
        self.temperature = temperature

    def forward(self, x_BTNC: torch.Tensor,
                mask_M: torch.Tensor,                  # (1 or B, N, N) with 1=allowed (neighbors)
                edge_feats: torch.Tensor | None = None,
                fallback_self_loop: bool = True):
        H = self.encoder(x_BTNC)                        # (B,N,H)
        logits = self.scorer(H, edge_feats)             # (B,N,N), includes self logits
        soft = masked_row_softmax(logits, mask_M, self.temperature, fallback_self_loop)
        return soft  # (B,N,N)

In [18]:
B, T, N, C = 4, 12, 300, 8
x = torch.randn(B, T, N, C)

In [19]:
# Physical mask M (endpoint proximity, turn restrictions …)
# shape can be (1,N,N) to broadcast across batch
M = torch.randint(0, 2, (1, N, N)).bool()
M[:, torch.arange(N), torch.arange(N)] = False  # forbid self loops

In [21]:
learner = PropagationGraphLearner(
    in_dim=C,
    node_dim=64,
    edge_hidden=128,
    scorer_edge_feat_dim=0,     # set to E.shape[-1] if you pass edge features
    temperature=0.7,
)

In [22]:
soft_P = learner(x, mask_M=M, edge_feats=None)

In [23]:
M

tensor([[[False,  True,  True,  ...,  True, False,  True],
         [ True, False,  True,  ...,  True, False, False],
         [False, False, False,  ...,  True, False,  True],
         ...,
         [ True,  True, False,  ..., False,  True,  True],
         [ True, False,  True,  ...,  True, False,  True],
         [False, False,  True,  ..., False, False, False]]])

In [24]:
soft_P

tensor([[[0.0000, 0.0073, 0.0072,  ..., 0.0073, 0.0000, 0.0073],
         [0.0070, 0.0000, 0.0070,  ..., 0.0071, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0064, 0.0000, 0.0065],
         ...,
         [0.0067, 0.0069, 0.0000,  ..., 0.0000, 0.0066, 0.0069],
         [0.0066, 0.0000, 0.0066,  ..., 0.0066, 0.0000, 0.0066],
         [0.0000, 0.0000, 0.0067,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0074, 0.0074,  ..., 0.0073, 0.0000, 0.0073],
         [0.0071, 0.0000, 0.0072,  ..., 0.0070, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0064, 0.0000, 0.0065],
         ...,
         [0.0068, 0.0069, 0.0000,  ..., 0.0000, 0.0068, 0.0069],
         [0.0067, 0.0000, 0.0067,  ..., 0.0066, 0.0000, 0.0067],
         [0.0000, 0.0000, 0.0068,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0072, 0.0073,  ..., 0.0072, 0.0000, 0.0073],
         [0.0070, 0.0000, 0.0071,  ..., 0.0070, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0064, 0.0000, 0.