Skip to content

Commit

Permalink
Fixed generate code to work with new delta generation code
Browse files Browse the repository at this point in the history
  • Loading branch information
Quangmire committed Nov 9, 2021
1 parent 7374eb4 commit b15d63b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 20 deletions.
34 changes: 19 additions & 15 deletions voyager/data_loader.py
Expand Up @@ -124,12 +124,24 @@ def _apply_delta(self, addr, page, offset):
else:
return addr - dist

def _apply_delta_to_idx(self, idx, page, offset):
prev_idx = self.pc_data[self.data[idx, 1], self.data[idx, 4] - 1]
prev_addr = self._idx_to_addr(prev_idx)
return self._apply_delta(prev_addr, page, offset)

def _idx_to_addr(self, data_idx):
page = self.reverse_page_mapping[self.data[data_idx][2]]
if isinstance(page, str):
return (self.orig[data_idx][0] << self.config.offset_bits) + self.orig[data_idx][1]
if isinstance(self.data, tf.Tensor):
page = self.reverse_page_mapping[self.data[data_idx, 2].numpy()]
if isinstance(page, str):
return (self.orig[data_idx.numpy()][0] << self.config.offset_bits) + self.orig[data_idx.numpy()][1]
else:
return (page << self.config.offset_bits) + self.data[data_idx, 3].numpy()
else:
return (page << self.config.offset_bits) + self.data[data_idx][3]
page = self.reverse_page_mapping[self.data[data_idx][2]]
if isinstance(page, str):
return (self.orig[data_idx][0] << self.config.offset_bits) + self.orig[data_idx][1]
else:
return (page << self.config.offset_bits) + self.data[data_idx][3]

@timefunction('Generating multi-label data')
def _generate_multi_label(self):
Expand Down Expand Up @@ -381,7 +393,7 @@ def mapper(idx):
y_page = hist[-1:, 2]
y_offset = hist[-1:, 3]

return inst_id, tf.concat(hists, axis=-1), y_page, y_offset
return idx, inst_id, tf.concat(hists, axis=-1), y_page, y_offset

# Closure for generating a reproducible random sequence
epoch_size = self.config.steps_per_epoch * self.config.batch_size
Expand Down Expand Up @@ -455,20 +467,12 @@ def random(x):
return train_ds, valid_ds, test_ds

# Unmaps the page and offset
def unmap(self, x, page, offset, sequence_length):
def unmap(self, idx, x, page, offset, sequence_length):
unmapped_page = self.reverse_page_mapping[page]

# DELTA LOCALIZED
if isinstance(unmapped_page, str):
prev_page = x[2 * sequence_length - 1]
prev_offset = x[-1]
unmapped_prev_page = self.reverse_page_mapping[prev_page]
prev_addr = (unmapped_prev_page << self.config.offset_bits) + prev_offset
delta = int(unmapped_page[1:])
if unmapped_page[0] == '+':
ret_addr = prev_addr + delta
else:
ret_addr = prev_addr - delta
ret_addr = self._apply_delta_to_idx(idx, page, offset)
else:
ret_addr = (unmapped_page << self.config.offset_bits) + offset

Expand Down
10 changes: 5 additions & 5 deletions voyager/model_wrappers.py
Expand Up @@ -182,7 +182,7 @@ def train(self, train_ds=None, valid_ds=None, callbacks=None):
logs = {}

# Main training loop
for _, x, y_page, y_out in train_ds:
for _, _, x, y_page, y_out in train_ds:
epoch_ended = False
self.step += 1

Expand Down Expand Up @@ -285,7 +285,7 @@ def evaluate(self, datasets=None, callbacks=None, training=False):

# Validation loop
for ds in datasets:
for step, (_, x, y_page, y_out) in enumerate(ds):
for step, (_, _, x, y_page, y_out) in enumerate(ds):
self.callbacks.on_test_batch_begin(step)
logs = self.evaluate_step(x, (y_page, y_out))
self.callbacks.on_test_batch_end(step, logs)
Expand Down Expand Up @@ -320,7 +320,7 @@ def generate(self, datasets=None, prefetch_file=None, callbacks=None):
self.reset_metrics()
self.callbacks.on_test_begin()
for ds in datasets:
for step, (batch_inst_ids, x, y_page, y_out) in enumerate(ds):
for step, (idx, batch_inst_ids, x, y_page, y_out) in enumerate(ds):
self.callbacks.on_test_batch_begin(step)
logits, logs = self.generate_step(x, (y_page, y_out))

Expand All @@ -338,11 +338,11 @@ def generate(self, datasets=None, prefetch_file=None, callbacks=None):
pred_offsets = tf.argmax(offset_logits, -1).numpy().tolist()

# Unmap addresses
for xi, inst_id, pred_page, pred_offset in zip(x.numpy().tolist(), batch_inst_ids.numpy().tolist(), pred_pages, pred_offsets):
for idxi, xi, inst_id, pred_page, pred_offset in zip(idx.numpy().tolist(), x.numpy().tolist(), batch_inst_ids.numpy().tolist(), pred_pages, pred_offsets):
# OOV
if pred_page == 0:
continue
addresses.append(self.benchmark.unmap(xi, pred_page, pred_offset, self.config.sequence_length))
addresses.append(self.benchmark.unmap(idxi, xi, pred_page, pred_offset, self.config.sequence_length))
inst_ids.append(inst_id)

self.callbacks.on_test_batch_end(step, logs)
Expand Down

0 comments on commit b15d63b

Please sign in to comment.