Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented model iteration averaging to reduce model variance #901

Merged
merged 16 commits into from Jul 17, 2020
Merged
3 changes: 2 additions & 1 deletion src/gluonts/mx/trainer/__init__.py
Expand Up @@ -14,9 +14,10 @@
# Relative imports
from . import learning_rate_scheduler as lrs
from . import model_averaging
from . import model_iteration_averaging
from ._base import Trainer

__all__ = ["lrs", "Trainer", "model_averaging"]
__all__ = ["lrs", "Trainer", "model_averaging", "model_iteration_averaging"]

# fix Sphinx issues, see https://bit.ly/2K2eptM
for item in __all__:
Expand Down
66 changes: 61 additions & 5 deletions src/gluonts/mx/trainer/_base.py
Expand Up @@ -40,6 +40,14 @@
save_epoch_info,
)

# iteration averaging
xcgoner marked this conversation as resolved.
Show resolved Hide resolved
from .model_iteration_averaging import (
IterationAveragingStrategy,
NTA_V1,
NTA_V2,
Alpha_Suffix,
)

logger = logging.getLogger("gluonts").getChild("trainer")


Expand Down Expand Up @@ -113,7 +121,9 @@ def __init__(
weight_decay: float = 1e-8,
init: Union[str, mx.initializer.Initializer] = "xavier",
hybridize: bool = True,
avg_strategy: AveragingStrategy = SelectNBestMean(num_models=1),
avg_strategy: Union[
AveragingStrategy, IterationAveragingStrategy
] = SelectNBestMean(num_models=1),
) -> None:

assert (
Expand Down Expand Up @@ -228,6 +238,12 @@ def loop(

epoch_loss = mx.metric.Loss()

# use averaged model for validation
if not is_training and isinstance(
self.avg_strategy, IterationAveragingStrategy
):
self.avg_strategy.load_averaged_model(net)

with tqdm(batch_iter) as it:
for batch_no, data_entry in enumerate(it, start=1):
if self.halt:
Expand All @@ -251,6 +267,13 @@ def loop(
loss.backward()
trainer.step(batch_size)

# iteration averaging in training
if isinstance(
self.avg_strategy,
IterationAveragingStrategy,
):
self.avg_strategy.apply(net)

epoch_loss.update(None, preds=loss)
lv = loss_value(epoch_loss)

Expand Down Expand Up @@ -289,6 +312,13 @@ def loop(
("" if is_training else "validation_") + "epoch_loss",
lv,
)

if not is_training and isinstance(
self.avg_strategy, IterationAveragingStrategy
):
# bring back the cached model
self.avg_strategy.load_cached_model(net)

return epoch_loss

for epoch_no in range(self.epochs):
Expand All @@ -307,6 +337,25 @@ def loop(
epoch_no, validation_iter, is_training=False
)

# update average trigger
if isinstance(
self.avg_strategy, IterationAveragingStrategy
):
if isinstance(self.avg_strategy, Alpha_Suffix):
# alpha suffix
self.avg_strategy.update_average_trigger(
epoch_no + 1
)
elif isinstance(self.avg_strategy, (NTA_V1, NTA_V2)):
# NTA
self.avg_strategy.update_average_trigger(
loss_value(epoch_loss)
)
xcgoner marked this conversation as resolved.
Show resolved Hide resolved
else:
raise NotImplementedError
# once triggered, update the average immediately
self.avg_strategy.apply(net)

should_continue = lr_scheduler.step(loss_value(epoch_loss))
if not should_continue:
logger.info("Stopping training")
Expand Down Expand Up @@ -344,10 +393,17 @@ def loop(
best_epoch_info["params_path"], self.ctx
)

logging.info("Computing averaged parameters.")
averaged_params_path = self.avg_strategy.apply(gluonts_temp)
if isinstance(self.avg_strategy, AveragingStrategy):
logging.info("Computing averaged parameters.")
averaged_params_path = self.avg_strategy.apply(
gluonts_temp
)

logging.info("Loading averaged parameters.")
net.load_parameters(averaged_params_path, self.ctx)

logging.info("Loading averaged parameters.")
net.load_parameters(averaged_params_path, self.ctx)
if isinstance(self.avg_strategy, IterationAveragingStrategy):
logging.info("Loading averaged parameters.")
self.avg_strategy.load_averaged_model(net)
xcgoner marked this conversation as resolved.
Show resolved Hide resolved

logger.info("End model training")
272 changes: 272 additions & 0 deletions src/gluonts/mx/trainer/model_iteration_averaging.py
@@ -0,0 +1,272 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

# Standard library imports
from typing import Any, Dict, Optional, List

import mxnet.gluon.nn as nn
xcgoner marked this conversation as resolved.
Show resolved Hide resolved

# First-party imports
from gluonts.core.component import validated


class IterationAveragingStrategy:
@validated()
def __init__(self):
r"""
Parameters
----------
averaged_model
Dict that maintains the averaged model parameters.
cached_model
Temporarily save the current model, so that the averaged model can be used for validation.
average_counter
The number of models accumulated in the average.
averaging_started
Indicate whether the model averaging has started.
"""
xcgoner marked this conversation as resolved.
Show resolved Hide resolved

self.averaged_model = None
self.cached_model = None
self.average_counter = 0
self.averaging_started = False

def update_average_trigger(self, average_trigger: Any):
xcgoner marked this conversation as resolved.
Show resolved Hide resolved
r"""
Parameters
----------
average_trigger
The criteria to trigger averaging.

Returns
-------
"""
# implement a naive strategy, use average_trigger as boolean
self.averaging_started = average_trigger
# raise NotImplementedError()
xcgoner marked this conversation as resolved.
Show resolved Hide resolved

def apply(self, model: nn.HybridBlock) -> Optional[Dict]:
r"""
Parameters
----------
model
The model of the current iteration.

Returns
-------
The averaged model, None if the averaging hasn't started.
"""

if self.averaging_started:
self.update_average(model)

return self.averaged_model

def update_average(self, model: nn.HybridBlock):
r"""
Parameters
----------
model
The model to update the average.
"""
self.average_counter += 1
if self.averaged_model is None:
self.averaged_model = {
k: v.list_data()[0].copy()
for k, v in model.collect_params().items()
}
else:
alpha = 1.0 / self.average_counter
xcgoner marked this conversation as resolved.
Show resolved Hide resolved
# moving average
for name, param_avg in self.averaged_model.items():
param_avg[:] += alpha * (
model.collect_params()[name].list_data()[0] - param_avg
)

def load_averaged_model(self, model: nn.HybridBlock):
r"""
When validating/evaluating the averaged model in the half way of training,
use load_averaged_model first to load the averaged model and overwrite the current model,
do the evaluation, and then use load_cached_model to load the current model back.

Parameters
----------
model
The model that the averaged model is loaded to.
"""
if self.averaged_model is not None:
# cache the current model
if self.cached_model is None:
self.cached_model = {
k: v.list_data()[0].copy()
for k, v in model.collect_params().items()
}
else:
for name, param_cached in self.cached_model.items():
param_cached[:] = model.collect_params()[name].list_data()[
0
]
# load the averaged model
for name, param_avg in self.averaged_model.items():
model.collect_params()[name].set_data(param_avg)

def load_cached_model(self, model: nn.HybridBlock):
r"""
Parameters
----------
model
The model that the cached model is loaded to.
"""
if self.cached_model is not None:
# load the cached model
for name, param_cached in self.cached_model.items():
model.collect_params()[name].set_data(param_cached)


class NTA_V1(IterationAveragingStrategy):
xcgoner marked this conversation as resolved.
Show resolved Hide resolved
val_logs: List[Any]

@validated()
def __init__(self, n: int = 5, maximize: bool = False):
r"""
Depending on the choice of metrics, the users may want to minimize or maximize the metrics.
Thus, set maximize = True to maximize, otherwise minimize.

Parameters
----------
n
The non-montone interval.
maximize
Whether to maximize or minimize the validation metric.
val_logs
Historical validation metrics.
xcgoner marked this conversation as resolved.
Show resolved Hide resolved
"""

super().__init__()

self.n = n
self.maximize = maximize
self.val_logs = []

def update_average_trigger(self, average_trigger: Any):
r"""
Parameters
----------
average_trigger
The criteria to trigger averaging, evaluation metrics in this case.

Returns
-------
"""

# implement NTA (salesforce)
# this is the implementation from the iclr (and salesforce github) version, which mismatches the arxiv (and gluonnlp) version
xcgoner marked this conversation as resolved.
Show resolved Hide resolved
if not self.averaging_started and self.n > 0:
if self.maximize:
if len(self.val_logs) > self.n and average_trigger < max(
self.val_logs[: -self.n]
):
self.averaging_started = True
else:
if len(self.val_logs) > self.n and average_trigger > min(
self.val_logs[: -self.n]
):
self.averaging_started = True
self.val_logs.append(average_trigger)


class NTA_V2(IterationAveragingStrategy):
val_logs: List[Any]

@validated()
def __init__(self, n: int = 5, maximize: bool = False):
r"""
Parameters
----------
n
The non-monotone interval
maximize
Whether to maximize or minimize the validation metric
val_logs
Historical validation metrics
"""

super().__init__()

self.n = n
self.maximize = maximize
self.val_logs = []

def update_average_trigger(self, average_trigger: Any):
r"""
Parameters
----------
average_trigger
The criteria to trigger averaging, evaluation metrics in this case.

Returns
-------
"""

# implement NTA (gluonnlp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for this comment, move in the top of the class with proper citations and links. If we merge the two classes both versions should be mentioned.

if not self.averaging_started and self.n > 0:
if self.maximize:
# in gluonnlp awd-lstm, "len(self.val_logs) > self.n" is used, but I think it should be ">=" instead
if len(self.val_logs) >= self.n and average_trigger < max(
self.val_logs[-self.n :]
):
self.averaging_started = True
else:
if len(self.val_logs) >= self.n and average_trigger > min(
self.val_logs[-self.n :]
):
self.averaging_started = True
self.val_logs.append(average_trigger)


class Alpha_Suffix(IterationAveragingStrategy):
@validated()
def __init__(self, epochs: int, alpha: float = 0.75):
r"""
Taking iteration average for the last epoch*alpha epochs

Parameters
----------
epochs
The total number of epochs.
alpha
Proportion of averaging.
alpha_suffix
The epoch where iteration averaging starts.
xcgoner marked this conversation as resolved.
Show resolved Hide resolved
"""

super().__init__()

assert alpha >= 0 and alpha <= 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think assert 0 <= alpha <= 1 will do the job as well.


self.alpha_suffix = epochs * (1.0 - alpha)

def update_average_trigger(self, average_trigger: Any):
r"""
Parameters
----------
average_trigger
The current number of epoch.

Returns
-------
"""

if not self.averaging_started:
if average_trigger >= self.alpha_suffix:
self.averaging_started = True