diff --git a/roborock/exceptions.py b/roborock/exceptions.py index b3e4bd41..0861fb85 100644 --- a/roborock/exceptions.py +++ b/roborock/exceptions.py @@ -77,3 +77,7 @@ class RoborockTooManyRequest(RoborockException): class RoborockRateLimit(RoborockException): """Class for our rate limits exceptions.""" + + +class RoborockNoResponseFromBaseURL(RoborockException): + """We could not find an url that had a record of the given account.""" diff --git a/roborock/web_api.py b/roborock/web_api.py index 0bd231b9..f5dbd52e 100644 --- a/roborock/web_api.py +++ b/roborock/web_api.py @@ -8,6 +8,7 @@ import secrets import string import time +from dataclasses import dataclass import aiohttp from aiohttp import ContentTypeError, FormData @@ -22,14 +23,28 @@ RoborockInvalidEmail, RoborockInvalidUserAgreement, RoborockMissingParameters, + RoborockNoResponseFromBaseURL, RoborockNoUserAgreement, RoborockRateLimit, RoborockTooFrequentCodeRequests, - RoborockTooManyRequest, - RoborockUrlException, ) _LOGGER = logging.getLogger(__name__) +BASE_URLS = [ + "https://usiot.roborock.com", + "https://euiot.roborock.com", + "https://cniot.roborock.com", + "https://ruiot.roborock.com", +] + + +@dataclass +class IotLoginInfo: + """Information about the login to the iot server.""" + + base_url: str + country_code: str + country: str class RoborockApiClient: @@ -49,41 +64,64 @@ class RoborockApiClient: _login_limiter = Limiter(_LOGIN_RATES) _home_data_limiter = Limiter(_HOME_DATA_RATES) - def __init__(self, username: str, base_url=None, session: aiohttp.ClientSession | None = None) -> None: + def __init__( + self, username: str, base_url: str | None = None, session: aiohttp.ClientSession | None = None + ) -> None: """Sample API Client.""" self._username = username - self._default_url = "https://euiot.roborock.com" - self.base_url = base_url + self._base_url = base_url self._device_identifier = secrets.token_urlsafe(16) self.session = session - - async def _get_base_url(self) -> str: - if not self.base_url: - url_request = PreparedRequest(self._default_url, self.session) - response = await url_request.request( - "post", - "/api/v1/getUrlByEmail", - params={"email": self._username, "needtwostepauth": "false"}, - ) - if response is None: - raise RoborockUrlException("get url by email returned None") - response_code = response.get("code") - if response_code != 200: - _LOGGER.info("Get base url failed for %s with the following context: %s", self._username, response) - if response_code == 2003: - raise RoborockInvalidEmail("Your email was incorrectly formatted.") - elif response_code == 1001: - raise RoborockMissingParameters( - "You are missing parameters for this request, are you sure you entered your username?" + self._iot_login_info: IotLoginInfo | None = None + + async def _get_iot_login_info(self) -> IotLoginInfo: + if self._iot_login_info is None: + valid_urls = BASE_URLS if self._base_url is None else [self._base_url] + for iot_url in valid_urls: + url_request = PreparedRequest(iot_url, self.session) + response = await url_request.request( + "post", + "/api/v1/getUrlByEmail", + params={"email": self._username, "needtwostepauth": "false"}, + ) + if response is None: + continue + response_code = response.get("code") + if response_code != 200: + if response_code == 2003: + raise RoborockInvalidEmail("Your email was incorrectly formatted.") + elif response_code == 1001: + raise RoborockMissingParameters( + "You are missing parameters for this request, are you sure you entered your username?" + ) + else: + raise RoborockException(f"{response.get('msg')} - response code: {response_code}") + if response["data"]["countrycode"] is not None: + self._iot_login_info = IotLoginInfo( + base_url=response["data"]["url"], + country=response["data"]["country"], + country_code=response["data"]["countrycode"], ) - elif response_code == 9002: - raise RoborockTooManyRequest("Please temporarily disable making requests and try again later.") - raise RoborockUrlException(f"error code: {response_code} msg: {response.get('error')}") - response_data = response.get("data") - if response_data is None: - raise RoborockUrlException("response does not have 'data'") - self.base_url = response_data.get("url") - return self.base_url + return self._iot_login_info + raise RoborockNoResponseFromBaseURL( + "No account was found for any base url we tried. Either your email is incorrect or we do not have a" + " record of the roborock server your device is on." + ) + return self._iot_login_info + + @property + async def base_url(self): + if self._base_url is not None: + return self._base_url + return (await self._get_iot_login_info()).base_url + + @property + async def country(self): + return (await self._get_iot_login_info()).country + + @property + async def country_code(self): + return (await self._get_iot_login_info()).country_code def _get_header_client_id(self): md5 = hashlib.md5() @@ -167,7 +205,7 @@ async def request_code(self) -> None: except BucketFullException as ex: _LOGGER.info(ex.meta_info) raise RoborockRateLimit("Reached maximum requests for login. Please try again later.") from ex - base_url = await self._get_base_url() + base_url = await self.base_url header_clientid = self._get_header_client_id() code_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid}) @@ -198,7 +236,7 @@ async def request_code_v4(self) -> None: except BucketFullException as ex: _LOGGER.info(ex.meta_info) raise RoborockRateLimit("Reached maximum requests for login. Please try again later.") from ex - base_url = await self._get_base_url() + base_url = await self.base_url header_clientid = self._get_header_client_id() code_request = PreparedRequest( base_url, @@ -229,7 +267,7 @@ async def request_code_v4(self) -> None: async def _sign_key_v3(self, s: str) -> str: """Sign a randomly generated string.""" - base_url = await self._get_base_url() + base_url = await self.base_url header_clientid = self._get_header_client_id() code_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid}) @@ -249,14 +287,20 @@ async def _sign_key_v3(self, s: str) -> str: return code_response["data"]["k"] - async def code_login_v4(self, code: int | str, country: str, country_code: int) -> UserData: + async def code_login_v4( + self, code: int | str, country: str | None = None, country_code: int | None = None + ) -> UserData: """ Login via code authentication. :param code: The code from the email. :param country: The two-character representation of the country, i.e. "US" :param country_code: the country phone number code i.e. 1 for US. """ - base_url = await self._get_base_url() + base_url = await self.base_url + if country is None: + country = await self.country + if country_code is None: + country_code = await self.country_code header_clientid = self._get_header_client_id() x_mercy_ks = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(16)) x_mercy_k = await self._sign_key_v3(x_mercy_ks) @@ -304,7 +348,7 @@ async def pass_login(self, password: str) -> UserData: except BucketFullException as ex: _LOGGER.info(ex.meta_info) raise RoborockRateLimit("Reached maximum requests for login. Please try again later.") from ex - base_url = await self._get_base_url() + base_url = await self.base_url header_clientid = self._get_header_client_id() login_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid}) @@ -343,7 +387,7 @@ async def pass_login_v3(self, password: str) -> UserData: raise NotImplementedError("Pass_login_v3 has not yet been implemented") async def code_login(self, code: int | str) -> UserData: - base_url = await self._get_base_url() + base_url = await self.base_url header_clientid = self._get_header_client_id() login_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid}) @@ -376,7 +420,7 @@ async def code_login(self, code: int | str) -> UserData: return UserData.from_dict(user_data) async def _get_home_id(self, user_data: UserData): - base_url = await self._get_base_url() + base_url = await self.base_url header_clientid = self._get_header_client_id() home_id_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid}) home_id_response = await home_id_request.request( @@ -547,7 +591,7 @@ async def execute_scene(self, user_data: UserData, scene_id: int) -> None: async def get_products(self, user_data: UserData) -> ProductResponse: """Gets all products and their schemas, good for determining status codes and model numbers.""" - base_url = await self._get_base_url() + base_url = await self.base_url header_clientid = self._get_header_client_id() product_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid}) product_response = await product_request.request( @@ -565,7 +609,7 @@ async def get_products(self, user_data: UserData) -> ProductResponse: raise RoborockException("product result was an unexpected type") async def download_code(self, user_data: UserData, product_id: int): - base_url = await self._get_base_url() + base_url = await self.base_url header_clientid = self._get_header_client_id() product_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid}) request = {"apilevel": 99999, "productids": [product_id], "type": 2} @@ -578,7 +622,7 @@ async def download_code(self, user_data: UserData, product_id: int): return response["data"][0]["url"] async def download_category_code(self, user_data: UserData): - base_url = await self._get_base_url() + base_url = await self.base_url header_clientid = self._get_header_client_id() product_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid}) response = await product_request.request( diff --git a/tests/conftest.py b/tests/conftest.py index a38d429f..6ca1e775 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -173,7 +173,7 @@ def mock_rest() -> aioresponses: with aioresponses() as mocked: # Match the base URL and allow any query params mocked.post( - re.compile(r"https://euiot\.roborock\.com/api/v1/getUrlByEmail.*"), + re.compile(r"https://.*iot\.roborock\.com/api/v1/getUrlByEmail.*"), status=200, payload={ "code": 200, diff --git a/tests/mock_data.py b/tests/mock_data.py index 98cd816e..e779d780 100644 --- a/tests/mock_data.py +++ b/tests/mock_data.py @@ -766,7 +766,7 @@ BASE_URL_REQUEST = { "code": 200, "msg": "success", - "data": {"url": "https://sample.com"}, + "data": {"url": "https://sample.com", "countrycode": 1, "country": "US"}, } GET_CODE_RESPONSE = {"code": 200, "msg": "success", "data": None} diff --git a/tests/test_api.py b/tests/test_api.py index a4771a02..3d8ea47a 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -49,24 +49,28 @@ async def test_get_base_url_no_url(): rc = RoborockApiClient("sample@gmail.com") with patch("roborock.web_api.PreparedRequest.request") as mock_request: mock_request.return_value = BASE_URL_REQUEST - await rc._get_base_url() - assert rc.base_url == "https://sample.com" + await rc._get_iot_login_info() + assert await rc.base_url == "https://sample.com" async def test_request_code(): rc = RoborockApiClient("sample@gmail.com") - with patch("roborock.web_api.RoborockApiClient._get_base_url"), patch( - "roborock.web_api.RoborockApiClient._get_header_client_id" - ), patch("roborock.web_api.PreparedRequest.request") as mock_request: + with ( + patch("roborock.web_api.RoborockApiClient._get_iot_login_info"), + patch("roborock.web_api.RoborockApiClient._get_header_client_id"), + patch("roborock.web_api.PreparedRequest.request") as mock_request, + ): mock_request.return_value = GET_CODE_RESPONSE await rc.request_code() async def test_get_home_data(): rc = RoborockApiClient("sample@gmail.com") - with patch("roborock.web_api.RoborockApiClient._get_base_url"), patch( - "roborock.web_api.RoborockApiClient._get_header_client_id" - ), patch("roborock.web_api.PreparedRequest.request") as mock_prepared_request: + with ( + patch("roborock.web_api.RoborockApiClient._get_iot_login_info"), + patch("roborock.web_api.RoborockApiClient._get_header_client_id"), + patch("roborock.web_api.PreparedRequest.request") as mock_prepared_request, + ): mock_prepared_request.side_effect = [ {"code": 200, "msg": "success", "data": {"rrHomeId": 1}}, {"code": 200, "success": True, "result": HOME_DATA_RAW}, @@ -117,10 +121,11 @@ async def test_get_prop(): home_data = HomeData.from_dict(HOME_DATA_RAW) device_info = DeviceData(device=home_data.devices[0], model=home_data.products[0].model) rmc = RoborockMqttClientV1(UserData.from_dict(USER_DATA), device_info) - with patch("roborock.version_1_apis.roborock_mqtt_client_v1.RoborockMqttClientV1.get_status") as get_status, patch( - "roborock.version_1_apis.roborock_client_v1.RoborockClientV1.send_command" - ), patch("roborock.version_1_apis.roborock_client_v1.AttributeCache.async_value"), patch( - "roborock.version_1_apis.roborock_mqtt_client_v1.RoborockMqttClientV1.get_dust_collection_mode" + with ( + patch("roborock.version_1_apis.roborock_mqtt_client_v1.RoborockMqttClientV1.get_status") as get_status, + patch("roborock.version_1_apis.roborock_client_v1.RoborockClientV1.send_command"), + patch("roborock.version_1_apis.roborock_client_v1.AttributeCache.async_value"), + patch("roborock.version_1_apis.roborock_mqtt_client_v1.RoborockMqttClientV1.get_dust_collection_mode"), ): status = S7MaxVStatus.from_dict(STATUS) status.dock_type = RoborockDockTypeCode.auto_empty_dock_pure @@ -194,8 +199,9 @@ async def test_disconnect_failure(connected_mqtt_client: RoborockMqttClientV1) - assert connected_mqtt_client.is_connected() # Make the MQTT client returns with an error when disconnecting - with patch("roborock.cloud_api.mqtt.Client.disconnect", return_value=mqtt.MQTT_ERR_PROTOCOL), pytest.raises( - RoborockException, match="Failed to disconnect" + with ( + patch("roborock.cloud_api.mqtt.Client.disconnect", return_value=mqtt.MQTT_ERR_PROTOCOL), + pytest.raises(RoborockException, match="Failed to disconnect"), ): await connected_mqtt_client.async_disconnect() @@ -231,8 +237,9 @@ async def test_subscribe_failure( response_queue.put(mqtt_packet.gen_connack(rc=0, flags=2)) - with patch("roborock.cloud_api.mqtt.Client.subscribe", return_value=(mqtt.MQTT_ERR_NO_CONN, None)), pytest.raises( - RoborockException, match="Failed to subscribe" + with ( + patch("roborock.cloud_api.mqtt.Client.subscribe", return_value=(mqtt.MQTT_ERR_NO_CONN, None)), + pytest.raises(RoborockException, match="Failed to subscribe"), ): await mqtt_client.async_connect() @@ -298,8 +305,9 @@ async def test_publish_failure( msg = mqtt.MQTTMessageInfo(0) msg.rc = mqtt.MQTT_ERR_PROTOCOL - with patch("roborock.cloud_api.mqtt.Client.publish", return_value=msg), pytest.raises( - RoborockException, match="Failed to publish" + with ( + patch("roborock.cloud_api.mqtt.Client.publish", return_value=msg), + pytest.raises(RoborockException, match="Failed to publish"), ): await connected_mqtt_client.get_room_mapping() @@ -308,7 +316,8 @@ async def test_future_timeout( connected_mqtt_client: RoborockMqttClientV1, ) -> None: """Test a timeout raised while waiting for an RPC response.""" - with patch("roborock.roborock_future.async_timeout.timeout", side_effect=asyncio.TimeoutError), pytest.raises( - RoborockTimeout, match="Timeout after" + with ( + patch("roborock.roborock_future.async_timeout.timeout", side_effect=asyncio.TimeoutError), + pytest.raises(RoborockTimeout, match="Timeout after"), ): await connected_mqtt_client.get_room_mapping() diff --git a/tests/test_web_api.py b/tests/test_web_api.py index 1a11f200..d71a585e 100644 --- a/tests/test_web_api.py +++ b/tests/test_web_api.py @@ -1,7 +1,10 @@ +import re + import aiohttp +from aioresponses.compat import normalize_url from roborock import HomeData, HomeDataScene, UserData -from roborock.web_api import RoborockApiClient +from roborock.web_api import IotLoginInfo, RoborockApiClient from tests.mock_data import HOME_DATA_RAW, USER_DATA @@ -71,3 +74,99 @@ async def test_code_login_v4_flow(mock_rest) -> None: await api.request_code_v4() ud = await api.code_login_v4(4123, "US", 1) assert ud == UserData.from_dict(USER_DATA) + + +async def test_url_cycling(mock_rest) -> None: + """Test that we cycle through the URLs correctly.""" + # Clear mock rest so that we can override the patches. + mock_rest.clear() + # 1. Mock US URL to return valid status but None for countrycode + + mock_rest.post( + re.compile("https://usiot.roborock.com/api/v1/getUrlByEmail.*"), + status=200, + payload={ + "code": 200, + "data": {"url": "https://usiot.roborock.com", "country": None, "countrycode": None}, + "msg": "Success", + }, + ) + + # 2. Mock EU URL to return valid status but None for countrycode + mock_rest.post( + re.compile("https://euiot.roborock.com/api/v1/getUrlByEmail.*"), + status=200, + payload={ + "code": 200, + "data": {"url": "https://euiot.roborock.com", "country": None, "countrycode": None}, + "msg": "Success", + }, + ) + + # 3. Mock CN URL to return the correct, valid data + mock_rest.post( + re.compile("https://cniot.roborock.com/api/v1/getUrlByEmail.*"), + status=200, + payload={ + "code": 200, + "data": {"url": "https://cniot.roborock.com", "country": "CN", "countrycode": "86"}, + "msg": "Success", + }, + ) + + # The RU URL should not be called, but we can mock it just in case + # to catch unexpected behavior. + mock_rest.post(re.compile("https://ruiot.roborock.com/api/v1/getUrlByEmail.*"), status=500) + + client = RoborockApiClient("test@example.com") + result = await client._get_iot_login_info() + + assert result is not None + assert isinstance(result, IotLoginInfo) + assert result.base_url == "https://cniot.roborock.com" + assert result.country == "CN" + assert result.country_code == "86" + + assert client._iot_login_info == result + # Check that all three urls were called. We have to do this kind of weirdly as aioresponses seems to have a bug. + assert ( + len( + mock_rest.requests[ + ( + "post", + normalize_url( + "https://usiot.roborock.com/api/v1/getUrlByEmail?email=test%2540example.com&needtwostepauth=false" + ), + ) + ] + ) + == 1 + ) + assert ( + len( + mock_rest.requests[ + ( + "post", + normalize_url( + "https://euiot.roborock.com/api/v1/getUrlByEmail?email=test%2540example.com&needtwostepauth=false" + ), + ) + ] + ) + == 1 + ) + assert ( + len( + mock_rest.requests[ + ( + "post", + normalize_url( + "https://cniot.roborock.com/api/v1/getUrlByEmail?email=test%2540example.com&needtwostepauth=false" + ), + ) + ] + ) + == 1 + ) + # Make sure we just have the three we tested for above. + assert len(mock_rest.requests) == 3