Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infrastructure for logging to wandb #758

Merged
merged 4 commits into from Dec 10, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 21 additions & 0 deletions doc/run_experiment.rst
Expand Up @@ -1400,6 +1400,27 @@ Whether to save the predictions from the individual estimators underlying a
:ref:`predictions <predictions>` must be set.
Defaults to ``False``.

wandb_credentials *(Optional)*
""""""""""""""""""""""""""""""
To enable logging metrics and artifacts to `Weights & Biases <https://wandb.ai/>`__, specify
a dictionary as follows:

.. code-block:: python

{'wandb_entity': 'your_entity_name', 'wandb_project': 'your_project_name'}


``wandb_entity`` can be a user name or the name of a team or organization.
``wandb_project`` is the name of the project to which this experiment will be logged.
If a project by this name does not already exist, it will be created.

.. important::
1. Both `wandb_entity` and `wandb_project` must be specified. If any of them is missing, logging to W&B will not be enabled.
2. Before using Weights & Biases for the first time, users should log in and provide their API key as described in
`W&B Quickstart guidelines <https://docs.wandb.ai/quickstart#2-log-in-to-wb>`__.
3. Note that when using W&B logging, the skll run may take significantly longer due to the network traffic being
tamarl08 marked this conversation as resolved.
Show resolved Hide resolved
sent to W&B.


.. _run_experiment:

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Expand Up @@ -12,3 +12,4 @@ seaborn
sphinx_rtd_theme==1.2.0
tabulate
typing_extensions
wandb
18 changes: 18 additions & 0 deletions skll/config/__init__.py
Expand Up @@ -97,6 +97,7 @@ def __init__(self) -> None:
"train_file": "",
"use_folds_file_for_grid_search": "True",
"save_votes": "False",
"wandb_credentials": "{}",
}

correct_section_mapping = {
Expand Down Expand Up @@ -144,6 +145,7 @@ def __init__(self) -> None:
"train_file": "Input",
"use_folds_file_for_grid_search": "Tuning",
"save_votes": "Output",
"wandb_credentials": "Output",
}

# make sure that the defaults dictionary and the
Expand Down Expand Up @@ -324,6 +326,7 @@ def parse_config_file(
List[Number],
List[str],
bool,
Optional[Dict[str, str]],
]:
"""
Parse a SKLL experiment configuration file with the given path.
Expand Down Expand Up @@ -508,6 +511,9 @@ def parse_config_file(
save_votes : bool
Whether to save the individual predictions from voting learners.

wandb_credentials : Dict[str,str]
A dictionary holding W&B entity and project name, if logging to W&B is enabled.

Raises
------
IOError
Expand Down Expand Up @@ -857,6 +863,17 @@ def parse_config_file(
# learner estimators?
save_votes = config.getboolean("Output", "save_votes")

# get wandb credentials if specified. Both entity and project must
# be specified to enable W&B logging
wandb_credentials = yaml.load(fix_json(config.get("Output", "wandb_credentials")))
if wandb_credentials:
if "wandb_entity" not in wandb_credentials or "wandb_project" not in wandb_credentials:
logger.warning(
"Logging to W&B is not enabled! "
"Please specify both wandb_entity and wandb_project"
tamarl08 marked this conversation as resolved.
Show resolved Hide resolved
)
wandb_credentials = {}

#####################
# 4. Tuning section #
#####################
Expand Down Expand Up @@ -1070,6 +1087,7 @@ def parse_config_file(
learning_curve_train_sizes,
output_metrics,
save_votes,
wandb_credentials,
)


Expand Down
6 changes: 6 additions & 0 deletions skll/experiments/__init__.py
Expand Up @@ -36,6 +36,7 @@
from skll.utils.logging import close_and_remove_logger_handlers, get_skll_logger
from skll.version import __version__

from ..utils.wandb import WandbLogger
from .input import load_featureset
from .output import (
_print_fancy_output,
Expand Down Expand Up @@ -716,12 +717,17 @@ def run_configuration(
learning_curve_train_sizes,
output_metrics,
save_votes,
wandb_credentials,
) = parse_config_file(config_file, log_level=log_level)

# get the main experiment logger that will already have been
# created by the configuration parser so we don't need anything
# except the name `experiment`.
logger = get_skll_logger("experiment")
wandb_logger = WandbLogger(wandb_credentials)
wandb_logger.log_configuration(
{"experiment_name": experiment_name, "task": task, "learners": learners}
)

# Check if we have gridmap
if not local and not _HAVE_GRIDMAP:
Expand Down
1 change: 1 addition & 0 deletions skll/utils/testing.py
Expand Up @@ -291,6 +291,7 @@ def fill_in_config_options(
"save_cv_folds",
"save_cv_models",
"save_votes",
"wandb_credentials",
],
}

Expand Down
40 changes: 40 additions & 0 deletions skll/utils/wandb.py
@@ -0,0 +1,40 @@
"""
Utility classes and functions for logging to Weights & Biases.

:author: Tamar Lavee (tlavee@ets.org)
"""
from typing import Any, Dict, Optional

import wandb


class WandbLogger:
"""Interface for Weights and Biases logging."""

def __init__(self, wandb_credentials: Optional[Dict[str, str]]):
"""
Initialize the wandb_run if wandb_credentials are provided.

Parameters
----------
wandb_credentials : Optional[Dict[str, str]]
A dictionary containing the W&B entity and project names that will be
used to initialize the wandb run. If ``None``, logging to W&B will not be performed.
"""
self.wandb_run = None
if wandb_credentials:
self.wandb_run = wandb.init(

Check warning on line 26 in skll/utils/wandb.py

View check run for this annotation

Codecov / codecov/patch

skll/utils/wandb.py#L26

Added line #L26 was not covered by tests
project=wandb_credentials["wandb_project"], entity=wandb_credentials["wandb_entity"]
)

def log_configuration(self, conf_dict: Dict[str, Any]):
"""
Log a configuration dictionary to W&B if logging to W&B is enabled.

Parameters
----------
conf_dict : Dict[str, Any]
A dictionary mapping configuration field names to their values.
"""
if self.wandb_run:
self.wandb_run.config.update(conf_dict)

Check warning on line 40 in skll/utils/wandb.py

View check run for this annotation

Codecov / codecov/patch

skll/utils/wandb.py#L40

Added line #L40 was not covered by tests
94 changes: 89 additions & 5 deletions tests/test_input.py
Expand Up @@ -837,7 +837,7 @@ def test_config_parsing_no_grid_objectives_needed_for_learning_curve(self):

configuration = parse_config_file(config_path)
do_grid_search, grid_objectives = configuration[14:16]
output_metrics = configuration[-2]
output_metrics = configuration[45]

self.assertEqual(do_grid_search, False)
self.assertEqual(grid_objectives, [])
Expand Down Expand Up @@ -1471,7 +1471,7 @@ def test_learning_curve_metrics_and_no_objectives(self):
config_template_path, values_to_fill_dict, "learning_curve_metrics_and_no_objectives"
)
configuration = parse_config_file(config_path)
output_metrics = configuration[-2]
output_metrics = configuration[45]

self.assertEqual(output_metrics, ["accuracy", "unweighted_kappa"])

Expand All @@ -1495,7 +1495,7 @@ def test_learning_curve_metrics(self):

configuration = parse_config_file(config_path)
grid_objectives = configuration[15]
output_metrics = configuration[-2]
output_metrics = configuration[45]

self.assertEqual(output_metrics, ["accuracy"])
self.assertEqual(grid_objectives, [])
Expand Down Expand Up @@ -1737,7 +1737,7 @@ def test_config_parsing_default_save_votes_value(self):
)

configuration = parse_config_file(config_path)
save_votes = configuration[-1]
save_votes = configuration[46]

self.assertEqual(save_votes, False)

Expand All @@ -1764,7 +1764,7 @@ def test_config_parsing_set_save_votes_value(self):
)

configuration = parse_config_file(config_path)
save_votes = configuration[-1]
save_votes = configuration[46]

self.assertEqual(save_votes, True)

Expand All @@ -1790,3 +1790,87 @@ def test_config_parsing_use_log_instead_of_logs(self):

with self.assertRaises(KeyError):
parse_config_file(config_path)

def test_config_parsing_default_wandb_credentials(self):
"""Test that config parsing works as expected for default `wandb_credentials`values."""
values_to_fill_dict = {
"experiment_name": "config_parsing",
"task": "evaluate",
"train_directory": train_dir,
"test_directory": test_dir,
"featuresets": "[['f1', 'f2', 'f3']]",
"fixed_parameters": '[{"estimator_names": '
'["SVC", "LogisticRegression", "MultinomialNB"]}]',
"learners": "['VotingClassifier']",
"objectives": "['accuracy']",
"logs": output_dir,
"results": output_dir,
}

config_template_path = config_dir / "test_config_parsing.template.cfg"

config_path = fill_in_config_options(
config_template_path, values_to_fill_dict, "default_value_save_votes"
)

configuration = parse_config_file(config_path)
wandb_credentials = configuration[47]

self.assertEqual({}, wandb_credentials)

def test_config_parsing_set_wandb_values(self):
"""Test that config parsing works as expected for given `wandb_credentials`."""
values_to_fill_dict = {
"experiment_name": "config_parsing",
"task": "evaluate",
"train_directory": train_dir,
"test_directory": test_dir,
"featuresets": "[['f1', 'f2', 'f3']]",
"fixed_parameters": '[{"estimator_names": '
'["SVC", "LogisticRegression", "MultinomialNB"]}]',
"learners": "['VotingClassifier']",
"objectives": "['accuracy']",
"logs": output_dir,
"results": output_dir,
"wandb_credentials": '{"wandb_entity": "wandb_entity",'
' "wandb_project": "wandb_project"}',
}

config_template_path = config_dir / "test_config_parsing.template.cfg"

config_path = fill_in_config_options(
config_template_path, values_to_fill_dict, "default_value_save_votes"
)

configuration = parse_config_file(config_path)
wandb_credentials = configuration[47]

self.assertEqual(wandb_credentials["wandb_entity"], "wandb_entity")
self.assertEqual(wandb_credentials["wandb_project"], "wandb_project")

def test_config_parsing_set_wandb_missing_value(self):
"""Test that config parsing works as expected for when values are missing `wandb_credentials`."""
values_to_fill_dict = {
"experiment_name": "config_parsing",
"task": "evaluate",
"train_directory": train_dir,
"test_directory": test_dir,
"featuresets": "[['f1', 'f2', 'f3']]",
"fixed_parameters": '[{"estimator_names": ["SVC", "LogisticRegression", "MultinomialNB"]}]',
"learners": "['VotingClassifier']",
"objectives": "['accuracy']",
"logs": output_dir,
"results": output_dir,
"wandb_credentials": '{"wandb_project": "wandb_project"}',
}

config_template_path = config_dir / "test_config_parsing.template.cfg"

config_path = fill_in_config_options(
config_template_path, values_to_fill_dict, "default_value_save_votes"
)

configuration = parse_config_file(config_path)
wandb_credentials = configuration[47]

self.assertEqual({}, wandb_credentials)
39 changes: 39 additions & 0 deletions tests/test_wandb.py
@@ -0,0 +1,39 @@
"""
Tests for wandb logging utility class.

:author: Tamar Lavee (tlavee@ets.org)
"""

import unittest
from unittest.mock import Mock, patch

from skll.utils.wandb import WandbLogger


class TestWandb(unittest.TestCase):
"""Test cases for wandb interface class."""

def test_init_wandb_enabled(self):
"""Test initialization with wandb credentials specified."""
mock_wandb_run = Mock()
with patch("skll.utils.wandb.wandb.init", return_value=mock_wandb_run) as mock_wandb_init:
WandbLogger({"wandb_entity": "wandb_entity", "wandb_project": "wandb_project"})
mock_wandb_init.assert_called_with(project="wandb_project", entity="wandb_entity")

def test_init_wandb_disabled(self):
"""Test initialization with no wandb credentials."""
mock_wandb_run = Mock()
with patch("skll.utils.wandb.wandb.init", return_value=mock_wandb_run) as mock_wandb_init:
WandbLogger({})
mock_wandb_init.assert_not_called()

def test_update_config(self):
"""Test initialization with wandb credentials specified."""
mock_wandb_run = Mock()
with patch("skll.utils.wandb.wandb.init", return_value=mock_wandb_run) as mock_wandb_init:
wandb_logger = WandbLogger(
{"wandb_entity": "wandb_entity", "wandb_project": "wandb_project"}
)
wandb_logger.log_configuration({"task": "train"})
mock_wandb_init.assert_called_with(project="wandb_project", entity="wandb_entity")
mock_wandb_run.config.update.assert_called_with({"task": "train"})