In [None]:
import torch
import json
import os
import numpy as np
import yaml
from collections import OrderedDict
from st_gcn.net.st_gcn import Model as Model

def load_skeleton_data(json_path, num_person_in=1, num_person_out=1):
    with open(json_path, 'r') as f:
        data = json.load(f)
    
    skeleton_data = np.zeros((3, 150, 18, num_person_out))
    
    for frame in data['data']:
        frame_index = frame['frame_index']
        if frame_index >= 150:
            break
        
        for m, skeleton in enumerate(frame['skeleton'][:num_person_in]):
            if m >= num_person_out:
                break
            pose = np.array(skeleton['pose']).reshape(-1, 2)
            score = np.array(skeleton['score'])
            
            skeleton_data[0, frame_index, :, m] = pose[:, 0]
            skeleton_data[1, frame_index, :, m] = pose[:, 1]
            skeleton_data[2, frame_index, :, m] = score
    
    skeleton_data[0:2] = skeleton_data[0:2] - 0.5
    skeleton_data[0][skeleton_data[2] == 0] = 0
    skeleton_data[1][skeleton_data[2] == 0] = 0
    
    return skeleton_data

def predict_action(model, skeleton_data):
    data = torch.FloatTensor(skeleton_data).unsqueeze(0)
    data = data.to(next(model.parameters()).device)
    
    with torch.no_grad():
        output = model(data)
    
    probabilities = torch.nn.functional.softmax(output, dim=1)
    return probabilities.cpu().numpy()[0]

config_path = "../config/st_gcn/kinetics-skeleton-from-rawdata/test.yaml"

with open(config_path, 'r') as f:
    arg = yaml.safe_load(f)

model_args = arg['model_args']
model = Model(**model_args)

weights_path = "../model/epoch180_model.pt"
weights = torch.load(weights_path)

new_weights = OrderedDict()
for k, v in weights.items():
    name = k.replace("module.", "") if "module." in k else k
    new_weights[name] = v

model.load_state_dict(new_weights, strict=False)
model.eval()

device = torch.device(f"cuda:{arg['device']}" if torch.cuda.is_available() else "cpu")
model = model.to(device)

class_labels = ['walk', 'stand', 'sit', 'armExe', 'lieDown', 'fall']

json_dir = "../data/test"

for json_file in os.listdir(json_dir):
    if json_file.endswith('.json'):
        json_path = os.path.join(json_dir, json_file)
        
        skeleton_data = load_skeleton_data(json_path, num_person_in=model_args['num_person'], num_person_out=model_args['num_person'])
        
        prediction = predict_action(model, skeleton_data)
        predicted_class = np.argmax(prediction)
        
        with open(json_path, 'r') as f:
            data = json.load(f)
        actual_label_index = data.get('label_index', -1)  # 'label_index'를 사용
        actual_label = class_labels[actual_label_index] if 0 <= actual_label_index < len(class_labels) else 'Unknown'
        
        print(f"File: {json_file}")
        print(f"Actual label: {actual_label} (index: {actual_label_index})")
        print(f"Prediction probabilities: {prediction}")
        print(f"Predicted class index: {predicted_class}")
        print(f"Predicted action: {class_labels[predicted_class]}")
        print("--------------------")

print("All predictions completed.")