Skip to content

Commit

Permalink
more evaluator fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jyp committed Sep 15, 2021
1 parent ef3e8ea commit 97d45f5
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions typedflow_rts.py
Expand Up @@ -238,16 +238,19 @@ def evaluate (model_static, model_fn, xs, result="y_"):
total_len = len(xs[k0]) # total length
else:
total_len = 1
zeros = dict((k,tf.zeros(phs[k]["shape"], dtype=phs[k]["dtype"])) for k in phs.keys())
zeros = dict((k,tf.zeros(phs[k]["shape"][1:], # remove the batch size
dtype=phs[k]["dtype"])) for k in phs.keys())
results = []
def run():
for i in range(0, bs*(-(-total_len//bs)), bs):
chunks = dict((k,tf.zeros(phs[k]["shape"], dtype=phs[k]["dtype"])) for k in phs)
for k in xs:
print(".",end="")
chunks = dict()
for k in phs:
chunks[k] = xs[k][i:i+bs]
if i + bs > total_len:
# dealing with an incomplete last chunk
origLen = total_len - i
for k in xs:
for k in chunks:
chunks[k] = list(chunks[k]) + [zeros[k]] * (bs - origLen) # pad the last chunk
else:
origLen = bs
Expand Down

0 comments on commit 97d45f5

Please sign in to comment.