Skip to content

Commit

Permalink
Fixing index error after close files and clear _log_destination (#323)
Browse files Browse the repository at this point in the history
* A bug fix for empty log destinations list.

* Add test for the case where `_open_files` is True.

* Stupid linting problems.

* Update the test to show the bug and use the simplified fix.
  • Loading branch information
hunterhector committed Sep 17, 2020
1 parent cb62781 commit a9e79f8
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
14 changes: 14 additions & 0 deletions tests/run/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import unittest
from pathlib import Path
import os
import sys
from typing import List, Dict, Tuple

import torch
Expand Down Expand Up @@ -93,6 +94,7 @@ def setUp(self) -> None:

self.checkpoint_dir = tempfile.mkdtemp()
self.tbx_logging_dir = tempfile.mkdtemp()
self.log_output_dir = tempfile.mkdtemp()

def tearDown(self) -> None:
shutil.rmtree(self.checkpoint_dir)
Expand All @@ -101,6 +103,7 @@ def tearDown(self) -> None:
def test_train_loop(self):
optimizer = torch.optim.Adam(self.model.parameters())
output_path = os.path.join(self.checkpoint_dir, "output_{split}.txt")
log_path = Path(self.log_output_dir) / "log.txt"
executor = Executor(
model=self.model,
train_data=self.datasets["train"],
Expand All @@ -127,6 +130,7 @@ def test_train_loop(self):
action_on_plateau=[action.early_stop(patience=2),
action.reset_params(),
action.scale_lr(0.8)],
log_destination=[sys.stdout, log_path],
log_every=cond.iteration(20),
show_live_progress=True,
)
Expand All @@ -140,10 +144,13 @@ def test_train_loop(self):
with open(path, "r") as f:
self.assertEqual(len(f.read().split(",")), len(dataset))

self.assertTrue(os.path.exists(log_path))

executor.save()
executor.load()

def test_tbx_logging(self):
log_path = Path(self.log_output_dir) / "log.txt"
executor = Executor(
model=self.model,
train_data=self.datasets["train"],
Expand All @@ -169,6 +176,7 @@ def test_tbx_logging(self):
action_on_plateau=[action.early_stop(patience=2),
action.reset_params(),
action.scale_lr(0.8)],
log_destination=[sys.stdout, log_path],
log_every=cond.iteration(20),
show_live_progress=True,
)
Expand All @@ -178,6 +186,12 @@ def test_tbx_logging(self):
self.assertTrue(path.exists())
self.assertEqual(len(list(os.walk(path))), 1)

# At this point, `executor._files_opened` is True, which will run the
# main steps in the `_open_files()` function, it can cause IndexError
# prior to this bug fix.
executor.test()
self.assertTrue(os.path.exists(log_path))


if __name__ == "__main__":
test = ExecutorTest()
Expand Down
13 changes: 9 additions & 4 deletions texar/torch/run/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import re
import sys
import time
from collections import OrderedDict, defaultdict # pylint: disable=unused-import
from collections import (OrderedDict, # pylint: disable=unused-import
defaultdict)
from datetime import datetime
from pathlib import Path
from typing import (
Expand Down Expand Up @@ -1346,8 +1347,9 @@ def train(self):
def _try_get_data_size(executor: 'Executor'):
assert executor.train_data is not None
try:
# pylint: disable=protected-access
size = len(executor.train_data)
executor._train_tracker.set_size(size) # pylint: disable=protected-access
executor._train_tracker.set_size(size)
except TypeError:
pass
executor.remove_action()
Expand Down Expand Up @@ -1467,7 +1469,8 @@ def log_fn(executor: 'Executor'):
_register(points, log_fn)

def _flush_log_hook(executor: 'Executor'):
executor._write_log("", skip_non_tty=True) # pylint: disable=protected-access
# pylint: disable=protected-access
executor._write_log("", skip_non_tty=True)

def _register_status_fn(update_event: Event, log_fn: LogFn):
def status_fn(executor: 'Executor'):
Expand Down Expand Up @@ -1740,13 +1743,16 @@ def _open_files(self) -> bool:
return False

self._opened_files = []

for idx, dest in enumerate(self.log_destination):
if isinstance(dest, (str, Path)):
# Append to the logs to prevent accidentally overwriting
# previous logs.
file = open(dest, "a")
self._opened_files.append(file)
self._log_destination[idx] = file
else:
self._log_destination[idx] = dest

if self._tbx_logging_dir is not None:
try:
Expand All @@ -1770,7 +1776,6 @@ def _close_files(self):
for file in self._opened_files:
file.close()
self._opened_files = []
self._log_destination = []

if self.summary_writer is not None:
self.summary_writer.close()
Expand Down

0 comments on commit a9e79f8

Please sign in to comment.