In [1]:
from __future__ import print_function
import argparse
import inspect
import os
import pickle
import random
import shutil
import sys
import time
import math
from collections import OrderedDict
import traceback
from sklearn.metrics import confusion_matrix
import csv
import numpy as np
import glob
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.optim as optim
import yaml
from tensorboardX import SummaryWriter
from tqdm import tqdm
from feeders.feeder_ntu import Feeder
from torch.optim.lr_scheduler import ReduceLROnPlateau

import torch_geometric.transforms as T
from torch_geometric.datasets import ZINC
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_add_pool
import inspect
from typing import Any, Dict, Optional

import torch.nn.functional as F
from torch import Tensor
from torch.nn import Dropout, Linear, Sequential

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import reset
from torch_geometric.nn.resolver import (
    activation_resolver,
    normalization_resolver,
)
from torch_geometric.typing import Adj
from torch_geometric.utils import to_dense_batch

from mamba_ssm import Mamba
from torch_geometric.utils import degree, sort_edge_index

import wandb

In [2]:
def init_seed(seed):
    torch.cuda.manual_seed_all(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    # torch.backends.cudnn.enabled = True
    # training speed is too slow if set to True
    torch.backends.cudnn.deterministic = True

    # on cuda 11 cudnn8, the default algorithm is very slow
    # unlike on cuda 10, the default works well
    torch.backends.cudnn.benchmark = False

In [3]:
def load_data(phase='train', data='S'):
    if phase=='train':
        data_loader = torch.utils.data.DataLoader(
            dataset=Feeder(f'data/ntu/NTU60_C{data}.npz', split='train', p_interval =[0.5, 1], window_size=64),
            batch_size=32,
            #num_workers=0,
            worker_init_fn=init_seed)
    else:
        data_loader = torch.utils.data.DataLoader(
            dataset=Feeder(f'data/ntu/NTU60_C{data}.npz', split='test', p_interval =[0.95], window_size=64),
            batch_size=32,
            #num_workers=0,
            worker_init_fn=init_seed)
    return data_loader

In [4]:
class SpatialGraphCov(nn.Module):
    def __init__(self, in_channels, out_channels, A, **kwargs):
        super(SpatialGraphCov, self).__init__()
        self.num_subset = A.shape[0]

        self.gcn = nn.Conv2d(in_channels, out_channels * A.shape[0], kernel_size=1)
        
        self.A = torch.tensor(A, dtype=torch.float32, requires_grad=False)

    def forward(self, x, edge_index):
        self.A = self.A.to(x.device)
        n, c, t, v = x.size()
        # perform gcn
        x = self.gcn(x).view(n, self.num_subset, -1, t, v)  # update
        x = torch.einsum('nkctv,kvw->nctw', (x, self.A))  # aggregation
        return x

class GPSConv(torch.nn.Module):
    def __init__(
        self,
        channels: int,
        conv: Optional[MessagePassing],
        heads: int = 1,
        dropout: float = 0.0,
        attn_dropout: float = 0.0,
        act: str = 'relu',
        att_type: str = 'transformer',
        order_by_degree: bool = False,
        shuffle_ind: int = 0,
        d_state: int = 16,
        d_conv: int = 4,
        act_kwargs: Optional[Dict[str, Any]] = None,
        norm: Optional[str] = 'batch_norm',
        norm_kwargs: Optional[Dict[str, Any]] = None,
    ):
        super().__init__()

        self.channels = channels
        self.conv = conv
        self.heads = heads
        self.dropout = dropout
        self.att_type = att_type
        self.shuffle_ind = shuffle_ind
        self.order_by_degree = order_by_degree
        
        assert (self.order_by_degree==True and self.shuffle_ind==0) or (self.order_by_degree==False), f'order_by_degree={self.order_by_degree} and shuffle_ind={self.shuffle_ind}'
        
        if self.att_type == 'transformer':
            self.attn = torch.nn.MultiheadAttention(
                channels,
                heads,
                dropout=attn_dropout,
                batch_first=True,
            )
        if self.att_type == 'mamba':
            self.self_attn = Mamba(
                d_model=channels,
                d_state=d_state,
                d_conv=d_conv,
                expand=1
            )
            
        self.mlp = Sequential(
            nn.Conv2d(channels, channels * 2, kernel_size=1),
            Dropout(dropout),
            nn.Conv2d(channels * 2, channels, kernel_size=1),
            Dropout(dropout),
        )

        norm_kwargs = norm_kwargs or {}
        '''self.norm1 = normalization_resolver(norm, channels, **norm_kwargs)
        self.norm2 = normalization_resolver(norm, channels, **norm_kwargs)
        self.norm3 = normalization_resolver(norm, channels, **norm_kwargs)'''
        self.norm1 = nn.BatchNorm2d(channels)
        self.norm2 = nn.BatchNorm2d(channels)
        self.norm3 = nn.BatchNorm2d(channels)

        self.norm_with_batch = False
        if self.norm1 is not None:
            signature = inspect.signature(self.norm1.forward)
            self.norm_with_batch = 'batch' in signature.parameters

    def reset_parameters(self):
        r"""Resets all learnable parameters of the module."""
        if self.conv is not None:
            self.conv.reset_parameters()
        self.attn._reset_parameters()
        reset(self.mlp)
        if self.norm1 is not None:
            self.norm1.reset_parameters()
        if self.norm2 is not None:
            self.norm2.reset_parameters()
        if self.norm3 is not None:
            self.norm3.reset_parameters()

    def forward(
        self,
        x: Tensor,
        edge_index: Adj,
        batch: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Tensor:
        r"""Runs the forward pass of the module."""
        hs = []
        if self.conv is not None:  # Local MPNN.
            h = self.conv(x, edge_index, **kwargs)
            h = F.dropout(h, p=self.dropout, training=self.training)
            h = h + x
            if self.norm1 is not None:
                if self.norm_with_batch:
                    h = self.norm1(h, batch=batch)
                else:
                    h = self.norm1(h)
            hs.append(h)

        ### Global attention transformer-style model.
        if self.att_type == 'transformer':
            h, mask = to_dense_batch(x, batch)
            h, _ = self.attn(h, h, h, key_padding_mask=~mask, need_weights=False)
            h = h[mask]
            
        if self.att_type == 'mamba':
            
            if self.order_by_degree:
                deg = degree(edge_index[0], x.shape[0], dtype=torch.long)
                order_tensor = torch.stack([batch, deg], 1).T
                _, x = sort_edge_index(order_tensor, edge_attr=x)
                
            if self.shuffle_ind == 0:
                #h, mask = to_dense_batch(x, batch)
                N, C, T, V = h.size()
                h = h.permute(0, 2, 3, 1).contiguous().view(N*T, V, C)
                h = self.self_attn(h)
                h = h.contiguous().view(-1, T, V, C).permute(0, 3, 1, 2)
            else:
                mamba_arr = []
                for _ in range(self.shuffle_ind):
                    h_ind_perm = permute_within_batch(x, batch)
                    h_i, mask = to_dense_batch(x[h_ind_perm], batch)
                    h_i = self.self_attn(h_i)[mask][h_ind_perm]
                    mamba_arr.append(h_i)
                h = sum(mamba_arr) / self.shuffle_ind
        ###
        
        h = F.dropout(h, p=self.dropout, training=self.training)
        h = h + x  # Residual connection.
        if self.norm2 is not None:
            if self.norm_with_batch:
                h = self.norm2(h, batch=batch)
            else:
                h = self.norm2(h)
        hs.append(h)

        out = sum(hs)  # Combine local and global outputs.

        out = out + self.mlp(out)
        if self.norm3 is not None:
            if self.norm_with_batch:
                out = self.norm3(out, batch=batch)
            else:
                out = self.norm3(out)

        return out

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.channels}, '
                f'conv={self.conv}, heads={self.heads})')

In [5]:
def conv_init(conv):
    nn.init.kaiming_normal_(conv.weight, mode='fan_out')
    nn.init.constant_(conv.bias, 0)


def bn_init(bn, scale):
    nn.init.constant_(bn.weight, scale)
    nn.init.constant_(bn.bias, 0)

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        if hasattr(m, 'weight'):
            nn.init.kaiming_normal_(m.weight, mode='fan_out')
        if hasattr(m, 'bias') and m.bias is not None and isinstance(m.bias, torch.Tensor):
            nn.init.constant_(m.bias, 0)
    elif classname.find('BatchNorm') != -1:
        if hasattr(m, 'weight') and m.weight is not None:
            m.weight.data.normal_(1.0, 0.02)
        if hasattr(m, 'bias') and m.bias is not None:
            m.bias.data.fill_(0)

class TemporalConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1):
        super(TemporalConv, self).__init__()
        # adjust padding for kernel size so that it will be equal to out_channe;s
        pad = (kernel_size + (kernel_size - 1) * (dilation - 1) - 1) // 2
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            # kernel_size, 1 so that we look only for spatial
            # 3 time steps windows of only 1 node
            kernel_size=(kernel_size, 1),
            padding=(pad, 0),
            stride=(stride, 1),
            dilation=(dilation, 1),
        )

        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x

class MultiScale_TemporalConv(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size=5,
                 stride=1,
                 dilations=[1, 2],
                 residual=False,
                 residual_kernel_size=1):

        super().__init__()
        assert out_channels % (len(dilations) + 2) == 0, '# out channels should be multiples of # branches'

        # Multiple branches of temporal convolution
        # + 2 because we have additional 2 branches for max and 1x1 branch
        self.num_branches = len(dilations) + 2
        branch_channels = out_channels // self.num_branches
        if type(kernel_size) == list:
            assert len(kernel_size) == len(dilations)
        else:
            kernel_size = [kernel_size] * len(dilations)
        # Temporal Convolution branches
        self.branches = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(
                    in_channels,
                    branch_channels,
                    kernel_size=1,
                    padding=0
                ),
                nn.BatchNorm2d(branch_channels),
                #nn.ReLU(inplace=True),
                TemporalConv(
                    branch_channels,
                    branch_channels,
                    kernel_size=ks,
                    stride=stride,
                    dilation=dilation
                ),
            )
            # checking for each dilation so that we will look for global context
            for ks, dilation in zip(kernel_size, dilations)
        ])

        # Additional Max & 1x1 branch
        self.branches.append(nn.Sequential(
            nn.Conv2d(in_channels, branch_channels, kernel_size=1, padding=0),
            nn.BatchNorm2d(branch_channels),
            #nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(3, 1), stride=(stride, 1), padding=(1, 0)),
            nn.BatchNorm2d(branch_channels)  # 为什么还要加bn
        ))

        self.branches.append(nn.Sequential(
            nn.Conv2d(in_channels, branch_channels, kernel_size=1, padding=0, stride=(stride, 1)),
            nn.BatchNorm2d(branch_channels)
        ))

        # Residual connection
        if not residual:
            self.residual = lambda x: 0
        elif (in_channels == out_channels) and (stride == 1):
            self.residual = lambda x: x
        else:
            self.residual = TemporalConv(in_channels, out_channels, kernel_size=residual_kernel_size, stride=stride)
        # print(len(self.branches))
        self.apply(weights_init)
        

    def forward(self, x):
        #x = x.permute(0,3,1,2)
        # Input dim: (N,C,T,V)
        res = self.residual(x)
        branch_outs = []
        for tempconv in self.branches:
            out = tempconv(x)
            branch_outs.append(out)
        
        out = torch.cat(branch_outs, dim=1)
        out += res
        #out = out.permute(0, 2, 3, 1)
        return out

class unit_tcn(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=5, stride=1):
        super(unit_tcn, self).__init__()
        pad = int((kernel_size - 1) / 2)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(kernel_size, 1), padding=(pad, 0),
                              stride=(stride, 1), groups=1)
        self.bn = nn.BatchNorm2d(out_channels)
        conv_init(self.conv)
        bn_init(self.bn, 1)

    def forward(self, x):
        ## n, t, v, c
        #x = x.permute(0,3,1,2)
        x = self.bn(self.conv(x))
        ## n, c, t, v
        #x = x.permute(0,2,3,1)
        return x


In [6]:
class GraphEnt(nn.Module):
    def __init__(self, dim_in, dim, num_points=25):
        super().__init__()
        self.dim = dim
        self.dim_in = dim_in
        '''nn1 = Sequential(
                nn.Linear(2*dim, dim),
            )'''
        self_link = [(i, i) for i in range(25)]
        inward_ori_index = [(1, 2), (2, 21), (3, 21), (4, 3), (5, 21), (6, 5), (7, 6),
                    (8, 7), (9, 21), (10, 9), (11, 10), (12, 11), (13, 1),
                    (14, 13), (15, 14), (16, 15), (17, 1), (18, 17), (19, 18),
                    (20, 19), (22, 8), (23, 8), (24, 12), (25, 12)]
        inward = [(i - 1, j - 1) for (i, j) in inward_ori_index]
        outward = [(j, i) for (i, j) in inward]
        self.neighbor = inward + outward
        self.A = self.get_spatial_graph(25, self_link, inward, outward)
        self.conv = GPSConv(dim, SpatialGraphCov(dim, dim, self.A), heads=4, attn_dropout=0.5,
                       att_type='mamba',
                       shuffle_ind=0,
                       order_by_degree=True,
                       d_state=16, d_conv=4, norm='batch_norm')
        #self.edge_emb = nn.Embedding(len(self.neighbor), dim)
        self.norm = nn.LayerNorm(dim)
        if dim != dim_in:
            self.t = unit_tcn(dim_in, dim, kernel_size=1)
        self.joint_label = [0, 4, 2, 2, 2, 2, 1, 1, 2, 2, 1, 1, 2, 3, 3, 3, 2, 3, 3, 3, 1, 0, 1, 0, 1]
        self.pe_proj = nn.Conv2d(dim_in, dim, 1, bias=False)
        #self.act = nn.Sigmoid()

    
    def forward(self, x, dims):
        # N*M, T, V, C
        N, C, T, V, M = dims
        # HyperEdge
        label = F.one_hot(torch.tensor(self.joint_label)).float().to(x.device)
        z = x @ (label / label.sum(dim=0, keepdim=True))
        z = self.pe_proj(z).permute(3, 0, 1, 2)
        e = z[self.joint_label].permute(1, 2, 3, 0).contiguous()
        #increase in demension
        if self.dim_in != self.dim:
            x = self.t(x)
        '''N * M - number of video sequences with person number
        C - number of channels (3d position of points)
        T - number of frames
        V - number of skeleton points (25)
        order is: N*M, T, V, C
        '''
        NM, C, T, V = x.size() 
        batch = self.create_batch_array(N*M, 1, 1, x.device)
        edge_index = self.convert_neighbor_to_edge_index(self.neighbor, x.device)
        #edge_index = self.repeat_tensor_with_increment(edge_index, N*M, V)
        #edge_attr = torch.ones(edge_index.size(1), dtype=torch.int, device=x.device)
        #edge_attr = self.edge_emb(edge_attr)
        # Sum
        ot = x+e
        ot = self.conv(ot, edge_index, batch).contiguous()
        return ot


    def convert_neighbor_to_edge_index(self, neighbor, device):
        indices = torch.tensor(neighbor, dtype=torch.int64, device=device).t()
        return indices

    def repeat_tensor_with_increment(self, tensor, batch_size, V):
        result = tensor
        for i in range(1, batch_size):
            new_tensor = tensor + V*i
            result = torch.cat((result, new_tensor), dim=1)
        
        return result

    def create_batch_array(self, NM, T, V, device):
        # Total number of unique indices
        num_indices = NM
        # Create a tensor of shape (num_indices, V) where each row contains the same index
        batch = torch.arange(num_indices, dtype=torch.int64, device=device).repeat_interleave(V * T)
        return batch

    def edge2mat(self, link, num_node):
        A = np.zeros((num_node, num_node))
        for i, j in link:
            A[j, i] = 1
        return A
    
    def normalize_digraph(self, A):
        Dl = np.sum(A, 0)
        h, w = A.shape
        Dn = np.zeros((w, w))
        for i in range(w):
            if Dl[i] > 0:
                Dn[i, i] = Dl[i] ** (-1)
        AD = np.dot(A, Dn)
        return AD
    
    def get_spatial_graph(self, num_node, self_link, inward, outward):
        I = self.edge2mat(self_link, num_node)
        In = self.normalize_digraph(self.edge2mat(inward, num_node))
        Out = self.normalize_digraph(self.edge2mat(outward, num_node))
        A = np.stack((I, In, Out))
        return A

In [7]:
class GraphTCN(nn.Module):
    def __init__(self, dim_in, dim, stride=1):
        super().__init__()
        self.dim_in = dim_in
        self.dim = dim
        self.conv = GraphEnt(dim_in, dim)
        self.tcn = MultiScale_TemporalConv(dim, dim, kernel_size=5, stride=stride,
                                            dilations=[1,2],
                                            # residual=True has worse performance in the end
                                            residual=False)
        self.act = nn.Tanh()
        self.norm = nn.BatchNorm2d(dim)
        self.dropout = nn.Dropout(0.2)
        if dim_in == dim and stride == 1:
            self.residual = lambda x: x
        else:
            self.residual = unit_tcn(dim_in, dim, kernel_size=1, stride=stride)
    

    def forward(self, x, dims):
        x = self.tcn(self.conv(x, dims)) + self.residual(x)
        #x = self.dropout(x)
        x = self.act(self.norm(x))
        return x

In [8]:
class GraphModel(nn.Module):
    def __init__(self, dim_in, dim):
        super().__init__()
        self.l1 = GraphTCN(dim_in, dim)
        self.l2 = GraphTCN(dim, dim)
        self.l3 = GraphTCN(dim, dim)
        self.l4 = GraphTCN(dim, dim)
        self.l5 = GraphTCN(dim, dim)
        '''self.l1 = GraphTCN(dim_in, dim)
        self.l2 = GraphTCN(dim, dim)
        self.l3 = GraphTCN(dim, dim*2, stride=2)
        self.l4 = GraphTCN(dim*2,dim*2)
        self.l5 = GraphTCN(dim*2, dim*2)'''
        self.fc1 = nn.Linear(dim, 60)
        self.mlp = nn.Sequential(
            self.fc1
        )
        self.data_bn = nn.BatchNorm1d(2*3*25)
        nn.init.normal_(self.fc1.weight, 0, math.sqrt(2. / 60))
        bn_init(self.data_bn, 1)
    
    def forward(self, x):
        N, C, T, V, M = x.size()
        dims = x.size()
        x = x.permute(0, 4, 3, 1, 2).contiguous().view(N, M * V * C, T)
        x = self.data_bn(x)
        x = x.view(N, M, V, C, T).contiguous().view(N * M, V, C, T).permute(0, 2, 3, 1).contiguous()
        # N*M, C, T, V
        x = self.l1(x, dims)
        x = self.l2(x, dims)
        x = self.l3(x, dims)
        x = self.l4(x, dims)
        x = self.l5(x, dims)
        '''x = self.l6(x, dims)
        x = self.l7(x, dims)
        x = self.l8(x, dims)
        x = self.l9(x, dims)
        x = self.l10(x, dims)'''
        '''
        order is: N*M, T, V, C
        '''
        #x = x.permute(0,3,1,2)
        _, C, T, V = x.size()
        x = x.view(N, M, C, -1)
        # order is: N, M, C, T*V
        x = x.mean(3).mean(1)
        x = self.mlp(x)
        #print("RESULT", x)
        return x

In [9]:
def train(epoch):
    model.train()
    loss_value = []
    acc_value = []
    train_loader = load_data('train')
    process = tqdm(train_loader, ncols=80)
    for batch_idx, (data, label, index) in enumerate(process):
        with torch.no_grad():
            data = data.float().to(device)
            label = label.long().to(device)
        out = model(data)
        loss = lossC(out, target=label)
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        # Gradient clipping
        #torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        # Step the optimizer
        optimizer.step()
        loss_value.append(loss.data.item())
        value, predict_label = torch.max(out.data, 1)
        acc = torch.mean((predict_label == label.data).float())
        acc_value.append(acc.data.item())
    return np.nanmean(loss_value), np.nanmean(acc_value)*100

In [10]:
def evalt(wrong_file = None, result_file = None, model_name = 'Model', save=True, loader=load_data('test')):
    if wrong_file is not None:
        f_w = open(wrong_file, 'w')
    if result_file is not None:
        f_r = open(result_file, 'w')
    model.eval()
    test_loader = loader
    loss_value = []
    score_frag = []
    process = tqdm(test_loader, ncols=80)
    for batch_idx, (data, label, index) in enumerate(process):
        with torch.no_grad():
            data = data.float().to(device)
            label = label.long().to(device)
            out = model(data)
            loss = lossC(out, target=label)
            #print(out.data.cpu().numpy())
            value, predict_label = torch.max(out.data, 1)
            score_frag.append(out.data.cpu().numpy())
            loss_value.append(loss.data.item())
            _, predict_label = torch.max(out.data, 1)
        if wrong_file is not None or result_file is not None:
            predict = list(predict_label.cpu().numpy())
            true = list(label.data.cpu().numpy())
            for i, x in enumerate(predict):
                if result_file is not None:
                    f_r.write(str(x) + ',' + str(true[i]) + '\n')
                if x != true[i] and wrong_file is not None:
                    f_w.write(str(index[i]) + ',' + str(x) + ',' + str(true[i]) + '\n')
    score = np.concatenate(score_frag)
    loss = np.nanmean(loss_value)
    top5_acc = test_loader.dataset.top_k(score, 5) * 100
    top1_acc = test_loader.dataset.top_k(score, 1) * 100
    global best_epoch_acc
    if save and top1_acc > best_epoch_acc:
        state_dict = model.state_dict()
        torch.save(state_dict, f'mainruns/{model_name}.pt')
        best_epoch_acc = top1_acc
    
    print('\tTop{}: {:.2f}%'.format(
            1, top1_acc))
    print('\tTop{}: {:.2f}%'.format(
            5, top5_acc))
    return top5_acc, loss, top1_acc

In [11]:
def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, mode='fan_out')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

def adjust_learning_rate(epoch):
    global lr, warm_epochs, lr_decay_rate, step
    if epoch <= warm_epochs:
        lr = base_lr * epoch / warm_epochs
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    else:
        if sched == 'plat':
            scheduler.step(tst[1])
        elif sched == 'custom':
            lr = base_lr * (lr_decay_rate ** np.sum(epoch >= np.array(step)))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
    return lr

In [None]:
init_seed(2)

torch.cuda.empty_cache()
#torch.autograd.set_detect_anomaly(True)
#import os
#os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

phase = 'train'
optimz = 'SGD'
warm_epochs = 5
epochs = 140
out_channels = 192
weight_decay = 1e-4
if optimz == 'Adam' or optimz == 'AdamW' or optimz == 'NAdam':
    base_lr = 0.0025
    lr = base_lr
    lr_decay_rate = 0.2
    step = [10, 25, 40, 60]
    sched = 'plat'
elif optimz == 'SGD':
    base_lr = 0.01
    lr = base_lr
    lr_decay_rate = 0.2
    step = [60, 100]
    sched = 'custom'

use_wandb = False
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
model = GraphModel(3, out_channels).to(device)
# model.apply(init_weights)
'''devices = [2, 3]
device = devices[0]
model = GraphModel(3, out_channels)
model = nn.DataParallel(model,
                        device_ids=devices,
                        output_device=device)
model.to(device)
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)'''

if optimz == 'SGD':
    optimizer = optim.SGD(
                model.parameters(),
                momentum=0.9,
                lr=base_lr,
                nesterov = True,
                weight_decay = weight_decay)
elif optimz == 'Adam':
    optimizer = optim.Adam(model.parameters(), lr=base_lr)
elif optimz == 'NAdam':
    optimizer = optim.NAdam(model.parameters(), lr=base_lr, weight_decay=weight_decay)
elif optimz == 'AdamW':
    optimizer = optim.AdamW(model.parameters(), lr=base_lr, weight_decay=weight_decay)
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
'''for name, parameter in model.named_parameters():
    if not parameter.requires_grad:
        continue
    params = parameter.numel()
    print(name, params)'''
lossC = nn.CrossEntropyLoss().to(device)
# scaler = torch.amp.GradScaler('cuda', enabled=use_amp)
if sched == 'plat':
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=lr_decay_rate, patience=5)
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = 20, gamma=0.7)
if phase == 'train':
    if use_wandb:
        wandb.login()
        wandb.init(
            # set the wandb project where this run will be logged
            project="spanchsan-hong-kong-polytechnic-university",
        
            # track hyperparameters and run metadata
            config={
            "learning_rate": lr,
            "architecture": "Mamba",
            "dataset": "NTU RGB+D 60",
            "epochs": epochs,
            "out_channels": out_channels,
            "batch_size": 64,
            "weight_decay": weight_decay,
            "lr_decay_rate": lr_decay_rate,
            "optimizer": optimz,
            "scheduler": sched
            }
        )
    print("Parameters:", count_parameters(model))
    print(model)
    best_epoch_acc = 0
    for epoch in range(1, epochs+1):
        adjust_learning_rate(epoch)
        trn = train(epoch)
        if use_wandb:
            tst = evalt(model_name = f'Model{out_channels}_{wandb.run.name}')
        else:
            tst = evalt(save = False)
        if use_wandb:
            wandb.log({"train_loss": trn[0], "train_acc": trn[1], "test_acc": tst[0], "test_acc_1st": tst[2], "test_loss": tst[1]})
        current_lr = optimizer.param_groups[0]['lr']
        print(f'Epoch: {epoch:02d}, Loss: {trn[0]:.4f}, Train Acc: {trn[1]:.4f}, \
        Accuracy: {tst[2]:.4f}, Test_Loss: {tst[1]:.4f}, Current_Lr: {current_lr:.6f}')
elif phase == 'test':
    print("Parameters:", count_parameters(model))
    weights_path = 'mainruns/Model192_winter-water-141.pt'
    model.load_state_dict(torch.load(weights_path, weights_only=True))
    wf = weights_path.replace('.pt', '_wrong.txt')
    rf = weights_path.replace('.pt', '_right.txt')
    loader = load_data(phase='test', data='S')
    evalt(wrong_file=wf, result_file=rf, save=False, loader=loader)

Parameters: 2400936
GraphModel(
  (l1): GraphTCN(
    (conv): GraphEnt(
      (conv): GPSConv(192, conv=SpatialGraphCov(
        (gcn): Conv2d(192, 576, kernel_size=(1, 1), stride=(1, 1))
      ), heads=4)
      (norm): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
      (t): unit_tcn(
        (conv): Conv2d(3, 192, kernel_size=(1, 1), stride=(1, 1))
        (bn): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (pe_proj): Conv2d(3, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
    )
    (tcn): MultiScale_TemporalConv(
      (branches): ModuleList(
        (0): Sequential(
          (0): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1))
          (1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): TemporalConv(
            (conv): Conv2d(48, 48, kernel_size=(5, 1), stride=(1, 1), padding=(2, 0))
            (bn): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_runnin

 42%|████████████████▉                       | 530/1253 [02:41<03:38,  3.31it/s]

In [None]:
        #self.joint_label = [0, 4, 2, 2, 2, 2, 1, 1, 2, 2, 1, 1, 2, 3, 3, 3, 2, 3, 3, 3, 1, 0, 1, 0, 1]
        #self.pe_proj = nn.Conv2d(dim_in, dim, 1, bias=False)
'''
        # HyperEdge
        x0 = x.permute(0, 3, 1, 2) 
        label = F.one_hot(torch.tensor(self.joint_label)).float().to(x.device)
        z = x0 @ (label / label.sum(dim=0, keepdim=True))
        z = self.pe_proj(z).permute(3, 0, 1, 2)
        e = z[self.joint_label].permute(1, 3, 0, 2).contiguous()
        # Sum
        ot = x+e
        ot = self.norm(ot)'''