Skip to content

Commit

Permalink
Merge branch 'develop' of ssh://phabricator.mitk.org:2222/source/trix…
Browse files Browse the repository at this point in the history
…i into develop
  • Loading branch information
dzimmerer committed May 8, 2020
2 parents aab305c + 1ebe1bc commit 607db99
Show file tree
Hide file tree
Showing 10 changed files with 374 additions and 23 deletions.
19 changes: 17 additions & 2 deletions Readme.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[![DOI](https://zenodo.org/badge/134823632.svg)](https://zenodo.org/badge/latestdoi/134823632)
[![PyPI version](https://badge.fury.io/py/trixi.svg)](https://badge.fury.io/py/trixi)
[![Build Status](https://img.shields.io/travis/MIC-DKFZ/trixi.svg)](https://travis-ci.org/MIC-DKFZ/trixi)
[![Documentation Status](https://readthedocs.org/projects/trixi/badge/?version=latest)](https://trixi.readthedocs.io/en/latest/?badge=latest)
[![Documentation Status](https://readthedocs.org/projects/trixi/badge/?version=develop)](https://trixi.readthedocs.io/en/develop/?badge=develop)
[![Downloads](https://pepy.tech/badge/trixi)](https://pepy.tech/project/trixi)
[![GitHub](https://img.shields.io/pypi/l/trixi.svg)](https://github.com/MIC-DKFZ/trixi/blob/master/LICENSE)
<p align="center">
Expand All @@ -13,8 +13,9 @@ Finally get some structure into your machine learning experiments.

* [Features](#features)
* [Installation](#installation)
* [Documentation](#documentation) ([trixi.rtfd.io](https://trixi.readthedocs.io/en/latest/))
* [Documentation](#documentation) ([trixi.rtfd.io](https://trixi.readthedocs.io/en/develop/))
* [Examples](#examples)
* [How to Cite](#how-to-cite)

# Contribute

Expand Down Expand Up @@ -151,3 +152,17 @@ Examples can be found here for:
* [Experiment Infrastructure](https://github.com/MIC-DKFZ/trixi/blob/master/examples/pytorch_experiment.ipynb) (with a
simple MNIST Experiment example and resuming and comparison of different hyperparameters)
* [U-Net Example](https://github.com/MIC-DKFZ/basic_unet_example)

# How to Cite

If you use **trixi** in your project, we'd appreciate a citation, for example like this

@misc{trixi2017,
author = {Zimmerer, David and Petersen, Jens and Köhler, Gregor and Wasserthal, Jakob and Adler, Tim and Wirkert, Sebastian and Ross, Tobias},
title = {trixi - Training and Retrospective Insight eXperiment Infrastructure},
year = {2017},
publisher = {GitHub},
journal = {GitHub Repository},
howpublished = {\url{https://github.com/MIC-DKFZ/trixi}},
doi = {10.5281/zenodo.1345136}
}
4 changes: 2 additions & 2 deletions doc/_api/trixi.logger.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,15 @@ tensorboard
:undoc-members:
:show-inheritance:

:hidden:`TensorboardXLogger`
:hidden:`TensorboardLogger`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. automodule:: trixi.logger.tensorboard.tensorboardxlogger
:members:
:undoc-members:
:show-inheritance:

:hidden:`PytorchTensorboardXLogger`
:hidden:`PytorchTensorboardLogger`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. automodule:: trixi.logger.tensorboard.pytorchtensorboardxlogger
Expand Down
317 changes: 317 additions & 0 deletions examples/tensorboard_logger_example.ipynb

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions requirements_full.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
-r requirements.txt
torch>=1.1.0
torch>=1.3.0
torchvision>=0.2.1
python-telegram-bot>=10.1.0
umap-learn>=0.3.6
scikit-learn==0.20.2
scikit-learn>=0.20.2
slackclient>=1.3.1
tb-nightly==1.14.0a20190523
tensorboard==2.1.1
4 changes: 2 additions & 2 deletions test/test_pytorchtensorboardxlogger.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
from scipy import misc

from trixi.logger.experiment.experimentlogger import ExperimentLogger
from trixi.logger.tensorboard.pytorchtensorboardxlogger import PytorchTensorboardXLogger
from trixi.logger.tensorboard.pytorchtensorboardlogger import PytorchTensorboardLogger
from trixi.util.config import Config


class TestPytorchTensorboardXLogger(unittest.TestCase):

def setUp(self):
self.test_dir = tempfile.gettempdir()
self.logger = PytorchTensorboardXLogger(self.test_dir)
self.logger = PytorchTensorboardLogger(self.test_dir)

def tearDown(self):
self.logger.close()
Expand Down
11 changes: 7 additions & 4 deletions trixi/experiment/pytorchexperiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,25 @@
from trixi.experiment.experiment import Experiment
from trixi.logger import CombinedLogger, PytorchExperimentLogger, PytorchVisdomLogger

from trixi.logger.tensorboard.pytorchtensorboardxlogger import PytorchTensorboardXLogger
from trixi.logger.tensorboard.pytorchtensorboardlogger import PytorchTensorboardLogger
from trixi.util import Config, ResultElement, ResultLogDict, SourcePacker, name_and_iter_to_filename
from trixi.util.config import update_from_sys_argv
from trixi.util.pytorchutils import set_seed
from trixi.util.util import is_picklable

logger_lookup_dict = dict(
visdom=PytorchVisdomLogger,
tensorboard=PytorchTensorboardXLogger,
tensorboard=PytorchTensorboardLogger,
)

try:
from trixi.logger import TelegramMessageLogger
from trixi.logger.message.slackmessagelogger import SlackMessageLogger

logger_lookup_dict["slack"] = SlackMessageLogger
except:
pass

try:
from trixi.logger import TelegramMessageLogger
logger_lookup_dict["telegram"] = TelegramMessageLogger
except:
pass
Expand Down
4 changes: 2 additions & 2 deletions trixi/logger/tensorboard/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from trixi.logger.tensorboard.tensorboardxlogger import TensorboardXLogger
from trixi.logger.tensorboard.tensorboardlogger import TensorboardLogger
try:
from trixi.logger.tensorboard.pytorchtensorboardxlogger import PytorchTensorboardXLogger
from trixi.logger.tensorboard.pytorchtensorboardlogger import PytorchTensorboardLogger
except ImportError as e:
import warnings
warnings.warn(ImportWarning("Could not import Pytorch related modules:\n%s"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
from trixi.logger.tensorboard.tensorboardxlogger import TensorboardXLogger
from trixi.logger.tensorboard.tensorboardlogger import TensorboardLogger


class PytorchTensorboardXLogger(TensorboardXLogger):
class PytorchTensorboardLogger(TensorboardLogger):
"""Abstract interface for visual logger."""

def process_params(self, f, *args, **kwargs):
Expand All @@ -21,7 +21,7 @@ def process_params(self, f, *args, **kwargs):
return f(self, *args, **kwargs)

def __init__(self, *args, **kwargs):
super(PytorchTensorboardXLogger, self).__init__(*args, **kwargs)
super(PytorchTensorboardLogger, self).__init__(*args, **kwargs)

def plot_model_structure(self, model, input_size):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from trixi.util.util import np_make_grid


class TensorboardXLogger(NumpySeabornPlotLogger):
"""Logger that uses tensorboardX to log to Tensorboard."""
class TensorboardLogger(NumpySeabornPlotLogger):
"""Logger that uses tensorboard to log to Tensorboard."""

def __init__(self, target_dir, *args, **kwargs):

super(TensorboardXLogger, self).__init__(*args, **kwargs)
super(TensorboardLogger, self).__init__(*args, **kwargs)

os.makedirs(target_dir, exist_ok=True)

Expand Down Expand Up @@ -280,5 +280,21 @@ def show_pr_curve(self, tensor, labels, name="pr-curve", counter=None, *args, **

self.writer.add_pr_curve(tag=name, labels=labels, predictions=tensor, global_step=self.val_dict["{}-pr-curve".format(name)])

def show_hparams(self, hparam_dict=None, metric_dict=None, counter=None, *args, **kwargs):
"""
Args:
hparam_dict: Each key-value pair in the dictionary is the name of the hyper parameter and it’s corresponding value.
metric_dict: Each key-value pair in the dictionary is the name of the metric and it’s corresponding value.
Note that the key used here should be unique in the tensorboard record. Otherwise the value you added by
show_value will be displayed in the hparam plugin. In most cases, this is unwanted.
"""
if counter is not None:
self.val_dict["{}-hparams"] = counter
else:
self.val_dict["{}-hparams"] += 1

self.writer.add_hparams(hparam_dict=hparam_dict, metric_dict=metric_dict)

def close(self):
self.writer.close()
4 changes: 2 additions & 2 deletions trixi/util/pytorchexperimentstub.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

from trixi.logger import PytorchExperimentLogger, PytorchVisdomLogger, TelegramMessageLogger
from trixi.logger.message.slackmessagelogger import SlackMessageLogger
from trixi.logger.tensorboard import PytorchTensorboardXLogger
from trixi.logger.tensorboard import PytorchTensorboardLogger
from trixi.util import ResultElement, ResultLogDict, Config

logger_lookup_dict = dict(
visdom=PytorchVisdomLogger,
tensorboard=PytorchTensorboardXLogger,
tensorboard=PytorchTensorboardLogger,
telegram=TelegramMessageLogger,
slack=SlackMessageLogger,
)
Expand Down

0 comments on commit 607db99

Please sign in to comment.