Skip to content

Commit

Permalink
[FEAT] ISQF (#1019)
Browse files Browse the repository at this point in the history
* isqf

* fix_docstrings

---------

Co-authored-by: Cristian Challu <cristiani.challu@gmail.com>
  • Loading branch information
elephaint and cchallu committed Jun 13, 2024
1 parent e1dc723 commit 90b6fd0
Show file tree
Hide file tree
Showing 10 changed files with 1,815 additions and 106 deletions.
6 changes: 5 additions & 1 deletion nbs/common.base_multivariate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,11 @@
" # Model Predictions\n",
" output = self(windows_batch)\n",
" if self.loss.is_distribution_output:\n",
" _, y_loc, y_scale = self._inv_normalization(y_hat=output[0],\n",
" _, y_loc, y_scale = self._inv_normalization(y_hat=torch.empty(size=(insample_y.shape[0], \n",
" self.h, \n",
" self.n_series),\n",
" dtype=output[0].dtype,\n",
" device=output[0].device),\n",
" temporal_cols=batch['temporal_cols'],\n",
" y_idx=y_idx)\n",
" distr_args = self.loss.scale_decouple(output=output, loc=y_loc, scale=y_scale)\n",
Expand Down
9 changes: 5 additions & 4 deletions nbs/common.base_recurrent.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
"import torch\n",
"import torch.nn as nn\n",
"import pytorch_lightning as pl\n",
"import neuralforecast.losses.pytorch as losses\n",
"\n",
"from neuralforecast.common._base_model import BaseModel\n",
"from neuralforecast.common._scalers import TemporalNorm\n",
Expand Down Expand Up @@ -138,10 +139,10 @@
" self.inference_input_size = inference_input_size\n",
" self.padder = nn.ConstantPad1d(padding=(0, self.h), value=0)\n",
"\n",
"\n",
" if str(type(self.loss)) == \"<class 'neuralforecast.losses.pytorch.DistributionLoss'>\" and\\\n",
" self.loss.distribution=='Bernoulli':\n",
" raise Exception('Temporal Classification not yet available for Recurrent-based models')\n",
" unsupported_distributions = ['Bernoulli', 'ISQF']\n",
" if isinstance(self.loss, losses.DistributionLoss) and\\\n",
" self.loss.distribution in unsupported_distributions:\n",
" raise Exception(f'Distribution {self.loss.distribution} not available for Recurrent-based models. Please choose another distribution.')\n",
"\n",
" # Valid batch_size\n",
" self.batch_size = batch_size\n",
Expand Down
24 changes: 3 additions & 21 deletions nbs/common.base_windows.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -420,12 +420,6 @@
" insample_y, insample_mask, outsample_y, outsample_mask, \\\n",
" hist_exog, futr_exog, stat_exog = self._parse_windows(batch, windows)\n",
"\n",
" # Implicit Quantile Loss\n",
" # if isinstance(self.loss, losses.IQLoss):\n",
" # self.loss.training_update_quantile(batch_size = (insample_y.shape[0], 1), \n",
" # device = insample_y.device)\n",
" # stat_exog = self._update_stat_exog_iqloss(self.loss.q, stat_exog)\n",
"\n",
" windows_batch = dict(insample_y=insample_y, # [Ws, L]\n",
" insample_mask=insample_mask, # [Ws, L]\n",
" futr_exog=futr_exog, # [Ws, L + h, F]\n",
Expand Down Expand Up @@ -514,12 +508,6 @@
" insample_y, insample_mask, _, outsample_mask, \\\n",
" hist_exog, futr_exog, stat_exog = self._parse_windows(batch, windows)\n",
"\n",
" # Implicit Quantile Loss\n",
" # if isinstance(self.valid_loss, losses.IQLoss):\n",
" # self.valid_loss.training_update_quantile(batch_size = (insample_y.shape[0], 1), \n",
" # device = insample_y.device)\n",
" # stat_exog = self._update_stat_exog_iqloss(self.valid_loss.q, stat_exog)\n",
"\n",
" windows_batch = dict(insample_y=insample_y, # [Ws, L]\n",
" insample_mask=insample_mask, # [Ws, L]\n",
" futr_exog=futr_exog, # [Ws, L + h, F]\n",
Expand Down Expand Up @@ -578,14 +566,6 @@
" insample_y, insample_mask, _, _, \\\n",
" hist_exog, futr_exog, stat_exog = self._parse_windows(batch, windows)\n",
"\n",
" # Implicit Quantile Loss\n",
" # if isinstance(self.loss, losses.IQLoss):\n",
" # quantiles = torch.full(size=(insample_y.shape[0], 1), \n",
" # fill_value=self.quantile,\n",
" # device=insample_y.device,\n",
" # dtype=insample_y.dtype) \n",
" # stat_exog = self._update_stat_exog_iqloss(quantiles, stat_exog)\n",
"\n",
" windows_batch = dict(insample_y=insample_y, # [Ws, L]\n",
" insample_mask=insample_mask, # [Ws, L]\n",
" futr_exog=futr_exog, # [Ws, L + h, F]\n",
Expand All @@ -596,7 +576,9 @@
" output_batch = self(windows_batch)\n",
" # Inverse normalization and sampling\n",
" if self.loss.is_distribution_output:\n",
" _, y_loc, y_scale = self._inv_normalization(y_hat=output_batch[0],\n",
" _, y_loc, y_scale = self._inv_normalization(y_hat=torch.empty(size=(insample_y.shape[0], self.h),\n",
" dtype=output_batch[0].dtype,\n",
" device=output_batch[0].device),\n",
" temporal_cols=batch['temporal_cols'],\n",
" y_idx=y_idx)\n",
" distr_args = self.loss.scale_decouple(output=output_batch, loc=y_loc, scale=y_scale)\n",
Expand Down
Loading

0 comments on commit 90b6fd0

Please sign in to comment.