Skip to content

Latest commit

 

History

History
20 lines (13 loc) · 733 Bytes

train_config.rst

File metadata and controls

20 lines (13 loc) · 733 Bytes

Create TrainConfig

Now let's define TrainConfig that will contains training hyperparameters.

In this tutorial we use predefined stages TrainStage and ValidationStage. TrainStage iterate by DataProducer and learn model in train() mode. Respectively ValidatioStage do same but in eval() mode.

from piepline import TrainConfig, TrainStage, ValidationStage

# define train stages
train_stages = [TrainStage(train_dataset), ValidationStage(validation_dataset)]

loss = torch.nn.NLLLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.5)

# define TrainConfig
train_config = TrainConfig(train_stages, loss, optimizer)