Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 73 additions & 13 deletions backend/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,50 +20,110 @@ def __init__(self, db):
self.encryption_key = self._get_or_create_encryption_key()

def _get_or_create_encryption_key(self) -> bytes:
"""Get or create encryption key for TOTP secrets. Prioritizes environment variables."""
"""Get or create encryption key for TOTP secrets. Prioritizes environment variables, then Redis, then local file."""
import os
import logging
from cryptography.fernet import Fernet

logger = logging.getLogger(__name__)

# 1. Try Environment Variables
# Prioritize environment variable for Azure and production stability
env_key = os.environ.get('ENCRYPTION_KEY')
# Check multiple potential variable names
env_keys = ['TOTP_ENCRYPTION_KEY', 'TOTP_SECRET', 'ENCRYPTION_KEY']
env_key = None

for var_name in env_keys:
val = os.environ.get(var_name)
if val:
env_key = val
logger.info(f"Found encryption key in environment variable: {var_name}")
break

if env_key:
try:
# Ensure it's a valid Fernet key
from cryptography.fernet import Fernet
Fernet(env_key.encode())
logger.info("Using encryption key from environment variable")
return env_key.encode()
except Exception as e:
logger.error(f"Invalid ENCRYPTION_KEY in environment: {str(e)}")
logger.error(f"Invalid encryption key in environment variable: {str(e)}")
# Fall through to other methods if env key is invalid

# 2. Try Redis (for persistence between deploys without env vars)
redis_client = None
try:
from flask import current_app
if current_app:
redis_client = current_app.config.get('REDIS_CLIENT')
if redis_client:
try:
stored_key = redis_client.get('totp_encryption_key')
if stored_key:
if isinstance(stored_key, bytes):
stored_key = stored_key.decode('utf-8')

Fernet(stored_key.encode())
logger.info("Using encryption key from Redis")
return stored_key.encode()
except Exception as e:
logger.error(f"Invalid encryption key in Redis: {e}")
except Exception as e:
logger.warning(f"Failed to check Redis for encryption key: {e}")

# 3. Try Local File (Legacy/Fallback)
# Use absolute path to ensure key is found regardless of working directory
# Key file should be in the backend root directory (parent of models directory)
current_dir = os.path.dirname(os.path.abspath(__file__))
backend_dir = os.path.dirname(current_dir)
key_file = os.path.join(backend_dir, 'totp_encryption.key')

final_key = None

if os.path.exists(key_file):
with open(key_file, 'rb') as f:
key = f.read()
# Validate key to ensure it's not corrupt
try:
from cryptography.fernet import Fernet
Fernet(key)
return key
final_key = key
logger.info(f"Using encryption key from local file: {key_file}")
except Exception as e:
logger.error(f"Invalid encryption key in {key_file}: {e}")
return key
else:
from cryptography.fernet import Fernet
key = Fernet.generate_key()

# 4. Generate New Key if nothing found
if not final_key:
final_key = Fernet.generate_key()
logger.info("Generated NEW encryption key")

# Save to file
try:
with open(key_file, 'wb') as f:
f.write(key)
logger.info(f"Generated new encryption key at {key_file}")
f.write(final_key)
logger.info(f"Saved new encryption key to {key_file}")
except Exception as e:
logger.error(f"Failed to write encryption key to {key_file}: {e}")
return key

# 5. Persist to Redis if available (to prevent future loss)
if final_key and redis_client:
try:
# Store indefinitely (no TTL) or with very long TTL
redis_client.set('totp_encryption_key', final_key.decode('utf-8'))
logger.info("Persisted encryption key to Redis for future deployments")
except Exception as e:
logger.error(f"Failed to persist encryption key to Redis: {e}")

# 6. Log the key for the user (as requested)
try:
key_str = final_key.decode('utf-8')
logger.warning("="*60)
logger.warning("TOTP ENCRYPTION KEY (SAVE THIS TO AZURE ENV VAR 'TOTP_ENCRYPTION_KEY'):")
logger.warning(f"{key_str}")
logger.warning("="*60)
except:
pass

return final_key

def _encrypt_totp_secret(self, secret: str) -> str:
"""Encrypt TOTP secret before storing."""
Expand Down