Skip to content

Commit

Permalink
refactor(base_streams): refactor base_streams.py by applying black
Browse files Browse the repository at this point in the history
  • Loading branch information
BjoernLudwigPTB committed Jun 29, 2021
1 parent 8c62972 commit 34ff2d0
Showing 1 changed file with 56 additions and 45 deletions.
101 changes: 56 additions & 45 deletions agentMET4FOF/streams/base_streams.py
Expand Up @@ -5,7 +5,8 @@
from time_series_metadata.scheme import MetaData
import warnings

class DataStreamMET4FOF():

class DataStreamMET4FOF:
"""
Abstract class for creating datastreams.
Expand Down Expand Up @@ -79,10 +80,10 @@ def __init__(self):
self._current_sample_quantities: Union[List, DataFrame, np.ndarray]
self._current_sample_target: Union[List, DataFrame, np.ndarray]
self._current_sample_time: Union[List, DataFrame, np.ndarray]
self._sample_idx: int = 0 #current sample index
self._n_samples: int = 0 #total number of samples
self._sample_idx: int = 0 # current sample index
self._n_samples: int = 0 # total number of samples
self._data_source_type: str = "function"
self._generator_function : Callable
self._generator_function: Callable
self._generator_parameters: Dict = {}
self.sfreq: int = 1
self._metadata: MetaData
Expand All @@ -105,7 +106,10 @@ def randomize_data(self):
np.random.shuffle(random_index)
self._quantities = self._quantities[random_index]

if type(self._target).__name__ == "ndarray" or type(self._target).__name__ == "list":
if (
type(self._target).__name__ == "ndarray"
or type(self._target).__name__ == "list"
):
self._target = self._target[random_index]
elif type(self._target).__name__ == "DataFrame":
self._target = self._target.iloc[random_index]
Expand All @@ -119,13 +123,13 @@ def sample_idx(self):
return self._sample_idx

def set_metadata(
self,
device_id: str,
time_name: str,
time_unit: str,
quantity_names: Union[str, Tuple[str, ...]],
quantity_units: Union[str, Tuple[str, ...]],
misc: Optional[Any] = None
self,
device_id: str,
time_name: str,
time_unit: str,
quantity_names: Union[str, Tuple[str, ...]],
quantity_units: Union[str, Tuple[str, ...]],
misc: Optional[Any] = None,
):
"""Set the quantities metadata as a ``MetaData`` object
Expand Down Expand Up @@ -154,7 +158,7 @@ def set_metadata(
time_unit=time_unit,
quantity_names=quantity_names,
quantity_units=quantity_units,
misc=misc
misc=misc,
)

def _default_generator_function(self, time):
Expand All @@ -164,11 +168,11 @@ def _default_generator_function(self, time):
----------
time : Union[List, DataFrame, np.ndarray]
"""
value = np.sin(2*np.pi*self.F*time)
value = np.sin(2 * np.pi * self.F * time)
return value

def set_generator_function(
self, generator_function: Callable = None, sfreq: int = None, **kwargs: Any
self, generator_function: Callable = None, sfreq: int = None, **kwargs: Any
):
"""
Sets the data source to a generator function. By default, this function resorts
Expand All @@ -191,14 +195,14 @@ def set_generator_function(
The generator function call for every sample will be supplied with the
``**generator_parameters``.
"""
#save the kwargs into generator_parameters
# save the kwargs into generator_parameters
self._generator_parameters = kwargs

if sfreq is not None:
self.sfreq = sfreq
self._set_data_source_type("function")

#resort to default wave generator if one is not supplied
# resort to default wave generator if one is not supplied
if generator_function is None:
warnings.warn(
"No uncertainty generator function specified. Setting to default ("
Expand All @@ -214,21 +218,20 @@ def _next_sample_generator(self, batch_size: int = 1) -> Dict[str, np.ndarray]:
"""
Internal method for generating a batch of samples from the generator function.
"""
time: np.ndarray = np.arange(self._sample_idx, self._sample_idx + batch_size,
1)/self.sfreq
time: np.ndarray = (
np.arange(self._sample_idx, self._sample_idx + batch_size, 1) / self.sfreq
)
self._sample_idx += batch_size

value: np.ndarray = self._generator_function(
time, **self._generator_parameters
)
value: np.ndarray = self._generator_function(time, **self._generator_parameters)

return {'quantities': value, 'time': time}
return {"quantities": value, "time": time}

def set_data_source(
self,
quantities: Union[List, DataFrame, np.ndarray]=None,
target: Optional[Union[List, DataFrame, np.ndarray]]=None,
time: Optional[Union[List, DataFrame, np.ndarray]]=None
self,
quantities: Union[List, DataFrame, np.ndarray] = None,
target: Optional[Union[List, DataFrame, np.ndarray]] = None,
time: Optional[Union[List, DataFrame, np.ndarray]] = None,
):
"""
This sets the data source by providing up to three iterables: ``quantities`` ,
Expand Down Expand Up @@ -269,10 +272,10 @@ def set_data_source(
self._target = target
self._time = time

#infer number of samples
# infer number of samples
if type(self._quantities).__name__ == "list":
self._n_samples = len(self._quantities)
elif type(self._quantities).__name__ == "DataFrame": #dataframe or numpy
elif type(self._quantities).__name__ == "DataFrame": # dataframe or numpy
self._quantities = self._quantities.to_numpy()
self._n_samples = self._quantities.shape[0]
elif type(self._quantities).__name__ == "ndarray":
Expand Down Expand Up @@ -312,9 +315,9 @@ def next_sample(self, batch_size: int = 1):
'target':current_sample_target}``
"""

if self._data_source_type == 'function':
if self._data_source_type == "function":
return self._next_sample_generator(batch_size)
elif self._data_source_type == 'dataset':
elif self._data_source_type == "dataset":
return self._next_sample_data_source(batch_size)

def _next_sample_data_source(
Expand All @@ -340,26 +343,35 @@ def _next_sample_data_source(
self._sample_idx += batch_size

try:
self._current_sample_quantities = self._quantities[self._sample_idx - batch_size:self._sample_idx]
self._current_sample_quantities = self._quantities[
self._sample_idx - batch_size : self._sample_idx
]

#if target is available
# if target is available
if self._target is not None:
self._current_sample_target = self._target[self._sample_idx - batch_size:self._sample_idx]
self._current_sample_target = self._target[
self._sample_idx - batch_size : self._sample_idx
]
else:
self._current_sample_target = None

#if time is available
# if time is available
if self._time is not None:
self._current_sample_time = self._time[self._sample_idx - batch_size
:self._sample_idx]
self._current_sample_time = self._time[
self._sample_idx - batch_size : self._sample_idx
]
else:
self._current_sample_time = None
except IndexError:
self._current_sample_quantities = None
self._current_sample_target = None
self._current_sample_time = None

return {'time':self._current_sample_time, 'quantities': self._current_sample_quantities, 'target': self._current_sample_target}
return {
"time": self._current_sample_time,
"quantities": self._current_sample_quantities,
"target": self._current_sample_target,
}

def reset(self):
self._sample_idx = 0
Expand All @@ -380,13 +392,12 @@ def extract_x_y(message):
Handle data structures of dictionary to extract features & target
"""
if type(message['data']) == tuple:
x = message['data'][0]
y = message['data'][1]
elif type(message['data']) == dict:
x = message['data']['x']
y = message['data']['y']
if type(message["data"]) == tuple:
x = message["data"][0]
y = message["data"][1]
elif type(message["data"]) == dict:
x = message["data"]["x"]
y = message["data"]["y"]
else:
return 1
return x, y

0 comments on commit 34ff2d0

Please sign in to comment.