Skip to content

Commit

Permalink
Use python logging in training.
Browse files Browse the repository at this point in the history
This way, we get the training logs in the experiment_root too!
TODO: Maybe also do that in embed and eval?
  • Loading branch information
lucasb-eyer committed Nov 27, 2017
1 parent 0e30b89 commit 250eb17
Show file tree
Hide file tree
Showing 2 changed files with 225 additions and 17 deletions.
199 changes: 199 additions & 0 deletions common.py
@@ -1,6 +1,7 @@
""" A bunch of general utilities shared by train/embed/eval """

from argparse import ArgumentTypeError
import logging
import os

import numpy as np
Expand Down Expand Up @@ -154,3 +155,201 @@ def fid_to_image(fid, pid, image_root, image_size):
image_resized = tf.image.resize_images(image_decoded, image_size)

return image_resized, fid, pid


def get_logging_dict(name):
return {
'version': 1,
'disable_existing_loggers': False,
'formatters': {
'standard': {
'format': '%(asctime)s [%(levelname)s] %(name)s: %(message)s'
},
},
'handlers': {
'stderr': {
'level': 'INFO',
'formatter': 'standard',
'class': 'common.ColorStreamHandler',
'stream': 'ext://sys.stderr',
},
'logfile': {
'level': 'DEBUG',
'formatter': 'standard',
'class': 'logging.FileHandler',
'filename': name + '.log',
'mode': 'a',
}
},
'loggers': {
'': {
'handlers': ['stderr', 'logfile'],
'level': 'DEBUG',
'propagate': True
},

# extra ones to shut up.
'tensorflow': {
'handlers': ['stderr', 'logfile'],
'level': 'INFO',
},
}
}


# Source for the remainder: https://gist.github.com/mooware/a1ed40987b6cc9ab9c65
# Fixed some things mentioned in the comments there.

# colored stream handler for python logging framework (use the ColorStreamHandler class).
#
# based on:
# http://stackoverflow.com/questions/384076/how-can-i-color-python-logging-output/1336640#1336640

# how to use:
# i used a dict-based logging configuration, not sure what else would work.
#
# import logging, logging.config, colorstreamhandler
#
# _LOGCONFIG = {
# "version": 1,
# "disable_existing_loggers": False,
#
# "handlers": {
# "console": {
# "class": "colorstreamhandler.ColorStreamHandler",
# "stream": "ext://sys.stderr",
# "level": "INFO"
# }
# },
#
# "root": {
# "level": "INFO",
# "handlers": ["console"]
# }
# }
#
# logging.config.dictConfig(_LOGCONFIG)
# mylogger = logging.getLogger("mylogger")
# mylogger.warning("foobar")

# Copyright (c) 2014 Markus Pointner
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

class _AnsiColorStreamHandler(logging.StreamHandler):
DEFAULT = '\x1b[0m'
RED = '\x1b[31m'
GREEN = '\x1b[32m'
YELLOW = '\x1b[33m'
CYAN = '\x1b[36m'

CRITICAL = RED
ERROR = RED
WARNING = YELLOW
INFO = DEFAULT # GREEN
DEBUG = CYAN

@classmethod
def _get_color(cls, level):
if level >= logging.CRITICAL: return cls.CRITICAL
elif level >= logging.ERROR: return cls.ERROR
elif level >= logging.WARNING: return cls.WARNING
elif level >= logging.INFO: return cls.INFO
elif level >= logging.DEBUG: return cls.DEBUG
else: return cls.DEFAULT

def __init__(self, stream=None):
logging.StreamHandler.__init__(self, stream)

def format(self, record):
text = logging.StreamHandler.format(self, record)
color = self._get_color(record.levelno)
return (color + text + self.DEFAULT) if self.is_tty() else text

def is_tty(self):
isatty = getattr(self.stream, 'isatty', None)
return isatty and isatty()


class _WinColorStreamHandler(logging.StreamHandler):
# wincon.h
FOREGROUND_BLACK = 0x0000
FOREGROUND_BLUE = 0x0001
FOREGROUND_GREEN = 0x0002
FOREGROUND_CYAN = 0x0003
FOREGROUND_RED = 0x0004
FOREGROUND_MAGENTA = 0x0005
FOREGROUND_YELLOW = 0x0006
FOREGROUND_GREY = 0x0007
FOREGROUND_INTENSITY = 0x0008 # foreground color is intensified.
FOREGROUND_WHITE = FOREGROUND_BLUE | FOREGROUND_GREEN | FOREGROUND_RED

BACKGROUND_BLACK = 0x0000
BACKGROUND_BLUE = 0x0010
BACKGROUND_GREEN = 0x0020
BACKGROUND_CYAN = 0x0030
BACKGROUND_RED = 0x0040
BACKGROUND_MAGENTA = 0x0050
BACKGROUND_YELLOW = 0x0060
BACKGROUND_GREY = 0x0070
BACKGROUND_INTENSITY = 0x0080 # background color is intensified.

DEFAULT = FOREGROUND_WHITE
CRITICAL = BACKGROUND_YELLOW | FOREGROUND_RED | FOREGROUND_INTENSITY | BACKGROUND_INTENSITY
ERROR = FOREGROUND_RED | FOREGROUND_INTENSITY
WARNING = FOREGROUND_YELLOW | FOREGROUND_INTENSITY
INFO = FOREGROUND_GREEN
DEBUG = FOREGROUND_CYAN

@classmethod
def _get_color(cls, level):
if level >= logging.CRITICAL: return cls.CRITICAL
elif level >= logging.ERROR: return cls.ERROR
elif level >= logging.WARNING: return cls.WARNING
elif level >= logging.INFO: return cls.INFO
elif level >= logging.DEBUG: return cls.DEBUG
else: return cls.DEFAULT

def _set_color(self, code):
import ctypes
ctypes.windll.kernel32.SetConsoleTextAttribute(self._outhdl, code)

def __init__(self, stream=None):
logging.StreamHandler.__init__(self, stream)
# get file handle for the stream
import ctypes, ctypes.util
# for some reason find_msvcrt() sometimes doesn't find msvcrt.dll on my system?
crtname = ctypes.util.find_msvcrt()
if not crtname:
crtname = ctypes.util.find_library("msvcrt")
crtlib = ctypes.cdll.LoadLibrary(crtname)
self._outhdl = crtlib._get_osfhandle(self.stream.fileno())

def emit(self, record):
color = self._get_color(record.levelno)
self._set_color(color)
logging.StreamHandler.emit(self, record)
self._set_color(self.FOREGROUND_WHITE)

# select ColorStreamHandler based on platform
import platform
if platform.system() == 'Windows':
ColorStreamHandler = _WinColorStreamHandler
else:
ColorStreamHandler = _AnsiColorStreamHandler
43 changes: 26 additions & 17 deletions train.py
Expand Up @@ -2,6 +2,7 @@
from argparse import ArgumentParser
from datetime import timedelta
from importlib import import_module
import logging.config
import os
from signal import SIGINT, SIGTERM
import sys
Expand Down Expand Up @@ -195,9 +196,9 @@ def main():
# If the experiment directory exists already, we bail in fear.
if os.path.exists(args.experiment_root):
if os.listdir(args.experiment_root):
print('The directory {} already exists and is not empty. If '
'you want to resume training, append --resume to your '
'call.'.format(args.experiment_root))
print('The directory {} already exists and is not empty.'
' If you want to resume training, append --resume to'
' your call.'.format(args.experiment_root))
exit(1)
else:
os.makedirs(args.experiment_root)
Expand All @@ -207,19 +208,23 @@ def main():
with open(args_file, 'w') as f:
json.dump(vars(args), f, ensure_ascii=False, indent=2, sort_keys=True)

log_file = os.path.join(args.experiment_root, "train")
logging.config.dictConfig(common.get_logging_dict(log_file))
log = logging.getLogger('train')

# Also show all parameter values at the start, for ease of reading logs.
print('Training using the following parameters:')
log.info('Training using the following parameters:')
for key, value in sorted(vars(args).items()):
print('{}: {}'.format(key, value))
log.info('{}: {}'.format(key, value))

# Check them here, so they are not required when --resume-ing.
if not args.train_set:
parser.print_help()
print("You did not specify the `train_set` argument!")
log.error("You did not specify the `train_set` argument!")
sys.exit(1)
if not args.image_root:
parser.print_help()
print("You did not specify the required `image_root` argument!")
log.error("You did not specify the required `image_root` argument!")
sys.exit(1)

# Load the data from the CSV file.
Expand Down Expand Up @@ -351,7 +356,7 @@ def main():
if args.resume:
# In case we're resuming, simply load the full checkpoint to init.
last_checkpoint = tf.train.latest_checkpoint(args.experiment_root)
print('Restoring from checkpoint: {}'.format(last_checkpoint))
log.info('Restoring from checkpoint: {}'.format(last_checkpoint))
checkpoint_saver.restore(sess, last_checkpoint)
else:
# But if we're starting from scratch, we may need to load some
Expand All @@ -370,7 +375,7 @@ def main():
summary_writer = tf.summary.FileWriter(args.experiment_root, sess.graph)

start_step = sess.run(global_step)
print('Starting training from iteration {}.'.format(start_step))
log.info('Starting training from iteration {}.'.format(start_step))

# Finally, here comes the main-loop. This `Uninterrupt` is a handy
# utility such that an iteration still finishes on Ctrl+C and we can
Expand All @@ -397,13 +402,17 @@ def main():

# Do a huge print out of the current progress.
seconds_todo = (args.train_iterations - step) * elapsed_time
print('iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, '
'batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it)'.format(
step, float(np.min(b_loss)), float(np.mean(b_loss)),
float(np.max(b_loss)),
args.batch_k-1, float(b_prec_at_k),
timedelta(seconds=int(seconds_todo)), elapsed_time),
flush=True)
log.info('iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, '
'batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it)'.format(
step,
float(np.min(b_loss)),
float(np.mean(b_loss)),
float(np.max(b_loss)),
args.batch_k-1, float(b_prec_at_k),
timedelta(seconds=int(seconds_todo)),
elapsed_time))
sys.stdout.flush()
sys.stderr.flush()

# Save a checkpoint of training every so often.
if (args.checkpoint_frequency > 0 and
Expand All @@ -413,7 +422,7 @@ def main():

# Stop the main-loop at the end of the step, if requested.
if u.interrupted:
print("Interrupted on request!")
log.info("Interrupted on request!")
break

# Store one final checkpoint. This might be redundant, but it is crucial
Expand Down

0 comments on commit 250eb17

Please sign in to comment.