In [1]:
import numpy as np
import torch
import torch.nn as nn
from torchsummary import summary

from moduleZoo.graphs import MultiHeadSelfGraphAttentionLinear

In [2]:
class TargetModel(nn.Module):
    def __init__(self, db = False) -> None:
        super().__init__()
        self.module = MultiHeadSelfGraphAttentionLinear(128, 256, n_heads=2, residual=True, dynamic_batching=db)
        self.n_nodes = [1]*(256//2) + [2]*(256//4)

    def enable_dynamic_batching(self) -> None:
        self.module.db = True

    def forward(self, x:torch.Tensor) -> torch.Tensor:
        # print(f'{x.shape = }')
        x = torch.cat(x.split(1, dim=0), dim=1).squeeze(dim=0)
        return torch.cat(self.module(x, self.n_nodes*2).unsqueeze(dim=0).split(x.shape[0]//2, dim=1), dim=0)

In [3]:
module1 = TargetModel(False)

In [4]:
module1 = module1.to('cuda')

In [5]:
data = torch.rand((2, 256, 128), device='cuda')
n_nodes = [1]*(256//2) + [2]*(256//4)
n_nodes = np.array(n_nodes*2)

In [6]:
result1 = module1(data)

In [7]:
# torch.all((result1.detach().cpu() - result2.detach().cpu()).abs() < 1e-6)

In [8]:
summary(module1, [256, 128])

Layer (type:depth-idx)                        Output Shape              Param #
├─MultiHeadSelfGraphAttentionLinear: 1-1      [-1, 2, 256]              --
|    └─Linear: 2-1                            [-1, 512]                 66,048
|    └─Linear: 2-2                            [-1, 512]                 66,048
|    └─Linear: 2-3                            [-1, 512]                 66,048
|    └─Softmax: 2-4                           [-1, 256, 256]            --
|    └─Linear: 2-5                            [-1, 1, 256]              32,768
|    └─Softmax: 2-6                           [-1, 256, 256]            --
|    └─Linear: 2-7                            [-1, 1, 256]              (recursive)
|    └─Softmax: 2-8                           [-1, 256, 256]            --
|    └─Linear: 2-9                            [-1, 1, 256]              (recursive)
|    └─Softmax: 2-10                          [-1, 256, 256]            --
|    └─Linear: 2-11                           [-1, 1, 256]   

Layer (type:depth-idx)                        Output Shape              Param #
├─MultiHeadSelfGraphAttentionLinear: 1-1      [-1, 2, 256]              --
|    └─Linear: 2-1                            [-1, 512]                 66,048
|    └─Linear: 2-2                            [-1, 512]                 66,048
|    └─Linear: 2-3                            [-1, 512]                 66,048
|    └─Softmax: 2-4                           [-1, 256, 256]            --
|    └─Linear: 2-5                            [-1, 1, 256]              32,768
|    └─Softmax: 2-6                           [-1, 256, 256]            --
|    └─Linear: 2-7                            [-1, 1, 256]              (recursive)
|    └─Softmax: 2-8                           [-1, 256, 256]            --
|    └─Linear: 2-9                            [-1, 1, 256]              (recursive)
|    └─Softmax: 2-10                          [-1, 256, 256]            --
|    └─Linear: 2-11                           [-1, 1, 256]   