Skip to content

Commit

Permalink
Update inference.py
Browse files Browse the repository at this point in the history
  • Loading branch information
tdktrang authored May 21, 2024
1 parent 3f07706 commit 879d294
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions PyTorch/Forecasting/TFT/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def inference(args, config, model, data_loader, scalers, cat_encodings):
if args.joint_visualization or args.save_predictions:
ids = torch.from_numpy(ids.squeeze())
#ids = torch.cat([x['id'][0] for x in data_loader.dataset])
unscaled_predictions = torch.tensor(unscaled_predictions)
unscaled_targets = torch.tensor(unscaled_targets)
joint_graphs = torch.cat([unscaled_targets, unscaled_predictions], dim=2)
graphs = {i:joint_graphs[ids == i, :, :] for i in set(ids.tolist())}
for key, g in graphs.items(): #timeseries id, joint targets and predictions
Expand Down

0 comments on commit 879d294

Please sign in to comment.