Skip to content

Commit

Permalink
Update app.py
Browse files Browse the repository at this point in the history
  • Loading branch information
antoinebou12 committed Mar 2, 2024
1 parent 6f38fb1 commit cc42d72
Showing 1 changed file with 74 additions and 63 deletions.
137 changes: 74 additions & 63 deletions api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ def __init__(self, email, password, public_key=CONF_PUBLIC_KEY, user_id=None, re
self.refresh = refresh
self.session = None
self.is_polling_active = False
self.session_key_expiry = datetime.datetime.now()
self.session_key_expiry = datetime.datetime.now(datetime.timezone.utc)
self.auth_in_progress = False
self.session_lock = asyncio.Lock()
self.active_requests = 0

def set_user_id(self, user_id):
"""
Expand Down Expand Up @@ -112,10 +115,9 @@ def prepare_data(self, data):
return data

async def open_session(self):
if self.session is None or self.session.closed:
self.session_key = None
self.session_key_expiry = datetime.datetime.now()
self.session = aiohttp.ClientSession()
async with self.session_lock:
if self.session is None or self.session.closed:
self.session = aiohttp.ClientSession()

async def _request(self, method: str, url: str, **kwargs) -> Union[Dict, List]:
"""
Expand All @@ -133,17 +135,16 @@ async def _request(self, method: str, url: str, **kwargs) -> Union[Dict, List]:
APIError: Custom exception for API-related errors.
"""
try:
async with self.session_lock:
self.active_requests += 1
# Initialize session if it does not exist
await self.open_session()

# Authenticate if session_key is missing, except for the auth URL itself
if self.session_key is None and method != "POST" and url != API_AUTH_URL:
if not self.auth_in_progress and (self.session_key is None and method != "POST" and url != API_AUTH_URL):
_LOGGER.warning(
"No session key found. Attempting to authenticate.")
await self.auth()

# check if the session key is valid
await self.ensure_valid_session()
await self.ensure_valid_session()

# Prepare the data for the API request
kwargs = self.prepare_data(kwargs)
Expand All @@ -161,79 +162,91 @@ async def _request(self, method: str, url: str, **kwargs) -> Union[Dict, List]:
return parsed_response

except (aiohttp.ClientResponseError, aiohttp.ClientConnectionError) as e:
_LOGGER.error(f"Client error occurred in _request: {e}")
_LOGGER.error(f"Aiohttp client error in _request: {e}")
raise APIError(f"API request failed {method} {url}") from e
except Exception as e:
_LOGGER.error(f"Unexpected error occurred in _request: {e}")
raise APIError(f"API request failed {method} {url}") from e
finally:
# Decrement active requests counter
async with self.session_lock:
self.active_requests -= 1

@staticmethod
def encrypt_password(public_key_str, password):
try:
# Ensure the public key is imported correctly
rsa_key = RSA.importKey(public_key_str)
# Create a cipher object using PKCS#1 v1.5
cipher = PKCS1_v1_5.new(rsa_key)
return b64encode(cipher.encrypt(bytes(password, "utf-8")))
return b64encode(cipher.encrypt(password.encode("utf-8")))
except Exception as e:
_LOGGER.error(f"Encryption error occurred in encrypt_password: {e}")
_LOGGER.error(f"Encryption error: {e}")
raise

async def auth(self):
if not self.email or not self.password:
await self.close()
raise AuthenticationError("Email and password must be provided")
if self.auth_in_progress:
return
self.auth_in_progress = True
try:
if not self.email or not self.password:
await self.close()
raise AuthenticationError("Email and password must be provided")

# Check if public_key is None
if self.public_key is None:
_LOGGER.error("Public key is None.")
await self.close()
raise AuthenticationError("Public key is None.")
# Check if public_key is None
if self.public_key is None:
_LOGGER.error("Public key is None.")
await self.close()
raise AuthenticationError("Public key is None.")

try:
encrypted_password = self.encrypt_password(self.public_key, self.password)
except Exception as e:
_LOGGER.error(f"An error occurred while encrypting the password: {e}")
await self.close()
return

data = {"secure_flag": "1", "email": self.email,
"password": encrypted_password}
parsed = await self._request("POST", API_AUTH_URL, json=data)
data = {"secure_flag": "1", "email": self.email,
"password": encrypted_password}

# Check if parsed object is None
if parsed is None:
_LOGGER.error("Parsed object is None.")
await self.close()
raise AuthenticationError("Received NoneType object.")
parsed = await self._request("POST", API_AUTH_URL, json=data)

# Check for 'terminal_user_session_key'
if "terminal_user_session_key" not in parsed:
_LOGGER.error(
"'terminal_user_session_key' not found in parsed object.")
await self.close()
raise AuthenticationError("'terminal_user_session_key' missing.")
# Check if parsed object is None
if parsed is None:
_LOGGER.error("Parsed object is None.")
await self.close()
raise AuthenticationError("Received NoneType object.")

# Check for 'terminal_user_session_key'
if "terminal_user_session_key" not in parsed:
_LOGGER.error(
"'terminal_user_session_key' not found in parsed object.")
await self.close()
raise AuthenticationError("'terminal_user_session_key' missing.")

# If everything is fine, set the session_key
self.session_key = parsed["terminal_user_session_key"]
self.session_key_expiry = datetime.datetime.now() + datetime.timedelta(minutes=10)
return parsed
# If everything is fine, set the session_key
self.session_key = parsed["terminal_user_session_key"]
self.session_key_expiry = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(minutes=10)
self.auth_in_progress = False
return parsed
except Exception as e:
_LOGGER.error(f"Authentication failed: {e}")
await self.close()
finally:
self.auth_in_progress = False

async def ensure_valid_session(self):
"""Ensure the session is valid and authenticated."""
if not self.is_session_valid():
if not self.is_session_valid() and not self.auth_in_progress:
_LOGGER.warning("Session key expired or missing. Re-authenticating.")
await self.auth()

def is_session_valid(self):
"""Check if the session key is valid."""
return True
return self.session_key and datetime.datetime.now(datetime.timezone.utc) < self.session_key_expiry

async def validate_credentials(self):
"""
Validate the current credentials by attempting to authenticate.
Returns True if authentication succeeds, False otherwise.
"""
try:
await self.auth()
await self.ensure_valid_session()
return True
except Exception as e:
_LOGGER.error(f"Validation failed: {e}")
Expand All @@ -244,6 +257,7 @@ async def get_scale_users(self):
Fetch the list of users associated with the scale.
"""
try:
await self.ensure_valid_session()
url = f"{API_SCALE_USERS_URL}?locale=en&terminal_user_session_key={self.session_key}"
parsed = await self._request("GET", url)

Expand Down Expand Up @@ -274,6 +288,7 @@ async def get_measurements(self) -> Optional[List[Dict]]:
Exception: For any other unexpected errors.
"""
try:
await self.ensure_valid_session()
ago_timestamp = int(time.mktime(datetime.date(1998, 1, 1).timetuple()))
url = f"{API_MEASUREMENTS_URL}?user_id={self.user_id}&last_at={ago_timestamp}&locale=en&app_id=Renpho&terminal_user_session_key={self.session_key}"
parsed = await self._request("GET", url)
Expand Down Expand Up @@ -337,6 +352,8 @@ async def get_specific_metric(
if user_id:
self.set_user_id(user_id)

await self.ensure_valid_session()

if metric_type == METRIC_TYPE_WEIGHT:
last_measurement = await self.get_weight()
if last_measurement and self.weight is not None:
Expand Down Expand Up @@ -377,16 +394,9 @@ async def get_info(self):
"""
Wrapper method to authenticate, fetch users, and get measurements.
"""
scale_users_task = self.get_scale_users()
measurements_task = self.get_measurements()

# Execute tasks concurrently
results = await asyncio.gather(scale_users_task, measurements_task, return_exceptions=True)

# Process and handle possible exceptions for each task
scale_users, measurements = results

return measurements
await self.ensure_valid_session()
await self.get_scale_users()
return await self.get_measurements()

async def start_polling(self, refresh=0):
"""
Expand Down Expand Up @@ -640,12 +650,13 @@ async def close(self):
"""
Shutdown the executor when you are done using the RenphoWeight instance.
"""
self.stop_polling() # Stop the polling
if self.session and not self.session.closed:
self.session_key = None
self.session_key_expiry = datetime.datetime.now()
await self.session.close() # Close the session

async with self.session_lock:
self.stop_polling() # Stop the polling
if self.session and not self.session.closed and self.active_requests == 0:
self.session_key = None
self.session_key_expiry = datetime.datetime.now(datetime.timezone.utc)
await self.session.close() # Close the session
self.session = None

class AuthenticationError(Exception):
pass
Expand Down

0 comments on commit cc42d72

Please sign in to comment.