In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import numpy as np
import pickle
from model import SGN
from data import NTUDataLoaders, AverageMeter
from sklearn.metrics import f1_score, precision_score, recall_score
import h5py
from types import SimpleNamespace
from captum.attr import IntegratedGradients
import random

In [None]:
batch_size = 32

with open('data/ntu120/X_full.pkl', 'rb') as f:
    X = pickle.load(f)

# Remove actions above 120
to_rem = []
for file in X.keys():
    if int(str(file).split('A')[1][:3]) > 60:
        to_rem.append(file)
for file in to_rem:
    del X[file]

def reshape_skeleton(skeleton):
    """
    Reshape the skeleton data from shape [300, 150] to [20, 75].
    """
    skeleton = skeleton[:20, :75]
    return skeleton


args = SimpleNamespace(batch_size=batch_size, train=0)

sgn_ar = SGN(60, None, 20, args, 0).cuda()
sgn_priv = SGN(40, None, 20, args, 0).cuda()
sgn_ar.load_state_dict(torch.load('results/NTUar/SGN/1_best.pth')['state_dict'], strict=False)
sgn_priv.load_state_dict(torch.load('results/NTUri/SGN/1_best.pth')['state_dict'], strict=False)
sgn_ar.eval()
sgn_priv.eval()


In [13]:
import time
instance_limiter = 1000  # Run on 10 samples

# Sample a subset of instances for testing
sampled_instances = random.sample(list(X.keys()), instance_limiter)

# Storage for joint-wise importance scores
joint_importances_ar = {joint: [] for joint in range(25)}  # For Action Recognition
joint_importances_ri = {joint: [] for joint in range(25)}  # For Re-identification

# Storage for times
times_ar = []
times_ri = []

# Define models for Captum
def model_ar(input_tensor):
    output = sgn_ar.eval_single(input_tensor)
    return output

def model_ri(input_tensor):
    output = sgn_priv.eval_single(input_tensor)
    return output

# Initialize Integrated Gradients
ig_ar = IntegratedGradients(model_ar)
ig_ri = IntegratedGradients(model_ri)

# Iterate through the data and collect explanations
for file_name in sampled_instances:
    print('Gathering explanation for: ', file_name)
    A = int(file_name[19:22]) - 1
    P = int(file_name[9:12]) - 1
    skeleton = X[file_name]

    reshaped_skeleton = reshape_skeleton(skeleton)
    input_tensor = torch.tensor(reshaped_skeleton, dtype=torch.float32).unsqueeze(0).cuda()  # Shape: (1, 20, 75)

    # Compute attributions for Action Recognition
    start_time_ar = time.time()
    attributions_ar = ig_ar.attribute(input_tensor, target=A)
    end_time_ar = time.time()
    times_ar.append(end_time_ar - start_time_ar)

    # Compute attributions for Re-identification
    start_time_ri = time.time()
    attributions_ri = ig_ri.attribute(input_tensor, target=P)
    end_time_ri = time.time()
    times_ri.append(end_time_ri - start_time_ri)

    # Process attributions to collect joint-wise importance scores
    # attributions_ar and attributions_ri have shape (1, 20, 75)

    # For each joint
    for joint in range(25):
        joint_indices = [joint * 3 + c for c in range(3)]  # Indices for x, y, z of the joint
        # Get attributions for these indices across all frames
        joint_attributions_ar = attributions_ar[0, :, joint_indices]  # Shape: (20, 3)
        joint_attributions_ri = attributions_ri[0, :, joint_indices]
        # Sum over frames and coords
        joint_importance_ar = joint_attributions_ar.abs().sum().item()
        joint_importance_ri = joint_attributions_ri.abs().sum().item()
        # Append to the lists
        joint_importances_ar[joint].append(joint_importance_ar)
        joint_importances_ri[joint].append(joint_importance_ri)

# Average the importance scores across sequences for each joint
average_importances_ar = {joint: np.mean(importance) for joint, importance in joint_importances_ar.items()}
average_importances_ri = {joint: np.mean(importance) for joint, importance in joint_importances_ri.items()}

# Display the average importances
print('Average importances for Action Recognition:')
for joint, importance in average_importances_ar.items():
    print('Joint {}: {}'.format(joint, importance))

print('Average importances for Re-identification:')
for joint, importance in average_importances_ri.items():
    print('Joint {}: {}'.format(joint, importance))

# Calculate average times
average_time_ar = np.mean(times_ar)
average_time_ri = np.mean(times_ri)
total_average_time = (sum(times_ar) + sum(times_ri)) / (2 * len(times_ar))

print('Average explanation time for Action Recognition: {:.4f} seconds'.format(average_time_ar))
print('Average explanation time for Re-identification: {:.4f} seconds'.format(average_time_ri))
print('Total average explanation time per model: {:.4f} seconds'.format(total_average_time))
print('Total run time: {:.4f} seconds'.format(sum(times_ar) + sum(times_ri)))

Gathering explanation for:  b'S007C003P017R001A046'
Gathering explanation for:  b'S014C003P019R001A015'
Gathering explanation for:  b'S011C002P028R002A019'
Gathering explanation for:  b'S015C002P037R001A027'
Gathering explanation for:  b'S008C003P029R001A011'
Gathering explanation for:  b'S006C003P015R001A050'
Gathering explanation for:  b'S007C002P026R001A058'
Gathering explanation for:  b'S005C002P016R002A024'
Gathering explanation for:  b'S010C001P015R001A041'
Gathering explanation for:  b'S002C002P013R001A009'
Gathering explanation for:  b'S017C001P008R001A057'
Gathering explanation for:  b'S008C003P031R001A051'
Gathering explanation for:  b'S002C003P010R001A040'
Gathering explanation for:  b'S008C002P036R001A008'
Gathering explanation for:  b'S011C001P002R002A060'
Gathering explanation for:  b'S005C001P018R001A051'
Gathering explanation for:  b'S004C003P020R002A031'
Gathering explanation for:  b'S013C003P027R001A047'
Gathering explanation for:  b'S008C002P007R001A003'
Gathering ex