Skip to content

Commit

Permalink
Docs: fix custom pytorch model tutorial (#3188)
Browse files Browse the repository at this point in the history
*Issue #, if available:* docs build was failing
https://github.com/awslabs/gluonts/actions/runs/9318549582/job/25651104403

*Description of changes:*


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.


**Please tag this pr with at least one of these labels to make our
release process faster:** BREAKING, new feature, bug fix, other change,
dev setup
  • Loading branch information
lostella committed May 31, 2024
1 parent e3562e4 commit f0f2266
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 16 deletions.
20 changes: 10 additions & 10 deletions docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ class FeedForwardNetwork(nn.Module):
torch.nn.init.zeros_(lin.bias)
return lin

def forward(self, context):
scale = self.scaling(context)
scaled_context = context / scale
nn_out = self.nn(scaled_context)
def forward(self, past_target):
scale = self.scaling(past_target)
scaled_past_target = past_target / scale
nn_out = self.nn(scaled_past_target)
nn_out_reshaped = nn_out.reshape(-1, self.prediction_length, self.hidden_dimensions[-1])
distr_args = self.args_proj(nn_out_reshaped)
return distr_args, torch.zeros_like(scale), scale
Expand Down Expand Up @@ -143,15 +143,15 @@ class LightningFeedForwardNetwork(FeedForwardNetwork, pl.LightningModule):
super().__init__(*args, **kwargs)

def training_step(self, batch, batch_idx):
context = batch["past_target"]
target = batch["future_target"]
past_target = batch["past_target"]
future_target = batch["future_target"]

assert context.shape[-1] == self.context_length
assert target.shape[-1] == self.prediction_length
assert past_target.shape[-1] == self.context_length
assert future_target.shape[-1] == self.prediction_length

distr_args, loc, scale = self(context)
distr_args, loc, scale = self(past_target)
distr = self.distr_output.distribution(distr_args, loc, scale)
loss = -distr.log_prob(target)
loss = -distr.log_prob(future_target)

return loss.mean()

Expand Down
15 changes: 9 additions & 6 deletions src/gluonts/model/forecast_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,15 @@ def make_distribution_forecast(distr, *args, **kwargs) -> Forecast:


def make_predictions(prediction_net, inputs: dict):
# MXNet predictors only support positional arguments
class_name = prediction_net.__class__.__module__
if class_name.startswith("gluonts.mx") or class_name.startswith("mxnet"):
return prediction_net(*inputs.values())
else:
return prediction_net(**inputs)
try:
# Feed inputs as positional arguments for MXNet block predictors
import mxnet as mx

if isinstance(prediction_net, mx.gluon.Block):
return prediction_net(*inputs.values())
except ImportError:
pass
return prediction_net(**inputs)


class ForecastGenerator:
Expand Down

0 comments on commit f0f2266

Please sign in to comment.