Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 9 additions & 102 deletions agents/matmaster_agent/base_agents/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
CURRENT_ENV,
FRONTEND_STATE_KEY,
LOCAL_EXECUTOR,
MATERIALS_ACCESS_KEY,
MATERIALS_PROJECT_ID,
OPENAPI_HOST,
Transfer2Agent,
)
Expand Down Expand Up @@ -178,14 +180,6 @@ def _get_projectId(ctx: Union[InvocationContext, ToolContext]):
)


@check_None_wrapper
def _get_machineType(ctx: Union[InvocationContext, ToolContext]):
session_state = get_session_state(ctx)
return session_state[FRONTEND_STATE_KEY]['biz'].get('machineType') or os.getenv(
'MACHINE_TYPE', 'c32_m64_cpu'
)


def _inject_ak(ctx: Union[InvocationContext, ToolContext], executor, storage):
access_key = _get_ak(ctx)
if executor is not None:
Expand Down Expand Up @@ -219,8 +213,7 @@ def _inject_projectId(ctx: Union[InvocationContext, ToolContext], executor, stor


def _inject_username(ctx: Union[InvocationContext, ToolContext], executor):
access_key = _get_ak(ctx)
username = ak_to_username(access_key=access_key)
username = ak_to_username(access_key=MATERIALS_ACCESS_KEY)
if username:
if executor is not None:
if executor['type'] == 'dispatcher': # BohriumExecutor
Expand All @@ -238,8 +231,7 @@ def _inject_username(ctx: Union[InvocationContext, ToolContext], executor):


def _inject_ticket(ctx: Union[InvocationContext, ToolContext], executor):
access_key = _get_ak(ctx)
ticket = ak_to_ticket(access_key=access_key)
ticket = ak_to_ticket(access_key=MATERIALS_ACCESS_KEY)
if ticket:
if executor is not None:
if executor['type'] == 'dispatcher': # BohriumExecutor
Expand Down Expand Up @@ -270,65 +262,6 @@ def _inject_current_env(executor):
return executor


def _inject_machine_type(ctx: Union[InvocationContext, ToolContext], executor):
machine_type = _get_machineType(ctx)
session_state = get_session_state(ctx)
logger.info(
f"biz = {session_state[FRONTEND_STATE_KEY]['biz']}; "
f"machineType = {machine_type}"
)
if executor is not None:
if executor['type'] == 'dispatcher': # BohriumExecutor
current_machine_type = executor['machine']['remote_profile']['machine_type']
if not current_machine_type:
executor['machine']['remote_profile']['machine_type'] = str(
machine_type
)
logger.info(f"After inject_machine_type, executor = {executor}")

return executor


def inject_ak_projectId(func: BeforeToolCallback) -> BeforeToolCallback:
@wraps(func)
async def wrapper(
tool: BaseTool, args: dict, tool_context: ToolContext
) -> Optional[dict]:
# 两步操作:
# 1. 调用被装饰的 before_tool_callback;
# 2. 如果调用的 before_tool_callback 有返回值,以这个为准
if (before_tool_result := await func(tool, args, tool_context)) is not None:
return before_tool_result

# 如果 tool 为 Transfer2Agent,不做 ak 和 project_id 设置/校验
if tool.name == Transfer2Agent:
return None

# 如果 tool 不是 CalculationMCPTool,不应该调用这个 callback
if not isinstance(tool, CalculationMCPTool):
raise TypeError(
'Not CalculationMCPTool type, current tool does not have <storage>'
)

# 获取 access_key
access_key, tool.executor, tool.storage = _inject_ak(
tool_context, tool.executor, tool.storage
)

# 获取 project_id
try:
project_id, tool.executor, tool.storage = _inject_projectId(
tool_context, tool.executor, tool.storage
)
except ValueError as e:
raise ValueError('ProjectId is invalid') from e

tool_context.state['ak'] = access_key
tool_context.state['project_id'] = project_id

return wrapper


def inject_username_ticket(func: BeforeToolCallback) -> BeforeToolCallback:
@wraps(func)
async def wrapper(
Expand Down Expand Up @@ -383,14 +316,14 @@ async def wrapper(
job_create_url = f"{OPENAPI_HOST}/openapi/v1/job/create"
user_project_list_url = f"{OPENAPI_HOST}/openapi/v1/open/user/project/list"
payload = {
'projectId': int(tool_context.state['project_id']),
'projectId': MATERIALS_PROJECT_ID,
'name': 'check_job_create',
}
params = {'accessKey': tool_context.state['ak']}
params = {'accessKey': MATERIALS_ACCESS_KEY}

logger.info(
f"[check_job_create] project_id = {tool_context.state['project_id']}, "
f"ak = {tool_context.state['ak']}"
f"[check_job_create] project_id = {MATERIALS_PROJECT_ID}, "
f"ak = {MATERIALS_ACCESS_KEY}"
)

async with aiohttp.ClientSession() as session:
Expand All @@ -402,7 +335,7 @@ async def wrapper(
project_name = [
item['project_name']
for item in res['data']['items']
if item['project_id'] == int(tool_context.state['project_id'])
if item['project_id'] == MATERIALS_PROJECT_ID
][0]

async with aiohttp.ClientSession() as session:
Expand All @@ -420,32 +353,6 @@ async def wrapper(
return wrapper


def inject_machineType(func: BeforeToolCallback) -> BeforeToolCallback:
@wraps(func)
async def wrapper(
tool: BaseTool, args: dict, tool_context: ToolContext
) -> Optional[dict]:
# 两步操作:
# 1. 调用被装饰的 before_tool_callback;
# 2. 如果调用的 before_tool_callback 有返回值,以这个为准
if (before_tool_result := await func(tool, args, tool_context)) is not None:
return before_tool_result

# 如果 tool 为 Transfer2Agent,不做 ak 和 project_id 设置/校验
if tool.name == Transfer2Agent:
return None

# 如果 tool 不是 CalculationMCPTool,不应该调用这个 callback
if not isinstance(tool, CalculationMCPTool):
raise TypeError(
'Not CalculationMCPTool type, current tool does not have <storage>'
)

_inject_machine_type(tool_context, tool.executor)

return wrapper


# 总应该在最后
def catch_before_tool_callback_error(func: BeforeToolCallback) -> BeforeToolCallback:
@wraps(func)
Expand Down
32 changes: 18 additions & 14 deletions agents/matmaster_agent/base_agents/job_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
default_after_model_callback,
default_after_tool_callback,
default_before_tool_callback,
inject_ak_projectId,
inject_current_env,
inject_machineType,
inject_username_ticket,
remove_function_call,
tgz_oss_to_oss_list,
Expand All @@ -37,6 +35,8 @@
LOADING_START,
LOADING_STATE_KEY,
LOADING_TITLE,
MATERIALS_ACCESS_KEY,
MATERIALS_PROJECT_ID,
TMP_FRONTEND_STATE_KEY,
ModelRole,
OpenAPIJobAPI,
Expand Down Expand Up @@ -160,14 +160,8 @@ def __init__(

# Todo: support List[before_tool_callback]
before_tool_callback = catch_before_tool_callback_error(
inject_machineType(
check_job_create(
inject_current_env(
inject_username_ticket(
inject_ak_projectId(before_tool_callback)
)
)
)
check_job_create(
inject_current_env(inject_username_ticket(before_tool_callback))
)
)
after_tool_callback = check_before_tool_callback_effect(
Expand Down Expand Up @@ -410,6 +404,9 @@ async def _run_async_impl(
):
raw_result = part.function_response.response['result']
results = json.loads(raw_result.content[0].text)
logger.info(
f"[SubmitCoreCalculationMCPLlmAgent] results = {results}"
)
origin_job_id = results['job_id']
job_name = part.function_response.name
job_status = results['status']
Expand Down Expand Up @@ -554,11 +551,15 @@ async def _run_async_impl(
try:
await self.tools[0].get_tools()
if not ctx.session.state['dflow']:
access_key, Executor, BohriumStorge = _inject_ak(
ctx, get_BohriumExecutor(), get_BohriumStorage()
access_key, Executor, BohriumStorge = (
MATERIALS_ACCESS_KEY,
get_BohriumExecutor(),
get_BohriumStorage(),
)
project_id, Executor, BohriumStorge = _inject_projectId(
ctx, Executor, BohriumStorge
project_id, Executor, BohriumStorge = (
MATERIALS_PROJECT_ID,
Executor,
BohriumStorge,
)
else:
access_key, Executor, BohriumStorge = _inject_ak(
Expand Down Expand Up @@ -850,6 +851,9 @@ async def _run_async_impl(
params_check_completed_json: dict = json.loads(
response.choices[0].message.content
)
logger.info(
f"[BaseAsyncJobAgent] params_check_completed_json = {params_check_completed_json}"
)
params_check_completed = params_check_completed_json['flag']
params_check_reason = params_check_completed_json['reason']
params_check_msg = params_check_completed_json['analyzed_messages']
Expand Down
6 changes: 2 additions & 4 deletions agents/matmaster_agent/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from google.genai import types
from google.genai.types import FunctionCall, Part

from agents.matmaster_agent.base_agents.callback import _get_ak
from agents.matmaster_agent.constant import FRONTEND_STATE_KEY
from agents.matmaster_agent.constant import FRONTEND_STATE_KEY, MATERIALS_ACCESS_KEY
from agents.matmaster_agent.locales import i18n
from agents.matmaster_agent.model import UserContent
from agents.matmaster_agent.prompt import get_user_content_lang
Expand Down Expand Up @@ -112,7 +111,6 @@ async def matmaster_check_job_status(
jobs_dict
): # 确认当前有在运行中的任务
running_job_ids = get_running_jobs_detail(jobs_dict) # 从 state 里面拿
access_key = _get_ak(callback_context) # 从 state 或环境变量里面拿
if callback_context.state['target_language'] in [
'Chinese',
'zh-CN',
Expand All @@ -130,7 +128,7 @@ async def matmaster_check_job_status(
'[matmaster_check_job_status] new LlmResponse, prepare call API'
)
job_status = await get_job_status(
job_query_url, access_key=access_key
job_query_url, access_key=MATERIALS_ACCESS_KEY
) # 查询任务的最新状态
callback_context.state['new_query_job_status'][
'origin_job_id'
Expand Down
64 changes: 34 additions & 30 deletions agents/matmaster_agent/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,38 @@
# DB
DBUrl = os.getenv('SESSION_API_URL')

OPENAPI_HOST = ''
DFLOW_HOST = ''
DFLOW_K8S_API_SERVER = ''
BOHRIUM_API_URL = ''

CURRENT_ENV = os.getenv('OPIK_PROJECT_NAME', 'prod')
if CURRENT_ENV == 'test':
OPENAPI_HOST = 'https://openapi.test.dp.tech'
DFLOW_HOST = 'https://lbg-workflow-mlops.test.dp.tech'
DFLOW_K8S_API_SERVER = 'https://lbg-workflow-mlops.test.dp.tech'
BOHRIUM_API_URL = 'https://bohrium-api.test.dp.tech'
elif CURRENT_ENV == 'uat':
OPENAPI_HOST = 'https://openapi.uat.dp.tech'
BOHRIUM_API_URL = 'https://bohrium-api.uat.dp.tech'
elif CURRENT_ENV == 'prod':
OPENAPI_HOST = 'https://openapi.dp.tech'
DFLOW_HOST = 'https://workflows.deepmodeling.com'
DFLOW_K8S_API_SERVER = 'https://workflows.deepmodeling.com'
BOHRIUM_API_URL = 'https://bohrium-api.dp.tech'

OpenAPIJobAPI = f"{OPENAPI_HOST}/openapi/v1/sandbox/job"

MATERIALS_ACCESS_KEY = str(os.getenv('MATERIALS_ACCESS_KEY'))
MATERIALS_PROJECT_ID = int(os.getenv('MATERIALS_PROJECT_ID'))

# Bohrium Constant
BohriumStorge = {
'type': 'https',
'plugin': {
'type': 'bohrium',
'access_key': '',
'project_id': -1,
'access_key': MATERIALS_ACCESS_KEY,
'project_id': MATERIALS_PROJECT_ID,
'app_key': 'agent',
},
}
Expand All @@ -34,38 +59,17 @@
'batch_type': 'OpenAPI',
'context_type': 'OpenAPI',
'remote_profile': {
'access_key': '',
'project_id': -1,
'access_key': MATERIALS_ACCESS_KEY,
'project_id': MATERIALS_PROJECT_ID,
'app_key': 'agent',
'image_address': 'registry.dp.tech/dptech/dp/native/prod-19853/dpa-mcp:0.0.0',
'image_address': '',
'platform': 'ali',
'machine_type': '',
'machine_type': 'c2_m8_cpu',
},
},
'resources': {'envs': {}},
'resources': {'envs': {'BOHRIUM_PROJECT_ID': MATERIALS_PROJECT_ID}},
}

OPENAPI_HOST = ''
DFLOW_HOST = ''
DFLOW_K8S_API_SERVER = ''
BOHRIUM_API_URL = ''

CURRENT_ENV = os.getenv('OPIK_PROJECT_NAME', 'prod')
if CURRENT_ENV == 'test':
OPENAPI_HOST = 'https://openapi.test.dp.tech'
DFLOW_HOST = 'https://lbg-workflow-mlops.test.dp.tech'
DFLOW_K8S_API_SERVER = 'https://lbg-workflow-mlops.test.dp.tech'
BOHRIUM_API_URL = 'https://bohrium-api.test.dp.tech'
elif CURRENT_ENV == 'uat':
OPENAPI_HOST = 'https://openapi.uat.dp.tech'
BOHRIUM_API_URL = 'https://bohrium-api.uat.dp.tech'
elif CURRENT_ENV == 'prod':
OPENAPI_HOST = 'https://openapi.dp.tech'
DFLOW_HOST = 'https://workflows.deepmodeling.com'
DFLOW_K8S_API_SERVER = 'https://workflows.deepmodeling.com'
BOHRIUM_API_URL = 'https://bohrium-api.dp.tech'
OpenAPIJobAPI = f"{OPENAPI_HOST}/openapi/v1/job"

DFlowExecutor = {
'type': 'local',
'dflow': True,
Expand All @@ -74,8 +78,8 @@
'DFLOW_K8S_API_SERVER': DFLOW_K8S_API_SERVER,
'DFLOW_S3_REPO_KEY': 'oss-bohrium',
'DFLOW_S3_STORAGE_CLIENT': 'dflow.plugins.bohrium.TiefblueClient',
'BOHRIUM_ACCESS_KEY': '',
'BOHRIUM_PROJECT_ID': '',
'BOHRIUM_ACCESS_KEY': MATERIALS_ACCESS_KEY,
'BOHRIUM_PROJECT_ID': str(MATERIALS_PROJECT_ID),
'BOHRIUM_APP_KEY': 'agent',
},
}
Expand Down
2 changes: 1 addition & 1 deletion agents/matmaster_agent/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class JobResult(BaseModel):

class BohrJobInfo(BaseModel):
origin_job_id: str
job_id: int
job_id: Union[int, str]
job_query_url: str
job_detail_url: str
job_status: JobStatus
Expand Down
2 changes: 1 addition & 1 deletion agents/matmaster_agent/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ def gen_params_check_completed_agent_instruction():
{{
"flag": <boolean>,
"reason": <string>, // *A concise explanation of the reasoning behind the judgment, covering both positive and negative evidence found in the context messages. Return empty string only if there is absolutely no relevant content to analyze.*
"analyzed_message": List[<string>] // *Quote the key messages that were analyzed to make this determination.*
"analyzed_messages": List[<string>] // *Quote the key messages that were analyzed to make this determination.*
}}

Return `flag: true` ONLY IF ALL of the following conditions are met:
Expand Down
2 changes: 1 addition & 1 deletion agents/matmaster_agent/structure_generate_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
] = 'registry.dp.tech/dptech/dp/native/prod-788025/structure-generate-agent:small'
StructureGenerateBohriumExecutor['machine']['remote_profile'][
'machine_type'
] = 'c8_m31_1 * NVIDIA T4'
] = 'c8_m32_1 * NVIDIA 4090'

sse_params = SseServerParams(url=StructureGenerateServerUrl)

Expand Down
Loading
Loading