Skip to content

Commit

Permalink
Fixup and tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
jaheba committed May 26, 2023
1 parent 672f80c commit 7bb35c8
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 3 deletions.
8 changes: 6 additions & 2 deletions src/gluonts/zebras/_time_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
rows_to_columns,
select,
join_items,
replace,
)

from ._base import Pad, TimeBase
Expand Down Expand Up @@ -228,16 +229,19 @@ def update(self, other: TimeFrame, default=np.nan) -> TimeFrame:
self_idx0 = index.index_of(self.index.start)
other_idx0 = index.index_of(other.index.start)

tdims = {**self.tdims, **other.tdims}
# create new columns, by first filling them with default values and
# then writing the values of self and other to them
columns = {}
for name, self_col, other_col in join_items(
self.columns, other.columns, "outer"
):
tdim = self.tdims[name]
tdim = tdims[name]

values = np.full(
replace(self_col.shape, tdim, len(index)),
replace(
maybe.or_(self_col, other_col).shape, tdim, len(index)
),
default,
)
view = AxisView(values, tdim)
Expand Down
2 changes: 1 addition & 1 deletion src/gluonts/zebras/_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from toolz import first

from gluonts import maybe
from gluonts.itertools import pluck_attr, replace
from gluonts.itertools import pluck_attr

from ._base import Pad, TimeBase
from ._period import period, Period, Periods
Expand Down
20 changes: 20 additions & 0 deletions test/zebras/test_timeframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,23 @@ def test_rename():

def test_update():
assert tf.update(tf).eq_to(tf)

left = tf[:3]
right = tf[-3:]

tf2 = left.update(right)
assert len(tf) == len(tf2)
assert tf.index == tf2.index

gap = tf2[3:-3]
assert np.isnan(gap["target"]).all()

xxx = zb.time_frame(
{"xxx": np.full(3, 99)},
index=right.index,
)
tf3 = left.update(xxx)

assert len(tf) == len(tf3)
assert tf.index == tf3.index
tf3.columns.keys() == tf.columns.keys() | {"xxx"}

0 comments on commit 7bb35c8

Please sign in to comment.