In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import numpy as np
import pickle
from model import SGN
from data import NTUDataLoaders, AverageMeter
from sklearn.metrics import f1_score, precision_score, recall_score
import h5py
from types import SimpleNamespace

In [10]:
batch_size = 32

with open('ntu/SGN/X_full.pkl', 'rb') as f:
    X = pickle.load(f)

def reshape_skeleton(skeleton):
    """
    Reshape the skeleton data from shape [300, 150] to [20, 75].
    """
    skeleton = skeleton[:20, :75]
    return skeleton

def predict_sgn(model, skeleton):
    skeleton = torch.tensor(skeleton).cuda()
    model.cuda()
    out = model.eval_single(skeleton)
    out = out.view((-1, skeleton.size(0)//skeleton.size(0), out.size(1)))
    out = out.mean(1)
    out = out.cpu().detach().numpy()
    out = np.argmax(out, axis=1)
    return out


args = SimpleNamespace(batch_size=batch_size, train=0)

sgn_ar = SGN(60, None, 20, args, 0).cuda()
sgn_priv = SGN(40, None, 20, args, 0).cuda()
sgn_ar.load_state_dict(torch.load('pretrained/action_60_sgnpt.pt')['state_dict'], strict=False)
sgn_priv.load_state_dict(torch.load('pretrained/privacy_60_sgnpt.pt')['state_dict'], strict=False)

  sgn_ar.load_state_dict(torch.load('pretrained/action_60_sgnpt.pt')['state_dict'], strict=False)
  sgn_priv.load_state_dict(torch.load('pretrained/privacy_60_sgnpt.pt')['state_dict'], strict=False)


_IncompatibleKeys(missing_keys=[], unexpected_keys=['tem_embed.cnn.0.cnn.bias', 'tem_embed.cnn.2.cnn.bias', 'spa_embed.cnn.0.cnn.bias', 'spa_embed.cnn.2.cnn.bias', 'joint_embed.cnn.1.cnn.bias', 'joint_embed.cnn.3.cnn.bias', 'dif_embed.cnn.1.cnn.bias', 'dif_embed.cnn.3.cnn.bias', 'cnn.cnn1.bias', 'cnn.cnn2.bias', 'compute_g1.g1.cnn.bias', 'compute_g1.g2.cnn.bias', 'gcn1.w1.cnn.bias', 'gcn2.w1.cnn.bias', 'gcn3.w1.cnn.bias'])

In [11]:
import lime
import lime.lime_tabular
import numpy as np

feature_names = ["frame_{}_joint_{}_coord_{}".format(frame, joint, coord)
                   for frame in range(20) for joint in range(25) for coord in ['x', 'y', 'z']]

# Initialize LIME explainer
ar_explainer = lime.lime_tabular.LimeTabularExplainer(
    training_data=np.array([reshape_skeleton(v).flatten() for v in X.values()]),  # Flattened training data
    feature_names=feature_names,
    mode='classification',
    class_names=['Action_{}'.format(i) for i in range(60)],  # Replace with actual class names if available
    training_labels=np.array([int(k[19:22]) - 1 for k in X.keys()]),  # Assuming action label is embedded in file name
)

ri_explainer = lime.lime_tabular.LimeTabularExplainer(
    training_data=np.array([reshape_skeleton(v).flatten() for v in X.values()]),  # Flattened training data
    feature_names=feature_names,
    mode='classification',
    class_names=['Action_{}'.format(i) for i in range(40)],  # Replace with actual class names if available
    training_labels=np.array([int(k[9:12]) - 1 for k in X.keys()]),  # Assuming action label is embedded in file name
)

AttributeError: 'numpy.ndarray' object has no attribute 'values'

In [None]:
import random
instance_limiter = 1

# Sample a subset of instances for testing
sampled_instances = random.sample(list(X.keys()), instance_limiter)

# Storage for joint-wise importance scores
joint_importances_ar = {joint: [] for joint in range(25)}  # For Action Recognition
joint_importances_ri = {joint: [] for joint in range(25)}  # For Re-identification

# Prediction functions for both models
def predict_fn_ar(input_skeleton):
    input_skeleton = input_skeleton.reshape(-1, 20, 75)
    input_tensor = torch.tensor(input_skeleton, dtype=torch.float32).cuda()
    with torch.no_grad():
        output = sgn_ar.eval_single(input_tensor)
        output = torch.softmax(output, 1).cpu().numpy()
    return output

def predict_fn_ri(input_skeleton):
    input_skeleton = input_skeleton.reshape(-1, 20, 75)
    input_tensor = torch.tensor(input_skeleton, dtype=torch.float32).cuda()
    with torch.no_grad():
        output = sgn_priv.eval_single(input_tensor)
        output = torch.softmax(output, 1).cpu().numpy()
    return output

# Iterate through the data and collect explanations
for file_name in sampled_instances:
    print('Gathering explanation for: ', file_name)
    A = int(file_name[19:22]) - 1
    P = int(file_name[9:12]) - 1
    skeleton = X[file_name]

    reshaped_skeleton = reshape_skeleton(skeleton)
    flattened_skeleton = reshaped_skeleton.flatten()

    # Get LIME explanation for Action Recognition
    explanation_ar = ar_explainer.explain_instance(flattened_skeleton, predict_fn_ar, num_features=20*25*3, labels=[A])

    # Get LIME explanation for Re-identification
    explanation_ri = ri_explainer.explain_instance(flattened_skeleton, predict_fn_ri, num_features=20*25*3, labels=[P])

    # Collect importance for each joint from both models
    for feature_index, importance_value in explanation_ar.local_exp[A]:  # Assuming class 0
        feature_name = ar_explainer.feature_names[feature_index]
        # Extract joint number from feature name (e.g., "frame_0_joint_7_coord_y")
        joint_num = int(feature_name.split('_')[3])
        joint_importances_ar[joint_num].append(importance_value)

    for feature_index, importance_value in explanation_ri.local_exp[P]:  # Assuming class 0
        feature_name = ri_explainer.feature_names[feature_index]
        joint_num = int(feature_name.split('_')[3])
        joint_importances_ri[joint_num].append(importance_value)

# Average the importance scores across sequences for each joint
average_importances_ar = {joint: np.mean(importance) for joint, importance in joint_importances_ar.items()}
average_importances_ri = {joint: np.mean(importance) for joint, importance in joint_importances_ri.items()}

# Compare AR vs. RI to find joints with the specific patterns
positive_ar_negative_ri = []
negative_ar_positive_ri = []

for joint in range(25):
    if average_importances_ar[joint] > 0 and average_importances_ri[joint] < 0:
        positive_ar_negative_ri.append(joint)
    if average_importances_ar[joint] < 0 and average_importances_ri[joint] > 0:
        negative_ar_positive_ri.append(joint)

# Output the results
print('Below are zero indexed (so add 1 when looking at the joint number):')
print("Joints with positive importance for AR but negative for RI:", positive_ar_negative_ri)
print("Joints with negative importance for AR but positive for RI:", negative_ar_positive_ri)