In [1]:
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
'''
File: /workspace/skeleton/misc/result_analysis.py
Project: /workspace/skeleton/misc
Created Date: Monday December 18th 2023
Author: Kaixu Chen
-----
Comment:

Have a good code time :)
-----
Last Modified: Monday December 18th 2023 5:39:35 am
Modified By: the developer formerly known as Kaixu Chen at <chenkaixusan@gmail.com>
-----
Copyright (c) 2023 The University of Tsukuba
-----
HISTORY:
Date      	By	Comments
----------	---	---------------------------------------------------------
'''
import os, hydra, warnings
warnings.filterwarnings("ignore")

import logging, time, sys, json, yaml, csv, shutil, copy
sys.path.append("/workspace/skeleton/")
sys.path.append("/workspace/skeleton/project")
from pathlib import Path
import torch 

import torchmetrics
from torchmetrics.classification import (
    MulticlassAccuracy,
    MulticlassPrecision,
    MulticlassRecall,
    MulticlassF1Score,
    MulticlassConfusionMatrix,
    MulticlassAUROC,
)

from project.dataloader.data_loader import WalkDataModule
from project.train import GaitCycleLightningModule

from pytorch_lightning import Trainer, seed_everything

def get_inference(test_data, model):
        
    total_pred_list = []
    total_label_list = []

    for i, batch in enumerate(test_data):

        pred_list = []
        label_list = []

        # input and label
        video = batch['video'].detach().to('cuda:0') # b, c, t, h, w

        label = batch['label'].detach().to('cuda:0') # b, class_num

        model.eval()

        # pred the video frames
        with torch.no_grad():
            preds = model(video)

        # when torch.size([1]), not squeeze.
        if preds.size()[0] != 1 or len(preds.size()) != 1 :
            preds = preds.squeeze(dim=-1)
            preds_softmax = torch.softmax(preds, dim=1)
        else:
            preds_softmax = torch.softmax(preds, dim=1)

        pred_list.append(preds_softmax.tolist())
        label_list.append(label.tolist())

        
        for i in pred_list:
            for number in i:
                total_pred_list.append(number)

        for i in label_list:
            for number in i: 
                total_label_list.append(number)

    pred = torch.tensor(total_pred_list)
    label = torch.tensor(total_label_list)

    return pred, label


In [5]:
class Config:
    model="resnet"
    model_class_num: 3 # the class num of model
    model_depth: 50 # choices=[50, 101, 152], help='the depth of used model'

    # Training config
    max_epochs: 50 # numer of epochs of training
    batch_size: 1 # batch size for the dataloader

    # used for val 
    clip_duration: 0.5 # clip duration for the video
    uniform_temporal_subsample_num: 16 # num frame from the clip duration, for define one gait cycle, we need use whole frames.

    gpu_num: 0 # choices=[0, 1], help='the gpu number whicht to train'

    # Transfor_learning
    transfor_learning: True # if use the transformer learning

    log_path="/workspace/skeleton/logs/${model.model}/${now:%Y-%m-%d}/${now:%H-%M-%S}/${train.gait_cycle}_${train.uniform_temporal_subsample_num}_${model.model_depth}"

    fast_dev_run: False # if use the fast_dev_run
    fold: 3 # the fold number of the cross validation

    gait_cycle: 0 # [0, 1]


ckpt_path = "/workspace/skeleton/logs/resnet/2023-12-18/08-50-55/8_50/0/version_0/checkpoints/1-1.16-0.5810.ckpt"
fold = ckpt_path.split('/')[-4]
model_name = ckpt_path.split('/')[4]
gpu_num=0
batch_size=64

classification_module = GaitCycleLightningModule(hparams).load_from_checkpoint(ckpt_path).to(f'cuda:{hparams.train.gpu_num}')

