In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from utils.tgcn import ConvTemporalGraphical
from utils.graph import Graph
from torchinfo import summary

class Model(nn.Module):
    r"""Spatial temporal graph convolutional networks.

    Args:
        in_channels (int): Number of channels in the input data
        num_class (int): Number of classes for the classification task
        graph_args (dict): The arguments for building the graph
        edge_importance_weighting (bool): If ``True``, adds a learnable
            importance weighting to the edges of the graph
        **kwargs (optional): Other parameters for graph convolution units

    Shape:
        - Input: :math:`(N, in_channels, T_{in}, V_{in}, M_{in})`
        - Output: :math:`(N, num_class)` where
            :math:`N` is a batch size,
            :math:`T_{in}` is a length of input sequence,
            :math:`V_{in}` is the number of graph nodes,
            :math:`M_{in}` is the number of instance in a frame.
    """

    def __init__(self, in_channels, num_class, graph_args,
                 edge_importance_weighting, **kwargs):
        super().__init__()

        # load graph
        self.graph = Graph(**graph_args)
        A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False)
        self.register_buffer('A', A)

        # build networks
        spatial_kernel_size = A.size(0)
        temporal_kernel_size = 9
        kernel_size = (temporal_kernel_size, spatial_kernel_size)
        self.data_bn = nn.BatchNorm1d(in_channels * A.size(1))
        kwargs0 = {k: v for k, v in kwargs.items() if k != 'dropout'}
        self.st_gcn_networks = nn.ModuleList((
            st_gcn(in_channels, 64, kernel_size, 1, residual=False, **kwargs0),
            st_gcn(64, 64, kernel_size, 1, **kwargs),
            st_gcn(64, 64, kernel_size, 1, **kwargs),
            st_gcn(64, 64, kernel_size, 1, **kwargs),
            st_gcn(64, 128, kernel_size, 2, **kwargs),
            st_gcn(128, 128, kernel_size, 1, **kwargs),
            st_gcn(128, 128, kernel_size, 1, **kwargs),
            st_gcn(128, 256, kernel_size, 2, **kwargs),
            st_gcn(256, 256, kernel_size, 1, **kwargs),
            st_gcn(256, 256, kernel_size, 1, **kwargs),
        ))

        # initialize parameters for edge importance weighting
        if edge_importance_weighting:
            self.edge_importance = nn.ParameterList([
                nn.Parameter(torch.ones(self.A.size()))
                for i in self.st_gcn_networks
            ])
        else:
            self.edge_importance = [1] * len(self.st_gcn_networks)

        # fcn for prediction
        self.fcn = nn.Conv2d(256, num_class, kernel_size=1)

    def forward(self, x):

        # data normalization
        N, C, T, V, M = x.size()
        x = x.permute(0, 4, 3, 1, 2).contiguous()
        x = x.view(N * M, V * C, T)
        x = self.data_bn(x)
        x = x.view(N, M, V, C, T)
        x = x.permute(0, 1, 3, 4, 2).contiguous()
        x = x.view(N * M, C, T, V)

        # forwad
        for gcn, importance in zip(self.st_gcn_networks, self.edge_importance):
            x, _ = gcn(x, self.A * importance)

        # global pooling
        x = F.avg_pool2d(x, x.size()[2:])
        x = x.view(N, M, -1, 1, 1).mean(dim=1)

        # prediction
        x = self.fcn(x)
        x = x.view(x.size(0), -1)

        return x

    def extract_feature(self, x):

        # data normalization
        N, C, T, V, M = x.size()
        x = x.permute(0, 4, 3, 1, 2).contiguous()
        x = x.view(N * M, V * C, T)
        x = self.data_bn(x)
        x = x.view(N, M, V, C, T)
        x = x.permute(0, 1, 3, 4, 2).contiguous()
        x = x.view(N * M, C, T, V)

        # forwad
        for gcn, importance in zip(self.st_gcn_networks, self.edge_importance):
            x, _ = gcn(x, self.A * importance)

        _, c, t, v = x.size()
        feature = x.view(N, M, c, t, v).permute(0, 2, 3, 4, 1)

        # prediction
        x = self.fcn(x)
        output = x.view(N, M, -1, t, v).permute(0, 2, 3, 4, 1)

        return output, feature

class st_gcn(nn.Module):
    r"""Applies a spatial temporal graph convolution over an input graph sequence.

    Args:
        in_channels (int): Number of channels in the input sequence data
        out_channels (int): Number of channels produced by the convolution
        kernel_size (tuple): Size of the temporal convolving kernel and graph convolving kernel
        stride (int, optional): Stride of the temporal convolution. Default: 1
        dropout (int, optional): Dropout rate of the final output. Default: 0
        residual (bool, optional): If ``True``, applies a residual mechanism. Default: ``True``

    Shape:
        - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format
        - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
        - Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format
        - Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format

        where
            :math:`N` is a batch size,
            :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
            :math:`T_{in}/T_{out}` is a length of input/output sequence,
            :math:`V` is the number of graph nodes.

    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 dropout=0,
                 residual=True):
        super().__init__()

        assert len(kernel_size) == 2
        assert kernel_size[0] % 2 == 1
        padding = ((kernel_size[0] - 1) // 2, 0)

        self.gcn = ConvTemporalGraphical(in_channels, out_channels,
                                         kernel_size[1])

        self.tcn = nn.Sequential(
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                out_channels,
                out_channels,
                (kernel_size[0], 1),
                (stride, 1),
                padding,
            ),
            nn.BatchNorm2d(out_channels),
            nn.Dropout(dropout, inplace=True),
        )

        if not residual:
            self.residual = lambda x: 0

        elif (in_channels == out_channels) and (stride == 1):
            self.residual = lambda x: x

        else:
            self.residual = nn.Sequential(
                nn.Conv2d(
                    in_channels,
                    out_channels,
                    kernel_size=1,
                    stride=(stride, 1)),
                nn.BatchNorm2d(out_channels),
            )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x, A):

        res = self.residual(x)
        x, A = self.gcn(x, A)
        x = self.tcn(x) + res

        return self.relu(x), A
    

In [81]:
# in_channels (int): Number of channels in the input data
#     num_class (int): Number of classes for the classification task
#     graph_args (dict): The arguments for building the graph
#     edge_importance_weighting (bool): If ``True``, adds a learnable
#         importance weighting to the edges of the graph
#     **kwargs (optional): Other parameters for graph convolution units
a = Model(8, 20, graph_args = {"layout": "openpose"}, edge_importance_weighting=False)
summary(a, input_size=(1, 8, 80, 18, 1), col_names=["input_size", "output_size", "num_params"])

Layer (type:depth-idx)                        Input Shape               Output Shape              Param #
Model                                         [1, 8, 80, 18, 1]         [1, 20]                   --
├─BatchNorm1d: 1-1                            [1, 144, 80]              [1, 144, 80]              288
├─ModuleList: 1-2                             --                        --                        --
│    └─st_gcn: 2-1                            [1, 8, 80, 18]            [1, 64, 80, 18]           --
│    │    └─ConvTemporalGraphical: 3-1        [1, 8, 80, 18]            [1, 64, 80, 18]           576
│    │    └─Sequential: 3-2                   [1, 64, 80, 18]           [1, 64, 80, 18]           37,184
│    │    └─ReLU: 3-3                         [1, 64, 80, 18]           [1, 64, 80, 18]           --
│    └─st_gcn: 2-2                            [1, 64, 80, 18]           [1, 64, 80, 18]           --
│    │    └─ConvTemporalGraphical: 3-4        [1, 64, 80, 18]           [1, 64, 

In [3]:
import torch
import torch.nn as nn

In [44]:
class AttentionLayer(nn.Module):
    def __init__(self, in_features : int, out_features : int,
                 n_heads : int, is_concat : bool = True, 
                 dropout : float = 0.6, 
                 leaky_relu_negative_slope : float = 0.2):
        
        super().__init__()
        self.is_concat = is_concat
        self.n_heads = n_heads
        
        if is_concat:
            assert out_features % n_heads == 0
            self.n_hidden = out_features // n_heads
        else:
            self.n_hidden = out_features

        self.linear = nn.Linear(in_features, self.n_hidden * n_heads, bias = False)
        self.attn = nn.Linear(self.n_hidden * 2, 1, bias = False)
        self.activation = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)
        self.softmax = nn.Softmax(dim = 3)
        self.dropout = nn.Dropout(dropout)

    def forward(self, h:torch.Tensor, adj_mat: torch.Tensor):
        # h shape (n_nodes, in_features)
        # adj_mat shape (n_nodes, n_nodes, n_heads)
        # h moi (batch_size, time, n_nodes, in_features)
        # adj_mat (n_nodes, n_nodes, n_heads)
        batch_size, time, n_nodes, in_features = h.shape
        n_nodes = h.shape[2]
        # (batch_size, time, n_nodes, in_features) -> (batch_size, time, n_nodes, n_heads, n_hidden)
        g = self.linear(h).view(batch_size, time, n_nodes, self.n_heads, self.n_hidden)
        # print(g.shape)
        # (batch_size, time, n_nodes, n_heads, n_hidden) -> (batch_size, time, n_nodes * n_nodes, n_heads, n_hidden)
        g_repeat = g.repeat(1, 1, n_nodes, 1, 1)
        # print(g_repeat.shape)
        # (batch_size, time, n_nodes, n_heads, n_hidden) -> (batch_size, time, n_nodes * n_nodes, n_heads, n_hidden)
        g_repeat_interleave = g.repeat_interleave(n_nodes, dim = 2)
        # print(g_repeat_interleave.shape)
        # (batch_size, time, n_nodes, n_heads, n_hidden) -> (batch_size, time, n_nodes * n_nodes, n_heads, 2 * n_hidden)
        g_concat = torch.cat([g_repeat_interleave, g_repeat], dim = -1)
        # (batch_size, time, n_nodes * n_nodes, n_heads, 2 * n_hidden) -> (batch_size, time, n_nodes, n_nodes, n_heads, 2 * n_hidden)
        g_concat = g_concat.view(batch_size, time, n_nodes, n_nodes, self.n_heads, 2*self.n_hidden)
        
        # (batch_size, time, n_nodes, n_nodes, n_heads, 2 * n_hidden) -> (batch_size, time, n_nodes, n_nodes, n_heads, 1) 
        e = self.attn(g_concat)
        # (batch_size, time, n_nodes, n_nodes, n_heads, 1) -> (batch_size, time, n_nodes, n_nodes, n_heads, 1) 
        e = self.activation(e)
        # (batch_size, time, n_nodes, n_nodes, n_heads, 1) -> (batch_size, time, n_nodes, n_nodes, n_heads) 
        e = e.squeeze(-1)
        assert adj_mat.shape[0] == 1 or adj_mat.shape[0] == n_nodes
        assert adj_mat.shape[1] == 1 or adj_mat.shape[1] == n_nodes
        assert adj_mat.shape[2] == 1 or adj_mat.shape[2] == self.n_heads
        
        adj_mat = adj_mat.unsqueeze(0).unsqueeze(0)
        e = e.masked_fill(adj_mat == 0, float('-inf'))
        # (batch_size, time, n_nodes, n_nodes, n_heads) -> (batch_size, time, n_nodes, n_nodes, n_heads)
        a = self.softmax(e)
        # (batch_size, time, n_nodes, n_nodes, n_heads) -> (batch_size, time, n_nodes, n_nodes, n_heads)
        # ??? 
        a = self.dropout(a)

        # (batch_size, time, n_nodes, n_heads, n_hidden) * (batch_size, time, n_nodes, n_nodes, n_heads)
        # -> (batch_size, time, n_nodes, n_hidden, n_heads)
        attn_res = torch.einsum('abijh, abjhf->abihf', a, g)
        
        if self.is_concat:
            return attn_res.reshape(batch_size, time, n_nodes, self.n_heads * self.n_hidden)
        else:
            return attn_res.mean(dim = -1)

In [5]:
# shape tensor goc (N, C, T, V, M)
# (batch_size, time, n_nodes, n_heads, n_hidden)
x = torch.randn((1, 80, 25, 2, 3))
g_repeat = x.repeat(1, 1, 25, 1, 1)
print(g_repeat.shape)
g_repeat_interleave = x.repeat_interleave(25, dim = 2)
print(g_repeat_interleave.shape)
g_concat = torch.cat([g_repeat_interleave, g_repeat], dim = -1)
print(g_concat.shape)
g_concat = g_concat.view(1, 80, 25, 25, 2, 6)
print(g_concat.shape)
activation = nn.LeakyReLU(negative_slope=0.1)
y = activation(g_concat)
y.shape

torch.Size([1, 80, 625, 2, 3])
torch.Size([1, 80, 625, 2, 3])
torch.Size([1, 80, 625, 2, 6])
torch.Size([1, 80, 25, 25, 2, 6])


torch.Size([1, 80, 25, 25, 2, 6])

In [14]:
from torchinfo import summary
n_heads = 4
layer = AttentionLayer(in_features = 2, out_features = 8, n_heads = n_heads, is_concat = True)
# h moi (batch_size, time, n_nodes, in_features)
summary(layer, input_size = [(1, 80, 25, 2), (25, 25, n_heads)], col_names=["input_size", "output_size", "num_params"])

torch.Size([1, 80, 25, 4, 2])
torch.Size([1, 80, 625, 4, 2])
torch.Size([1, 80, 625, 4, 2])


Layer (type:depth-idx)                   Input Shape               Output Shape              Param #
AttentionLayer                           [1, 80, 25, 2]            [1, 80, 25, 8]            --
├─Linear: 1-1                            [1, 80, 25, 2]            [1, 80, 25, 8]            16
├─Linear: 1-2                            [1, 80, 25, 25, 4, 4]     [1, 80, 25, 25, 4, 1]     4
├─LeakyReLU: 1-3                         [1, 80, 25, 25, 4, 1]     [1, 80, 25, 25, 4, 1]     --
├─Softmax: 1-4                           [1, 80, 25, 25, 4]        [1, 80, 25, 25, 4]        --
├─Dropout: 1-5                           [1, 80, 25, 25, 4]        [1, 80, 25, 25, 4]        --
Total params: 20
Trainable params: 20
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0.00
Input size (MB): 0.03
Forward/backward pass size (MB): 1.73
Params size (MB): 0.00
Estimated Total Size (MB): 1.75

In [18]:
device = "cuda" if torch.cuda.is_available() else "cpu"
layer = AttentionLayer(in_features = 2, out_features = 8, n_heads = n_heads, is_concat = True)
torch_input1 = torch.randn(1, 80, 25, 2)
torch_input2 = torch.randn(25, 25, n_heads)
torch.onnx.export(layer, (torch_input1, torch_input2), "attentionLayer.onnx", verbose=True)


torch.Size([1, 80, 25, 4, 2])
torch.Size([1, 80, 625, 4, 2])
torch.Size([1, 80, 625, 4, 2])


  assert adj_mat.shape[0] == 1 or adj_mat.shape[0] == n_nodes
  assert adj_mat.shape[1] == 1 or adj_mat.shape[1] == n_nodes
  assert adj_mat.shape[2] == 1 or adj_mat.shape[2] == self.n_heads


verbose: False, log level: Level.ERROR



In [63]:
class Model(nn.Module):
    r"""Spatial temporal graph convolutional networks.

    Args:
        in_channels (int): Number of channels in the input data
        num_class (int): Number of classes for the classification task
        graph_args (dict): The arguments for building the graph
        edge_importance_weighting (bool): If ``True``, adds a learnable
            importance weighting to the edges of the graph
        **kwargs (optional): Other parameters for graph convolution units

    Shape:
        - Input: :math:`(N, in_channels, T_{in}, V_{in}, M_{in})`
        - Output: :math:`(N, num_class)` where
            :math:`N` is a batch size,
            :math:`T_{in}` is a length of input sequence,
            :math:`V_{in}` is the number of graph nodes,
            :math:`M_{in}` is the number of instance in a frame.
    """

    def __init__(self, in_channels, num_class, graph_args,
                 edge_importance_weighting, **kwargs):
        super().__init__()

        # load graph
        self.graph = Graph(**graph_args)
        A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False)
        self.register_buffer('A', A)

        # build networks
        spatial_kernel_size = A.size(0)
        temporal_kernel_size = 9
        kernel_size = (temporal_kernel_size, spatial_kernel_size)
        self.data_bn = nn.BatchNorm1d(in_channels * A.size(1))
        kwargs0 = {k: v for k, v in kwargs.items() if k != 'dropout'}
        self.n_heads = 4
        self.attention_layer = AttentionLayer(in_channels, in_channels * 4, 4, is_concat=True)
        self.st_gcn_networks = nn.ModuleList((
            st_gcn(in_channels*4, 64, kernel_size, 1, residual=False, **kwargs0),
            st_gcn(64, 64, kernel_size, 1, **kwargs),
            st_gcn(64, 64, kernel_size, 1, **kwargs),
            st_gcn(64, 64, kernel_size, 1, **kwargs),
            st_gcn(64, 128, kernel_size, 2, **kwargs),
            st_gcn(128, 128, kernel_size, 1, **kwargs),
            st_gcn(128, 128, kernel_size, 1, **kwargs),
            st_gcn(128, 256, kernel_size, 2, **kwargs),
            st_gcn(256, 256, kernel_size, 1, **kwargs),
            st_gcn(256, 256, kernel_size, 1, **kwargs),
        ))

        # initialize parameters for edge importance weighting
        if edge_importance_weighting:
            self.edge_importance = nn.ParameterList([
                nn.Parameter(torch.ones(self.A.size()))
                for i in self.st_gcn_networks
            ])
        else:
            self.edge_importance = [1] * len(self.st_gcn_networks)

        # fcn for prediction
        self.fcn = nn.Conv2d(256, num_class, kernel_size=1)

    def forward(self, x):

        # data normalization
        N, C, T, V, M = x.size()
        x = x.permute(0, 4, 3, 1, 2).contiguous()
        x = x.view(N * M, V * C, T)
        x = self.data_bn(x)
        #           0, 1, 2, 3, 4
        x = x.view(N, M, V, C, T)
        x = x.permute(0, 1, 4, 2, 3).contiguous()
        x = x.view(N * M, T, V, C)

        print(x.shape)
        print(self.A.shape)
        adj_mat = self.A.detach().squeeze(dim = 0)
        print(adj_mat.shape)
        adj_mat = adj_mat.unsqueeze(dim = 2).repeat(1, 1, self.n_heads)
        print(adj_mat.shape)
        x = self.attention_layer(x, adj_mat)
        print(x.shape)
        # 0, 1, 2, 3, 4
        # [1, 80, 18, 8]
        print(f"N{N}, M{M}, T{T}, V{V}, C{C}")
        # Note C*4
        x = x.view(N, M, T, V, C*4)
        x = x.permute(0, 1, 4, 2, 3).contiguous()
        x = x.view(N * M, C*4, T, V)
        print(x.shape)
        # forwad
        for gcn, importance in zip(self.st_gcn_networks, self.edge_importance):
            x, _ = gcn(x, self.A * importance)

        # global pooling
        x = F.avg_pool2d(x, x.size()[2:])
        x = x.view(N, M, -1, 1, 1).mean(dim=1)

        # prediction
        x = self.fcn(x)
        x = x.view(x.size(0), -1)

        return x

    def extract_feature(self, x):

        # data normalization
        N, C, T, V, M = x.size()
        x = x.permute(0, 4, 3, 1, 2).contiguous()
        x = x.view(N * M, V * C, T)
        x = self.data_bn(x)
        x = x.view(N, M, V, C, T)
        x = x.permute(0, 1, 3, 4, 2).contiguous()
        x = x.view(N * M, C, T, V)

        # forwad
        for gcn, importance in zip(self.st_gcn_networks, self.edge_importance):
            x, _ = gcn(x, self.A * importance)

        _, c, t, v = x.size()
        feature = x.view(N, M, c, t, v).permute(0, 2, 3, 4, 1)

        # prediction
        x = self.fcn(x)
        output = x.view(N, M, -1, t, v).permute(0, 2, 3, 4, 1)

        return output, feature

In [65]:
# in_channels (int): Number of channels in the input data
#     num_class (int): Number of classes for the classification task
#     graph_args (dict): The arguments for building the graph
#     edge_importance_weighting (bool): If ``True``, adds a learnable
#         importance weighting to the edges of the graph
#     **kwargs (optional): Other parameters for graph convolution units
a = Model(2, 20, graph_args = {"layout": "openpose"}, edge_importance_weighting=False)
x = torch.randn(1, 2, 80, 18, 1)
y = a(x)
summary(a, input_size=(1, 2, 80, 18, 1), col_names=["input_size", "output_size", "num_params"])

torch.Size([1, 80, 18, 2])
torch.Size([1, 18, 18])
torch.Size([18, 18])
torch.Size([18, 18, 4])
torch.Size([1, 80, 18, 8])
N1, M1, T80, V18, C2
torch.Size([1, 8, 80, 18])
torch.Size([1, 80, 18, 2])
torch.Size([1, 18, 18])
torch.Size([18, 18])
torch.Size([18, 18, 4])
torch.Size([1, 80, 18, 8])
N1, M1, T80, V18, C2
torch.Size([1, 8, 80, 18])


Layer (type:depth-idx)                        Input Shape               Output Shape              Param #
Model                                         [1, 2, 80, 18, 1]         [1, 20]                   --
├─BatchNorm1d: 1-1                            [1, 36, 80]               [1, 36, 80]               72
├─AttentionLayer: 1-2                         [1, 80, 18, 2]            [1, 80, 18, 8]            --
│    └─Linear: 2-1                            [1, 80, 18, 2]            [1, 80, 18, 8]            16
│    └─Linear: 2-2                            [1, 80, 18, 18, 4, 4]     [1, 80, 18, 18, 4, 1]     4
│    └─LeakyReLU: 2-3                         [1, 80, 18, 18, 4, 1]     [1, 80, 18, 18, 4, 1]     --
│    └─Softmax: 2-4                           [1, 80, 18, 18, 4]        [1, 80, 18, 18, 4]        --
│    └─Dropout: 2-5                           [1, 80, 18, 18, 4]        [1, 80, 18, 18, 4]        --
├─ModuleList: 1-3                             --                        --             

In [49]:
x = torch.rand(1, 80, 18, 8)
y = x.view(1, 1, 80, 18, 8)