Skip to content

Commit

Permalink
Added basic file logger #1803
Browse files Browse the repository at this point in the history
  • Loading branch information
xmotli02 committed Jul 27, 2020
1 parent 3f2c102 commit 35c235d
Show file tree
Hide file tree
Showing 4 changed files with 273 additions and 0 deletions.
6 changes: 6 additions & 0 deletions docs/source/loggers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,10 @@ Test-tube
^^^^^^^^^

.. autoclass:: pytorch_lightning.loggers.test_tube.TestTubeLogger
:noindex:

FileLogger
^^^^^^^^^^

.. autoclass:: pytorch_lightning.loggers.file_logger.FileLogger
:noindex:
2 changes: 2 additions & 0 deletions pytorch_lightning/loggers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from pytorch_lightning.loggers.base import LightningLoggerBase, LoggerCollection
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.loggers.file_logger import FileLogger


__all__ = [
'LightningLoggerBase',
Expand Down
180 changes: 180 additions & 0 deletions pytorch_lightning/loggers/file_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
"""
File logger
-----------
"""
import io
import os
import csv
import torch

from argparse import Namespace
from typing import Optional, Dict, Any, Union


from pytorch_lightning import _logger as log
from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.utilities.distributed import rank_zero_only


class ExperimentWriter(object):
NAME_HPARAMS_FILE = 'hparams.yaml'
NAME_METRICS_FILE = 'metrics.csv'

def __init__(self, log_dir):
self.hparams = {}
self.metrics = []
self.metrics_keys = ["step"]

self.log_dir = log_dir
if not os.path.exists(log_dir):
os.makedirs(log_dir)

self.metrics_file_path = os.path.join(self.log_dir, self.NAME_METRICS_FILE)

def log_hparams(self, params):
self.hparams.update(params)

def log_metrics(self, metrics_dict, step=None):
def _handle_value(value):
if isinstance(value, torch.Tensor):
return value.item()
return value

if step is None:
step = len(self.metrics)

new_row = dict.fromkeys(self.metrics_keys)
new_row['step'] = step
for k, v in metrics_dict.items():
if k not in self.metrics_keys:
self.metrics_keys.append(k)
new_row[k] = _handle_value(v)
self.metrics.append(new_row)

def save(self):
hparams_file = os.path.join(self.log_dir, self.NAME_HPARAMS_FILE)
save_hparams_to_yaml(hparams_file, self.hparams)

if self.metrics:
with io.open(self.metrics_file_path, 'w', newline='') as f:
self.writer = csv.DictWriter(f, fieldnames=self.metrics_keys)
self.writer.writeheader()
self.writer.writerows(self.metrics)


class FileLogger(LightningLoggerBase):
r"""
Log to local file system in yaml and CSV format. Logs are saved to
``os.path.join(save_dir, name, version)``.
Example:
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.loggers import FileLogger
>>> logger = FileLogger("logs", name="my_exp_name")
>>> trainer = Trainer(logger=logger)
Args:
save_dir: Save directory
name: Experiment name. Defaults to ``'default'``.
version: Experiment version. If version is not specified the logger inspects the save
directory for existing versions, then automatically assigns the next available version.
"""

def __init__(self,
save_dir: str,
name: Optional[str] = "default",
version: Optional[Union[int, str]] = None):

super().__init__()
self._save_dir = save_dir
self._name = name or ''
self._version = version
self._experiment = None

@property
def root_dir(self) -> str:
"""
Parent directory for all checkpoint subdirectories.
If the experiment name parameter is ``None`` or the empty string, no experiment subdirectory is used
and the checkpoint will be saved in "save_dir/version_dir"
"""
if self.name is None or len(self.name) == 0:
return self._save_dir
return os.path.join(self._save_dir, self.name)

@property
def log_dir(self) -> str:
"""
The log directory for this run. By default, it is named
``'version_${self.version}'`` but it can be overridden by passing a string value
for the constructor's version parameter instead of ``None`` or an int.
"""
# create a pseudo standard path ala test-tube
version = self.version if isinstance(self.version, str) else f"version_{self.version}"
log_dir = os.path.join(self.root_dir, version)
return log_dir

@property
def experiment(self) -> ExperimentWriter:
r"""
Actual ExperimentWriter object. To use ExperimentWriter features in your
:class:`~pytorch_lightning.core.lightning.LightningModule` do the following.
Example::
self.logger.experiment.some_experiment_writer_function()
"""
if self._experiment is not None:
return self._experiment

os.makedirs(self.root_dir, exist_ok=True)
self._experiment = ExperimentWriter(log_dir=self.log_dir)
return self._experiment

@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = self._convert_params(params)
self.experiment.log_hparams(params)

@rank_zero_only
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
self.experiment.log_metrics(metrics, step)

@rank_zero_only
def save(self) -> None:
super().save()
self.experiment.save()

@rank_zero_only
def finalize(self, status: str) -> None:
self.save()

@property
def name(self) -> str:
return self._name

@property
def version(self) -> int:
if self._version is None:
self._version = self._get_next_version()
return self._version

def _get_next_version(self):
root_dir = os.path.join(self._save_dir, self.name)

if not os.path.isdir(root_dir):
log.warning('Missing logger folder: %s', root_dir)
return 0

existing_versions = []
for d in os.listdir(root_dir):
if os.path.isdir(os.path.join(root_dir, d)) and d.startswith("version_"):
existing_versions.append(int(d.split("_")[1]))

if len(existing_versions) == 0:
return 0

return max(existing_versions) + 1
85 changes: 85 additions & 0 deletions tests/loggers/test_file_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from argparse import Namespace

import pytest
import torch
import os

from pytorch_lightning.loggers import FileLogger


def test_file_logger_automatic_versioning(tmpdir):
"""Verify that automatic versioning works"""

root_dir = tmpdir.mkdir("exp")
root_dir.mkdir("version_0")
root_dir.mkdir("version_1")

logger = FileLogger(save_dir=tmpdir, name="exp")

assert logger.version == 2


def test_file_logger_manual_versioning(tmpdir):
"""Verify that manual versioning works"""

root_dir = tmpdir.mkdir("exp")
root_dir.mkdir("version_0")
root_dir.mkdir("version_1")
root_dir.mkdir("version_2")

logger = FileLogger(save_dir=tmpdir, name="exp", version=1)

assert logger.version == 1


def test_file_logger_named_version(tmpdir):
"""Verify that manual versioning works for string versions, e.g. '2020-02-05-162402' """

exp_name = "exp"
tmpdir.mkdir(exp_name)
expected_version = "2020-02-05-162402"

logger = FileLogger(save_dir=tmpdir, name=exp_name, version=expected_version)
logger.log_hyperparams({"a": 1, "b": 2})
logger.save()
assert logger.version == expected_version
assert os.listdir(tmpdir / exp_name) == [expected_version]
assert os.listdir(tmpdir / exp_name / expected_version)


@pytest.mark.parametrize("name", ['', None])
def test_file_logger_no_name(tmpdir, name):
"""Verify that None or empty name works"""
logger = FileLogger(save_dir=tmpdir, name=name)
logger.save()
assert logger.root_dir == tmpdir
assert os.listdir(tmpdir / 'version_0')


@pytest.mark.parametrize("step_idx", [10, None])
def test_file_logger_log_metrics(tmpdir, step_idx):
logger = FileLogger(tmpdir)
metrics = {
"float": 0.3,
"int": 1,
"FloatTensor": torch.tensor(0.1),
"IntTensor": torch.tensor(1)
}
logger.log_metrics(metrics, step_idx)
logger.save()


def test_file_logger_log_hyperparams(tmpdir):
logger = FileLogger(tmpdir)
hparams = {
"float": 0.3,
"int": 1,
"string": "abc",
"bool": True,
"dict": {'a': {'b': 'c'}},
"list": [1, 2, 3],
"namespace": Namespace(foo=Namespace(bar='buzz')),
"layer": torch.nn.BatchNorm1d
}
logger.log_hyperparams(hparams)
logger.save()

0 comments on commit 35c235d

Please sign in to comment.