In [1]:
# STL
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
# 3rd Party
import torch
import lightning as L
torch.manual_seed(0)
from transformers import BertModel, AutoModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from typing import Tuple, Literal

class BasisLinear(L.LightningModule):
    def __init__(self,
                 in_features: Tuple[int],
                 out_features: int,
                 n_bases: int,
                 n_projections: Tuple[int]):
        """
        Defines a collection of N linear transformations
        that are all parameterized by a shared set of B basis matrices
        """
        super().__init__()

        self.__in_features = tuple(in_features)
        self.__out_features = out_features
        self.__n_bases = n_bases
        self.__n_projections = n_projections

        matched_dims = []
        for (in_dim, proj_dim) in zip(reversed(self.__in_features[:-1]), reversed(self.__n_projections)):
            if in_dim == proj_dim:
                matched_dims.insert(0, proj_dim)
        self.__matched_dims = matched_dims

        # We could do one layer with stacked weight matrices, but that will also affect the fan-in values for the Xavier initializer...
        self.__bases = [torch.nn.Linear(self.__in_features[-1], self.__out_features, bias=False) for _ in range(n_bases)]

        # (O, num_bases)
        coefficient_data = torch.empty(*self.__n_projections, self.__n_bases)
        torch.nn.init.xavier_normal_(coefficient_data) # FIXME: Better initializer to use for this?
        self.__coefficients = torch.nn.parameter.Parameter(coefficient_data)

    def forward(self, x: torch.Tensor):
        """

        Args:
            x:
                Tensor of shape (..., in_features)

        
        Returns:
            `torch.Tensor`:
                Tensor of dimensions (..., n_projections, out_features)
        """
        assert tuple(x.shape[-len(self.__in_features):]) == self.__in_features

        # TODO: Could do this with a single linear layer--more efficient
        # (*leading_x_dims, num_bases, out_features)
        basis_vecs = torch.stack([basis(x) for basis in self.__bases], dim=-2)

        basis_shape = (
            *(x.shape[:-len(self.__matched_dims) - 1]),

            # Broadcast over unmatched projection dims
            *(1 for _ in self.__n_projections[:len(self.__n_projections) - len(self.__matched_dims)]),
            *self.__matched_dims,
            
            self.__n_bases,
            self.__out_features
        )
        basis_vecs = basis_vecs.view(*basis_shape)

        coeff_shape = (
            *(1 for _ in x.shape[:-len(self.__matched_dims) - 1]),

            *self.__n_projections,

            self.__n_bases,
            1                # Broadcast over out_features
        )
        coeff_view = self.__coefficients.view(*coeff_shape)
        multiplied = coeff_view * basis_vecs

        
        # (*leading_dims, *n_projections, out_features)
        summed = torch.sum(multiplied, dim=-2)
        return summed


In [3]:
def flatten_until(tensor: torch.Tensor, end_dim=-1):
    orig_indices = tensor.indices()

    if end_dim < 0:
        end_dim = len(orig_indices) - end_dim
    assert end_dim > 0

    new_dim_size = tensor.shape[end_dim]
    cum_indices = orig_indices[end_dim]
    for dim in range(end_dim - 1, -1, -1):
        cum_indices += orig_indices[dim] * new_dim_size
        new_dim_size *= tensor.shape[dim]

    new_indices = torch.concatenate([
        cum_indices.unsqueeze(0),
        orig_indices[end_dim + 1:]
    ])
    sparse_tense = torch.sparse_coo_tensor(
        indices=new_indices,
        values=tensor.values(),
        size=(new_dim_size, *tensor.shape[end_dim + 1:])
    )
    return sparse_tense

In [4]:

class RGATAttentionHead(torch.nn.Module):
    def __init__(self,
                 in_features: int,
                 attention_units: int,
                 out_features: int,
                 n_heads: int,
                 n_relations: int,
                 n_bases: int,
                 attention_mode: Literal['argat', 'wirgat'] = 'wirgat'):
        """
        Args:
            features:
                Number of input and output features
        """
        super().__init__()
        assert out_features % n_heads == 0
        assert attention_mode in {'argat', 'wirgat'}

        self.__in_features = in_features                          # F from RGAT
        self.__attention_units = attention_units                  # D from RGAT
        self.__chunked_features = out_features // n_heads         # F'/K 

        self.__n_bases = n_bases
        self.__n_relations = n_relations
        self.__n_heads = n_heads
        self.__attention_mode = attention_mode


        self.__projection = BasisLinear(
            in_features=(self.__in_features,),
            out_features=self.__chunked_features,
            n_bases=self.__n_bases,
            n_projections=[self.__n_relations, self.__n_heads]
        )

        # The query and key projections share the same bases
        self.__qk_proj = BasisLinear(
            in_features=(self.__n_relations, self.__n_heads, self.__chunked_features),
            out_features=self.__attention_units * 2,
            n_bases=self.__n_bases,
            n_projections=[self.__n_relations, self.__n_heads]
        )


    def forward(self, node_states, edges):
        """

        node_states: (batch, nodes, features)
        edges: sparse boolean array (batch, nodes, nodes, relations)
        """
        (batch_size, n_nodes) = node_states.shape[:-1]

        # (batch, nodes, relations, heads, chunked_features)
        V = self.__projection(node_states)

        # (batch, nodes, relations, heads 2, attention_units)
        QK = self.__qk_proj(V)
        # (batch, nodes, relations, heads, attention_units)
        Q, K = torch.split(QK, self.__attention_units, dim=-1)

        edge_indices = edges.indices()

        # (total_edges, num_heads, attention_units)
        Q_prime = Q[edge_indices[0], edge_indices[1], edge_indices[3], :, :]
        K_prime = K[edge_indices[0], edge_indices[2], edge_indices[3], :, :]
        # (total_edges, num_heads)
        logits = (Q_prime * K_prime).sum(dim=-1)

        if self.__attention_mode == 'wirgat':
            sparse_logits = torch.sparse_coo_tensor(
                indices=edge_indices, values=logits,
                size=(batch_size, n_nodes, n_nodes, self.__n_relations, self.__n_heads)
            )
            # Compute a separate probability distribution for each relation
            sparse_attention = torch.sparse.softmax(sparse_logits, dim=-3)
            # Validate probability distributions
            # assert torch.all(torch.abs(torch.sum(sparse_attention.to_dense(), -3) - 1) < 1e6)
        else:
            mask_indices = torch.stack([
                edge_indices[0],
                edge_indices[1],
                edge_indices[2] * self.__n_relations + edge_indices[3],
            ], dim=0)
            sparse_logits = torch.sparse_coo_tensor(
                indices=mask_indices, values=logits,
                size=(batch_size, n_nodes, n_nodes * self.__n_relations, self.__n_heads),
                requires_grad=True
            )
            sparse_attention = torch.sparse.softmax(sparse_logits, dim=-2)
            # Validate probability distributions
            # assert torch.all(torch.abs(torch.sum(sparse_attention.to_dense(), -2) - 1) < 1e6)
        edge_attentions = sparse_attention.values()
        

        # (batch, dest_nodes, relations, heads, chunk) --> (edges, heads, chunks)
        V_prime = V[edge_indices[0], edge_indices[2], edge_indices[3]]
        # (edges, heads, chunks)
        elwise_prod = torch.unsqueeze(edge_attentions, -1) * V_prime
        # (edges, out_features)
        elwise_prod = torch.flatten(elwise_prod, start_dim=1)

        # (batch_size*source_nodes, edges)
        edge_mask = torch.sparse_coo_tensor(
            indices=torch.stack([edge_indices[0]*n_nodes + edge_indices[1], torch.arange(edge_indices.shape[1])]),
            values=torch.ones(edge_indices.shape[1]),
            size=(batch_size * n_nodes, edge_indices.shape[1]),
            check_invariants=True
        )

        # (batch_size * source_nodes, out_features)
        flat_node_states = torch.sparse.mm(edge_mask, elwise_prod)

        # (batch_size, source_nodes, out_features)
        node_states = flat_node_states.view(batch_size, n_nodes, -1)
        return node_states


In [9]:
torch.manual_seed(0)
model = RGATAttentionHead(in_features=in_features,
                          attention_units=attention_units,
                          out_features=out_features,
                          n_heads=n_heads,
                          n_relations=n_relations,
                          n_bases=n_bases,
                          attention_mode='argat')

In [10]:
output = model(random_features, random_adj)
output.shape


torch.Size([5, 10, 264])

In [139]:
loss = torch.sum(output)
loss.backward()

Logits grad: tensor([[-1.0378e+00, -6.7509e-01, -2.8978e-01, -6.1378e-01, -2.2774e-01,
          2.9163e-01],
        [-1.7269e+00,  6.4099e-02, -1.3452e+00, -1.7433e-01, -2.3266e+00,
         -5.7234e-01],
        [ 3.0749e+00, -2.2655e+00,  2.4762e-01, -1.6088e+00, -9.8352e-01,
         -1.7201e-01],
        ...,
        [ 6.3667e-01,  6.4911e-01,  7.7401e-04, -3.0992e-01,  1.5062e-01,
         -8.9753e-01],
        [ 2.5477e-01,  6.1890e-01,  1.5265e-01,  8.2389e-02, -5.5311e-01,
          1.5517e-01],
        [-1.2211e-01,  2.1029e-01, -1.0515e-02,  2.1513e-02, -3.0218e-01,
         -4.7483e-01]])


In [6]:
in_features = 123
attention_units = 53
out_features = 264
n_heads = 6
n_relations = 7
n_bases = 3
max_nodes = 10
batch_size = 5

In [7]:
gen = torch.Generator().manual_seed(1)
random_features = 5 * (torch.randn(batch_size, max_nodes, in_features, generator=gen) - .5)
random_features.shape

torch.Size([5, 10, 123])

In [8]:
gen = torch.Generator().manual_seed(2)
random_adj = torch.randint(0, 2, size=[batch_size, max_nodes, max_nodes, n_relations], generator=gen).to_sparse()
random_adj

tensor(indices=tensor([[0, 0, 0,  ..., 4, 4, 4],
                       [0, 0, 0,  ..., 9, 9, 9],
                       [0, 0, 0,  ..., 9, 9, 9],
                       [1, 2, 5,  ..., 3, 4, 6]]),
       values=tensor([1, 1, 1,  ..., 1, 1, 1]),
       size=(5, 10, 10, 7), nnz=1779, layout=torch.sparse_coo)