In [None]:
from __future__ import print_function
import argparse
import inspect
import os
import pickle
import random
import shutil
import sys
import time
from collections import OrderedDict
import traceback
from sklearn.metrics import confusion_matrix
import csv
import numpy as np
import glob
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.optim as optim
import yaml
from tensorboardX import SummaryWriter
from tqdm import tqdm
from feeders.feeder_ntu import Feeder
from torch.optim.lr_scheduler import ReduceLROnPlateau

import torch_geometric.transforms as T
from torch_geometric.datasets import ZINC
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINEConv, global_add_pool
import inspect
from typing import Any, Dict, Optional

import torch.nn.functional as F
from torch import Tensor
from torch.nn import Dropout, Linear, Sequential

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import reset
from torch_geometric.nn.resolver import (
    activation_resolver,
    normalization_resolver,
)
from torch_geometric.typing import Adj
from torch_geometric.utils import to_dense_batch

from mamba_ssm import Mamba
from torch_geometric.utils import degree, sort_edge_index

In [None]:
def init_seed(seed):
    torch.cuda.manual_seed_all(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    # torch.backends.cudnn.enabled = True
    # training speed is too slow if set to True
    torch.backends.cudnn.deterministic = True

    # on cuda 11 cudnn8, the default algorithm is very slow
    # unlike on cuda 10, the default works well
    torch.backends.cudnn.benchmark = False

In [None]:
def load_data(self, phase='train'):
    if phase=='train':
        data_loader = torch.utils.data.DataLoader(
            dataset=Feeder('data/ntu/NTU60_CS.npz', split='train', p_interval =[0.5, 1], window_size=32),
            batch_size=16,
            num_workers=0,
            worker_init_fn=init_seed)
    else:
        data_loader = torch.utils.data.DataLoader(
            dataset=Feeder('data/ntu/NTU60_CS.npz', split='test', p_interval =[0.5, 1], window_size=32),
            batch_size=16,
            num_workers=0,
            worker_init_fn=init_seed)
    return data_loader

In [None]:
class GPSConv(torch.nn.Module):
    def __init__(
        self,
        channels: int,
        conv: Optional[MessagePassing],
        heads: int = 1,
        dropout: float = 0.0,
        attn_dropout: float = 0.0,
        act: str = 'relu',
        att_type: str = 'transformer',
        order_by_degree: bool = False,
        shuffle_ind: int = 0,
        d_state: int = 16,
        d_conv: int = 4,
        act_kwargs: Optional[Dict[str, Any]] = None,
        norm: Optional[str] = 'batch_norm',
        norm_kwargs: Optional[Dict[str, Any]] = None,
    ):
        super().__init__()

        self.channels = channels
        self.conv = conv
        self.heads = heads
        self.dropout = dropout
        self.att_type = att_type
        self.shuffle_ind = shuffle_ind
        self.order_by_degree = order_by_degree
        
        assert (self.order_by_degree==True and self.shuffle_ind==0) or (self.order_by_degree==False), f'order_by_degree={self.order_by_degree} and shuffle_ind={self.shuffle_ind}'
        
        if self.att_type == 'transformer':
            self.attn = torch.nn.MultiheadAttention(
                channels,
                heads,
                dropout=attn_dropout,
                batch_first=True,
            )
        if self.att_type == 'mamba':
            self.self_attn = Mamba(
                d_model=channels,
                d_state=d_state,
                d_conv=d_conv,
                expand=1
            )
            
        self.mlp = Sequential(
            Linear(channels, channels * 2),
            activation_resolver(act, **(act_kwargs or {})),
            Dropout(dropout),
            Linear(channels * 2, channels),
            Dropout(dropout),
        )

        norm_kwargs = norm_kwargs or {}
        self.norm1 = normalization_resolver(norm, channels, **norm_kwargs)
        self.norm2 = normalization_resolver(norm, channels, **norm_kwargs)
        self.norm3 = normalization_resolver(norm, channels, **norm_kwargs)

        self.norm_with_batch = False
        if self.norm1 is not None:
            signature = inspect.signature(self.norm1.forward)
            self.norm_with_batch = 'batch' in signature.parameters

    def reset_parameters(self):
        r"""Resets all learnable parameters of the module."""
        if self.conv is not None:
            self.conv.reset_parameters()
        self.attn._reset_parameters()
        reset(self.mlp)
        if self.norm1 is not None:
            self.norm1.reset_parameters()
        if self.norm2 is not None:
            self.norm2.reset_parameters()
        if self.norm3 is not None:
            self.norm3.reset_parameters()

    def forward(
        self,
        x: Tensor,
        edge_index: Adj,
        batch: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Tensor:
        r"""Runs the forward pass of the module."""
        hs = []
        if self.conv is not None:  # Local MPNN.
            h = self.conv(x, edge_index, **kwargs)
            h = F.dropout(h, p=self.dropout, training=self.training)
            h = h + x
            if self.norm1 is not None:
                if self.norm_with_batch:
                    h = self.norm1(h, batch=batch)
                else:
                    h = self.norm1(h)
            hs.append(h)

        ### Global attention transformer-style model.
        if self.att_type == 'transformer':
            h, mask = to_dense_batch(x, batch)
            h, _ = self.attn(h, h, h, key_padding_mask=~mask, need_weights=False)
            h = h[mask]
            
        if self.att_type == 'mamba':
            
            if self.order_by_degree:
                deg = degree(edge_index[0], x.shape[0]).to(torch.long)
                order_tensor = torch.stack([batch, deg], 1).T
                _, x = sort_edge_index(order_tensor, edge_attr=x)
                
            if self.shuffle_ind == 0:
                h, mask = to_dense_batch(x, batch)
                h = self.self_attn(h)[mask]
            else:
                mamba_arr = []
                for _ in range(self.shuffle_ind):
                    h_ind_perm = permute_within_batch(x, batch)
                    h_i, mask = to_dense_batch(x[h_ind_perm], batch)
                    h_i = self.self_attn(h_i)[mask][h_ind_perm]
                    mamba_arr.append(h_i)
                h = sum(mamba_arr) / self.shuffle_ind
        ###
        
        h = F.dropout(h, p=self.dropout, training=self.training)
        h = h + x  # Residual connection.
        if self.norm2 is not None:
            if self.norm_with_batch:
                h = self.norm2(h, batch=batch)
            else:
                h = self.norm2(h)
        hs.append(h)

        out = sum(hs)  # Combine local and global outputs.

        out = out + self.mlp(out)
        if self.norm3 is not None:
            if self.norm_with_batch:
                out = self.norm3(out, batch=batch)
            else:
                out = self.norm3(out)

        return out

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.channels}, '
                f'conv={self.conv}, heads={self.heads})')

In [None]:
class TemporalConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1):
        super(TemporalConv, self).__init__()
        # adjust padding for kernel size so that it will be equal to out_channe;s
        pad = (kernel_size + (kernel_size - 1) * (dilation - 1) - 1) // 2
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            # kernel_size, 1 so that we look only for spatial
            # 3 time steps windows of only 1 node
            kernel_size=(kernel_size, 1),
            padding=(pad, 0),
            stride=(stride, 1),
            dilation=(dilation, 1),
        )

        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x

class MultiScale_TemporalConv(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size=3,
                 stride=1,
                 dilations=[1, 2, 3, 4],
                 residual=False,
                 residual_kernel_size=1):

        super().__init__()
        assert out_channels % (len(dilations) + 2) == 0, '# out channels should be multiples of # branches'

        # Multiple branches of temporal convolution
        # + 2 because we have additional 2 branches for max and 1x1 branch
        self.num_branches = len(dilations) + 2
        branch_channels = out_channels // self.num_branches
        if type(kernel_size) == list:
            assert len(kernel_size) == len(dilations)
        else:
            kernel_size = [kernel_size] * len(dilations)
        # Temporal Convolution branches
        self.branches = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(
                    in_channels,
                    branch_channels,
                    kernel_size=1,
                    padding=0
                ),
                nn.BatchNorm2d(branch_channels),
                nn.ReLU(inplace=True),
                TemporalConv(
                    branch_channels,
                    branch_channels,
                    kernel_size=ks,
                    stride=stride,
                    dilation=dilation
                ),
            )
            # checking for each dilation so that we will look for global context
            for ks, dilation in zip(kernel_size, dilations)
        ])

        # Additional Max & 1x1 branch
        self.branches.append(nn.Sequential(
            nn.Conv2d(in_channels, branch_channels, kernel_size=1, padding=0),
            nn.BatchNorm2d(branch_channels),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(3, 1), stride=(stride, 1), padding=(1, 0)),
            nn.BatchNorm2d(branch_channels)  # 为什么还要加bn
        ))

        self.branches.append(nn.Sequential(
            nn.Conv2d(in_channels, branch_channels, kernel_size=1, padding=0, stride=(stride, 1)),
            nn.BatchNorm2d(branch_channels)
        ))

        # Residual connection
        if not residual:
            self.residual = lambda x: 0
        elif (in_channels == out_channels) and (stride == 1):
            self.residual = lambda x: x
        else:
            self.residual = TemporalConv(in_channels, out_channels, kernel_size=residual_kernel_size, stride=stride)
        # print(len(self.branches))

    def forward(self, x):
        x = x.permute(0,3,1,2)
        # Input dim: (N,C,T,V)
        res = self.residual(x)
        branch_outs = []
        for tempconv in self.branches:
            out = tempconv(x)
            branch_outs.append(out)

        out = torch.cat(branch_outs, dim=1)
        out += res
        out = out.permute(0, 2, 3, 1)
        return out

In [None]:
class GraphEnt(nn.Module):
    def __init__(self, dim_in, dim, num_points=25):
        super().__init__()
        nn1 = Sequential(
                nn.Linear(dim_in, dim_in),
                nn.ReLU(),
                nn.Linear(dim_in, dim_in),
            )
        self.conv = GPSConv(dim_in, GINEConv(nn1), heads=4, attn_dropout=0.5,
                               att_type='mamba',
                               shuffle_ind=0,
                               order_by_degree=True,
                               d_state=16, d_conv=4)
        self.lin = nn.Linear(dim_in, dim)
        self.edge_emb = nn.Embedding(num_points*2, dim_in)
        self_link = [(i, i) for i in range(25)]
        inward_ori_index = [(1, 2), (2, 21), (3, 21), (4, 3), (5, 21), (6, 5), (7, 6),
                    (8, 7), (9, 21), (10, 9), (11, 10), (12, 11), (13, 1),
                    (14, 13), (15, 14), (16, 15), (17, 1), (18, 17), (19, 18),
                    (20, 19), (22, 8), (23, 8), (24, 12), (25, 12)]
        inward = [(i - 1, j - 1) for (i, j) in inward_ori_index]
        outward = [(j, i) for (i, j) in inward]
        self.neighbor = inward + outward
        global device
        self.edge_index = self.convert_neighbor_to_edge_index(self.neighbor, device)
        self.adj_matr = torch.from_numpy(self.get_spatial_graph(25, self_link, inward, outward))

    def forward(self, x, dims):
        # N*M, T, V, C
        N, C, T, V, M = dims
        edge_attr = torch.ones(self.edge_index.size(1), dtype=torch.int, device=x.device)
        edge_attr = self.edge_emb(edge_attr)
        batch = self.create_batch_array(N, M, T, V, x.device)
        '''N * M - number of video sequences with person number
        C - number of channels (3d position of points)
        T - number of frames
        V - number of skeleton points (25)
        order is: N*M, T, V, C
        '''
        _, T, V, C = x.size()
        #x = x.view(-1, C)
        # implement gcn here
        #x = self.conv(x, self.adj_matr, batch, edge_attr=edge_attr)
        #x = x.view(N*M, T, V, C)
        x = self.lin(x)
        return x

    def normalize_digraph(self, A):
        Dl = np.sum(A, 0)
        h, w = A.shape
        Dn = np.zeros((w, w))
        for i in range(w):
            if Dl[i] > 0:
                Dn[i, i] = Dl[i] ** (-1)
        AD = np.dot(A, Dn)
        return AD

    def edge2mat(self, link, num_node):
        A = np.zeros((num_node, num_node))
        for i, j in link:
            A[j, i] = 1
        return A

    def get_spatial_graph(self, num_node, self_link, inward, outward):
        I = self.edge2mat(self_link, num_node)
        In = self.normalize_digraph(self.edge2mat(inward, num_node))
        Out = self.normalize_digraph(self.edge2mat(outward, num_node))
        A = np.stack((I, In, Out))
        return A

    def convert_neighbor_to_edge_index(self, neighbor, device):
        indices = torch.tensor(neighbor, dtype=torch.int64, device=device).t()
        return indices

    def create_batch_array(self, N, M, T, V, device):
        # Total number of unique indices
        num_indices = N * M 
        # Create a tensor of shape (num_indices, V) where each row contains the same index
        batch = torch.arange(num_indices, dtype=torch.int64, device=device).repeat_interleave(V*T)
        return batch

In [None]:
class GraphModel(nn.Module):
    def __init__(self, dim_in, dim):
        super().__init__()
        self.graph_conv_first = GraphEnt(dim_in, dim)
        self.graph_conv = GraphEnt(dim, dim)
        self.tcn = MultiScale_TemporalConv(dim, dim, kernel_size=5, stride=1,
                                            dilations=[1,2],
                                            # residual=True has worse performance in the end
                                            residual=False)
        self.act = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(0.3)
        self.lin = nn.Linear(dim_in, dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, 60),
        )
    
    def forward(self, x):
        N, C, T, V, M = x.size()
        dims = x.size()
        x = x.permute(0, 4, 2, 3, 1).contiguous().view(N * M, T, V, C)
        # N*M, T, V, C
        layers = 5
        #print("BEFORE", x[0][0][0])
        for i in range(layers): 
            if i == 0:
                x = self.graph_conv_first(x, dims)
            else:
                #residual = x
                x = self.graph_conv(x, dims)
        
        '''
        order is: N*M, T, V, C
        '''
        x = x.permute(0,3,1,2)
        _, C, T, V = x.size()
        x = x.view(N, M, C, -1)
        # order is: N, M, C, T*V
        x = x.mean(3).mean(1)
        x = self.mlp(x)
        #print("RESULT", x)
        return x

In [None]:
def train():
    model.train()
    loss_value = []
    acc_value = []
    train_loader = load_data('train')
    process = tqdm(train_loader, ncols=80)
    for batch_idx, (data, label, index) in enumerate(process):
        with torch.no_grad():
            data = data.float().to(device)
            label = label.long().to(device)
        with torch.amp.autocast('cuda', enabled=use_amp):
            out = model(data)
            loss = lossC(out, target=label)
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        loss_value.append(loss.data.item())
        value, predict_label = torch.max(out.data, 1)
        acc = torch.mean((predict_label == label.data).float())
        acc_value.append(acc.data.item())
    return np.nanmean(loss_value), np.nanmean(acc_value)*100

In [None]:
def test():
    model.eval()
    test_loader = load_data('test')
    loss_value = []
    score_frag = []
    process = tqdm(test_loader, ncols=80)
    for batch_idx, (data, label, index) in enumerate(process):
        with torch.no_grad():
            data = data.float().to(device)
            label = label.long().to(device)
            out = model(data)
            loss = lossC(out, target=label)
            #print(out.data.cpu().numpy())
            score_frag.append(out.data.cpu().numpy())
            loss_value.append(loss.data.item())
    loss = np.mean(loss_value)
    score = np.concatenate(score_frag)
    accuracy = test_loader.dataset.top_k(score, 1)
    return accuracy, loss

In [None]:
init_seed(2)
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
model = GraphModel(3, 24).to(device)
'''devices = [3,2]
device = devices[0]
model = nn.DataParallel(GraphModel(3, 1024).to(device),
                        device_ids=devices,
                        output_device=device)'''
optimizer = optim.SGD(
                model.parameters(),
                lr=0.001)
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Parameters:", count_parameters(model))
lossC = nn.CrossEntropyLoss().to(device)
use_amp = True
scaler = torch.amp.GradScaler('cuda', enabled=use_amp)

for epoch in range(1, 11):
    loss = train()
    print("Loss:", loss[0], "Train Acc:", loss[1])
    print("AFTER:")
    tst = test()
    print(f'Epoch: {epoch:02d}, Loss: {loss[0]:.4f}, Train Acc: {loss[1]:.4f}, Accuracy: {tst[0]}, Test_Loss: {tst[1]}')