Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidate DeepNPTSEstimator #2496

Merged
merged 7 commits into from
Jan 20, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/gluonts/torch/model/deep_npts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ._estimator import DeepNPTSEstimator
from ._network import (
DeepNPTSNetwork,
DeepNPTSMultiStepPredictor,
DeepNPTSMultiStepNetwork,
DeepNPTSNetworkDiscrete,
DeepNPTSNetworkSmooth,
)
Expand All @@ -23,7 +23,7 @@
__all__ = [
"DeepNPTSEstimator",
"DeepNPTSNetwork",
"DeepNPTSMultiStepPredictor",
"DeepNPTSMultiStepNetwork",
"DeepNPTSNetworkDiscrete",
"DeepNPTSNetworkSmooth",
]
87 changes: 42 additions & 45 deletions src/gluonts/torch/model/deep_npts/_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# permissions and limitations under the License.

from copy import deepcopy
from typing import List, Optional, Callable, Union
from typing import Dict, List, Optional, Callable, Union
from functools import partial

import torch
Expand Down Expand Up @@ -44,14 +44,14 @@
from ._network import (
DeepNPTSNetwork,
DeepNPTSNetworkDiscrete,
DeepNPTSMultiStepPredictor,
DeepNPTSMultiStepNetwork,
)
from .scaling import (
min_max_scaling,
standard_normal_scaling,
)

LOSS_SCALING_MAP = {
LOSS_SCALING_MAP: Dict[str, Callable] = {
"min_max_scaling": partial(min_max_scaling, dim=1, keepdim=False),
"standard_normal_scaling": partial(
standard_normal_scaling, dim=1, keepdim=False
Expand Down Expand Up @@ -129,8 +129,14 @@ def __init__(
cardinality: Optional[List[int]] = None,
embedding_dimension: Optional[List[int]] = None,
input_scaling: Optional[Union[Callable, str]] = None,
dropout_rate: Optional[float] = None,
dropout_rate: float = 0.0,
network_type: DeepNPTSNetwork = DeepNPTSNetworkDiscrete,
epochs: int = 100,
lr: float = 1e-5,
batch_size: int = 32,
num_batches_per_epoch: int = 100,
cache_data: bool = False,
loss_scaling: Optional[Union[str, Callable]] = None,
):
assert (cardinality is not None) == use_feat_static_cat, (
"You should set `cardinality` if and only if"
Expand Down Expand Up @@ -203,6 +209,17 @@ def __init__(
self.batch_norm = batch_norm
self.network_type = network_type

self.epochs = epochs
self.lr = lr
self.batch_size = batch_size
self.num_batches_per_epoch = num_batches_per_epoch
self.cache_data = cache_data
self.loss_scaling: Optional[Callable] = (
LOSS_SCALING_MAP[loss_scaling]
if isinstance(loss_scaling, str)
else loss_scaling
)

def input_transform(self) -> Transformation:
# Note: Any change here should be reflected in the
# `self.num_time_features` field as well.
Expand Down Expand Up @@ -307,30 +324,19 @@ def training_data_loader(

def train_model(
self,
train_dataset: Dataset,
epochs: int,
lr: float = 1e-5,
batch_size: int = 32,
num_batches_per_epoch: int = 100,
training_data: Dataset,
cache_data: bool = False,
loss_scaling: Optional[Union[Callable, str]] = None,
) -> DeepNPTSNetwork:
loss_scaling = (
LOSS_SCALING_MAP[loss_scaling]
if isinstance(loss_scaling, str)
else loss_scaling
)

transformed_dataset = TransformedDataset(
train_dataset, self.input_transform()
training_data, self.input_transform()
)

data_loader = self.training_data_loader(
transformed_dataset
if not cache_data
else Cached(transformed_dataset),
batch_size=batch_size,
num_batches_per_epoch=num_batches_per_epoch,
batch_size=self.batch_size,
num_batches_per_epoch=self.num_batches_per_epoch,
)

net = self.network_type(
Expand All @@ -344,21 +350,22 @@ def train_model(
batch_norm=self.batch_norm,
)

optimizer = torch.optim.Adam(net.parameters(), lr=lr)
optimizer = torch.optim.Adam(net.parameters(), lr=self.lr)

best_loss = float("inf")
for epoch_num in range(epochs):
for epoch_num in range(self.epochs):
sum_epoch_loss = 0.0
for batch_no, batch in enumerate(data_loader, start=1):
x = {k: batch[k] for k in self.features_fields}
y = batch[self.target_field]

predicted_distribution = net(**x)
scale = (
loss_scaling(x[self.past_target_field])[1]
if loss_scaling
else 1
)

if self.loss_scaling is not None:
scale = self.loss_scaling(x[self.past_target_field])[1]
else:
scale = 1.0

loss = (-predicted_distribution.log_prob(y) / scale).mean()

optimizer.zero_grad()
Expand All @@ -373,49 +380,39 @@ def train_model(

print(
f"Loss for epoch {epoch_num}: "
f"{sum_epoch_loss / num_batches_per_epoch}"
f"{sum_epoch_loss / self.num_batches_per_epoch}"
)

print(f"Best loss: {best_loss / num_batches_per_epoch}")
print(f"Best loss: {best_loss / self.num_batches_per_epoch}")

return best_net

def get_predictor(
self, net: torch.nn.Module, batch_size: int, device=torch.device("cpu")
self, net: torch.nn.Module, device=torch.device("cpu")
) -> PyTorchPredictor:
pred_net_multi_step = DeepNPTSMultiStepPredictor(
pred_net_multi_step = DeepNPTSMultiStepNetwork(
net=net, prediction_length=self.prediction_length
)

return PyTorchPredictor(
prediction_net=pred_net_multi_step,
prediction_length=self.prediction_length,
input_names=self.features_fields + self.prediction_features_field,
batch_size=batch_size,
batch_size=self.batch_size,
input_transform=self.input_transform()
+ self.instance_splitter(TestSplitSampler(), is_train=False),
device=device,
)

def train(
self,
train_dataset: Dataset,
validation_dataset: Optional[Dataset] = None,
epochs: int = 100,
lr: float = 1e-5,
batch_size: int = 32,
num_batches_per_epoch: int = 100,
training_data: Dataset,
validation_data: Optional[Dataset] = None,
cache_data: bool = False,
loss_scaling: Optional[Callable] = None,
) -> PyTorchPredictor:
pred_net = self.train_model(
train_dataset=train_dataset,
epochs=epochs,
lr=lr,
batch_size=batch_size,
num_batches_per_epoch=num_batches_per_epoch,
training_data=training_data,
cache_data=cache_data,
loss_scaling=loss_scaling,
)

return self.get_predictor(pred_net, batch_size)
return self.get_predictor(pred_net)
68 changes: 42 additions & 26 deletions src/gluonts/torch/model/deep_npts/_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def __init__(
num_time_features: int,
batch_norm: bool = False,
input_scaling: Optional[Union[Callable, str]] = None,
dropout_rate: Optional[float] = None,
dropout_rate: float = 0.0,
):
super().__init__()

Expand Down Expand Up @@ -148,7 +148,7 @@ def __init__(
modules += [nn.Linear(in_features, out_features), nn.ReLU()]
if self.batch_norm:
modules.append(nn.BatchNorm1d(out_features))
if self.dropout_rate:
if self.dropout_rate > 0:
modules.append(nn.Dropout(self.dropout_rate))

self.model = nn.Sequential(*modules)
Expand Down Expand Up @@ -209,27 +209,31 @@ def forward(


class DeepNPTSNetworkDiscrete(DeepNPTSNetwork):
"""Extends `DeepNTPSNetwork` by implementing the output layer which
"""
Extends `DeepNTPSNetwork` by implementing the output layer which
converts the ouptuts from the base network into probabilities of length
`context_length`. These probabilities together with the past values in the
context window constitute the one-step-ahead forecast distribution.
Specifically, the forecast is always one of the values observed in the
context window with the corresponding predicted probability.

Parameters ---------- *args Arguments to ``DeepNPTSNetwork``.
use_softmax Flag indicating whether to use softmax or normalization for
converting the outputs of the base network to probabilities. kwargs
Keyword arguments to ``DeepNPTSNetwork``.
Parameters
----------
*args
Arguments to ``DeepNPTSNetwork``.
use_softmax
Flag indicating whether to use softmax or normalization for
converting the outputs of the base network to probabilities.
kwargs
Keyword arguments to ``DeepNPTSNetwork``.
"""

@validated()
def __init__(self, *args, use_softmax: bool = False, **kwargs):
super().__init__(*args, **kwargs)
self.use_softmax = use_softmax
modules: List[nn.Module] = (
[]
if self.dropout_rate is None
else [nn.Dropout(self.dropout_rate)]
[] if self.dropout_rate > 0 else [nn.Dropout(self.dropout_rate)]
)
modules.append(
nn.Linear(self.num_hidden_nodes[-1], self.context_length)
Expand Down Expand Up @@ -264,7 +268,8 @@ def forward(


class DeepNPTSNetworkSmooth(DeepNPTSNetwork):
"""Extends `DeepNTPSNetwork` by implementing the output layer which
"""
Extends `DeepNTPSNetwork` by implementing the output layer which
converts the ouptuts from the base network into a smoothed mixture
distribution. The components of the mixture are Gaussians centered around
the observations in the context window. The mixing probabilities as well as
Expand All @@ -279,9 +284,7 @@ class DeepNPTSNetworkSmooth(DeepNPTSNetwork):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
modules = (
[]
if self.dropout_rate is None
else [nn.Dropout(self.dropout_rate)]
[] if self.dropout_rate > 0 else [nn.Dropout(self.dropout_rate)]
)
modules += [
nn.Linear(self.num_hidden_nodes[-1], self.context_length + 1),
Expand Down Expand Up @@ -315,9 +318,11 @@ def forward(
)


class DeepNPTSMultiStepPredictor(nn.Module):
"""Implements multi-step prediction given a trained `DeepNPTSNewtork` model
that outputs one-step-ahead forecast distribution."""
class DeepNPTSMultiStepNetwork(nn.Module):
"""
Implements multi-step prediction given a trained `DeepNPTSNewtork` model
that outputs one-step-ahead forecast distribution.
"""

@validated()
def __init__(
Expand All @@ -342,15 +347,26 @@ def forward(
):
"""Generates samples from the forecast distribution.

Parameters ---------- feat_static_cat Shape (-1, num_features).
feat_static_real Shape (-1, num_features). past_target Shape
(-1, context_length). past_observed_values Shape (-1,
context_length). past_time_feat Shape (-1, context_length,
self.num_time_features). future_time_feat Shape (-1,
prediction_length, self.num_time_features). Returns -------
torch.Tensor Tensor containing samples from the predicted
distribution. Shape is (-1, self.num_parallel_samples,
self.prediction_length).
Parameters
----------
feat_static_cat
Shape (-1, num_features).
feat_static_real
Shape (-1, num_features).
past_target
Shape (-1, context_length).
past_observed_values
Shape (-1, context_length).
past_time_feat
Shape (-1, context_length, self.num_time_features).
future_time_feat
Shape (-1, prediction_length, self.num_time_features).

Returns
-------
torch.Tensor
Tensor containing samples from the predicted distribution.
Shape is (-1, self.num_parallel_samples, self.prediction_length).
"""
# Blow up the initial `x` by the number of parallel samples required.
# (batch_size * num_parallel_samples, context_length)
Expand Down