Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zebras: Add update to TimeSeries. #2855

Open
wants to merge 12 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions src/gluonts/zebras/_time_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
rows_to_columns,
select,
join_items,
replace,
)

from ._base import Pad, TimeBase
Expand Down Expand Up @@ -168,6 +169,112 @@ def pad(self, value, left: int = 0, right: int = 0) -> TimeFrame:
_pad=Pad(pad_left, pad_right),
)

def update(self, other: TimeFrame, default=np.nan) -> TimeFrame:
"""Create a new ``TimeFrame`` which includes values of both input
frames.

The new frame spans both input frames and inserts default values if
there is a gap between the two frames. If both frames overlap, the
second overwrites the values of the first frame.

Static columns and metadata is also updated, and the second frames
value take precedence.

Updating a frame with itself is effectively a noop, similar to how
``dict.update`` on the same dict will return an identical result.

Update requires that both frame have defined indices, since otherwise
its not possible to know how the values relate to each other.

Note: ``update`` will reset the padding.
"""

if self.index is None or other.index is None:
raise ValueError("Both time frames need to have an index.")

if self.index.freq != other.index.freq:
raise ValueError("frequency mismatch on index.")

# ensure tdims match
for name, left, right in join_items(self.tdims, other.tdims, "inner"):
if left != right:
raise ValueError(
f"tdims of {name} don't match {left} != {right}"
)

# ensure column shapes match
for name, left, right in join_items(
self.columns, other.columns, "inner"
):
tdim = self.tdims[name]

if replace(left.shape, tdim, 0) != replace(right.shape, tdim, 0):
raise ValueError(f"Incompatible shapes of columns {name}")

start = min(self.index.start, other.index.start)
end = max(self.index.end, other.index.end)

# create a new index that spans the new range
index = Periods(
np.arange(
start.to_numpy(),
# arange is exclusive, thus we need to add `1`
(end + 1).to_numpy(),
start.freq.step,
),
start.freq,
)
# get position of self and other relative to new index
# (one of them will be zero)
self_idx0 = index.index_of(self.index.start)
other_idx0 = index.index_of(other.index.start)

tdims = {**self.tdims, **other.tdims}
# create new columns, by first filling them with default values and
# then writing the values of self and other to them
columns = {}
for name, self_col, other_col in join_items(
self.columns, other.columns, "outer"
):
tdim = tdims[name]

values = np.full(
replace(
cast(
TimeFrame,
maybe.unwrap(maybe.or_(self_col, other_col)),
).shape,
tdim,
len(index),
),
default,
)
view = AxisView(values, tdim)

if self_col is not None:
view[self_idx0 : self_idx0 + len(self)] = self_col
if other_col is not None:
view[other_idx0 : other_idx0 + len(other)] = other_col

columns[name] = values

static = {**self.static, **other.static}

if self.metadata is not None and other.metadata is not None:
metadata: Optional[dict] = {**self.metadata, **other.metadata}
else:
metadata = maybe.or_(self.metadata, other.metadata)

return _replace(
self,
columns=columns,
static=static,
index=index,
length=len(index),
metadata=metadata,
_pad=Pad(),
)

def astype(self, type, columns=None) -> TimeFrame:
if columns is None:
columns = self.columns
Expand Down
80 changes: 79 additions & 1 deletion src/gluonts/zebras/_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from toolz import first

from gluonts import maybe
from gluonts.itertools import pluck_attr
from gluonts.itertools import pluck_attr, replace

from ._base import Pad, TimeBase
from ._period import period, Period, Periods
Expand Down Expand Up @@ -91,6 +91,84 @@ def pad(self, value, left: int = 0, right: int = 0) -> TimeSeries:
_pad=self._pad.extend(left, right),
)

def update(self, other: TimeSeries, default=np.nan) -> TimeSeries:
"""Create a new ``TimeSeries`` which includes values of both input
series.

The new series spans both input series and inserts default values if
there is a gap between the two series. If both series overlap, the
second overwrites the values of the first series.

Name and metadata is also updated, and the second series
value take precedence.

Updating a series with itself is effectively a noop, similar to how
``dict.update`` on the same dict will return an identical result.

Update requires that both series have defined indices, since otherwise
its not possible to know how the values relate to each other.

Note: ``update`` will reset the padding.
"""

if self.index is None or other.index is None:
raise ValueError("Both time frames need to have an index.")

if self.index.freq != other.index.freq:
raise ValueError("frequency mismatch on index.")

# ensure tdims match
if self.tdim != other.tdim:
raise ValueError("tdims mismatch.")

if replace(np.shape(self), self.tdim, 0) != replace(
np.shape(other), other.tdim, 0
):
raise ValueError("Incompatible shapes.")

start = min(self.index.start, other.index.start)
end = max(self.index.end, other.index.end)

index = Periods(
np.arange(
start.to_numpy(),
# arange is exclusive, thus we need to add `1`
(end + 1).to_numpy(),
start.freq.step,
),
start.freq,
)

values = np.full(
replace(np.shape(self), self.tdim, len(index)),
default,
)
view = AxisView(values, self.tdim)

idx = index.index_of(self.index.start)
view[idx : idx + len(self)] = self.values

idx = index.index_of(other.index.start)
view[idx : idx + len(other)] = other.values

if self.metadata is not None and other.metadata is not None:
metadata: Optional[dict] = {**self.metadata, **other.metadata}
else:
metadata = maybe.or_(self.metadata, other.metadata)

# TODO: Pad -- does it even make sense?

name = maybe.or_(other.name, self.name)

return _replace(
self,
values=values,
index=index,
metadata=metadata,
name=name,
_pad=Pad(),
)

@staticmethod
def _batch(xs: List[TimeSeries]) -> BatchTimeSeries:
for series in xs:
Expand Down
6 changes: 6 additions & 0 deletions src/gluonts/zebras/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ def __getitem__(self, index):

return self.data[tuple(slices)]

def __setitem__(self, index, value):
slices = [slice(None)] * self.data.ndim
slices[self.axis] = index

self.data[tuple(slices)] = value

def __len__(self):
return self.data.shape[self.axis]

Expand Down
43 changes: 43 additions & 0 deletions test/zebras/test_timeframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,46 @@ def test_rename():

tf2 = tf.rename_static(x="static")
assert tf2.static.keys() == {"x"}


def test_update_timeframe():
assert tf.update(tf).eq_to(tf)

left = tf[:3]
right = tf[-3:]

tf2 = left.update(right)
assert len(tf) == len(tf2)
assert tf.index == tf2.index

gap = tf2[3:-3]
assert np.isnan(gap["target"]).all()

xxx = zb.time_frame(
{"xxx": np.full(3, 99)},
index=right.index,
)
tf3 = left.update(xxx)

assert len(tf) == len(tf3)
assert tf.index == tf3.index
tf3.columns.keys() == tf.columns.keys() | {"xxx"}


def test_update_timeseries():
ts = tf["target"]
assert np.all(ts.update(ts) == ts)

left = ts[:3]
right = ts[-3:]

ts2 = left.update(right)
assert len(ts) == len(ts2)
assert ts.index == ts2.index

gap = ts2[3:-3]
assert np.isnan(gap).all()

ts3 = left.update(zb.time_series(np.full(3, 99), index=right.index))
assert len(ts) == len(ts3)
assert ts.index == ts3.index
Loading