From 9fd32782e1c0b770bc5c3a8583ac2071cae50710 Mon Sep 17 00:00:00 2001 From: Abdul Fatir Date: Sat, 15 Nov 2025 21:17:15 +0000 Subject: [PATCH 1/3] Mask past-only covariates during loss computation --- src/chronos/chronos2/dataset.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/chronos/chronos2/dataset.py b/src/chronos/chronos2/dataset.py index 8d6fd3f4..c4258f41 100644 --- a/src/chronos/chronos2/dataset.py +++ b/src/chronos/chronos2/dataset.py @@ -503,6 +503,8 @@ def _construct_slice(self, task_idx: int) -> tuple[torch.Tensor, torch.Tensor | if self.mode in [DatasetMode.TRAIN, DatasetMode.VALIDATION]: # the first task_n_targets elements in task_context_tensor are the targets task_future_target = task_past_tensor[:, slice_idx : slice_idx + self.prediction_length] + # mask out all rows corresponding to covariates + task_future_target[task_n_targets:] = torch.nan if task_n_future_covariates > 0: # the last task_n_future_covariates elements in task_context_tensor are the known covariates From 6fd9682eed72187d9e229b832e40d61943a095a8 Mon Sep 17 00:00:00 2001 From: Abdul Fatir Date: Mon, 17 Nov 2025 13:50:04 +0000 Subject: [PATCH 2/3] Clone task_future_target --- src/chronos/chronos2/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/chronos/chronos2/dataset.py b/src/chronos/chronos2/dataset.py index c4258f41..850aad71 100644 --- a/src/chronos/chronos2/dataset.py +++ b/src/chronos/chronos2/dataset.py @@ -502,7 +502,7 @@ def _construct_slice(self, task_idx: int) -> tuple[torch.Tensor, torch.Tensor | # the task_context_tensor by slicing the appropriate indices which we do below if self.mode in [DatasetMode.TRAIN, DatasetMode.VALIDATION]: # the first task_n_targets elements in task_context_tensor are the targets - task_future_target = task_past_tensor[:, slice_idx : slice_idx + self.prediction_length] + task_future_target = task_past_tensor[:, slice_idx : slice_idx + self.prediction_length].clone() # mask out all rows corresponding to covariates task_future_target[task_n_targets:] = torch.nan From b8032f1207a986332fdedbf9535fac8f6dbe802d Mon Sep 17 00:00:00 2001 From: Abdul Fatir Date: Mon, 17 Nov 2025 16:17:39 +0000 Subject: [PATCH 3/3] Clone tensors when construcing slices --- src/chronos/chronos2/dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/chronos/chronos2/dataset.py b/src/chronos/chronos2/dataset.py index 850aad71..cb755716 100644 --- a/src/chronos/chronos2/dataset.py +++ b/src/chronos/chronos2/dataset.py @@ -477,6 +477,7 @@ def _construct_slice(self, task_idx: int) -> tuple[torch.Tensor, torch.Tensor | task_n_covariates, task_n_future_covariates, ) = self.tasks[task_idx] + task_past_tensor, task_future_tensor = task_past_tensor.clone(), task_future_tensor.clone() task_n_past_only_covariates = task_n_covariates - task_n_future_covariates full_length = task_past_tensor.shape[-1]