Skip to content

Commit

Permalink
update reader for data
Browse files Browse the repository at this point in the history
  • Loading branch information
angeladai committed May 25, 2018
1 parent f042918 commit 0805326
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 45 deletions.
55 changes: 25 additions & 30 deletions src/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def ReadSceneBlocksLevel(data_filepattern,
train_samples,
dim_block,
height_block,
stored_dim_block_hi,
stored_height_block_hi,
stored_dim_block,
stored_height_block,
is_base_level,
hierarchy_level,
num_quant_levels,
Expand All @@ -33,8 +33,8 @@ def ReadSceneBlocksLevel(data_filepattern,
train_samples: Train on previous model predictions.
dim_block: x/z dimension of train block.
height_block: y dimension of train block.
stored_dim_block_hi: Stored data x/z dimension (high-resolution).
stored_height_block_hi: Stored data y dimension (high-resolution).
stored_dim_block: Stored data x/z dimension (high-resolution).
stored_height_block: Stored data y dimension (high-resolution).
is_base_level: Whether there are no previous hierarchy levels.
hierarchy_level: hierarchy level (1 is finest).
num_quant_levels: Number of quantization bins (if used).
Expand All @@ -58,8 +58,8 @@ def ReadSceneBlocksLevel(data_filepattern,
[batch_size, dim_target//2, height_target//2, dim_target//2], dtype
tf.uint8.
"""
assert (stored_dim_block_hi >= dim_block and
stored_height_block_hi >= height_block)
assert (stored_dim_block >= dim_block and
stored_height_block >= height_block)

read_target_lo = not is_base_level
_, examples, samples_lo, samples_sem_lo = _ReadBlockExample(
Expand All @@ -69,20 +69,20 @@ def ReadSceneBlocksLevel(data_filepattern,
is_base_level=is_base_level,
shuffle=shuffle,
num_epochs=num_epochs,
stored_dim_block_hi=stored_dim_block_hi,
stored_height_block_hi=stored_height_block_hi)
stored_dim_block=stored_dim_block,
stored_height_block=stored_height_block)

# jitter height (must be even)
jitter = (np.random.random_integers(
low=0, high=_HEIGHT_JITTER[hierarchy_level - 1]) // 2) * 2

# extract relevant portion of data block as per input/target dim
#key_input = _RESOLUTIONS[hierarchy_level - 1] + '_' + _INPUT_FEATURE
#key_target = _RESOLUTIONS[hierarchy_level - 1] + '_' + _TARGET_FEATURE
#key_target_sem = _RESOLUTIONS[hierarchy_level - 1] + '_' + _TARGET_SEM_FEATURE
key_input = 'input_sdf'
key_target = 'target_df'
key_target_sem = 'target_sem'
key_input = _RESOLUTIONS[hierarchy_level - 1] + '_' + _INPUT_FEATURE
key_target = _RESOLUTIONS[hierarchy_level - 1] + '_' + _TARGET_FEATURE
key_target_sem = _RESOLUTIONS[hierarchy_level - 1] + '_' + _TARGET_SEM_FEATURE
#key_input = 'input_sdf'
#key_target = 'target_df'
#key_target_sem = 'target_sem'

input_sdf_blocks = examples[key_input]
input_sdf_blocks = preprocessor.extract_block(
Expand All @@ -108,8 +108,6 @@ def ReadSceneBlocksLevel(data_filepattern,
target_lo_blocks = preprocessor.preprocess_target_sdf(
target_lo_blocks, num_quant_levels, constants.TRUNCATION, quantize)

stored_dim_block = stored_dim_block_hi >> (hierarchy_level - 1)
stored_height_block = stored_height_block_hi >> (hierarchy_level - 1)
target_sem_blocks = tf.decode_raw(examples[key_target_sem], tf.uint8)
target_sem_blocks = tf.reshape(
target_sem_blocks,
Expand Down Expand Up @@ -167,17 +165,17 @@ def ReadSceneBlocksLevel(data_filepattern,


def _ReadBlockExample(data_filepattern, train_samples, hierarchy_level,
is_base_level, stored_dim_block_hi,
stored_height_block_hi, shuffle, num_epochs):
is_base_level, stored_dim_block,
stored_height_block, shuffle, num_epochs):
"""Deserializes train data.
Args:
data_filepattern: A list of data file patterns.
train_samples: Train on previous model predictions.
hierarchy_level: hierarchy level (1 is finest).
is_base_level: Whether there are no previous hierarchy levels.
stored_dim_block_hi: Stored data x/z dimension (high-resolution).
stored_height_block_hi: Stored data y dimension (high-resolution).
stored_dim_block: Stored data x/z dimension (high-resolution).
stored_height_block: Stored data y dimension (high-resolution).
shuffle: Whether to shuffle.
num_epochs: Number of data epochs.
Returns:
Expand All @@ -201,9 +199,6 @@ def _ReadBlockExample(data_filepattern, train_samples, hierarchy_level,
reader = tf.TFRecordReader()
key, serialized_example = reader.read(filename_queue)

stored_dim_block = stored_dim_block_hi >> (hierarchy_level - 1)
stored_height_block = stored_height_block_hi >> (hierarchy_level - 1)

samples_lo = None
samples_sem_lo = None
if train_samples and not is_base_level:
Expand All @@ -224,12 +219,12 @@ def _ReadBlockExample(data_filepattern, train_samples, hierarchy_level,
serialized_example = example['data']

# Parse sequence example.
#key_input = _RESOLUTIONS[hierarchy_level - 1] + '_' + _INPUT_FEATURE
#key_target = _RESOLUTIONS[hierarchy_level - 1] + '_' + _TARGET_FEATURE
#key_target_sem = _RESOLUTIONS[hierarchy_level - 1] + '_' + _TARGET_SEM_FEATURE
key_input = 'input_sdf'
key_target = 'target_df'
key_target_sem = 'target_sem'
key_input = _RESOLUTIONS[hierarchy_level - 1] + '_' + _INPUT_FEATURE
key_target = _RESOLUTIONS[hierarchy_level - 1] + '_' + _TARGET_FEATURE
key_target_sem = _RESOLUTIONS[hierarchy_level - 1] + '_' + _TARGET_SEM_FEATURE
#key_input = 'input_sdf'
#key_target = 'target_df'
#key_target_sem = 'target_sem'

sequence_features_spec = {
key_input:
Expand All @@ -255,4 +250,4 @@ def _ReadBlockExample(data_filepattern, train_samples, hierarchy_level,
tf.string)

example = tf.parse_single_example(serialized_example, sequence_features_spec)
return key, example, samples_lo, samples_sem_lo
return key, example, samples_lo, samples_sem_lo
42 changes: 32 additions & 10 deletions src/run_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,50 @@ GPU=0
BATCH_SIZE=8
BASE_DIR='./train'
# Fill in training data filepattern here.
DATA=''
DATA='data/vox19_dim32/train*.tfrecords' # data for 19cm level
#DATA='data/vox5-9-19_dim32/train_*.tfrecords' # data for 9cm and 5cm levels
NUMBER_OF_STEPS=100000

STORED_BLOCK_DIM_HI=64
STORED_BLOCK_HEIGHT_HI=64

# coarse level
IS_BASE_LEVEL=1
HIERARCHY_LEVEL=1
HIERARCHY_LEVEL=3
STORED_BLOCK_DIM=32
STORED_BLOCK_HEIGHT=16
BLOCK_DIM=32
BLOCK_HEIGHT=64
BLOCK_HEIGHT=16
TRAIN_SAMPLES=0
VERSION=003

## mid level
#IS_BASE_LEVEL=0
#HIERARCHY_LEVEL=2
#STORED_BLOCK_DIM=32
#STORED_BLOCK_HEIGHT=32
#BLOCK_DIM=32
#BLOCK_HEIGHT=32
#TRAIN_SAMPLES=1
#VERSION=002

## hi level
#IS_BASE_LEVEL=0
#HIERARCHY_LEVEL=1
#STORED_BLOCK_DIM=64
#STORED_BLOCK_HEIGHT=64
#BLOCK_DIM=32
#BLOCK_HEIGHT=64
#TRAIN_SAMPLES=1
#VERSION=001

PREDICT_SEMANTICS=0 # set to 1 to predict semantics
WEIGHT_SEM=0.5

VERSION=000

python train.py \
--gpu="${GPU}" \
--train_dir=${BASE_DIR}/train_v${VERSION} \
--batch_size="${BATCH_SIZE}" \
--data_filepattern="${DATA}" \
--stored_dim_block_hi="${STORED_BLOCK_DIM_HI}" \
--stored_height_block_hi="${STORED_BLOCK_HEIGHT_HI}" \
--stored_dim_block="${STORED_BLOCK_DIM}" \
--stored_height_block="${STORED_BLOCK_HEIGHT}" \
--dim_block="${BLOCK_DIM}" \
--height_block="${BLOCK_HEIGHT}" \
--hierarchy_level="${HIERARCHY_LEVEL}" \
Expand Down
10 changes: 5 additions & 5 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
flags.DEFINE_bool('predict_semantics', False,
'Also predict semantic labels per-voxel.')
flags.DEFINE_integer('num_quant_levels', 256, 'Number of quantization bins.')
flags.DEFINE_integer('stored_dim_block_hi', 64,
flags.DEFINE_integer('stored_dim_block', 64,
'Stored data block x/z dim, high-resolution.')
flags.DEFINE_integer('stored_height_block_hi', 64,
flags.DEFINE_integer('stored_height_block', 64,
'Stored data block y dim, high-resolution.')
flags.DEFINE_float('weight_semantic', 0.5, 'Weight for semantic loss.')

Expand Down Expand Up @@ -67,8 +67,8 @@ def _train():
FLAGS.train_samples,
FLAGS.dim_block,
FLAGS.height_block,
FLAGS.stored_dim_block_hi,
FLAGS.stored_height_block_hi,
FLAGS.stored_dim_block,
FLAGS.stored_height_block,
FLAGS.is_base_level,
FLAGS.hierarchy_level,
FLAGS.num_quant_levels,
Expand Down Expand Up @@ -230,4 +230,4 @@ def main(_):


if __name__ == '__main__':
tf.app.run()
tf.app.run()

0 comments on commit 0805326

Please sign in to comment.