---
title: Implement hyper-dgstgcn
subtitle: DGSTGCN에 hypergraph 구조 도입
description: blank
categories: HAR/experiments
author: YeEun Hong
date: 2023-05-03
---

`pyskl-hyper/configs/dgstgcn/ntu60_xsub_3dkp/j_hyper.py`

In [None]:
# https://deephypergraph.readthedocs.io/en/latest/api/dhg.html#dhg.Hypergraph
# class dhg.Hypergraph(num_v, e_list=None, e_weight=None, v_weight=None, merge_op='mean', device=torch.device)
# [Source](https://deephypergraph.readthedocs.io/en/latest/_modules/dhg/structure/hypergraphs/hypergraph.html#Hypergraph) 살펴보면 BaseHypergraph 상속해서 클래스 생성함을 알 수 있음

# tutorial
# https://deephypergraph.readthedocs.io/en/0.9.3/tutorial/structure.html#build-hypergraph

# Main ref
# https://deephypergraph.readthedocs.io/en/latest/tutorial/structure.html?highlight=masking#prometed-from-low-order-structures

import torch
import dhg  # Graph, Hypergraph
from dhg.models import GCN
from dhg.random import set_seed

e_list = [(3,4,21), (3,4,5), (3,4,9),
          (5,9,21), (3,6,21), (3,10,21), (2,6,21), (2,6,21), (2,10,21),
          (9,10,11), (5,6,7), (6,10,21),
          (2,11,12), (2,7,8), (7,11,21),
          (10,12,24), (9,11,25), (6,8,22), (5,7,22), (2,10,24), (2,6,22),
          (1,2,17), (1,2,13),
          (1,17,18), (1,13,14), (14,17,18), (13,14,18),
          (18,19,20), (14,15,16), (16,19,20), (15,16,20)]

def hyper2graph(e_list=e_list):    
    """
    Args:
        e_list (tuple): hyperedges
    Returns:
        Adjacency Matrix
    """
    
    hg = dhg.Hypergraph(num_v=25, e_list = [(e[0]-1, e[1]-1, e[2]-1) for e in e_list])

    # incidence graph : hg.e

    # Star Expansion
    g, v_mask = dhg.Graph.from_hypergraph_star(hg)

    # Clique Expansion
    # g = dhg.Graph.from_hypergraph_clique(hg)

    # sets = g.e[0]
    # A = g.A.to_dense()
    return hg, g, v_mask

`pyskl-hyper/pyskl/models/gcns/dgstgcn.py`

In [None]:
# Run Code
# bash tools/dist_train.sh configs/dgstgcn/ntu60_xsub_3dkp/j_hyper.py 3 --validate --test-last --test-best


import copy as cp
import torch
import torch.nn as nn
from mmcv.runner import load_checkpoint

from ...utils import Graph, cache_checkpoint
from ..builder import BACKBONES
from .utils import dggcn, dgmstcn, unit_tcn

EPS = 1e-4

# 예시:
# modules = [DGBlock(in_channels, base_channels, A.clone(), 1, residual=False, **lw_kwargs[0])]

class DGBlock(nn.Module):

    def __init__(self, in_channels, out_channels, A, stride=1, residual=True, **kwargs):
        super().__init__()
        # prepare kwargs for gcn and tcn
        common_args = ['act', 'norm', 'g1x1']
        for arg in common_args:
            if arg in kwargs:
                value = kwargs.pop(arg)
                kwargs['tcn_' + arg] = value
                kwargs['gcn_' + arg] = value

        gcn_kwargs = {k[4:]: v for k, v in kwargs.items() if k[:4] == 'gcn_'}
        tcn_kwargs = {k[4:]: v for k, v in kwargs.items() if k[:4] == 'tcn_'}
        kwargs = {k: v for k, v in kwargs.items() if k[1:4] != 'cn_'}
        assert len(kwargs) == 0

        self.gcn = dggcn(in_channels, out_channels, A, **gcn_kwargs)
        self.tcn = dgmstcn(out_channels, out_channels, stride=stride, **tcn_kwargs)

        self.relu = nn.ReLU()

        if not residual:
            self.residual = lambda x: 0
        elif (in_channels == out_channels) and (stride == 1):
            self.residual = lambda x: x
        else:
            self.residual = unit_tcn(in_channels, out_channels, kernel_size=1, stride=stride)

    def forward(self, x, A=None):
        """Defines the computation performed at every call."""
        res = self.residual(x)
        x = self.tcn(self.gcn(x, A)) + res
        return self.relu(x)


@BACKBONES.register_module()
class DGSTGCN(nn.Module):

    def __init__(self,
                 graph_cfg,     # ./configs/dgstgcn/ntu~/~.py model:dict
                 in_channels=3,
                 base_channels=64,
                 ch_ratio=2,
                 num_stages=10,
                 inflate_stages=[5, 8],
                 down_stages=[5, 8],
                 data_bn_type='VC',
                 num_person=2,
                 pretrained=None,
                 **kwargs):
        super().__init__()

        # 원본코드
        # self.graph = Graph(**graph_cfg)
        # A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False)

        # 수정본
        from pyskl.models.gcns.utils._hyper import hyper2graph    
        self.hg, self.g, self.v_mask = hyper2graph()
        
        # example
        # Hypergraph(num_v=25, num_e=30),
        # Graph(num_v=55, num_e=90),
        # tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
        #         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
        #         True,  True,  True,  True,  True, False, False, False, False, False,
        #         False, False, False, False, False, False, False, False, False, False,
        #         False, False, False, False, False, False, False, False, False, False,
        #         False, False, False, False, False])   # 55개

        # size of A = num_v * num_v (== size of v_mask * v_mask)
        A = torch.tensor(self.g.A.to_dense(), dtype=torch.float32, requires_grad=False)

        # 원본코드
        self.data_bn_type = data_bn_type
        self.kwargs = kwargs

        if data_bn_type == 'MVC':
            self.data_bn = nn.BatchNorm1d(num_person * in_channels * A.size(1))
        elif data_bn_type == 'VC':      # VC 로 설정돼서 이쪽으로 넘어감
            """
            in channel: 3
            A.size : torch.Size([55, 55]) -> A.size()[0] : 55
            data_bn: BatchNorm1d(165, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            """

            # self.data_bn = nn.BatchNorm1d(in_channels * A.size(1))    # original code
            # 원래 이 윗줄에 75가 들어가야했는데 165가 들어가는 바람에 문제가 생겼으므로 165/inchannels = 165/3 을 A.size(1) 대신 넣나
            self.data_bn = nn.BatchNorm1d(in_channels * A.size()[0])
            # print(f"\n\n\ndata_bn: {self.data_bn}\n\n\n")

        else:
            self.data_bn = nn.Identity()

        lw_kwargs = [cp.deepcopy(kwargs) for i in range(num_stages)]
        for k, v in kwargs.items():
            if isinstance(v, tuple) and len(v) == num_stages:
                for i in range(num_stages):
                    lw_kwargs[i][k] = v[i]
        lw_kwargs[0].pop('tcn_dropout', None)
        lw_kwargs[0].pop('g1x1', None)
        lw_kwargs[0].pop('gcn_g1x1', None)

        self.in_channels = in_channels
        self.base_channels = base_channels
        self.ch_ratio = ch_ratio
        self.inflate_stages = inflate_stages
        self.down_stages = down_stages
        modules = []
        if self.in_channels != self.base_channels:
            modules = [DGBlock(in_channels, base_channels, A.clone(), 1, residual=False, **lw_kwargs[0])]

        inflate_times = 0
        down_times = 0
        for i in range(2, num_stages + 1):
            stride = 1 + (i in down_stages)
            in_channels = base_channels
            if i in inflate_stages:
                inflate_times += 1
            out_channels = int(self.base_channels * self.ch_ratio ** inflate_times + EPS)
            base_channels = out_channels
            modules.append(DGBlock(in_channels, out_channels, A.clone(), stride, **lw_kwargs[i - 1]))
            down_times += (i in down_stages)

        if self.in_channels == self.base_channels:
            num_stages -= 1

        self.num_stages = num_stages
        self.gcn = nn.ModuleList(modules)
        self.pretrained = pretrained

    def init_weights(self):
        if isinstance(self.pretrained, str):
            self.pretrained = cache_checkpoint(self.pretrained)
            load_checkpoint(self, self.pretrained, strict=False)

    def forward(self, x):
        N, M, T, V, C = x.size()        # st-gcn에서 온 값들. data normalization에 사용됨.
        """
        N : 16, M : 2, T : 100, V : 25, C : 3
        """

        # permute : 모든 차원을 맞교환할 수 있으며 contiguous tensor에서만 사용 가능하고 결과값 또한 contiguous tensor다.
        # contiguous tensor인가? : 메모리상에서 인접해있는가?
        x = x.permute(0, 1, 3, 4, 2).contiguous()
        
        # x after permuted
        # N : 16, M : 2, T : 100, V : 25, C : 3

        
        # RuntimeError: running_mean should contain 75 elements not 165
        # -> Conv layer에서 input channel의 개수를 바꾸어 해결 (https://discuss.pytorch.org/t/runtimeerror-running-mean-should-contain-64-elements-not-96/30846)
        
        print("\n\n\n[\tFORWARD\t]\n\n\n")

        if self.data_bn_type == 'MVC':
            pass        # for testing
            print(f"\n\n\n MVC : ")
            x = self.data_bn(x.view(N, M * V * C, T))
            print("\t x : ", x)
            print("\n\t (N, M * V * C, T) : ", (N, M * V * C, T))

        else:       # vc이므로 이쪽으로 올 것.
            print(f"\n\n\n not MVC : ")     # 여기까지는 되는데
            print("\n\t (N * M, V * C, T) : ", (N * M, V * C, T))   # (6, 75, 100)
            print("\n\t x.view(N * M, V * C, T) : ", x.view(N * M, V * C, T).size())
            
            # x = self.data_bn(x.view(N * M, V * C, T))    # original code
            # view는 기존의 데이터와 같은 메모리 공간을 공유하며 stride 크기만 변경하여 보여주기만 다르게 한다. 
            # 그래서 contigious해야만 동작하며, 아닌 경우 에러가 발생함
            x = self.data_bn(x.view(1, 165, 272))   # RuntimeError: shape '[6, 165, 100]' is invalid for input of size 45000

            '''
            File "/home/devin/wdir/pyskl-hyper/pyskl/models/gcns/dgstgcn.py", line 178, in forward
                x = self.data_bn(x.view(N * M, V * C, T))
            File "/home/devin/anaconda3/envs/skl-mmlab/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
                return forward_call(*input, **kwargs)
            File "/home/devin/anaconda3/envs/skl-mmlab/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py", line 171, in forward
                return F.batch_norm(
            File "/home/devin/anaconda3/envs/skl-mmlab/lib/python3.10/site-packages/torch/nn/functional.py", line 2450, in batch_norm
                return torch.batch_norm(
            RuntimeError: running_mean should contain 75 elements not 165
            '''
            print("\t x : ", x)
            print("\n\t (N, M * V * C, T) : ", (N, M * V * C, T))

        x = x.view(N, M, V, C, T).permute(0, 1, 3, 4, 2).contiguous().view(N * M, C, T, V)

        for i in range(self.num_stages):
            x = self.gcn[i](x)

        x = x.reshape((N, M) + x.shape[1:])
        return x
