diff --git a/client/lib/idds/client/client.py b/client/lib/idds/client/client.py index d9b91af1..04294a3d 100644 --- a/client/lib/idds/client/client.py +++ b/client/lib/idds/client/client.py @@ -20,6 +20,7 @@ from idds.common import exceptions from idds.common.utils import get_proxy_path from idds.client.requestclient import RequestClient +from idds.client.transformclient import TransformClient from idds.client.catalogclient import CatalogClient from idds.client.cacherclient import CacherClient from idds.client.hpoclient import HPOClient @@ -32,7 +33,7 @@ warnings.filterwarnings("ignore") -class Client(RequestClient, CatalogClient, CacherClient, HPOClient, LogsClient, MessageClient, PingClient, AuthClient): +class Client(RequestClient, TransformClient, CatalogClient, CacherClient, HPOClient, LogsClient, MessageClient, PingClient, AuthClient): """Main client class for IDDS rest callings.""" diff --git a/client/lib/idds/client/clientmanager.py b/client/lib/idds/client/clientmanager.py index e8152ef9..ca8f13b6 100644 --- a/client/lib/idds/client/clientmanager.py +++ b/client/lib/idds/client/clientmanager.py @@ -39,7 +39,8 @@ from idds.common import exceptions from idds.common.config import (get_local_cfg_file, get_local_config_root, get_local_config_value, get_main_config_file) -from idds.common.constants import RequestType, RequestStatus +from idds.common.constants import (WorkflowType, RequestType, RequestStatus, + TransformType, TransformStatus) # from idds.common.utils import get_rest_host, exception_handler from idds.common.utils import exception_handler @@ -183,7 +184,7 @@ def get_local_configuration(self): self.x509_proxy = self.get_config_value(config, None, 'x509_proxy', current=self.x509_proxy, default='/tmp/x509up_u%d' % os.geteuid()) - if not self.x509_proxy or not os.path.exists(self.x509_proxy): + if 'X509_USER_PROXY' in os.environ or not self.x509_proxy or not os.path.exists(self.x509_proxy): proxy = get_proxy_path() if proxy: self.x509_proxy = proxy @@ -431,16 +432,31 @@ def submit(self, workflow, username=None, userdn=None, use_dataset_name=False): """ self.setup_client() + scope = 'workflow' + request_type = RequestType.Workflow + transform_tag = 'workflow' + priority = 0 + try: + if workflow.type in [WorkflowType.iWorkflow]: + scope = 'iworkflow' + request_type = RequestType.iWorkflow + transform_tag = workflow.get_work_tag() + priority = workflow.priority + if priority is None: + priority = 0 + except Exception: + pass + props = { - 'scope': 'workflow', + 'scope': scope, 'name': workflow.name, 'requester': 'panda', - 'request_type': RequestType.Workflow, + 'request_type': request_type, 'username': username if username else workflow.username, 'userdn': userdn if userdn else workflow.userdn, - 'transform_tag': 'workflow', + 'transform_tag': transform_tag, 'status': RequestStatus.New, - 'priority': 0, + 'priority': priority, 'site': workflow.get_site(), 'lifetime': workflow.lifetime, 'workload_id': workflow.get_workload_id(), @@ -453,7 +469,8 @@ def submit(self, workflow, username=None, userdn=None, use_dataset_name=False): props['userdn'] = self.client.original_user_dn if self.auth_type == 'x509_proxy': - workflow.add_proxy() + if hasattr(workflow, 'add_proxy'): + workflow.add_proxy() if use_dataset_name or not workflow.name: primary_init_work = workflow.get_primary_initial_collection() @@ -469,6 +486,56 @@ def submit(self, workflow, username=None, userdn=None, use_dataset_name=False): request_id = self.client.add_request(**props) return request_id + @exception_handler + def submit_work(self, request_id, work, use_dataset_name=False): + """ + Submit the workflow as a request to iDDS server. + + :param workflow: The workflow to be submitted. + """ + self.setup_client() + + transform_type = TransformType.Workflow + transform_tag = 'work' + priority = 0 + workload_id = None + try: + if work.type in [WorkflowType.iWork]: + transform_type = TransformType.iWork + transform_tag = work.get_work_tag() + workload_id = work.workload_id + priority = work.priority + if priority is None: + priority = 0 + elif work.type in [WorkflowType.iWorkflow]: + transform_type = TransformType.iWorkflow + transform_tag = work.get_work_tag() + workload_id = work.workload_id + priority = work.priority + if priority is None: + priority = 0 + except Exception: + pass + + props = { + 'workload_id': workload_id, + 'transform_type': transform_type, + 'transform_tag': transform_tag, + 'priority': work.priority, + 'retries': 0, + 'parent_transform_id': None, + 'previous_transform_id': None, + # 'site': work.site, + 'name': work.name, + 'token': work.token, + 'status': TransformStatus.New, + 'transform_metadata': {'version': release_version, 'work': work} + } + + # print(props) + transform_id = self.client.add_transform(request_id, **props) + return transform_id + @exception_handler def submit_build(self, workflow, username=None, userdn=None, use_dataset_name=True): """ @@ -628,6 +695,32 @@ def get_status(self, request_id=None, workload_id=None, with_detail=False, with_ # print(ret) return str(ret) + @exception_handler + def get_transforms(self, request_id=None, workload_id=None): + """ + Get transforms. + + :param workload_id: the workload id. + :param request_id: the request. + """ + self.setup_client() + + tfs = self.client.get_transforms(request_id=request_id, workload_id=workload_id) + return tfs + + @exception_handler + def get_transform(self, request_id=None, transform_id=None): + """ + Get transforms. + + :param transform_id: the transform id. + :param request_id: the request. + """ + self.setup_client() + + tf = self.client.get_transform(request_id=request_id, transform_id=transform_id) + return tf + @exception_handler def download_logs(self, request_id=None, workload_id=None, dest_dir='./', filename=None): """ diff --git a/client/lib/idds/client/transformclient.py b/client/lib/idds/client/transformclient.py new file mode 100644 index 00000000..c271b59f --- /dev/null +++ b/client/lib/idds/client/transformclient.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0OA +# +# Authors: +# - Wen Guan, , 2024 + + +""" +Request Rest client to access IDDS system. +""" + +import os + +from idds.client.base import BaseRestClient +# from idds.common.constants import RequestType, RequestStatus + + +class TransformClient(BaseRestClient): + + """Transform Rest client""" + + TRANSFORM_BASEURL = 'transform' + + def __init__(self, host=None, auth=None, timeout=None): + """ + Constructor of the BaseRestClient. + + :param host: the address of the IDDS server. + :param client_proxy: the client certificate proxy. + :param timeout: timeout in seconds. + """ + super(TransformClient, self).__init__(host=host, auth=auth, timeout=timeout) + + def add_transform(self, request_id, **kwargs): + """ + Add transform to the Head service. + + :param kwargs: attributes of the request. + + :raise exceptions if it's not registerred successfully. + """ + path = self.TRANSFORM_BASEURL + # url = self.build_url(self.host, path=path + '/') + url = self.build_url(self.host, path=os.path.join(path, str(request_id))) + + data = kwargs + + # if 'request_type' in data and data['request_type'] and isinstance(data['request_type'], RequestType): + # data['request_type'] = data['request_type'].value + # if 'status' in data and data['status'] and isinstance(data['status'], RequestStatus): + # data['status'] = data['status'].value + + r = self.get_request_response(url, type='POST', data=data) + return r['transform_id'] + + def get_transforms(self, request_id=None): + """ + Get transforms from the Head service. + + :param request_id: the request id. + + :raise exceptions if it's not got successfully. + """ + path = self.TRANSFORM_BASEURL + url = self.build_url(self.host, path=os.path.join(path, str(request_id))) + tfs = self.get_request_response(url, type='GET') + return tfs + + def get_transform(self, request_id=None, transform_id=None): + """ + Get transforms from the Head service. + + :param request_id: the request id. + :param transform_id: the transform id. + + :raise exceptions if it's not got successfully. + """ + path = self.TRANSFORM_BASEURL + url = self.build_url(self.host, path=os.path.join(path, str(request_id), str(transform_id))) + tf = self.get_request_response(url, type='GET') + return tf diff --git a/common/lib/idds/common/authentication.py b/common/lib/idds/common/authentication.py index 2d2e83cb..336ba8fc 100644 --- a/common/lib/idds/common/authentication.py +++ b/common/lib/idds/common/authentication.py @@ -11,7 +11,6 @@ import datetime import base64 import json -import jwt import os import re import requests @@ -31,12 +30,7 @@ from urllib.parse import urlencode raw_input = input -# from cryptography import x509 -from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import serialization - -# from idds.common import exceptions +from idds.common import exceptions from idds.common.constants import HTTP_STATUS_CODE @@ -163,7 +157,7 @@ def get_http_content(self, url, no_verify=False): r = requests.get(url, allow_redirects=True, verify=should_verify(no_verify, self.get_ssl_verify())) return r.content except Exception as error: - return False, 'Failed to get http content for %s: %s' (str(url), str(error)) + return False, 'Failed to get http content for %s: %s' % (str(url), str(error)) def get_endpoint_config(self, auth_config): content = self.get_http_content(auth_config['oidc_config_url'], no_verify=auth_config['no_verify']) @@ -285,58 +279,10 @@ def refresh_id_token(self, vo, refresh_token): return False, 'Failed to refresh oidc token: ' + str(error) def get_public_key(self, token, jwks_uri, no_verify=False): - headers = jwt.get_unverified_header(token) - if headers is None or 'kid' not in headers: - raise jwt.exceptions.InvalidTokenError('cannot extract kid from headers') - kid = headers['kid'] - - jwks = self.get_cache_value(jwks_uri) - if not jwks: - jwks_content = self.get_http_content(jwks_uri, no_verify=no_verify) - jwks = json.loads(jwks_content) - self.set_cache_value(jwks_uri, jwks) - - jwk = None - for j in jwks.get('keys', []): - if j.get('kid') == kid: - jwk = j - if jwk is None: - raise jwt.exceptions.InvalidTokenError('JWK not found for kid={0}: {1}'.format(kid, str(jwks))) - - public_num = RSAPublicNumbers(n=decode_value(jwk['n']), e=decode_value(jwk['e'])) - public_key = public_num.public_key(default_backend()) - pem = public_key.public_bytes(encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo) - return pem + raise exceptions.NotImplementedException("Method get_public_key is not implemented.") def verify_id_token(self, vo, token): - try: - auth_config, endpoint_config = self.get_auth_endpoint_config(vo) - - # check audience - decoded_token = jwt.decode(token, verify=False, options={"verify_signature": False}) - audience = decoded_token['aud'] - if audience not in [auth_config['audience'], auth_config['client_id']]: - # discovery_endpoint = auth_config['oidc_config_url'] - return False, "The audience %s of the token doesn't match vo configuration(client_id: %s)." % (audience, auth_config['client_id']), None - - public_key = self.get_public_key(token, endpoint_config['jwks_uri'], no_verify=auth_config['no_verify']) - # decode token only with RS256 - if 'iss' in decoded_token and decoded_token['iss'] and decoded_token['iss'] != endpoint_config['issuer'] and endpoint_config['issuer'].startswith(decoded_token['iss']): - # iss is missing the last '/' in access tokens - issuer = decoded_token['iss'] - else: - issuer = endpoint_config['issuer'] - - decoded = jwt.decode(token, public_key, verify=True, algorithms='RS256', - audience=audience, issuer=issuer) - decoded['vo'] = vo - if 'name' in decoded: - username = decoded['name'] - else: - username = None - return True, decoded, username - except Exception as error: - return False, 'Failed to verify oidc token: ' + str(error), None + raise exceptions.NotImplementedException("Method verify_id_token is not implemented.") def setup_oidc_client_token(self, issuer, client_id, client_secret, scope, audience): try: diff --git a/common/lib/idds/common/constants.py b/common/lib/idds/common/constants.py index 25e6650f..1d527d19 100644 --- a/common/lib/idds/common/constants.py +++ b/common/lib/idds/common/constants.py @@ -97,6 +97,12 @@ def from_dict(d): return d +class WorkflowType(IDDSEnum): + Workflow = 0 + iWorkflow = 1 + iWork = 2 + + class WorkStatus(IDDSEnum): New = 0 Ready = 1 @@ -189,6 +195,7 @@ class RequestType(IDDSEnum): ActiveLearning = 3 HyperParameterOpt = 4 Derivation = 5 + iWorkflow = 6 Other = 99 @@ -202,6 +209,8 @@ class TransformType(IDDSEnum): Processing = 6 Actuating = 7 Data = 8 + iWorkflow = 9 + iWork = 10 Other = 99 @@ -313,6 +322,21 @@ class GranularityType(IDDSEnum): Event = 1 +class ProcessingType(IDDSEnum): + Workflow = 0 + EventStreaming = 1 + StageIn = 2 + ActiveLearning = 3 + HyperParameterOpt = 4 + Derivation = 5 + Processing = 6 + Actuating = 7 + Data = 8 + iWorkflow = 9 + iWork = 10 + Other = 99 + + class ProcessingStatus(IDDSEnum): New = 0 Submitting = 1 @@ -344,6 +368,7 @@ class ProcessingStatus(IDDSEnum): Terminating = 27 ToTrigger = 28 Triggering = 29 + Synchronizing = 30 class ProcessingLocking(IDDSEnum): @@ -524,3 +549,59 @@ def get_work_status_from_transform_processing_status(status): return WorkStatus.Terminating else: return WorkStatus.Transforming + + +def get_transform_status_from_processing_status(status): + map = {ProcessingStatus.New: TransformStatus.Transforming, # when Processing is created, set it to transforming + ProcessingStatus.Submitting: TransformStatus.Transforming, + ProcessingStatus.Submitted: TransformStatus.Transforming, + ProcessingStatus.Running: TransformStatus.Transforming, + ProcessingStatus.Finished: TransformStatus.Finished, + ProcessingStatus.Failed: TransformStatus.Failed, + ProcessingStatus.Lost: TransformStatus.Failed, + ProcessingStatus.Cancel: TransformStatus.Cancelled, + ProcessingStatus.FinishedOnStep: TransformStatus.Finished, + ProcessingStatus.FinishedOnExec: TransformStatus.Finished, + ProcessingStatus.FinishedTerm: TransformStatus.Finished, + ProcessingStatus.SubFinished: TransformStatus.SubFinished, + ProcessingStatus.ToCancel: TransformStatus.ToCancel, + ProcessingStatus.Cancelling: TransformStatus.Cancelling, + ProcessingStatus.Cancelled: TransformStatus.Cancelled, + ProcessingStatus.ToSuspend: TransformStatus.ToSuspend, + ProcessingStatus.Suspending: TransformStatus.Suspending, + ProcessingStatus.Suspended: TransformStatus.Suspended, + ProcessingStatus.ToResume: TransformStatus.ToResume, + ProcessingStatus.Resuming: TransformStatus.Resuming, + ProcessingStatus.ToExpire: TransformStatus.ToExpire, + ProcessingStatus.Expiring: TransformStatus.Expiring, + ProcessingStatus.Expired: TransformStatus.Expired, + ProcessingStatus.TimeOut: TransformStatus.Failed, + ProcessingStatus.ToFinish: TransformStatus.ToFinish, + ProcessingStatus.ToForceFinish: TransformStatus.ToForceFinish, + ProcessingStatus.Broken: TransformStatus.Failed, + ProcessingStatus.Terminating: TransformStatus.Terminating, + ProcessingStatus.ToTrigger: TransformStatus.Transforming, + ProcessingStatus.Triggering: TransformStatus.Transforming, + ProcessingStatus.Synchronizing: TransformStatus.Transforming + } + if status in map: + return map[status] + return WorkStatus.Transforming + + +def get_processing_type_from_transform_type(tf_type): + map = {TransformType.Workflow: ProcessingType.Workflow, + TransformType.EventStreaming: ProcessingType.EventStreaming, + TransformType.StageIn: ProcessingType.StageIn, + TransformType.ActiveLearning: ProcessingType.ActiveLearning, + TransformType.HyperParameterOpt: ProcessingType.HyperParameterOpt, + TransformType.Derivation: ProcessingType.Derivation, + TransformType.Processing: ProcessingType.Processing, + TransformType.Actuating: ProcessingType.Actuating, + TransformType.Data: ProcessingType.Data, + TransformType.iWorkflow: ProcessingType.iWorkflow, + TransformType.iWork: ProcessingType.iWork, + TransformType.Other: ProcessingType.Other} + if tf_type in map: + return map[tf_type] + return ProcessingType.Other diff --git a/common/lib/idds/common/dict_class.py b/common/lib/idds/common/dict_class.py index 58f9ed94..27da70ca 100644 --- a/common/lib/idds/common/dict_class.py +++ b/common/lib/idds/common/dict_class.py @@ -59,7 +59,7 @@ def to_dict(self): # print(value) # if not key.startswith('__') and not key.startswith('_'): if not key.startswith('__'): - if key == 'logger': + if key in ['logger']: new_value = None else: new_value = self.to_dict_l(value) @@ -119,7 +119,7 @@ def from_dict(d): last_items = {} for key, value in d['attributes'].items(): # print(key) - if key == 'logger': + if key in ['logger']: continue elif key == "_metadata": last_items[key] = value diff --git a/common/lib/idds/common/imports.py b/common/lib/idds/common/imports.py new file mode 100644 index 00000000..4229f4e3 --- /dev/null +++ b/common/lib/idds/common/imports.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0OA +# +# Authors: +# - Wen Guan, , 2019 - 2024 + + +import importlib +import inspect +import os +import sys +# import traceback + +from contextlib import contextmanager +from typing import Any, Callable + + +@contextmanager +def add_cwd_path(): + """Context adding the current working directory to sys.path.""" + try: + cwd = os.getcwd() + except FileNotFoundError: + cwd = None + if not cwd: + yield + elif cwd in sys.path: + yield + else: + sys.path.insert(0, cwd) + try: + yield cwd + finally: + try: + sys.path.remove(cwd) + except ValueError: # pragma: no cover + pass + + +def get_func_name(func: Callable, base_dir=None) -> str: + """ + Return function name from a function. + """ + filename = inspect.getfile(func) + module = inspect.getmodule(func) + module_name = module.__name__ + if base_dir is None: + filename = os.path.basename(filename) + else: + filename = os.path.relpath(filename, base_dir) + if 'site-packages' in filename: + filename = filename.split('site-packages') + if filename.startswith('/'): + filename = filename[1:] + return filename + ":" + module_name + ":" + func.__name__ + + +def import_func(name: str) -> Callable[..., Any]: + """Returns a function from a dotted path name. Example: `path.to.module:func`. + + When the attribute we look for is a staticmethod, module name in its + dotted path is not the last-before-end word + + E.g.: filename:module_a.module_b:ClassA.my_static_method + + Thus we remove the bits from the end of the name until we can import it + + Args: + name (str): The name (reference) to the path. + + Raises: + ValueError: If no module is found or invalid attribute name. + + Returns: + Any: An attribute (normally a Callable) + """ + with add_cwd_path(): + filename, module_name_bits, attribute_bits = name.split(':') + # module_name_bits, attribute_bits = name_bits[:-1], [name_bits[-1]] + if module_name_bits == '__main__': + module_name_bits = filename.replace('.py', '').replace('.pyc', '') + module_name_bits = module_name_bits.replace('/', '') + module_name_bits = module_name_bits.split('.') + attribute_bits = attribute_bits.split('.') + module = None + while len(module_name_bits): + try: + module_name = '.'.join(module_name_bits) + module = importlib.import_module(module_name) + break + except ImportError: + attribute_bits.insert(0, module_name_bits.pop()) + + if module is None: + # maybe it's a builtin + try: + return __builtins__[name] + except KeyError: + raise ValueError('Invalid attribute name: %s' % name) + + attribute_name = '.'.join(attribute_bits) + if hasattr(module, attribute_name): + return getattr(module, attribute_name) + # staticmethods + attribute_name = attribute_bits.pop() + attribute_owner_name = '.'.join(attribute_bits) + try: + attribute_owner = getattr(module, attribute_owner_name) + except: # noqa + raise ValueError('Invalid attribute name: %s' % attribute_name) + + if not hasattr(attribute_owner, attribute_name): + raise ValueError('Invalid attribute name: %s' % name) + return getattr(attribute_owner, attribute_name) + + +def import_attribute(name: str) -> Callable[..., Any]: + """Returns an attribute from a dotted path name. Example: `path.to.func`. + + When the attribute we look for is a staticmethod, module name in its + dotted path is not the last-before-end word + + E.g.: package_a.package_b.module_a.ClassA.my_static_method + + Thus we remove the bits from the end of the name until we can import it + + Args: + name (str): The name (reference) to the path. + + Raises: + ValueError: If no module is found or invalid attribute name. + + Returns: + Any: An attribute (normally a Callable) + """ + name_bits = name.split('.') + module_name_bits, attribute_bits = name_bits[:-1], [name_bits[-1]] + module = None + while len(module_name_bits): + try: + module_name = '.'.join(module_name_bits) + module = importlib.import_module(module_name) + break + except ImportError: + attribute_bits.insert(0, module_name_bits.pop()) + + if module is None: + # maybe it's a builtin + try: + return __builtins__[name] + except KeyError: + raise ValueError('Invalid attribute name: %s' % name) + + attribute_name = '.'.join(attribute_bits) + if hasattr(module, attribute_name): + return getattr(module, attribute_name) + # staticmethods + attribute_name = attribute_bits.pop() + attribute_owner_name = '.'.join(attribute_bits) + try: + attribute_owner = getattr(module, attribute_owner_name) + except: # noqa + raise ValueError('Invalid attribute name: %s' % attribute_name) + + if not hasattr(attribute_owner, attribute_name): + raise ValueError('Invalid attribute name: %s' % name) + return getattr(attribute_owner, attribute_name) diff --git a/common/lib/idds/common/utils.py b/common/lib/idds/common/utils.py index 34af0160..e6d8d9f7 100644 --- a/common/lib/idds/common/utils.py +++ b/common/lib/idds/common/utils.py @@ -13,6 +13,7 @@ import errno import datetime import importlib +import hashlib import logging import json import os @@ -63,11 +64,15 @@ def setup_logging(name, stream=None, loglevel=None): else: loglevel = logging.INFO - if os.environ.get('IDDS_LOG_LEVEL', None): - idds_log_level = os.environ.get('IDDS_LOG_LEVEL', None) - idds_log_level = idds_log_level.upper() - if idds_log_level in ["DEBUG", "CRITICAL", "ERROR", "WARNING", "INFO"]: - loglevel = getattr(logging, idds_log_level) + if os.environ.get('IDDS_LOG_LEVEL', None): + idds_log_level = os.environ.get('IDDS_LOG_LEVEL', None) + idds_log_level = idds_log_level.upper() + if idds_log_level in ["DEBUG", "CRITICAL", "ERROR", "WARNING", "INFO"]: + loglevel = getattr(logging, idds_log_level) + + if type(loglevel) in [str]: + loglevel = loglevel.upper() + loglevel = getattr(logging, loglevel) if stream is None: if config_has_section('common') and config_has_option('common', 'logdir'): @@ -733,7 +738,7 @@ def group_list(input_list, key): return update_groups -def import_fun(name: str) -> Callable[..., Any]: +def import_func(name: str) -> Callable[..., Any]: """Returns a function from a dotted path name. Example: `path.to.module:func`. When the attribute we look for is a staticmethod, module name in its @@ -873,6 +878,26 @@ def create_archive_file(work_dir, archive_filename, files): with tarfile.open(archive_filename, "w:gz", dereference=True) as tar: for local_file in files: - # base_name = os.path.basename(local_file) - tar.add(local_file, arcname=os.path.basename(local_file)) + if os.path.isfile(local_file): + # base_name = os.path.basename(local_file) + tar.add(local_file, arcname=os.path.basename(local_file)) + elif os.path.isdir(local_file): + for root, dirs, fs in os.walk(local_file): + for f in fs: + file_path = os.path.join(root, f) + tar.add(file_path, arcname=os.path.relpath(file_path, local_file)) return archive_filename + + +class SecureString(object): + def __init__(self, value): + self._value = value + + def __str__(self): + return '****' + + +def get_unique_id_for_dict(dict_): + ret = hashlib.sha1(json.dumps(dict_, sort_keys=True).encode()).hexdigest() + # logging.debug("get_unique_id_for_dict, type: %s: %s, ret: %s" % (type(dict_), dict_, ret)) + return ret diff --git a/common/tools/env/environment.yml b/common/tools/env/environment.yml index c991fe97..fd136187 100644 --- a/common/tools/env/environment.yml +++ b/common/tools/env/environment.yml @@ -3,8 +3,6 @@ dependencies: - python==3.6 - pip - pip: - - cryptography - - pyjwt # Pyjwt - packaging - requests - - dogpile.cache \ No newline at end of file + - dogpile.cache diff --git a/main/etc/sql/oracle_update.sql b/main/etc/sql/oracle_update.sql index 0da5cb11..4e32ed30 100644 --- a/main/etc/sql/oracle_update.sql +++ b/main/etc/sql/oracle_update.sql @@ -468,3 +468,8 @@ CREATE TABLE meta_info CONSTRAINT METAINFO_NAME_UQ UNIQUE (name) ); +--- 20240219 +alter table TRANSFORMS add (parent_transform_id NUMBER(12)); +alter table TRANSFORMS add (previous_transform_id NUMBER(12)); +alter table TRANSFORMS add (current_processing_id NUMBER(12)); +alter table PROCESSINGS add (processing_type NUMBER(2)); diff --git a/main/lib/idds/agents/carrier/finisher.py b/main/lib/idds/agents/carrier/finisher.py index ee25f0e0..5ab39859 100644 --- a/main/lib/idds/agents/carrier/finisher.py +++ b/main/lib/idds/agents/carrier/finisher.py @@ -6,21 +6,29 @@ # http://www.apache.org/licenses/LICENSE-2.0OA # # Authors: -# - Wen Guan, , 2019 - 2023 +# - Wen Guan, , 2019 - 2024 import time import traceback -from idds.common.constants import (Sections, ReturnCode, ProcessingStatus, ProcessingLocking) +from idds.common import exceptions +from idds.common.constants import (Sections, ReturnCode, ProcessingType, + ProcessingStatus, ProcessingLocking) from idds.common.utils import setup_logging, truncate_string +from idds.core import processings as core_processings +from idds.agents.common.baseagent import BaseAgent from idds.agents.common.eventbus.event import (EventType, UpdateProcessingEvent, + SyncProcessingEvent, + TerminatedProcessingEvent, UpdateTransformEvent) + from .utils import (handle_abort_processing, handle_resume_processing, # is_process_terminated, sync_processing) +from .iutils import sync_iprocessing from .poller import Poller setup_logging(__name__) @@ -63,6 +71,57 @@ def show_queue_size(self): q_str = "number of processings: %s, max number of processings: %s" % (self.number_workers, self.max_number_workers) self.logger.debug(q_str) + def get_finishing_processings(self): + """ + Get finishing processing + """ + try: + if not self.is_ok_to_run_more_processings(): + return [] + + self.show_queue_size() + + if BaseAgent.min_request_id is None: + return [] + + processing_status = [ProcessingStatus.Terminating, ProcessingStatus.Synchronizing] + # next_poll_at = datetime.datetime.utcnow() + datetime.timedelta(seconds=self.poll_period) + processings = core_processings.get_processings_by_status(status=processing_status, + locking=True, update_poll=True, + not_lock=True, + # only_return_id=True, + min_request_id=BaseAgent.min_request_id, + bulk_size=self.retrieve_bulk_size) + + # self.logger.debug("Main thread get %s [submitting + submitted + running] processings to process" % (len(processings))) + if processings: + self.logger.info("Main thread get terminating/synchronizing processings to process: %s" % (str(processings))) + + events, pr_ids = [], [] + for pr in processings: + pr_id = pr['processing_id'] + pr_ids.append(pr_id) + pr_status = pr['status'] + if pr_status in [ProcessingStatus.Terminating]: + self.logger.info("TerminatedProcessingEvent(processing_id: %s)" % pr_id) + event = TerminatedProcessingEvent(publisher_id=self.id, processing_id=pr_id) + events.append(event) + elif pr_status in [ProcessingStatus.Synchronizing]: + self.logger.info("SyncProcessingEvent(processing_id: %s)" % pr_id) + event = SyncProcessingEvent(publisher_id=self.id, processing_id=pr_id) + events.append(event) + self.event_bus.send_bulk(events) + + return pr_ids + except exceptions.DatabaseException as ex: + if 'ORA-00060' in str(ex): + self.logger.warn("(cx_Oracle.DatabaseError) ORA-00060: deadlock detected while waiting for resource") + else: + # raise ex + self.logger.error(ex) + self.logger.error(traceback.format_exc()) + return [] + def handle_sync_processing(self, processing, log_prefix=""): """ process terminated processing @@ -90,6 +149,33 @@ def handle_sync_processing(self, processing, log_prefix=""): return ret return None + def handle_sync_iprocessing(self, processing, log_prefix=""): + """ + process terminated processing + """ + try: + processing, update_collections, messages = sync_iprocessing(processing, self.agent_attributes, logger=self.logger, log_prefix=log_prefix) + + update_processing = {'processing_id': processing['processing_id'], + 'parameters': {'status': processing['status'], + 'locking': ProcessingLocking.Idle}} + ret = {'update_processing': update_processing, + 'update_collections': update_collections, + 'messages': messages} + return ret + except Exception as ex: + self.logger.error(ex) + self.logger.error(traceback.format_exc()) + error = {'sync_err': {'msg': truncate_string('%s' % (ex), length=200)}} + update_processing = {'processing_id': processing['processing_id'], + 'parameters': {'status': ProcessingStatus.Running, + 'locking': ProcessingLocking.Idle, + 'errors': processing['errors'] if processing['errors'] else {}}} + update_processing['parameters']['errors'].update(error) + ret = {'update_processing': update_processing} + return ret + return None + def process_sync_processing(self, event): self.number_workers += 1 pro_ret = ReturnCode.Ok.value @@ -104,7 +190,10 @@ def process_sync_processing(self, event): log_pre = self.get_log_prefix(pr) self.logger.info(log_pre + "process_sync_processing") - ret = self.handle_sync_processing(pr, log_prefix=log_pre) + if pr['processing_type'] and pr['processing_type'] in [ProcessingType.iWorkflow, ProcessingType.iWork]: + ret = self.handle_sync_iprocessing(pr, log_prefix=log_pre) + else: + ret = self.handle_sync_processing(pr, log_prefix=log_pre) ret_copy = {} for ret_key in ret: if ret_key != 'messages': @@ -152,6 +241,34 @@ def handle_terminated_processing(self, processing, log_prefix=""): return ret return None + def handle_terminated_iprocessing(self, processing, log_prefix=""): + """ + process terminated processing + """ + try: + processing, update_collections, messages = sync_iprocessing(processing, self.agent_attributes, terminate=True, logger=self.logger, log_prefix=log_prefix) + + update_processing = {'processing_id': processing['processing_id'], + 'parameters': {'status': processing['status'], + 'locking': ProcessingLocking.Idle}} + ret = {'update_processing': update_processing, + 'update_collections': update_collections, + 'messages': messages} + + return ret + except Exception as ex: + self.logger.error(ex) + self.logger.error(traceback.format_exc()) + error = {'term_err': {'msg': truncate_string('%s' % (ex), length=200)}} + update_processing = {'processing_id': processing['processing_id'], + 'parameters': {'status': ProcessingStatus.Running, + 'locking': ProcessingLocking.Idle, + 'errors': processing['errors'] if processing['errors'] else {}}} + update_processing['parameters']['errors'].update(error) + ret = {'update_processing': update_processing} + return ret + return None + def process_terminated_processing(self, event): self.number_workers += 1 pro_ret = ReturnCode.Ok.value @@ -169,7 +286,10 @@ def process_terminated_processing(self, event): log_pre = self.get_log_prefix(pr) self.logger.info(log_pre + "process_terminated_processing") - ret = self.handle_terminated_processing(pr, log_prefix=log_pre) + if pr['processing_type'] and pr['processing_type'] in [ProcessingType.iWorkflow, ProcessingType.iWork]: + ret = self.handle_terminated_iprocessing(pr, log_prefix=log_pre) + else: + ret = self.handle_terminated_processing(pr, log_prefix=log_pre) ret_copy = {} for ret_key in ret: if ret_key != 'messages': @@ -181,12 +301,15 @@ def process_terminated_processing(self, event): event = UpdateTransformEvent(publisher_id=self.id, transform_id=pr['transform_id']) self.event_bus.send(event) - if pr['status'] not in [ProcessingStatus.Finished, ProcessingStatus.Failed, ProcessingStatus.SubFinished]: - # some files are missing, poll it. - self.logger.info(log_pre + "UpdateProcessingEvent(processing_id: %s)" % pr['processing_id']) - event = UpdateProcessingEvent(publisher_id=self.id, processing_id=pr['processing_id'], counter=original_event._counter + 1) - event.set_terminating() - self.event_bus.send(event) + if pr['processing_type'] and pr['processing_type'] in [ProcessingType.iWorkflow, ProcessingType.iWork]: + pass + else: + if pr['status'] not in [ProcessingStatus.Finished, ProcessingStatus.Failed, ProcessingStatus.SubFinished, ProcessingStatus.Broken]: + # some files are missing, poll it. + self.logger.info(log_pre + "UpdateProcessingEvent(processing_id: %s)" % pr['processing_id']) + event = UpdateProcessingEvent(publisher_id=self.id, processing_id=pr['processing_id'], counter=original_event._counter + 1) + event.set_terminating() + self.event_bus.send(event) except Exception as ex: self.logger.error(ex) self.logger.error(traceback.format_exc()) @@ -372,6 +495,9 @@ def run(self): self.init_event_function_map() + task = self.create_task(task_func=self.get_finishing_processings, task_output_queue=None, task_args=tuple(), task_kwargs={}, delay_time=10, priority=1) + self.add_task(task) + self.execute() except KeyboardInterrupt: self.stop() diff --git a/main/lib/idds/agents/carrier/iutils.py b/main/lib/idds/agents/carrier/iutils.py new file mode 100644 index 00000000..1daa935a --- /dev/null +++ b/main/lib/idds/agents/carrier/iutils.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0OA +# +# Authors: +# - Wen Guan, , 2024 + +import datetime +import logging + +from idds.common.constants import ProcessingStatus, CollectionStatus +from idds.common.utils import setup_logging +from idds.core import catalog as core_catalog + +setup_logging(__name__) + + +def get_logger(logger=None): + if logger: + return logger + logger = logging.getLogger(__name__) + return logger + + +def is_process_terminated(processing_status): + if processing_status in [ProcessingStatus.Finished, ProcessingStatus.Failed, + ProcessingStatus.SubFinished, ProcessingStatus.Cancelled, + ProcessingStatus.Suspended, ProcessingStatus.Expired, + ProcessingStatus.Broken, ProcessingStatus.FinishedOnStep, + ProcessingStatus.FinishedOnExec, ProcessingStatus.FinishedTerm]: + return True + return False + + +def handle_new_iprocessing(processing, agent_attributes, plugin=None, func_site_to_cloud=None, max_updates_per_round=2000, executors=None, logger=None, log_prefix=''): + logger = get_logger(logger) + + work = processing['processing_metadata']['work'] + # transform_id = processing['transform_id'] + + try: + workload_id, errors = plugin.submit(work, logger=logger, log_prefix=log_prefix) + logger.info(log_prefix + "submit work (workload_id: %s, errors: %s)" % (workload_id, errors)) + except Exception as ex: + err_msg = "submit work failed with exception: %s" % (ex) + logger.error(log_prefix + err_msg) + raise Exception(err_msg) + + processing['workload_id'] = workload_id + processing['submitted_at'] = datetime.datetime.utcnow() + + # return True, processing, update_collections, new_contents, new_input_dependency_contents, ret_msgs, errors + return True, processing, [], [], [], [], errors + + +def handle_update_iprocessing(processing, agent_attributes, plugin=None, max_updates_per_round=2000, use_bulk_update_mappings=True, executors=None, logger=None, log_prefix=''): + logger = get_logger(logger) + + # work = processing['processing_metadata']['work'] + + # request_id = processing['request_id'] + # transform_id = processing['transform_id'] + workload_id = processing['workload_id'] + + try: + status = plugin.poll(workload_id, logger=logger, log_prefix=log_prefix) + logger.info(log_prefix + "poll work (status: %s, workload_id: %s)" % (status, workload_id)) + except Exception as ex: + err_msg = "poll work failed with exception: %s" % (ex) + logger.error(log_prefix + err_msg) + raise Exception(err_msg) + + return status, [], [], [], [], [], [], [] + + +def sync_iprocessing(processing, agent_attributes, terminate=False, abort=False, logger=None, log_prefix=""): + # logger = get_logger() + + # request_id = processing['request_id'] + # transform_id = processing['transform_id'] + # workload_id = processing['workload_id'] + + # work = processing['processing_metadata']['work'] + + u_colls = [] + if processing['substatus'] in [ProcessingStatus.Finished, ProcessingStatus.Failed, ProcessingStatus.SubFinished, ProcessingStatus.Broken]: + collections = core_catalog.get_collections(transform_id=processing['transform_id']) + if collections: + for coll in collections: + u_coll = {'coll_id': coll['coll_id'], + 'status': CollectionStatus.Closed} + u_colls.append(u_coll) + + processing['status'] = processing['substatus'] + + return processing, u_colls, None diff --git a/main/lib/idds/agents/carrier/plugins/__init__.py b/main/lib/idds/agents/carrier/plugins/__init__.py new file mode 100644 index 00000000..62879db4 --- /dev/null +++ b/main/lib/idds/agents/carrier/plugins/__init__.py @@ -0,0 +1,9 @@ +#!/usr/bin/env python +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0OA +# +# Authors: +# - Wen Guan, , 2024 diff --git a/main/lib/idds/agents/carrier/plugins/base.py b/main/lib/idds/agents/carrier/plugins/base.py new file mode 100644 index 00000000..cbd9ef16 --- /dev/null +++ b/main/lib/idds/agents/carrier/plugins/base.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0OA +# +# Authors: +# - Wen Guan, , 2024 + +import json + +from idds.common.constants import WorkflowType +from idds.common.utils import encode_base64 + + +class BaseSubmitter(object): + + def __init__(self, *args, **kwargs): + pass + + def get_task_params(self, work): + if work.type in [WorkflowType.iWork]: + task_name = work.name + "_" + str(work.request_id) + "_" + str(work.transform_id) + elif work.type in [WorkflowType.iWorkflow]: + task_name = work.name + "_" + str(work.request_id) + else: + task_name = work.name + + in_files = [] + group_parameters = work.group_parameters + for p in group_parameters: + p = json.dumps(p) + p = encode_base64(p) + in_files.append(p) + + task_param_map = {} + task_param_map['vo'] = work.vo if work.vo else 'wlcg' + if work.queue and len(work.queue) > 0: + task_param_map['site'] = work.queue + if work.site and len(work.site) > 0: + task_param_map['PandaSite'] = work.site + if work.cloud and len(work.cloud) > 0: + task_param_map['cloud'] = work.cloud + + task_param_map['workingGroup'] = work.working_group + + task_param_map['nFilesPerJob'] = 1 + if in_files: + # if has_dependencies: + # task_param_map['inputPreStaging'] = True + task_param_map['nFiles'] = len(in_files) + task_param_map['noInput'] = True + task_param_map['pfnList'] = in_files + else: + # task_param_map['inputPreStaging'] = True + in_files = [json.dumps('pseudo_file')] + task_param_map['nFiles'] = len(in_files) + task_param_map['noInput'] = True + task_param_map['pfnList'] = in_files + + task_param_map['taskName'] = task_name + task_param_map['userName'] = work.username if work.username else 'iDDS' + task_param_map['taskPriority'] = work.priority + task_param_map['architecture'] = '' + task_param_map['transUses'] = '' + task_param_map['transHome'] = None + + # executable = work.executable + executable = work.get_runner() + # task_param_map['transPath'] = 'https://storage.googleapis.com/drp-us-central1-containers/bash-c-enc' + # task_param_map['encJobParams'] = True + # task_param_map['transPath'] = 'https://wguan-wisc.web.cern.ch/wguan-wisc/run_workflow_wrapper' + task_param_map['transPath'] = 'http://pandaserver-doma.cern.ch:25080/trf/user/run_workflow_wrapper' + task_param_map['processingType'] = None + task_param_map['prodSourceLabel'] = 'managed' # managed, test, ptest + + task_param_map['noWaitParent'] = True + task_param_map['taskType'] = 'iDDS' + task_param_map['coreCount'] = work.core_count + task_param_map['skipScout'] = True + task_param_map['ramCount'] = work.total_memory / work.core_count if work.core_count else work.total_memory + # task_param_map['ramUnit'] = 'MB' + task_param_map['ramUnit'] = 'MBPerCoreFixed' + + # task_param_map['inputPreStaging'] = True + task_param_map['prestagingRuleID'] = 123 + task_param_map['nChunksToWait'] = 1 + task_param_map['maxCpuCount'] = work.core_count + task_param_map['maxWalltime'] = work.max_walltime + task_param_map['maxFailure'] = work.max_attempt if work.max_attempt else 5 + task_param_map['maxAttempt'] = work.max_attempt if work.max_attempt else 5 + if task_param_map['maxAttempt'] < work.max_attempt: + task_param_map['maxAttempt'] = work.max_attempt + if task_param_map['maxFailure'] < work.max_attempt: + task_param_map['maxFailure'] = work.max_attempt + task_param_map['log'] = {"dataset": "PandaJob_iworkflow/", # "PandaJob_#{pandaid}/" + "destination": "local", + "param_type": "log", + "token": "local", + "type": "template", + "value": "log.tgz"} + task_param_map['jobParameters'] = [ + {'type': 'constant', + 'value': executable, # noqa: E501 + }, + ] + + task_param_map['reqID'] = work.request_id + + return task_param_map + + def submit(self, *args, **kwargs): + pass + + +class BasePoller(object): + + def __init__(self, *args, **kwargs): + pass + + def poll(self, *args, **kwargs): + pass + + +class BaseSubmitterPoller(BaseSubmitter): + + def __init__(self, *args, **kwargs): + super(BaseSubmitterPoller, self).__init__(*args, **kwargs) + + def poll(self, *args, **kwargs): + pass diff --git a/main/lib/idds/agents/carrier/plugins/panda.py b/main/lib/idds/agents/carrier/plugins/panda.py new file mode 100644 index 00000000..c2110056 --- /dev/null +++ b/main/lib/idds/agents/carrier/plugins/panda.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0OA +# +# Authors: +# - Wen Guan, , 2024 + + +try: + import ConfigParser +except ImportError: + import configparser as ConfigParser + +import os +import traceback + +from idds.common.constants import ProcessingStatus +from .base import BaseSubmitterPoller + + +class PandaSubmitterPoller(BaseSubmitterPoller): + + def __init__(self, *args, **kwargs): + super(PandaSubmitterPoller, self).__init__() + self.load_panda_urls() + + def load_panda_config(self): + panda_config = ConfigParser.ConfigParser() + if os.environ.get('IDDS_PANDA_CONFIG', None): + configfile = os.environ['IDDS_PANDA_CONFIG'] + if panda_config.read(configfile) == [configfile]: + return panda_config + + configfiles = ['%s/etc/panda/panda.cfg' % os.environ.get('IDDS_HOME', ''), + '/etc/panda/panda.cfg', '/opt/idds/etc/panda/panda.cfg', + '%s/etc/panda/panda.cfg' % os.environ.get('VIRTUAL_ENV', '')] + for configfile in configfiles: + if panda_config.read(configfile) == [configfile]: + return panda_config + return panda_config + + def load_panda_urls(self): + panda_config = self.load_panda_config() + # self.logger.debug("panda config: %s" % panda_config) + self.panda_url = None + self.panda_url_ssl = None + self.panda_monitor = None + self.panda_auth = None + self.panda_auth_vo = None + self.panda_config_root = None + self.pandacache_url = None + self.panda_verify_host = None + + if panda_config.has_section('panda'): + if 'PANDA_MONITOR_URL' not in os.environ and panda_config.has_option('panda', 'panda_monitor_url'): + self.panda_monitor = panda_config.get('panda', 'panda_monitor_url') + os.environ['PANDA_MONITOR_URL'] = self.panda_monitor + # self.logger.debug("Panda monitor url: %s" % str(self.panda_monitor)) + if 'PANDA_URL' not in os.environ and panda_config.has_option('panda', 'panda_url'): + self.panda_url = panda_config.get('panda', 'panda_url') + os.environ['PANDA_URL'] = self.panda_url + # self.logger.debug("Panda url: %s" % str(self.panda_url)) + if 'PANDACACHE_URL' not in os.environ and panda_config.has_option('panda', 'pandacache_url'): + self.pandacache_url = panda_config.get('panda', 'pandacache_url') + os.environ['PANDACACHE_URL'] = self.pandacache_url + # self.logger.debug("Pandacache url: %s" % str(self.pandacache_url)) + if 'PANDA_VERIFY_HOST' not in os.environ and panda_config.has_option('panda', 'panda_verify_host'): + self.panda_verify_host = panda_config.get('panda', 'panda_verify_host') + os.environ['PANDA_VERIFY_HOST'] = self.panda_verify_host + # self.logger.debug("Panda verify host: %s" % str(self.panda_verify_host)) + if 'PANDA_URL_SSL' not in os.environ and panda_config.has_option('panda', 'panda_url_ssl'): + self.panda_url_ssl = panda_config.get('panda', 'panda_url_ssl') + os.environ['PANDA_URL_SSL'] = self.panda_url_ssl + # self.logger.debug("Panda url ssl: %s" % str(self.panda_url_ssl)) + if 'PANDA_AUTH' not in os.environ and panda_config.has_option('panda', 'panda_auth'): + self.panda_auth = panda_config.get('panda', 'panda_auth') + os.environ['PANDA_AUTH'] = self.panda_auth + if 'PANDA_AUTH_VO' not in os.environ and panda_config.has_option('panda', 'panda_auth_vo'): + self.panda_auth_vo = panda_config.get('panda', 'panda_auth_vo') + os.environ['PANDA_AUTH_VO'] = self.panda_auth_vo + if 'PANDA_CONFIG_ROOT' not in os.environ and panda_config.has_option('panda', 'panda_config_root'): + self.panda_config_root = panda_config.get('panda', 'panda_config_root') + os.environ['PANDA_CONFIG_ROOT'] = self.panda_config_root + + def submit(self, work, logger=None, log_prefix=''): + from pandaclient import Client + + task_params = self.get_task_params(work) + try: + return_code = Client.insertTaskParams(task_params, verbose=True) + if return_code[0] == 0 and return_code[1][0] is True: + try: + task_id = int(return_code[1][1]) + return task_id, None + except Exception as ex: + if logger: + logger.warn(log_prefix + "task id is not retruned: (%s) is not task id: %s" % (return_code[1][1], str(ex))) + if return_code[1][1] and 'jediTaskID=' in return_code[1][1]: + parts = return_code[1][1].split(" ") + for part in parts: + if 'jediTaskID=' in part: + task_id = int(part.split("=")[1]) + return task_id, None + else: + raise Exception(return_code) + else: + if logger: + logger.warn(log_prefix + "submit_panda_task, return_code: %s" % str(return_code)) + raise Exception(return_code) + except Exception as ex: + if logger: + logger.error(log_prefix + str(ex)) + logger.error(traceback.format_exc()) + raise ex + + def get_processing_status(self, task_status): + if task_status in ['registered', 'defined', 'assigning']: + processing_status = ProcessingStatus.Submitting + elif task_status in ['ready', 'scouting', 'scouted', 'prepared', 'topreprocess', 'preprocessing']: + processing_status = ProcessingStatus.Submitting + elif task_status in ['pending']: + processing_status = ProcessingStatus.Submitted + elif task_status in ['running', 'toretry', 'toincexec', 'throttled']: + processing_status = ProcessingStatus.Running + elif task_status in ['done']: + processing_status = ProcessingStatus.Finished + elif task_status in ['finished', 'paused']: + # finished, finishing, waiting it to be done + processing_status = ProcessingStatus.SubFinished + elif task_status in ['failed', 'exhausted']: + # aborting, tobroken + processing_status = ProcessingStatus.Failed + elif task_status in ['aborted']: + # aborting, tobroken + processing_status = ProcessingStatus.Cancelled + elif task_status in ['broken']: + processing_status = ProcessingStatus.Broken + else: + # finished, finishing, aborting, topreprocess, preprocessing, tobroken + # toretry, toincexec, rerefine, paused, throttled, passed + processing_status = ProcessingStatus.Submitted + return processing_status + + def poll(self, workload_id, logger=None, log_prefix=''): + from pandaclient import Client + + try: + status, task_status = Client.getTaskStatus(workload_id) + if status == 0: + return self.get_processing_status(task_status) + else: + msg = "Failed to poll task %s: status: %s, task_status: %s" % (workload_id, status, task_status) + raise Exception(msg) + except Exception as ex: + if logger: + logger.error(log_prefix + str(ex)) + logger.error(traceback.format_exc()) + raise ex diff --git a/main/lib/idds/agents/carrier/poller.py b/main/lib/idds/agents/carrier/poller.py index a2ab47e2..0b9e8a46 100644 --- a/main/lib/idds/agents/carrier/poller.py +++ b/main/lib/idds/agents/carrier/poller.py @@ -6,7 +6,7 @@ # http://www.apache.org/licenses/LICENSE-2.0OA # # Authors: -# - Wen Guan, , 2019 - 2023 +# - Wen Guan, , 2019 - 2024 import datetime import random @@ -14,7 +14,8 @@ import traceback from idds.common import exceptions -from idds.common.constants import Sections, ReturnCode, ProcessingStatus, ProcessingLocking +from idds.common.constants import (Sections, ReturnCode, ProcessingType, + ProcessingStatus, ProcessingLocking) from idds.common.utils import setup_logging, truncate_string, json_dumps from idds.core import processings as core_processings from idds.agents.common.baseagent import BaseAgent @@ -25,6 +26,7 @@ TerminatedProcessingEvent) from .utils import handle_update_processing, is_process_terminated, is_process_finished +from .iutils import handle_update_iprocessing setup_logging(__name__) @@ -147,8 +149,7 @@ def get_running_processings(self): ProcessingStatus.ToSuspend, ProcessingStatus.Suspending, ProcessingStatus.ToResume, ProcessingStatus.Resuming, ProcessingStatus.ToExpire, ProcessingStatus.Expiring, - ProcessingStatus.ToFinish, ProcessingStatus.ToForceFinish, - ProcessingStatus.Terminating] + ProcessingStatus.ToFinish, ProcessingStatus.ToForceFinish] # next_poll_at = datetime.datetime.utcnow() + datetime.timedelta(seconds=self.poll_period) processings = core_processings.get_processings_by_status(status=processing_status, locking=True, update_poll=True, @@ -199,8 +200,12 @@ def get_work_tag_attribute(self, work_tag, attribute): return work_tag_attribute_value def load_poll_period(self, processing, parameters, new=False): - proc = processing['processing_metadata']['processing'] - work = proc.work + if 'processing' in processing['processing_metadata']: + proc = processing['processing_metadata']['processing'] + work = proc.work + else: + work = processing['processing_metadata']['work'] + work_tag = work.get_work_tag() work_tag_new_poll_period = self.get_work_tag_attribute(work_tag, "new_poll_period") @@ -333,6 +338,9 @@ def handle_update_processing(self, processing): work.reactivate_processing(processing, log_prefix=log_prefix) process_status = ProcessingStatus.Running new_process_status = ProcessingStatus.Running + else: + if (update_contents or new_contents or new_contents_ext or update_contents_ext or ret_msgs): + new_process_status = ProcessingStatus.Synchronizing update_processing = {'processing_id': processing['processing_id'], 'parameters': {'status': new_process_status, @@ -413,6 +421,113 @@ def handle_update_processing(self, processing): 'update_contents': []} return ret + def handle_update_iprocessing(self, processing): + try: + log_prefix = self.get_log_prefix(processing) + + executors, plugin = None, None + if processing['processing_type']: + plugin_name = processing['processing_type'].name.lower() + '_poller' + plugin = self.get_plugin(plugin_name) + else: + raise exceptions.ProcessSubmitFailed('No corresponding submitter plugins for %s' % processing['processing_type']) + + ret_handle_update_processing = handle_update_iprocessing(processing, + self.agent_attributes, + plugin=plugin, + max_updates_per_round=self.max_updates_per_round, + executors=executors, + logger=self.logger, + log_prefix=log_prefix) + + process_status, new_contents, new_input_dependency_contents, ret_msgs, update_contents, parameters, new_contents_ext, update_contents_ext = ret_handle_update_processing + + new_process_status = process_status + if is_process_terminated(process_status): + new_process_status = ProcessingStatus.Terminating + if is_process_finished(process_status): + new_process_status = ProcessingStatus.Terminating + else: + new_process_status = ProcessingStatus.Terminating + + update_processing = {'processing_id': processing['processing_id'], + 'parameters': {'status': new_process_status, + 'substatus': process_status, + 'locking': ProcessingLocking.Idle}} + + update_processing['parameters'] = self.load_poll_period(processing, update_processing['parameters']) + + if 'submitted_at' in processing['processing_metadata']: + if not processing['submitted_at'] or processing['submitted_at'] < processing['processing_metadata']['submitted_at']: + parameters['submitted_at'] = processing['processing_metadata']['submitted_at'] + + if 'workload_id' in processing['processing_metadata']: + parameters['workload_id'] = processing['processing_metadata']['workload_id'] + + # update_processing['parameters']['expired_at'] = work.get_expired_at(processing) + update_processing['parameters']['processing_metadata'] = processing['processing_metadata'] + + if parameters: + # special parameters such as 'output_metadata' + for p in parameters: + update_processing['parameters'][p] = parameters[p] + + ret = {'update_processing': update_processing, + 'update_contents': update_contents, + 'new_contents': new_contents, + 'new_input_dependency_contents': new_input_dependency_contents, + 'messages': ret_msgs, + 'new_contents_ext': new_contents_ext, + 'update_contents_ext': update_contents_ext, + 'processing_status': new_process_status} + except exceptions.ProcessFormatNotSupported as ex: + self.logger.error(ex) + self.logger.error(traceback.format_exc()) + + retries = processing['update_retries'] + 1 + if not processing['max_update_retries'] or retries < processing['max_update_retries']: + proc_status = ProcessingStatus.Running + else: + proc_status = ProcessingStatus.Failed + error = {'update_err': {'msg': truncate_string('%s' % (ex), length=200)}} + + # increase poll period + update_poll_period = int(processing['update_poll_period'].total_seconds() * self.poll_period_increase_rate) + if update_poll_period > self.max_update_poll_period: + update_poll_period = self.max_update_poll_period + + update_processing = {'processing_id': processing['processing_id'], + 'parameters': {'status': proc_status, + 'locking': ProcessingLocking.Idle, + 'update_retries': retries, + 'update_poll_period': update_poll_period, + 'errors': processing['errors'] if processing['errors'] else {}}} + update_processing['parameters']['errors'].update(error) + + ret = {'update_processing': update_processing, + 'update_contents': []} + except Exception as ex: + self.logger.error(ex) + self.logger.error(traceback.format_exc()) + + retries = processing['update_retries'] + 1 + if not processing['max_update_retries'] or retries < processing['max_update_retries']: + proc_status = ProcessingStatus.Running + else: + proc_status = ProcessingStatus.Failed + error = {'update_err': {'msg': truncate_string('%s' % (ex), length=200)}} + update_processing = {'processing_id': processing['processing_id'], + 'parameters': {'status': proc_status, + 'locking': ProcessingLocking.Idle, + 'update_retries': retries, + 'errors': processing['errors'] if processing['errors'] else {}}} + update_processing['parameters']['errors'].update(error) + update_processing['parameters'] = self.load_poll_period(processing, update_processing['parameters']) + + ret = {'update_processing': update_processing, + 'update_contents': []} + return ret + def process_update_processing(self, event): self.number_workers += 1 pro_ret = ReturnCode.Ok.value @@ -429,7 +544,10 @@ def process_update_processing(self, event): log_pre = self.get_log_prefix(pr) self.logger.info(log_pre + "process_update_processing") - ret = self.handle_update_processing(pr) + if pr['processing_type'] and pr['processing_type'] in [ProcessingType.iWorkflow, ProcessingType.iWork]: + ret = self.handle_update_iprocessing(pr) + else: + ret = self.handle_update_processing(pr) # self.logger.info(log_pre + "process_update_processing result: %s" % str(ret)) self.update_processing(ret, pr) @@ -452,11 +570,7 @@ def process_update_processing(self, event): event.set_terminating() self.event_bus.send(event) else: - if (('update_contents' in ret and ret['update_contents']) - or ('new_contents' in ret and ret['new_contents']) # noqa W503 - or ('new_contents_ext' in ret and ret['new_contents_ext']) # noqa W503 - or ('update_contents_ext' in ret and ret['update_contents_ext']) # noqa W503 - or ('messages' in ret and ret['messages'])): # noqa E129 + if 'processing_status' in ret and ret['processing_status'] == ProcessingStatus.Synchronizing: self.logger.info(log_pre + "SyncProcessingEvent(processing_id: %s)" % pr['processing_id']) event = SyncProcessingEvent(publisher_id=self.id, processing_id=pr['processing_id'], counter=original_event._counter) diff --git a/main/lib/idds/agents/carrier/submitter.py b/main/lib/idds/agents/carrier/submitter.py index 83fde5cf..782f3b09 100644 --- a/main/lib/idds/agents/carrier/submitter.py +++ b/main/lib/idds/agents/carrier/submitter.py @@ -6,12 +6,12 @@ # http://www.apache.org/licenses/LICENSE-2.0OA # # Authors: -# - Wen Guan, , 2019 - 2022 +# - Wen Guan, , 2019 - 2024 import traceback from idds.common import exceptions -from idds.common.constants import ProcessingStatus, ProcessingLocking +from idds.common.constants import ProcessingType, ProcessingStatus, ProcessingLocking from idds.common.utils import setup_logging, truncate_string from idds.core import processings as core_processings from idds.agents.common.baseagent import BaseAgent @@ -21,6 +21,7 @@ UpdateTransformEvent) from .utils import handle_new_processing +from .iutils import handle_new_iprocessing from .poller import Poller setup_logging(__name__) @@ -178,6 +179,80 @@ def handle_new_processing(self, processing): 'update_contents': []} return ret + def handle_new_iprocessing(self, processing): + try: + log_prefix = self.get_log_prefix(processing) + + # transform_id = processing['transform_id'] + # transform = core_transforms.get_transform(transform_id=transform_id) + # work = transform['transform_metadata']['work'] + executors, plugin = None, None + if processing['processing_type']: + plugin_name = processing['processing_type'].name.lower() + '_submitter' + plugin = self.get_plugin(plugin_name) + else: + raise exceptions.ProcessSubmitFailed('No corresponding submitter plugins for %s' % processing['processing_type']) + ret_new_processing = handle_new_iprocessing(processing, + self.agent_attributes, + plugin=plugin, + func_site_to_cloud=self.get_site_to_cloud, + max_updates_per_round=self.max_updates_per_round, + executors=executors, + logger=self.logger, + log_prefix=log_prefix) + status, processing, update_colls, new_contents, new_input_dependency_contents, msgs, errors = ret_new_processing + + if not status: + raise exceptions.ProcessSubmitFailed(str(errors)) + + parameters = {'status': ProcessingStatus.Submitting, + 'substatus': ProcessingStatus.Submitting, + 'locking': ProcessingLocking.Idle, + 'processing_metadata': processing['processing_metadata']} + parameters = self.load_poll_period(processing, parameters, new=True) + + if 'submitted_at' in processing: + parameters['submitted_at'] = processing['submitted_at'] + + if 'workload_id' in processing: + parameters['workload_id'] = processing['workload_id'] + + update_processing = {'processing_id': processing['processing_id'], + 'parameters': parameters} + ret = {'update_processing': update_processing, + 'update_collections': update_colls, + 'update_contents': [], + 'new_contents': new_contents, + 'new_input_dependency_contents': new_input_dependency_contents, + 'messages': msgs, + } + except Exception as ex: + self.logger.error(ex) + self.logger.error(traceback.format_exc()) + + retries = processing['new_retries'] + 1 + if not processing['max_new_retries'] or retries < processing['max_new_retries']: + pr_status = processing['status'] + else: + pr_status = ProcessingStatus.Failed + # increase poll period + new_poll_period = int(processing['new_poll_period'].total_seconds() * self.poll_period_increase_rate) + if new_poll_period > self.max_new_poll_period: + new_poll_period = self.max_new_poll_period + + error = {'submit_err': {'msg': truncate_string('%s' % str(ex), length=200)}} + parameters = {'status': pr_status, + 'new_poll_period': new_poll_period, + 'errors': processing['errors'] if processing['errors'] else {}, + 'new_retries': retries} + parameters['errors'].update(error) + + update_processing = {'processing_id': processing['processing_id'], + 'parameters': parameters} + ret = {'update_processing': update_processing, + 'update_contents': []} + return ret + def process_new_processing(self, event): self.number_workers += 1 try: @@ -190,7 +265,10 @@ def process_new_processing(self, event): else: log_pre = self.get_log_prefix(pr) self.logger.info(log_pre + "process_new_processing") - ret = self.handle_new_processing(pr) + if pr['processing_type'] and pr['processing_type'] in [ProcessingType.iWorkflow, ProcessingType.iWork]: + ret = self.handle_new_iprocessing(pr) + else: + ret = self.handle_new_processing(pr) # self.logger.info(log_pre + "process_new_processing result: %s" % str(ret)) self.update_processing(ret, pr) diff --git a/main/lib/idds/agents/clerk/clerk.py b/main/lib/idds/agents/clerk/clerk.py index 2a135c2b..213cbf7c 100644 --- a/main/lib/idds/agents/clerk/clerk.py +++ b/main/lib/idds/agents/clerk/clerk.py @@ -15,11 +15,12 @@ from idds.common import exceptions from idds.common.constants import (Sections, ReturnCode, - RequestStatus, RequestLocking, + RequestType, RequestStatus, RequestLocking, + TransformType, WorkflowType, TransformStatus, ProcessingStatus, ContentStatus, ContentRelationType, CommandType, CommandStatus, CommandLocking) -from idds.common.utils import setup_logging, truncate_string +from idds.common.utils import setup_logging, truncate_string, str_to_date from idds.core import (requests as core_requests, transforms as core_transforms, processings as core_processings, @@ -402,11 +403,14 @@ def get_work_tag_attribute(self, work_tag, attribute): work_tag_attribute_value = int(getattr(self, work_tag_attribute)) return work_tag_attribute_value - def generate_transform(self, req, work, build=False): - if build: - wf = req['request_metadata']['build_workflow'] + def generate_transform(self, req, work, build=False, iworkflow=False): + if iworkflow: + wf = None else: - wf = req['request_metadata']['workflow'] + if build: + wf = req['request_metadata']['build_workflow'] + else: + wf = req['request_metadata']['workflow'] work.set_request_id(req['request_id']) work.username = req['username'] @@ -430,13 +434,25 @@ def generate_transform(self, req, work, build=False): else: max_update_retries = self.max_update_retries + transform_type = TransformType.Workflow + try: + work_type = work.get_work_type() + if work_type in [WorkflowType.iWorkflow]: + transform_type = TransformType.iWorkflow + elif work_type in [WorkflowType.iWork]: + transform_type = TransformType.iWork + except Exception: + pass + new_transform = {'request_id': req['request_id'], 'workload_id': req['workload_id'], - 'transform_type': work.get_work_type(), + 'transform_type': transform_type, 'transform_tag': work.get_work_tag(), 'priority': req['priority'], 'status': TransformStatus.New, 'retries': 0, + 'parent_transform_id': None, + 'previous_transform_id': None, 'name': work.get_work_name(), 'new_poll_period': self.new_poll_period, 'update_poll_period': self.update_poll_period, @@ -713,6 +729,57 @@ def handle_new_request(self, req): self.logger.warn(log_pre + "Handle new request error result: %s" % str(ret_req)) return ret_req + def handle_new_irequest(self, req): + try: + log_pre = self.get_log_prefix(req) + self.logger.info(log_pre + "Handle new irequest") + to_throttle = self.whether_to_throttle(req) + if to_throttle: + ret_req = {'request_id': req['request_id'], + 'parameters': {'status': RequestStatus.Throttling, + 'locking': RequestLocking.Idle}} + ret_req['parameters'] = self.load_poll_period(req, ret_req['parameters'], throttling=True) + self.logger.info(log_pre + "Throttle new irequest result: %s" % str(ret_req)) + else: + workflow = req['request_metadata']['workflow'] + + transforms = [] + transform = self.generate_transform(req, workflow) + transforms.append(transform) + self.logger.debug(log_pre + "Processing request(%s): new transforms: %s" % (req['request_id'], + str(transforms))) + ret_req = {'request_id': req['request_id'], + 'parameters': {'status': RequestStatus.Transforming, + 'locking': RequestLocking.Idle}, + 'new_transforms': transforms} + ret_req['parameters'] = self.load_poll_period(req, ret_req['parameters']) + self.logger.info(log_pre + "Handle new irequest result: %s" % str(ret_req)) + except Exception as ex: + self.logger.error(ex) + self.logger.error(traceback.format_exc()) + retries = req['new_retries'] + 1 + if not req['max_new_retries'] or retries < req['max_new_retries']: + req_status = req['status'] + else: + req_status = RequestStatus.Failed + + # increase poll period + new_poll_period = int(req['new_poll_period'].total_seconds() * self.poll_period_increase_rate) + if new_poll_period > self.max_new_poll_period: + new_poll_period = self.max_new_poll_period + + error = {'submit_err': {'msg': truncate_string('%s' % (ex), length=200)}} + + ret_req = {'request_id': req['request_id'], + 'parameters': {'status': req_status, + 'locking': RequestLocking.Idle, + 'new_retries': retries, + 'new_poll_period': new_poll_period, + 'errors': req['errors'] if req['errors'] else {}}} + ret_req['parameters']['errors'].update(error) + self.logger.warn(log_pre + "Handle new irequest error result: %s" % str(ret_req)) + return ret_req + def has_to_build_work(self, req): try: if req['status'] in [RequestStatus.New] and 'build_workflow' in req['request_metadata']: @@ -856,6 +923,8 @@ def process_new_request(self, event): log_pre = self.get_log_prefix(req) if self.has_to_build_work(req): ret = self.handle_build_request(req) + elif req['request_type'] in [RequestType.iWorkflow]: + ret = self.handle_new_irequest(req) else: ret = self.handle_new_request(req) new_tf_ids, update_tf_ids = self.update_request(ret) @@ -1082,7 +1151,102 @@ def handle_update_request(self, req, event): 'errors': req['errors'] if req['errors'] else {}}} ret_req['parameters']['errors'].update(error) log_pre = self.get_log_prefix(req) - self.logger.warn(log_pre + "Handle new request exception result: %s" % str(ret_req)) + self.logger.warn(log_pre + "Handle update request exception result: %s" % str(ret_req)) + return ret_req + + def is_to_expire(self, expired_at=None, pending_time=None, request_id=None): + if expired_at: + if type(expired_at) in [str]: + expired_at = str_to_date(expired_at) + if expired_at < datetime.datetime.utcnow(): + self.logger.info("Request(%s) expired_at(%s) is smaller than utc now(%s), expiring" % (request_id, + expired_at, + datetime.datetime.utcnow())) + return True + return False + + def handle_update_irequest_real(self, req, event): + """ + process running request + """ + log_pre = self.get_log_prefix(req) + self.logger.info(log_pre + " handle_update_irequest: request_id: %s" % req['request_id']) + + tfs = core_transforms.get_transforms(request_id=req['request_id']) + total_tfs, finished_tfs, subfinished_tfs, failed_tfs = 0, 0, 0, 0 + for tf in tfs: + total_tfs += 1 + if tf['status'] in [TransformStatus.Finished, TransformStatus.Built]: + finished_tfs += 1 + elif tf['status'] in [TransformStatus.SubFinished]: + subfinished_tfs += 1 + elif tf['status'] in [TransformStatus.Failed, TransformStatus.Cancelled, + TransformStatus.Suspended, TransformStatus.Expired]: + failed_tfs += 1 + + req_status = RequestStatus.Transforming + if total_tfs == finished_tfs: + req_status = RequestStatus.Finished + elif total_tfs == finished_tfs + subfinished_tfs + failed_tfs: + if finished_tfs + subfinished_tfs > 0: + req_status = RequestStatus.SubFinished + else: + req_status = RequestStatus.Failed + + log_msg = log_pre + "ireqeust %s status: %s" % (req['request_id'], req_status) + log_msg = log_msg + "(transforms: total %s, finished: %s, subfinished: %s, failed %s)" % (total_tfs, finished_tfs, subfinished_tfs, failed_tfs) + self.logger.debug(log_msg) + + if req_status not in [RequestStatus.Finished, RequestStatus.SubFinished, RequestStatus.Failed]: + if self.is_to_expire(req['expired_at'], self.pending_time, request_id=req['request_id']): + event_content = {'request_id': req['request_id'], + 'cmd_type': CommandType.ExpireRequest, + 'cmd_content': {}} + self.logger.debug(log_pre + "ExpireRequestEvent(request_id: %s)" % req['request_id']) + event = ExpireRequestEvent(publisher_id=self.id, request_id=req['request_id'], content=event_content) + self.event_bus.send(event) + + parameters = {'status': req_status, + 'locking': RequestLocking.Idle, + 'request_metadata': req['request_metadata'] + } + parameters = self.load_poll_period(req, parameters) + + ret = {'request_id': req['request_id'], + 'parameters': parameters} + self.logger.info(log_pre + "Handle update irequest result: %s" % str(ret)) + return ret + + def handle_update_irequest(self, req, event): + """ + process running irequest + """ + try: + ret_req = self.handle_update_irequest_real(req, event) + except Exception as ex: + self.logger.error(ex) + self.logger.error(traceback.format_exc()) + retries = req['update_retries'] + 1 + if not req['max_update_retries'] or retries < req['max_update_retries']: + req_status = req['status'] + else: + req_status = RequestStatus.Failed + error = {'update_err': {'msg': truncate_string('%s' % (ex), length=200)}} + + # increase poll period + update_poll_period = int(req['update_poll_period'].total_seconds() * self.poll_period_increase_rate) + if update_poll_period > self.max_update_poll_period: + update_poll_period = self.max_update_poll_period + + ret_req = {'request_id': req['request_id'], + 'parameters': {'status': req_status, + 'locking': RequestLocking.Idle, + 'update_retries': retries, + 'update_poll_period': update_poll_period, + 'errors': req['errors'] if req['errors'] else {}}} + ret_req['parameters']['errors'].update(error) + log_pre = self.get_log_prefix(req) + self.logger.warn(log_pre + "Handle update irequest exception result: %s" % str(ret_req)) return ret_req def process_update_request(self, event): @@ -1105,7 +1269,10 @@ def process_update_request(self, event): pro_ret = ReturnCode.Locked.value else: log_pre = self.get_log_prefix(req) - ret = self.handle_update_request(req, event=event) + if req['request_type'] in [RequestType.iWorkflow]: + ret = self.handle_update_irequest(req, event=event) + else: + ret = self.handle_update_request(req, event=event) new_tf_ids, update_tf_ids = self.update_request(ret) for tf_id in new_tf_ids: self.logger.info(log_pre + "NewTransformEvent(transform_id: %s)" % tf_id) diff --git a/main/lib/idds/agents/common/baseagent.py b/main/lib/idds/agents/common/baseagent.py index 578b9721..9759ab03 100644 --- a/main/lib/idds/agents/common/baseagent.py +++ b/main/lib/idds/agents/common/baseagent.py @@ -15,6 +15,7 @@ import threading import uuid +from idds.common import exceptions from idds.common.constants import Sections from idds.common.constants import (MessageType, MessageTypeStr, MessageStatus, MessageSource, @@ -133,6 +134,7 @@ def load_plugin_sequence(self): def load_plugins(self): self.plugins = load_plugins(self.config_section, logger=self.logger) + self.logger.info("plugins: %s" % str(self.plugins)) """ for plugin_name in self.plugin_sequence: if plugin_name not in self.plugins: @@ -142,6 +144,11 @@ def load_plugins(self): raise AgentPluginError("Plugin %s is defined but it is not defined in plugin_sequence" % plugin_name) """ + def get_plugin(self, plugin_name): + if plugin_name in self.plugins and self.plugins[plugin_name]: + return self.plugins[plugin_name] + raise exceptions.AgentPluginError("No corresponding plugin configured for %s" % plugin_name) + def get_num_hang_active_workers(self): return self.num_hang_workers, self.num_active_workers diff --git a/main/lib/idds/agents/common/eventbus/msgeventbusbackend.py b/main/lib/idds/agents/common/eventbus/msgeventbusbackend.py index b9104e67..ccb4848c 100644 --- a/main/lib/idds/agents/common/eventbus/msgeventbusbackend.py +++ b/main/lib/idds/agents/common/eventbus/msgeventbusbackend.py @@ -417,12 +417,14 @@ def send(self, event): return ret except (zmq.error.ZMQError, zmq.Again) as error: - self.logger.critical("Caught an exception: %s\n%s" % (str(error), traceback.format_exc())) + if not self.graceful_stop.is_set(): + self.logger.critical("Caught an exception: %s\n%s" % (str(error), traceback.format_exc())) self.manager_socket.close() self.cache_events.append(event) self.num_failures += 1 except Exception as error: - self.logger.critical("Caught an exception: %s\n%s" % (str(error), traceback.format_exc())) + if not self.graceful_stop.is_set(): + self.logger.critical("Caught an exception: %s\n%s" % (str(error), traceback.format_exc())) self.manager_socket.close() self.cache_events.append(event) self.num_failures += 1 @@ -458,13 +460,15 @@ def send_bulk(self, events): return ret except (zmq.error.ZMQError, zmq.Again) as error: - self.logger.critical("Caught an exception: %s\n%s" % (str(error), traceback.format_exc())) + if not self.graceful_stop.is_set(): + self.logger.critical("Caught an exception: %s\n%s" % (str(error), traceback.format_exc())) self.manager_socket.close() for event in events: self.cache_events.append(event) self.num_failures += 1 except Exception as error: - self.logger.critical("Caught an exception: %s\n%s" % (str(error), traceback.format_exc())) + if not self.graceful_stop.is_set(): + self.logger.critical("Caught an exception: %s\n%s" % (str(error), traceback.format_exc())) self.manager_socket.close() for event in events: self.cache_events.append(event) @@ -500,11 +504,13 @@ def get(self, event_type, num_events=1, wait=0): return ret except (zmq.error.ZMQError, zmq.Again) as error: - self.logger.critical("Caught an exception: %s\n%s" % (str(error), traceback.format_exc())) + if not self.graceful_stop.is_set(): + self.logger.critical("Caught an exception: %s\n%s" % (str(error), traceback.format_exc())) self.manager_socket.close() self.num_failures += 1 except Exception as error: - self.logger.critical("Caught an exception: %s\n%s" % (str(error), traceback.format_exc())) + if not self.graceful_stop.is_set(): + self.logger.critical("Caught an exception: %s\n%s" % (str(error), traceback.format_exc())) self.manager_socket.close() self.num_failures += 1 return [] diff --git a/main/lib/idds/agents/transformer/transformer.py b/main/lib/idds/agents/transformer/transformer.py index ab0d1863..40a02eed 100644 --- a/main/lib/idds/agents/transformer/transformer.py +++ b/main/lib/idds/agents/transformer/transformer.py @@ -6,7 +6,7 @@ # http://www.apache.org/licenses/LICENSE-2.0OA # # Authors: -# - Wen Guan, , 2019 - 2023 +# - Wen Guan, , 2019 - 2024 import copy import datetime @@ -15,9 +15,13 @@ import traceback from idds.common import exceptions -from idds.common.constants import (Sections, ReturnCode, +from idds.common.constants import (Sections, ReturnCode, TransformType, TransformStatus, TransformLocking, - CommandType, ProcessingStatus) + CollectionType, CollectionStatus, + CollectionRelationType, + CommandType, ProcessingStatus, WorkflowType, + get_processing_type_from_transform_type, + get_transform_status_from_processing_status) from idds.common.utils import setup_logging, truncate_string from idds.core import (transforms as core_transforms, processings as core_processings) @@ -217,6 +221,7 @@ def generate_processing_model(self, transform): # new_processing_model['expired_at'] = work.get_expired_at(None) new_processing_model['expired_at'] = transform['expired_at'] + new_processing_model['processing_type'] = get_processing_type_from_transform_type(transform['transform_type']) new_processing_model['new_poll_period'] = transform['new_poll_period'] new_processing_model['update_poll_period'] = transform['update_poll_period'] new_processing_model['max_new_retries'] = transform['max_new_retries'] @@ -312,6 +317,105 @@ def handle_new_transform(self, transform): self.logger.info(log_pre + "handle_new_transform exception result: %s" % str(ret)) return ret + def generate_collection(self, transform, collection, relation_type=CollectionRelationType.Input): + coll = {'transform_id': transform['transform_id'], + 'request_id': transform['request_id'], + 'workload_id': transform['workload_id'], + 'coll_type': CollectionType.Dataset, + 'scope': collection['scope'], + 'name': collection['name'][:254], + 'relation_type': relation_type, + 'bytes': 0, + 'total_files': 0, + 'new_files': 0, + 'processed_files': 0, + 'processing_files': 0, + 'coll_metadata': None, + 'status': CollectionStatus.Open, + 'expired_at': transform['expired_at']} + return coll + + def handle_new_itransform_real(self, transform): + """ + Process new transform + """ + log_pre = self.get_log_prefix(transform) + self.logger.info(log_pre + "handle_new_itransform: transform_id: %s" % transform['transform_id']) + + work = transform['transform_metadata']['work'] + if work.type in [WorkflowType.iWork]: + work.transform_id = transform['transform_id'] + + # create processing + new_processing_model = self.generate_processing_model(transform) + new_processing_model['processing_metadata'] = {'work': work} + + transform_parameters = {'status': TransformStatus.Transforming, + 'locking': TransformLocking.Idle, + 'workload_id': transform['workload_id']} + + transform_parameters = self.load_poll_period(transform, transform_parameters) + + if new_processing_model is not None: + if 'new_poll_period' in transform_parameters: + new_processing_model['new_poll_period'] = transform_parameters['new_poll_period'] + if 'update_poll_period' in transform_parameters: + new_processing_model['update_poll_period'] = transform_parameters['update_poll_period'] + if 'max_new_retries' in transform_parameters: + new_processing_model['max_new_retries'] = transform_parameters['max_new_retries'] + if 'max_update_retries' in transform_parameters: + new_processing_model['max_update_retries'] = transform_parameters['max_update_retries'] + + func_name = work.get_func_name() + func_name = func_name.split(':')[-1] + input_coll = {'scope': 'pseudo_dataset', 'name': 'pseudo_input_%s' % func_name} + output_coll = {'scope': 'pseudo_dataset', 'name': 'pseudo_output_%s' % func_name} + + input_collection = self.generate_collection(transform, input_coll, relation_type=CollectionRelationType.Input) + output_collection = self.generate_collection(transform, output_coll, relation_type=CollectionRelationType.Output) + + ret = {'transform': transform, + 'transform_parameters': transform_parameters, + 'new_processing': new_processing_model, + 'input_collections': [input_collection], + 'output_collections': [output_collection] + } + return ret + + def handle_new_itransform(self, transform): + """ + Process new transform + """ + try: + log_pre = self.get_log_prefix(transform) + ret = self.handle_new_itransform_real(transform) + self.logger.info(log_pre + "handle_new_itransform result: %s" % str(ret)) + except Exception as ex: + self.logger.error(ex) + self.logger.error(traceback.format_exc()) + retries = transform['new_retries'] + 1 + if not transform['max_new_retries'] or retries < transform['max_new_retries']: + tf_status = transform['status'] + else: + tf_status = TransformStatus.Failed + + # increase poll period + new_poll_period = int(transform['new_poll_period'].total_seconds() * self.poll_period_increase_rate) + if new_poll_period > self.max_new_poll_period: + new_poll_period = self.max_new_poll_period + + error = {'submit_err': {'msg': truncate_string('%s' % (ex), length=200)}} + + transform_parameters = {'status': tf_status, + 'new_retries': retries, + 'new_poll_period': new_poll_period, + 'errors': transform['errors'] if transform['errors'] else {}, + 'locking': TransformLocking.Idle} + transform_parameters['errors'].update(error) + ret = {'transform': transform, 'transform_parameters': transform_parameters} + self.logger.info(log_pre + "handle_new_itransform exception result: %s" % str(ret)) + return ret + def update_transform(self, ret): new_pr_ids, update_pr_ids = [], [] try: @@ -396,7 +500,10 @@ def process_new_transform(self, event): else: log_pre = self.get_log_prefix(tf) self.logger.info(log_pre + "process_new_transform") - ret = self.handle_new_transform(tf) + if tf['transform_type'] in [TransformType.iWorkflow, TransformType.iWork]: + ret = self.handle_new_itransform(tf) + else: + ret = self.handle_new_transform(tf) self.logger.info(log_pre + "process_new_transform result: %s" % str(ret)) new_pr_ids, update_pr_ids = self.update_transform(ret) @@ -540,6 +647,90 @@ def handle_update_transform(self, transform, event): self.logger.warn(log_pre + "handle_update_transform exception result: %s" % str(ret)) return ret, False, None + def handle_update_itransform_real(self, transform, event): + """ + process running transforms + """ + log_pre = self.get_log_prefix(transform) + + self.logger.info(log_pre + "handle_update_itransform: transform_id: %s" % transform['transform_id']) + + # work = transform['transform_metadata']['work'] + + prs = core_processings.get_processings(transform_id=transform['transform_id']) + pr = None + for pr in prs: + if pr['processing_id'] == transform['current_processing_id']: + transform['workload_id'] = pr['workload_id'] + break + + errors = None + if pr: + transform['status'] = get_transform_status_from_processing_status(pr['status']) + log_msg = log_pre + "transform id: %s, transform status: %s" % (transform['transform_id'], transform['status']) + log_msg = log_msg + ", processing id: %s, processing status: %s" % (pr['processing_id'], pr['status']) + self.logger.info(log_msg) + else: + transform['status'] = TransformStatus.Failed + log_msg = log_pre + "transform id: %s, transform status: %s" % (transform['transform_id'], transform['status']) + log_msg = log_msg + ", no attached processings." + self.logger.error(log_msg) + errors = {'submit_err': 'no attached processings'} + + is_terminated = False + if transform['status'] in [TransformStatus.Finished, TransformStatus.Failed, TransformStatus.Cancelled, + TransformStatus.SubFinished, TransformStatus.Suspended, TransformStatus.Expired]: + is_terminated = True + + transform_parameters = {'status': transform['status'], + 'locking': TransformLocking.Idle, + 'workload_id': transform['workload_id']} + transform_parameters = self.load_poll_period(transform, transform_parameters) + if errors: + transform_parameters['errors'] = errors + + ret = {'transform': transform, + 'transform_parameters': transform_parameters} + return ret, is_terminated, None + + def handle_update_itransform(self, transform, event): + """ + Process running transform + """ + try: + log_pre = self.get_log_prefix(transform) + + self.logger.info(log_pre + "handle_update_itransform: %s" % transform) + ret, is_terminated, ret_processing_id = self.handle_update_itransform_real(transform, event) + self.logger.info(log_pre + "handle_update_itransform result: %s" % str(ret)) + return ret, is_terminated, ret_processing_id + except Exception as ex: + self.logger.error(ex) + self.logger.error(traceback.format_exc()) + + retries = transform['update_retries'] + 1 + if not transform['max_update_retries'] or retries < transform['max_update_retries']: + tf_status = transform['status'] + else: + tf_status = TransformStatus.Failed + error = {'submit_err': {'msg': truncate_string('%s' % (ex), length=200)}} + + # increase poll period + update_poll_period = int(transform['update_poll_period'].total_seconds() * self.poll_period_increase_rate) + if update_poll_period > self.max_update_poll_period: + update_poll_period = self.max_update_poll_period + + transform_parameters = {'status': tf_status, + 'update_retries': retries, + 'update_poll_period': update_poll_period, + 'errors': transform['errors'] if transform['errors'] else {}, + 'locking': TransformLocking.Idle} + transform_parameters['errors'].update(error) + + ret = {'transform': transform, 'transform_parameters': transform_parameters} + self.logger.warn(log_pre + "handle_update_itransform exception result: %s" % str(ret)) + return ret, False, None + def process_update_transform(self, event): self.number_workers += 1 pro_ret = ReturnCode.Ok.value @@ -559,7 +750,10 @@ def process_update_transform(self, event): else: log_pre = self.get_log_prefix(tf) - ret, is_terminated, ret_processing_id = self.handle_update_transform(tf, event) + if tf['transform_type'] in [TransformType.iWorkflow, TransformType.iWork]: + ret, is_terminated, ret_processing_id = self.handle_update_itransform(tf, event) + else: + ret, is_terminated, ret_processing_id = self.handle_update_transform(tf, event) new_pr_ids, update_pr_ids = self.update_transform(ret) if is_terminated or (event._content and 'event' in event._content and event._content['event'] == 'submitted'): diff --git a/main/lib/idds/core/authentication.py b/main/lib/idds/core/authentication.py new file mode 100644 index 00000000..018ecd9d --- /dev/null +++ b/main/lib/idds/core/authentication.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0OA +# +# Authors: +# - Wen Guan, , 2024 + +import base64 +import json +import jwt + +# from cryptography import x509 +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization + +from idds.common import authentication + + +def decode_value(val): + if isinstance(val, str): + val = val.encode() + decoded = base64.urlsafe_b64decode(val + b'==') + return int.from_bytes(decoded, 'big') + + +class OIDCAuthentication(authentication.OIDCAuthentication): + def __init__(self, timeout=None): + super(OIDCAuthentication, self).__init__(timeout=timeout) + + def get_public_key(self, token, jwks_uri, no_verify=False): + headers = jwt.get_unverified_header(token) + if headers is None or 'kid' not in headers: + raise jwt.exceptions.InvalidTokenError('cannot extract kid from headers') + kid = headers['kid'] + + jwks = self.get_cache_value(jwks_uri) + if not jwks: + jwks_content = self.get_http_content(jwks_uri, no_verify=no_verify) + jwks = json.loads(jwks_content) + self.set_cache_value(jwks_uri, jwks) + + jwk = None + for j in jwks.get('keys', []): + if j.get('kid') == kid: + jwk = j + if jwk is None: + raise jwt.exceptions.InvalidTokenError('JWK not found for kid={0}: {1}'.format(kid, str(jwks))) + + public_num = RSAPublicNumbers(n=decode_value(jwk['n']), e=decode_value(jwk['e'])) + public_key = public_num.public_key(default_backend()) + pem = public_key.public_bytes(encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo) + return pem + + def verify_id_token(self, vo, token): + try: + auth_config, endpoint_config = self.get_auth_endpoint_config(vo) + + # check audience + decoded_token = jwt.decode(token, verify=False, options={"verify_signature": False}) + audience = decoded_token['aud'] + if audience not in [auth_config['audience'], auth_config['client_id']]: + # discovery_endpoint = auth_config['oidc_config_url'] + return False, "The audience %s of the token doesn't match vo configuration(client_id: %s)." % (audience, auth_config['client_id']), None + + public_key = self.get_public_key(token, endpoint_config['jwks_uri'], no_verify=auth_config['no_verify']) + # decode token only with RS256 + if 'iss' in decoded_token and decoded_token['iss'] and decoded_token['iss'] != endpoint_config['issuer'] and endpoint_config['issuer'].startswith(decoded_token['iss']): + # iss is missing the last '/' in access tokens + issuer = decoded_token['iss'] + else: + issuer = endpoint_config['issuer'] + + decoded = jwt.decode(token, public_key, verify=True, algorithms='RS256', + audience=audience, issuer=issuer) + decoded['vo'] = vo + if 'name' in decoded: + username = decoded['name'] + else: + username = None + return True, decoded, username + except Exception as error: + return False, 'Failed to verify oidc token: ' + str(error), None + + +class OIDCAuthenticationUtils(authentication.OIDCAuthenticationUtils): + def __init__(self): + super(OIDCAuthenticationUtils, self).__init__() + + +class X509Authentication(authentication.X509Authentication): + def __init__(self, timeout=None): + super(X509Authentication, self).__init__(timeout=timeout) + + +def get_user_name_from_dn1(dn): + return authentication.get_user_name_from_dn1(dn) + + +def get_user_name_from_dn2(dn): + return authentication.get_user_name_from_dn2(dn) + + +def get_user_name_from_dn(dn): + dn = get_user_name_from_dn1(dn) + dn = get_user_name_from_dn2(dn) + return dn + + +def authenticate_x509(vo, dn, client_cert): + return authentication.authenticate_x509(vo, dn, client_cert) + + +def authenticate_oidc(vo, token): + oidc_auth = OIDCAuthentication() + status, data, username = oidc_auth.verify_id_token(vo, token) + if status: + return status, data, username + else: + return status, data, username + + +def authenticate_is_super_user(username, dn=None): + return authentication.authenticate_is_super_user(username=username, dn=dn) diff --git a/main/lib/idds/core/processings.py b/main/lib/idds/core/processings.py index 5c35c886..80075a1b 100644 --- a/main/lib/idds/core/processings.py +++ b/main/lib/idds/core/processings.py @@ -16,7 +16,7 @@ import datetime from idds.orm.base.session import read_session, transactional_session -from idds.common.constants import ProcessingLocking, ProcessingStatus, GranularityType, ContentRelationType +from idds.common.constants import ProcessingLocking, ProcessingStatus, ProcessingType, GranularityType, ContentRelationType from idds.common.utils import get_list_chunks from idds.orm import (processings as orm_processings, collections as orm_collections, @@ -29,6 +29,7 @@ def add_processing(request_id, workload_id, transform_id, status, submitter=None, substatus=ProcessingStatus.New, granularity=None, granularity_type=GranularityType.File, + processing_type=ProcessingType.Workflow, new_poll_period=1, update_poll_period=10, new_retries=0, update_retries=0, max_new_retries=3, max_update_retries=0, expired_at=None, processing_metadata=None, session=None): @@ -58,6 +59,7 @@ def add_processing(request_id, workload_id, transform_id, status, submitter=None new_retries=new_retries, update_retries=update_retries, max_new_retries=max_new_retries, max_update_retries=max_update_retries, + processing_type=processing_type, expired_at=expired_at, processing_metadata=processing_metadata, session=session) diff --git a/main/lib/idds/core/requests.py b/main/lib/idds/core/requests.py index d9b08aa8..d13c3199 100644 --- a/main/lib/idds/core/requests.py +++ b/main/lib/idds/core/requests.py @@ -245,7 +245,7 @@ def generate_collection(transform, collection, relation_type=CollectionRelationT if collection.status is None: collection.status = CollectionStatus.Open - coll = {'transform_id': transform['request_id'], + coll = {'transform_id': transform['transform_id'], 'request_id': transform['request_id'], 'workload_id': transform['workload_id'], 'coll_type': coll_type, @@ -267,6 +267,9 @@ def generate_collection(transform, collection, relation_type=CollectionRelationT def generate_collections(transform): work = transform['transform_metadata']['work'] + if not hasattr(work, 'get_input_collections'): + return [] + input_collections = work.get_input_collections() output_collections = work.get_output_collections() log_collections = work.get_log_collections() @@ -313,9 +316,13 @@ def update_request_with_transforms(request_id, parameters, # work = tf['transform_metadata']['work'] # original_work.set_work_id(tf_id, transforming=True) # original_work.set_status(WorkStatus.New) - work.set_work_id(tf_id, transforming=True) - work.set_status(WorkStatus.New) - workflow.refresh_works() + if hasattr(work, 'set_work_id'): + work.set_work_id(tf_id, transforming=True) + if hasattr(work, 'set_status'): + work.set_status(WorkStatus.New) + if workflow is not None: + if hasattr(workflow, 'refresh_works'): + workflow.refresh_works() collections = generate_collections(tf) for coll in collections: @@ -327,7 +334,8 @@ def update_request_with_transforms(request_id, parameters, collection.coll_id = coll_id # update transform to record the coll_id - work.refresh_work() + if hasattr(work, 'refresh_works'): + work.refresh_work() orm_transforms.update_transform(transform_id=tf_id, parameters={'transform_metadata': tf['transform_metadata']}, session=session) diff --git a/main/lib/idds/core/transforms.py b/main/lib/idds/core/transforms.py index 5a029c69..f6643b40 100644 --- a/main/lib/idds/core/transforms.py +++ b/main/lib/idds/core/transforms.py @@ -33,6 +33,7 @@ def add_transform(request_id, workload_id, transform_type, transform_tag=None, p status=TransformStatus.New, substatus=TransformStatus.New, locking=TransformLocking.Idle, new_poll_period=1, update_poll_period=10, retries=0, expired_at=None, transform_metadata=None, new_retries=0, update_retries=0, max_new_retries=3, max_update_retries=0, + parent_transform_id=None, previous_transform_id=None, current_processing_id=None, workprogress_id=None, session=None): """ Add a transform. @@ -62,13 +63,16 @@ def add_transform(request_id, workload_id, transform_type, transform_tag=None, p new_retries=new_retries, update_retries=update_retries, max_new_retries=max_new_retries, max_update_retries=max_update_retries, + parent_transform_id=parent_transform_id, + previous_transform_id=previous_transform_id, + current_processing_id=current_processing_id, expired_at=expired_at, transform_metadata=transform_metadata, workprogress_id=workprogress_id, session=session) return transform_id @read_session -def get_transform(transform_id, to_json=False, session=None): +def get_transform(transform_id, request_id=None, to_json=False, session=None): """ Get transform or raise a NoObject exception. @@ -80,7 +84,7 @@ def get_transform(transform_id, to_json=False, session=None): :returns: Transform. """ - return orm_transforms.get_transform(transform_id=transform_id, to_json=to_json, session=session) + return orm_transforms.get_transform(transform_id=transform_id, request_id=request_id, to_json=to_json, session=session) @transactional_session @@ -275,25 +279,34 @@ def add_transform_outputs(transform, transform_parameters, input_collections=Non if input_collections: for coll in input_collections: - collection = coll['collection'] - del coll['collection'] + collection = None + if 'collection' in coll: + collection = coll['collection'] + del coll['collection'] coll_id = orm_collections.add_collection(**coll, session=session) - # work.set_collection_id(coll, coll_id) - collection.coll_id = coll_id + if collection: + # work.set_collection_id(coll, coll_id) + collection.coll_id = coll_id if output_collections: for coll in output_collections: - collection = coll['collection'] - del coll['collection'] + collection = None + if 'collection' in coll: + collection = coll['collection'] + del coll['collection'] coll_id = orm_collections.add_collection(**coll, session=session) - # work.set_collection_id(coll, coll_id) - collection.coll_id = coll_id + if collection: + # work.set_collection_id(coll, coll_id) + collection.coll_id = coll_id if log_collections: for coll in log_collections: - collection = coll['collection'] - del coll['collection'] + collection = None + if 'collection' in coll: + collection = coll['collection'] + del coll['collection'] coll_id = orm_collections.add_collection(**coll, session=session) - # work.set_collection_id(coll, coll_id) - collection.coll_id = coll_id + if collection: + # work.set_collection_id(coll, coll_id) + collection.coll_id = coll_id if update_input_collections: update_input_colls = [coll.collection for coll in update_input_collections] @@ -315,6 +328,7 @@ def add_transform_outputs(transform, transform_parameters, input_collections=Non # print(new_processing) processing_id = orm_processings.add_processing(**new_processing, session=session) new_pr_ids.append(processing_id) + transform_parameters['current_processing_id'] = processing_id if update_processing: for proc_id in update_processing: orm_processings.update_processing(processing_id=proc_id, parameters=update_processing[proc_id], session=session) @@ -342,8 +356,10 @@ def add_transform_outputs(transform, transform_parameters, input_collections=Non if transform: if processing_id: # work.set_processing_id(new_processing, processing_id) - work.set_processing_id(new_processing['processing_metadata']['processing'], processing_id) - work.refresh_work() + if hasattr(work, 'set_processing_id'): + work.set_processing_id(new_processing['processing_metadata']['processing'], processing_id) + if hasattr(work, 'refresh_work'): + work.refresh_work() orm_transforms.update_transform(transform_id=transform['transform_id'], parameters=transform_parameters, session=session) diff --git a/main/lib/idds/orm/base/alembic/script.py.mako b/main/lib/idds/orm/base/alembic/script.py.mako index 7436eb8b..379a4d48 100644 --- a/main/lib/idds/orm/base/alembic/script.py.mako +++ b/main/lib/idds/orm/base/alembic/script.py.mako @@ -6,7 +6,6 @@ # http://www.apache.org/licenses/LICENSE-2.0OA # # Authors: -# - Wen Guan, , 2023 # - Wen Guan, , 2024 """${message} diff --git a/main/lib/idds/orm/base/alembic/versions/cc9f730e54c5_add_parent_transform_id_and_processing_.py b/main/lib/idds/orm/base/alembic/versions/cc9f730e54c5_add_parent_transform_id_and_processing_.py new file mode 100644 index 00000000..ca7c7e10 --- /dev/null +++ b/main/lib/idds/orm/base/alembic/versions/cc9f730e54c5_add_parent_transform_id_and_processing_.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0OA +# +# Authors: +# - Wen Guan, , 2024 + +"""add parent_transform_id and processing_type + +Revision ID: cc9f730e54c5 +Revises: 354f8e5a5879 +Create Date: 2024-03-01 14:58:12.471189+00:00 + +""" +from alembic import op +from alembic import context +import sqlalchemy as sa + +from idds.common.constants import ProcessingType +from idds.orm.base.types import EnumWithValue + +# revision identifiers, used by Alembic. +revision = 'cc9f730e54c5' +down_revision = '354f8e5a5879' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + if context.get_context().dialect.name in ['oracle', 'mysql', 'postgresql']: + schema = context.get_context().version_table_schema if context.get_context().version_table_schema else '' + + op.add_column('transforms', sa.Column('parent_transform_id', sa.BigInteger()), schema=schema) + op.add_column('transforms', sa.Column('previous_transform_id', sa.BigInteger()), schema=schema) + op.add_column('transforms', sa.Column('current_processing_id', sa.BigInteger()), schema=schema) + + op.add_column('processings', sa.Column('processing_type', EnumWithValue(ProcessingType), server_default='0', nullable=False), schema=schema) + + +def downgrade() -> None: + if context.get_context().dialect.name in ['oracle', 'mysql', 'postgresql']: + schema = context.get_context().version_table_schema if context.get_context().version_table_schema else '' + + op.drop_column('transforms', 'parent_transform_id', schema=schema) + op.drop_column('transforms', 'previous_transform_id', schema=schema) + op.drop_column('transforms', 'current_processing_id', schema=schema) + + op.drop_column('processings', 'processing_type', schema=schema) diff --git a/main/lib/idds/orm/base/models.py b/main/lib/idds/orm/base/models.py index cbdd8ff2..5699ddae 100644 --- a/main/lib/idds/orm/base/models.py +++ b/main/lib/idds/orm/base/models.py @@ -26,7 +26,7 @@ from idds.common.constants import (RequestType, RequestStatus, RequestLocking, WorkprogressStatus, WorkprogressLocking, TransformType, TransformStatus, TransformLocking, - ProcessingStatus, ProcessingLocking, + ProcessingType, ProcessingStatus, ProcessingLocking, CollectionStatus, CollectionLocking, CollectionType, CollectionRelationType, ContentType, ContentRelationType, ContentStatus, ContentFetchStatus, ContentLocking, GranularityType, @@ -297,6 +297,9 @@ class Transform(BASE, ModelBase): oldstatus = Column(EnumWithValue(TransformStatus), default=0) locking = Column(EnumWithValue(TransformLocking), nullable=False) retries = Column(Integer(), default=0) + parent_transform_id = Column(BigInteger()) + previous_transform_id = Column(BigInteger()) + current_processing_id = Column(BigInteger()) created_at = Column("created_at", DateTime, default=datetime.datetime.utcnow, nullable=False) updated_at = Column("updated_at", DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False) next_poll_at = Column("next_poll_at", DateTime, default=datetime.datetime.utcnow) @@ -395,6 +398,7 @@ class Processing(BASE, ModelBase): transform_id = Column(BigInteger().with_variant(Integer, "sqlite"), nullable=False) request_id = Column(BigInteger().with_variant(Integer, "sqlite"), nullable=False) workload_id = Column(Integer()) + processing_type = Column(EnumWithValue(ProcessingType), nullable=False) status = Column(EnumWithValue(ProcessingStatus), nullable=False) substatus = Column(EnumWithValue(ProcessingStatus), default=0) oldstatus = Column(EnumWithValue(ProcessingStatus), default=0) diff --git a/main/lib/idds/orm/base/session.py b/main/lib/idds/orm/base/session.py index 925d698e..aa768416 100644 --- a/main/lib/idds/orm/base/session.py +++ b/main/lib/idds/orm/base/session.py @@ -232,7 +232,7 @@ def retry_if_db_connection_error(exception): conn_err_codes = ('server closed the connection unexpectedly', 'closed the connection',) for err_code in conn_err_codes: - if exception.args[0].find(err_code) != -1: + if str(exception.args[0]).find(err_code) != -1: return True return False diff --git a/main/lib/idds/orm/processings.py b/main/lib/idds/orm/processings.py index f77e340a..f3b9736b 100644 --- a/main/lib/idds/orm/processings.py +++ b/main/lib/idds/orm/processings.py @@ -21,14 +21,14 @@ from sqlalchemy.sql.expression import asc from idds.common import exceptions -from idds.common.constants import ProcessingStatus, ProcessingLocking, GranularityType +from idds.common.constants import ProcessingType, ProcessingStatus, ProcessingLocking, GranularityType from idds.orm.base.session import read_session, transactional_session from idds.orm.base import models def create_processing(request_id, workload_id, transform_id, status=ProcessingStatus.New, locking=ProcessingLocking.Idle, submitter=None, granularity=None, granularity_type=GranularityType.File, expired_at=None, processing_metadata=None, - new_poll_period=1, update_poll_period=10, + new_poll_period=1, update_poll_period=10, processing_type=ProcessingType.Workflow, new_retries=0, update_retries=0, max_new_retries=3, max_update_retries=0, substatus=ProcessingStatus.New, output_metadata=None): """ @@ -52,6 +52,7 @@ def create_processing(request_id, workload_id, transform_id, status=ProcessingSt submitter=submitter, granularity=granularity, granularity_type=granularity_type, expired_at=expired_at, processing_metadata=processing_metadata, new_retries=new_retries, update_retries=update_retries, + processing_type=processing_type, max_new_retries=max_new_retries, max_update_retries=max_update_retries, output_metadata=output_metadata) @@ -69,6 +70,7 @@ def add_processing(request_id, workload_id, transform_id, status=ProcessingStatu locking=ProcessingLocking.Idle, submitter=None, substatus=ProcessingStatus.New, granularity=None, granularity_type=GranularityType.File, expired_at=None, processing_metadata=None, new_poll_period=1, update_poll_period=10, + processing_type=ProcessingType.Workflow, new_retries=0, update_retries=0, max_new_retries=3, max_update_retries=0, output_metadata=None, session=None): """ @@ -95,7 +97,7 @@ def add_processing(request_id, workload_id, transform_id, status=ProcessingStatu status=status, substatus=substatus, locking=locking, submitter=submitter, granularity=granularity, granularity_type=granularity_type, expired_at=expired_at, new_poll_period=new_poll_period, - update_poll_period=update_poll_period, + update_poll_period=update_poll_period, processing_type=processing_type, new_retries=new_retries, update_retries=update_retries, max_new_retries=max_new_retries, max_update_retries=max_update_retries, processing_metadata=processing_metadata, output_metadata=output_metadata) diff --git a/main/lib/idds/orm/requests.py b/main/lib/idds/orm/requests.py index e64a7251..c9a999e9 100644 --- a/main/lib/idds/orm/requests.py +++ b/main/lib/idds/orm/requests.py @@ -883,7 +883,8 @@ def update_request(request_id, parameters, update_request_metadata=False, sessio workflow = parameters['request_metadata']['workflow'] if workflow is not None: - workflow.refresh_works() + if hasattr(workflow, 'refresh_works'): + workflow.refresh_works() if 'processing_metadata' not in parameters or not parameters['processing_metadata']: parameters['processing_metadata'] = {} parameters['processing_metadata']['workflow_data'] = workflow.metadata diff --git a/main/lib/idds/orm/transforms.py b/main/lib/idds/orm/transforms.py index ca06e017..f2d6e367 100644 --- a/main/lib/idds/orm/transforms.py +++ b/main/lib/idds/orm/transforms.py @@ -31,6 +31,7 @@ def create_transform(request_id, workload_id, transform_type, transform_tag=None substatus=TransformStatus.New, locking=TransformLocking.Idle, new_poll_period=1, update_poll_period=10, new_retries=0, update_retries=0, max_new_retries=3, max_update_retries=0, + parent_transform_id=None, previous_transform_id=None, current_processing_id=None, retries=0, expired_at=None, transform_metadata=None): """ Create a transform. @@ -54,6 +55,9 @@ def create_transform(request_id, workload_id, transform_type, transform_tag=None retries=retries, expired_at=expired_at, new_retries=new_retries, update_retries=update_retries, max_new_retries=max_new_retries, max_update_retries=max_update_retries, + parent_transform_id=parent_transform_id, + previous_transform_id=previous_transform_id, + current_processing_id=current_processing_id, transform_metadata=transform_metadata) if new_poll_period: new_poll_period = datetime.timedelta(seconds=new_poll_period) @@ -69,6 +73,7 @@ def add_transform(request_id, workload_id, transform_type, transform_tag=None, p status=TransformStatus.New, substatus=TransformStatus.New, locking=TransformLocking.Idle, new_poll_period=1, update_poll_period=10, retries=0, expired_at=None, new_retries=0, update_retries=0, max_new_retries=3, max_update_retries=0, + parent_transform_id=None, previous_transform_id=None, current_processing_id=None, transform_metadata=None, workprogress_id=None, session=None): """ Add a transform. @@ -98,6 +103,9 @@ def add_transform(request_id, workload_id, transform_type, transform_tag=None, p update_poll_period=update_poll_period, new_retries=new_retries, update_retries=update_retries, max_new_retries=max_new_retries, max_update_retries=max_update_retries, + parent_transform_id=parent_transform_id, + previous_transform_id=previous_transform_id, + current_processing_id=current_processing_id, transform_metadata=transform_metadata) new_transform.save(session=session) transform_id = new_transform.transform_id @@ -152,7 +160,7 @@ def add_wp2transform(workprogress_id, transform_id, session=None): @read_session -def get_transform(transform_id, to_json=False, session=None): +def get_transform(transform_id, request_id=None, to_json=False, session=None): """ Get transform or raise a NoObject exception. @@ -167,6 +175,8 @@ def get_transform(transform_id, to_json=False, session=None): try: query = session.query(models.Transform)\ .filter(models.Transform.transform_id == transform_id) + if request_id: + query = query.filter(models.Transform.request_id == request_id) ret = query.first() if not ret: return None @@ -448,7 +458,8 @@ def update_transform(transform_id, parameters, session=None): if 'transform_metadata' in parameters and 'work' in parameters['transform_metadata']: work = parameters['transform_metadata']['work'] if work is not None: - work.refresh_work() + if hasattr(work, 'refresh_work'): + work.refresh_work() if 'running_metadata' not in parameters: parameters['running_metadata'] = {} parameters['running_metadata']['work_data'] = work.metadata diff --git a/main/lib/idds/rest/v1/app.py b/main/lib/idds/rest/v1/app.py index dcd3c2f0..0b97f582 100644 --- a/main/lib/idds/rest/v1/app.py +++ b/main/lib/idds/rest/v1/app.py @@ -18,12 +18,14 @@ from flask import Flask, Response from idds.common import exceptions -from idds.common.authentication import authenticate_x509, authenticate_oidc, authenticate_is_super_user +# from idds.common.authentication import authenticate_x509, authenticate_oidc, authenticate_is_super_user from idds.common.config import (config_has_section, config_has_option, config_get) from idds.common.constants import HTTP_STATUS_CODE from idds.common.utils import get_rest_debug +from idds.core.authentication import authenticate_x509, authenticate_oidc, authenticate_is_super_user # from idds.common.utils import get_rest_url_prefix from idds.rest.v1 import requests +from idds.rest.v1 import transforms from idds.rest.v1 import catalog from idds.rest.v1 import cacher from idds.rest.v1 import hyperparameteropt @@ -60,6 +62,7 @@ def log_response(status, headers, *args): def get_normal_blueprints(): bps = [] bps.append(requests.get_blueprint()) + bps.append(transforms.get_blueprint()) bps.append(catalog.get_blueprint()) bps.append(cacher.get_blueprint()) bps.append(hyperparameteropt.get_blueprint()) diff --git a/main/lib/idds/rest/v1/auth.py b/main/lib/idds/rest/v1/auth.py index 098b5401..3d5f4fa2 100644 --- a/main/lib/idds/rest/v1/auth.py +++ b/main/lib/idds/rest/v1/auth.py @@ -15,7 +15,7 @@ from idds.common import exceptions from idds.common.constants import HTTP_STATUS_CODE -from idds.common.authentication import OIDCAuthentication +from idds.core.authentication import OIDCAuthentication from idds.rest.v1.controller import IDDSController diff --git a/main/lib/idds/rest/v1/transforms.py b/main/lib/idds/rest/v1/transforms.py new file mode 100644 index 00000000..210d2983 --- /dev/null +++ b/main/lib/idds/rest/v1/transforms.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0OA +# +# Authors: +# - Wen Guan, , 2024 + + +from traceback import format_exc + +from flask import Blueprint + +from idds.common import exceptions +from idds.common.authentication import authenticate_is_super_user +from idds.common.constants import HTTP_STATUS_CODE +from idds.common.constants import RequestStatus, RequestType +from idds.common.utils import json_loads +from idds.core.requests import get_request +from idds.core.transforms import add_transform, get_transform, get_transforms +from idds.rest.v1.controller import IDDSController + + +class Transform(IDDSController): + """ Create a Transform. """ + + def post(self, request_id): + """ Create Transform. + HTTP Success: + 200 OK + HTTP Error: + 400 Bad request + 500 Internal Error + """ + try: + parameters = self.get_request().data and json_loads(self.get_request().data) + if 'status' not in parameters: + parameters['status'] = RequestStatus.New + if 'priority' not in parameters or not parameters['priority']: + parameters['priority'] = 0 + if 'token' not in parameters or not parameters['token']: + return self.generate_http_response(HTTP_STATUS_CODE.BadRequest, exc_cls=exceptions.BadRequest.__name__, exc_msg='Token is required') + token = parameters['token'] + del parameters['token'] + except ValueError: + return self.generate_http_response(HTTP_STATUS_CODE.BadRequest, exc_cls=exceptions.BadRequest.__name__, exc_msg='Cannot decode json parameter dictionary') + + try: + if not request_id or request_id in [None, 'None', 'none', 'NULL', 'null']: + raise exceptions.IDDSException("Request (request_id: %s) is required" % request_id) + + request_id = int(request_id) + req = get_request(request_id=request_id) + if not req: + raise exceptions.IDDSException("Request %s is not found" % request_id) + + if req['request_type'] != RequestType.iWorkflow: + raise exceptions.IDDSException("Request type %s doesn't support this operations" % req['request_type']) + + workflow = req['request_metadata']['workflow'] + if workflow.token != token: + raise exceptions.IDDSException("Token %s is not correct for request %s" % (token, request_id)) + + username = self.get_username() + if req['username'] and req['username'] != username and not authenticate_is_super_user(username): + raise exceptions.AuthenticationNoPermission("User %s has no permission to update request %s" % (username, request_id)) + + parameters['request_id'] = request_id + transform_id = add_transform(**parameters) + except exceptions.DuplicatedObject as error: + return self.generate_http_response(HTTP_STATUS_CODE.Conflict, exc_cls=error.__class__.__name__, exc_msg=error) + except exceptions.IDDSException as error: + return self.generate_http_response(HTTP_STATUS_CODE.InternalError, exc_cls=error.__class__.__name__, exc_msg=error) + except Exception as error: + print(error) + print(format_exc()) + return self.generate_http_response(HTTP_STATUS_CODE.InternalError, exc_cls=exceptions.CoreException.__name__, exc_msg=error) + + return self.generate_http_response(HTTP_STATUS_CODE.OK, data={'transform_id': transform_id}) + + def get(self, request_id, transform_id=None): + """ Get transforms with given id. + HTTP Success: + 200 OK + HTTP Error: + 404 Not Found + 500 InternalError + :returns: dictionary of an request. + """ + + try: + if not request_id or request_id in [None, 'None', 'none', 'NULL', 'null']: + raise exceptions.IDDSException("Request (request_id: %s) is required" % request_id) + if not transform_id or transform_id in [None, 'None', 'none', 'NULL', 'null']: + transform_id = None + + if not transform_id: + tfs = get_transforms(request_id=request_id) + else: + tfs = get_transform(request_id=request_id, transform_id=transform_id) + except exceptions.IDDSException as error: + return self.generate_http_response(HTTP_STATUS_CODE.InternalError, exc_cls=error.__class__.__name__, exc_msg=error) + except Exception as error: + print(error) + print(format_exc()) + return self.generate_http_response(HTTP_STATUS_CODE.InternalError, exc_cls=exceptions.CoreException.__name__, exc_msg=error) + + return self.generate_http_response(HTTP_STATUS_CODE.OK, data=tfs) + + +"""---------------------- + Web service url maps +----------------------""" + + +def get_blueprint(): + bp = Blueprint('transfrom', __name__) + + transform_view = Transform.as_view('transform') + bp.add_url_rule('/transform/', view_func=transform_view, methods=['post', ]) + bp.add_url_rule('/transform/', view_func=transform_view, methods=['get', ]) + bp.add_url_rule('/transform//', view_func=transform_view, methods=['get', ]) + + return bp diff --git a/main/lib/idds/tests/auth_test_script.py b/main/lib/idds/tests/auth_test_script.py index c33d4c0a..92e82b99 100644 --- a/main/lib/idds/tests/auth_test_script.py +++ b/main/lib/idds/tests/auth_test_script.py @@ -6,7 +6,7 @@ # http://www.apache.org/licenses/LICENSE-2.0OA # # Authors: -# - Wen Guan, , 2021 - 2022 +# - Wen Guan, , 2021 - 2024 """ @@ -26,7 +26,7 @@ import unittest2 as unittest # noqa F401 # from nose.tools import assert_equal from idds.common.utils import setup_logging -from idds.common.authentication import OIDCAuthentication +from idds.core.authentication import OIDCAuthentication setup_logging(__name__) diff --git a/main/lib/idds/tests/core_tests.py b/main/lib/idds/tests/core_tests.py index c8e4770d..75f7e217 100644 --- a/main/lib/idds/tests/core_tests.py +++ b/main/lib/idds/tests/core_tests.py @@ -177,6 +177,7 @@ def print_workflow_template(workflow, layers=0): reqs = get_requests(request_id=3244, with_request=True, with_detail=False, with_metadata=True) reqs = get_requests(request_id=6082, with_request=True, with_detail=False, with_metadata=True) # reqs = get_requests(request_id=589913, with_request=True, with_detail=False, with_metadata=True) +reqs = get_requests(request_id=617073, with_request=True, with_detail=False, with_metadata=True) for req in reqs: # print(req['request_id']) # print(req) diff --git a/main/lib/idds/tests/core_tests_atlas.py b/main/lib/idds/tests/core_tests_atlas.py new file mode 100644 index 00000000..15108bbd --- /dev/null +++ b/main/lib/idds/tests/core_tests_atlas.py @@ -0,0 +1,157 @@ +import sys # noqa F401 +import datetime # noqa F401 + +from idds.common.utils import json_dumps, setup_logging # noqa F401 +from idds.common.constants import ContentStatus, ContentType, ContentRelationType, ContentLocking # noqa F401 +from idds.core.requests import get_requests # noqa F401 +from idds.core.messages import retrieve_messages # noqa F401 +from idds.core.transforms import get_transforms, get_transform # noqa F401 +from idds.core.workprogress import get_workprogresses # noqa F401 +from idds.core.processings import get_processings # noqa F401 +from idds.core import transforms as core_transforms # noqa F401 +from idds.orm.contents import get_contents # noqa F401 +from idds.core.transforms import release_inputs_by_collection, release_inputs_by_collection_old # noqa F401 +from idds.workflowv2.workflow import Workflow # noqa F401 +from idds.workflowv2.work import Work # noqa F401 + + +setup_logging(__name__) + + +def show_works(req): + workflow = req['processing_metadata']['workflow'] + print(workflow.independent_works) + print(len(workflow.independent_works)) + print(workflow.works_template.keys()) + print(len(workflow.works_template.keys())) + print(workflow.work_sequence.keys()) + print(len(workflow.work_sequence.keys())) + print(workflow.works.keys()) + print(len(workflow.works.keys())) + + work_ids = [] + for i_id in workflow.works: + work = workflow.works[i_id] + print(i_id) + print(work.work_name) + print(work.task_name) + print(work.work_id) + work_ids.append(work.work_id) + print(work_ids) + + +def print_workflow(workflow, layers=0): + prefix = " " * layers * 4 + for run in workflow.runs: + print(prefix + "run: " + str(run) + ", has_loop_condition: " + str(workflow.runs[run].has_loop_condition())) + # if workflow.runs[run].has_loop_condition(): + # print(prefix + " Loop condition: %s" % json_dumps(workflow.runs[run].loop_condition, sort_keys=True, indent=4)) + for work_id in workflow.runs[run].works: + print(prefix + " " + str(work_id) + " " + str(type(workflow.runs[run].works[work_id]))) + if type(workflow.runs[run].works[work_id]) in [Workflow]: + print(prefix + " parent_num_run: " + workflow.runs[run].works[work_id].parent_num_run + ", num_run: " + str(workflow.runs[run].works[work_id].num_run)) + print_workflow(workflow.runs[run].works[work_id], layers=layers + 1) + # print(prefix + " is_terminated: " + str(workflow.runs[run].works[work_id].is_terminated())) + # print(prefix + " is_finished: " + str(workflow.runs[run].works[work_id].is_finished())) + # elif type(workflow.runs[run].works[work_id]) in [Work]: + else: + work = workflow.runs[run].works[work_id] + tf = get_transform(transform_id=work.get_work_id()) + if tf: + transform_work = tf['transform_metadata']['work'] + # print(json_dumps(transform_work, sort_keys=True, indent=4)) + work.sync_work_data(status=tf['status'], substatus=tf['substatus'], work=transform_work, workload_id=tf['workload_id']) + + print(prefix + " or: " + str(work.or_custom_conditions) + " and: " + str(work.and_custom_conditions)) + print(prefix + " output: " + str(work.output_data)) + print(prefix + " " + workflow.runs[run].works[work_id].task_name + ", num_run: " + str(workflow.runs[run].works[work_id].num_run)) + print(prefix + " workload_id: " + str(work.workload_id)) + print(prefix + " is_terminated: " + str(workflow.runs[run].works[work_id].is_terminated())) + print(prefix + " is_finished: " + str(workflow.runs[run].works[work_id].is_finished())) + if workflow.runs[run].has_loop_condition(): + print(prefix + " Loop condition status: %s" % workflow.runs[run].get_loop_condition_status()) + print(prefix + " Loop condition: %s" % json_dumps(workflow.runs[run].loop_condition, sort_keys=True, indent=4)) + + +def print_workflow_template(workflow, layers=0): + prefix = " " * layers * 4 + print(prefix + str(workflow.template.internal_id) + ", has_loop_condition: " + str(workflow.template.has_loop_condition())) + for work_id in workflow.template.works: + print(prefix + " " + str(work_id) + " " + str(type(workflow.template.works[work_id]))) + if type(workflow.template.works[work_id]) in [Workflow]: + print(prefix + " parent_num_run: " + str(workflow.template.works[work_id].parent_num_run) + ", num_run: " + str(workflow.template.works[work_id].num_run)) + print_workflow_template(workflow.template.works[work_id], layers=layers + 1) + else: + print(prefix + " " + workflow.template.works[work_id].task_name + ", num_run: " + str(workflow.template.works[work_id].num_run)) + + +# 283511, 283517 +# reqs = get_requests(request_id=599, with_detail=True, with_metadata=True) +# reqs = get_requests(request_id=283511, with_request=True, with_detail=False, with_metadata=True) +# reqs = get_requests(request_id=298163, with_request=True, with_detail=False, with_metadata=True) +# reqs = get_requests(request_id=298557, with_request=True, with_detail=False, with_metadata=True) +# reqs = get_requests(request_id=299111, with_request=True, with_detail=False, with_metadata=True) +# reqs = get_requests(request_id=299235, with_request=True, with_detail=False, with_metadata=True) +# reqs = get_requests(request_id=965, with_request=True, with_detail=False, with_metadata=True) +# reqs = get_requests(request_id=350695, with_request=True, with_detail=False, with_metadata=True) + +# reqs = get_requests(request_id=370028, with_request=True, with_detail=False, with_metadata=True) +# reqs = get_requests(request_id=370400, with_request=True, with_detail=False, with_metadata=True) +# reqs = get_requests(request_id=371204, with_request=True, with_detail=False, with_metadata=True) +# reqs = get_requests(request_id=372678, with_request=True, with_detail=False, with_metadata=True) +# reqs = get_requests(request_id=373602, with_request=True, with_detail=False, with_metadata=True) +# reqs = get_requests(request_id=376086, with_request=True, with_detail=False, with_metadata=True) +# reqs = get_requests(request_id=380474, with_request=True, with_detail=False, with_metadata=True) +# reqs = get_requests(request_id=381520, with_request=True, with_detail=False, with_metadata=True) +# reqs = get_requests(request_id=28182323, with_request=True, with_detail=False, with_metadata=True) +# reqs = get_requests(request_id=385554, with_request=True, with_detail=False, with_metadata=True) +reqs = get_requests(request_id=545851, with_request=True, with_detail=False, with_metadata=True) +for req in reqs: + # print(req['request_id']) + # print(req) + # print(rets) + print(json_dumps(req, sort_keys=True, indent=4)) + # show_works(req) + pass + if 'build_workflow' in req['request_metadata']: + workflow = req['request_metadata']['build_workflow'] + # workflow.get_new_works() + print(workflow.runs.keys()) + # print(workflow.runs["1"]) + print(json_dumps(workflow.runs["1"], sort_keys=True, indent=4)) + elif 'workflow' in req['request_metadata']: + workflow = req['request_metadata']['workflow'] + # workflow.get_new_works() + print(workflow.runs.keys()) + # print(workflow.runs["1"]) + # print(json_dumps(workflow.runs["1"], sort_keys=True, indent=4)) + + # print(workflow.runs["1"].works.keys()) + # print(workflow.runs["1"].has_loop_condition()) + # print(workflow.runs["1"].works["7aa1ec08"]) + # print(json_dumps(workflow.runs["1"].works["048a1811"], indent=4)) + # print(workflow.runs["1"].works["7aa1ec08"].runs.keys()) + # print(workflow.runs["1"].works["7aa1ec08"].runs["1"].has_loop_condition()) + # print(workflow.runs["1"].works["7aa1ec08"].runs["1"].works.keys()) + + # print(json_dumps(workflow.runs["1"].works["7aa1ec08"].runs["1"], indent=4)) + if hasattr(workflow, 'get_relation_map'): + # print(json_dumps(workflow.get_relation_map(), sort_keys=True, indent=4)) + pass + + print("workflow") + print_workflow(workflow) + new_works = workflow.get_new_works() + print('new_works:' + str(new_works)) + all_works = workflow.get_all_works() + print('all_works:' + str(all_works)) + for work in all_works: + print("work %s signature: %s" % (work.get_work_id(), work.signature)) + + # print("workflow template") + print_workflow_template(workflow) + + # workflow.sync_works() + + print("workflow template") + print(json_dumps(workflow.template, sort_keys=True, indent=4)) diff --git a/main/lib/idds/tests/panda_test.py b/main/lib/idds/tests/panda_test.py index a7788ece..9f0a52e3 100644 --- a/main/lib/idds/tests/panda_test.py +++ b/main/lib/idds/tests/panda_test.py @@ -7,11 +7,11 @@ os.environ['PANDA_URL_SSL'] = 'https://pandaserver-doma.cern.ch:25443/server/panda' os.environ['PANDA_BEHIND_REAL_LB'] = "1" -os.environ['PANDA_URL'] = 'http://rubin-panda-server-dev.slac.stanford.edu:80/server/panda' -os.environ['PANDA_URL_SSL'] = 'https://rubin-panda-server-dev.slac.stanford.edu:8443/server/panda' +# os.environ['PANDA_URL'] = 'http://rubin-panda-server-dev.slac.stanford.edu:80/server/panda' +# os.environ['PANDA_URL_SSL'] = 'https://rubin-panda-server-dev.slac.stanford.edu:8443/server/panda' -os.environ['PANDA_URL'] = 'https://usdf-panda-server.slac.stanford.edu:8443/server/panda' -os.environ['PANDA_URL_SSL'] = 'https://usdf-panda-server.slac.stanford.edu:8443/server/panda' +# os.environ['PANDA_URL'] = 'https://usdf-panda-server.slac.stanford.edu:8443/server/panda' +# os.environ['PANDA_URL_SSL'] = 'https://usdf-panda-server.slac.stanford.edu:8443/server/panda' from pandaclient import Client # noqa E402 @@ -51,6 +51,9 @@ task_ids = [688, 8686, 8695, 8696] task_ids = [i for i in range(8958, 9634)] task_ids = [i for i in range(8752, 8958)] +task_ids = [168645, 168638] +task_ids = [168747, 168761, 168763] +task_ids = [168770] for task_id in task_ids: print("Killing %s" % task_id) ret = Client.killTask(task_id, verbose=True) diff --git a/main/lib/idds/tests/test_auth.py b/main/lib/idds/tests/test_auth.py index e609c93f..df036a35 100644 --- a/main/lib/idds/tests/test_auth.py +++ b/main/lib/idds/tests/test_auth.py @@ -6,7 +6,7 @@ # http://www.apache.org/licenses/LICENSE-2.0OA # # Authors: -# - Wen Guan, , 2021 - 2022 +# - Wen Guan, , 2021 - 2024 """ @@ -26,7 +26,7 @@ import unittest2 as unittest # from nose.tools import assert_equal from idds.common.utils import setup_logging -from idds.common.authentication import OIDCAuthentication +from idds.core.authentication import OIDCAuthentication setup_logging(__name__) diff --git a/main/lib/idds/tests/test_domapanda_mem.py b/main/lib/idds/tests/test_domapanda_mem.py new file mode 100644 index 00000000..254a581a --- /dev/null +++ b/main/lib/idds/tests/test_domapanda_mem.py @@ -0,0 +1,290 @@ +#!/usr/bin/env python +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0OA +# +# Authors: +# - Sergey Padolski, , 2021 +# - Wen Guan, , 2021 + + +""" +Test client. +""" + +import sys +import string +import random +import time + +# import traceback + +# from rucio.client.client import Client as Rucio_Client +# from rucio.common.exception import CannotAuthenticate + +# from idds.client.client import Client +from idds.client.clientmanager import ClientManager +# from idds.common.constants import RequestType, RequestStatus +from idds.common.utils import get_rest_host +# from idds.tests.common import get_example_real_tape_stagein_request +# from idds.tests.common import get_example_prodsys2_tape_stagein_request + +# from idds.workflowv2.work import Work, Parameter, WorkStatus +# from idds.workflowv2.workflow import Condition, Workflow +from idds.workflowv2.workflow import Workflow +# from idds.atlas.workflowv2.atlasstageinwork import ATLASStageinWork +from idds.doma.workflowv2.domapandawork import DomaPanDAWork + + +if len(sys.argv) > 1 and sys.argv[1] == "in2p3": + site = 'in2p3' + panda_site = "CC-IN2P3" + task_cloud = 'EU' + # task_queue = 'CC-IN2P3_TEST' + task_queue = 'CC-IN2P3_Rubin' + task_queue1 = 'CC-IN2P3_Rubin_Medium' + task_queue2 = 'CC-IN2P3_Rubin_Himem' + task_queue3 = 'CC-IN2P3_Rubin_Extra_Himem' + task_queue4 = 'CC-IN2P3_Rubin_Merge' +elif len(sys.argv) > 1 and sys.argv[1] == "lancs": + site = 'lancs' + panda_site = "LANCS" + task_cloud = 'EU' + # task_queue = 'LANCS_TEST' + task_queue = 'LANCS_Rubin' + task_queue1 = 'LANCS_Rubin_Medium' + task_queue2 = 'LANCS_Rubin_Himem' + task_queue3 = 'LANCS_Rubin_Extra_Himem' + task_queue3 = 'LANCS_Rubin_Himem' + task_queue4 = 'LANCS_Rubin_Merge' +else: + site = 'slac' + panda_site = "SLAC" + # task_cloud = 'LSST' + task_cloud = 'US' + + task_queue = 'DOMA_LSST_GOOGLE_TEST' + # task_queue = 'DOMA_LSST_GOOGLE_MERGE' + # task_queue = 'SLAC_TEST' + # task_queue = 'DOMA_LSST_SLAC_TEST' + task_queue = 'SLAC_Rubin' + task_queue1 = 'SLAC_Rubin_Medium' + task_queue2 = 'SLAC_Rubin_Himem' + task_queue3 = 'SLAC_Rubin_Extra_Himem' + task_queue4 = 'SLAC_Rubin_Merge' + # task_queue = 'SLAC_Rubin_Extra_Himem_32Cores' + # task_queue = 'SLAC_Rubin_Merge' + # task_queue = 'SLAC_TEST' + # task_queue4 = task_queue3 = task_queue2 = task_queue1 = task_queue + +# task_cloud = None + + +def randStr(chars=string.ascii_lowercase + string.digits, N=10): + return ''.join(random.choice(chars) for _ in range(N)) + + +class PanDATask(object): + name = None + step = None + dependencies = [] + + +def setup_workflow(): + + taskN1 = PanDATask() + taskN1.step = "step1" + taskN1.name = site + "_" + taskN1.step + "_" + randStr() + taskN1.dependencies = [ + {"name": "00000" + str(k), + "order_id": k, + "dependencies": [], + "submitted": False} for k in range(6) + ] + + taskN2 = PanDATask() + taskN2.step = "step2" + taskN2.name = site + "_" + taskN2.step + "_" + randStr() + taskN2.dependencies = [ + { + "name": "000010", + "order_id": 0, + "dependencies": [{"task": taskN1.name, "inputname": "000001", "available": False}, + {"task": taskN1.name, "inputname": "000002", "available": False}], + "submitted": False + }, + { + "name": "000011", + "order_id": 1, + "dependencies": [{"task": taskN1.name, "inputname": "000001", "available": False}, + {"task": taskN1.name, "inputname": "000002", "available": False}], + "submitted": False + }, + { + "name": "000012", + "order_id": 2, + "dependencies": [{"task": taskN1.name, "inputname": "000001", "available": False}, + {"task": taskN1.name, "inputname": "000002", "available": False}], + "submitted": False + } + ] + + taskN3 = PanDATask() + taskN3.step = "step3" + taskN3.name = site + "_" + taskN3.step + "_" + randStr() + taskN3.dependencies = [ + { + "name": "000020", + "order_id": 0, + "dependencies": [], + "submitted": False + }, + { + "name": "000021", + "order_id": 1, + "dependencies": [{"task": taskN2.name, "inputname": "000010", "available": False}, + {"task": taskN2.name, "inputname": "000011", "available": False}], + "submitted": False + }, + { + "name": "000022", + "order_id": 2, + "dependencies": [{"task": taskN2.name, "inputname": "000011", "available": False}, + {"task": taskN2.name, "inputname": "000012", "available": False}], + "submitted": False + }, + { + "name": "000023", + "order_id": 3, + "dependencies": [], + "submitted": False + }, + { + "name": "000024", + "order_id": 4, + "dependencies": [{"task": taskN3.name, "inputname": "000021", "available": False}, + {"task": taskN3.name, "inputname": "000023", "available": False}], + "submitted": False + }, + ] + + taskN4 = PanDATask() + taskN4.step = "step4" + taskN4.name = site + "_" + taskN4.step + "_" + randStr() + taskN4.dependencies = [ + {"name": "00004" + str(k), + "order_id": k, + "dependencies": [], + "submitted": False} for k in range(6) + ] + + taskN5 = PanDATask() + taskN5.step = "step5" + taskN5.name = site + "_" + taskN5.step + "_" + randStr() + taskN5.dependencies = [ + {"name": "00005" + str(k), + "order_id": k, + "dependencies": [], + "submitted": False} for k in range(6) + ] + + work1 = DomaPanDAWork(executable='echo', + primary_input_collection={'scope': 'pseudo_dataset', 'name': 'pseudo_input_collection#1'}, + output_collections=[{'scope': 'pseudo_dataset', 'name': 'pseudo_output_collection#1'}], + log_collections=[], dependency_map=taskN1.dependencies, + task_name=taskN1.name, task_queue=None, task_site=panda_site, task_rss=3000, + encode_command_line=True, + task_priority=981, + prodSourceLabel='managed', + task_log={"dataset": "PandaJob_#{pandaid}/", + "destination": "local", + "param_type": "log", + "token": "local", + "type": "template", + "value": "log.tgz"}, + task_cloud=task_cloud) + work2 = DomaPanDAWork(executable='echo', + primary_input_collection={'scope': 'pseudo_dataset', 'name': 'pseudo_input_collection#2'}, + output_collections=[{'scope': 'pseudo_dataset', 'name': 'pseudo_output_collection#2'}], + log_collections=[], dependency_map=taskN2.dependencies, + task_name=taskN2.name, task_queue=None, task_site=panda_site, task_rss=7000, + encode_command_line=True, + task_priority=881, + prodSourceLabel='managed', + task_log={"dataset": "PandaJob_#{pandaid}/", + "destination": "local", + "param_type": "log", + "token": "local", + "type": "template", + "value": "log.tgz"}, + task_cloud=task_cloud) + work3 = DomaPanDAWork(executable='echo', + primary_input_collection={'scope': 'pseudo_dataset', 'name': 'pseudo_input_collection#3'}, + output_collections=[{'scope': 'pseudo_dataset', 'name': 'pseudo_output_collection#3'}], + log_collections=[], dependency_map=taskN3.dependencies, + task_name=taskN3.name, task_queue=None, task_site=panda_site, task_rss=14000, + encode_command_line=True, + task_priority=781, + prodSourceLabel='managed', + task_log={"dataset": "PandaJob_#{pandaid}/", + "destination": "local", + "param_type": "log", + "token": "local", + "type": "template", + "value": "log.tgz"}, + task_cloud=task_cloud) + + work4 = DomaPanDAWork(executable='echo', + primary_input_collection={'scope': 'pseudo_dataset', 'name': 'pseudo_input_collection#1'}, + output_collections=[{'scope': 'pseudo_dataset', 'name': 'pseudo_output_collection#1'}], + log_collections=[], dependency_map=taskN4.dependencies, + task_name=taskN4.name, task_queue=None, task_site=panda_site, task_rss=20000, + encode_command_line=True, + task_priority=981, + prodSourceLabel='managed', + task_log={"dataset": "PandaJob_#{pandaid}/", + "destination": "local", + "param_type": "log", + "token": "local", + "type": "template", + "value": "log.tgz"}, + task_cloud=task_cloud) + + work5 = DomaPanDAWork(executable='echo', + primary_input_collection={'scope': 'pseudo_dataset', 'name': 'pseudo_input_collection#1'}, + output_collections=[{'scope': 'pseudo_dataset', 'name': 'pseudo_output_collection#1'}], + log_collections=[], dependency_map=taskN5.dependencies, + task_name=taskN5.name, task_queue=None, task_site=panda_site, task_rss=33000, + encode_command_line=True, + task_priority=981, + prodSourceLabel='managed', + task_log={"dataset": "PandaJob_#{pandaid}/", + "destination": "local", + "param_type": "log", + "token": "local", + "type": "template", + "value": "log.tgz"}, + task_cloud=task_cloud) + + pending_time = 12 + # pending_time = None + workflow = Workflow(pending_time=pending_time) + workflow.add_work(work1) + workflow.add_work(work2) + workflow.add_work(work3) + workflow.add_work(work4) + workflow.add_work(work5) + workflow.name = site + "_" + 'test_workflow.idds.%s.test' % time.time() + return workflow + + +if __name__ == '__main__': + host = get_rest_host() + workflow = setup_workflow() + + wm = ClientManager(host=host) + # wm.set_original_user(user_name="wguandev") + request_id = wm.submit(workflow, use_dataset_name=False) + print(request_id) diff --git a/main/lib/idds/tests/test_get_source_code.py b/main/lib/idds/tests/test_get_source_code.py new file mode 100644 index 00000000..d2de7890 --- /dev/null +++ b/main/lib/idds/tests/test_get_source_code.py @@ -0,0 +1,19 @@ +import dill +import inspect + + +def foo(arg1, arg2): + # do something with args + a = arg1 + arg2 + return a + + +source_foo = inspect.getsource(foo) # foo is normal function +print(source_foo) + + +# source_max = inspect.getsource(max) # max is a built-in function +# print(source_max) + +print(inspect.signature(foo)) +print(dill.dumps(foo)) diff --git a/main/lib/idds/tests/test_iworkflow/command b/main/lib/idds/tests/test_iworkflow/command new file mode 100644 index 00000000..28620e70 --- /dev/null +++ b/main/lib/idds/tests/test_iworkflow/command @@ -0,0 +1,2 @@ +python main/lib/idds/tests/test_iworkflow/test_iworkflow.py + diff --git a/main/lib/idds/tests/test_iworkflow/optimize.py b/main/lib/idds/tests/test_iworkflow/optimize.py new file mode 100644 index 00000000..21b51923 --- /dev/null +++ b/main/lib/idds/tests/test_iworkflow/optimize.py @@ -0,0 +1,437 @@ + +import json +import hashlib +import os +import sys +import time +import traceback + +import numpy as np # noqa F401 + +from sklearn.metrics import roc_curve, auc, confusion_matrix, brier_score_loss, mean_squared_error, log_loss, roc_auc_score # noqa F401 +from sklearn import metrics # noqa F401 +from sklearn.preprocessing import label_binarize # noqa F401 +from sklearn.neighbors import KNeighborsClassifier # noqa F401 + +from sklearn.model_selection import cross_val_score # noqa F401 +from sklearn.model_selection import KFold # noqa F401 +from sklearn.preprocessing import LabelEncoder # noqa F401 +from sklearn import model_selection # noqa F401 +from sklearn.metrics import roc_curve, auc, confusion_matrix # noqa F401 +from sklearn.model_selection import train_test_split # noqa F401 + +from sklearn.preprocessing import StandardScaler # noqa F401 +from sklearn.model_selection import StratifiedKFold # noqa F401 + +from tabulate import tabulate # noqa F401 + +import xgboost as xgb + +from bayes_opt import BayesianOptimization, UtilityFunction + + +def load_data(workdir='/opt/bdt_0409/ttHyyML_had_single', analysis_type='had'): + currentDir = os.path.dirname(os.path.realpath(__file__)) + + print("CurrentDir: %s" % currentDir) + print("WorkDir: %s" % workdir) + workdir = os.path.abspath(workdir) + print("Absolute WorkDir: %s" % workdir) + os.chdir(workdir) + sys.path.insert(0, workdir) + sys.argv = ['test'] + + if analysis_type == 'had': + from load_data_real_auto import load_data_real_hadronic + data, label = load_data_real_hadronic() + elif analysis_type == 'lep': + from load_data_real_auto import load_data_real_leptonic + data, label = load_data_real_leptonic() + sys.path.remove(workdir) + os.chdir(currentDir) + return data, label + + +def get_param(params, name, default): + if params and name in params: + return params[name] + return default + + +def getAUC(y_test, score, y_val_weight=None): + # fpr, tpr, _ = roc_curve(y_test, score, sample_weight=y_val_weight) + # roc_auc = auc(fpr, tpr, True) + print(y_test.shape) + print(score.shape) + if y_val_weight: + print(y_val_weight.shape) + return roc_auc_score(y_test, score, sample_weight=y_val_weight) + + +def getBrierScore(y_test, score): + return 1 - brier_score_loss(y_test, score) + + +def evaluateBrierScore(y_pred, data): + label = data.get_label() + return 'brierLoss', 1 - brier_score_loss(y_pred, label) + + +def getRMSE(y_test, score): + return mean_squared_error(y_test, score) ** 0.5 + + +def getLogLoss(y_test, score): + return log_loss(y_test, score) + + +def xgb_callback_save_model(model_name, period=1000): + def callback(env): + try: + bst, i, _ = env.model, env.iteration, env.end_iteration + if (i % period == 0): + bst.save_model(model_name) + except Exception: + print(traceback.format_exc()) + return callback + + +def train_bdt(input_x, input_y, params=None, retMethod=None, hist=True, saveModel=False, input_weight=None): + + train, val = input_x + y_train_cat, y_val_cat = input_y + if input_weight: + y_train_weight, y_val_weight = input_weight + else: + y_train_weight = None + y_val_weight = None + + train = train.reshape((train.shape[0], -1)) + val = val.reshape((val.shape[0], -1)) + + dTrain = xgb.DMatrix(train, label=y_train_cat, weight=y_train_weight) + dVal = xgb.DMatrix(val, label=y_val_cat, weight=y_val_weight) + + # train model + print('Train model.') + + # param = {'max_depth':10, 'eta':0.1, 'min_child_weight': 60, 'silent':1, 'objective':'binary:logistic', 'eval_metric': ['logloss', 'auc' ]} + # param = {'max_depth':10, 'eta':0.1, 'min_child_weight': 1, 'silent':1, 'objective':'rank:pairwise', 'eval_metric': ['auc','logloss']} + # def_params = {'max_depth':10, 'eta':0.005, 'min_child_weight': 15, 'silent':1, 'objective':'binary:logistic', 'eval_metric': ['auc','logloss']} + # def_params = {'colsample_bytree': 0.7, 'silent': 0, 'eval_metric': ['auc', 'logloss'], 'scale_pos_weight': 1.4, 'max_delta_step': 0, 'nthread': 8, 'min_child_weight': 160, 'subsample': 0.8, 'eta': 0.04, 'objective': 'binary:logistic', 'alpha': 0.1, 'lambda': 10, 'seed': 10, 'max_depth': 10, 'gamma': 0.03, 'booster': 'gbtree'} # noqa E501 + # def_params = {'colsample_bytree': 0.7, 'silent': 0, 'eval_metric': ['auc', 'logloss'], 'scale_pos_weight': 1.4, 'max_delta_step': 0, 'nthread': 8, 'min_child_weight': 160, 'subsample': 0.8, 'eta': 0.04, 'objective': 'binary:logistic', 'alpha': 0.1, 'lambda': 10, 'seed': 10, 'max_depth': 10, u'gamma': 0.5, 'booster': 'gbtree'} # noqa E501 + + # def_params = {'eval_metric': ['logloss', 'auc'], 'scale_pos_weight': 5.1067081406104631, 'max_delta_step': 4.6914331907848759, 'seed': 10, 'alpha': 0.1, 'booster': 'gbtree', 'colsample_bytree': 0.64067554676687111, 'nthread': 4, 'min_child_weight': 58, 'subsample': 0.76111573761360196, 'eta': 0.1966696564443787, 'objective': 'binary:logistic', 'max_depth': 10, 'gamma': 0.74055129530012553} # noqa E501 + + def_params = {} + if not params: + params = {} + if 'num_boost_round' not in params: + params['num_boost_round'] = 100000 + if 'objective' not in params: + params['objective'] = 'binary:logistic' + + for key in def_params: + if key not in params: + params[key] = def_params[key] + + if 'silent' not in params: + params['silent'] = 0 + + if hist: + params['tree_method'] = 'hist' + params['booster'] = 'gbtree' + params['grow_policy'] = 'lossguide' + params['nthread'] = 4 + params['booster'] = 'gbtree' + + start = time.time() + evallist = [(dTrain, 'train'), (dVal, 'eval')] + evals_result = {} + + try: + save_model_callback = xgb_callback_save_model(params['model'] + "temp" if params and 'model' in params else 'models/default_bdt_temp.h5') + # with early stop + if not saveModel: + # bst = xgb.train(params, dTrain, params['num_boost_round'], evals=evallist, early_stopping_rounds=10, evals_result=evals_result, verbose_eval=False, callbacks=[save_model_callback]) + bst = xgb.train(params, dTrain, params['num_boost_round'], evals=evallist, early_stopping_rounds=10, evals_result=evals_result, verbose_eval=True) + else: + bst = xgb.train(params, dTrain, params['num_boost_round'], evals=evallist, early_stopping_rounds=10, evals_result=evals_result, callbacks=[save_model_callback]) + # bst = xgb.train(params, dTrain, params['num_boost_round'], evals=evallist, early_stopping_rounds=10, evals_result=evals_result, feval=evaluateBrierScore, callbacks=[save_model_callback]) + # bst = xgb.train(params, dTrain, params['num_boost_round'], evals=evallist, evals_result=evals_result, callbacks=[save_model_callback]) + except KeyboardInterrupt: + print('Finishing on SIGINT.') + print("CPU Training time: %s" % (time.time() - start)) + + # test model + print('Test model.') + + score = bst.predict(dVal) + rmse = None + logloss = None + if 'num_class' in params and params['num_class'] == 3: + y_val_cat_binary = label_binarize(y_val_cat, classes=[0, 1, 2]) + aucValue = getAUC(y_val_cat_binary[:, 0], score[:, 0]) + aucValue1 = getAUC(y_val_cat_binary[:, 1], score[:, 1]) + aucValue2 = getAUC(y_val_cat_binary[:, 2], score[:, 2]) + + print("AUC: %s, %s, %s" % (aucValue, aucValue1, aucValue2)) + + bslValue = getBrierScore(y_val_cat_binary[:, 0], score[:, 0]) + bslValue1 = getBrierScore(y_val_cat_binary[:, 1], score[:, 1]) + bslValue2 = getBrierScore(y_val_cat_binary[:, 2], score[:, 2]) + + print("BrierScoreLoss: %s, %s, %s" % (bslValue, bslValue1, bslValue2)) + + rmse = getRMSE(y_val_cat_binary[:, 0], score[:, 0]) + logloss = getLogLoss(y_val_cat_binary[:, 0], score[:, 0]) + else: + aucValue = getAUC(y_val_cat, score) + bslValue = getBrierScore(y_val_cat, score) + rmse = None + logloss = None + auc = None # noqa F811 + if 'auc' in evals_result['eval']: + auc = evals_result['eval']['auc'][-1] + if 'rmse' in evals_result['eval']: + rmse = evals_result['eval']['rmse'][-1] + if 'logloss' in evals_result['eval']: + logloss = evals_result['eval']['logloss'][-1] + print("params: %s #, Val AUC: %s, BrierScoreLoss: %s, xgboost rmse: %s, xgboost logloss: %s, xgboost auc: %s" % (params, aucValue, bslValue, rmse, logloss, auc)) + rmse = getRMSE(y_val_cat, score) + logloss = getLogLoss(y_val_cat, score) + print("params: %s #, Val AUC: %s, BrierScoreLoss: %s, Val rmse: %s, Val logloss: %s" % (params, aucValue, bslValue, rmse, logloss)) + + # bst.save_model(params['model'] if params and 'model' in params else 'models/default_bdt.h5') + + print(bst.get_fscore()) + # print(bst.get_score()) + + try: + pass + # from matplotlib import pyplot + # print("Plot importance") + # xgb.plot_importance(bst) + # pyplot.savefig('plots/' + params['name'] if 'name' in params else 'default' + '_feature_importance.png') + # pyplot.savefig('plots/' + params['name'] if 'name' in params else 'default' + '_feature_importance.eps') + except Exception: + print(traceback.format_exc()) + + try: + history = {'loss': evals_result['train']['logloss'], 'val_loss': evals_result['eval']['logloss'], + 'acc': evals_result['train']['auc'], 'val_acc': evals_result['eval']['auc']} + except Exception: + print(traceback.format_exc()) + history = {} + + if retMethod: + if retMethod == 'auc': + return aucValue + if retMethod == 'rmse': + return rmse + if retMethod == 'brier': + return bslValue + if retMethod == 'logloss': + return logloss + return score, history + + +def evaluate_bdt(input_x, input_y, opt_params, retMethod=None, hist=True, saveModel=False, input_weight=None, **kwargs): + params = kwargs + if not params: + params = {} + if params and 'max_depth' in params: + params['max_depth'] = int(params['max_depth']) + if params and 'num_boost_round' in params: + params['num_boost_round'] = int(params['num_boost_round']) + if params and 'seed' in params: + params['seed'] = int(params['seed']) + if params and 'max_bin' in params: + params['max_bin'] = int(params['max_bin']) + # params[''] = int(params['']) + + if opt_params: + for opt in opt_params: + if opt not in params: + params[opt] = opt_params[opt] + + if retMethod and retMethod == 'auc': + params['eval_metric'] = ['rmse', 'logloss', 'auc'] + elif retMethod and retMethod == 'logloss': + params['eval_metric'] = ['rmse', 'auc', 'logloss'] + elif retMethod and retMethod == 'rmse': + params['eval_metric'] = ['logloss', 'auc', 'rmse'] + elif retMethod: + params['eval_metric'] = [retMethod] + + print(params) + auc = train_bdt(input_x, input_y, params=params, retMethod=retMethod, hist=hist, saveModel=saveModel, input_weight=input_weight) # noqa F811 + print("params: %s, ret: %s" % (params, auc)) + return auc + + +def optimize_bdt(input_x, input_y, opt_params, opt_method='auc', opt_ranges=None, hist=True, input_weight=None): + eval_params = { + 'colsample_bytree': (0.1, 1), + 'scale_pos_weight': (0, 10), + 'max_delta_step': (0, 10), + 'seed': (1, 50), + 'min_child_weight': (0, 100), + 'subsample': (0.1, 1), + 'eta': (0, 0.1), + 'alpha': (0, 1), + # 'lambda': (0, 100), + 'max_depth': (0, 50), + 'gamma': (0, 1), + # 'num_boost_round': (100000, 1000000) + } + + explore_params1 = { # noqa F841 + # 'eta': [0.001, 0.002, 0.003, 0.004, 0.005, 0.006, 0.007, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08], # noqa E126 + 'eta': [0.001, 0.004, 0.007, 0.03, 0.05, 0.07], + # 'scale_pos_weight': [1, 2, 3, 5, 7, 8], + 'colsample_bytree': [0.67, 0.79, 0.76, 0.5, 0.4, 0.85], + 'scale_pos_weight': [4.9, 2.8, 2.1, 4.7, 1.7, 4], + 'max_delta_step': [5, 2.2, 1.67, 2.53, 8.8, 8], + 'seed': [10, 20, 30, 40, 50, 60], + 'min_child_weight': [50, 74, 53, 14, 45, 30], + 'subsample': [0.78, 0.82, 0.6, 0.87, 0.5, 0.7], + 'alpha': [0.1, 0.2, 0.3, 0.05, 0.15, 0.25], + # 'lambda': [50], + 'max_depth': [6, 7, 10, 14, 19, 28], + 'gamma': [0.47, 0.19, 0.33, 0.43, 0.5, 0.76], + # 'num_boost_round': [1000000] + } + + explore_params = { # noqa F841 + 'eta': [0.004, 0.03], # noqa E126 + # 'scale_pos_weight': [3, 7], + 'colsample_bytree': [0.67, 0.4], + 'scale_pos_weight': [4.9, 1.7], + 'max_delta_step': [2.2, 8], + 'seed': [10, 50], + 'min_child_weight': [50, 30], + 'subsample': [0.78, 0.5], + 'alpha': [0.2, 0.25], + # 'lambda': [50], + 'max_depth': [10, 28], + 'gamma': [0.47, 0.76], + # 'num_boost_round': [1000000] + } + + print("Eval: %s" % eval_params) + optFunc = lambda **z: evaluate_bdt(input_x, input_y, opt_params, opt_method, hist=hist, saveModel=False, input_weight=input_weight, **z) # noqa F731 + bayesopt = BayesianOptimization(optFunc, eval_params) + # bayesopt.explore(explore_params) + bayesopt.maximize(init_points=3, n_iter=5) + # bayesopt.maximize(init_points=30, n_iter=200) + # bayesopt.maximize(init_points=2, n_iter=2) + # print(bayesopt.res) + p = bayesopt.max + print("Best params: %s" % p) + + +def optimize(): + data, label = load_data() + train, val = data + y_train_cat, y_val_cat = label + + opt_method = 'auc' + opt_ranges = {"subsample": [0.10131415926000001, 1], + "eta": [0.0100131415926, 0.03], + "colsample_bytree": [0.10131415926000001, 1], + "gamma": [0.00131415926, 1], + "alpha": [0.00131415926, 1], + "max_delta_step": [0.00131415926, 10], + "max_depth": [5.00131415926, 50], + "min_child_weight": [0.00131415926, 100]} + params = {'num_boost_round': 10} + + optimize_bdt(input_x=[train, val], input_y=[y_train_cat, y_val_cat], opt_params=params, + opt_ranges=opt_ranges, opt_method=opt_method, hist=True, input_weight=None) + + +def get_unique_id_for_dict(dict_): + ret = hashlib.sha1(json.dumps(dict_, sort_keys=True).encode()).hexdigest() + return ret + + +def optimize_bdt1(input_x, input_y, opt_params, opt_method='auc', opt_ranges=None, hist=True, input_weight=None): + eval_params = { + 'colsample_bytree': (0.1, 1), + 'scale_pos_weight': (0, 10), + 'max_delta_step': (0, 10), + 'seed': (1, 50), + 'min_child_weight': (0, 100), + 'subsample': (0.1, 1), + 'eta': (0, 0.1), + 'alpha': (0, 1), + # 'lambda': (0, 100), + 'max_depth': (0, 50), + 'gamma': (0, 1), + # 'num_boost_round': (100000, 1000000), + } + + print("Eval: %s" % eval_params) + optFunc = lambda **z: evaluate_bdt(input_x, input_y, opt_params, opt_method, hist=hist, saveModel=False, input_weight=input_weight, **z) # noqa F731 + bayesopt = BayesianOptimization(optFunc, eval_params) + util = UtilityFunction(kind='ucb', + kappa=2.576, + xi=0.0, + kappa_decay=1, + kappa_decay_delay=0) + + n_iterations, n_points_per_iteration = 3, 5 + for i in range(n_iterations): + points = {} + for j in range(n_points_per_iteration): + x_probe = bayesopt.suggest(util) + u_id = get_unique_id_for_dict(x_probe) + print('x_probe (%s): %s' % (u_id, x_probe)) + points[u_id] = {'kwargs': x_probe} + ret = evaluate_bdt(input_x, input_y, opt_params, opt_method, hist=hist, saveModel=False, input_weight=input_weight, **x_probe) + print('ret :%s' % ret) + points[u_id]['ret'] = ret + bayesopt.register(x_probe, ret) + + # bayesopt.explore(explore_params) + # bayesopt.maximize(init_points=3, n_iter=5) + # bayesopt.maximize(init_points=30, n_iter=200) + # bayesopt.maximize(init_points=2, n_iter=2) + print(bayesopt.res) + p = bayesopt.max + print("Best params: %s" % p) + + +def get_bayesian_optimizer_and_util(func, opt_params): + bayesopt = BayesianOptimization(func, opt_params) + util = UtilityFunction(kind='ucb', + kappa=2.576, + xi=0.0, + kappa_decay=1, + kappa_decay_delay=0) + return bayesopt, util + + +def optimize1(): + data, label = load_data() + train, val = data + y_train_cat, y_val_cat = label + + opt_method = 'auc' + opt_ranges = {"subsample": [0.10131415926000001, 1], + "eta": [0.0100131415926, 0.03], + "colsample_bytree": [0.10131415926000001, 1], + "gamma": [0.00131415926, 1], + "alpha": [0.00131415926, 1], + "max_delta_step": [0.00131415926, 10], + "max_depth": [5.00131415926, 50], + "min_child_weight": [0.00131415926, 100]} + params = {'num_boost_round': 10} + + optimize_bdt1(input_x=[train, val], input_y=[y_train_cat, y_val_cat], opt_params=params, + opt_ranges=opt_ranges, opt_method=opt_method, hist=True, input_weight=None) + + +if __name__ == '__main__': + optimize1() diff --git a/main/lib/idds/tests/test_iworkflow/optimize_iworkflow.py b/main/lib/idds/tests/test_iworkflow/optimize_iworkflow.py new file mode 100644 index 00000000..52e99437 --- /dev/null +++ b/main/lib/idds/tests/test_iworkflow/optimize_iworkflow.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0OA +# +# Authors: +# - Wen Guan, , 2024 + + +""" +Test workflow. +""" + +import inspect # noqa F401 +import logging +import os # noqa F401 +import shutil # noqa F401 +import sys # noqa F401 + +# from nose.tools import assert_equal +# from idds.common.imports import get_func_name +from idds.common.utils import setup_logging, get_unique_id_for_dict + +from idds.iworkflow.workflow import Workflow # workflow +from idds.iworkflow.work import work + + +setup_logging(__name__) + + +def get_initial_parameter(): + opt_params = { + 'colsample_bytree': (0.1, 1), + 'scale_pos_weight': (0, 10), + 'max_delta_step': (0, 10), + 'seed': (1, 50), + 'min_child_weight': (0, 100), + 'subsample': (0.1, 1), + 'eta': (0, 0.1), + 'alpha': (0, 1), + # 'lambda': (0, 100), + 'max_depth': (0, 50), + 'gamma': (0, 1), + # 'num_boost_round': (100000, 1000000) + } + return opt_params + + +@work(map_results=True) +def optimize_work(opt_params, retMethod=None, hist=True, saveModel=False, input_weight=None, **kwargs): + from optimize import evaluate_bdt, load_data + + data, label = load_data() + train, val = data + y_train_cat, y_val_cat = label + input_x = [train, val] + input_y = [y_train_cat, y_val_cat] + + ret = evaluate_bdt(input_x=input_x, input_y=input_y, opt_params=opt_params, retMethod=retMethod, hist=hist, + saveModel=saveModel, input_weight=input_weight, **kwargs) + return ret + + +def optimize_workflow(): + from optimize import evaluate_bdt, get_bayesian_optimizer_and_util + + opt_method = 'auc' + params = {'num_boost_round': 1000} + + opt_params = get_initial_parameter() + + print("To optimize with parameters: %s" % opt_params) + + optFunc = lambda **z: evaluate_bdt(input_x=None, input_y=None, + opt_params=params, opt_method=opt_method, + hist=True, saveModel=False, input_weight=None, **z) # noqa F731 + bayesopt, util = get_bayesian_optimizer_and_util(optFunc, opt_params) + + n_iterations, n_points_per_iteration = 10, 20 + for i in range(n_iterations): + print("Iteration %s" % i) + points = {} + group_kwargs = [] + for j in range(n_points_per_iteration): + x_probe = bayesopt.suggest(util) + u_id = get_unique_id_for_dict(x_probe) + print('x_probe (%s): %s' % (u_id, x_probe)) + points[u_id] = {'kwargs': x_probe} + group_kwargs.append(x_probe) + + results = optimize_work(opt_params=params, opt_method=opt_method, hist=True, saveModel=False, input_weight=None, + retMethod=opt_method, group_kwargs=group_kwargs) + print("points: %s" % str(points)) + + for u_id in points: + points[u_id]['ret'] = results.get_result(name=None, args=points[u_id]['kwargs']) + print('ret :%s, kwargs: %s' % (points[u_id]['ret'], points[u_id]['kwargs'])) + bayesopt.register(points[u_id]['kwargs'], points[u_id]['ret']) + + print(bayesopt.res) + p = bayesopt.max + print("Best params: %s" % p) + return p + + +if __name__ == '__main__': + logging.info("start") + os.chdir(os.path.dirname(os.path.realpath(__file__))) + # wf = Workflow(func=test_workflow, service='idds', distributed=False) + # init_env = 'singularity exec /eos/user/w/wguan/idds_ml/singularity/idds_ml_al9.simg ' + init_env = 'singularity exec /afs/cern.ch/user/w/wguan/workdisk/iDDS/test/eic/idds_ml_al9.simg ' + wf = Workflow(func=optimize_workflow, service='idds', init_env=init_env) + + # wf.queue = 'BNL_OSG_2' + wf.queue = 'FUNCX_TEST' + wf.cloud = 'US' + + logging.info("prepare workflow") + wf.prepare() + logging.info("prepared workflow") + + wf.submit() diff --git a/main/lib/idds/tests/test_iworkflow/test_asyncresults.py b/main/lib/idds/tests/test_iworkflow/test_asyncresults.py new file mode 100644 index 00000000..e662d69e --- /dev/null +++ b/main/lib/idds/tests/test_iworkflow/test_asyncresults.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0OA +# +# Authors: +# - Wen Guan, , 2024 + + +import inspect # noqa F401 +import logging # noqa F401 +import os # noqa F401 +import shutil # noqa F401 +import sys # noqa F401 + +# from nose.tools import assert_equal +from idds.common.utils import setup_logging + +from idds.iworkflow.asyncresult import AsyncResult +from idds.iworkflow.workflow import Workflow # workflow + +setup_logging(__name__) +# setup_logging(__name__, loglevel='debug') +# logging.getLogger('stomp').setLevel(logging.DEBUG) + + +def test_workflow(): + print("test workflow starts") + print('idds') + print("test workflow ends") + + +if __name__ == '__main__': + + wf = Workflow(func=test_workflow, service='idds') + + workflow_context = wf._context + + logging.info("Test AsyncResult") + a_ret = AsyncResult(workflow_context, wait_num=1, timeout=30) + a_ret.subscribe() + + async_ret = AsyncResult(workflow_context, internal_id=a_ret.internal_id) + test_result = "AsyncResult test (request_id: %s)" % (workflow_context.request_id) + logging.info("AsyncResult publish: %s" % test_result) + async_ret.publish(test_result) + + ret_q = a_ret.wait_result() + logging.info("AsyncResult results: %s" % str(ret_q)) diff --git a/main/lib/idds/tests/test_iworkflow/test_imports.py b/main/lib/idds/tests/test_iworkflow/test_imports.py new file mode 100644 index 00000000..9e4b5e23 --- /dev/null +++ b/main/lib/idds/tests/test_iworkflow/test_imports.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0OA +# +# Authors: +# - Wen Guan, , 2024 + + +""" +Test workflow. +""" + +# import datetime +# import inspect +# import sys + +from idds.common.imports import import_func, get_func_name + + +func_name = "test_iworkflow1.py:__main__:test_workflow1" +func = import_func(func_name) +print(func) +func() + +# sys.exit(1) + +for func_name in ["test_iworkflow1.py:__main__:test_workflow2", "test_iworkflow1.py:__main__:test_workflow1"]: + func = import_func(func_name) + print(func) + func() + + func_name = get_func_name(func) + print(func_name) + + func = import_func(func_name) + print(func) + func() diff --git a/main/lib/idds/tests/test_iworkflow/test_iworkflow.py b/main/lib/idds/tests/test_iworkflow/test_iworkflow.py new file mode 100644 index 00000000..4482d873 --- /dev/null +++ b/main/lib/idds/tests/test_iworkflow/test_iworkflow.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0OA +# +# Authors: +# - Wen Guan, , 2024 + + +""" +Test workflow. +""" + +import inspect # noqa F401 +import logging +import os # noqa F401 +import shutil # noqa F401 +import sys # noqa F401 + +# from nose.tools import assert_equal +from idds.common.utils import setup_logging, run_process, json_dumps, json_loads, create_archive_file + +from idds.iworkflow.workflow import Workflow # workflow +from idds.iworkflow.work import work + + +setup_logging(__name__) + + +@work +def test_func(name): + print('test_func starts') + print(name) + print('test_func ends') + return 'test result: %s' % name + + +def test_func1(name): + print('test_func1 starts') + print(name) + print('test_func1 ends') + return 'test result: %s' % name + + +def test_workflow(): + print("test workflow starts") + ret = test_func(name='idds') + print(ret) + print("test workflow ends") + + +@work +def get_params(): + list_params = [i for i in range(10)] + return list_params + + +def test_workflow_mulitple_work(): + print("test workflow multiple work starts") + list_params = get_params() + + ret = test_func(list_params) + print(ret) + print("test workflow multiple work ends") + + +def submit_workflow(wf): + req_id = wf.submit() + print("req id: %s" % req_id) + + +def run_workflow_wrapper(wf): + cmd = wf.get_runner() + logging.info(f'To run workflow: {cmd}') + + exit_code = run_process(cmd, wait=True) + logging.info(f'Run workflow finished with exit code: {exit_code}') + return exit_code + + +def run_workflow_remote_wrapper(wf): + cmd = wf.get_runner() + logging.info('To run workflow: %s' % cmd) + + work_dir = '/tmp/idds' + shutil.rmtree(work_dir) + os.makedirs(work_dir) + os.chdir(work_dir) + logging.info("current dir: %s" % os.getcwd()) + + # print(dir(wf)) + # print(inspect.getmodule(wf)) + # print(inspect.getfile(wf)) + setup = wf.setup_source_files() + logging.info("setup: %s" % setup) + + exc_cmd = 'cd %s' % work_dir + exc_cmd += "; wget https://wguan-wisc.web.cern.ch/wguan-wisc/run_workflow_wrapper" + exc_cmd += "; chmod +x run_workflow_wrapper; bash run_workflow_wrapper %s" % cmd + logging.info("exc_cmd: %s" % exc_cmd) + exit_code = run_process(exc_cmd, wait=True) + logging.info(f'Run workflow finished with exit code: {exit_code}') + return exit_code + + +def test_create_archive_file(wf): + archive_name = wf._context.get_archive_name() + source_dir = wf._context._source_dir + logging.info("archive_name :%s, source dir: %s" % (archive_name, source_dir)) + archive_file = create_archive_file('/tmp', archive_name, [source_dir]) + logging.info("created archive file: %s" % archive_file) + + +if __name__ == '__main__': + logging.info("start") + os.chdir(os.path.dirname(os.path.realpath(__file__))) + # wf = Workflow(func=test_workflow, service='idds', distributed=False) + wf = Workflow(func=test_workflow, service='idds') + + # wf.queue = 'BNL_OSG_2' + wf.queue = 'FUNCX_TEST' + wf.cloud = 'US' + + wf_json = json_dumps(wf) + # print(wf_json) + wf_1 = json_loads(wf_json) + + # test_create_archive_file(wf) + + # sys.exit(0) + + logging.info("prepare workflow") + wf.prepare() + logging.info("prepared workflow") + + wf.submit() + + # logging.info("run_workflow_wrapper") + # run_workflow_wrapper(wf) + + # logging.info("run_workflow_remote_wrapper") + # run_workflow_remote_wrapper(wf) diff --git a/main/lib/idds/tests/test_iworkflow/test_iworkflow1.py b/main/lib/idds/tests/test_iworkflow/test_iworkflow1.py new file mode 100644 index 00000000..b4abb1cf --- /dev/null +++ b/main/lib/idds/tests/test_iworkflow/test_iworkflow1.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0OA +# +# Authors: +# - Wen Guan, , 2024 + + +""" +Test workflow. +""" + +import datetime +import inspect + +# import unittest2 as unittest +# from nose.tools import assert_equal +from idds.common.utils import setup_logging + + +from idds.iworkflow.work import work +from idds.iworkflow.workflow import workflow + +# from idds.iworkflow.utils import perform_workflow, run_workflow + +from idds.common.utils import json_dumps, run_process, encode_base64 + + +setup_logging(__name__) + + +@work +def test_func(name): + print('test_func starts') + print(name) + print('test_func ends') + return 'test result: %s' % name + + +def test_workflow(): + print("test workflow starts") + ret = test_func(name='idds') + print(ret) + print("test workflow ends") + + +@work +def get_params(): + list_params = [i for i in range(10)] + return list_params + + +def test_workflow_mulitple_work(): + print("test workflow multiple work starts") + list_params = get_params() + + ret = test_func(list_params) + print(ret) + print("test workflow multiple work ends") + + +def submit_workflow(workflow): + req_id = workflow.submit() + print("req id: %s" % req_id) + + +def perform_workflow_wrapper(workflow): + # ret = perform_workflow(workflow) + # print(ret) + setup = workflow.setup() + cmd = setup + ";" + return cmd + + +def run_workflow_wrapper(workflow): + setup = workflow.setup() + + cmd = setup + "; perform_workflow --workflow " + encode_base64(json_dumps(workflow)) + print(f'To run workflow: {cmd}') + exit_code = run_process(cmd) + print(f'Run workflow finished with exit code: {exit_code}') + return exit_code + + +@workflow +def test_workflow1(): + print("test workflow starts") + # ret = test_func(name='idds') + # print(ret) + print("test workflow ends") + + +def test_workflow2(): + print("test workflow2 starts") + # ret = test_func(name='idds') + # print(ret) + print("test workflow2 ends") + + +if __name__ == '__main__': + print("datetime.datetime object") + f = datetime.datetime(2019, 10, 10, 10, 10) + print(dir(f)) + print(inspect.getmodule(f)) + # print(inspect.getfile(f)) + # print(inspect.getmodule(f).__name__) + # print(inspect.getmodule(f).__file__) + # print(inspect.signature(f)) + + print("datetime.datetime function") + f = datetime.datetime + print(dir(f)) + print(f.__module__) + print(f.__name__) + print(inspect.getmodule(f)) + print(inspect.getfile(f)) + print(inspect.getmodule(f).__name__) + print(inspect.getmodule(f).__file__) + # print(inspect.signature(f)) + + print("test_workflow function") + f = test_workflow + print(dir(f)) + print(f.__module__) + print(f.__name__) + print(inspect.getmodule(f)) + print(inspect.getfile(f)) + print(inspect.getmodule(f).__name__) + print(inspect.getmodule(f).__file__) + print(inspect.signature(f)) + + print("test_workflow1") + test_workflow1() diff --git a/main/lib/idds/tests/test_iworkflow/test_iworkflow_mul.py b/main/lib/idds/tests/test_iworkflow/test_iworkflow_mul.py new file mode 100644 index 00000000..122728dc --- /dev/null +++ b/main/lib/idds/tests/test_iworkflow/test_iworkflow_mul.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0OA +# +# Authors: +# - Wen Guan, , 2024 + + +""" +Test workflow. +""" + +import inspect # noqa F401 +import logging +import os # noqa F401 +import shutil # noqa F401 +import sys # noqa F401 + +# from nose.tools import assert_equal +from idds.common.utils import setup_logging, run_process, json_dumps, json_loads, create_archive_file + +from idds.iworkflow.workflow import Workflow # workflow +from idds.iworkflow.work import work + + +setup_logging(__name__) + + +@work +def test_func(name, name1=None): + print('test_func starts') + print(name) + print(name1) + print('test_func ends') + return 'test result: %s, %s' % (name, name1) + + +@work(map_results=True) +def test_func1(name, name1=None): + print('test_func1 starts') + print(name) + print(name1) + print('test_func1 ends') + return 'test result1: %s, %s' % (name, name1) + + +def test_workflow(): + print("test workflow starts") + group_kwargs = [{'name': 1, 'name1': 2}, + {'name': 3, 'name1': 4}] + ret = test_func(name='idds', group_kwargs=group_kwargs) + print(ret) + + ret = test_func1(name='idds', group_kwargs=group_kwargs) + print(ret) + print("test workflow ends") + + +@work +def get_params(): + list_params = [i for i in range(10)] + return list_params + + +def test_workflow_mulitple_work(): + print("test workflow multiple work starts") + list_params = get_params() + + ret = test_func(list_params) + print(ret) + print("test workflow multiple work ends") + + +def submit_workflow(wf): + req_id = wf.submit() + print("req id: %s" % req_id) + + +def run_workflow_wrapper(wf): + cmd = wf.get_runner() + logging.info(f'To run workflow: {cmd}') + + exit_code = run_process(cmd, wait=True) + logging.info(f'Run workflow finished with exit code: {exit_code}') + return exit_code + + +def run_workflow_remote_wrapper(wf): + cmd = wf.get_runner() + logging.info('To run workflow: %s' % cmd) + + work_dir = '/tmp/idds' + shutil.rmtree(work_dir) + os.makedirs(work_dir) + os.chdir(work_dir) + logging.info("current dir: %s" % os.getcwd()) + + # print(dir(wf)) + # print(inspect.getmodule(wf)) + # print(inspect.getfile(wf)) + setup = wf.setup_source_files() + logging.info("setup: %s" % setup) + + exc_cmd = 'cd %s' % work_dir + exc_cmd += "; wget https://wguan-wisc.web.cern.ch/wguan-wisc/run_workflow_wrapper" + exc_cmd += "; chmod +x run_workflow_wrapper; bash run_workflow_wrapper %s" % cmd + logging.info("exc_cmd: %s" % exc_cmd) + exit_code = run_process(exc_cmd, wait=True) + logging.info(f'Run workflow finished with exit code: {exit_code}') + return exit_code + + +def test_create_archive_file(wf): + archive_name = wf._context.get_archive_name() + source_dir = wf._context._source_dir + logging.info("archive_name :%s, source dir: %s" % (archive_name, source_dir)) + archive_file = create_archive_file('/tmp', archive_name, [source_dir]) + logging.info("created archive file: %s" % archive_file) + + +if __name__ == '__main__': + logging.info("start") + os.chdir(os.path.dirname(os.path.realpath(__file__))) + # wf = Workflow(func=test_workflow, service='idds', distributed=False) + wf = Workflow(func=test_workflow, service='idds') + + # wf.queue = 'BNL_OSG_2' + wf.queue = 'FUNCX_TEST' + wf.cloud = 'US' + + wf_json = json_dumps(wf) + # print(wf_json) + wf_1 = json_loads(wf_json) + + # test_create_archive_file(wf) + + # sys.exit(0) + + logging.info("prepare workflow") + wf.prepare() + logging.info("prepared workflow") + + wf.submit() + + # logging.info("run_workflow_wrapper") + # run_workflow_wrapper(wf) + + # logging.info("run_workflow_remote_wrapper") + # run_workflow_remote_wrapper(wf) diff --git a/main/lib/idds/tests/test_iworkflow/test_plot.py b/main/lib/idds/tests/test_iworkflow/test_plot.py new file mode 100644 index 00000000..c5560c62 --- /dev/null +++ b/main/lib/idds/tests/test_iworkflow/test_plot.py @@ -0,0 +1,70 @@ +import matplotlib.pyplot as plt + + +def get_param_result_map(points_list, results_list): + ret = {} + for points, results in zip(points_list, results_list): + for key in points: + if key in results: + ret[key] = {'kwargs': points[key]['kwargs'], 'ret': results[key]} + else: + print("key %s not in results") + return ret + + +def draw_plots(ret, xkey): + # print(ret) + x_y_map = {} + for key in ret: + y = ret[key]['ret'] + kwargs = ret[key]['kwargs'] + # print(kwargs) + x = kwargs[xkey] + x_y_map[x] = y + + # print(x_y_map) + + x_list = list(x_y_map.keys()) + x_list.sort() + y_list = [x_y_map[x] for x in x_list] + + # plot the data + fig = plt.figure() + ax = fig.add_subplot(1, 1, 1) + ax.plot(x_list, y_list, color='tab:blue') + + # set the limits + ax.set_xlim([0, 0.1]) + ax.set_ylim([0, 1]) + + ax.set_xlabel(xkey) + ax.set_ylabel('AUC') + + ax.set_title('AUC with %s' % xkey) + + # display the plot + plt.show() + + +points1 = {'f0715a5cac2032cfc5e3cc9f3e3357813a565cab': {'kwargs': {'alpha': 0.7370499824147523, 'colsample_bytree': 0.7052001671810584, 'eta': 0.061539107095042206, 'gamma': 0.3768972446409322, 'max_delta_step': 5.559055959338833, 'max_depth': 34.455202948905296, 'min_child_weight': 27.251073306284823, 'scale_pos_weight': 8.254507265058038, 'seed': 5.731980219696928, 'subsample': 0.376002012129681}}, '8a191dd9c4b4625ee8e4b30cac582d932ec68011': {'kwargs': {'alpha': 0.5796381619950467, 'colsample_bytree': 0.3603926073873194, 'eta': 0.08047465769383089, 'gamma': 0.9008199093397357, 'max_delta_step': 1.920273836975528, 'max_depth': 4.779348283851704, 'min_child_weight': 82.8789096714284, 'scale_pos_weight': 9.928966624328437, 'seed': 32.229443478406495, 'subsample': 0.315629175902784}}, '53777c8f07b2e83bbb0890adb9f0f56adc1b6c4d': {'kwargs': {'alpha': 0.5996223120415098, 'colsample_bytree': 0.23516105244338611, 'eta': 0.0033226331556943213, 'gamma': 0.9922203935242493, 'max_delta_step': 1.7465652061444947, 'max_depth': 15.723780695793998, 'min_child_weight': 88.42520948613331, 'scale_pos_weight': 4.174890056932079, 'seed': 1.4249320110647643, 'subsample': 0.27191037032614307}}, 'd808f47c5e63ef8f82497ae6aad05ffba20193db': {'kwargs': {'alpha': 0.6939188950911951, 'colsample_bytree': 0.7111335846873444, 'eta': 0.007635640601150418, 'gamma': 0.12176011951112908, 'max_delta_step': 8.862138412738078, 'max_depth': 48.139115205423444, 'min_child_weight': 59.58231509839007, 'scale_pos_weight': 0.8539533058791215, 'seed': 14.732825128475884, 'subsample': 0.2217087973782852}}, '312d603aa80fe5054fac7749449cff5801c7f7b8': {'kwargs': {'alpha': 0.48096280852811757, 'colsample_bytree': 0.35348412204311486, 'eta': 0.026102697075971827, 'gamma': 0.3260982955502004, 'max_delta_step': 0.9610973179531157, 'max_depth': 20.470106656255936, 'min_child_weight': 72.30625266803096, 'scale_pos_weight': 1.9498819699168735, 'seed': 47.44404774296532, 'subsample': 0.46716798126713954}}, '053d6eaf336584152287c19cb6d80dd7e988dc5f': {'kwargs': {'alpha': 0.054033996041948695, 'colsample_bytree': 0.9782353872575908, 'eta': 0.01859000904795547, 'gamma': 0.004300658642705013, 'max_delta_step': 7.359231033895142, 'max_depth': 4.910291573708891, 'min_child_weight': 57.125856279726904, 'scale_pos_weight': 9.245448333326683, 'seed': 40.42671404631153, 'subsample': 0.5592057464441867}}, '47fcec81d02c935e577d68d645a851411646c850': {'kwargs': {'alpha': 0.29069710338203425, 'colsample_bytree': 0.6682782283104473, 'eta': 0.01682694789712872, 'gamma': 0.3047694836865329, 'max_delta_step': 5.176571289757744, 'max_depth': 12.21515248071105, 'min_child_weight': 17.607898836542056, 'scale_pos_weight': 0.10758844277118595, 'seed': 12.233262459745259, 'subsample': 0.1504837565060477}}, '0debbccbbad1898b0ace911abc7b194a46d133ae': {'kwargs': {'alpha': 0.2755720157659465, 'colsample_bytree': 0.8103778255898365, 'eta': 0.07465707041461515, 'gamma': 0.9840186804101336, 'max_delta_step': 9.848043855766889, 'max_depth': 44.82707132597007, 'min_child_weight': 16.746069592164893, 'scale_pos_weight': 8.467612139929722, 'seed': 23.119090106253118, 'subsample': 0.2558395325351913}}, 'c84803cf3a931663af3a86797483676e881b891e': {'kwargs': {'alpha': 0.10002005750895016, 'colsample_bytree': 0.2952004213836158, 'eta': 0.058566469003877464, 'gamma': 0.6553658771012805, 'max_delta_step': 4.357798491613893, 'max_depth': 23.77837224633501, 'min_child_weight': 9.554246397836664, 'scale_pos_weight': 7.569666848173825, 'seed': 17.273448881690502, 'subsample': 0.7481747369549834}}, 'b5c10d0cc3c56f06e4284b34101eabf7ac9cc1a0': {'kwargs': {'alpha': 0.7581625345381989, 'colsample_bytree': 0.6288180373263546, 'eta': 0.0753961020307722, 'gamma': 0.5293168324266648, 'max_delta_step': 3.8720508345941065, 'max_depth': 11.155505790856868, 'min_child_weight': 63.711879483371746, 'scale_pos_weight': 9.730873199926606, 'seed': 30.389636733887784, 'subsample': 0.6623138628962174}}, '25c787b0f29835df85766b7f6f37c58a8afb722a': {'kwargs': {'alpha': 0.3227375050238718, 'colsample_bytree': 0.7806423105910397, 'eta': 0.06892746201974502, 'gamma': 0.1246605468633577, 'max_delta_step': 1.564983697804062, 'max_depth': 22.590472089219443, 'min_child_weight': 33.832950599220055, 'scale_pos_weight': 9.293021509583099, 'seed': 28.077731789771665, 'subsample': 0.9649082999834623}}, 'a4c62a005e50b30563c8136b65b7ac4ab35ef248': {'kwargs': {'alpha': 0.5410005589757522, 'colsample_bytree': 0.1591859943945069, 'eta': 0.0827309462667341, 'gamma': 0.5968317340758771, 'max_delta_step': 8.2359689740328, 'max_depth': 27.395473217885137, 'min_child_weight': 84.62647470946146, 'scale_pos_weight': 4.902995991585466, 'seed': 20.41216463813878, 'subsample': 0.19914362945322717}}, '207f893919165287721a352ad8985f6bbaf2587e': {'kwargs': {'alpha': 0.48815370426766747, 'colsample_bytree': 0.6313417505984559, 'eta': 0.03625857372140906, 'gamma': 0.5941348348536, 'max_delta_step': 8.30415838532181, 'max_depth': 39.90128227863364, 'min_child_weight': 20.906948897302247, 'scale_pos_weight': 9.178432257293121, 'seed': 38.91404909117018, 'subsample': 0.4817191405100111}}, '9a0defa184465033ca1d7dd97c845af05d60af62': {'kwargs': {'alpha': 0.27937613955144236, 'colsample_bytree': 0.34957037690598636, 'eta': 0.06668600700091035, 'gamma': 0.20330959648954072, 'max_delta_step': 4.305258190092057, 'max_depth': 15.094946280049882, 'min_child_weight': 25.082827785015503, 'scale_pos_weight': 5.419478346554111, 'seed': 17.518303380357107, 'subsample': 0.5167279400928934}}, '509ce5e826fbea8815baa534657940c910cfc55f': {'kwargs': {'alpha': 0.0717960988827191, 'colsample_bytree': 0.9926020256373844, 'eta': 0.014105256518578269, 'gamma': 0.7410022267038714, 'max_delta_step': 1.1627709318345791, 'max_depth': 28.891784313398833, 'min_child_weight': 30.40921594272542, 'scale_pos_weight': 3.2135244845279765, 'seed': 18.067389079678517, 'subsample': 0.9024437144084017}}, 'faadeebe4f5f480b0b433538591e6c26ef72b6ae': {'kwargs': {'alpha': 0.4422868803573755, 'colsample_bytree': 0.9491471435222574, 'eta': 0.06663549994152372, 'gamma': 0.5276256553916654, 'max_delta_step': 5.094908198006554, 'max_depth': 46.14858637049848, 'min_child_weight': 23.601495233420767, 'scale_pos_weight': 7.2124962492250235, 'seed': 48.01599680319991, 'subsample': 0.9699822533939738}}, '8350a99b70489adeefeda23e2219f07fd4be87d5': {'kwargs': {'alpha': 0.1653502027975029, 'colsample_bytree': 0.807084495324573, 'eta': 0.02126694837679881, 'gamma': 0.1475901533664299, 'max_delta_step': 6.239721969537473, 'max_depth': 14.044864339684693, 'min_child_weight': 70.51447946940795, 'scale_pos_weight': 8.92584040580361, 'seed': 35.16151295640776, 'subsample': 0.419909640995761}}, 'ad09c5fc5b4c983923ee1067bb93c6aa73bfe04f': {'kwargs': {'alpha': 0.9697869355335356, 'colsample_bytree': 0.6210305954871612, 'eta': 0.05328069143710448, 'gamma': 0.9129411876267608, 'max_delta_step': 0.7700718122531136, 'max_depth': 32.61481152259293, 'min_child_weight': 68.09955158441387, 'scale_pos_weight': 9.284108348629527, 'seed': 26.312279555846022, 'subsample': 0.9147227373822518}}, 'a8aae82962012c58efc4201a76375c94be729294': {'kwargs': {'alpha': 0.5954531194968008, 'colsample_bytree': 0.9529414821398111, 'eta': 0.018380590700565004, 'gamma': 0.7748281465026012, 'max_delta_step': 1.3073360370678921, 'max_depth': 37.624755306746685, 'min_child_weight': 63.87154457082049, 'scale_pos_weight': 4.031115553919237, 'seed': 41.4365373661436, 'subsample': 0.48050134500529507}}, 'cea3a26032b72a654ee143e328f09fcd3bc5d6ff': {'kwargs': {'alpha': 0.5040644966135356, 'colsample_bytree': 0.32922216176752694, 'eta': 0.06107475218204287, 'gamma': 0.5491973347269214, 'max_delta_step': 5.299663024545489, 'max_depth': 43.91575293473623, 'min_child_weight': 15.00083630214336, 'scale_pos_weight': 8.560480553359616, 'seed': 13.431798160907203, 'subsample': 0.5343724921934745}}} # noqa E501 + +points2 = {'be14c556f8a841b40db33a74c478eb3f4cb23755': {'kwargs': {'alpha': 0.4580153943125277, 'colsample_bytree': 0.1, 'eta': 0.1, 'gamma': 0.3045870705362288, 'max_delta_step': 10.0, 'max_depth': 40.1474578788322, 'min_child_weight': 96.97116605726112, 'scale_pos_weight': 10.0, 'seed': 47.2831737268388, 'subsample': 0.1}}, 'df0ce2ee19c418ffc8e830a91510f069d5a875b1': {'kwargs': {'alpha': 0.04215955459538778, 'colsample_bytree': 0.29086541978112146, 'eta': 0.05866151696155121, 'gamma': 0.8311794942302044, 'max_delta_step': 1.60720335305918, 'max_depth': 19.213935683299923, 'min_child_weight': 72.4247572192607, 'scale_pos_weight': 2.1627615057750713, 'seed': 46.30461310936564, 'subsample': 0.9452809888159761}}, '197ac6f87dffedd43678ba8527eb56becd309c79': {'kwargs': {'alpha': 0.5364680195881149, 'colsample_bytree': 0.45189088748933604, 'eta': 0.012371305875746585, 'gamma': 0.9064360497565197, 'max_delta_step': 1.4126470367805288, 'max_depth': 38.486995676085996, 'min_child_weight': 62.464195059355134, 'scale_pos_weight': 1.9583422410872753, 'seed': 39.67358138575, 'subsample': 0.4310335654442331}}, '7170eed1bd5f4a099b4c544cee79ed1283d9ff37': {'kwargs': {'alpha': 0.2654385662062296, 'colsample_bytree': 0.1, 'eta': 0.1, 'gamma': 1.0, 'max_delta_step': 10.0, 'max_depth': 40.74300186240512, 'min_child_weight': 96.8398377569323, 'scale_pos_weight': 10.0, 'seed': 46.807896117574124, 'subsample': 0.1}}, '8abb9e4eb40f5a694f5b618c608302b882b498e6': {'kwargs': {'alpha': 1.0, 'colsample_bytree': 0.1, 'eta': 0.1, 'gamma': 1.0, 'max_delta_step': 10.0, 'max_depth': 40.850853201498104, 'min_child_weight': 97.0549046383368, 'scale_pos_weight': 10.0, 'seed': 47.25385945582012, 'subsample': 0.9210320936625356}}, 'a811bda1e0b5c7cff610fa70fb85521ea0c869a5': {'kwargs': {'alpha': 0.6726491408436711, 'colsample_bytree': 0.9269069267585942, 'eta': 0.06833338001009504, 'gamma': 0.30622643752425194, 'max_delta_step': 7.3139395959433084, 'max_depth': 38.107722806641064, 'min_child_weight': 21.020434049795245, 'scale_pos_weight': 8.625353218323646, 'seed': 41.23483252115022, 'subsample': 0.9271536994585713}}, '79536f8106c0b53a718c3f9019b78a64a8aa4ccc': {'kwargs': {'alpha': 1.0, 'colsample_bytree': 0.1, 'eta': 0.1, 'gamma': 0.6426685109205953, 'max_delta_step': 10.0, 'max_depth': 41.810419215310425, 'min_child_weight': 96.37283611262006, 'scale_pos_weight': 10.0, 'seed': 46.25045893436094, 'subsample': 0.1}}, 'cd87458fe8bc94d8dfed999174faac1972617c58': {'kwargs': {'alpha': 0.3646361026289764, 'colsample_bytree': 0.5304744918363348, 'eta': 0.011874812292597915, 'gamma': 0.4459527826112035, 'max_delta_step': 1.8149394340185943, 'max_depth': 20.995979979503147, 'min_child_weight': 71.56134824779393, 'scale_pos_weight': 1.4992809317315503, 'seed': 45.15888883082988, 'subsample': 0.1540915641713254}}, 'a20174ec210131e97137710284a3fdb4047e4b0a': {'kwargs': {'alpha': 0.7234865078104759, 'colsample_bytree': 0.2556252549399418, 'eta': 0.041324249574747986, 'gamma': 0.42396037128920105, 'max_delta_step': 1.5349116722277856, 'max_depth': 28.720096898384245, 'min_child_weight': 30.259144124422598, 'scale_pos_weight': 4.910845925218909, 'seed': 18.85620687924454, 'subsample': 0.8305267316952343}}, '93790e7e12ac7c1ce44903c825981ef82ea9a674': {'kwargs': {'alpha': 0.028260010405939395, 'colsample_bytree': 0.9902894226408323, 'eta': 0.0477759322455881, 'gamma': 0.1493077456736832, 'max_delta_step': 1.3681665841543689, 'max_depth': 38.96334182187581, 'min_child_weight': 62.97524482974987, 'scale_pos_weight': 5.7359104666825145, 'seed': 41.76572530322487, 'subsample': 0.6560279991568655}}, 'e194cdbd7208e1fd8b2a874515739ffc26681bea': {'kwargs': {'alpha': 1.0, 'colsample_bytree': 0.1, 'eta': 0.1, 'gamma': 0.7367438875465829, 'max_delta_step': 10.0, 'max_depth': 41.161156638200104, 'min_child_weight': 96.58333136689605, 'scale_pos_weight': 10.0, 'seed': 47.51635068885047, 'subsample': 0.1}}, '4cde5b8e54c55a7420a9ca03df7ea33c91527745': {'kwargs': {'alpha': 0.8412613710127163, 'colsample_bytree': 0.6317122838792428, 'eta': 0.0023314148623097774, 'gamma': 0.45397306658996595, 'max_delta_step': 0.8277097324783322, 'max_depth': 29.64834888950255, 'min_child_weight': 29.40298565101134, 'scale_pos_weight': 1.6418341544823978, 'seed': 17.546786397998623, 'subsample': 0.10107797057345437}}, 'db014e29b17436fa766ace84c2c4591e8533865d': {'kwargs': {'alpha': 1.0, 'colsample_bytree': 0.1, 'eta': 0.1, 'gamma': 0.9494125025648193, 'max_delta_step': 10.0, 'max_depth': 40.431636883912695, 'min_child_weight': 98.00228095586948, 'scale_pos_weight': 10.0, 'seed': 47.81606400706705, 'subsample': 0.1}}, 'e39ad808f6980dba7d320b2d1e02854a653563f2': {'kwargs': {'alpha': 0.6226353183543241, 'colsample_bytree': 0.9633858983987651, 'eta': 0.025507650364184987, 'gamma': 0.2526722336657782, 'max_delta_step': 1.1115816204056783, 'max_depth': 20.28987252695812, 'min_child_weight': 70.93580123132998, 'scale_pos_weight': 1.6865368744793607, 'seed': 47.4112687993577, 'subsample': 0.47621247638017916}}, '93a2d3e39bed31fc7df4aab484c8f1402833f29e': {'kwargs': {'alpha': 0.4654579770875542, 'colsample_bytree': 0.8217708017908828, 'eta': 0.02541732515276022, 'gamma': 0.7139940787731547, 'max_delta_step': 2.9004034766349163, 'max_depth': 9.651050115755544, 'min_child_weight': 64.96257180377857, 'scale_pos_weight': 9.748574606801421, 'seed': 29.883339473112517, 'subsample': 0.5097000232806876}}, '53ab4cbd8b772ce31be09fa2bfb156e30d4a91d5': {'kwargs': {'alpha': 0.24269021594049067, 'colsample_bytree': 0.8344480786258023, 'eta': 0.04079745399825966, 'gamma': 0.003773313638273046, 'max_delta_step': 0.8188512349914712, 'max_depth': 26.48844983712253, 'min_child_weight': 30.40442688574331, 'scale_pos_weight': 1.6429352431303523, 'seed': 16.838427821607635, 'subsample': 0.2810149933254557}}, '04cbb5b4ebcd9790c86197badfcf5fe97d67be78': {'kwargs': {'alpha': 1.0, 'colsample_bytree': 0.1, 'eta': 0.1, 'gamma': 1.0, 'max_delta_step': 10.0, 'max_depth': 41.26269634587534, 'min_child_weight': 96.44318057112507, 'scale_pos_weight': 10.0, 'seed': 46.354154492638365, 'subsample': 0.1}}, '6e2c4d6e09570df33f815067d20ee53e5db5ad24': {'kwargs': {'alpha': 0.6933458135606636, 'colsample_bytree': 0.1, 'eta': 0.1, 'gamma': 0.3892709873027238, 'max_delta_step': 10.0, 'max_depth': 42.98588175186529, 'min_child_weight': 96.6647021087748, 'scale_pos_weight': 7.796687756769591, 'seed': 47.11218295852123, 'subsample': 0.1}}, '56fc5113ba134afe98fee106db67bd2a64156e00': {'kwargs': {'alpha': 0.8479768497476806, 'colsample_bytree': 0.8111384034507712, 'eta': 0.0005347138616303759, 'gamma': 0.008119648332206353, 'max_delta_step': 9.804490997203043, 'max_depth': 40.258030915405335, 'min_child_weight': 19.78532778420805, 'scale_pos_weight': 9.858434166546141, 'seed': 37.51085532619348, 'subsample': 0.898888119554901}}, '43fd803c635e8b8a6bc6d1d561d2de2179f68363': {'kwargs': {'alpha': 1.0, 'colsample_bytree': 0.1, 'eta': 0.1, 'gamma': 0.9530669314561849, 'max_delta_step': 10.0, 'max_depth': 40.40496213650497, 'min_child_weight': 98.01332801885265, 'scale_pos_weight': 10.0, 'seed': 47.7993906601342, 'subsample': 0.1}}} # noqa E501 + +points3 = {'fafe03cd6b4a8cd9b9867447317e069cc63b540f': {'kwargs': {'alpha': 0.7027101973022741, 'colsample_bytree': 0.3969726123308247, 'eta': 0.08804899477662154, 'gamma': 0.7332885869178056, 'max_delta_step': 1.8927487961368616, 'max_depth': 36.55438115222037, 'min_child_weight': 63.232517532502676, 'scale_pos_weight': 3.96685994356639, 'seed': 42.41559299421199, 'subsample': 0.6362395159118978}}, 'a0c475873c1e3bd84cd7a1b119142eafc1a1bd25': {'kwargs': {'alpha': 0.5943310092891746, 'colsample_bytree': 0.21945087421840515, 'eta': 0.06907806427410738, 'gamma': 0.6601073560483439, 'max_delta_step': 0.9344235563770931, 'max_depth': 38.3309892123419, 'min_child_weight': 63.81068842474452, 'scale_pos_weight': 2.23891566439122, 'seed': 41.97994334838864, 'subsample': 0.36545190311446785}}, '8f85065dcfe1b34af5f12edc593be90b0df1a034': {'kwargs': {'alpha': 0.07017523858568098, 'colsample_bytree': 0.25245430296074767, 'eta': 0.07372738651245907, 'gamma': 0.4248833432428326, 'max_delta_step': 7.882712105372102, 'max_depth': 39.46168797937724, 'min_child_weight': 23.007131093319046, 'scale_pos_weight': 9.162678541265091, 'seed': 41.45050611040056, 'subsample': 0.26802697905939493}}, 'b01bb823a4ed6d6b55a0570861c7402a8e69a952': {'kwargs': {'alpha': 0.2778321144179313, 'colsample_bytree': 0.24783129396130135, 'eta': 0.03912534818797658, 'gamma': 0.4663767659308974, 'max_delta_step': 0.5969513938896576, 'max_depth': 20.116687479189903, 'min_child_weight': 72.65557882251787, 'scale_pos_weight': 3.506923818134493, 'seed': 46.62543265063431, 'subsample': 0.8192148073532302}}, '867222ddfe76c2e4939c1a5d07dcdd8af3f61263': {'kwargs': {'alpha': 0.28252777374080884, 'colsample_bytree': 0.8400576756601256, 'eta': 0.04327203262030846, 'gamma': 0.8624792015142486, 'max_delta_step': 2.5038134011126356, 'max_depth': 24.87300616817943, 'min_child_weight': 30.633048799756867, 'scale_pos_weight': 1.5793817027389545, 'seed': 16.292432861827407, 'subsample': 0.8607089372051691}}, 'c3f545d1ac41263ac374e72daa586f9f0ba488d5': {'kwargs': {'alpha': 0.9614060003583368, 'colsample_bytree': 0.5470816113462734, 'eta': 0.021622141248675775, 'gamma': 0.8768751749518214, 'max_delta_step': 3.5447970634971515, 'max_depth': 20.325632117761916, 'min_child_weight': 72.35420840242924, 'scale_pos_weight': 1.8340941559179336, 'seed': 47.016284327493636, 'subsample': 0.3722865968933107}}, '3951b831115c88df886c2785e62191990ecbbf79': {'kwargs': {'alpha': 0.7185573738874034, 'colsample_bytree': 0.587776160190954, 'eta': 0.01824607769713198, 'gamma': 0.5980782554862835, 'max_delta_step': 5.241386352818127, 'max_depth': 14.866031458707162, 'min_child_weight': 25.61654255058955, 'scale_pos_weight': 5.009978453726078, 'seed': 16.47855957741612, 'subsample': 0.9467042719983234}}, '8ce5f1b484a9f96f7587de7548bf7e70c3315ba4': {'kwargs': {'alpha': 0.4252815329760634, 'colsample_bytree': 0.6050223806476221, 'eta': 0.06279134027079947, 'gamma': 0.7372940595627078, 'max_delta_step': 0.43167579746224205, 'max_depth': 37.72059597659628, 'min_child_weight': 62.67432855341192, 'scale_pos_weight': 4.917341562415212, 'seed': 41.69041245162549, 'subsample': 0.40295378190317477}}, '14419a3779d1ff8fcf6873558993ebbedc89d987': {'kwargs': {'alpha': 0.5660924216310128, 'colsample_bytree': 0.5288758554890975, 'eta': 0.09152795934206077, 'gamma': 0.8567557168223149, 'max_delta_step': 6.549446679928241, 'max_depth': 14.983435119406053, 'min_child_weight': 71.85532948956917, 'scale_pos_weight': 9.695867277789734, 'seed': 34.98380568191385, 'subsample': 0.5537510395389743}}, 'a2842976f367d7190590f9ff984fb189e39bf77c': {'kwargs': {'alpha': 0.9515503282889703, 'colsample_bytree': 0.14846546461213583, 'eta': 0.049796246903521274, 'gamma': 0.43034129450534, 'max_delta_step': 4.903202451873084, 'max_depth': 22.888382020684478, 'min_child_weight': 10.276935516154795, 'scale_pos_weight': 8.145439797744409, 'seed': 18.147311251474655, 'subsample': 0.6367058712157992}}, '82c1f9fff1847eeebd5790de973a25414ed03c4e': {'kwargs': {'alpha': 0.3583571832282799, 'colsample_bytree': 0.7888902217417609, 'eta': 0.09206654626929545, 'gamma': 0.185718219320687, 'max_delta_step': 0.8332918151235735, 'max_depth': 37.31113537069529, 'min_child_weight': 65.18148397663374, 'scale_pos_weight': 3.9426434129898036, 'seed': 42.17671974123183, 'subsample': 0.487027713543176}}, '4964c55ff939b41150ea433d34ba54c46b73bc6a': {'kwargs': {'alpha': 0.23406035151804216, 'colsample_bytree': 0.579042534809806, 'eta': 0.028600368999834338, 'gamma': 0.23818222147295043, 'max_delta_step': 0.8490976659073024, 'max_depth': 19.211553258551916, 'min_child_weight': 70.89679426305557, 'scale_pos_weight': 0.41080827258102803, 'seed': 47.50122115129428, 'subsample': 0.9416281815903255}}, '0faee152d960ab76c2c9d9d8f9a54456c3101ec3': {'kwargs': {'alpha': 0.5599072754380644, 'colsample_bytree': 0.2165004538719889, 'eta': 0.04162758668598478, 'gamma': 0.4356332097414869, 'max_delta_step': 1.0941546100733701, 'max_depth': 38.01570722270386, 'min_child_weight': 62.452288988763314, 'scale_pos_weight': 4.346240319712552, 'seed': 41.39389069656801, 'subsample': 0.463114251878039}}, 'cd7137c0bf4d3f827c7421f542dbbc64916541a8': {'kwargs': {'alpha': 0.8869573684575884, 'colsample_bytree': 0.6902984861982361, 'eta': 0.03080627863752672, 'gamma': 0.9245409023620079, 'max_delta_step': 5.892836209964155, 'max_depth': 12.896642419419463, 'min_child_weight': 17.93589479730323, 'scale_pos_weight': 0.647945136780913, 'seed': 13.443694351976626, 'subsample': 0.5701860704990126}}, '213e70ecf8287e91c235e86144d9eb9b134c6d53': {'kwargs': {'alpha': 0.8284209401027416, 'colsample_bytree': 0.817074151477903, 'eta': 0.07200443927829558, 'gamma': 0.5114740131616691, 'max_delta_step': 0.26033962409835576, 'max_depth': 36.752851129556504, 'min_child_weight': 64.28107275799505, 'scale_pos_weight': 4.448210925228673, 'seed': 43.01065716874698, 'subsample': 0.32210006283513226}}, 'c87cc0b69045040ed69eaa8123edc6ccb2ab9e4d': {'kwargs': {'alpha': 0.1633970438356287, 'colsample_bytree': 0.1519336707210463, 'eta': 0.033360719369632655, 'gamma': 0.42081849258348936, 'max_delta_step': 3.418449947145543, 'max_depth': 11.147224425644653, 'min_child_weight': 64.83620593285208, 'scale_pos_weight': 9.390045456970249, 'seed': 29.95364916707937, 'subsample': 0.10606655548871441}}, '1c6d03bd8b06eb173e2967c4272f3a67148e78bd': {'kwargs': {'alpha': 0.04208119478518213, 'colsample_bytree': 0.6570249548171683, 'eta': 0.04458158794488585, 'gamma': 0.16696235169779328, 'max_delta_step': 2.3252023655167853, 'max_depth': 9.14278697997893, 'min_child_weight': 65.52474029656636, 'scale_pos_weight': 8.080576249652733, 'seed': 29.632865507785233, 'subsample': 0.4752579477073231}}, 'e9772925ef0e76e4253a6558b4f6a21e9dd65240': {'kwargs': {'alpha': 0.5821223589718156, 'colsample_bytree': 0.21054124739571017, 'eta': 0.014024016305762978, 'gamma': 0.0691875275212791, 'max_delta_step': 1.8485411019875952, 'max_depth': 22.46872238443997, 'min_child_weight': 71.13399159458304, 'scale_pos_weight': 2.2801930139645554, 'seed': 46.14912747503271, 'subsample': 0.38832342636151695}}, '7204d46c1d79b806849dd776856560443488cd62': {'kwargs': {'alpha': 0.48918192575218433, 'colsample_bytree': 0.8802771638062074, 'eta': 0.024379240601836907, 'gamma': 0.5435958667485749, 'max_delta_step': 1.5250511650679621, 'max_depth': 18.943633514783965, 'min_child_weight': 70.66155353203935, 'scale_pos_weight': 2.313949295550981, 'seed': 46.37771705734808, 'subsample': 0.7585227251888592}}, 'fd13af0cbfed477c2aca976f87a4ffe5cba01f6c': {'kwargs': {'alpha': 0.9170671839476585, 'colsample_bytree': 0.7048639567433913, 'eta': 0.09641342058779678, 'gamma': 0.567108807844195, 'max_delta_step': 0.49189167285642377, 'max_depth': 36.928677638727045, 'min_child_weight': 63.46554744335657, 'scale_pos_weight': 5.09560090663176, 'seed': 39.75902566415049, 'subsample': 0.6584570383288748}}} # noqa E501 + +results1 = {'a4c62a005e50b30563c8136b65b7ac4ab35ef248': 0.9553849421389543, 'd808f47c5e63ef8f82497ae6aad05ffba20193db': 0.9557551651587148, '25c787b0f29835df85766b7f6f37c58a8afb722a': 0.957533009248183, 'b5c10d0cc3c56f06e4284b34101eabf7ac9cc1a0': 0.9578631340829472, '0debbccbbad1898b0ace911abc7b194a46d133ae': 0.9538168806809304, '053d6eaf336584152287c19cb6d80dd7e988dc5f': 0.9560617184487816, '53777c8f07b2e83bbb0890adb9f0f56adc1b6c4d': 0.9290285374565601, '312d603aa80fe5054fac7749449cff5801c7f7b8': 0.9594610405946045, '47fcec81d02c935e577d68d645a851411646c850': 0.9565828380935457, '8a191dd9c4b4625ee8e4b30cac582d932ec68011': 0.9575967322650636, 'c84803cf3a931663af3a86797483676e881b891e': 0.956408341729566, '509ce5e826fbea8815baa534657940c910cfc55f': 0.9587363779941975, 'cea3a26032b72a654ee143e328f09fcd3bc5d6ff': 0.9557460386707255, '8350a99b70489adeefeda23e2219f07fd4be87d5': 0.9588596656538051, '207f893919165287721a352ad8985f6bbaf2587e': 0.9577829647794619, 'f0715a5cac2032cfc5e3cc9f3e3357813a565cab': 0.9563017535411852, 'ad09c5fc5b4c983923ee1067bb93c6aa73bfe04f': 0.9580936589632294, 'faadeebe4f5f480b0b433538591e6c26ef72b6ae': 0.9572100370199619, '9a0defa184465033ca1d7dd97c845af05d60af62': 0.9572507853650232, 'a8aae82962012c58efc4201a76375c94be729294': 0.959368931578009} # noqa E501 + +results2 = {'8abb9e4eb40f5a694f5b618c608302b882b498e6': 0.9504149959571236, '43fd803c635e8b8a6bc6d1d561d2de2179f68363': 0.9501685618912445, '197ac6f87dffedd43678ba8527eb56becd309c79': 0.9587137960686796, '93790e7e12ac7c1ce44903c825981ef82ea9a674': 0.9585775993766628, '4cde5b8e54c55a7420a9ca03df7ea33c91527745': 0.9193196721132836, 'a20174ec210131e97137710284a3fdb4047e4b0a': 0.9565113100545148, '56fc5113ba134afe98fee106db67bd2a64156e00': 0.9433124026809275, 'df0ce2ee19c418ffc8e830a91510f069d5a875b1': 0.9583433931694213, '6e2c4d6e09570df33f815067d20ee53e5db5ad24': 0.9505786533126307, '93a2d3e39bed31fc7df4aab484c8f1402833f29e': 0.9592043711597578, 'e194cdbd7208e1fd8b2a874515739ffc26681bea': 0.9477700833492515, '04cbb5b4ebcd9790c86197badfcf5fe97d67be78': 0.9077305509106499, 'a811bda1e0b5c7cff610fa70fb85521ea0c869a5': 0.9573711086781767, 'be14c556f8a841b40db33a74c478eb3f4cb23755': 0.9501653108089769, '53ab4cbd8b772ce31be09fa2bfb156e30d4a91d5': 0.9583137810005292, '7170eed1bd5f4a099b4c544cee79ed1283d9ff37': 0.9077188438003386, 'e39ad808f6980dba7d320b2d1e02854a653563f2': 0.9600402060563064, '79536f8106c0b53a718c3f9019b78a64a8aa4ccc': 0.9076761633013631, 'db014e29b17436fa766ace84c2c4591e8533865d': 0.950061683019851, 'cd87458fe8bc94d8dfed999174faac1972617c58': 0.9573071068198805} # noqa E501 + +results3 = {'0faee152d960ab76c2c9d9d8f9a54456c3101ec3': 0.9567203015545008, 'cd7137c0bf4d3f827c7421f542dbbc64916541a8': 0.9597165061337541, '867222ddfe76c2e4939c1a5d07dcdd8af3f61263': 0.9584214994975399, 'c87cc0b69045040ed69eaa8123edc6ccb2ab9e4d': 0.9225358555029012, '1c6d03bd8b06eb173e2967c4272f3a67148e78bd': 0.9586664071251745, '8f85065dcfe1b34af5f12edc593be90b0df1a034': 0.95246185209079, 'b01bb823a4ed6d6b55a0570861c7402a8e69a952': 0.9582324898466975, 'fd13af0cbfed477c2aca976f87a4ffe5cba01f6c': 0.9575296696358804, 'a0c475873c1e3bd84cd7a1b119142eafc1a1bd25': 0.957667651871562, '4964c55ff939b41150ea433d34ba54c46b73bc6a': 0.960055699094351, '14419a3779d1ff8fcf6873558993ebbedc89d987': 0.957283816090198, 'a2842976f367d7190590f9ff984fb189e39bf77c': 0.9531734601763102, '8ce5f1b484a9f96f7587de7548bf7e70c3315ba4': 0.9584228116592668, '7204d46c1d79b806849dd776856560443488cd62': 0.9599626593846182, 'c3f545d1ac41263ac374e72daa586f9f0ba488d5': 0.9599830012747, 'e9772925ef0e76e4253a6558b4f6a21e9dd65240': 0.9254054556476099, '3951b831115c88df886c2785e62191990ecbbf79': 0.9584402145777451, '82c1f9fff1847eeebd5790de973a25414ed03c4e': 0.957751489250696, 'fafe03cd6b4a8cd9b9867447317e069cc63b540f': 0.9573437915235644, '213e70ecf8287e91c235e86144d9eb9b134c6d53': 0.9581928371313033} # noqa E501 + + +if __name__ == '__main__': + points_list = [points1, points2, points3] + results_list = [results1, results2, results3] + ret = get_param_result_map(points_list, results_list) + + xkey = 'eta' + draw_plots(ret, xkey) + xkey = 'alpha' + draw_plots(ret, xkey) diff --git a/main/lib/idds/tests/test_wrapper.py b/main/lib/idds/tests/test_wrapper.py new file mode 100644 index 00000000..f59f39c6 --- /dev/null +++ b/main/lib/idds/tests/test_wrapper.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0OA +# +# Authors: +# - Wen Guan, , 2024 + + +""" +Test workflow. +""" + +import functools +import time + + +def slow_down(func): + """Sleep 1 second before calling the function""" + @functools.wraps(func) + def wrapper_slow_down(*args, **kwargs): + time.sleep(1) + return func(*args, **kwargs) + return wrapper_slow_down + + +@slow_down +def countdown(from_number): + if from_number < 1: + print("Liftoff!") + else: + print(from_number) + countdown(from_number - 1) + + +def repeat(num_times): + def decorator_repeat(func): + @functools.wraps(func) + def wrapper_repeat(*args, **kwargs): + for _ in range(num_times): + value = func(*args, **kwargs) + return value + return wrapper_repeat + return decorator_repeat + + +@repeat(num_times=4) +def greet(name): + print(f"Hello {name}") + + +def repeat1(_func=None, *, num_times=2): + print(_func) + + def decorator_repeat(func): + @functools.wraps(func) + def wrapper_repeat(*args, **kwargs): + for _ in range(num_times): + value = func(*args, **kwargs) + return value + return wrapper_repeat + + if _func is None: + return decorator_repeat + else: + return decorator_repeat(_func) + + +@repeat1 +def greet1(name): + print(f"Hello {name}") + + +@repeat1(num_times=4) +def greet2(name): + print(f"Hello {name}") + + +if __name__ == "__main__": + # countdown(10) + greet1('kitty') + greet2('kitty2') + + greet2('kitty3') diff --git a/main/tools/container/singularity/commands b/main/tools/container/singularity/commands index 75d9f95c..46abe6b9 100644 --- a/main/tools/container/singularity/commands +++ b/main/tools/container/singularity/commands @@ -1,6 +1,20 @@ + +# cp -r /eos/user/w/wguan/idds_ml/singularity/* /opt/singularity + cd /opt/singularity/ singularity build --sandbox idds_ng idds_nevergrad.def singularity exec idds_ng /bin/hostname singularity exec idds_ng python /opt/hyperparameteropt_nevergrad.py singularity build idds_nevergrad.simg idds_ng/ singularity exec idds_nevergrad.simg python /opt/hyperparameteropt_nevergrad.py + + + +cd /opt/singularity/ +singularity build --sandbox idds_ml idds_ml.def +singularity exec --writable --bind /source:/dest idds_ml /dest/myexec +singularity shell --writable --bind /source:/dest idds_ml + +singularity exec --bind ml:/ml idds_ml python ml/optimize.py + +singularity build idds_ml.simg idds_ml diff --git a/main/tools/container/singularity/idds_ml.def b/main/tools/container/singularity/idds_ml.def new file mode 100644 index 00000000..b0156185 --- /dev/null +++ b/main/tools/container/singularity/idds_ml.def @@ -0,0 +1,50 @@ +Bootstrap: docker +From: centos:7 + +%files + hyperparameteropt_nevergrad.py /opt + bdt_0409 /opt + +%post + yum update -q -y\ + && yum install -q -y wget make git gcc openssl-devel bzip2-devel libffi-devel \ + && cd /usr/src \ + && wget https://www.python.org/ftp/python/3.7.4/Python-3.7.4.tgz \ + && tar xzf Python-3.7.4.tgz \ + && cd Python-3.7.4 \ + && ./configure --enable-optimizations \ + && make altinstall \ + && rm -rf /usr/src/Python-3.7.4.tgz \ + && yum clean all \ + && echo "set python3.7 as default" \ + && alternatives --install /usr/bin/python python /usr/bin/python2 50 \ + && alternatives --install /usr/bin/python python /usr/local/bin/python3.7 70 \ + && alternatives --set python /usr/local/bin/python3.7 \ + && echo "symlink pip" \ + && ln -s /usr/local/bin/pip3.7 /usr/bin/pip \ + && pip install --no-cache-dir --upgrade pip pipenv setuptools wheel \ + && ln -s /usr/local/bin/pipenv /usr/bin/pipenv \ + && echo "set the locale" \ + && localedef --quiet -c -i en_US -f UTF-8 en_US.UTF-8 + + pip install --upgrade pip + pip install nevergrad + pip install theano keras h5py matplotlib tabulate + pip install bayesian-optimization + pip install xgboost + pip install lightgbm + + +%environment + # export LC_ALL=C + # export PATH=/usr/games:$PATH + export LANG=en_US.UTF-8 + export LANGUAGE=en_US:en + export LC_ALL=en_US.UTF-8 + +%labels + Maintainer iDDS_HPO_Nevergrad(wen.guan@cern.ch) + Version v1.0 + +%runscript + echo "iDDS Nevergrad hyper parameter optimization plugin" diff --git a/main/tools/container/singularity/idds_ml_al9.def b/main/tools/container/singularity/idds_ml_al9.def new file mode 100644 index 00000000..df672843 --- /dev/null +++ b/main/tools/container/singularity/idds_ml_al9.def @@ -0,0 +1,31 @@ +Bootstrap: docker +From: centos:7 + +%files + hyperparameteropt_nevergrad.py /opt + bdt_0409 /opt + +%post + yum update -q -y\ + && yum install -q -y wget make git gcc openssl-devel bzip2-devel libffi-devel which pip + + ln -s /usr/bin/python3 /usr/bin/python + + pip install --upgrade pip + pip install nevergrad + pip install theano keras h5py matplotlib tabulate + pip install bayesian-optimization + pip install xgboost + pip install lightgbm + + +%environment + # export LC_ALL=C + # export PATH=/usr/games:$PATH + +%labels + Maintainer iDDS_HPO_Nevergrad(wen.guan@cern.ch) + Version v1.0 + +%runscript + echo "iDDS ML hyper parameter optimization plugin" diff --git a/main/tools/container/singularity/ml/optimize.py b/main/tools/container/singularity/ml/optimize.py new file mode 100644 index 00000000..1286f35e --- /dev/null +++ b/main/tools/container/singularity/ml/optimize.py @@ -0,0 +1,425 @@ + +import json +import hashlib +import os +import sys +import time +import traceback + +import numpy as np # noqa F401 + +from sklearn.metrics import roc_curve, auc, confusion_matrix, brier_score_loss, mean_squared_error, log_loss, roc_auc_score # noqa F401 +from sklearn import metrics # noqa F401 +from sklearn.preprocessing import label_binarize # noqa F401 +from sklearn.neighbors import KNeighborsClassifier # noqa F401 + +from sklearn.model_selection import cross_val_score # noqa F401 +from sklearn.model_selection import KFold # noqa F401 +from sklearn.preprocessing import LabelEncoder # noqa F401 +from sklearn import model_selection # noqa F401 +from sklearn.metrics import roc_curve, auc, confusion_matrix # noqa F401 +from sklearn.model_selection import train_test_split # noqa F401 + +from sklearn.preprocessing import StandardScaler # noqa F401 +from sklearn.model_selection import StratifiedKFold # noqa F401 + +from tabulate import tabulate # noqa F401 + +import xgboost as xgb + +from bayes_opt import BayesianOptimization, UtilityFunction + + +def load_data(workdir='/opt/bdt_0409/ttHyyML_had_single', analysis_type='had'): + currentDir = os.path.dirname(os.path.realpath(__file__)) + + print("CurrentDir: %s" % currentDir) + print("WorkDir: %s" % workdir) + workdir = os.path.abspath(workdir) + print("Absolute WorkDir: %s" % workdir) + os.chdir(workdir) + sys.path.insert(0, workdir) + sys.argv = ['test'] + + if analysis_type == 'had': + from load_data_real_auto import load_data_real_hadronic + data, label = load_data_real_hadronic() + elif analysis_type == 'lep': + from load_data_real_auto import load_data_real_leptonic + data, label = load_data_real_leptonic() + sys.path.remove(workdir) + os.chdir(currentDir) + return data, label + + +def get_param(params, name, default): + if params and name in params: + return params[name] + return default + + +def getAUC(y_test, score, y_val_weight=None): + # fpr, tpr, _ = roc_curve(y_test, score, sample_weight=y_val_weight) + # roc_auc = auc(fpr, tpr, True) + print(y_test.shape) + print(score.shape) + if y_val_weight: + print(y_val_weight.shape) + return roc_auc_score(y_test, score, sample_weight=y_val_weight) + + +def getBrierScore(y_test, score): + return 1-brier_score_loss(y_test, score) + + +def evaluateBrierScore(y_pred, data): + label = data.get_label() + return 'brierLoss', 1-brier_score_loss(y_pred, label) + + +def getRMSE(y_test, score): + return mean_squared_error(y_test, score) ** 0.5 + + +def getLogLoss(y_test, score): + return log_loss(y_test, score) + + +def xgb_callback_save_model(model_name, period=1000): + def callback(env): + try: + bst, i, _ = env.model, env.iteration, env.end_iteration + if (i % period == 0): + bst.save_model(model_name) + except Exception: + print(traceback.format_exc()) + return callback + + +def train_bdt(input_x, input_y, params=None, retMethod=None, hist=True, saveModel=False, input_weight=None): + + train, val = input_x + y_train_cat, y_val_cat = input_y + if input_weight: + y_train_weight, y_val_weight = input_weight + else: + y_train_weight = None + y_val_weight = None + + train = train.reshape((train.shape[0], -1)) + val = val.reshape((val.shape[0], -1)) + + dTrain = xgb.DMatrix(train, label=y_train_cat, weight=y_train_weight) + dVal = xgb.DMatrix(val, label=y_val_cat, weight=y_val_weight) + + # train model + print('Train model.') + + # param = {'max_depth':10, 'eta':0.1, 'min_child_weight': 60, 'silent':1, 'objective':'binary:logistic', 'eval_metric': ['logloss', 'auc' ]} + # param = {'max_depth':10, 'eta':0.1, 'min_child_weight': 1, 'silent':1, 'objective':'rank:pairwise', 'eval_metric': ['auc','logloss']} + # def_params = {'max_depth':10, 'eta':0.005, 'min_child_weight': 15, 'silent':1, 'objective':'binary:logistic', 'eval_metric': ['auc','logloss']} + # def_params = {'colsample_bytree': 0.7, 'silent': 0, 'eval_metric': ['auc', 'logloss'], 'scale_pos_weight': 1.4, 'max_delta_step': 0, 'nthread': 8, 'min_child_weight': 160, 'subsample': 0.8, 'eta': 0.04, 'objective': 'binary:logistic', 'alpha': 0.1, 'lambda': 10, 'seed': 10, 'max_depth': 10, 'gamma': 0.03, 'booster': 'gbtree'} + # def_params = {'colsample_bytree': 0.7, 'silent': 0, 'eval_metric': ['auc', 'logloss'], 'scale_pos_weight': 1.4, 'max_delta_step': 0, 'nthread': 8, 'min_child_weight': 160, 'subsample': 0.8, 'eta': 0.04, 'objective': 'binary:logistic', 'alpha': 0.1, 'lambda': 10, 'seed': 10, 'max_depth': 10, u'gamma': 0.5, 'booster': 'gbtree'} + # def_params = {'eval_metric': ['logloss', 'auc'], 'scale_pos_weight': 5.1067081406104631, 'max_delta_step': 4.6914331907848759, 'seed': 10, 'alpha': 0.1, 'booster': 'gbtree', 'colsample_bytree': 0.64067554676687111, 'nthread': 4, 'min_child_weight': 58, 'subsample': 0.76111573761360196, 'eta': 0.1966696564443787, 'objective': 'binary:logistic', 'max_depth': 10, 'gamma': 0.74055129530012553} + def_params = {} + if not params: + params = {} + if 'num_boost_round' not in params: + params['num_boost_round'] = 100000 + if 'objective' not in params: + params['objective'] = 'binary:logistic' + + for key in def_params: + if key not in params: + params[key] = def_params[key] + + if 'silent' not in params: + params['silent'] = 0 + + if hist: + params['tree_method'] = 'hist' + params['booster'] = 'gbtree' + params['grow_policy'] = 'lossguide' + params['nthread'] = 4 + params['booster'] = 'gbtree' + + start = time.time() + evallist = [(dTrain, 'train'), (dVal, 'eval')] + evals_result = {} + + try: + save_model_callback = xgb_callback_save_model(params['model']+"temp" if params and 'model' in params else 'models/default_bdt_temp.h5') + # with early stop + if not saveModel: + # bst = xgb.train(params, dTrain, params['num_boost_round'], evals=evallist, early_stopping_rounds=10, evals_result=evals_result, verbose_eval=False, callbacks=[save_model_callback]) + bst = xgb.train(params, dTrain, params['num_boost_round'], evals=evallist, early_stopping_rounds=10, evals_result=evals_result, verbose_eval=True) + else: + bst = xgb.train(params, dTrain, params['num_boost_round'], evals=evallist, early_stopping_rounds=10, evals_result=evals_result, callbacks=[save_model_callback]) + # bst = xgb.train(params, dTrain, params['num_boost_round'], evals=evallist, early_stopping_rounds=10, evals_result=evals_result, feval=evaluateBrierScore, callbacks=[save_model_callback]) + # bst = xgb.train(params, dTrain, params['num_boost_round'], evals=evallist, evals_result=evals_result, callbacks=[save_model_callback]) + except KeyboardInterrupt: + print('Finishing on SIGINT.') + print("CPU Training time: %s" % (time.time() - start)) + + # test model + print('Test model.') + + score = bst.predict(dVal) + rmse = None + logloss = None + if 'num_class' in params and params['num_class'] == 3: + y_val_cat_binary = label_binarize(y_val_cat, classes=[0, 1, 2]) + aucValue = getAUC(y_val_cat_binary[:, 0], score[:, 0]) + aucValue1 = getAUC(y_val_cat_binary[:, 1], score[:, 1]) + aucValue2 = getAUC(y_val_cat_binary[:, 2], score[:, 2]) + + print("AUC: %s, %s, %s" % (aucValue, aucValue1, aucValue2)) + + bslValue = getBrierScore(y_val_cat_binary[:, 0], score[:, 0]) + bslValue1 = getBrierScore(y_val_cat_binary[:, 1], score[:, 1]) + bslValue2 = getBrierScore(y_val_cat_binary[:, 2], score[:, 2]) + + print("BrierScoreLoss: %s, %s, %s" % (bslValue, bslValue1, bslValue2)) + + rmse = getRMSE(y_val_cat_binary[:, 0], score[:, 0]) + logloss = getLogLoss(y_val_cat_binary[:, 0], score[:, 0]) + else: + aucValue = getAUC(y_val_cat, score) + bslValue = getBrierScore(y_val_cat, score) + rmse = None + logloss = None + auc = None # noqa F811 + if 'auc' in evals_result['eval']: + auc = evals_result['eval']['auc'][-1] + if 'rmse' in evals_result['eval']: + rmse = evals_result['eval']['rmse'][-1] + if 'logloss' in evals_result['eval']: + logloss = evals_result['eval']['logloss'][-1] + print("params: %s #, Val AUC: %s, BrierScoreLoss: %s, xgboost rmse: %s, xgboost logloss: %s, xgboost auc: %s" % (params, aucValue, bslValue, rmse, logloss, auc)) + rmse = getRMSE(y_val_cat, score) + logloss = getLogLoss(y_val_cat, score) + print("params: %s #, Val AUC: %s, BrierScoreLoss: %s, Val rmse: %s, Val logloss: %s" % (params, aucValue, bslValue, rmse, logloss)) + + # bst.save_model(params['model'] if params and 'model' in params else 'models/default_bdt.h5') + + print(bst.get_fscore()) + # print(bst.get_score()) + + try: + pass + # from matplotlib import pyplot + # print("Plot importance") + # xgb.plot_importance(bst) + # pyplot.savefig('plots/' + params['name'] if 'name' in params else 'default' + '_feature_importance.png') + # pyplot.savefig('plots/' + params['name'] if 'name' in params else 'default' + '_feature_importance.eps') + except Exception: + print(traceback.format_exc()) + + try: + history = {'loss': evals_result['train']['logloss'], 'val_loss': evals_result['eval']['logloss'], + 'acc': evals_result['train']['auc'], 'val_acc': evals_result['eval']['auc']} + except Exception: + print(traceback.format_exc()) + history = {} + + if retMethod: + if retMethod == 'auc': + return aucValue + if retMethod == 'rmse': + return rmse + if retMethod == 'brier': + return bslValue + if retMethod == 'logloss': + return logloss + return score, history + + +def evaluate_bdt(input_x, input_y, opt_params, retMethod=None, hist=True, saveModel=False, input_weight=None, **kwargs): + params = kwargs + if not params: + params = {} + if params and 'max_depth' in params: + params['max_depth'] = int(params['max_depth']) + if params and 'num_boost_round' in params: + params['num_boost_round'] = int(params['num_boost_round']) + if params and 'seed' in params: + params['seed'] = int(params['seed']) + if params and 'max_bin' in params: + params['max_bin'] = int(params['max_bin']) + # params[''] = int(params['']) + + if opt_params: + for opt in opt_params: + if opt not in params: + params[opt] = opt_params[opt] + + if retMethod and retMethod == 'auc': + params['eval_metric'] = ['rmse', 'logloss', 'auc'] + elif retMethod and retMethod == 'logloss': + params['eval_metric'] = ['rmse', 'auc', 'logloss'] + elif retMethod and retMethod == 'rmse': + params['eval_metric'] = ['logloss', 'auc', 'rmse'] + elif retMethod: + params['eval_metric'] = [retMethod] + + print(params) + auc = train_bdt(input_x, input_y, params=params, retMethod=retMethod, hist=hist, saveModel=saveModel, input_weight=input_weight) # noqa F811 + print("params: %s, ret: %s" % (params, auc)) + return auc + + +def optimize_bdt(input_x, input_y, opt_params, opt_method='auc', opt_ranges=None, hist=True, input_weight=None): + eval_params = { + 'colsample_bytree': (0.1, 1), + 'scale_pos_weight': (0, 10), + 'max_delta_step': (0, 10), + 'seed': (1, 50), + 'min_child_weight': (0, 100), + 'subsample': (0.1, 1), + 'eta': (0, 0.1), + 'alpha': (0, 1), + # 'lambda': (0, 100), + 'max_depth': (0, 50), + 'gamma': (0, 1), + # 'num_boost_round': (100000, 1000000), + } + + explore_params1 = { # noqa F841 + # 'eta': [0.001, 0.002, 0.003, 0.004, 0.005, 0.006, 0.007, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08], + 'eta': [0.001, 0.004, 0.007, 0.03, 0.05, 0.07], + # 'scale_pos_weight': [1, 2, 3, 5, 7, 8], + 'colsample_bytree': [0.67, 0.79, 0.76, 0.5, 0.4, 0.85], + 'scale_pos_weight': [4.9, 2.8, 2.1, 4.7, 1.7, 4], + 'max_delta_step': [5, 2.2, 1.67, 2.53, 8.8, 8], + 'seed': [10, 20, 30, 40, 50, 60], + 'min_child_weight': [50, 74, 53, 14, 45, 30], + 'subsample': [0.78, 0.82, 0.6, 0.87, 0.5, 0.7], + 'alpha': [0.1, 0.2, 0.3, 0.05, 0.15, 0.25], + # 'lambda': [50], + 'max_depth': [6, 7, 10, 14, 19, 28], + 'gamma': [0.47, 0.19, 0.33, 0.43, 0.5, 0.76], + # 'num_boost_round': [1000000] + } + + explore_params = { # noqa F841 + 'eta': [0.004, 0.03], + # 'scale_pos_weight': [3, 7], + 'colsample_bytree': [0.67, 0.4], + 'scale_pos_weight': [4.9, 1.7], + 'max_delta_step': [2.2, 8], + 'seed': [10, 50], + 'min_child_weight': [50, 30], + 'subsample': [0.78, 0.5], + 'alpha': [0.2, 0.25], + # 'lambda': [50], + 'max_depth': [10, 28], + 'gamma': [0.47, 0.76], + # 'num_boost_round': [1000000] + } + + print("Eval: %s" % eval_params) + optFunc = lambda **z: evaluate_bdt(input_x, input_y, opt_params, opt_method, hist=hist, saveModel=False, input_weight=input_weight, **z) # noqa F731 + bayesopt = BayesianOptimization(optFunc, eval_params) + # bayesopt.explore(explore_params) + bayesopt.maximize(init_points=3, n_iter=5) + # bayesopt.maximize(init_points=30, n_iter=200) + # bayesopt.maximize(init_points=2, n_iter=2) + # print(bayesopt.res) + p = bayesopt.max + print("Best params: %s" % p) + + +def optimize(): + data, label = load_data() + train, val = data + y_train_cat, y_val_cat = label + + opt_method = 'auc' + opt_ranges = {"subsample": [0.10131415926000001, 1], + "eta": [0.0100131415926, 0.03], + "colsample_bytree": [0.10131415926000001, 1], + "gamma": [0.00131415926, 1], + "alpha": [0.00131415926, 1], + "max_delta_step": [0.00131415926, 10], + "max_depth": [5.00131415926, 50], + "min_child_weight": [0.00131415926, 100]} + params = {'num_boost_round': 10} + + optimize_bdt(input_x=[train, val], input_y=[y_train_cat, y_val_cat], opt_params=params, + opt_ranges=opt_ranges, opt_method=opt_method, hist=True, input_weight=None) + + +def get_unique_id_for_dict(dict_): + ret = hashlib.sha1(json.dumps(dict_, sort_keys=True).encode()).hexdigest() + return ret + + +def optimize_bdt1(input_x, input_y, opt_params, opt_method='auc', opt_ranges=None, hist=True, input_weight=None): + eval_params = { + 'colsample_bytree': (0.1, 1), + 'scale_pos_weight': (0, 10), + 'max_delta_step': (0, 10), + 'seed': (1, 50), + 'min_child_weight': (0, 100), + 'subsample': (0.1, 1), + 'eta': (0, 0.1), + 'alpha': (0, 1), + # 'lambda': (0, 100), + 'max_depth': (0, 50), + 'gamma': (0, 1), + # 'num_boost_round': (100000, 1000000), + } + + print("Eval: %s" % eval_params) + optFunc = lambda **z: evaluate_bdt(input_x, input_y, opt_params, opt_method, hist=hist, saveModel=False, input_weight=input_weight, **z) # noqa F731 + bayesopt = BayesianOptimization(optFunc, eval_params) + util = UtilityFunction(kind='ucb', + kappa=2.576, + xi=0.0, + kappa_decay=1, + kappa_decay_delay=0) + + n_iterations, n_points_per_iteration = 3, 5 + for i in range(n_iterations): + points = {} + for j in range(n_points_per_iteration): + x_probe = bayesopt.suggest(util) + u_id = get_unique_id_for_dict(x_probe) + print('x_probe (%s): %s' % (u_id, x_probe)) + points[u_id] = {'kwargs': x_probe} + ret = evaluate_bdt(input_x, input_y, opt_params, retMethod=opt_method, hist=hist, saveModel=False, input_weight=input_weight, **x_probe) + print('ret :%s' % ret) + points[u_id]['ret'] = ret + bayesopt.register(x_probe, ret) + + # bayesopt.explore(explore_params) + # bayesopt.maximize(init_points=3, n_iter=5) + # bayesopt.maximize(init_points=30, n_iter=200) + # bayesopt.maximize(init_points=2, n_iter=2) + print(bayesopt.res) + p = bayesopt.max + print("Best params: %s" % p) + + +def optimize1(): + data, label = load_data() + train, val = data + y_train_cat, y_val_cat = label + + opt_method = 'auc' + opt_ranges = {"subsample": [0.10131415926000001, 1], + "eta": [0.0100131415926, 0.03], + "colsample_bytree": [0.10131415926000001, 1], + "gamma": [0.00131415926, 1], + "alpha": [0.00131415926, 1], + "max_delta_step": [0.00131415926, 10], + "max_depth": [5.00131415926, 50], + "min_child_weight": [0.00131415926, 100]} + params = {'num_boost_round': 10} + + optimize_bdt1(input_x=[train, val], input_y=[y_train_cat, y_val_cat], opt_params=params, + opt_ranges=opt_ranges, opt_method=opt_method, hist=True, input_weight=None) + + +if __name__ == '__main__': + optimize1() diff --git a/main/tools/env/install_idds.sh b/main/tools/env/install_idds.sh index eab6057b..1033399a 100644 --- a/main/tools/env/install_idds.sh +++ b/main/tools/env/install_idds.sh @@ -9,3 +9,13 @@ # - Wen Guan, , 2019 python setup.py install --old-and-unmanageable --force + +bash workflow/tools/make/make.sh + +echo cp workflow/bin/run_workflow_wrapper ~/www/wiscgroup/ +cp workflow/bin/run_workflow_wrapper ~/www/wiscgroup/ + +echo scp workflow/bin/run_workflow_wrapper root@ai-idds-04:/data/iddssv1/srv/var/trf/user/ +scp workflow/bin/run_workflow_wrapper root@ai-idds-04:/data/iddssv1/srv/var/trf/user/ + +rm -fr workflow/bin/run_workflow_wrapper diff --git a/main/tools/env/setup_dev.sh b/main/tools/env/setup_dev.sh index 6ef6db8b..3fc2f5a1 100644 --- a/main/tools/env/setup_dev.sh +++ b/main/tools/env/setup_dev.sh @@ -20,7 +20,8 @@ echo 'Root dir: ' $RootDir export IDDS_HOME=$RootDir export ALEMBIC_CONFIG=${IDDS_HOME}/etc/idds/alembic.ini -source /afs/cern.ch/user/w/wguan/workdisk/conda/setup.sh +# source /afs/cern.ch/user/w/wguan/workdisk/conda/setup.sh +source /afs/cern.ch/user/w/wguan/workdisk/conda/setup_mini.sh conda activate $CondaDir #export PYTHONPATH=${IDDS_HOME}/lib:$PYTHONPATH diff --git a/main/tools/env/setup_panda.sh b/main/tools/env/setup_panda.sh index 37ef38ca..2f8c9a9d 100644 --- a/main/tools/env/setup_panda.sh +++ b/main/tools/env/setup_panda.sh @@ -69,4 +69,22 @@ else export PANDA_SYS=/afs/cern.ch/user/w/wguan/workdisk/iDDS/.conda/iDDS/ # export PANDA_CONFIG_ROOT=/afs/cern.ch/user/w/wguan/workdisk/iDDS/main/etc/panda/ export PANDA_CONFIG_ROOT=~/.panda/ + + # export IDDS_HOST=https://aipanda015.cern.ch:443/idds + + # dev + # export IDDS_HOST=https://aipanda104.cern.ch:443/idds + + # doma + export IDDS_HOST=https://aipanda105.cern.ch:443/idds + + export IDDS_BROKERS=atlas-test-mb.cern.ch:61013 + export IDDS_BROKER_DESTINATION=/topic/doma.idds + export IDDS_BROKER_USERNAME=domaidds + export IDDS_BROKER_PASSWORD=1d25yeft6krJ1HFH + export IDDS_BROKER_TIMEOUT=360 + + PANDA_QUEUE=BNL_OSG_2 + PANDA_WORKING_GROUP=EIC + PANDA_VO=wlcg fi diff --git a/monitor/data/conf.js b/monitor/data/conf.js index 190646a6..bc520561 100644 --- a/monitor/data/conf.js +++ b/monitor/data/conf.js @@ -1,9 +1,9 @@ var appConfig = { - 'iddsAPI_request': "https://lxplus811.cern.ch:443/idds/monitor_request/null/null", - 'iddsAPI_transform': "https://lxplus811.cern.ch:443/idds/monitor_transform/null/null", - 'iddsAPI_processing': "https://lxplus811.cern.ch:443/idds/monitor_processing/null/null", - 'iddsAPI_request_detail': "https://lxplus811.cern.ch:443/idds/monitor/null/null/true/false/false", - 'iddsAPI_transform_detail': "https://lxplus811.cern.ch:443/idds/monitor/null/null/false/true/false", - 'iddsAPI_processing_detail': "https://lxplus811.cern.ch:443/idds/monitor/null/null/false/false/true" + 'iddsAPI_request': "https://lxplus927.cern.ch:443/idds/monitor_request/null/null", + 'iddsAPI_transform': "https://lxplus927.cern.ch:443/idds/monitor_transform/null/null", + 'iddsAPI_processing': "https://lxplus927.cern.ch:443/idds/monitor_processing/null/null", + 'iddsAPI_request_detail': "https://lxplus927.cern.ch:443/idds/monitor/null/null/true/false/false", + 'iddsAPI_transform_detail': "https://lxplus927.cern.ch:443/idds/monitor/null/null/false/true/false", + 'iddsAPI_processing_detail': "https://lxplus927.cern.ch:443/idds/monitor/null/null/false/false/true" } diff --git a/workflow/bin/run_workflow b/workflow/bin/run_workflow new file mode 100644 index 00000000..08e8538f --- /dev/null +++ b/workflow/bin/run_workflow @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0OA +# +# Authors: +# - Wen Guan, , 2024 + + +""" +Run workflow. +""" + +from __future__ import print_function + +import argparse +import argcomplete +# import json +import logging +import os +import sys +import time +import traceback + +from idds.common.utils import json_dumps, json_loads, setup_logging, decode_base64 +# from idds.common.utils import merge_dict +from idds.iworkflow.version import release_version +from idds.iworkflow.workflow import Workflow +from idds.iworkflow.work import Work + + +setup_logging(__name__, stream=sys.stdout) + + +def get_context_args(context, original_args, update_args): + func_name, args, kwargs, group_kwargs = None, None, None, None + if original_args: + original_args = json_loads(original_args) + func_name, args, kwargs, group_kwargs = original_args + if update_args: + if update_args == "${IN/L}": + logging.info("update_args == original ${IN/L}, is not set") + else: + try: + update_args = json_loads(update_args) + update_kwargs = update_args + if update_kwargs and isinstance(update_kwargs, dict): + # kwargs = merge_dict(kwargs, update_kwargs) + kwargs.update(update_kwargs) + except Exception as ex: + logging.error("Failed to update kwargs: %s" % ex) + return context, func_name, args, kwargs, group_kwargs, update_args + + +def run_workflow(context, original_args, update_args): + context, func_name, args, kwargs, group_kwargs, update_args = get_context_args(context, original_args, update_args) + logging.info("context: %s" % context) + logging.info("func_name: %s" % func_name) + logging.info("args: %s" % str(args)) + logging.info("kwargs: %s" % kwargs) + logging.info("group_kwargs: %s" % group_kwargs) + + context.initialize() + context.setup_source_files() + + workflow = Workflow(func=func_name, args=args, kwargs=kwargs, group_kwargs=group_kwargs, update_kwargs=update_args, context=context) + logging.info("workflow: %s" % workflow) + with workflow: + ret = workflow.run() + logging.info("run workflow result: %s" % str(ret)) + return 0 + + +def run_work(context, original_args, update_args): + context, func_name, args, kwargs, group_kwargs, update_args = get_context_args(context, original_args, update_args) + logging.info("context: %s" % context) + logging.info("func_name: %s" % func_name) + logging.info("args: %s" % str(args)) + logging.info("kwargs: %s" % kwargs) + logging.info("group_kwargs: %s" % group_kwargs) + + context.initialize() + context.setup_source_files() + + work = Work(func=func_name, args=args, kwargs=kwargs, group_kwargs=group_kwargs, update_kwargs=update_args, context=context) + logging.info("work: %s" % work) + ret = work.run() + logging.info("run work result: %s" % str(ret)) + return 0 + + +def run_iworkflow(args): + if args.context: + context = decode_base64(args.context) + context = json_loads(context) + # logging.info(context) + # context = str(binascii.unhexlify(args.context).decode()) + else: + context = None + if args.original_args: + original_args = decode_base64(args.original_args) + # logging.info(original_args) + # orginal_args = str(binascii.unhexlify(args.original_args).decode()) + else: + original_args = None + if args.update_args: + # logging.info(args.update_args) + # update_args = str(binascii.unhexlify(args.update_args).decode()) + update_args = decode_base64(args.update_args) + logging.info(update_args) + else: + update_args = None + + if args.type == 'workflow': + logging.info("run workflow") + password = context.broker_password + context.broker_password = '***' + logging.info("context: %s" % json_dumps(context)) + context.broker_password = password + logging.info("original_args: %s" % original_args) + logging.info("update_args: %s" % update_args) + exit_code = run_workflow(context, original_args, update_args) + logging.info("exit code: %s" % exit_code) + else: + logging.info("run work") + password = context.broker_password + context.broker_password = '***' + logging.info("context: %s" % json_dumps(context)) + context.broker_password = password + logging.info("original_args: %s" % original_args) + logging.info("update_args: %s" % update_args) + exit_code = run_work(context, original_args, update_args) + logging.info("exit code: %s" % exit_code) + return exit_code + + +def custom_action(): + class CustomAction(argparse.Action): + def __init__(self, option_strings, dest, default=False, required=False, help=None): + super(CustomAction, self).__init__(option_strings=option_strings, + dest=dest, const=True, default=default, + required=required, help=help) + + def __call__(self, parser, namespace, values=None, option_string=None): + print(values) + # setattr(namespace, self.dest, values) + return CustomAction + + +def get_parser(): + """ + Return the argparse parser. + """ + oparser = argparse.ArgumentParser(prog=os.path.basename(sys.argv[0]), add_help=True) + + # common items + oparser.add_argument('--version', action='version', version='%(prog)s ' + release_version) + oparser.add_argument('--verbose', '-v', default=False, action='store_true', help="Print more verbose output.") + oparser.add_argument('--type', dest='type', action='store', choices=['workflow', 'work'], default='workflow', help='The type in [workflow, work]. Default is workflow.') + oparser.add_argument('--context', dest='context', help="The context.") + oparser.add_argument('--original_args', dest='original_args', help="The original arguments.") + oparser.add_argument('--update_args', dest='update_args', nargs='?', const=None, help="The updated arguments.") + return oparser + + +if __name__ == '__main__': + arguments = sys.argv[1:] + + oparser = get_parser() + argcomplete.autocomplete(oparser) + + args = oparser.parse_args(arguments) + + try: + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + start_time = time.time() + + exit_code = run_iworkflow(args) + end_time = time.time() + if args.verbose: + print("Completed in %-0.4f sec." % (end_time - start_time)) + sys.exit(exit_code) + except Exception as error: + logging.error("Strange error: {0}".format(error)) + logging.error(traceback.format_exc()) + sys.exit(-1) diff --git a/workflow/lib/idds/iworkflow/__init__.py b/workflow/lib/idds/iworkflow/__init__.py new file mode 100644 index 00000000..6693d76b --- /dev/null +++ b/workflow/lib/idds/iworkflow/__init__.py @@ -0,0 +1,9 @@ +#!/usr/bin/env python +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0OA +# +# Authors: +# - Wen Guan, , 2023 diff --git a/workflow/lib/idds/iworkflow/asyncresult.py b/workflow/lib/idds/iworkflow/asyncresult.py new file mode 100644 index 00000000..3e92b09f --- /dev/null +++ b/workflow/lib/idds/iworkflow/asyncresult.py @@ -0,0 +1,454 @@ +#!/usr/bin/env python +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0OA +# +# Authors: +# - Wen Guan, , 2023 - 2024 + +import logging +import random +import socket +import stomp +import threading +import time +import traceback + +from queue import Queue + +from idds.common.constants import WorkflowType +from idds.common.utils import json_dumps, json_loads, setup_logging, get_unique_id_for_dict +from .base import Base + + +setup_logging(__name__) +logging.getLogger("stomp").setLevel(logging.CRITICAL) + + +class MessagingListener(stomp.ConnectionListener): + ''' + Messaging Listener + ''' + def __init__(self, broker, output_queue, logger=None): + ''' + __init__ + ''' + self.name = "MessagingListener" + self.__broker = broker + self.__output_queue = output_queue + # self.logger = logging.getLogger(self.__class__.__name__) + if logger: + self.logger = logger + else: + self.logger = logging.getLogger(self.__class__.__name__) + + def on_error(self, frame): + ''' + Error handler + ''' + self.logger.error('[broker] [%s]: %s', self.__broker, frame.body) + + def on_message(self, frame): + self.logger.debug('[broker] [%s]: headers: %s, body: %s', self.__broker, frame.headers, frame.body) + self.__output_queue.put(json_loads(frame.body)) + + +class MapResult(object): + def __init__(self): + self._name_results = {} + self._results = {} + + def __str__(self): + return str(self._name_results) + + def add_result(self, name=None, args=None, key=None, result=None): + name_key = key + if name_key is None: + key = get_unique_id_for_dict(args) + name_key = '%s:%s' % (name, key) + else: + # name_key = key + # name = ':'.join(name_key.split(":")[:-1]) + key = name_key.split(":")[-1] + + self._name_results[name_key] = result + self._results[key] = result + + def has_result(self, name=None, args=None, key=None): + name_key = key + if name_key is not None: + if name_key in self._name_results: + return True + return False + else: + key = get_unique_id_for_dict(args) + name_key = '%s:%s' % (name, key) + + if name is not None: + if name_key in self._name_results: + return True + return False + else: + if key in self._result: + return True + return False + + def get_result(self, name=None, args=None, key=None): + logging.debug("get_result: key %s, name: %s, args: %s" % (key, name, args)) + logging.debug("get_result: results: %s, name_results: %s" % (self._results, self._name_results)) + + name_key = key + if name_key is not None: + ret = self._name_results.get(name_key, None) + else: + key = get_unique_id_for_dict(args) + + if name is not None: + name_key = '%s:%s' % (name, key) + ret = self._name_results.get(name_key, None) + else: + ret = self._results.get(key, None) + return ret + + def get_all_results(self): + return self._results + + +class AsyncResult(Base): + + def __init__(self, work_context, name=None, wait_num=1, wait_keys=[], group_kwargs=[], run_group_kwarg=None, map_results=False, + wait_percent=1, internal_id=None, timeout=None): + """ + Init a workflow. + """ + super(AsyncResult, self).__init__() + if internal_id: + self.internal_id = internal_id + self._work_context = work_context + + self._name = name + self._queue = Queue() + + self._connections = [] + self._graceful_stop = None + self._subscribe_thread = None + + self._results = [] + self._bad_results = [] + self._results_percentage = 0 + self._map_results = map_results + + self._wait_num = wait_num + if not self._wait_num: + self._wait_num = 1 + self._wait_keys = set(wait_keys) + self._group_kwargs = group_kwargs + self._run_group_kwarg = run_group_kwarg + + self._wait_percent = wait_percent + self._num_wrong_keys = 0 + + self._timeout = timeout + + @property + def logger(self): + return logging.getLogger(self.__class__.__name__) + + @logger.setter + def logger(self, value): + pass + + @property + def wait_keys(self): + if len(self._wait_keys) > 0: + self._wait_num = len(self._wait_keys) + return self._wait_keys + if self._group_kwargs: + for kwargs in self._group_kwargs: + k = get_unique_id_for_dict(kwargs) + k = "%s:%s" % (self._name, k) + self.logger.info("args (%s) to key: %s" % (str(kwargs), k)) + self._wait_keys.add(k) + self._wait_num = len(self._wait_keys) + return self._wait_keys + + @wait_keys.setter + def wait_keys(self, value): + self._wait_keys = set(value) + + @property + def results(self): + while not self._queue.empty(): + ret = self._queue.get() + try: + internal_id = ret['internal_id'] + if internal_id == self.internal_id: + self._results.append(ret) + else: + self._bad_results.append(ret) + except Exception as ex: + self.logger.error("Received bad result: %s: %s" % (str(ret), str(ex))) + if self._bad_results: + self.logger.error("Received bad results: %s" % str(self._bad_results)) + + self.logger.debug("_results: %s, bad_results: %s" % (str(self._results), str(self._bad_results))) + self.logger.debug("wait_keys: %s, wait_num: %s" % (str(self.wait_keys), self._wait_num)) + rets_dict = {} + for result in self._results: + key = result['key'] + ret = result['ret'] + rets_dict[key] = ret + + if self._map_results: + rets = {} + if len(self.wait_keys) > 0: + for k in self.wait_keys: + if k in rets_dict: + rets[k] = rets_dict[k] + self._results_percentage = len(list(rets.keys())) * 1.0 / len(self.wait_keys) + else: + rets = rets_dict + self._results_percentage = len(list(rets.keys())) * 1.0 / self._wait_num + + ret_map = MapResult() + for k in rets: + ret_map.add_result(key=k, result=rets[k]) + return ret_map + else: + rets = [] + if len(self.wait_keys) > 0: + for k in self.wait_keys: + if k in rets_dict: + rets.append(rets_dict[k]) + self._results_percentage = len(rets) * 1.0 / len(self.wait_keys) + else: + rets = [rets_dict[k] for k in rets_dict] + self._results_percentage = len(rets) * 1.0 / self._wait_num + + if self._wait_num == 1: + if rets: + return rets[0] + else: + return None + return rets + + @results.setter + def results(self, value): + raise Exception("Not allowed to set results.") + if type(value) not in [list, tuple]: + raise Exception("Results must be list or tuple, currently it is %s" % type(value)) + self._results = value + + def disconnect(self): + for con in self._connections: + try: + if con.is_connected(): + con.disconnect() + except Exception: + pass + + def connect_to_messaging_broker(self): + workflow_context = self._work_context + brokers = workflow_context.brokers + + brokers = brokers.split(",") + broker = random.sample(brokers, k=1)[0] + + self.logger.info("Got broker: %s" % (broker)) + + timeout = workflow_context.broker_timeout + self.disconnect() + + broker, port = broker.split(":") + conn = stomp.Connection12(host_and_ports=[(broker, port)], + keepalive=True, + heartbeats=(30000, 30000), # half minute = num / 1000 + timeout=timeout) + conn.connect(workflow_context.broker_username, workflow_context.broker_password, wait=True) + self._connections = [conn] + return conn + + def subscribe_to_messaging_brokers(self): + workflow_context = self._work_context + brokers = workflow_context.brokers + conns = [] + + broker_addresses = [] + for b in brokers.split(","): + try: + b, port = b.split(":") + + addrinfos = socket.getaddrinfo(b, 0, socket.AF_INET, 0, socket.IPPROTO_TCP) + for addrinfo in addrinfos: + b_addr = addrinfo[4][0] + broker_addresses.append((b_addr, port)) + except socket.gaierror as error: + self.logger.error('Cannot resolve hostname %s: %s' % (b, str(error))) + self._graceful_stop.set() + + self.logger.info("Resolved broker addresses: %s" % (broker_addresses)) + + timeout = workflow_context.broker_timeout + + self.disconnect() + + listener = MessagingListener(brokers, self._queue, logger=self.logger) + conns = [] + for broker, port in broker_addresses: + conn = stomp.Connection12(host_and_ports=[(broker, port)], + keepalive=True, + heartbeats=(30000, 30000), # half minute = num / 1000 + timeout=timeout) + conn.set_listener("messag-subscriber", listener) + conn.connect(workflow_context.broker_username, workflow_context.broker_password, wait=True) + if workflow_context.type == WorkflowType.iWorkflow: + subscribe_id = 'idds-workflow_%s' % self.internal_id + # subscribe_selector = {'selector': "type = 'iworkflow' AND request_id = %s" % workflow_context.request_id} + # subscribe_selector = {'selector': "type = 'iworkflow' AND internal_id = '%s'" % self.internal_id} + subscribe_selector = {'selector': "internal_id = '%s'" % self.internal_id} + elif workflow_context.type == WorkflowType.iWork: + subscribe_id = 'idds-work_%s' % self.internal_id + # subscribe_selector = {'selector': "type = 'iwork' AND request_id = %s AND transform_id = %s " % (workflow_context.request_id, + # workflow_context.transform_id)} + # subscribe_selector = {'selector': "type = 'iwork' AND internal_id = '%s'" % self.internal_id} + subscribe_selector = {'selector': "internal_id = '%s'" % self.internal_id} + else: + subscribe_id = 'idds-workflow_%s' % self.internal_id + subscribe_selector = None + # subscribe_selector = None + # conn.subscribe(destination=workflow_context.broker_destination, id=subscribe_id, + # ack='auto', conf=subscribe_selector) + conn.subscribe(destination=workflow_context.broker_destination, id=subscribe_id, + ack='auto', headers=subscribe_selector) + self.logger.info("subscribe to %s:%s with selector: %s" % (broker, port, subscribe_selector)) + conns.append(conn) + self._connections = conns + return conns + + def publish(self, ret, key=None): + conn = self.connect_to_messaging_broker() + workflow_context = self._work_context + if key is None: + if self._run_group_kwarg: + key = get_unique_id_for_dict(self._run_group_kwarg) + key = "%s:%s" % (self._name, key) + self.logger.info("publish args (%s) to key: %s" % (str(self._run_group_kwarg), key)) + + if workflow_context.type == WorkflowType.iWorkflow: + headers = {'persistent': 'true', + 'type': 'iworkflow', + 'internal_id': str(self.internal_id), + 'request_id': workflow_context.request_id} + body = json_dumps({'ret': ret, 'key': key, 'internal_id': self.internal_id}) + conn.send(body=body, + destination=workflow_context.broker_destination, + id='idds-iworkflow_%s' % self.internal_id, + ack='auto', + headers=headers + ) + self.logger.info("publish header: %s, body: %s" % (str(headers), str(body))) + elif workflow_context.type == WorkflowType.iWork: + headers = {'persistent': 'true', + 'type': 'iwork', + 'internal_id': str(self.internal_id), + 'request_id': workflow_context.request_id, + 'transform_id': workflow_context.transform_id} + body = json_dumps({'ret': ret, 'key': key, 'internal_id': self.internal_id}) + conn.send(body=body, + destination=workflow_context.broker_destination, + id='idds-iwork_%s' % self.internal_id, + ack='auto', + headers=headers + ) + self.logger.info("publish header: %s, body: %s" % (str(headers), str(body))) + self.disconnect() + + def run_subscriber(self): + try: + self.logger.info("run subscriber") + self.subscribe_to_messaging_brokers() + while not self._graceful_stop.is_set(): + has_failed_conns = False + for conn in self._connections: + if not conn.is_connected(): + has_failed_conns = True + if has_failed_conns: + self.subscribe_to_messaging_brokers() + time.sleep(1) + except Exception as ex: + self.logger.error("run subscriber failed with error: %s" % str(ex)) + self.logger.error(traceback.format_exc()) + + def get_results(self): + rets = self.results + self.logger.debug('results: %s' % str(rets)) + return rets + + def get_results_percentage(self): + return self._results_percentage + + def subscribe(self): + self._graceful_stop = threading.Event() + thread = threading.Thread(target=self.run_subscriber, name="RunSubscriber") + thread.start() + time.sleep(1) + self._subscribed = True + + def stop(self): + if self._graceful_stop: + self._graceful_stop.set() + self.disconnect() + + def __del__(self): + self.stop() + + def wait_results(self, timeout=None, force_return_results=False): + if not self._subscribed: + self.subscribe() + + get_results = False + time_log = time.time() + time_start = time.time() + if timeout is None: + self.logger.info("waiting for results") + try: + while not get_results and not self._graceful_stop.is_set(): + self.get_results() + percent = self.get_results_percentage() + if time.time() - time_log > 600: # 10 minutes + self.logger.info("waiting for results: %s (number of wrong keys: %s)" % (percent, self._num_wrong_keys)) + time_log = time.time() + time.sleep(1) + if percent >= self._wait_percent: + get_results = True + if self._timeout is not None and self._timeout > 0 and time.time() - time_start > self._timeout: + # global timeout + self.logger.info("Waiting result timeout(%s seconds)" % self._timeout) + get_results = True + if timeout is not None and timeout > 0 and time.time() - time_start > timeout: + # local timeout + break + + percent = self.get_results_percentage() + if timeout is None or time.time() - time_start > 600: + self.logger.info("Got results: %s (number of wrong keys: %s)" % (percent, self._num_wrong_keys)) + except Exception as ex: + self.logger.error("Wait_results got some exception: %s" % str(ex)) + self.logger.error(traceback.format_exc()) + self._graceful_stop.set() + + if get_results or self._graceful_stop.is_set() or percent >= self._wait_percent or force_return_results: + # stop the subscriber + self._graceful_stop.set() + self.logger.info("Got results: %s (number of wrong keys: %s)" % (percent, self._num_wrong_keys)) + + results = self.results + return results + return None + + def wait_result(self, timeout=None, force_return_results=False): + self.wait_results(timeout=timeout, force_return_results=force_return_results) + results = self.results + return results diff --git a/workflow/lib/idds/iworkflow/base.py b/workflow/lib/idds/iworkflow/base.py new file mode 100644 index 00000000..b6e6da28 --- /dev/null +++ b/workflow/lib/idds/iworkflow/base.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0OA +# +# Authors: +# - Wen Guan, , 2024 + +import logging +import inspect +import os +import traceback +import uuid + +from typing import Any, Dict, List, Optional, Tuple, Union + +from idds.common.dict_class import DictMetadata, DictBase +from idds.common.imports import import_func, get_func_name + + +class IDDSDict(dict): + def __setitem__(self, key, value): + if key == 'test': + pass + else: + super().__setitem__(key, value) + + +class IDDSMetadata(DictMetadata): + def __init__(self): + super(IDDSMetadata, self).__init__() + + +class Base(DictBase): + def __init__(self): + super(Base, self).__init__() + self._internal_id = str(uuid.uuid4())[:8] + self._template_id = self._internal_id + self._sequence_id = 0 + + @property + def internal_id(self): + return self._internal_id + + @internal_id.setter + def internal_id(self, value): + self._internal_id = value + + def get_func_name_and_args(self, + func, + args: Union[List[Any], Optional[Tuple]] = None, + kwargs: Optional[Dict[str, Any]] = None, + group_kwargs: Optional[list[Dict[str, Any]]] = None): + + if args is None: + args = () + if kwargs is None: + kwargs = {} + if group_kwargs is None: + group_kwargs = [] + if not isinstance(args, (tuple, list)): + raise TypeError('{0!r} is not a valid args list'.format(args)) + if not isinstance(kwargs, dict): + raise TypeError('{0!r} is not a valid kwargs dict'.format(kwargs)) + if not isinstance(group_kwargs, list): + raise TypeError('{0!r} is not a valid group_kwargs list'.format(group_kwargs)) + + func_call, func_name = None, None + if isinstance(func, str): + func_name = func + elif inspect.ismethod(func) or inspect.isfunction(func) or inspect.isbuiltin(func): + # func_name = '{0}.{1}'.format(func.__module__, func.__qualname__) + func_name = get_func_name(func) + func_call = func + else: + # raise TypeError('Expected a callable or a string, but got: {0}'.format(func)) + func_name = func + return func_call, (func_name, args, kwargs, group_kwargs) + + @property + def logger(self): + return logging.getLogger(self.__class__.__name__) + + @logger.setter + def logger(self, value): + pass + + def get_internal_id(self): + return self._internal_id + + def get_template_work_id(self): + return self._template_id + + def get_sequence_id(self): + return self._sequence_id + + def get_input_collections(self): + return [] + + def get_output_collections(self): + return [] + + def get_log_collections(self): + return [] + + def prepare(self): + """ + Prepare the workflow: upload the source files to server. + + :returns id: The workflow id. + :raise Exception when failing to prepare the workflow. + """ + + def submit(self): + """ + Submit the workflow to the iDDS server. + + :returns id: The workflow id. + :raise Exception when failing to submit the workflow. + """ + self.prepare() + return None + + def setup(self): + """ + :returns command: `str` to setup the workflow. + """ + return None + + def load(self, func_name): + """ + Load the function from the source files. + + :raise Exception + """ + os.environ['IDDS_IWORKFLOW_LOAD'] = 'true' + func = import_func(func_name) + del os.environ['IDDS_IWORKFLOW_LOAD'] + + return func + + def run_func(self, func, args, kwargs): + """ + Run the function. + + :raise Exception. + """ + try: + return func(*args, **kwargs) + except Exception as ex: + logging.error("Failed to run the function: %s" % str(ex)) + logging.debug(traceback.format_exc()) + + +class Context(DictBase): + def __init__(self): + super(Context, self).__init__() + self._internal_id = str(uuid.uuid4())[:8] + + @property + def internal_id(self): + return self._internal_id + + @internal_id.setter + def internal_id(self, value): + self._internal_id = value + + def prepare(self): + """ + Prepare the workflow. + """ + return None + + def setup(self): + """ + :returns command: `str` to setup the workflow. + """ + return None diff --git a/workflow/lib/idds/iworkflow/utils.py b/workflow/lib/idds/iworkflow/utils.py new file mode 100644 index 00000000..c5e7e26c --- /dev/null +++ b/workflow/lib/idds/iworkflow/utils.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0OA +# +# Authors: +# - Wen Guan, , 2024 + +# from idds.common.utils import encode_base64 + + +def show_relation_map(relation_map, level=0): + # a workflow with a list of works. + if level == 0: + prefix = "" + else: + prefix = " " * level * 4 + + for item in relation_map: + if type(item) in [dict]: + # it's a Work + print("%s%s" % (prefix, item['work']['workload_id'])) + if 'next_works' in item: + # print("%s%s next_works:" % (prefix, item['work']['workload_id'])) + next_works = item['next_works'] + # it's a list. + show_relation_map(next_works, level=level + 1) + elif type(item) in [list]: + # it's a subworkflow with a list of works. + print("%ssubworkflow:" % (prefix)) + show_relation_map(next_works, level=level + 1) + + +def perform_workflow(workflow): + workflow.load() + workflow.run() diff --git a/workflow/lib/idds/iworkflow/version.py b/workflow/lib/idds/iworkflow/version.py new file mode 100644 index 00000000..317bd49c --- /dev/null +++ b/workflow/lib/idds/iworkflow/version.py @@ -0,0 +1,12 @@ +#!/usr/bin/env python +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0OA +# +# Authors: +# - Wen Guan, , 2023 + + +release_version = "2.0.0" diff --git a/workflow/lib/idds/iworkflow/work.py b/workflow/lib/idds/iworkflow/work.py new file mode 100644 index 00000000..38e343ca --- /dev/null +++ b/workflow/lib/idds/iworkflow/work.py @@ -0,0 +1,902 @@ +#!/usr/bin/env python +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0OA +# +# Authors: +# - Wen Guan, , 2023 - 2024 + +import datetime +import functools +import logging +import os +import time +import traceback + +from idds.common import exceptions +from idds.common.constants import WorkflowType, TransformStatus +from idds.common.imports import get_func_name +from idds.common.utils import setup_logging, json_dumps, encode_base64 +from .asyncresult import AsyncResult, MapResult +from .base import Base, Context +from .workflow import WorkflowCanvas + +setup_logging(__name__) + + +class WorkContext(Context): + + def __init__(self, name=None, workflow_context=None, source_dir=None): + super(WorkContext, self).__init__() + self._workflow_context = workflow_context + self._transform_id = None + self._processing_id = None + self._type = WorkflowType.iWork + + self._name = name + self._site = None + self._queue = None + + self._priority = 500 + self._core_count = 1 + self._total_memory = 1000 # MB + self._max_walltime = 7 * 24 * 3600 + self._max_attempt = 5 + + self._map_results = False + + def get_service(self): + return self._workflow_context.service + + @property + def logger(self): + return logging.getLogger(self.__class__.__name__) + + @property + def distributed(self): + return self._workflow_context.distributed + + @distributed.setter + def distributed(self, value): + self._workflow_context.distributed = value + + @property + def service(self): + return self._workflow_context.service + + @service.setter + def service(self, value): + self._workflow_context.service = value + + @property + def vo(self): + return self._workflow_context.vo + + @vo.setter + def vo(self, value): + self._workflow_context.vo = value + + @property + def site(self): + if self._site: + return self._site + return self._workflow_context.site + + @site.setter + def site(self, value): + self._site = value + + @property + def queue(self): + if self._queue: + return self._queue + return self._workflow_context.queue + + @queue.setter + def queue(self, value): + self._queue = value + + @property + def cloud(self): + return self._workflow_context.cloud + + @cloud.setter + def cloud(self, value): + self._workflow_context.cloud = value + + @property + def working_group(self): + return self._workflow_context.working_group + + @working_group.setter + def working_group(self, value): + self._workflow_context.working_group = value + + @property + def priority(self): + if self._priority: + return self._priority + return self._workflow_context.priority + + @priority.setter + def priority(self, value): + self._priority = value + + @property + def core_count(self): + if self._core_count: + return self._core_count + return self._workflow_context.core_count + + @core_count.setter + def core_count(self, value): + self._core_count = value + + @property + def total_memory(self): + if self._total_memory: + return self._total_memory + return self._workflow_context.total_memory + + @total_memory.setter + def total_memory(self, value): + self._total_memory = value + + @property + def max_walltime(self): + if self._max_walltime: + return self._max_walltime + return self._workflow_context.max_walltime + + @max_walltime.setter + def max_walltime(self, value): + self._max_walltime = value + + @property + def max_attempt(self): + if self._max_attempt: + return self._max_attempt + return self._workflow_context.max_attempt + + @max_attempt.setter + def max_attempt(self, value): + self._max_attempt = value + + @property + def username(self): + return self._workflow_context.username + + @username.setter + def username(self, value): + self._workflow_context.username = value + + @property + def userdn(self): + return self._workflow_context.userdn + + @userdn.setter + def userdn(self, value): + self._workflow_context.userdn = value + + @property + def type(self): + return self._type + + @type.setter + def type(self, value): + self._type = value + + @property + def lifetime(self): + return self._workflow_context.lifetime + + @lifetime.setter + def lifetime(self, value): + self._workflow_context.lifetime = value + + @property + def request_id(self): + return self._workflow_context.request_id + + @request_id.setter + def request_id(self, value): + self._workflow_context.request_id = value + + @property + def workload_id(self): + return self._workload_id + + @workload_id.setter + def workload_id(self, value): + self._workload_id = value + + @property + def transform_id(self): + return self._transform_id + + @transform_id.setter + def transform_id(self, value): + self._transform_id = int(value) + + @property + def processing_id(self): + return self._processing_id + + @processing_id.setter + def processing_id(self, value): + self._processing_id = value + + @property + def brokers(self): + return self._workflow_context.brokers + + @brokers.setter + def brokers(self, value): + self._workflow_context.brokers = value + + @property + def broker_timeout(self): + return self._workflow_context.broker_timeout + + @broker_timeout.setter + def broker_timeout(self, value): + self._workflow_context.broker_timeout = value + + @property + def broker_username(self): + return self._workflow_context.broker_username + + @broker_username.setter + def broker_username(self, value): + self._workflow_context.broker_username = value + + @property + def broker_password(self): + return self._workflow_context.broker_password + + @broker_password.setter + def broker_password(self, value): + self._workflow_context.broker_password = value + + @property + def broker_destination(self): + return self._workflow_context.broker_destination + + @broker_destination.setter + def broker_destination(self, value): + self._workflow_context.broker_destination = value + + @property + def token(self): + return self._workflow_context.token + + @token.setter + def token(self, value): + self._workflow_context.token = value + + @property + def map_results(self): + return self._map_results + + @map_results.setter + def map_results(self, value): + self._map_results = value + + def get_idds_server(self): + return self._workflow_context.get_idds_server() + + def initialize(self): + return self._workflow_context.initialize() + + def setup_source_files(self): + """ + Setup source files. + """ + return self._workflow_context.setup_source_files() + + def setup(self): + """ + :returns command: `str` to setup the workflow. + """ + return self._workflow_context.setup() + + +class Work(Base): + + def __init__(self, func=None, workflow_context=None, context=None, args=None, kwargs=None, group_kwargs=None, + update_kwargs=None, map_results=False, source_dir=None, is_unique_func_name=False): + """ + Init a workflow. + """ + super(Work, self).__init__() + self.prepared = False + + # self._func = func + self._func, self._func_name_and_args = self.get_func_name_and_args(func, args, kwargs, group_kwargs) + + self._update_kwargs = update_kwargs + + self._name = self._func_name_and_args[0] + if self._name: + self._name = self._name.replace('__main__:', '').replace('.py', '').replace(':', '.') + if not is_unique_func_name: + if self._name: + self._name = self._name + "_" + datetime.datetime.utcnow().strftime("%Y_%m_%d_%H_%M_%S") + + if context: + self._context = context + else: + self._context = WorkContext(name=self._name, workflow_context=workflow_context) + + self.map_results = map_results + self._results = None + + @property + def logger(self): + return logging.getLogger(self.__class__.__name__) + + @property + def internal_id(self): + return self._context.internal_id + + @internal_id.setter + def internal_id(self, value): + self._context.internal_id = value + + @property + def service(self): + return self._context.service + + @service.setter + def service(self, value): + self._context.service = value + + @property + def name(self): + return self._name + + @name.setter + def name(self, value): + self._name = value + + @property + def request_id(self): + return self._context.request_id + + @request_id.setter + def request_id(self, value): + self._context.request_id = value + + @property + def transform_id(self): + return self._context.transform_id + + @transform_id.setter + def transform_id(self, value): + self._context.transform_id = int(value) + + @property + def processing_id(self): + return self._context.processing_id + + @processing_id.setter + def processing_id(self, value): + self._context.processing_id = value + + @property + def vo(self): + return self._context.vo + + @vo.setter + def vo(self, value): + self._context.vo = value + + @property + def queue(self): + return self._context.queue + + @queue.setter + def queue(self, value): + self._context.queue = value + + @property + def site(self): + return self._context.site + + @site.setter + def site(self, value): + self._context.site = value + + def get_site(self): + return self.site + + @property + def cloud(self): + return self._context.cloud + + @cloud.setter + def cloud(self, value): + self._context.cloud = value + + @property + def working_group(self): + return self._context.working_group + + @working_group.setter + def working_group(self, value): + self._context.working_group = value + + @property + def priority(self): + return self._context.priority + + @priority.setter + def priority(self, value): + self._context.priority = value + + @property + def core_count(self): + return self._context.core_count + + @core_count.setter + def core_count(self, value): + self._context.core_count = value + + @property + def total_memory(self): + return self._context.total_memory + + @total_memory.setter + def total_memory(self, value): + self._context.total_memory = value + + @property + def max_walltime(self): + return self._context.max_walltime + + @max_walltime.setter + def max_walltime(self, value): + self._context.max_walltime = value + + @property + def max_attempt(self): + return self._context.max_attempt + + @max_attempt.setter + def max_attempt(self, value): + self._context.max_attempt = value + + @property + def username(self): + return self._context.username + + @username.setter + def username(self, value): + self._context.username = value + + @property + def userdn(self): + return self._context.userdn + + @userdn.setter + def userdn(self, value): + self._context.userdn = value + + @property + def type(self): + return self._context.type + + @type.setter + def type(self, value): + self._context.type = value + + @property + def map_results(self): + return self._context.map_results + + @map_results.setter + def map_results(self, value): + self._context.map_results = value + + @property + def lifetime(self): + return self._context.lifetime + + @lifetime.setter + def lifetime(self, value): + self._context.lifetime = value + + @property + def workload_id(self): + return self._context.workload_id + + @workload_id.setter + def workload_id(self, value): + self._context.workload_id = value + + def get_workload_id(self): + return self.workload_id + + @property + def token(self): + return self._context.token + + @token.setter + def token(self, value): + self._context.token = value + + @property + def group_parameters(self): + return self._func_name_and_args[3] + + @group_parameters.setter + def group_parameters(self, value): + raise Exception("Not allwed to update group parameters") + + def get_work_tag(self): + return 'iWork' + + def get_work_type(self): + return WorkflowType.iWork + + def get_work_name(self): + return self._name + + def to_dict(self): + func = self._func + self._func = None + obj = super(Work, self).to_dict() + self._func = func + return obj + + def submit_to_idds_server(self): + """ + Submit the workflow to the iDDS server. + + :returns id: The workflow id. + :raise Exception when failing to submit the workflow. + """ + # iDDS ClientManager + from idds.client.clientmanager import ClientManager + client = ClientManager(host=self._context.get_idds_server()) + request_id = self._context.request_id + transform_id = client.submit_work(request_id, self, use_dataset_name=False) + logging.info("Submitted into iDDS with transform id=%s", str(transform_id)) + return transform_id + + def submit_to_panda_server(self): + """ + Submit the workflow to the iDDS server through PanDA service. + + :returns id: The workflow id. + :raise Exception when failing to submit the workflow. + """ + import idds.common.utils as idds_utils + import pandaclient.idds_api as idds_api + idds_server = self._context.get_idds_server() + request_id = self._context.request_id + client = idds_api.get_api(idds_utils.json_dumps, + idds_host=idds_server, + compress=True, + manager=True) + transform_id = client.submit_work(request_id, self, use_dataset_name=False) + logging.info("Submitted work into PanDA-iDDS with transform id=%s", str(transform_id)) + return transform_id + + def submit(self): + """ + Submit the workflow to the iDDS server. + + :returns id: The workflow id. + :raise Exception when failing to submit the workflow. + """ + if self._context.get_service() == 'panda': + tf_id = self.submit_to_panda_server() + else: + tf_id = self.submit_to_idds_server() + + try: + self._context.transform_id = int(tf_id) + return tf_id + except Exception as ex: + logging.info("Transform id (%s) is not integer, there should be some submission errors: %s" % (tf_id, str(ex))) + + return None + + def get_status_from_panda_server(self): + import idds.common.utils as idds_utils + import pandaclient.idds_api as idds_api + + idds_server = self._context.get_idds_server() + client = idds_api.get_api(idds_utils.json_dumps, + idds_host=idds_server, + compress=True, + manager=True) + + request_id = self._context.request_id + transform_id = self._context.transform_id + if not transform_id: + log_msg = "No transform id defined (request_id: %s, transform_id: %s)", (request_id, transform_id) + logging.error(log_msg) + return exceptions.IDDSException(log_msg) + + tf = client.get_transform(request_id=request_id, transform_id=transform_id) + if not tf: + logging.info("Get transform (request_id: %s, transform_id: %s) from iDDS: %s" % (request_id, transform_id, tf)) + return None + + logging.info("Get transform status (request_id: %s, transform_id: %s) from iDDS: %s" % (request_id, transform_id, tf['status'])) + + return tf['status'] + + def get_status_from_idds_server(self): + from idds.client.clientmanager import ClientManager + client = ClientManager(host=self._context.get_idds_server()) + + request_id = self._context.request_id + transform_id = self._context.transform_id + if not transform_id: + log_msg = "No transform id defined (request_id: %s, transform_id: %s)" % (request_id, transform_id) + logging.error(log_msg) + return exceptions.IDDSException(log_msg) + + tf = client.get_transform(request_id=request_id, transform_id=transform_id) + if not tf: + logging.info("Get transform (request_id: %s, transform_id: %s) from iDDS: %s" % (request_id, transform_id, tf)) + return None + + logging.info("Get transform status (request_id: %s, transform_id: %s) from iDDS: %s" % (request_id, transform_id, tf['status'])) + + return tf['status'] + + def get_status(self): + try: + if self._context.get_service() == 'panda': + return self.get_status_from_panda_server() + return self.get_status_from_idds_server() + except Exception as ex: + logging.info("Failed to get transform status: %s" % str(ex)) + + def get_terminated_status(self): + return [None, TransformStatus.Finished, TransformStatus.SubFinished, + TransformStatus.Failed, TransformStatus.Cancelled, + TransformStatus.Suspended, TransformStatus.Expired] + + def get_func_name(self): + func_name = self._func_name_and_args[0] + return func_name + + def get_group_kwargs(self): + group_kwargs = self._func_name_and_args[3] + return group_kwargs + + def wait_results(self): + try: + terminated_status = self.get_terminated_status() + + group_kwargs = self.get_group_kwargs() + if group_kwargs: + async_ret = AsyncResult(self._context, name=self.get_func_name(), group_kwargs=group_kwargs, + map_results=self.map_results, internal_id=self.internal_id) + else: + async_ret = AsyncResult(self._context, name=self.get_func_name(), wait_num=1, internal_id=self.internal_id) + + async_ret.subscribe() + + status = self.get_status() + time_last_check_status = time.time() + logging.info("waiting for results") + while status not in terminated_status: + # time.sleep(10) + ret = async_ret.wait_results(timeout=10) + if ret: + logging.info("Recevied result: %s" % ret) + break + if time.time() - time_last_check_status > 600: # 10 minutes + status = self.get_status() + time_last_check_status = time.time() + + async_ret.stop() + self._results = async_ret.wait_results(force_return_results=True) + return self._results + except Exception as ex: + logging.error("wait_results got some errors: %s" % str(ex)) + async_ret.stop() + return ex + + def get_results(self): + return self._results + + def setup(self): + """ + :returns command: `str` to setup the workflow. + """ + return self._context.setup() + + def load(self, func_name): + """ + Load the function from the source files. + + :raise Exception + """ + os.environ['IDDS_IWORKFLOW_LOAD_WORK'] = 'true' + func = super(Work, self).load(func_name) + del os.environ['IDDS_IWORKFLOW_LOAD_WORK'] + + return func + + def pre_run(self): + # test AsyncResult + workflow_context = self._context + if workflow_context.distributed: + logging.info("Test AsyncResult") + a_ret = AsyncResult(workflow_context, wait_num=1, timeout=30) + a_ret.subscribe() + + async_ret = AsyncResult(workflow_context, internal_id=a_ret.internal_id) + test_result = "AsyncResult test (request_id: %s, transform_id: %s)" % (workflow_context.request_id, workflow_context.transform_id) + logging.info("AsyncResult publish: %s" % test_result) + async_ret.publish(test_result) + + ret_q = a_ret.wait_result(force_return_results=True) + logging.info("AsyncResult results: %s" % str(ret_q)) + if ret_q and ret_q == test_result: + logging.info("AsyncResult test succeeded") + return True + else: + logging.info("AsyncResult test failed (published: %s, received: %s)" % (test_result, ret_q)) + return False + return True + + def run(self): + """ + Run the work. + """ + self.pre_run() + + func_name, args, kwargs, group_kwargs = self._func_name_and_args + if self._func is None: + func = self.load(func_name) + self._func = func + + if self._context.distributed: + rets = None + kwargs_copy = kwargs.copy() + if self._update_kwargs and type(self._update_kwargs) in [dict]: + kwargs_copy.update(self._update_kwargs) + + rets = self.run_func(self._func, args, kwargs_copy) + + request_id = self._context.request_id + transform_id = self._context.transform_id + logging.info("publishing AsyncResult to (request_id: %s, transform_id: %s): %s" % (request_id, transform_id, rets)) + async_ret = AsyncResult(self._context, name=self.get_func_name(), internal_id=self.internal_id, run_group_kwarg=self._update_kwargs) + async_ret.publish(rets) + + if not self.map_results: + self._results = rets + else: + self._results = MapResult() + self._results.add_result(name=self.get_func_name(), args=self._update_kwargs, result=rets) + return self._results + else: + if not group_kwargs: + rets = self.run_func(self._func, args, kwargs) + if not self.map_results: + self._results = rets + else: + self._results = MapResult() + self._results.add_result(name=self.get_func_name(), args=self._update_kwargs, result=rets) + return self._results + else: + if not self.map_results: + self._results = [] + for group_kwarg in group_kwargs: + kwargs_copy = kwargs.copy() + kwargs_copy.update(group_kwarg) + rets = self.run_func(self._func, args, kwargs_copy) + self._results.append(rets) + else: + self._results = MapResult() + for group_kwarg in group_kwargs: + kwargs_copy = kwargs.copy() + kwargs_copy.update(group_kwarg) + rets = self.run_func(self._func, args, kwargs_copy) + self._results.add_result(name=self.get_func_name(), args=group_kwarg, result=rets) + return self._results + + def get_run_command(self): + cmd = "run_workflow --type work " + cmd += "--context %s --original_args %s " % (encode_base64(json_dumps(self._context)), + encode_base64(json_dumps(self._func_name_and_args))) + cmd += "--update_args ${IN/L}" + return cmd + + def get_runner(self): + setup = self.setup() + cmd = "" + run_command = self.get_run_command() + + if setup: + cmd = ' --setup "' + setup + '" ' + if cmd: + cmd = cmd + " " + run_command + else: + cmd = run_command + return cmd + + +def run_work_distributed(w): + try: + tf_id = w.submit() + if tf_id: + logging.info("wait for results") + rets = w.wait_results() + logging.info("Got results: %s" % rets) + return rets + else: + logging.error("Failed to distribute work: %s" % w.name) + return None + except Exception as ex: + logging.error("Failed to run the work distributedly: %s" % ex) + logging.error(traceback.format_exc()) + return None + + +# foo = work(arg)(foo) +def work(func=None, *, map_results=False, lazy=False): + if func is None: + return functools.partial(work, map_results=map_results, lazy=lazy) + + if 'IDDS_IWORKFLOW_LOAD_WORK' in os.environ: + return func + + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + f = kwargs.pop('workflow', None) or WorkflowCanvas.get_current_workflow() + workflow_context = f._context + group_kwargs = kwargs.pop('group_kwargs', []) + logging.debug("workflow context: %s" % workflow_context) + + logging.debug("work decorator: func: %s, map_results: %s" % (func, map_results)) + if workflow_context: + logging.debug("setup work") + w = Work(workflow_context=workflow_context, func=func, args=args, kwargs=kwargs, group_kwargs=group_kwargs, map_results=map_results) + # if distributed: + if workflow_context.distributed: + ret = run_work_distributed(w) + return ret + + return w.run() + else: + logging.info("workflow context is not defined, run function locally") + if not group_kwargs: + return func(*args, **kwargs) + + if not kwargs: + kwargs = {} + if not map_results: + rets = [] + for group_kwarg in group_kwargs: + kwargs_copy = kwargs.copy() + kwargs_copy.update(group_kwarg) + ret = func(*args, **kwargs_copy) + rets.append(ret) + return rets + else: + rets = MapResult() + for group_kwarg in group_kwargs: + kwargs_copy = kwargs.copy() + kwargs_copy.update(group_kwarg) + ret = func(*args, **kwargs_copy) + rets.add_result(name=get_func_name(func), args=group_kwarg, result=ret) + return rets + except Exception as ex: + logging.error("Failed to run workflow %s: %s" % (func, ex)) + raise ex + except: + raise + return wrapper diff --git a/workflow/lib/idds/iworkflow/workflow.py b/workflow/lib/idds/iworkflow/workflow.py new file mode 100644 index 00000000..7f6d2358 --- /dev/null +++ b/workflow/lib/idds/iworkflow/workflow.py @@ -0,0 +1,1048 @@ +#!/usr/bin/env python +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0OA +# +# Authors: +# - Wen Guan, , 2023 - 2024 + +import collections +import datetime +import functools +import logging +import inspect +import os +import tarfile +import uuid + +# from types import ModuleType + +# from idds.common import exceptions +from idds.common.constants import WorkflowType +from idds.common.utils import setup_logging, create_archive_file, json_dumps, encode_base64 +from .asyncresult import AsyncResult +from .base import Base, Context + + +setup_logging(__name__) + + +class WorkflowCanvas(object): + + _managed_workflows: collections.deque[Base] = collections.deque() + + @classmethod + def push_managed_workflow(cls, workflow: Base): + cls._managed_workflows.appendleft(workflow) + + @classmethod + def pop_managed_workflow(cls): + workflow = cls._managed_workflows.popleft() + return workflow + + @classmethod + def get_current_workflow(cls): + try: + return cls._managed_workflows[0] + except IndexError: + return None + + +class WorkflowContext(Context): + def __init__(self, name=None, service='panda', source_dir=None, distributed=True, init_env=None): + super(WorkflowContext, self).__init__() + self._service = service # panda, idds, sharefs + self._request_id = None + self._type = WorkflowType.iWorkflow + + # self.idds_host = None + # self.idds_async_host = None + self._idds_env = {} + self._panda_env = {} + + self._name = name + self._source_dir = source_dir + self.remote_source_file = None + + self._vo = None + + self._queue = None + self._site = None + self._cloud = None + + self._working_group = None + + self._priority = 500 + self._core_count = 1 + self._total_memory = 1000 # MB + self._max_walltime = 7 * 24 * 3600 + self._max_attempt = 5 + + self._username = None + self._userdn = None + self._type = WorkflowType.iWorkflow + self._lifetime = 7 * 24 * 3600 + self._workload_id = None + self._request_id = None + + self.distributed = distributed + + self._broker_initialized = False + self._brokers = None + self._broker_timeout = 180 + self._broker_username = None + self._broker_password = None + self._broker_destination = None + + self.init_brokers() + + self._token = str(uuid.uuid4()) + + self._panda_initialized = False + self.init_panda() + self._idds_initialized = False + self.init_idds() + + self._init_env = init_env + + @property + def logger(self): + return logging.getLogger(self.__class__.__name__) + + @property + def distributed(self): + return self._distributed + + @distributed.setter + def distributed(self, value): + self._distributed = value + + @property + def service(self): + return self._service + + @service.setter + def service(self, value): + self._service = value + + @property + def init_env(self): + return self._init_env + + @init_env.setter + def init_env(self, value): + self._init_env = value + + @property + def vo(self): + return self._vo + + @vo.setter + def vo(self, value): + self._vo = value + + @property + def queue(self): + return self._queue + + @queue.setter + def queue(self, value): + self._queue = value + + @property + def site(self): + return self._site + + @site.setter + def site(self, value): + self._site = value + + @property + def cloud(self): + return self._cloud + + @cloud.setter + def cloud(self, value): + self._cloud = value + + @property + def working_group(self): + return self._working_group + + @working_group.setter + def working_group(self, value): + self._working_group = value + + @property + def priority(self): + return self._priority + + @priority.setter + def priority(self, value): + self._priority = value + + @property + def core_count(self): + return self._core_count + + @core_count.setter + def core_count(self, value): + self._core_count = value + + @property + def total_memory(self): + return self._total_memory + + @total_memory.setter + def total_memory(self, value): + self._total_memory = value + + @property + def max_walltime(self): + return self._max_walltime + + @max_walltime.setter + def max_walltime(self, value): + self._max_walltime = value + + @property + def max_attempt(self): + return self._max_attempt + + @max_attempt.setter + def max_attempt(self, value): + self._max_attempt = value + + @property + def username(self): + return self._username + + @username.setter + def username(self, value): + self._username = value + + @property + def userdn(self): + return self._userdn + + @userdn.setter + def userdn(self, value): + self._userdn = value + + @property + def type(self): + return self._type + + @type.setter + def type(self, value): + self._type = value + + @property + def lifetime(self): + return self._lifetime + + @lifetime.setter + def lifetime(self, value): + self._lifetime = value + + @property + def request_id(self): + return self._request_id + + @request_id.setter + def request_id(self, value): + self._request_id = int(value) + + @property + def workload_id(self): + return self._workload_id + + @workload_id.setter + def workload_id(self, value): + self._workload_id = value + + @property + def brokers(self): + return self._brokers + + @brokers.setter + def brokers(self, value): + self._brokers = value + + @property + def broker_timeout(self): + return self._broker_timeout + + @broker_timeout.setter + def broker_timeout(self, value): + self._broker_timeout = value + + @property + def broker_username(self): + return self._broker_username + + @broker_username.setter + def broker_username(self, value): + self._broker_username = value + + @property + def broker_password(self): + if self._broker_password: + return self._broker_password + return None + + @broker_password.setter + def broker_password(self, value): + self._broker_password = value + + @property + def broker_destination(self): + return self._broker_destination + + @broker_destination.setter + def broker_destination(self, value): + self._broker_destination = value + + @property + def token(self): + return self._token + + @token.setter + def token(self, value): + self._token = value + + def init_brokers(self): + if not self._broker_initialized: + brokers = os.environ.get("IDDS_BROKERS", None) + broker_destination = os.environ.get("IDDS_BROKER_DESTINATION", None) + broker_timeout = os.environ.get("IDDS_BROKER_TIMEOUT", 180) + broker_username = os.environ.get("IDDS_BROKER_USERNAME", None) + broker_password = os.environ.get("IDDS_BROKER_PASSWORD", None) + if brokers and broker_destination and broker_username and broker_password: + self._brokers = brokers + self._broker_timeout = int(broker_timeout) + self._broker_username = broker_username + self._broker_password = broker_password + self._broker_destination = broker_destination + + self._broker_initialized = True + + def init_idds(self): + if not self._idds_initialized: + self._idds_initialized = True + self._idds_env = self.get_idds_env() + + def init_panda(self): + if not self._panda_initialized: + self._panda_initialized = True + if not self.site: + self.site = os.environ.get("PANDA_SITE", None) + if not self.queue: + self.queue = os.environ.get("PANDA_QUEUE", None) + if not self.cloud: + self.cloud = os.environ.get("PANDA_CLOUD", None) + if not self.vo: + self.vo = os.environ.get("PANDA_VO", None) + if not self.working_group: + self.working_group = os.environ.get("PANDA_WORKING_GROUP", None) + + def initialize(self): + # env_list = ['IDDS_HOST', 'IDDS_AUTH_TYPE', 'IDDS_VO', 'IDDS_AUTH_NO_VERIFY', + # 'OIDC_AUTH_ID_TOKEN', 'OIDC_AUTH_VO'] + env_list = ['IDDS_HOST', 'IDDS_AUTH_NO_VERIFY'] + for env in env_list: + if env not in os.environ and env in self._idds_env: + os.environ[env] = self._idds_env[env] + + # env_list = ['PANDA_CONFIG_ROOT', 'PANDA_URL_SSL', 'PANDA_URL', 'PANDACACHE_URL', 'PANDAMON_URL', + # 'PANDA_AUTH', 'PANDA_VERIFY_HOST', 'PANDA_AUTH_VO', 'PANDA_BEHIND_REAL_LB'] + env_list = ['PANDA_URL_SSL', 'PANDA_URL', 'PANDACACHE_URL', 'PANDAMON_URL', + 'PANDA_VERIFY_HOST', 'PANDA_BEHIND_REAL_LB'] + for env in env_list: + if env not in os.environ and env in self._panda_env: + os.environ[env] = self._panda_env[env] + if 'PANDA_CONFIG_ROOT' not in os.environ: + os.environ['PANDA_CONFIG_ROOT'] = os.getcwd() + + def setup(self): + """ + :returns command: `str` to setup the workflow. + """ + if self.service == 'panda': + set_up = self.setup_panda() + elif self.service == 'idds': + set_up = self.setup_idds() + elif self.service == 'sharefs': + set_up = self.setup_sharefs() + else: + set_up = self.setup_sharefs() + + init_env = self.init_env + ret = None + if set_up: + ret = set_up + if init_env: + if ret: + ret = ret + "; " + init_env + else: + ret = init_env + return ret + + def setup_source_files(self): + """ + Setup source files. + """ + if self.service == 'panda': + return self.setup_panda_source_files() + elif self.service == 'idds': + return self.setup_idds_source_files() + elif self.service == 'sharefs': + return self.setup_sharefs_source_files() + return self.setup_sharefs_source_files() + + def download_source_files_from_panda(self, filename): + """Download and extract the tarball from pandacache""" + archive_basename = os.path.basename(filename) + target_dir = os.getcwd() + full_output_filename = os.path.join(target_dir, archive_basename) + logging.info("Downloading %s to %s" % (filename, full_output_filename)) + + if filename.startswith("https:"): + panda_cache_url = os.path.dirname(os.path.dirname(filename)) + os.environ["PANDACACHE_URL"] = panda_cache_url + elif "PANDACACHE_URL" not in os.environ and "PANDA_URL_SSL" in os.environ: + os.environ["PANDACACHE_URL"] = os.environ["PANDA_URL_SSL"] + logging.info("PANDACACHE_URL: %s" % os.environ.get("PANDACACHE_URL", None)) + + from pandaclient import Client + + attempt = 0 + max_attempts = 3 + done = False + while attempt < max_attempts and not done: + attempt += 1 + status, output = Client.getFile(archive_basename, output_path=full_output_filename) + if status == 0: + done = True + logging.info(f"Download archive file from pandacache status: {status}, output: {output}") + if status != 0: + raise RuntimeError("Failed to download archive file from pandacache") + with tarfile.open(full_output_filename, "r:gz") as f: + f.extractall(target_dir) + logging.info(f"Extract {full_output_filename} to {target_dir}") + os.remove(full_output_filename) + logging.info("Remove %s" % full_output_filename) + + def setup_panda(self): + """ + Download source files from the panda cache and return the setup env. + + :returns command: `str` to setup the workflow. + """ + # setup = 'source setup.sh' + # return setup + return None + + def setup_idds(self): + """ + Download source files from the idds cache and return the setup env. + + :returns command: `str` to setup the workflow. + """ + return None + + def setup_sharefs(self): + """ + Download source files from the share file system or use the codes from the share file system. + Return the setup env. + + :returns command: `str` to setup the workflow. + """ + return None + + def setup_panda_source_files(self): + """ + Download source files from the panda cache and return the setup env. + + :returns command: `str` to setup the workflow. + """ + if self.remote_source_file: + self.download_source_files_from_panda(self.remote_source_file) + return None + + def setup_idds_source_files(self): + """ + Download source files from the idds cache and return the setup env. + + :returns command: `str` to setup the workflow. + """ + if self.remote_source_file: + self.download_source_files_from_panda(self.remote_source_file) + + return None + + def setup_sharefs_source_files(self): + """ + Download source files from the share file system or use the codes from the share file system. + Return the setup env. + + :returns command: `str` to setup the workflow. + """ + return None + + def get_panda_env(self): + env_list = ['PANDA_CONFIG_ROOT', 'PANDA_URL_SSL', 'PANDA_URL', 'PANDACACHE_URL', 'PANDAMON_URL', + 'PANDA_AUTH', 'PANDA_VERIFY_HOST', 'PANDA_AUTH_VO', 'PANDA_BEHIND_REAL_LB'] + ret_envs = {} + for env in env_list: + if env in os.environ: + ret_envs[env] = os.environ[env] + return ret_envs + + def get_archive_name(self): + name = self._name.split(":")[-1] + # name = name + "_" + datetime.datetime.utcnow().strftime("%Y_%m_%d_%H_%M_%S") + archive_name = "%s.tar.gz" % name + return archive_name + + def upload_source_files_to_panda(self): + if not self._source_dir: + return None + + archive_name = self.get_archive_name() + archive_file = create_archive_file('/tmp', archive_name, [self._source_dir]) + logging.info("created archive file: %s" % archive_file) + from pandaclient import Client + + attempt = 0 + max_attempts = 3 + done = False + while attempt < max_attempts and not done: + attempt += 1 + status, out = Client.putFile(archive_file, True) + if status == 0: + done = True + logging.info(f"copy_files_to_pandacache: status: {status}, out: {out}") + if out.startswith("NewFileName:"): + # found the same input sandbox to reuse + archive_file = out.split(":")[-1] + elif out != "True": + logging.error(out) + return None + + filename = os.path.basename(archive_file) + cache_path = os.path.join(os.environ["PANDACACHE_URL"], "cache") + filename = os.path.join(cache_path, filename) + return filename + + def prepare_with_panda(self): + """ + Upload the source files to the panda server. + + :raise Exception when failed. + """ + logging.info("preparing workflow with PanDA") + self._panda_env = self.get_panda_env() + remote_file_name = self.upload_source_files_to_panda() + self.remote_source_file = remote_file_name + logging.info("remote source file: %s" % self.remote_source_file) + logging.info("prepared workflow with PanDA") + + def get_idds_env(self): + env_list = ['IDDS_HOST', 'IDDS_AUTH_TYPE', 'IDDS_VO', 'IDDS_AUTH_NO_VERIFY', + 'OIDC_AUTH_ID_TOKEN', 'OIDC_AUTH_VO', 'IDDS_CONFIG'] + ret_envs = {} + for env in env_list: + if env in os.environ: + ret_envs[env] = os.environ[env] + return ret_envs + + def get_idds_server(self): + if 'IDDS_HOST' in self._idds_env: + return self._idds_env['IDDS_HOST'] + if os.environ.get('IDDS_HOST', None): + return os.environ.get('IDDS_HOST', None) + return None + + def prepare_with_idds(self): + """ + Upload the source files to the idds server. + + :raise Exception when failed. + """ + # idds_env = self.get_idds_env() + pass + + def prepare_with_sharefs(self): + """ + Upload the source files to the share file system + Or directly use the source files on the share file system.. + + :raise Exception when failed. + """ + pass + + def prepare(self): + """ + Prepare the workflow. + """ + if self.service == 'panda': + return self.prepare_with_panda() + elif self.service == 'idds': + # return self.prepare_with_idds() + return self.prepare_with_panda() + elif self.service == 'sharefs': + return self.prepare_with_sharefs() + return self.prepare_with_sharefs() + + +class Workflow(Base): + + def __init__(self, func=None, service='panda', context=None, source_dir=None, distributed=True, + args=None, kwargs={}, group_kwargs=[], update_kwargs=None, init_env=None, is_unique_func_name=False): + """ + Init a workflow. + """ + super(Workflow, self).__init__() + self.prepared = False + + # self._func = func + self._func, self._func_name_and_args = self.get_func_name_and_args(func, args, kwargs, group_kwargs) + self._update_kwargs = update_kwargs + + self._name = self._func_name_and_args[0] + if self._name: + self._name = self._name.replace('__main__:', '').replace('.py', '').replace(':', '.') + if not is_unique_func_name: + if self._name: + self._name = self._name + "_" + datetime.datetime.utcnow().strftime("%Y_%m_%d_%H_%M_%S") + source_dir = self.get_source_dir(self._func, source_dir) + if context is not None: + self._context = context + else: + self._context = WorkflowContext(name=self._name, service=service, source_dir=source_dir, distributed=distributed, init_env=init_env) + + @property + def service(self): + return self._context.service + + @property + def internal_id(self): + return self._context.internal_id + + @internal_id.setter + def internal_id(self, value): + self._context.internal_id = value + + @service.setter + def service(self, value): + self._context.service = value + + @property + def name(self): + return self._name + + @name.setter + def name(self, value): + self._name = value + + @property + def request_id(self): + return self._context.request_id + + @request_id.setter + def request_id(self, value): + self._context.request_id = value + + def set_request_id(self, request_id): + self.request_id = request_id + + @property + def vo(self): + return self._context.vo + + @vo.setter + def vo(self, value): + self._context.vo = value + + @property + def site(self): + return self._context.site + + @site.setter + def site(self, value): + self._context.site = value + + @property + def queue(self): + return self._context.queue + + @queue.setter + def queue(self, value): + self._context.queue = value + + def get_site(self): + return self.site + + @property + def cloud(self): + return self._context.cloud + + @cloud.setter + def cloud(self, value): + self._context.cloud = value + + @property + def working_group(self): + return self._context.working_group + + @working_group.setter + def working_group(self, value): + self._context.working_group = value + + @property + def priority(self): + return self._context.priority + + @priority.setter + def priority(self, value): + self._context.priority = value + + @property + def core_count(self): + return self._context.core_count + + @core_count.setter + def core_count(self, value): + self._context.core_count = value + + @property + def total_memory(self): + return self._context.total_memory + + @total_memory.setter + def total_memory(self, value): + self._context.total_memory = value + + @property + def max_walltime(self): + return self._context.max_walltime + + @max_walltime.setter + def max_walltime(self, value): + self._context.max_walltime = value + + @property + def max_attempt(self): + return self._context.max_attempt + + @max_attempt.setter + def max_attempt(self, value): + self._context.max_attempt = value + + @property + def username(self): + return self._context.username + + @username.setter + def username(self, value): + self._context.username = value + + @property + def userdn(self): + return self._context.userdn + + @userdn.setter + def userdn(self, value): + self._context.userdn = value + + @property + def type(self): + return self._context.type + + @type.setter + def type(self, value): + self._context.type = value + + @property + def lifetime(self): + return self._context.lifetime + + @lifetime.setter + def lifetime(self, value): + self._context.lifetime = value + + @property + def workload_id(self): + return self._context.workload_id + + @workload_id.setter + def workload_id(self, value): + self._context.workload_id = value + + def get_workload_id(self): + return self.workload_id + + @property + def token(self): + return self._context.token + + @token.setter + def token(self, value): + self._context.token = value + + def get_work_tag(self): + return 'iWorkflow' + + def get_work_type(self): + return WorkflowType.iWorkflow + + def get_work_name(self): + return self._name + + @property + def group_parameters(self): + return self._func_name_and_args[3] + + @group_parameters.setter + def group_parameters(self, value): + raise Exception("Not allwed to update group parameters") + + def to_dict(self): + func = self._func + self._func = None + obj = super(Workflow, self).to_dict() + self._func = func + return obj + + def get_source_dir(self, func, source_dir): + if source_dir: + return source_dir + if func: + if inspect.isbuiltin(func): + return None + source_file = inspect.getsourcefile(func) + if not source_file: + return None + file_path = os.path.abspath(source_file) + return os.path.dirname(file_path) + return None + + def prepare(self): + """ + Prepare the workflow: for example uploading the source codes to cache server. + :returns command: `str` to setup the workflow. + """ + if not self.prepared: + self._context.prepare() + self.prepared = True + + def submit_to_idds_server(self): + """ + Submit the workflow to the iDDS server. + + :returns id: the workflow id. + :raise Exception when failing to submit the workflow. + """ + # iDDS ClientManager + from idds.client.clientmanager import ClientManager + + client = ClientManager(host=self._context.get_idds_server()) + request_id = client.submit(self, use_dataset_name=False) + + logging.info("Submitted into iDDS with request id=%s", str(request_id)) + return request_id + + def submit_to_panda_server(self): + """ + Submit the workflow to the iDDS server through PanDA service. + + :returns id: the workflow id. + :raise Exception when failing to submit the workflow. + """ + import idds.common.utils as idds_utils + import pandaclient.idds_api as idds_api + + idds_server = self._context.get_idds_server() + client = idds_api.get_api(idds_utils.json_dumps, + idds_host=idds_server, + compress=True, + manager=True) + request_id = client.submit(self, username=None, use_dataset_name=False) + + logging.info("Submitted into PanDA-iDDS with request id=%s", str(request_id)) + return request_id + + def submit(self): + """ + Submit the workflow to the iDDS server. + + :returns id: the workflow id. + :raise Exception when failing to submit the workflow. + """ + self.prepare() + if self.service == 'panda': + request_id = self.submit_to_panda_server() + else: + request_id = self.submit_to_idds_server() + + try: + self._context.request_id = int(request_id) + return request_id + except Exception as ex: + logging.info("Request id (%s) is not integer, there should be some submission errors: %s" % (request_id, str(ex))) + + return None + + def setup(self): + """ + :returns command: `str` to setup the workflow. + """ + return self._context.setup() + + def setup_source_files(self): + """ + Setup location of source files + """ + return self._context.setup_source_files() + + def load(self, func_name): + """ + Load the function from the source files. + + :raise Exception + """ + os.environ['IDDS_IWORKFLOW_LOAD_WORKFLOW'] = 'true' + func = super(Workflow, self).load(func_name) + del os.environ['IDDS_IWORKFLOW_LOAD_WORKFLOW'] + + return func + + def pre_run(self): + # test AsyncResult + workflow_context = self._context + if workflow_context.distributed: + logging.info("Test AsyncResult") + a_ret = AsyncResult(workflow_context, wait_num=1, timeout=30) + a_ret.subscribe() + + async_ret = AsyncResult(workflow_context, internal_id=a_ret.internal_id) + test_result = "AsyncResult test (request_id: %s)" % workflow_context.request_id + logging.info("AsyncResult publish: %s" % test_result) + async_ret.publish(test_result) + + ret_q = a_ret.wait_result(force_return_results=True) + logging.info("AsyncResult results: %s" % str(ret_q)) + if ret_q: + if ret_q == test_result: + logging.info("AsyncResult test succeeded") + return True + else: + logging.info("AsyncResult test failed (published: %s, received: %s)" % (test_result, ret_q)) + return False + else: + logging.info("Not received results") + return False + return True + + def run(self): + """ + Run the workflow. + """ + # with self: + if True: + self.pre_run() + + func_name, args, kwargs, group_kwargs = self._func_name_and_args + if self._func is None: + func = self.load(func_name) + self._func = func + ret = self.run_func(self._func, args, kwargs) + + return ret + + # Context Manager ----------------------------------------------- + def __enter__(self): + WorkflowCanvas.push_managed_workflow(self) + return self + + def __exit__(self, _type, _value, _tb): + WorkflowCanvas.pop_managed_workflow() + + # /Context Manager ---------------------------------------------- + + def get_run_command(self): + cmd = "run_workflow --type workflow " + cmd += "--context %s --original_args %s " % (encode_base64(json_dumps(self._context)), + encode_base64(json_dumps(self._func_name_and_args))) + cmd += "--update_args ${IN/L}" + return cmd + + def get_runner(self): + setup = self.setup() + cmd = "" + run_command = self.get_run_command() + + if setup: + cmd = ' --setup "' + setup + '" ' + if cmd: + cmd = cmd + " " + run_command + else: + cmd = run_command + return cmd + + def get_func_name(self): + func_name = self._func_name_and_args[0] + return func_name + + +# foo = workflow(arg)(foo) +def workflow(func=None, *, lazy=False, service='panda', source_dir=None, primary=False, distributed=True): + if func is None: + return functools.partial(workflow, lazy=lazy) + + if 'IDDS_IWORKFLOW_LOAD_WORKFLOW' in os.environ: + return func + + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + f = Workflow(func, service=service, source_dir=source_dir, distributed=distributed) + + if lazy: + return f + + return f.run() + except Exception as ex: + logging.error("Failed to run workflow %s: %s" % (func, ex)) + raise ex + except: + raise + return wrapper + + +def workflow_old(func=None, *, lazy=False, service='panda', source_dir=None, primary=False, distributed=True): + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + f = Workflow(func, service=service, source_dir=source_dir, distributed=distributed) + + if lazy: + return f + + return f.run() + except Exception as ex: + logging.error("Failed to run workflow %s: %s" % (func, ex)) + raise ex + except: + raise + return wrapper + + if func is None: + return decorator + else: + return decorator(func) diff --git a/workflow/lib/idds/workflow/work.py b/workflow/lib/idds/workflow/work.py index a15f6b10..3258f5b8 100644 --- a/workflow/lib/idds/workflow/work.py +++ b/workflow/lib/idds/workflow/work.py @@ -1425,6 +1425,14 @@ def is_subfinished(self): return True return False + def is_processed(self, synchronize=True): + """ + *** Function called by Transformer agent. + """ + if self.status in [WorkStatus.Finished, WorkStatus.SubFinished] and self.substatus not in [WorkStatus.ToCancel, WorkStatus.ToSuspend, WorkStatus.ToResume]: + return True + return False + def is_failed(self): """ *** Function called by Transformer agent. diff --git a/workflow/lib/idds/workflowv2/work.py b/workflow/lib/idds/workflowv2/work.py index e3caa610..7737343f 100644 --- a/workflow/lib/idds/workflowv2/work.py +++ b/workflow/lib/idds/workflowv2/work.py @@ -1469,6 +1469,14 @@ def is_subfinished(self, synchronize=True): return True return False + def is_processed(self, synchronize=True): + """ + *** Function called by Transformer agent. + """ + if self.status in [WorkStatus.Finished, WorkStatus.SubFinished] and self.substatus not in [WorkStatus.ToCancel, WorkStatus.ToSuspend, WorkStatus.ToResume]: + return True + return False + def is_failed(self, synchronize=True): """ *** Function called by Transformer agent. diff --git a/workflow/tools/make/environment.yaml b/workflow/tools/make/environment.yaml new file mode 100644 index 00000000..7c5ce80f --- /dev/null +++ b/workflow/tools/make/environment.yaml @@ -0,0 +1,20 @@ +name: idds +dependencies: +- python==3.6 +- pip +- pip: + - argcomplete + - requests + - tabulate + - urllib3==1.26.18 + - setuptools_rust + - packaging + - anytree + - networkx + - stomp.py + - panda-client + - cffi + - charset_normalizer + - idna + - pycparser + - websocket diff --git a/workflow/tools/make/make.sh b/workflow/tools/make/make.sh new file mode 100644 index 00000000..5948a439 --- /dev/null +++ b/workflow/tools/make/make.sh @@ -0,0 +1,79 @@ +#!/bin/bash + +CurrentDir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +ToolsDir="$( dirname "$CurrentDir" )" +WorkflowDir="$( dirname "$ToolsDir" )" +RootDir="$( dirname "$WorkflowDir" )" + +EXECNAME=${WorkflowDir}/bin/run_workflow_wrapper +rm -fr $EXECNAME + +workdir=/tmp/idds +tmpzip=/tmp/idds/tmp.zip +rm -fr $workdir +mkdir -p $workdir + +echo "setup virtualenv $workdir" +# source /afs/cern.ch/user/w/wguan/workdisk/conda/setup_mini.sh +# conda deactivate +# echo conda env create --prefix=$workdir -f ${WorkflowDir}/tools/make/environment.yaml +# conda env create --prefix=$workdir -f ${WorkflowDir}/tools/make/environment.yaml +# conda activate $workdir + +python3 -m venv $workdir +source $workdir/bin/activate + +echo "install panda client" +pip install panda-client +pip install tabulate requests urllib3==1.26.18 argcomplete packaging anytree networkx stomp.py + +echo "install idds-common" +python ${RootDir}/common/setup.py clean --all +python ${RootDir}/common/setup.py install --old-and-unmanageable --force + +echo "install idds-client" +python ${RootDir}/client/setup.py clean --all +python ${RootDir}/client/setup.py install --old-and-unmanageable --force + +echo "install idds-workflow" +python ${RootDir}/workflow/setup.py clean --all +python ${RootDir}/workflow/setup.py install --old-and-unmanageable --force + +python_lib_path=`python -c 'from sysconfig import get_path; print(get_path("purelib"))'` +echo $python_lib_path + +cur_dir=$PWD + +# cd ${python_lib_path} +# # for libname in idds pandaclient pandatools tabulate pyjwt requests urllib3 argcomplete cryptography packaging anytree networkx; do +# for libname in idds pandaclient pandatools tabulate jwt requests urllib3 argcomplete cryptography packaging stomp; do +# echo zip -r $tmpzip $libname +# zip -r $tmpzip $libname +# done +# cd - + +cd $workdir +mkdir lib_py +# for libname in idds pandaclient pandatools tabulate pyjwt requests urllib3 argcomplete cryptography packaging anytree networkx; do +# for libname in idds pandaclient pandatools tabulate jwt requests urllib3 argcomplete cryptography packaging stomp cffi charset_normalizer docopt.py idna pycparser six.py websocket _cffi_backend*; do +for libname in idds pandaclient pandatools tabulate requests urllib3 argcomplete stomp websocket charset_normalizer idna certifi; do + echo cp -fr ${python_lib_path}/$libname lib_py + cp -fr ${python_lib_path}/$libname lib_py +done +echo zip -r $tmpzip lib_py +zip -r $tmpzip lib_py +cd - + +cd $workdir +echo zip -r $tmpzip etc +zip -r $tmpzip etc +cd - + +cd ${WorkflowDir} +echo zip -r $tmpzip bin +zip -r $tmpzip bin + +cd - + +cat ${WorkflowDir}/tools/make/zipheader $tmpzip > $EXECNAME +chmod +x $EXECNAME diff --git a/workflow/tools/make/make_old.sh b/workflow/tools/make/make_old.sh new file mode 100644 index 00000000..dc5a819c --- /dev/null +++ b/workflow/tools/make/make_old.sh @@ -0,0 +1,72 @@ +#!/bin/bash + +CurrentDir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +ToolsDir="$( dirname "$CurrentDir" )" +WorkflowDir="$( dirname "$ToolsDir" )" +RootDir="$( dirname "$WorkflowDir" )" + +EXECNAME=${WorkflowDir}/bin/run_workflow_wrapper +rm -fr $EXECNAME + +workdir=/tmp/idds +tmpzip=/tmp/idds/tmp.zip +rm -fr $workdir +mkdir -p $workdir + +echo "setup virtualenv $workdir" +python3 -m venv $workdir +source $workdir/bin/activate + +echo "install panda client" +pip install panda-client +pip install tabulate pyjwt requests urllib3==1.26.18 argcomplete setuptools_rust cryptography packaging anytree networkx stomp.py + +echo "install idds-common" +python ${RootDir}/common/setup.py clean --all +python ${RootDir}/common/setup.py install --old-and-unmanageable --force + +echo "install idds-client" +python ${RootDir}/client/setup.py clean --all +python ${RootDir}/client/setup.py install --old-and-unmanageable --force + +echo "install idds-workflow" +python ${RootDir}/workflow/setup.py clean --all +python ${RootDir}/workflow/setup.py install --old-and-unmanageable --force + +python_lib_path=`python -c 'from sysconfig import get_path; print(get_path("purelib"))'` +echo $python_lib_path + +cur_dir=$PWD + +# cd ${python_lib_path} +# # for libname in idds pandaclient pandatools tabulate pyjwt requests urllib3 argcomplete cryptography packaging anytree networkx; do +# for libname in idds pandaclient pandatools tabulate jwt requests urllib3 argcomplete cryptography packaging stomp; do +# echo zip -r $tmpzip $libname +# zip -r $tmpzip $libname +# done +# cd - + +cd $workdir +mkdir lib_py +# for libname in idds pandaclient pandatools tabulate pyjwt requests urllib3 argcomplete cryptography packaging anytree networkx; do +for libname in idds pandaclient pandatools tabulate jwt requests urllib3 argcomplete cryptography packaging stomp cffi charset_normalizer docopt.py idna pycparser six.py websocket; do + echo cp -fr ${python_lib_path}/$libname lib_py + cp -fr ${python_lib_path}/$libname lib_py +done +echo zip -r $tmpzip lib_py +zip -r $tmpzip lib_py +cd - + +cd $workdir +echo zip -r $tmpzip etc +zip -r $tmpzip etc +cd - + +cd ${WorkflowDir} +echo zip -r $tmpzip bin +zip -r $tmpzip bin + +cd - + +cat ${WorkflowDir}/tools/make/zipheader $tmpzip > $EXECNAME +chmod +x $EXECNAME diff --git a/workflow/tools/make/zipheader b/workflow/tools/make/zipheader new file mode 100644 index 00000000..086dd2cb --- /dev/null +++ b/workflow/tools/make/zipheader @@ -0,0 +1,103 @@ +#!/bin/bash + +which unzip > /dev/null +if [ $? -ne 0 ]; then +echo "ERROR: unzip is missing" +exit 111 +fi + +unzip -o $0 > /dev/null 2>&1 +# PYNAME=`echo $0 | sed -e "s/\(-.*$\)/.py/"` + +current_dir=$PWD +export PATH=${current_dir}:${current_dir}/bin:$PATH + +chmod +x ${current_dir}/bin/* +ln -fs ${current_dir}/bin/* ${current_dir}/ + +export PYTHONPATH=${current_dir}/lib_py:$PYTHONPATH +export IDDS_CONFIG=${current_dir}/etc/idds/idds.cfg.client.template + +if [[ ! -z "${PANDA_AUTH_DIR}" ]] && [[ ! -z "${PANDA_AUTH_ORIGIN}" ]]; then + export PANDA_AUTH_ID_TOKEN=$(cat $PANDA_AUTH_DIR); + export PANDA_AUTH_VO=$PANDA_AUTH_ORIGIN; + export IDDS_OIDC_TOKEN=$(cat $PANDA_AUTH_DIR); + export IDDS_VO=$PANDA_AUTH_ORIGIN; + export PANDA_AUTH=oidc; +else + unset PANDA_AUTH; + export IDDS_AUTH_TYPE=x509_proxy; + if [ -f $X509_USER_PROXY ]; then + cp $X509_USER_PROXY ${current_dir}/x509_proxy + fi +fi; + +export PANDA_CONFIG_ROOT=$(pwd); +export PANDA_VERIFY_HOST=off; +export PANDA_BEHIND_REAL_LB=true; + +myargs="$@" +setup="" + +POSITIONAL=() +while [[ $# -gt 0 ]]; do + key="$1" + case $key in + --setup) + setup="$2" + shift + shift + ;; + *) + POSITIONAL+=("$1") # save it in an array for later + shift + ;; + esac +done + +set -- "${POSITIONAL[@]}" # restore positional parameters + +echo $setup + +run_args=$@ +echo $run_args + +cmdfile="run_workflow.sh" +cat <<- EOF > ./$cmdfile +#/bin/bash + +current_dir=\$PWD +export PATH=\${current_dir}:\${current_dir}/tmp_bin:\${current_dir}/bin:\$PATH +export PYTHONPATH=\${current_dir}:\${current_dir}/lib_py:\$PYTHONPATH + +if ! command -v python &> /dev/null +then + echo "no python, alias python3 to python" + alias python=python3 +fi + +if [ -f \${current_dir}/x509_proxy ]; then + export X509_USER_PROXY=\${current_dir}/x509_proxy +fi + +$run_args + +EOF + +chmod +x ./$cmdfile + +# exec python "$@" +# python "$@" +# exec "$@" + +$setup ./$cmdfile +ret=$? + +echo pwd +pwd; ls + +echo rm -fr ${current_dir}/lib_py ${current_dir}/etc ${current_dir}/bin ${current_dir}/tmp_bin ${current_dir}/run_workflow_wrapper ${current_dir}/__pycache__ +rm -fr ${current_dir}/lib_py ${current_dir}/etc ${current_dir}/bin ${current_dir}/tmp_bin ${current_dir}/run_workflow_wrapper ${current_dir}/__pycache__ + +echo "return code: " $ret +exit $ret