Skip to content

Commit

Permalink
Merge pull request #284 from HSF/dev
Browse files Browse the repository at this point in the history
new iworkflow(function-as-a-task) development
  • Loading branch information
wguanicedew committed Mar 12, 2024
2 parents 95c717e + e029d95 commit 9720992
Show file tree
Hide file tree
Showing 74 changed files with 7,403 additions and 172 deletions.
3 changes: 2 additions & 1 deletion client/lib/idds/client/client.py
Expand Up @@ -20,6 +20,7 @@
from idds.common import exceptions
from idds.common.utils import get_proxy_path
from idds.client.requestclient import RequestClient
from idds.client.transformclient import TransformClient
from idds.client.catalogclient import CatalogClient
from idds.client.cacherclient import CacherClient
from idds.client.hpoclient import HPOClient
Expand All @@ -32,7 +33,7 @@
warnings.filterwarnings("ignore")


class Client(RequestClient, CatalogClient, CacherClient, HPOClient, LogsClient, MessageClient, PingClient, AuthClient):
class Client(RequestClient, TransformClient, CatalogClient, CacherClient, HPOClient, LogsClient, MessageClient, PingClient, AuthClient):

"""Main client class for IDDS rest callings."""

Expand Down
107 changes: 100 additions & 7 deletions client/lib/idds/client/clientmanager.py
Expand Up @@ -39,7 +39,8 @@
from idds.common import exceptions
from idds.common.config import (get_local_cfg_file, get_local_config_root,
get_local_config_value, get_main_config_file)
from idds.common.constants import RequestType, RequestStatus
from idds.common.constants import (WorkflowType, RequestType, RequestStatus,
TransformType, TransformStatus)
# from idds.common.utils import get_rest_host, exception_handler
from idds.common.utils import exception_handler

Expand Down Expand Up @@ -183,7 +184,7 @@ def get_local_configuration(self):

self.x509_proxy = self.get_config_value(config, None, 'x509_proxy', current=self.x509_proxy,
default='/tmp/x509up_u%d' % os.geteuid())
if not self.x509_proxy or not os.path.exists(self.x509_proxy):
if 'X509_USER_PROXY' in os.environ or not self.x509_proxy or not os.path.exists(self.x509_proxy):
proxy = get_proxy_path()
if proxy:
self.x509_proxy = proxy
Expand Down Expand Up @@ -431,16 +432,31 @@ def submit(self, workflow, username=None, userdn=None, use_dataset_name=False):
"""
self.setup_client()

scope = 'workflow'
request_type = RequestType.Workflow
transform_tag = 'workflow'
priority = 0
try:
if workflow.type in [WorkflowType.iWorkflow]:
scope = 'iworkflow'
request_type = RequestType.iWorkflow
transform_tag = workflow.get_work_tag()
priority = workflow.priority
if priority is None:
priority = 0
except Exception:
pass

props = {
'scope': 'workflow',
'scope': scope,
'name': workflow.name,
'requester': 'panda',
'request_type': RequestType.Workflow,
'request_type': request_type,
'username': username if username else workflow.username,
'userdn': userdn if userdn else workflow.userdn,
'transform_tag': 'workflow',
'transform_tag': transform_tag,
'status': RequestStatus.New,
'priority': 0,
'priority': priority,
'site': workflow.get_site(),
'lifetime': workflow.lifetime,
'workload_id': workflow.get_workload_id(),
Expand All @@ -453,7 +469,8 @@ def submit(self, workflow, username=None, userdn=None, use_dataset_name=False):
props['userdn'] = self.client.original_user_dn

if self.auth_type == 'x509_proxy':
workflow.add_proxy()
if hasattr(workflow, 'add_proxy'):
workflow.add_proxy()

if use_dataset_name or not workflow.name:
primary_init_work = workflow.get_primary_initial_collection()
Expand All @@ -469,6 +486,56 @@ def submit(self, workflow, username=None, userdn=None, use_dataset_name=False):
request_id = self.client.add_request(**props)
return request_id

@exception_handler
def submit_work(self, request_id, work, use_dataset_name=False):
"""
Submit the workflow as a request to iDDS server.
:param workflow: The workflow to be submitted.
"""
self.setup_client()

transform_type = TransformType.Workflow
transform_tag = 'work'
priority = 0
workload_id = None
try:
if work.type in [WorkflowType.iWork]:
transform_type = TransformType.iWork
transform_tag = work.get_work_tag()
workload_id = work.workload_id
priority = work.priority
if priority is None:
priority = 0
elif work.type in [WorkflowType.iWorkflow]:
transform_type = TransformType.iWorkflow
transform_tag = work.get_work_tag()
workload_id = work.workload_id
priority = work.priority
if priority is None:
priority = 0
except Exception:
pass

props = {
'workload_id': workload_id,
'transform_type': transform_type,
'transform_tag': transform_tag,
'priority': work.priority,
'retries': 0,
'parent_transform_id': None,
'previous_transform_id': None,
# 'site': work.site,
'name': work.name,
'token': work.token,
'status': TransformStatus.New,
'transform_metadata': {'version': release_version, 'work': work}
}

# print(props)
transform_id = self.client.add_transform(request_id, **props)
return transform_id

@exception_handler
def submit_build(self, workflow, username=None, userdn=None, use_dataset_name=True):
"""
Expand Down Expand Up @@ -628,6 +695,32 @@ def get_status(self, request_id=None, workload_id=None, with_detail=False, with_
# print(ret)
return str(ret)

@exception_handler
def get_transforms(self, request_id=None, workload_id=None):
"""
Get transforms.
:param workload_id: the workload id.
:param request_id: the request.
"""
self.setup_client()

tfs = self.client.get_transforms(request_id=request_id, workload_id=workload_id)
return tfs

@exception_handler
def get_transform(self, request_id=None, transform_id=None):
"""
Get transforms.
:param transform_id: the transform id.
:param request_id: the request.
"""
self.setup_client()

tf = self.client.get_transform(request_id=request_id, transform_id=transform_id)
return tf

@exception_handler
def download_logs(self, request_id=None, workload_id=None, dest_dir='./', filename=None):
"""
Expand Down
85 changes: 85 additions & 0 deletions client/lib/idds/client/transformclient.py
@@ -0,0 +1,85 @@
#!/usr/bin/env python
#
# Licensed under the Apache License, Version 2.0 (the "License");
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0OA
#
# Authors:
# - Wen Guan, <wen.guan@cern.ch>, 2024


"""
Request Rest client to access IDDS system.
"""

import os

from idds.client.base import BaseRestClient
# from idds.common.constants import RequestType, RequestStatus


class TransformClient(BaseRestClient):

"""Transform Rest client"""

TRANSFORM_BASEURL = 'transform'

def __init__(self, host=None, auth=None, timeout=None):
"""
Constructor of the BaseRestClient.
:param host: the address of the IDDS server.
:param client_proxy: the client certificate proxy.
:param timeout: timeout in seconds.
"""
super(TransformClient, self).__init__(host=host, auth=auth, timeout=timeout)

def add_transform(self, request_id, **kwargs):
"""
Add transform to the Head service.
:param kwargs: attributes of the request.
:raise exceptions if it's not registerred successfully.
"""
path = self.TRANSFORM_BASEURL
# url = self.build_url(self.host, path=path + '/')
url = self.build_url(self.host, path=os.path.join(path, str(request_id)))

data = kwargs

# if 'request_type' in data and data['request_type'] and isinstance(data['request_type'], RequestType):
# data['request_type'] = data['request_type'].value
# if 'status' in data and data['status'] and isinstance(data['status'], RequestStatus):
# data['status'] = data['status'].value

r = self.get_request_response(url, type='POST', data=data)
return r['transform_id']

def get_transforms(self, request_id=None):
"""
Get transforms from the Head service.
:param request_id: the request id.
:raise exceptions if it's not got successfully.
"""
path = self.TRANSFORM_BASEURL
url = self.build_url(self.host, path=os.path.join(path, str(request_id)))
tfs = self.get_request_response(url, type='GET')
return tfs

def get_transform(self, request_id=None, transform_id=None):
"""
Get transforms from the Head service.
:param request_id: the request id.
:param transform_id: the transform id.
:raise exceptions if it's not got successfully.
"""
path = self.TRANSFORM_BASEURL
url = self.build_url(self.host, path=os.path.join(path, str(request_id), str(transform_id)))
tf = self.get_request_response(url, type='GET')
return tf
62 changes: 4 additions & 58 deletions common/lib/idds/common/authentication.py
Expand Up @@ -11,7 +11,6 @@
import datetime
import base64
import json
import jwt
import os
import re
import requests
Expand All @@ -31,12 +30,7 @@
from urllib.parse import urlencode
raw_input = input

# from cryptography import x509
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization

# from idds.common import exceptions
from idds.common import exceptions
from idds.common.constants import HTTP_STATUS_CODE


Expand Down Expand Up @@ -163,7 +157,7 @@ def get_http_content(self, url, no_verify=False):
r = requests.get(url, allow_redirects=True, verify=should_verify(no_verify, self.get_ssl_verify()))
return r.content
except Exception as error:
return False, 'Failed to get http content for %s: %s' (str(url), str(error))
return False, 'Failed to get http content for %s: %s' % (str(url), str(error))

def get_endpoint_config(self, auth_config):
content = self.get_http_content(auth_config['oidc_config_url'], no_verify=auth_config['no_verify'])
Expand Down Expand Up @@ -285,58 +279,10 @@ def refresh_id_token(self, vo, refresh_token):
return False, 'Failed to refresh oidc token: ' + str(error)

def get_public_key(self, token, jwks_uri, no_verify=False):
headers = jwt.get_unverified_header(token)
if headers is None or 'kid' not in headers:
raise jwt.exceptions.InvalidTokenError('cannot extract kid from headers')
kid = headers['kid']

jwks = self.get_cache_value(jwks_uri)
if not jwks:
jwks_content = self.get_http_content(jwks_uri, no_verify=no_verify)
jwks = json.loads(jwks_content)
self.set_cache_value(jwks_uri, jwks)

jwk = None
for j in jwks.get('keys', []):
if j.get('kid') == kid:
jwk = j
if jwk is None:
raise jwt.exceptions.InvalidTokenError('JWK not found for kid={0}: {1}'.format(kid, str(jwks)))

public_num = RSAPublicNumbers(n=decode_value(jwk['n']), e=decode_value(jwk['e']))
public_key = public_num.public_key(default_backend())
pem = public_key.public_bytes(encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo)
return pem
raise exceptions.NotImplementedException("Method get_public_key is not implemented.")

def verify_id_token(self, vo, token):
try:
auth_config, endpoint_config = self.get_auth_endpoint_config(vo)

# check audience
decoded_token = jwt.decode(token, verify=False, options={"verify_signature": False})
audience = decoded_token['aud']
if audience not in [auth_config['audience'], auth_config['client_id']]:
# discovery_endpoint = auth_config['oidc_config_url']
return False, "The audience %s of the token doesn't match vo configuration(client_id: %s)." % (audience, auth_config['client_id']), None

public_key = self.get_public_key(token, endpoint_config['jwks_uri'], no_verify=auth_config['no_verify'])
# decode token only with RS256
if 'iss' in decoded_token and decoded_token['iss'] and decoded_token['iss'] != endpoint_config['issuer'] and endpoint_config['issuer'].startswith(decoded_token['iss']):
# iss is missing the last '/' in access tokens
issuer = decoded_token['iss']
else:
issuer = endpoint_config['issuer']

decoded = jwt.decode(token, public_key, verify=True, algorithms='RS256',
audience=audience, issuer=issuer)
decoded['vo'] = vo
if 'name' in decoded:
username = decoded['name']
else:
username = None
return True, decoded, username
except Exception as error:
return False, 'Failed to verify oidc token: ' + str(error), None
raise exceptions.NotImplementedException("Method verify_id_token is not implemented.")

def setup_oidc_client_token(self, issuer, client_id, client_secret, scope, audience):
try:
Expand Down

0 comments on commit 9720992

Please sign in to comment.