diff --git a/src/chronos/chronos2/dataset.py b/src/chronos/chronos2/dataset.py index 8d6fd3f4..cb755716 100644 --- a/src/chronos/chronos2/dataset.py +++ b/src/chronos/chronos2/dataset.py @@ -477,6 +477,7 @@ def _construct_slice(self, task_idx: int) -> tuple[torch.Tensor, torch.Tensor | task_n_covariates, task_n_future_covariates, ) = self.tasks[task_idx] + task_past_tensor, task_future_tensor = task_past_tensor.clone(), task_future_tensor.clone() task_n_past_only_covariates = task_n_covariates - task_n_future_covariates full_length = task_past_tensor.shape[-1] @@ -502,7 +503,9 @@ def _construct_slice(self, task_idx: int) -> tuple[torch.Tensor, torch.Tensor | # the task_context_tensor by slicing the appropriate indices which we do below if self.mode in [DatasetMode.TRAIN, DatasetMode.VALIDATION]: # the first task_n_targets elements in task_context_tensor are the targets - task_future_target = task_past_tensor[:, slice_idx : slice_idx + self.prediction_length] + task_future_target = task_past_tensor[:, slice_idx : slice_idx + self.prediction_length].clone() + # mask out all rows corresponding to covariates + task_future_target[task_n_targets:] = torch.nan if task_n_future_covariates > 0: # the last task_n_future_covariates elements in task_context_tensor are the known covariates