diff --git a/src/chronos/chronos2/model.py b/src/chronos/chronos2/model.py index 9b72f61a..0397be2a 100644 --- a/src/chronos/chronos2/model.py +++ b/src/chronos/chronos2/model.py @@ -547,6 +547,74 @@ def _compute_loss( return loss + def encode( + self, + context: torch.Tensor, + context_mask: torch.Tensor | None = None, + group_ids: torch.Tensor | None = None, + future_covariates: torch.Tensor | None = None, + future_covariates_mask: torch.Tensor | None = None, + num_output_patches: int = 1, + future_target: torch.Tensor | None = None, + future_target_mask: torch.Tensor | None = None, + output_attentions: bool = False, + ): + self._validate_input( + context=context, + context_mask=context_mask, + future_covariates=future_covariates, + future_covariates_mask=future_covariates_mask, + group_ids=group_ids, + num_output_patches=num_output_patches, + future_target=future_target, + future_target_mask=future_target_mask, + ) + + batch_size = context.shape[0] + patched_context, attention_mask, loc_scale = self._prepare_patched_context( + context=context, context_mask=context_mask + ) + num_context_patches = attention_mask.shape[-1] + + # get input embeddings of shape (batch, num_context_patches, d_model) + input_embeds: torch.Tensor = self.input_patch_embedding(patched_context) + # append [REG] special token embedding, if needed + if self.chronos_config.use_reg_token: + reg_input_ids = torch.full((batch_size, 1), self.config.reg_token_id, device=input_embeds.device) + reg_embeds = self.shared(reg_input_ids) + input_embeds = torch.cat([input_embeds, reg_embeds], dim=-2) + attention_mask = torch.cat( + [attention_mask.to(self.dtype), torch.ones_like(reg_input_ids).to(self.dtype)], dim=-1 + ) + + patched_future, patched_future_covariates_mask = self._prepare_patched_future( + future_covariates=future_covariates, + future_covariates_mask=future_covariates_mask, + loc_scale=loc_scale, + num_output_patches=num_output_patches, + batch_size=batch_size, + ) + future_attention_mask = torch.ones(batch_size, num_output_patches, dtype=self.dtype, device=self.device) + + # get future embeddings of shape (batch, num_output_patches, d_model) + future_embeds: torch.Tensor = self.input_patch_embedding(patched_future) + + # concatenate context and future embeddings and masks + input_embeds = torch.cat([input_embeds, future_embeds], dim=-2) + attention_mask = torch.cat([attention_mask, future_attention_mask], dim=-1) + + if group_ids is None: + # by default, each time series is treated independently, i.e., no mixing across the batch + group_ids = torch.arange(batch_size, dtype=torch.long, device=self.device) + + encoder_outputs: Chronos2EncoderOutput = self.encoder( + attention_mask=attention_mask, + inputs_embeds=input_embeds, + group_ids=group_ids, + output_attentions=output_attentions, + ) + return encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches + def forward( self, context: torch.Tensor, @@ -625,63 +693,19 @@ def forward( - enc_time_self_attn_weights: Time self attention weights, if output_attentions=True - enc_group_self_attn_weights: Group self attention weights, if output_attentions=True """ - - self._validate_input( + batch_size = context.shape[0] + encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches = self.encode( context=context, context_mask=context_mask, + group_ids=group_ids, future_covariates=future_covariates, future_covariates_mask=future_covariates_mask, - group_ids=group_ids, num_output_patches=num_output_patches, future_target=future_target, future_target_mask=future_target_mask, - ) - - batch_size = context.shape[0] - patched_context, attention_mask, loc_scale = self._prepare_patched_context( - context=context, context_mask=context_mask - ) - num_context_patches = attention_mask.shape[-1] - - # get input embeddings of shape (batch, num_context_patches, d_model) - input_embeds: torch.Tensor = self.input_patch_embedding(patched_context) - # append [REG] special token embedding, if needed - if self.chronos_config.use_reg_token: - reg_input_ids = torch.full((batch_size, 1), self.config.reg_token_id, device=input_embeds.device) - reg_embeds = self.shared(reg_input_ids) - input_embeds = torch.cat([input_embeds, reg_embeds], dim=-2) - attention_mask = torch.cat( - [attention_mask.to(self.dtype), torch.ones_like(reg_input_ids).to(self.dtype)], dim=-1 - ) - - patched_future, patched_future_covariates_mask = self._prepare_patched_future( - future_covariates=future_covariates, - future_covariates_mask=future_covariates_mask, - loc_scale=loc_scale, - num_output_patches=num_output_patches, - batch_size=batch_size, - ) - future_attention_mask = torch.ones(batch_size, num_output_patches, dtype=self.dtype, device=self.device) - - # get future embeddings of shape (batch, num_output_patches, d_model) - future_embeds: torch.Tensor = self.input_patch_embedding(patched_future) - - # concatenate context and future embeddings and masks - input_embeds = torch.cat([input_embeds, future_embeds], dim=-2) - attention_mask = torch.cat([attention_mask, future_attention_mask], dim=-1) - - if group_ids is None: - # by default, each time series is treated independently, i.e., no mixing across the batch - group_ids = torch.arange(batch_size, dtype=torch.long, device=self.device) - - encoder_outputs: Chronos2EncoderOutput = self.encoder( - attention_mask=attention_mask, - inputs_embeds=input_embeds, - group_ids=group_ids, output_attentions=output_attentions, ) hidden_states: torch.Tensor = encoder_outputs[0] - assert hidden_states.shape == (batch_size, num_context_patches + 1 + num_output_patches, self.model_dim) # slice the last num_output_patches hidden states to be input into the output_patch_embedding diff --git a/src/chronos/chronos2/pipeline.py b/src/chronos/chronos2/pipeline.py index ce398b05..cadc544a 100644 --- a/src/chronos/chronos2/pipeline.py +++ b/src/chronos/chronos2/pipeline.py @@ -988,6 +988,81 @@ def predict_fev( return predictions_per_window, inference_time_s + @torch.no_grad() + def embed( + self, inputs: TensorOrArray | Sequence[TensorOrArray], batch_size: int = 256, context_length: int | None = None + ) -> tuple[list[torch.Tensor], list[tuple[torch.Tensor, torch.Tensor]]]: + """ + Get encoder embeddings for the given time series. + + Parameters + ---------- + inputs + The time series to get embeddings for, can be one of: + - A 3-dimensional `torch.Tensor` or `np.ndarray` of shape (batch, n_variates, history_length). When `n_variates > 1`, information + will be shared among the different variates of each time series in the batch. + - A list of `torch.Tensor` or `np.ndarray` where each element can either be 1-dimensional of shape (history_length,) + or 2-dimensional of shape (n_variates, history_length). The history_lengths may be different across elements; left-padding + will be applied, if needed. + batch_size + The batch size used for generating embeddings. Note that the batch size here means the total number of time series which are input into the model. + If your data has multiple variates, the effective number of time series tasks in a batch will be lower than this value, by default 256 + context_length + The maximum context length used during for inference, by default set to the model's default context length + + Returns + ------- + embeddings + a list of `torch.Tensor` where each element has shape (n_variates, num_patches + 2, d_model) and the number of elements are equal to the number + of target time series (univariate or multivariate) in the `inputs`. The extra +2 is due to embeddings of the [REG] token and a masked output patch token. + loc_scale + a list of tuples with the mean and standard deviation of each time series. + """ + if context_length is None: + context_length = self.model_context_length + + if context_length > self.model_context_length: + warnings.warn( + f"The specified context_length {context_length} is greater than the model's default context length {self.model_context_length}. " + f"Resetting context_length to {self.model_context_length}." + ) + context_length = self.model_context_length + + test_dataset = Chronos2Dataset.convert_inputs( + inputs=inputs, + context_length=context_length, + prediction_length=0, + batch_size=batch_size, + output_patch_size=self.model_output_patch_size, + mode=DatasetMode.TEST, + ) + test_loader = DataLoader( + test_dataset, batch_size=None, num_workers=1, pin_memory=True, shuffle=False, drop_last=False + ) + all_embeds: list[torch.Tensor] = [] + all_loc_scales: list[tuple[torch.Tensor, torch.Tensor]] = [] + for batch in test_loader: + assert batch["future_target"] is None + batch_context = batch["context"] + batch_group_ids = batch["group_ids"] + batch_target_idx_ranges = batch["target_idx_ranges"] + + encoder_outputs, (locs, scales), *_ = self.model.encode( + context=batch_context.to(device=self.model.device, dtype=torch.float32), + group_ids=batch_group_ids.to(self.model.device), + ) + batch_embeds = [encoder_outputs[0][start:end].cpu() for (start, end) in batch_target_idx_ranges] + batch_loc_scales = list( + zip( + [locs[start:end].cpu() for (start, end) in batch_target_idx_ranges], + [scales[start:end].cpu() for (start, end) in batch_target_idx_ranges], + ) + ) + all_embeds.extend(batch_embeds) + all_loc_scales.extend(batch_loc_scales) + + return all_embeds, all_loc_scales + @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): """ diff --git a/test/test_chronos2.py b/test/test_chronos2.py index bc550a1d..a13bd998 100644 --- a/test/test_chronos2.py +++ b/test/test_chronos2.py @@ -340,6 +340,35 @@ def test_pipeline_predict_can_handle_different_model_and_input_dtypes(dtype: tor validate_tensor(quantiles_item, (3, expected_num_quantiles, 7), dtype=torch.float32) +@pytest.mark.parametrize( + "inputs, expected_output_shapes", + [ + # NOTE: d_model for the dummy model is 6 + # Homogenous univariate task + (torch.rand(4, 1, 16), [(1, 3, 6)] * 4), + # Homogenous multivariate task + (torch.rand(4, 3, 37), [(3, 5, 6)] * 4), + # Heterogenous tasks with different history lengths + ( + [torch.rand(100), torch.rand(2, 150), torch.rand(120)], + [(1, 12, 6), (2, 12, 6), (1, 12, 6)], + ), + ], +) +def test_when_input_is_valid_then_pipeline_can_embed(pipeline, inputs, expected_output_shapes): + embeds, loc_scales = pipeline.embed(inputs) + + assert ( + isinstance(embeds, list) + and len(embeds) == len(expected_output_shapes) + and len(loc_scales) == len(expected_output_shapes) + ) + for embed, loc_scale, expected_shape in zip(embeds, loc_scales, expected_output_shapes): + validate_tensor(embed, expected_shape, dtype=torch.float32) + validate_tensor(loc_scale[0], (expected_shape[0], 1), dtype=torch.float32) + validate_tensor(loc_scale[1], (expected_shape[0], 1), dtype=torch.float32) + + @pytest.mark.parametrize( "task_kwargs", [