diff --git a/backend/models/user.py b/backend/models/user.py index b34b777..f9cc157 100644 --- a/backend/models/user.py +++ b/backend/models/user.py @@ -20,50 +20,110 @@ def __init__(self, db): self.encryption_key = self._get_or_create_encryption_key() def _get_or_create_encryption_key(self) -> bytes: - """Get or create encryption key for TOTP secrets. Prioritizes environment variables.""" + """Get or create encryption key for TOTP secrets. Prioritizes environment variables, then Redis, then local file.""" import os import logging + from cryptography.fernet import Fernet + logger = logging.getLogger(__name__) + # 1. Try Environment Variables # Prioritize environment variable for Azure and production stability - env_key = os.environ.get('ENCRYPTION_KEY') + # Check multiple potential variable names + env_keys = ['TOTP_ENCRYPTION_KEY', 'TOTP_SECRET', 'ENCRYPTION_KEY'] + env_key = None + + for var_name in env_keys: + val = os.environ.get(var_name) + if val: + env_key = val + logger.info(f"Found encryption key in environment variable: {var_name}") + break + if env_key: try: # Ensure it's a valid Fernet key - from cryptography.fernet import Fernet Fernet(env_key.encode()) logger.info("Using encryption key from environment variable") return env_key.encode() except Exception as e: - logger.error(f"Invalid ENCRYPTION_KEY in environment: {str(e)}") + logger.error(f"Invalid encryption key in environment variable: {str(e)}") + # Fall through to other methods if env key is invalid + + # 2. Try Redis (for persistence between deploys without env vars) + redis_client = None + try: + from flask import current_app + if current_app: + redis_client = current_app.config.get('REDIS_CLIENT') + if redis_client: + try: + stored_key = redis_client.get('totp_encryption_key') + if stored_key: + if isinstance(stored_key, bytes): + stored_key = stored_key.decode('utf-8') + + Fernet(stored_key.encode()) + logger.info("Using encryption key from Redis") + return stored_key.encode() + except Exception as e: + logger.error(f"Invalid encryption key in Redis: {e}") + except Exception as e: + logger.warning(f"Failed to check Redis for encryption key: {e}") + # 3. Try Local File (Legacy/Fallback) # Use absolute path to ensure key is found regardless of working directory # Key file should be in the backend root directory (parent of models directory) current_dir = os.path.dirname(os.path.abspath(__file__)) backend_dir = os.path.dirname(current_dir) key_file = os.path.join(backend_dir, 'totp_encryption.key') + final_key = None + if os.path.exists(key_file): with open(key_file, 'rb') as f: key = f.read() # Validate key to ensure it's not corrupt try: - from cryptography.fernet import Fernet Fernet(key) - return key + final_key = key + logger.info(f"Using encryption key from local file: {key_file}") except Exception as e: logger.error(f"Invalid encryption key in {key_file}: {e}") - return key - else: - from cryptography.fernet import Fernet - key = Fernet.generate_key() + + # 4. Generate New Key if nothing found + if not final_key: + final_key = Fernet.generate_key() + logger.info("Generated NEW encryption key") + + # Save to file try: with open(key_file, 'wb') as f: - f.write(key) - logger.info(f"Generated new encryption key at {key_file}") + f.write(final_key) + logger.info(f"Saved new encryption key to {key_file}") except Exception as e: logger.error(f"Failed to write encryption key to {key_file}: {e}") - return key + + # 5. Persist to Redis if available (to prevent future loss) + if final_key and redis_client: + try: + # Store indefinitely (no TTL) or with very long TTL + redis_client.set('totp_encryption_key', final_key.decode('utf-8')) + logger.info("Persisted encryption key to Redis for future deployments") + except Exception as e: + logger.error(f"Failed to persist encryption key to Redis: {e}") + + # 6. Log the key for the user (as requested) + try: + key_str = final_key.decode('utf-8') + logger.warning("="*60) + logger.warning("TOTP ENCRYPTION KEY (SAVE THIS TO AZURE ENV VAR 'TOTP_ENCRYPTION_KEY'):") + logger.warning(f"{key_str}") + logger.warning("="*60) + except: + pass + + return final_key def _encrypt_totp_secret(self, secret: str) -> str: """Encrypt TOTP secret before storing."""