diff --git a/keepercommander/__init__.py b/keepercommander/__init__.py index 50f1aa370..9aae1a614 100644 --- a/keepercommander/__init__.py +++ b/keepercommander/__init__.py @@ -10,4 +10,4 @@ # Contact: commander@keepersecurity.com # -__version__ = '17.2.10' +__version__ = '17.2.11' diff --git a/keepercommander/cli.py b/keepercommander/cli.py index 64593fd26..104d3df90 100644 --- a/keepercommander/cli.py +++ b/keepercommander/cli.py @@ -129,7 +129,10 @@ def clean_description(desc): ] domain_subcommands = [ ('domain list (dl)', 'List all reserved domains for the enterprise'), - ('domain reserve (dr)', 'Reserve, delete, or generate token for a domain'), + ('domain reserve (dr)', 'Reserve, delete or generate token for a domain'), + ('domain alias list (dal)', 'List domain aliases for the enterprise'), + ('domain alias create (dac)', 'Create a domain alias for the enterprise'), + ('domain alias delete (dad)', 'Delete a domain alias for the enterprise'), ] for category in get_category_order(): diff --git a/keepercommander/commands/credential_provision.py b/keepercommander/commands/credential_provision.py index 9b5b5a9a5..a68a47e75 100644 --- a/keepercommander/commands/credential_provision.py +++ b/keepercommander/commands/credential_provision.py @@ -70,6 +70,19 @@ from ..commands import email_commands from ..email_service import EmailSender, build_onboarding_email from keepercommander.commands.pam.user_facade import PamUserRecordFacade +from keepercommander.commands.register import ShareRecordCommand +from keepercommander.commands.pam.pam_dto import ( + GatewayAction, + GatewayActionRmCreateUser, GatewayActionRmCreateUserInputs, + GatewayActionRmAddUserToGroup, GatewayActionRmAddUserToGroupInputs, + GatewayActionRmDeleteUser, GatewayActionRmDeleteUserInputs, +) +from keepercommander.commands.pam.router_helper import ( + router_send_action_to_gateway, + router_get_connected_gateways, + get_response_payload, +) +from keepercommander.proto import pam_pb2 from keepercommander.commands.pam.config_facades import PamConfigurationRecordFacade # ============================================================================= @@ -97,6 +110,12 @@ help='Base64-encoded YAML configuration content (for API/Service Mode usage)' ) +credential_provision_parser.add_argument( + '-c', '--pam-config', + dest='pam_config', + help='PAM Configuration record UID (determines which Gateway to use)' +) + credential_provision_parser.add_argument( '--dry-run', dest='dry_run', @@ -126,6 +145,46 @@ def register_command_info(aliases, command_info): aliases['cp'] = 'credential-provision' command_info['credential-provision'] = 'Automate PAM User credential provisioning' +# ============================================================================= +# Username Template Engine (KC-1035) +# ============================================================================= + +def resolve_username_template(template, user_data): + """Resolve a username template using user data fields. + + Supported variables: + {first_name} - Full first name (e.g., "Felipe") + {last_name} - Full last name (e.g., "Dias") + {first_initial} - First character of first name (e.g., "f") + {last_initial} - First character of last name (e.g., "d") + {email_prefix} - Part before @ in personal_email (e.g., "fdias") + + Args: + template: String with {variable} placeholders (e.g., "{first_initial}{last_name}.adm") + user_data: Dict with keys: first_name, last_name, personal_email + + Returns: + Resolved string, lowercased (e.g., "fdias.adm") + """ + first_name = user_data.get('first_name', '') + last_name = user_data.get('last_name', '') + email = user_data.get('personal_email', '') + + replacements = { + 'first_name': first_name, + 'last_name': last_name, + 'first_initial': first_name[0] if first_name else '', + 'last_initial': last_name[0] if last_name else '', + 'email_prefix': email.split('@')[0] if '@' in email else email, + } + + result = template + for key, value in replacements.items(): + result = result.replace('{' + key + '}', value) + + return result.lower() + + # ============================================================================= # Main Command Class # ============================================================================= @@ -137,6 +196,11 @@ def __init__(self): self.pam_user_uid = None self.dag_link_created = False self.folder_created = None + # KC-1035: AD user creation tracking for rollback + self.ad_user_created = False + self.ad_username = None + self.ad_config_uid = None + self.ad_gateway_uid = None class CredentialProvisionCommand(Command): """ @@ -179,6 +243,7 @@ def execute(self, params: KeeperParams, **kwargs): config_path = kwargs.get('config') config_base64 = kwargs.get('config_base64') + pam_config_arg = kwargs.get('pam_config') dry_run = kwargs.get('dry_run', False) output_format = kwargs.get('output', 'text') @@ -198,6 +263,13 @@ def execute(self, params: KeeperParams, **kwargs): 'credential-provision', 'Either --config or --config-base64 is required' ) + # Set pam_config_uid from CLI arg -c, or fall back to YAML value + if 'account' not in config: + config['account'] = {} + if pam_config_arg: + config['account']['pam_config_uid'] = pam_config_arg + # If neither -c nor YAML pam_config_uid provided, validation will catch it + validation_errors = self._validate_config(params, config) if validation_errors: @@ -242,6 +314,29 @@ def execute(self, params: KeeperParams, **kwargs): logging.error('Make sure you have access to this PAM Configuration') raise CommandError('credential-provision', error_msg) + # Resolve username template if provided (KC-1035) + # Must happen before system-specific validation and dry-run + if config['account'].get('username_template') and not config['account'].get('username'): + resolved = resolve_username_template( + config['account']['username_template'], + config['user'] + ) + if not resolved or not resolved.strip() or not resolved.strip('.'): + raise CommandError( + 'credential-provision', + f'Username template resolved to invalid value: "{resolved}"\n' + f'Template: {config["account"]["username_template"]}\n' + f'Check that user.first_name and user.last_name are not empty.' + ) + config['account']['username'] = resolved + if output_format == 'text': + logging.info(f'Resolved username: {resolved}') + + # Also resolve distinguished_name if it contains template variables + dn = config['account'].get('distinguished_name', '') + if '{username}' in dn: + config['account']['distinguished_name'] = dn.replace('{username}', resolved) + # Validate system-specific fields based on PAM type system_errors = self._validate_system_specific_fields(config, pam_config_record, params) @@ -270,8 +365,11 @@ def execute(self, params: KeeperParams, **kwargs): # Execute provisioning state = ProvisioningState() + has_delivery = 'delivery' in config + has_email = 'email' in config try: + # Check for duplicates if self._check_duplicate(config, params): error_msg = f'Duplicate PAM User already exists for username: {config["account"]["username"]}' @@ -282,18 +380,39 @@ def execute(self, params: KeeperParams, **kwargs): logging.error(error_msg) raise CommandError('credential-provision', error_msg) - # Generate password and create PAM User - password = self._generate_password(config['pam']['rotation']['password_complexity']) + # Generate password + password = self._generate_password(config['rotation']['password_complexity']) + + # Create AD user via Gateway if AD-specific fields are present (KC-1035) + ad_groups = config['account'].get('ad_groups', []) + has_ad_config = config['account'].get('distinguished_name') or ad_groups + if has_ad_config: + self._create_ad_user_via_gateway(config, password, params, state) + if output_format == 'text': + logging.info(f'✅ AD user created: {config["account"]["username"]}') + + # Add to AD groups + if ad_groups: + self._add_ad_user_to_groups_via_gateway( + config, params, state.ad_gateway_uid + ) + if output_format == 'text': + logging.info(f'✅ Added to AD groups: {", ".join(ad_groups)}') + + # Create PAM User record pam_user_uid = self._create_pam_user(config, password, params) state.pam_user_uid = pam_user_uid + if output_format == 'text': + logging.info(f'✅ PAM User record created: {pam_user_uid}') + # Link to PAM Configuration and configure rotation self._create_dag_link(pam_user_uid, config['account']['pam_config_uid'], params) state.dag_link_created = True self._configure_rotation(pam_user_uid, config, params) if output_format == 'text': - logging.info('✅ PAM User created and linked') + logging.info('✅ Rotation configured') # Perform immediate rotation if configured rotation_success = self._rotate_immediately(pam_user_uid, config, params) @@ -301,28 +420,49 @@ def execute(self, params: KeeperParams, **kwargs): if output_format == 'text': logging.info('✅ Password rotation submitted') - # Generate share URL for PAM User (shares source of truth, not a copy) - share_url = self._generate_share_url(pam_user_uid, config, params) + # Delivery: direct share or one-time share URL + email + share_url = None + share_success = False + email_success = False + + # Direct share (if delivery section present) + if has_delivery: + share_success = self._share_directly(pam_user_uid, config, params) + if output_format == 'text': + share_to = config['delivery']['share_to'] + if share_success: + logging.info(f'✅ Record shared to {share_to}') + else: + logging.warning(f'⚠️ Direct share to {share_to} failed — share manually') + + # Email delivery (if email section present) + if has_email: + share_url = self._generate_share_url(pam_user_uid, config, params) + if output_format == 'text': + logging.info('✅ Share URL generated for PAM User') + email_success = self._send_email(config, share_url, params) + if output_format == 'text': + logging.info('✅ Email with one-time share sent') + + if not has_delivery and not has_email: + if output_format == 'text': + logging.info('✅ Record created (no delivery configured)') - if output_format == 'text': - logging.info('✅ Share URL generated for PAM User') - - # Send welcome email - email_success = self._send_email(config, share_url, params) - - if output_format == 'text': - logging.info('✅ Email with one-time share sent') - else: + if output_format == 'json': result = { 'success': True, 'pam_user_uid': pam_user_uid, - 'share_url': share_url, 'username': config['account']['username'], 'employee_name': f"{config['user']['first_name']} {config['user']['last_name']}", 'rotation_status': 'synced' if rotation_success else 'scheduled', - 'email_status': 'sent' if email_success else 'failed', 'message': 'Credential provisioning complete' } + if has_delivery: + result['share_status'] = 'shared' if share_success else 'failed' + result['shared_to'] = config['delivery']['share_to'] + if has_email: + result['share_url'] = share_url + result['email_status'] = 'sent' if email_success else 'failed' print(json.dumps(result, indent=2)) except CommandError as e: @@ -478,7 +618,7 @@ def _validate_config(self, params: KeeperParams, config: Dict[str, Any]) -> List errors = [] # Validate required top-level sections - required_sections = ['user', 'account', 'pam', 'email'] + required_sections = ['user', 'account', 'rotation'] for section in required_sections: if section not in config: errors.append(f'Missing required section: {section}') @@ -490,8 +630,15 @@ def _validate_config(self, params: KeeperParams, config: Dict[str, Any]) -> List # Validate each section errors.extend(self._validate_user_section(config.get('user', {}))) errors.extend(self._validate_account_section(config.get('account', {}))) - errors.extend(self._validate_pam_section(config.get('pam', {}))) - errors.extend(self._validate_email_section(params, config.get('email', {}))) + errors.extend(self._validate_rotation_section(config.get('rotation', {}))) + + # Validate delivery section if present (vault sharing) + if 'delivery' in config: + errors.extend(self._validate_delivery_section(config['delivery'])) + + # Validate email section if present (email delivery) + if 'email' in config: + errors.extend(self._validate_email_section(params, config.get('email', {}))) # Validate optional vault section if 'vault' in config: @@ -501,6 +648,26 @@ def _validate_config(self, params: KeeperParams, config: Dict[str, Any]) -> List if 'managed_company' in config: errors.extend(self._validate_mc_context(params, config['managed_company'])) + # Validate: transfer_ownership/remove_from_service_vault are incompatible with rotation + has_rotation = bool(config.get('rotation', {}).get('schedule')) + transfer = config.get('delivery', {}).get('transfer_ownership', False) + remove = config.get('delivery', {}).get('remove_from_service_vault', False) + if has_rotation and (transfer or remove): + if transfer: + errors.append( + 'delivery.transfer_ownership is incompatible with rotation.\n' + ' Transferring ownership moves the record out of the Gateway\'s control,\n' + ' which prevents password rotation. Remove transfer_ownership or remove\n' + ' the rotation schedule.' + ) + if remove: + errors.append( + 'delivery.remove_from_service_vault is incompatible with rotation.\n' + ' Removing the record from the service vault removes Gateway access,\n' + ' which prevents password rotation. Remove remove_from_service_vault\n' + ' or remove the rotation schedule.' + ) + return errors def _validate_user_section(self, user: Dict[str, Any]) -> List[str]: @@ -534,9 +701,9 @@ def _validate_account_section(self, account: Dict[str, Any]) -> List[str]: """Validate account section (target system credentials).""" errors = [] - # Required fields - if not account.get('username'): - errors.append('account.username is required') + # Either username or username_template is required + if not account.get('username') and not account.get('username_template'): + errors.append('account.username or account.username_template is required') if not account.get('pam_config_uid'): errors.append('account.pam_config_uid is required') @@ -552,35 +719,28 @@ def _validate_account_section(self, account: Dict[str, Any]) -> List[str]: return errors - def _validate_pam_section(self, pam: Dict[str, Any]) -> List[str]: - """Validate PAM section (rotation configuration).""" + def _validate_rotation_section(self, rotation: Dict[str, Any]) -> List[str]: + """Validate rotation section (schedule and password complexity).""" errors = [] - # Validate rotation subsection - rotation = pam.get('rotation', {}) - if not rotation: - errors.append('pam.rotation section is required') - return errors - - # Required rotation fields if not rotation.get('schedule'): - errors.append('pam.rotation.schedule is required (CRON format)') + errors.append('rotation.schedule is required (CRON format)') else: schedule = rotation['schedule'] if validate_cron_expression and not validate_cron_expression(schedule, for_rotation=True)[0]: errors.append( - f'pam.rotation.schedule has invalid CRON format: {schedule}\n' + f'rotation.schedule has invalid CRON format: {schedule}\n' f' Expected 6 fields: seconds minute hour day month day-of-week\n' f' Example: "0 0 3 * * ?" (Daily at 3:00:00 AM)' ) if not rotation.get('password_complexity'): - errors.append('pam.rotation.password_complexity is required') + errors.append('rotation.password_complexity is required') else: complexity = rotation['password_complexity'] if not self._is_valid_complexity(complexity): errors.append( - f'pam.rotation.password_complexity has invalid format: {complexity}\n' + f'rotation.password_complexity has invalid format: {complexity}\n' f' Expected: "length,upper,lower,digit,special"\n' f' Example: "32,5,5,5,5"' ) @@ -616,6 +776,25 @@ def _validate_email_section(self, params: KeeperParams, email: Dict[str, Any]) - return errors + def _validate_delivery_section(self, delivery: Dict[str, Any]) -> List[str]: + """Validate delivery section (KC-1035: direct vault sharing).""" + errors = [] + + share_to = delivery.get('share_to') + if not share_to: + errors.append('delivery.share_to is required') + elif not utils.is_email(share_to): + errors.append(f'delivery.share_to must be a valid email: {share_to}') + + # Validate optional permissions + permissions = delivery.get('permissions', {}) + if permissions: + for key in ('can_edit', 'can_share'): + if key in permissions and not isinstance(permissions[key], bool): + errors.append(f'delivery.permissions.{key} must be a boolean') + + return errors + def _validate_vault_section(self, params: KeeperParams, vault_config: Dict[str, Any]) -> List[str]: """Validate vault section (folder paths).""" errors = [] @@ -704,7 +883,7 @@ def _dry_run_report(self, params: KeeperParams, config: Dict[str, Any], output_f """ user = config.get('user', {}) account = config.get('account', {}) - pam = config.get('pam', {}) + rotation = config.get('rotation', {}) email_config = config.get('email', {}) vault_config = config.get('vault', {}) @@ -726,7 +905,7 @@ def _dry_run_report(self, params: KeeperParams, config: Dict[str, Any], output_f 'Generate secure password (complexity requirements applied)', f'Create PAM User: {redacted_username}', f'Link PAM User to PAM Config: {account.get("pam_config_uid")}', - f'Configure rotation: {pam.get("rotation", {}).get("schedule")}', + f'Configure rotation: {rotation.get("schedule")}', 'Submit immediate rotation', f'Generate share URL for PAM User (expiry: {email_config.get("share_url_expiry", "7d")})', f'Send email to: {redacted_email}' @@ -756,7 +935,7 @@ def _dry_run_report(self, params: KeeperParams, config: Dict[str, Any], output_f print(f' Folder: {vault_config.get("folder", default_folder)}') print(f' 4. Link to PAM Config: {account.get("pam_config_uid")[:20]}...') print(f' 5. Configure rotation') - print(f' Schedule: {pam.get("rotation", {}).get("schedule")}') + print(f' Schedule: {rotation.get("schedule")}') print(f' 6. Submit immediate rotation') print(f' 7. Generate one-time share URL for PAM User') print(f' Expiry: {email_config.get("share_url_expiry", "7d")}') @@ -1025,7 +1204,7 @@ def _create_pam_user( facade.managed = True # Set title (custom or auto-generated) - pam_title = config.get('pam', {}).get('pam_user_title') + pam_title = config.get('rotation', {}).get('pam_user_title') if pam_title: pam_user.title = pam_title else: @@ -1061,6 +1240,24 @@ def _create_pam_user( 'Distinguished Name' )) + # KC-1035: Owner metadata for deprovision support + delivery = config.get('delivery', {}) + if delivery.get('method') == 'direct_share' and delivery.get('share_to'): + custom_fields.append(vault.TypedField.new_field( + 'text', + delivery['share_to'], + 'Owner Email' + )) + + # KC-1035: AD groups metadata + ad_groups = config['account'].get('ad_groups', []) + if ad_groups: + custom_fields.append(vault.TypedField.new_field( + 'text', + ', '.join(ad_groups), + 'AD Groups' + )) + if custom_fields: pam_user.custom = custom_fields @@ -1301,7 +1498,7 @@ def _configure_rotation( CommandError: If rotation configuration fails """ - rotation_config = config['pam']['rotation'] + rotation_config = config['rotation'] pam_config_uid = config['account']['pam_config_uid'] # Check if rotation commands are available (Python 3.8+) @@ -1387,6 +1584,356 @@ def _rotate_immediately( logging.warning(f' Password will sync on next scheduled rotation') return False # Graceful degradation + # ========================================================================= + # AD User Creation via Gateway (KC-1035) + # ========================================================================= + + def _get_gateway_uid_for_config(self, pam_config_uid: str, params: KeeperParams) -> Optional[str]: + """Find the connected Gateway UID for a PAM Configuration. + + Looks up the gateway associated with the specific PAM Config, then + verifies it is online. Falls back to the controllerUid field on the + PAM Configuration record if the API call is unavailable. + """ + from ..commands.pam.config_helper import configuration_controller_get + + gateway_uid = None + + # First: try reading controllerUid from the PAM Configuration record + try: + pam_config_record = vault.KeeperRecord.load(params, pam_config_uid) + if pam_config_record: + field = pam_config_record.get_typed_field('pamResources') + value = field.get_default_value(dict) + if value: + gateway_uid = value.get('controllerUid', '') or '' + except Exception as e: + logging.debug(f'Failed to read gateway UID from PAM Config record: {e}') + + # Fallback: ask the server for the controller associated with this config + if not gateway_uid: + try: + config_uid_bytes = utils.base64_url_decode(pam_config_uid) + controller = configuration_controller_get(params, config_uid_bytes) + if controller and controller.controllerUid: + gateway_uid = utils.base64_url_encode(controller.controllerUid) + except Exception as e: + logging.debug(f'Failed to get controller from API: {e}') + + if not gateway_uid: + return None + + # Verify the gateway is actually online + try: + online = router_get_connected_gateways(params) + if online: + connected_uids = [utils.base64_url_encode(c.controllerUid) for c in online.controllers] + if gateway_uid not in connected_uids: + logging.warning(f'Gateway {gateway_uid} is associated with PAM Config but not online') + return None + except Exception as e: + logging.debug(f'Failed to verify gateway online status: {e}') + + return gateway_uid + + def _create_ad_user_via_gateway( + self, + config: Dict[str, Any], + password: str, + params: KeeperParams, + state: 'ProvisioningState' + ) -> bool: + """ + Create AD user via Gateway's rm-create-user action (KC-1035). + + Sends the create-user action to the PAM Gateway which calls + ActiveDirectory.create_user() via LDAP. + + Args: + config: YAML config with account section + password: Generated password for the new AD user + params: KeeperParams session + state: ProvisioningState for rollback tracking + + Returns: + True if AD user created successfully + + Raises: + CommandError: If AD user creation fails (critical failure) + """ + username = config['account']['username'] + pam_config_uid = config['account']['pam_config_uid'] + + gateway_uid = self._get_gateway_uid_for_config(pam_config_uid, params) + if not gateway_uid: + raise CommandError('credential-provision', 'No connected Gateway found for PAM Configuration') + + # Store for rollback + state.ad_config_uid = pam_config_uid + state.ad_gateway_uid = gateway_uid + state.ad_username = username + + # Send the full DN as the user field if available — the Gateway's + # ActiveDirectory.create_user() detects DN format and uses it to place + # the user in the correct OU (see active_directory.py lines 2118-2132) + dn = config['account'].get('distinguished_name', '') + user_value = dn if dn else username + + # Encrypt user and password with PAM Config record key — the Gateway + # decrypts them via rm.decrypt_content() which does base64_decode → AES-GCM decrypt + record_key = params.record_cache[pam_config_uid]['record_key_unencrypted'] + encrypted_user = base64.b64encode(crypto.encrypt_aes_v2(user_value.encode(), record_key)).decode() + encrypted_password = base64.b64encode(crypto.encrypt_aes_v2(password.encode(), record_key)).decode() + + # Build and send rm-create-user action + resource_uid = config['account'].get('directory_uid') + action_inputs = GatewayActionRmCreateUserInputs( + configuration_uid=pam_config_uid, + user=encrypted_user, + password=encrypted_password, + resource_uid=resource_uid, + ) + conversation_id = GatewayAction.generate_conversation_id() + + try: + router_response = router_send_action_to_gateway( + params=params, + gateway_action=GatewayActionRmCreateUser( + inputs=action_inputs, + conversation_id=conversation_id, + gateway_destination=gateway_uid + ), + message_type=pam_pb2.CMT_GENERAL, + is_streaming=False, + destination_gateway_uid_str=gateway_uid + ) + except Exception as e: + raise CommandError('credential-provision', f'Gateway communication failed: {e}') + + if router_response is None: + raise CommandError('credential-provision', 'No response from Gateway for AD user creation') + + try: + payload = get_response_payload(router_response) + except Exception as e: + raise CommandError('credential-provision', f'Failed to parse Gateway response: {e}') + + # Gateway response structure: payload = {data: {success, error, ...}, is_ok, ...} + data = payload.get('data', {}) if payload else {} + if not data or not data.get('success', False): + error = data.get('error', 'Unknown error') if data else 'Empty response' + raise CommandError('credential-provision', f'AD user creation failed: {error}') + + state.ad_user_created = True + return True + + def _add_ad_user_to_groups_via_gateway( + self, + config: Dict[str, Any], + params: KeeperParams, + gateway_uid: str + ) -> None: + """ + Add AD user to groups via Gateway's rm-add-user-to-group action (KC-1035). + + Args: + config: YAML config with account.ad_groups list + params: KeeperParams session + gateway_uid: Connected Gateway UID + """ + username = config['account']['username'] + pam_config_uid = config['account']['pam_config_uid'] + resource_uid = config['account'].get('directory_uid') + groups = config['account'].get('ad_groups', []) + + # Encrypt username with PAM Config record key + record_key = params.record_cache[pam_config_uid]['record_key_unencrypted'] + encrypted_user = base64.b64encode(crypto.encrypt_aes_v2(username.encode(), record_key)).decode() + + for group in groups: + action_inputs = GatewayActionRmAddUserToGroupInputs( + configuration_uid=pam_config_uid, + user=encrypted_user, + group_id=group, + resource_uid=resource_uid, + ) + conversation_id = GatewayAction.generate_conversation_id() + + try: + router_response = router_send_action_to_gateway( + params=params, + gateway_action=GatewayActionRmAddUserToGroup( + inputs=action_inputs, + conversation_id=conversation_id, + gateway_destination=gateway_uid + ), + message_type=pam_pb2.CMT_GENERAL, + is_streaming=False, + destination_gateway_uid_str=gateway_uid + ) + + payload = get_response_payload(router_response) if router_response else None + data = payload.get('data', {}) if payload else {} + if data and not data.get('success', False): + error = data.get('error', 'Unknown error') + logging.warning(f' Failed to add {username} to group {group}: {error}') + else: + logging.info(f' Added {username} to AD group: {group}') + + except Exception as e: + logging.warning(f' Failed to add {username} to group {group}: {e}') + + def _delete_ad_user_via_gateway(self, state: 'ProvisioningState', params: KeeperParams) -> None: + """Delete AD user via Gateway for rollback (KC-1035).""" + if not state.ad_user_created or not state.ad_username: + return + + try: + action_inputs = GatewayActionRmDeleteUserInputs( + configuration_uid=state.ad_config_uid, + user=state.ad_username, + ) + conversation_id = GatewayAction.generate_conversation_id() + + router_send_action_to_gateway( + params=params, + gateway_action=GatewayActionRmDeleteUser( + inputs=action_inputs, + conversation_id=conversation_id, + gateway_destination=state.ad_gateway_uid + ), + message_type=pam_pb2.CMT_GENERAL, + is_streaming=False, + destination_gateway_uid_str=state.ad_gateway_uid + ) + logging.info(f'Rolled back AD user: {state.ad_username}') + except Exception as e: + logging.error(f'Failed to rollback AD user {state.ad_username}: {e}') + logging.error('Manual cleanup may be required in Active Directory') + + def _share_directly( + self, + pam_user_uid: str, + config: Dict[str, Any], + params: KeeperParams + ) -> bool: + """ + Share PAM User record directly to a user's vault (KC-1035). + + Uses ShareRecordCommand to grant vault-to-vault access. The recipient + sees the record in their vault with the current password, including + updates after rotation. + + Args: + pam_user_uid: UID of PAM User record to share + config: YAML configuration with delivery section + params: KeeperParams session + + Returns: + True if shared successfully, False on failure (non-critical) + """ + delivery = config.get('delivery', {}) + share_to = delivery.get('share_to') + permissions = delivery.get('permissions', {}) + can_edit = permissions.get('can_edit', False) + can_share = permissions.get('can_share', False) + + transfer_ownership = delivery.get('transfer_ownership', False) + + try: + # Sync vault to ensure the new record is available + api.sync_down(params) + + # Step 1: Share the record to the target user's vault + share_kwargs = { + 'record': pam_user_uid, + 'email': [share_to], + 'action': 'grant', + 'can_edit': can_edit, + 'can_share': can_share, + } + rq = ShareRecordCommand.prep_request(params, share_kwargs) + if rq is None: + logging.warning(f'Share invitation sent to {share_to} — share will complete when accepted') + return True + ShareRecordCommand.send_requests(params, [rq]) + + # Step 2: Transfer ownership if configured + if transfer_ownership: + api.sync_down(params) + transfer_kwargs = { + 'record': pam_user_uid, + 'email': [share_to], + 'action': 'owner', + } + rq = ShareRecordCommand.prep_request(params, transfer_kwargs) + if rq: + ShareRecordCommand.send_requests(params, [rq]) + logging.info(f'Ownership transferred to {share_to}') + + # Step 3: Remove record from service vault if configured + # Uses pre_delete + delete (two-step) to unlink from folder without deleting for new owner + remove_after = delivery.get('remove_from_service_vault', False) + if remove_after: + try: + api.sync_down(params) + # Find which folders contain this record and unlink from all of them + unlink_objects = [] + for folder_uid, record_uids in params.subfolder_record_cache.items(): + if pam_user_uid in record_uids: + folder = params.folder_cache.get(folder_uid) if folder_uid else params.root_folder + del_obj = { + 'delete_resolution': 'unlink', + 'object_uid': pam_user_uid, + 'object_type': 'record', + } + if hasattr(folder, 'type'): + if folder.type in {'user_folder', 'root'}: + del_obj['from_type'] = 'user_folder' + if folder_uid: + del_obj['from_uid'] = folder_uid + else: + del_obj['from_type'] = 'shared_folder_folder' + del_obj['from_uid'] = folder_uid + else: + del_obj['from_type'] = 'user_folder' + if folder_uid: + del_obj['from_uid'] = folder_uid + unlink_objects.append(del_obj) + + if unlink_objects: + # Step 1: pre_delete — get the deletion token + rq = { + 'command': 'pre_delete', + 'objects': unlink_objects + } + rs = api.communicate(params, rq) + if rs.get('result') == 'success': + pdr = rs.get('pre_delete_response', {}) + if 'pre_delete_token' in pdr: + # Step 2: delete — execute with the token + rq2 = { + 'command': 'delete', + 'pre_delete_token': pdr['pre_delete_token'] + } + api.communicate(params, rq2) + logging.info(f'Record removed from service vault') + else: + logging.warning(f'No pre_delete_token in response') + else: + logging.warning(f'pre_delete failed: {rs}') + else: + logging.info(f'Record not found in service vault folders') + except Exception as rev_e: + logging.warning(f'Failed to remove record from service vault: {rev_e}') + + return True + + except Exception as e: + logging.warning(f'Direct share to {share_to} failed: {e}') + logging.warning('The PAM User record was created successfully. Share manually if needed.') + return False + def _generate_share_url( self, pam_user_uid: str, @@ -1603,6 +2150,9 @@ def _rollback(self, state: ProvisioningState, params: KeeperParams) -> None: logging.warning('Rolling back provisioning changes') + # Rollback in LIFO order (reverse of creation order: AD first, then PAM User) + + # 1. Delete PAM User record first (created after AD user) if state.pam_user_uid: try: api.delete_record(params, state.pam_user_uid) @@ -1610,6 +2160,14 @@ def _rollback(self, state: ProvisioningState, params: KeeperParams) -> None: rollback_errors.append(f'PAM User: {e}') logging.error(f'Rollback failed for PAM User: {e}') + # 2. Delete AD user last (created before PAM User) + if state.ad_user_created: + try: + self._delete_ad_user_via_gateway(state, params) + except Exception as e: + rollback_errors.append(f'AD User ({state.ad_username}): {e}') + logging.error(f'Rollback failed for AD User: {e}') + if rollback_errors: logging.error('Rollback completed with errors - manual cleanup may be required') for error in rollback_errors: @@ -1781,10 +2339,13 @@ def _validate_system_specific_fields( # Unknown PAM Type # =================================================================== else: - logging.warning('') - logging.warning(f'⚠️ Unknown or unsupported PAM system type: "{pam_type}"') - logging.warning(' Using generic validation only') - logging.warning(' Supported types: Active Directory, Azure AD, AWS IAM') - logging.warning('') + # Skip warning when resource_uid is provided — the resource record + # determines AD behavior, not the config type + if not config['account'].get('directory_uid'): + logging.warning('') + logging.warning(f'⚠️ Unknown or unsupported PAM system type: "{pam_type}"') + logging.warning(' Using generic validation only') + logging.warning(' Supported types: Active Directory, Azure AD, AWS IAM') + logging.warning('') return errors diff --git a/keepercommander/commands/domain_management/__init__.py b/keepercommander/commands/domain_management/__init__.py new file mode 100644 index 000000000..dcafbaeea --- /dev/null +++ b/keepercommander/commands/domain_management/__init__.py @@ -0,0 +1,106 @@ +# _ __ +# | |/ /___ ___ _ __ ___ _ _ ® +# | ' `` to the appropriate command.""" + + def __init__(self): + super().__init__() + self.list_cmd = GetDomainAliasCommand() + self.create_cmd = CreateDomainAliasCommand() + self.delete_cmd = DeleteDomainAliasCommand() + + def get_parser(self): + return domain_alias_parser + + def execute(self, params, **kwargs): + alias_subcommand = kwargs.get('alias_subcommand') + + if not alias_subcommand: + self.get_parser().print_help() + return + + if alias_subcommand == 'list': + return self.list_cmd.execute(params, **kwargs) + elif alias_subcommand == 'create': + return self.create_cmd.execute(params, **kwargs) + elif alias_subcommand == 'delete': + return self.delete_cmd.execute(params, **kwargs) + else: + output_format = kwargs.get('format', 'text') + DomainManagementHelper.handle_invalid_subcommand( + f'alias {alias_subcommand}', output_format, + ) + return None + + +class GetDomainAliasCommand(EnterpriseCommand): + """List all domain aliases for the enterprise.""" + + def get_parser(self): + return domain_alias_list_parser + + def execute(self, params, **kwargs): + output_format = kwargs.get('format') or 'text' + try: + rs = api.communicate_rest( + params, None, API_ENDPOINTS['get_domain_alias'], + rs_type=enterprise_pb2.DomainAliasResponse, + ) + + if not rs.domainAlias: + logging.info('No domain aliases found for this enterprise.') + return + + return DomainManagementHelper.render_alias_response(rs, output_format, kwargs) + + except KeeperApiError as e: + DomainManagementHelper.handle_alias_api_error(e, output_format, 'retrieving') + + +class CreateDomainAliasCommand(EnterpriseCommand): + """Create one or more domain aliases for a domain owned by the enterprise.""" + + def get_parser(self): + return domain_alias_create_parser + + def execute(self, params, **kwargs): + domain = kwargs.get('domain', '') + aliases = kwargs.get('alias', []) + output_format = kwargs.get('format', 'text') + + if not domain: + DomainManagementHelper.output_error('Domain name is required.', output_format) + return + if not aliases: + DomainManagementHelper.output_error('At least one alias is required.', output_format) + return + + is_valid, domain, error_msg = DomainManagementHelper.validate_domain(domain) + if not is_valid: + DomainManagementHelper.output_error(error_msg, output_format) + return + + normalized_aliases = DomainManagementHelper.validate_aliases(domain, aliases, output_format) + if normalized_aliases is None: + return + + try: + rq = DomainManagementHelper.build_alias_request(domain, normalized_aliases) + rs = api.communicate_rest( + params, rq, API_ENDPOINTS['create_domain_alias'], + rs_type=enterprise_pb2.DomainAliasResponse, + ) + return DomainManagementHelper.render_alias_response( + rs, output_format, kwargs, + status_messages=DomainManagementHelper.CREATE_ALIAS_STATUS_MESSAGES, + action='create', + ) + + except KeeperApiError as e: + DomainManagementHelper.handle_alias_api_error(e, output_format, 'creating') + + +class DeleteDomainAliasCommand(EnterpriseCommand): + """Delete one or more domain aliases for a domain owned by the enterprise.""" + + def get_parser(self): + return domain_alias_delete_parser + + def execute(self, params, **kwargs): + domain = kwargs.get('domain', '') + aliases = kwargs.get('alias', []) + output_format = kwargs.get('format', 'text') + force = kwargs.get('force', False) + + if not domain: + DomainManagementHelper.output_error('Domain name is required.', output_format) + return + if not aliases: + DomainManagementHelper.output_error('At least one alias is required.', output_format) + return + + is_valid, domain, error_msg = DomainManagementHelper.validate_domain(domain) + if not is_valid: + DomainManagementHelper.output_error(error_msg, output_format) + return + + normalized_aliases = DomainManagementHelper.validate_aliases(domain, aliases, output_format) + if normalized_aliases is None: + return + + existing_aliases = self._get_existing_aliases(params) + not_found = [a for a in normalized_aliases if (domain, a) not in existing_aliases] + if not_found: + for alias in not_found: + DomainManagementHelper.output_error( + f"Domain alias '{alias}' for domain '{domain}' does not exist.", output_format, + ) + return + + if not force: + alias_list_str = ', '.join(normalized_aliases) + try: + confirm = input( + f'Are you sure you want to delete alias(es) [{alias_list_str}] for domain "{domain}"? (y/N): ' + ) + except (KeyboardInterrupt, EOFError): + logging.info('Delete cancelled.') + return + if confirm.strip().lower() not in ('y', 'yes'): + logging.info('Delete cancelled.') + return + + try: + rq = DomainManagementHelper.build_alias_request(domain, normalized_aliases) + rs = api.communicate_rest( + params, rq, API_ENDPOINTS['delete_domain_alias'], + rs_type=enterprise_pb2.DomainAliasResponse, + ) + return DomainManagementHelper.render_alias_response( + rs, output_format, kwargs, + status_messages=DomainManagementHelper.DELETE_ALIAS_STATUS_MESSAGES, + action='delete', + ) + + except KeeperApiError as e: + DomainManagementHelper.handle_alias_api_error(e, output_format, 'deleting') + + @staticmethod + def _get_existing_aliases(params): + """Fetch current domain aliases and return as a set of (domain, alias) tuples.""" + try: + rs = api.communicate_rest( + params, None, API_ENDPOINTS['get_domain_alias'], + rs_type=enterprise_pb2.DomainAliasResponse, + ) + return {(da.domain, da.alias) for da in rs.domainAlias} if rs.domainAlias else set() + except KeeperApiError: + return set() diff --git a/keepercommander/commands/domain_management/constants.py b/keepercommander/commands/domain_management/constants.py new file mode 100644 index 000000000..87a39bb7e --- /dev/null +++ b/keepercommander/commands/domain_management/constants.py @@ -0,0 +1,59 @@ +# _ __ +# | |/ /___ ___ _ __ ___ _ _ ® +# | ' DomainManagementHelper.MAX_DOMAIN_LENGTH: + return ( + False, domain, + f'Invalid domain name: must be between 1 and {DomainManagementHelper.MAX_DOMAIN_LENGTH} characters', + ) + + if not re.match(DomainManagementHelper.DOMAIN_PATTERN, domain): + return False, domain, 'Invalid domain format: domain must contain only letters, numbers, hyphens, and dots' + + if '.' not in domain: + return False, domain, 'Invalid domain: must contain at least one dot (e.g., example.com)' + + labels = domain.split('.') + for label in labels: + if len(label) > DomainManagementHelper.MAX_LABEL_LENGTH: + return ( + False, domain, + f'Invalid domain: label "{label}" exceeds {DomainManagementHelper.MAX_LABEL_LENGTH} characters', + ) + + if len(labels[-1]) < DomainManagementHelper.MIN_TLD_LENGTH: + return ( + False, domain, + f'Invalid domain: TLD must be at least {DomainManagementHelper.MIN_TLD_LENGTH} characters', + ) + + return True, domain, None + + @staticmethod + def get_error_message(error_code, domain, action): + """Get user-friendly error message using class constant dictionary.""" + if error_code == 'invalid_token' and action == 'add': + return ( + f'Failed to verify domain "{domain}". Please ensure you have added the TXT ' + f'record with the correct token to your DNS settings and try again.' + ) + + if error_code in ('exists', 'domain_exists'): + if action == 'token': + return f'Domain "{domain}" already exists in the enterprise. Use action "delete" to remove it first.' + elif action == 'add': + return f'Domain "{domain}" already exists in the enterprise. It may have already been added successfully.' + + if error_code in ('not_exists', 'domain_not_found', 'doesnt_exist') and action == 'delete': + return f'Domain "{domain}" does not exist. Use action "token" to start the domain reservation process.' + + message_template = DomainManagementHelper.ERROR_MESSAGES.get(error_code) + if message_template: + if '{domain}' in message_template: + return message_template.format(domain=domain) + return message_template + + return f'Unable to {action} domain "{domain}". Please try again or contact support if the issue persists.' + + + @staticmethod + def validate_aliases(domain, aliases, output_format): + """Validate a list of alias names and check none equals the domain. + + Returns a list of normalized alias strings on success, or None if any + validation fails (error is already reported via output_error). + """ + normalized = [] + for alias_name in aliases: + valid, normalized_alias, err = DomainManagementHelper.validate_domain(alias_name) + if not valid: + DomainManagementHelper.output_error(f'{err}', output_format) + return None + if normalized_alias == domain: + DomainManagementHelper.output_error( + f'Alias cannot be the same as the domain ("{domain}").', output_format + ) + return None + normalized.append(normalized_alias) + return normalized + + @staticmethod + def build_alias_request(domain, normalized_aliases): + """Build a DomainAliasRequest protobuf from a validated domain and alias list.""" + rq = enterprise_pb2.DomainAliasRequest() + for alias in normalized_aliases: + da = enterprise_pb2.DomainAlias() + da.domain = domain + da.alias = alias + rq.domainAlias.append(da) + return rq + + @staticmethod + def render_alias_response(rs, output_format, kwargs, status_messages=None, action=None): + """Render a DomainAliasResponse as text or JSON. + + When status_messages is None the response is treated as a listing + (domain + alias only). When provided, a simple message is shown per + alias for text format, or a status object for JSON format. + """ + if status_messages is not None: + if output_format == 'json': + results = [ + { + 'domain': da.domain, + 'alias': da.alias, + 'status': da.status, + 'status_message': status_messages.get(da.status, f'Unknown status: {da.status}'), + } + for da in rs.domainAlias + ] + print(json.dumps(results, indent=2)) + else: + for da in rs.domainAlias: + status_msg = status_messages.get(da.status, f'Unknown ({da.status})') + if da.status == 0: + if action == 'create': + logging.info(f"Created domain alias '{da.alias}' for domain '{da.domain}'") + elif action == 'delete': + logging.info(f"Deleted domain alias '{da.alias}' for domain '{da.domain}'") + else: + logging.info(f"Domain alias '{da.alias}' for domain '{da.domain}': {status_msg}") + else: + if action == 'delete': + logging.error(f"Failed to delete domain alias '{da.alias}' for domain '{da.domain}': {status_msg}") + elif action == 'create': + logging.error(f"Failed to create domain alias '{da.alias}' for domain '{da.domain}': {status_msg}") + else: + logging.error(f"Domain alias '{da.alias}' for domain '{da.domain}': {status_msg}") + else: + headers = ['Domain', 'Alias'] + table = [[da.domain, da.alias] for da in rs.domainAlias] + return dump_report_data(table, headers, fmt=output_format, filename=kwargs.get('output')) + + @staticmethod + def handle_alias_api_error(e, output_format, operation): + """Shared KeeperApiError handler for alias commands.""" + error_code = DomainManagementHelper.get_error_code(e) + + if DomainManagementHelper.is_feature_unavailable(error_code): + result = DomainManagementHelper.handle_unavailable_feature(output_format) + if result: + print(result) + return + + if error_code == 'access_denied': + DomainManagementHelper.output_error( + DomainManagementHelper.ALIAS_ACCESS_DENIED_MSG, output_format + ) + return + + logging.error(f'Error {operation} domain aliases: {e}') + raise diff --git a/keepercommander/commands/domain_management/parsers.py b/keepercommander/commands/domain_management/parsers.py new file mode 100644 index 000000000..57b5d73b9 --- /dev/null +++ b/keepercommander/commands/domain_management/parsers.py @@ -0,0 +1,139 @@ +# _ __ +# | |/ /___ ___ _ __ ___ _ _ ® +# | ' DomainManagementHelper.MAX_DOMAIN_LENGTH: - return False, domain, f'Invalid domain name: must be between 1 and {DomainManagementHelper.MAX_DOMAIN_LENGTH} characters' - - import re - if not re.match(DomainManagementHelper.DOMAIN_PATTERN, domain): - return False, domain, 'Invalid domain format: domain must contain only letters, numbers, hyphens, and dots' - - if '.' not in domain: - return False, domain, 'Invalid domain: must contain at least one dot (e.g., example.com)' - - # Check each label - labels = domain.split('.') - for label in labels: - if len(label) > DomainManagementHelper.MAX_LABEL_LENGTH: - return False, domain, f'Invalid domain: label "{label}" exceeds {DomainManagementHelper.MAX_LABEL_LENGTH} characters' - - # Check TLD length - if len(labels[-1]) < DomainManagementHelper.MIN_TLD_LENGTH: - return False, domain, f'Invalid domain: TLD must be at least {DomainManagementHelper.MIN_TLD_LENGTH} characters' - - return True, domain, None - - @staticmethod - def get_error_message(error_code, domain, action): - """Get user-friendly error message using class constant dictionary.""" - - if error_code == 'invalid_token' and action == 'add': - return f'Failed to verify domain "{domain}". Please ensure you have added the TXT record with the correct token to your DNS settings and try again.' - - if error_code in ('exists', 'domain_exists'): - if action == 'token': - return f'Domain "{domain}" already exists in the enterprise. Use action "delete" to remove it first.' - elif action == 'add': - return f'Domain "{domain}" already exists in the enterprise. It may have already been added successfully.' - - if error_code in ('not_exists', 'domain_not_found', 'doesnt_exist') and action == 'delete': - return f'Domain "{domain}" does not exist. Use action "token" to start the domain reservation process.' - - message_template = DomainManagementHelper.ERROR_MESSAGES.get(error_code) - - if message_template: - if '{domain}' in message_template: - return message_template.format(domain=domain) - return message_template - - return f'Unable to {action} domain "{domain}". Please try again or contact support if the issue persists.' - - @staticmethod - def handle_invalid_subcommand(subcommand, output_format='text'): - """ - Handle invalid subcommand error with user-friendly message. - - Args: - subcommand: The invalid subcommand provided by the user - output_format: Output format (text or json) - """ - error_message = ( - f"Invalid subcommand: '{subcommand}'. " - f"Use 'domain --help' for more information." - ) - - if output_format == 'json': - error_output = { - 'error': error_message, - } - print(json.dumps(error_output, indent=2)) - else: - logging.error(error_message) - - -class DomainCommand(EnterpriseCommand): - def __init__(self): - super().__init__() - self.list_cmd = ListDomainsCommand() - self.reserve_cmd = ReserveDomainCommand() - - def get_parser(self): - return domain_parser - - def execute_args(self, params, args, **kwargs): - import shlex - from .base import ParseError, expand_cmd_args, normalize_output_param - - try: - d = {} - d.update(kwargs) - self.extra_parameters = '' - parser = self._get_parser_safe() - envvars = params.environment_variables - args = '' if args is None else args - - if parser: - args = expand_cmd_args(args, envvars) - args = normalize_output_param(args) - opts = parser.parse_args(shlex.split(args)) - d.update(opts.__dict__) - - return self.execute(params, **d) - - except ParseError as e: - error_str = str(e) - if 'invalid choice' in error_str: - import re - match = re.search(r"invalid choice: '([^']+)'", error_str) - if match: - invalid_cmd = match.group(1) - output_format = kwargs.get('format', 'text') - DomainManagementHelper.handle_invalid_subcommand(invalid_cmd, output_format) - return None - logging.error(error_str) - return None - - def execute(self, params, **kwargs): - subcommand = kwargs.get('subcommand') - - if not subcommand: - self.get_parser().print_help() - return - - if subcommand in ('list'): - return self.list_cmd.execute(params, **kwargs) - elif subcommand in ('reserve'): - return self.reserve_cmd.execute(params, **kwargs) - else: - output_format = kwargs.get('format', 'text') - DomainManagementHelper.handle_invalid_subcommand(subcommand, output_format) - return None - - -class ListDomainsCommand(EnterpriseCommand): - def get_parser(self): - return domain_list_parser - - def execute(self, params, **kwargs): - try: - rs = api.communicate_rest( - params, - None, - 'enterprise/list_domains', - rs_type=enterprise_pb2.ListDomainsResponse - ) - - fmt = kwargs.get('format', '') - - if not rs.domain: - logging.info('No reserved domains found for this enterprise.') - return - - if fmt == 'json': - domains_list = list(rs.domain) - print(json.dumps(domains_list, indent=2)) - else: - headers = ['Domain Name'] - table = [[domain] for domain in rs.domain] - return dump_report_data(table, headers, fmt=fmt, filename=kwargs.get('output')) - - except KeeperApiError as e: - error_code = e.result_code if hasattr(e, 'result_code') else 'Unknown' - - if DomainManagementHelper.is_feature_unavailable(error_code): - result = DomainManagementHelper.handle_unavailable_feature(kwargs.get('format') or 'text') - if result: - print(result) - return - - logging.error(f'Error listing domains: {e}') - raise - - -class ReserveDomainCommand(EnterpriseCommand): - - ACTION_MAP = { - 'token': enterprise_pb2.DOMAIN_TOKEN, - 'add': enterprise_pb2.DOMAIN_ADD, - 'delete': enterprise_pb2.DOMAIN_DELETE - } - - def get_parser(self): - return domain_reserve_parser - - def execute(self, params, **kwargs): - action = kwargs.get('action') - domain = kwargs.get('domain') - output_format = kwargs.get('format', 'text') - force = kwargs.get('force', False) - - if not self._validate_inputs(action, domain, output_format): - return - - is_valid, domain, error_msg = DomainManagementHelper.validate_domain(domain) - if not is_valid: - DomainManagementHelper.output_error(error_msg, output_format, domain=domain or '', status='failed') - return - - try: - result = self._execute_action(params, action, domain, output_format, force=force) - if result: - return result - - except KeeperApiError as e: - return self._handle_api_error(e, domain, action, output_format) - - except Exception as e: - error_msg = f'Unexpected error: {str(e)}' - DomainManagementHelper.output_error(error_msg, output_format, domain=domain, action=action) - logging.debug(f'Exception details: {e}', exc_info=True) - - def _validate_inputs(self, action, domain, output_format): - """Validate action and domain inputs.""" - if not action: - DomainManagementHelper.output_error('Action is required', output_format, status='failed') - return False - - if action not in self.ACTION_MAP: - DomainManagementHelper.output_error( - f'Invalid action: {action}. Must be one of: {", ".join(self.ACTION_MAP.keys())}', - output_format, - status='failed' - ) - return False - - if not domain: - DomainManagementHelper.output_error('Domain is required', output_format, status='failed') - return False - - return True - - def _execute_action(self, params, action, domain, output_format, force=False): - """Execute the specified domain action after validation.""" - if not action or not domain: - DomainManagementHelper.output_error('Action and domain are required', output_format, status='failed') - return - - rq = self._create_request(action, domain) - - if action == 'token': - return self._handle_token_action(params, rq, domain, output_format) - elif action == 'add': - return self._handle_add_action(params, rq, domain, output_format) - elif action == 'delete': - return self._handle_delete_action(params, rq, domain, output_format, force=force) - - def _create_request(self, action, domain): - rq = enterprise_pb2.ReserveDomainRequest() - rq.reserveDomainAction = self.ACTION_MAP[action] - rq.domain = domain - - return rq - - def _handle_token_action(self, params, rq, domain, output_format): - rs = api.communicate_rest( - params, - rq, - 'enterprise/reserve_domain', - rs_type=enterprise_pb2.ReserveDomainResponse - ) - - if not rs or not hasattr(rs, 'token') or not rs.token: - DomainManagementHelper.output_error( - 'Failed to generate token: empty response from server', - output_format, - domain=domain, - ) - return - - if output_format == 'json': - return json.dumps({'token': rs.token, 'domain': domain}, indent=2) - - self._display_token_instructions(domain, rs.token) - - def _handle_add_action(self, params, rq, domain, output_format): - api.communicate_rest(params, rq, 'enterprise/reserve_domain') - - if output_format == 'json': - return json.dumps({ - 'message': 'Domain successfully added to enterprise', - 'domain': domain, - 'action': 'add', - }, indent=2) - - logging.info(f'Domain "{domain}" has been reserved for the enterprise') - self._refresh_enterprise_data(params, 'added') - - def _handle_delete_action(self, params, rq, domain, output_format, force=False): - """Handle domain deletion with optional confirmation.""" - if not force and output_format != 'json': - domain_exists = self._check_domain_exists(params, domain) - if not domain_exists: - pass - else: - confirm = input(f'\n{bcolors.WARNING}Are you sure you want to delete domain "{domain}"? (y/n): {bcolors.ENDC}') - if confirm.lower() not in ['yes', 'y']: - logging.info('Domain deletion cancelled') - return - - api.communicate_rest(params, rq, 'enterprise/reserve_domain') - - if output_format == 'json': - return json.dumps({ - 'message': 'Domain removed from enterprise', - 'domain': domain, - 'action': 'delete', - }, indent=2) - - logging.info(f'Domain "{domain}" has been removed from the enterprise') - self._refresh_enterprise_data(params, 'removed') - - def _check_domain_exists(self, params, domain): - """Check if a domain exists in the enterprise.""" - rs = api.communicate_rest( - params, - None, - 'enterprise/list_domains', - rs_type=enterprise_pb2.ListDomainsResponse - ) - return domain in rs.domain if rs.domain else False - - def _display_token_instructions(self, domain, token): - logging.info(f'\n{bcolors.OKGREEN}Token generated successfully!{bcolors.ENDC}\n') - logging.info(f'Domain: {bcolors.BOLD}{domain}{bcolors.ENDC}') - logging.info(f'Token: {bcolors.BOLD}{token}{bcolors.ENDC}\n') - logging.info('Next steps:') - logging.info('1. Log into your domain registrar or DNS provider') - logging.info(f'2. Add a TXT record for domain "{domain}" with value:') - logging.info(f' {bcolors.WARNING}{token}{bcolors.ENDC}') - logging.info('3. Wait for DNS propagation (may take a few minutes)') - logging.info(f'4. Run: domain reserve --action add --domain {domain}') - - def _refresh_enterprise_data(self, params, action_past_tense): - try: - api.query_enterprise(params) - except Exception as refresh_error: - logging.warning(f'Successfully {action_past_tense} domain but failed to refresh enterprise data: {refresh_error}') - - def _handle_api_error(self, error, domain, action, output_format): - error_code = error.result_code if hasattr(error, 'result_code') else 'Unknown' - - if DomainManagementHelper.is_feature_unavailable(error_code): - result = DomainManagementHelper.handle_unavailable_feature(output_format) - if result: - return result - return - - error_msg = DomainManagementHelper.get_error_message(error_code, domain, action) - - if output_format == 'json': - return json.dumps({ - 'message': error_msg, - 'domain': domain, - 'action': action, - }, indent=2) - - logging.error(error_msg) diff --git a/keepercommander/commands/folder.py b/keepercommander/commands/folder.py index 9324ad146..f857b9f43 100644 --- a/keepercommander/commands/folder.py +++ b/keepercommander/commands/folder.py @@ -581,6 +581,10 @@ def execute(self, params, **kwargs): params.environment_variables[LAST_FOLDER_UID] = folder_uid if request['folder_type'] == 'shared_folder': params.environment_variables[LAST_SHARED_FOLDER_UID] = folder_uid + parent_path = get_folder_path(params, base_folder.uid) if base_folder.uid else '' + path = f'{parent_path}{name}' + response_data = {'folder_uid': folder_uid, 'name': name, 'path': path} + logging.info(json.dumps(response_data)) return folder_uid diff --git a/keepercommander/commands/msp.py b/keepercommander/commands/msp.py index 950d6bf0a..b267a2b1d 100644 --- a/keepercommander/commands/msp.py +++ b/keepercommander/commands/msp.py @@ -985,6 +985,12 @@ def execute(self, params, **kwargs): seats = 2147483647 name = kwargs['name'] + managed_companies = params.enterprise.get('managed_companies', []) + existing_mc = get_mc_by_name_or_id(managed_companies, name) + if existing_mc: + logging.warning('Managed company \'%s\' already exists: Skipping', name) + return + tree_key = utils.generate_aes_key() rq = { 'command': 'enterprise_registration_by_msp', diff --git a/keepercommander/commands/pam/pam_dto.py b/keepercommander/commands/pam/pam_dto.py index b1fc1934f..b23eb4c46 100644 --- a/keepercommander/commands/pam/pam_dto.py +++ b/keepercommander/commands/pam/pam_dto.py @@ -228,3 +228,84 @@ def __init__(self, inputs: dict, conversation_id=None, message_id=None): def toJSON(self): return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + + +# REMOTE MANAGEMENT ACTIONS (KC-1035) + +class GatewayActionRmCreateUserInputs: + + def __init__(self, configuration_uid, user, password=None, resource_uid=None, meta=None, connect_info=None): + self.configurationUid = configuration_uid + self.user = user + if password is not None: + self.password = password + if resource_uid is not None: + self.resourceUid = resource_uid + if meta is not None: + self.meta = meta + if connect_info is not None: + self.connect_info = connect_info + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + + +class GatewayActionRmCreateUser(GatewayAction): + + def __init__(self, inputs, conversation_id=None, gateway_destination=None): + super().__init__('rm-create-user', inputs=inputs, conversation_id=conversation_id, + gateway_destination=gateway_destination, is_scheduled=False) + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + + +class GatewayActionRmAddUserToGroupInputs: + + def __init__(self, configuration_uid, user, group_id, resource_uid=None, connect_info=None): + self.configurationUid = configuration_uid + self.user = user + self.groupId = group_id + if resource_uid is not None: + self.resourceUid = resource_uid + if connect_info is not None: + self.connect_info = connect_info + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + + +class GatewayActionRmAddUserToGroup(GatewayAction): + + def __init__(self, inputs, conversation_id=None, gateway_destination=None): + super().__init__('rm-add-user-to-group', inputs=inputs, conversation_id=conversation_id, + gateway_destination=gateway_destination, is_scheduled=False) + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + + +class GatewayActionRmDeleteUserInputs: + + def __init__(self, configuration_uid, user, resource_uid=None, meta=None, connect_info=None): + self.configurationUid = configuration_uid + self.user = user + if resource_uid is not None: + self.resourceUid = resource_uid + if meta is not None: + self.meta = meta + if connect_info is not None: + self.connect_info = connect_info + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + + +class GatewayActionRmDeleteUser(GatewayAction): + + def __init__(self, inputs, conversation_id=None, gateway_destination=None): + super().__init__('rm-delete-user', inputs=inputs, conversation_id=conversation_id, + gateway_destination=gateway_destination, is_scheduled=False) + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) diff --git a/keepercommander/commands/pam_launch/guac_cli/__init__.py b/keepercommander/commands/pam_launch/guac_cli/__init__.py index 34e31b416..316062e49 100644 --- a/keepercommander/commands/pam_launch/guac_cli/__init__.py +++ b/keepercommander/commands/pam_launch/guac_cli/__init__.py @@ -14,12 +14,33 @@ This module provides Guacamole protocol handling for CLI sessions. -Components: -- instructions: Instruction handlers with routing to new guacamole module -- stdin_handler: Reads stdin and sends via pipe/blob/end pattern (for SSH/TTY) -- decoder: Parses Guacamole protocol instructions (legacy) -- renderer: Renders terminal output via ANSI/curses -- input: Maps stdin keystrokes to X11 keysyms (for graphical protocols) +Input modes +----------- +Key-event mode (default) + InputHandler maps every keystroke to a Guacamole ``key`` instruction + (press + release), matching Web Vault behaviour (Guacamole.Keyboard → + sendKeyEvent). This is the default for ``pam launch``. + +Pipe / stdin mode (--stdin) + StdinHandler reads raw stdin bytes and sends them via the pipe/blob/end + STDIN stream, matching kcm-cli behaviour. Selected with ``--stdin``. + +Shared behaviour (both modes) + Paste chords (Ctrl+V, Shift+Insert; Windows: also Ctrl+Shift+V) read the + OS clipboard via pyperclip and send it using the Vault-equivalent Guacamole + clipboard stream protocol (``clipboard`` + ``blob`` + ``end``). + + Ctrl+C double-tap: first press forwards the interrupt to the remote; + a second press within 400 ms tears down the local session. + +Components +---------- +- input: InputHandler — key-event mode stdin reader +- stdin_handler: StdinHandler — pipe/byte mode stdin reader +- session_input: CtrlCCoordinator, PasteOrchestrator — shared helpers +- instructions: Instruction handlers with routing to the guacamole module +- decoder: Guacamole protocol parser (legacy) +- renderer: Terminal output renderer via ANSI/curses """ from .instructions import create_instruction_router, get_default_handlers @@ -27,6 +48,7 @@ from .decoder import GuacamoleDecoder, GuacInstruction, GuacOp, X11Keysym from .renderer import TerminalRenderer from .input import InputHandler +from .session_input import CtrlCCoordinator, PasteOrchestrator __all__ = [ 'create_instruction_router', @@ -38,5 +60,7 @@ 'X11Keysym', 'TerminalRenderer', 'InputHandler', + 'CtrlCCoordinator', + 'PasteOrchestrator', ] diff --git a/keepercommander/commands/pam_launch/guac_cli/decoder.py b/keepercommander/commands/pam_launch/guac_cli/decoder.py index d467849bb..817b788a8 100644 --- a/keepercommander/commands/pam_launch/guac_cli/decoder.py +++ b/keepercommander/commands/pam_launch/guac_cli/decoder.py @@ -302,6 +302,7 @@ class X11Keysym: RETURN = 0xFF0D ESCAPE = 0xFF1B DELETE = 0xFFFF + INSERT = 0xFF63 # Cursor movement HOME = 0xFF50 @@ -338,13 +339,29 @@ class X11Keysym: ALT_L = 0xFFE9 ALT_R = 0xFFEA - # ASCII printable range (0x20-0x7E) maps directly - # For example: 'A' = 0x41, 'a' = 0x61, '0' = 0x30 + # ASCII / Latin-1 (U+0000–U+00FF): keysym equals code point (legacy X11 Latin-1 keysyms). + # Unicode outside that range: Guacamole / X11 use U+01000000 | codepoint (see Guacamole + # Keyboard.js keysym_from_unicode; X11 keysym Unicode extension). + _GUAC_UNICODE_KEYSYM_OFFSET = 0x01000000 + + @staticmethod + def keysym_from_unicode_codepoint(codepoint: int) -> int: + """ + Map a Unicode scalar to the X11 keysym value sent on the Guacamole ``key`` instruction. + + Code points U+0000-U+00FF use the direct keysym (ASCII + ISO 8859-1). U+0100 and + above (Cyrillic, Greek, CJK, emoji, etc.) use ``0x01000000 | codepoint``. + """ + if codepoint < 0 or codepoint > 0x10FFFF: + return 0 + if codepoint <= 0xFF: + return codepoint + return X11Keysym._GUAC_UNICODE_KEYSYM_OFFSET | codepoint @staticmethod def from_char(ch: str) -> int: - """Convert a single character to X11 keysym""" + """Convert a single character to X11 keysym (same rules as keysym_from_unicode_codepoint).""" if len(ch) == 1: - return ord(ch) + return X11Keysym.keysym_from_unicode_codepoint(ord(ch)) return 0 diff --git a/keepercommander/commands/pam_launch/guac_cli/input.py b/keepercommander/commands/pam_launch/guac_cli/input.py index 38a8dc78c..693077806 100644 --- a/keepercommander/commands/pam_launch/guac_cli/input.py +++ b/keepercommander/commands/pam_launch/guac_cli/input.py @@ -10,167 +10,227 @@ # """ -Input handler for Guacamole CLI mode. +Input handler for Guacamole CLI key-event mode. -Maps stdin keystrokes to Guacamole key instructions using X11 keysyms. -Handles control keys, special keys, and character input. +Maps stdin keystrokes to Guacamole `key` instructions (X11 keysyms) — +the default input path for `pam launch`, matching Web Vault behaviour where +Guacamole.Keyboard forwards every keystroke as sendKeyEvent(). + +Paste and Ctrl+C double-tap are handled via shared helpers from session_input: + • Ctrl+V / Shift+Insert → PasteOrchestrator → Vault clipboard stream + • Ctrl+C (single) → CtrlCCoordinator → remote interrupt via send_key + • Ctrl+C (double, 400ms) → CtrlCCoordinator → local session exit + +Windows extended keys (arrows, F-keys, Home/End …) are handled through a +ReadConsoleInput-based reader that emits standard VT100 escape sequences so +the existing _escape_to_keysym mapping works unchanged. Shift+Insert and +Ctrl+Shift+V are detected with modifier state and mapped to paste. """ from __future__ import annotations + +import collections import sys import logging import threading from typing import Optional, Callable + from .decoder import X11Keysym +from .session_input import CtrlCCoordinator, PasteOrchestrator +from .win_console_input import ( + win_stdin_disable_ctrl_c_process_input, + win_stdin_restore_console_mode, +) + +# Paste-chord sentinels (InputHandler internal) +# Ctrl+V (Unix raw + Windows uChar): 0x16 +_PASTE_BYTE = '\x16' +# Windows ReadConsoleInput distinguishes these from Ctrl+V: +_CHORD_CTRL_SHIFT_V = '\x17' +_CHORD_SHIFT_INSERT = '\x18' +_CHORD_CTRL_INSERT = '\x19' class InputHandler: """ - Handles stdin input and converts it to Guacamole key events. + Handles stdin input and converts every keystroke to a Guacamole `key` + instruction (press + release) via key_callback. - Reads from stdin in raw mode (non-buffered, non-echoing) and maps - keys to X11 keysyms for transmission via Guacamole protocol. + Paste chords are routed through PasteOrchestrator (local OS → Guacamole + clipboard) unless ``disable_paste`` is set; then they are sent as key + events so the remote uses its own clipboard. + Ctrl+C is routed through CtrlCCoordinator for double-tap exit logic. """ - def __init__(self, key_callback: Callable[[int, bool], None]): + def __init__( + self, + key_callback: Callable[[int, bool], None], + ctrl_c_coordinator: Optional[CtrlCCoordinator] = None, + paste_orchestrator: Optional[PasteOrchestrator] = None, + *, + disable_paste: bool = False, + ): """ - Initialize the input handler. - Args: - key_callback: Callback function(keysym, pressed) to send key events + key_callback: function(keysym: int, pressed: bool) — sends key events. + ctrl_c_coordinator: Shared double-tap Ctrl+C coordinator. When None, + Ctrl+C is forwarded as keysym 3 (ETX) like any other control char. + paste_orchestrator: Shared paste handler. When None, Ctrl+V and + Shift+Insert are forwarded as key events unchanged. + disable_paste: When True (PAM disablePaste), paste chords send Guacamole + key events (Ctrl+V, etc.) so the remote uses its own clipboard, not + the local OS clipboard stream. """ self.key_callback = key_callback + self.ctrl_c_coordinator = ctrl_c_coordinator + self.paste_orchestrator = paste_orchestrator + self.disable_paste = disable_paste self.running = False self.thread = None self.raw_mode_active = False - # Platform-specific stdin handler self.stdin_reader = self._get_stdin_reader() def _get_stdin_reader(self): - """Get platform-specific stdin reader""" if sys.platform == 'win32': return WindowsStdinReader() - else: - return UnixStdinReader() + return UnixStdinReader() def start(self): - """Start reading input in a background thread""" + """Start reading input in a background thread.""" if self.running: return - self.running = True self.stdin_reader.set_raw_mode() self.raw_mode_active = True - self.thread = threading.Thread(target=self._input_loop, daemon=True) self.thread.start() - logging.debug("Input handler started") + logging.debug('InputHandler started (key-event mode)') def stop(self): - """Stop reading input and restore terminal""" + """Stop reading input and restore terminal.""" self.running = False if self.raw_mode_active: self.stdin_reader.restore() self.raw_mode_active = False if self.thread: self.thread.join(timeout=1.0) - logging.debug("Input handler stopped") + logging.debug('InputHandler stopped') def _input_loop(self): - """Main input reading loop""" while self.running: try: ch = self.stdin_reader.read_char() if ch: self._process_input(ch) - except Exception as e: - logging.error(f"Error in input loop: {e}") + except Exception as exc: + logging.error(f'Error in input loop: {exc}') break + # Input processing + def _process_input(self, ch: str): """ - Process a character from stdin and generate key events. - - Args: - ch: Character or escape sequence from stdin + Process a single character (or the first character of a buffered + sequence) from stdin and emit the appropriate key event(s). """ - # Handle escape sequences for special keys - if ch == '\x1b': # ESC - # Try to read escape sequence + if not ch: + return + + code = ord(ch) if len(ch) == 1 else -1 + + # ESC / ANSI escape sequences + if ch == '\x1b': seq = self._read_escape_sequence() if seq: + # Shift+Insert → ESC[2~ on Unix (Windows uses _CHORD_SHIFT_INSERT). + if seq == '[2~': + if self.disable_paste: + self._send_shift_insert_chord() + return + if self.paste_orchestrator: + self.paste_orchestrator.paste() + return keysym = self._escape_to_keysym(seq) if keysym: self._send_key(keysym) else: - # Just ESC key self._send_key(X11Keysym.ESCAPE) + return - # Handle control characters - elif ord(ch) < 32: + # Ctrl+C double-tap + if code == 0x03 and self.ctrl_c_coordinator: + self.ctrl_c_coordinator.handle() + return + + # Paste chords: local clipboard stream vs key events (disablePaste) + if ch in ( + _PASTE_BYTE, + _CHORD_CTRL_SHIFT_V, + _CHORD_SHIFT_INSERT, + _CHORD_CTRL_INSERT, + ): + if self.disable_paste: + if ch == _PASTE_BYTE: + self._send_ctrl_v_chord() + elif ch == _CHORD_CTRL_SHIFT_V: + self._send_ctrl_shift_v_chord() + elif ch == _CHORD_SHIFT_INSERT: + self._send_shift_insert_chord() + else: + self._send_ctrl_insert_chord() + return + if self.paste_orchestrator: + self.paste_orchestrator.paste() + return + + # Other control characters + if 0 < code < 32: keysym = self._control_char_to_keysym(ch) if keysym: self._send_key(keysym) + return - # Handle DEL (127) - elif ord(ch) == 127: + # DEL + if code == 127: self._send_key(X11Keysym.BACKSPACE) + return - # Handle regular printable characters - else: - # For printable ASCII, keysym is just the character code - keysym = ord(ch) - self._send_key(keysym) + # Printable / Unicode + if code > 0: + self._send_key(X11Keysym.keysym_from_unicode_codepoint(code)) def _read_escape_sequence(self) -> Optional[str]: - """ - Read an ANSI escape sequence from stdin. - - Returns: - Escape sequence string (without ESC prefix) or None - """ - seq = "" - for _ in range(5): # Read up to 5 characters + """Read an ANSI escape sequence from stdin after the leading ESC.""" + seq = '' + for _ in range(8): ch = self.stdin_reader.read_char(timeout=0.05) - if ch: - seq += ch - # Common sequences end with a letter - if ch.isalpha() or ch == '~': - break - else: + if not ch: + break + seq += ch + if ch.isalpha() or ch == '~': break - return seq if seq else None def _escape_to_keysym(self, seq: str) -> Optional[int]: - """ - Map an ANSI escape sequence to X11 keysym. - - Args: - seq: Escape sequence (without ESC prefix) - - Returns: - X11 keysym or None - """ - # Common escape sequences + """Map an ANSI escape sequence (without leading ESC) to an X11 keysym.""" mappings = { - '[A': X11Keysym.UP, - '[B': X11Keysym.DOWN, - '[C': X11Keysym.RIGHT, - '[D': X11Keysym.LEFT, - '[H': X11Keysym.HOME, - '[F': X11Keysym.END, - '[1~': X11Keysym.HOME, - '[2~': 0xFFFF, # Insert - '[3~': X11Keysym.DELETE, - '[4~': X11Keysym.END, - '[5~': X11Keysym.PAGE_UP, - '[6~': X11Keysym.PAGE_DOWN, - 'OP': X11Keysym.F1, - 'OQ': X11Keysym.F2, - 'OR': X11Keysym.F3, - 'OS': X11Keysym.F4, + '[A': X11Keysym.UP, + '[B': X11Keysym.DOWN, + '[C': X11Keysym.RIGHT, + '[D': X11Keysym.LEFT, + '[H': X11Keysym.HOME, + '[F': X11Keysym.END, + '[1~': X11Keysym.HOME, + '[2~': X11Keysym.INSERT, + '[3~': X11Keysym.DELETE, + '[4~': X11Keysym.END, + '[5~': X11Keysym.PAGE_UP, + '[6~': X11Keysym.PAGE_DOWN, + 'OP': X11Keysym.F1, + 'OQ': X11Keysym.F2, + 'OR': X11Keysym.F3, + 'OS': X11Keysym.F4, '[15~': X11Keysym.F5, '[17~': X11Keysym.F6, '[18~': X11Keysym.F7, @@ -180,148 +240,333 @@ def _escape_to_keysym(self, seq: str) -> Optional[int]: '[23~': X11Keysym.F11, '[24~': X11Keysym.F12, } - + # Strip modifier suffix (e.g. [1;5A → [A) + if seq.startswith('[1;') and len(seq) >= 4: + final = seq[-1] + base = {'A': '[A', 'B': '[B', 'C': '[C', 'D': '[D'}.get(final) + if base: + return mappings.get(base) return mappings.get(seq) def _control_char_to_keysym(self, ch: str) -> Optional[int]: - """ - Map control character to X11 keysym. + """Map a control character (code < 32) to an X11 keysym.""" + code = ord(ch) + if code == 8: return X11Keysym.BACKSPACE + if code == 9: return X11Keysym.TAB + if code == 10: return X11Keysym.RETURN + if code == 13: return X11Keysym.RETURN + if code == 27: return X11Keysym.ESCAPE + # Ctrl+A … Ctrl+Z and other control codes: send as the raw code value. + # guacd maps ETX (3), EOT (4), etc. correctly for SSH/terminal use. + return code - Args: - ch: Control character + def _send_key(self, keysym: int): + """Emit a key press followed by a key release.""" + self.key_callback(keysym, True) + self.key_callback(keysym, False) - Returns: - X11 keysym or None - """ - code = ord(ch) + def _send_modifier_chord(self, modifiers: list[int], main_keysym: int) -> None: + """Press modifiers, press+release main key, release modifiers (remote TTY paste).""" + for m in modifiers: + self.key_callback(m, True) + self.key_callback(main_keysym, True) + self.key_callback(main_keysym, False) + for m in reversed(modifiers): + self.key_callback(m, False) - # Common control characters - if code == 8: # Backspace (Ctrl+H) - return X11Keysym.BACKSPACE - elif code == 9: # Tab - return X11Keysym.TAB - elif code == 10: # Line feed (Enter on Unix) - return X11Keysym.RETURN - elif code == 13: # Carriage return (Enter on Windows) - return X11Keysym.RETURN - elif code == 27: # ESC - return X11Keysym.ESCAPE - else: - # Ctrl+letter combinations (Ctrl+A = 1, Ctrl+B = 2, etc.) - # Send as lowercase letter with Ctrl modifier - # For simplicity, just send the control character as-is - # Guacamole can interpret it - return code + def _send_ctrl_v_chord(self) -> None: + self._send_modifier_chord([X11Keysym.CONTROL_L], ord('v')) - def _send_key(self, keysym: int): - """ - Send a key press and release event. + def _send_ctrl_shift_v_chord(self) -> None: + self._send_modifier_chord([X11Keysym.CONTROL_L, X11Keysym.SHIFT_L], ord('v')) - Args: - keysym: X11 keysym value - """ - # Send key press - self.key_callback(keysym, True) + def _send_shift_insert_chord(self) -> None: + self._send_modifier_chord([X11Keysym.SHIFT_L], X11Keysym.INSERT) + + def _send_ctrl_insert_chord(self) -> None: + self._send_modifier_chord([X11Keysym.CONTROL_L], X11Keysym.INSERT) - # Send key release - self.key_callback(keysym, False) +# Unix/macOS stdin reader class UnixStdinReader: - """Unix/Linux stdin reader with raw mode support""" + """Unix/macOS stdin reader with raw mode via termios.""" def __init__(self): self.old_settings = None def set_raw_mode(self): - """Set terminal to raw mode (non-buffered, non-echoing)""" try: - import termios - import tty + import termios, tty, time + sys.stdout.flush() + sys.stderr.flush() self.old_settings = termios.tcgetattr(sys.stdin.fileno()) tty.setraw(sys.stdin.fileno()) - except Exception as e: - logging.warning(f"Failed to set raw mode: {e}") + time.sleep(0.01) + sys.stdout.flush() + sys.stderr.flush() + except Exception as exc: + logging.warning(f'Failed to set raw mode: {exc}') def restore(self): - """Restore terminal to normal mode""" if self.old_settings: try: import termios - termios.tcsetattr(sys.stdin.fileno(), termios.TCSADRAIN, self.old_settings) - except Exception as e: - logging.warning(f"Failed to restore terminal: {e}") + termios.tcsetattr( + sys.stdin.fileno(), termios.TCSADRAIN, self.old_settings + ) + except Exception as exc: + logging.warning(f'Failed to restore terminal: {exc}') self.old_settings = None def read_char(self, timeout: Optional[float] = None) -> Optional[str]: - """ - Read a single character from stdin. - - Args: - timeout: Read timeout in seconds (None = blocking) - - Returns: - Character or None if timeout - """ - if timeout: + if timeout is not None: import select ready, _, _ = select.select([sys.stdin], [], [], timeout) if not ready: return None - try: return sys.stdin.read(1) - except: + except Exception: return None +# Windows stdin reader (ReadConsoleInput-based) + +# VT100 / xterm escape sequences for Windows VK codes. +# These are queued as individual chars so the existing _escape_to_keysym +# mapping in InputHandler works unchanged. +_VK_TO_ESC_SEQ: dict = { + 0x26: '\x1b[A', # VK_UP + 0x28: '\x1b[B', # VK_DOWN + 0x27: '\x1b[C', # VK_RIGHT + 0x25: '\x1b[D', # VK_LEFT + 0x24: '\x1b[H', # VK_HOME + 0x23: '\x1b[F', # VK_END + 0x2D: '\x1b[2~', # VK_INSERT + 0x2E: '\x1b[3~', # VK_DELETE + 0x21: '\x1b[5~', # VK_PRIOR (Page Up) + 0x22: '\x1b[6~', # VK_NEXT (Page Down) + 0x70: '\x1bOP', # VK_F1 + 0x71: '\x1bOQ', # VK_F2 + 0x72: '\x1bOR', # VK_F3 + 0x73: '\x1bOS', # VK_F4 + 0x74: '\x1b[15~', # VK_F5 + 0x75: '\x1b[17~', # VK_F6 + 0x76: '\x1b[18~', # VK_F7 + 0x77: '\x1b[19~', # VK_F8 + 0x78: '\x1b[20~', # VK_F9 + 0x79: '\x1b[21~', # VK_F10 + 0x7A: '\x1b[23~', # VK_F11 + 0x7B: '\x1b[24~', # VK_F12 +} + +_VK_INSERT = 0x2D +_VK_V = 0x56 +_SHIFT_PRESSED = 0x0010 +_LEFT_CTRL = 0x0008 +_RIGHT_CTRL = 0x0004 +_CTRL_PRESSED = _LEFT_CTRL | _RIGHT_CTRL +_KEY_EVENT = 0x0001 +_STD_INPUT_HANDLE = -10 + + class WindowsStdinReader: - """Windows stdin reader with raw mode support""" + """ + Windows console reader using ReadConsoleInputW for full modifier awareness. + + Paste chords are translated to sentinels so _process_input can route to + PasteOrchestrator or, when PAM disablePaste, to key chords (Ctrl+V, etc.). + + Navigation / function keys are translated to VT100 escape sequences and + queued one character at a time; InputHandler._read_escape_sequence drains + the queue transparently. + """ def __init__(self): - self.old_mode = None + self._queue: collections.deque = collections.deque() + self._hstdin = None + self._input_record_type = None + self._ready = False + self._init_win32() + + def _init_win32(self): + """Set up ctypes structures for ReadConsoleInputW.""" + try: + import ctypes + from ctypes import wintypes + + class _KeyEventRecord(ctypes.Structure): + _fields_ = [ + ('bKeyDown', wintypes.BOOL), + ('wRepeatCount', wintypes.WORD), + ('wVirtualKeyCode', wintypes.WORD), + ('wVirtualScanCode', wintypes.WORD), + ('uChar', wintypes.WCHAR), + ('dwControlKeyState', wintypes.DWORD), + ] + + class _EventUnion(ctypes.Union): + _fields_ = [ + ('KeyEvent', _KeyEventRecord), + ('_pad', ctypes.c_byte * 20), + ] + + class _InputRecord(ctypes.Structure): + _fields_ = [ + ('EventType', wintypes.WORD), + ('Event', _EventUnion), + ] + + self._InputRecord = _InputRecord + self._wintypes = wintypes + self._ctypes = ctypes + kernel32 = ctypes.windll.kernel32 + self._hstdin = kernel32.GetStdHandle(_STD_INPUT_HANDLE) + self._ReadConsoleInputW = kernel32.ReadConsoleInputW + self._WaitForSingleObject = kernel32.WaitForSingleObject + self._ready = True + except Exception as exc: + logging.warning(f'WindowsStdinReader: Win32 init failed, falling back to msvcrt: {exc}') + self._ready = False def set_raw_mode(self): - """Set console to raw mode on Windows""" - try: - import msvcrt - # Windows console is already non-buffered for getch - pass - except: - pass + import time + sys.stdout.flush() + sys.stderr.flush() + time.sleep(0.01) + sys.stdout.flush() + sys.stderr.flush() + # Ctrl+C as input (not SIGINT) so CtrlCCoordinator can handle double-tap. + self._win_saved_console_mode = win_stdin_disable_ctrl_c_process_input() def restore(self): - """Restore console mode""" - pass + win_stdin_restore_console_mode(self._win_saved_console_mode) + self._win_saved_console_mode = None def read_char(self, timeout: Optional[float] = None) -> Optional[str]: - """ - Read a single character from stdin on Windows. + # Drain queued chars first (from previously decoded escape sequences). + if self._queue: + return self._queue.popleft() + + if self._ready: + return self._read_via_console_input(timeout) + return self._read_via_msvcrt(timeout) + + def _read_via_console_input(self, timeout: Optional[float]) -> Optional[str]: + """Read one logical key event, emitting VT100 sequences for nav keys.""" + ctypes = self._ctypes + wintypes = self._wintypes + + while True: + if timeout is not None: + wait_ms = int(timeout * 1000) + result = self._WaitForSingleObject(self._hstdin, wait_ms) + if result != 0: # WAIT_OBJECT_0 = 0 + return None + + record = self._InputRecord() + n_read = wintypes.DWORD(0) + ok = self._ReadConsoleInputW( + self._hstdin, + ctypes.byref(record), + 1, + ctypes.byref(n_read), + ) + if not ok or n_read.value == 0: + return None - Args: - timeout: Read timeout in seconds (None = blocking) + if record.EventType != _KEY_EVENT: + continue - Returns: - Character or None if timeout - """ + key = record.Event.KeyEvent + if not key.bKeyDown: + continue # ignore key-up events + + vk = key.wVirtualKeyCode + ctrl = key.dwControlKeyState & _CTRL_PRESSED + shift = key.dwControlKeyState & _SHIFT_PRESSED + + # Ctrl+Shift+V + if vk == _VK_V and ctrl and shift: + return _CHORD_CTRL_SHIFT_V + + # Ctrl+V (plain) — some consoles omit uChar; match before uChar path + if vk == _VK_V and ctrl and not shift: + return _PASTE_BYTE + + # Shift+Insert + if vk == _VK_INSERT and shift and not ctrl: + return _CHORD_SHIFT_INSERT + + # Ctrl+Insert (some terminals) + if vk == _VK_INSERT and ctrl and not shift: + return _CHORD_CTRL_INSERT + + # Navigation / function keys → VT100 ESC sequence (queued) + if vk in _VK_TO_ESC_SEQ: + seq = _VK_TO_ESC_SEQ[vk] + for c in seq[1:]: + self._queue.append(c) + return seq[0] # '\x1b' — rest drained by _read_escape_sequence + + # Regular character from uChar (includes Ctrl+letter control codes) + ch = key.uChar + if ch and ord(ch) > 0: + return ch + + # Modifier-only or unhandled VK — loop for next event + + def _read_via_msvcrt(self, timeout: Optional[float]) -> Optional[str]: + """Fallback when ReadConsoleInput init failed.""" try: - import msvcrt - - if timeout: - # Poll for input with timeout - import time - start = time.time() - while time.time() - start < timeout: - if msvcrt.kbhit(): - ch = msvcrt.getch() - return ch.decode('utf-8', errors='ignore') - time.sleep(0.01) - return None - else: - # Blocking read - ch = msvcrt.getch() - return ch.decode('utf-8', errors='ignore') - except Exception as e: - logging.error(f"Error reading from stdin on Windows: {e}") + import msvcrt, time + start = time.time() + while True: + if timeout is not None and (time.time() - start) >= timeout: + return None + if msvcrt.kbhit(): + ch = msvcrt.getch() + # Extended key prefix — read second byte immediately + if ch in (b'\xe0', b'\x00'): + scan = msvcrt.getch() + seq = self._win_scan_to_esc(scan[0] if scan else 0) + if seq: + for c in seq[1:]: + self._queue.append(c) + return seq[0] + return None + return ch.decode('utf-8', errors='replace') + time.sleep(0.01) + except Exception as exc: + logging.error(f'msvcrt read error: {exc}') return None + @staticmethod + def _win_scan_to_esc(scan: int) -> Optional[str]: + """Map a Windows extended-key scan code to a VT100 escape sequence.""" + table = { + 0x48: '\x1b[A', # Up + 0x50: '\x1b[B', # Down + 0x4D: '\x1b[C', # Right + 0x4B: '\x1b[D', # Left + 0x47: '\x1b[H', # Home + 0x4F: '\x1b[F', # End + 0x52: '\x1b[2~', # Insert + 0x53: '\x1b[3~', # Delete + 0x49: '\x1b[5~', # Page Up + 0x51: '\x1b[6~', # Page Down + 0x3B: '\x1bOP', # F1 + 0x3C: '\x1bOQ', # F2 + 0x3D: '\x1bOR', # F3 + 0x3E: '\x1bOS', # F4 + 0x3F: '\x1b[15~', # F5 + 0x40: '\x1b[17~', # F6 + 0x41: '\x1b[18~', # F7 + 0x42: '\x1b[19~', # F8 + 0x43: '\x1b[20~', # F9 + 0x44: '\x1b[21~', # F10 + 0x85: '\x1b[23~', # F11 + 0x86: '\x1b[24~', # F12 + } + return table.get(scan) diff --git a/keepercommander/commands/pam_launch/guac_cli/instructions.py b/keepercommander/commands/pam_launch/guac_cli/instructions.py index 6c6c7c1f9..b93310b51 100644 --- a/keepercommander/commands/pam_launch/guac_cli/instructions.py +++ b/keepercommander/commands/pam_launch/guac_cli/instructions.py @@ -33,7 +33,32 @@ import base64 import logging import sys -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, cast + + +def is_stdout_pipe_stream_name(name: str) -> bool: + """True if Guacamole named pipe is the terminal STDOUT stream (case/whitespace tolerant).""" + if not name: + return False + return str(name).strip().casefold() == 'stdout' + + +def is_stdin_pipe_stream_name(name: str) -> bool: + """True if this named pipe is the client→server STDIN stream (do not treat as terminal output).""" + if not name: + return False + return str(name).strip().casefold() == 'stdin' + + +def _pipe_looks_like_terminal_stdout(mimetype: str, name: str) -> bool: + """ + Heuristic when guacr/gateway uses a non-STDOUT pipe name for TTY bytes (e.g. PAM clipboard flags). + Require text/* and exclude STDIN. Only used when no stdout stream is tracked yet. + """ + if is_stdin_pipe_stream_name(name): + return False + mt = (mimetype or '').strip().lower() + return mt == 'text/plain' or mt.startswith('text/') # Handler type: receives list of string arguments @@ -498,10 +523,29 @@ def router(opcode: str, args: List[str]) -> None: # Handle pipe - track STDOUT stream if opcode == 'pipe' and len(args) >= 3: stream_index, mimetype, name = args[0], args[1], args[2] - if name == 'STDOUT': + _note = getattr(stdout_stream_tracker, 'note_guac_pipe_instruction', None) + if callable(_note): + _note() + use_as_stdout = is_stdout_pipe_stream_name(name) + if ( + not use_as_stdout + and stdout_stream_tracker.stdout_stream_index == -1 + and _pipe_looks_like_terminal_stdout(mimetype, name) + ): + use_as_stdout = True + logging.debug( + 'CLI: using pipe name=%r mimetype=%r as terminal STDOUT (fallback)', + name, + mimetype, + ) + + if use_as_stdout: stdout_stream_tracker.stdout_stream_index = int(stream_index) send_ack_callback(stream_index, 'OK', '0') - logging.debug(f"STDOUT pipe opened on stream {stream_index}") + evt = getattr(stdout_stream_tracker, 'stdout_pipe_opened', None) + if evt is not None and hasattr(evt, 'set'): + evt.set() + logging.debug('Terminal output pipe on stream %s (name=%r)', stream_index, name) # Still call original handler for diagnostics handler = handlers.get(opcode) if handler: @@ -528,6 +572,11 @@ def router(opcode: str, args: List[str]) -> None: except Exception as e: logging.error(f"Error decoding STDOUT blob: {e}") return + # Inbound Guacamole clipboard stream (server → client) + clip_blob = getattr(stdout_stream_tracker, 'handle_remote_clipboard_blob', None) + if clip_blob is not None: + if cast(Callable[[str, str], bool], clip_blob)(args[0], args[1]): + return # Non-STDOUT blob falls through to default handler # Handle end - clear STDOUT tracking @@ -544,6 +593,10 @@ def router(opcode: str, args: List[str]) -> None: except Exception as e: logging.error(f"Error in end handler: {e}") return + clip_end = getattr(stdout_stream_tracker, 'handle_remote_clipboard_end', None) + if clip_end is not None: + if cast(Callable[[str], bool], clip_end)(args[0]): + return # Default routing handler = handlers.get(opcode) diff --git a/keepercommander/commands/pam_launch/guac_cli/session_input.py b/keepercommander/commands/pam_launch/guac_cli/session_input.py new file mode 100644 index 000000000..39dcbeab5 --- /dev/null +++ b/keepercommander/commands/pam_launch/guac_cli/session_input.py @@ -0,0 +1,137 @@ +# _ __ +# | |/ /___ ___ _ __ ___ _ _ ® +# | ' ,text/plain; + blob,,; + end,; + Never falls back to send_stdin for paste. If disablePaste is set the + chord is silently ignored (early warning is printed at session start). +""" + +from __future__ import annotations + +import logging +import time +from typing import Callable, Optional + +# Fixed Ctrl+C double-tap window (plan: 400 ms, within the 300–500 ms band). +CTRL_C_WINDOW: float = 0.4 + + +class CtrlCCoordinator: + """ + Double-tap Ctrl+C coordinator shared by InputHandler and StdinHandler. + + Args: + remote_interrupt_fn: Called on the *first* tap (or any tap outside the + window) to forward the interrupt to the remote session. + • Key mode : send_key(keysym=3, pressed) x 2 (press + release) + • Pipe mode : send_stdin(b'\\x03') + local_exit_fn: Called on the *second* tap inside the window to end the + local pam-launch session (sets shutdown_requested=True). + """ + + def __init__( + self, + remote_interrupt_fn: Callable[[], None], + local_exit_fn: Callable[[], None], + ) -> None: + self._remote_interrupt = remote_interrupt_fn + self._local_exit = local_exit_fn + self._last_ctrl_c: Optional[float] = None + + def handle(self) -> None: + """Call whenever Ctrl+C (byte 0x03) is detected in the input stream.""" + now = time.monotonic() + if ( + self._last_ctrl_c is not None + and (now - self._last_ctrl_c) <= CTRL_C_WINDOW + ): + # Second tap inside window → local exit + self._last_ctrl_c = None + print('\r\nExiting session...', flush=True) + self._local_exit() + else: + # First tap (or outside window) → remote interrupt only + self._last_ctrl_c = now + self._remote_interrupt() + + +class PasteOrchestrator: + """ + OS-clipboard → remote Guacamole clipboard stream. + + GuacamoleClipboard.setRemoteClipboard: + client.createClipboardStream(mimetype) → clipboard instruction + writer.sendText(data) → blob instruction + writer.sendEnd() → end instruction + + Args: + send_clipboard_fn: Callable(text: str) that formats and sends the + three-instruction clipboard stream to the gateway. Should be + GuacamoleHandler.send_clipboard_stream. + disable_paste: When True the chord is a silent no-op (warning already + printed at session start by launch.py execute()). + """ + + def __init__( + self, + send_clipboard_fn: Callable[[str], None], + disable_paste: bool = False, + ) -> None: + self._send_clipboard = send_clipboard_fn + self._disable_paste = disable_paste + + def paste(self) -> None: + """Trigger a clipboard paste to the remote session.""" + if self._disable_paste: + return + + try: + import pyperclip # type: ignore[import] + text = pyperclip.paste() + except ImportError: + msg = ( + 'Paste unavailable: pyperclip is not installed. ' + 'Run: pip install pyperclip' + ) + logging.warning(msg) + print(f'\r\n{msg}', flush=True) + return + except Exception as exc: + msg = f'Could not read clipboard: {exc}' + logging.warning(msg) + print(f'\r\n{msg}', flush=True) + return + + if not text: + return + + try: + self._send_clipboard(text) + logging.debug('Paste: %d chars sent via Guacamole clipboard stream', len(text)) + except Exception as exc: + msg = f'Failed to send clipboard to remote: {exc}' + logging.warning(msg) + print(f'\r\n{msg}', flush=True) diff --git a/keepercommander/commands/pam_launch/guac_cli/stdin_handler.py b/keepercommander/commands/pam_launch/guac_cli/stdin_handler.py index 53c6bdb63..60e51942a 100644 --- a/keepercommander/commands/pam_launch/guac_cli/stdin_handler.py +++ b/keepercommander/commands/pam_launch/guac_cli/stdin_handler.py @@ -34,32 +34,47 @@ import threading from typing import Callable, Optional +from .session_input import CtrlCCoordinator, PasteOrchestrator +from .win_console_input import ( + win_stdin_disable_ctrl_c_process_input, + win_stdin_restore_console_mode, +) + class StdinHandler: """ - Handles stdin input for plaintext SSH/TTY sessions. + Handles stdin input for plaintext SSH/TTY sessions (--stdin / pipe mode). + + Reads raw stdin in non-buffered mode and sends typed bytes via + stdin_callback (pipe/blob/end pattern, matching kcm-cli). - Reads raw stdin in non-buffered mode and sends data via callback. - Uses pipe/blob/end pattern matching kcm-cli implementation. + Escape sequences (arrow keys, function keys) are converted to X11 key + events via key_callback when provided. - Enhanced to detect escape sequences (arrow keys, function keys) and - send them as X11 key events instead of raw bytes. + Paste chords (Ctrl+V byte 0x16, Shift+Insert ESC[2~) and Ctrl+C double-tap + are handled via the shared CtrlCCoordinator / PasteOrchestrator helpers so + behaviour is identical to key-event mode (InputHandler). """ - def __init__(self, stdin_callback: Callable[[bytes], None], - key_callback: Optional[Callable[[int, bool], None]] = None): + def __init__( + self, + stdin_callback: Callable[[bytes], None], + key_callback: Optional[Callable[[int, bool], None]] = None, + ctrl_c_coordinator: Optional[CtrlCCoordinator] = None, + paste_orchestrator: Optional[PasteOrchestrator] = None, + ): """ - Initialize the stdin handler. - Args: - stdin_callback: Callback function(data: bytes) to send stdin data. - Should call GuacamoleHandler.send_stdin() - key_callback: Optional callback function(keysym: int, pressed: bool) - to send key events. Should call GuacamoleHandler.send_key() - If provided, escape sequences will be converted to key events. + stdin_callback: Sends typed bytes via GuacamoleHandler.send_stdin(). + key_callback: Sends key events via GuacamoleHandler.send_key(). + Required for escape-sequence → keysym conversion. + ctrl_c_coordinator: Shared double-tap Ctrl+C handler. + paste_orchestrator: Shared paste handler (clipboard → Guac stream). """ self.stdin_callback = stdin_callback self.key_callback = key_callback + self.ctrl_c_coordinator = ctrl_c_coordinator + self.paste_orchestrator = paste_orchestrator self.running = False self.thread: Optional[threading.Thread] = None self.raw_mode_active = False @@ -170,9 +185,13 @@ def _process_input(self, data: bytes): self._escape_buffer += bytes([byte]) keysym = self._detect_escape_sequence() if keysym is not None: - # Found a complete escape sequence - send as key event + # Found a complete escape sequence. + # INSERT (ESC[2~) is a paste chord in both modes. logging.debug(f"Detected escape sequence: {self._escape_buffer.hex()} -> keysym 0x{keysym:04X}") - self._send_key(keysym) + if keysym == 0xFF63 and self.paste_orchestrator: # INSERT → paste + self.paste_orchestrator.paste() + else: + self._send_key(keysym) self._escape_buffer = b'' i += 1 continue @@ -231,15 +250,28 @@ def _process_input(self, data: bytes): # End of data, wait for next read break - # Regular character - send as stdin - # But first check if it's a control character that might be part of an escape sequence - if byte < 32 and byte != 0x1B: # Control char but not ESC - # Send control characters as-is (they might be Ctrl+key combinations) + # Regular character - send as stdin. + # Normalize line endings: send \n for Enter so remote sees one newline (avoids double + # newlines when terminal sends \r or \r\n and remote echoes + app sends newline). + if byte == 0x0D: # \r (Enter on some terminals) + self.stdin_callback(b'\n') + if i + 1 < len(data) and data[i + 1] == 0x0A: # skip trailing \n in \r\n + i += 1 + elif byte == 0x03: # Ctrl+C — double-tap coordinator + if self.ctrl_c_coordinator: + self.ctrl_c_coordinator.handle() + else: + self.stdin_callback(bytes([byte])) + elif byte == 0x16: # Ctrl+V — paste chord + if self.paste_orchestrator: + self.paste_orchestrator.paste() + else: + self.stdin_callback(bytes([byte])) + elif byte < 32 and byte != 0x1B: # Other control chars as-is self.stdin_callback(bytes([byte])) elif byte >= 32: # Printable character self.stdin_callback(bytes([byte])) else: - # Shouldn't reach here, but send anyway self.stdin_callback(bytes([byte])) i += 1 @@ -564,7 +596,7 @@ class _WindowsStdinReader: """Windows stdin reader using msvcrt for console input.""" def __init__(self): - self.old_mode = None + self._win_saved_console_mode: Optional[int] = None def set_raw_mode(self): """Set console to raw mode on Windows.""" @@ -575,24 +607,19 @@ def set_raw_mode(self): sys.stdout.flush() sys.stderr.flush() - # Windows console is already suitable for getch-style reading - # No explicit raw mode needed for msvcrt, but we still flush and delay - # to prevent visual glitches when entering CLI mode - # Small delay to allow console to process any pending output - # This helps prevent visual glitches where lines appear to be deleted time.sleep(0.01) # 10ms delay - # Flush again after the delay sys.stdout.flush() sys.stderr.flush() + # Ctrl+C as input (not SIGINT) so CtrlCCoordinator can handle double-tap. + self._win_saved_console_mode = win_stdin_disable_ctrl_c_process_input() except Exception as e: logging.warning(f"Failed to set raw mode on Windows: {e}") def restore(self): - """Restore console mode.""" - # Nothing to restore for basic msvcrt usage - pass + win_stdin_restore_console_mode(self._win_saved_console_mode) + self._win_saved_console_mode = None def read(self, timeout: Optional[float] = None) -> Optional[bytes]: """ diff --git a/keepercommander/commands/pam_launch/guac_cli/win_console_input.py b/keepercommander/commands/pam_launch/guac_cli/win_console_input.py new file mode 100644 index 000000000..d2c04d05c --- /dev/null +++ b/keepercommander/commands/pam_launch/guac_cli/win_console_input.py @@ -0,0 +1,76 @@ +# _ __ +# | |/ /___ ___ _ __ ___ _ _ ® +# | ' Optional[int]: + """ + Clear ENABLE_PROCESSED_INPUT on the stdin console handle so Ctrl+C is read + as character 0x03 (ReadConsoleInput / msvcrt) instead of raising SIGINT. + + Returns the previous mode for win_stdin_restore_console_mode, or None if not + Windows, not a console, or the API failed. + """ + if sys.platform != 'win32': + return None + try: + import ctypes + from ctypes import wintypes + + kernel32 = ctypes.windll.kernel32 + h = kernel32.GetStdHandle(_STD_INPUT_HANDLE) + mode = wintypes.DWORD() + if not kernel32.GetConsoleMode(h, ctypes.byref(mode)): + return None + old = int(mode.value) + new = old & ~_ENABLE_PROCESSED_INPUT + if new == old: + return old + if not kernel32.SetConsoleMode(h, new): + logging.debug('SetConsoleMode(clear ENABLE_PROCESSED_INPUT) failed') + return None + return old + except Exception as exc: + logging.debug('win_stdin_disable_ctrl_c_process_input: %s', exc) + return None + + +def win_stdin_restore_console_mode(old_mode: Optional[int]) -> None: + """Restore stdin console mode from win_stdin_disable_ctrl_c_process_input.""" + if old_mode is None or sys.platform != 'win32': + return + try: + import ctypes + from ctypes import wintypes + + kernel32 = ctypes.windll.kernel32 + h = kernel32.GetStdHandle(_STD_INPUT_HANDLE) + if not kernel32.SetConsoleMode(h, old_mode): + logging.debug('SetConsoleMode(restore) failed') + except Exception as exc: + logging.debug('win_stdin_restore_console_mode: %s', exc) diff --git a/keepercommander/commands/pam_launch/launch.py b/keepercommander/commands/pam_launch/launch.py index 3646be303..8a6be003e 100644 --- a/keepercommander/commands/pam_launch/launch.py +++ b/keepercommander/commands/pam_launch/launch.py @@ -5,42 +5,256 @@ # |_| # # Keeper Commander -# Copyright 2024 Keeper Security Inc. +# Copyright 2026 Keeper Security Inc. # Contact: ops@keepersecurity.com # from __future__ import annotations import argparse +import os +import ipaddress import logging import re import shutil import signal -import sys import time -from typing import TYPE_CHECKING, Dict, Any, Optional +from typing import TYPE_CHECKING, Dict, Any, Optional, Tuple from keeper_secrets_manager_core.utils import url_safe_str_to_bytes -from .terminal_connection import launch_terminal_connection +from .terminal_connection import ( + _build_connect_as_payload, + _retrieve_gateway_public_key, + _get_launch_credential_uid, + launch_terminal_connection, + detect_protocol, + ALL_TERMINAL, + CONNECT_AS_MIN_VERSION, + _version_at_least, + _pam_settings_connection_port, +) from .terminal_size import get_terminal_size_pixels, is_interactive_tty from .guac_cli.stdin_handler import StdinHandler +from .guac_cli.input import InputHandler +from .guac_cli.session_input import CtrlCCoordinator, PasteOrchestrator from ..base import Command from ..tunnel.port_forward.tunnel_helpers import ( get_gateway_uid_from_record, get_config_uid_from_record, + get_tunnel_session, unregister_tunnel_session, unregister_conversation_key, ) +from .rust_log_filter import ( + enter_pam_launch_terminal_rust_logging, + exit_pam_launch_terminal_rust_logging, +) from ..pam.gateway_helper import get_all_gateways from ..pam.router_helper import router_get_connected_gateways +from ..ssh_agent import try_extract_private_key from ... import api, vault from ...subfolder import try_resolve_path from ...error import CommandError +from ...utils import value_to_boolean if TYPE_CHECKING: from ...params import KeeperParams +def _pam_connection_clipboard_bool(v: Any) -> bool: + """True if a PAM connection clipboard flag is enabled; coerces JSON/string booleans.""" + if isinstance(v, bool): + return v + if v is None: + return False + b = value_to_boolean(v) + return b is True + + +def _pam_connection_font_size_int(raw: Any) -> Optional[int]: + """Parse pamSettings.connection.fontSize to int, or None if unset or not parseable as an integer size.""" + if raw is None: + return None + if isinstance(raw, bool): + return None + if isinstance(raw, int): + return raw + if isinstance(raw, float): + return int(raw) if raw.is_integer() else None + if isinstance(raw, str): + s = raw.strip() + if not s: + return None + try: + return int(s) + except ValueError: + return None + try: + return int(raw) + except (TypeError, ValueError): + return None + + +def _parse_host_port(value: str) -> Tuple[str, int]: + """ + Parse a 'host:port' or '[ipv6]:port' string into (host, port). + + Supported formats: + - IPv4 / hostname: 192.168.1.1:22 or server.example.com:3306 + - IPv6: [::1]:22 or [2001:db8::1]:443 + + Raises: + CommandError: if the format is invalid or the port is out of range. + """ + value = value.strip() + if value.startswith('['): + end_bracket = value.find(']') + if end_bracket == -1: + raise CommandError('pam launch', + f'Invalid host format {value!r}. Expected [ipv6]:port (e.g. [::1]:22).') + host = value[1:end_bracket] + rest = value[end_bracket + 1:] + if not rest.startswith(':'): + raise CommandError('pam launch', + f'Invalid host format {value!r}. Expected [ipv6]:port (e.g. [::1]:22).') + port_str = rest[1:] + elif ':' in value: + last_colon = value.rfind(':') + host = value[:last_colon] + port_str = value[last_colon + 1:] + else: + raise CommandError('pam launch', + f'Invalid host format {value!r}. Expected host:port (e.g. 192.168.1.1:22 or server.example.com:3306).') + try: + port = int(port_str) + except ValueError: + raise CommandError('pam launch', + f'Invalid port {port_str!r} in {value!r}. Port must be an integer 1-65535.') + _validate_host_port(host, port) + return host, port + + +def _validate_host_port(host: str, port: int) -> None: + """ + Validate host (non-empty, valid IPv4/IPv6 or hostname) and port (1-65535). + Raises CommandError if invalid. + """ + if not host: + raise CommandError('pam launch', 'Host cannot be empty.') + if not (1 <= port <= 65535): + raise CommandError('pam launch', f'Port {port} is out of range (valid range: 1-65535).') + # Attempt strict IP validation; if it raises ValueError the host is treated as a hostname + # (any non-empty hostname string is accepted — the gateway does the DNS resolution). + try: + ipaddress.ip_address(host) + except ValueError: + pass # Not an IP literal — treat as hostname, basic non-empty check above is sufficient + + +def _iter_record_fields(record: Any): + """Yield every TypedField from both record.fields and record.custom.""" + for field in list(getattr(record, 'fields', None) or []) + list(getattr(record, 'custom', None) or []): + yield field + + +def _get_host_port_from_record(record: Any) -> Tuple[Optional[str], Optional[int]]: + """ + Extract (hostName, port) from a record's pamHostname or host typed fields. + + Requires a non-empty hostName on exactly one such field. Port comes from + pamSettings.connection.port when the record is pamMachine/pamDirectory/pamDatabase + and that port is set (overrides the field's port); otherwise from the field's port. + + Raises CommandError if more than one qualifying host field is found (ambiguous). + + Returns: + Tuple of (host, port) where either may be None if none found. + """ + if not record: + return None, None + + pam_override_port = _pam_settings_connection_port(record) + candidates: list = [] + for field in _iter_record_fields(record): + if getattr(field, 'type', None) not in ('pamHostname', 'host'): + continue + value = field.get_default_value(dict) if hasattr(field, 'get_default_value') else {} + if not isinstance(value, dict): + continue + host = (value.get('hostName') or '').strip() + if not host: + continue + port_raw = pam_override_port if pam_override_port is not None else value.get('port') + if not port_raw: + continue + try: + p = int(port_raw) + except (ValueError, TypeError): + continue + if 1 <= p <= 65535: + candidates.append((host, p)) + + if len(candidates) > 1: + raise CommandError('pam launch', + f'Record has {len(candidates)} non-empty host/pamHostname fields with valid host and port ' + '(expected exactly one). Clear the extra field before launching.') + if not candidates: + return None, None + return candidates[0] + + +def _record_has_credentials(record: Any, params: Optional['KeeperParams'] = None) -> bool: + """ + Return True if the record has exactly one non-empty login field and at least one of: + - exactly one non-empty password field (fields[] and custom[]), or + - a usable SSH private key (same discovery as the launch path: keyPair, notes, custom fields, + attachments), when ``params`` is given so attachments can be resolved. + + Raises CommandError if multiple non-empty login or password fields are found (ambiguous). + """ + if not record: + return False + + def _count_nonempty(field_type: str) -> int: + count = 0 + for field in _iter_record_fields(record): + if getattr(field, 'type', None) == field_type: + val = field.get_default_value(str) if hasattr(field, 'get_default_value') else '' + if val: + count += 1 + return count + + login_count = _count_nonempty('login') + if login_count > 1: + raise CommandError('pam launch', + f'Record has {login_count} non-empty login fields (expected exactly one). ' + 'Clear the extra login field before launching.') + if login_count == 0: + return False + + password_count = _count_nonempty('password') + if password_count > 1: + raise CommandError('pam launch', + f'Record has {password_count} non-empty password fields (expected exactly one). ' + 'Clear the extra password field before launching.') + if password_count == 1: + return True + + # No password: SSH (and similar) may authenticate with a private key only. + if password_count == 0 and params is not None: + key_result = try_extract_private_key(params, record) + if key_result and key_result[0]: + return True + + return False + + +def _record_has_host_port(record: Any) -> bool: + """Return True if the record has exactly one non-empty host/pamHostname field with valid host and port.""" + host, port = _get_host_port_from_record(record) + return bool(host) and port is not None + + class PAMLaunchCommand(Command): """PAM Launch command to launch a connection to a PAM resource""" @@ -53,12 +267,18 @@ class PAMLaunchCommand(Command): parser.add_argument('--no-trickle-ice', '-nti', required=False, dest='no_trickle_ice', action='store_true', help='Disable trickle ICE for WebRTC connections. By default, trickle ICE is enabled ' 'for real-time candidate exchange.') - # parser.add_argument('--user', '-u', required=False, dest='launch_credential_uid', type=str, - # help='UID of pamUser record to use as launch credentials when allowSupplyUser is enabled. ' - # 'Fails if allowSupplyUser is not enabled or the specified record is not found.') - # parser.add_argument('--host', '-H', required=False, dest='custom_host', type=str, - # help='Hostname or IP address to connect to when allowSupplyHost is enabled. ' - # 'Fails if allowSupplyHost is not enabled.') + parser.add_argument('--credential', '-cr', required=False, dest='launch_credential', type=str, + help='Record (UID, path, or title) for launch credentials') + parser.add_argument('--host', '-H', required=False, dest='custom_host', type=str, + help='Host and port in format host:port (e.g. -H=192.168.1.1:22 or -H=[::1]:22 for IPv6). ' + 'Requires allowSupplyHost. Mutually exclusive with --host-record.') + parser.add_argument('--host-record', '-hr', required=False, dest='host_record', type=str, + help='Record (UID, path, or title) with a host or pamHostname field containing hostName and port. ' + 'Requires allowSupplyHost. Mutually exclusive with --host.') + parser.add_argument('--stdin', required=False, dest='use_stdin', action='store_true', + help='Send typed input via stdin pipe bytes (pipe/blob/end, kcm-cli style) instead of ' + 'the default Guacamole key-event mode. Paste and Ctrl+C double-tap behave the ' + 'same in both modes.') def get_parser(self): return PAMLaunchCommand.parser @@ -95,9 +315,6 @@ def find_record(self, params: KeeperParams, record_token: str) -> Optional[str]: Returns: Record UID if found, None otherwise - - Raises: - CommandError: If multiple records match """ if not record_token: return None @@ -108,13 +325,8 @@ def find_record(self, params: KeeperParams, record_token: str) -> Optional[str]: uid_pattern = re.compile(r'^[A-Za-z0-9_-]{22}$') if uid_pattern.match(record_token): if record_token in params.record_cache: - # Validate it's a PAM record type - if self._is_valid_pam_record(params, record_token): - logging.debug(f"Found record by UID: {record_token}") - return record_token - else: - logging.debug(f"Record {record_token} found but is not a valid PAM record type") - return None + logging.debug(f"Found record by UID: {record_token}") + return record_token # Step 2: Try path lookup record_uid = self._find_by_path(params, record_token) @@ -132,15 +344,12 @@ def _find_by_path(self, params: KeeperParams, path: str) -> Optional[str]: """ Find record by path resolution. - Args: - params: KeeperParams instance - path: Path to the record + If exactly one record matches (any type), returns its UID. If two or more + match, filters to PAM types only: returns the single PAM UID if one, + else logs error (no PAM types vs multiple PAM matches) and returns None. Returns: Record UID if found, None otherwise - - Raises: - CommandError: If multiple records match """ rs = try_resolve_path(params, path) if rs is None: @@ -154,21 +363,33 @@ def _find_by_path(self, params: KeeperParams, path: str) -> Optional[str]: if folder_uid not in params.subfolder_record_cache: return None - # Find all records in the folder with matching title (only valid PAM types) - matched_uids = [] + # All records in folder with matching title (any type) + all_matched = [] for uid in params.subfolder_record_cache[folder_uid]: r = api.get_record(params, uid) if r and r.title and r.title.lower() == name.lower(): - # Only include valid PAM record types - if self._is_valid_pam_record(params, uid): - matched_uids.append(uid) - - if len(matched_uids) > 1: - raise CommandError('pam launch', f'Multiple valid PAM records found with path "{path}". Please use a unique identifier.') - - if matched_uids: - logging.debug(f"Found record by path: {path} -> {matched_uids[0]}") - return matched_uids[0] + all_matched.append(uid) + + if len(all_matched) == 1: + logging.debug(f"Found record by path: {path} -> {all_matched[0]}") + return all_matched[0] + + if len(all_matched) >= 2: + pam_matched = [uid for uid in all_matched if self._is_valid_pam_record(params, uid)] + if len(pam_matched) == 1: + logging.debug(f"Found record by path: {path} -> {pam_matched[0]} (1 PAM among {len(all_matched)} matches)") + return pam_matched[0] + if len(pam_matched) == 0: + logging.error( + 'pam launch: path "%s" matches %d record(s) but none are PAM types (pamMachine, pamDirectory, pamDatabase). Use UID or a path that resolves to a single PAM record.', + path, len(all_matched), + ) + return None + logging.error( + 'pam launch: path "%s" matches %d PAM records. Please use a unique identifier (UID or full path).', + path, len(pam_matched), + ) + return None return None @@ -176,30 +397,39 @@ def _find_by_title(self, params: KeeperParams, title: str) -> Optional[str]: """ Find record by exact title match. - Args: - params: KeeperParams instance - title: Title to match + If exactly one record matches (any type), returns its UID. If two or more + match, filters to PAM types only: returns the single PAM UID if one, + else logs error (no PAM types vs multiple PAM matches) and returns None. Returns: Record UID if found, None otherwise - - Raises: - CommandError: If multiple records match """ - matched_uids = [] + all_matched = [] for record_uid in params.record_cache: record = vault.KeeperRecord.load(params, record_uid) if record and record.title and record.title.lower() == title.lower(): - # Only include valid PAM record types - if self._is_valid_pam_record(params, record_uid): - matched_uids.append(record_uid) - - if len(matched_uids) > 1: - raise CommandError('pam launch', f'Multiple valid PAM records found with title "{title}". Please use a unique identifier (UID or full path).') - - if matched_uids: - logging.debug(f"Found record by title: {title} -> {matched_uids[0]}") - return matched_uids[0] + all_matched.append(record_uid) + + if len(all_matched) == 1: + logging.debug(f"Found record by title: {title} -> {all_matched[0]}") + return all_matched[0] + + if len(all_matched) >= 2: + pam_matched = [uid for uid in all_matched if self._is_valid_pam_record(params, uid)] + if len(pam_matched) == 1: + logging.debug(f"Found record by title: {title} -> {pam_matched[0]} (1 PAM among {len(all_matched)} matches)") + return pam_matched[0] + if len(pam_matched) == 0: + logging.error( + 'pam launch: title "%s" matches %d record(s) but none are PAM types (pamMachine, pamDirectory, pamDatabase). Use UID or full path.', + title, len(all_matched), + ) + return None + logging.error( + 'pam launch: title "%s" matches %d PAM records. Please use a unique identifier (UID or full path).', + title, len(pam_matched), + ) + return None return None @@ -270,6 +500,12 @@ def execute(self, params: KeeperParams, **kwargs): root_logger.setLevel(logging.ERROR) try: + # TODO: Add JIT - note that allowSupplyHost overrides all other supply modes. + # When a PAM record has allowSupplyHost, allowSupplyUser, and JIT settings all enabled, + # the Web Vault (and this CLI) treat allowSupplyHost as the active mode and ignore the + # other two. Any validation logic below must reflect this precedence: if allowSupplyHost + # is True, treat the record as "host+credential supply" mode regardless of the other flags. + record_token = kwargs.get('record') if not record_token: @@ -284,94 +520,263 @@ def execute(self, params: KeeperParams, **kwargs): logging.debug(f"Found record: {record_uid}") - # Validate --user and --host parameters against allowSupply flags - # Note: cmdline options override record data when provided - # launch_credential_uid = kwargs.get('launch_credential_uid') - # custom_host = kwargs.get('custom_host') - - # Load record to check allowSupply flags and existing values - # record = vault.KeeperRecord.load(params, record_uid) - # if not isinstance(record, vault.TypedRecord): - # raise CommandError('pam launch', f'Record {record_uid} is not a TypedRecord') - - # pam_settings_field = record.get_typed_field('pamSettings') - # allow_supply_user = False - # allow_supply_host = False - # user_records_on_record = [] - # hostname_on_record = None - - # Get hostname from record - # hostname_field = record.get_typed_field('pamHostname') - # if hostname_field: - # host_value = hostname_field.get_default_value(dict) - # if host_value: - # hostname_on_record = host_value.get('hostName') - - # if pam_settings_field: - # pam_settings_value = pam_settings_field.get_default_value(dict) - # if pam_settings_value: - # # allowSupplyHost is at top level of pamSettings value - # allow_supply_host = pam_settings_value.get('allowSupplyHost', False) - # # allowSupplyUser is inside connection - # connection = pam_settings_value.get('connection', {}) - # if isinstance(connection, dict): - # allow_supply_user = connection.get('allowSupplyUser', False) - # user_records_on_record = connection.get('userRecords', []) - - # Validation based on allowSupply flags - # if allow_supply_host and allow_supply_user: - # # Both flags true: --user is required (no fallback to userRecords) - # if not launch_credential_uid: - # raise CommandError('pam launch', - # f'Both allowSupplyUser and allowSupplyHost are enabled. ' - # f'You must provide --user to specify launch credentials.') - # # --host required if no hostname on record - # if not custom_host and not hostname_on_record: - # raise CommandError('pam launch', - # f'Both allowSupplyUser and allowSupplyHost are enabled and no hostname on record. ' - # f'You must provide --host to specify the target host.') - - # elif allow_supply_user and not allow_supply_host: - # # Only allowSupplyUser: use --user if provided, else userRecords, else error - # if not launch_credential_uid and not user_records_on_record: - # raise CommandError('pam launch', - # f'allowSupplyUser is enabled but no credentials available. ' - # f'Use --user to specify a pamUser record or configure userRecords on the record.') - - # elif allow_supply_host and not allow_supply_user: - # # Only allowSupplyHost: --host required if no hostname on record - # if not custom_host and not hostname_on_record: - # raise CommandError('pam launch', - # f'allowSupplyHost is enabled but no hostname available. ' - # f'Use --host to specify the target host or configure hostname on the record.') - - # Validate --user parameter if provided - # if launch_credential_uid: - # if not allow_supply_user: - # raise CommandError('pam launch', - # f'--user parameter requires allowSupplyUser to be enabled on the record. ' - # f'allowSupplyUser is currently disabled for record {record_uid}.') - - # # Validate the launch credential record exists and is a pamUser - # cred_record = vault.KeeperRecord.load(params, launch_credential_uid) - # if not cred_record: - # raise CommandError('pam launch', - # f'Launch credential record not found: {launch_credential_uid}') - # if not isinstance(cred_record, vault.TypedRecord) or cred_record.record_type != 'pamUser': - # raise CommandError('pam launch', - # f'Launch credential record {launch_credential_uid} must be a pamUser record. ' - # f'Found: {cred_record.record_type if isinstance(cred_record, vault.TypedRecord) else "non-typed"}') - - # logging.debug(f"Using custom launch credential: {launch_credential_uid}") - - # Validate --host parameter if provided - # if custom_host: - # if not allow_supply_host: - # raise CommandError('pam launch', - # f'--host parameter requires allowSupplyHost to be enabled on the record. ' - # f'allowSupplyHost is currently disabled for record {record_uid}.') - - # logging.debug(f"Using custom host: {custom_host}") + record = vault.KeeperRecord.load(params, record_uid) + if not isinstance(record, vault.TypedRecord): + raise CommandError('pam launch', f'Record {record_uid} is not a TypedRecord') + + if not self._is_valid_pam_record(params, record_uid): + record_type = getattr(record, 'record_type', type(record).__name__) + raise CommandError('pam launch',f'Record {record_uid} of type "{record_type}" is not a machine record type (pamMachine, pamDirectory, pamDatabase)') + + # Only terminal protocols are supported (SSH, Telnet, Kubernetes, databases). + protocol = detect_protocol(params, record_uid) + if protocol not in ALL_TERMINAL: + logging.error( + "pam launch only supports terminal protocols (ssh, telnet, kubernetes, mysql, postgresql, sql-server). " + "Protocol %r is not supported; use Web Vault for RDP/VNC/RBI etc.", + protocol, + ) + return + + # Get DAG-linked credential UID early (needed for comparison and validation) + dag_linked_uid = _get_launch_credential_uid(params, record_uid) + if not dag_linked_uid: + # Fallback: first entry in pamSettings.connection.userRecords + _psf = record.get_typed_field('pamSettings') + if _psf: + _psv = _psf.get_default_value(dict) + if _psv: + _conn = _psv.get('connection', {}) + if isinstance(_conn, dict): + _ur = _conn.get('userRecords', []) + if _ur: + dag_linked_uid = _ur[0] + + # Read allowSupply flags and other connection settings from pamSettings + pam_settings_field = record.get_typed_field('pamSettings') + allow_supply_user = False + allow_supply_host = False + pam_connection_font_size: Any = None + if pam_settings_field: + pam_settings_value = pam_settings_field.get_default_value(dict) + if pam_settings_value: + allow_supply_host = pam_settings_value.get('allowSupplyHost', False) + connection = pam_settings_value.get('connection', {}) + if isinstance(connection, dict): + allow_supply_user = connection.get('allowSupplyUser', False) + pam_connection_font_size = connection.get('fontSize') + if _pam_connection_clipboard_bool(connection.get('readOnly')): + raise CommandError( + 'pam launch', + f'Record {record_uid} has connection.readOnly:true. That setting is only ' + 'meaningful when joining an existing session; starting a new session from ' + 'Commander is not a read-only join flow. Set readOnly:false on the record ' + 'for CLI launch, or use Web Vault when joining a session read-only.', + ) + if connection.get('disableCopy'): + raise CommandError( + 'pam launch', + f'Record {record_uid} has disableCopy:true and KCM blocks terminal STDOUT ' + 'in that configuration, so Commander cannot run a CLI session. ' + 'Use Web Vault or another client, or set disableCopy:false on the record ' + 'if CLI terminal access is required.', + ) + if _pam_connection_clipboard_bool(connection.get('disablePaste')): + raise CommandError( + 'pam launch', + f'Record {record_uid} has disablePaste:true and a terminal session cannot ' + 'reliably distinguish pasted text from normal keystrokes, so Commander ' + 'cannot run a CLI session. Use Web Vault or another client, or set ' + 'disablePaste:false on the record if CLI terminal access is required.', + ) + # If we want to ignore disablePaste - uncomment below (still need --stdin check) + # disablePaste is incompatible with --stdin (pipe mode) + # if connection.get('disablePaste') and kwargs.get('use_stdin'): + # raise CommandError( + # 'pam launch', + # f'Record {record_uid} has disablePaste:true and KCM does not expose the ' + # 'STDIN pipe that --stdin requires. Run pam launch without --stdin to use ' + # 'default key-event driven input mode, or set disablePaste:false on the record.', + # ) + # Do not remove: + # There's no reliable way to detect pasted text from typed input alone + # so either disable this command or ignore disablePaste i.e. allow paste. + # - Paste is impossible to detect/block in CLI mode - all characters are received/read as console input. + # - Capturing raw key events is platform dependent and usually require admin access. + # - Python modules like keyboard/pyinput capture global keys (not per window/terminal) + # and often need admin/root privileges to work (driver level event capture). + # - Burst detection is unreliable and eating typed characters during auto-repeat/fast typing. + + # Get record host/port for fallback validation + hostname_on_record, port_on_record = _get_host_port_from_record(record) + + # --- Resolve --credential option --- + launch_credential = kwargs.get('launch_credential') + launch_credential_uid = None + if launch_credential: + # Reject early — before record resolution — when neither supply flag permits it. + # (With host options the flag requirement is checked later; here we only gate the + # case where -cr alone requires at least one supply flag to be meaningful.) + if not allow_supply_user and not allow_supply_host: + raise CommandError('pam launch', + '--credential requires allowSupplyUser or allowSupplyHost to be enabled on the record.') + launch_credential_uid = self.find_record(params, launch_credential) + if not launch_credential_uid: + raise CommandError('pam launch', f'Credential record not found: {launch_credential}') + + # --- Parse --host / --host-record (mutually exclusive) --- + raw_custom_host = kwargs.get('custom_host') + host_record_token = kwargs.get('host_record') + custom_host = None + custom_port = None + + # All -H/-hr checks happen BEFORE any record resolution to give the right error first. + + # -H and -hr are mutually exclusive (conflicting options prevent execution). + if raw_custom_host and host_record_token: + raise CommandError('pam launch', + 'Cannot use both --host and --host-record. Use one to specify the target host.') + + # Options conflict: -H/-hr require -cr (Web Vault: host and credentials supplied together). + if (raw_custom_host or host_record_token) and not launch_credential: + raise CommandError('pam launch', + '--host / --host-record requires --credential (-cr) to also be provided. ' + 'When allowSupplyHost is enabled, credentials and host must be supplied together.') + + # allowSupplyHost must be enabled to use -H/-hr at all. + if (raw_custom_host or host_record_token) and not allow_supply_host: + raise CommandError('pam launch', + '--host / --host-record requires allowSupplyHost to be enabled on the record. ' + '(Web Vault: Record > Allow shared users to select their own host and credential)') + + if raw_custom_host: + custom_host, custom_port = _parse_host_port(raw_custom_host) + kwargs['custom_host'] = custom_host + kwargs['custom_port'] = custom_port + logging.debug(f"Parsed --host: {custom_host}:{custom_port}") + + if host_record_token: + host_record_uid = self.find_record(params, host_record_token) + if not host_record_uid: + raise CommandError('pam launch', f'Host record not found: {host_record_token}') + host_record = vault.KeeperRecord.load(params, host_record_uid) + if not host_record: + raise CommandError('pam launch', f'Could not load host record: {host_record_uid}') + custom_host, custom_port = _get_host_port_from_record(host_record) + if not custom_host: + raise CommandError('pam launch', + f'Record {host_record_token} has no hostname. ' + 'It must have a host or pamHostname field with hostName.') + if custom_port is None: + raise CommandError('pam launch', + f'Record {host_record_token} has no valid port (1-65535). ' + 'It must have a host or pamHostname field with a port.') + kwargs['custom_host'] = custom_host + kwargs['custom_port'] = custom_port + logging.debug(f"Using host from record {host_record_uid}: {custom_host}:{custom_port}") + + has_cli_host = custom_host is not None + has_cli_cred = launch_credential_uid is not None + + # --credential record with no host options that matches DAG-linked -> treat as no --credential + if has_cli_cred and not has_cli_host and launch_credential_uid == dag_linked_uid: + logging.warning( + '--credential %s matches linked Launch Credential; treating as if no --credential provided', + launch_credential, + ) + launch_credential_uid = None + has_cli_cred = False + + # --host / --host-record require allowSupplyHost + if has_cli_host and not allow_supply_host: + raise CommandError('pam launch', + '--host / --host-record requires allowSupplyHost to be enabled on the record. ' + '(Web Vault: Record > Allow shared users to select their own host and credential)') + + if has_cli_cred: + # with host options -> allowSupplyHost; without -> allowSupplyUser or allowSupplyHost + if has_cli_host: + if not allow_supply_host: + raise CommandError('pam launch', + '--credential with --host/--host-record requires allowSupplyHost to be enabled.') + else: + if not allow_supply_user and not allow_supply_host: + raise CommandError('pam launch', + '--credential requires allowSupplyUser or allowSupplyHost to be enabled on the record.') + + # Strictly validate --credential record has login and password + cred_record = vault.KeeperRecord.load(params, launch_credential_uid) + if not cred_record: + raise CommandError('pam launch', f'Credential record not found: {launch_credential_uid}') + if not _record_has_credentials(cred_record, params): + raise CommandError('pam launch', + f'Credential record {launch_credential_uid} must have non-empty login and ' + 'password, or login with an SSH private key.') + + if allow_supply_host: + # allowSupplyHost mode: host comes from -H/-hr (CLI) or from the --credential record. + if has_cli_host: + # -H/-hr provided: CLI host wins. Warn if --credential also has a host. + if _record_has_host_port(cred_record): + _cr_host, _ = _get_host_port_from_record(cred_record) + logging.warning( + '--host / --host-record (%s:%s) overrides host %r from --credential record %s; ' + 'the credential record host will be ignored.', + custom_host, custom_port, _cr_host, launch_credential_uid, + ) + else: + # no -H/-hr -> --credential record must supply host:port. + if not _record_has_host_port(cred_record): + raise CommandError('pam launch', + f'Credential record {launch_credential_uid} must have a non-empty host and port ' + 'when allowSupplyHost is enabled and no --host or --host-record is provided.') + cred_host, cred_port = _get_host_port_from_record(cred_record) + custom_host = cred_host + custom_port = cred_port + kwargs['custom_host'] = custom_host + kwargs['custom_port'] = custom_port + logging.debug(f"Using host from --credential record: {custom_host}:{custom_port}") + + else: + # allowSupplyUser mode: only login + password come from --credential. + # Any host/pamHostname on the --credential record is intentionally ignored; + # host and port always come from the PAM machine/connection record. + if _record_has_host_port(cred_record): + _cr_host, _ = _get_host_port_from_record(cred_record) + logging.warning( + 'allowSupplyUser mode: host %r in --credential record %s is ignored; ' + 'host and port will come from the PAM machine record.', + _cr_host, launch_credential_uid, + ) + + kwargs['launch_credential_uid'] = launch_credential_uid + logging.debug(f"Using --credential: {launch_credential_uid}") + + else: + # No --credential: validate that the record itself provides what's needed + if not has_cli_host: + # No CLI host -> must come from the PAM launch record + if not hostname_on_record: + if allow_supply_host: + raise CommandError('pam launch', + 'allowSupplyHost is enabled but no hostname on record. ' + 'Use --host, --host-record, or --credential with a host:port to specify.') + else: + raise CommandError('pam launch', + f'No hostname configured for record {record_uid}.') + + # No CLI options at all -> validate DAG-linked credential has login + password or SSH key + if dag_linked_uid: + dag_cred_record = vault.KeeperRecord.load(params, dag_linked_uid) + if dag_cred_record and not _record_has_credentials(dag_cred_record, params): + raise CommandError('pam launch', + f'Linked credential record {dag_linked_uid} has no usable auth ' + '(need login and password, or login and SSH private key). ' + 'Configure valid credentials or use --credential to override.') + elif not allow_supply_user and not allow_supply_host: + raise CommandError('pam launch', + f'No credentials configured for record {record_uid}. ' + 'Configure a linked credential or enable allowSupplyUser/allowSupplyHost.') # Find the gateway for this record gateway_info = self.find_gateway(params, record_uid) @@ -389,26 +794,68 @@ def execute(self, params: KeeperParams, **kwargs): connected_gateway_uids = [x.controllerUid for x in connected_gateways.controllers] gateway_uid_bytes = url_safe_str_to_bytes(gateway_info['gateway_uid']) if gateway_uid_bytes not in connected_gateway_uids: - logging.warning( + # Root logger is ERROR when not DEBUG; use logging.error so this is visible. + logging.error( 'Gateway "%s" (%s) seems offline - trying to connect anyway.', - gateway_info['gateway_name'], gateway_info['gateway_uid'] + gateway_info['gateway_name'], + gateway_info['gateway_uid'], ) else: - logging.debug(f"✓ Gateway is online and connected") + logging.debug("✓ Gateway is online and connected") else: - logging.warning('Gateway seems offline - trying to connect anyway.') + logging.error('Gateway seems offline - trying to connect anyway.') except Exception as e: logging.debug('Could not verify gateway status: %s. Continuing...', e) + if pam_connection_font_size is not None and str(pam_connection_font_size).strip() != '': + fs_int = _pam_connection_font_size_int(pam_connection_font_size) + if fs_int != 12: + fs_disp = fs_disp = str(fs_int) if fs_int is not None else str(pam_connection_font_size).strip() + logging.warning( + 'Record %s sets connection.fontSize=%s (guacd default is 12); session recordings ' + 'may look different from this Commander terminal session.', + record_uid, + fs_disp, + ) + print( + f'Warning: This record sets fontSize={fs_disp}; session recordings may look ' + 'different from this Commander terminal session.', + ) + # Launch terminal connection result = launch_terminal_connection(params, record_uid, gateway_info, **kwargs) if result.get('success'): - logging.debug(f"Terminal connection launched successfully") + logging.debug("Terminal connection launched successfully") logging.debug(f"Protocol: {result.get('protocol')}") + # Warn early: clipboard policy + gateway risk before the terminal session starts. + _clip = result.get('settings', {}).get('clipboard', {}) + if _clip.get('disablePaste') or _clip.get('disableCopy'): + print('Warning: This record disables clipboard copy and/or paste in PAM settings.') + logging.debug( + '\nWarning: This record disables clipboard copy and/or paste in PAM. ' + 'Commander enforces that in the client; guacd also receives disable-paste/' + 'disable-copy when the record requires it. Some gateways never emit a ' + 'Guacamole terminal `pipe` in that configuration — `pam launch` may show no ' + 'shell output. Workarounds: use Web Vault for this machine, temporarily allow ' + 'clipboard on the record for CLI (enable-pipe is still requested in the offer).\n' + ) + if _clip.get('disablePaste'): + logging.debug( + 'disablePaste: local clipboard paste is off; paste chords send key events ' + '(remote session clipboard / TTY paste).\n' + ) + if _clip.get('disableCopy'): + logging.debug('Copy is disabled (disableCopy): remote text will not be placed on the local clipboard.\n') + # Always start interactive CLI session - self._start_cli_session(result, params) + # Pass launch_credential_uid to know if ConnectAs payload is needed + self._start_cli_session( + result, params, + kwargs.get('launch_credential_uid'), + use_stdin=kwargs.get('use_stdin', False), + ) else: error_msg = result.get('error', 'Unknown error') raise CommandError('pam launch', f'Failed to launch connection: {error_msg}') @@ -416,7 +863,13 @@ def execute(self, params: KeeperParams, **kwargs): # Restore original root logger level root_logger.setLevel(original_level) - def _start_cli_session(self, tunnel_result: Dict[str, Any], params: KeeperParams): + def _start_cli_session( + self, + tunnel_result: Dict[str, Any], + params: KeeperParams, + launch_credential_uid: Optional[str] = None, + use_stdin: bool = False, + ): """ Start CLI session using PythonHandler protocol mode. @@ -434,10 +887,38 @@ def _start_cli_session(self, tunnel_result: Dict[str, Any], params: KeeperParams 4. Python responds with 'connect', 'size', 'audio', 'image' 5. guacd sends 'ready', terminal session begins + Input modes + ----------- + Default (key-event): + InputHandler maps every keystroke to a Guacamole `key` instruction + (press + release), matching Web Vault behaviour. + --stdin (pipe mode): + StdinHandler sends typed bytes via the pipe/blob/end STDIN stream, + matching kcm-cli behaviour. + Both modes share the same paste (Ctrl+V / Shift+Insert → Vault clipboard + stream) and Ctrl+C double-tap (400 ms → local exit) logic. + Args: tunnel_result: Result from launch_terminal_connection params: KeeperParams instance + launch_credential_uid: Optional UID resolved from CLI --credential; + triggers ConnectAs payload when set. + use_stdin: When True use StdinHandler (pipe/byte mode) instead of + the default InputHandler (key-event mode). """ + import sys as _sys + + # Non-interactive stdin guard: key-event mode requires a real TTY. + # --stdin (pipe mode) is fine with redirected stdin, but key mode is not — + # tty.setraw() will raise and character-at-a-time mapping makes no sense + # for piped/scripted input. + if not use_stdin and not _sys.stdin.isatty(): + raise CommandError( + 'pam launch', + 'Interactive (key-event) mode requires a TTY. ' + 'stdin is not a terminal — scripted/piped input is not supported. ' + 'If you need to drive pam launch non-interactively use --stdin.', + ) shutdown_requested = False def signal_handler_fn(signum, frame): @@ -447,7 +928,9 @@ def signal_handler_fn(signum, frame): original_handler = signal.signal(signal.SIGINT, signal_handler_fn) + rust_log_token = None try: + rust_log_token = enter_pam_launch_terminal_rust_logging() tube_id = tunnel_result['tunnel'].get('tube_id') if not tube_id: raise CommandError('pam launch', 'No tube ID in tunnel result') @@ -499,21 +982,75 @@ def signal_handler_fn(signum, frame): if not connected: raise CommandError('pam launch', "WebRTC connection not established within timeout") - # Wait a brief moment for DataChannel to be ready after connection state becomes "connected" - # The connection state can be "connected" before the DataChannel is actually ready to send data - time.sleep(0.2) + # Wait for DataChannel to be ready and Gateway to wire the session. + # connection state "connected" can precede DataChannel readiness; Gateway also needs + # time to associate the WebRTC connection with the channel and prepare guacd. + # Configurable via PAM_OPEN_CONNECTION_DELAY (default 0.2s; use 2.0 if handshake never starts). + open_conn_delay = float(os.environ.get('PAM_OPEN_CONNECTION_DELAY', '0.2')) + time.sleep(open_conn_delay) # Send OpenConnection to Gateway to initiate guacd session # This is critical - without it, Gateway doesn't start guacd and no Guacamole traffic flows # Retry with exponential backoff if DataChannel isn't ready yet logging.debug(f"Sending OpenConnection to Gateway (conn_no=1, conversation_id={conversation_id})") + + # Build ConnectAs payload when cliUserOverride is set — this covers both: + # (a) explicit -cr that differs from DAG-linked, and + # (b) implicit userRecords[0] fallback (no DAG link, allowSupply* enabled, no -cr given). + # In case (b) launch_credential_uid is None; use userRecordUid from settings instead. + connect_as_payload = None + gateway_uid = tunnel_result['tunnel'].get('gateway_uid') + _tunnel_settings = tunnel_result.get('settings', {}) + cli_user_override = _tunnel_settings.get('cliUserOverride', False) + effective_credential_uid = launch_credential_uid or ( + _tunnel_settings.get('userRecordUid') if cli_user_override else None + ) + + # Remote keeper-pam-webrtc-rs version: from tunnel (non-streaming) or session (streaming) + remote_webrtc_version = tunnel_result['tunnel'].get('remote_webrtc_version') + if remote_webrtc_version is None: + sess = get_tunnel_session(tube_id) + remote_webrtc_version = getattr(sess, 'remote_webrtc_version', None) if sess else None + + connect_as_supported = _version_at_least(remote_webrtc_version, CONNECT_AS_MIN_VERSION) + + if cli_user_override and effective_credential_uid and gateway_uid: + # When using userRecords[0] fallback, include explanation in CommandError if ConnectAs fails + connect_as_fallback_msg = '' + if launch_credential_uid is None: + connect_as_fallback_msg = ( + f'Using credential from userRecords[0] ({effective_credential_uid}) as ConnectAs fallback because ' + 'no launch credential on record; ConnectAs is enabled but no --credential was given. ' + ) + if not connect_as_supported: + raise CommandError( + 'pam launch', + connect_as_fallback_msg + + f'ConnectAs (--credential) requires Gateway with keeper-pam-webrtc-rs >= {CONNECT_AS_MIN_VERSION}. ' + f'Remote version: {remote_webrtc_version or "unknown"}. ' + 'Please upgrade the Gateway to use --credential.' + ) + logging.debug(f"Building ConnectAs payload for credential: {effective_credential_uid}") + gateway_public_key = _retrieve_gateway_public_key(params, gateway_uid) + if gateway_public_key: + connect_as_payload = _build_connect_as_payload(params, effective_credential_uid, gateway_public_key) + if connect_as_payload: + logging.debug(f"ConnectAs payload built: {len(connect_as_payload)} bytes") + else: + logging.warning("Failed to build ConnectAs payload - credentials may not be passed to gateway") + else: + logging.warning("Could not retrieve gateway public key - credentials may not be passed to gateway") + max_retries = 5 retry_delay = 0.1 last_error = None for attempt in range(max_retries): try: - tube_registry.open_handler_connection(conversation_id, 1) + # Pass ConnectAs payload when user supplied credentials via -cr (matches vault behavior) + tube_registry.open_handler_connection( + conversation_id, 1, connect_as_payload + ) logging.debug("✓ OpenConnection sent successfully") break except Exception as e: @@ -549,30 +1086,96 @@ def signal_handler_fn(signum, frame): guac_ready_timeout = 10.0 # Reduced from 30s - sync triggers readiness quickly - if python_handler.wait_for_ready(guac_ready_timeout): + guac_ready_result = python_handler.wait_for_ready(guac_ready_timeout) + if guac_ready_result: logging.debug("* Guacamole connection ready!") - logging.debug("Terminal session active. Press Ctrl+C to exit.") + logging.debug( + 'Terminal session active. Ctrl+C → remote interrupt; double Ctrl+C (<400 ms) to exit.', + ) else: logging.warning(f"Guacamole did not report ready within {guac_ready_timeout}s") logging.warning("Terminal may still work if data is flowing.") # Check for STDOUT pipe support (feature detection) # This warns the user if CLI pipe mode is not supported by the gateway - python_handler.check_stdout_pipe_support(timeout=10.0) - - # Create stdin handler for pipe/blob/end input pattern - # StdinHandler reads raw stdin and sends via send_stdin (base64-encoded) - # This matches kcm-cli's implementation for plaintext SSH/TTY streams - stdin_handler = StdinHandler( - stdin_callback=lambda data: python_handler.send_stdin(data), - key_callback=lambda keysym, pressed: python_handler.send_key(keysym, pressed) + _clipboard_pol = tunnel_result.get('settings', {}).get('clipboard', {}) + _pam_clipboard = ( + _pam_connection_clipboard_bool(_clipboard_pol.get('disablePaste')) + or _pam_connection_clipboard_bool(_clipboard_pol.get('disableCopy')) ) + python_handler.check_stdout_pipe_support( + timeout=45.0 if _pam_clipboard else 10.0, + pam_clipboard_record_policy=_pam_clipboard, + ) + + # Shared input coordinators: + # Both InputHandler (key mode) and StdinHandler (--stdin mode) use + # the same CtrlCCoordinator and PasteOrchestrator so paste and + # Ctrl+C double-tap behave identically in both modes. - # Main event loop with stdin input + def _request_exit(): + nonlocal shutdown_requested + shutdown_requested = True + + disable_paste = _pam_connection_clipboard_bool(_clipboard_pol.get('disablePaste')) + + if use_stdin: + # Pipe mode: remote interrupt = raw Ctrl+C byte via send_stdin + ctrl_c_coord = CtrlCCoordinator( + remote_interrupt_fn=lambda: python_handler.send_stdin(b'\x03'), + local_exit_fn=_request_exit, + ) + else: + # Key-event mode: remote interrupt = send_key(ETX) press+release + def _remote_key_ctrl_c() -> None: + python_handler.send_key(3, True) + python_handler.send_key(3, False) + + ctrl_c_coord = CtrlCCoordinator( + remote_interrupt_fn=_remote_key_ctrl_c, + local_exit_fn=_request_exit, + ) + + # No PasteOrchestrator when PAM disables paste — only Guacamole key chords (no pyperclip). + paste_orch: Optional[PasteOrchestrator] = None + if not disable_paste: + paste_orch = PasteOrchestrator( + send_clipboard_fn=python_handler.send_clipboard_stream, + disable_paste=False, + ) + + # Build the appropriate input handler: + if use_stdin: + input_handler = StdinHandler( + stdin_callback=lambda data: python_handler.send_stdin(data), + key_callback=lambda keysym, pressed: python_handler.send_key(keysym, pressed), + ctrl_c_coordinator=ctrl_c_coord, + paste_orchestrator=paste_orch, + ) + logging.debug('Input mode: --stdin (pipe/blob/end, StdinHandler)') + else: + input_handler = InputHandler( + key_callback=lambda keysym, pressed: python_handler.send_key(keysym, pressed), + ctrl_c_coordinator=ctrl_c_coord, + paste_orchestrator=paste_orch, + disable_paste=disable_paste, + ) + logging.debug('Input mode: key-event (InputHandler, default)') + + # Main event loop with input handler try: - # Start stdin handler (runs in background thread) - stdin_handler.start() - logging.debug("STDIN handler started") # (pipe/blob/end mode) + # Start input handler (runs in background thread) + input_handler.start() + # Ctrl+C → byte 0x03 in the input thread (CtrlCCoordinator). Windows clears + # ENABLE_PROCESSED_INPUT in the stdin readers so the key is delivered; SIG_IGN had + # blocked ReadConsoleInput/msvcrt from seeing Ctrl+C on Win11. + mode_label = '--stdin (pipe)' if use_stdin else 'key-event (default)' + logging.debug( + 'Input handler started [mode=%s]. ' + 'Ctrl+C → remote interrupt; press Ctrl+C twice quickly (<400 ms) to exit. ' + 'Paste chords → Guacamole clipboard stream, or key events when disablePaste.', + mode_label, + ) # --- Terminal resize tracking --- # Resize polling is skipped entirely in non-interactive (piped) @@ -669,12 +1272,12 @@ def signal_handler_fn(signum, frame): logging.debug("\n\nExiting CLI terminal mode...") finally: - # Stop stdin handler first (restores terminal) - logging.debug("Stopping stdin handler...") + # Stop input handler first (restores terminal) + logging.debug("Stopping input handler...") try: - stdin_handler.stop() + input_handler.stop() except Exception as e: - logging.debug(f"Error stopping stdin handler: {e}") + logging.debug(f"Error stopping input handler: {e}") # Cleanup - check if connection is already closed to avoid deadlock logging.debug("Stopping Python handler...") @@ -708,8 +1311,11 @@ def signal_handler_fn(signum, frame): logging.info("CLI session ended - cleanup complete") + except CommandError: + raise except Exception as e: logging.error(f"Error in PythonHandler CLI session: {e}") raise CommandError('pam launch', f'Failed to start CLI session: {e}') finally: + exit_pam_launch_terminal_rust_logging(rust_log_token) signal.signal(signal.SIGINT, original_handler) diff --git a/keepercommander/commands/pam_launch/python_handler.py b/keepercommander/commands/pam_launch/python_handler.py index 3294c7b76..63694418a 100644 --- a/keepercommander/commands/pam_launch/python_handler.py +++ b/keepercommander/commands/pam_launch/python_handler.py @@ -58,12 +58,11 @@ from __future__ import annotations import base64 import logging -import sys import threading from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Any from .guacamole import Parser, to_instruction -from .guac_cli.instructions import create_instruction_router +from .guac_cli.instructions import create_instruction_router, is_stdout_pipe_stream_name if TYPE_CHECKING: pass @@ -113,6 +112,7 @@ def __init__( - audio_mimetypes: List of supported audio types (optional) - image_mimetypes: List of supported image types (optional) - guacd_params: Additional guacd parameters dict (optional) + - clipboard: Optional {disableCopy, disablePaste} from PAM (optional) on_ready: Optional callback when Guacamole connection is ready on_disconnect: Optional callback when connection is closed (receives reason) """ @@ -126,6 +126,10 @@ def __init__( self.connection_settings = connection_settings or {} self.handshake_sent = False # Track if we've responded to 'args' + _clip = self.connection_settings.get('clipboard') or {} + # Block remote → local OS clipboard (PAM disableCopy); still ack Guacamole blobs. + self._clipboard_disable_copy = bool(_clip.get('disableCopy')) + # Guacamole protocol parser (using new guacamole module) self.parser = Parser() @@ -133,6 +137,11 @@ def __init__( # Server sends pipe with name "STDOUT", then blobs with base64 terminal output self.stdout_stream_index: int = -1 + # Inbound clipboard stream (guacd → client): clipboard,, then blob/end + self.remote_clipboard_stream_index: int = -1 + self._remote_clipboard_mimetype: Optional[str] = None + self._remote_clipboard_acc = bytearray() + # Feature detection for CLI pipe mode # STDOUT pipe: if the server supports plaintext SSH/TTY mode, it sends a STDOUT pipe # STDIN pipe: when we try to send input, the server should ack successfully @@ -152,6 +161,7 @@ def __init__( 'error': self._on_error, 'ack': self._on_ack, # Custom ack handler for STDIN failure detection 'pipe': self._on_pipe, # Custom pipe handler for STDOUT detection + 'clipboard': self._on_remote_clipboard_instruction, }, send_ack_callback=self._send_ack, stdout_stream_tracker=self, @@ -182,6 +192,17 @@ def __init__( self.messages_sent = 0 self.bytes_sent = 0 + # Clipboard stream counter — starts at 200 to avoid collision with + # image streams (1–99) and named pipe streams (100–101). + self._clipboard_stream_index: int = 200 + + # Count `pipe` opcodes from guacd (diagnostics when STDOUT never binds). + self._guac_pipe_instruction_count: int = 0 + + def note_guac_pipe_instruction(self) -> None: + """Incremented by the instruction router for each well-formed pipe from guacd.""" + self._guac_pipe_instruction_count += 1 + def start(self): """Start the handler.""" if self.running: @@ -374,12 +395,6 @@ def _send_handshake_response(self, args_list: List[str]): # Get guacd parameters (hostname, port, username, password, etc.) guacd_params = settings.get('guacd_params', {}) - # Debug: Log what credentials we have - logging.debug(f"DEBUG: guacd_params keys: {list(guacd_params.keys())}") - logging.debug(f"DEBUG: guacd_params['username']: {'(set)' if guacd_params.get('username') else '(empty)'}") - logging.debug(f"DEBUG: guacd_params['password']: {'(set)' if guacd_params.get('password') else '(empty)'}") - logging.debug(f"DEBUG: guacd_params['private-key']: {'(set)' if guacd_params.get('private-key') else '(empty)'}") - # Build connect args: first arg is version (from guacd), rest are param values connect_args = [] @@ -410,13 +425,6 @@ def _send_handshake_response(self, args_list: List[str]): connect_instruction = self._format_instruction('connect', *connect_args) self._send_to_gateway(connect_instruction) logging.debug(f"Sent 'connect' with {len(connect_args)} args") - # Debug: Show which args were sent (without revealing secrets) - if args_list: - for i, param_name in enumerate(args_list[1:], start=1): - value = connect_args[i] if i < len(connect_args) else "(missing)" - is_secret = param_name.lower() in ['password', 'passphrase', 'private-key'] - display_value = '(set)' if (is_secret and value) else ('(empty)' if is_secret else value[:20] if isinstance(value, str) else value) - logging.debug(f"DEBUG: connect arg '{param_name}' = {display_value}") # Send size instruction size_instruction = self._format_instruction('size', width, height, dpi) @@ -537,8 +545,8 @@ def _on_pipe(self, args: List[str]) -> None: When the server supports plaintext SSH/TTY mode, it sends a pipe with name "STDOUT". If this pipe never opens, the feature is not supported by the gateway/connection. - Note: The instruction router handles STDOUT ack and blob decode before calling this. - This handler just sets the event to signal that STDOUT pipe was opened. + Note: The instruction router handles STDOUT ack, blob decode, and stdout_pipe_opened + when the pipe is treated as terminal output (STDOUT name or text/* fallback). Args: args: [stream_index, mimetype, name] @@ -546,12 +554,6 @@ def _on_pipe(self, args: List[str]) -> None: if len(args) >= 3: stream_index, mimetype, name = args[0], args[1], args[2] logging.debug(f"[PIPE] stream={stream_index}, type={mimetype}, name={name}") - - if name == 'STDOUT': - # Signal that STDOUT pipe was opened - CLI pipe mode is supported - # Note: stream_index and ack are already handled by instruction router - self.stdout_pipe_opened.set() - logging.debug(f"STDOUT pipe opened on stream {stream_index} - CLI pipe mode supported") else: logging.debug(f"[PIPE] {args}") @@ -582,6 +584,90 @@ def _on_ack(self, args: List[str]) -> None: else: logging.debug(f"[ACK] {args}") + def _on_remote_clipboard_instruction(self, args: List[str]) -> None: + """Open inbound clipboard stream from guacd; ack and prepare for blob/end.""" + if len(args) < 2: + logging.debug(f"[CLIPBOARD] unexpected args: {args}") + return + stream, mimetype = args[0], args[1] + try: + idx = int(stream) + except ValueError: + logging.debug(f"[CLIPBOARD] bad stream index: {stream!r}") + return + self.remote_clipboard_stream_index = idx + self._remote_clipboard_mimetype = mimetype + self._remote_clipboard_acc.clear() + try: + self._send_ack(stream, 'OK', '0') + except Exception as e: + logging.error(f"Clipboard stream ack failed: {e}") + logging.debug(f"[CLIPBOARD] inbound stream={stream}, type={mimetype}") + + def handle_remote_clipboard_blob(self, stream_index: str, b64_data: str) -> bool: + """ + Called by the instruction router for blob instructions. + Returns True if this blob belonged to the active inbound clipboard stream. + """ + if self.remote_clipboard_stream_index < 0: + return False + try: + if int(stream_index) != self.remote_clipboard_stream_index: + return False + except ValueError: + return False + try: + if not self._clipboard_disable_copy: + self._remote_clipboard_acc.extend(base64.b64decode(b64_data)) + self._send_ack(stream_index, 'OK', '0') + except Exception as e: + logging.error(f"Inbound clipboard blob error: {e}") + try: + self._send_ack(stream_index, 'OK', '0') + except Exception: + pass + return True + + def handle_remote_clipboard_end(self, stream_index: str) -> bool: + """ + Called by the instruction router for end instructions. + Returns True if this ended the active inbound clipboard stream. + """ + if self.remote_clipboard_stream_index < 0: + return False + try: + if int(stream_index) != self.remote_clipboard_stream_index: + return False + except ValueError: + return False + try: + if ( + not self._clipboard_disable_copy + and self._remote_clipboard_acc + and (self._remote_clipboard_mimetype or '').lower().startswith('text/') + ): + text = self._remote_clipboard_acc.decode('utf-8', errors='replace') + try: + import pyperclip # type: ignore[import] + + pyperclip.copy(text) + except ImportError: + logging.warning( + 'Remote clipboard data received but pyperclip is not installed; ' + 'cannot copy to the local clipboard. Run: pip install pyperclip' + ) + except Exception as e: + logging.warning(f'Could not copy remote clipboard to local OS: {e}') + else: + logging.debug( + 'Remote clipboard copied to local OS (%d chars)', len(text) + ) + finally: + self.remote_clipboard_stream_index = -1 + self._remote_clipboard_mimetype = None + self._remote_clipboard_acc.clear() + return True + def _format_instruction(self, *elements) -> bytes: """Format elements into a Guacamole instruction.""" # Use the new guacamole module's to_instruction function @@ -686,7 +772,12 @@ def send_stdin(self, data: bytes): except Exception as e: logging.error(f"Error sending stdin: {e}") - def check_stdout_pipe_support(self, timeout: float = 10.0) -> bool: + def check_stdout_pipe_support( + self, + timeout: float = 10.0, + *, + pam_clipboard_record_policy: bool = False, + ) -> bool: """ Check if STDOUT pipe is supported with a timeout. @@ -696,6 +787,8 @@ def check_stdout_pipe_support(self, timeout: float = 10.0) -> bool: Args: timeout: Seconds to wait for STDOUT pipe (default 10.0) + pam_clipboard_record_policy: PAM record disables copy/paste; gateway may still + force guacd clipboard off after the Commander handshake (no STDOUT pipe). Returns: True if STDOUT pipe opened, False if timeout expired @@ -708,11 +801,25 @@ def check_stdout_pipe_support(self, timeout: float = 10.0) -> bool: f"STDOUT pipe did not open within {timeout}s. " f"CLI pipe mode may not be supported by this gateway/connection." ) + n_pipe = self._guac_pipe_instruction_count print( "\nNo STDOUT stream has been received since the connection was opened. " "This may indicate the gateway/guacd does not support CLI mode. " "You can continue waiting, or press Ctrl+C to cancel." ) + if pam_clipboard_record_policy: + if n_pipe == 0: + logging.error( + "This record disables clipboard copy or paste in PAM. Some KCM builds may " + "omit the terminal pipe entirely. Commander requests enable-pipe in the offer " + "- if pipes still never appear use Web Vault or temporarily allow " + "clipboard on the record for CLI sesions.\n" + ) + else: + print( + f"\nThis record disables PAM clipboard; guacd sent {n_pipe} pipe instruction(s) " + "but none were accepted as terminal STDOUT (unexpected).\n" + ) return False def is_stdin_supported(self) -> bool: @@ -798,23 +905,49 @@ def send_size(self, width: int, height: int, dpi: int = 96): except Exception as e: logging.error(f"Error sending size: {e}") - def send_clipboard(self, text: str): + def send_clipboard_stream(self, text: str) -> None: """ - Send clipboard data to guacd. + Send clipboard text using the Web Vault-equivalent stream protocol. - Only sends if session is active (running and data flowing). + Mirrors GuacamoleClipboard.setRemoteClipboard: + createClipboardStream(mimetype) → clipboard instruction + StringWriter.sendText(data) → blob instruction (base64) + StringWriter.sendEnd() → end instruction + + Wire format: + clipboard,,text/plain; + blob,,; + end,; + + Never uses send_stdin for clipboard data. Args: - text: Clipboard text + text: Clipboard text to send """ if not self.running or not self.data_flowing.is_set(): return try: - instruction = self._format_instruction('clipboard', 'text/plain', text) - self._send_to_gateway(instruction) - except Exception as e: - logging.error(f"Error sending clipboard: {e}") + stream_id = str(self._clipboard_stream_index) + self._clipboard_stream_index += 1 + + data_b64 = base64.b64encode(text.encode('utf-8')).decode('ascii') + + self._send_to_gateway( + self._format_instruction('clipboard', stream_id, 'text/plain') + ) + self._send_to_gateway( + self._format_instruction('blob', stream_id, data_b64) + ) + self._send_to_gateway( + self._format_instruction('end', stream_id) + ) + + logging.debug( + 'Clipboard stream sent: %d chars, stream_id=%s', len(text), stream_id + ) + except Exception as exc: + logging.error(f"Error sending clipboard stream: {exc}") def wait_for_ready(self, timeout: float = 10.0) -> bool: """ diff --git a/keepercommander/commands/pam_launch/rust_log_filter.py b/keepercommander/commands/pam_launch/rust_log_filter.py new file mode 100644 index 000000000..b91ef795d --- /dev/null +++ b/keepercommander/commands/pam_launch/rust_log_filter.py @@ -0,0 +1,143 @@ +""" +Rust/webrtc log filtering for pam launch terminal session only. + +Downgrades Rust/webrtc/turn log messages to DEBUG so they only appear when --debug is on, +and only while the pam launch CLI terminal session is active. +""" + +import logging + + +def _rust_webrtc_logger_name(name: str) -> bool: + """True if logger name is from Rust/webrtc/turn so we treat its messages as DEBUG-only.""" + if not name: + return False + # Normalize so we match both '.' and '::' (Rust may use either when passed to Python) + n = (name or '').replace('::', '.') + return ( + n.startswith('keeper_pam_webrtc_rs') + or n.startswith('webrtc') + or n.startswith('turn') + or n.startswith('stun') + or 'relay_conn' in n # turn crate submodule + ) + + +class _RustWebrtcToDebugFilter(logging.Filter): + """ + Filter for Rust/webrtc/turn log records. + When not in debug mode: suppress entirely (return False) so no handler can emit them. + When in debug mode: allow (return True); downgrading to DEBUG is redundant but harmless. + """ + + def filter(self, record: logging.LogRecord) -> bool: + if not _rust_webrtc_logger_name(record.name): + return True + # Only show these messages when debug is enabled (root or effective level is DEBUG) + if logging.getLogger().getEffectiveLevel() <= logging.DEBUG: + return True + return False # suppress when not in debug + + +class _RustAwareLogger(logging.Logger): + """ + Logger that forces Rust/webrtc/turn loggers to have no handlers and propagate to root, + and applies the downgrade filter at the logger so messages are DEBUG-only. + Used so loggers created *after* enter_ (e.g. by turn crate on first use) are still suppressed. + """ + + def __init__(self, name, level=logging.NOTSET): + super().__init__(name, level) + if _rust_webrtc_logger_name(name): + self.setLevel(logging.DEBUG) + self.propagate = True + self.handlers.clear() + self.addFilter(_RustWebrtcToDebugFilter()) + + +_WEBRTC_CRATE_NAMES = [ + 'webrtc', 'webrtc_ice', 'webrtc_mdns', 'webrtc_dtls', + 'webrtc_sctp', 'turn', 'stun', 'webrtc_ice.agent.agent_internal', + 'webrtc_ice.agent.agent_gather', 'webrtc_ice.mdns', + 'webrtc_mdns.conn', 'webrtc.peer_connection', 'turn.client', + 'turn.client.relay_conn', # turn crate submodule that emits "fail to refresh permissions..." +] + + +def enter_pam_launch_terminal_rust_logging(): + """ + Apply Rust/webrtc log filtering only during pam launch terminal session. + Downgrades Rust/webrtc/turn messages to DEBUG so they only show with --debug. + Returns a token to pass to exit_pam_launch_terminal_rust_logging() on exit. + """ + root = logging.getLogger() + flt = _RustWebrtcToDebugFilter() + root.addFilter(flt) + + # Use custom Logger class so any Rust/webrtc logger created later (e.g. turn crate) + # gets no handlers and propagates to root, where our filter downgrades to DEBUG. + _original_logger_class = logging.getLoggerClass() + logging.setLoggerClass(_RustAwareLogger) + + saved = [] + downgrade_filter = _RustWebrtcToDebugFilter() + for name in list(logging.Logger.manager.loggerDict.keys()): + if not isinstance(name, str) or not _rust_webrtc_logger_name(name): + continue + log = logging.getLogger(name) + # Only save if it's a real Logger with state we can restore (not our custom class yet) + if not isinstance(log, _RustAwareLogger): + saved.append((name, log.level, log.propagate, list(log.handlers))) + log.setLevel(logging.DEBUG) + log.propagate = True + log.handlers.clear() + if downgrade_filter not in log.filters: + log.addFilter(downgrade_filter) + for crate_name in _WEBRTC_CRATE_NAMES: + log = logging.getLogger(crate_name) + if not isinstance(log, _RustAwareLogger): + saved.append((crate_name, log.level, log.propagate, list(log.handlers))) + log.setLevel(logging.DEBUG) + log.propagate = True + log.handlers.clear() + if downgrade_filter not in log.filters: + log.addFilter(downgrade_filter) + + return (flt, saved, _original_logger_class) + + +def exit_pam_launch_terminal_rust_logging(token): + """Restore Rust/webrtc logger state after pam launch terminal session. Pass token from enter_pam_launch_terminal_rust_logging().""" + if not token: + return + flt, saved = token[0], token[1] + original_logger_class = token[2] if len(token) > 2 else logging.Logger + logging.setLoggerClass(original_logger_class) + root = logging.getLogger() + root.removeFilter(flt) + # Remove downgrade filter from all Rust/webrtc loggers (we may have added the shared + # filter to existing loggers, and _RustAwareLogger instances have their own filter) + for name in list(logging.Logger.manager.loggerDict.keys()): + if not isinstance(name, str) or not _rust_webrtc_logger_name(name): + continue + log = logging.getLogger(name) + for f in list(log.filters): + if isinstance(f, _RustWebrtcToDebugFilter): + try: + log.removeFilter(f) + except ValueError: + pass + for crate_name in _WEBRTC_CRATE_NAMES: + log = logging.getLogger(crate_name) + for f in list(log.filters): + if isinstance(f, _RustWebrtcToDebugFilter): + try: + log.removeFilter(f) + except ValueError: + pass + for name, level, propagate, handlers in saved: + log = logging.getLogger(name) + log.setLevel(level) + log.propagate = propagate + for h in handlers: + log.addHandler(h) diff --git a/keepercommander/commands/pam_launch/terminal_connection.py b/keepercommander/commands/pam_launch/terminal_connection.py index 5e182b612..952c1d96a 100644 --- a/keepercommander/commands/pam_launch/terminal_connection.py +++ b/keepercommander/commands/pam_launch/terminal_connection.py @@ -5,7 +5,7 @@ # |_| # # Keeper Commander -# Copyright 2024 Keeper Security Inc. +# Copyright 2026 Keeper Security Inc. # Contact: ops@keepersecurity.com # @@ -19,7 +19,6 @@ from __future__ import annotations import logging import os -import sys import base64 import json import secrets @@ -32,11 +31,14 @@ from cryptography.hazmat.primitives.ciphers.aead import AESGCM from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.hkdf import HKDF +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives import serialization from keeper_secrets_manager_core.utils import bytes_to_base64, base64_to_bytes, url_safe_str_to_bytes, string_to_bytes, bytes_to_string from ...error import CommandError -from ... import vault +from ... import vault, api from ...keeper_dag import EdgeType +from ...proto.APIRequest_pb2 import GetKsmPublicKeysRequest, GetKsmPublicKeysResponse from ..ssh_agent import try_extract_private_key from ..tunnel.port_forward.tunnel_helpers import ( get_or_create_tube_registry, @@ -49,10 +51,11 @@ TunnelSignalHandler, tunnel_encrypt, tunnel_decrypt, - get_tunnel_session, get_keeper_tokens, MAIN_NONCE_LENGTH, SYMMETRIC_KEY_LENGTH, + parse_keeper_webrtc_version_from_sdp, + set_remote_description_and_parse_version, ) from ..tunnel.port_forward.TunnelGraph import TunnelDAG from ..pam.pam_dto import GatewayAction, GatewayActionWebRTCSession @@ -69,43 +72,36 @@ if TYPE_CHECKING: from ...params import KeeperParams +from ..pam_import.base import ConnectionProtocol + +# Protocol sets and defaults (ConnectionProtocol from pam_import.base) +GRAPHICAL = {ConnectionProtocol.RDP.value, ConnectionProtocol.VNC.value} # not supported by CLI +ALL_TERMINAL = { + ConnectionProtocol.SSH.value, + ConnectionProtocol.TELNET.value, + ConnectionProtocol.KUBERNETES.value, + ConnectionProtocol.MYSQL.value, + ConnectionProtocol.POSTGRESQL.value, + ConnectionProtocol.SQLSERVER.value, +} +DATABASE = { + ConnectionProtocol.MYSQL.value, + ConnectionProtocol.POSTGRESQL.value, + ConnectionProtocol.SQLSERVER.value, +} -# Protocol type constants -class ProtocolType: - """Terminal protocol types supported by PAM Launch""" - SSH = 'ssh' - TELNET = 'telnet' - KUBERNETES = 'kubernetes' - MYSQL = 'mysql' - POSTGRESQL = 'postgresql' - SQLSERVER = 'sqlserver' - - # All supported terminal protocols - ALL_TERMINAL = {SSH, TELNET, KUBERNETES, MYSQL, POSTGRESQL, SQLSERVER} - - # Database protocols - DATABASE = {MYSQL, POSTGRESQL, SQLSERVER} - - # Machine protocols - MACHINE = {SSH, TELNET} - - -# Default ports for protocols DEFAULT_PORTS = { - ProtocolType.SSH: 22, - ProtocolType.TELNET: 23, - ProtocolType.KUBERNETES: 443, - ProtocolType.MYSQL: 3306, - ProtocolType.POSTGRESQL: 5432, - ProtocolType.SQLSERVER: 1433, + ConnectionProtocol.SSH.value: 22, + ConnectionProtocol.TELNET.value: 23, + ConnectionProtocol.KUBERNETES.value: 443, + ConnectionProtocol.MYSQL.value: 3306, + ConnectionProtocol.POSTGRESQL.value: 5432, + ConnectionProtocol.SQLSERVER.value: 1433, } from .terminal_size import ( DEFAULT_TERMINAL_COLUMNS, DEFAULT_TERMINAL_ROWS, - DEFAULT_CELL_WIDTH_PX, - DEFAULT_CELL_HEIGHT_PX, - DEFAULT_SCREEN_DPI, _build_screen_info, get_terminal_size_pixels, ) @@ -120,6 +116,41 @@ class ProtocolType: MAX_MESSAGE_SIZE_LINE = "a=max-message-size:1073741823" +# Minimum keeper-pam-webrtc-rs version that supports ConnectAs payload in OpenConnection. +# Older Gateways (Rust module < this) do not parse connect_as_payload; omit it when not supported. +CONNECT_AS_MIN_VERSION = "2.1.6" + + +def _version_at_least(version: Optional[str], min_version: str) -> bool: + """ + Compare semantic versions. Returns True if version >= min_version. + + Args: + version: Parsed version (e.g. "2.1.4") or None (treated as unknown/old). + min_version: Minimum required version (e.g. "2.1.0"). + + Returns: + True if version is known and >= min_version; False if unknown or older. + """ + if not version: + return False + + def parse(v: str) -> tuple: + parts = [] + for p in v.split(".")[:3]: # major.minor.patch + try: + parts.append(int(p)) + except ValueError: + parts.append(0) + while len(parts) < 3: + parts.append(0) + return tuple(parts[:3]) + + try: + return parse(version) >= parse(min_version) + except Exception: + return False + def _ensure_max_message_size_attribute(sdp_offer: Optional[str]) -> Optional[str]: """ @@ -214,87 +245,110 @@ def _notify_gateway_connection_close(params, router_token, terminated=True): def detect_protocol(params: KeeperParams, record_uid: str) -> Optional[str]: """ - Detect the terminal protocol from a PAM record. + Detect the connection protocol from a PAM record. + + All machine types (pamMachine, pamDirectory, pamDatabase) allow any connection + type (ssh, telnet, rdp, vnc, kubernetes, mysql, etc.). Extraction follows: + first connection.protocol; for pamDatabase only, if still undetermined then + connection.databaseType, then infer from port. Args: params: KeeperParams instance record_uid: Record UID Returns: - Protocol string (ssh, telnet, kubernetes, mysql, postgresql, sqlserver) or None - - Raises: - CommandError: If record type is not supported or protocol cannot be determined + Protocol string (ex. ssh, telnet, rdp, mysql, etc.) or None if + not present/undetermined. If connection.protocol is set to a value that + matches a ConnectionProtocol enum, returns that canonical value; + otherwise returns the raw string (lowercased). """ record = vault.KeeperRecord.load(params, record_uid) if not isinstance(record, vault.TypedRecord): - raise CommandError('pam launch', f'Record {record_uid} is not a TypedRecord') + return None record_type = record.record_type + if record_type not in ('pamMachine', 'pamDirectory', 'pamDatabase'): + return None + + # Map lowercase protocol string to canonical ConnectionProtocol.value + _protocol_values = {p.value.lower(): p.value for p in ConnectionProtocol} + + pam_settings = record.get_typed_field('pamSettings') + if not pam_settings: + return None + + settings_value = pam_settings.get_default_value(dict) + if not settings_value: + return None + + connection = settings_value.get('connection') or {} + if not isinstance(connection, dict): + return None + + # 1) Try connection.protocol (same for all record types) + protocol_field = (connection.get('protocol') or '').strip() + if protocol_field: + protocol_lower = protocol_field.lower() + return _protocol_values.get(protocol_lower, protocol_lower) + + # 2) For pamDatabase only: connection.databaseType, then infer from port + if record_type == 'pamDatabase': + db_type = (connection.get('databaseType') or '').lower() + if 'mysql' in db_type: + return ConnectionProtocol.MYSQL.value + if 'postgres' in db_type or 'postgresql' in db_type: + return ConnectionProtocol.POSTGRESQL.value + if 'sql server' in db_type or 'sqlserver' in db_type or 'mssql' in db_type: + return ConnectionProtocol.SQLSERVER.value - # pamMachine -> SSH or Telnet - if record_type == 'pamMachine': - # Check if telnet is explicitly configured - # Look for telnet-specific fields or settings - pam_settings = record.get_typed_field('pamSettings') - if pam_settings: - settings_value = pam_settings.get_default_value(dict) - if settings_value: - connection = settings_value.get('connection', {}) - if isinstance(connection, dict): - # Check for telnet protocol indicator - protocol_field = connection.get('protocol') - if protocol_field and 'telnet' in str(protocol_field).lower(): - return ProtocolType.TELNET - - # Default to SSH for pamMachine - return ProtocolType.SSH - - # pamDirectory -> Kubernetes - elif record_type == 'pamDirectory': - return ProtocolType.KUBERNETES - - # pamDatabase -> MySQL, PostgreSQL, or SQL Server - elif record_type == 'pamDatabase': - # Inspect the database type field - pam_settings = record.get_typed_field('pamSettings') - if pam_settings: - settings_value = pam_settings.get_default_value(dict) - if settings_value: - connection = settings_value.get('connection', {}) - if isinstance(connection, dict): - db_type = connection.get('databaseType', '').lower() - - if 'mysql' in db_type: - return ProtocolType.MYSQL - elif 'postgres' in db_type or 'postgresql' in db_type: - return ProtocolType.POSTGRESQL - elif 'sql server' in db_type or 'sqlserver' in db_type or 'mssql' in db_type: - return ProtocolType.SQLSERVER - - # Try to infer from port if database type not specified hostname_field = record.get_typed_field('pamHostname') if hostname_field: host_value = hostname_field.get_default_value(dict) - if host_value: - port = host_value.get('port') - if port: - port_int = int(port) if isinstance(port, str) else port - if port_int == 3306: - return ProtocolType.MYSQL - elif port_int == 5432: - return ProtocolType.POSTGRESQL - elif port_int == 1433: - return ProtocolType.SQLSERVER - - # Default to MySQL if we can't determine - logging.warning(f"Could not determine database type for record {record_uid}, defaulting to MySQL") - return ProtocolType.MYSQL + if host_value and host_value.get('port') is not None: + try: + port_int = int(host_value['port']) + except (TypeError, ValueError): + port_int = None + if port_int == 3306: + return ConnectionProtocol.MYSQL.value + if port_int == 5432: + return ConnectionProtocol.POSTGRESQL.value + if port_int == 1433: + return ConnectionProtocol.SQLSERVER.value + + return None + + +_PAM_TYPES_WITH_CONNECTION_PORT = ['pamMachine', 'pamDatabase', 'pamDirectory'] - else: - raise CommandError('pam launch', - f'Record type "{record_type}" is not supported for terminal connections. ' - f'Supported types: pamMachine, pamDirectory, pamDatabase') + +def _pam_settings_connection_port(record: Any) -> Optional[int]: + """ + For PAM machine record types only, return a valid pamSettings.connection.port if set. + """ + if getattr(record, 'record_type', None) not in _PAM_TYPES_WITH_CONNECTION_PORT: + return None + if not hasattr(record, 'get_typed_field'): + return None + psf = record.get_typed_field('pamSettings') + if not psf or not hasattr(psf, 'get_default_value'): + return None + pam_val = psf.get_default_value(dict) + if not isinstance(pam_val, dict): + return None + connection = pam_val.get('connection') + if not isinstance(connection, dict): + return None + conn_port = connection.get('port') + if conn_port is None or conn_port == '': + return None + try: + p = int(conn_port) + except (ValueError, TypeError): + return None + if 1 <= p <= 65535: + return p + return None def extract_terminal_settings( @@ -303,6 +357,7 @@ def extract_terminal_settings( protocol: str, launch_credential_uid: Optional[str] = None, custom_host: Optional[str] = None, + custom_port: Optional[int] = None, ) -> Dict[str, Any]: """ Extract terminal connection settings from a PAM record. @@ -311,8 +366,9 @@ def extract_terminal_settings( params: KeeperParams instance record_uid: Record UID protocol: Protocol type (from detect_protocol) - launch_credential_uid: Optional override for userRecordUid (from --user CLI param) - custom_host: Optional override for hostname (from --host CLI param) + launch_credential_uid: Optional override for userRecordUid (from --credential CLI param) + custom_host: Optional override for hostname (from --host/--host-record/--credential CLI param) + custom_port: Optional override for port (from --host/--host-record/--credential CLI param) Returns: Dictionary containing terminal settings: @@ -345,31 +401,49 @@ def extract_terminal_settings( 'userRecordUid': None, } - # Extract hostname and port - hostname_field = record.get_typed_field('pamHostname') - if not hostname_field: - raise CommandError('pam launch', f'No hostname configured for record {record_uid}') - - host_value = hostname_field.get_default_value(dict) - if not host_value: - raise CommandError('pam launch', f'Invalid hostname configuration for record {record_uid}') - - settings['hostname'] = host_value.get('hostName') - - # Override hostname if custom_host provided (requires allowSupplyHost - validated in launch.py) + # Extract hostname and port from record - enforce single non-empty host/pamHostname field. + # Host requires non-empty hostName; port is pamSettings.connection.port (PAM types only) + # when set, else the field's port — same precedence as launch._get_host_port_from_record. + _pam_override_port = _pam_settings_connection_port(record) + _host_candidates = [] + for _f in list(getattr(record, 'fields', None) or []) + list(getattr(record, 'custom', None) or []): + if getattr(_f, 'type', None) in ('pamHostname', 'host'): + _hv = _f.get_default_value(dict) if hasattr(_f, 'get_default_value') else {} + _hn = ((_hv.get('hostName') or '').strip()) if isinstance(_hv, dict) else '' + if not _hn: + continue + _pr = _pam_override_port if _pam_override_port is not None else ( + _hv.get('port') if isinstance(_hv, dict) else None + ) + if not _pr: + continue + try: + _pp = int(_pr) + if 1 <= _pp <= 65535: + _host_candidates.append((_hn, _pp, _hv)) + except (ValueError, TypeError): + pass + if len(_host_candidates) > 1: + raise CommandError('pam launch', + f'Record {record_uid} has {len(_host_candidates)} non-empty host/pamHostname fields ' + '(expected exactly one). Clear the extra field before launching.') + _record_host, _record_port_val, _host_value = _host_candidates[0] if _host_candidates else (None, None, {}) + + settings['hostname'] = _record_host + + # CLI --host overrides record hostname (allowSupplyHost validated in launch.py) if custom_host: settings['hostname'] = custom_host logging.debug(f"Using custom host override: {custom_host}") - # Validate hostname is present (either from record or CLI override) - # Note: allowSupplyHost check happens later after pamSettings are parsed - - # Get port (use default if not specified) - port_value = host_value.get('port') - if port_value: - settings['port'] = int(port_value) if isinstance(port_value, str) else port_value - else: - settings['port'] = DEFAULT_PORTS.get(protocol, 22) + # Port precedence: CLI (custom_port) > record (pamSettings.connection.port overrides host field + # on PAM types, else field port) > pamSettings.connection.port when record port still unset > + # protocol DEFAULT. pamSettings fallback runs in the pamSettings block below. + if custom_port is not None: + settings['port'] = custom_port + elif _record_port_val is not None: + settings['port'] = _record_port_val + # else: remains None until pamSettings fallback or DEFAULT below # Extract PAM settings pam_settings_field = record.get_typed_field('pamSettings') @@ -403,32 +477,76 @@ def extract_terminal_settings( if dag_launch_uid: settings['userRecordUid'] = dag_launch_uid logging.debug(f"Using launch credential from DAG: {settings['userRecordUid']}") - else: - # Fallback to userRecords from pamSettings if DAG lookup fails + elif not launch_credential_uid: + # No DAG-linked credential and no -cr given. + # If allowSupply* is enabled, use pamSettings.connection.userRecords[0] as + # implicit credential and warn so the user can be explicit via -cr. user_records = connection.get('userRecords', []) if user_records and len(user_records) > 0: - settings['userRecordUid'] = user_records[0] - logging.debug(f"Using userRecordUid from pamSettings: {settings['userRecordUid']}") + fallback_uid = user_records[0] + settings['userRecordUid'] = fallback_uid + allow_supply_host_flag = pam_settings_value.get('allowSupplyHost', False) + allow_supply_user_flag = connection.get('allowSupplyUser', False) + if allow_supply_host_flag or allow_supply_user_flag: + logging.warning( + 'Record %s: allowSupply* is enabled but no DAG-linked launch credential ' + 'was found; using pamSettings.connection.userRecords[0] (%s) as credential. ' + 'Pass --credential (-cr %s) to be explicit.', + record_uid, fallback_uid, fallback_uid, + ) + settings['_fallbackCredential'] = True + else: + logging.debug(f"Using userRecordUid from pamSettings: {fallback_uid}") + + # pamSettings.connection.port when CLI and host-derived port are still absent + if settings['port'] is None: + conn_port = connection.get('port') + if conn_port: + try: + settings['port'] = int(conn_port) + except (ValueError, TypeError): + pass # Protocol-specific settings - if protocol == ProtocolType.SSH: + if protocol == ConnectionProtocol.SSH.value: settings['protocol_specific'] = _extract_ssh_settings(connection) - elif protocol == ProtocolType.TELNET: + elif protocol == ConnectionProtocol.TELNET.value: settings['protocol_specific'] = _extract_telnet_settings(connection) - elif protocol == ProtocolType.KUBERNETES: + elif protocol == ConnectionProtocol.KUBERNETES.value: settings['protocol_specific'] = _extract_kubernetes_settings(connection) - elif protocol in ProtocolType.DATABASE: + elif protocol in DATABASE: settings['protocol_specific'] = _extract_database_settings(connection, protocol) # allowSupplyHost is at top level of pamSettings value, not inside connection settings['allowSupplyHost'] = pam_settings_value.get('allowSupplyHost', False) - # CLI overrides always take precedence (applied after pamSettings extraction) - # These are validated in launch.py before being passed here - logging.debug(f"DEBUG extract_terminal_settings: launch_credential_uid={launch_credential_uid}, current userRecordUid={settings.get('userRecordUid')}") + # Final port fallback to protocol default + if settings['port'] is None: + settings['port'] = DEFAULT_PORTS.get(protocol, 22) + + # CLI overrides: check if --credential provides a DIFFERENT user than DAG-linked. + # Always query the DAG directly - settings['userRecordUid'] may have been set from the + # userRecords[0] fallback (not DAG-linked) and must not be used for this comparison. + dag_linked_uid = _get_launch_credential_uid(params, record_uid) + if launch_credential_uid: - settings['userRecordUid'] = launch_credential_uid - logging.debug(f"Using launch credential from CLI override: {settings['userRecordUid']}") + if launch_credential_uid == dag_linked_uid: + # CLI --credential matches DAG-linked credential - treat as if no --credential was provided + # so gateway uses normal 'linked' flow + logging.debug(f"CLI --credential matches DAG-linked credential {dag_linked_uid} - using normal 'linked' flow") + settings['cliUserOverride'] = False + else: + # CLI --credential provides a different user - this is a real override + settings['userRecordUid'] = launch_credential_uid + settings['cliUserOverride'] = True + logging.debug(f"CLI --credential overrides DAG credential: {launch_credential_uid} (was {dag_linked_uid})") + elif settings.pop('_fallbackCredential', False): + # userRecords[0] fallback with allowSupply* - treat as implicit -cr: + # gateway may not have it in DAG, so use userSupplied + ConnectAs payload + settings['cliUserOverride'] = True + logging.debug(f"Implicit credential from userRecords[0] fallback: {settings.get('userRecordUid')} - treating as userSupplied") + else: + settings['cliUserOverride'] = False # Final validation: hostname must be present for connection to succeed # Note: userRecordUid is optional - if not present, _build_guacamole_connection_settings() @@ -483,11 +601,11 @@ def _extract_database_settings(connection: Dict[str, Any], protocol: str) -> Dic } # Add protocol-specific database settings - if protocol == ProtocolType.MYSQL: + if protocol == ConnectionProtocol.MYSQL.value: settings['useSSL'] = connection.get('useSSL', False) - elif protocol == ProtocolType.POSTGRESQL: + elif protocol == ConnectionProtocol.POSTGRESQL.value: settings['useSSL'] = connection.get('useSSL', False) - elif protocol == ProtocolType.SQLSERVER: + elif protocol == ConnectionProtocol.SQLSERVER.value: settings['useSSL'] = connection.get('useSSL', True) # SQL Server typically uses SSL by default return settings @@ -525,24 +643,26 @@ def create_connection_context(params: KeeperParams, 'terminal': settings['terminal'], 'recording': settings['recording'], 'connectAs': connect_as, - 'conversationType': _get_conversation_type(protocol), + 'conversationType': str(protocol).lower(), # Credential supply flags 'allowSupplyUser': settings.get('allowSupplyUser', False), 'allowSupplyHost': settings.get('allowSupplyHost', False), # Linked pamUser record UID for credential extraction 'userRecordUid': settings.get('userRecordUid'), + # True only when --credential was provided via CLI and differs from the DAG-linked record. + # Required by the offer-building path to distinguish "flag enabled but nothing supplied" + # from "flag enabled and user actually provided credentials". + 'cliUserOverride': settings.get('cliUserOverride', False), } - logging.debug(f"DEBUG create_connection_context: userRecordUid={context.get('userRecordUid')}") - # Add protocol-specific settings - if protocol == ProtocolType.SSH: + if protocol == ConnectionProtocol.SSH.value: context['ssh'] = settings['protocol_specific'] - elif protocol == ProtocolType.TELNET: + elif protocol == ConnectionProtocol.TELNET.value: context['telnet'] = settings['protocol_specific'] - elif protocol == ProtocolType.KUBERNETES: + elif protocol == ConnectionProtocol.KUBERNETES.value: context['kubernetes'] = settings['protocol_specific'] - elif protocol in ProtocolType.DATABASE: + elif protocol in DATABASE: context['database'] = settings['protocol_specific'] context['database']['type'] = protocol @@ -578,48 +698,25 @@ def _get_launch_credential_uid(params: 'KeeperParams', record_uid: str) -> Optio logging.debug(f"Record vertex not found in DAG for {record_uid}") return None - # Find all ACL links where Head is recordUID - # Look for the credential marked as is_launch_credential=True + # Find the credential explicitly marked as is_launch_credential=True in DAG launch_credential = None - admin_credential = None - all_linked = [] for user_vertex in record_vertex.has_vertices(EdgeType.ACL): acl_edge = user_vertex.get_edge(record_vertex, EdgeType.ACL) if acl_edge: try: content = acl_edge.content_as_dict or {} - is_admin = content.get('is_admin', False) - is_launch = content.get('is_launch_credential', None) - - all_linked.append(user_vertex.uid) - - if is_launch and launch_credential is None: + if content.get('is_launch_credential', False) and launch_credential is None: launch_credential = user_vertex.uid logging.debug(f"Found launch credential via DAG: {launch_credential}") - - if is_admin and admin_credential is None: - admin_credential = user_vertex.uid - logging.debug(f"Found admin credential via DAG: {admin_credential}") - except Exception as e: logging.debug(f"Error parsing ACL edge content: {e}") - # Prefer launch credential, fall back to first linked if no specific launch credential if launch_credential: logging.debug(f"Using launch credential from DAG: {launch_credential}") return launch_credential - elif all_linked: - # If no explicit launch credential but we have linked users, - # prefer non-admin credential - for uid in all_linked: - if uid != admin_credential: - logging.debug(f"Using non-admin linked credential: {uid}") - return uid - # Fall back to first linked - logging.debug(f"Using first linked credential: {all_linked[0]}") - return all_linked[0] + logging.debug(f"No explicit launch credential (is_launch_credential=True) in DAG for {record_uid}") return None except Exception as e: @@ -627,18 +724,206 @@ def _get_launch_credential_uid(params: 'KeeperParams', record_uid: str) -> Optio return None -def _get_conversation_type(protocol: str) -> str: - """Map protocol to Guacamole conversation type""" - # Map our protocol names to Guacamole conversation types - mapping = { - ProtocolType.SSH: 'ssh', - ProtocolType.TELNET: 'telnet', - ProtocolType.KUBERNETES: 'kubernetes', - ProtocolType.MYSQL: 'mysql', - ProtocolType.POSTGRESQL: 'postgresql', - ProtocolType.SQLSERVER: 'sql-server', - } - return mapping.get(protocol, protocol) +# ECIES info string for ConnectAs payload encryption +# Must match the gateway's expected value +CONNECT_AS_ECIES_INFO = b'KEEPER_CONNECT_AS_ECIES_SECP256R1_HKDF_SHA256' + + +def _ecies_encrypt_with_hkdf( + plaintext: bytes, + recipient_public_key: bytes, + info: bytes = CONNECT_AS_ECIES_INFO +) -> bytes: + """ + Encrypt data using ECIES with HKDF key derivation. + + This implements ECIES (Elliptic Curve Integrated Encryption Scheme) using: + - SECP256R1 (P-256) curve for ECDH key exchange + - HKDF-SHA256 for key derivation with the provided info string + - AES-256-GCM for symmetric encryption + + Args: + plaintext: Data to encrypt + recipient_public_key: 65-byte uncompressed public key of recipient + info: HKDF info/context string (default: CONNECT_AS_ECIES_INFO) + + Returns: + Encrypted payload: [ephemeral_pubkey (65)] + [nonce (12)] + [ciphertext + auth_tag] + """ + # Generate ephemeral key pair + ephemeral_private_key = ec.generate_private_key(ec.SECP256R1(), default_backend()) + ephemeral_public_key = ephemeral_private_key.public_key() + + # Serialize ephemeral public key (65 bytes uncompressed) + ephemeral_public_key_bytes = ephemeral_public_key.public_bytes( + encoding=serialization.Encoding.X962, + format=serialization.PublicFormat.UncompressedPoint + ) + + # Load recipient's public key from bytes + recipient_key = ec.EllipticCurvePublicKey.from_encoded_point( + ec.SECP256R1(), + recipient_public_key + ) + + # Perform ECDH to get shared secret + shared_secret = ephemeral_private_key.exchange(ec.ECDH(), recipient_key) + + # Derive encryption key using HKDF + hkdf = HKDF( + algorithm=hashes.SHA256(), + length=32, # AES-256 key + salt=None, + info=info, + backend=default_backend() + ) + encryption_key = hkdf.derive(shared_secret) + + # Generate random nonce for AES-GCM + nonce = os.urandom(12) + + # Encrypt with AES-256-GCM + aesgcm = AESGCM(encryption_key) + ciphertext = aesgcm.encrypt(nonce, plaintext, None) + + # Return: [ephemeral_pubkey (65)] + [nonce (12)] + [ciphertext + auth_tag] + return ephemeral_public_key_bytes + nonce + ciphertext + + +def _build_connect_as_payload( + params: 'KeeperParams', + user_record_uid: str, + gateway_public_key: bytes +) -> Optional[bytes]: + """ + Build encrypted ConnectAs payload for credential passing to gateway. + + The ConnectAs payload contains user credentials from a pamUser record, + encrypted using ECIES with HKDF. This allows the gateway to receive + credentials via the OpenConnection message instead of looking them up in DAG. + + Args: + params: KeeperParams instance + user_record_uid: UID of the pamUser record containing credentials + gateway_public_key: 65-byte public key of the gateway for ECIES encryption + + Returns: + Encrypted payload in format expected by keeper-pam-webrtc-rs Gateway: + [ephemeral_pubkey (65)] + [nonce (12)] + [ciphertext + auth_tag] = 185 bytes. + Returns None if credentials cannot be extracted or encryption fails. + """ + if not user_record_uid or not gateway_public_key: + return None + + try: + # Extract credentials from pamUser record + creds = _extract_user_record_credentials(params, user_record_uid) + + # Build ConnectAs user data structure (matches webvault's ConnectAsUser) + connect_as_user = {} + if creds.get('username'): + connect_as_user['username'] = creds['username'] + if creds.get('password'): + connect_as_user['password'] = creds['password'] + if creds.get('private_key'): + connect_as_user['private_key'] = creds['private_key'] + if creds.get('passphrase'): + connect_as_user['passphrase'] = creds['passphrase'] + + # The payload structure matches webvault: {"user": {...}} + payload_dict = {'user': connect_as_user} + payload_json = json.dumps(payload_dict).encode('utf-8') + + # keeper-pam-webrtc-rs protocol.rs expects: + # [encrypted_data_len: 4 bytes] + [PK(65)] + [Nonce(12)] + [Encrypted(encrypted_data_len)] + # Encrypted = ciphertext + auth_tag(16). Ciphertext len = plaintext len. + # So plaintext must be >= 92 bytes to produce 108-byte encrypted portion. + # Use space padding (not null) so decrypted JSON parses correctly. + min_plaintext_len = 92 + if len(payload_json) < min_plaintext_len: + payload_json = payload_json + b' ' * (min_plaintext_len - len(payload_json)) + + logging.debug(f"ConnectAs payload: username={'(set)' if connect_as_user.get('username') else '(empty)'}, " + f"password={'(set)' if connect_as_user.get('password') else '(empty)'}, " + f"private_key={'(set)' if connect_as_user.get('private_key') else '(empty)'}") + + # Encrypt with ECIES+HKDF + ecies_encrypted = _ecies_encrypt_with_hkdf(payload_json, gateway_public_key) + + # protocol.rs reads: connect_as_payload_len = get_u32(), then + # required_crypto_block_len = 65 + 12 + connect_as_payload_len + # The length is of the ENCRYPTED portion only (ciphertext+auth_tag) = 108 + encrypted_data_len = len(ecies_encrypted) - 65 - 12 # ciphertext + auth_tag + length_bytes = encrypted_data_len.to_bytes(4, byteorder='big') + connect_as_payload = length_bytes + ecies_encrypted + + logging.debug(f"Built ConnectAs payload: total_len={len(connect_as_payload)}, encrypted_data_len={encrypted_data_len}") + + return connect_as_payload + + except Exception as e: + logging.error(f"Failed to build ConnectAs payload: {e}") + return None + + +def _retrieve_gateway_public_key( + params: 'KeeperParams', + gateway_uid: str +) -> Optional[bytes]: + """ + Retrieve the public key for a gateway. + + This function calls the vault/get_ksm_public_keys API to retrieve the + gateway's public key needed for ECIES encryption of ConnectAs payloads. + + Args: + params: KeeperParams instance + gateway_uid: UID of the gateway + + Returns: + 65-byte uncompressed public key, or None if not found + """ + try: + gateway_uid_bytes = url_safe_str_to_bytes(gateway_uid) + get_ksm_pubkeys_rq = GetKsmPublicKeysRequest() + get_ksm_pubkeys_rq.controllerUids.append(gateway_uid_bytes) + get_ksm_pubkeys_rs = api.communicate_rest( + params, get_ksm_pubkeys_rq, 'vault/get_ksm_public_keys', + rs_type=GetKsmPublicKeysResponse + ) + + if len(get_ksm_pubkeys_rs.keyResponses) == 0: + logging.warning(f"No public key found for gateway {gateway_uid}") + return None + + gateway_public_key_bytes = get_ksm_pubkeys_rs.keyResponses[0].publicKey + logging.debug(f"Retrieved gateway public key: {len(gateway_public_key_bytes)} bytes") + return gateway_public_key_bytes + + except Exception as e: + logging.error(f"Error retrieving gateway public key: {e}") + return None + + +def _get_single_str_field(record: Any, field_type: str) -> str: + """ + Return the value of the single non-empty typed field matching field_type. + + Enforces exactly one non-empty field across both record.fields[] and record.custom[]. + Raises CommandError if multiple non-empty fields of that type are found. + Returns '' if none are found. + """ + nonempty_values = [] + for field in list(getattr(record, 'fields', None) or []) + list(getattr(record, 'custom', None) or []): + if getattr(field, 'type', None) == field_type: + val = field.get_default_value(str) if hasattr(field, 'get_default_value') else '' + if val: + nonempty_values.append(val) + if len(nonempty_values) > 1: + raise CommandError('pam launch', + f'Record has {len(nonempty_values)} non-empty {field_type!r} fields ' + '(expected exactly one). Clear the extra field before launching.') + return nonempty_values[0] if nonempty_values else '' def _extract_user_record_credentials( @@ -678,15 +963,11 @@ def _extract_user_record_credentials( logging.warning(f"User record {user_record_uid} is not a TypedRecord") return result - # Extract username from login field - login_field = user_record.get_typed_field('login') - if login_field: - result['username'] = login_field.get_default_value(str) or '' + # Extract username - enforce single non-empty login field across fields[] + custom[] + result['username'] = _get_single_str_field(user_record, 'login') - # Extract password - password_field = user_record.get_typed_field('password') - if password_field: - result['password'] = password_field.get_default_value(str) or '' + # Extract password - enforce single non-empty password field across fields[] + custom[] + result['password'] = _get_single_str_field(user_record, 'password') # Extract private key using try_extract_private_key() # This function checks: keyPair field, notes, custom fields (text, multiline, secret, note), and attachments @@ -750,12 +1031,13 @@ def _build_guacamole_connection_settings( private_key = None passphrase = None - logging.debug(f"DEBUG _build_guacamole_connection_settings: credential_type={credential_type}, user_record_uid={user_record_uid}") - # Determine how to get credentials based on credential_type - if credential_type == 'userSupplied': - # User-supplied credentials: leave empty, user will provide via guacamole prompt - logging.debug("Using userSupplied credential type - leaving credentials empty") + # Note: Even for 'userSupplied', if we have user_record_uid (from CLI --credential), extract credentials + # because guacd_params go directly to guacd via our connect instruction + if credential_type == 'userSupplied' and not user_record_uid: + # True user-supplied: no credentials provided at all + # Note: user may not be able to provide via guacamole prompt since STDIN/STDOUT not open yet + logging.debug("Using userSupplied credential type with no pamUser - leaving credentials empty") elif user_record_uid: # Extract credentials from linked pamUser record user_creds = _extract_user_record_credentials(params, user_record_uid) @@ -767,21 +1049,17 @@ def _build_guacamole_connection_settings( else: # Fallback: Get credentials from the pamMachine record directly # (backward compatibility for records without linked pamUser) + # Enforces single non-empty login/password field across fields[] + custom[]. record = vault.KeeperRecord.load(params, record_uid) if isinstance(record, vault.TypedRecord): - login_field = record.get_typed_field('login') - if login_field: - username = login_field.get_default_value(str) or '' - - password_field = record.get_typed_field('password') - if password_field: - password = password_field.get_default_value(str) or '' + username = _get_single_str_field(record, 'login') + password = _get_single_str_field(record, 'password') logging.debug("Using credentials from pamMachine record (no linked pamUser)") # Build guacd parameters dictionary # These map to guacd's expected parameter names # The 'protocol' field is required for guacd to know which backend to use - guacd_protocol = _get_conversation_type(protocol) # Convert to guacd protocol name (e.g., ssh, telnet) + guacd_protocol = str(protocol).lower() guacd_params = { 'protocol': guacd_protocol, # Required: tells guacd which protocol handler to use 'hostname': settings.get('hostname', ''), @@ -790,12 +1068,10 @@ def _build_guacamole_connection_settings( 'password': password, } - logging.debug(f"DEBUG guacd_params built: username={'(set)' if username else '(empty)'}, password={'(set)' if password else '(empty)'}") - # Add private key for SSH protocol if available # SSH authentication precedence: guacd/SSH tries private key first, then password # Both can be present simultaneously - this matches gateway behavior - if protocol == ProtocolType.SSH and private_key: + if protocol == ConnectionProtocol.SSH.value and private_key: guacd_params['private-key'] = private_key if passphrase: guacd_params['passphrase'] = passphrase @@ -804,7 +1080,7 @@ def _build_guacamole_connection_settings( # Add protocol-specific parameters protocol_specific = settings.get('protocol_specific', {}) - if protocol == ProtocolType.SSH: + if protocol == ConnectionProtocol.SSH.value: # SSH-specific params if protocol_specific.get('publicHostKey'): guacd_params['host-key'] = protocol_specific['publicHostKey'] @@ -814,14 +1090,14 @@ def _build_guacamole_connection_settings( if protocol_specific.get('sftpEnabled'): guacd_params['enable-sftp'] = 'true' - elif protocol == ProtocolType.TELNET: + elif protocol == ConnectionProtocol.TELNET.value: # Telnet-specific params if protocol_specific.get('usernameRegex'): guacd_params['username-regex'] = protocol_specific['usernameRegex'] if protocol_specific.get('passwordRegex'): guacd_params['password-regex'] = protocol_specific['passwordRegex'] - elif protocol == ProtocolType.KUBERNETES: + elif protocol == ConnectionProtocol.KUBERNETES.value: # Kubernetes-specific params if protocol_specific.get('namespace'): guacd_params['namespace'] = protocol_specific['namespace'] @@ -838,11 +1114,14 @@ def _build_guacamole_connection_settings( if protocol_specific.get('ignoreServerCertificate'): guacd_params['ignore-cert'] = 'true' - elif protocol in ProtocolType.DATABASE: + elif protocol in DATABASE: # Database-specific params if protocol_specific.get('defaultDatabase'): guacd_params['database'] = protocol_specific['defaultDatabase'] + # CLI mode: named pipe for terminal STDOUT (guacr terminal handlers; not graphical RDP/VNC) + guacd_params['enable-pipe'] = 'true' + # Terminal display settings terminal_settings = settings.get('terminal', {}) if terminal_settings.get('colorScheme'): @@ -850,12 +1129,12 @@ def _build_guacamole_connection_settings( if terminal_settings.get('fontSize'): guacd_params['font-size'] = terminal_settings['fontSize'] - # Clipboard settings - clipboard_settings = settings.get('clipboard', {}) - if clipboard_settings.get('disableCopy'): - guacd_params['disable-copy'] = 'true' - if clipboard_settings.get('disablePaste'): + # PAM clipboard → guacd: only pass disable-* when the record sets them (guacd "true" = on). + _pam_clip = settings.get('clipboard') or {} + if _pam_clip.get('disablePaste'): guacd_params['disable-paste'] = 'true' + if _pam_clip.get('disableCopy'): + guacd_params['disable-copy'] = 'true' # Build final connection settings connection_settings = { @@ -870,6 +1149,8 @@ def _build_guacamole_connection_settings( 'audio_mimetypes': [], # No audio for terminal 'video_mimetypes': [], # No video for terminal 'image_mimetypes': ['image/png', 'image/jpeg', 'image/webp'], + # PAM clipboard policy (also in guacd_params as disable-* only when record disables) + 'clipboard': dict(settings.get('clipboard') or {}), } logging.debug(f"Built Guacamole connection settings for {protocol}: " @@ -1087,13 +1368,11 @@ def _open_terminal_webrtc_tunnel(params: KeeperParams, allow_supply_user = context.get('allowSupplyUser', False) user_record_uid = context.get('userRecordUid') - logging.debug(f"DEBUG credential determination: allow_supply_host={allow_supply_host}, allow_supply_user={allow_supply_user}, user_record_uid={user_record_uid}") - # credential_type is None when using pamMachine credentials directly (backward compatible) # Priority: if user_record_uid is provided (from CLI or record), use 'linked' to send those credentials credential_type = None if user_record_uid: - # Linked user present (from CLI --user or record) - use linked credentials + # Linked user present (from CLI --credential or record) - use linked credentials credential_type = 'linked' logging.debug(f"Using 'linked' credential type with userRecordUid: {user_record_uid}") elif allow_supply_host or allow_supply_user: @@ -1126,14 +1405,8 @@ def _open_terminal_webrtc_tunnel(params: KeeperParams, ) logging.debug(f"Created PythonHandler for conversation {conversation_id}") - logging.debug(f"DEBUG: handler_callback is {'SET' if handler_callback else 'None'}, type={type(handler_callback)}") - logging.debug(f"DEBUG: python_handler is {'SET' if python_handler else 'None'}") - logging.debug(f"DEBUG: connection_settings has {len(connection_settings)} keys: {list(connection_settings.keys())}") # Create the tube to get the WebRTC offer - logging.debug(f"DEBUG: Calling create_tube with handler_callback={'SET' if handler_callback else 'None'}") - logging.debug(f"DEBUG: Calling create_tube with handler_callback={'SET' if handler_callback else 'None'}") - logging.debug(f"DEBUG: webrtc_settings['conversationType'] = {webrtc_settings.get('conversationType')}") offer = tube_registry.create_tube( conversation_id=conversation_id, settings=webrtc_settings, @@ -1249,6 +1522,23 @@ def _open_terminal_webrtc_tunnel(params: KeeperParams, else: offer_payload = offer_sdp + # Gateway may configure guacd from this map before Python's `connect`. + offer_guacd_params: Dict[str, Any] = {'enable-pipe': 'true'} + _offer_clip = settings.get('clipboard') or {} + if use_python_handler: + _cs_gp = connection_settings.get('guacd_params') or {} + for _k, _v in _cs_gp.items(): + if _k in ('disable-paste', 'disable-copy'): + continue + offer_guacd_params[_k] = _v + if _offer_clip.get('disablePaste'): + offer_guacd_params['disable-paste'] = 'true' + if _offer_clip.get('disableCopy'): + offer_guacd_params['disable-copy'] = 'true' + + _offer_disable_copy = bool(_offer_clip.get('disableCopy')) + _offer_disable_paste = bool(_offer_clip.get('disablePaste')) + offer_data = { "offer": offer_payload, "audio": ["audio/L8", "audio/L16"], # Supported audio codecs @@ -1259,7 +1549,18 @@ def _open_terminal_webrtc_tunnel(params: KeeperParams, "host": { "hostName": settings['hostname'], "port": settings['port'] - } + }, + # enable-pipe + optional disable-paste/disable-copy from PAM (see offer_guacd_params) + "guacd_params": offer_guacd_params, + "terminalSettings": { + "disableCopy": _offer_disable_copy, + "disablePaste": _offer_disable_paste, + }, + # Alternate shape (PAM record uses connection.clipboard) + "clipboard": { + "disableCopy": _offer_disable_copy, + "disablePaste": _offer_disable_paste, + }, # these are not sent by webvault during open connection for terminal connections # "protocol": protocol, # "terminalSettings": { @@ -1272,13 +1573,9 @@ def _open_terminal_webrtc_tunnel(params: KeeperParams, # if 'protocol_specific' in settings and settings['protocol_specific']: # offer_data["protocolSettings"] = settings['protocol_specific'] - # Log what we're sending in the initial offer logging.debug(f"Sending initial offer with connection parameters: {json.dumps(offer_data, indent=2)}") - - string_data = json.dumps(offer_data) - logging.debug(f"payload.inputs.data JSON before encryption: {string_data}") - bytes_data = string_to_bytes(string_data) - encrypted_data = tunnel_encrypt(symmetric_key, bytes_data) + data_bytes = string_to_bytes(json.dumps(offer_data)) + encrypted_data = tunnel_encrypt(symmetric_key, data_bytes) # Get userRecordUid and credential flags from context (extracted in extract_terminal_settings) user_record_uid = context.get('userRecordUid') @@ -1286,19 +1583,26 @@ def _open_terminal_webrtc_tunnel(params: KeeperParams, allow_supply_user = context.get('allowSupplyUser', False) # Determine credential type for gateway inputs - # IMPORTANT: Priority must match the guacd credentials logic above: - # 1. If user_record_uid is set (from CLI or record), use 'linked' - credentials come from that record - # 2. If allowSupply* but no user_record_uid, use 'userSupplied' - user will type at prompt - # 3. Otherwise, use pamMachine credentials directly (no credentialType) + # Gateway credential types: + # - 'linked': Look up credential in DAG (for records with DAG-linked pamUser) + # - 'userSupplied': Skip DAG lookup, credentials from ConnectAs (-cr) or user prompt + # - None: Use pamMachine credentials directly + # Priority: prefer 'linked' when DAG has credentials (even if allowSupply* is enabled). + # Use 'userSupplied' only when no linked credential but allowSupply* enabled. credential_type_for_gateway = None - if user_record_uid: - # Credentials will come from linked pamUser record (via python_handler) + cli_user_override = context.get('cliUserOverride', False) + if cli_user_override: + # User explicitly supplied a different credential via -cr. + # The -cr record is NOT DAG-linked to this machine so 'linked' would fail; + # credentials arrive via the ConnectAs payload (built in launch.py after tunnel opens). + # NOTE: -H/-hr are not accepted without -cr (legacy, to match Web Vault behaviour), + # so cli_user_override=True is the only reliable signal that the user supplied something. + credential_type_for_gateway = 'userSupplied' + logging.debug("CLI credential override active - using 'userSupplied' for gateway") + elif user_record_uid: + # DAG-linked pamUser (no CLI override) - gateway looks up credentials via DAG credential_type_for_gateway = 'linked' logging.debug(f"Using 'linked' credential type for gateway with userRecordUid: {user_record_uid}") - elif allow_supply_host or allow_supply_user: - # No credentials provided, user must type at prompt - credential_type_for_gateway = 'userSupplied' - logging.debug("No credentials provided, allowSupply enabled - using 'userSupplied' for gateway") else: logging.debug(f"No linked pamUser for record {record_uid} - using pamMachine credentials directly") @@ -1388,6 +1692,8 @@ def _open_terminal_webrtc_tunnel(params: KeeperParams, "screen_info": screen_info, "python_handler": python_handler, # PythonHandler for simplified guac protocol "use_python_handler": use_python_handler, + "user_record_uid": user_record_uid, # For ConnectAs payload + "gateway_uid": gateway_uid, # For ConnectAs payload } else: # Non-streaming path: Handle response immediately @@ -1407,6 +1713,9 @@ def _open_terminal_webrtc_tunnel(params: KeeperParams, logging.debug(f"{bcolors.OKGREEN}Offer sent to gateway (non-streaming mode){bcolors.ENDC}") logging.debug(f"Router response: {router_response}") + # Must be defined before return below; only refined inside `if router_response`. + remote_webrtc_version = None + # Handle immediate response if router_response and router_response.get('response'): response_dict = router_response['response'] @@ -1448,37 +1757,36 @@ def _open_terminal_webrtc_tunnel(params: KeeperParams, data_text = bytes_to_string(decrypted_data).replace("'", '"') logging.debug(f"Successfully decrypted data for {conversation_id_original}, length: {len(data_text)}") - # Parse JSON + # Parse JSON; fallback to raw SDP if decrypted data is plain SDP + answer_sdp = None + data_json = None try: data_json = json.loads(data_text) - - # Ensure data_json is a dictionary if isinstance(data_json, dict): logging.debug(f"🔓 Decrypted payload type: {data_json.get('type', 'unknown')}, keys: {list(data_json.keys())}") - - # Handle SDP answer - if "answer" in data_json: - answer_sdp = data_json.get('answer') - if answer_sdp: - logging.debug(f"Found SDP answer in non-streaming response, sending to Rust for conversation: {conversation_id_original}") - tube_registry.set_remote_description(commander_tube_id, answer_sdp, is_answer=True) - - if hasattr(tunnel_session, "gateway_ready_event") and tunnel_session.gateway_ready_event is not None: - tunnel_session.gateway_ready_event.set() - logging.debug(f"{bcolors.OKBLUE}Connection state: {bcolors.ENDC}SDP answer received, connecting...") - - # Send any buffered local ICE candidates now that we have the answer - if tunnel_session.buffered_ice_candidates: - logging.debug(f"Sending {len(tunnel_session.buffered_ice_candidates)} buffered ICE candidates after answer") - for candidate in tunnel_session.buffered_ice_candidates: - signal_handler._send_ice_candidate_immediately(candidate, commander_tube_id) - tunnel_session.buffered_ice_candidates.clear() - elif "offer" in data_json or (data_json.get("type") == "offer"): - # Gateway is sending us an ICE restart offer (unlikely in non-streaming mode) - logging.warning(f"Received ICE restart offer in non-streaming mode - this is unexpected") - except json.JSONDecodeError as e: - logging.error(f"Failed to parse decrypted data as JSON: {e}") - logging.debug(f"Data text: {data_text[:200]}...") + answer_sdp = data_json.get('answer') or data_json.get('sdp') + except (json.JSONDecodeError, TypeError): + if data_text.strip().startswith('v=') and 'm=' in data_text: + answer_sdp = data_text.strip() + logging.debug("Decrypted data appears to be raw SDP (not JSON), using as answer") + + if answer_sdp: + logging.debug(f"Found SDP answer in non-streaming response, sending to Rust for conversation: {conversation_id_original}") + remote_webrtc_version = set_remote_description_and_parse_version( + tube_registry, commander_tube_id, answer_sdp, is_answer=True + ) + + if hasattr(tunnel_session, "gateway_ready_event") and tunnel_session.gateway_ready_event is not None: + tunnel_session.gateway_ready_event.set() + logging.debug(f"{bcolors.OKBLUE}Connection state: {bcolors.ENDC}SDP answer received, connecting...") + + if tunnel_session.buffered_ice_candidates: + logging.debug(f"Sending {len(tunnel_session.buffered_ice_candidates)} buffered ICE candidates after answer") + for candidate in tunnel_session.buffered_ice_candidates: + signal_handler._send_ice_candidate_immediately(candidate, commander_tube_id) + tunnel_session.buffered_ice_candidates.clear() + elif isinstance(data_json, dict) and ("offer" in data_json or data_json.get("type") == "offer"): + logging.warning(f"Received ICE restart offer in non-streaming mode - this is unexpected") else: logging.warning(f"Decryption returned None for conversation {conversation_id_original}") except Exception as e: @@ -1506,9 +1814,20 @@ def _open_terminal_webrtc_tunnel(params: KeeperParams, "screen_info": screen_info, "python_handler": python_handler, # PythonHandler for simplified guac protocol "use_python_handler": use_python_handler, + "user_record_uid": user_record_uid, # For ConnectAs payload + "gateway_uid": gateway_uid, # For ConnectAs payload + "remote_webrtc_version": remote_webrtc_version, # From SDP for ConnectAs capability } except Exception as e: + # Stop dedicated WebSocket before Rust/tube cleanup so we do not process a late + # channel_closed after the CLI has already returned (avoids stray ERROR after prompt). + try: + if tunnel_session.websocket_stop_event and tunnel_session.websocket_thread: + tunnel_session.websocket_stop_event.set() + tunnel_session.websocket_thread.join(timeout=3.0) + except Exception: + logging.debug("Stopping WebSocket after HTTP offer failure", exc_info=True) signal_handler.cleanup() unregister_tunnel_session(commander_tube_id) unregister_conversation_key(conversation_id) @@ -1559,8 +1878,12 @@ def launch_terminal_connection(params: KeeperParams, try: # Step 1: Detect protocol protocol = detect_protocol(params, record_uid) - if not protocol: - raise CommandError('pam launch', f'Could not detect protocol for record {record_uid}') + if not protocol or protocol not in ALL_TERMINAL: + raise CommandError( + 'pam launch', + f'Protocol {protocol!r} is not supported for record {record_uid}. ' + 'Only terminal protocols (ssh, telnet, kubernetes, mysql, postgresql, sql-server) are supported.' + ) logging.debug(f"Detected protocol: {protocol}") @@ -1571,6 +1894,7 @@ def launch_terminal_connection(params: KeeperParams, protocol, launch_credential_uid=kwargs.get('launch_credential_uid'), custom_host=kwargs.get('custom_host'), + custom_port=kwargs.get('custom_port'), ) logging.debug(f"Extracted settings: hostname={settings['hostname']}, port={settings['port']}") @@ -1622,5 +1946,3 @@ def launch_terminal_connection(params: KeeperParams, except Exception as e: logging.error(f"Error launching terminal connection: {e}") raise CommandError('pam launch', f'Failed to launch terminal connection: {e}') - - diff --git a/keepercommander/commands/pedm/pedm_admin.py b/keepercommander/commands/pedm/pedm_admin.py index 2725dbd45..6365cbedf 100644 --- a/keepercommander/commands/pedm/pedm_admin.py +++ b/keepercommander/commands/pedm/pedm_admin.py @@ -681,6 +681,8 @@ def execute(self, context: KeeperParams, **kwargs): if isinstance(status, admin_types.EntityStatus) and not status.success: raise base.CommandError(f'Failed to update policy "{status.entity_uid}": {status.message}') + utils.get_logger().info('Successfully updated deployment: %s', deployment.name or deployment.deployment_uid) + class PedmDeploymentDeleteCommand(base.ArgparseCommand): def __init__(self): @@ -745,9 +747,12 @@ def execute(self, context: KeeperParams, **kwargs) -> Optional[str]: token = f'{host}:{deployment.deployment_uid}:{utils.base64_url_encode(deployment.private_key)}' filename = kwargs.get('file') if filename: + if os.path.isdir(filename): + raise base.CommandError(f'"{filename}" is a directory. Please provide a full file path, e.g. "{os.path.join(filename, "deployment-token.txt")}"') with open(filename, 'wt') as f: f.write(token) - return None + utils.get_logger().info('Deployment token saved to: %s', os.path.abspath(filename)) + return None if not kwargs.get('verbose'): return token @@ -882,11 +887,24 @@ def execute(self, context: KeeperParams, **kwargs) -> Any: if len(agent_uid_list) == 0: return - statuses = plugin.modify_agents( remove_agents=agent_uid_list) + force = kwargs.get('force') is True + if not force: + answer = prompt_utils.user_choice(f'Do you want to delete {len(agent_uid_list)} agent(s)?', 'yN') + if answer.lower() not in {'y', 'yes'}: + return + + statuses = plugin.modify_agents(remove_agents=agent_uid_list) + deleted_count = 0 if isinstance(statuses.remove, list): for status in statuses.remove: - if isinstance(status, admin_types.EntityStatus) and not status.success: - utils.get_logger().warning(f'Failed to remove agent "{status.entity_uid}": {status.message}') + if isinstance(status, admin_types.EntityStatus): + if status.success: + deleted_count += 1 + utils.get_logger().info('Agent "%s" deleted successfully.', status.entity_uid) + else: + utils.get_logger().warning(f'Failed to remove agent "{status.entity_uid}": {status.message}') + if deleted_count > 0: + utils.get_logger().info('%d agent(s) deleted successfully.', deleted_count) class PedmAgentEditCommand(base.ArgparseCommand): @@ -905,9 +923,8 @@ def execute(self, context: KeeperParams, **kwargs) -> Any: deployment_uid = kwargs.get('deployment') if deployment_uid: - deployment = plugin.deployments.get_entity(deployment_uid) - if not deployment: - raise base.CommandError(f'Deployment "{deployment_uid}" does not exist') + deployment = PedmUtils.resolve_single_deployment(plugin, deployment_uid) + deployment_uid = deployment.deployment_uid else: deployment_uid = None @@ -939,8 +956,11 @@ def execute(self, context: KeeperParams, **kwargs) -> Any: statuses = plugin.modify_agents(update_agents=update_agents) if isinstance(statuses.update, list): for status in statuses.update: - if isinstance(status, admin_types.EntityStatus) and not status.success: - utils.get_logger().warning(f'Failed to update agent "{status.entity_uid}": {status.message}') + if isinstance(status, admin_types.EntityStatus): + if status.success: + utils.get_logger().info(f'Agent "{status.entity_uid}" updated successfully.') + else: + utils.get_logger().warning(f'Failed to update agent "{status.entity_uid}": {status.message}') class PedmAgentListCommand(base.ArgparseCommand): @@ -1236,12 +1256,6 @@ def get_policy_controls(policy_type_name: str, **kwargs) -> Optional[List[str]]: def get_policy_filter(plugin: admin_plugin.PedmPlugin, **kwargs) -> Dict[str, Any]: policy_filter: Dict[str, Any] = {} for f in PedmPolicyMixin.ALL_FILTERS: - arg_name = f'{f.lower()}_filter' - p_filter: Any = kwargs.get(arg_name) - if not p_filter: continue - if isinstance(p_filter, str): - p_filter = [p_filter] - if f == 'USER': filter_name = 'UserCheck' elif f == 'MACHINE': @@ -1256,21 +1270,31 @@ def get_policy_filter(plugin: admin_plugin.PedmPlugin, **kwargs) -> Dict[str, An filter_name = 'DayCheck' else: continue - if '*' in p_filter: - policy_filter[filter_name] = ['*'] + + arg_name = f'{f.lower()}_filter' + p_filter: Any = kwargs.get(arg_name) + if p_filter: + if isinstance(p_filter, str): + p_filter = [p_filter] + if '*' in p_filter: + policy_filter[filter_name] = ['*'] + else: + if f == 'USER': + policy_filter[filter_name] = PedmPolicyAddCommand.resolve_collections(plugin, [3, 6, 103], p_filter) + elif f == 'MACHINE': + policy_filter[filter_name] = PedmPolicyAddCommand.resolve_collections(plugin, [1, 101], p_filter) + elif f == 'APP': + policy_filter[filter_name] = PedmPolicyAddCommand.resolve_collections(plugin, [2, 102], p_filter) + elif f == 'DATE': + policy_filter[filter_name] = PedmPolicyAddCommand.resolve_dates(p_filter) + elif f == 'TIME': + policy_filter[filter_name] = PedmPolicyAddCommand.resolve_times(p_filter) + elif f == 'DAY': + policy_filter[filter_name] = PedmPolicyAddCommand.resolve_days(p_filter) else: - if f == 'USER': - policy_filter[filter_name] = PedmPolicyAddCommand.resolve_collections(plugin, [3, 6, 103], p_filter) - elif f == 'MACHINE': - policy_filter[filter_name] = PedmPolicyAddCommand.resolve_collections(plugin, [1, 101], p_filter) - elif f == 'APP': - policy_filter[filter_name] = PedmPolicyAddCommand.resolve_collections(plugin, [2, 102], p_filter) - elif f == 'DATE': - policy_filter[filter_name] = PedmPolicyAddCommand.resolve_dates(p_filter) - elif f == 'TIME': - policy_filter[filter_name] = PedmPolicyAddCommand.resolve_times(p_filter) - elif f == 'DAY': - policy_filter[filter_name] = PedmPolicyAddCommand.resolve_days(p_filter) + if filter_name not in policy_filter: + policy_filter[filter_name] = [] + risk_level = kwargs.get('risk_level') if isinstance(risk_level, int): if risk_level < 0 or risk_level > 100: @@ -2157,10 +2181,19 @@ def execute(self, context: KeeperParams, **kwargs) -> Any: approval.expire_in ) + fmt = kwargs.get('format') + justification = approval.justification + if fmt != 'json' and isinstance(justification, str): + try: + parsed = json.loads(justification) + if isinstance(parsed, dict): + justification = parsed.get('text', justification) + except (json.JSONDecodeError, ValueError): + pass + row = [approval.approval_uid, approval_type, approval_status, approval.agent_uid, approval.account_info, - approval.application_info, approval.justification, approval.expire_in, approval.created] + approval.application_info, justification, approval.expire_in, approval.created] - fmt = kwargs.get('format') if fmt == 'json': table = [row] else: @@ -2186,6 +2219,7 @@ def execute(self, context: KeeperParams, **kwargs) -> Any: approval_type = approval_type.lower() else: approval_type = None + fmt = kwargs.get('format') table: List[List[Any]] = [] headers = ['approval_uid', 'approval_type', 'status', 'agent_uid', 'account_info', 'application_info', 'justification', 'expire_in', 'created'] for approval in plugin.approvals.get_all_entities(): @@ -2201,12 +2235,19 @@ def execute(self, context: KeeperParams, **kwargs) -> Any: account_info = [y[:30] for y in (f'{k}={v}' for k, v in approval.account_info.items())] application_info = [y[:30] for y in (f'{k}={v}' for k, v in approval.application_info.items())] + justification = approval.justification + if fmt != 'json' and isinstance(justification, str): + try: + parsed = json.loads(justification) + if isinstance(parsed, dict): + justification = parsed.get('text', justification) + except (json.JSONDecodeError, ValueError): + pass table.append([approval.approval_uid, pedm_shared.approval_type_to_name(approval.approval_type), - status, approval.agent_uid, account_info, application_info, approval.justification, + status, approval.agent_uid, account_info, application_info, justification, approval.expire_in, approval.created]) table.sort(key=lambda x: x[8], reverse=True) - fmt = kwargs.get('format') if fmt != 'json': headers = [report_utils.field_to_title(x) for x in headers] return report_utils.dump_report_data(table, headers, fmt=fmt, filename=kwargs.get('output')) diff --git a/keepercommander/commands/pedm/pedm_aram.py b/keepercommander/commands/pedm/pedm_aram.py index b89648514..898d758f8 100644 --- a/keepercommander/commands/pedm/pedm_aram.py +++ b/keepercommander/commands/pedm/pedm_aram.py @@ -25,6 +25,7 @@ def __init__(self): self.register_command_new(PedmColumnReportCommand(), 'column', 'c') self.register_command_new(PedmEventReportCommand(), 'event', 'e') self.register_command_new(PedmEventSummaryReportCommand(), 'summary', 's') + self.register_command_new(PedmValueReportCommand(), 'value', 'v') @dataclass @@ -715,4 +716,51 @@ def execute(self, context: KeeperParams, **kwargs) -> Any: agent_uid = utils.base64_url_encode(rs.agentUid[i]) if i < len(rs.agentUid) else '' rows.append([policy_uid, agent_uid]) - return report_utils.dump_report_data(rows, headers, fmt=kwargs.get('format'), filename=kwargs.get('output')) \ No newline at end of file + return report_utils.dump_report_data(rows, headers, fmt=kwargs.get('format'), filename=kwargs.get('output')) + + +class PedmValueReportCommand(base.ArgparseCommand): + def __init__(self): + parser = argparse.ArgumentParser(prog='report value', description='Look up audit event values by UID', + parents=[base.report_output_parser]) + parser.add_argument('uid', nargs='+', help='Value UID') + super().__init__(parser) + + def execute(self, context: KeeperParams, **kwargs) -> Any: + enterprise = context.enterprise + tree_key = enterprise['unencrypted_tree_key'] + encrypted_ec_private_key = utils.base64_url_decode(enterprise['keys']['ecc_encrypted_private_key']) + ec_private_key = crypto.load_ec_private_key(crypto.decrypt_aes_v2(encrypted_ec_private_key, tree_key)) + + uids = kwargs.get('uid') + if not isinstance(uids, list): + uids = [str(uids)] + + coll_rq = pedm_pb2.AuditCollectionRequest() + coll_rq.valueUid.extend([utils.base64_url_decode(x) for x in uids]) + coll_rs = api.execute_router(context, + 'pedm/get_audit_collections', coll_rq, rs_type=pedm_pb2.AuditCollectionResponse) + assert coll_rs is not None + + results: List[Dict[str, Any]] = [] + for v in coll_rs.values: + value_uid = utils.base64_url_encode(v.valueUid) + field_name = v.collectionName + try: + decrypted = crypto.decrypt_ec(v.encryptedData, ec_private_key).decode('utf-8') + except Exception: + decrypted = '' + results.append({'uid': value_uid, 'field': field_name, 'value': decrypted}) + + if kwargs.get('format') == 'json': + for r in results: + try: + r['value'] = json.loads(r['value']) + except (json.JSONDecodeError, TypeError): + pass + return json.dumps(results, indent=2) + + rows = [[r['uid'], r['field'], r['value']] for r in results] + headers = [report_utils.field_to_title(x) for x in ('uid', 'field', 'value')] + return report_utils.dump_report_data(rows, headers, fmt=kwargs.get('format'), filename=kwargs.get('output'), + row_number=True) \ No newline at end of file diff --git a/keepercommander/commands/record.py b/keepercommander/commands/record.py index 87d863f85..829a5f5fa 100644 --- a/keepercommander/commands/record.py +++ b/keepercommander/commands/record.py @@ -277,9 +277,11 @@ def execute(self, params, **kwargs): admins = api.get_share_admins_for_shared_folder(params, uid) sf = api.get_shared_folder(params, uid) if fmt == 'json': + path = get_folder_path(params, sf.shared_folder_uid, delimiter=os.sep) if sf.shared_folder_uid else '' sfo = { "shared_folder_uid": sf.shared_folder_uid, "name": sf.name, + "path": path, "manage_users": sf.default_manage_users, "manage_records": sf.default_manage_records, "can_edit": sf.default_can_edit, @@ -291,17 +293,25 @@ def execute(self, params, **kwargs): 'can_edit': r['can_edit'], 'can_share': r['can_share'] } for r in sf.records] + def _format_expiration(expiration_value): + if expiration_value is None or expiration_value <= 0: + return 'never' + return datetime.datetime.fromtimestamp(expiration_value // 1000).isoformat() if sf.users: sfo['users'] = [{ 'username': u['username'], + 'user_id': u.get('account_uid'), 'manage_records': u['manage_records'], - 'manage_users': u['manage_users'] + 'manage_users': u['manage_users'], + 'expiration': _format_expiration(u.get('expiration')) } for u in sf.users] if sf.teams: sfo['teams'] = [{ 'name': t['name'], + 'team_uid': t.get('team_uid'), 'manage_records': t['manage_records'], - 'manage_users': t['manage_users'] + 'manage_users': t['manage_users'], + 'expiration': _format_expiration(t.get('expiration')) } for t in sf.teams] if admins: diff --git a/keepercommander/commands/record_edit.py b/keepercommander/commands/record_edit.py index 87676deff..bb94eb341 100644 --- a/keepercommander/commands/record_edit.py +++ b/keepercommander/commands/record_edit.py @@ -848,9 +848,9 @@ def execute(self, params, **kwargs): record_fields.append(parsed_field) if record_type in ('legacy', 'general'): - # raise CommandError('record-add', 'Legacy record type is not supported anymore.') - record = vault.PasswordRecord() - self.assign_legacy_fields(record, record_fields) + raise CommandError('record-add', 'Legacy record type is not supported.') + # record = vault.PasswordRecord() + # self.assign_legacy_fields(record, record_fields) else: rt_fields = self.get_record_type_fields(params, record_type) if not rt_fields: @@ -1244,7 +1244,8 @@ def execute(self, params, **kwargs): record_fields.append(parsed_field) if isinstance(record, vault.PasswordRecord): - self.assign_legacy_fields(record, record_fields) + raise CommandError('record-update', 'Legacy record type is not supported. Convert the record to login record type.') + # self.assign_legacy_fields(record, record_fields) elif isinstance(record, vault.TypedRecord): record_type = kwargs.get('record_type') if record_type: diff --git a/keepercommander/commands/register.py b/keepercommander/commands/register.py index 1c1f5ed7b..7cfc9fefd 100644 --- a/keepercommander/commands/register.py +++ b/keepercommander/commands/register.py @@ -1240,6 +1240,8 @@ def get_sf_shares(): aram_enabled = True if shared_records: headers = ['record_owner', 'record_uid', 'record_title', 'shared_with', 'folder_path'] + if include_share_date: + headers.append('share_date') table = [] for uid, shared_record in shared_records.items(): share_events = include_share_date and aram_enabled and self.get_record_share_activities(params, uid) @@ -1270,7 +1272,10 @@ def get_sf_shares(): share_info = '\n'.join(share_info) - table.append([shared_record.owner, shared_record.uid, shared_record.name, share_info, folder_paths]) + row = [shared_record.owner, shared_record.uid, shared_record.name, share_info, folder_paths] + if include_share_date: + row.append(self.get_record_share_date(share_events)) + table.append(row) if output_format != 'json': headers = [field_to_title(x) for x in headers] return dump_report_data( @@ -1400,6 +1405,18 @@ def get_date_for_share_folder_record(share_activity_list: list, shared_folder_ui return '\t(shared on {0})'.format(date_formatted) + @staticmethod + def get_record_share_date(share_events): + if not share_events: + return '' + share_types = ('share', 'record_share_outside_user', 'folder_add_record') + relevant = [e for e in share_events if e.get('audit_event_type') in share_types] + if not relevant: + return '' + earliest = min(relevant, key=lambda e: e['created']) + created = earliest['created'] + return time.strftime('%Y-%m-%d %H:%M:%S UTC', time.gmtime(created // 1000)) + class RecordPermissionCommand(Command): def get_parser(self): diff --git a/keepercommander/commands/scim.py b/keepercommander/commands/scim.py index 9fd6a9bfd..4a379fab7 100644 --- a/keepercommander/commands/scim.py +++ b/keepercommander/commands/scim.py @@ -152,12 +152,22 @@ def execute(self, params, node=None, **kwargs): api.communicate(params, rq) api.query_enterprise(params) + scim_url = get_scim_url(params, matched_node['node_id']) + node_id = matched_node['node_id'] logging.info('') logging.info('SCIM ID: %d', rq['scim_id']) - logging.info('SCIM URL: %s', get_scim_url(params, matched_node['node_id'])) + logging.info('SCIM URL: %s', scim_url) logging.info('Provisioning Token: %s', token) logging.info('') - return token + return { + 'scim_id': rq['scim_id'], + 'scim_url': scim_url, + 'provisioning_token': token, + 'node_name': self.get_node_path(params, node_id), + 'node_id': node_id, + 'prefix': prefix or '', + 'unique_groups': kwargs.get('unique_groups', '') == 'on', + } def find_scim(param, name): # type: (KeeperParams, any) -> dict @@ -243,12 +253,23 @@ def execute(self, params, target=None, **kwargs): api.communicate(params, rq) api.query_enterprise(params) + node_id = scim['node_id'] + scim_url = get_scim_url(params, node_id) + updated = find_scim(params, str(scim['scim_id'])) logging.info('') logging.info('SCIM ID: %d', scim['scim_id']) - logging.info('SCIM URL: %s', get_scim_url(params, scim['node_id'])) + logging.info('SCIM URL: %s', scim_url) logging.info('Provisioning Token: %s', token) logging.info('') - return token + return { + 'scim_id': scim['scim_id'], + 'scim_url': scim_url, + 'provisioning_token': token, + 'node_name': self.get_node_path(params, node_id), + 'node_id': node_id, + 'prefix': updated.get('role_prefix') or '', + 'unique_groups': updated.get('unique_groups', False), + } class ScimDeleteCommand(EnterpriseCommand): diff --git a/keepercommander/commands/security_audit.py b/keepercommander/commands/security_audit.py index 7b0547ab6..225f9a4a1 100644 --- a/keepercommander/commands/security_audit.py +++ b/keepercommander/commands/security_audit.py @@ -228,7 +228,12 @@ def get_node_id(name_or_id): nodes = kwargs.get('node') or [] node_ids = [get_node_id(n) for n in nodes] - node_ids = [n for n in node_ids if n] + if nodes: + invalid_nodes = [n for n, nid in zip(nodes, node_ids) if not nid] + if invalid_nodes: + logging.error('Node(s) not found: %s', ', '.join(invalid_nodes)) + return + node_ids = [n for n in node_ids if n] score_type = kwargs.get('score_type', 'default') save_report = kwargs.get('save') or attempt_fix show_updated = save_report or kwargs.get('show_updated') diff --git a/keepercommander/commands/ssh_agent.py b/keepercommander/commands/ssh_agent.py index 74fdfb009..d7d53691d 100644 --- a/keepercommander/commands/ssh_agent.py +++ b/keepercommander/commands/ssh_agent.py @@ -179,13 +179,59 @@ def is_private_key_name(name): # type: (str) -> bool return False KEY_SIZE_MIN = 119 # Smallest possible size for ed25519 private key in PKCS#8 format +# PEM bodies for large RSA keys (8192+) exceed 4K; keep a generous cap for sanity. KEY_SIZE_MAX = 4000 +PEM_TEXT_MAX = 256 * 1024 + + +def _normalize_typed_field_label(label): + # type: (Any) -> str + if not label or not isinstance(label, str): + return '' + return ''.join(c.lower() for c in label if c.isalnum()) + + +# Vault/PAM pamUser often stores PEM in a secret field labeled privatePEMKey (or similar). +_PEM_SECRET_FIELD_LABELS = frozenset({ + 'privatepemkey', + 'sshprivatekey', + 'sshkeypem', +}) + + +def _coerce_str_field_value(value): + # type: (Any) -> Optional[str] + if value is None: + return None + if isinstance(value, bytes): + try: + return value.decode('utf-8') + except Exception: + return None + if isinstance(value, str): + return value + return None + + def is_valid_key_value(value): return isinstance(value, str) and KEY_SIZE_MIN <= len(value) < KEY_SIZE_MAX + +def _is_plausible_pem_private_key_blob(text): + # type: (Optional[str]) -> bool + text = _coerce_str_field_value(text) + if not text: + return False + text = text.strip() + if len(text) < KEY_SIZE_MIN or len(text) > PEM_TEXT_MAX: + return False + header, _, _ = text.partition('\n') + return bool(is_private_key(header)) + + def is_valid_key_file(file): try: - return KEY_SIZE_MIN <= file.size < KEY_SIZE_MAX + return KEY_SIZE_MIN <= file.size < PEM_TEXT_MAX except: return False @@ -211,31 +257,41 @@ def try_extract_private_key(params, record_or_uid): if key_pair: private_key = key_pair.get('privateKey') + # Explicit PEM secret fields (pamUser template: type secret, label privatePEMKey, etc.) + if not private_key and isinstance(record, vault.TypedRecord): + for fld in itertools.chain(record.fields, record.custom): + if _normalize_typed_field_label(getattr(fld, 'label', None)) not in _PEM_SECRET_FIELD_LABELS: + continue + candidate = _coerce_str_field_value(fld.get_default_value()) + if _is_plausible_pem_private_key_blob(candidate): + private_key = candidate.strip() + break + # check notes field if not private_key: if isinstance(record, (vault.PasswordRecord, vault.TypedRecord)): - if is_valid_key_value(record.notes): - header, _, _ = record.notes.partition('\n') - if is_private_key(header): - private_key = record.notes + notes = getattr(record, 'notes', None) + if _is_plausible_pem_private_key_blob(notes): + private_key = notes.strip() - # check custom fields + # check typed fields / custom (text, multiline, secret, note) if not private_key: if isinstance(record, vault.TypedRecord): - try_values = (x.get_default_value() for x in itertools.chain(record.fields, record.custom) if x.type in ('text', 'multiline', 'secret', 'note')) - for value in (x for x in try_values if x): - if is_valid_key_value(value): - header, _, _ = value.partition('\n') - if is_private_key(header): - private_key = value - break + for x in itertools.chain(record.fields, record.custom): + if x.type not in ('text', 'multiline', 'secret', 'note'): + continue + candidate = _coerce_str_field_value(x.get_default_value()) + if _is_plausible_pem_private_key_blob(candidate): + private_key = candidate.strip() + break elif isinstance(record, vault.PasswordRecord): - for value in (x.value for x in record.custom if x.value): - if is_valid_key_value(value): - header, _, _ = value.partition('\n') - if is_private_key(header): - private_key = value - break + for cf in record.custom: + if not cf.value: + continue + candidate = _coerce_str_field_value(cf.value[0] if isinstance(cf.value, list) and cf.value else cf.value) + if _is_plausible_pem_private_key_blob(candidate): + private_key = candidate.strip() + break # check for a single attachment if not private_key: diff --git a/keepercommander/commands/tunnel/port_forward/TunnelGraph.py b/keepercommander/commands/tunnel/port_forward/TunnelGraph.py index 8f86d83b4..2facf8705 100644 --- a/keepercommander/commands/tunnel/port_forward/TunnelGraph.py +++ b/keepercommander/commands/tunnel/port_forward/TunnelGraph.py @@ -458,6 +458,39 @@ def check_if_resource_has_launch_credential(self, resource_uid): return user_vertex.uid return False + def clear_launch_credential_for_resource(self, resource_uid, exclude_user_uid=None): + """Remove is_launch_credential from all users on a resource except exclude_user_uid.""" + resource_vertex = self.linking_dag.get_vertex(resource_uid) + if resource_vertex is None: + return + dirty = False + for user_vertex in resource_vertex.has_vertices(EdgeType.ACL): + if exclude_user_uid and user_vertex.uid == exclude_user_uid: + continue + acl_edge = user_vertex.get_edge(resource_vertex, EdgeType.ACL) + if not acl_edge: + continue + edge_content = acl_edge.content_as_dict + if edge_content and edge_content.get('is_launch_credential'): + edge_content = dict(edge_content) + edge_content.pop('is_launch_credential') + user_vertex.belongs_to(resource_vertex, EdgeType.ACL, content=edge_content) + dirty = True + if dirty: + self.linking_dag.save() + + def upgrade_resource_meta_to_v1(self, resource_uid): + """Ensure resource vertex meta has version >= 1 so vault reads ACL launch credentials.""" + resource_vertex = self.linking_dag.get_vertex(resource_uid) + if resource_vertex is None: + return + content = get_vertex_content(resource_vertex) + if content and content.get('version', 0) >= RESOURCE_META_VERSION_V1: + return + upgraded = ensure_resource_meta_v1(content) + resource_vertex.add_data(content=upgraded, path='meta', needs_encryption=False) + self.linking_dag.save() + def check_if_resource_allowed(self, resource_uid, setting): resource_vertex = self.linking_dag.get_vertex(resource_uid) content = get_vertex_content(resource_vertex) diff --git a/keepercommander/commands/tunnel/port_forward/tunnel_helpers.py b/keepercommander/commands/tunnel/port_forward/tunnel_helpers.py index 0e67930e7..7fc378f45 100644 --- a/keepercommander/commands/tunnel/port_forward/tunnel_helpers.py +++ b/keepercommander/commands/tunnel/port_forward/tunnel_helpers.py @@ -2,7 +2,9 @@ import enum import json import logging +import re import os +import threading import secrets import socket import string @@ -52,6 +54,54 @@ WEBSOCKETS_VERSION = None print("websockets library not available - install with: pip install websockets", file=sys.stderr) +# Regex for SDP attribute: a=keeper-webrtc:X.Y.Z (injected by keeper-pam-webrtc-rs) +_KEEPER_WEBRTC_VERSION_RE = re.compile(r"a=keeper-webrtc:(\S+)", re.IGNORECASE) + + +def parse_keeper_webrtc_version_from_sdp(sdp): + """ + Parse keeper-pam-webrtc-rs version from SDP attribute a=keeper-webrtc:X.Y.Z. + + The attribute is injected by the Rust module in both offer and answer. + Handles SDP that may be base64-encoded. + + Args: + sdp: SDP string (plain or base64-encoded). + + Returns: + Version string (e.g. "2.1.4") or None if not found. + """ + if not sdp or not isinstance(sdp, str): + return None + text = sdp + if "\n" not in sdp and "\r" not in sdp and len(sdp) > 20: + try: + decoded = base64.b64decode(sdp, validate=True) + text = decoded.decode("utf-8", errors="replace") + except Exception: + pass + m = _KEEPER_WEBRTC_VERSION_RE.search(text) + return m.group(1) if m else None + + +def set_remote_description_and_parse_version(tube_registry, tube_id, sdp, is_answer): + """ + Call tube_registry.set_remote_description and parse/store remote keeper-pam-webrtc + version when is_answer=True. Ensures version is always parsed regardless of which + code path delivered the SDP (WebSocket, HTTP, different JSON keys). + Returns the parsed version or None. + """ + tube_registry.set_remote_description(tube_id, sdp, is_answer=is_answer) + remote_ver = None + if is_answer: + remote_ver = parse_keeper_webrtc_version_from_sdp(sdp) + session = get_tunnel_session(tube_id) + if session and remote_ver: + session.remote_webrtc_version = remote_ver + logging.debug("Remote keeper-pam-webrtc version from SDP: %s", remote_ver) + return remote_ver + + # Constants NONCE_LENGTH = 12 MAIN_NONCE_LENGTH = 16 @@ -64,7 +114,6 @@ # ICE candidate buffering - store until SDP answer is received # Global conversation key management for multiple concurrent tunnels -import threading _CONVERSATION_KEYS_LOCK = threading.Lock() _GLOBAL_CONVERSATION_KEYS = {} # conversationId -> symmetric_key mapping @@ -334,7 +383,6 @@ def _configure_rust_logger_levels(current_is_debug: bool, log_level: int): # CRITICAL: Ensure root logger has a handler # pyo3_log sends Rust logs to Python loggers, but if loggers have no handlers, # messages are lost even if propagate=True - import sys if not root_logger.handlers: # Add a console handler if none exists console_handler = logging.StreamHandler(sys.stderr) @@ -655,8 +703,10 @@ def get_keeper_tokens(params): def get_config_uid_from_record(params, vault, record_uid): record = vault.KeeperRecord.load(params, record_uid) - if not isinstance(record, vault.TypedRecord): + if record is None: raise CommandError('', f"{bcolors.FAIL}Record {record_uid} not found.{bcolors.ENDC}") + if not isinstance(record, vault.TypedRecord): + raise CommandError('', f"{bcolors.FAIL}Record {record_uid} is not v3/typed record.{bcolors.ENDC}") record_type = record.record_type if record_type not in "pamMachine pamDatabase pamDirectory pamRemoteBrowser".split(): raise CommandError('', f"{bcolors.FAIL}This record's type is not supported for tunnels. " @@ -1135,55 +1185,45 @@ def route_message_to_rust(response_item, tube_registry): except (json.JSONDecodeError, TypeError): pass # Not a simple JSON string, continue with normal processing - data_json = json.loads(data_text) - - # Ensure data_json is a dictionary before processing - if not isinstance(data_json, dict): - logging.debug(f"Data is not a dictionary (got {type(data_json).__name__}), treating as acknowledgment: {data_json}") - return - - # Log what type of data we received - logging.debug(f"🔓 Decrypted payload type: {data_json.get('type', 'unknown')}, keys: {list(data_json.keys())}") - - if "answer" in data_json: - answer_sdp = data_json.get('answer') - - if answer_sdp: - logging.debug(f"Found SDP answer, sending to Rust for conversation: {conversation_id}") - # Send decrypted SDP answer to Rust - - # Try to find tube ID - gateway may have converted URL-safe base64 to standard - tube_id = tube_registry.tube_id_from_connection_id(conversation_id) - if not tube_id: - # Try URL-safe version (convert + to -, / to _, remove =) - url_safe_conversation_id = conversation_id.replace('+', '-').replace('/', '_').rstrip('=') - tube_id = tube_registry.tube_id_from_connection_id(url_safe_conversation_id) - if tube_id: - logging.debug(f"Found tube using URL-safe conversion: {url_safe_conversation_id}") - - if not tube_id: - logging.error(f"No tube ID found for conversation: {conversation_id} (also tried URL-safe version)") - return + try: + data_json = json.loads(data_text) + except (json.JSONDecodeError, TypeError): + data_json = None + + # Fallback: decrypted data may be raw SDP + answer_sdp = None + if isinstance(data_json, dict): + logging.debug(f"🔓 Decrypted payload type: {data_json.get('type', 'unknown')}, keys: {list(data_json.keys())}") + answer_sdp = data_json.get('answer') or data_json.get('sdp') + elif data_text.strip().startswith('v=') and 'm=' in data_text: + answer_sdp = data_text.strip() + logging.debug("Decrypted data appears to be raw SDP (not JSON), using as answer") + + if answer_sdp: + logging.debug(f"Found SDP answer, sending to Rust for conversation: {conversation_id}") + # Try to find tube ID - gateway may have converted URL-safe base64 to standard + tube_id = tube_registry.tube_id_from_connection_id(conversation_id) + if not tube_id: + url_safe_conversation_id = conversation_id.replace('+', '-').replace('/', '_').rstrip('=') + tube_id = tube_registry.tube_id_from_connection_id(url_safe_conversation_id) + if tube_id: + logging.debug(f"Found tube using URL-safe conversion: {url_safe_conversation_id}") - tube_registry.set_remote_description(tube_id, answer_sdp, is_answer=True) + if not tube_id: + logging.error(f"No tube ID found for conversation: {conversation_id} (also tried URL-safe version)") + else: + set_remote_description_and_parse_version(tube_registry, tube_id, answer_sdp, is_answer=True) logging.debug("Connection state: SDP answer received, connecting...") - # Send any buffered local ICE candidates now that we have the answer session = get_tunnel_session(tube_id) - if session: - # Send any buffered local ICE candidates now that we have the answer - if session.buffered_ice_candidates: - logging.debug(f"Sending {len(session.buffered_ice_candidates)} buffered ICE candidates after answer") - # Need to get the signal handler to send candidates - # Since we're in the routing function, we need to find the handler - # is stored in the session for this purpose - if hasattr(session, 'signal_handler') and session.signal_handler: - for candidate in session.buffered_ice_candidates: - session.signal_handler._send_ice_candidate_immediately(candidate, tube_id) - session.buffered_ice_candidates.clear() - else: - logging.warning(f"No signal handler found for tube {tube_id} to send buffered candidates") - elif "offer" in data_json or (data_json.get("type") == "offer"): + if session and session.buffered_ice_candidates: + if hasattr(session, 'signal_handler') and session.signal_handler: + for candidate in session.buffered_ice_candidates: + session.signal_handler._send_ice_candidate_immediately(candidate, tube_id) + session.buffered_ice_candidates.clear() + else: + logging.warning(f"No signal handler found for tube {tube_id} to send buffered candidates") + elif isinstance(data_json, dict) and ("offer" in data_json or data_json.get("type") == "offer"): # Gateway is sending us an ICE restart offer offer_sdp = data_json.get('sdp') or data_json.get('offer') @@ -1493,17 +1533,17 @@ def signal_from_rust(self, response: dict): # Detailed logging for specific states if new_state == 'disconnected': - logging.warning(f"Connection disconnected for tube {tube_id} - ICE restart may be attempted by Rust") + logging.debug(f"Connection disconnected for tube {tube_id} - ICE restart may be attempted by Rust") elif new_state == 'failed': - logging.error(f"Connection failed for tube {tube_id} - ICE restart may be attempted by Rust") + logging.debug(f"Connection failed for tube {tube_id} - ICE restart may be attempted by Rust") elif new_state == 'connected': logging.debug( f"Connection established/restored for tube {tube_id} " f"(conversation_id={conversation_id_from_signal or self.conversation_id})" ) - logging.debug(f"Connection state: connected") + logging.debug("Connection state: connected") # CRITICAL: Mark connection as connected - IMMEDIATELY stop sending ICE candidates self.connection_connected = True @@ -1610,9 +1650,9 @@ def signal_from_rust(self, response: dict): logging.debug(f"Stopping dedicated WebSocket for tunnel {tube_id}") session.websocket_stop_event.set() # Signal WebSocket to close # Give it a moment to close gracefully - session.websocket_thread.join(timeout=2.0) + session.websocket_thread.join(timeout=5.0) if session.websocket_thread.is_alive(): - logging.warning(f"Dedicated WebSocket for tunnel {tube_id} did not close in time") + logging.debug(f"Dedicated WebSocket for tunnel {tube_id} did not close in time") else: logging.debug(f"Dedicated WebSocket closed for tunnel {tube_id}") @@ -1637,9 +1677,9 @@ def signal_from_rust(self, response: dict): logging.debug(f"Stopping dedicated WebSocket for failed tunnel {tube_id}") session.websocket_stop_event.set() # Signal WebSocket to close # Give it a moment to close gracefully - session.websocket_thread.join(timeout=2.0) + session.websocket_thread.join(timeout=5.0) if session.websocket_thread.is_alive(): - logging.warning(f"Dedicated WebSocket for tunnel {tube_id} did not close in time") + logging.debug(f"Dedicated WebSocket for tunnel {tube_id} did not close in time") else: logging.debug(f"Dedicated WebSocket closed for failed tunnel {tube_id}") @@ -2376,13 +2416,14 @@ def start_rust_tunnel(params, record_uid, gateway_uid, host, port, decrypted_answer = tunnel_decrypt(symmetric_key, encrypted_answer) answer_data = json.loads(decrypted_answer) - if 'answer' in answer_data: - answer_sdp = answer_data['answer'] - logging.debug(f"Non-trickle ICE: Received SDP answer via HTTP, setting in Rust") - tube_registry.set_remote_description(commander_tube_id, answer_sdp, is_answer=True) - logging.debug("Non-trickle ICE: SDP answer set successfully") + if 'answer' in answer_data or 'sdp' in answer_data: + answer_sdp = answer_data.get('answer') or answer_data.get('sdp') + if answer_sdp: + logging.debug("Non-trickle ICE: Received SDP answer via HTTP, setting in Rust") + set_remote_description_and_parse_version(tube_registry, commander_tube_id, answer_sdp, is_answer=True) + logging.debug("Non-trickle ICE: SDP answer set successfully") else: - logging.error(f"Non-trickle ICE: No 'answer' field in decrypted data: {answer_data}") + logging.error(f"Non-trickle ICE: No 'answer' or 'sdp' field in decrypted data: {answer_data}") else: logging.error(f"Non-trickle ICE: No 'data' field in payload JSON: {payload_json}") else: diff --git a/keepercommander/commands/tunnel_and_connections.py b/keepercommander/commands/tunnel_and_connections.py index 48cd8b2dd..e664ba7ee 100644 --- a/keepercommander/commands/tunnel_and_connections.py +++ b/keepercommander/commands/tunnel_and_connections.py @@ -10,10 +10,17 @@ # import argparse +import datetime +import http.client import logging import os +import socket +import ssl +import struct import sys -from keeper_secrets_manager_core.utils import bytes_to_base64, base64_to_bytes +import time +from typing import List, Optional, Tuple +from keeper_secrets_manager_core.utils import bytes_to_base64, base64_to_bytes, url_safe_str_to_bytes from .base import Command, GroupCommand, dump_report_data, RecordMixin from .tunnel.port_forward.TunnelGraph import TunnelDAG @@ -650,12 +657,15 @@ def execute(self, params, **kwargs): class PAMTunnelDiagnoseCommand(Command): - pam_cmd_parser = argparse.ArgumentParser(prog='pam tunnel diagnose', - description='Diagnose network connectivity to krelay server. ' - 'Tests DNS resolution, TCP/UDP connectivity, AWS infrastructure, ' - 'and WebRTC peer connection setup for IT troubleshooting.') - pam_cmd_parser.add_argument('record', type=str, action='store', - help='The Record UID of the PAM resource record to test connectivity for') + # ── parser ──────────────────────────────────────────────────────────────── + pam_cmd_parser = argparse.ArgumentParser( + prog='pam tunnel diagnose', + description='Diagnose network connectivity for KeeperPAM. ' + 'When run without a record the command tests connectivity using only the ' + 'logged-in session (krelay server, HTTPS API, WebSocket, UDP port range). ' + 'When a record is supplied the full WebRTC peer connection test is also run.') + pam_cmd_parser.add_argument('record', type=str, nargs='?', default=None, + help='Optional: Record UID of a PAM resource for the full WebRTC peer connection test') pam_cmd_parser.add_argument('--timeout', '-t', required=False, dest='timeout', action='store', type=int, default=30, help='Test timeout in seconds (default: 30)') @@ -665,134 +675,657 @@ class PAMTunnelDiagnoseCommand(Command): choices=['table', 'json'], default='table', help='Output format: table (human-readable) or json (machine-readable)') pam_cmd_parser.add_argument('--test', required=False, dest='test_filter', action='store', - help='Comma-separated list of specific tests to run. Available: ' + help='Comma-separated list of specific WebRTC tests to run. Available: ' 'dns_resolution,aws_connectivity,tcp_connectivity,udp_binding,' 'ice_configuration,webrtc_peer_connection') def get_parser(self): return PAMTunnelDiagnoseCommand.pam_cmd_parser + # ── ANSI helpers ────────────────────────────────────────────────────────── + @staticmethod + def _use_color() -> bool: + return sys.stdout.isatty() and os.environ.get('NO_COLOR') is None + + @staticmethod + def _c(code: str, text: str) -> str: + return f'\033[{code}m{text}\033[0m' if PAMTunnelDiagnoseCommand._use_color() else text + + @classmethod + def _green(cls, t: str) -> str: return cls._c('92', t) + @classmethod + def _bright(cls, t: str) -> str: return cls._c('1;92', t) + @classmethod + def _dim(cls, t: str) -> str: return cls._c('2;32', t) + @classmethod + def _red(cls, t: str) -> str: return cls._c('1;91', t) + @classmethod + def _check(cls) -> str: return cls._bright('\u2713') + @classmethod + def _cross(cls) -> str: return cls._red('\u2717') + @classmethod + def _bullet(cls) -> str: return cls._bright('\u25ba') + @classmethod + def _sep(cls, w: int = 76) -> str: return cls._dim('\u2500' * w) + @classmethod + def _dsep(cls, w: int = 78) -> str: return cls._dim('\u2550' * w) + + # ── output helpers ──────────────────────────────────────────────────────── + _W = 80 # output width + + @classmethod + def _print_header(cls): + title = 'KeeperPAM \u00b7 Gateway Network Readiness Tester' + inner = cls._W - 2 + pad_l = (inner - len(title)) // 2 + pad_r = inner - len(title) - pad_l + print(cls._bright('\u2554' + '\u2550' * inner + '\u2557')) + print(cls._bright('\u2551' + ' ' * pad_l + title + ' ' * pad_r + '\u2551')) + print(cls._bright('\u255a' + '\u2550' * inner + '\u255d')) + + @classmethod + def _print_result(cls, name: str, passed: bool, detail: str, ms: int, indent: int = 4): + icon = cls._check() if passed else cls._cross() + ms_str = cls._dim(f' {ms}ms') + body = f'{cls._green(name)} \u00b7 {cls._green(detail)}' if detail else cls._green(name) + print(f'{" " * indent}{icon} {body}{ms_str}') + + # ── STUN ────────────────────────────────────────────────────────────────── + _MAGIC_COOKIE = 0x2112A442 + _STUN_PORT = 3478 + _UDP_SAMPLE_PORTS = [49152, 50000, 52000, 55000, 58000, 61000, 63000, 65535] + + @classmethod + def _stun_request(cls, msg_type: int = 0x0001) -> bytes: + return struct.pack('!HHI12s', msg_type, 0, cls._MAGIC_COOKIE, os.urandom(12)) + + @classmethod + def _recv_stun(cls, sock: socket.socket, timeout: float = 5.0) -> bytes: + buf = b'' + deadline = time.monotonic() + timeout + try: + if sock.type == socket.SOCK_DGRAM: + sock.settimeout(max(0.1, deadline - time.monotonic())) + buf, _ = sock.recvfrom(2048) + else: + while len(buf) < 20: + remaining = deadline - time.monotonic() + if remaining <= 0: + break + sock.settimeout(remaining) + chunk = sock.recv(2048) + if not chunk: + break + buf += chunk + except (socket.timeout, OSError): + pass + return buf + + @classmethod + def _parse_stun(cls, data: bytes) -> dict: + out: dict = {} + if len(data) < 20: + return out + msg_type, _msg_len, magic = struct.unpack('!HHI', data[:8]) + if magic != cls._MAGIC_COOKIE: + return out + msg_class = ((msg_type >> 7) & 0x2) | ((msg_type >> 4) & 0x1) + out['is_success'] = msg_class == 2 + out['is_error'] = msg_class == 3 + offset = 20 + while offset + 4 <= len(data): + attr_type, attr_len = struct.unpack('!HH', data[offset:offset + 4]) + offset += 4 + attr = data[offset:offset + attr_len] + if attr_type == 0x0020 and len(attr) >= 8 and attr[1] == 0x01: + xip = struct.unpack('!I', attr[4:8])[0] ^ cls._MAGIC_COOKIE + out['ext_ip'] = socket.inet_ntoa(struct.pack('!I', xip)) + elif attr_type == 0x0001 and len(attr) >= 8 and attr[1] == 0x01: + out.setdefault('ext_ip', socket.inet_ntoa(attr[4:8])) + elif attr_type == 0x0009 and len(attr) >= 4: + out['error_code'] = (attr[2] & 0x07) * 100 + attr[3] + offset += (attr_len + 3) & ~3 + return out + + # ── individual Python-side tests ────────────────────────────────────────── + @classmethod + def _test_https(cls, hostname: str, port: int = 443) -> Tuple[bool, str, int]: + """Returns (passed, detail, ms).""" + t0 = time.monotonic() + conn = None + try: + ctx = ssl.create_default_context() + conn = http.client.HTTPSConnection(hostname, port=port, context=ctx, timeout=10) + conn.request('GET', '/', headers={'User-Agent': 'keeper-pam-diagnose/1.0'}) + resp = conn.getresponse() + ms = int((time.monotonic() - t0) * 1000) + return 100 <= resp.status < 600, f'HTTP {resp.status} (reachable)', ms + except Exception as exc: + return False, str(exc)[:60], int((time.monotonic() - t0) * 1000) + finally: + if conn: + try: conn.close() + except Exception: pass + + @classmethod + def _test_websocket(cls, hostname: str, port: int = 443) -> Tuple[bool, str, int]: + """HTTP Upgrade probe — any 4xx means the server is reachable.""" + t0 = time.monotonic() + conn = None + try: + ctx = ssl.create_default_context() + conn = http.client.HTTPSConnection(hostname, port=port, context=ctx, timeout=10) + conn.request('GET', '/', headers={ + 'Upgrade': 'websocket', + 'Connection': 'Upgrade', + 'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==', + 'Sec-WebSocket-Version': '13', + 'User-Agent': 'keeper-pam-diagnose/1.0', + }) + resp = conn.getresponse() + ms = int((time.monotonic() - t0) * 1000) + return 100 <= resp.status < 600, f'HTTP {resp.status}', ms + except Exception as exc: + return False, str(exc)[:60], int((time.monotonic() - t0) * 1000) + finally: + if conn: + try: conn.close() + except Exception: pass + + @classmethod + def _test_tcp_stun(cls, hostname: str) -> Tuple[bool, str, int, Optional[str]]: + """Returns (passed, detail, ms, ext_ip).""" + t0 = time.monotonic() + sock = None + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(10) + sock.connect((hostname, cls._STUN_PORT)) + sock.sendall(cls._stun_request(0x0001)) + parsed = cls._parse_stun(cls._recv_stun(sock)) + ms = int((time.monotonic() - t0) * 1000) + ext_ip = parsed.get('ext_ip') + detail = f'external IP {ext_ip}' if ext_ip else 'TCP connected' + return True, detail, ms, ext_ip + except Exception as exc: + return False, str(exc)[:60], int((time.monotonic() - t0) * 1000), None + finally: + if sock: + try: sock.close() + except Exception: pass + + @classmethod + def _test_udp_stun(cls, hostname: str) -> Tuple[bool, str, int, Optional[str]]: + """Returns (passed, detail, ms, ext_ip).""" + t0 = time.monotonic() + sock = None + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.settimeout(5) + sock.sendto(cls._stun_request(0x0001), (hostname, cls._STUN_PORT)) + parsed = cls._parse_stun(cls._recv_stun(sock)) + ms = int((time.monotonic() - t0) * 1000) + ext_ip = parsed.get('ext_ip') + if ext_ip: + return True, f'external IP {ext_ip}', ms, ext_ip + return False, 'no STUN response', ms, None + except Exception as exc: + return False, str(exc)[:60], int((time.monotonic() - t0) * 1000), None + finally: + if sock: + try: sock.close() + except Exception: pass + + @classmethod + def _test_turn(cls, hostname: str) -> Tuple[bool, str, int]: + """Send unauthenticated TURN Allocate; expect 401 = reachable.""" + t0 = time.monotonic() + sock = None + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(10) + sock.connect((hostname, cls._STUN_PORT)) + sock.sendall(cls._stun_request(0x0003)) # Allocate + parsed = cls._parse_stun(cls._recv_stun(sock)) + ms = int((time.monotonic() - t0) * 1000) + if parsed.get('is_error') and parsed.get('error_code') == 401: + detail = 'reachable \u00b7 auth required' + elif parsed.get('is_success') or parsed.get('is_error'): + detail = 'reachable' + else: + detail = 'TCP connected' + return True, detail, ms + except Exception as exc: + return False, str(exc)[:60], int((time.monotonic() - t0) * 1000) + finally: + if sock: + try: sock.close() + except Exception: pass + + @classmethod + def _test_udp_port(cls, ip: str, port: int, timeout: float = 1.5) -> Tuple[bool, int]: + """Returns (reachable, ms). Timeout = no ICMP unreachable = assumed open.""" + import errno as _errno + t0 = time.monotonic() + sock = None + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.settimeout(timeout) + sock.sendto(cls._stun_request(0x0001), (ip, port)) + try: + sock.recvfrom(512) + except socket.timeout: + pass # no response = assume open + except OSError as oserr: + ms = int((time.monotonic() - t0) * 1000) + if oserr.errno in (_errno.ECONNREFUSED, _errno.ENETUNREACH, _errno.EHOSTUNREACH): + return False, ms + return True, int((time.monotonic() - t0) * 1000) + except Exception: + return False, int((time.monotonic() - t0) * 1000) + finally: + if sock: + try: sock.close() + except Exception: pass + + # ── execute ─────────────────────────────────────────────────────────────── def execute(self, params, **kwargs): record_name = kwargs.get('record') timeout = kwargs.get('timeout', 30) verbose = kwargs.get('verbose', False) output_format = kwargs.get('format', 'table') test_filter = kwargs.get('test_filter') - - if not record_name: - raise CommandError('pam tunnel diagnose', '"record" parameter is required.') - # Check for Rust WebRTC library availability - # Logger initialization is handled by get_or_create_tube_registry() - tube_registry = get_or_create_tube_registry(params) - if not tube_registry: - print(f"{bcolors.FAIL}This command requires the Rust WebRTC library (keeper_pam_webrtc_rs).{bcolors.ENDC}") - print(f"{bcolors.OKBLUE}Please ensure the keeper_pam_webrtc_rs module is installed and available.{bcolors.ENDC}") - return 1 + server = params.server # e.g. "keepersecurity.com" + krelay_server = os.environ.get('KRELAY_URL') or f'krelay.{server}' + connect_host = f'connect.{server}' + + # ── header ──────────────────────────────────────────────────────────── + self._print_header() + print() + now = datetime.datetime.utcnow() + region_label = 'US' if server == 'keepersecurity.com' else server.split('.')[0].upper() + print(self._green(f' Region {region_label} \u00b7 {server}')) + print(self._green(f' Date {now.strftime("%Y-%m-%d %H:%M")} UTC')) + if record_name: + print(self._green(f' Record {record_name}')) + print() + + t_overall_start = time.monotonic() + all_passed: List[bool] = [] + blocked_names: List[str] = [] + public_ip: Optional[str] = None + + def _record(name: str, passed: bool, detail: str, ms: int): + all_passed.append(passed) + if not passed: + blocked_names.append(name) + self._print_result(name, passed, detail, ms) + + # ── section 1: DNS & cloud connectivity ─────────────────────────────── + print(f'{self._bullet()} {self._bright("DNS & Cloud Connectivity")}') + print(f' {self._sep()}') + + # DNS + t0 = time.monotonic() + try: + infos = socket.getaddrinfo(server, None, socket.AF_INET) + ips = list(dict.fromkeys(a[4][0] for a in infos)) + ms = int((time.monotonic() - t0) * 1000) + extra = f'(+{len(ips) - 1} addr)' if len(ips) > 1 else '' + _record(f'DNS {server}', True, f'\u2192 {ips[0]} {extra}'.strip(), ms) + except Exception as exc: + _record(f'DNS {server}', False, str(exc)[:60], int((time.monotonic() - t0) * 1000)) - # Resolve and validate the record - api.sync_down(params) - record = RecordMixin.resolve_single_record(params, record_name) - if not record: - print(f"{bcolors.FAIL}Record '{record_name}' not found.{bcolors.ENDC}") - return 1 - if not isinstance(record, vault.TypedRecord): - print(f"{bcolors.FAIL}Record '{record_name}' cannot be used for tunneling.{bcolors.ENDC}") - return 1 + passed, detail, ms = self._test_https(server) + _record(f'HTTPS {server}:443', passed, detail, ms) - record_uid = record.record_uid - record_type = record.record_type - if record_type not in ("pamMachine pamDatabase pamDirectory pamNetworkConfiguration pamAwsConfiguration " - "pamRemoteBrowser pamAzureConfiguration").split(): - print(f"{bcolors.FAIL}Record type '{record_type}' is not supported for tunneling.{bcolors.ENDC}") - print(f"Supported types: pamMachine, pamDatabase, pamDirectory, pamRemoteBrowser, " - f"pamNetworkConfiguration, pamAwsConfiguration, pamAzureConfiguration") - return 1 + passed, detail, ms = self._test_websocket(connect_host) + _record(f'WebSocket {connect_host}:443', passed, detail, ms) + + print() + + # ── section 2: STUN / TURN ──────────────────────────────────────────── + print(f'{self._bullet()} {self._bright("STUN / TURN")} \u00b7 {self._green(krelay_server)}') + print(f' {self._sep()}') + + passed, detail, ms, ext_ip = self._test_tcp_stun(krelay_server) + _record(f'TCP STUN {krelay_server}:{self._STUN_PORT}', passed, detail, ms) + if ext_ip: + public_ip = ext_ip + + passed, detail, ms, ext_ip = self._test_udp_stun(krelay_server) + _record(f'UDP STUN {krelay_server}:{self._STUN_PORT}', passed, detail, ms) + if ext_ip and not public_ip: + public_ip = ext_ip - # Get the krelay server from the PAM configuration + passed, detail, ms = self._test_turn(krelay_server) + _record(f'TURN relay {krelay_server}:{self._STUN_PORT}', passed, detail, ms) + + print() + + # ── section 3: WebRTC media ports ───────────────────────────────────── try: - encrypted_session_token, encrypted_transmission_key, transmission_key = get_keeper_tokens(params) - pam_config_uid = get_config_uid(params, encrypted_session_token, encrypted_transmission_key, record_uid) - - if not pam_config_uid: - print(f"{bcolors.FAIL}No PAM Configuration found for record '{record_name}'.{bcolors.ENDC}") - print(f"Please configure the record with: {bcolors.OKBLUE}pam tunnel edit {record_uid} --config [ConfigUID]{bcolors.ENDC}") - return 1 + krelay_ip = socket.gethostbyname(krelay_server) + except Exception: + krelay_ip = krelay_server + + udp_range_label = "UDP 49152\u201365535" + print(f'{self._bullet()} {self._bright("WebRTC Media Ports")} \u00b7 {self._green(udp_range_label)}') + print(f' {self._sep()}') + + udp_timeout = min(float(timeout), 1.5) + port_results: List[Tuple[int, bool, int]] = [] + for port in self._UDP_SAMPLE_PORTS: + ok, ms = self._test_udp_port(krelay_ip, port, timeout=udp_timeout) + port_results.append((port, ok, ms)) + all_passed.append(ok) + if not ok: + blocked_names.append(f'UDP:{port}') + + row = ' ' + for port, ok, _ in port_results: + row += f'{self._check() if ok else self._cross()} {self._green(str(port))} ' + print(row.rstrip()) + + passed_ports = sum(1 for _, ok, _ in port_results if ok) + print(f' {self._check()} {self._green(str(passed_ports))}/{len(port_results)} sampled ports reachable') + print() + + # ── section 4: WebRTC connectivity (Rust library) ───────────────────── + tube_registry = get_or_create_tube_registry(params) + rust_results = None - # The krelay server hostname is constructed from the params.server - krelay_server = f"krelay.{params.server}" - - except Exception as e: - print(f"{bcolors.FAIL}Failed to get PAM configuration: {e}{bcolors.ENDC}") - return 1 - - # Build test settings - settings = { - "use_turn": True, - "turn_only": False - } - - # Parse test filter if provided - if test_filter: - allowed_tests = {'dns_resolution', 'aws_connectivity', 'tcp_connectivity', + if tube_registry: + print(f'{self._bullet()} {self._bright("WebRTC Connectivity")} \u00b7 {self._green("STUN/TURN/ICE/Peer")}') + print(f' {self._sep()}') + + # Resolve optional record for pam_config_uid + if record_name: + try: + api.sync_down(params) + record = RecordMixin.resolve_single_record(params, record_name) + if record and isinstance(record, vault.TypedRecord): + encrypted_session_token, encrypted_transmission_key, _ = get_keeper_tokens(params) + pam_config_uid = get_config_uid(params, encrypted_session_token, + encrypted_transmission_key, record.record_uid) + if not pam_config_uid: + print(f' {self._cross()} {self._red(f"No PAM config found for record {record_name}")}') + except Exception as exc: + logging.debug(f'Record lookup failed: {exc}', exc_info=True) + + # Get TURN credentials + turn_username = turn_password = None + try: + from .pam.router_helper import router_get_relay_access_creds + creds = router_get_relay_access_creds(params, expire_sec=60000000) + turn_username = creds.username + turn_password = creds.password + except Exception as exc: + logging.debug(f'Could not get TURN credentials: {exc}', exc_info=True) + + settings = {'use_turn': True, 'turn_only': False} + if test_filter: + allowed = {'dns_resolution', 'aws_connectivity', 'tcp_connectivity', 'udp_binding', 'ice_configuration', 'webrtc_peer_connection'} - requested_tests = set(test.strip() for test in test_filter.split(',')) - invalid_tests = requested_tests - allowed_tests - if invalid_tests: - print(f"{bcolors.FAIL}Invalid test names: {', '.join(invalid_tests)}{bcolors.ENDC}") - print(f"Available tests: {', '.join(sorted(allowed_tests))}") - return 1 - settings["test_filter"] = list(requested_tests) - - print(f"{bcolors.OKBLUE}Starting network connectivity diagnosis for krelay server: {krelay_server}{bcolors.ENDC}") - print(f"Record: {record.title} ({record_uid})") - print(f"Timeout: {timeout}s") - print("") - - # Get TURN credentials for the connectivity test - try: - webrtc_settings = create_rust_webrtc_settings( - params, host="127.0.0.1", port=0, - target_host="test", target_port=22, - socks=False, nonce=os.urandom(32) - ) - turn_username = webrtc_settings.get("turn_username") - turn_password = webrtc_settings.get("turn_password") - except Exception as e: - print(f"{bcolors.WARNING}Could not get TURN credentials: {e}{bcolors.ENDC}") - turn_username = None - turn_password = None - - # Run the connectivity test - try: - results = tube_registry.test_webrtc_connectivity( - krelay_server=krelay_server, - settings=settings, - timeout_seconds=timeout, - username=turn_username, - password=turn_password - ) - - if output_format == 'json': - import json - print(json.dumps(results, indent=2)) - return 0 - else: - # Use the built-in formatter for human-readable output - formatted_output = tube_registry.format_connectivity_results(results, detailed=verbose) - print(formatted_output) - - # Return appropriate exit code - overall_result = results.get('overall_result', {}) - if overall_result.get('success', False): + requested = {t.strip() for t in test_filter.split(',')} + invalid = requested - allowed + if invalid: + print(f"{bcolors.FAIL}Invalid test names: {', '.join(invalid)}{bcolors.ENDC}") + return 1 + settings['test_filter'] = list(requested) + + try: + rust_results = tube_registry.test_webrtc_connectivity( + krelay_server=krelay_server, + settings=settings, + timeout_seconds=timeout, + username=turn_username, + password=turn_password, + ) + if output_format == 'json': + import json + print(json.dumps(rust_results, indent=2)) return 0 + + # Fold Rust test results into the unified pass/fail accounting + for test in rust_results.get('test_results', []): + name = test.get('test_name', '?') + ok = test.get('success', False) + ms = int(test.get('duration_ms', 0)) + msg = test.get('message', '') + _record(name.replace('_', ' ').title(), ok, msg, ms) + + except Exception as exc: + print(f' {self._cross()} {self._red(f"WebRTC test failed: {exc}")}') + all_passed.append(False) + blocked_names.append('webrtc') + logging.debug('WebRTC test error', exc_info=True) + + print() + else: + logging.debug('keeper_pam_webrtc_rs not available; skipping WebRTC section') + + # ── section 5: PAM configuration graph (record-specific) ──────────── + if record_name: + print(f'{self._bullet()} {self._bright("PAM Configuration")} \u00b7 {self._green(record_name)}') + print(f' {self._sep()}') + + try: + # Ensure vault is synced so the config record is in the local cache + api.sync_down(params) + record_obj = RecordMixin.resolve_single_record(params, record_name) + record_uid = record_obj.record_uid if record_obj else record_name + + _supported_types = ('pamMachine', 'pamDatabase', 'pamDirectory', 'pamRemoteBrowser') + rec_type_early = record_obj.record_type if record_obj and isinstance(record_obj, vault.TypedRecord) else None + if rec_type_early and rec_type_early not in _supported_types: + print(f' {self._cross()} {self._red("Record type")} {self._green(rec_type_early)} ' + f'{self._red("is not a PAM resource — skipping configuration checks")}') + print(f' {self._dim("Supported types: " + ", ".join(_supported_types))}') + print() + # Skip to Technical Details + raise StopIteration + + # 1. Config linked — find the PAM config that owns this record + enc_session_token, enc_transmission_key, _tx_key = get_keeper_tokens(params) + config_uid = get_config_uid(params, enc_session_token, enc_transmission_key, record_uid) + if config_uid: + _record('Config linked', True, config_uid, 0) else: - return 1 - - except Exception as e: - print(f"{bcolors.FAIL}Network connectivity test failed: {e}{bcolors.ENDC}") - logging.debug(f"Full error details: {e}", exc_info=True) - return 1 + _record('Config linked', False, 'record not found in any PAM config graph', 0) + + if config_uid: + # 2. DAG loaded — fresh tokens required; TunnelDAG's Connection + # needs its own key pair separate from the get_config_uid call, + # and transmission_key must be passed so it can decrypt the response + enc_st2, enc_tk2, tx_key2 = get_keeper_tokens(params) + tdag = TunnelDAG(params, enc_st2, enc_tk2, config_uid, is_config=True, + transmission_key=tx_key2) + dag_ok = tdag.linking_dag.has_graph + vertex_count = len(tdag.linking_dag._vertices) if dag_ok else 0 + _record('DAG loaded', dag_ok, + '{} vertices'.format(vertex_count) if dag_ok else 'graph empty — config may be unconfigured', + 0) + + if dag_ok: + # 3. Resource linked — LINK edge from config → resource present + linked = tdag.resource_belongs_to_config(record_uid) + _record('Resource linked', linked, + 'LINK edge present' if linked else 'resource not linked to config', 0) + + rec_type = record_obj.record_type if record_obj else '' + is_rbi = rec_type == 'pamRemoteBrowser' + + # 4. Config-level settings + con_config = tdag.check_tunneling_enabled_config(enable_connections=True) + _record('Connections at config', con_config, + 'connections enabled' if con_config else 'connections disabled at config', 0) + + if not is_rbi: + tun_config = tdag.check_tunneling_enabled_config(enable_tunneling=True) + _record('Tunneling at config', tun_config, + 'portForwards enabled' if tun_config else 'portForwards disabled at config', 0) + + # 5. Resource-level settings + con_resource = tdag.check_if_resource_allowed(record_uid, 'connections') + _record('Connections at resource', con_resource, + 'connections enabled' if con_resource else 'connections disabled at resource', 0) + + if not is_rbi: + tun_resource = tdag.check_if_resource_allowed(record_uid, 'portForwards') + _record('Tunneling at resource', tun_resource, + 'portForwards enabled' if tun_resource else 'portForwards disabled at resource', 0) + + # verbose: dump allowedSettings for config and resource + if verbose: + from .tunnel.port_forward.TunnelGraph import get_vertex_content + _setting_keys = [ + ('connections', 'Connections'), + ('portForwards', 'Port Forwards'), + ('rotation', 'Rotation'), + ('sessionRecording', 'Session Recording'), + ('typescriptRecording', 'Typescript Recording'), + ('remoteBrowserIsolation', 'Remote Browser Isolation'), + ] + config_vertex = tdag.linking_dag.get_vertex(tdag.record.record_uid) + resource_vertex = tdag.linking_dag.get_vertex(record_uid) + cfg_content = get_vertex_content(config_vertex) or {} + res_content = get_vertex_content(resource_vertex) or {} + cfg_settings = cfg_content.get('allowedSettings', {}) + res_settings = res_content.get('allowedSettings', {}) + + _yes = self._bright('on ') + _no = self._dim('off') + _def = self._dim('---') + + def _fmt_bool(d, key): + v = d.get(key) + if v is True: return _yes + if v is False: return _no + return _def + + print() + print(f' {self._dim("DAG allowedSettings"):<28}' + f'{self._dim("Config"):<12}{self._dim("Resource")}') + print(f' {self._dim("-" * 52)}') + for key, label in _setting_keys: + print(f' {self._green(label):<28}' + f'{_fmt_bool(cfg_settings, key):<12}' + f'{_fmt_bool(res_settings, key)}') + + # typed field on the vault record — field name differs by record type + def _val(v): + if v is None: return self._dim('---') + if v is True: return self._bright('true') + if v is False: return self._dim('false') + return self._green(str(v)) + + print() + if is_rbi: + rbs_field = record_obj.get_typed_field('pamRemoteBrowserSettings') if record_obj else None + rbs = {} + if rbs_field and rbs_field.value: + rbs = rbs_field.value[0] if isinstance(rbs_field.value[0], dict) else {} + cn = rbs.get('connection', {}) or {} + print(f' {self._dim("Record pamRemoteBrowserSettings")}') + print(f' {self._dim("-" * 52)}') + print(f' {self._green("connection.protocol"):<36}{_val(cn.get("protocol"))}') + print(f' {self._green("connection.httpCredentialsUid"):<36}{_val(cn.get("httpCredentialsUid") or None)}') + print(f' {self._green("connection.recordingIncludeKeys"):<36}{_val(cn.get("recordingIncludeKeys"))}') + else: + pam_settings_field = record_obj.get_typed_field('pamSettings') if record_obj else None + ps = {} + if pam_settings_field and pam_settings_field.value: + ps = pam_settings_field.value[0] if isinstance(pam_settings_field.value[0], dict) else {} + pf = ps.get('portForward', {}) or {} + cn = ps.get('connection', {}) or {} + print(f' {self._dim("Record pamSettings")}') + print(f' {self._dim("-" * 52)}') + print(f' {self._green("portForward.port"):<36}{_val(pf.get("port"))}') + print(f' {self._green("connection.port"):<36}{_val(cn.get("port"))}') + print(f' {self._green("connection.protocol"):<36}{_val(cn.get("protocol"))}') + print(f' {self._green("connection.allowKeeperDBProxy"):<36}{_val(cn.get("allowKeeperDBProxy"))}') + print(f' {self._green("connection.recordingIncludeKeys"):<36}{_val(cn.get("recordingIncludeKeys"))}') + print(f' {self._green("allowSupplyHost"):<36}{_val(ps.get("allowSupplyHost"))}') + if ps.get('configUid'): + print(f' {self._green("configUid"):<36}{_val(ps.get("configUid"))}') + + # 6. Gateway registered — a controller UID is associated with this config + gateway_uid = get_gateway_uid_from_record(params, vault, record_uid) + if gateway_uid: + _record('Gateway registered', True, gateway_uid, 0) + else: + _record('Gateway registered', False, 'no gateway registered for this config', 0) + + # 7. Gateway online — that gateway is currently connected to krouter + if gateway_uid: + try: + from .pam.router_helper import router_get_connected_gateways + online_controllers = router_get_connected_gateways(params) + if online_controllers: + gw_bytes = url_safe_str_to_bytes(gateway_uid) + connected_uids = [c.controllerUid for c in online_controllers.controllers] + gw_online = gw_bytes in connected_uids + _record('Gateway online', gw_online, + 'connected to krouter' if gw_online else 'gateway offline or unreachable', 0) + else: + _record('Gateway online', False, 'could not retrieve connected gateways', 0) + except Exception as exc: + _record('Gateway online', False, str(exc)[:60], 0) + + except StopIteration: + pass # unsupported record type — already printed, skip gracefully + except Exception as exc: + print(f' {self._cross()} {self._red("PAM graph check failed: " + str(exc)[:70])}') + all_passed.append(False) + blocked_names.append('pam-graph') + logging.debug('PAM graph check error', exc_info=True) + + print() + + # ── section 6: technical details ────────────────────────────────────── + print(f'{self._bullet()} {self._bright("Technical Details")}') + print(f' {self._sep()}') + + try: + fqdn = socket.getfqdn() + local_ip = socket.gethostbyname(socket.gethostname()) + except Exception: + fqdn = socket.gethostname() + local_ip = '?' + + passed_total = sum(1 for v in all_passed if v) + total_checks = len(all_passed) + duration_s = time.monotonic() - t_overall_start + blocked_str = 'none \u2013 all paths open' if not blocked_names else ', '.join(blocked_names) + + col = 10 + print(f' {self._dim("Machine"):<{col}}{self._green(fqdn)} \u00b7 {self._green(local_ip)}') + if public_ip: + print(f' {self._dim("Public IP"):<{col}}{self._green(public_ip)} {self._dim("via STUN")}') + print(f' {self._dim("Duration"):<{col}}{self._green(f"{duration_s:.1f}s")} \u00b7 ' + f'{self._green(f"{passed_total}/{total_checks} checks")}') + print(f' {self._dim("Blocked"):<{col}}{self._green(blocked_str)}') + + print() + print(f' {self._dsep()}') + print() + + if passed_total == total_checks: + summary = "GATEWAY READY \u00b7 {} / {} checks passed".format(passed_total, total_checks) + print(f' {self._check()} {self._bright(summary)}') + else: + summary = "GATEWAY NOT READY \u00b7 {} / {} checks passed".format(passed_total, total_checks) + print(f' {self._cross()} {self._red(summary)}') + for name in blocked_names: + print(f' {self._red(name)}') + + print() + print(f' {self._dsep()}') + print() + + return 0 if passed_total == total_checks else 1 class PAMConnectionEditCommand(Command): @@ -1046,7 +1579,9 @@ def execute(self, params, **kwargs): f'{bcolors.FAIL}Launch user record must be a pamUser record type.{bcolors.ENDC}') launch_uid = launch_rec.record_uid if record_type in ("pamDatabase", "pamDirectory", "pamMachine"): + tdag.clear_launch_credential_for_resource(record_uid, exclude_user_uid=launch_uid) tdag.link_user_to_resource(launch_uid, record_uid, is_launch_credential=True, belongs_to=True) + tdag.upgrade_resource_meta_to_v1(record_uid) # Print out PAM Settings if not kwargs.get("silent", False): tdag.print_tunneling_config(record_uid, record.get_typed_field('pamSettings'), config_uid) diff --git a/keepercommander/constants.py b/keepercommander/constants.py index 2d2842b5e..dd4bc0c1e 100644 --- a/keepercommander/constants.py +++ b/keepercommander/constants.py @@ -264,6 +264,18 @@ def enforcement_list(): # type: () -> List[Tuple[str, str, str]] ENFORCEMENTS = {e[0].lower(): e[2].lower() for e in [*_ENFORCEMENTS, *_COMPOUND_ENFORCEMENTS]} +TWO_FACTOR_DURATION_MAP = {'0': 'login', '12': '12_hours', '24': '24_hours', '30': '30_days', '9999': 'forever'} + + +def format_two_factor_duration(raw_value): # type: (str) -> str + """Convert stored cumulative value like '0,12,24,30' to the effective setting like '30_days'.""" + if not isinstance(raw_value, str): + raw_value = str(raw_value) if raw_value is not None else '' + tokens = [x.strip() for x in raw_value.split(',')] + last = tokens[-1] if tokens else '' + return TWO_FACTOR_DURATION_MAP.get(last, last) + + week_days = ('SUNDAY', 'MONDAY', 'TUESDAY', 'WEDNESDAY', 'THURSDAY', 'FRIDAY', 'SATURDAY') occurrences = ('FIRST', 'SECOND', 'THIRD', 'FOURTH', 'LAST') months = ('JANUARY', 'FEBRUARY', 'MARCH', 'APRIL', 'MAY', 'JUNE', 'JULY', 'AUGUST', 'SEPTEMBER', 'OCTOBER', diff --git a/keepercommander/discovery_common/__version__.py b/keepercommander/discovery_common/__version__.py index bc50bee68..1bdaf4702 100644 --- a/keepercommander/discovery_common/__version__.py +++ b/keepercommander/discovery_common/__version__.py @@ -1 +1 @@ -__version__ = '1.1.4' +__version__ = '1.1.10' diff --git a/keepercommander/discovery_common/infrastructure.py b/keepercommander/discovery_common/infrastructure.py index 61147bf33..ee7ad7fed 100644 --- a/keepercommander/discovery_common/infrastructure.py +++ b/keepercommander/discovery_common/infrastructure.py @@ -4,11 +4,12 @@ from ..keeper_dag import DAG, EdgeType from ..keeper_dag.exceptions import DAGVertexException from ..keeper_dag.crypto import urlsafe_str_to_bytes -from ..keeper_dag.types import PamGraphId, PamEndpoints +from ..keeper_dag.types import PamGraphId +from .types import DiscoveryObject import os import importlib import time -from typing import Any, Optional, TYPE_CHECKING +from typing import Any, Optional, Dict, List, TYPE_CHECKING if TYPE_CHECKING: from ..keeper_dag.vertex import DAGVertex @@ -59,6 +60,8 @@ def __init__(self, record: Any, logger: Optional[Any] = None, history_level: int self.conn = get_connection(logger=logger, **kwargs) + self._cache: Optional[Dict] = None + @property def dag(self) -> DAG: if self._dag is None: @@ -123,6 +126,12 @@ def close(self): Clean up resources held by this Infrastructure instance. Releases the DAG instance and connection to prevent memory leaks. """ + if self._cache: + for v in self._cache.values(): + v["vertex"] = None + v["content"] = None + self._cache.clear() + if self._dag is not None: self._dag = None self.conn = None @@ -150,6 +159,86 @@ def save(self, delta_graph: Optional[bool] = None): self._dag.save(delta_graph=delta_graph) self.logger.debug(f"infrastructure took {time.time()-ts} secs to save") + def cache_objects(self): + + self.logger.debug(f"building id to infrastructure cache") + + self._cache = {} + + def _cache(v: DAGVertex, parent_content: Optional[DiscoveryObject] = None): + c = DiscoveryObject.get_discovery_object(v) + key = c.object_type_value.lower() + c.id.lower() + self._cache[key] = { + "key": key, + "uid": v.uid, + "parent_uid": parent_content.uid if parent_content else None, + "vertex": v, + "content": c, + "was_found": False, + "could_login": True, + "is_new": False, + "md5": c.md5 + } + + for next_v in v.has_vertices(): + _cache(next_v, c) + + if self.has_discovery_data: + ts = time.time() + _cache(self.get_configuration, None) + self.logger.info(f" infrastructure cache build time: {time.time()-ts} seconds") + else: + self.logger.info(f" no infrastructure data to cache") + + def get_cache_info(self, object_type_value: str, object_id: str) -> Dict: + return self._cache.get(object_type_value.lower() + object_id.lower()) + + def get_cache_info_by_key(self, key: str) -> Dict: + return self._cache.get(key.lower()) + + def get_missing_cache_list(self, uid: Optional[str] = None) -> List[str]: + not_found_list = [] + for k, v in self._cache.items(): + if not v["is_new"] and not v["was_found"]: + if uid is None or uid == v["uid"] or uid == v["parent_uid"]: + not_found_list.append(k) + return not_found_list + + def add_info_to_cache(self, vertex: DAGVertex, content: DiscoveryObject, parent_vertex: Optional[DAGVertex] = None): + if self._cache is None: + self._cache = {} + + key = content.object_type_value.lower() + content.id.lower() + self._cache[key] = { + "key": key, + "uid": vertex.uid, + "parent_uid": parent_vertex.uid if parent_vertex else None, + "vertex": vertex, + "content": content, + "was_found": True, + "could_login": True, + "is_new": True, + "md5": content.md5 + } + + def update_cache_info(self, info: Dict): + key = info["key"] + self._cache[key] = info + + def find_content(self, query: Dict, ignore_case: bool = False) -> Optional[DAGVertex]: + """ + Find the vertex that matches the query. + + Will only find one. + If it does not match, return None + If matches on more, return None + """ + + vertices = self.dag.search_content(query=query, ignore_case=ignore_case) + if len(vertices) != 1: + return None + return vertices[0] + def to_dot(self, graph_format: str = "svg", show_hex_uid: bool = False, show_version: bool = True, show_only_active_vertices: bool = False, show_only_active_edges: bool = False, sync_point: int = None, graph_type: str = "dot"): diff --git a/keepercommander/discovery_common/jobs.py b/keepercommander/discovery_common/jobs.py index ff7198f06..098179da0 100644 --- a/keepercommander/discovery_common/jobs.py +++ b/keepercommander/discovery_common/jobs.py @@ -2,7 +2,7 @@ from .utils import get_connection, make_agent from .types import JobContent, JobItem, Settings, DiscoveryDelta from ..keeper_dag import DAG, EdgeType -from ..keeper_dag.types import PamGraphId, PamEndpoints +from ..keeper_dag.types import PamGraphId import logging import os import base64 @@ -320,26 +320,12 @@ def get_job(self, job_id) -> Optional[JobItem]: # Get the job item from the job vertex DATA edge. # Replace the one from the job history if we have it. try: - job = job_vertex.content_as_object(JobItem) + found_job = job_vertex.content_as_object(JobItem) + if found_job is not None: + job = found_job except Exception as err: self.logger.debug(f"could not find job item on job vertex, use job histry entry: {err}") - # If the job delta is None, check to see if it chunked as vertices. - delta_lookup = {} - vertices = job_vertex.has_vertices() - self.logger.debug(f"found {len(vertices)} delta vertices") - for vertex in vertices: - edge = vertex.get_edge(job_vertex, edge_type=EdgeType.KEY) - delta_lookup[int(edge.path)] = vertex - - json_value = "" - # Sort numerically increasing and then append their content. - # This will re-assemble the JSON - for key in sorted(delta_lookup): - json_value += delta_lookup[key].content_as_str - if json_value != "": - self.logger.debug(f"delta content length is {len(json_value)}") - job.delta = DiscoveryDelta.model_validate_json(json_value) else: self.logger.debug("could not find job vertex") diff --git a/keepercommander/discovery_common/process.py b/keepercommander/discovery_common/process.py index 2e835112a..1a3eaf8f6 100644 --- a/keepercommander/discovery_common/process.py +++ b/keepercommander/discovery_common/process.py @@ -1019,7 +1019,7 @@ def _process_level(self, if admin_uid is not None: self.logger.debug(" found directory user admin, connect to resource") - # self.record_link.belongs_to(admin_uid, add_content.record_uid, acl=acl) + self.record_link.belongs_to(admin_uid, add_content.record_uid, acl=acl) should_prompt_for_admin = False else: self.logger.debug(" did not find the directory user for the admin, " @@ -1562,7 +1562,4 @@ def run(self, self.infra.save(delta_graph=False) self.logger.debug("# ####################################################################################") - # Update the user service mapping - self.user_service.run(infra=self.infra) - return bulk_process_results diff --git a/keepercommander/discovery_common/rm_types.py b/keepercommander/discovery_common/rm_types.py index 3f6d00b3c..38057ba18 100644 --- a/keepercommander/discovery_common/rm_types.py +++ b/keepercommander/discovery_common/rm_types.py @@ -137,6 +137,30 @@ class RmAzureGroupAddMeta(RmMetaBase): group_types: List[str] = [] +class RmGcpUserAddMeta(RmMetaBase): + account_enabled: Optional[bool] = True + display_name: Optional[str] = None + password_reset_required: Optional[bool] = False + password_reset_required_with_mfa: Optional[bool] = False + groups: List[str] = [] + + +class RmGcpGroupAddMeta(RmMetaBase): + group_types: List[str] = [] + + +class RmOktaUserAddMeta(RmMetaBase): + account_enabled: Optional[bool] = True + display_name: Optional[str] = None + password_reset_required: Optional[bool] = False + password_reset_required_with_mfa: Optional[bool] = False + groups: List[str] = [] + + +class RmOktaGroupAddMeta(RmMetaBase): + group_types: List[str] = [] + + class RmDomainUserAddMeta(RmMetaBase): roles: List[str] = [] groups: List[str] = [] @@ -253,6 +277,10 @@ class RmMongoDbRoleAddMeta(RmMetaBase): # MACHINE +class RmUserDeleteBaseMeta(RmMetaBase): + remove_home_dir: Optional[bool] = True + + class RmLinuxGroupAddMeta(RmMetaBase): gid: Optional[int] = None system_group: Optional[bool] = False @@ -291,8 +319,7 @@ class RmLinuxUserAddMeta(RmMachineUserAddMeta): non_system_dir_mode: Optional[str] = None -class RmLinuxUserDeleteMeta(RmMetaBase): - remove_home_dir: Optional[bool] = True +class RmLinuxUserDeleteMeta(RmUserDeleteBaseMeta): remove_user_group: Optional[bool] = True @@ -308,6 +335,10 @@ class RmWindowsUserAddMeta(RmMachineUserAddMeta): groups: List[str] = [] +class RmWindowsUserDeleteMeta(RmUserDeleteBaseMeta): + pass + + class RmMacOsUserAddMeta(RmMachineUserAddMeta): display_name: Optional[str] = None uid: Optional[str] = None @@ -325,6 +356,10 @@ class RmMacOsRoleAddMeta(RmMetaBase): record_name: Optional[str] = None +class RmMacOsUserDeleteMeta(RmUserDeleteBaseMeta): + pass + + # DIRECTORY diff --git a/keepercommander/discovery_common/types.py b/keepercommander/discovery_common/types.py index 710353f44..d4d17d953 100644 --- a/keepercommander/discovery_common/types.py +++ b/keepercommander/discovery_common/types.py @@ -5,6 +5,7 @@ import datetime import base64 import json +import hashlib from keeper_secrets_manager_core.crypto import CryptoUtils from typing import Any, Union, Optional, List, TYPE_CHECKING @@ -524,6 +525,18 @@ class DiscoveryObject(BaseModel): # Specific information for a record type. item: Union[DiscoveryConfiguration, DiscoveryUser, DiscoveryMachine, DiscoveryDatabase, DiscoveryDirectory] + @property + def md5(self) -> str: + data = self.model_dump() + + # Don't include these in the MD5 + data.pop("missing_since_ts", None) + data.pop("access_user", None) + + m = hashlib.md5() + m.update(json.dumps(data).encode('utf-8')) + return m.hexdigest() + @property def record_exists(self): return self.record_uid is not None @@ -603,29 +616,98 @@ class NormalizedRecord(BaseModel): title: str fields: List[RecordField] = [] note: Optional[str] = None + record_exists: bool = True + + def _field(self, + field_type: Optional[str] = None, + label: Optional[str] = None) -> Optional[RecordField]: + if field_type is None and label is None: + raise ValueError("either field_type or label needs to be set to find field in NormalizedRecord.") - def _field(self, field_type, label) -> Optional[RecordField]: for field in self.fields: - value = field.value - if value is None or len(value) == 0: - continue - if field.label == field_type and value[0].lower() == label.lower(): + if field_type is not None and field_type == field.type: + return field + if label is not None and label == field.label: return field return None - def find_user(self, user): + def find_field(self, + field_type: Optional[str] = None, + label: Optional[str] = None) -> Optional[RecordField]: + + return self._field(field_type=field_type, label=label) + + def get_value(self, + field_type: Optional[str] = None, + label: Optional[str] = None) -> Optional[Any]: + + field = self.find_field(field_type=field_type, label=label) + if field is None or field.value is None or len(field.value) == 0: + return None + return field.value[0] + + def get_user(self) -> Optional[str]: + field = self._field(field_type="login") + if field is None: + return None + value = field.value + if isinstance(value, list): + if len(value) == 0: + return None + value = value[0] + return value + + def get_dn(self) -> Optional[str]: + field = self._field(label="distinguishedName") + if field is None: + return None + value = field.value + if isinstance(value, list): + if len(value) == 0: + return None + value = value[0] + return value + + def has_user(self, user) -> bool: from .utils import split_user_and_domain - res = self._field("login", user) - if res is None: - user, _ = split_user_and_domain(user) - res = self._field("login", user) + user, _ = split_user_and_domain(user) + + field = self._field(field_type="login") + if field is None: + return False + + value = field.value + if isinstance(value, list): + if len(value) == 0: + return False + value = value[0] + elif isinstance(value, str): + value = value.lower() + + if user.lower() == value: + return True + + return False + + def has_dn(self, user) -> bool: + field = self._field(label="distinguishedName") + if field is None: + return False - return res + value = field.value + if isinstance(value, list): + if len(value) == 0: + return False + value = value[0] + elif isinstance(value, str): + value = value.lower() - def find_dn(self, user): - return self._field("distinguishedName", user) + if user.lower() == value: + return True + + return False class PromptResult(BaseModel): diff --git a/keepercommander/discovery_common/user_service.py b/keepercommander/discovery_common/user_service.py index 455526ad8..cf58675a2 100644 --- a/keepercommander/discovery_common/user_service.py +++ b/keepercommander/discovery_common/user_service.py @@ -1,21 +1,22 @@ from __future__ import annotations import logging -from .constants import USER_SERVICE_GRAPH_ID, PAM_MACHINE, PAM_USER, PAM_DIRECTORY, DOMAIN_USER_CONFIGS -from .utils import get_connection, user_in_lookup, user_check_list, make_agent -from .types import DiscoveryObject, ServiceAcl, FactsNameUser +import os + +from .constants import PAM_MACHINE, PAM_USER, PAM_DIRECTORY, DOMAIN_USER_CONFIGS +from .utils import get_connection, make_agent, split_user_and_domain, value_to_boolean +from .types import DiscoveryObject, ServiceAcl, NormalizedRecord from .infrastructure import Infrastructure +from .record_link import RecordLink from ..keeper_dag import DAG, EdgeType -from ..keeper_dag.types import PamEndpoints, PamGraphId +from ..keeper_dag.types import PamGraphId import importlib -from typing import Any, Optional, List, TYPE_CHECKING +from typing import Any, Optional, List, Callable, Dict, TYPE_CHECKING if TYPE_CHECKING: from ..keeper_dag.vertex import DAGVertex from ..keeper_dag.edge import DAGEdge -# TODO: Refactor this code; we can make this smaller since method basically do the same functions, just different -# attributes. class UserService: def __init__(self, record: Any, logger: Optional[Any] = None, history_level: int = 0, @@ -23,6 +24,10 @@ def __init__(self, record: Any, logger: Optional[Any] = None, history_level: int save_batch_count: int = 200, agent: Optional[str] = None, **kwargs): + # Keep these for other graphs + self._params = kwargs.get("params") + self._ksm = kwargs.get("ksm") + self.conn = get_connection(**kwargs) # This will either be a KSM Record, or Commander KeeperRecord @@ -44,19 +49,38 @@ def __init__(self, record: Any, logger: Optional[Any] = None, history_level: int self.auto_save = False self.last_sync_point = -1 + self.directory_user_cache: Optional[Dict[str, Dict]] = None + + # Mapping that use to keep track of what relationship have been update. + self.cleanup_mapping = {} + + self.insecure_debug = value_to_boolean(os.environ.get("INSECURE_DEBUG", False)) + self.log_finer_level = 0 + try: + self.log_finer_level = int(os.environ.get("KEEPER_GATEWAY_SERVICE_LOG_FINER_LEVEL", 0)) + except (Exception,): + pass + + def debug(self, msg, level: int = 0, secret: bool = False): + if self.log_finer_level >= level: + if secret: + if self.insecure_debug: + self.logger.debug(msg) + else: + self.logger.debug(msg) + @property def dag(self) -> DAG: if self._dag is None: self._dag = DAG(conn=self.conn, record=self.record, - # endpoint=PamEndpoints.SERVICE_LINKS, graph_id=PamGraphId.SERVICE_LINKS, auto_save=False, logger=self.logger, history_level=self.history_level, debug_level=self.debug_level, - name="Discovery Service/Tasks", + name="Discovery Services", fail_on_corrupt=self.fail_on_corrupt, log_prefix=self.log_prefix, save_batch_count=self.save_batch_count, @@ -64,6 +88,9 @@ def dag(self) -> DAG: self._dag.load(sync_point=0) + # If an empty graph, call root get create a vertex. + _ = self._dag.get_root + return self._dag def close(self): @@ -71,9 +98,11 @@ def close(self): Clean up resources held by this UserService instance. Releases the DAG instance and connection to prevent memory leaks. """ - if self._dag is not None: - self._dag = None + + self._dag = None self.conn = None + self._params = None + self._ksm = None def __enter__(self): """Context manager entry.""" @@ -112,8 +141,11 @@ def get_record_uid(discovery_vertex: DAGVertex) -> str: return content.record_uid raise Exception(f"The discovery vertex {discovery_vertex.uid} data does not have a populated record UID.") - def belongs_to(self, resource_uid: str, user_uid: str, acl: Optional[ServiceAcl] = None, - resource_name: Optional[str] = None, user_name: Optional[str] = None): + def belongs_to(self, + resource_uid: str, + user_uid: str, acl: Optional[ServiceAcl] = None, + resource_name: Optional[str] = None, + user_name: Optional[str] = None): """ Link vault records using record UIDs. @@ -121,24 +153,32 @@ def belongs_to(self, resource_uid: str, user_uid: str, acl: Optional[ServiceAcl] If a link already exists, no additional link will be created. """ + if resource_uid is None: + self.debug("resource_uid is blank, do not connect") + return + if user_uid is None: + self.debug("user_uid is blank, do not connect") + return + # Get thr record vertices. # If a vertex does not exist, then add the vertex using the record UID resource_vertex = self.dag.get_vertex(resource_uid) if resource_vertex is None: - self.logger.debug(f"adding resource vertex for record UID {resource_uid} ({resource_name})") + self.debug(f"adding resource vertex for record UID {resource_uid} ({resource_name})") resource_vertex = self.dag.add_vertex(uid=resource_uid, name=resource_name) user_vertex = self.dag.get_vertex(user_uid) if user_vertex is None: - self.logger.debug(f"adding user vertex for record UID {user_uid} ({user_name})") + self.debug(f"adding user vertex for record UID {user_uid} ({user_name})") user_vertex = self.dag.add_vertex(uid=user_uid, name=user_name) - self.logger.debug(f"user {user_vertex.uid} controls services on {resource_vertex.uid}") + self.debug(f"user {user_vertex.uid} controls services on {resource_vertex.uid}") edge_type = EdgeType.LINK if acl is not None: edge_type = EdgeType.ACL + self.debug(f"Connect {user_vertex.uid} to {resource_vertex.uid}") user_vertex.belongs_to(resource_vertex, edge_type=edge_type, content=acl) def disconnect_from(self, resource_uid: str, user_uid: str): @@ -156,11 +196,16 @@ def get_acl(self, resource_uid, user_uid) -> Optional[ServiceAcl]: resource_vertex = self.dag.get_vertex(resource_uid) user_vertex = self.dag.get_vertex(user_uid) if resource_vertex is None or user_vertex is None: - self.logger.debug(f"there is no acl between {resource_uid} and {user_uid}") + if resource_vertex is None: + self.debug("The resource vertex does not exists get; return default ACL") + if user_vertex is None: + self.debug("The user vertex does not exists get; return default ACL") return ServiceAcl() acl_edge = user_vertex.get_edge(resource_vertex, edge_type=EdgeType.ACL) # type: DAGEdge if acl_edge is None: + self.debug(f"ACL does not exists between resource {resource_uid} and user {user_vertex} doesn't " + "exist; return None") return None return acl_edge.content_as_object(ServiceAcl) @@ -206,10 +251,10 @@ def delete(vertex: DAGVertex): def save(self): if self.dag.has_graph: - self.logger.debug("saving the service user.") + self.debug("saving the service user.") self.dag.save(delta_graph=False) else: - self.logger.debug("the service user graph does not contain any data, was not saved.") + self.debug("the service user graph does not contain any data, was not saved.") def to_dot(self, graph_format: str = "svg", show_version: bool = True, show_only_active_vertices: bool = True, show_only_active_edges: bool = True, graph_type: str = "dot"): @@ -230,7 +275,7 @@ def to_dot(self, graph_format: str = "svg", show_version: bool = True, show_only else: dot.attr(layout=graph_type) - self.logger.debug(f"have {len(self.dag.all_vertices)} vertices") + self.debug(f"have {len(self.dag.all_vertices)} vertices") for v in self.dag.all_vertices: if show_only_active_vertices is True and v.active is False: continue @@ -301,378 +346,579 @@ def to_dot(self, graph_format: str = "svg", show_version: bool = True, show_only return dot - def _get_directory_user_vertices(self, configuration_vertex: DAGVertex, domain_name: str) -> List[DAGVertex]: + def _init_cleanup_user_mapping(self): + """ - Find the directory in the graph and return of list of user vertices. + Create of mapping of existing user services to see what was updated. + + This is the basically graph in dictionary format with the update flag set to False. """ - domain_name = domain_name.lower() - - user_vertices: List[DAGVertex] = [] - - # Check the configuration; it might provide domains. - # Need to only include the user vertices. - # If we find it here, we don't need to check for directories; so return with the list. - config_content = DiscoveryObject.get_discovery_object(configuration_vertex) - if config_content.record_type in DOMAIN_USER_CONFIGS: - config_domains = config_content.item.info.get("domains", []) - self.logger.debug(f" the provider provides domains: {config_domains}") - for config_domain in config_domains: - if config_domain.lower() == domain_name: - self.logger.debug(f" matched for {domain_name}") - for vertex in configuration_vertex.has_vertices(): - content = DiscoveryObject.get_discovery_object(vertex) - if content.record_type == PAM_USER: - user_vertices.append(vertex) - self.logger.debug(f" found {len(user_vertices)} users for {domain_name}") - return user_vertices - - self.logger.debug(" checking pam directories for users") - - # If the configuration did not have domain users, or there were do users, check the PAM Directories. - for resource_vertex in configuration_vertex.has_vertices(): - content = DiscoveryObject.get_discovery_object(resource_vertex) - if content.record_type != PAM_DIRECTORY: - continue - if content.name.lower() == domain_name: - user_vertices = resource_vertex.has_vertices() - self.logger.debug(f" found {len(user_vertices)} users for {domain_name}") - break + self.cleanup_mapping = {} + for user_service_machine in self.dag.get_root.has_vertices(): + if user_service_machine.uid not in self.cleanup_mapping: + self.cleanup_mapping[user_service_machine.uid] = {} + for user_service_user in user_service_machine.has_vertices(): + self.cleanup_mapping[user_service_machine.uid][user_service_user.uid] = False - return user_vertices + def _user_is_used(self, machine_record_uid: str, user_record_uid: str): - def _get_user_vertices(self, - infra_resource_content: DiscoveryObject, - infra_resource_vertex: DAGVertex) -> List[DAGVertex]: + """ + Flag the user exists for a machine. + """ - self.logger.debug(f" getting users for {infra_resource_content.name}") + if machine_record_uid in self.cleanup_mapping and user_record_uid in self.cleanup_mapping[machine_record_uid]: + self.cleanup_mapping[machine_record_uid][user_record_uid] = True - # If this machine joined to a directory. - # Since this a Windows machine, we can have only one joined directory; take the first one. - domain_name = None - if len(infra_resource_content.item.facts.directories) > 0: - domain_name = infra_resource_content.item.facts.directories[0].domain - self.logger.debug(f" joined to {domain_name}") - - # Get a list of local users. - # If the machine is joined to a domain, get a list of users from that domain. - user_vertices = infra_resource_vertex.has_vertices() - self.logger.debug(f" found {len(user_vertices)} local users") - if domain_name is not None: - user_vertices += self._get_directory_user_vertices( - configuration_vertex=infra_resource_vertex.belongs_to_vertices()[0], - domain_name=domain_name - ) - - self.logger.debug(f" found {len(user_vertices)} total users") - - return user_vertices - - def _connect_service_users(self, - infra_resource_content: DiscoveryObject, - infra_resource_vertex: DAGVertex, - services: List[FactsNameUser]): - - self.logger.debug(f"processing services for {infra_resource_content.description} ({infra_resource_vertex.uid})") - - # We don't care about the name of the service, we just need a list users. - lookup = {} - for service in services: - lookup[service.user.lower()] = True - - infra_user_vertices = self._get_user_vertices(infra_resource_content=infra_resource_content, - infra_resource_vertex=infra_resource_vertex) - - for infra_user_vertex in infra_user_vertices: - infra_user_content = DiscoveryObject.get_discovery_object(infra_user_vertex) - if infra_user_content.record_uid is None: + def _cleanup_users(self): + + """ + Disconnect all users from machines that are not used. + """ + + self.debug("cleaning up unused user service relationships") + + did_something = False + for machine_record_uid in self.cleanup_mapping: + for user_record_uid in self.cleanup_mapping[machine_record_uid]: + if not self.cleanup_mapping[machine_record_uid][user_record_uid]: + self.debug(f" * disconnect user {user_record_uid} from machine {machine_record_uid}") + did_something = True + self.disconnect_from(machine_record_uid, user_record_uid) + if not did_something: + self.debug(f" nothing to cleanup") + + @staticmethod + def _get_local_users_from_record(record_lookup_func: Callable, + rl_machine_vertex: DAGVertex) -> Dict[str, str]: + + # Get the local users + user_records: Dict[str, str] = {} + + for rl_user_vertex in rl_machine_vertex.has_vertices(): + record = record_lookup_func(rl_user_vertex.uid, allow_sm=False) # type: NormalizedRecord + if record and record.record_type == PAM_USER: + user = record.get_user() + if user is not None: + user, domain = split_user_and_domain(user.lower()) + if domain is not None: + user += "@" + domain + user_records[user] = record.record_uid + + return user_records + + @staticmethod + def _get_local_users_from_infra(record_lookup_func: Callable, + infra_machine_vertex: DAGVertex) -> Dict[str, str]: + + user_records: Dict[str, str] = {} + for infra_user_vertex in infra_machine_vertex.has_vertices(): + user_content = DiscoveryObject.get_discovery_object(infra_user_vertex) + if user_content.record_type != PAM_USER or user_content.record_uid is None: continue - if user_in_lookup( - lookup=lookup, - user=infra_user_content.item.user, - name=infra_user_content.name, - source=infra_user_content.item.source): - self.logger.debug(f" * found user for service: {infra_user_content.item.user}") - acl = self.get_acl(infra_resource_content.record_uid, infra_user_content.record_uid) - if acl is None: - acl = ServiceAcl() - acl.is_service = True - self.belongs_to( - resource_uid=infra_resource_content.record_uid, - resource_name=infra_resource_content.uid, - user_uid=infra_user_content.record_uid, - user_name=infra_user_content.uid, - acl=acl) - - def _connect_task_users(self, - infra_resource_content: DiscoveryObject, - infra_resource_vertex: DAGVertex, - tasks: List[FactsNameUser]): - - self.logger.debug(f"processing tasks for {infra_resource_content.description} ({infra_resource_vertex.uid})") - - # We don't care about the name of the tasks, we just need a list users. - lookup = {} - for task in tasks: - lookup[task.user.lower()] = True - - infra_user_vertices = self._get_user_vertices(infra_resource_content=infra_resource_content, - infra_resource_vertex=infra_resource_vertex) - - for infra_user_vertex in infra_user_vertices: - infra_user_content = DiscoveryObject.get_discovery_object(infra_user_vertex) - if infra_user_content.record_uid is None: + if record_lookup_func(user_content.record_uid, allow_sm=False): + user, domain = split_user_and_domain(user_content.item.user.lower()) + if domain is not None: + user += "@" + domain + user_records[user] = user_content.record_uid + + return user_records + + def _get_directory_users_from_conf_record(self, + record_linking: RecordLink, + domain_name: str, + record_lookup_func: Callable) -> Dict[str, str]: + + user_records: Dict[str, str] = {} + + # check if a PAM configuration that support having users (Azure, Domain Controller) + # We need to get the normalized record of the configuration record. + configuration_record = record_lookup_func( + self.conn.get_record_uid(self.record), allow_sm=False) # type: NormalizedRecord + if configuration_record.record_type in DOMAIN_USER_CONFIGS: + # The Domain Controller record will have the domain; Azure record will not. + config_domain_name = configuration_record.get_value(label="pamdomainid") + + # If the domain name is not set, or it is, and we match the one that machine is joined to. + if config_domain_name is None or config_domain_name.lower() == domain_name: + config_vertex = record_linking.dag.get_vertex(configuration_record.record_uid) + for child_vertex in config_vertex.has_vertices(): + user_record = record_lookup_func(child_vertex.uid, allow_sm=False) # type: NormalizedRecord + if not user_record: + # self.debug(f" * record uid {child_vertex.uid} not found") + continue + if user_record.record_type != PAM_USER: + # self.debug(f" * record uid {child_vertex.uid} is not PAM User") + continue + user, domain = split_user_and_domain(user_record.get_user().lower()) + if domain is None: + domain = domain_name + user += "@" + domain + user_records[user] = user_record.record_uid + else: + self.debug(f" domain name {config_domain_name} does not match {domain_name}") + else: + self.debug(" configuration type does not allow AD users") + + return user_records + + def _get_directory_users_from_conf_infra(self, + infra: Infrastructure, + domain_name: str, + record_lookup_func: Callable) -> Dict[str, str]: + + user_records: Dict[str, str] = {} + + config_vertex = infra.get_configuration + config_context = DiscoveryObject.get_discovery_object(config_vertex) + if config_context.record_type in DOMAIN_USER_CONFIGS: + for config_domain_name in config_context.item.info.get("domains", []): + if config_domain_name != domain_name: + self.debug(f" domain name {config_domain_name} does not match {domain_name}") + continue + for child_vertex in config_vertex.has_vertices(): + child_context = DiscoveryObject.get_discovery_object(child_vertex) + if child_context.record_type == PAM_USER and record_lookup_func(child_context.record_uid, + allow_sm=False): + user, domain = split_user_and_domain(child_context.item.user.lower()) + if domain is None: + domain = domain_name + user += "@" + domain + user_records[user] = child_context.record_uid + + return user_records + + def _get_directory_users_from_records(self, + record_linking: RecordLink, + domain_name: str, + record_lookup_func: Callable) -> Dict[str, str]: + + user_records: Dict[str, str] = {} + + # From the record linking graph, check each record connected to the configuration to see if it is a + # PAM directory record. + for rl_resource_vertex in record_linking.dag.get_root.has_vertices(): + directory_record = record_lookup_func(rl_resource_vertex.uid, allow_sm=False) # type: NormalizedRecord + if directory_record and directory_record.record_type == PAM_DIRECTORY: + record_domain_name = directory_record.get_value(label="domainName") + if record_domain_name is None: + self.logger.warning(f" record uid {rl_resource_vertex.uid} is a directory, but the " + "Domain Name is not set.") + continue + if record_domain_name.lower() == domain_name: + self.debug(f" record uid {rl_resource_vertex.uid} matches the domain name") + for rl_user_vertex in rl_resource_vertex.has_vertices(): + user_record = record_lookup_func(rl_user_vertex.uid, allow_sm=False) # type: NormalizedRecord + if user_record is None or user_record.record_type != PAM_USER: + continue + + # Get the directory users, format the username to be user@domain + user = user_record.get_user() + if user is not None: + user, domain = split_user_and_domain(user.lower()) + if domain is None: + domain = domain_name + user += "@" + domain + user_records[user] = user_record.record_uid + else: + self.debug(f" ! record uid {rl_user_vertex.uid} has a blank user") + + return user_records + + @staticmethod + def _get_directory_users_from_infra(infra_machine_vertex: DAGVertex, + domain_name: str, + record_lookup_func: Callable) -> Dict[str, str]: + + user_records: Dict[str, str] = {} + + configuration_vertex = infra_machine_vertex.belongs_to_vertices()[0] + for resource_vertex in configuration_vertex.has_vertices(): + if not resource_vertex.has_data: continue - if user_in_lookup( - lookup=lookup, - user=infra_user_content.item.user, - name=infra_user_content.name, - source=infra_user_content.item.source): - self.logger.debug(f" * found user for task: {infra_user_content.item.user}") - acl = self.get_acl(infra_resource_content.record_uid, infra_user_content.record_uid) - if acl is None: - acl = ServiceAcl() - acl.is_task = True - self.belongs_to( - resource_uid=infra_resource_content.record_uid, - resource_name=infra_resource_content.uid, - user_uid=infra_user_content.record_uid, - user_name=infra_user_content.uid, - acl=acl) - - def _connect_iis_pool_users(self, - infra_resource_content: DiscoveryObject, - infra_resource_vertex: DAGVertex, - iis_pools: List[FactsNameUser]): - - self.logger.debug(f"processing iis pools for " - f"{infra_resource_content.description} ({infra_resource_vertex.uid})") - - # We don't care about the name of the tasks, we just need a list users. - lookup = {} - for iis_pool in iis_pools: - lookup[iis_pool.user.lower()] = True - - infra_user_vertices = self._get_user_vertices(infra_resource_content=infra_resource_content, - infra_resource_vertex=infra_resource_vertex) - - for infra_user_vertex in infra_user_vertices: - infra_user_content = DiscoveryObject.get_discovery_object(infra_user_vertex) - if infra_user_content.record_uid is None: + resource_content = DiscoveryObject.get_discovery_object(resource_vertex) + if resource_content.record_type != PAM_DIRECTORY or resource_content.name.lower() != domain_name: continue - if user_in_lookup( - lookup=lookup, - user=infra_user_content.item.user, - name=infra_user_content.name, - source=infra_user_content.item.source): - self.logger.debug(f" * found user for iis pool: {infra_user_content.item.user}") - acl = self.get_acl(infra_resource_content.record_uid, infra_user_content.record_uid) - if acl is None: - acl = ServiceAcl() - acl.is_iis_pool = True - self.belongs_to( - resource_uid=infra_resource_content.record_uid, - resource_name=infra_resource_content.uid, - user_uid=infra_user_content.record_uid, - user_name=infra_user_content.uid, - acl=acl) - - def _validate_users(self, - infra_resource_content: DiscoveryObject, - infra_resource_vertex: DAGVertex): + for user_vertex in resource_vertex.has_vertices(): + if not user_vertex.has_data: + continue + user_content = DiscoveryObject.get_discovery_object(user_vertex) + if user_content.record_type != PAM_USER and user_content.record_uid is None: + continue + if record_lookup_func(user_content.record_uid, allow_sm=False): + + # Format the username to be user@domain + user, domain = split_user_and_domain(user_content.item.user.lower()) + if domain is None: + domain = domain_name + user += "@" + domain + user_records[user] = user_content.record_uid + return user_records + + def _get_users(self, + infra: Infrastructure, + infra_machine_content: DiscoveryObject, + infra_machine_vertex: DAGVertex, + record_linking: RecordLink, + record_lookup_func: Callable) -> Dict[str, str]: """ - This method will check to see if a resource's users' ACL edges are still valid. + Get local and directory users for machine. + + The return values will be a dictionary of record_uid to username. - This check will check both local and directory users. + It will first check the records linking graph. Then check the infrastructure graph. """ - self.logger.debug(f"validate existing user service edges to see if still valid to " - f"{infra_resource_content.name}") + self.debug(f" getting users for {infra_machine_content.name}, {infra_machine_content.record_uid}") - service_lookup = {} - for service in infra_resource_content.item.facts.services: - service_lookup[service.user.lower()] = True + # Get the domain name that the machine it joined to. + # Only accept the first one; we are Windows, only allow one domain. + domain_name = None + for directory in infra_machine_content.item.facts.directories: + if directory.domain is not None: + domain_name = directory.domain.lower() + self.debug(f" machine is joined to {domain_name}") + break + + # Keep separate dictionaries since we are going to cache the directory users by domain name. + # { "user": "record uid", ... } + local_user_records: Dict[str, str] = {} + directory_user_records: Dict[str, str] = {} + + using_directory_user_cache = False + if domain_name: + # Once we get directory users for a domain name, they will not change. + # Cache them so we don't have to get them again. + if self.directory_user_cache is not None: + directory_user_records = self.directory_user_cache.get(domain_name) + self.debug(f" using directory user cache for {domain_name}, " + f"{len(directory_user_records)} users") + using_directory_user_cache = True + + ########################### + + # Find the users using the record linking graph. + self.debug(f" getting users from record linking", level=1) + record_link_vertex = record_linking.dag.get_vertex(infra_machine_content.record_uid) + if record_link_vertex is None: + self.debug(" record uid {machine_record_uid} does not exist in the Vault.", level=1) + else: - task_lookup = {} - for task in infra_resource_content.item.facts.tasks: - task_lookup[task.user.lower()] = True + # Get the local users from records + self.debug(" getting local users from records", level=1) + user_records = self._get_local_users_from_record(rl_machine_vertex=record_link_vertex, + record_lookup_func=record_lookup_func) + self.debug(f" * found {len(user_records)} local users from records", level=1) + local_user_records = {**local_user_records, **user_records} - iis_pool_lookup = {} - for iss_pool in infra_resource_content.item.facts.iis_pools: - iis_pool_lookup[iss_pool.user.lower()] = True + if not using_directory_user_cache and domain_name is not None: - # Get the user service resource vertex. - # If it does not exist, then we cannot validate users. - user_service_resource_vertex = self.dag.get_vertex(infra_resource_content.record_uid) - if user_service_resource_vertex is None: - return + self.debug(" getting directory users from the configuration record", level=1) + user_records = self._get_directory_users_from_conf_record(record_linking=record_linking, + domain_name=domain_name, + record_lookup_func=record_lookup_func) - infra_dag = infra_resource_vertex.dag + self.debug(f" * found {len(user_records)} directory users records from " + "the configuration record", level=1) + directory_user_records = {**directory_user_records, **user_records} - # The users from the service graph will contain local and directory users. - for user_service_user_vertex in user_service_resource_vertex.has_vertices(): - acl_edge = user_service_user_vertex.get_edge( - user_service_resource_vertex, edge_type=EdgeType.ACL) # type: DAGEdge - if acl_edge is None: - self.logger.info(f"User record {user_service_user_vertex.uid} does not have an ACL edge to " - f"{user_service_resource_vertex.uid} for user services.") - continue + self.debug(" getting directory users from directory records", level=1) + user_records = self._get_directory_users_from_records(record_linking=record_linking, + domain_name=domain_name, + record_lookup_func=record_lookup_func) + self.debug(f" * found {len(user_records)} directory users from records for {domain_name}", + level=1) + + directory_user_records = {**directory_user_records, **user_records} + + #################### + + # Find the users via infrastructure graph + + self.debug(f" getting users from infrastructure", level=1) + self.debug(" getting local users from infrastructure", level=1) + user_records = self._get_local_users_from_infra(infra_machine_vertex=infra_machine_vertex, + record_lookup_func=record_lookup_func) + self.debug(f" * found {len(user_records)} local users from graph", level=1) + local_user_records = {**user_records, **local_user_records} + + if not using_directory_user_cache and domain_name is not None: + + self.debug(" getting directory users from configuration infrastructure", level=1) + user_records = self._get_directory_users_from_conf_infra(infra=infra, + domain_name=domain_name, + record_lookup_func=record_lookup_func) + self.debug(f" * found {len(user_records)} directory users from configuration for {domain_name}", + level=1) + directory_user_records = {**user_records, **directory_user_records} + + # ------------- + + self.debug(" getting directory users from directory infrastructure", level=1) + user_records = self._get_directory_users_from_infra(infra_machine_vertex=infra_machine_vertex, + domain_name=domain_name, + record_lookup_func=record_lookup_func) + self.debug(f" * found {len(user_records)} directory users from graph for {domain_name}", level=1) + directory_user_records = {**user_records, **directory_user_records} + + # If we were not using the directory cache, cache them. + if domain_name is not None and not using_directory_user_cache: + if self.directory_user_cache is None: + self.directory_user_cache = {} + self.directory_user_cache[domain_name] = directory_user_records + + all_record = {**directory_user_records, **local_user_records} + + self.debug(f" total union of users count {len(all_record.keys())}") + + return all_record + + def _connect_users_to_services(self, + infra: Infrastructure, + infra_machine_content: DiscoveryObject, + infra_machine_vertex: DAGVertex, + record_linking: RecordLink, + record_lookup_func: Callable, + strict: bool = False): + + domain_name = None + for directory in infra_machine_content.item.facts.directories: + if directory.domain is not None: + domain_name = directory.domain.lower() + break + + # Add mapping from user to machine, that control services. + for service_type in ["service", "task", "iis_pool"]: + self.debug("-" * 40) + self.debug(f"processing {service_type}s for {infra_machine_content.name} " + f"({infra_machine_vertex.uid})") + + # We don't care about the name of the service, we just need a list users. + service_users = [] + for service_user in getattr(infra_machine_content.item.facts, f"{service_type}s"): + self.debug(f" * {service_type}: {service_user.name} ({service_user.user})", secret=True) + user = service_user.user.lower() + if not strict: + user, domain = split_user_and_domain(user) + service_users.append(user) + if domain is not None and domain != ".": + service_users.append(user + "@" + domain) + service_users.append(user + "@" + domain.split(".")[0]) + if domain_name is not None: + service_users.append(user + "@" + domain_name) + service_users.append(user + "@" + domain_name.split(".")[0]) - found_service_acl = False - found_task_acl = False - found_iis_pool_acl = False - changed = False - - acl = acl_edge.content_as_object(ServiceAcl) - - # This will check the entire infrastructure graph for the user with the record UID. - # This could be a local or directory users. - user = infra_dag.search_content({"record_type": PAM_USER, "record_uid": user_service_user_vertex.uid}) - infra_user_content = None - found_user = len(user) > 0 - if found_user: - infra_user_vertex = user[0] - if infra_user_vertex.active is False: - found_user = False else: - infra_user_content = DiscoveryObject.get_discovery_object(infra_user_vertex) + service_users.append(user) + + service_users = list(set(service_users)) - if not found_user: - self.disconnect_from(user_service_resource_vertex.uid, user_service_user_vertex.uid) + if len(service_users) == 0: + self.debug(f" no users control {service_type}s, skipping.") continue - check_list = user_check_list( - user=infra_user_content.item.user, - name=infra_user_content.name, - source=infra_user_content.item.source - ) - - if acl.is_service: - for check_user in check_list: - if check_user in service_lookup: - found_service_acl = True - break - if not found_service_acl: - acl.is_service = False - changed = True - - if acl.is_task: - for check_user in check_list: - if check_user in task_lookup: - found_task_acl = True - break - if not found_task_acl: - acl.is_task = False - changed = True - - if acl.is_iis_pool: - for check_user in check_list: - if check_user in iis_pool_lookup: - found_iis_pool_acl = True - break - if not found_iis_pool_acl: - acl.is_iis_pool = False - changed = True - - if (found_service_acl is True or found_task_acl is True or found_iis_pool_acl is True) or changed is True: - self.logger.debug(f"user {user_service_user_vertex.uid}(US) to " - f"{user_service_resource_vertex.uid} updated") - self.belongs_to(user_service_resource_vertex.uid, user_service_user_vertex.uid, acl) - elif found_service_acl is False and found_task_acl is False and found_iis_pool_acl is False: - self.logger.debug(f"user {user_service_user_vertex.uid}(US) to " - f"{user_service_resource_vertex.uid} disconnected") - self.disconnect_from(user_service_resource_vertex.uid, user_service_user_vertex.uid) - - self.logger.debug(f"DONE validate existing user") - - def run(self, infra: Optional[Infrastructure] = None, **kwargs): + users = self._get_users(infra=infra, + infra_machine_content=infra_machine_content, + infra_machine_vertex=infra_machine_vertex, + record_linking=record_linking, + record_lookup_func=record_lookup_func) + + if self.log_finer_level >= 2 and self.insecure_debug: + for k, v in users.items(): + self.debug(f"> {k} = {v}") + + self.debug(f"users to check: {service_users}", secret=True) + for service_user in service_users: + self.debug(f" * {service_user}", secret=True) + if service_user in users: + record_uid = users[service_user] + self.debug(f" found user {service_user} for {service_type}", secret=True) + acl = self.get_acl(infra_machine_content.record_uid, record_uid) + if acl is None: + acl = ServiceAcl() + acl_attr = "is_" + service_type + + # Flag the user was found; don't disconnect + self._user_is_used(machine_record_uid=infra_machine_content.record_uid, + user_record_uid=record_uid) + + # Only update if the attribute is currently False; reduce edges. + if getattr(acl, acl_attr) is False: + setattr(acl, acl_attr, True) + self.belongs_to(resource_uid=infra_machine_content.record_uid, + user_uid=record_uid, + acl=acl) + + def _get_resource_info(self, + record_uid: str, + infra: Infrastructure, + record_lookup_func: Callable, + record_types: Optional[List[str]] = None) -> Optional[NormalizedRecord]: + """ - Map users to services/tasks on machines. + Find a resource, or user, in the Vault or in the Infrastructure graph. + + This will return a NormalizedRecord record. + This doesn't mean the - IMPORTANT: To avoid memory leaks, pass an existing Infrastructure instance - instead of letting this method create a new one. Example: - user_service.run(infra=process.infra) """ - self.logger.debug("") - self.logger.debug("##########################################################################################") - self.logger.debug("# MAP USER TO MACHINE FOR SERVICE/TASKS") - self.logger.debug("") - - # If an instance of Infrastructure is not passed in. - # NOTE: Creating a new Infrastructure instance here can cause memory leaks. - # Prefer passing an existing instance via the infra parameter. - _cleanup_infra_on_exit = False - if infra is None: - self.logger.warning("Creating new Infrastructure instance - consider passing existing instance to avoid memory leaks") - - # Get ksm from the connection. - # However, this might be a local connection, so check first. - # Local connections don't need ksm. - if hasattr(self.conn, "ksm"): - kwargs["ksm"] = getattr(self.conn, "ksm") - - # Get the entire infrastructure graph; sync point = 0 - infra = Infrastructure(record=self.record, **kwargs) - infra.load() - _cleanup_infra_on_exit = True - - # Work ourselves to the configuration vertex. - infra_root_vertex = infra.get_root - infra_config_vertex = infra_root_vertex.has_vertices()[0] - - # For the user service, the root vertex is the equivalent to the infrastructure configuration vertex. - user_service_config_vertex = self.dag.get_root - - # Find all the resources that are machines. - for infra_resource_vertex in infra_config_vertex.has_vertices(): - if infra_resource_vertex.active is False or infra_resource_vertex.has_data is False: - continue - infra_resource_content = DiscoveryObject.get_discovery_object(infra_resource_vertex) - if infra_resource_content.record_type == PAM_MACHINE: + # Check the record first; return a NormalizedRecord + record = record_lookup_func(record_uid, allow_sm=False) # type: NormalizedRecord + if record is not None: + self.debug(f" resource is {record.title}") + if record_types is not None and record.record_type not in record_types: + self.debug(f" not correct record type: {record.record_type}") + return None + return record + else: + self.debug(" not in Vault") + + infra_vertices = infra.dag.search_content({"record_uid": record_uid}) + if not len(infra_vertices): + self.debug(" not in infrastructure graph") + return None - self.logger.debug(f"checking {infra_resource_content.name}") + for vertex in infra_vertices: + if vertex.active: + content = DiscoveryObject.get_discovery_object(vertex) + record = NormalizedRecord( + record_uid=record_uid, + record_type=content.record_type, + title=content.title, + record_exists=False + ) + for field in content.fields: + record.fields.append(field) + + return record + + return None + + def run_user(self): + pass + + def run_full(self, + record_lookup_func: Callable, + infra: Optional[Infrastructure] = None, + record_linking: Optional[RecordLink] = None, + **kwargs): + """ + Map users to services on machines. - # Check the user on the resource if they still are part of a service or task. - self._validate_users(infra_resource_content, infra_resource_vertex) + This is driven by the record linking graph. + + :param infra: Instance of Infrastructure graph. + :param record_linking: Instance of the Record Linking graph. + :param record_lookup_func: A function that will return a record by record id. Returns a normalize record. + """ + + self.debug("") + self.debug("##########################################################################################") + self.debug("# MAP USER TO MACHINE FOR SERVICES") + self.debug("") + + # Load fresh + + created_infra = False + created_record_linking = False + + try: + + # Make of map of the current user to machine relationship. + self._init_cleanup_user_mapping() + + if not infra: + infra = Infrastructure(record=self.record, logger=self.logger, ksm=self._ksm, params=self._params) + infra.load(sync_point=0) + created_infra = True + + if not record_linking: + record_linking = RecordLink(record=self.record, logger=self.logger, ksm=self._ksm, params=self._params) + created_record_linking = True + + # The PAM Configuration record is the root vertex of the PAM/record linking graph. + rl_configuration_vertex = record_linking.dag.get_root + + # At this level the vertex will either be a resource or a cloud user. + for rl_resource_vertex in rl_configuration_vertex.has_vertices(): + + self.debug(f"checking record {rl_resource_vertex.uid}") + + # This will get machine from the records or from infrastructure graph. + # The results is a NormalizedRecord. + machine_record = self._get_resource_info(record_uid=rl_resource_vertex.uid, + infra=infra, + record_lookup_func=record_lookup_func, + record_types=[PAM_MACHINE]) + + if machine_record is None: + self.debug(" could not find record") + continue + + if machine_record.record_type != PAM_MACHINE: + self.debug(" record is not PAM Machine") + continue + + self.debug(f" checking machine {machine_record.title}") + + # Since the facts hold information about services, get those from the infrastructure graph. + infra_machine_vertex = infra.find_content({"record_uid": machine_record.record_uid}) + if not infra_machine_vertex: + self.debug(" could not find machine in the infrastructure graph, skipping") + continue + if not infra_machine_vertex.has_data: + self.debug(" machine has no data yet, skipping") + continue + + infra_machine_content = DiscoveryObject.get_discovery_object(infra_machine_vertex) + + # The `services` are currently on Windows machine, skip any machine that is not running Windows. + if infra_machine_content.item.os != "windows": + self.debug(" machine is not Windows, skipping") + continue # Do we have services, tasks, iis_pools that are run as a user with a password? - if infra_resource_content.item.facts.has_service_items is True: - - # If the resource does not exist in the user service graph, add a vertex and link it to the - # user service root/configuration vertex. - user_service_resource_vertex = self.dag.get_vertex(infra_resource_content.record_uid) - if user_service_resource_vertex is None: - user_service_resource_vertex = self.dag.add_vertex(uid=infra_resource_content.record_uid, - name=infra_resource_content.description) - if not user_service_config_vertex.has(user_service_resource_vertex): - user_service_resource_vertex.belongs_to_root(EdgeType.LINK) - - # Do we have services that are run as a user with a password? - if infra_resource_content.item.facts.has_services is True: - self._connect_service_users( - infra_resource_content, - infra_resource_vertex, - infra_resource_content.item.facts.services) - - # Do we have tasks that are run as a user with a password? - if infra_resource_content.item.facts.has_tasks is True: - self._connect_task_users( - infra_resource_content, - infra_resource_vertex, - infra_resource_content.item.facts.tasks) - - # Do we have tasks that are run as a user with a password? - if infra_resource_content.item.facts.has_iis_pools is True: - self._connect_iis_pool_users( - infra_resource_content, - infra_resource_vertex, - infra_resource_content.item.facts.iis_pools) - - self.save() - - # Clean up the Infrastructure instance if we created it - if _cleanup_infra_on_exit and infra is not None: - self.logger.debug("cleaning up Infrastructure instance created in run()") - infra.close() + if not infra_machine_content.item.facts.has_service_items: + self.debug(" machine has no user controlled services, skipping") + continue + + user_service_machine_vertex = self.dag.get_vertex(infra_machine_content.record_uid) + + # If the resource does not exist in the user service graph, add a vertex and link it to the + # user service root/configuration vertex. + if user_service_machine_vertex is None: + user_service_machine_vertex = self.dag.add_vertex(uid=infra_machine_content.record_uid, + name=infra_machine_content.name) + + # If the UserService resource vertex is not connect to root, connect it. + if not self.dag.get_root.has(user_service_machine_vertex): + user_service_machine_vertex.belongs_to_root(EdgeType.LINK) + + self.debug("-" * 40) + self._connect_users_to_services( + infra=infra, + infra_machine_content=infra_machine_content, + infra_machine_vertex=infra_machine_vertex, + record_linking=record_linking, + record_lookup_func=record_lookup_func) + self.debug("-" * 40) + + # Disconnect any users not used. + # TODO - Handle this better. + # If a machine is off, or we cannot connect, we might disconnect users. + # This needs more testing. + # self._cleanup_users() + + self.save() + + except Exception as err: + self.logger.error(f"could not map users to services: {err}") + raise err + + finally: + if created_infra: + infra.close() + if created_record_linking: + record_linking.close() diff --git a/keepercommander/discovery_common/utils.py b/keepercommander/discovery_common/utils.py index 0a5c7db92..5d8b06d0d 100644 --- a/keepercommander/discovery_common/utils.py +++ b/keepercommander/discovery_common/utils.py @@ -4,7 +4,7 @@ from .types import DiscoveryObject from ..keeper_dag.vertex import DAGVertex from .__version__ import __version__ -from typing import List, Optional, Tuple, TYPE_CHECKING +from typing import Optional, Tuple, TYPE_CHECKING if TYPE_CHECKING: from ..keeper_dag.dag import DAG @@ -65,62 +65,28 @@ def get_connection(**kwargs): def split_user_and_domain(user: str) -> Tuple[Optional[str], Optional[str]]: + """ + If the username is a UPN, email, netbios\\username, break it apart into user and domain/netbios. + """ + if user is None: return None, None domain = None - if "\\" in user: - user_parts = user.split("\\", maxsplit=1) + if "@" in user: + user_parts = user.split("@", maxsplit=1) user = user_parts[0] + if "\\" in user: + _, user = user.split("\\") domain = user_parts[1] - elif "@" in user: - user_parts = user.split("@") - domain = user_parts.pop() - user = "@".join(user_parts) + elif "\\" in user: + user_parts = user.split("\\", maxsplit=1) + user = user_parts[1].replace("\\", "") + domain = user_parts[0] return user, domain - -def user_check_list(user: str, name: Optional[str] = None, source: Optional[str] = None) -> List[str]: - user, domain = split_user_and_domain(user) - user = user.lower() - - # TODO: Add boolean for tasks to include `local users` patterns. - # It appears that for task lists, directory users do not have domains. - # A problem could arise where the customer uses a local user and directory with the same name. - check_list = [user, f".\\{user}"] - if name is not None: - name = name.lower() - check_list += [name, f".\\{name}"] - if source is not None: - source = source.lower() - check_list.append(f"{source[:15]}\\{user}") - check_list.append(f"{user}@{source}") - netbios_parts = source.split(".") - if len(netbios_parts) > 1: - check_list.append(f"{netbios_parts[0][:15]}\\{user}") - check_list.append(f"{user}@{netbios_parts[0]}") - if domain is not None: - domain = domain.lower() - check_list.append(f"{domain[:15]}\\{user}") - check_list.append(f"{user}@{domain}") - domain_parts = domain.split(".") - if len(domain_parts) > 1: - check_list.append(f"{domain_parts[0][:15]}\\{user}") - check_list.append(f"{user}@{domain_parts[0]}") - - return list(set(check_list)) - - -def user_in_lookup(user: str, lookup: dict, name: Optional[str] = None, source: Optional[str] = None) -> bool: - - for check_user in user_check_list(user, name, source): - if check_user in lookup: - return True - return False - - def find_user_vertex(graph: DAG, user: str, domain: Optional[str] = None) -> Optional[DAGVertex]: user_vertices = graph.search_content({"record_type": PAM_USER}) diff --git a/keepercommander/importer/thycotic/thycotic.py b/keepercommander/importer/thycotic/thycotic.py index 77c387c83..ccd5b9029 100644 --- a/keepercommander/importer/thycotic/thycotic.py +++ b/keepercommander/importer/thycotic/thycotic.py @@ -321,12 +321,12 @@ def do_import(self, filename, **kwargs): filter_folder = kwargs.get('filter_folder') if filter_folder: if filter_folder == 'Personal Folders': - folder_ids = [1] + matched_folder_ids = [1] else: - folder_ids = [x['id'] for x in folders.values() - if x['folderName'] == x['folderPath'] and x['folderName'].lower() == filter_folder.lower()] - if len(folder_ids) == 0: + matched_folder_ids = [x['id'] for x in folders.values() if x['folderName'].lower() == filter_folder.lower()] + if len(matched_folder_ids) == 0: logging.warning('Folder \"%s\" not found', filter_folder) + folder_ids = list(matched_folder_ids) pos = 0 while pos < len(folder_ids): folder_id = folder_ids[pos] @@ -335,13 +335,8 @@ def do_import(self, filename, **kwargs): folder_ids = set(folder_ids) folders = {i: x for i, x in folders.items() if i in folder_ids} - if filter_folder: - if filter_folder == 'Personal Folders': - root_folder_ids = [1] - else: - root_folder_ids = [x['id'] for x in folders.values() if x['folderName'] == x['folderPath']] secrets_ids = [] - for folder_id in root_folder_ids: + for folder_id in matched_folder_ids: query = f'/v1/secrets/lookup?filter.folderId={folder_id}&filter.includeSubFolders=true' secrets_ids.extend([x['id'] for x in auth.thycotic_search(query)]) else: diff --git a/keepercommander/keeper_dag/__version__.py b/keepercommander/keeper_dag/__version__.py index 394531931..a98f9837e 100644 --- a/keepercommander/keeper_dag/__version__.py +++ b/keepercommander/keeper_dag/__version__.py @@ -1 +1 @@ -__version__ = '1.1.6' # pragma: no cover +__version__ = '1.1.9' # pragma: no cover diff --git a/keepercommander/keeper_dag/connection/local.py b/keepercommander/keeper_dag/connection/local.py index 96c008ff6..e8f2e79dc 100644 --- a/keepercommander/keeper_dag/connection/local.py +++ b/keepercommander/keeper_dag/connection/local.py @@ -86,6 +86,7 @@ def get_key_bytes(record: object) -> bytes: def clear_database(self): try: + print(f"remove DAG file as {self.db_file}") os.unlink(self.db_file) except (Exception,): pass @@ -195,6 +196,8 @@ def _find_stream_id(self, payload: DataPayload): # First check if we can route with existing edges in the database. stream_id = None + if not os.path.exists(self.db_file): + raise Exception(f"Cannot find local DAG as {self.db_file}") with closing(sqlite3.connect(self.db_file)) as connection: with closing(connection.cursor()) as cursor: diff --git a/keepercommander/keeper_dag/dag.py b/keepercommander/keeper_dag/dag.py index a2d6ddfed..81a75339c 100644 --- a/keepercommander/keeper_dag/dag.py +++ b/keepercommander/keeper_dag/dag.py @@ -93,8 +93,7 @@ def __init__(self, if logger is None: logger = logging.getLogger() self.logger = logger - if debug_level is None: - debug_level = int(os.environ.get("GS_DEBUG_LEVEL", os.environ.get("DAG_DEBUG_LEVEL", 0))) + self.debug_level = int(os.environ.get("GS_DEBUG_LEVEL", os.environ.get("DAG_DEBUG_LEVEL", debug_level))) # Prevent duplicate edges to be added. # The goal is to prevent unneeded edges. @@ -106,7 +105,6 @@ def __init__(self, raise Exception("Cannot run dedup_edge and auto_save at the same time. The dedup_edge feature only works " "in bulk saves.") - self.debug_level = debug_level self.log_prefix = log_prefix if save_batch_count is None or save_batch_count <= 0: @@ -1105,8 +1103,8 @@ def _add_data(vertex): self.debug(f"{data.ref.value} -> {data.parentRef.value} ({data.type})") self.debug("##############################################") - self.debug(f"total list has {len(data_list)} items", level=0) - self.debug(f"batch {self.save_batch_count} edges", level=0) + self.debug(f"total list has {len(data_list)} items", level=1) + self.debug(f"batch {self.save_batch_count} edges", level=1) batch_num = 0 while len(data_list) > 0: @@ -1126,7 +1124,7 @@ def _add_data(vertex): if len(batch_list) == 0: break - self.debug(f"adding {len(batch_list)} edges, batch {batch_num}", level=0) + self.debug(f"adding {len(batch_list)} edges, batch {batch_num}", level=1) payload = self.write_struct_obj.payload( origin_ref=self.write_struct_obj.origin_ref( diff --git a/keepercommander/keeper_dag/edge.py b/keepercommander/keeper_dag/edge.py index eae576121..ccc0faaa1 100644 --- a/keepercommander/keeper_dag/edge.py +++ b/keepercommander/keeper_dag/edge.py @@ -3,14 +3,15 @@ from .types import EdgeType from .exceptions import DAGContentException import json -from typing import Optional, Union, Any, TYPE_CHECKING +from typing import Optional, Union, Any, TYPE_CHECKING, TypeVar, Type if TYPE_CHECKING: # pragma: no cover from .vertex import DAGVertex Content = Union[str, bytes, dict] QueryValue = Union[list, dict, str, float, int, bool] - import pydantic - from pydantic import BaseModel + + +T = TypeVar('T') class DAGEdge: @@ -159,8 +160,7 @@ def content_as_str(self) -> Optional[str]: pass return content - def content_as_object(self, - meta_class: pydantic._internal._model_construction.ModelMetaclass) -> Optional[BaseModel]: + def content_as_object(self, meta_class: Type[T]) -> Optional[T]: """ Get the content as a pydantic based object. diff --git a/keepercommander/keeper_dag/vertex.py b/keepercommander/keeper_dag/vertex.py index 989bdf2c2..48c45d0c8 100644 --- a/keepercommander/keeper_dag/vertex.py +++ b/keepercommander/keeper_dag/vertex.py @@ -3,14 +3,15 @@ from .types import EdgeType, RefType from .crypto import generate_random_bytes, generate_uid_str, urlsafe_str_to_bytes from .exceptions import DAGDeletionException, DAGIllegalEdgeException, DAGVertexException, DAGKeyException -from typing import Optional, Union, List, Any, Tuple, TYPE_CHECKING +from typing import Optional, Union, List, Any, Tuple, TYPE_CHECKING, TypeVar, Type if TYPE_CHECKING: from .dag import DAG Content = Union[str, bytes, dict] QueryValue = Union[list, dict, str, float, int, bool] - import pydantic - from pydantic import BaseModel + + +T = TypeVar('T') class DAGVertex: @@ -489,8 +490,7 @@ def content_as_str(self) -> Optional[str]: return None return data_edge.content_as_str - def content_as_object(self, - meta_class: pydantic._internal._model_construction.ModelMetaclass) -> Optional[BaseModel]: + def content_as_object(self, meta_class: Type[T]) -> Optional[T]: """ Get the content as a pydantic based object. @@ -798,7 +798,7 @@ def _delete(vertex, prior_vertex): self.debug(f" * vertex is root, cannot delete root", level=2) return - self.debug(f"> checking vertex {vertex.uid}") + self.debug(f"> checking vertex {vertex.uid}", level=1) # Should we ignore a vertex? # If deleting an edge, we want to ignore the vertex that owns the edge. @@ -821,7 +821,7 @@ def _delete(vertex, prior_vertex): if e.edge_type != EdgeType.DATA and (prior_vertex is None or e.head_uid == prior_vertex.uid): e.delete() if vertex.belongs_to_a_vertex is False: - self.debug(f" * inactive vertex {vertex.uid}") + self.debug(f" * inactive vertex {vertex.uid}", level=1) vertex.active = False self.debug(f"DELETING vertex {self.uid}", level=3) diff --git a/keepercommander/rest_api.py b/keepercommander/rest_api.py index 43af276ba..5f9314dc9 100644 --- a/keepercommander/rest_api.py +++ b/keepercommander/rest_api.py @@ -9,6 +9,7 @@ # Contact: ops@keepersecurity.com # +import re import requests import os import json @@ -142,6 +143,8 @@ def execute_rest(context, endpoint, payload, timeout=None): context.server_key_id = 7 run_request = True + throttle_retries = 0 + max_throttle_retries = 3 while run_request: run_request = False @@ -251,8 +254,25 @@ def execute_rest(context, endpoint, payload, timeout=None): continue elif rs.status_code == 403: if failure.get('error') == 'throttled' and not context.fail_on_throttle: - logging.debug('Throttled, retrying in 10 seconds') - time.sleep(10) + throttle_retries += 1 + if throttle_retries > max_throttle_retries: + raise KeeperApiError(failure.get('error'), failure.get('message')) + # Parse server's suggested wait time, default to 60s + wait_seconds = 60 + message = failure.get('message', '') + wait_match = re.search(r'(\d+)\s*(second|minute)', message, re.IGNORECASE) + if wait_match: + wait_val = int(wait_match.group(1)) + if 'minute' in wait_match.group(2).lower(): + wait_seconds = wait_val * 60 + else: + wait_seconds = wait_val + # Cap server suggestion at 5 minutes, then take the larger of suggestion vs backoff + wait_seconds = min(wait_seconds, 300) + backoff = max(wait_seconds, 30 * (2 ** (throttle_retries - 1))) + logging.warning('Throttled (attempt %d/%d), retrying in %d seconds', + throttle_retries, max_throttle_retries, backoff) + time.sleep(backoff) run_request = True continue elif rs.status_code in (400, 500) and context.qrc_key_id is not None: diff --git a/keepercommander/security_audit.py b/keepercommander/security_audit.py index b5c2f92f6..3c44f74d1 100644 --- a/keepercommander/security_audit.py +++ b/keepercommander/security_audit.py @@ -115,14 +115,16 @@ def needs_security_audit(params, record): # type: (KeeperParams, KeeperRecord) saved_score_data = params.security_score_data.get(record.record_uid, {}) saved_sec_data = params.breach_watch_security_data.get(record.record_uid, {}) score_data = saved_score_data.get('data', {}) + score_revision = saved_score_data.get('revision') + sec_revision = saved_sec_data.get('revision') current_password = _get_pass(record) - if current_password != score_data.get('password') or None: + if current_password != score_data.get('password'): return True scores = dict(new=get_security_score(record) or 0, old=score_data.get('score', 0)) score_changed_on_passkey = any(x >= 100 for x in scores.values()) and any(x < 100 for x in scores.values()) creds_removed = bool(scores.get('old') and not scores.get('new')) - needs_alignment = bool(scores.get('new')) and not saved_sec_data + needs_alignment = current_password is not None and score_revision != sec_revision return score_changed_on_passkey or creds_removed or needs_alignment def update_security_audit_data(params, records): # type: (KeeperParams, List[KeeperRecord]) -> int diff --git a/keepercommander/service/util/parse_keeper_response.py b/keepercommander/service/util/parse_keeper_response.py index 7f8f004f1..99a80fd9c 100644 --- a/keepercommander/service/util/parse_keeper_response.py +++ b/keepercommander/service/util/parse_keeper_response.py @@ -118,6 +118,13 @@ def parse_response(command: str, response: Any, log_output: str = None) -> Dict[ Returns: Dict[str, Any]: Structured JSON response """ + if isinstance(response, dict) and 'status' not in response: + base_cmd = ' '.join(command.split()[:2]) if len(command.split()) >= 2 else command.split()[0] + return { + "status": "success", + "command": base_cmd, + "data": response, + } # Preprocess response once response_str, is_from_log = KeeperResponseParser._preprocess_response(response, log_output) @@ -486,27 +493,41 @@ def _parse_this_device_command(response: str) -> Dict[str, Any]: @staticmethod def _parse_mkdir_command(response: str) -> Dict[str, Any]: - """Parse 'mkdir' command output to extract folder UID.""" + """Parse 'mkdir' command output to extract folder UID, path, and name.""" response_str = response.strip() - - # Success case - try to extract UID + lines = [ln.strip() for ln in response_str.split('\n') if ln.strip()] + result = { "status": "success", "command": "mkdir", "data": None } - - if re.match(r'^[a-zA-Z0-9_-]+$', response_str): + + for line in lines: + try: + data = json.loads(line) + if isinstance(data, dict) and 'folder_uid' in data: + result["data"] = { + "folder_uid": data["folder_uid"], + "path": data.get("path"), + "name": data.get("name") + } + return result + except (json.JSONDecodeError, TypeError): + pass + + last_line = lines[-1] if lines else response_str + if re.match(r'^[a-zA-Z0-9_-]+$', last_line): result["data"] = { - "folder_uid": response_str + "folder_uid": last_line } else: - uid_match = re.search(r'folder_uid=([a-zA-Z0-9_-]+)', response_str) + uid_match = re.search(r'folder_uid=([a-zA-Z0-9_-]+)', last_line) if uid_match: result["data"] = { "folder_uid": uid_match.group(1) } - + return result @staticmethod diff --git a/requirements.txt b/requirements.txt index 5897db9ff..156d2f51f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,7 @@ requests>=2.31.0 cryptography>=39.0.1 protobuf>=4.23.0 keeper-secrets-manager-core>=16.6.0 -keeper_pam_webrtc_rs>=2.0.1; python_version>='3.8' +keeper_pam_webrtc_rs>=2.1.6; python_version>='3.8' pydantic>=2.6.4; python_version>='3.8' flask; python_version>='3.8' pyngrok>=7.5.0 diff --git a/setup.cfg b/setup.cfg index b05719c26..732d48d9b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,7 +43,7 @@ install_requires = requests>=2.31.0 tabulate websockets - keeper_pam_webrtc_rs>=2.0.1; python_version>='3.8' + keeper_pam_webrtc_rs>=2.1.6; python_version>='3.8' pydantic>=2.6.4; python_version>='3.8' fpdf2>=2.8.3 cbor2; sys_platform == "darwin" and python_version>='3.10' diff --git a/tests/test_credential_provision.py b/tests/test_credential_provision.py new file mode 100644 index 000000000..280efa564 --- /dev/null +++ b/tests/test_credential_provision.py @@ -0,0 +1,602 @@ +""" +KC-1035: Atlassian Onboarding Project — Test Suite + +Test levels: + Level 1 - Unit tests (no external deps, pure logic) + Level 2 - Mocked integration tests (mock Gateway + vault APIs) + Level 3 - E2E tests (real Gateway + AD + vault, requires config) + +Run specific levels: + pytest tests/test_credential_provision_kc1035.py -m unit + pytest tests/test_credential_provision_kc1035.py -m integration + pytest tests/test_credential_provision_kc1035.py -m e2e +""" + +import json +import os +import pytest +from unittest import TestCase, mock +from unittest.mock import MagicMock, patch, PropertyMock + +from keepercommander.commands.credential_provision import ( + CredentialProvisionCommand, + ProvisioningState, + resolve_username_template, +) +from keepercommander.commands.pam.pam_dto import ( + GatewayAction, + GatewayActionRmCreateUser, + GatewayActionRmCreateUserInputs, + GatewayActionRmAddUserToGroup, + GatewayActionRmAddUserToGroupInputs, + GatewayActionRmDeleteUser, + GatewayActionRmDeleteUserInputs, +) + + +# ============================================================================= +# Level 1: Unit Tests — Pure logic, no external dependencies +# ============================================================================= + +@pytest.mark.unit +class TestUsernameTemplate(TestCase): + """Test the username template engine.""" + + def test_basic_template(self): + user = {'first_name': 'Felipe', 'last_name': 'Dias', 'personal_email': 'fdias@atlassian.com'} + result = resolve_username_template('{first_initial}{last_name}.adm', user) + self.assertEqual(result, 'fdias.adm') + + def test_first_name(self): + user = {'first_name': 'Felipe', 'last_name': 'Dias', 'personal_email': 'fdias@atlassian.com'} + self.assertEqual(resolve_username_template('{first_name}', user), 'felipe') + + def test_last_name(self): + user = {'first_name': 'Felipe', 'last_name': 'Dias', 'personal_email': 'fdias@atlassian.com'} + self.assertEqual(resolve_username_template('{last_name}', user), 'dias') + + def test_initials(self): + user = {'first_name': 'Felipe', 'last_name': 'Dias', 'personal_email': 'fdias@atlassian.com'} + self.assertEqual(resolve_username_template('{first_initial}', user), 'f') + self.assertEqual(resolve_username_template('{last_initial}', user), 'd') + + def test_email_prefix(self): + user = {'first_name': 'Felipe', 'last_name': 'Dias', 'personal_email': 'fdias@atlassian.com'} + self.assertEqual(resolve_username_template('{email_prefix}', user), 'fdias') + + def test_output_is_lowercase(self): + user = {'first_name': 'FELIPE', 'last_name': 'DIAS', 'personal_email': 'FDIAS@ATLASSIAN.COM'} + self.assertEqual(resolve_username_template('{first_initial}{last_name}.adm', user), 'fdias.adm') + + def test_dn_template(self): + user = {'first_name': 'Felipe', 'last_name': 'Dias', 'personal_email': 'fdias@atlassian.com'} + result = resolve_username_template( + 'CN={first_initial}{last_name}.adm,OU=DomainAdmins,DC=atlassian,DC=com', user + ) + self.assertEqual(result, 'cn=fdias.adm,ou=domainadmins,dc=atlassian,dc=com') + + def test_hyphenated_name(self): + user = {'first_name': 'Mary-Jane', 'last_name': "O'Brien", 'personal_email': 'mj@test.com'} + result = resolve_username_template('{first_initial}{last_name}.adm', user) + self.assertEqual(result, "mo'brien.adm") + + def test_empty_first_name(self): + user = {'first_name': '', 'last_name': 'Dias', 'personal_email': 'fdias@atlassian.com'} + result = resolve_username_template('{first_initial}{last_name}.adm', user) + self.assertEqual(result, 'dias.adm') + + def test_missing_fields(self): + user = {} + result = resolve_username_template('{first_initial}{last_name}', user) + self.assertEqual(result, '') + + def test_no_template_variables(self): + user = {'first_name': 'Felipe', 'last_name': 'Dias', 'personal_email': 'fdias@atlassian.com'} + result = resolve_username_template('static-username', user) + self.assertEqual(result, 'static-username') + + def test_email_without_at(self): + user = {'first_name': 'Felipe', 'last_name': 'Dias', 'personal_email': 'fdias'} + result = resolve_username_template('{email_prefix}', user) + self.assertEqual(result, 'fdias') + + +@pytest.mark.unit +class TestGatewayDTOs(TestCase): + """Test Gateway action DTO serialization.""" + + def test_rm_create_user_json(self): + inputs = GatewayActionRmCreateUserInputs( + configuration_uid='config-123', + user='CN=fdias.adm,OU=DomainAdmins,DC=test,DC=com', + password='SecureP@ss1', + ) + action = GatewayActionRmCreateUser(inputs=inputs, gateway_destination='gw-789') + data = json.loads(action.toJSON()) + + self.assertEqual(data['action'], 'rm-create-user') + self.assertFalse(data['is_scheduled']) + self.assertEqual(data['inputs']['configurationUid'], 'config-123') + self.assertEqual(data['inputs']['user'], 'CN=fdias.adm,OU=DomainAdmins,DC=test,DC=com') + self.assertEqual(data['inputs']['password'], 'SecureP@ss1') + + def test_rm_create_user_optional_fields(self): + """Optional fields should not be present when not provided.""" + inputs = GatewayActionRmCreateUserInputs( + configuration_uid='config-123', + user='fdias.adm', + ) + data = json.loads(GatewayActionRmCreateUser(inputs=inputs).toJSON()) + self.assertNotIn('password', data['inputs']) + self.assertNotIn('resourceUid', data['inputs']) + self.assertNotIn('meta', data['inputs']) + + def test_rm_add_user_to_group_json(self): + inputs = GatewayActionRmAddUserToGroupInputs( + configuration_uid='config-123', + user='fdias.adm', + group_id='Domain Admins', + ) + action = GatewayActionRmAddUserToGroup(inputs=inputs, gateway_destination='gw-789') + data = json.loads(action.toJSON()) + + self.assertEqual(data['action'], 'rm-add-user-to-group') + self.assertEqual(data['inputs']['groupId'], 'Domain Admins') + + def test_rm_delete_user_json(self): + inputs = GatewayActionRmDeleteUserInputs( + configuration_uid='config-123', + user='fdias.adm', + ) + action = GatewayActionRmDeleteUser(inputs=inputs) + data = json.loads(action.toJSON()) + + self.assertEqual(data['action'], 'rm-delete-user') + self.assertEqual(data['inputs']['user'], 'fdias.adm') + + def test_conversation_id_generation(self): + cid = GatewayAction.generate_conversation_id() + self.assertIsInstance(cid, str) + self.assertGreater(len(cid), 10) + + +@pytest.mark.unit +class TestProvisioningState(TestCase): + """Test ProvisioningState tracks AD creation for rollback.""" + + def test_initial_state(self): + state = ProvisioningState() + self.assertIsNone(state.pam_user_uid) + self.assertFalse(state.ad_user_created) + self.assertIsNone(state.ad_username) + self.assertIsNone(state.ad_config_uid) + self.assertIsNone(state.ad_gateway_uid) + + def test_ad_state_tracking(self): + state = ProvisioningState() + state.ad_user_created = True + state.ad_username = 'fdias.adm' + state.ad_config_uid = 'config-123' + state.ad_gateway_uid = 'gw-456' + self.assertTrue(state.ad_user_created) + self.assertEqual(state.ad_username, 'fdias.adm') + + +@pytest.mark.unit +class TestValidation(TestCase): + """Test YAML config validation changes.""" + + def setUp(self): + self.cmd = CredentialProvisionCommand() + self.params = MagicMock() + + def test_delivery_valid(self): + delivery = {'share_to': 'fdias@atlassian.com'} + errors = self.cmd._validate_delivery_section(delivery) + self.assertEqual(errors, []) + + def test_delivery_missing_share_to(self): + delivery = {} + errors = self.cmd._validate_delivery_section(delivery) + self.assertTrue(any('share_to' in e for e in errors)) + + def test_delivery_invalid_email(self): + delivery = {'share_to': 'not-an-email'} + errors = self.cmd._validate_delivery_section(delivery) + self.assertTrue(any('valid email' in e for e in errors)) + + def test_account_username_template_accepted(self): + account = {'username_template': '{first_initial}{last_name}.adm', 'pam_config_uid': 'xxx'} + errors = self.cmd._validate_account_section(account) + username_errors = [e for e in errors if 'username' in e.lower()] + self.assertEqual(username_errors, []) + + def test_account_neither_username_nor_template(self): + account = {'pam_config_uid': 'xxx'} + errors = self.cmd._validate_account_section(account) + self.assertTrue(any('username' in e for e in errors)) + + def test_email_section_optional_with_delivery(self): + """Email section should not be required when delivery section is present.""" + config = { + 'user': {'first_name': 'Felipe', 'last_name': 'Dias', 'personal_email': 'fdias@test.com'}, + 'account': {'username': 'fdias.adm', 'pam_config_uid': 'xxx'}, + 'rotation': {'schedule': '0 0 3 * * ?', 'password_complexity': '24,4,4,4,4'}, + 'delivery': {'share_to': 'fdias@test.com'}, + } + errors = self.cmd._validate_config(self.params, config) + email_errors = [e for e in errors if e.startswith('email.')] + self.assertEqual(email_errors, []) + + def test_no_delivery_no_email_valid(self): + """No delivery and no email should be valid — record created but not shared.""" + config = { + 'user': {'first_name': 'Felipe', 'last_name': 'Dias', 'personal_email': 'fdias@test.com'}, + 'account': {'username': 'fdias.adm', 'pam_config_uid': 'xxx'}, + 'rotation': {'schedule': '0 0 3 * * ?', 'password_complexity': '24,4,4,4,4'}, + } + errors = self.cmd._validate_config(self.params, config) + delivery_errors = [e for e in errors if 'delivery' in e.lower() or 'email' in e.lower()] + self.assertEqual(delivery_errors, []) + + +# ============================================================================= +# Level 2: Mocked Integration Tests — Mock Gateway + vault APIs +# ============================================================================= + +@pytest.mark.integration +class TestDirectShare(TestCase): + """Test direct share delivery with mocked ShareRecordCommand.""" + + def setUp(self): + self.cmd = CredentialProvisionCommand() + + @patch('keepercommander.commands.credential_provision.ShareRecordCommand') + @patch('keepercommander.commands.credential_provision.api') + def test_share_directly_success(self, mock_api, mock_share_cmd): + mock_rq = MagicMock() + mock_share_cmd.prep_request.return_value = mock_rq + + params = MagicMock() + config = { + 'delivery': { + 'method': 'direct_share', + 'share_to': 'fdias@atlassian.com', + 'permissions': {'can_edit': False, 'can_share': False}, + } + } + + result = self.cmd._share_directly('pam-uid-123', config, params) + + self.assertTrue(result) + mock_api.sync_down.assert_called_once_with(params) + mock_share_cmd.prep_request.assert_called_once() + mock_share_cmd.send_requests.assert_called_once() + + # Verify the kwargs passed to prep_request + call_kwargs = mock_share_cmd.prep_request.call_args[0][1] + self.assertEqual(call_kwargs['record'], 'pam-uid-123') + self.assertEqual(call_kwargs['email'], ['fdias@atlassian.com']) + self.assertEqual(call_kwargs['action'], 'grant') + self.assertFalse(call_kwargs['can_edit']) + self.assertFalse(call_kwargs['can_share']) + + @patch('keepercommander.commands.credential_provision.ShareRecordCommand') + @patch('keepercommander.commands.credential_provision.api') + def test_share_directly_invitation_sent(self, mock_api, mock_share_cmd): + """When prep_request returns None, invitation was sent (vault not yet accepted).""" + mock_share_cmd.prep_request.return_value = None + + params = MagicMock() + config = { + 'delivery': { + 'method': 'direct_share', + 'share_to': 'newuser@atlassian.com', + } + } + + result = self.cmd._share_directly('pam-uid-123', config, params) + self.assertTrue(result) + mock_share_cmd.send_requests.assert_not_called() + + @patch('keepercommander.commands.credential_provision.ShareRecordCommand') + @patch('keepercommander.commands.credential_provision.api') + def test_share_directly_failure_non_fatal(self, mock_api, mock_share_cmd): + """Share failure should return False, not raise.""" + mock_share_cmd.prep_request.side_effect = Exception('Public key not found') + + params = MagicMock() + config = {'delivery': {'method': 'direct_share', 'share_to': 'bad@test.com'}} + + result = self.cmd._share_directly('pam-uid-123', config, params) + self.assertFalse(result) + + +@pytest.mark.integration +class TestADCreationViaGateway(TestCase): + """Test AD user creation with mocked Gateway communication.""" + + def setUp(self): + self.cmd = CredentialProvisionCommand() + # Create a fake 32-byte AES key for encryption tests + self.fake_record_key = os.urandom(32) + + def _mock_params_with_record_cache(self, config_uid='config-123'): + params = MagicMock() + params.record_cache = {config_uid: {'record_key_unencrypted': self.fake_record_key}} + return params + + @patch('keepercommander.commands.credential_provision.get_response_payload') + @patch('keepercommander.commands.credential_provision.router_send_action_to_gateway') + def test_create_ad_user_success(self, mock_router_send, mock_get_payload): + mock_router_send.return_value = {'response': 'ok'} + mock_get_payload.return_value = {'data': {'success': True, 'configurationUid': 'config-123'}} + + params = self._mock_params_with_record_cache() + state = ProvisioningState() + config = { + 'account': { + 'username': 'fdias.adm', + 'pam_config_uid': 'config-123', + 'distinguished_name': 'CN=fdias.adm,OU=DomainAdmins,DC=test,DC=com', + } + } + + with patch.object(self.cmd, '_get_gateway_uid_for_config', return_value='gw-456'): + result = self.cmd._create_ad_user_via_gateway(config, 'P@ssw0rd', params, state) + + self.assertTrue(result) + self.assertTrue(state.ad_user_created) + self.assertEqual(state.ad_username, 'fdias.adm') + self.assertEqual(state.ad_config_uid, 'config-123') + self.assertEqual(state.ad_gateway_uid, 'gw-456') + + @patch('keepercommander.commands.credential_provision.get_response_payload') + @patch('keepercommander.commands.credential_provision.router_send_action_to_gateway') + def test_create_ad_user_already_exists(self, mock_router_send, mock_get_payload): + mock_router_send.return_value = {'response': 'ok'} + mock_get_payload.return_value = {'data': {'success': False, 'error': 'User already exists'}} + + params = self._mock_params_with_record_cache() + state = ProvisioningState() + config = { + 'account': { + 'username': 'fdias.adm', + 'pam_config_uid': 'config-123', + 'distinguished_name': 'CN=fdias.adm,OU=DomainAdmins,DC=test,DC=com', + } + } + + with patch.object(self.cmd, '_get_gateway_uid_for_config', return_value='gw-456'): + from keepercommander.error import CommandError + with self.assertRaises(CommandError) as ctx: + self.cmd._create_ad_user_via_gateway(config, 'P@ssw0rd', params, state) + + self.assertIn('User already exists', str(ctx.exception)) + self.assertFalse(state.ad_user_created) + + @patch('keepercommander.commands.credential_provision.router_send_action_to_gateway') + def test_create_ad_user_gateway_offline(self, mock_router_send): + params = MagicMock() + state = ProvisioningState() + config = { + 'account': { + 'username': 'fdias.adm', + 'pam_config_uid': 'config-123', + } + } + + with patch.object(self.cmd, '_get_gateway_uid_for_config', return_value=None): + from keepercommander.error import CommandError + with self.assertRaises(CommandError) as ctx: + self.cmd._create_ad_user_via_gateway(config, 'P@ssw0rd', params, state) + + self.assertIn('No connected Gateway', str(ctx.exception)) + + +@pytest.mark.integration +class TestRollback(TestCase): + """Test rollback handles AD + PAM User cleanup in LIFO order.""" + + def setUp(self): + self.cmd = CredentialProvisionCommand() + + @patch('keepercommander.commands.credential_provision.api') + def test_rollback_pam_user_only(self, mock_api): + """Rollback with only PAM User created (no AD user).""" + state = ProvisioningState() + state.pam_user_uid = 'pam-123' + + params = MagicMock() + self.cmd._rollback(state, params) + + mock_api.delete_record.assert_called_once_with(params, 'pam-123') + + @patch('keepercommander.commands.credential_provision.api') + def test_rollback_ad_and_pam_user(self, mock_api): + """Rollback with both AD user and PAM User created — LIFO order.""" + state = ProvisioningState() + state.pam_user_uid = 'pam-123' + state.ad_user_created = True + state.ad_username = 'fdias.adm' + state.ad_config_uid = 'config-123' + state.ad_gateway_uid = 'gw-456' + + params = MagicMock() + + with patch.object(self.cmd, '_delete_ad_user_via_gateway') as mock_ad_delete: + self.cmd._rollback(state, params) + + # PAM User deleted first (LIFO) + mock_api.delete_record.assert_called_once_with(params, 'pam-123') + # AD user deleted second + mock_ad_delete.assert_called_once_with(state, params) + + @patch('keepercommander.commands.credential_provision.api') + def test_rollback_ad_only(self, mock_api): + """Rollback with AD user created but PAM User creation failed.""" + state = ProvisioningState() + state.ad_user_created = True + state.ad_username = 'fdias.adm' + state.ad_config_uid = 'config-123' + state.ad_gateway_uid = 'gw-456' + + params = MagicMock() + + with patch.object(self.cmd, '_delete_ad_user_via_gateway') as mock_ad_delete: + self.cmd._rollback(state, params) + + mock_api.delete_record.assert_not_called() + mock_ad_delete.assert_called_once() + + @patch('keepercommander.commands.credential_provision.api') + def test_rollback_nothing_created(self, mock_api): + """Rollback with nothing created — should not fail.""" + state = ProvisioningState() + params = MagicMock() + + with patch.object(self.cmd, '_delete_ad_user_via_gateway') as mock_ad_delete: + self.cmd._rollback(state, params) + + mock_api.delete_record.assert_not_called() + mock_ad_delete.assert_not_called() + + +# ============================================================================= +# Level 3: E2E Tests — Real Gateway + AD + Vault +# Requires: vault.json config, running Gateway, AD access, Okta/SCIM +# ============================================================================= + +@pytest.mark.e2e +@pytest.mark.skip(reason="Requires real environment — run manually with: pytest -m e2e --no-header -v") +class TestE2EProvisioningFlow(TestCase): + """ + End-to-end test against real infrastructure. + + Prerequisites: + - Commander config at tests/vault.json (service vault credentials) + - PAM Gateway running and connected + - AD accessible from Gateway + - Target user vault exists (SCIM provisioned) + + Setup: + 1. Create tests/e2e_config.json with: + { + "pam_config_uid": "", + "target_user_email": "", + "ad_base_dn": "OU=TestUsers,DC=yourdomain,DC=com", + "ad_groups": ["TestGroup1"] + } + 2. Run: pytest tests/test_credential_provision_kc1035.py -m e2e -v + """ + + params = None + e2e_config = None + created_resources = [] + + @classmethod + def setUpClass(cls): + import os + from data_config import read_config_file + from keepercommander.params import KeeperParams + from keepercommander import api + + cls.params = KeeperParams() + read_config_file(cls.params, 'vault.json') + api.login(cls.params) + + config_path = os.path.join(os.path.dirname(__file__), 'e2e_config.json') + with open(config_path, 'r') as f: + cls.e2e_config = json.load(f) + + @classmethod + def tearDownClass(cls): + from keepercommander import cli + # Cleanup created resources + for resource in cls.created_resources: + try: + if resource['type'] == 'record': + from keepercommander import api + api.delete_record(cls.params, resource['uid']) + except Exception as e: + print(f"Cleanup failed for {resource}: {e}") + cli.do_command(cls.params, 'logout') + + def test_01_direct_share_with_ad_creation(self): + """Full flow: create AD user → PAM User → rotation → direct share.""" + import base64 + import yaml + + config_yaml = { + 'user': { + 'first_name': 'Test', + 'last_name': 'User', + 'personal_email': self.e2e_config['target_user_email'], + }, + 'account': { + 'username_template': '{first_initial}{last_name}.adm.test', + 'pam_config_uid': self.e2e_config['pam_config_uid'], + 'distinguished_name': f'CN={{username}},{self.e2e_config["ad_base_dn"]}', + 'ad_groups': self.e2e_config.get('ad_groups', []), + }, + 'vault': { + 'folder': 'KC-1035-E2E-Test', + }, + 'pam': { + 'rotation': { + 'schedule': '0 0 3 * * ?', + 'password_complexity': '24,4,4,4,4', + }, + }, + 'delivery': { + 'method': 'direct_share', + 'share_to': self.e2e_config['target_user_email'], + 'permissions': {'can_edit': False, 'can_share': False}, + }, + } + + yaml_str = yaml.dump(config_yaml) + b64_config = base64.b64encode(yaml_str.encode()).decode() + + cmd = CredentialProvisionCommand() + cmd.execute(self.params, config_base64=b64_config, output='json') + + def test_02_dry_run(self): + """Dry run should validate without creating anything.""" + import base64 + import yaml + + config_yaml = { + 'user': { + 'first_name': 'DryRun', + 'last_name': 'Test', + 'personal_email': self.e2e_config['target_user_email'], + }, + 'account': { + 'username_template': '{first_initial}{last_name}.adm.dryrun', + 'pam_config_uid': self.e2e_config['pam_config_uid'], + 'distinguished_name': f'CN={{username}},{self.e2e_config["ad_base_dn"]}', + }, + 'pam': { + 'rotation': { + 'schedule': '0 0 3 * * ?', + 'password_complexity': '24,4,4,4,4', + }, + }, + 'delivery': { + 'method': 'direct_share', + 'share_to': self.e2e_config['target_user_email'], + }, + } + + yaml_str = yaml.dump(config_yaml) + b64_config = base64.b64encode(yaml_str.encode()).decode() + + cmd = CredentialProvisionCommand() + # Should not raise + cmd.execute(self.params, config_base64=b64_config, dry_run=True, output='json') + + def test_03_duplicate_detection(self): + """Running the same config twice should fail on duplicate check.""" + # This test depends on test_01 having run first + pass # Implement after confirming test_01 works diff --git a/tests/test_security_audit_refresh.py b/tests/test_security_audit_refresh.py new file mode 100644 index 000000000..40e9a7418 --- /dev/null +++ b/tests/test_security_audit_refresh.py @@ -0,0 +1,237 @@ +import os +import json +from collections import Counter +from unittest import TestCase + +import pytest + +from data_config import read_config_file +from keepercommander import api, cli, security_audit +from keepercommander.commands.security_audit import SecurityAuditReportCommand, SecurityAuditSyncCommand +from keepercommander.error import CommandError +from keepercommander.params import KeeperParams +from keepercommander.utils import is_pw_fair, is_pw_strong, is_pw_weak +from keepercommander import vault + + +@pytest.mark.integration +class TestSecurityAuditRefresh(TestCase): + params = None # type: KeeperParams + + @classmethod + def setUpClass(cls): + cls.params = KeeperParams() + read_config_file(cls.params, os.environ.get('KEEPER_CONFIG', '../config.json')) + api.login(cls.params) + api.query_enterprise(cls.params) + api.sync_down(cls.params, record_types=True) + + @classmethod + def tearDownClass(cls): + try: + if cls.params: + cli.do_command(cls.params, 'delete-all --force') + api.sync_down(cls.params, record_types=True) + except Exception: + pass + + def setUp(self): + api.sync_down(self.params, record_types=True) + cli.do_command(self.params, 'delete-all --force') + api.sync_down(self.params, record_types=True) + + def add_legacy_record(self, title, password, extra_fields=''): + command = ( + f'record-add --title="{title}" --record-type=legacy ' + f'login=security.audit@example.com password={password} url=https://example.com' + ) + if extra_fields: + command = f'{command} {extra_fields}' + record_uid = cli.do_command(self.params, command) + api.sync_down(self.params, record_types=True) + return record_uid + + def add_typed_login_record(self, title, password): + command = ( + f'record-add --title="{title}" --record-type=login ' + f'login=security.audit@example.com password={password} url=https://example.com' + ) + try: + record_uid = cli.do_command(self.params, command) + except CommandError as err: + if 'Record type "login" cannot be found.' in str(err): + self.skipTest('Typed login record type is not available in this integration environment') + raise + api.sync_down(self.params, record_types=True) + return record_uid + + def update_password(self, record_uid, password): + cli.do_command(self.params, f'record-update --record={record_uid} password={password}') + api.sync_down(self.params, record_types=True) + + def rotate_password(self, record_uid): + cli.do_command(self.params, f'rotate -- {record_uid}') + api.sync_down(self.params, record_types=True) + + def hard_clear_current_user_security_data(self): + SecurityAuditSyncCommand().execute( + self.params, + email=[self.params.user], + hard=True, + force=True, + ) + api.sync_down(self.params, record_types=True) + + def current_user_report_row(self): + report = json.loads(SecurityAuditReportCommand().execute(self.params, save=True, format='json')) + return next((x for x in report if x.get('email') == self.params.user), None) + + def current_user_debug_row(self): + report = json.loads(SecurityAuditReportCommand().execute(self.params, debug=True, format='json')) + return next((x for x in report if x.get('vault_owner') == self.params.user), None) + + def get_score_payload(self, record_uid): + return (self.params.security_score_data.get(record_uid) or {}).get('data', {}) + + def assert_record_security_state(self, record_uid, password, score, has_security_data): + score_data = self.get_score_payload(record_uid) + self.assertEqual(score_data.get('password'), password) + self.assertEqual(score_data.get('score'), score) + + security_data = self.params.breach_watch_security_data.get(record_uid) + if has_security_data: + self.assertIsNotNone(security_data) + else: + self.assertIsNone(security_data) + + def assert_record_revisions_aligned(self, record_uid): + score_revision = (self.params.security_score_data.get(record_uid) or {}).get('revision') + security_revision = (self.params.breach_watch_security_data.get(record_uid) or {}).get('revision') + self.assertEqual(score_revision, security_revision) + + def assert_record_has_no_password_score_data(self, record_uid): + self.assertEqual(self.get_score_payload(record_uid), {}) + + def expected_summary(self, record_uids): + summary = { + 'weak': 0, + 'fair': 0, + 'medium': 0, + 'strong': 0, + 'reused': 0, + 'unique': 0, + 'securityScore': 25, + } + password_counts = Counter() + total = 0 + for record_uid in record_uids: + score_data = self.get_score_payload(record_uid) + password = score_data.get('password') + score = score_data.get('score') + if password is None or score is None: + continue + total += 1 + password_counts[password] += 1 + if is_pw_strong(score): + summary['strong'] += 1 + elif is_pw_fair(score): + summary['fair'] += 1 + elif is_pw_weak(score): + summary['weak'] += 1 + else: + summary['medium'] += 1 + + summary['reused'] = sum(count for count in password_counts.values() if count > 1) + summary['unique'] = total - summary['reused'] + if total > 0: + strong_ratio = summary['strong'] / total + unique_ratio = summary['unique'] / total + summary['securityScore'] = int(100 * round((strong_ratio + unique_ratio + 1) / 4, 2)) + return summary + + def assert_debug_pending(self): + debug_row = self.current_user_debug_row() + self.assertIsNotNone(debug_row) + raw_old = debug_row.get('old_incremental_data') or [] + raw_curr = debug_row.get('current_incremental_data') or [] + self.assertTrue(any(item is not None for item in raw_old + raw_curr)) + + def assert_admin_summary_matches_records(self, record_uids, expect_debug_pending=True): + if expect_debug_pending: + self.assert_debug_pending() + + row = self.current_user_report_row() + self.assertIsNotNone(row) + expected = self.expected_summary(record_uids) + for key, value in expected.items(): + self.assertEqual(row.get(key), value, msg=f'{key} mismatch: {row}') + self.assertIsNone(self.current_user_debug_row()) + + def test_summary_alignment_for_add_update_reuse_and_password_removal(self): + record_uid_1 = self.add_legacy_record('Security audit lifecycle-1', 'aa') + self.assert_record_security_state(record_uid_1, 'aa', 0, True) + self.assert_record_revisions_aligned(record_uid_1) + self.assert_admin_summary_matches_records([record_uid_1]) + + self.update_password(record_uid_1, 'weak-password') + self.assert_record_security_state(record_uid_1, 'weak-password', 41, True) + self.assert_record_revisions_aligned(record_uid_1) + self.assert_admin_summary_matches_records([record_uid_1]) + + self.update_password(record_uid_1, 'A1!bcdefgh') + self.assert_record_security_state(record_uid_1, 'A1!bcdefgh', 61, True) + self.assert_record_revisions_aligned(record_uid_1) + self.assert_admin_summary_matches_records([record_uid_1]) + + self.update_password(record_uid_1, 'StrongPass123!') + self.assert_record_security_state(record_uid_1, 'StrongPass123!', 100, True) + self.assert_record_revisions_aligned(record_uid_1) + self.assert_admin_summary_matches_records([record_uid_1]) + + record_uid_2 = self.add_legacy_record('Security audit lifecycle-2', 'StrongPass123!') + self.assert_record_security_state(record_uid_2, 'StrongPass123!', 100, True) + self.assert_record_revisions_aligned(record_uid_2) + self.assert_admin_summary_matches_records([record_uid_1, record_uid_2]) + + self.update_password(record_uid_1, '') + self.assert_record_has_no_password_score_data(record_uid_1) + self.assert_admin_summary_matches_records([record_uid_1, record_uid_2]) + + def test_rotation_and_hard_clear_repair_align_admin_summary(self): + record_uid = self.add_legacy_record('Security audit rotate/repair', 'aa', extra_fields='cmdr:plugin=noop') + self.assert_record_security_state(record_uid, 'aa', 0, True) + self.assert_record_revisions_aligned(record_uid) + self.assert_admin_summary_matches_records([record_uid]) + + self.rotate_password(record_uid) + rotated_score_data = self.get_score_payload(record_uid) + self.assertIsInstance(rotated_score_data.get('password'), str) + self.assertTrue(rotated_score_data.get('password')) + self.assertIn('score', rotated_score_data) + self.assertIsNotNone(self.params.breach_watch_security_data.get(record_uid)) + self.assert_record_revisions_aligned(record_uid) + self.assert_admin_summary_matches_records([record_uid]) + + self.hard_clear_current_user_security_data() + self.assertIsNotNone(self.get_score_payload(record_uid)) + self.assertIsNone(self.params.breach_watch_security_data.get(record_uid)) + + record = vault.KeeperRecord.load(self.params, record_uid) + self.assertTrue(security_audit.needs_security_audit(self.params, record)) + + cli.do_command(self.params, f'sync-security-data {record_uid} --quiet') + api.sync_down(self.params, record_types=True) + self.assertIsNotNone(self.params.breach_watch_security_data.get(record_uid)) + self.assert_record_revisions_aligned(record_uid) + self.assert_admin_summary_matches_records([record_uid]) + + def test_typed_login_add_and_update_align_admin_summary(self): + record_uid = self.add_typed_login_record('Security audit typed login', 'aa') + self.assert_record_security_state(record_uid, 'aa', 0, True) + self.assert_record_revisions_aligned(record_uid) + self.assert_admin_summary_matches_records([record_uid]) + + self.update_password(record_uid, 'StrongPass123!') + self.assert_record_security_state(record_uid, 'StrongPass123!', 100, True) + self.assert_record_revisions_aligned(record_uid) + self.assert_admin_summary_matches_records([record_uid]) diff --git a/unit-tests/test_command_record.py b/unit-tests/test_command_record.py index 54283501c..18d213b8a 100644 --- a/unit-tests/test_command_record.py +++ b/unit-tests/test_command_record.py @@ -237,7 +237,15 @@ def test_append_notes_command(self): params = get_synced_params() cmd = record_edit.RecordAppendNotesCommand() - record_uid = next(iter(params.subfolder_record_cache[''])) + # Fixture mixes legacy (Record 1) and typed login (Record 2) in root; append-notes + # uses RecordUpdateCommand, which rejects legacy PasswordRecord. + record_uid = None + for uid in params.subfolder_record_cache['']: + rec = vault.KeeperRecord.load(params, uid) + if isinstance(rec, vault.TypedRecord): + record_uid = uid + break + self.assertIsNotNone(record_uid) with mock.patch('keepercommander.record_management.update_record'): cmd.execute(params, notes='notes', record=record_uid) diff --git a/unit-tests/test_credential_provision.py b/unit-tests/test_credential_provision.py index a40e40c74..fe958e2e4 100644 --- a/unit-tests/test_credential_provision.py +++ b/unit-tests/test_credential_provision.py @@ -460,54 +460,44 @@ def test_reject_initial_password_field(self): class TestPAMSectionValidation(unittest.TestCase): - """Test PAM section validation.""" + """Test rotation section validation.""" def setUp(self): """Set up test fixtures.""" self.cmd = CredentialProvisionCommand() def test_valid_pam_section(self): - """Test that valid PAM section passes validation.""" - pam = { - 'rotation': { - 'rotate_immediately': True, - 'schedule': '0 0 3 * * ?', - 'password_complexity': '32,5,5,5,5', - }, - 'pam_user_title': 'PAM: John Doe', - 'login_record_title': 'John Doe Login', + """Test that valid rotation section passes validation.""" + rotation = { + 'schedule': '0 0 3 * * ?', + 'password_complexity': '32,5,5,5,5', } - errors = self.cmd._validate_pam_section(pam) - self.assertEqual(len(errors), 0, 'Valid PAM section should have no errors') + errors = self.cmd._validate_rotation_section(rotation) + self.assertEqual(len(errors), 0, 'Valid rotation section should have no errors') def test_missing_rotation_section(self): - """Test detection of missing rotation section.""" - pam = {} - errors = self.cmd._validate_pam_section(pam) - self.assertGreater(len(errors), 0, 'Should detect missing rotation') - self.assertIn('rotation', ' '.join(errors)) + """Test detection of missing rotation fields.""" + rotation = {} + errors = self.cmd._validate_rotation_section(rotation) + self.assertGreater(len(errors), 0, 'Should detect missing rotation fields') def test_invalid_cron_schedule(self): """Test detection of invalid CRON schedule.""" - pam = { - 'rotation': { - 'schedule': 'invalid cron', # Invalid - 'password_complexity': '32,5,5,5,5', - } + rotation = { + 'schedule': 'invalid cron', # Invalid + 'password_complexity': '32,5,5,5,5', } - errors = self.cmd._validate_pam_section(pam) + errors = self.cmd._validate_rotation_section(rotation) self.assertGreater(len(errors), 0, 'Should detect invalid CRON') self.assertIn('CRON', ' '.join(errors)) def test_invalid_complexity_format(self): """Test detection of invalid complexity format.""" - pam = { - 'rotation': { - 'schedule': '0 0 3 * * ?', - 'password_complexity': 'invalid', # Invalid - } + rotation = { + 'schedule': '0 0 3 * * ?', + 'password_complexity': 'invalid', # Invalid } - errors = self.cmd._validate_pam_section(pam) + errors = self.cmd._validate_rotation_section(rotation) self.assertGreater(len(errors), 0, 'Should detect invalid complexity') self.assertIn('complexity', ' '.join(errors).lower()) @@ -579,11 +569,9 @@ def test_valid_complete_config(self): 'username': 'john.doe', 'pam_config_uid': 'test-uid', }, - 'pam': { - 'rotation': { - 'schedule': '0 0 3 * * ?', - 'password_complexity': '32,5,5,5,5', - } + 'rotation': { + 'schedule': '0 0 3 * * ?', + 'password_complexity': '32,5,5,5,5', }, 'email': { 'config_name': 'Test Config', @@ -597,14 +585,13 @@ def test_missing_required_sections(self): """Test detection of missing required sections.""" config = { 'user': {'first_name': 'John'}, - # Missing account, pam, email sections + # Missing account, rotation sections } errors = self.cmd._validate_config(self.mock_params, config) self.assertGreater(len(errors), 0, 'Should detect missing sections') error_text = ' '.join(errors) self.assertIn('account', error_text) - self.assertIn('pam', error_text) - self.assertIn('email', error_text) + self.assertIn('rotation', error_text) def test_multiple_validation_errors(self): """Test that multiple errors are collected (not fail-fast).""" @@ -616,19 +603,14 @@ def test_multiple_validation_errors(self): 'account': { # Missing username and pam_config_uid }, - 'pam': { - 'rotation': { - 'schedule': 'invalid', # Invalid CRON - 'password_complexity': 'invalid', # Invalid format - } + 'rotation': { + 'schedule': 'invalid', # Invalid CRON + 'password_complexity': 'invalid', # Invalid format }, - 'email': { - # Missing config_name - } } errors = self.cmd._validate_config(self.mock_params, config) # Should have multiple errors from different sections - self.assertGreater(len(errors), 5, 'Should collect multiple errors') + self.assertGreater(len(errors), 3, 'Should collect multiple errors') class TestPasswordGeneration(unittest.TestCase): diff --git a/unit-tests/test_security_audit.py b/unit-tests/test_security_audit.py new file mode 100644 index 000000000..a6d1b6bd4 --- /dev/null +++ b/unit-tests/test_security_audit.py @@ -0,0 +1,85 @@ +from types import SimpleNamespace +from unittest import TestCase, mock + +from keepercommander import security_audit + + +class TestSecurityAudit(TestCase): + def setUp(self): + self.record = SimpleNamespace(record_uid='record_uid') + self.params = SimpleNamespace( + enterprise_ec_key=b'enterprise-key', + security_score_data={}, + breach_watch_security_data={}, + ) + + def test_needs_security_audit_updates_missing_security_data_for_weak_password(self): + self.params.security_score_data = { + self.record.record_uid: { + 'data': {'password': 'weak-password', 'score': 0}, + 'revision': 7, + } + } + + with mock.patch('keepercommander.security_audit._get_pass', return_value='weak-password'), \ + mock.patch('keepercommander.security_audit.get_security_score', return_value=0): + self.assertTrue(security_audit.needs_security_audit(self.params, self.record)) + + def test_needs_security_audit_updates_missing_security_data_for_nonzero_score(self): + self.params.security_score_data = { + self.record.record_uid: { + 'data': {'password': 'StrongPass123!', 'score': 100}, + 'revision': 9, + } + } + + with mock.patch('keepercommander.security_audit._get_pass', return_value='StrongPass123!'), \ + mock.patch('keepercommander.security_audit.get_security_score', return_value=100): + self.assertTrue(security_audit.needs_security_audit(self.params, self.record)) + + def test_needs_security_audit_skips_when_no_password_and_no_security_data(self): + with mock.patch('keepercommander.security_audit._get_pass', return_value=None), \ + mock.patch('keepercommander.security_audit.get_security_score', return_value=None): + self.assertFalse(security_audit.needs_security_audit(self.params, self.record)) + + def test_needs_security_audit_skips_when_security_data_already_exists_for_weak_password(self): + self.params.security_score_data = { + self.record.record_uid: { + 'data': {'password': 'weak-password', 'score': 0}, + 'revision': 7, + } + } + self.params.breach_watch_security_data = { + self.record.record_uid: {'revision': 7} + } + + with mock.patch('keepercommander.security_audit._get_pass', return_value='weak-password'), \ + mock.patch('keepercommander.security_audit.get_security_score', return_value=0): + self.assertFalse(security_audit.needs_security_audit(self.params, self.record)) + + def test_needs_security_audit_updates_when_security_data_revision_is_stale(self): + self.params.security_score_data = { + self.record.record_uid: { + 'data': {'password': 'StrongPass123!', 'score': 100}, + 'revision': 11, + } + } + self.params.breach_watch_security_data = { + self.record.record_uid: {'revision': 7} + } + + with mock.patch('keepercommander.security_audit._get_pass', return_value='StrongPass123!'), \ + mock.patch('keepercommander.security_audit.get_security_score', return_value=100): + self.assertTrue(security_audit.needs_security_audit(self.params, self.record)) + + def test_needs_security_audit_updates_when_password_is_removed(self): + self.params.security_score_data = { + self.record.record_uid: { + 'data': {'password': 'StrongPass123!', 'score': 100}, + 'revision': 11, + } + } + + with mock.patch('keepercommander.security_audit._get_pass', return_value=None), \ + mock.patch('keepercommander.security_audit.get_security_score', return_value=None): + self.assertTrue(security_audit.needs_security_audit(self.params, self.record)) diff --git a/unit-tests/test_throttle_retry.py b/unit-tests/test_throttle_retry.py new file mode 100644 index 000000000..b45091640 --- /dev/null +++ b/unit-tests/test_throttle_retry.py @@ -0,0 +1,194 @@ +"""Tests for execute_rest() throttle retry logic in rest_api.py. + +Verifies: +- Normal (non-throttled) requests are unaffected +- Throttled requests retry up to 3 times with exponential backoff +- KeeperApiError raised after max retries +- --fail-on-throttle skips retries entirely +- Server's "try again in X" message is parsed correctly (seconds + minutes) +- Server wait capped at 300s +- Backoff takes the larger of server hint vs exponential schedule +""" + +import json +import os +import sys +import unittest +from unittest.mock import patch, MagicMock + +# Add parent dir so imports work from unit-tests/ +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + +from keepercommander.rest_api import execute_rest +from keepercommander.error import KeeperApiError +from keepercommander.params import RestApiContext +from keepercommander.proto import APIRequest_pb2 as proto + + +def make_context(fail_on_throttle=False): + """Create a real RestApiContext with QRC disabled to simplify mocking.""" + ctx = RestApiContext(server='https://keepersecurity.com', locale='en_US') + ctx.transmission_key = os.urandom(32) + ctx.server_key_id = 7 + ctx.fail_on_throttle = fail_on_throttle + ctx.disable_qrc() # Skip QRC key negotiation + return ctx + + +def make_throttle_response(message="Please try again in 1 minute"): + """Build a fake HTTP 403 throttle response.""" + body = {"error": "throttled", "message": message} + resp = MagicMock() + resp.status_code = 403 + resp.headers = {'Content-Type': 'application/json'} + resp.json.return_value = body + resp.content = json.dumps(body).encode() + return resp + + +def make_success_response(): + """Build a fake HTTP 200 response with empty body.""" + resp = MagicMock() + resp.status_code = 200 + resp.headers = {'Content-Type': 'application/octet-stream'} + resp.content = b'' + return resp + + +def make_payload(): + """Create a minimal ApiRequestPayload for execute_rest.""" + return proto.ApiRequestPayload() + + +class TestThrottleRetry(unittest.TestCase): + """Tests for the bounded retry with exponential backoff on 403 throttle.""" + + @patch('keepercommander.rest_api.time.sleep') + @patch('keepercommander.rest_api.requests.post') + def test_normal_request_unaffected(self, mock_post, mock_sleep): + """Non-throttled 200 response should pass through with no retries.""" + mock_post.return_value = make_success_response() + + execute_rest(make_context(), 'test/endpoint', make_payload()) + + self.assertEqual(mock_post.call_count, 1) + mock_sleep.assert_not_called() + + @patch('keepercommander.rest_api.time.sleep') + @patch('keepercommander.rest_api.requests.post') + def test_retries_then_succeeds(self, mock_post, mock_sleep): + """Throttle twice, succeed on 3rd attempt.""" + mock_post.side_effect = [ + make_throttle_response("try again in 30 seconds"), + make_throttle_response("try again in 30 seconds"), + make_success_response(), + ] + + execute_rest(make_context(), 'test/endpoint', make_payload()) + + self.assertEqual(mock_post.call_count, 3) + self.assertEqual(mock_sleep.call_count, 2) + # 1st: max(30, 30*2^0=30) = 30 + # 2nd: max(30, 30*2^1=60) = 60 + calls = [c[0][0] for c in mock_sleep.call_args_list] + self.assertEqual(calls, [30, 60]) + + @patch('keepercommander.rest_api.time.sleep') + @patch('keepercommander.rest_api.requests.post') + def test_raises_after_max_retries(self, mock_post, mock_sleep): + """Always throttled — should raise KeeperApiError after 3 retries.""" + mock_post.return_value = make_throttle_response("try again in 1 minute") + + with self.assertRaises(KeeperApiError): + execute_rest(make_context(), 'test/endpoint', make_payload()) + + # 1 initial + 3 retries = 4 posts, error raised when retry 4 > max 3 + self.assertEqual(mock_post.call_count, 4) + self.assertEqual(mock_sleep.call_count, 3) + + @patch('keepercommander.rest_api.time.sleep') + @patch('keepercommander.rest_api.requests.post') + def test_fail_on_throttle_skips_retry(self, mock_post, mock_sleep): + """--fail-on-throttle should return error immediately with no retries.""" + mock_post.return_value = make_throttle_response() + + result = execute_rest(make_context(fail_on_throttle=True), 'test/endpoint', make_payload()) + + # When fail_on_throttle=True, the throttle block is skipped and the + # failure dict is returned directly (no retry, no exception) + self.assertEqual(result.get('error'), 'throttled') + self.assertEqual(mock_post.call_count, 1) + mock_sleep.assert_not_called() + + @patch('keepercommander.rest_api.time.sleep') + @patch('keepercommander.rest_api.requests.post') + def test_parses_seconds_hint(self, mock_post, mock_sleep): + """Server says 'try again in 45 seconds' — wait should be 45s.""" + mock_post.side_effect = [ + make_throttle_response("try again in 45 seconds"), + make_success_response(), + ] + + execute_rest(make_context(), 'test/endpoint', make_payload()) + + # max(45, 30*2^0=30) = 45 + mock_sleep.assert_called_once_with(45) + + @patch('keepercommander.rest_api.time.sleep') + @patch('keepercommander.rest_api.requests.post') + def test_parses_minutes_hint(self, mock_post, mock_sleep): + """Server says 'try again in 2 minutes' — wait should be 120s.""" + mock_post.side_effect = [ + make_throttle_response("try again in 2 minutes"), + make_success_response(), + ] + + execute_rest(make_context(), 'test/endpoint', make_payload()) + + # max(120, 30*2^0=30) = 120 + mock_sleep.assert_called_once_with(120) + + @patch('keepercommander.rest_api.time.sleep') + @patch('keepercommander.rest_api.requests.post') + def test_caps_server_wait_at_300s(self, mock_post, mock_sleep): + """Server says 'try again in 49 minutes' — capped to 300s.""" + mock_post.side_effect = [ + make_throttle_response("try again in 49 minutes"), + make_success_response(), + ] + + execute_rest(make_context(), 'test/endpoint', make_payload()) + + # min(2940, 300)=300; max(300, 30*2^0=30) = 300 + mock_sleep.assert_called_once_with(300) + + @patch('keepercommander.rest_api.time.sleep') + @patch('keepercommander.rest_api.requests.post') + def test_exponential_backoff_progression(self, mock_post, mock_sleep): + """Verify backoff doubles: 30s, 60s, 120s when server hint is small.""" + mock_post.return_value = make_throttle_response("try again in 10 seconds") + + with self.assertRaises(KeeperApiError): + execute_rest(make_context(), 'test/endpoint', make_payload()) + + # Server says 10s, but backoff wins: max(10, 30*2^0)=30, max(10, 30*2^1)=60, max(10, 30*2^2)=120 + calls = [c[0][0] for c in mock_sleep.call_args_list] + self.assertEqual(calls, [30, 60, 120]) + + @patch('keepercommander.rest_api.time.sleep') + @patch('keepercommander.rest_api.requests.post') + def test_no_message_defaults_to_60s(self, mock_post, mock_sleep): + """Missing 'try again' text defaults to 60s server hint.""" + mock_post.side_effect = [ + make_throttle_response("Rate limit exceeded"), # no "try again in X" + make_success_response(), + ] + + execute_rest(make_context(), 'test/endpoint', make_payload()) + + # Default 60s; max(60, 30*2^0=30) = 60 + mock_sleep.assert_called_once_with(60) + + +if __name__ == '__main__': + unittest.main()