From 7bb35c8f347ca319def34174f5ee9fbe284964ed Mon Sep 17 00:00:00 2001 From: Jasper Zschiegner Date: Fri, 26 May 2023 23:03:31 +0200 Subject: [PATCH] Fixup and tests. --- src/gluonts/zebras/_time_frame.py | 8 ++++++-- src/gluonts/zebras/_time_series.py | 2 +- test/zebras/test_timeframe.py | 20 ++++++++++++++++++++ 3 files changed, 27 insertions(+), 3 deletions(-) diff --git a/src/gluonts/zebras/_time_frame.py b/src/gluonts/zebras/_time_frame.py index 202b4ce4d5..a5953f1e64 100644 --- a/src/gluonts/zebras/_time_frame.py +++ b/src/gluonts/zebras/_time_frame.py @@ -27,6 +27,7 @@ rows_to_columns, select, join_items, + replace, ) from ._base import Pad, TimeBase @@ -228,16 +229,19 @@ def update(self, other: TimeFrame, default=np.nan) -> TimeFrame: self_idx0 = index.index_of(self.index.start) other_idx0 = index.index_of(other.index.start) + tdims = {**self.tdims, **other.tdims} # create new columns, by first filling them with default values and # then writing the values of self and other to them columns = {} for name, self_col, other_col in join_items( self.columns, other.columns, "outer" ): - tdim = self.tdims[name] + tdim = tdims[name] values = np.full( - replace(self_col.shape, tdim, len(index)), + replace( + maybe.or_(self_col, other_col).shape, tdim, len(index) + ), default, ) view = AxisView(values, tdim) diff --git a/src/gluonts/zebras/_time_series.py b/src/gluonts/zebras/_time_series.py index f0dae36dfd..b35b9e981d 100644 --- a/src/gluonts/zebras/_time_series.py +++ b/src/gluonts/zebras/_time_series.py @@ -21,7 +21,7 @@ from toolz import first from gluonts import maybe -from gluonts.itertools import pluck_attr, replace +from gluonts.itertools import pluck_attr from ._base import Pad, TimeBase from ._period import period, Period, Periods diff --git a/test/zebras/test_timeframe.py b/test/zebras/test_timeframe.py index 1b307b4dcd..74f3943710 100644 --- a/test/zebras/test_timeframe.py +++ b/test/zebras/test_timeframe.py @@ -196,3 +196,23 @@ def test_rename(): def test_update(): assert tf.update(tf).eq_to(tf) + + left = tf[:3] + right = tf[-3:] + + tf2 = left.update(right) + assert len(tf) == len(tf2) + assert tf.index == tf2.index + + gap = tf2[3:-3] + assert np.isnan(gap["target"]).all() + + xxx = zb.time_frame( + {"xxx": np.full(3, 99)}, + index=right.index, + ) + tf3 = left.update(xxx) + + assert len(tf) == len(tf3) + assert tf.index == tf3.index + tf3.columns.keys() == tf.columns.keys() | {"xxx"}