Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
simplify callback trainer (#3029)
Browse files Browse the repository at this point in the history
* simplify

* simplify

* moving average as callback

* rename callbacks()

* simplify

* update comments
  • Loading branch information
joelgrus committed Jul 9, 2019
1 parent 0663e0b commit dd3476f
Show file tree
Hide file tree
Showing 16 changed files with 538 additions and 373 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@
"optimizer": {"type": "sgd", "lr": 0.01, "momentum": 0.9},
"num_epochs": 2,
"callbacks": [
"generate_training_batches",
{"type": "train_supervised", "grad_norm": 1.0},
{"type": "gradient_norm_and_clip", "grad_norm": 1.0},
"checkpoint",
{"type": "track_metrics", "patience": 500},
"validate",
Expand Down
242 changes: 178 additions & 64 deletions allennlp/tests/training/callback_trainer_test.py

Large diffs are not rendered by default.

194 changes: 109 additions & 85 deletions allennlp/tests/training/gan_callback_trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def set_stage(self, stage: str) -
"""
# pylint: disable=bad-continuation,redefined-outer-name

from typing import Iterable, List, Iterator
from typing import Iterable, List, Iterator, Union, Optional
import tempfile

import torch
Expand All @@ -20,11 +20,15 @@ def set_stage(self, stage: str) -
from allennlp.common.params import Params
from allennlp.common.testing import ModelTestCase
from allennlp.data import Instance
from allennlp.data.iterators import DataIterator
from allennlp.data.dataset_readers.dataset_reader import DatasetReader, _LazyInstances
from allennlp.data.fields import ArrayField, MetadataField
from allennlp.models import Model
from allennlp.training import util as training_util
from allennlp.training.callback_trainer import CallbackTrainer
from allennlp.training.callbacks import Callback, Events, handle_event
from allennlp.training.optimizers import Optimizer
from allennlp.training.trainer_base import TrainerBase

from allennlp.tests.training.gan_trainer_test import InputSampler

Expand Down Expand Up @@ -127,90 +131,23 @@ def read(self, file_path: str) -> Iterable[Instance]:
return _LazyInstances(self._one_epoch)


@Callback.register("train-gan")
@Callback.register("track-gan-metrics")
class TrainGan(Callback):
def __init__(self) -> None:
self.generator_loss = 0.0
self.discriminator_real_loss = 0.0
self.discriminator_fake_loss = 0.0
self.fake_mean = 0.0
self.fake_stdev = 0.0
self.count = 0
self.loss = None

@handle_event(Events.EPOCH_START)
def reset_loss(self, _trainer):
self.generator_loss = 0.0
self.discriminator_real_loss = 0.0
self.discriminator_fake_loss = 0.0
self.fake_mean = 0.0
self.fake_stdev = 0.0
self.count = 0

@handle_event(Events.BATCH_START)
def zero_grad(self, trainer):
# pylint: disable=no-self-use
trainer.optimizer.zero_grad()

@handle_event(Events.FORWARD)
def compute_loss(self, trainer):
batch, = trainer.batch_group
array = batch["array"]

# We should not have mixed batches:
if len(set(batch["stage"])) != 1:
raise ValueError("mixed batch")

stage = batch["stage"][0]
trainer.optimizer.stage = stage

if stage == "discriminator_real":
# Generate real data and expect the discriminator to predict 1.
output = trainer.model.discriminator(array, torch.ones(1))
self.loss = output["loss"]
self.discriminator_real_loss += self.loss.sum().item()
elif stage == "discriminator_fake":
# Generate fake data and expect the discriminator to predict 0.
fake_data = trainer.model.generator(array)
output = trainer.model.discriminator(fake_data["output"], torch.zeros(1))
self.loss = output["loss"]
self.discriminator_fake_loss += self.loss.sum().item()
elif stage == "generator":
# Generate fake data and try to fool the discriminator.
generated = trainer.model.generator(array, trainer.model.discriminator)
fake_data = generated["output"]
self.loss = generated["loss"]
self.generator_loss += self.loss.sum().item()

self.fake_mean += fake_data.mean()
self.fake_stdev += fake_data.std()
self.count += 1

@handle_event(Events.BACKWARD)
def backpropagate_errors(self, trainer):
self.loss.backward()
trainer.train_loss += self.loss.item()

@handle_event(Events.BACKWARD, priority=1000)
def optimize(self, trainer):
# pylint: disable=no-self-use
trainer.optimizer.step()

@handle_event(Events.BATCH_END)
def compute_metrics(self, trainer):
def compute_metrics(self, trainer: 'GanCallbackTrainer'):
# pylint: disable=no-self-use
trainer.train_metrics = {
"dfl": self.discriminator_fake_loss,
"drl": self.discriminator_real_loss,
"gl": self.discriminator_real_loss,
"mean": self.fake_mean / max(self.count, 1),
"stdev": self.fake_stdev / max(self.count, 1)
"dfl": trainer.discriminator_fake_loss,
"drl": trainer.discriminator_real_loss,
"gl": trainer.discriminator_real_loss,
"mean": trainer.fake_mean / max(trainer.count, 1),
"stdev": trainer.fake_stdev / max(trainer.count, 1)
}


def config(sample_size: int = 500,
batches_per_epoch: int = 40,
num_epochs: int = 50,
learning_rate: float = 0.05) -> Params:
num_epochs: int = 50) -> Params:
return Params({
"dataset_reader": {
"type": "gan-callback",
Expand Down Expand Up @@ -243,28 +180,116 @@ def config(sample_size: int = 500,
}
},
"trainer": {
"type": "callback",
"shuffle": False,
"type": "gan-callback",
"optimizer": {
"type": "gan",
"generator_optimizer": {
"type": "sgd",
"lr": learning_rate
"lr": 0.05
},
"discriminator_optimizer": {
"type": "sgd",
"lr": learning_rate
"lr": 0.05
}
},
"num_epochs": num_epochs,
"callbacks": [
"generate_training_batches",
"train-gan",
"track_metrics"
"track-gan-metrics",
{"type": "gradient_norm_and_clip", "grad_norm": 1.0}
]
}
})

@TrainerBase.register('gan-callback')
class GanCallbackTrainer(CallbackTrainer):
def __init__(self,
model: Gan,
train_dataset: Iterable[Instance],
iterator: DataIterator,
optimizer: GanOptimizer,
num_epochs: int = 20,
shuffle: bool = False,
serialization_dir: Optional[str] = None,
cuda_device: Union[int, List] = -1,
callbacks: List[Callback] = None) -> None:
super().__init__(model,
train_dataset,
iterator,
optimizer,
num_epochs,
shuffle,
serialization_dir,
cuda_device,
callbacks)
# Need to track our own metrics as well
self._reset_counters()

def _reset_counters(self) -> None:
self.generator_loss = 0.0
self.discriminator_real_loss = 0.0
self.discriminator_fake_loss = 0.0
self.fake_mean = 0.0
self.fake_stdev = 0.0
self.count = 0

def train_one_batch_group(self, batch_group):
# Each batch_group should have only one batch
batch, = batch_group
array = batch["array"]

# We should not have mixed batches:
if len(set(batch["stage"])) != 1:
raise ValueError("mixed batch")

stage = batch["stage"][0]
self.optimizer.stage = stage
self.optimizer.zero_grad()

if stage == "discriminator_real":
# Generate real data and expect the discriminator to predict 1.
output = self.model.discriminator(array, torch.ones(1))
loss = output["loss"]
self.discriminator_real_loss += loss.sum().item()
elif stage == "discriminator_fake":
# Generate fake data and expect the discriminator to predict 0.
fake_data = self.model.generator(array)
output = self.model.discriminator(fake_data["output"], torch.zeros(1))
loss = output["loss"]
self.discriminator_fake_loss += loss.sum().item()
elif stage == "generator":
# Generate fake data and try to fool the discriminator.
generated = self.model.generator(array, self.model.discriminator)
fake_data = generated["output"]
loss = generated["loss"]
self.generator_loss += loss.sum().item()

self.fake_mean += fake_data.mean()
self.fake_stdev += fake_data.std()
self.count += 1

self.train_loss += loss.sum().item()
loss.backward()

count = max(self.count, 1)
self.train_metrics = {
"gl": self.generator_loss / count,
"dfl": self.discriminator_fake_loss / count,
"drl": self.discriminator_real_loss / count,
"mean": self.fake_mean / count,
"stdev": self.fake_stdev / count
}

self.optimizer.step()

return training_util.description_from_metrics(self.train_metrics)

def train_one_epoch(self) -> None:
# Reset epoch counters
self._reset_counters()

# Will call `self.train_one_batch_group`
super().train_one_epoch()


class GanCallbackTrainerTest(ModelTestCase):
def test_gan_can_train(self):
Expand All @@ -278,7 +303,6 @@ def test_gan_can_train(self):
# python -m allennlp.tests.training.gan_callback_trainer_test
#
# pylint: disable=invalid-name
from allennlp.training.trainer_base import TrainerBase
serialization_dir = tempfile.mkdtemp()

params = config()
Expand Down

0 comments on commit dd3476f

Please sign in to comment.