Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more flexible sampler types through Range #2758

Draft
wants to merge 1 commit into
base: dev
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 80 additions & 1 deletion src/gluonts/transform/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,93 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import Tuple
from dataclasses import dataclass
from typing import Tuple, Optional, Union

import numpy as np
import pandas as pd
from pydantic import BaseModel

from gluonts.dataset.stat import ScaleHistogram


def clip(value, low, high):
"""
Clip ``value`` between ``low`` and ``high``, included.
"""
return max(low, min(high, value))


@dataclass
class Range:
start: Optional[Union[int, pd.Period]] = None
stop: Optional[Union[int, pd.Period]] = None
step: int = 1

def _start_as_int(self, start: pd.Period, length: int) -> int:
if self.start is None:
return 0
if isinstance(self.start, pd.Period):
return int((self.start - start) / start.freq)
if self.start < 0:
return length + self.start
return self.start

def _stop_as_int(self, start: pd.Period, length: int) -> int:
if self.stop is None:
return length
if isinstance(self.stop, pd.Period):
return int((self.stop - start) / start.freq)
if self.stop < 0:
return length + self.stop
return self.stop

def get(self, start: pd.Period, length: int) -> range:
return range(
clip(self._start_as_int(start, length), 0, length),
clip(self._stop_as_int(start, length), 0, length),
self.step,
)


@dataclass
class Sampler:
range_: Range

def sample(self, rge: range) -> list:
raise NotImplementedError()

def __call__(self, start: pd.Period, length: int) -> list:
return self.sample(self.range_.get(start, length))


@dataclass
class SampleAll(Sampler):
def sample(self, rge: range) -> list:
return list(rge)


@dataclass
class SampleOnAverage(Sampler):
average_num_samples: float = 1.0

def __post_init__(self):
self.average_length = 0
self.count = 0

def sample(self, rge: range) -> list:
if len(rge) == 0:
return []

self.average_length = (self.count * self.average_length + len(rge)) / (
self.count + 1
)
self.count += 1
p = self.average_num_samples / self.average_length
(indices,) = np.where(np.random.random_sample(len(rge)) < p)
return (min(rge) + indices).tolist()


class InstanceSampler(BaseModel):
"""
An InstanceSampler is called with the time series ``ts``, and returns a set
Expand Down