From eae617ca596b0f5885f5e7a96e36cbd768d66a81 Mon Sep 17 00:00:00 2001 From: Risto0211 <2533895673@qq.com> Date: Fri, 24 Oct 2025 11:34:02 -0400 Subject: [PATCH] timesfm_2p5 code and examples update --- example/timesfm_2p5.ipynb | 86 ++++-- leaderboard/monash_moirai.csv | 4 +- src/samay/dataset.py | 29 +- src/samay/model.py | 162 +++--------- .../models/timesfm/timesfm/v2/configs.py | 1 + .../timesfm/timesfm/v2/timesfm_2p5_base.py | 1 + .../timesfm/timesfm/v2/timesfm_2p5_torch.py | 152 +++++++---- .../models/timesfm/timesfm/v2/transformer.py | 248 ++++++++++-------- src/samay/utils.py | 2 +- 9 files changed, 373 insertions(+), 312 deletions(-) diff --git a/example/timesfm_2p5.ipynb b/example/timesfm_2p5.ipynb index 5ec5d73..d2f46b0 100644 --- a/example/timesfm_2p5.ipynb +++ b/example/timesfm_2p5.ipynb @@ -29,28 +29,15 @@ "text": [ "/nethome/sli999/anaconda3/envs/torch/lib/python3.11/site-packages/gluonts/json.py:102: UserWarning: Using `json`-module for json-handling. Consider installing one of `orjson`, `ujson` to speed up serialization and deserialization.\n", " warnings.warn(\n", - "INFO:p-1279300:t-140075744466752:timesfm_2p5_torch.py:load_checkpoint:Downloading checkpoint from Hugging Face repo google/timesfm-2.5-200m-pytorch\n" + "INFO:p-3382524:t-140205328676672:timesfm_2p5_torch.py:_from_pretrained:Downloading checkpoint from Hugging Face repo google/timesfm-2.5-200m-pytorch\n", + "INFO:p-3382524:t-140205328676672:timesfm_2p5_torch.py:_from_pretrained:Loading checkpoint from: /nethome/sli999/.cache/huggingface/hub/models--google--timesfm-2.5-200m-pytorch/snapshots/1d952420fba87f3c6dee4f240de0f1a0fbc790e3/model.safetensors\n" ] }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "6caf1486003a41b288f303eb6468a123", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Fetching 4 files: 0%| | 0/4 [00:00" ] @@ -156,6 +143,63 @@ "source": [ "visualize(trues=trues, preds=preds, history=histories, context_len=512)" ] + }, + { + "cell_type": "markdown", + "id": "1201ebd6", + "metadata": {}, + "source": [ + "## Finetune the model " + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "04f2776b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0, Loss: 2.4298\n", + "Epoch 1, Loss: 1.5070\n", + "Epoch 2, Loss: 1.1590\n", + "Epoch 3, Loss: 0.9918\n", + "Epoch 4, Loss: 0.8892\n" + ] + } + ], + "source": [ + "model.finetune(train_dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "ce07cf52", + "metadata": {}, + "source": [ + "## Evaluate the model after finetune " + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "ba16d47f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'mse': 11.999966, 'mae': 1.789054, 'mase': 0.23354335, 'mape': 184.51486, 'rmse': 3.4640965, 'nrmse': 0.0398107419294543, 'smape': 0.45212102, 'msis': 0.056255825, 'nd': 0.33803181907580215, 'mwsq': 2.8931375801731507, 'crps': 0.7806806351359066}\n" + ] + } + ], + "source": [ + "metric, trues, preds, histories = model.evaluate(val_dataset)\n", + "print(metric)" + ] } ], "metadata": { diff --git a/leaderboard/monash_moirai.csv b/leaderboard/monash_moirai.csv index ed9d351..69592da 100644 --- a/leaderboard/monash_moirai.csv +++ b/leaderboard/monash_moirai.csv @@ -27,4 +27,6 @@ kaggle_web_traffic_weekly (W-SUN),104.91,43.93m,677892827.0147512,1691.418769624 temperature_rain (D),106.72,193.2m,731.8126906485586,5.994333053214229,0.4581582907670293,0.249426455236986,27.05203671904499,0.0036437877732579,0.7951260719877393,0.0656396716800804,0.694280646566257,131.2446799187715,2.7762765877611284 solar_4_seconds (4s),181.73,265.02m,3.56522408907416e-10,1.5385409490183596e-05,1.5385409490183597,1.5385409490183597,1.888180099745297e-05,1.8881800997452969,1.0573307614461165,1.5385409490183597,1.5385409490183597,0.0174618302899513,0.0159100078192055 wind_4_seconds (4s),184.15,295.86m,3.108434090470448,1.6458506820258083,12.77600689413882,0.082432241288121,1.763075180039254,0.5342635870919336,0.0786722471393581,0.6037483766474667,0.0811029831305508,5.187621701262427,1.674276021405027 -kaggle_web_traffic (D),639.07,1270.6m,2452766448416.5938,20043.06505466295,0.29456241553874846,2.4090639022504874,1566131.044458475,0.0005082301009184855,0.5116392842774276,7.789624846643337,0.45522290224502393,2109413088817.9578,14272.143354708753 +kaggle_web_traffic (D),639.07,1270.6m,2452766448416.5938,20043.06505466295,0.2945624155387484,2.4090639022504874,1566131.044458475,0.0005082301009184,0.5116392842774276,7.789624846643337,0.4552229022450239,2109413088817.9573,14272.143354708753 +wind_farms_minutely (min),835.99,3349.23m,16.206971220492438,0.7136967852770562,0.01016127294468,0.0144428278254881,4.0257882731823385,0.0061153134085223,0.0171983034046891,0.0055462004428248,0.0083765532898652,12.73310756721593,0.5549177864411421 +weather (D),1180.74,2492.3m,23.087079294192066,0.43981675735280407,0.11220441599089351,0.09457813572814078,4.804901590479462,0.00866736117753027,0.08763642437944673,0.023134484880568202,0.10417793350460773,2.903852911979907,0.18787836223668095 diff --git a/src/samay/dataset.py b/src/samay/dataset.py index 6136af7..9e2b0c9 100644 --- a/src/samay/dataset.py +++ b/src/samay/dataset.py @@ -2056,6 +2056,7 @@ def __init__( stride=10, context_len=512, horizon_len=96, + normalize=True, **kwargs, ): super().__init__( @@ -2071,6 +2072,7 @@ def __init__( self.stride = stride self.boundaries = boundaries + self.normalize = normalize self.pad = False self._read_data() @@ -2093,15 +2095,15 @@ def _read_data(self): if self.boundaries[2] == 0: self.boundaries[2] = int(len(self.df) - 1) - scaler = StandardScaler() + self.scaler = StandardScaler() if self.boundaries == [-1, -1, -1]: # use all data for training self.boundaries = [0, 0, len(self.df) - 1] - scaler = scaler.fit(self.df) + self.scaler = self.scaler.fit(self.df) else: # fit the scaler on the training data - scaler = scaler.fit(self.df[slice(0, self.boundaries[0]), :]) + self.scaler = self.scaler.fit(self.df[slice(0, self.boundaries[0]), :]) if self.mode == "train": @@ -2110,7 +2112,7 @@ def _read_data(self): elif self.mode == "test": self.data = self.df[slice(self.boundaries[1], self.boundaries[2]), :] - self.data = scaler.transform(self.data) + self.data = self.scaler.transform(self.data) self.length_timeseries = self.data.shape[0] self.required_len = self.context_len + self.horizon_len @@ -2150,7 +2152,7 @@ def __getitem__(self, index): return input_seq, forecast_seq elif self.task_name == "finetune": - pred_end = seq_end + 1 + pred_end = seq_end + self.horizon_len if pred_end > self.length_timeseries: pred_end = self.length_timeseries seq_end = pred_end - 1 @@ -2160,8 +2162,8 @@ def __getitem__(self, index): seq_start:seq_end, channel_idx ] # shape: (context_len, ) forecast_seq = self.data[seq_end:pred_end, channel_idx] - loss_mask = np.ones(input_seq.shape[0]) - return input_seq, forecast_seq, loss_mask + # loss_mask = np.ones(input_seq.shape[0]) + return input_seq, forecast_seq def __len__(self): if self.length_timeseries < self.context_len + self.horizon_len: @@ -2173,4 +2175,15 @@ def get_data_loader(self): return DataLoader(self, shuffle=True, batch_size=self.batchsize) else: return DataLoader(self, shuffle=False, batch_size=self.batchsize) - # shape: (batch_size, n_channels, seq_len) \ No newline at end of file + # shape: (batch_size, n_channels, seq_len) + + def _denormalize_data(self, data: np.ndarray): + data = np.asarray(data) + if self.normalize: + data = data[:, : self.n_channels, :] + data_flatten = np.transpose(data, (0, 2, 1)).reshape(-1, self.n_channels) + return self.scaler.inverse_transform(data_flatten).reshape( + data.shape[0], data.shape[1], data.shape[2] + ) + else: + return data \ No newline at end of file diff --git a/src/samay/model.py b/src/samay/model.py index 6e7ace7..9619fb2 100644 --- a/src/samay/model.py +++ b/src/samay/model.py @@ -266,7 +266,10 @@ def evaluate(self, dataset, **kwargs): trues = dataset._denormalize_data(trues) preds = dataset._denormalize_data(preds) histories = dataset._denormalize_data(histories) - quantiles = dataset._denormalize_data(q for q in quantiles) + new_quantiles = [] + for i in range(quantiles.shape[0]): + new_quantiles.append(dataset._denormalize_data(quantiles[i])) + quantiles = np.array(new_quantiles) mse = MSE(trues, preds) mae = MAE(trues, preds) @@ -2594,14 +2597,15 @@ class TimesFM_2p5_Model(Basemodel): def __init__(self, config=None, repo=None, **kwargs): super().__init__(config=config, repo=repo) - self.model = TimesFM_2p5_200M_torch(device=self.device) + if repo: - self.model.load_checkpoint(hf_repo_id=repo) + self.model = TimesFM_2p5_200M_torch.from_pretrained(repo, device=self.device) else: - self.model.load_checkpoint() + self.model = TimesFM_2p5_200M_torch(device=self.device) self.config = ForecastConfig(**config) self.model.compile(self.config) + self.quantiles = self.model.model.config.quantiles def evaluate(self, dataset, **kwargs): """Evaluate the model on the given dataset. @@ -2614,6 +2618,7 @@ def evaluate(self, dataset, **kwargs): """ # self.model.to(self.device) self.model.model.eval() + self.model.model.to(self.device) inference_loader = dataset.get_data_loader() horizon = dataset.horizon_len @@ -2633,6 +2638,7 @@ def evaluate(self, dataset, **kwargs): input_seq, mask_seq, ) + quantile_forecast = quantile_forecast[..., 1:].transpose(2, 0, 1) trues.append(target_seq.cpu().numpy()) preds.append(point_forecast) @@ -2641,9 +2647,17 @@ def evaluate(self, dataset, **kwargs): trues = np.concatenate(trues, axis=0).reshape(-1, dataset.n_channels, dataset.horizon_len) preds = np.concatenate(preds, axis=0).reshape(-1, dataset.n_channels, dataset.horizon_len) - q_preds = np.concatenate(q_preds, axis=0).reshape(q_preds[-1].shape[-1], -1, dataset.n_channels, dataset.horizon_len) + q_preds = np.concatenate(q_preds, axis=1).reshape(q_preds[-1].shape[0], -1, dataset.n_channels, dataset.horizon_len) histories = np.concatenate(histories, axis=0).reshape(-1, dataset.n_channels, dataset.context_len) + trues = dataset._denormalize_data(trues) + preds = dataset._denormalize_data(preds) + new_q_preds = np.zeros_like(q_preds) + for i in range(q_preds.shape[0]): + new_q_preds[i] = dataset._denormalize_data(q_preds[i]) + q_preds = new_q_preds + histories = dataset._denormalize_data(histories) + # Calculate metrics mse = MSE(trues, preds) mae = MAE(trues, preds) @@ -2654,8 +2668,8 @@ def evaluate(self, dataset, **kwargs): smape = SMAPE(trues, preds) msis = MSIS(trues, preds) nd = ND(trues, preds) - mwsq = MWSQ(trues, preds, q_preds) - crps = CRPS(trues, preds, q_preds) + mwsq = MWSQ(trues, q_preds, quantiles=self.quantiles) + crps = CRPS(trues, q_preds, quantiles=self.quantiles) cleanup_dataloader(inference_loader) return { @@ -2698,130 +2712,38 @@ def finetune(self, dataset, **kwargs): epoch = 5 if "epoch" not in kwargs else kwargs["epoch"] optimizer = torch.optim.AdamW(self.model.model.parameters(), lr=lr) - self.model.model.to(self.device) + # self.model.model.to(self.device) self.model.model.train() dataloader = dataset.get_data_loader() + horizon = dataset.horizon_len + # torch.autograd.set_detect_anomaly(True) for ep in range(epoch): total_loss = 0 - for i, batch in enumerate(dataloader): + for batch in dataloader: input_seq, target_seq = batch batch_size = input_seq.shape[0] - input_seq = input_seq.float().to(self.device) + input_seq = input_seq.float() target_seq = target_seq.float().to(self.device) - mask_seq = torch.zeros_like(input_seq).to(self.device) + mask_seq = torch.zeros_like(input_seq) optimizer.zero_grad() - # perform auto-regressive forward pass - num_decode_steps = (dataset.horizon_len - 1) // self.model.model.o - num_input_patches = dataset.context_len // self.model.model.p - decode_cache_size = num_input_patches + num_decode_steps * self.model.model.m - - # Prefill - patched_inputs = torch.reshape(input_seq, (batch_size, -1, self.p)) - patched_masks = torch.reshape(mask_seq, (batch_size, -1, self.p)) - - # running stats - n = torch.zeros(batch_size, device=input_seq.device) - mu = torch.zeros(batch_size, device=input_seq.device) - sigma = torch.zeros(batch_size, device=input_seq.device) - patch_mu = [] - patch_sigma = [] - for i in range(num_input_patches): - (n, mu, sigma), _ = util.update_running_stats( - n, mu, sigma, patched_inputs[:, i], patched_masks[:, i] - ) - patch_mu.append(mu) - patch_sigma.append(sigma) - last_n, last_mu, last_sigma = n, mu, sigma - context_mu = torch.stack(patch_mu, dim=1) - context_sigma = torch.stack(patch_sigma, dim=1) - decode_caches = [ - util.DecodeCache( - next_index=torch.zeros( - batch_size, dtype=torch.int32, device=input_seq.device - ), - num_masked=torch.zeros( - batch_size, dtype=torch.int32, device=input_seq.device - ), - key=torch.zeros( - batch_size, - decode_cache_size, - self.model.model.h, - self.model.model.hd, - device=input_seq.device, - ), - value=torch.zeros( - batch_size, - decode_cache_size, - self.model.model.h, - self.model.model.hd, - device=input_seq.device, - ), - ) - for _ in range(self.model.model.x) - ] - - normed_inputs = util.revin( - patched_inputs, context_mu, context_sigma, reverse=False - ) - normed_inputs = torch.where(patched_masks, 0.0, normed_inputs) - (_, _, normed_outputs, normed_quantile_spread), decode_caches = self.model.model( - normed_inputs, patched_masks, decode_caches - ) - renormed_outputs = torch.reshape( - util.revin(normed_outputs, context_mu, context_sigma, reverse=True), - (batch_size, -1, self.model.model.o, self.model.model.q), + point_forecast, quantile_forecast = self.model.compiled_decode( + horizon, + input_seq, + mask_seq, + train=True ) - renormed_quantile_spread = torch.reshape( - util.revin( - normed_quantile_spread, context_mu, context_sigma, reverse=True - ), - (batch_size, -1, self.model.model.os, self.model.model.q), - )[:, -1, ...] - - # Autogressive decode - ar_outputs = [] - last_renormed_output = renormed_outputs[:, -1, :, self.model.model.aridx] - - for _ in range(num_decode_steps): - new_patched_input = torch.reshape( - last_renormed_output, (batch_size, self.model.model.m, self.model.model.p) - ) - new_mask = torch.zeros_like(new_patched_input, dtype=torch.bool) - - n, mu, sigma = last_n, last_mu, last_sigma - new_mus, new_sigmas = [], [] - for i in range(self.model.model.m): - (n, mu, sigma), _ = util.update_running_stats( - n, mu, sigma, new_patched_input[:, i], new_mask[:, i] - ) - new_mus.append(mu) - new_sigmas.append(sigma) - last_n, last_mu, last_sigma = n, mu, sigma - new_mu = torch.stack(new_mus, dim=1) - new_sigma = torch.stack(new_sigmas, dim=1) - - new_normed_input = util.revin( - new_patched_input, new_mu, new_sigma, reverse=False - ) - (_, _, new_normed_output, _), decode_caches = self.model.model( - new_normed_input, new_mask, decode_caches - ) - - new_renormed_output = torch.reshape( - util.revin(new_normed_output, new_mu, new_sigma, reverse=True), - (batch_size, self.model.model.m, self.model.model.o, self.model.model.q), - ) - ar_outputs.append(new_renormed_output[:, -1, ...]) - last_renormed_output = new_renormed_output[:, -1, :, self.model.model.aridx] - - if num_decode_steps > 0: - ar_renormed_outputs = torch.stack(ar_outputs, dim=1) - else: - ar_renormed_outputs = None - - forecast_seq = ar_renormed_outputs if ar_renormed_outputs is not None else renormed_outputs[:, :, :dataset.horizon_len, :] + # quantile_forecast = quantile_forecast[..., 1:] + # point_forecast, quantile_forecast = torch.Tensor(point_forecast), torch.Tensor(quantile_forecast) + + loss = nn.functional.mse_loss(point_forecast, target_seq) + quantiles = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] + for i, quantile in enumerate(quantiles): + last_patch_quantile = quantile_forecast[:, :, i + 1] + loss += torch.mean( + quantile_loss(last_patch_quantile, target_seq.squeeze(-1), + quantile)) loss.backward() optimizer.step() diff --git a/src/samay/models/timesfm/timesfm/v2/configs.py b/src/samay/models/timesfm/timesfm/v2/configs.py index 7f2e7b6..65d7271 100644 --- a/src/samay/models/timesfm/timesfm/v2/configs.py +++ b/src/samay/models/timesfm/timesfm/v2/configs.py @@ -94,6 +94,7 @@ class TransformerConfig: use_bias: bool use_rotary_position_embeddings: bool ff_activation: Literal["relu", "swish", "none"] + fuse_qkv: bool @dataclasses.dataclass(frozen=True) diff --git a/src/samay/models/timesfm/timesfm/v2/timesfm_2p5_base.py b/src/samay/models/timesfm/timesfm/v2/timesfm_2p5_base.py index e8fa265..167b59a 100644 --- a/src/samay/models/timesfm/timesfm/v2/timesfm_2p5_base.py +++ b/src/samay/models/timesfm/timesfm/v2/timesfm_2p5_base.py @@ -107,6 +107,7 @@ class TimesFM_2p5_200M_Definition: use_bias=False, use_rotary_position_embeddings=True, ff_activation="swish", + fuse_qkv=True, ), ) output_projection_point: ResidualBlockConfig = ResidualBlockConfig( diff --git a/src/samay/models/timesfm/timesfm/v2/timesfm_2p5_torch.py b/src/samay/models/timesfm/timesfm/v2/timesfm_2p5_torch.py index 13a70a7..9752dcd 100644 --- a/src/samay/models/timesfm/timesfm/v2/timesfm_2p5_torch.py +++ b/src/samay/models/timesfm/timesfm/v2/timesfm_2p5_torch.py @@ -17,11 +17,12 @@ import logging import math import os -from typing import Sequence +from typing import Dict, Optional, Sequence, Union import huggingface_hub import numpy as np -from safetensors.torch import load_file +from safetensors.torch import load_file, save_file +from pathlib import Path import torch from torch import nn @@ -80,11 +81,17 @@ def __init__(self, device: torch.device = torch.device("cpu")): else: self.device_count = 1 - def load_checkpoint(self, path: str): + def load_checkpoint(self, path: str, **kwargs): """Loads a PyTorch TimesFM model from a checkpoint.""" tensors = load_file(path) - self.load_state_dict(tensors) + self.load_state_dict(tensors, strict=True) self.to(self.device) + torch_compile = True + if "torch_compile" in kwargs: + torch_compile = kwargs["torch_compile"] + if torch_compile: + print("Compiling model...") + self = torch.compile(self) def forward( self, @@ -115,13 +122,13 @@ def forward( output_quantile_spread, ), new_decode_caches - def decode(self, horizon: int, inputs, masks): + def decode(self, horizon: int, inputs, masks, train=False): """Decodes the time series.""" inputs = inputs.to(self.device) masks = masks.to(self.device) - with torch.no_grad(): + with torch.set_grad_enabled(train): batch_size, context = inputs.shape[0], inputs.shape[1] num_decode_steps = (horizon - 1) // self.o num_input_patches = context // self.p @@ -274,7 +281,7 @@ def forecast_naive( return outputs -class TimesFM_2p5_200M_torch(timesfm_2p5_base.TimesFM_2p5): +class TimesFM_2p5_200M_torch(timesfm_2p5_base.TimesFM_2p5, huggingface_hub.ModelHubMixin): """PyTorch implementation of TimesFM 2.5 with 200M parameters.""" def __init__(self, device: torch.device = torch.device("cpu")): @@ -283,33 +290,65 @@ def __init__(self, device: torch.device = torch.device("cpu")): self.model = TimesFM_2p5_200M_torch_module(device=device) assert isinstance(self.model, nn.Module) # For type checker. - def load_checkpoint( - self, - *, - path: str | None = None, - hf_repo_id: str | None = "google/timesfm-2.5-200m-pytorch", - ) -> None: - """Loads a PyTorch safetensors TimesFM model. - - Args: - path: Path to a local checkpoint. If not provided, will try to download - from the default Hugging Face repo. - hf_repo_id: If provided, will download from the specified Hugging Face - repo instead. + @classmethod + def _from_pretrained( + cls, + *, + model_id: str = "google/timesfm-2.5-200m-pytorch", + revision: Optional[str], + cache_dir: Optional[Union[str, Path]], + force_download: bool, + proxies: Optional[Dict], + resume_download: Optional[bool], + local_files_only: bool, + token: Optional[str], + **model_kwargs, + ): """ - if path: - pass - elif hf_repo_id: - logging.info( - "Downloading checkpoint from Hugging Face repo %s", hf_repo_id - ) - path = os.path.join( - huggingface_hub.snapshot_download(hf_repo_id), "model.safetensors" - ) - logging.info("Loading checkpoint from: %s", path) + Loads a PyTorch safetensors TimesFM model from a local path or the Hugging + Face Hub. This method is the backend for the `from_pretrained` class + method provided by `ModelHubMixin`. + """ + # Create an instance of the model wrapper class. + instance = cls(**model_kwargs) + + # Determine the path to the model weights. + model_file_path = "" + if os.path.isdir(model_id): + logging.info("Loading checkpoint from local directory: %s", model_id) + model_file_path = os.path.join(model_id, "model.safetensors") + if not os.path.exists(model_file_path): + raise FileNotFoundError(f"model.safetensors not found in directory {model_id}") else: - raise ValueError("Either path or hf_repo_id must be provided.") - self.model.load_checkpoint(path) + logging.info("Downloading checkpoint from Hugging Face repo %s", model_id) + model_file_path = huggingface_hub.hf_hub_download( + repo_id=model_id, + filename="model.safetensors", + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + + logging.info("Loading checkpoint from: %s", model_file_path) + # Load the weights into the model. + instance.model.load_checkpoint(model_file_path, **model_kwargs) + return instance + + def _save_pretrained(self, save_directory: Union[str, Path]): + """ + Saves the model's state dictionary to a safetensors file. This method + is called by the `save_pretrained` method from `ModelHubMixin`. + """ + if not os.path.exists(save_directory): + os.makedirs(save_directory) + + weights_path = os.path.join(save_directory, "model.safetensors") + save_file(self.model.state_dict(), weights_path) + def compile(self, forecast_config: configs.ForecastConfig, **kwargs) -> None: """Attempts to compile the model for fast decoding. @@ -361,7 +400,7 @@ def compile(self, forecast_config: configs.ForecastConfig, **kwargs) -> None: ) self.forecast_config = fc - def _compiled_decode(horizon, inputs, masks): + def _compiled_decode(horizon, inputs, masks, train=False): if horizon > fc.max_horizon: raise ValueError( "Horizon must be less than the max horizon." @@ -385,7 +424,7 @@ def _compiled_decode(horizon, inputs, masks): mu, sigma = None, None pf_outputs, quantile_spreads, ar_outputs = self.model.decode( - forecast_config.max_horizon, inputs, masks + forecast_config.max_horizon, inputs, masks, train=train ) to_cat = [pf_outputs[:, -1, ...]] if ar_outputs is not None: @@ -398,7 +437,7 @@ def _compiled_decode(horizon, inputs, masks): if fc.force_flip_invariance: flipped_pf_outputs, flipped_quantile_spreads, flipped_ar_outputs = ( - self.model.decode(forecast_config.max_horizon, -inputs, masks) + self.model.decode(forecast_config.max_horizon, -inputs, masks, train=train) ) flipped_quantile_spreads = flip_quantile_fn(flipped_quantile_spreads) flipped_pf_outputs = flip_quantile_fn(flipped_pf_outputs) @@ -413,13 +452,21 @@ def _compiled_decode(horizon, inputs, masks): full_forecast = (full_forecast - flipped_full_forecast) / 2 if fc.use_continuous_quantile_head: - for quantile_index in [1, 2, 3, 4, 6, 7, 8, 9]: - full_forecast[:, :, quantile_index] = ( - quantile_spreads[:, : fc.max_horizon, quantile_index] - - quantile_spreads[:, : fc.max_horizon, 5] - + full_forecast[:, : fc.max_horizon, 5] - ) - full_forecast = full_forecast[:, :horizon, :] + # for quantile_index in [1, 2, 3, 4, 6, 7, 8, 9]: + # full_forecast[:, :, quantile_index] = ( + # quantile_spreads[:, : fc.max_horizon, quantile_index] + # - quantile_spreads[:, : fc.max_horizon, 5] + # + full_forecast[:, : fc.max_horizon, 5] + # ) + idx = torch.tensor([1,2,3,4,6,7,8,9], device=full_forecast.device) + base = full_forecast[:, :fc.max_horizon, 5:6] # (B, H, 1) + repl = (quantile_spreads[:, :fc.max_horizon, idx] # (B, H, |idx|) + - quantile_spreads[:, :fc.max_horizon, 5:6] + base) + + ff = full_forecast.clone() + ff[:, :fc.max_horizon, idx] = repl + full_forecast = ff + full_forecast = full_forecast[:, :horizon, :].contiguous() if fc.return_backcast: full_backcast = pf_outputs[:, :-1, : self.model.p, :].reshape( @@ -428,18 +475,20 @@ def _compiled_decode(horizon, inputs, masks): full_forecast = torch.cat([full_backcast, full_forecast], dim=1) if fc.fix_quantile_crossing: + ff = full_forecast.clone() for i in [4, 3, 2, 1]: - full_forecast[:, :, i] = torch.where( - full_forecast[:, :, i] < full_forecast[:, :, i + 1], - full_forecast[:, :, i], - full_forecast[:, :, i + 1], + ff[:, :, i] = torch.where( + ff[:, :, i] < ff[:, :, i + 1], + ff[:, :, i], + ff[:, :, i + 1], ) for i in [6, 7, 8, 9]: - full_forecast[:, :, i] = torch.where( - full_forecast[:, :, i] > full_forecast[:, :, i - 1], - full_forecast[:, :, i], - full_forecast[:, :, i - 1], + ff[:, :, i] = torch.where( + ff[:, :, i] > ff[:, :, i - 1], + ff[:, :, i], + ff[:, :, i - 1], ) + full_forecast = ff if fc.normalize_inputs: full_forecast = revin(full_forecast, mu, sigma, reverse=True) @@ -451,7 +500,8 @@ def _compiled_decode(horizon, inputs, masks): full_forecast, ) - full_forecast = full_forecast.detach().cpu().numpy() + if not train: + full_forecast = full_forecast.detach().cpu().numpy() return full_forecast[..., 5], full_forecast self.compiled_decode = _compiled_decode \ No newline at end of file diff --git a/src/samay/models/timesfm/timesfm/v2/transformer.py b/src/samay/models/timesfm/timesfm/v2/transformer.py index e2ab608..d3c3b01 100644 --- a/src/samay/models/timesfm/timesfm/v2/transformer.py +++ b/src/samay/models/timesfm/timesfm/v2/transformer.py @@ -18,12 +18,11 @@ from typing import Callable import torch -from torch import nn import torch.nn.functional as F +from torch import nn from . import configs -from . import normalization -from . import util +from . import normalization, util LayerNorm = nn.LayerNorm RMSNorm = normalization.RMSNorm @@ -31,26 +30,26 @@ def make_attn_mask( - query_length: int, - num_all_masked_kv: torch.Tensor, - query_index_offset: torch.Tensor | None = None, - kv_length: int = 0, + query_length: int, + num_all_masked_kv: torch.Tensor, + query_index_offset: torch.Tensor | None = None, + kv_length: int = 0, ) -> torch.Tensor: """Makes attention mask.""" if kv_length == 0: kv_length = query_length q_index = torch.arange(query_length, device=num_all_masked_kv.device)[ - None, None, :, None + None, None, :, None ] if query_index_offset is not None: q_index = q_index + query_index_offset[:, None, None, None] kv_index = torch.arange(kv_length, device=num_all_masked_kv.device)[ - None, None, None, : + None, None, None, : ] return torch.logical_and( - q_index >= kv_index, - kv_index >= num_all_masked_kv[:, None, None, None], + q_index >= kv_index, + kv_index >= num_all_masked_kv[:, None, None, None], ) @@ -58,10 +57,10 @@ class RotaryPositionalEmbedding(nn.Module): """Rotary positional embedding.""" def __init__( - self, - embedding_dims: int, - min_timescale: float = 1.0, - max_timescale: float = 10000.0, + self, + embedding_dims: int, + min_timescale: float = 1.0, + max_timescale: float = 10000.0, ): super().__init__() self.embedding_dims = embedding_dims @@ -69,31 +68,30 @@ def __init__( self.max_timescale = max_timescale def forward( - self, - inputs: torch.Tensor, - position: torch.Tensor | None = None, + self, + inputs: torch.Tensor, + position: torch.Tensor | None = None, ): """Generates a JTensor of sinusoids with different frequencies.""" if self.embedding_dims != inputs.shape[-1]: raise ValueError( - "The embedding dims of the rotary position embedding" - "must match the hidden dimension of the inputs." + "The embedding dims of the rotary position embedding" + "must match the hidden dimension of the inputs." ) half_embedding_dim = self.embedding_dims // 2 fraction = ( - 2 - * torch.arange(0, half_embedding_dim, device=inputs.device) - / self.embedding_dims + 2 + * torch.arange(0, half_embedding_dim, device=inputs.device) + / self.embedding_dims ) timescale = ( - self.min_timescale - * (self.max_timescale / self.min_timescale) ** fraction + self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction ).to(inputs.device) if position is None: seq_length = inputs.shape[1] - position = torch.arange( - seq_length, dtype=torch.float32, device=inputs.device - )[None, :] + position = torch.arange(seq_length, dtype=torch.float32, device=inputs.device)[ + None, : + ] if len(inputs.shape) == 4: position = position[..., None, None] @@ -114,16 +112,16 @@ def forward( def _dot_product_attention( - query, - key, - value, - mask=None, + query, + key, + value, + mask=None, ): """Computes dot-product attention given query, key, and value.""" attn_weights = torch.einsum("...qhd,...khd->...hqk", query, key) if mask is not None: attn_weights = torch.where( - mask, attn_weights, -torch.finfo(attn_weights.dtype).max / 2 + mask, attn_weights, -torch.finfo(attn_weights.dtype).max / 2 ) attn_weights = F.softmax(attn_weights, dim=-1) @@ -131,6 +129,28 @@ def _dot_product_attention( return torch.einsum("...hqk,...khd->...qhd", attn_weights, value) +def _torch_dot_product_attention(query, key, value, mask=None): + """ + Performs the exact same (unscaled) attention as the above function, + but using the fast and fused F.scaled_dot_product_attention kernel. + """ + + # 1. Permute inputs from (B, L, H, D) to the expected (B, H, L, D) + query = query.permute(0, 2, 1, 3).contiguous() + key = key.permute(0, 2, 1, 3).contiguous() + value = value.permute(0, 2, 1, 3).contiguous() + + # 2. Call the fused attention kernel + # - Pass the mask to `attn_mask`. + # - Set `scale=1.0` to disable the default 1/sqrt(d_k) scaling. + output = F.scaled_dot_product_attention(query, key, value, attn_mask=mask, scale=1.0) + + # 3. Permute the output back to the original (B, L, H, D) layout + output = output.permute(0, 2, 1, 3) + + return output + + class PerDimScale(nn.Module): """Per-dimension scaling.""" @@ -141,7 +161,7 @@ def __init__(self, num_dims: int): def forward(self, x: torch.Tensor) -> torch.Tensor: scale_factor = ( - 1.442695041 / math.sqrt(self.num_dims) * F.softplus(self.per_dim_scale) + 1.442695041 / math.sqrt(self.num_dims) * F.softplus(self.per_dim_scale) ) return x * scale_factor @@ -150,15 +170,16 @@ class MultiHeadAttention(nn.Module): """Multi-head attention.""" def __init__( - self, - num_heads: int, - in_features: int, - *, - use_per_dim_scale: bool = True, - use_rotary_position_embeddings: bool = True, - use_bias: bool = False, - attention_fn: Callable[..., torch.Tensor] = _dot_product_attention, - qk_norm: str = "rms", + self, + num_heads: int, + in_features: int, + *, + use_per_dim_scale: bool = True, + use_rotary_position_embeddings: bool = True, + use_bias: bool = False, + attention_fn: Callable[..., torch.Tensor] = _torch_dot_product_attention, + qk_norm: str = "rms", + fuse_qkv: bool = False, ): super().__init__() self.num_heads = num_heads @@ -167,16 +188,20 @@ def __init__( self.use_bias = use_bias self.attention_fn = attention_fn self.qk_norm = qk_norm + self.fuse_qkv = fuse_qkv if self.in_features % self.num_heads != 0: raise ValueError( - f"Memory dimension ({self.in_features}) must be divisible by " - f"'num_heads' heads ({self.num_heads})." + f"Memory dimension ({self.in_features}) must be divisible by " + f"'num_heads' heads ({self.num_heads})." ) - self.query = nn.Linear(self.in_features, self.in_features, bias=use_bias) - self.key = nn.Linear(self.in_features, self.in_features, bias=use_bias) - self.value = nn.Linear(self.in_features, self.in_features, bias=use_bias) + if self.fuse_qkv: + self.qkv_proj = nn.Linear(self.in_features, 3 * self.in_features, bias=use_bias) + else: + self.query = nn.Linear(self.in_features, self.in_features, bias=use_bias) + self.key = nn.Linear(self.in_features, self.in_features, bias=use_bias) + self.value = nn.Linear(self.in_features, self.in_features, bias=use_bias) self.out = nn.Linear(self.in_features, self.in_features, bias=use_bias) if self.qk_norm == "rms": @@ -189,7 +214,7 @@ def __init__( self.use_rotary_position_embeddings = use_rotary_position_embeddings if self.use_rotary_position_embeddings: self.rotary_position_embedding = RotaryPositionalEmbedding( - embedding_dims=self.head_dim, + embedding_dims=self.head_dim, ) self.use_per_dim_scale = use_per_dim_scale @@ -197,41 +222,41 @@ def __init__( self.per_dim_scale = PerDimScale(num_dims=self.head_dim) def forward( - self, - inputs_q: torch.Tensor, - *, - decode_cache: DecodeCache | None = None, - patch_mask: torch.Tensor | None = None, + self, + inputs_q: torch.Tensor, + *, + decode_cache: DecodeCache | None = None, + patch_mask: torch.Tensor | None = None, ) -> tuple[torch.Tensor, DecodeCache | None]: b, n_patches, _ = inputs_q.shape if patch_mask is None: - patch_mask = torch.zeros( - b, n_patches, dtype=torch.bool, device=inputs_q.device - ) - - query = self.query(inputs_q).view( - b, n_patches, self.num_heads, self.head_dim - ) - key = self.key(inputs_q).view(b, n_patches, self.num_heads, self.head_dim) - value = self.value(inputs_q).view( - b, n_patches, self.num_heads, self.head_dim - ) + patch_mask = torch.zeros(b, n_patches, dtype=torch.bool, device=inputs_q.device) + + if self.fuse_qkv: + qkv = self.qkv_proj(inputs_q) + query, key, value = torch.chunk(qkv, 3, dim=-1) + query = query.view(b, n_patches, self.num_heads, self.head_dim) + key = key.view(b, n_patches, self.num_heads, self.head_dim) + value = value.view(b, n_patches, self.num_heads, self.head_dim) + else: + query = self.query(inputs_q).view(b, n_patches, self.num_heads, self.head_dim) + key = self.key(inputs_q).view(b, n_patches, self.num_heads, self.head_dim) + value = self.value(inputs_q).view(b, n_patches, self.num_heads, self.head_dim) if decode_cache is None: num_masked = torch.sum(patch_mask.to(torch.int32), dim=-1) next_index = torch.zeros_like(num_masked, dtype=torch.int32) else: num_masked = ( - torch.sum(patch_mask.to(torch.int32), dim=-1) - + decode_cache.num_masked + torch.sum(patch_mask.to(torch.int32), dim=-1) + decode_cache.num_masked ) next_index = decode_cache.next_index.clone() if self.use_rotary_position_embeddings: position = ( - torch.arange(n_patches, device=inputs_q.device)[None, :] - + next_index[:, None] - - num_masked[:, None] + torch.arange(n_patches, device=inputs_q.device)[None, :] + + next_index[:, None] + - num_masked[:, None] ) query = self.rotary_position_embedding(query, position) key = self.rotary_position_embedding(key, position) @@ -244,32 +269,36 @@ def forward( if decode_cache is not None: _, decode_cache_size, _, _ = decode_cache.value.shape - for i in range(b): - start = decode_cache.next_index[i] - end = start + n_patches - decode_cache.key[i, start:end] = key[i].clone() - decode_cache.value[i, start:end] = value[i].clone() - key = decode_cache.key.clone() - value = decode_cache.value.clone() + + start = decode_cache.next_index[0] + end = start + n_patches + + # Perform a single, vectorized slice assignment for the entire batch. + # This is vastly more efficient than a Python for-loop. + + decode_cache.key[:, start:end] = key + decode_cache.value[:, start:end] = value + + key = decode_cache.key + value = decode_cache.value decode_cache.next_index += n_patches decode_cache.num_masked = num_masked attn_mask = make_attn_mask( - query_length=n_patches, - num_all_masked_kv=num_masked, - query_index_offset=next_index, - kv_length=decode_cache_size, + query_length=n_patches, + num_all_masked_kv=num_masked, + query_index_offset=next_index, + kv_length=decode_cache_size, ) else: - attn_mask = make_attn_mask( - query_length=n_patches, num_all_masked_kv=num_masked - ) + attn_mask = make_attn_mask(query_length=n_patches, num_all_masked_kv=num_masked) x = self.attention_fn( - query, - key, - value, - mask=attn_mask, + query, + key, + value, + mask=attn_mask, ) + x = x.reshape(b, n_patches, self.in_features) out = self.out(x) return out, decode_cache @@ -289,11 +318,12 @@ def __init__(self, config: configs.TransformerConfig): raise ValueError(f"Layer norm: {config.attention_norm} not supported.") self.attn = MultiHeadAttention( - num_heads=config.num_heads, - in_features=config.model_dims, - use_per_dim_scale=True, - use_rotary_position_embeddings=config.use_rotary_position_embeddings, - qk_norm=config.qk_norm, + num_heads=config.num_heads, + in_features=config.model_dims, + use_per_dim_scale=True, + use_rotary_position_embeddings=config.use_rotary_position_embeddings, + qk_norm=config.qk_norm, + fuse_qkv=config.fuse_qkv, ) if config.feedforward_norm == "rms": @@ -303,14 +333,14 @@ def __init__(self, config: configs.TransformerConfig): raise ValueError(f"Layer norm: {config.feedforward_norm} not supported.") self.ff0 = nn.Linear( - in_features=config.model_dims, - out_features=config.hidden_dims, - bias=config.use_bias, + in_features=config.model_dims, + out_features=config.hidden_dims, + bias=config.use_bias, ) self.ff1 = nn.Linear( - in_features=config.hidden_dims, - out_features=config.model_dims, - bias=config.use_bias, + in_features=config.hidden_dims, + out_features=config.model_dims, + bias=config.use_bias, ) if config.ff_activation == "relu": self.activation = nn.ReLU() @@ -322,21 +352,19 @@ def __init__(self, config: configs.TransformerConfig): raise ValueError(f"Activation: {config.ff_activation} not supported.") def forward( - self, - input_embeddings: torch.Tensor, - patch_mask: torch.Tensor, - decode_cache: DecodeCache | None = None, + self, + input_embeddings: torch.Tensor, + patch_mask: torch.Tensor, + decode_cache: DecodeCache | None = None, ) -> tuple[torch.Tensor, DecodeCache | None]: attn_output, decode_cache = self.attn( - inputs_q=self.pre_attn_ln(input_embeddings), - decode_cache=decode_cache, - patch_mask=patch_mask, + inputs_q=self.pre_attn_ln(input_embeddings), + decode_cache=decode_cache, + patch_mask=patch_mask, ) attn_output = self.post_attn_ln(attn_output) + input_embeddings output_embeddings = ( - self.post_ff_ln( - self.ff1(self.activation(self.ff0(self.pre_ff_ln(attn_output)))) - ) - + attn_output + self.post_ff_ln(self.ff1(self.activation(self.ff0(self.pre_ff_ln(attn_output))))) + + attn_output ) return output_embeddings, decode_cache \ No newline at end of file diff --git a/src/samay/utils.py b/src/samay/utils.py index 55af6e0..b9885ae 100644 --- a/src/samay/utils.py +++ b/src/samay/utils.py @@ -490,7 +490,7 @@ def f1_score(predict, actual): return f1 -def quantile_loss(self, pred, actual, quantile): +def quantile_loss(pred, actual, quantile): """Calculates quantile loss.""" dev = actual - pred loss_first = dev * quantile