In [None]:
from __future__ import print_function
import inspect
import os
import pickle
import random
import shutil
import sys
import time
import math
from collections import OrderedDict
import traceback
from sklearn.metrics import confusion_matrix
import csv
import numpy as np
import glob
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.optim as optim
import yaml
from tqdm import tqdm
from feeders.feeder_ntu import Feeder
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR, LinearLR
from ptflops import get_model_complexity_info
import scipy.sparse as sp

import torch.nn.functional as F
from torch import Tensor
from torch.nn import Dropout, Linear, Sequential

from mamba_ssm import Mamba2

import wandb

In [None]:
def init_seed(seed):
    torch.cuda.manual_seed_all(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    # torch.backends.cudnn.enabled = True
    # training speed is too slow if set to True
    torch.backends.cudnn.deterministic = True

    # on cuda 11 cudnn8, the default algorithm is very slow
    # unlike on cuda 10, the default works well
    torch.backends.cudnn.benchmark = False

In [None]:
def conv_init(conv):
    nn.init.kaiming_normal_(conv.weight, mode='fan_out')
    nn.init.constant_(conv.bias, 0)


def bn_init(bn, scale):
    nn.init.constant_(bn.weight, scale)
    nn.init.constant_(bn.bias, 0)

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        if hasattr(m, 'weight'):
            nn.init.kaiming_normal_(m.weight, mode='fan_out')
        if hasattr(m, 'bias') and m.bias is not None and isinstance(m.bias, torch.Tensor):
            nn.init.constant_(m.bias, 0)
    elif classname.find('BatchNorm') != -1:
        if hasattr(m, 'weight') and m.weight is not None:
            m.weight.data.normal_(1.0, 0.02)
        if hasattr(m, 'bias') and m.bias is not None:
            m.bias.data.fill_(0)

class TemporalConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1):
        super(TemporalConv, self).__init__()
        # adjust padding for kernel size so that it will be equal to out_channe;s
        pad = (kernel_size + (kernel_size - 1) * (dilation - 1) - 1) // 2
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            # kernel_size, 1 so that we look only for spatial
            # 3 time steps windows of only 1 node
            kernel_size=(kernel_size, 1),
            padding=(pad, 0),
            stride=(stride, 1),
            dilation=(dilation, 1),
        )

        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x

class MultiScale_TemporalConv(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size=5,
                 stride=1,
                 dilations=[1, 2],
                 residual=False,
                 residual_kernel_size=1):

        super().__init__()
        assert out_channels % (len(dilations) + 2) == 0, '# out channels should be multiples of # branches'

        # Multiple branches of temporal convolution
        # + 2 because we have additional 2 branches for max and 1x1 branch
        self.num_branches = len(dilations) + 2
        branch_channels = out_channels // self.num_branches
        if type(kernel_size) == list:
            assert len(kernel_size) == len(dilations)
        else:
            kernel_size = [kernel_size] * len(dilations)
        # Temporal Convolution branches
        self.branches = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(
                    in_channels,
                    branch_channels,
                    kernel_size=1,
                    padding=0
                ),
                nn.BatchNorm2d(branch_channels),
                nn.ReLU(inplace=True),
                TemporalConv(
                    branch_channels,
                    branch_channels,
                    kernel_size=ks,
                    stride=stride,
                    dilation=dilation
                ),
            )
            # checking for each dilation so that we will look for global context
            for ks, dilation in zip(kernel_size, dilations)
        ])

        # Additional Max & 1x1 branch
        self.branches.append(nn.Sequential(
            nn.Conv2d(in_channels, branch_channels, kernel_size=1, padding=0),
            nn.BatchNorm2d(branch_channels),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(3, 1), stride=(stride, 1), padding=(1, 0)),
            nn.BatchNorm2d(branch_channels)  # 为什么还要加bn
        ))

        self.branches.append(nn.Sequential(
            nn.Conv2d(in_channels, branch_channels, kernel_size=1, padding=0, stride=(stride, 1)),
            nn.BatchNorm2d(branch_channels)
        ))

        # Residual connection
        if not residual:
            self.residual = lambda x: 0
        elif (in_channels == out_channels) and (stride == 1):
            self.residual = lambda x: x
        else:
            self.residual = TemporalConv(in_channels, out_channels, kernel_size=residual_kernel_size, stride=stride)
        # print(len(self.branches))
        self.apply(weights_init)
        

    def forward(self, x):
        # Input dim: (N,C,T,V)
        res = self.residual(x)
        branch_outs = []
        for tempconv in self.branches:
            out = tempconv(x)
            branch_outs.append(out)
        
        out = torch.cat(branch_outs, dim=1)
        out += res
        return out

class unit_tcn(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=5, stride=1):
        super(unit_tcn, self).__init__()
        pad = int((kernel_size - 1) / 2)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(kernel_size, 1), padding=(pad, 0),
                              stride=(stride, 1), groups=1)
        self.bn = nn.BatchNorm2d(out_channels)
        conv_init(self.conv)
        bn_init(self.bn, 1)

    def forward(self, x):
        x = self.bn(self.conv(x))
        return x


In [None]:
def gdc(A: sp.csr_matrix, alpha: float, eps: float):
    N = A.shape[0]

    # Self-loops
    A_loop = sp.eye(N) + A

    # Symmetric transition matrix
    D_loop_vec = A_loop.sum(0).A1
    D_loop_vec_invsqrt = 1 / np.sqrt(D_loop_vec)
    D_loop_invsqrt = sp.diags(D_loop_vec_invsqrt)
    T_sym = D_loop_invsqrt @ A_loop @ D_loop_invsqrt

    # PPR-based diffusion
    S = alpha * sp.linalg.inv(sp.eye(N) - (1 - alpha) * T_sym)

    # Sparsify using threshold epsilon
    S_tilde = S.multiply(S >= eps)

    # Column-normalized transition matrix on graph S_tilde
    D_tilde_vec = S_tilde.sum(0).A1
    T_S = S_tilde / D_tilde_vec
    
    return T_S

def edge2mat(link, num_node):
    A = np.zeros((num_node, num_node))
    for i, j in link:
        A[j, i] = 1
    return A

def normalize_digraph( A):
    Dl = np.sum(A, 0)
    h, w = A.shape
    Dn = np.zeros((w, w))
    for i in range(w):
        if Dl[i] > 0:
            Dn[i, i] = Dl[i] ** (-1)
    AD = np.dot(A, Dn)
    return AD

def get_spatial_graph( num_node, self_link, inward, outward):
    I = edge2mat(self_link, num_node)
    In = normalize_digraph(edge2mat(inward, num_node))
    Out = normalize_digraph(edge2mat(outward, num_node))
    A = np.stack((I, In, Out))
    return A

def construct_incidence_matrix(A):
    v = A.shape[0]
    edge_list = []
    for i in range(v):
        for j in range(i + 1, v):
            if A[i, j] != 0:
                edge_list.append((i, j))
    e = len(edge_list)
    I = torch.zeros((v, e))
    B = torch.zeros((e, e))
    for e, (i, j) in enumerate(edge_list):
        I[i, e] = 1
        I[j, e] = 1
    for i in range(e):
        for j in range(e):
            if edge_list[i][0] in edge_list[j] or edge_list[i][1] in edge_list[j]:
                B[i, j] = 1
                B[j, i] = 1
    return I, B

In [None]:
class ParalleGraphCov(nn.Module):
    def __init__(self, in_channels, out_channels, A, **kwargs):
        super(ParalleGraphCov, self).__init__()
        '''self.num_subset = A.shape[0]

        self.gcn = nn.Conv2d(in_channels, out_channels * A.shape[0], kernel_size=1)'''
        self.A = torch.tensor(A, dtype=torch.float32, requires_grad=False)
        self.num_heads = 8
        self.fc2 = nn.ModuleList([nn.Conv2d(in_channels, out_channels, 1, groups=self.num_heads) for _ in range(3)])
        self.fc1 = nn.Parameter(torch.stack([torch.stack([torch.eye(A.shape[-1]) for _ in range(self.num_heads)], dim=0) for _ in range(3)], dim=0), requires_grad=True)
        

    def forward(self, x):
        self.A = self.A.to(x.device)
        N, C, T, V = x.size()
        # perform gcn
        #x = self.gcn(x).view(N, self.num_subset, -1, T, V)  # update
        #x = torch.einsum('nkctv,kvw->nctw', (x, self.A))  # aggregation
        y = None
        for i in range(3):
            w1 = self.fc1[i]
            x = x.view(N, self.num_heads, -1, T, V)
            z = torch.einsum("nhctv, hvw->nhctw", (x, w1)).contiguous().view(N, -1, T, V)

            z = self.fc2[i](z)

            y = z + y if y is not None else z
        return y

class STGC(nn.Module):
    def __init__(self, in_channels, out_channels, A, **kwargs):
        super(STGC, self).__init__()
        self.num_subset = A.shape[0]

        self.gcn = nn.Conv2d(in_channels, out_channels * A.shape[0], kernel_size=1)
        self.A = torch.tensor(A, dtype=torch.float32, requires_grad=False).clone().detach()
    
    
    def forward(self, x):
        self.A = self.A.to(x.device)
        N, C, T, V = x.size()
        # perform gcn
        x = self.gcn(x).view(N, self.num_subset, -1, T, V)  # update
        x = torch.einsum('nkctv,kvw->nctw', (x, self.A))  # aggregation
        return x
        

class GCNMamba(torch.nn.Module):
    def __init__(
        self,
        dim_in,
        dim,
        A,
        line_repr,
        d_state: int = 16,
        d_conv: int = 4,
    ):
        super().__init__()
        self.dim = dim
        self.d_state = d_state
        self.d_conv = d_conv
        self.line_repr = line_repr
        '''self.self_attn = Mamba2(
            d_model=dim,
            d_state=d_state,
            d_conv=d_conv,
            headdim=dim//8,
            expand=1,
        )
        self.hops, self.rpe = construct_hop_rpe(h1)
        self.joint_label = [0, 4, 2, 2, 2, 2, 1, 1, 2, 2, 1, 1, 2, 3, 3, 3, 2, 3, 3, 3, 1, 0, 1, 0, 1]
        self.pe_proj = nn.Conv2d(dim_in, dim, 1, bias=False)
        '''
        self.norm1 = nn.BatchNorm2d(dim)
        if line_repr:
            self.conv = STGC(dim_in, dim, np.array([A]))
            self.learnableA = nn.Parameter(torch.zeros((A.shape[1], dim)))
            self.A = torch.tensor(A).clone().detach().long()
        else:
            self.conv = STGC(dim_in, dim, A)

    def forward(self, x, dims) -> Tensor:
        #Mamba
        _, C, T, V = x.size()
        N, _, _, _, M = dims
        if self.line_repr:
            learn_repr = self.learnableA[self.A].permute(2, 0, 1)
            x = torch.einsum("bctv, cvw -> bctw", x, learn_repr)
        x = x.reshape(N, C, T*M, V)
        x = self.conv(x)
        x = self.norm1(x)
        x = x.reshape(N*M, C, T, V)
        return x

    def construct_rpe_hops(self, h1):
        h = [None for _ in range(25)]
        h[0] = np.eye(25)
        h[1] = h1
        hops = 0*h[0]
        for i in range(2, 25):
            h[i] = h[i-1] @ h1.transpose(0, 1)
            h[i][h[i] != 0] = 1
        
        for i in range(25-1, 0, -1):
            if np.any(h[i]-h[i-1]):
                h[i] = h[i] - h[i - 1]
                hops += i*h[i]
            else:
                continue

        hops = torch.tensor(hops).long()
        rpe = nn.Parameter(torch.zeros((self.hops.max()+1, dim)))
        return hops, rpe

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.dim}, '
                f'd_state={self.d_state}, d_conv={self.d_conv})')

In [None]:
class GraphEnt(nn.Module):
    def __init__(self, dim_in, dim, A, line_repr):
        super().__init__()
        #SpatialGraphCov(dim, dim, self.A)
        self.conv = GCNMamba(dim, dim, A, line_repr, d_state=64, d_conv=4) # changed number of param
        self.act = nn.ReLU(inplace = True)
        #self.norm_e = nn.BatchNorm2d(dim)
        self.dim_up = unit_tcn(dim_in, dim, kernel_size=1)

    
    def forward(self, x, dims):
        # N*M, T, V, C
        N, C, T, V, M = dims
        #e = self.norm_e(e)
        '''N * M - number of video sequences with person number
        C - number of channels (3d position of points)
        T - number of frames
        V - number of skeleton points (25)
        order is: N*M, T, V, C
        '''
        x = self.dim_up(x) 
        # Sum
        x = self.conv(x, dims).contiguous()
        x = self.act(x)
        return x

In [None]:
class GraphTCN(nn.Module):
    def __init__(self, dim_in, dim, A, stride=1, line_repr = False):
        super().__init__()
        self.dim_in = dim_in
        self.dim = dim
        self.line_repr = line_repr
        self.conv = GraphEnt(dim_in, dim, A, line_repr)
        self.tcn = MultiScale_TemporalConv(dim, dim, kernel_size=5, stride=stride,
                                            dilations=[1,2],
                                            # residual=True has worse performance in the end
                                            residual=False)
        self.act = nn.ReLU(inplace = True)
        if dim_in == dim and stride == 1:
            self.residual = lambda x: x
        else:
            self.residual = unit_tcn(dim_in, dim, kernel_size=1, stride=stride)

    

    def forward(self, x, dims):  
        x = self.tcn(self.conv(x, dims)) + self.residual(x)
        #x = self.dropout(x)
        x = self.act(x)
        return x

In [None]:
class GraphModel(nn.Module):
    def __init__(self, dim_in, dim):
        super().__init__()

        # line_repr and A
        self_link = [(i, i) for i in range(25)]
        inward_ori_index = [(1, 2), (2, 21), (3, 21), (4, 3), (5, 21), (6, 5), (7, 6),
                    (8, 7), (9, 21), (10, 9), (11, 10), (12, 11), (13, 1),
                    (14, 13), (15, 14), (16, 15), (17, 1), (18, 17), (19, 18),
                    (20, 19), (22, 23), (23, 8), (24, 25), (25, 12)]
        inward = [(i - 1, j - 1) for (i, j) in inward_ori_index]
        outward = [(j, i) for (i, j) in inward]
        self.neighbor = inward + outward
        self.A = get_spatial_graph(25, self_link, inward, outward)
        h1 = self.A.sum(0)
        h1[h1 != 0] = 1
        #h1 = gdc(sp.csr_matrix(h1), 0.15, 5e-2).toarray()
        self.I_edge, B = construct_incidence_matrix(h1)

        self.l1 = GraphTCN(dim_in, dim, self.A)
        self.l2 = GraphTCN(dim, dim, self.A)
        self.l3 = GraphTCN(dim, dim, self.A)
        self.l4 = GraphTCN(dim, dim, self.A)
        self.l5 = GraphTCN(dim, dim*2, self.A, stride=2)
        self.l6 = GraphTCN(dim*2, dim*2, self.A)
        self.l7 = GraphTCN(dim*2, dim*2, self.A)
        self.l8 = GraphTCN(dim*2, dim*4, self.A, stride=2)
        self.l9 = GraphTCN(dim*4, dim*4, self.A)
        self.l10 = GraphTCN(dim*4, dim*4, self.A)

        # for line repr
        self.r1 = GraphTCN(dim_in, dim, B, line_repr=True)
        self.r2 = GraphTCN(dim, dim, B, line_repr=True)
        self.r3 = GraphTCN(dim, dim, B, line_repr=True)
        self.r4 = GraphTCN(dim, dim, B, line_repr=True)
        self.r5 = GraphTCN(dim, dim, B, line_repr=True)
        self.r6 = GraphTCN(dim, dim*2, B, stride = 2, line_repr=True)
        self.r7 = GraphTCN(dim*2, dim*2, B, line_repr=True)
        self.r8 = GraphTCN(dim*2, dim*2, B, line_repr=True)
        self.r9 = GraphTCN(dim*2, dim*4, B, stride=2, line_repr=True)
        self.r10 = GraphTCN(dim*4, dim*4, B, line_repr=True)

        self.fc1 = nn.Linear(dim*4, 60)
        self.mlp = nn.Sequential(
            self.fc1
        )
        self.data_bn = nn.BatchNorm1d(2*3*25)
        nn.init.normal_(self.fc1.weight, 0, math.sqrt(2. / 60))
        bn_init(self.data_bn, 1)
    
    def forward(self, x):
        N, C, T, V, M = x.size()
        dims = x.size()
        x = x.permute(0, 4, 3, 1, 2).contiguous().view(N, M * V * C, T)
        x = self.data_bn(x)
        x = x.view(N, M, V, C, T).contiguous().view(N * M, V, C, T).permute(0, 2, 3, 1).contiguous()
        
        # line graph representation
        self.I_edge = self.I_edge.to(x.device)
        y = torch.einsum("bctv, ve -> bcte", x, self.I_edge)
        # N*M, C, T, V
        y = self.r1(y, dims)
        x = self.l1(x, dims) + self.convert_line_to_x(y, self.I_edge)

        y = self.r2(y, dims)
        x = self.l2(x + self.convert_line_to_x(y, self.I_edge), dims)

        y = self.r3(y, dims)
        x = self.l3(x + self.convert_line_to_x(y, self.I_edge), dims)

        y = self.r4(y, dims)
        x = self.l4(x + self.convert_line_to_x(y, self.I_edge), dims)

        y = self.r5(y, dims)
        x = self.l5(x + self.convert_line_to_x(y, self.I_edge), dims)

        y = self.r6(y, dims)
        x = self.l6(x + self.convert_line_to_x(y, self.I_edge), dims)

        y = self.r7(y, dims)
        x = self.l7(x + self.convert_line_to_x(y, self.I_edge), dims)

        y = self.r8(y, dims)
        x = self.l8(x + self.convert_line_to_x(y, self.I_edge), dims)

        y = self.r9(y, dims)
        x = self.l9(x + self.convert_line_to_x(y, self.I_edge), dims)

        y = self.r10(y, dims)
        x = self.l10(x + self.convert_line_to_x(y, self.I_edge), dims)
        
        '''x = self.l6(x, dims)
        x = self.l7(x, dims)
        x = self.l8(x, dims)
        x = self.l9(x, dims)
        x = self.l10(x, dims)'''
        '''
        order is: N*M, T, V, C
        '''
        #x = x.permute(0,3,1,2)
        _, C, T, V = x.size()
        x = x.view(N, M, C, -1)
        # order is: N, M, C, T*V
        x = x.mean(3).mean(1)
        x = self.mlp(x)
        return x
    
    def convert_line_to_x(self, y, I):
        x = torch.einsum("bcte, ve -> bctv", y, I)
        return x

In [None]:
def train(epoch, loader):
    model.train()
    loss_value = []
    acc_value = []
    process = tqdm(loader, ncols=80)
    for batch_idx, (data, label, index) in enumerate(process):
        with torch.no_grad():
            data = data.float().to(device)
            label = label.long().to(device)
        out = model(data)
        loss = lossC(out, target=label)
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        # Gradient clipping
        #torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        # Step the optimizer
        optimizer.step()
        loss_value.append(loss.data.item())
        value, predict_label = torch.max(out.data, 1)
        acc = torch.mean((predict_label == label.data).float())
        acc_value.append(acc.data.item())
    return np.nanmean(loss_value), np.nanmean(acc_value)*100

In [None]:
def evalt(loader, wrong_file = None, result_file = None, model_name = 'Model', save=True, save_pickle=False):
    if wrong_file is not None:
            f_w = open(wrong_file, 'w')
    if result_file is not None:
        f_r = open(result_file, 'w')
    model.eval()
    #print('Eval epoch: {}'.format(epoch + 1))
    loss_value = []
    score_frag = []
    show_topk = [1, 5]
    process = tqdm(loader, ncols=80)
    for batch_idx, (data, label, index) in enumerate(process):
        with torch.no_grad():
            data = data.float().cuda(device)
            label = label.long().cuda(device)
            output = model(data)

            loss = lossC(output, label)

            score_frag.append(output.data.cpu().numpy())
            loss_value.append(loss.data.item())

            _, predict_label = torch.max(output.data, 1)

        if wrong_file is not None or result_file is not None:
            predict = list(predict_label.cpu().numpy())
            true = list(label.data.cpu().numpy())
            for i, x in enumerate(predict):
                if result_file is not None:
                    f_r.write(str(x) + ',' + str(true[i]) + '\n')
                if x != true[i] and wrong_file is not None:
                    f_w.write(str(index[i]) + ',' + str(x) + ',' + str(true[i]) + '\n')

    score = np.concatenate(score_frag)
    loss = np.mean(loss_value)
    accuracy = loader.dataset.top_k(score, 1)
    print('Accuracy: ', accuracy, ' model: ', model_name)
    print('\tMean {} loss of {} batches: {}.'.format(
        "test", len(loader), np.mean(loss_value)))
    for k in show_topk:
        print('\tTop{}: {:.2f}%'.format(
            k, 100 * loader.dataset.top_k(score, k)))

    top5_acc = loader.dataset.top_k(score, 5) * 100
    top1_acc = loader.dataset.top_k(score, 1) * 100
    global best_epoch_acc
    if save and top1_acc > best_epoch_acc:
        state_dict = model.state_dict()
        torch.save(state_dict, f'mainruns/{model_name}.pt')
        best_epoch_acc = top1_acc
    if save_pickle:
        score_dict = dict(
                    zip(loader.dataset.sample_name, score))
        with open('work_test/ntu/xview_bonevel_best_acc.pkl', 'wb') as f:
            pickle.dump(score_dict, f)
    return top5_acc, loss, top1_acc

In [None]:
def load_data(data, bone, vel, batch_size=32):
    print(f'data/ntu/NTU60_C{data}.npz', end=' ')
    print("bone:", bone, "vel:", vel) 
    data_loader = {}
    # train
    data_loader['train'] = torch.utils.data.DataLoader(
        dataset=Feeder(f'data/ntu/NTU60_C{data}.npz', split='train', p_interval =[0.5, 1], window_size=64, bone=bone, vel=vel),
        batch_size=batch_size,
        num_workers=0,
        shuffle=True,
        pin_memory=True,
        drop_last=True,
        worker_init_fn=init_seed)
    # test
    data_loader['test'] = torch.utils.data.DataLoader(
        dataset=Feeder(f'data/ntu/NTU60_C{data}.npz', split='test', p_interval =[0.95], window_size=64, bone=bone, vel=vel),
        batch_size=batch_size*2,
        num_workers=0,
        shuffle=False,
        worker_init_fn=init_seed)
    return data_loader

In [None]:
def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, mode='fan_out')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

def adjust_learning_rate(epoch):
    global lr, warm_epochs, lr_decay_rate, step
    if epoch <= warm_epochs:
        lr = base_lr * epoch / warm_epochs
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    else:
        if sched == 'plat':
            scheduler.step(tst[1])
        elif sched == 'cos' or sched == 'lin':
            scheduler.step()
        elif sched == 'custom':
            lr = base_lr * (lr_decay_rate ** np.sum(epoch >= np.array(step)))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
    return lr

In [None]:
init_seed(2)

torch.cuda.empty_cache()

phase = 'train'
optimz = 'SGD'
warm_epochs = 5
epochs = 140
batch_size = 64
out_channels = 48
dataset = 'S'
bone, vel = False, False

use_wandb = False

# Optimizers
if optimz == 'Adam' or optimz == 'AdamW' or optimz == 'NAdam':
    base_lr = 0.001
    lr = base_lr
    weight_decay = 0.1
    lr_decay_rate = 0.03
    sched = 'cos'
    lin_fin_epoch = 70
elif optimz == 'SGD':
    base_lr = 0.025
    lr = base_lr
    weight_decay = 0.0006
    lr_decay_rate = 0.1
    step = [110, 120]
    sched = 'custom'

if use_wandb:
    wandb.login()
    wandb.init(
        # set the wandb project where this run will be logged
        project="Sanzhar_URIS",
    
        # track hyperparameters and run metadata
        config={
        "learning_rate": lr,
        "architecture": "Mamba",
        "dataset": "NTU RGB+D 60",
        "epochs": epochs,
        "out_channels": out_channels,
        "batch_size": batch_size,
        "weight_decay": weight_decay,
        "lr_decay_rate": lr_decay_rate,
        "optimizer": optimz,
        "scheduler": sched,
        "dataset": dataset,
        "bone": bone,
        "vel": vel
        }
    )

device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
model = GraphModel(3, out_channels).to(device)
# model.apply(init_weights)
'''devices = [2, 3]
device = devices[0]
model = GraphModel(3, out_channels)
model = nn.DataParallel(model,
                        device_ids=devices,
                        output_device=device)
model.to(device)'''

# Optimizer
if optimz == 'SGD':
    optimizer = optim.SGD(
                model.parameters(),
                momentum=0.9,
                lr=base_lr,
                nesterov = True,
                weight_decay = weight_decay)
elif optimz == 'Adam':
    optimizer = optim.Adam(model.parameters(), lr=base_lr)
elif optimz == 'NAdam':
    optimizer = optim.NAdam(model.parameters(), lr=base_lr, weight_decay=weight_decay)
elif optimz == 'AdamW':
    optimizer = optim.AdamW(model.parameters(), lr=base_lr, weight_decay=weight_decay)
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Loss
lossC = nn.CrossEntropyLoss().to(device)

# Scheduler
if sched == 'plat':
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=lr_decay_rate, patience=5)
elif sched == 'cos':
    scheduler = CosineAnnealingLR(optimizer, epochs, base_lr*lr_decay_rate)
elif sched == 'lin':
    scheduler = LinearLR(optimizer, start_factor=1.0, end_factor=lr_decay_rate, total_iters=lin_fin_epoch)

# loader
data_loader = load_data(dataset, bone, vel, batch_size)

if phase == 'train':
    print("Parameters:", count_parameters(model))
    '''for name, param in model.named_parameters():
        if param.requires_grad:
            print(name, param.numel())'''
    print(model)
    best_epoch_acc = 0
    for epoch in range(1, epochs+1):
        adjust_learning_rate(epoch)
        trn = train(epoch, loader=data_loader['train'])
        if use_wandb:
            tst = evalt(loader=data_loader['test'], model_name = f'Model{out_channels}_{wandb.run.name}')
        else:
            tst = evalt(loader=data_loader['test'], save = False)
        if use_wandb:
            wandb.log({"train_loss": trn[0], "train_acc": trn[1], "test_acc": tst[0], "test_acc_1st": tst[2], "test_loss": tst[1]})
        current_lr = optimizer.param_groups[0]['lr']
        print(f'Epoch: {epoch:02d}, Loss: {trn[0]:.4f}, Train Acc: {trn[1]:.4f}, \
        Accuracy: {tst[2]:.4f}, Test_Loss: {tst[1]:.4f}, Current_Lr: {current_lr:.6f}')
elif phase == 'test':
    print("Parameters:", count_parameters(model))
    weights_path = 'mainruns/Model96_light-bee-139.pt'
    model.load_state_dict(torch.load(weights_path, weights_only=True))
    wf = weights_path.replace('.pt', '_wrong.txt')
    rf = weights_path.replace('.pt', '_right.txt')
    evalt(loader=data_loader['test'], wrong_file=wf, result_file=rf, save=False, save_pickle=False)
elif phase == 'flops':
    #from ptflops import get_model_complexity_in
    print("Parameters:", count_parameters(model))
    weights_path = 'mainrjupuns/Model160_rich-terrain-231.pt'
    model.load_state_dict(torch.load(weights_path, weights_only=True))
    # n, c, t, v, m
    #x = torch.randn([1, 3, 64, 25, 2]).to(device)  
    Flops, params = get_model_complexity_info(model,  tuple([3, 64, 25, 2]), as_strings=False, print_per_layer_stat=False, verbose=False)
    print(f"Model stats: GFlops: {Flops*1e-9}, and (M) params: {params*1e-6}")

In [None]:
'''h2_forward = x.permute(0, 2, 3, 1).contiguous()
        h2_backward = torch.flip(h2_forward, [2]).view(N, M * T * V, C)
        h2 = torch.cat((h2_forward.view(N, M * T * V, C), h2_backward), dim=1)
        with torch.cuda.device(x.device):
            h2 = self.self_attn(h2)
        h2_forward, h2_backward = h2.chunk(2, dim=1)
        h2_backward = torch.flip(h2_backward.reshape(N * M, T, V, C), [2])
        h2 = h2_forward.reshape(N * M, T, V, C) + h2_backward
        h2 = h2.contiguous().permute(0, 3, 1, 2)'''
'''
        # linear transform
        h1 = self.conv(x).permute(0, 2, 3, 1)

        # k-hops
        pos_emb = self.rpe[self.hops].permute(2, 0, 1)
        pos_emb = torch.einsum('nctv, cvw -> nctw', x, pos_emb).permute(0, 2, 3, 1)
    
        # hyperedge   
        label = F.one_hot(torch.tensor(self.joint_label)).float().to(x.device)
        z = x @ (label / label.sum(dim=0, keepdim=True))
        z = self.pe_proj(z).permute(3, 0, 1, 2)
        # n=1, c=2, t=3, v=0
        e = z[self.joint_label].permute(1, 3, 0, 2).contiguous()
        
        h_ov = torch.cat((pos_emb, h1, e), dim=1)
        h_ov = h_ov.reshape(N, -1, C)
        with torch.cuda.device(x.device):
            h_ov = self.self_attn(h_ov)
        h_ov1, h_ov2, h_ov3 = h_ov.reshape(N*M, 3*T, V, C).chunk(3, dim=1)
        h_ov = (h_ov1 + h_ov2 + h_ov3).permute(0, 3, 1, 2)
        out = self.norm3(h_ov + x).contiguous()
        '''