Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions roborock/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,7 @@ class RoborockTooManyRequest(RoborockException):

class RoborockRateLimit(RoborockException):
"""Class for our rate limits exceptions."""


class RoborockNoResponseFromBaseURL(RoborockException):
"""We could not find an url that had a record of the given account."""
130 changes: 87 additions & 43 deletions roborock/web_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import secrets
import string
import time
from dataclasses import dataclass

import aiohttp
from aiohttp import ContentTypeError, FormData
Expand All @@ -22,14 +23,28 @@
RoborockInvalidEmail,
RoborockInvalidUserAgreement,
RoborockMissingParameters,
RoborockNoResponseFromBaseURL,
RoborockNoUserAgreement,
RoborockRateLimit,
RoborockTooFrequentCodeRequests,
RoborockTooManyRequest,
RoborockUrlException,
)

_LOGGER = logging.getLogger(__name__)
BASE_URLS = [
"https://usiot.roborock.com",
"https://euiot.roborock.com",
"https://cniot.roborock.com",
"https://ruiot.roborock.com",
]


@dataclass
class IotLoginInfo:
"""Information about the login to the iot server."""

base_url: str
country_code: str
country: str


class RoborockApiClient:
Expand All @@ -49,41 +64,64 @@ class RoborockApiClient:
_login_limiter = Limiter(_LOGIN_RATES)
_home_data_limiter = Limiter(_HOME_DATA_RATES)

def __init__(self, username: str, base_url=None, session: aiohttp.ClientSession | None = None) -> None:
def __init__(
self, username: str, base_url: str | None = None, session: aiohttp.ClientSession | None = None
) -> None:
"""Sample API Client."""
self._username = username
self._default_url = "https://euiot.roborock.com"
self.base_url = base_url
self._base_url = base_url
self._device_identifier = secrets.token_urlsafe(16)
self.session = session

async def _get_base_url(self) -> str:
if not self.base_url:
url_request = PreparedRequest(self._default_url, self.session)
response = await url_request.request(
"post",
"/api/v1/getUrlByEmail",
params={"email": self._username, "needtwostepauth": "false"},
)
if response is None:
raise RoborockUrlException("get url by email returned None")
response_code = response.get("code")
if response_code != 200:
_LOGGER.info("Get base url failed for %s with the following context: %s", self._username, response)
if response_code == 2003:
raise RoborockInvalidEmail("Your email was incorrectly formatted.")
elif response_code == 1001:
raise RoborockMissingParameters(
"You are missing parameters for this request, are you sure you entered your username?"
self._iot_login_info: IotLoginInfo | None = None

async def _get_iot_login_info(self) -> IotLoginInfo:
if self._iot_login_info is None:
valid_urls = BASE_URLS if self._base_url is None else [self._base_url]
for iot_url in valid_urls:
url_request = PreparedRequest(iot_url, self.session)
response = await url_request.request(
"post",
"/api/v1/getUrlByEmail",
params={"email": self._username, "needtwostepauth": "false"},
)
if response is None:
continue
response_code = response.get("code")
if response_code != 200:
if response_code == 2003:
raise RoborockInvalidEmail("Your email was incorrectly formatted.")
elif response_code == 1001:
raise RoborockMissingParameters(
"You are missing parameters for this request, are you sure you entered your username?"
)
else:
raise RoborockException(f"{response.get('msg')} - response code: {response_code}")
if response["data"]["countrycode"] is not None:
self._iot_login_info = IotLoginInfo(
base_url=response["data"]["url"],
country=response["data"]["country"],
country_code=response["data"]["countrycode"],
)
elif response_code == 9002:
raise RoborockTooManyRequest("Please temporarily disable making requests and try again later.")
raise RoborockUrlException(f"error code: {response_code} msg: {response.get('error')}")
response_data = response.get("data")
if response_data is None:
raise RoborockUrlException("response does not have 'data'")
self.base_url = response_data.get("url")
return self.base_url
return self._iot_login_info
raise RoborockNoResponseFromBaseURL(
"No account was found for any base url we tried. Either your email is incorrect or we do not have a"
" record of the roborock server your device is on."
)
return self._iot_login_info

@property
async def base_url(self):
if self._base_url is not None:
return self._base_url
return (await self._get_iot_login_info()).base_url

@property
async def country(self):
return (await self._get_iot_login_info()).country

@property
async def country_code(self):
return (await self._get_iot_login_info()).country_code

def _get_header_client_id(self):
md5 = hashlib.md5()
Expand Down Expand Up @@ -167,7 +205,7 @@ async def request_code(self) -> None:
except BucketFullException as ex:
_LOGGER.info(ex.meta_info)
raise RoborockRateLimit("Reached maximum requests for login. Please try again later.") from ex
base_url = await self._get_base_url()
base_url = await self.base_url
header_clientid = self._get_header_client_id()
code_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid})

Expand Down Expand Up @@ -198,7 +236,7 @@ async def request_code_v4(self) -> None:
except BucketFullException as ex:
_LOGGER.info(ex.meta_info)
raise RoborockRateLimit("Reached maximum requests for login. Please try again later.") from ex
base_url = await self._get_base_url()
base_url = await self.base_url
header_clientid = self._get_header_client_id()
code_request = PreparedRequest(
base_url,
Expand Down Expand Up @@ -229,7 +267,7 @@ async def request_code_v4(self) -> None:

async def _sign_key_v3(self, s: str) -> str:
"""Sign a randomly generated string."""
base_url = await self._get_base_url()
base_url = await self.base_url
header_clientid = self._get_header_client_id()
code_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid})

Expand All @@ -249,14 +287,20 @@ async def _sign_key_v3(self, s: str) -> str:

return code_response["data"]["k"]

async def code_login_v4(self, code: int | str, country: str, country_code: int) -> UserData:
async def code_login_v4(
self, code: int | str, country: str | None = None, country_code: int | None = None
) -> UserData:
"""
Login via code authentication.
:param code: The code from the email.
:param country: The two-character representation of the country, i.e. "US"
:param country_code: the country phone number code i.e. 1 for US.
"""
base_url = await self._get_base_url()
base_url = await self.base_url
if country is None:
country = await self.country
if country_code is None:
country_code = await self.country_code
header_clientid = self._get_header_client_id()
x_mercy_ks = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(16))
x_mercy_k = await self._sign_key_v3(x_mercy_ks)
Expand Down Expand Up @@ -304,7 +348,7 @@ async def pass_login(self, password: str) -> UserData:
except BucketFullException as ex:
_LOGGER.info(ex.meta_info)
raise RoborockRateLimit("Reached maximum requests for login. Please try again later.") from ex
base_url = await self._get_base_url()
base_url = await self.base_url
header_clientid = self._get_header_client_id()

login_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid})
Expand Down Expand Up @@ -343,7 +387,7 @@ async def pass_login_v3(self, password: str) -> UserData:
raise NotImplementedError("Pass_login_v3 has not yet been implemented")

async def code_login(self, code: int | str) -> UserData:
base_url = await self._get_base_url()
base_url = await self.base_url
header_clientid = self._get_header_client_id()

login_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid})
Expand Down Expand Up @@ -376,7 +420,7 @@ async def code_login(self, code: int | str) -> UserData:
return UserData.from_dict(user_data)

async def _get_home_id(self, user_data: UserData):
base_url = await self._get_base_url()
base_url = await self.base_url
header_clientid = self._get_header_client_id()
home_id_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid})
home_id_response = await home_id_request.request(
Expand Down Expand Up @@ -547,7 +591,7 @@ async def execute_scene(self, user_data: UserData, scene_id: int) -> None:

async def get_products(self, user_data: UserData) -> ProductResponse:
"""Gets all products and their schemas, good for determining status codes and model numbers."""
base_url = await self._get_base_url()
base_url = await self.base_url
header_clientid = self._get_header_client_id()
product_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid})
product_response = await product_request.request(
Expand All @@ -565,7 +609,7 @@ async def get_products(self, user_data: UserData) -> ProductResponse:
raise RoborockException("product result was an unexpected type")

async def download_code(self, user_data: UserData, product_id: int):
base_url = await self._get_base_url()
base_url = await self.base_url
header_clientid = self._get_header_client_id()
product_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid})
request = {"apilevel": 99999, "productids": [product_id], "type": 2}
Expand All @@ -578,7 +622,7 @@ async def download_code(self, user_data: UserData, product_id: int):
return response["data"][0]["url"]

async def download_category_code(self, user_data: UserData):
base_url = await self._get_base_url()
base_url = await self.base_url
header_clientid = self._get_header_client_id()
product_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid})
response = await product_request.request(
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def mock_rest() -> aioresponses:
with aioresponses() as mocked:
# Match the base URL and allow any query params
mocked.post(
re.compile(r"https://euiot\.roborock\.com/api/v1/getUrlByEmail.*"),
re.compile(r"https://.*iot\.roborock\.com/api/v1/getUrlByEmail.*"),
status=200,
payload={
"code": 200,
Expand Down
2 changes: 1 addition & 1 deletion tests/mock_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@
BASE_URL_REQUEST = {
"code": 200,
"msg": "success",
"data": {"url": "https://sample.com"},
"data": {"url": "https://sample.com", "countrycode": 1, "country": "US"},
}

GET_CODE_RESPONSE = {"code": 200, "msg": "success", "data": None}
Expand Down
49 changes: 29 additions & 20 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,24 +49,28 @@ async def test_get_base_url_no_url():
rc = RoborockApiClient("sample@gmail.com")
with patch("roborock.web_api.PreparedRequest.request") as mock_request:
mock_request.return_value = BASE_URL_REQUEST
await rc._get_base_url()
assert rc.base_url == "https://sample.com"
await rc._get_iot_login_info()
assert await rc.base_url == "https://sample.com"


async def test_request_code():
rc = RoborockApiClient("sample@gmail.com")
with patch("roborock.web_api.RoborockApiClient._get_base_url"), patch(
"roborock.web_api.RoborockApiClient._get_header_client_id"
), patch("roborock.web_api.PreparedRequest.request") as mock_request:
with (
patch("roborock.web_api.RoborockApiClient._get_iot_login_info"),
patch("roborock.web_api.RoborockApiClient._get_header_client_id"),
patch("roborock.web_api.PreparedRequest.request") as mock_request,
):
mock_request.return_value = GET_CODE_RESPONSE
await rc.request_code()


async def test_get_home_data():
rc = RoborockApiClient("sample@gmail.com")
with patch("roborock.web_api.RoborockApiClient._get_base_url"), patch(
"roborock.web_api.RoborockApiClient._get_header_client_id"
), patch("roborock.web_api.PreparedRequest.request") as mock_prepared_request:
with (
patch("roborock.web_api.RoborockApiClient._get_iot_login_info"),
patch("roborock.web_api.RoborockApiClient._get_header_client_id"),
patch("roborock.web_api.PreparedRequest.request") as mock_prepared_request,
):
mock_prepared_request.side_effect = [
{"code": 200, "msg": "success", "data": {"rrHomeId": 1}},
{"code": 200, "success": True, "result": HOME_DATA_RAW},
Expand Down Expand Up @@ -117,10 +121,11 @@ async def test_get_prop():
home_data = HomeData.from_dict(HOME_DATA_RAW)
device_info = DeviceData(device=home_data.devices[0], model=home_data.products[0].model)
rmc = RoborockMqttClientV1(UserData.from_dict(USER_DATA), device_info)
with patch("roborock.version_1_apis.roborock_mqtt_client_v1.RoborockMqttClientV1.get_status") as get_status, patch(
"roborock.version_1_apis.roborock_client_v1.RoborockClientV1.send_command"
), patch("roborock.version_1_apis.roborock_client_v1.AttributeCache.async_value"), patch(
"roborock.version_1_apis.roborock_mqtt_client_v1.RoborockMqttClientV1.get_dust_collection_mode"
with (
patch("roborock.version_1_apis.roborock_mqtt_client_v1.RoborockMqttClientV1.get_status") as get_status,
patch("roborock.version_1_apis.roborock_client_v1.RoborockClientV1.send_command"),
patch("roborock.version_1_apis.roborock_client_v1.AttributeCache.async_value"),
patch("roborock.version_1_apis.roborock_mqtt_client_v1.RoborockMqttClientV1.get_dust_collection_mode"),
):
status = S7MaxVStatus.from_dict(STATUS)
status.dock_type = RoborockDockTypeCode.auto_empty_dock_pure
Expand Down Expand Up @@ -194,8 +199,9 @@ async def test_disconnect_failure(connected_mqtt_client: RoborockMqttClientV1) -
assert connected_mqtt_client.is_connected()

# Make the MQTT client returns with an error when disconnecting
with patch("roborock.cloud_api.mqtt.Client.disconnect", return_value=mqtt.MQTT_ERR_PROTOCOL), pytest.raises(
RoborockException, match="Failed to disconnect"
with (
patch("roborock.cloud_api.mqtt.Client.disconnect", return_value=mqtt.MQTT_ERR_PROTOCOL),
pytest.raises(RoborockException, match="Failed to disconnect"),
):
await connected_mqtt_client.async_disconnect()

Expand Down Expand Up @@ -231,8 +237,9 @@ async def test_subscribe_failure(

response_queue.put(mqtt_packet.gen_connack(rc=0, flags=2))

with patch("roborock.cloud_api.mqtt.Client.subscribe", return_value=(mqtt.MQTT_ERR_NO_CONN, None)), pytest.raises(
RoborockException, match="Failed to subscribe"
with (
patch("roborock.cloud_api.mqtt.Client.subscribe", return_value=(mqtt.MQTT_ERR_NO_CONN, None)),
pytest.raises(RoborockException, match="Failed to subscribe"),
):
await mqtt_client.async_connect()

Expand Down Expand Up @@ -298,8 +305,9 @@ async def test_publish_failure(

msg = mqtt.MQTTMessageInfo(0)
msg.rc = mqtt.MQTT_ERR_PROTOCOL
with patch("roborock.cloud_api.mqtt.Client.publish", return_value=msg), pytest.raises(
RoborockException, match="Failed to publish"
with (
patch("roborock.cloud_api.mqtt.Client.publish", return_value=msg),
pytest.raises(RoborockException, match="Failed to publish"),
):
await connected_mqtt_client.get_room_mapping()

Expand All @@ -308,7 +316,8 @@ async def test_future_timeout(
connected_mqtt_client: RoborockMqttClientV1,
) -> None:
"""Test a timeout raised while waiting for an RPC response."""
with patch("roborock.roborock_future.async_timeout.timeout", side_effect=asyncio.TimeoutError), pytest.raises(
RoborockTimeout, match="Timeout after"
with (
patch("roborock.roborock_future.async_timeout.timeout", side_effect=asyncio.TimeoutError),
pytest.raises(RoborockTimeout, match="Timeout after"),
):
await connected_mqtt_client.get_room_mapping()
Loading