Skip to content

Commit

Permalink
Update inference.py
Browse files Browse the repository at this point in the history
Related to DeepLearningExamples/PyTorch/Forecasting/TFT/
(e.g. GNMT/PyTorch or FasterTransformer/All)

Describe the bug
When I run the "inference.py" the error happen because "unscaled_predictions" was numpy.ndarray. Therefore, we need to add the code to process the unscaled_predictions to tensor

To Reproduce
Steps to reproduce the behavior:

python inference.py \
--checkpoint /results/TFT_electricity_bs8x1024_lr1e-3/seed_1/checkpoint.pt \
--data /data/processed/electricity_bin/test.csv \
--tgt_scalers /data/processed/electricity_bin/tgt_scalers.bin \
--cat_encodings /data/processed/electricity_bin/cat_encodings.bin \
--visualize \
--save_predictions
Expected behavior
'numpy.ndarray' object has no attribute 'new_full'

Environment

GPUs in the system: NVIDIA GeForce RTX 3090
CUDA driver version 520.61.05
  • Loading branch information
tdktrang committed May 18, 2024
1 parent 729963d commit 3f07706
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion PyTorch/Forecasting/TFT/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def predict(args, config, model, data_loader, scalers, cat_encodings, extend_tar

def visualize_v2(args, config, model, data_loader, scalers, cat_encodings):
unscaled_predictions, unscaled_targets, ids, _ = predict(args, config, model, data_loader, scalers, cat_encodings, extend_targets=True)

unscaled_predictions = torch.tensor(unscaled_predictions)
unscaled_targets = torch.tensor(unscaled_targets)
num_horizons = config.example_length - config.encoder_length + 1
pad = unscaled_predictions.new_full((unscaled_targets.shape[0], unscaled_targets.shape[1] - unscaled_predictions.shape[1], unscaled_predictions.shape[2]), fill_value=float('nan'))
pad[:,-1,:] = unscaled_targets[:,-num_horizons,:]
Expand Down

0 comments on commit 3f07706

Please sign in to comment.