Skip to content

Commit

Permalink
Update with mypy typing (#427)
Browse files Browse the repository at this point in the history
  • Loading branch information
allenporter committed May 12, 2024
1 parent 9b97a53 commit 951f2c6
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 62 deletions.
16 changes: 8 additions & 8 deletions pyrainbird/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _device_busy_retry() -> JitterRetry:
return JitterRetry(
attempts=_retry_attempts(),
start_timeout=_retry_delay(),
statuses=[HTTPStatus.SERVICE_UNAVAILABLE.value],
statuses=set([HTTPStatus.SERVICE_UNAVAILABLE.value]),
retry_all_server_errors=False,
)

Expand All @@ -114,10 +114,10 @@ def __init__(
self._password = password
self._coder = encryption.PayloadCoder(password, _LOGGER)

def with_retry_options(self, retry_options: RetryOptions) -> "AsyncRainbirdClient":
def with_retry_options(self, retry_options: RetryOptions) -> "AsyncRainbirdClient": # type: ignore[valid-type]
"""Create a new AsyncRainbirdClient with retry options."""
return AsyncRainbirdClient(
RetryClient(client_session=self._websession, retry_options=retry_options),
RetryClient(client_session=self._websession, retry_options=retry_options), # type: ignore[arg-type]
self._host,
self._password,
)
Expand Down Expand Up @@ -147,7 +147,7 @@ async def request(
"Error communicating with Rain Bird device"
) from err
content = await resp.read()
return self._coder.decode_command(content)
return self._coder.decode_command(content) # type: ignore


def CreateController(
Expand All @@ -165,7 +165,7 @@ class AsyncRainbirdController:
def __init__(
self,
local_client: AsyncRainbirdClient,
cloud_client: AsyncRainbirdClient = None,
cloud_client: AsyncRainbirdClient | None = None,
) -> None:
"""Initialize AsyncRainbirdController."""
self._local_client = local_client
Expand Down Expand Up @@ -418,15 +418,15 @@ async def get_schedule(self) -> Schedule:
commands.append("%04x" % (0x80 | zone_page))
_LOGGER.debug("Sending schedule commands: %s", commands)
# Run command serially to avoid overwhelming the controller
schedule_data = {
schedule_data: dict[str, Any] = {
"controllerInfo": {},
"programInfo": [],
"programStartInfo": [],
"durations": [],
}
for command in commands:
result = await self._process_command(
None, "RetrieveScheduleRequest", int(command, 16) # Disable validation
None, "RetrieveScheduleRequest", int(command, 16) # type: ignore
)
if not isinstance(result, dict):
continue
Expand Down Expand Up @@ -509,4 +509,4 @@ async def _cacheable_command(
return result
result = await self._process_command(funct, command, *args)
self._cache[key] = result
return result
return result # type: ignore
152 changes: 106 additions & 46 deletions pyrainbird/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class States:
"""Rainbird controller response containing a bitmask string e.g. active zones."""

count: int
mask: str
mask: int
states: tuple

def __init__(self, mask: str) -> None:
Expand Down Expand Up @@ -195,20 +195,42 @@ class WaterBudget:
class WifiParams(DataClassDictMixin):
"""Wifi parameters for the device."""

mac_address: Optional[str] = field(metadata=field_options(alias="macAddress"), default=None)
mac_address: Optional[str] = field(
metadata=field_options(alias="macAddress"), default=None
)
"""The mac address for the device, also referred to as the stick id."""

local_ip_address: Optional[str] = field(metadata=field_options(alias="localIpAddress"), default=None)
local_netmask: Optional[str] = field(metadata=field_options(alias="localNetmask"), default=None)
local_gateway: Optional[str] = field(metadata=field_options(alias="localGateway"), default=None)
local_ip_address: Optional[str] = field(
metadata=field_options(alias="localIpAddress"), default=None
)
local_netmask: Optional[str] = field(
metadata=field_options(alias="localNetmask"), default=None
)
local_gateway: Optional[str] = field(
metadata=field_options(alias="localGateway"), default=None
)
rssi: Optional[int] = None
wifi_ssid: Optional[str] = field(metadata=field_options(alias="wifiSsid"), default=None)
wifi_password: Optional[str] = field(metadata=field_options(alias="wifiPassword"), default=None)
wifi_security: Optional[str] = field(metadata=field_options(alias="wifiSecurity"), default=None)
ap_timeout_no_lan: Optional[int] = field(metadata=field_options(alias="apTimeoutNoLan"), default=None)
ap_timeout_idle: Optional[int] = field(metadata=field_options(alias="apTimeoutIdle"), default=None)
ap_security: Optional[str] = field(metadata=field_options(alias="apSecurity"), default=None)
sick_version: Optional[str] = field(metadata=field_options(alias="stickVersion"), default=None)
wifi_ssid: Optional[str] = field(
metadata=field_options(alias="wifiSsid"), default=None
)
wifi_password: Optional[str] = field(
metadata=field_options(alias="wifiPassword"), default=None
)
wifi_security: Optional[str] = field(
metadata=field_options(alias="wifiSecurity"), default=None
)
ap_timeout_no_lan: Optional[int] = field(
metadata=field_options(alias="apTimeoutNoLan"), default=None
)
ap_timeout_idle: Optional[int] = field(
metadata=field_options(alias="apTimeoutIdle"), default=None
)
ap_security: Optional[str] = field(
metadata=field_options(alias="apSecurity"), default=None
)
sick_version: Optional[str] = field(
metadata=field_options(alias="stickVersion"), default=None
)


class SoilType(IntEnum):
Expand All @@ -227,9 +249,15 @@ class ProgramInfo(DataClassDictMixin):
The values are repeated once for each program.
"""

soil_types: list[SoilType] = field(default_factory=list, metadata=field_options(alias="SoilTypes"))
flow_rates: list[int] = field(default_factory=list, metadata=field_options(alias="FlowRates"))
flow_units: list[int] = field(default_factory=list, metadata=field_options(alias="FlowUnits"))
soil_types: list[SoilType] = field(
default_factory=list, metadata=field_options(alias="SoilTypes")
)
flow_rates: list[int] = field(
default_factory=list, metadata=field_options(alias="FlowRates")
)
flow_units: list[int] = field(
default_factory=list, metadata=field_options(alias="FlowUnits")
)

@classmethod
def __pre_deserialize__(cls, values: dict[Any, Any]) -> dict[Any, Any]:
Expand All @@ -253,9 +281,15 @@ class Settings(DataClassDictMixin):
"""Country location of the device."""

# Program information
soil_types: list[SoilType] = field(default_factory=list, metadata=field_options(alias="SoilTypes"))
flow_rates: list[int] = field(default_factory=list, metadata=field_options(alias="FlowRates"))
flow_units: list[int] = field(default_factory=list, metadata=field_options(alias="FlowUnits"))
soil_types: list[SoilType] = field(
default_factory=list, metadata=field_options(alias="SoilTypes")
)
flow_rates: list[int] = field(
default_factory=list, metadata=field_options(alias="FlowRates")
)
flow_units: list[int] = field(
default_factory=list, metadata=field_options(alias="FlowUnits")
)

@classmethod
def __pre_deserialize__(cls, values: dict[Any, Any]) -> dict[Any, Any]:
Expand Down Expand Up @@ -294,7 +328,7 @@ def __init__(self, status: Optional[str], settings: Optional[Settings]) -> None:
@property
def status(self) -> str:
"""Return device status."""
return self._status
return self._status or "unknown"

@property
def settings(self) -> Optional[Settings]:
Expand All @@ -316,7 +350,9 @@ class Controller(DataClassDictMixin):
available_stations: list[int] = field(
metadata=field_options(alias="availableStations"), default_factory=list
)
custom_name: Optional[str] = field(metadata=field_options(alias="customName"), default=None)
custom_name: Optional[str] = field(
metadata=field_options(alias="customName"), default=None
)
custom_program_names: dict[str, str] = field(
metadata=field_options(alias="customProgramNames"), default_factory=dict
)
Expand Down Expand Up @@ -345,18 +381,30 @@ class Weather(DataClassDictMixin):
city: Optional[str] = None
forecast: list[Forecast] = field(default_factory=list)
location: Optional[str] = None
time_zone_id: Optional[str] = field(metadata=field_options(alias="timeZoneId"), default=None)
time_zone_raw_offset: Optional[str] = field(metadata=field_options(alias="timeZoneRawOffset"), default=None)
time_zone_id: Optional[str] = field(
metadata=field_options(alias="timeZoneId"), default=None
)
time_zone_raw_offset: Optional[str] = field(
metadata=field_options(alias="timeZoneRawOffset"), default=None
)


@dataclass
class WeatherAndStatus(DataClassDictMixin):
"""Weather and status from the cloud API."""

stick_id: Optional[str] = field(metadata=field_options(alias="StickId"), default=None)
controller: Optional[Controller] = field(metadata=field_options(alias="Controller"), default=None)
forecasted_rain: Optional[dict[str, Any]] = field(metadata=field_options(alias="ForecastedRain"), default=None)
weather: Optional[Weather] = field(metadata=field_options(alias="Weather"), default=None)
stick_id: Optional[str] = field(
metadata=field_options(alias="StickId"), default=None
)
controller: Optional[Controller] = field(
metadata=field_options(alias="Controller"), default=None
)
forecasted_rain: Optional[dict[str, Any]] = field(
metadata=field_options(alias="ForecastedRain"), default=None
)
weather: Optional[Weather] = field(
metadata=field_options(alias="Weather"), default=None
)


@dataclass
Expand Down Expand Up @@ -398,6 +446,7 @@ def deserialize(self, values: dict[str, Any]) -> datetime.datetime:
int(values["second"]),
)


@dataclass
class ControllerState(DataClassDictMixin):
"""Details about the controller state."""
Expand All @@ -417,13 +466,14 @@ class ControllerState(DataClassDictMixin):
# TODO: Likely need to make this a mask w/ States
active_station: int = field(metadata=field_options(alias="activeStation"))

device_time: datetime.datetime = field(metadata=field_options(serialization_strategy=DeviceTime()))
device_time: datetime.datetime = field(
metadata=field_options(serialization_strategy=DeviceTime())
)

@classmethod
def __pre_deserialize__(cls, d: dict[Any, Any]) -> dict[Any, Any]:
d["device_time"] = {
k: d[k]
for k in ("year", "month", "day", "hour", "minute", "second")
k: d[k] for k in ("year", "month", "day", "hour", "minute", "second")
}
return d

Expand Down Expand Up @@ -459,7 +509,7 @@ def name(self) -> str:
@classmethod
def __pre_deserialize__(cls, values: dict[Any, Any]) -> dict[Any, Any]:
if duration := values.get("duration"):
values["duration"] = duration * 60 #datetime.timedelta(minutes=duration)
values["duration"] = duration * 60 # datetime.timedelta(minutes=duration)
return values


Expand All @@ -479,15 +529,13 @@ def deserialize(self, starts: list[int]) -> list[datetime.time]:
return result




class DayOfWeekSerializationStrategy(SerializationStrategy):
"""Validate different ways the device time parameter is handled."""

def serialize(self, value: Any) -> str:
raise ValueError("Serialization not implemented")

def deserialize(self, mask: int) -> list[DayOfWeek]:
def deserialize(self, mask: int) -> set[DayOfWeek]:
"""Deserialize the device time fields."""
_LOGGER.debug("DayOfWeekSerializationStrategy=%s", mask)
result: set[DayOfWeek] = set()
Expand All @@ -512,7 +560,13 @@ class Program(DataClassDictMixin):
frequency: ProgramFrequency
"""Determines how often the program runs."""

days_of_week: set[DayOfWeek] = field(metadata=field_options(alias="daysOfWeekMask", serialization_strategy=DayOfWeekSerializationStrategy()), default_factory=set)
days_of_week: set[DayOfWeek] = field(
metadata=field_options(
alias="daysOfWeekMask",
serialization_strategy=DayOfWeekSerializationStrategy(),
),
default_factory=set,
)
"""For a CUSTOM program determines the days of the week."""

period: Optional[int] = None
Expand All @@ -521,13 +575,18 @@ class Program(DataClassDictMixin):
synchro: Optional[int] = None
"""Days from today before starting the first day of the program."""

starts: list[datetime.time] = field(default_factory=list, metadata=field_options(serialization_strategy=TimeSerializationStrategy()))
starts: list[datetime.time] = field(
default_factory=list,
metadata=field_options(serialization_strategy=TimeSerializationStrategy()),
)
"""Time of day the program starts."""

durations: list[ZoneDuration] = field(default_factory=list)
"""Durations for run times for each zone."""

controller_info: Optional[ControllerInfo] = field(metadata=field_options(alias="controllerInfo"), default=None)
controller_info: Optional[ControllerInfo] = field(
metadata=field_options(alias="controllerInfo"), default=None
)
"""Information about the controller as input into the programs."""

@property
Expand All @@ -541,7 +600,7 @@ def timeline(self) -> ProgramTimeline:
"""Return a timeline of events for the program."""
return self.timeline_tz(datetime.datetime.now().tzinfo)

def timeline_tz(self, tzinfo: datetime.tzinfo) -> ProgramTimeline:
def timeline_tz(self, tzinfo: datetime.tzinfo | None) -> ProgramTimeline:
"""Return a timeline of events for the program."""
iters: list[Iterable[SortableItem[Timespan, ProgramEvent]]] = []
now = datetime.datetime.now(tzinfo)
Expand All @@ -553,9 +612,9 @@ def timeline_tz(self, tzinfo: datetime.tzinfo) -> ProgramTimeline:
self.frequency,
dtstart,
self.duration,
self.synchro,
self.synchro or 0,
self.days_of_week,
self.period,
self.period or 0,
delay_days=self.delay_days,
),
)
Expand All @@ -575,9 +634,9 @@ def zone_timeline(self) -> ProgramTimeline:
self.frequency,
dtstart,
zone_duration.duration,
self.synchro,
self.synchro or 0,
self.days_of_week,
self.period,
self.period or 0,
delay_days=self.delay_days,
)
)
Expand All @@ -604,12 +663,13 @@ def __post_init__(self):
self.period = None



@dataclass
class Schedule(DataClassDictMixin):
"""Details about program schedules."""

controller_info: Optional[ControllerInfo] = field(metadata=field_options(alias="controllerInfo"))
controller_info: Optional[ControllerInfo] = field(
metadata=field_options(alias="controllerInfo")
)
"""Information about the controller used in the schedule."""

programs: list[Program] = field(metadata=field_options(alias="programInfo"))
Expand All @@ -620,7 +680,7 @@ def timeline(self) -> ProgramTimeline:
"""Return a timeline of all programs."""
return self.timeline_tz(datetime.datetime.now().tzinfo)

def timeline_tz(self, tzinfo: datetime.tzinfo) -> ProgramTimeline:
def timeline_tz(self, tzinfo: datetime.tzinfo | None) -> ProgramTimeline:
"""Return a timeline of all programs."""
iters: list[Iterable[SortableItem[Timespan, ProgramEvent]]] = []
now = datetime.datetime.now(tzinfo)
Expand All @@ -633,9 +693,9 @@ def timeline_tz(self, tzinfo: datetime.tzinfo) -> ProgramTimeline:
program.frequency,
dtstart,
program.duration,
program.synchro,
program.synchro or 0,
program.days_of_week,
program.period,
program.period or 0,
delay_days=self.delay_days,
)
)
Expand Down
4 changes: 2 additions & 2 deletions pyrainbird/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def encode_command(self, method: str, params: dict[str, Any]) -> str:
return send_data
return encrypt(send_data, self._password)

def decode_command(self, content: bytes) -> str:
def decode_command(self, content: bytes) -> str | dict[str, Any]:
"""Decode a response payload."""
if self._password is not None:
decrypted_data = (
Expand All @@ -112,7 +112,7 @@ def decode_command(self, content: bytes) -> str:
.rstrip()
)
content = decrypted_data
self._logger.debug("Response: %s" % content)
self._logger.debug("Response: %r" % content)
response = json.loads(content)
if error := response.get("error"):
msg = ["Error from controller"]
Expand Down
Loading

0 comments on commit 951f2c6

Please sign in to comment.