In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [4]:
class noflayer(nn.Module):
    """
    SEA-GWNN lifting layer supporting input of shape (B, N, T, C).
    Vectorized across batch and time. 
    Adjacency matrix is passed as argument in forward().
    """
    def __init__(self, nnode, in_features, out_features, hop, alpha,
                 residual=False, variant=False, leaky_alpha=0.2, alp=0.9,
                 symmetrize=True, nonneg=True):
        super().__init__()
        self.variant = variant
        self.nnode = nnode
        self.alpha_ = alpha
        self.hop = hop
        self.alp = alp
        self.leaky_alpha = leaky_alpha
        self.symmetrize = symmetrize
        self.nonneg = nonneg

        self.in_features = 2*in_features if variant else in_features
        self.out_features = out_features
        self.residual = residual

        # ----- Attention vector -----
        self.a = nn.Parameter(torch.empty(size=(2*self.in_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)
        self.leakyrelu = nn.LeakyReLU(self.leaky_alpha)

        # ----- Lifting parameters -----
        self.temp = nn.Parameter(torch.Tensor(self.hop+1))
        Temp = self.alp * np.arange(self.hop+1)
        Temp = Temp / np.sum(np.abs(Temp)) if np.sum(np.abs(Temp)) > 0 else Temp
        self.cheb = nn.Parameter(torch.tensor(Temp, dtype=torch.float32))

        self.reset_parameters()

    def reset_parameters(self):
        self.temp.data.fill_(0.0)
        for k in range(self.hop+1):
            self.cheb.data[k] = self.alp*(1-self.alp)**k
        self.cheb.data[-1] = (1-self.alp)**2

    # --------- process adjacency ---------
    def process_adj(self, A):
        if self.nonneg:
            A = torch.sigmoid(A)   # constrain to (0,1)
        if self.symmetrize:
            A = 0.5 * (A + A.T)    # enforce symmetry
        return A

    # ---------- Attention ----------
    def attention(self, x_BTNC, A):
        """
        x_BTNC: (B,T,N,C)
        A: (N,N) adjacency (can be fixed or learnable, passed at forward)
        returns U, P: (B,T,N,N)
        """
        B, T, N, C = x_BTNC.shape
        a1 = self.a[:self.in_features, :]   # (C,1)
        a2 = self.a[self.in_features:, :]   # (C,1)

        feat_1 = torch.matmul(x_BTNC, a1)                # (B,T,N,1)
        feat_2 = torch.matmul(x_BTNC, a2)                # (B,T,N,1)
        e = feat_1 + feat_2.transpose(-2, -1)            # (B,T,N,N)
        e = self.leakyrelu(e)

        A = self.process_adj(A)
        mask = (A > 0).view(1,1,N,N).expand(B,T,-1,-1)
        neg_inf = torch.finfo(e.dtype).min
        e_masked = torch.where(mask, e, e.new_full((), neg_inf))

        U = torch.softmax(e_masked, dim=-1)  # (B,T,N,N)
        P = 0.5 * U
        return U, P, A

    # ---------- Lifting ----------
    def forward_lifting_bases(self, x_BTNC, P_BTNN, U_BTNN, A):
        B, T, N, C = x_BTNC.shape
        coe = torch.sigmoid(self.temp)   # (hop+1,)
        cheb_coe = torch.sigmoid(self.cheb)

        AdjP = A.view(1,1,N,N) * P_BTNN   # (B,T,N,N)
        rowsum = AdjP.sum(-1)             # (B,T,N)

        update = x_BTNC
        feat_prime = None

        for step in range(self.hop):
            update = torch.einsum("btij,btjc->btic", U_BTNN, update)

            if self.alpha_ is None:
                feat_even_bar = coe[0]*x_BTNC + update
            else:
                feat_even_bar = update

            if step >= 1:
                rowsum = cheb_coe[step-1] * rowsum

            feat_odd_bar = update - feat_even_bar * rowsum.unsqueeze(-1)

            if step == 0:
                if self.alpha_ is None:
                    feat_fuse = coe[1]*feat_even_bar + (1-coe[1])*feat_odd_bar
                    feat_prime = coe[2]*x_BTNC + (1-coe[2])*feat_fuse
                else:
                    feat_fuse = self.alpha_*feat_even_bar + (1-self.alpha_)*feat_odd_bar
                    feat_prime = self.alpha_*x_BTNC + (1-self.alpha_)*feat_fuse
            else:
                if self.alpha_ is None:
                    feat_fuse = coe[1]*feat_even_bar + (1-coe[1])*feat_odd_bar
                    feat_prime = coe[2]*feat_prime + (1-coe[2])*feat_fuse
                else:
                    feat_fuse = self.alpha_*feat_even_bar + (1-self.alpha_)*feat_odd_bar
                    feat_prime = self.alpha_*feat_prime + (1-self.alpha_)*feat_fuse

        return feat_prime   # (B,T,N,C)

    # ---------- Forward ----------
    def forward(self, input, A, _format='BCNT'):
        """
        input: (B,N,T,C_in)
        # h0: same shape as input
        A: (N,N) adjacency (fixed or learnable)
        _format: input format
        returns: (B,N,T,C_out)
        """
        if _format == 'BCNT':
            x_BTNC = input.permute(0,3,2,1).contiguous()
        elif _format == 'BNTC':
            x_BTNC = input.permute(0,2,1,3).contiguous()
        
        U, P, A_proc = self.attention(x_BTNC, A)
        out_BTNC = self.forward_lifting_bases(x_BTNC, P, U, A_proc)

        if _format == 'BCNT':
            out = out_BTNC.permute(0,3,2,1).contiguous()
        elif _format == 'BNTC':
            out = out_BTNC.permute(0,2,1,3).contiguous()

        return out

In [5]:
N = 1212  # number of nodes
A = torch.randn(N, N, requires_grad=True)  # pretend it's learnable

In [6]:
layer = noflayer(nnode=N, in_features=40, out_features=40, hop=2, alpha=0.5)

B, T, C = 16, 12, 40
x = torch.randn(B, N, T, C, requires_grad=True)
h0 = torch.randn(B, N, T, C, requires_grad=True)

In [7]:
class GraphLearner(nn.Module):
    def __init__(self, in_dim, embed_dim):
        super(GraphLearner, self).__init__()
        self.fc = nn.Linear(in_dim, embed_dim)
        self.attn = nn.Parameter(torch.Tensor(embed_dim, embed_dim))
        nn.init.xavier_uniform_(self.attn)

    def forward(self, X, A_road):
        """
        X: (B, T, N, C)   traffic states
        A_road: (N, N)    static road adjacency
        return: (B, N, N) adaptive adjacency per batch
        """
        B, T, N, C = X.shape

        # 1. Temporal pooling (average over time steps)
        X_pool = X.mean(dim=1)        # (B, N, C)

        # 2. Feature projection
        E_feat = self.fc(X_pool)      # (B, N, embed_dim)

        # 3. Graph-aware smoothing using A_road
        # apply A_road to each batch separately
        E_graph = torch.matmul(A_road, E_feat)   # (B, N, embed_dim)

        # 4. Combine embeddings
        E = E_feat + E_graph          # (B, N, embed_dim)

        # 5. Compute adaptive adjacency via bilinear similarity
        # E (B, N, d), attn (d, d)
        logits = torch.matmul(E @ self.attn, E.transpose(1, 2))  # (B, N, N)

        # Normalize rows with softmax to ensure stochastic adjacency
        A_adaptive = F.softmax(F.relu(logits), dim=-1)  # (B, N, N)

        return A_adaptive

In [12]:
gl = GraphLearner(40, 40)
_x = torch.randn(B, T, N, C, requires_grad=True)
a_adapt = gl(_x, A)

In [13]:
a_adapt.shape

torch.Size([16, 1212, 1212])

In [14]:
out = layer(x, A, _format='BNTC')  # forward pass with adjacency
loss = out.sum()                       # simple scalar loss
loss.backward()                        # backprop

In [16]:
print("Adjacency grad:", A.grad)      # should not be None
print("x grad:", x.grad)              # check input grad as well
# print("h0 grad:", h0.grad)

Adjacency grad: tensor([[-0.0049, -0.0092, -0.0036,  ..., -0.0087, -0.0082, -0.0081],
        [-0.0091, -0.0084, -0.0090,  ..., -0.0102, -0.0090, -0.0084],
        [-0.0100, -0.0092, -0.0069,  ..., -0.0037, -0.0090, -0.0078],
        ...,
        [-0.0089, -0.0097, -0.0085,  ..., -0.0103, -0.0061, -0.0092],
        [-0.0070, -0.0080, -0.0083,  ..., -0.0085, -0.0084, -0.0068],
        [-0.0082, -0.0081, -0.0085,  ..., -0.0072, -0.0087, -0.0073]])
x grad: tensor([[[[ 5.1203e-01,  6.1127e-01,  5.5338e-01,  ...,  4.9528e-01,
            6.8351e-01,  5.4603e-01],
          [ 6.6421e-01,  4.5335e-01,  5.5989e-01,  ...,  7.6277e-01,
            3.5063e-01,  6.0671e-01],
          [ 2.1107e+00,  1.2880e+00,  1.6410e+00,  ...,  2.7350e+00,
            1.0806e+00,  1.9425e+00],
          ...,
          [ 6.5839e-01,  4.2865e-01,  5.4012e-01,  ...,  7.8343e-01,
            3.3098e-01,  5.9987e-01],
          [ 3.7596e-01,  3.4178e-01,  3.5605e-01,  ...,  4.0342e-01,
            3.3440e-01,  3.693