Skip to content
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
43 lines (32 sloc) 1.43 KB
from typing import TYPE_CHECKING
from allennlp.common.params import Params
from allennlp.models.model import Model
from import Callback, handle_event
from import Events
from import MovingAverage
from import CallbackTrainer
class UpdateMovingAverage(Callback):
Callback that orchestrates checkpointing of your model and training state.
moving_aveage : ``MovingAverage``
The MovingAverage object to update.
def __init__(self, moving_average: MovingAverage) -> None:
self.moving_average = moving_average
@handle_event(Events.BATCH_END, priority=-1000)
def apply_moving_average(self, trainer: "CallbackTrainer") -> None:
def from_params(cls, params: Params, model: Model) -> "UpdateMovingAverage": # type: ignore
moving_average_params = params.pop("moving_average")
model_parameters = [
[name, param] for name, param in model.named_parameters() if param.requires_grad
moving_average = MovingAverage.from_params(
params=moving_average_params, parameters=model_parameters
return UpdateMovingAverage(moving_average)
You can’t perform that action at this time.