Skip to content

Commit

Permalink
fix missing test STC samples, see issue#3,#4,#5
Browse files Browse the repository at this point in the history
  • Loading branch information
LiUzHiAn committed Nov 10, 2021
1 parent f909f42 commit b378d4f
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions pre_process/extract_samples.py
Expand Up @@ -8,7 +8,14 @@
def samples_extraction(dataset_root, dataset_name, mode, all_bboxes, save_dir):
num_predicted_frame = 1
# save samples in chunked file
num_samples_each_chunk = 100000
if dataset_name == "ped2":
num_samples_each_chunk = 100000
elif dataset_name == "avenue":
num_samples_each_chunk = 200000 if mode == "test" else 20000
elif dataset_name == "shanghaitech":
num_samples_each_chunk = 300000 if mode == "test" else 100000
else:
raise NotImplementedError("dataset name should be one of ped2,avenue or shanghaitech!")

# frames dataset
dataset = get_dataset(
Expand Down Expand Up @@ -68,7 +75,7 @@ def samples_extraction(dataset_root, dataset_name, mode, all_bboxes, save_dir):
chunked_samples["motion"] = np.array(chunked_samples["motion"])
chunked_samples["bbox"] = np.array(chunked_samples["bbox"])
chunked_samples["pred_frame"] = np.array(chunked_samples["pred_frame"])
joblib.dump(chunked_samples, os.path.join(save_dir, "chunked_samples_%d.pkl" % chunk_id))
joblib.dump(chunked_samples, os.path.join(save_dir, "chunked_samples_%02d.pkl" % chunk_id))
print("Chunk %d file saved!" % chunk_id)

chunk_id += 1
Expand All @@ -91,7 +98,7 @@ def samples_extraction(dataset_root, dataset_name, mode, all_bboxes, save_dir):

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--proj_root", type=str, default="/home/liuzhian/hdd4T/code/hfvad", help='project root path')
parser.add_argument("--proj_root", type=str, default="/home/liuzhian/hdd4T/code/hf2vad", help='project root path')
parser.add_argument("--dataset_name", type=str, default="ped2", help='dataset name')
parser.add_argument("--mode", type=str, default="train", help='train or test data')

Expand Down

0 comments on commit b378d4f

Please sign in to comment.