In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install torchinfo

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchinfo
  Downloading torchinfo-1.7.2-py3-none-any.whl (22 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.7.2


In [None]:
import os
import json
import random
import numpy as np
import pandas as pd
import pickle
import math

import torchinfo
from itertools import product
import torch 
from torch import nn 
from torch import optim
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm.autonotebook import tqdm
import itertools
import random
import copy
import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
import cv2
import json
from sklearn.model_selection import train_test_split
from functools import partial
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor

%matplotlib inline
%config InlineBackend.figure_format='retina'

sns.set(style='whitegrid', palette='muted', font_scale=1.2)

HAPPY_COLORS_PALETTE = ["#01BEFE",
                        "#FFDD00",
                        "#FF7D00",
                        "#FF006D",
                        "#ADFF02",
                        "#8F00FF"]

sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE))
rcParams['figure.figsize'] = 12, 8

In [None]:
import torch 
import torch.nn as nn
from torch.nn import functional as F
import numpy as np




class BasicBlock(nn.Module):
    """
    Basic block is composed of 2 CNN layers with residual connection.
    Each CNN layer is followed by batchnorm layer and swish activation 
    function. 
    Args:
        in_channel: number of input channels
        out_channel: number of output channels
        k: (default = 1) kernel size
    """
    def __init__(self, in_channel, out_channel, k=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_channel,
            out_channel,
            kernel_size=k,
            padding=(0, 0),
            stride=(1, 1))
        self.bn1 = nn.BatchNorm2d(out_channel)

        self.conv2 = nn.Conv2d(
            out_channel,
            out_channel,
            kernel_size=1,
            padding=(0, 0),
            stride=(1, 1))
        self.bn2 = nn.BatchNorm2d(out_channel)

        self.shortcut = nn.Sequential()
        # if in_channel != out_channel:
        self.shortcut.add_module(
            'conv',
            nn.Conv2d(
                in_channel,
                out_channel,
                kernel_size=k,
                padding=(0,0),
                stride=(1,1)))
        self.shortcut.add_module('bn', nn.BatchNorm2d(out_channel))

    def swish(self,x):
        """
        We use swish in spatio-temporal encoding/decoding. We tried with 
        other activation functions such as ReLU and LeakyReLU. But we 
        achieved the best performance with swish activation function.
        Args:
            X: tensor: (batch_size, ...)
        Return:
            _: tensor: (batch, ...): applies swish 
            activation to input tensor and returns  
        """
        return x*torch.sigmoid(x)

    def forward(self, x):
        y = self.swish(self.conv1(x))
        y = self.swish(self.conv2(y))
        y = y + self.shortcut(x)
        y = self.swish(y)
        return y


class BasicBlockTranspose(nn.Module):
    """
    Basic block is composed of 2 CNN layers with residual connection.
    Each CNN layer is followed by batchnorm layer and swish activation 
    function. 
    Args:
        in_channel: number of input channels
        out_channel: number of output channels
        k: (default = 1) kernel size
    """
    def __init__(self, in_channel, out_channel, k=(1,1)):
        super(BasicBlockTranspose, self).__init__()
        self.stride = (1, 1)
        self.padding = (0, 0)
        self.k = k
        self.conv1 = nn.ConvTranspose2d(
            in_channel,
            out_channel,
            kernel_size=k,
            padding=self.padding,
            stride=self.stride)
        self.bn1 = nn.BatchNorm2d(out_channel)

        self.conv2 = nn.ConvTranspose2d(
            out_channel,
            out_channel,
            kernel_size=1,
            padding=self.padding,
            stride=self.stride)
        self.bn2 = nn.BatchNorm2d(out_channel)

        self.shortcut = nn.Sequential()
        # if in_channel != out_channel:
        self.shortcut.add_module(
            'conv',
            nn.ConvTranspose2d(
                in_channel,
                out_channel,
                kernel_size=k,
                padding=self.padding,
                stride=self.stride))
        self.shortcut.add_module('bn', nn.BatchNorm2d(out_channel))

    def get_h_out(self,h_in):
        return (h_in - 1)*self.stride[0]-2*self.padding[0]+(self.k[0]-1)+1
    def get_w_out(self,w_in):
        return (w_in - 1)*self.stride[1]-2*self.padding[1]+(self.k[1]-1)+1

    def swish(self,x):
        """
        We use swish in spatio-temporal encoding/decoding. We tried with 
        other activation functions such as ReLU and LeakyReLU. But we 
        achieved the best performance with swish activation function.
        Args:
            X: tensor: (batch_size, ...)
        Return:
            _: tensor: (batch, ...): applies swish 
            activation to input tensor and returns  
        """
        return x*torch.sigmoid(x)

    def forward(self, x):
        y = self.swish(self.bn1(self.conv1(x)))
        y = self.swish(self.bn2(self.conv2(y)))
        y = y + self.shortcut(x)
        y = self.swish(y)
        return y



class Self_Attn_Seq(nn.Module):
    def __init__(self,in_dim, n_head=3):
        super(Self_Attn_Seq,self).__init__()
        input_dim = in_dim
        self.n_head = n_head # number of attenn head
        self.hidden_size_attention = input_dim // self.n_head
        self.w_q = nn.Linear(input_dim, self.n_head * self.hidden_size_attention)
        self.w_k = nn.Linear(input_dim, self.n_head * self.hidden_size_attention)
        self.w_v = nn.Linear(input_dim, self.n_head * self.hidden_size_attention)
        nn.init.normal_(self.w_q.weight, mean=0, std=np.sqrt(2.0 / (input_dim + self.hidden_size_attention)))
        nn.init.normal_(self.w_k.weight, mean=0,
                        std=np.sqrt(2.0 / (input_dim + self.hidden_size_attention)))
        nn.init.normal_(self.w_v.weight, mean=0,
                        std=np.sqrt(2.0 / (input_dim + self.hidden_size_attention)))
        self.temperature = np.power(self.hidden_size_attention, 0.5)

        self.softmax = nn.Softmax(dim=2)
        self.linear2 = nn.Linear(self.n_head * self.hidden_size_attention, input_dim)
        self.layer_norm = nn.LayerNorm(input_dim)
        self.gamma = nn.Parameter(torch.zeros(1))
    

    def forward(self, q):
        n_head = self.n_head
        residual = q
        k, v = q, q
        bs, len, _ = q.size()
        q = self.w_q(q).view(bs, len, n_head, self.hidden_size_attention)
        k = self.w_k(k).view(bs, len, n_head, self.hidden_size_attention)
        v = self.w_v(v).view(bs, len, n_head, self.hidden_size_attention)

        q = q.permute(2, 0, 1, 3).contiguous().view(-1, len, self.hidden_size_attention)
        k = k.permute(2, 0, 1, 3).contiguous().view(-1, len, self.hidden_size_attention)
        v = v.permute(2, 0, 1, 3).contiguous().view(-1, len, self.hidden_size_attention)

        # generate mask
        subsequent_mask = torch.triu(
            torch.ones((len, len), device=q.device, dtype=torch.uint8), diagonal=1)
        subsequent_mask = subsequent_mask.unsqueeze(0).expand(bs, -1, -1).gt(0)
        mask = subsequent_mask.repeat(n_head, 1, 1)

        # self attention
        attn = torch.bmm(q, k.transpose(1, 2)) / self.temperature
        attn = attn.masked_fill(mask, -np.inf)
        attn = self.softmax(attn)

        output = torch.bmm(attn, v)
        output = output.view(n_head, bs, len, self.hidden_size_attention)
        output = output.permute(1, 2, 0, 3).contiguous().view(bs, len, -1)
        output = self.gamma * self.linear2(output) + residual


        attn = attn.view(n_head,bs,len,len)
        attn_avg = torch.mean(attn,0)
        return output, attn_avg





class DSAGGenerator(nn.Module):
    def __init__(self,  
                 seq_len, 
                 input_size, 
                 embedding_size:int,
                 temporal_decoder_filters=[4,8,14,16],
                 feat_size = [2,4],
                 internal_attention = [168]
                 ):
        super(DSAGGenerator, self).__init__()
        self.embedding_size = embedding_size

        temporal_decoder_filters.append(seq_len-2)
        self.temporal_decoder_filters = temporal_decoder_filters

        self.latent_dim_inner = self.embedding_size//self.temporal_decoder_filters[0]

        self.input_size = input_size
        self.internal_attention = internal_attention
        self.feat_sizes = feat_size
        self.seq_len = seq_len

        #transpose blocks
        self.decode_s1 = BasicBlockTranspose(self.latent_dim_inner//self.feat_sizes[0], self.latent_dim_inner//self.feat_sizes[0], k=(3,1))
        self.decode_s2 = BasicBlockTranspose(self.internal_attention[0]//self.feat_sizes[1], self.internal_attention[0]//self.feat_sizes[1], k=(3,1))

        # decoder 
        self.conv1 = BasicBlock(1,1)
        self.conv2 = BasicBlock(1,1)
        self.conv3 = BasicBlock(1,1)
        self.conv4 = BasicBlock(1,1)
        self.decode_t = BasicBlock(self.temporal_decoder_filters[0],self.temporal_decoder_filters[1])
        self.decode_t1 = BasicBlock(self.temporal_decoder_filters[1],self.temporal_decoder_filters[2])
        self.decode_t2 = BasicBlock(self.decode_s1.get_h_out(self.temporal_decoder_filters[2]),
                                    self.temporal_decoder_filters[3])
        self.decode_t3 = BasicBlock(self.temporal_decoder_filters[3],self.temporal_decoder_filters[4])
        self.decode_t4 = BasicBlock(self.decode_s2.get_h_out(self.temporal_decoder_filters[4]),
                                    self.seq_len)

        # self attention layer
        self.decoder_attn1 = Self_Attn_Seq(self.latent_dim_inner)
        self.decoder_attn2 = Self_Attn_Seq(self.internal_attention[0])
        self.decoder = nn.Linear(self.latent_dim_inner, self.internal_attention[0])
        self.decoder1 = nn.Linear(self.internal_attention[0],self.input_size)

        
        # self.decode_s3 = BasicBlockTranspose(22, 22, k=(3,1))



    def forward(self, X):
        """0
        The deocder is opposit of the encoder. It takes the vector sampled
        from a mixture of gaussian parameter conditioned by class label on-
        hot vector and viewpoint vector, upsamples it in the temporal dimension 
        first and then upsamples it in the spatial dimension.
        Args:
            X: tensor: (batch_size, 4, ...): sampled vector conditionied on class 
            label and viewpoint
        Return:
            x: tensor: (batch_size, 32, 48, 6): generated human motion
        """

        N = X.shape[0]
        X = X.reshape((N,self.temporal_decoder_filters[0],-1))
        N,T,J = X.shape
        x, attn = self.decoder_attn1(X)

        # temporal decoding
        x = x.reshape((N,T,J//self.feat_sizes[0],self.feat_sizes[0]))
        x = self.decode_t(x)
        x = self.decode_t1(x)

        # ----------------------------------------------------------------
        # ------------------------- newly added --------------------------
        # ----------------------------------------------------------------

        x = x.transpose(2,1)
        x = self.decode_s1(x)
        x = x.transpose(2,1)

        # ----------------------------------------------------------------
        # pose decoding
        x = x.reshape((N*self.decode_s1.get_h_out(self.temporal_decoder_filters[2]),1,J//self.feat_sizes[0],self.feat_sizes[0]))
        x = self.conv1(x)
        x = x.reshape((N,self.decode_s1.get_h_out(self.temporal_decoder_filters[2]), -1))

        x = self.decoder(x)
        x, attn = self.decoder_attn2(x)
        # ------------------------ End of block one ---------------------

        
        N,T,J = x.shape
        # temporal decoding
        x = x.reshape((N,T,J//self.feat_sizes[1], self.feat_sizes[1]))
        x = self.decode_t2(x)
        x = self.decode_t3(x)

        # ----------------------------------------------------------------
        # ------------------------- Transpose block --------------------------
        # ----------------------------------------------------------------

        x = x.transpose(2,1)
        x = self.decode_s2(x)
        x = x.transpose(2,1)

        # ----------------------------------------------------------------
        # pose decoding
        x = x.reshape((N*self.seq_len,1,J//self.feat_sizes[1],self.feat_sizes[1]))
        x = self.conv2(x)
        x = x.reshape((N,self.seq_len, -1))
        x = self.decoder1(x)
        # ------------------------ End of block two ---------------------

        return x

class norm_data(nn.Module):
    def __init__(self, dim=3, joints=20):
        super(norm_data, self).__init__()

        self.bn = nn.BatchNorm1d(dim*joints)

    def forward(self, x):
        bs, c, num_joints, step = x.size()
        x = x.view(bs, -1, step)
        x = self.bn(x)
        x = x.view(bs, -1, num_joints, step).contiguous()
        return x

class embed(nn.Module):
    def __init__(self, dim=3, joint=20, hidden_dim=128, norm=True, bias=False):
        super(embed, self).__init__()

        if norm:
            self.cnn = nn.Sequential(
                norm_data(dim, joint),
                cnn1x1(dim, 64, bias=bias),
                nn.ReLU(),
                cnn1x1(64, hidden_dim, bias=bias),
                nn.ReLU(),
            )
        else:
            self.cnn = nn.Sequential(
                cnn1x1(dim, 64, bias=bias),
                nn.ReLU(),
                cnn1x1(64, hidden_dim, bias=bias),
                nn.ReLU(),
            )

    def forward(self, x):
        x = self.cnn(x)
        return x

class cnn1x1(nn.Module):
    def __init__(self, dim1 = 3, dim2 =3, bias = True):
        super(cnn1x1, self).__init__()
        self.cnn = nn.Conv2d(dim1, dim2, kernel_size=1, bias=bias)

    def forward(self, x):
        x = self.cnn(x)
        return x

class local(nn.Module):
    def __init__(self, dim1 = 3, dim2 = 3, bias = False):
        super(local, self).__init__()
        self.maxpool = nn.AdaptiveMaxPool2d((1, None))
        self.cnn1 = nn.Conv2d(dim1, dim1, kernel_size=(1, 3), padding=(0, 1), bias=bias)
        self.bn1 = nn.BatchNorm2d(dim1)
        self.relu = nn.ReLU()
        self.cnn2 = nn.Conv2d(dim1, dim2, kernel_size=1, bias=bias)
        self.bn2 = nn.BatchNorm2d(dim2)
        self.dropout = nn.Dropout2d(0.2)

    def forward(self, x1):
        x1 = self.maxpool(x1)
        x = self.cnn1(x1)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.cnn2(x)
        x = self.bn2(x)
        x = self.relu(x)

        return x

class gcn_spa(nn.Module):
    def __init__(self, in_feature, out_feature, bias = False):
        super(gcn_spa, self).__init__()
        self.bn = nn.BatchNorm2d(out_feature)
        self.relu = nn.ReLU()
        self.w = cnn1x1(in_feature, out_feature, bias=False)
        self.w1 = cnn1x1(in_feature, out_feature, bias=bias)


    def forward(self, x1, g):
        x = x1.permute(0, 3, 2, 1).contiguous()
        x = g.matmul(x)
        x = x.permute(0, 3, 2, 1).contiguous()
        x = self.w(x) + self.w1(x1)
        x = self.relu(self.bn(x))
        return x

class compute_g_spa(nn.Module):
    def __init__(self, dim1 = 64 *3, dim2 = 64*3, bias = False):
        super(compute_g_spa, self).__init__()
        self.dim1 = dim1
        self.dim2 = dim2
        self.g1 = cnn1x1(self.dim1, self.dim2, bias=bias)
        self.g2 = cnn1x1(self.dim1, self.dim2, bias=bias)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x1):

        g1 = self.g1(x1).permute(0, 3, 2, 1).contiguous()
        g2 = self.g2(x1).permute(0, 3, 1, 2).contiguous()
        g3 = g1.matmul(g2)
        g = self.softmax(g3)
        return g
    

class SGN(nn.Module):
    def __init__(self, num_joint, seg, hidden_size=128, bs=32, is_3d=True, train=True, bias=True, device='cpu'):
        super(SGN, self).__init__()

        self.dim1 = hidden_size
        self.dim_unit = hidden_size // 4 
        self.seg = seg
        self.num_joint = num_joint
        self.bs = bs

        if is_3d:
          self.spatial_dim = 3
        else:
          self.spatial_dim = 2

        if train:
            self.spa = self.one_hot(bs, num_joint, self.seg)
            self.spa = self.spa.permute(0, 3, 2, 1).to(device)
            self.tem = self.one_hot(bs, self.seg, num_joint)
            self.tem = self.tem.permute(0, 3, 1, 2).to(device)
        else:
            self.spa = self.one_hot(32 * 5, num_joint, self.seg)
            self.spa = self.spa.permute(0, 3, 2, 1).to(device)
            self.tem = self.one_hot(32 * 5, self.seg, num_joint)
            self.tem = self.tem.permute(0, 3, 1, 2).to(device)

        self.tem_embed = embed(self.seg, joint=self.num_joint, hidden_dim=self.dim_unit*4, norm=False, bias=bias)
        self.spa_embed = embed(num_joint, joint=self.num_joint, hidden_dim=self.dim_unit, norm=False, bias=bias)
        self.joint_embed = embed(self.spatial_dim, joint=self.num_joint, hidden_dim=self.dim_unit, norm=True, bias=bias)
        self.dif_embed = embed(self.spatial_dim, joint=self.num_joint, hidden_dim=self.dim_unit, norm=True, bias=bias)
        self.maxpool = nn.AdaptiveMaxPool2d([1, 1])
        self.cnn = local(self.dim1, self.dim1 * 2, bias=bias)
        self.compute_g1 = compute_g_spa(self.dim1 // 2, self.dim1, bias=bias)
        self.gcn1 = gcn_spa(self.dim1 // 2, self.dim1 // 2, bias=bias)
        self.gcn2 = gcn_spa(self.dim1 // 2, self.dim1, bias=bias)
        self.gcn3 = gcn_spa(self.dim1, self.dim1, bias=bias)
        
        self.embed_maxpool = nn.AdaptiveMaxPool2d([self.dim1, 2])


        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))

        nn.init.constant_(self.gcn1.w.cnn.weight, 0)
        nn.init.constant_(self.gcn2.w.cnn.weight, 0)
        nn.init.constant_(self.gcn3.w.cnn.weight, 0)


    def forward(self, input):
        
        # Dynamic Representation
        input = input.view((self.bs, self.seg, self.num_joint, self.spatial_dim))
        input = input.permute(0, 3, 2, 1).contiguous()
        dif = input[:, :, :, 1:] - input[:, :, :, 0:-1]
        dif = torch.cat([dif.new(self.bs, dif.size(1), self.num_joint, 1).zero_(), dif], dim=-1)
        # print(input.shape)
        pos = self.joint_embed(input)
        tem1 = self.tem_embed(self.tem)
        spa1 = self.spa_embed(self.spa)
        dif = self.dif_embed(dif)
        dy = pos + dif
        # Joint-level Module
        input= torch.cat([dy, spa1], 1)
        g = self.compute_g1(input)
        input = self.gcn1(input, g)
        input = self.gcn2(input, g)
        input = self.gcn3(input, g)
        # Frame-level Module
        input = input + tem1
        input = self.cnn(input)
        output_feat = torch.squeeze(input)
        output_feat = self.embed_maxpool(output_feat)
        output_feat = torch.flatten(output_feat, 1)

        return output_feat

    def one_hot(self, bs, spa, tem):

        y = torch.arange(spa).unsqueeze(-1)
        y_onehot = torch.FloatTensor(spa, spa)

        y_onehot.zero_()
        y_onehot.scatter_(1, y, 1)

        y_onehot = y_onehot.unsqueeze(0).unsqueeze(0)
        y_onehot = y_onehot.repeat(bs, tem, 1, 1)

        return y_onehot

class SGNClassifier(nn.Module):
  def __init__(self,num_classes,embedding_size, *args, **kwargs) -> None:
      super().__init__(*args, **kwargs)
      self.num_classes = num_classes
      self.embedding_size = embedding_size
      self.fc = nn.Linear(self.embedding_size, self.num_classes)

  def forward(self, input):
      output = self.fc(input)
      return output
    


class EncDecModel(nn.Module):
    def __init__(self,encoder,decoder,classifier):
        super(EncDecModel, self).__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.classifier = classifier
        
    def forward(self,x):
        embedding = self.encoder(x)
        classifier_out = self.classifier(embedding)
        decoder_out = self.decoder(embedding)
        
        return decoder_out, embedding, classifier_out
        

In [None]:
batch_size=128

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_config(file_loc,device):
    file = torch.load(file_loc,map_location=device)
    return file["model_state_dict"], file["model_config"], file["config"]

model_params, model_config, config = get_config(
    f"/content/drive/MyDrive/22_FYP42 - Zero-shot Explainable HAR/Devin/SkeletonAE/model_saves/temp_NTURGB120_skeleton_SGN_DSAG_classifier_1024_emb1d/20__epoch50_emb1024_xy.pt",
    device
    )

In [None]:
encoder = SGN( 
    num_joint=config["model"]["num_joint"], 
    seg=config["model"]["seq_len"], 
    hidden_size=config["model"]["encoder_hidden_size"], 
    bs=batch_size, 
    is_3d=config["model"]["is_3d"],
    device = device,
    train=True).to(device)

classifier = SGNClassifier(
    num_classes=82,
    embedding_size=config["model"]["embedding_size"],
).to(device)

decoder = DSAGGenerator(
    seq_len = config["model"]["seq_len"], 
    input_size = config["model"]["input_size"], 
    embedding_size=config["model"]["embedding_size"]
).to(device)

bilstm_model = EncDecModel(
    encoder = encoder,
    decoder = decoder,
    classifier = classifier
).to(device)

In [None]:
bilstm_model.load_state_dict(model_params)

<All keys matched successfully>

In [None]:
encoder = bilstm_model.encoder
decoder = bilstm_model.decoder

In [None]:
encoder(torch.rand(128,60,24).to(device)).size()

torch.Size([128, 1024])

In [None]:
decoder(torch.rand(128,1024).to(device)).size()

torch.Size([128, 60, 24])

In [None]:
torchinfo.summary(bilstm_model, input_size=(batch_size, config["model"]["seq_len"], config["model"]["input_size"]), col_names = ("input_size", "output_size", "num_params", "kernel_size", "mult_adds"))

  action_fn=lambda data: sys.getsizeof(data.storage()),
  return super().__sizeof__() + self.nbytes()


Layer (type:depth-idx)                        Input Shape               Output Shape              Param #                   Kernel Shape              Mult-Adds
EncDecModel                                   [128, 60, 24]             [128, 60, 24]             --                        --                        --
├─SGN: 1-1                                    [128, 60, 24]             [128, 1024]               --                        --                        --
│    └─embed: 2-1                             [128, 2, 12, 60]          [128, 128, 12, 60]        --                        --                        --
│    │    └─Sequential: 3-1                   [128, 2, 12, 60]          [128, 128, 12, 60]        8,560                     --                        784,472,064
│    └─embed: 2-2                             [128, 60, 12, 60]         [128, 512, 12, 60]        --                        --                        --
│    │    └─Sequential: 3-2                   [128, 60, 12, 60]   