Skip to content

Commit

Permalink
update min_size in TimeSeriesDataset.append (#1033)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez committed Jul 1, 2024
1 parent ab918ba commit b534ddf
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 19 deletions.
27 changes: 14 additions & 13 deletions nbs/tsdataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,9 @@
" len_temporal, col_temporal = self.temporal.shape\n",
" len_futr = futr_dataset.temporal.shape[0]\n",
" new_temporal = torch.empty(size=(len_temporal + len_futr, col_temporal))\n",
" new_sizes = np.diff(self.indptr) + np.diff(futr_dataset.indptr)\n",
" new_indptr = np.append(0, new_sizes.cumsum()).astype(np.int32)\n",
" new_indptr = self.indptr + futr_dataset.indptr\n",
" new_sizes = np.diff(new_indptr)\n",
" new_min_size = np.min(new_sizes)\n",
" new_max_size = np.max(new_sizes)\n",
"\n",
" for i in range(self.n_groups):\n",
Expand All @@ -260,17 +261,17 @@
" new_temporal[new_indptr[i] + curr_size : new_indptr[i + 1]] = futr_dataset.temporal[futr_slice]\n",
" \n",
" # Define new dataset\n",
" updated_dataset = TimeSeriesDataset(temporal=new_temporal,\n",
" temporal_cols=self.temporal_cols.copy(),\n",
" indptr=new_indptr,\n",
" max_size=new_max_size,\n",
" min_size=self.min_size,\n",
" static=self.static,\n",
" y_idx=self.y_idx,\n",
" static_cols=self.static_cols,\n",
" sorted=self.sorted)\n",
"\n",
" return updated_dataset\n",
" return TimeSeriesDataset(\n",
" temporal=new_temporal,\n",
" temporal_cols=self.temporal_cols.copy(),\n",
" indptr=new_indptr,\n",
" max_size=new_max_size,\n",
" min_size=new_min_size,\n",
" static=self.static,\n",
" y_idx=self.y_idx,\n",
" static_cols=self.static_cols,\n",
" sorted=self.sorted\n",
" )\n",
"\n",
" @staticmethod\n",
" def update_dataset(dataset, futr_df, id_col='unique_id', time_col='ds', target_col='y'):\n",
Expand Down
11 changes: 5 additions & 6 deletions neuralforecast/tsdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,9 @@ def append(self, futr_dataset: "TimeSeriesDataset") -> "TimeSeriesDataset":
len_temporal, col_temporal = self.temporal.shape
len_futr = futr_dataset.temporal.shape[0]
new_temporal = torch.empty(size=(len_temporal + len_futr, col_temporal))
new_sizes = np.diff(self.indptr) + np.diff(futr_dataset.indptr)
new_indptr = np.append(0, new_sizes.cumsum()).astype(np.int32)
new_indptr = self.indptr + futr_dataset.indptr
new_sizes = np.diff(new_indptr)
new_min_size = np.min(new_sizes)
new_max_size = np.max(new_sizes)

for i in range(self.n_groups):
Expand All @@ -207,20 +208,18 @@ def append(self, futr_dataset: "TimeSeriesDataset") -> "TimeSeriesDataset":
)

# Define new dataset
updated_dataset = TimeSeriesDataset(
return TimeSeriesDataset(
temporal=new_temporal,
temporal_cols=self.temporal_cols.copy(),
indptr=new_indptr,
max_size=new_max_size,
min_size=self.min_size,
min_size=new_min_size,
static=self.static,
y_idx=self.y_idx,
static_cols=self.static_cols,
sorted=self.sorted,
)

return updated_dataset

@staticmethod
def update_dataset(
dataset, futr_df, id_col="unique_id", time_col="ds", target_col="y"
Expand Down

0 comments on commit b534ddf

Please sign in to comment.