In [2]:
from torch.utils.data import Dataset
import random
import torch
import torch.nn as nn
from torch import Tensor
import continual as co
import pytorch_lightning as pl
from torch.autograd import Variable
from torch.nn import functional as F
from abc import abstractmethod
from typing_extensions import Self
import math
from typing import  Tuple, Optional, Any, Union
from functools import partial
from continual import TensorPlaceholder
from continual.module import CallMode
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
import torch.optim as optim
from tqdm import tqdm
import time
import os
import torchmetrics

  from .autonotebook import tqdm as notebook_tqdm


In [3]:


torch.manual_seed(42)
torch.cuda.manual_seed(42)


def edge2mat(link, num_node):
    A = np.zeros((num_node, num_node))
    for i, j in link:
        A[j, i] = 1
    return A


def normalize_digraph(A):
    Dl = np.sum(A, 0)
    _, w = A.shape
    Dn = np.zeros((w, w))
    for i in range(w):
        if Dl[i] > 0:
            Dn[i, i] = Dl[i] ** (-1)
    AD = np.dot(A, Dn)
    return AD


def get_spatial_graph(num_node, self_link, inward, outward):
    I = edge2mat(self_link, num_node)  # noqa: E741
    In = normalize_digraph(edge2mat(inward, num_node))
    Out = normalize_digraph(edge2mat(outward, num_node))
    A = np.stack((I, In, Out))
    return A


class Graph():

    def __init__(self,
                 layout='DHG14/28',
                 strategy='uniform',
                 max_hop=2,
                 dilation=1):
        self.max_hop = max_hop
        self.dilation = dilation

        self.get_edge(layout)
        self.hop_dis = self.get_hop_distance(
            self.num_node, self.edge, max_hop=max_hop)
        self.get_adjacency(strategy)

    def __str__(self):
        return self.A

    def get_edge(self, layout):
        if layout == 'DHG14/28':
            self.num_node = 22
            self_link = [(i, i) for i in range(self.num_node)]
            neighbor_link = [(0, 1),
                             (0, 2),
                             (1, 0),
                             (1, 6),
                             (1, 10),
                             (1, 14),
                             (1, 18),
                             (2, 0),
                             (2, 3),
                             (3, 2),
                             (3, 4),
                             (4, 3),
                             (4, 5),
                             (5, 4),
                             (6, 1),
                             (6, 7),
                             (7, 6),
                             (7, 8),
                             (8, 7),
                             (8, 9),
                             (9, 8),
                             (10, 1),
                             (10, 11),
                             (11, 10),
                             (11, 12),
                             (12, 11),
                             (12, 13),
                             (13, 12),
                             (14, 1),
                             (14, 15),
                             (15, 14),
                             (15, 16),
                             (16, 15),
                             (16, 17),
                             (17, 16),
                             (18, 1),
                             (18, 19),
                             (19, 18),
                             (19, 20),
                             (20, 19),
                             (20, 21),
                             (21, 20)]
            self.edge = self_link + neighbor_link
            self.center = 1
        elif layout == "SHREC21":
            self.num_node = 20
            self_link = [(i, i) for i in range(self.num_node)]
            neighbor_link = [
                (0, 1),
                (1, 2),
                (2, 3),
                (0, 4),
                (4, 5),
                (5, 6),
                (6, 7),
                (0, 8),
                (8, 9),
                (9, 10),
                (10, 11),
                (0, 12),
                (12, 13),
                (13, 14),
                (14, 15),
                (0, 16),
                (16, 17),
                (17, 18),
                (18, 19),
            ]
            self.edge = self_link + neighbor_link
            self.center = 0
        elif layout == "FPHA":
            self.num_node = 21
            self_link = [(i, i) for i in range(self.num_node)]
            neighbor_link = [
                (0, 1),
                (0, 2),
                (0, 3),
                (0, 4),
                (0, 5),
                (1, 0),
                (1, 6),
                (2, 0),
                (2, 7),
                (3, 0),
                (3, 8),
                (4, 0),
                (4, 9),
                (5, 0),
                (5, 10),
                (6, 1),
                (6, 11),
                (7, 2),
                (7, 12),
                (8, 3),
                (8, 13),
                (9, 4),
                (9, 14),
                (10, 5),
                (10, 15),
                (11, 6),
                (11, 16),
                (12, 7),
                (12, 17),
                (13, 8),
                (13, 18),
                (14, 9),
                (14, 19),
                (15, 10),
                (15, 20),
                (16, 11),
                (17, 12),
                (18, 13),
                (19, 14),
                (20, 15)
            ]
            self.edge = self_link + neighbor_link
            self.center = 0
        else:
            raise ValueError("Do Not Exist This Layout.")

    def get_adjacency(self, strategy):
        valid_hop = range(0, self.max_hop + 1, self.dilation)
        adjacency = np.zeros((self.num_node, self.num_node))
        for hop in valid_hop:
            adjacency[self.hop_dis == hop] = 1
        normalize_adjacency = self.normalize_digraph(adjacency)

        if strategy == 'uniform':
            A = np.zeros((1, self.num_node, self.num_node))
            A[0] = normalize_adjacency
            self.A = A
        elif strategy == 'distance':
            A = np.zeros((len(valid_hop), self.num_node, self.num_node))
            for i, hop in enumerate(valid_hop):
                A[i][self.hop_dis == hop] = normalize_adjacency[self.hop_dis == hop]
            self.A = A
        elif strategy == 'spatial':
            A = []
            for hop in valid_hop:
                a_root = np.zeros((self.num_node, self.num_node))
                a_close = np.zeros((self.num_node, self.num_node))
                a_further = np.zeros((self.num_node, self.num_node))
                for i in range(self.num_node):
                    for j in range(self.num_node):
                        if self.hop_dis[j, i] == hop:
                            if self.hop_dis[j, self.center] == self.hop_dis[i, self.center]:
                                a_root[j, i] = normalize_adjacency[j, i]
                            elif self.hop_dis[j, self.center] > self.hop_dis[i, self.center]:
                                a_close[j, i] = normalize_adjacency[j, i]
                            else:
                                a_further[j, i] = normalize_adjacency[j, i]
                if hop == 0:
                    A.append(a_root)
                else:
                    A.append(a_root + a_close)
                    A.append(a_further)
            A = np.stack(A)
            self.A = A
        else:
            raise ValueError("Do Not Exist This Strategy")

    def get_hop_distance(self, num_node, edge, max_hop=1):
        A = np.zeros((num_node, num_node))
        for i, j in edge:
            A[j, i] = 1
            A[i, j] = 1

        hop_dis = np.zeros((num_node, num_node)) + np.inf
        transfer_mat = [np.linalg.matrix_power(
            A, d) for d in range(max_hop + 1)]
        arrive_mat = (np.stack(transfer_mat) > 0)
        for d in range(max_hop, -1, -1):
            hop_dis[arrive_mat[d]] = d
        return hop_dis

    def normalize_digraph(self, A):
        Dl = np.sum(A, 0)
        num_node = A.shape[0]
        Dn = np.zeros((num_node, num_node))
        for i in range(num_node):
            if Dl[i] > 0:
                Dn[i, i] = Dl[i]**(-1)
        AD = np.dot(A, Dn)
        return AD

    def normalize_undigraph(self, A):
        Dl = np.sum(A, 0)
        num_node = A.shape[0]
        Dn = np.zeros((num_node, num_node))
        for i in range(num_node):
            if Dl[i] > 0:
                Dn[i, i] = Dl[i]**(-0.5)
        DAD = np.dot(np.dot(Dn, A), Dn)
        return DAD


num_joint = 20
max_frame = 2500


class Feeder_SHREC21(Dataset):
    """
    Feeder for skeleton-based gesture recognition in shrec21-skeleton dataset
    Arguments:
        data_path: the path to '.npy' data, the shape of data should be (N, C, T, V, M)
    """

    def __init__(
            self,
            data_path="SHREC21",
            set_name="training",
            window_size=10,
            aug_by_sw=False,
            is_segmented=False
    ):
        self.data_path = data_path
        self.set_name = set_name
        self.classes = ["",
                        "RIGHT",
                        "KNOB",
                        "CROSS",
                        "THREE",
                        "V",
                        "ONE",
                        "FOUR",
                        "GRAB",
                        "DENY",
                        "MENU",
                        "CIRCLE",
                        "TAP",
                        "PINCH",
                        "LEFT",
                        "TWO",
                        "OK",
                        "EXPAND",
                        ]
        self.class_to_idx = {class_l: idx for idx,
                             class_l in enumerate(self.classes)}
        self.window_size = window_size
        self.aug_by_sw = aug_by_sw
        self.is_segmented = is_segmented
        self.load_data()

    def load_data(self):
        self.dataset = []
        # load file list
        # classes = set([''])
        # self.classes = []
        with open(
                f'{self.data_path}/{self.set_name}_set/annotations_revised.txt' if self.set_name == "test" else f'{self.data_path}/{self.set_name}_set/annotations_revised_{self.set_name}.txt',
                mode="r") as f:

            for line in f:
                fields = line.split(';')
                seq_idx = fields[0]
                gestures = fields[1:-1]
                nb_gestures = len(gestures) // 3
                gesture_infos = []
                for i in range(nb_gestures):
                    gesture_info = gestures[i * 3:(i + 1) * 3]
                    gesture_label = gesture_info[0]
                    gesture_start = gesture_info[1]
                    gesture_end = gesture_info[2]
                    gesture_infos.append(
                        (gesture_start, gesture_end, gesture_label))
                    # classes.add(gesture_label)
                self.dataset.append((seq_idx, gesture_infos))

        # self.classes = list(classes)
        # with open('datasets/shrec21/classes.yaml', mode="w") as f:
        #     yaml.dump(self.classes, f, explicit_start=True, default_flow_style=False)

    def __len__(self):
        return len(self.dataset)

    def __iter__(self):
        return self

    def __getitem__(self, index):
        def parse_seq_data(src_file):
            '''
            Retrieves the skeletons sequence for each gesture
            '''
            video = []
            mode = "pos"
            for line in src_file:

                line = line.split("\n")[0]

                data = line.split(";")[:-1]

                frame = []
                point = []
                for data_ele in data:
                    if len(data_ele) == 0:
                        continue
                    point.append(float(data_ele))

                    if len(point) == 3 and mode == "pos":
                        frame.append(point)
                        point = []
                        mode = "quat"
                    elif len(point) == 4 and mode == "quat":
                        frame.append(point)
                        point = []
                        mode = "pos"
                if len(frame) > 0:
                    positions = []
                    quats = []

                    for i in range(num_joint):
                        positions.append(frame[i*2])
                        quats.append(frame[i*2+1])

                    video.append(positions)
            return np.array(video)

        def sample_window(data_num, stride):
            # sample #window_size frames from whole video

            sample_size = self.window_size

            idx_list = [0, data_num - 1]
            for i in range(sample_size):
                if index not in idx_list and index < data_num:
                    idx_list.append(index)
            idx_list.sort()

            while len(idx_list) < sample_size:
                idx = random.randint(0, data_num - 1)
                if idx not in idx_list:
                    idx_list.append(idx)
            idx_list.sort()
            return idx_list
        # output shape (C, T, V, M)
        # get data

        def get_segmented(seq_idx, gesture_infos):
            with open(f'{self.data_path}/{self.set_name}_set/sequences/{seq_idx}.txt', mode="r") as seq_f:
                sequence = parse_seq_data(seq_f)
            labeled_sequence = [(f, "") for f in sequence]
            # if len(labeled_sequence) > max_frame:
            #     max_frame = len(labeled_sequence)
            for gesture_start, gesture_end, gesture_label in gesture_infos:
                labeled_sequence = [
                    (np.array(f), gesture_label if int(gesture_start) <=
                     idx <= int(gesture_end) and label == "" else label)
                    for
                    idx, (f, label) in enumerate(labeled_sequence)]

            frames = [f for f, l in labeled_sequence]
            # print(len(self.classes))
            labels_per_frame = [self.class_to_idx[l]
                                for f, l in labeled_sequence]
            gestures = []
            windows_sub_sequences_per_gesture = {
                i: [] for i in range(len(self.classes))}

            for gesture_start, gesture_end, gesture_label in gesture_infos:
                gesture_start = int(gesture_start)
                gesture_end = int(gesture_end)
                g_frames = frames[gesture_start:gesture_end]
                g_label = labels_per_frame[gesture_start:gesture_end]
                gestures.append((g_frames, g_label))
                label = self.class_to_idx[gesture_label]
                if self.aug_by_sw:
                    num_windows = len(g_frames) // self.window_size

                    for stride in range(1, self.window_size):
                        l = len(g_frames)
                        if l // stride >= self.window_size:
                            window_indices = sample_window(l, stride)
                            window = [g_frames[idx] for idx in window_indices]
                            windows_sub_sequences_per_gesture[label].append(
                                (window, label))

            ng_sequences = []
            ng_seq = []
            l = len(frames)
            indices_ng = []
            for i in range(len(frames)-1):
                f_curr = frames[i]
                f_next = frames[i+1]
                l_curr = labels_per_frame[i]
                l_next = labels_per_frame[i+1]

                if l_curr == 0 and l_next == 0:
                    indices_ng.append(i)
                    ng_seq.append(f_curr)
                    if i == l-2:
                        ng_seq.append(f_next)
                        ng_sequences.append((ng_seq, 0))
                        ng_seq = []
                        continue
                elif l_curr == 0 and l_next != 0:
                    indices_ng.append(i)
                    ng_seq.append(f_curr)
                    ng_sequences.append((ng_seq, 0))
                    ng_seq = []
                    continue

            return gestures, ng_sequences, windows_sub_sequences_per_gesture

        def get_full_sequences(seq_idx, gesture_infos):
            with open(f'{self.data_path}/{self.set_name}_set/sequences/{seq_idx}.txt', mode="r") as seq_f:
                sequence = parse_seq_data(seq_f)
            labeled_sequence = [(f, "") for f in sequence]
            # if len(labeled_sequence) > max_frame:
            #     max_frame = len(labeled_sequence)
            for gesture_start, gesture_end, gesture_label in gesture_infos:
                labeled_sequence = [
                    (np.array(f), gesture_label if int(gesture_start) <=
                     idx <= int(gesture_end) and label == "" else label)
                    for
                    idx, (f, label) in enumerate(labeled_sequence)]

            frames = [f for f, l in labeled_sequence]

            labels_per_frame = [self.classes.index(
                l) for f, l in labeled_sequence]
            return labeled_sequence, np.array(frames), labels_per_frame
        seq_idx, gesture_infos = self.dataset[index]

        if self.is_segmented:
            return get_segmented(seq_idx, gesture_infos)
        else:
            return get_full_sequences(seq_idx, gesture_infos)


def get_window_label(label, num_classes=18):

    W = len(label)
    sum = torch.zeros((num_classes))
    for t in range(W):
        sum[label[t]] += 1
    return sum.argmax(dim=-1).item()


def gendata(
        data_path,
        set_name,
        max_frame,
        window_size=20,
        aug_by_sw=False,
        is_segmented=False
):
    feeder = Feeder_SHREC21(
        data_path=data_path,
        set_name=set_name,
        window_size=window_size,
        aug_by_sw=aug_by_sw,
        is_segmented=is_segmented
    )
    dataset = feeder.dataset
    if is_segmented:
        data = []
        ng_sequences_data = []
        windows_sub_sequences_data = {i: []
                                      for i in range(len(feeder.classes))}
        for i, s in enumerate(tqdm(dataset)):
            data_el, ng_sequences, windows_sub_sequences_per_gesture = feeder[i]
            ng_sequences_data = [*ng_sequences_data, *ng_sequences]
            l = len(data_el)
            # for w in range(num_windows):
            for idx, gesture in enumerate(data_el):
                current_skeletons_window = np.array(gesture[0])
                label = gesture[1]
                label = get_window_label(label)
                windows_sub_sequences_data[label] = [
                    *windows_sub_sequences_data[label], *windows_sub_sequences_per_gesture[label]]
                data.append((current_skeletons_window, label))

        return data, ng_sequences_data, windows_sub_sequences_data
    else:
        data = []
        for i, s in enumerate(tqdm(dataset)):
            labeled_seq, frames, labels = feeder[i]
            data.append((frames, labels))

        return data


class GraphDataset(Dataset):
    def __init__(
        self,
        data_path,
        set_name,
        window_size,
        use_data_aug=False,
        normalize=True,
        scaleInvariance=False,
        translationInvariance=False,
        isPadding=False,
        useSequenceFragments=False,
        useRandomMoving=False,
        useMirroring=False,
        useTimeInterpolation=False,
        useNoise=False,
        useScaleAug=False,
        useTranslationAug=False,
        use_aug_by_sw=False,
        nb_sub_sequences=10,
        sample_classes=False,
        is_segmented=False,
        number_of_samples_per_class=0
    ):
        """Initialise a Graph dataset
        """
        self.data_path = data_path
        self.set_name = set_name
        self.use_data_aug = use_data_aug
        self.window_size = window_size
        self.compoent_num = 20
        self.normalize = normalize
        self.scaleInvariance = scaleInvariance
        self.translationInvariance = translationInvariance
        # self.transform = transform
        self.isPadding = isPadding
        self.useSequenceFragments = useSequenceFragments
        self.useRandomMoving = useRandomMoving
        self.useMirroring = useMirroring
        self.useTimeInterpolation = useTimeInterpolation
        self.useNoise = useNoise
        self.useScaleAug = useScaleAug
        self.useTranslationAug = useTranslationAug
        self.use_aug_by_sw = use_aug_by_sw
        self.number_of_samples_per_class = number_of_samples_per_class
        self.is_segmented = is_segmented
        self.nb_sub_sequences=nb_sub_sequences
        self.sample_classes_=sample_classes
        self.classes = ["No gesture",
                        "RIGHT",
                        "KNOB",
                        "CROSS",
                        "THREE",
                        "V",
                        "ONE",
                        "FOUR",
                        "GRAB",
                        "DENY",
                        "MENU",
                        "CIRCLE",
                        "TAP",
                        "PINCH",
                        "LEFT",
                        "TWO",
                        "OK",
                        "EXPAND",
                        ]
        self.load_data()
        



        





    def load_data(self):
        # Data: N C V T M
        if self.is_segmented :
            self.data, self.ng_sequences_data, self.gesture_sub_sequences_data = gendata(
                self.data_path,
                self.set_name,
                max_frame,
                self.window_size,
                self.use_aug_by_sw,
                self.is_segmented
            )
            self.sample_no_gesture_class()
            print("Number of gestures per class in the original "+self.set_name+" set :")
            self.print_classes_information()
            print(self.set_name)
            data = []
            for idx, data_el in enumerate(self.data):
                if np.array(data_el[0]).shape[0] > 0:
                    data.append(data_el)

            self.data = data
            if self.sample_classes_:
                self.sample_classes(self.nb_sub_sequences)
            if self.use_data_aug:
                print("Augmenting data ....")
                augmented_data = []
                for idx, data_el in enumerate(self.data):
                    augmented_skeletons = self.data_aug(self.preprocessSkeleton(
                        torch.from_numpy(np.array(data_el[0])).float()))
                    for s in augmented_skeletons:
                        augmented_data.append((s, data_el[1]))
                self.data = augmented_data
            if self.use_aug_by_sw or self.use_data_aug:
                print("Number of gestures per class in the " +
                    self.set_name+" set after augmentation:")
                self.print_classes_information()
        else :
            self.data = gendata(
                self.data_path,
                self.set_name,
                max_frame,
                self.window_size,
                self.use_aug_by_sw,
                self.is_segmented
            )

    def print_classes_information(self):
        data_dict = {i: 0 for i in range(len(self.classes))}
        for seq, label in self.data:
            data_dict[label] += 1
        for class_label in data_dict.keys():
            print("Class", self.classes[class_label],
                  "has", data_dict[class_label], "samples")

    def sample_no_gesture_class(self):
        random.Random(4).shuffle(self.ng_sequences_data)
        print(len(self.ng_sequences_data))
        samples = self.ng_sequences_data[:self.number_of_samples_per_class*2+(self.nb_sub_sequences if self.use_aug_by_sw else 0)]

        self.data = [*self.data, *samples]

    def sample_classes(self, nb_sub_sequences):
        # Data: N C V T M
        data_dict = {i: [] for i in range(len(self.classes))}
        data = []
        for seq, label in self.data:
            data_dict[label].append((seq, label))

        for k in data_dict.keys():
            samples = data_dict[k][:self.number_of_samples_per_class if k !=
                                   0 else self.number_of_samples_per_class * 3]
            if self.use_aug_by_sw:
                samples = [
                    *samples, *self.gesture_sub_sequences_data[k][:nb_sub_sequences]]
            data = [*data, *samples]

        self.data = data

    def __len__(self):
        return len(self.data)

    def __iter__(self):
        return self

    def preprocessSkeleton(self, skeleton):
        def translationInvariance(skeleton):
            # normalize by palm center value at frame=1
            skeleton -= torch.clone(skeleton[0][1])
            skeleton = skeleton.float()
            return skeleton

        def scaleInvariance(skeleton):

            x_c = torch.clone(skeleton)

            distance = torch.sqrt(torch.sum((x_c[0, 1]-x_c[0, 0])**2, dim=-1))

            factor = 1/distance

            x_c *= factor

            return x_c

        def normalize(skeleton):

            # if self.transform:
            #     skeleton = self.transform(skeleton.numpy())
            skeleton = F.normalize(skeleton)

            return skeleton
        if self.normalize:
            skeleton = normalize(skeleton)
        if self.scaleInvariance:
            skeleton = scaleInvariance(skeleton)
        if self.translationInvariance:
            skeleton = translationInvariance(skeleton)

        return skeleton

    def __getitem__(self, index):

        data_numpy, label = self.data[index]
        # label = self.labels[index]

        skeleton = np.array(data_numpy)
        
        # if self.data_aug :
        #     pass

        data_num = skeleton.shape[0]
        if self.is_segmented==False :
            if data_num < max_frame :
                if self.isPadding:
                    # padding
                    skeleton = self.auto_padding(skeleton, max_frame)
                    # label
                    label=[*label,*[ 0 for _ in range(max_frame-len(label))]]
                else :
                    skeleton = self.upsample(skeleton, self.window_size)
            else :
                idx_list = self.sample_frames(data_num, max_frame)
                skeleton = [skeleton[idx] for idx in idx_list]
                skeleton = np.array(skeleton)
                skeleton = torch.from_numpy(skeleton)

            return skeleton, label, index
            
        if data_num >= self.window_size:
            idx_list = self.sample_frames(data_num, self.window_size)
            skeleton = [skeleton[idx] for idx in idx_list]
            skeleton = np.array(skeleton)
            skeleton = torch.from_numpy(skeleton)
        else:
            skeleton = self.upsample(skeleton, self.window_size)

        # print(label)
        return skeleton, label, index

    def data_aug(self, skeleton):

        def scale(skeleton):
            ratio = 0.2
            low = 1 - ratio
            high = 1 + ratio
            factor = np.random.uniform(low, high)
            video_len = skeleton.shape[0]
            for t in range(video_len):
                for j_id in range(self.compoent_num):
                    skeleton[t][j_id] *= factor
            skeleton = np.array(skeleton)
            return skeleton

        def shift(skeleton):
            low = -0.1
            high = -low
            offset = np.random.uniform(low, high, 3)
            video_len = skeleton.shape[0]
            for t in range(video_len):
                for j_id in range(self.compoent_num):
                    skeleton[t][j_id] += offset
            skeleton = np.array(skeleton)
            return skeleton

        def noise(skeleton):
            low = -0.1
            high = -low
            # select 4 joints
            all_joint = list(range(self.compoent_num))
            random.Random(4).shuffle(all_joint)
            selected_joint = all_joint[0:4]
            for j_id in selected_joint:
                noise_offset = np.random.uniform(low, high, 3)
                for t in range(skeleton.shape[0]):
                    skeleton[t][j_id] += noise_offset

            skeleton = np.array(skeleton)
            return skeleton

        def time_interpolate(skeleton):
            skeleton = np.array(skeleton)
            video_len = skeleton.shape[0]

            r = np.random.uniform(0, 1)

            result = []

            for i in range(1, video_len):
                displace = skeleton[i] - skeleton[i - 1]  # d_t = s_t+1 - s_t
                displace *= r
                result.append(skeleton[i - 1] + displace)  # r*disp

            while len(result) < self.window_size:
                result.append(result[-1])  # padding
            result = np.array(result)
            return result

        def random_sequence_fragments(sample):
            samples = [sample]
            sample = torch.from_numpy(sample)
            n_fragments = 5
            T, V, C = sample.shape
            if T <= self.window_size:
                return samples
            for _ in range(n_fragments):

                # fragment_len=int(T*fragment_len)
                fragment_len = self.window_size
                max_start_frame = T-fragment_len

                random_start_frame = random.randint(0, max_start_frame)
                new_sample = sample[random_start_frame:random_start_frame+fragment_len]
                samples.append(new_sample.numpy())

            return samples

        def mirroring(data_numpy):
            T, V, C = data_numpy.shape
            data_numpy[:, :, 0] = np.max(
                data_numpy[:, :, 0]) + np.min(data_numpy[:, :, 0]) - data_numpy[:, :, 0]
            return data_numpy

        def random_moving(data_numpy,
                          angle_candidate=[-10., -5., 0., 5., 10.],
                          scale_candidate=[0.9, 1.0, 1.1],
                          transform_candidate=[-0.2, -0.1, 0.0, 0.1, 0.2],
                          move_time_candidate=[1]):
            # input: T,V,C
            data_numpy = np.transpose(data_numpy, (2, 0, 1))
            new_data_numpy = np.zeros(data_numpy.shape)
            C, T, V = data_numpy.shape
            move_time = random.choice(move_time_candidate)

            node = np.arange(0, T, T * 1.0 / move_time).round().astype(int)
            node = np.append(node, T)
            num_node = len(node)

            A = np.random.choice(angle_candidate, num_node)
            S = np.random.choice(scale_candidate, num_node)
            T_x = np.random.choice(transform_candidate, num_node)
            T_y = np.random.choice(transform_candidate, num_node)

            a = np.zeros(T)
            s = np.zeros(T)
            t_x = np.zeros(T)
            t_y = np.zeros(T)

            # linspace
            for i in range(num_node - 1):
                a[node[i]:node[i + 1]] = np.linspace(
                    A[i], A[i + 1], node[i + 1] - node[i]) * np.pi / 180
                s[node[i]:node[i + 1]] = np.linspace(S[i], S[i + 1],
                                                     node[i + 1] - node[i])
                t_x[node[i]:node[i + 1]] = np.linspace(T_x[i], T_x[i + 1],
                                                       node[i + 1] - node[i])
                t_y[node[i]:node[i + 1]] = np.linspace(T_y[i], T_y[i + 1],
                                                       node[i + 1] - node[i])

            theta = np.array([[np.cos(a) * s, -np.sin(a) * s],
                              [np.sin(a) * s, np.cos(a) * s]])

            # perform transformation
            for i_frame in range(T):
                xy = data_numpy[0:2, i_frame, :]
                new_xy = np.dot(theta[:, :, i_frame], xy.reshape(2, -1))

                new_xy[0] += t_x[i_frame]
                new_xy[1] += t_y[i_frame]

                new_data_numpy[0:2, i_frame, :] = new_xy.reshape(2, V)

            new_data_numpy[2, :, :] = data_numpy[2, :, :]

            return np.transpose(new_data_numpy, (1, 2, 0))

        skeleton = np.array(skeleton)
        skeletons = [skeleton]
        if self.useTimeInterpolation:
            skeletons.append(time_interpolate(skeleton))

        if self.useNoise:
            skeletons.append(noise(skeleton))

        if self.useScaleAug:
            skeletons.append(scale(skeleton))

        if self.useTranslationAug:
            skeletons.append(shift(skeleton))

        if self.useSequenceFragments:
            skeletons = [*skeletons, random_sequence_fragments(s)]

        if self.useRandomMoving:
            skeletons.append(random_moving(skeleton))

        if self.useMirroring:
            skeletons = [*skeletons, mirroring(s)]

        return skeletons

    def auto_padding(self, data_numpy, size, random_pad=False):
        T, V, C = data_numpy.shape
        if T < size:
            begin = random.randint(0, size - T) if random_pad else 0
            data_numpy_paded = np.zeros((size, V, C))
            data_numpy_paded[begin:begin + T, :, :] = data_numpy
            return data_numpy_paded
        else:
            return data_numpy

    def upsample(self, skeleton, max_frames):
        tensor = torch.unsqueeze(torch.unsqueeze(
            torch.from_numpy(skeleton), dim=0), dim=0)

        out = F.interpolate(
            tensor, size=[max_frames, tensor.shape[-2], tensor.shape[-1]], mode='trilinear')
        tensor = torch.squeeze(torch.squeeze(out, dim=0), dim=0)

        return tensor

    def sample_frames(self, data_num,sample_size):
        # sample #window_size frames from whole video

        each_num = (data_num - 1) / (sample_size - 1)
        idx_list = [0, data_num - 1]
        for i in range(sample_size):
            index = round(each_num * i)
            if index not in idx_list and index < data_num:
                idx_list.append(index)
        idx_list.sort()

        while len(idx_list) < sample_size:
            idx = random.randint(0, data_num - 1)
            if idx not in idx_list:
                idx_list.append(idx)
        idx_list.sort()
        return idx_list



def load_data_sets(window_size=10, batch_size=32, workers=4, is_segmented=False):

    train_ds = GraphDataset("./data/SHREC21", "training", window_size=window_size,
                            use_data_aug=False,
                            normalize=False,
                            scaleInvariance=False,
                            translationInvariance=False,
                            useRandomMoving=True,
                            isPadding=False,
                            useSequenceFragments=False,
                            useMirroring=False,
                            useTimeInterpolation=False,
                            useNoise=True,
                            useScaleAug=False,
                            useTranslationAug=False,
                            use_aug_by_sw=False,
                            sample_classes=False,
                            number_of_samples_per_class=23,
                            is_segmented=is_segmented
                            )
    test_ds = GraphDataset("./data/SHREC21", "test",
                           window_size=window_size,
                           use_data_aug=False,
                           normalize=False,
                           scaleInvariance=False,
                           translationInvariance=False,
                           isPadding=True,
                           number_of_samples_per_class=14,
                           use_aug_by_sw=False,
                           sample_classes=False,
                           is_segmented=is_segmented)
    graph = Graph(layout="SHREC21", strategy="distance")
    print("train data num: ", len(train_ds))
    print("test data num: ", len(test_ds))
    train_loader = torch.utils.data.DataLoader(
        train_ds,
        batch_size=batch_size, shuffle=True,
        num_workers=workers, pin_memory=False)

    val_loader = torch.utils.data.DataLoader(
        test_ds,
        batch_size=batch_size, shuffle=False,
        num_workers=workers, pin_memory=False)
    test_loader = torch.utils.data.DataLoader(
        test_ds,
        batch_size=1, shuffle=False,
        num_workers=workers, pin_memory=False)

    return train_loader, val_loader, test_loader, torch.from_numpy(graph.A)


In [4]:

MaybeTensor=Union[Tensor, TensorPlaceholder]
State = Tuple[
    Tensor, 
    Tensor, 
    Tensor, 
]



class FeedForward(nn.Module, co.CoModule):
    def __init__(self, dim_input: int = 128, dim_feedforward: int = 512):
        super().__init__()
        self.call_mode = CallMode.FORWARD_STEPS
        self.out=nn.Sequential(
        nn.Linear(dim_input, dim_feedforward,dtype=torch.float).cuda(),
        nn.Mish(),
        nn.Linear(dim_feedforward, dim_input,dtype=torch.float).cuda(),
    )
    def clean_state(self):
        pass
    def forward(self, x: Tensor) -> Tensor:
        return self.out(x) 
class Residual(nn.Module, co.CoModule):
    def __init__(self, sublayer: nn.Module, dimension: int, dropout: float = 0.1):
        super().__init__()
        self.call_mode = CallMode.FORWARD_STEPS
        self.sublayer = sublayer
        self.norm = nn.LayerNorm(dimension,dtype=torch.float).cuda()
        self.dropout = nn.Dropout(dropout)
    def clean_state(self):
        self.sublayer.clean_state()
    def forward_steps(self, x: Tensor) -> Tensor:
        return self.forward(x)
    def forward(self, *tensors: Tensor) -> Tensor:
        # Assume that the "query" tensor is given first, so we can compute the
        # residual.  This matches the signature of 'MultiHeadAttention'.
        x=self.dropout(self.sublayer(*tensors))
        # print(x.shape)
        # print(tensors[0].shape)
        x=tensors[0] + x
        x=self.norm(x)
        return x

def _scaled_dot_product_attention_default_state(
    batch_size: int,
    sequence_len: int,
    num_nodes : int,
    embed_dim_k: int,
    embed_dim_v: int,
    query_index=-1,
    init_fn=torch.zeros,
    dtype=None,
    device=None,
):
    init_fn = partial(init_fn, dtype=dtype, device=device)
    B = batch_size
    V=num_nodes
    N = sequence_len
    Nq = sequence_len
    Q_mem = init_fn((B, V, Nq, embed_dim_k)).float()
    K_T_mem = init_fn((B, V, embed_dim_k, N)).float()
    V_mem = init_fn((B, V, N, embed_dim_v)).float()
    return (Q_mem, K_T_mem, V_mem)

def _clone_state(state):
    return [s.clone() for s in state]

def _scaled_dot_product_attention_step(
    prev_state: State,
    q_step: Tensor,  # step input (B, E)
    k_step: Tensor,  # step input (B, E)
    v_step: Tensor,  # step input (B, E)
    T,
    dropout_p: float = 0.0,
) -> Tuple[Tensor, State]:
    """
    Computes the Continual Singe-output Scaled Dot-Product Attention on query, key and value tensors.
    Returns attended values and updated states.

    Args:
        q_step, k_step, v_step: query, key and value tensors for a step. See Shape section for shape details.
        attn_mask: optional tensor containing mask values to be added to calculated
            attention. May be 2D or 3D; see Shape section for details.
        dropout_p: dropout probability. If greater than 0.0, dropout is applied.

    Shape:
        - q_step: :math:`(B, V, E)` where B is batch size, V is the number of vertices and E is embedding dimension.
        - k_step: :math:`(B, V, E)` where B is batch size, V is the number of vertices and E is embedding dimension.
        - v_step: :math:`(B, V, E)` where B is batch size, V is the number of vertices and E is embedding dimension.

        - Output: attention values have shape :math:`(B, Nt, E)`; new state
    """
    # if attn_mask is not None:
    #     logger.warning("attn_mask is not supported yet and will be skipped")
    # if dropout_p != 0.0:
    #     logger.warning("dropout_p is not supported yet and will be skipped")
    
    (
        Q_mem,  # (B, V, Nq, E)
        K_T_mem,  # (B, V, E, Ns)
        V_mem,  # (B, V, Ns, E)
    ) = prev_state
    # print(Q_mem)
    B, V, E = q_step.shape
    q_step = q_step / math.sqrt(E)
    q_sel = (Q_mem[:B,:, 0] if Q_mem.shape[2] > 0 else q_step).unsqueeze(2).cuda()
    # Update states
    # Note: We're allowing the K and V mem to have one more entry than
    # strictly necessary to simplify computatations.
    K_T_new = torch.roll(K_T_mem, shifts=-1, dims=(3,))
    K_T_new[:B, :, :, -1] = k_step
    V_new = torch.roll(V_mem, shifts=-1, dims=(2,))
    V_new[:B, :, -1] = v_step
    
    attn = torch.bmm(q_sel.reshape(-1,1,E), K_T_new[:q_sel.shape[0]].reshape(-1,E,T).cuda())
    K_T_new=K_T_new.detach().cpu()
    attn_sm = F.softmax(attn, dim=-1)
    
    if dropout_p > 0.0:
        attn_sm = F.dropout(attn_sm, p=dropout_p)
    
    # (B, V, Nt, Ns) x (B, V, Ns, E) -> (B, V, Nt, E)
    output = torch.bmm(attn_sm, V_new[:B].reshape(-1,T,E).cuda()).reshape(B,V,-1,E)
    
    if Q_mem.shape[2] > 0:
        Q_new = torch.roll(Q_mem, shifts=-1, dims=(2,))
        Q_new[:B, :, -1] = q_step.detach().cpu()
    else:
        Q_new = Q_mem
    new_states = (Q_new, K_T_new, V_new.detach().cpu())
    
    return output, new_states


class AttentionHead(nn.Module, co.CoModule):
    def __init__(self, is_continual : bool, dim_in: int, dim_v: int, dim_k: int, kernel_size: int = 1 , stride :int =1, dropout :int=.1):
        super().__init__()
        self.call_mode = CallMode.FORWARD_STEPS if is_continual else CallMode.FORWARD
        self.embed_dim_second=False
        self.batch_first=True
        self.d_k=dim_k
        self.d_v=dim_v
        self.dropout=dropout
        self.q_conv=co.Conv2d(
                dim_in,
                dim_k,
                kernel_size=(kernel_size, 1),
                padding=(int((kernel_size - 1) / 2), 0),
                stride=(stride, 1),dtype=torch.float).cuda()
        self.k_conv=co.Conv2d(
                dim_in,
                dim_k,
                kernel_size=(kernel_size, 1),
                padding=(int((kernel_size - 1) / 2), 0),
                stride=(stride, 1),dtype=torch.float).cuda()
        self.v_conv=co.Conv2d(
                dim_in,
                dim_v,
                kernel_size=(kernel_size, 1),
                padding=(int((kernel_size - 1) / 2), 0),
                stride=(stride, 1),dtype=torch.float).cuda()

    def get_state(self) -> Optional[State]:
        """Get model state

        Returns:
            Optional[State]: A State tuple if the model has been initialised and otherwise None.
        """
        if (
            getattr(self, "Q_mem", None) is not None
            and getattr(self, "K_T_mem", None) is not None
            and getattr(self, "V_mem", None) is not None
            and getattr(self, "stride_index", None) is not None
        ):
            return (
                self.Q_mem,
                self.K_T_mem,
                self.V_mem,
                self.stride_index,
            )

    def set_state(self, state: State):
        """Set model state

        Args:
            state (State): State tuple to set as new internal internal state
        """
        (
            self.Q_mem,
            self.K_T_mem,
            self.V_mem,
            self.stride_index,
        ) = state

    def clean_state(self):
        """Clean model state"""
        if hasattr(self, "Q_mem"):
            del self.Q_mem
        
        if hasattr(self, "K_T_mem"):
            del self.K_T_mem
        if hasattr(self, "V_mem"):
            del self.V_mem
        if hasattr(self, "stride_index"):
            del self.stride_index

    def _forward_step(
        self,
        query: Tensor,
        key: Tensor,
        value: Tensor,
        T: int,
        prev_state: State = None,
    ) -> Tuple[MaybeTensor, State]:
        """Forward computation for a single step with state initialisation

        Args:
            query, key, value: step inputs of shape `(B, E)` where B is the batch size and E is the embedding dimension.

        Returns:
            Tuple[MaybeTensor, State]: Step output and new state.
        """
        B, V, E = query.shape
        if prev_state is None:
            prev_state = (
                *_scaled_dot_product_attention_default_state(B, T, V, self.d_k, self.d_v),
                -T,
            )

        o, new_state = _scaled_dot_product_attention_step(
            prev_state[:-1],
            query,
            key,
            value,
            T,
            self.dropout,
        )
        stride_index = prev_state[-1]
        if stride_index < 0:
            stride_index += 1

        new_state = (*new_state, stride_index)
    
        return (
             o,
            new_state,
        )

    def forward_step(
        self,
        T: int,
        query: Tensor,
        key: Tensor = None,
        value: Tensor = None,
        update_state=True,
    ) -> MaybeTensor:
        """
        Args:
            query, key, value: step_inputs for mapping a query and a set of key-value pairs to an output.
                See "Attention Is All You Need" for more details.

        Shapes for inputs:
            - query: :math:`(N, E)` where L is the target sequence length, N is the batch size, E is
              the embedding dimension.
            - key: :math:`(N, E)`, where S is the source sequence length, N is the batch size, E is
              the embedding dimension.
            - value: :math:`(N, E)` where S is the source sequence length, N is the batch size, E is
              the embedding dimension.

        Shapes for outputs:
            - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
              E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
        """
        if key is None:
            key = query
        if value is None:
            value = query

        tmp_state = self.get_state()

        if not update_state and tmp_state:
            backup_state = _clone_state(tmp_state)

        o, tmp_state = self._forward_step(query, key, value, T, tmp_state)
        if self.batch_first and not isinstance(o, TensorPlaceholder):
            o = o.transpose(1, 0)

        if update_state:
            self.set_state(tmp_state)
        elif tmp_state is not None:
            self.set_state(backup_state)

        return o

    def forward_steps(
        self,
        x : Tensor,
        update_state=True,
    ) -> MaybeTensor:
        """Forward computation for multiple steps with state initialisation

        Args:
            x (Tensor): input.
            update_state (bool): Whether internal state should be updated during this operation.

        Returns:
            Tensor: Stepwise layer outputs
        """
        _, T, _, _ = x.shape
        query, key, value= self.projection(x)
        # if key is None:
        #     key = query
        # if value is None:
        #     value = query

        if self.embed_dim_second:
            # N E V T -> N T V E
            query = query.permute(0, 3, 2, 1)
            key = key.permute(0, 3, 2, 1)
            value = value.permute(0, 3, 2, 1)

        if self.batch_first:
            # N T V E -> T N V E
            query, key, value = [x.transpose(1, 0) for x in (query, key, value)]

        tmp_state = self.get_state()

        if not update_state and tmp_state:
            backup_state = _clone_state(tmp_state)
        T = query.shape[0]
        assert T == key.shape[0]
        assert T == value.shape[0]
        outs = []

        for t in range(T):
            # print(t)
            o, tmp_state = self._forward_step(query[t], key[t], value[t], T, tmp_state)
            if isinstance(o, Tensor):
                if self.batch_first:
                    o = o.transpose(0, 1)
                outs.append(o)
        # print("here",T,len(outs),outs[0].shape)

        if update_state:
            self.set_state(tmp_state)
        elif backup_state is not None:
            self.set_state(backup_state)

        if len(outs) == 0:
            return o

        o = torch.stack(outs, dim=2 ).squeeze(3).permute(1,2,0,3)

        return o

    def attention(self,Q,K,V):
      sqrt_dk=torch.sqrt(torch.tensor(self.d_k))
      attention_weights=F.softmax((Q @ K.transpose(-2,-1))/sqrt_dk)
      attention_vectors=attention_weights @ V
      return attention_vectors

    def projection(self,x: Tensor):
        
        x=x.permute(0,3,2,1)
        Q=self.q_conv(x).permute(0,3,2,1)
        K=self.k_conv(x).permute(0,3,2,1)
        V=self.v_conv(x).permute(0,3,2,1)
        return Q, K, V
    def forward(self, x: Tensor) -> Tensor:
        
        batch_size = x.size(0)
        seq_length = x.size(1)
        graph_size=x.size(2)
        
        x=x.permute(0,3,2,1)
        # x=x.transpose(1,2)
        #Q, K, V=torch.split(self.qkv_conv(x), [self.d_k , self.d_k, self.d_v],
        #                            dim=1)
        Q=self.q_conv(x).permute(0,3,2,1)
        K=self.k_conv(x).permute(0,3,2,1)
        V=self.v_conv(x).permute(0,3,2,1)



        x=self.attention(Q,K,V).transpose(1,2).contiguous().view(batch_size,seq_length,graph_size, self.d_k)
        
        return x

class MultiHeadAttention(co.CoModule,nn.Module):
    def __init__(self, is_continual: bool, num_heads: int, dim_in: int,dim_k,dim_q,dim_v,dropout):
        super().__init__()
        self.call_mode = CallMode.FORWARD_STEPS if is_continual else CallMode.FORWARD
        self.heads = nn.ModuleList(
            [AttentionHead(is_continual,dim_in, dim_v, dim_k,dropout=dropout) for _ in range(num_heads)]
        )
        self.linear = nn.Linear(num_heads * dim_k, dim_in,dtype=torch.float).cuda()
    def clean_state(self):
        for h in self.heads:
            h.clean_state()
    
    def forward_steps(self, x: Tensor, pad_end=False, update_state=True) -> Tensor:
        out=self.linear(
            torch.cat([h.forward_steps(x) for h in self.heads], dim=-1)
        ).cuda()
        
        return out
    def forward(self, x) -> Tensor:
        return self.linear(
            torch.cat([h(x) for h in self.heads], dim=-1)
        )

class TransformerGraphEncoderLayer(nn.Module, co.CoModule):
    def __init__(
        self,
        is_continual:bool=False,
        dim_model: int = 128,
        num_heads: int = 8,
        dim_feedforward: int = 512,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.call_mode = CallMode.FORWARD_STEPS if is_continual else CallMode.FORWARD
        dim_v=dim_q = dim_k = max(dim_model // num_heads, 1)
        self.attention = Residual(
            MultiHeadAttention(is_continual,num_heads, dim_model,32,32,32,dropout),
            dimension=dim_model,
            dropout=dropout,
        )
        self.feed_forward = Residual(
            FeedForward(dim_model, dim_feedforward),
            dimension=dim_model,
            dropout=dropout,
        )
        self.norm = nn.LayerNorm(dim_model,dtype=torch.float).cuda()
    def clean_state(self):
        self.attention.clean_state()
    def forward(self, src: Tensor) -> Tensor:
        # print("before",torch.cuda.mem_get_info(torch.device('cuda:0')))
        src = self.attention(self.norm(src))
        # print("after",torch.cuda.mem_get_info(torch.device('cuda:0')))
        return self.feed_forward(src)

class PositionalEncoder(nn.Module, co.CoModule):
    def __init__(self, d_model, max_seq_len = 200):
        super().__init__()
        self.d_model = d_model
        
        # create constant 'pe' matrix with values dependant on z
        # pos and i
        pe = torch.zeros(max_seq_len,20 , d_model)
        for pos in range(max_seq_len):
          for node_id in range(0,20) :
            for i in range(0, d_model, 2):
                pe[pos, node_id, i] = \
                math.sin(pos / (10000 ** ((2 * i)/d_model)))
                pe[pos, node_id, i + 1] = \
                math.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))
                
        pe = pe.unsqueeze(0)
        #self.learnable_pe=nn.Linear(d_model, d_model,dtype=torch.float)
        self.norm=nn.LayerNorm(d_model,dtype=torch.float).cuda()
        self.register_buffer('pe', pe)

    
    def forward(self, x):
        # make embeddings relatively larger
        # x = x * math.sqrt(self.d_model)
        #add constant to embedding
        seq_len = x.size(1)
        
        x = self.norm(x + Variable(self.pe[:,:seq_len,:,:], \
        requires_grad=False).cuda())
        
        return x

class TransformerGraphEncoder(nn.Module, co.CoModule):
    def __init__(
        self,
        is_continual: bool=False,
        num_layers: int = 6,
        dim_model: int = 128,
        num_heads: int = 8,
        dim_feedforward: int = 512,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.call_mode = CallMode.FORWARD_STEPS if is_continual else CallMode.FORWARD
        self.layers = nn.ModuleList(
            [
            TransformerGraphEncoderLayer(is_continual,dim_model, num_heads, dim_feedforward, dropout)      
            for _ in range(num_layers)
            ]
        )
        self.positional_encoder=PositionalEncoder(dim_model)
    def clean_state(self):
        for layer in self.layers:
            layer.clean_state()
    def forward_steps(self, x: Tensor, pad_end=False, update_state=True) -> Tensor:
        return self.forward(x)
    def forward(self, x: Tensor) -> Tensor:
        x += self.positional_encoder(x)
        for layer in self.layers:
            x = layer(x)
        # if self.call_mode==CallMode.FORWARD_STEPS:
        #     self.clean_state()
        return x

In [5]:

def conv_init(conv):
    nn.init.kaiming_normal_(conv.weight, mode='fan_out')
    nn.init.constant_(conv.bias, 0)


class unit_gcn(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 A,
                 use_local_bn=False,
                 kernel_size=1,
                 stride=1,
                 mask_learning=False):
        super(unit_gcn, self).__init__()

        # ==========================================
        # number of nodes
        self.V = 20

        # the adjacency matrixes of the graph
        # self.A = Variable(
        #     A.clone(), requires_grad=False).view(-1, self.V, self.V)

        # number of input channels
        self.in_channels = in_channels

        # number of output channels
        self.out_channels = out_channels

        # if true, use mask matrix to reweight the adjacency matrix
        self.mask_learning = mask_learning

        # number of adjacency matrix (number of partitions)
        self.num_A = A.size(0)

        # if true, each node have specific parameters of batch normalizaion layer.
        # if false, all nodes share parameters.
        self.use_local_bn = use_local_bn
        # ==========================================

        self.conv_list = nn.ModuleList([
            nn.Conv2d(
                self.in_channels,
                self.out_channels,
                kernel_size=(kernel_size, 1),
                padding=(int((kernel_size - 1) / 2), 0),
                stride=(stride, 1), dtype=torch.float).cuda() for i in range(self.num_A)
        ])

        if mask_learning:
            self.mask = nn.Parameter(torch.ones(A.size())).cuda()
        if use_local_bn:
            self.bn = nn.BatchNorm1d(self.out_channels * self.V).cuda()
        else:
            self.bn = nn.BatchNorm2d(self.out_channels, dtype=torch.float).cuda()

        self.act = nn.Mish()

        # initialize
        for conv in self.conv_list:
            conv_init(conv)

    def forward(self, x, A):
        
        x = x.permute(0, 3, 1, 2)

        N, C, T, V = x.size()
        A = A.cuda(x.get_device())

        # reweight adjacency matrix
        if self.mask_learning:
            A = A*self.mask.cuda()
        # graph convolution
        for i, a in enumerate(A):

            xa = x.reshape(-1, V).mm(a).reshape(N, C, T, V)

            if i == 0:
                y = self.conv_list[i](xa)
            else:
                y = y+self.conv_list[i](xa)

        # batch normalization
        if self.use_local_bn:
            y = y.permute(0, 1, 3, 2).contiguous().view(
                N, self.out_channels * V, T)
            y = self.bn(y)
            y = y.view(N, self.out_channels, V, T).permute(0, 1, 3, 2)
        else:
            y = self.bn(y.clone())

        # nonliner
        y = self.act(y.clone())

        y = y.clone().permute(0, 2, 3, 1)
        return y


class SGCN(nn.Module):
    def __init__(self, features_in, features_out, A) -> None:
        super().__init__()
        default_backbone = [(features_in, 64, 1), (64, 64, 1), (64, 64, 1), (64, 64, 1), (64, features_out, 2), (features_out, features_out, 1),(features_out, features_out, 1)]
        # , (128, 256, 2), (256, 256, 1), (256, 256, 1) , (256, 512, 2), (512, 512, 1), (512, 512, 1)
        # default_backbone = [(3,128,1)]
        self.conv_layers = nn.ModuleList([
            # unit_agcn(dim_in, dim_out, A)
            unit_gcn(dim_in, dim_out, A, mask_learning=True)
            for dim_in, dim_out, kernel_size in default_backbone
        ])

    def forward(self, x: Tensor, adjacency_matrix: Tensor) -> torch.Tensor:
        for l in self.conv_layers:
            x = l(x, adjacency_matrix)

        return x


In [6]:


# model definition



 
class STrGCN(pl.LightningModule):

    def __init__(self, adjacency_matrix,optimizer_params, labels, num_classes : int=18, d_model: int=512, n_heads: int=8,
                 nEncoderlayers: int=6, dropout: float = 0.1):
        super(STrGCN, self).__init__()
        # not the best model...
        self.labels=labels
        features_in=3       
        self.cnf_matrix= torch.zeros(num_classes, num_classes).cuda()
        self.Learning_Rate, self.betas, self.epsilon, self.weight_decay=optimizer_params
        self.num_classes=num_classes
        self.adjacency_matrix=adjacency_matrix.float()
        self.is_continual=True
        self.train_acc = torchmetrics.Accuracy()
        self.valid_acc = torchmetrics.Accuracy()
        self.test_acc = torchmetrics.Accuracy()
        self.val_f1_score=torchmetrics.F1Score(num_classes)
        self.train_f1_score=torchmetrics.F1Score(num_classes)
        self.test_f1_score=torchmetrics.F1Score(num_classes)
        self.val_jaccard=torchmetrics.JaccardIndex(num_classes)
        self.train_jaccard=torchmetrics.JaccardIndex(num_classes)
        self.test_jaccard=torchmetrics.JaccardIndex(num_classes)
        self.confusion_matrix=torchmetrics.ConfusionMatrix(num_classes)
        self.gcn=SGCN(features_in,d_model,self.adjacency_matrix)

        self.encoder=TransformerGraphEncoder(is_continual=self.is_continual,dropout=dropout,num_heads=n_heads,dim_model=d_model, num_layers=nEncoderlayers)

        self.out = nn.Sequential(
            nn.Linear(d_model, d_model,dtype=torch.float).cuda(),
            nn.Mish(),
            # nn.Dropout(dropout),
            nn.LayerNorm(d_model,dtype=torch.float).cuda(),
            nn.Linear(d_model,num_classes,dtype=torch.float).cuda()
          )

        self.d_model = d_model
        self.init_parameters()
    def init_parameters(self):
        for name,p in self.named_parameters() :
          if p.dim() > 1:
              nn.init.xavier_uniform_(p)
    def get_fp_rate(self,score,labels):
        


        cnf_matrix = self.confusion_matrix(score.detach().cpu(), labels.detach().cpu())
        FP = cnf_matrix.sum(axis=0) - np.diag(cnf_matrix)
        FN = cnf_matrix.sum(axis=1) - np.diag(cnf_matrix)
        TP = np.diag(cnf_matrix)
        TN = cnf_matrix.sum() - (FP + FN + TP)

        FP = FP.type(torch.float)
        TN = TN.type(torch.float)

        # # Sensitivity, hit rate, recall, or true positive rate
        # TPR = TP/(TP+FN)
        # # Specificity or true negative rate
        # TNR = TN/(TN+FP)
        # # Precision or positive predictive value
        # PPV = TP/(TP+FP)
        # # Negative predictive value
        # NPV = TN/(TN+FN)
        # # Fall out or false positive rate
        FPR = FP/(FP+TN)
        # # False negative rate
        # FNR = FN/(TP+FN)
        # # False discovery rate
        # FDR = FP/(TP+FP)
        # # Overall accuracy
        # ACC = (TP+TN)/(TP+FP+FN+TN)
        return torch.sum(torch.nan_to_num(FPR),dim=-1) 
    def forward(self, x):
        # print(x.shape)
        x=x.type(torch.float).cuda() 
        
        # print(x.shape)
        #spatial features from SGCN
        x=self.gcn(x,self.adjacency_matrix)
        
        # print(x.shape)
        # print(x.shape)
        # temporal features from TGE
        x=self.encoder(x)
        
        
        # print(x.shape)

        # Global average pooling
        N,T,V,C=x.shape
        x=x.permute(0,3,1,2)
        # V pooling
        x = F.avg_pool2d(x, kernel_size=(1, V)).view(N,C,T)
        
        # T pooling
        x = F.avg_pool1d(x, kernel_size=T).view(N,C)
        
        # print(x)
        # Classifier
        x=self.out(x)
        # print(torch.equal(x[0],x[1]))
        
        return x
    def plot_confusion_matrix(self,filename,eps=1e-5) :
        import seaborn as sn
        confusion_matrix_sum_vec= torch.sum(self.cnf_matrix,dim=1) +eps
        
        confusion_matrix_percentage=(self.cnf_matrix /  confusion_matrix_sum_vec.view(-1,1) )

        plt.figure(figsize = (18,16))
        sn.heatmap(confusion_matrix_percentage.cpu().numpy(), annot=True,cmap="coolwarm", xticklabels=self.labels,yticklabels=self.labels)
        plt.savefig(filename,format="eps")
    def training_step(self, batch, batch_nb):
        # REQUIRED
        x = batch[0].float()
        y = batch[1]
        y = y.type(torch.LongTensor)
        y = y.cuda()
        y = Variable(y, requires_grad=False)
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        # print(loss)
        # input()
        #l1 regularization
        l1_lambda = 1e-4
        l1_norm = sum( p.abs().sum()  for p in self.parameters())

        loss_with_l1 = loss + l1_lambda * l1_norm

        self.train_acc(y_hat, y)
        self.train_f1_score(y_hat, y)
        
        self.log('train_loss', loss,on_epoch=True,on_step=True)
        self.log('train_acc', self.train_acc.compute(), prog_bar=True, on_step=True, on_epoch=True)

        # self.log('train_F1_score', self.train_f1_score.compute(), prog_bar=True, on_step=True, on_epoch=True)
        # self.log('train_Jaccard', self.train_jaccard(y_hat, y), prog_bar=True, on_step=True, on_epoch=True)
        # self.log('train_FP_rate', self.get_fp_rate(torch.argmax(torch.nn.functional.softmax(y_hat, dim=-1), dim=-1), y), prog_bar=True, on_step=True, on_epoch=True)

        return loss

    def validation_step(self, batch, batch_nb):
        # OPTIONAL
        
        x = batch[0].float()
        y = batch[1]
        y = y.type(torch.LongTensor)
        y = y.cuda()
        targets = Variable(y, requires_grad=False)
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, targets)
        # print(loss)
        # input()
        self.valid_acc(y_hat, y)
        self.val_f1_score(y_hat, y)
        
        self.log('val_loss', loss, prog_bar=True,on_epoch=True,on_step=True)
        self.log('val_accuracy', self.valid_acc.compute(), prog_bar=True, on_step=True, on_epoch=True)

        # self.log('val_F1_score', self.val_f1_score.compute(), prog_bar=True, on_step=True, on_epoch=True)
        # self.log('val_Jaccard', self.val_jaccard(y_hat, y), prog_bar=True, on_step=True, on_epoch=True)
        # self.log('val_FP_rate', self.get_fp_rate(torch.argmax(torch.nn.functional.softmax(y_hat, dim=-1), dim=-1), y), prog_bar=True, on_step=True, on_epoch=True)


    def training_epoch_end(self, outputs):
        #for name,p in self.named_parameters() :
        #    print(p.shape)
        
        self.train_acc.reset()

    def validation_epoch_end(self, outputs):
        self.valid_acc.reset()

    def test_step(self, batch, batch_nb):
        # global confusion_matrix
        # OPTIONAL
        x = batch[0].float()
        y = batch[1]
        y = y.type(torch.LongTensor)
        y = y.cuda()
        targets = Variable(y, requires_grad=False)
        y_hat = self(x)
        _, preds = torch.max(y_hat, 1)
        self.test_acc(y_hat, targets)
        self.test_f1_score(y_hat, y)
        
        loss = F.cross_entropy(y_hat, targets)        
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_accuracy', self.test_acc.compute(), prog_bar=True)

        # self.log('test_F1_score', self.val_f1_score.compute(), prog_bar=True)
        # self.log('test_Jaccard', self.test_jaccard(y_hat, y), prog_bar=True, on_step=True, on_epoch=True)
        # self.log('test_FP_rate', self.get_fp_rate(torch.argmax(torch.nn.functional.softmax(y_hat, dim=-1), dim=-1), y), prog_bar=True, on_step=True, on_epoch=True)

        
        self.cnf_matrix+=self.confusion_matrix(preds,targets)

    def on_test_end(self):
        time_now=datetime.today().strftime('%Y-%m-%d_%H_%M_%S')
        self.plot_confusion_matrix(f"./Confusion_matrices/Confusion_matrix_{time_now}.eps")

    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers and learning_rate schedulers
        # (LBFGS it is automatically supported, no need for closure function)
        

        opt = torch.optim.RAdam(filter(lambda p: p.requires_grad, self.parameters()), lr=self.Learning_Rate, weight_decay=self.weight_decay)
        reduce_lr_on_plateau = torch.optim.lr_scheduler.ReduceLROnPlateau(
            opt,
            mode='min',
            factor=.5,
            patience=2,
            min_lr=1e-4,
            verbose=True
        )

        return  {"optimizer": opt, "lr_scheduler": reduce_lr_on_plateau, "monitor": "val_loss"}

In [7]:
# from sklearn.metrics import confusion_matrix
labels = [
    "",
    "RIGHT",
    "KNOB",
    "CROSS",
    "THREE",
    "V",
    "ONE",
    "FOUR",
    "GRAB",
    "DENY",
    "MENU",
    "CIRCLE",
    "TAP",
    "PINCH",
    "LEFT",
    "TWO",
    "OK",
    "EXPAND",
]
DATASETS_PATH = "datasets/"
DS_NAME = "shrec21"
DS_PATH = DATASETS_PATH + "shrec21/"
batch_size = 32
workers = 4
lr = 1e-4
num_classes = 18
window_size=10
input_shape = (window_size,20,3)
device = torch.device('cuda')
d_model=128
n_heads=8
lr = 1e-3
betas=(.9,.98)
epsilon=1e-9
weight_decay=5e-4
optimizer_params=(lr,betas,epsilon,weight_decay)
Max_Epochs = 500
Early_Stopping = 25
dropout_rate=.3
num_classes=18
stride=1
def compute_energy(x):
    N, T, V, C = x.shape

    x_values= x[:,:,:,0]
    y_values = x[:, :, :, 1]
    z_values = x[:, :, :, 2]
    w=None
    for v in range(V):
        w_v=None
        for t in range(1,T):
            if w_v == None :
                w_v = torch.sqrt(( x_values[:,t,v]/x_values[:,t-1,v] -1)**2 + ( y_values[:,t,v]/y_values[:,t-1,v] -1)**2 + ( z_values[:,t,v]/z_values[:,t-1,v] -1)**2)
            else :
                w_v  += torch.sqrt((x_values[:, t, v] / x_values[:, t - 1, v] - 1) ** 2 + (
                            y_values[:, t, v] / y_values[:, t - 1, v] - 1) ** 2 + (
                                           z_values[:, t, v] / z_values[:, t - 1, v] - 1) ** 2)
        if w==None :
            w=w_v
        else :
            w+=w_v
    return w
def init_data_loader():
    train_loader, val_loader, test_loader, graph= load_data_sets(is_segmented=False)


    return train_loader, val_loader, test_loader, graph


def init_model(graph, optimizer_params, labels,num_classes,dropout_rate=.1):
    model = STrGCN(graph, optimizer_params, labels, d_model=128,n_heads=8,num_classes=num_classes, dropout=dropout_rate)
    

    return model





def get_acc(score, labels):
    score = score.cpu().data.numpy()
    labels = labels.cpu().data.numpy()
    outputs = np.argmax(score, axis=1)
    return np.sum(outputs == labels) / float(labels.size)

def get_fp_rate(score,labels):
    confusion_matrix=torchmetrics.ConfusionMatrix(num_classes=num_classes)


    cnf_matrix = confusion_matrix(score.detach().cpu(), labels.detach().cpu())
    FP = cnf_matrix.sum(axis=1) - np.diag(cnf_matrix)
    FN = cnf_matrix.sum(axis=0) - np.diag(cnf_matrix)
    TP = np.diag(cnf_matrix)
    TN = cnf_matrix.sum() - (FP + FN + TP)

    FP = FP.type(torch.float)
    TN = TN.type(torch.float)

    # # Sensitivity, hit rate, recall, or true positive rate
    # TPR = TP/(TP+FN)
    # # Specificity or true negative rate
    # TNR = TN/(TN+FP)
    # # Precision or positive predictive value
    # PPV = TP/(TP+FP)
    # # Negative predictive value
    # NPV = TN/(TN+FN)
    # # Fall out or false positive rate
    FPR = FP/(FP+TN)
    # # False negative rate
    # FNR = FN/(TP+FN)
    # # False discovery rate
    # FDR = FP/(TP+FP)
    # # Overall accuracy
    # ACC = (TP+TN)/(TP+FP+FN+TN)
    return torch.sum(torch.nan_to_num(FPR),dim=-1)

def get_window_label(label):
    N,W=label.shape

    sum=torch.zeros((1,num_classes))
    for t in range(N):
        sum[0,label[t]] += 1
    out=sum.argmax(dim=-1)
    return  out 
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False



# fold for saving trained model...
# change this path to the fold where you want to save your pre-trained model
model_fold = "./models/costr_gcn/online_model_checkpoints"
try:
    os.mkdir(model_fold)
except:
    pass

train_loader, val_loader, test_loader, graph = init_data_loader()



# .........inital model
print("\n loading model.............")
model = model = STrGCN.load_from_checkpoint(checkpoint_path="./models/STRGCN-SHREC17_2022-08-29_17_59_22/best_model-128-8-v1.ckpt",adjacency_matrix=graph, optimizer_params=optimizer_params, labels=labels, d_model=128,n_heads=8,num_classes=num_classes, dropout=dropout_rate)
# model_solver = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)

# # ........set loss
criterion = torch.nn.CrossEntropyLoss()



# parameters recording training log


f1_score=torchmetrics.F1Score(num_classes=num_classes)

jaccard = torchmetrics.JaccardIndex(num_classes=num_classes)
avg_precision = torchmetrics.AveragePrecision(num_classes=num_classes)
eps=1e-1

100%|████████████████████████████████████████████████████████████████████████████████| 108/108 [00:22<00:00,  4.78it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 72/72 [00:12<00:00,  5.73it/s]


train data num:  108
test data num:  72

 loading model.............




In [None]:
start_time = time.time()
#         # ***********evaluation***********
print("*"*10,"Testing","*"*10)
with torch.no_grad():
    val_loss = 0
    val_f1 = 0
    val_jaccard=0
    val_fp_rate=0
    val_avg_precision=0
    score_list = None
    label_list = None
    acc_sum = 0
    # model.eval()
    val_loss_epoch = 0
    val_jaccard_epoch=0
    val_fp_rate_epoch=0
    val_avg_precision_epoch=0
    val_f1_epoch = 0
    for i, batch in enumerate(test_loader):
        print("batch=",i)
        x,y,index=batch
        y=torch.stack(y)
        N, T, V, C = x.shape





        score_list = None
        label_list = None   
        num_windows=T-window_size // window_size
        for t in tqdm(range(0,T-window_size+1,stride), leave=False):
            # print(i)
            window=x[:,t:t+window_size]
            label=get_window_label(y[t:t+window_size])
            window = x[:,t: t+window_size].clone()
            if t < 2*stride :
                continue
            window_i_m_2 = x[:,(t-2*stride): (t-2*stride)+window_size].clone()
            window_i_m_1 = x[:,(t-1*stride):(t-1*stride)+window_size ].clone()
            window_i = x[:,t: t+window_size].clone()
            window_i_p_1 = x[:,t+1*stride: t+1*stride+window_size].clone()
            window_i_p_2 = x[:,t+2*stride: (t+2*stride)+window_size].clone()

            w_1=compute_energy(window_i_m_2)

            w_2=compute_energy(window_i_m_1)
            w_3=compute_energy(window_i)
            w_4=compute_energy(window_i_p_1)
            w_5=compute_energy(window_i_p_2)
            d_wi=(w_4-w_2)/((t+1*stride)-(t-1*stride))
            d_wi_m_1=(w_3-w_1)/(t-(t-2*stride))
            d_wi_p_1=(w_5-w_3)/((t+2*stride)-t)
            if d_wi < eps and d_wi_m_1 > 0 and d_wi_p_1 < 0 :
                score = model(window)

                if score_list is None:
                    score_list = score
                    label_list = label
                else:
                    score_list = torch.cat((score_list, score), 0)
                    label_list = torch.cat((label_list, label), 0)


        loss = criterion(score_list.detach().cpu(), label_list.detach().cpu())
        score_list_labels= torch.argmax(torch.nn.functional.softmax(score_list, dim=-1), dim=-1)
        val_f1_step= f1_score(score_list_labels.detach().cpu(), label_list.detach().cpu())
        val_jaccard_step= jaccard(score_list_labels.detach().cpu(), label_list.detach().cpu())
        val_fp_rate_step= get_fp_rate(score_list_labels.detach().cpu(), label_list.detach().cpu())
        val_avg_precision_step=avg_precision(score_list.detach().cpu(), label_list.detach().cpu())
        val_f1_epoch += val_f1_step
        val_jaccard_epoch += val_jaccard_step
        val_fp_rate_epoch += val_fp_rate_step
        val_avg_precision_epoch+=val_avg_precision_step
        val_loss += loss
        print("*** SHREC  21"
            "val_loss_step: %.6f,"
            "val_F1_step: %.6f ***,"
            "val_jaccard_step: %.6f ***"
            "val_fp_rate_step: %.6f ***"
            "val_avg_precision_step: %.6f ***"
            % ( loss, val_f1_step,val_jaccard_step, val_fp_rate_step,val_avg_precision_step))

    val_loss = val_loss / (float(i + 1))
    val_f1 = val_f1_epoch.item() / (float(i + 1))
    val_jaccard = val_jaccard_epoch / (float(i + 1))
    val_fp_rate = val_fp_rate_epoch / (float(i + 1))
    val_avg_precision = val_avg_precision_epoch / (float(i + 1))
    print("*** SHREC 21, "
            "val_loss: %.6f,"
            "val_F1: %.6f ***,"
            "val_jaccard: %.6f ***"
            "val_fp_rate: %.6f ***"
            "val_avg_precision_rate: %.6f ***"
            % (val_loss, val_f1,val_jaccard, val_fp_rate, val_avg_precision))




********** Testing **********
here

