In [1]:
T = 100
window_size = 64
thres = 64
cropDir = '/home/k100/Code/motionclassifiation_trainData_process/RGB_lastFrame'
dataPath = './testdata/NTU120_CV_RGB_2024_11_13_10_22.npz'
device = 'cuda:0'

In [2]:
import torch
import sys
import shutil
import inspect
from collections import OrderedDict
import yaml
import time
from feeders import tools
import cv2
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt


In [3]:
data = np.load(dataPath)
x_test = data['x_test']
y_test = data['y_test']
crops = data['x_test_crop'] 

In [4]:
model_args = {'num_classes': 120, 'num_people': 2, 'num_points': 17,
              'kernel_size': 7, 'num_heads': 32, 'attn_drop': 0.5,
              'head_drop': 0.0, 'rel': True, 'drop_path': 0.2, 
              'type_1_size': [8, 2], 'type_2_size': [8, 17],
              'type_3_size': [8, 2], 'type_4_size': [8, 17],
              'mlp_ratio': 4.0, 'index_t': True}

work_dir = './work_dir/ntu/cs/SkateFormerRGB_j/'
# ckptPath = './work_dir/ntu/cs/SkateFormerRGB_j/runs-last_model_Epoch47_acc79.pt'
ckptPath = './work_dir/ntu/cs/SkateFormerRGB_j/runs-last_model.pt'

modelType = 'model.SkateFormer.SkateFormerRGB_'

In [5]:
def cropPreprocess(crop):
    cropH,cropW,_ = crop.shape
    assert cropH==256 and cropW==128
    crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB).astype(float)/255
    return torch.from_numpy(np.transpose(crop, (2, 0, 1)))#.unsqueeze(0)

def import_class(import_str):
    mod_str, _sep, class_str = import_str.rpartition('.')
    __import__(mod_str)
    try:
        return getattr(sys.modules[mod_str], class_str)
    except AttributeError:
        raise ImportError('Class %s cannot be found (%s)' % (class_str, traceback.format_exception(*sys.exc_info())))

def load_model(work_dir,ckptPath,device,modelType='model.SkateFormer.SkateFormer_'):
    print('[modelType] : ', modelType)
    Model = import_class(modelType)
    shutil.copy2(inspect.getfile(shutil), work_dir)
    model = Model(**model_args)
    weights = torch.load(ckptPath)
    weights = OrderedDict([[k.split('module.')[-1], v.to(device)] for k, v in weights.items()])
    keys = list(weights.keys())
    model.load_state_dict(weights)
    model.to(device)
    return model

In [28]:
model = load_model(work_dir,ckptPath,device,modelType=modelType)

[modelType] :  model.SkateFormer.SkateFormerRGB_


In [29]:
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params}")

Total parameters: 3749551


In [33]:
BatchDatanum = 32
Epochs = int(np.ceil(x_test.shape[0]/BatchDatanum))
print(Epochs)
# Epochs = 100

649


In [34]:
prss = []
gtss = []
scoress = []
total_true,total_false = 0,0
for Epoch in tqdm(range(Epochs)):
    dataIdxes = np.arange(Epoch*BatchDatanum,(Epoch+1)*BatchDatanum,1)
    Inputdata = x_test.copy()[dataIdxes]
    Inputlabel = y_test.copy()[dataIdxes]
    cropsBatch = crops.copy()[dataIdxes]
    Inputcrops = torch.stack([cropPreprocess(cv2.imread(f'{cropDir}/{cropPath}')) for cropPath in cropsBatch],dim=0)

    inputdata = Inputdata.copy().reshape((BatchDatanum, T, 2, 17, 3)).transpose(0, 4, 1, 3, 2)
    valid_frame_nums = np.sum(np.sum(np.sum(inputdata, axis=-1), axis=-1) != 0, axis=-1)[...,0]
    p_interval = [0.95]

    datas = np.empty((len(inputdata),3,64,17,2))#,dtype=float
    t_indexs = np.empty((len(inputdata),64))#,dtype=float
    for i in range(len(inputdata)):
        data,t_index = tools.valid_crop_uniform(inputdata[i], valid_frame_nums[i], p_interval, window_size, thres)
        datas[i] = data
        t_indexs[i] = t_index
    datas = torch.from_numpy(datas).float().to(device)
    t_indexs = torch.from_numpy(t_indexs).float().to(device)
    Inputcrops = Inputcrops.float().to(device)

    with torch.no_grad():
        preds = model(datas,t_indexs,Inputcrops)
    preds = torch.nn.functional.softmax(preds, dim=0)
    scores = torch.max(preds,dim=1).values.cpu()
    prs = torch.argmax(preds,dim=1).cpu()
    gts = torch.argmax(torch.from_numpy(Inputlabel), dim=1)
    true,false = sum(gts==prs).item(),sum(gts!=prs).item()
    total_true += true
    total_false += false
    prss+=prs.tolist()
    gtss+=gts.tolist()
    scoress+= scores.tolist()

100%|████████████████████████████████████████████████████████████████████████████████████████████▊| 648/649 [06:20<00:00,  1.70it/s]


IndexError: index 20742 is out of bounds for axis 0 with size 20742

In [None]:
print(f'acc : {total_true/(total_false+total_true)*100} %')

In [None]:
ground_true = np.array(gtss)
predict = np.array(prss)
confidence = np.array(scoress)

# 定义不同的阈值范围
thresholds = np.linspace(0, 1, 50)  # 从0到1的50个不同阈值

# 存储准确率和数据数量
accuracies = []
data_counts = []
thresholds_have = []
for threshold in thresholds:
    mask = confidence >= threshold
    accuracy = np.sum(ground_true[mask] == predict[mask]) / np.sum(mask) if np.sum(mask) > 0 else 0
    data_count = np.sum(mask)  # 符合阈值的数据数量
    if data_count == 0:
        continue
    accuracies.append(accuracy)
    data_counts.append(data_count)
    thresholds_have.append(threshold)
# 创建图形
fig, ax1 = plt.subplots()

# 绘制左Y轴的准确率曲线
ax1.set_xlabel('Confidence Threshold')
ax1.set_ylabel('Accuracy')#, color='tab:blue'
ax1.plot(thresholds_have, accuracies, label='Accuracy')#, color='tab:blue'
ax1.tick_params(axis='y')#, labelcolor='tab:blue'

# 创建右Y轴的数据数量曲线
# ax2 = ax1.twinx()
# ax2.set_ylabel('Data Count', color='tab:red')
# ax2.plot(thresholds_have, data_counts, color='tab:red', label='Data Count')
# ax2.tick_params(axis='y', labelcolor='tab:red')

# 显示图例和图形
fig.tight_layout()
# plt.title('Validation Accuracy')
plt.show()