/
kth_dataset.py
125 lines (94 loc) · 4.24 KB
/
kth_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import argparse
import glob
import itertools
import os
import random
import cv2
import numpy as np
import tensorflow as tf
from video_prediction.datasets.base_dataset import VarLenFeatureVideoDataset
class KTHVideoDataset(VarLenFeatureVideoDataset):
def __init__(self, *args, **kwargs):
super(KTHVideoDataset, self).__init__(*args, **kwargs)
self.state_like_names_and_shapes['images'] = 'images/encoded', (64, 64, 3)
def get_default_hparams_dict(self):
default_hparams = super(KTHVideoDataset, self).get_default_hparams_dict()
hparams = dict(
context_frames=10,
sequence_length=20,
force_time_shift=True,
use_state=False,
)
return dict(itertools.chain(default_hparams.items(), hparams.items()))
@property
def jpeg_encoding(self):
return False
def num_examples_per_epoch(self):
return len(self.filenames)
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _bytes_list_feature(values):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def partition_data(input_dir):
# List files and corresponding person IDs
files = glob.glob(os.path.join(input_dir, '*/*.avi'))
persons = np.array([int(f.split('/person')[1].split('_')[0]) for f in files])
train_mask = persons <= 16
train_fnames = [files[i] for i in np.where(train_mask)[0]]
test_fnames = [files[i] for i in np.where(~train_mask)[0]]
random.shuffle(train_fnames)
pivot = int(0.95 * len(train_fnames))
train_fnames, val_fnames = train_fnames[:pivot], train_fnames[pivot:]
return train_fnames, val_fnames, test_fnames
def read_video(fname):
if not os.path.isfile(fname):
raise FileNotFoundError
vidcap = cv2.VideoCapture(fname)
frames, (success, image) = [], vidcap.read()
while success:
frames.append(image)
success, image = vidcap.read()
return frames
def save_tf_record(output_fname, sequences, preprocess_image):
print('saving sequences to %s' % output_fname)
with tf.python_io.TFRecordWriter(output_fname) as writer:
for sequence in sequences:
num_frames = len(sequence)
height, width, channels = sequence[0].shape
encoded_sequence = [preprocess_image(image) for image in sequence]
features = tf.train.Features(feature={
'sequence_length': _int64_feature(num_frames),
'height': _int64_feature(height),
'width': _int64_feature(width),
'channels': _int64_feature(channels),
'images/encoded': _bytes_list_feature(encoded_sequence),
})
example = tf.train.Example(features=features)
writer.write(example.SerializeToString())
def read_videos_and_save_tf_records(output_dir, fnames):
def preprocess_image(image):
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image[:, 20:-20], (64, 64), interpolation=cv2.INTER_LINEAR)
return image.tostring()
for i, fname in enumerate(fnames):
output_fname = os.path.join(output_dir, os.path.splitext(os.path.basename(fname))[0] + '.tfrecords')
sequence = read_video(fname)
save_tf_record(output_fname, [sequence], preprocess_image)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("input_dir", type=str, help="directory containing the directories "
"boxing, handclapping, handwaving, "
"jogging, running, walking")
parser.add_argument("output_dir", type=str)
args = parser.parse_args()
partition_names = ['train', 'val', 'test']
partition_fnames = partition_data(args.input_dir)
for partition_name, partition_fnames in zip(partition_names, partition_fnames):
partition_dir = os.path.join(args.output_dir, partition_name)
if not os.path.exists(partition_dir):
os.makedirs(partition_dir)
read_videos_and_save_tf_records(partition_dir, partition_fnames)
if __name__ == '__main__':
main()