Skip to content

Commit

Permalink
Add alias feature. KC-257
Browse files Browse the repository at this point in the history
  • Loading branch information
sk-keeper committed Jan 19, 2023
1 parent 4cb8e8a commit bd5457e
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 39 deletions.
134 changes: 95 additions & 39 deletions keepercommander/commands/enterprise.py
Expand Up @@ -36,20 +36,14 @@
from .aram import ActionReportCommand
from .base import user_choice, suppress_exit, raise_parse_exception, dump_report_data, Command
from .enterprise_common import EnterpriseCommand
from .enterprise_push import EnterprisePushCommand, enterprise_push_parser
from .scim import ScimCommand
from .transfer_account import EnterpriseTransferUserCommand, transfer_user_parser
from .enterprise_push import EnterprisePushCommand, enterprise_push_parser
from .. import api, rest_api, crypto, utils, constants
from ..display import bcolors
from ..error import CommandError
from ..error import CommandError, KeeperApiError
from ..params import KeeperParams
from ..proto import record_pb2 as record_proto
from ..proto.APIRequest_pb2 import (UserDataKeyRequest, UserDataKeyResponse, SecurityReportRequest,
SecurityReportResponse, SecurityReportSaveRequest, SecurityReport,
SecurityReportIncrementalData)
from ..proto.enterprise_pb2 import (EnterpriseUserIds, ApproveUserDeviceRequest, ApproveUserDevicesRequest,
ApproveUserDevicesResponse, EnterpriseUserDataKeys, SetRestrictVisibilityRequest,
GetSharingAdminsRequest, GetSharingAdminsResponse)
from ..proto import record_pb2, APIRequest_pb2, enterprise_pb2


def register_commands(commands):
Expand Down Expand Up @@ -164,6 +158,8 @@ def register_command_info(aliases, command_info):
enterprise_user_parser.add_argument('-hsf', '--hide-shared-folders', dest='hide_shared_folders', action='store',
choices=['on', 'off'], help='User does not see shared folders. --add-team only')
enterprise_user_parser.add_argument('--remove-team', dest='remove_team', action='append', help='team name or team UID')
enterprise_user_parser.add_argument('--add-alias', dest='add_alias', action='store', metavar="EMAIL", help='new email alias for a user')
enterprise_user_parser.add_argument('--delete-alias', dest='delete_alias', action='store', metavar="EMAIL", help='delete email alias')
enterprise_user_parser.add_argument('email', type=str, nargs='+', help='User Email or ID. Can be repeated.')
enterprise_user_parser.error = raise_parse_exception
enterprise_user_parser.exit = suppress_exit
Expand Down Expand Up @@ -894,7 +890,7 @@ def execute(self, params, **kwargs):
node_id = mn['node_id']
data = mn['data']
displayname = data['displayname']
request = SetRestrictVisibilityRequest()
request = enterprise_pb2.SetRestrictVisibilityRequest()
request.nodeId = node_id
try:
api.communicate_rest(params, request, 'enterprise/set_restrict_visibility')
Expand Down Expand Up @@ -1200,14 +1196,19 @@ def execute(self, params, **kwargs):
user_lookup = {}
if 'users' in params.enterprise:
for u in params.enterprise['users']:

user_lookup[str(u['enterprise_user_id'])] = u

if 'username' in u:
user_lookup[u['username'].lower()] = u
else:
logging.debug('All users: %s', params.enterprise['users'])
logging.debug('WARNING: username is missing from the user id=%s, obj=%s', u['enterprise_user_id'], u)
if 'user_aliases' in params.enterprise:
for alias in params.enterprise['user_aliases']:
username = alias['username'].lower()
if username not in user_lookup:
user_id = str(alias['enterprise_user_id'])
if user_id in user_lookup:
user_lookup[username] = user_lookup[user_id]

emails = kwargs['email']
if emails:
Expand Down Expand Up @@ -1304,7 +1305,50 @@ def execute(self, params, **kwargs):
}
request_batch.append(rq)
else:
if kwargs.get('lock') or kwargs.get('unlock'):
if kwargs.get('add_alias'):
new_alias = kwargs['add_alias'].lower()
if len(matched_users) == 1:
user = matched_users[0]
enterprise_user_id = user['enterprise_user_id']
aliases = {x['username'].lower() for x in params.enterprise.get('user_aliases', []) if x['enterprise_user_id'] == enterprise_user_id}
existing_alias = new_alias in aliases
if existing_alias:
endpoint = 'enterprise/enterprise_user_set_primary_alias'
rq = APIRequest_pb2.EnterpriseUserAliasRequest()
rq.enterpriseUserId = enterprise_user_id
rq.alias = new_alias
else:
endpoint = 'enterprise/enterprise_user_add_alias'
rq = APIRequest_pb2.EnterpriseUserAddAliasRequest()
rq.enterpriseUserId = enterprise_user_id
rq.alias = new_alias
rq.primary = True
try:
api.communicate_rest(params, rq, endpoint)
logging.info('Added alias \"%s\" for user \"%s\"', new_alias, user['username'])
api.query_enterprise(params)
except KeeperApiError as kae:
logging.warning('Failed to add alias for user \"%s\": %s', user['username'], kae.message)
else:
logging.warning('Alias can be added to a single user only: Skipping')
return
elif kwargs.get('delete_alias'):
alias = kwargs['delete_alias']
if len(matched_users) == 1:
user = matched_users[0]
rq = APIRequest_pb2.EnterpriseUserAddAliasRequest()
rq.enterpriseUserId = user['enterprise_user_id']
rq.alias = alias
try:
api.communicate_rest(params, rq, 'enterprise/enterprise_user_delete_alias')
logging.info('Alias \"%s\" deleted from user \"%s\"', alias, user['username'])
api.query_enterprise(params)
except KeeperApiError as kae:
logging.warning('Failed to delete alias \"%s\" from user \"%s\": %s', alias, user['username'], kae.message)
else:
logging.warning('Alias can be deleted from a single user only: Skipping')
return
elif kwargs.get('lock') or kwargs.get('unlock'):
for user in matched_users:
if user['status'] == 'active':
to_lock = kwargs.get('lock')
Expand Down Expand Up @@ -1590,7 +1634,7 @@ def execute(self, params, **kwargs):
api.query_enterprise(params)

if disable_2fa_users:
uids = EnterpriseUserIds()
uids = enterprise_pb2.EnterpriseUserIds()
for user in disable_2fa_users:
uids.enterpriseUserId.append(user['enterprise_user_id'])
api.communicate_rest(params, uids, 'enterprise/disable_two_fa')
Expand All @@ -1605,8 +1649,10 @@ def execute(self, params, **kwargs):
print('\n')

def display_user(self, params, user, is_verbose=False):
print('{0:>16s}: {1}'.format('User ID', user['enterprise_user_id']))
print('{0:>16s}: {1}'.format('Email', user['username'] if 'username' in user else '[empty]'))
enterprise_user_id = user['enterprise_user_id']
username = user['username'] if 'username' in user else '[empty]'
print('{0:>16s}: {1}'.format('User ID', enterprise_user_id))
print('{0:>16s}: {1}'.format('Email', username))
print('{0:>16s}: {1}'.format('Display Name', user['data'].get('displayname') or ''))
node_id = user['node_id']
print('{0:>16s}: {1:<24s}{2}'.format(
Expand All @@ -1622,6 +1668,13 @@ def display_user(self, params, user, is_verbose=False):
if acct_transfer_status:
print('{0:>16s}: {1}'.format('Transfer Status', acct_transfer_status))

if 'user_aliases' in params.enterprise:
aliases = [x['username'] for x in params.enterprise['user_aliases'] if x['enterprise_user_id'] == enterprise_user_id and x['username'] != username]
if len(aliases) > 0:
aliases.sort()
for i in range(len(aliases)):
print('{0:>16s}: {1}'.format('Email Alias' if i == 0 else '', aliases[i]))

if 'role_users' in params.enterprise:
role_ids = [x['role_id'] for x in params.enterprise['role_users'] if x['enterprise_user_id'] == user['enterprise_user_id']]
if len(role_ids) > 0:
Expand Down Expand Up @@ -1673,9 +1726,9 @@ def get_share_administrators(params, user): # type: (KeeperParams, dict) -> Op
try:
if isinstance(user, dict):
if 'share_admins' not in user:
rq = GetSharingAdminsRequest()
rq = enterprise_pb2.GetSharingAdminsRequest()
rq.username = user['username']
rs = api.communicate_rest(params, rq, 'enterprise/get_sharing_admins', rs_type=GetSharingAdminsResponse)
rs = api.communicate_rest(params, rq, 'enterprise/get_sharing_admins', rs_type=enterprise_pb2.GetSharingAdminsResponse)
user['share_admins'] = [x.email for x in rs.userProfileExts
if x.isShareAdminForRequestedObject or x.isMSPMCAdmin]
return [x for x in user['share_admins']]
Expand Down Expand Up @@ -1884,11 +1937,11 @@ def execute(self, params, **kwargs):
}
types = [x.strip().lower() for x in enforcement_value.split(',')]

rq = record_proto.RecordTypesRequest()
rq = record_pb2.RecordTypesRequest()
rq.standard = True
rq.user = True
rq.enterprise = True
record_types_rs = api.communicate_rest(params, rq, 'vault/get_record_types', rs_type=record_proto.RecordTypesResponse)
record_types_rs = api.communicate_rest(params, rq, 'vault/get_record_types', rs_type=record_pb2.RecordTypesResponse)
lookup = {}
for rti in record_types_rs.recordTypes:
try:
Expand All @@ -1901,18 +1954,18 @@ def execute(self, params, **kwargs):
for rt in types:
if rt in lookup:
rti = lookup[rt]
if rti[1] == record_proto.RT_STANDARD:
if rti[1] == record_pb2.RT_STANDARD:
record_types['std'].append(rti[0])
elif rti[1] == record_proto.RT_ENTERPRISE:
elif rti[1] == record_pb2.RT_ENTERPRISE:
record_types['ent'].append(rti[0])
else:
if rt == 'all':
record_types['std'].clear()
record_types['ent'].clear()
for rti in lookup.values():
if rti[1] == record_proto.RT_STANDARD:
if rti[1] == record_pb2.RT_STANDARD:
record_types['std'].append(rti[0])
elif rti[1] == record_proto.RT_ENTERPRISE:
elif rti[1] == record_pb2.RT_ENTERPRISE:
record_types['ent'].append(rti[0])
break
else:
Expand Down Expand Up @@ -2872,8 +2925,9 @@ def execute(self, params, **kwargs):
save_report = kwargs.get('save')
show_updated = save_report or kwargs.get('show_updated')
updated_security_reports = []
rq = SecurityReportRequest()
security_report_data_rs = api.communicate_rest(params, rq, 'enterprise/get_security_report_data', rs_type=SecurityReportResponse)
rq = APIRequest_pb2.SecurityReportRequest()
security_report_data_rs = api.communicate_rest(
params, rq, 'enterprise/get_security_report_data', rs_type=APIRequest_pb2.SecurityReportResponse)
rsa_key = self.get_enterprise_private_rsa_key(params, security_report_data_rs.enterprisePrivateKey)
rows = []
for sr in security_report_data_rs.securityReport:
Expand Down Expand Up @@ -2911,7 +2965,7 @@ def execute(self, params, **kwargs):
data = self.get_updated_security_report_row(sr, rsa_key, data)

if save_report:
updated_sr = SecurityReport()
updated_sr = APIRequest_pb2.SecurityReport()
updated_sr.revision = security_report_data_rs.asOfRevision
updated_sr.enterpriseUserId = sr.enterpriseUserId
report = json.dumps(data).encode('utf-8')
Expand Down Expand Up @@ -2960,9 +3014,9 @@ def execute(self, params, **kwargs):
return dump_report_data(table, field_descriptions, fmt=fmt, filename=kwargs.get('output'))

def get_updated_security_report_row(self, sr, rsa_key, last_saved_data):
# type: (SecurityReport, RSAPrivateKey, Dict[str, int]) -> Dict[str, int]
# type: (APIRequest_pb2.SecurityReport, RSAPrivateKey, Dict[str, int]) -> Dict[str, int]
def apply_incremental_data(old_report_data, incremental_dataset, key):
# type: (Dict[str, int], List[SecurityReportIncrementalData], RSAPrivateKey) -> Dict[str, int]
# type: (Dict[str, int], List[APIRequest_pb2.SecurityReportIncrementalData], RSAPrivateKey) -> Dict[str, int]
def decrypt_security_data(sec_data, k): # type: (bytes, RSAPrivateKey) -> Dict[str, int] or None
if sec_data:
decrypted = None
Expand All @@ -2974,15 +3028,15 @@ def decrypt_security_data(sec_data, k): # type: (bytes, RSAPrivateKey) -> Dict[s
return None

def decrypt_incremental_data(inc_data):
# type: (SecurityReportIncrementalData) -> Dict[str, Dict[str, int] or None]
# type: (APIRequest_pb2.SecurityReportIncrementalData) -> Dict[str, Dict[str, int] or None]
decrypted = {
'old': decrypt_security_data(inc_data.oldSecurityData, key),
'curr': decrypt_security_data(inc_data.currentSecurityData, key)
}
return decrypted

def decrypt_incremental_dataset(inc_dataset):
# type: (List[SecurityReportIncrementalData]) -> List[Dict[str, Dict[str, int] or None]]
# type: (List[APIRequest_pb2.SecurityReportIncrementalData]) -> List[Dict[str, Dict[str, int] or None]]
return [decrypt_incremental_data(x) for x in inc_dataset]

def is_reset_needed(inc_datas):
Expand Down Expand Up @@ -3045,7 +3099,7 @@ def update(u_sec_data, old_sec_d, diff):
return result

def save_updated_security_reports(self, params, reports):
save_rq = SecurityReportSaveRequest()
save_rq = APIRequest_pb2.SecurityReportSaveRequest()
for r in reports:
save_rq.securityReport.append(r)
api.communicate_rest(params, save_rq, 'enterprise/save_summary_security_report')
Expand Down Expand Up @@ -3433,7 +3487,7 @@ def execute(self, params, **kwargs):
return

if kwargs.get('approve') or kwargs.get('deny'):
approve_rq = ApproveUserDevicesRequest()
approve_rq = enterprise_pb2.ApproveUserDevicesRequest()
data_keys = {}
curve = ec.SECP256R1()
if kwargs.get('approve'):
Expand All @@ -3454,9 +3508,10 @@ def execute(self, params, **kwargs):
logging.debug(e)

if ecc_private_key:
data_key_rq = UserDataKeyRequest()
data_key_rq = APIRequest_pb2.UserDataKeyRequest()
data_key_rq.enterpriseUserId.extend(user_ids)
data_key_rs = api.communicate_rest(params, data_key_rq, 'enterprise/get_enterprise_user_data_key', rs_type=EnterpriseUserDataKeys)
data_key_rs = api.communicate_rest(
params, data_key_rq, 'enterprise/get_enterprise_user_data_key', rs_type=enterprise_pb2.EnterpriseUserDataKeys)
for key in data_key_rs.keys:
enc_data_key = key.userEncryptedDataKey
if enc_data_key:
Expand All @@ -3475,9 +3530,10 @@ def execute(self, params, **kwargs):
user_ids = set([x['enterprise_user_id'] for x in matching_devices.values()])
user_ids.difference_update(data_keys.keys())
if len(user_ids) > 0:
data_key_rq = UserDataKeyRequest()
data_key_rq = APIRequest_pb2.UserDataKeyRequest()
data_key_rq.enterpriseUserId.extend(user_ids)
data_key_rs = api.communicate_rest(params, data_key_rq, 'enterprise/get_user_data_key_shared_to_enterprise', rs_type=UserDataKeyResponse)
data_key_rs = api.communicate_rest(
params, data_key_rq, 'enterprise/get_user_data_key_shared_to_enterprise', rs_type=APIRequest_pb2.UserDataKeyResponse)
if data_key_rs.noEncryptedDataKey:
user_ids = set(data_key_rs.noEncryptedDataKey)
usernames = [x['username'] for x in params.enterprise['users'] if x['enterprise_user_id'] in user_ids]
Expand All @@ -3503,7 +3559,7 @@ def execute(self, params, **kwargs):

for device in matching_devices.values():
ent_user_id = device['enterprise_user_id']
device_rq = ApproveUserDeviceRequest()
device_rq = enterprise_pb2.ApproveUserDeviceRequest()
device_rq.enterpriseUserId = ent_user_id
device_rq.encryptedDeviceToken = utils.base64_url_decode(device['encrypted_device_token'])
device_rq.denyApproval = True if kwargs.get('deny') else False
Expand Down Expand Up @@ -3534,7 +3590,7 @@ def execute(self, params, **kwargs):
if len(approve_rq.deviceRequests) == 0:
return

rs = api.communicate_rest(params, approve_rq, 'enterprise/approve_user_devices', rs_type=ApproveUserDevicesResponse)
rs = api.communicate_rest(params, approve_rq, 'enterprise/approve_user_devices', rs_type=enterprise_pb2.ApproveUserDevicesResponse)
api.query_enterprise(params)
else:
print('')
Expand Down

0 comments on commit bd5457e

Please sign in to comment.