In [1]:
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from scipy.signal import medfilt
import pickle
import glob

### Functions

In [2]:
def read_annotation(annotation_path):
    lines = []
    with open(annotation_path) as f:
        lines.append(f.read().splitlines() )
    f.close()
    #lines = np.sort(lines)
    lines = np.hstack(lines)
    return lines

def generate_train_test_list(lists):
    train_list = []
    test_list = []
    for i in range(len(lists)):
        lines = read_annotation(lists[i])
        for line in lines:
            file_name, flag = line.split(' ')
            if flag == '1':
                train_list.append(file_name.split('.')[0])
            elif flag == '2':
                test_list.append(file_name.split('.')[0])
    return train_list,test_list

import scipy.io
def read_pose(path):
    mat = scipy.io.loadmat(path)
    poses = np.round(mat['pos_img'],3).swapaxes(0,2)
    openpose_layout = [2, 0, 
                       3, 7, 11, 4, 8, 12,
                       5, 9, 13, 6, 10, 14]
    poses = poses[:, openpose_layout, :]
    return poses

def generate_pose_label(pose_list,train_list,test_list, selected_action=None):
    train = {}
    train['pose'] = []
    train['label'] = []
    test = {}
    test['pose'] = []
    test['label'] = []
    summary = {}
    for i in range(len(pose_list)):
        label = pose_list[i].split('/')[-2]
        if selected_action:
            if label not in selected_action:
                continue
                
        if label not in summary.keys():
            summary[label] = 1
        else:
            summary[label] += 1
            
        pose_path = pose_list[i]+'/joint_positions.mat'
        pose = read_pose(pose_path)
        # simulate 'sitting'  not 'sitting down'
#         if label == 'sit':
#             pose[:] = pose[-1]
            
        file = pose_list[i].split('/')[-1]
        if file in train_list:
            train['label'].append(label)
            train['pose'].append(pose)
        elif file in test_list:
            test['label'].append(label)
            test['pose'].append(pose)
    return train,test, summary

### Config settings

In [3]:
class Config():
    def __init__(self):        
        self.data_dir = './JHMDB_original/'
        self.save_dir = './JHMDB_processed/'
C = Config()
selected_actions = ['sit', 'catch', 'walk']

### There are 3 ways of splitting the ground-truth pose data

In [4]:
GT_split_lists = glob.glob(C.data_dir + 'splits/*.txt')
GT_pose_list = glob.glob(C.data_dir + 'joint_positions/*/*')

In [5]:
GT_lists_1 = []
GT_lists_2 = []
GT_lists_3 = []
for file in GT_split_lists:
    if file.split('/')[-1].split('.')[0].split('_')[-1] == 'split1':
        GT_lists_1.append(file) 
    elif file.split('/')[-1].split('.')[0].split('_')[-1] == 'split2':
        GT_lists_2.append(file)
    elif file.split('/')[-1].split('.')[0].split('_')[-1] == 'split3':
        GT_lists_3.append(file)
       

## Processing each spliting method

In [6]:
GT_train_list_1,GT_test_list_1 = generate_train_test_list(GT_lists_1)
GT_train_1,GT_test_1, summary = generate_pose_label(GT_pose_list,GT_train_list_1,GT_test_list_1, selected_actions)
print(list(set(GT_train_1['label'])))
print(summary)
pickle.dump(GT_train_1, open(C.save_dir+"GT_train_1.pkl", "wb"))
pickle.dump(GT_test_1, open(C.save_dir+"GT_test_1.pkl", "wb"))


['walk', 'sit', 'catch']
{'walk': 41, 'sit': 39, 'catch': 48}


In [7]:
GT_train_list_2,GT_test_list_2 = generate_train_test_list(GT_lists_2)

GT_train_2,GT_test_2, summary = generate_pose_label(GT_pose_list,GT_train_list_2,GT_test_list_2, selected_actions)
print(summary)
pickle.dump(GT_train_2, open(C.save_dir+"GT_train_2.pkl", "wb"))
pickle.dump(GT_test_2, open(C.save_dir+"GT_test_2.pkl", "wb"))


{'walk': 41, 'sit': 39, 'catch': 48}


In [8]:
GT_train_list_3,GT_test_list_3 = generate_train_test_list(GT_lists_3)

GT_train_3,GT_test_3, summary = generate_pose_label(GT_pose_list,GT_train_list_3,GT_test_list_3, selected_actions)
print(summary)
pickle.dump(GT_train_3, open(C.save_dir+"GT_train_3.pkl", "wb"))
pickle.dump(GT_test_3, open(C.save_dir+"GT_test_3.pkl", "wb"))


{'walk': 41, 'sit': 39, 'catch': 48}
