diff --git a/pdl-live/src/pdl_ast.d.ts b/pdl-live/src/pdl_ast.d.ts index a7af232b4..27f826492 100644 --- a/pdl-live/src/pdl_ast.d.ts +++ b/pdl-live/src/pdl_ast.d.ts @@ -275,7 +275,6 @@ export type Fallback = export type Role = string | null; export type Path = string[]; export type File = string; -export type HasError = boolean; export type Kind = "empty"; /** * Name of the variable used to store the result of the execution of the block. @@ -351,7 +350,6 @@ export type Fallback1 = * */ export type Role1 = string | null; -export type HasError1 = boolean; export type Kind1 = "error"; export type Msg = string; export type Program1 = @@ -478,7 +476,6 @@ export type Fallback2 = * */ export type Role2 = string | null; -export type HasError2 = boolean; export type Kind2 = "include"; /** * Name of the file to include. @@ -584,7 +581,6 @@ export type Fallback3 = * */ export type Role3 = string | null; -export type HasError3 = boolean; export type Kind3 = "read"; /** * Name of the file to read. If `None`, read the standard input. @@ -674,7 +670,6 @@ export type Fallback4 = * Role of associated to the message. */ export type Role4 = string | null; -export type HasError4 = boolean; export type Kind4 = "message"; /** * Content of the message. @@ -803,7 +798,6 @@ export type Fallback5 = * */ export type Role5 = string | null; -export type HasError5 = boolean; export type Kind5 = "object"; export type Object = | { @@ -932,7 +926,6 @@ export type Fallback6 = * */ export type Role6 = string | null; -export type HasError6 = boolean; export type Kind6 = "array"; export type Array = | number @@ -1058,7 +1051,6 @@ export type Fallback7 = * */ export type Role7 = string | null; -export type HasError7 = boolean; export type Kind7 = "sequence"; export type Sequence = | number @@ -1184,7 +1176,6 @@ export type Fallback8 = * */ export type Role8 = string | null; -export type HasError8 = boolean; export type Kind8 = "document"; /** * Body of the document. @@ -1314,7 +1305,6 @@ export type Fallback9 = * */ export type Role9 = string | null; -export type HasError9 = boolean; export type Kind9 = "for"; /** * Body of the loop. @@ -1503,7 +1493,6 @@ export type Fallback10 = * */ export type Role10 = string | null; -export type HasError10 = boolean; export type Kind10 = "repeat_until"; /** * Body of the loop. @@ -1692,7 +1681,6 @@ export type Fallback11 = * */ export type Role11 = string | null; -export type HasError11 = boolean; export type Kind11 = "repeat"; /** * Body of the loop. @@ -1886,7 +1874,6 @@ export type Fallback12 = * */ export type Role12 = string | null; -export type HasError12 = boolean; export type Kind12 = "if"; /** * Branch to exectute if the condition is true. @@ -2072,7 +2059,6 @@ export type Fallback13 = * */ export type Role13 = string | null; -export type HasError13 = boolean; export type Kind13 = "data"; /** * Do not evaluate expressions inside strings. @@ -2152,7 +2138,6 @@ export type Fallback14 = * */ export type Role14 = string | null; -export type HasError14 = boolean; export type Kind14 = "get"; /** * Name of the variable to access. @@ -2232,7 +2217,6 @@ export type Fallback15 = * */ export type Role15 = string | null; -export type HasError15 = boolean; export type Kind15 = "api"; export type Api = string; /** @@ -2367,7 +2351,6 @@ export type Fallback16 = * */ export type Role16 = string | null; -export type HasError16 = boolean; export type Kind16 = "code"; /** * Programming language of the code. @@ -2502,7 +2485,6 @@ export type Fallback17 = * */ export type Role17 = string | null; -export type HasError17 = boolean; export type Kind17 = "model"; export type Model = string; export type Input1 = @@ -2713,7 +2695,6 @@ export type Fallback18 = * */ export type Role18 = string | null; -export type HasError18 = boolean; export type Kind18 = "model"; export type Model1 = string; export type Input2 = @@ -2920,7 +2901,6 @@ export type Fallback19 = * */ export type Role19 = string | null; -export type HasError19 = boolean; export type Kind19 = "call"; /** * Function to call. @@ -3052,7 +3032,6 @@ export type Fallback20 = * */ export type Role20 = string | null; -export type HasError20 = boolean; export type Kind20 = "function"; /** * Functions parameters with their types. @@ -3207,7 +3186,6 @@ export interface FunctionBlock { role?: Role20; result?: unknown; location?: LocationType | null; - has_error?: HasError20; kind?: Kind20; function: Function; return: Return; @@ -3290,7 +3268,6 @@ export interface CallBlock { role?: Role19; result?: unknown; location?: LocationType | null; - has_error?: HasError19; kind?: Kind19; call: Call; args?: Args; @@ -3373,7 +3350,6 @@ export interface LitellmModelBlock { role?: Role18; result?: unknown; location?: LocationType | null; - has_error?: HasError18; kind?: Kind18; model: Model1; input?: Input2; @@ -3455,7 +3431,6 @@ export interface BamModelBlock { role?: Role17; result?: unknown; location?: LocationType | null; - has_error?: HasError17; kind?: Kind17; model: Model; input?: Input1; @@ -3544,7 +3519,6 @@ export interface CodeBlock { role?: Role16; result?: unknown; location?: LocationType | null; - has_error?: HasError16; kind?: Kind16; lan: Lan; code: Code; @@ -3626,7 +3600,6 @@ export interface ApiBlock { role?: Role15; result?: unknown; location?: LocationType | null; - has_error?: HasError15; kind?: Kind15; api: Api; url: Url; @@ -3709,7 +3682,6 @@ export interface GetBlock { role?: Role14; result?: unknown; location?: LocationType | null; - has_error?: HasError14; kind?: Kind14; get: Get; } @@ -3790,7 +3762,6 @@ export interface DataBlock { role?: Role13; result?: unknown; location?: LocationType | null; - has_error?: HasError13; kind?: Kind13; data: Data; raw?: Raw; @@ -3872,7 +3843,6 @@ export interface IfBlock { role?: Role12; result?: unknown; location?: LocationType | null; - has_error?: HasError12; kind?: Kind12; if: If; then: Then; @@ -3956,7 +3926,6 @@ export interface RepeatBlock { role?: Role11; result?: unknown; location?: LocationType | null; - has_error?: HasError11; kind?: Kind11; repeat: Repeat2; num_iterations: NumIterations; @@ -4040,7 +4009,6 @@ export interface RepeatUntilBlock { role?: Role10; result?: unknown; location?: LocationType | null; - has_error?: HasError10; kind?: Kind10; repeat: Repeat1; until: Until; @@ -4124,7 +4092,6 @@ export interface ForBlock { role?: Role9; result?: unknown; location?: LocationType | null; - has_error?: HasError9; kind?: Kind9; for: For; repeat: Repeat; @@ -4208,7 +4175,6 @@ export interface DocumentBlock { role?: Role8; result?: unknown; location?: LocationType | null; - has_error?: HasError8; kind?: Kind8; document: Document; } @@ -4289,7 +4255,6 @@ export interface SequenceBlock { role?: Role7; result?: unknown; location?: LocationType | null; - has_error?: HasError7; kind?: Kind7; sequence: Sequence; } @@ -4370,7 +4335,6 @@ export interface ArrayBlock { role?: Role6; result?: unknown; location?: LocationType | null; - has_error?: HasError6; kind?: Kind6; array: Array; } @@ -4451,7 +4415,6 @@ export interface ObjectBlock { role?: Role5; result?: unknown; location?: LocationType | null; - has_error?: HasError5; kind?: Kind5; object: Object; } @@ -4532,7 +4495,6 @@ export interface MessageBlock { role: Role4; result?: unknown; location?: LocationType | null; - has_error?: HasError4; kind?: Kind4; content: Content; } @@ -4613,7 +4575,6 @@ export interface ReadBlock { role?: Role3; result?: unknown; location?: LocationType | null; - has_error?: HasError3; kind?: Kind3; read: Read; message?: Message; @@ -4696,7 +4657,6 @@ export interface IncludeBlock { role?: Role2; result?: unknown; location?: LocationType | null; - has_error?: HasError2; kind?: Kind2; include: Include; trace?: Trace; @@ -4775,7 +4735,6 @@ export interface ErrorBlock { role?: Role1; result?: unknown; location?: LocationType | null; - has_error?: HasError1; kind?: Kind1; msg: Msg; program: Program1; @@ -4857,7 +4816,6 @@ export interface EmptyBlock { role?: Role; result?: unknown; location?: LocationType | null; - has_error?: HasError; kind?: Kind; } /** diff --git a/pdl-schema.json b/pdl-schema.json index 4c033d6bc..9e7c2a727 100644 --- a/pdl-schema.json +++ b/pdl-schema.json @@ -426,11 +426,6 @@ ], "default": null }, - "has_error": { - "default": false, - "title": "Has Error", - "type": "boolean" - }, "kind": { "const": "api", "default": "api", @@ -1041,11 +1036,6 @@ ], "default": null }, - "has_error": { - "default": false, - "title": "Has Error", - "type": "boolean" - }, "kind": { "const": "array", "default": "array", @@ -1643,11 +1633,6 @@ ], "default": null }, - "has_error": { - "default": false, - "title": "Has Error", - "type": "boolean" - }, "kind": { "const": "model", "default": "model", @@ -2622,11 +2607,6 @@ ], "default": null }, - "has_error": { - "default": false, - "title": "Has Error", - "type": "boolean" - }, "kind": { "const": "call", "default": "call", @@ -3240,11 +3220,6 @@ ], "default": null }, - "has_error": { - "default": false, - "title": "Has Error", - "type": "boolean" - }, "kind": { "const": "code", "default": "code", @@ -3862,11 +3837,6 @@ ], "default": null }, - "has_error": { - "default": false, - "title": "Has Error", - "type": "boolean" - }, "kind": { "const": "data", "default": "data", @@ -4327,11 +4297,6 @@ ], "default": null }, - "has_error": { - "default": false, - "title": "Has Error", - "type": "boolean" - }, "kind": { "const": "document", "default": "document", @@ -4931,11 +4896,6 @@ ], "default": null }, - "has_error": { - "default": false, - "title": "Has Error", - "type": "boolean" - }, "kind": { "const": "empty", "default": "empty", @@ -5374,11 +5334,6 @@ ], "default": null }, - "has_error": { - "default": false, - "title": "Has Error", - "type": "boolean" - }, "kind": { "const": "error", "default": "error", @@ -5982,11 +5937,6 @@ ], "default": null }, - "has_error": { - "default": false, - "title": "Has Error", - "type": "boolean" - }, "kind": { "const": "for", "default": "for", @@ -6764,11 +6714,6 @@ ], "default": null }, - "has_error": { - "default": false, - "title": "Has Error", - "type": "boolean" - }, "kind": { "const": "function", "default": "function", @@ -7393,11 +7338,6 @@ ], "default": null }, - "has_error": { - "default": false, - "title": "Has Error", - "type": "boolean" - }, "kind": { "const": "get", "default": "get", @@ -7845,11 +7785,6 @@ ], "default": null }, - "has_error": { - "default": false, - "title": "Has Error", - "type": "boolean" - }, "kind": { "const": "if", "default": "if", @@ -8627,11 +8562,6 @@ ], "default": null }, - "has_error": { - "default": false, - "title": "Has Error", - "type": "boolean" - }, "kind": { "const": "include", "default": "include", @@ -9202,11 +9132,6 @@ ], "default": null }, - "has_error": { - "default": false, - "title": "Has Error", - "type": "boolean" - }, "kind": { "const": "model", "default": "model", @@ -10345,11 +10270,6 @@ ], "default": null }, - "has_error": { - "default": false, - "title": "Has Error", - "type": "boolean" - }, "kind": { "const": "message", "default": "message", @@ -11221,11 +11141,6 @@ ], "default": null }, - "has_error": { - "default": false, - "title": "Has Error", - "type": "boolean" - }, "kind": { "const": "object", "default": "object", @@ -12356,11 +12271,6 @@ ], "default": null }, - "has_error": { - "default": false, - "title": "Has Error", - "type": "boolean" - }, "kind": { "const": "read", "default": "read", @@ -12884,11 +12794,6 @@ ], "default": null }, - "has_error": { - "default": false, - "title": "Has Error", - "type": "boolean" - }, "kind": { "const": "repeat", "default": "repeat", @@ -13666,11 +13571,6 @@ ], "default": null }, - "has_error": { - "default": false, - "title": "Has Error", - "type": "boolean" - }, "kind": { "const": "repeat_until", "default": "repeat_until", @@ -14447,11 +14347,6 @@ ], "default": null }, - "has_error": { - "default": false, - "title": "Has Error", - "type": "boolean" - }, "kind": { "const": "sequence", "default": "sequence", diff --git a/pdl/pdl_ast.py b/pdl/pdl_ast.py index e998ad9df..6142a2d08 100644 --- a/pdl/pdl_ast.py +++ b/pdl/pdl_ast.py @@ -126,7 +126,6 @@ class Block(BaseModel): # Fields for internal use result: Optional[Any] = None location: Optional[LocationType] = None - has_error: bool = False class FunctionBlock(Block): @@ -525,8 +524,9 @@ class PdlBlocks(RootModel): class PDLException(Exception): - def __init__(self, msg): - self.msg = msg + def __init__(self, message): + super().__init__(message) + self.message = message MAX_NEW_TOKENS = 1024 diff --git a/pdl/pdl_dumper.py b/pdl/pdl_dumper.py index 91495e8ad..b8c50abd9 100644 --- a/pdl/pdl_dumper.py +++ b/pdl/pdl_dumper.py @@ -193,8 +193,6 @@ def block_to_dict(block: pdl_ast.BlockType) -> int | float | str | dict[str, Any d["parser"] = parser_to_dict(block.parser) if block.location is not None: d["location"] = location_to_dict(block.location) - if block.has_error: - d["has_error"] = block.has_error if block.fallback is not None: d["fallback"] = blocks_to_dict(block.fallback) return d diff --git a/pdl/pdl_interpreter.py b/pdl/pdl_interpreter.py index da34ae22a..26a57f7c9 100644 --- a/pdl/pdl_interpreter.py +++ b/pdl/pdl_interpreter.py @@ -5,12 +5,18 @@ # from itertools import batched from pathlib import Path -from typing import Any, Generator, Optional, Sequence +from typing import Any, Generator, Optional, Sequence, TypeVar import litellm import requests import yaml -from jinja2 import Environment, StrictUndefined, Template, UndefinedError +from jinja2 import ( + Environment, + StrictUndefined, + Template, + TemplateSyntaxError, + UndefinedError, +) from jinja2.runtime import Undefined from pydantic import BaseModel @@ -30,7 +36,6 @@ DocumentBlock, EmptyBlock, ErrorBlock, - ExpressionType, ForBlock, FunctionBlock, GetBlock, @@ -58,7 +63,6 @@ SequenceBlock, empty_block_location, ) -from .pdl_ast_utils import iter_block_children, iter_blocks from .pdl_dumper import block_to_dict from .pdl_llms import BamModel, LitellmModel from .pdl_location_utils import append, get_loc_string @@ -76,10 +80,44 @@ from .pdl_utils import messages_concat, messages_to_str, stringify -class PDLRuntimeParserError(PDLException): +class PDLRuntimeError(PDLException): + def __init__( + self, + message: str, + loc: Optional[LocationType] = None, + trace: Optional[BlockType] = None, + fallback: Optional[Any] = None, + ): + super().__init__(message) + self.loc = loc + self.trace = trace + self.fallback = fallback + self.message = message + + +class PDLRuntimeExpressionError(PDLRuntimeError): + pass + + +class PDLRuntimeParserError(PDLRuntimeError): pass +class PDLRuntimeStepBlocksError(PDLException): + def __init__( + self, + message: str, + blocks: list[BlockType], + loc: Optional[LocationType] = None, + fallback: Optional[Any] = None, + ): + super().__init__(message) + self.loc = loc + self.blocks = blocks + self.fallback = fallback + self.message = message + + empty_scope: ScopeType = {"context": []} @@ -140,8 +178,16 @@ def generate( log_fp.write(line) if trace_file: write_trace(trace_file, trace) - except PDLParseError as e: - print("\n".join(e.msg)) + except PDLParseError as exc: + print("\n".join(exc.message), file=sys.stderr) + except PDLRuntimeError as exc: + if exc.loc is None: + message = exc.message + else: + message = get_loc_string(exc.loc) + exc.message + print(message, file=sys.stderr) + if trace_file and exc.trace is not None: + write_trace(trace_file, exc.trace) def write_trace( @@ -174,6 +220,9 @@ def process_prog( Returns: Return the final result, the background messages, the final variable mapping, and the execution trace. + + Raises: + PDLRuntimeError: If the program raises an error. """ scope: ScopeType = empty_scope | initial_scope doc_generator = step_block(state, scope, block=prog.root, loc=loc) @@ -219,14 +268,16 @@ def step_block( background: Messages trace: BlockType if not isinstance(block, Block): - result, errors = process_expr(scope, block, loc) - if len(errors) != 0: - trace = handle_error(loc, None, errors, block) - result = block - background = [{"role": state.role, "content": str(block)}] - else: - background = [{"role": state.role, "content": stringify(result)}] - trace = stringify(result) + try: + result = process_expr(scope, block, loc) + except PDLRuntimeExpressionError as exc: + raise PDLRuntimeError( + exc.message, + loc=exc.loc or loc, + trace=ErrorBlock(msg=exc.message, location=loc, program=block), + ) from exc + background = [{"role": state.role, "content": stringify(result)}] + trace = stringify(result) if state.yield_background: yield YieldBackgroundMessage(background) if state.yield_result: @@ -250,30 +301,56 @@ def step_advanced_block( state = state.with_role(block.role) if len(block.defs) > 0: scope, defs_trace = yield from step_defs(state, scope, block.defs, loc) - else: - defs_trace = block.defs + block = block.model_copy(update={"defs": defs_trace}) state = state.with_yield_result( state.yield_result and ContributeTarget.RESULT in block.contribute ) state = state.with_yield_background( state.yield_background and ContributeTarget.CONTEXT in block.contribute ) - result, background, scope, trace = yield from step_block_body( - state, scope, block, loc - ) - trace = trace.model_copy(update={"defs": defs_trace, "result": result}) + try: + result, background, scope, trace = yield from step_block_body( + state, scope, block, loc + ) + except PDLRuntimeError as exc: + if block.fallback is None: + raise exc from exc + ( + result, + background, + scope, + trace, + ) = yield from step_blocks_of( + block, + "fallback", + IterationType.SEQUENCE, + state, + scope, + loc=loc, + ) + trace = trace.model_copy(update={"result": result}) if block.parser is not None: try: result = parse_result(block.parser, result) - except PDLRuntimeParserError as e: - trace = handle_error(loc, e.msg, [], trace) + except PDLRuntimeParserError as exc: + raise PDLRuntimeError( + exc.message, + loc=exc.loc or loc, + trace=ErrorBlock(msg=exc.message, program=trace, fallback=result), + ) from exc if block.assign is not None: var = block.assign scope = scope | {var: result} if block.spec is not None and not isinstance(block, FunctionBlock): errors = type_check_spec(result, block.spec, block.location) if len(errors) > 0: - trace = handle_error(loc, "Type errors during spec checking", errors, trace) + message = "Type errors during spec checking:\n" + "\n".join(errors) + raise PDLRuntimeError( + message, + loc=loc, + trace=ErrorBlock(msg=message, program=trace), + fallback=result, + ) if ContributeTarget.RESULT not in block.contribute: result = "" if ContributeTarget.CONTEXT not in block.contribute: @@ -306,18 +383,17 @@ def step_block_body( if state.yield_background: yield YieldBackgroundMessage(background) case GetBlock(get=var): - result = get_var(var, scope) - if result is None: - background = [] - trace = handle_error( - append(loc, "get"), - f"Variable is undefined: {var}", - [], - block.model_copy(), - ) - else: - background = [{"role": state.role, "content": stringify(result)}] - trace = block.model_copy() + block.location = append(loc, "get") + try: + result = get_var(var, scope, block.location) + except PDLRuntimeExpressionError as exc: + raise PDLRuntimeError( + exc.message, + loc=exc.loc or loc, + trace=ErrorBlock(msg=exc.message, location=loc, program=block), + ) from exc + background = [{"role": state.role, "content": stringify(result)}] + trace = block.model_copy() if state.yield_result: yield YieldResultMessage(result) if state.yield_background: @@ -326,19 +402,10 @@ def step_block_body( block.location = append(loc, "data") if block.raw: result = v - background = [{"role": state.role, "content": stringify(result)}] trace = block.model_copy() else: - result, errors = process_expr(scope, v, append(loc, "data")) - if len(errors) != 0: - result = None - background = [] - trace = handle_error( - append(loc, "data"), None, errors, block.model_copy() - ) - else: - background = [{"role": state.role, "content": stringify(result)}] - trace = block.model_copy() + result, trace = process_expr_of(block, "data", scope, loc) + background = [{"role": state.role, "content": stringify(result)}] if state.yield_result: yield YieldResultMessage(result) if state.yield_background: @@ -352,112 +419,101 @@ def step_block_body( if state.yield_background: yield YieldBackgroundMessage(background) case DocumentBlock(): - result, background, scope, document = yield from step_blocks( + result, background, scope, trace = yield from step_blocks_of( + block, + "document", IterationType.DOCUMENT, state, scope, - block.document, - append(loc, "document"), + loc, ) - trace = block.model_copy(update={"document": document}) case SequenceBlock(): - result, background, scope, sequence = yield from step_blocks( + result, background, scope, trace = yield from step_blocks_of( + block, + "sequence", IterationType.SEQUENCE, state, scope, - block.sequence, - append(loc, "sequence"), + loc, ) - trace = block.model_copy(update={"sequence": sequence}) case ArrayBlock(): - result, background, scope, array = yield from step_blocks( + result, background, scope, trace = yield from step_blocks_of( + block, + "array", IterationType.ARRAY, state, scope, - block.array, - append(loc, "array"), + loc, ) - trace = block.model_copy(update={"array": array}) case ObjectBlock(): iteration_state = state.with_yield_result(False) if isinstance(block.object, dict): - values, background, scope, values_trace = yield from step_blocks( - IterationType.ARRAY, - iteration_state, - scope, - list(block.object.values()), - append(loc, "object"), - ) + try: + values, background, scope, values_trace = yield from step_blocks( + IterationType.ARRAY, + iteration_state, + scope, + list(block.object.values()), + append(loc, "object"), + ) + except PDLRuntimeStepBlocksError as exc: + obj = dict(zip(block.object.keys(), exc.blocks)) + trace = block.model_copy(update={"object": obj}) + raise PDLRuntimeError( + exc.message, + loc=exc.loc or loc, + trace=trace, + ) from exc assert isinstance(values, list) assert isinstance(values_trace, list) result = dict(zip(block.object.keys(), values)) object_trace = dict(zip(block.object.keys(), values_trace)) + trace = block.model_copy(update={"object": object_trace}) elif isinstance(block.object, list): - ( - results, - background, - scope, - object_trace, - ) = yield from step_blocks( # type: ignore + results, background, scope, trace = yield from step_blocks_of( + block, + "object", IterationType.ARRAY, - iteration_state, + state, scope, - block.object, - append(loc, "object"), + loc, ) result = {} for d in results: result = result | d else: assert False - trace = block.model_copy(update={"object": object_trace}) if state.yield_result and not iteration_state.yield_result: yield YieldResultMessage(result) case MessageBlock(): - content, background, scope, content_trace = yield from step_blocks( + content, background, scope, trace = yield from step_blocks_of( + block, + "content", IterationType.SEQUENCE, state, scope, - block.content, - append(loc, "content"), + loc, ) - result = {"role": state.role, "content": content_trace} - trace = block.model_copy(update={"content": content}) + result = {"role": state.role, "content": content} case IfBlock(): - b, errors = process_condition(scope, block.condition, append(loc, "if")) - if len(errors) != 0: - result = None - background = [] - trace = handle_error( - append(loc, "if"), None, errors, block.model_copy() + b = process_condition_of(block, "condition", scope, loc, "if") + if b: + result, background, scope, trace = yield from step_blocks_of( + block, "then", IterationType.SEQUENCE, state, scope, loc + ) + elif block.elses is not None: + result, background, scope, trace = yield from step_blocks_of( + block, "elses", IterationType.SEQUENCE, state, scope, loc, "else" ) else: - if b: - thenloc = append(loc, "then") - result, background, scope, then_trace = yield from step_blocks( - IterationType.SEQUENCE, state, scope, block.then, thenloc - ) - trace = block.model_copy( - update={ - "if_result": b, - "then": then_trace, - } - ) - elif block.elses is not None: - elseloc = append(loc, "else") - result, background, scope, else_trace = yield from step_blocks( - IterationType.SEQUENCE, state, scope, block.elses, elseloc - ) - trace = block.model_copy( - update={ - "if_result": b, - "elses": else_trace, - } - ) - else: - result = "" - background = [] - trace = block.model_copy(update={"if_result": b}) + result = "" + background = [] + trace = block + trace = trace.model_copy( + update={ + "if_result": b, + } + ) case RepeatBlock(num_iterations=n): results = [] background = [] @@ -466,26 +522,35 @@ def step_block_body( iteration_state = state.with_yield_result( state.yield_result and block.iteration_type == IterationType.DOCUMENT ) - for _ in range(n): - repeatloc = append(loc, "repeat") - scope = scope | {"context": messages_concat(context_init, background)} - ( - iteration_result, - iteration_background, - scope, - body_trace, - ) = yield from step_blocks( - IterationType.SEQUENCE, - iteration_state, - scope, - block.repeat, - repeatloc, - ) - results.append(iteration_result) - background = messages_concat(background, iteration_background) - iterations_trace.append(body_trace) - if contains_error(body_trace): - break + repeat_loc = append(loc, "repeat") + try: + for _ in range(n): + scope = scope | { + "context": messages_concat(context_init, background) + } + ( + iteration_result, + iteration_background, + scope, + body_trace, + ) = yield from step_blocks( + IterationType.SEQUENCE, + iteration_state, + scope, + block.repeat, + repeat_loc, + ) + results.append(iteration_result) + background = messages_concat(background, iteration_background) + iterations_trace.append(body_trace) + except PDLRuntimeStepBlocksError as exc: + iterations_trace.append(exc.blocks) + trace = block.model_copy(update={"trace": iterations_trace}) + raise PDLRuntimeError( + exc.message, + loc=exc.loc or repeat_loc, + trace=trace, + ) from exc result = combine_results(block.iteration_type, results) if state.yield_result and not iteration_state.yield_result: yield YieldResultMessage(result) @@ -495,44 +560,41 @@ def step_block_body( background = [] iter_trace: list[BlocksType] = [] context_init = scope_init["context"] - items: dict[str, Any] = {} + items, block = process_expr_of(block, "fors", scope, loc, "for") lengths = [] - for k, v in block.fors.items(): - klist: list[Any] = [] - kloc = append(append(block.location, "for"), k) - klist, errors = process_expr(scope, v, kloc) - if len(errors) != 0: - trace = handle_error(kloc, None, errors, block.model_copy()) - if not isinstance(klist, list): - trace = handle_error( - kloc, - "Values inside the For block must be lists", - [], - block.model_copy(), + for idx, lst in items.items(): + if not isinstance(lst, list): + msg = "Values inside the For block must be lists." + lst_loc = append( + append(block.location or empty_block_location, "for"), idx ) - klist = [] - items = items | {k: klist} - lengths.append(len(klist)) + raise PDLRuntimeError( + message=msg, + loc=lst_loc, + trace=ErrorBlock(msg=msg, location=lst_loc, program=block), + fallback=[], + ) + lengths.append(len(lst)) if len(set(lengths)) != 1: # Not all the lists are of the same length - result = [] - trace = handle_error( - append(block.location, "for"), - "Lists inside the For block must be of the same length", - [], - block.model_copy(), - ) - else: - iteration_state = state.with_yield_result( - state.yield_result - and block.iteration_type == IterationType.DOCUMENT + msg = "Lists inside the For block must be of the same length." + for_loc = append(block.location or empty_block_location, "for") + raise PDLRuntimeError( + msg, + loc=for_loc, + trace=ErrorBlock(msg=msg, location=for_loc, program=block), + fallback=[], ) + iteration_state = state.with_yield_result( + state.yield_result and block.iteration_type == IterationType.DOCUMENT + ) + repeat_loc = append(loc, "repeat") + try: for i in range(lengths[0]): scope = scope | { "context": messages_concat(context_init, background) } for k in items.keys(): scope = scope | {k: items[k][i]} - newloc = append(loc, "repeat") ( iteration_result, iteration_background, @@ -543,18 +605,24 @@ def step_block_body( iteration_state, scope, block.repeat, - newloc, + repeat_loc, ) background = messages_concat(background, iteration_background) results.append(iteration_result) iter_trace.append(body_trace) - if contains_error(body_trace): - break - result = combine_results(block.iteration_type, results) - if state.yield_result and not iteration_state.yield_result: - yield YieldResultMessage(result) + except PDLRuntimeStepBlocksError as exc: + iter_trace.append(exc.blocks) trace = block.model_copy(update={"trace": iter_trace}) - case RepeatUntilBlock(until=cond): + raise PDLRuntimeError( + exc.message, + loc=exc.loc or repeat_loc, + trace=trace, + ) from exc + result = combine_results(block.iteration_type, results) + if state.yield_result and not iteration_state.yield_result: + yield YieldResultMessage(result) + trace = block.model_copy(update={"trace": iter_trace}) + case RepeatUntilBlock(): results = [] stop = False background = [] @@ -563,33 +631,36 @@ def step_block_body( iteration_state = state.with_yield_result( state.yield_result and block.iteration_type == IterationType.DOCUMENT ) - while not stop: - scope = scope | {"context": messages_concat(context_init, background)} - repeatloc = append(loc, "repeat") - ( - iteration_result, - iteration_background, - scope, - body_trace, - ) = yield from step_blocks( - IterationType.SEQUENCE, - iteration_state, - scope, - block.repeat, - repeatloc, - ) - results.append(iteration_result) - background = messages_concat(background, iteration_background) - iterations_trace.append(body_trace) - if contains_error(body_trace): - break - stop, errors = process_condition(scope, cond, append(loc, "until")) - if len(errors) != 0: - trace = handle_error( - append(loc, "until"), None, errors, block.model_copy() + repeat_loc = append(loc, "repeat") + try: + while not stop: + scope = scope | { + "context": messages_concat(context_init, background) + } + ( + iteration_result, + iteration_background, + scope, + body_trace, + ) = yield from step_blocks( + IterationType.SEQUENCE, + iteration_state, + scope, + block.repeat, + repeat_loc, ) - iterations_trace.append(trace) - break + results.append(iteration_result) + background = messages_concat(background, iteration_background) + iterations_trace.append(body_trace) + stop = process_condition_of(block, "until", scope, loc) + except PDLRuntimeStepBlocksError as exc: + iterations_trace.append(exc.blocks) + trace = block.model_copy(update={"trace": iterations_trace}) + raise PDLRuntimeError( + exc.message, + loc=exc.loc or repeat_loc, + trace=trace, + ) from exc result = combine_results(block.iteration_type, results) if state.yield_result and not iteration_state.yield_result: yield YieldResultMessage(result) @@ -625,24 +696,6 @@ def step_block_body( case _: assert False, f"Internal error: unsupported type ({type(block)})" - if isinstance(trace, ErrorBlock) or children_contain_error(trace): - if block.fallback is None: - trace.has_error = True - else: - ( - result, - fallback_background, - scope, - fallback_trace, - ) = yield from step_blocks( - IterationType.SEQUENCE, - state, - scope, - blocks=block.fallback, - loc=append(loc, "fallback"), - ) - background = messages_concat(background, fallback_background) - trace.fallback = fallback_trace return result, background, scope, trace @@ -666,6 +719,41 @@ def step_defs( return scope, defs_trace +BlockTypeTVarStepBlocksOf = TypeVar( + "BlockTypeTVarStepBlocksOf", bound=AdvancedBlockType +) + + +def step_blocks_of( # pylint: disable=too-many-arguments + block: BlockTypeTVarStepBlocksOf, + field: str, + iteration_type: IterationType, + state: InterpreterState, + scope: ScopeType, + loc: LocationType, + field_alias: Optional[str] = None, +) -> Generator[ + YieldMessage, Any, tuple[Any, Messages, ScopeType, BlockTypeTVarStepBlocksOf] +]: + try: + result, background, scope, blocks = yield from step_blocks( + iteration_type, + state, + scope, + getattr(block, field), + append(loc, field_alias or field), + ) + except PDLRuntimeStepBlocksError as exc: + trace = block.model_copy(update={field: exc.blocks}) + raise PDLRuntimeError( + exc.message, + loc=exc.loc or loc, + trace=trace, + ) from exc + trace = block.model_copy(update={field: blocks}) + return result, background, scope, trace + + def step_blocks( iteration_type: IterationType, state: InterpreterState, @@ -681,20 +769,30 @@ def step_blocks( iteration_state = state.with_yield_result( state.yield_result and iteration_type != IterationType.ARRAY ) + new_loc = None background = [] trace = [] context_init = scope["context"] - for i, block in enumerate(blocks): - scope = scope | {"context": messages_concat(context_init, background)} - newloc = append(loc, "[" + str(i) + "]") - if iteration_type == IterationType.SEQUENCE and state.yield_result: - iteration_state = state.with_yield_result(i + 1 == len(blocks)) - iteration_result, iteration_background, scope, t = yield from step_block( - iteration_state, scope, block, newloc - ) - results.append(iteration_result) - background = messages_concat(background, iteration_background) - trace.append(t) # type: ignore + try: + for i, block in enumerate(blocks): + scope = scope | {"context": messages_concat(context_init, background)} + new_loc = append(loc, "[" + str(i) + "]") + if iteration_type == IterationType.SEQUENCE and state.yield_result: + iteration_state = state.with_yield_result(i + 1 == len(blocks)) + ( + iteration_result, + iteration_background, + scope, + t, + ) = yield from step_block(iteration_state, scope, block, new_loc) + results.append(iteration_result) + background = messages_concat(background, iteration_background) + trace.append(t) # type: ignore + except PDLRuntimeError as exc: + trace.append(exc.trace) # type: ignore + raise PDLRuntimeStepBlocksError( + message=exc.message, blocks=trace, loc=exc.loc or new_loc + ) from exc else: iteration_state = state.with_yield_result( state.yield_result and iteration_type != IterationType.ARRAY @@ -726,9 +824,53 @@ def combine_results(iteration_type: IterationType, results: list[Any]): return result -def process_expr( - scope: ScopeType, expr: Any, loc: LocationType -) -> tuple[Any, list[str]]: +BlockTypeTVarProcessExprOf = TypeVar( + "BlockTypeTVarProcessExprOf", bound=AdvancedBlockType +) + + +def process_expr_of( + block: BlockTypeTVarProcessExprOf, + field: str, + scope: ScopeType, + loc: LocationType, + field_alias: Optional[str] = None, +) -> tuple[Any, BlockTypeTVarProcessExprOf]: + expr = getattr(block, field) + loc = append(loc, field_alias or field) + try: + result = process_expr(scope, expr, loc) + except PDLRuntimeExpressionError as exc: + raise PDLRuntimeError( + exc.message, + loc=exc.loc or loc, + trace=ErrorBlock(msg=exc.message, location=loc, program=block), + ) from exc + trace = block.model_copy(update={field: result}) + return result, trace + + +def process_condition_of( + block: AdvancedBlockType, + field: str, + scope: ScopeType, + loc: LocationType, + field_alias: Optional[str] = None, +) -> bool: + expr = getattr(block, field) + loc = append(loc, field_alias or field) + try: + result = process_expr(scope, expr, loc) + except PDLRuntimeExpressionError as exc: + raise PDLRuntimeError( + exc.message, + loc=exc.loc or loc, + trace=ErrorBlock(msg=exc.message, location=loc, program=block), + ) from exc + return result + + +def process_expr(scope: ScopeType, expr: Any, loc: LocationType) -> Any: if isinstance(expr, str): try: if expr.startswith("{{") and expr.endswith("}}") and "}}" not in expr[:-2]: @@ -753,37 +895,29 @@ def process_expr( undefined=StrictUndefined, ) s = template.render(scope) - except UndefinedError as e: - msg = f"{get_loc_string(loc)}{e}" - return (None, [msg]) - - return (s, []) + except UndefinedError as exc: + raise PDLRuntimeExpressionError( + f"Error during the evaluation of {expr}: {exc}", loc + ) from exc + except TemplateSyntaxError as exc: + raise PDLRuntimeExpressionError( + f"Syntax error in {expr}: {exc}", loc + ) from exc + + return s if isinstance(expr, list): - errors = [] result = [] for index, x in enumerate(expr): - res, errs = process_expr(scope, x, append(loc, "[" + str(index) + "]")) - if len(errs) != 0: - errors += errs + res = process_expr(scope, x, append(loc, "[" + str(index) + "]")) result.append(res) - return (result, errors) # type: ignore + return result if isinstance(expr, dict): - errors = [] result_dict: dict[str, Any] = {} for k, x in expr.items(): - r, errs = process_expr(scope, x, append(loc, k)) - if len(errs) != 0: - errors += errs + r = process_expr(scope, x, append(loc, k)) result_dict[k] = r - return (result_dict, errors) # type: ignore - return (expr, []) - - -def process_condition( - scope: ScopeType, cond: ExpressionType, loc: LocationType -) -> tuple[bool, list[str]]: - b, errors = process_expr(scope, cond, loc) - return b, errors + return result_dict + return expr def step_call_model( @@ -798,20 +932,41 @@ def step_call_model( Any, Messages, ScopeType, - BamModelBlock | LitellmModelBlock | ErrorBlock, + BamModelBlock | LitellmModelBlock, ], ]: # evaluate model name - model, errors = process_expr(scope, block.model, append(loc, "model")) + _, concrete_block = process_expr_of(block, "model", scope, loc) + # evaluate model params + match concrete_block: + case BamModelBlock(): + if isinstance(concrete_block.parameters, BamTextGenerationParameters): + concrete_block = concrete_block.model_copy( + update={"parameters": concrete_block.parameters.model_dump()} + ) + _, concrete_block = process_expr_of( + concrete_block, "parameters", scope, loc + ) + case LitellmModelBlock(): + if isinstance(concrete_block.parameters, LitellmParameters): + concrete_block = concrete_block.model_copy( + update={"parameters": concrete_block.parameters.model_dump()} + ) + _, concrete_block = process_expr_of( + concrete_block, "parameters", scope, loc + ) + case _: + assert False # evaluate input model_input: Messages - if block.input is not None: # If not implicit, then input must be a block - model_input_result, _, _, input_trace = yield from step_blocks( + if concrete_block.input is not None: # If not implicit, then input must be a block + model_input_result, _, _, input_trace = yield from step_blocks_of( + concrete_block, + "input", IterationType.SEQUENCE, state.with_yield_result(False).with_yield_background(False), scope, - block.input, - append(loc, "input"), + loc, ) if isinstance(model_input_result, str): model_input = [{"role": None, "content": model_input_result}] @@ -820,46 +975,11 @@ def step_call_model( else: model_input = scope["context"] input_trace = None - # evaluate model params - match block: - case BamModelBlock(): - if isinstance(block.parameters, BamTextGenerationParameters): - params_expr = block.parameters.model_dump() - else: - params_expr = block.parameters - params, param_errors = process_expr(scope, params_expr, loc) - errors += param_errors - concrete_block = block.model_copy( - update={ - "model": model, - "input": input_trace, - "parameters": params, - } - ) - case LitellmModelBlock(): - if isinstance(block.parameters, LitellmParameters): - params_expr = litellm_parameters_to_dict(block.parameters) - else: - params_expr = block.parameters - params, param_errors = process_expr(scope, params_expr, loc) - errors += param_errors - concrete_block = block.model_copy( - update={ - "model": model, - "input": input_trace, - "parameters": params, - } - ) - case _: - assert False - if len(errors) != 0: - trace = handle_error( - loc, - None, - errors, - block.model_copy(update={"trace": concrete_block}), - ) - return "", [], scope, trace + concrete_block = concrete_block.model_copy( + update={ + "input": input_trace, + } + ) # Execute model call try: litellm_params = {} @@ -881,14 +1001,13 @@ def get_transformed_inputs(kwargs): append_log(state, "Model Output", result) trace = block.model_copy(update={"result": result, "trace": concrete_block}) return result, background, scope, trace - except Exception as e: - trace = handle_error( - loc, - f"Model error: {e}", - [], - block.model_copy(update={"trace": concrete_block}), - ) - return None, [], scope, trace + except Exception as exc: + message = f"Error during model call: {repr(exc)}" + raise PDLRuntimeError( + message, + loc=loc, + trace=ErrorBlock(msg=message, location=loc, program=concrete_block), + ) from exc def generate_client_response( # pylint: disable=too-many-arguments @@ -1027,16 +1146,15 @@ def generate_client_response_batching( # pylint: disable=too-many-arguments def step_call_api( state: InterpreterState, scope: ScopeType, block: ApiBlock, loc: LocationType -) -> Generator[ - YieldMessage, Any, tuple[Any, Messages, ScopeType, ApiBlock | ErrorBlock] -]: +) -> Generator[YieldMessage, Any, tuple[Any, Messages, ScopeType, ApiBlock]]: background: Messages - input_value, _, _, input_trace = yield from step_blocks( + input_value, _, _, block = yield from step_blocks_of( + block, + "input", IterationType.SEQUENCE, state.with_yield_result(False).with_yield_background(False), scope, - block.input, - append(loc, "input"), + loc, ) input_str = block.url + stringify(input_value) try: @@ -1045,60 +1163,50 @@ def step_call_api( result = response.json() background = [{"role": state.role, "content": stringify(result)}] append_log(state, "API Output", background) - trace = block.model_copy(update={"input": input_trace}) - except Exception as e: - trace = handle_error( - loc, - f"API error: {e}", - [], - block.model_copy(update={"input": input_trace}), - ) - result = None - background = [] + trace = block.model_copy(update={"result": result}) + except Exception as exc: + message = f"API error: {repr(exc)}" + raise PDLRuntimeError( + message, + loc=loc, + trace=ErrorBlock(msg=message, program=block), + ) from exc return result, background, scope, trace def step_call_code( state: InterpreterState, scope: ScopeType, block: CodeBlock, loc: LocationType -) -> Generator[ - YieldMessage, Any, tuple[Any, Messages, ScopeType, CodeBlock | ErrorBlock] -]: +) -> Generator[YieldMessage, Any, tuple[Any, Messages, ScopeType, CodeBlock]]: background: Messages - code_s, _, _, code_trace = yield from step_blocks( + code_s, _, _, block = yield from step_blocks_of( + block, + "code", IterationType.SEQUENCE, state.with_yield_result(False).with_yield_background(False), scope, - block.code, - append(loc, "code"), + loc, ) append_log(state, "Code Input", code_s) - try: - match block.lan: - case "python": + match block.lan: + case "python": + try: result = call_python(code_s, scope) background = [{"role": state.role, "content": str(result)}] - case _: - trace = handle_error( - append(loc, "lan"), - f"Unsupported language: {block.lan}", - [], - block.model_copy(), - ) - result = None - background = [] - return result, background, scope, trace - except Exception as e: - trace = handle_error( - loc, - f"Code error: {e}", - [], - block.model_copy(update={"code": code_s}), - ) - result = None - background = [] - + except Exception as exc: + raise PDLRuntimeError( + f"Code error: {repr(exc)}", + loc=loc, + trace=block.model_copy(update={"code": code_s}), + ) from exc + case _: + message = f"Unsupported language: {block.lan}" + raise PDLRuntimeError( + message, + loc=loc, + trace=block.model_copy(), + ) append_log(state, "Code Output", result) - trace = block.model_copy(update={"result": result, "code": code_trace}) + trace = block.model_copy(update={"result": result}) return result, background, scope, trace @@ -1114,66 +1222,62 @@ def call_python(code: str, scope: dict) -> Any: def step_call( state: InterpreterState, scope: ScopeType, block: CallBlock, loc: LocationType -) -> Generator[ - YieldMessage, Any, tuple[Any, Messages, ScopeType, CallBlock | ErrorBlock] -]: +) -> Generator[YieldMessage, Any, tuple[Any, Messages, ScopeType, CallBlock]]: result = None background: Messages = [] - args, errors = process_expr(scope, block.args, append(loc, "args")) - if len(errors) != 0: - trace = handle_error(append(loc, "args"), None, errors, block.model_copy()) - closure_expr, errors = process_expr(scope, block.call, append(loc, "call")) - if len(errors) != 0: - trace = handle_error(append(loc, "call"), None, errors, block.model_copy()) - closure = get_var(closure_expr, scope) - if closure is None: - trace = handle_error( - append(loc, "call"), - f"Function is undefined: {block.call}", - [], - block.model_copy(), + args, block = process_expr_of(block, "args", scope, loc) + closure_expr, block = process_expr_of(block, "call", scope, loc) + try: + closure = get_var(closure_expr, scope, append(loc, "call")) + except PDLRuntimeExpressionError as exc: + raise PDLRuntimeError( + exc.message, + loc=exc.loc or loc, + trace=ErrorBlock(msg=exc.message, location=loc, program=block), + ) from exc + args_loc = append(loc, "args") + type_errors = type_check_args(args, closure.function, args_loc) + if len(type_errors) > 0: + raise PDLRuntimeError( + f"Type errors during function call to {closure_expr}:\n" + + "\n".join(type_errors), + loc=args_loc, + trace=block.model_copy(), ) - else: - argsloc = append(loc, "args") - type_errors = type_check_args(args, closure.function, argsloc) - if len(type_errors) > 0: - trace = handle_error( - argsloc, - f"Type errors during function call to {closure_expr}", - type_errors, - block.model_copy(), - ) - else: - f_body = closure.returns - f_scope = closure.scope | {"context": scope["context"]} | args - funloc = LocationType( - file=closure.location.file, - path=closure.location.path + ["return"], - table=loc.table, - ) - result, background, _, f_trace = yield from step_blocks( - IterationType.SEQUENCE, state, f_scope, f_body, funloc + f_body = closure.returns + f_scope = closure.scope | {"context": scope["context"]} | args + fun_loc = LocationType( + file=closure.location.file, + path=closure.location.path + ["return"], + table=loc.table, + ) + try: + result, background, _, f_trace = yield from step_blocks( + IterationType.SEQUENCE, state, f_scope, f_body, fun_loc + ) + except PDLRuntimeError as exc: + raise PDLRuntimeError( + exc.message, + loc=exc.loc or fun_loc, + trace=block.model_copy(update={"trace": exc.trace}), + ) from exc + trace = block.model_copy(update={"trace": f_trace}) + if closure.spec is not None: + errors = type_check_spec(result, closure.spec, fun_loc) + if len(errors) > 0: + raise PDLRuntimeError( + f"Type errors in result of function call to {closure_expr}:\n" + + "\n".join(errors), + loc=loc, + trace=trace, ) - trace = block.model_copy(update={"trace": f_trace}) - if closure.spec is not None: - errors = type_check_spec(result, closure.spec, funloc) - if len(errors) > 0: - trace = handle_error( - loc, - f"Type errors in result of function call to {closure_expr}", - errors, - trace, - ) return result, background, scope, trace def process_input( state: InterpreterState, scope: ScopeType, block: ReadBlock, loc: LocationType -) -> tuple[str, Messages, ScopeType, ReadBlock | ErrorBlock]: - read, errors = process_expr(scope, block.read, append(loc, "read")) - if len(errors) != 0: - trace = handle_error(loc, None, errors, block.model_copy()) - return "", [], scope, trace +) -> tuple[str, Messages, ScopeType, ReadBlock]: + read, block = process_expr_of(block, "read", scope, loc) if read is not None: file = state.cwd / read with open(file, encoding="utf-8") as f: @@ -1211,9 +1315,7 @@ def step_include( scope: ScopeType, block: IncludeBlock, loc: LocationType, -) -> Generator[ - YieldMessage, Any, tuple[Any, Messages, ScopeType, IncludeBlock | ErrorBlock] -]: +) -> Generator[YieldMessage, Any, tuple[Any, Messages, ScopeType, IncludeBlock]]: file = state.cwd / block.include try: prog, newloc = parse_file(file) @@ -1222,14 +1324,13 @@ def step_include( ) include_trace = block.model_copy(update={"trace": trace}) return result, background, scope, include_trace - except PDLParseError as e: - trace = handle_error( - append(loc, "include"), - f"Attempting to include invalid yaml: {str(file)}", - e.msg, - block.model_copy(), - ) - return None, [], scope, trace + except PDLParseError as exc: + message = f"Attempting to include invalid yaml: {str(file)}\n{exc.message}" + raise PDLRuntimeError( + message, + loc=loc, + trace=ErrorBlock(msg=message, program=block.model_copy()), + ) from exc def parse_result(parser: ParserType, text: str) -> Optional[dict[str, Any] | list[Any]]: @@ -1240,14 +1341,14 @@ def parse_result(parser: ParserType, text: str) -> Optional[dict[str, Any] | lis result = json.loads(text) except Exception as exc: raise PDLRuntimeParserError( - "Attempted to parse ill-formed JSON" + f"Attempted to parse ill-formed JSON: {repr(exc)}" ) from exc case "yaml": try: result = yaml.safe_load(text) except Exception as exc: raise PDLRuntimeParserError( - "Attempted to parse ill-formed YAML" + f"Attempted to parse ill-formed YAML: {repr(exc)}" ) from exc case PdlParser(): assert False, "TODO" @@ -1262,7 +1363,11 @@ def parse_result(parser: ParserType, text: str) -> Optional[dict[str, Any] | lis matcher = re.fullmatch case _: assert False - m = matcher(regex, text, flags=re.M) + try: + m = matcher(regex, text, flags=re.M) + except Exception as exc: + msg = f"Fail to parse with regex {regex}: {repr(exc)}" + raise PDLRuntimeParserError(msg) from exc if m is None: return None if parser.spec is None: @@ -1292,56 +1397,10 @@ def parse_result(parser: ParserType, text: str) -> Optional[dict[str, Any] | lis return result -def get_var(var: str, scope: ScopeType) -> Any: - try: - segs = var.split(".") - res = scope[segs[0]] - - for v in segs[1:]: - res = res[v] - except Exception: - return None - return res +def get_var(var: str, scope: ScopeType, loc: LocationType) -> Any: + return process_expr(scope, "{{ " + var + " }}", loc) def append_log(state: InterpreterState, title, somestring): state.log.append("********** " + title + " **********\n") state.log.append(str(somestring) + "\n") - - -def handle_error( - loc: LocationType, - top_message: Optional[str], - errors: list[str], - subtrace: BlocksType, -) -> ErrorBlock: - msg = "" - if top_message is not None: - msg += f"{get_loc_string(loc)}{top_message}\n" - msg += "\n".join(errors) - print("\n" + msg, file=sys.stderr) - return ErrorBlock(msg=msg, program=subtrace) - - -def _raise_on_error(block: BlockType): - if not isinstance(block, Block) or block.fallback is not None: - return - if isinstance(block, ErrorBlock): - raise StopIteration - iter_block_children(_raise_on_error, block) - - -def children_contain_error(block: AdvancedBlockType) -> bool: - try: - iter_block_children(_raise_on_error, block) - return False - except StopIteration: - return True - - -def contains_error(blocks: BlocksType) -> bool: - try: - iter_blocks(_raise_on_error, blocks) - return False - except StopIteration: - return True diff --git a/tests/data/line/hello26.pdl b/tests/data/line/hello26.pdl index 98bca3efa..4d6e6c24b 100644 --- a/tests/data/line/hello26.pdl +++ b/tests/data/line/hello26.pdl @@ -10,7 +10,7 @@ document: input: document: - for: - question: "{{ questions2 }}" + question: "Hello" answer: "{{ data.answers }}" repeat: document: diff --git a/tests/data/line/hello27.pdl b/tests/data/line/hello27.pdl new file mode 100644 index 000000000..2465e54ae --- /dev/null +++ b/tests/data/line/hello27.pdl @@ -0,0 +1,33 @@ +description: Hello world to call into a model +document: +- read: hello16_data.json + parser: json + def: data + contribute: [] + spec: { "questions": ["str"], "answers": ["obj"] } +- model: watsonx/ibm/granite-34b-code-instruct + def: model_output + input: + document: + - for: + question: "{{ [ ] }}" + answer: "{{ data.answers }}" + repeat: + document: + - "{{ question }}" + - "{{ answer }}" + as: document + - 'Question: Write a JSON object with 2 fields "bob" and "carol" set to "20" and "30" respectively.' + parameters: + decoding_method: greedy + stop: + - '}' + include_stop_sequence: true + mock_response: + | + Here is the code: + ```json + { + "bob": "20", + "carol": "30" + } \ No newline at end of file diff --git a/tests/test_array.py b/tests/test_array.py index d88668022..2e0230a86 100644 --- a/tests/test_array.py +++ b/tests/test_array.py @@ -33,7 +33,6 @@ def test_for_data(): { "def": "I", "document": [{"lan": "python", "code": "result = 0"}], - "contribute": [], }, { "repeat": [ diff --git a/tests/test_errors.py b/tests/test_errors.py index 131679405..c85529c58 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -45,7 +45,7 @@ def test_error1(): "Hello,", { "model": "watsonx/ibm/granite-20b-code-instruct", - "parameters": { + "parameterss": { "decoding_method": "greedy", "stop_sequences": ["!"], "include_stop_sequence": False, @@ -65,57 +65,59 @@ def test_error2(): ) -error3 = { - "description": "Hello world with a variable to call into a model", - "document": [ - "Hello,", - { - "model": "watsonx/ibm/granite-20b-code-instruct", - "parameters": { - "decoding_methods": "greedy", - "stop_sequences": ["!"], - "include_stop_sequence": False, - }, - }, - "!\n", - ], -} - - -def test_error3(): - error( - error3, - [ - ":0 - Field not allowed: decoding_methods", - ], - ) - - -error4 = { - "description": "Hello world with a variable to call into a model", - "document": [ - "Hello,", - { - "model": "watsonx/ibm/granite-20b-code-instruct", - "parameters": { - "decoding_methods": "greedy", - "stop_sequencess": ["!"], - "include_stop_sequence": False, - }, - }, - "!\n", - ], -} - - -def test_error4(): - error( - error4, - [ - ":0 - Field not allowed: decoding_methods", - ":0 - Field not allowed: stop_sequencess", - ], - ) +# error3 = { +# "description": "Hello world with a variable to call into a model", +# "document": [ +# "Hello,", +# { +# "model": "ibm/granite-20b-code-instruct", +# "platform": "bam", +# "parameters": { +# "decoding_methods": "greedy", +# "stop_sequences": ["!"], +# "include_stop_sequence": False, +# }, +# }, +# "!\n", +# ], +# } + + +# def test_error3(): +# error( +# error3, +# [ +# ":0 - Field not allowed: decoding_methods", +# ], +# ) + + +# error4 = { +# "description": "Hello world with a variable to call into a model", +# "document": [ +# "Hello,", +# { +# "model": "ibm/granite-20b-code-instruct", +# "platform": "bam", +# "parameters": { +# "decoding_methods": "greedy", +# "stop_sequencess": ["!"], +# "include_stop_sequence": False, +# }, +# }, +# "!\n", +# ], +# } + + +# def test_error4(): +# error( +# error4, +# [ +# ":0 - Field not allowed: decoding_methods", +# ":0 - Field not allowed: stop_sequencess", +# ], +# ) error5 = { diff --git a/tests/test_expr.py b/tests/test_expr.py index c48486ceb..b2d4c830c 100644 --- a/tests/test_expr.py +++ b/tests/test_expr.py @@ -1,7 +1,9 @@ +import pytest + from pdl.pdl_ast import Program from pdl.pdl_interpreter import ( InterpreterState, - contains_error, + PDLRuntimeError, empty_scope, process_prog, ) @@ -64,9 +66,8 @@ def test_false(): def test_undefined_var(): state = InterpreterState() data = Program.model_validate(undefined_var_data) - document, _, _, trace = process_prog(state, empty_scope, data) - assert contains_error(trace) - assert document == "Hello {{ X }}" + with pytest.raises(PDLRuntimeError): + process_prog(state, empty_scope, data) autoescape_data = {"document": "<|system|>"} diff --git a/tests/test_for.py b/tests/test_for.py index b9ee29b26..448bde26e 100644 --- a/tests/test_for.py +++ b/tests/test_for.py @@ -1,7 +1,9 @@ +import pytest + from pdl.pdl_ast import Program from pdl.pdl_interpreter import ( InterpreterState, - contains_error, + PDLRuntimeError, empty_scope, process_prog, ) @@ -81,8 +83,8 @@ def test_for_data2(): def test_for_data3(): state = InterpreterState() data = Program.model_validate(for_data3) - _, _, _, trace = process_prog(state, empty_scope, data) - assert contains_error(trace) + with pytest.raises(PDLRuntimeError): + process_prog(state, empty_scope, data) for_data4 = { diff --git a/tests/test_hello.py b/tests/test_hello.py index e4d832012..fcc4e756c 100644 --- a/tests/test_hello.py +++ b/tests/test_hello.py @@ -1,7 +1,9 @@ -from pdl.pdl_ast import Program, RepeatBlock +import pytest + +from pdl.pdl_ast import Program from pdl.pdl_interpreter import ( InterpreterState, - contains_error, + PDLRuntimeError, empty_scope, process_prog, ) @@ -118,13 +120,5 @@ def test_repeat_nested3(): def test_repeat_error(): state = InterpreterState() data = Program.model_validate(repeat_data_error) - _, _, _, trace = process_prog(state, empty_scope, data) - errors = 0 - print(trace) - if isinstance(trace, RepeatBlock): - traces = trace.trace or [] - for document in traces: - if contains_error(document): - errors += 1 - - assert errors == 1 + with pytest.raises(PDLRuntimeError): + process_prog(state, empty_scope, data) diff --git a/tests/test_line_table.py b/tests/test_line_table.py index 0006b631a..852133582 100644 --- a/tests/test_line_table.py +++ b/tests/test_line_table.py @@ -43,8 +43,7 @@ def test_line1(capsys): "file": "tests/data/line/hello3.pdl", "errors": [ "", - "Hello, World!", - "tests/data/line/hello3.pdl:6 - Type errors during spec checking", + "tests/data/line/hello3.pdl:6 - Type errors during spec checking:", "tests/data/line/hello3.pdl:6 - World! should be of type ", ], } @@ -100,8 +99,7 @@ def test_line8(capsys): "file": "tests/data/line/hello9.pdl", "errors": [ "", - "hello", - "tests/data/line/hello9.pdl:3 - Type errors during spec checking", + "tests/data/line/hello9.pdl:3 - Type errors during spec checking:", "tests/data/line/hello9.pdl:3 - hello should be of type ", ], } @@ -141,8 +139,7 @@ def test_line11(capsys): "file": "tests/data/line/hello12.pdl", "errors": [ "", - "Hello! How are you?", - "tests/data/line/hello12.pdl:9 - Type errors during spec checking", + "tests/data/line/hello12.pdl:9 - Type errors during spec checking:", "tests/data/line/hello12.pdl:9 - How are you? should be of type ", ], } @@ -156,9 +153,7 @@ def test_line12(capsys): "file": "tests/data/line/hello13.pdl", "errors": [ "", - "0", - "1", - "tests/data/line/hello13.pdl:9 - Type errors during spec checking", + "tests/data/line/hello13.pdl:9 - Type errors during spec checking:", "tests/data/line/hello13.pdl:9 - 1 should be of type ", ], } @@ -172,8 +167,7 @@ def test_line13(capsys): "file": "tests/data/line/hello14.pdl", "errors": [ "", - "Hello, World!Bonjour le monde!", - "tests/data/line/hello14.pdl:24 - Type errors in result of function call to translate", + "tests/data/line/hello14.pdl:24 - Type errors in result of function call to translate:", "tests/data/line/hello14.pdl:16 - Bonjour le monde! should be of type ", ], } @@ -187,10 +181,7 @@ def test_line14(capsys): "file": "tests/data/line/hello15.pdl", "errors": [ "", - "Hello World!", - "tests/data/line/hello15.pdl:6 - Variable is undefined: boolean", - "tests/data/line/hello15.pdl:7 - 'something' is undefined", - "{{ something }}", + "tests/data/line/hello15.pdl:6 - Error during the evaluation of {{ boolean }}: 'boolean' is undefined", ], } @@ -203,8 +194,7 @@ def test_line15(capsys): "file": "tests/data/line/hello16.pdl", "errors": [ "", - '{"bob": 20, "carol": 30}', - "tests/data/line/hello16.pdl:8 - Type errors during spec checking", + "tests/data/line/hello16.pdl:8 - Type errors during spec checking:", "tests/data/line/hello16.pdl:8 - 30 should be of type ", ], } @@ -218,7 +208,7 @@ def test_line16(capsys): "file": "tests/data/line/hello17.pdl", "errors": [ "", - "tests/data/line/hello17.pdl:3 - Type errors during spec checking", + "tests/data/line/hello17.pdl:3 - Type errors during spec checking:", "tests/data/line/hello17.pdl:3 - hello should be of type ", ], } @@ -230,7 +220,10 @@ def test_line17(capsys): line18 = { "file": "tests/data/line/hello18.pdl", - "errors": ["", "0", "1", "tests/data/line/hello18.pdl:13 - 'J' is undefined"], + "errors": [ + "", + "tests/data/line/hello18.pdl:13 - Error during the evaluation of {{ J == 5 }}: 'J' is undefined", + ], } @@ -242,10 +235,9 @@ def test_line18(capsys): "file": "tests/data/line/hello19.pdl", "errors": [ "", - "Hello,", - "tests/data/line/hello19.pdl:6 - 'models' is undefined", - "tests/data/line/hello19.pdl:6 - Type errors during spec checking", - "tests/data/line/hello19.pdl:6 - should be of type ", + "tests/data/line/hello19.pdl:6 - Error during the evaluation of {{ models }}: 'models' is undefined", + # "tests/data/line/hello19.pdl:6 - Type errors during spec checking:", + # "tests/data/line/hello19.pdl:6 - should be of type ", ], } @@ -258,8 +250,7 @@ def test_line19(capsys): "file": "tests/data/line/hello20.pdl", "errors": [ "", - "tests/data/line/hello20.pdl:3 - 'NAME' is undefined", - "Who is{{ NAME }}?", + "tests/data/line/hello20.pdl:3 - Error during the evaluation of Who is{{ NAME }}?: 'NAME' is undefined", ], } @@ -270,7 +261,10 @@ def test_line20(capsys): line21 = { "file": "tests/data/line/hello21.pdl", - "errors": ["", "tests/data/line/hello21.pdl:3 - 'QUESTION' is undefined", "null"], + "errors": [ + "", + "tests/data/line/hello21.pdl:3 - Error during the evaluation of {{ QUESTION }}: 'QUESTION' is undefined", + ], } @@ -282,8 +276,7 @@ def test_line21(capsys): "file": "tests/data/line/hello22.pdl", "errors": [ "", - "tests/data/line/hello22.pdl:4 - 'I' is undefined", - "{{ I }}", + "tests/data/line/hello22.pdl:4 - Error during the evaluation of {{ I }}: 'I' is undefined", ], } @@ -296,8 +289,7 @@ def test_line22(capsys): "file": "tests/data/line/hello23.pdl", "errors": [ "", - "tests/data/line/hello23.pdl:5 - 'I' is undefined", - "{{ I }}", + "tests/data/line/hello23.pdl:5 - Error during the evaluation of {{ I }}: 'I' is undefined", ], } @@ -310,14 +302,7 @@ def test_line23(capsys): "file": "tests/data/line/hello24.pdl", "errors": [ "", - "Hello, World!null", - "tests/data/line/hello24.pdl:24 - 'GEN1' is undefined", - "tests/data/line/hello24.pdl:25 - 'GEN2' is undefined", - "tests/data/line/hello24.pdl:23 - Type errors during function call to translate", - "tests/data/line/hello24.pdl:21 - None should be of type ", - "tests/data/line/hello24.pdl:25 - None should be of type ", - "tests/data/line/hello24.pdl:21 - Type errors during spec checking", - "tests/data/line/hello24.pdl:24 - None should be of type ", + "tests/data/line/hello24.pdl:24 - Error during the evaluation of Hello,{{ GEN1 }}: 'GEN1' is undefined", ], } @@ -345,15 +330,7 @@ def test_line24(capsys): "file": "tests/data/line/hello26.pdl", "errors": [ "", - "tests/data/line/hello26.pdl:13 - 'questions2' is undefined", - "tests/data/line/hello26.pdl:13 - Values inside the For block must be lists", - "tests/data/line/hello26.pdl:12 - Lists inside the For block must be of the same length", - "Here is the code:", - "```json", - "{", - ' "bob": "20",', - ' "carol": "30"', - "}", + "tests/data/line/hello26.pdl:13 - Values inside the For block must be lists.", ], } @@ -362,12 +339,24 @@ def test_line26(capsys): do_test(line26, capsys) +line27 = { + "file": "tests/data/line/hello27.pdl", + "errors": [ + "", + "tests/data/line/hello27.pdl:12 - Lists inside the For block must be of the same length.", + ], +} + + +def test_line27(capsys): + do_test(line27, capsys) + + line28 = { "file": "tests/data/line/hello28.pdl", "errors": [ - "Hello! {{ QUESTION1 }}", - "tests/data/line/hello28.pdl:9 - 'QUESTION1' is undefined", "", + "tests/data/line/hello28.pdl:9 - Error during the evaluation of {{ QUESTION1 }}: 'QUESTION1' is undefined", ], } @@ -379,12 +368,8 @@ def test_line28(capsys): line29 = { "file": "tests/data/line/hello29.pdl", "errors": [ - "Hello! null", - "tests/data/line/hello29.pdl:10 - 'QUESTION1' is undefined", - "tests/data/line/hello29.pdl:11 - 'QUESTION2' is undefined", - "tests/data/line/hello29.pdl:13 - 'QUESTION3' is undefined", - "tests/data/line/hello29.pdl:15 - 'QUESTION4' is undefined", "", + "tests/data/line/hello29.pdl:10 - Error during the evaluation of {{ QUESTION1 }}: 'QUESTION1' is undefined", ], } @@ -397,8 +382,7 @@ def test_line29(capsys): "file": "tests/data/line/hello30.pdl", "errors": [ "", - "tests/data/line/hello30.pdl:7 - Values inside the For block must be lists", - "[]", + "tests/data/line/hello30.pdl:7 - Values inside the For block must be lists.", ], } diff --git a/tests/test_parser.py b/tests/test_parser.py index 46fbbd327..0b8945ac7 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -1,7 +1,9 @@ +import pytest + from pdl.pdl_ast import Program from pdl.pdl_interpreter import ( InterpreterState, - contains_error, + PDLRuntimeError, empty_scope, process_prog, ) @@ -32,8 +34,7 @@ def test_model_parser(): state = InterpreterState() data = Program.model_validate(model_parser) - result, _, _, trace = process_prog(state, empty_scope, data) - assert not contains_error(trace) + result, _, _, _ = process_prog(state, empty_scope, data) assert result == {"bob": 20, "carol": 30} @@ -53,8 +54,8 @@ def test_model_parser(): def test_model_parser1(): state = InterpreterState() data = Program.model_validate(model_parser1) - _, _, _, trace = process_prog(state, empty_scope, data) - assert contains_error(trace) + with pytest.raises(PDLRuntimeError): + process_prog(state, empty_scope, data) get_parser = {"get": "x", "parser": "json", "def": "y", "contribute": []} @@ -64,8 +65,7 @@ def test_get_parser(): state = InterpreterState() data = Program.model_validate(get_parser) scope = {"x": '{"a": "foo", "b": "bar"}'} - result, _, scope, trace = process_prog(state, scope, data) - assert not contains_error(trace) + result, _, scope, _ = process_prog(state, scope, data) assert result == "" assert scope["x"] == '{"a": "foo", "b": "bar"}' assert scope["y"] == {"a": "foo", "b": "bar"} @@ -87,8 +87,7 @@ def test_get_parser(): def test_code_parser(): state = InterpreterState() data = Program.model_validate(code_parser) - result, _, _, trace = process_prog(state, empty_scope, data) - assert not contains_error(trace) + result, _, _, _ = process_prog(state, empty_scope, data) assert result == {"a": "b", "c": "d"} @@ -101,6 +100,5 @@ def test_code_parser(): def test_code_parser1(): state = InterpreterState() data = Program.model_validate(code_parser1) - result, _, _, trace = process_prog(state, empty_scope, data) - assert not contains_error(trace) + result, _, _, _ = process_prog(state, empty_scope, data) assert result == "{'a': 'b', 'c': 'd'}" diff --git a/tests/test_runtime_errors.py b/tests/test_runtime_errors.py new file mode 100644 index 000000000..4989219f9 --- /dev/null +++ b/tests/test_runtime_errors.py @@ -0,0 +1,137 @@ +import pytest + +from pdl.pdl import exec_str +from pdl.pdl_interpreter import PDLRuntimeError + + +def test_jinja_undefined(): + prog_str = """ +"{{ x }}" +""" + with pytest.raises(PDLRuntimeError) as exc: + exec_str(prog_str) + assert ( + str(exc.value.message) + == "Error during the evaluation of {{ x }}: 'x' is undefined" + ) + + +def test_jinja_access(): + prog_str = """ +"{{ {}['x'] }}" +""" + with pytest.raises(PDLRuntimeError) as exc: + exec_str(prog_str) + assert ( + str(exc.value.message) + == "Error during the evaluation of {{ {}['x'] }}: 'dict object' has no attribute 'x'" + ) + + +def test_jinja_syntax(): + prog_str = """ +"{{ {}[ }}" +""" + with pytest.raises(PDLRuntimeError) as exc: + exec_str(prog_str) + assert ( + str(exc.value.message) + == "Syntax error in {{ {}[ }}: unexpected 'end of template'" + ) + + +def test_parser_json(): + prog_str = """ +document: "{ x: 1 + 1 }" +parser: json +""" + with pytest.raises(PDLRuntimeError) as exc: + exec_str(prog_str) + assert ( + str(exc.value.message) + == "Attempted to parse ill-formed JSON: JSONDecodeError('Expecting property name enclosed in double quotes: line 1 column 3 (char 2)')" + ) + + +def test_parser_regex(): + prog_str = """ +document: "Hello" +parser: + regex: "(" +""" + with pytest.raises(PDLRuntimeError) as exc: + exec_str(prog_str) + assert ( + str(exc.value.message) + == "Fail to parse with regex (: error('missing ), unterminated subpattern at position 0')" + ) + + +def test_type_result(): + prog_str = """ +document: "Hello" +spec: int +""" + with pytest.raises(PDLRuntimeError) as exc: + exec_str(prog_str) + assert ( + str(exc.value.message) + == "Type errors during spec checking:\n:0 - Hello should be of type " + ) + + +def test_get(): + prog_str = """ +document: +- "Hello" +- get: x +""" + with pytest.raises(PDLRuntimeError) as exc: + exec_str(prog_str) + assert ( + str(exc.value.message) + == "Error during the evaluation of {{ x }}: 'x' is undefined" + ) + + +def test_call_undefined(): + prog_str = """ +call: "f" +""" + with pytest.raises(PDLRuntimeError) as exc: + exec_str(prog_str) + assert ( + str(exc.value.message) + == "Error during the evaluation of {{ f }}: 'f' is undefined" + ) + + +def test_call_bad_name(): + prog_str = """ +call: "{{ ( f }}" +""" + with pytest.raises(PDLRuntimeError) as exc: + exec_str(prog_str) + assert ( + str(exc.value.message) + == "Syntax error in {{ ( f }}: unexpected end of template, expected ')'." + ) + + +def test_call_bad_args(): + prog_str = """ +defs: + f: + function: + x: int + return: Hello +call: "f" +args: + x: "{{ (x }}" +""" + with pytest.raises(PDLRuntimeError) as exc: + exec_str(prog_str) + assert ( + str(exc.value.message) + == "Syntax error in {{ (x }}: unexpected end of template, expected ')'." + ) diff --git a/tests/test_type_checking.py b/tests/test_type_checking.py index ace3e5769..ddfeb9948 100644 --- a/tests/test_type_checking.py +++ b/tests/test_type_checking.py @@ -1,9 +1,10 @@ +import pytest import yaml from pdl.pdl_ast import Program from pdl.pdl_interpreter import ( InterpreterState, - contains_error, + PDLRuntimeError, empty_scope, process_prog, ) @@ -168,8 +169,7 @@ def test_pdltype_to_jsonschema(): def test_function_call(): state = InterpreterState() data = Program.model_validate(function_call) - document, _, _, trace = process_prog(state, empty_scope, data) - assert not contains_error(trace) + document, _, _, _ = process_prog(state, empty_scope, data) assert document == "Hello Bob!" @@ -190,8 +190,7 @@ def test_function_call(): def test_function_call1(): state = InterpreterState() data = Program.model_validate(function_call1) - document, _, _, trace = process_prog(state, empty_scope, data) - assert not contains_error(trace) + document, _, _, _ = process_prog(state, empty_scope, data) assert document == "Hello Bob!" @@ -212,8 +211,7 @@ def test_function_call1(): def test_function_call2(): state = InterpreterState() data = Program.model_validate(function_call2) - document, _, _, trace = process_prog(state, empty_scope, data) - assert not contains_error(trace) + document, _, _, _ = process_prog(state, empty_scope, data) assert document == "Hello 42!" @@ -234,8 +232,7 @@ def test_function_call2(): def test_function_call3(): state = InterpreterState() data = Program.model_validate(function_call3) - document, _, _, trace = process_prog(state, empty_scope, data) - assert not contains_error(trace) + document, _, _, _ = process_prog(state, empty_scope, data) assert document == 'Hello ["Bob", "Carrol"]!' @@ -256,8 +253,7 @@ def test_function_call3(): def test_function_call4(): state = InterpreterState() data = Program.model_validate(function_call4) - document, _, _, trace = process_prog(state, empty_scope, data) - assert not contains_error(trace) + document, _, _, _ = process_prog(state, empty_scope, data) assert document == 'Hello {"bob": "caroll"}!' @@ -278,8 +274,7 @@ def test_function_call4(): def test_function_call5(): state = InterpreterState() data = Program.model_validate(function_call5) - document, _, _, trace = process_prog(state, empty_scope, data) - assert not contains_error(trace) + document, _, _, _ = process_prog(state, empty_scope, data) assert document == "Hello true!" @@ -300,8 +295,7 @@ def test_function_call5(): def test_function_call6(): state = InterpreterState() data = Program.model_validate(function_call6) - document, _, _, trace = process_prog(state, empty_scope, data) - assert not contains_error(trace) + document, _, _, _ = process_prog(state, empty_scope, data) assert document == "Hello 6.6!" @@ -322,8 +316,7 @@ def test_function_call6(): def test_function_call7(): state = InterpreterState() data = Program.model_validate(function_call7) - document, _, _, trace = process_prog(state, empty_scope, data) - assert not contains_error(trace) + document, _, _, _ = process_prog(state, empty_scope, data) assert document == "Hello 6.6!" @@ -344,8 +337,8 @@ def test_function_call7(): def test_function_call8(): state = InterpreterState() data = Program.model_validate(function_call8) - _, _, _, trace = process_prog(state, empty_scope, data) - assert contains_error(trace) + with pytest.raises(PDLRuntimeError): + process_prog(state, empty_scope, data) function_call9 = { @@ -365,8 +358,7 @@ def test_function_call8(): def test_function_call9(): state = InterpreterState() data = Program.model_validate(function_call9) - document, _, _, trace = process_prog(state, empty_scope, data) - assert not contains_error(trace) + document, _, _, _ = process_prog(state, empty_scope, data) assert document == "Hello 6.6 street!" @@ -387,8 +379,8 @@ def test_function_call9(): def test_function_call10(): state = InterpreterState() data = Program.model_validate(function_call10) - _, _, _, trace = process_prog(state, empty_scope, data) - assert contains_error(trace) + with pytest.raises(PDLRuntimeError): + process_prog(state, empty_scope, data) function_call11 = { @@ -408,8 +400,8 @@ def test_function_call10(): def test_function_call11(): state = InterpreterState() data = Program.model_validate(function_call11) - _, _, _, trace = process_prog(state, empty_scope, data) - assert contains_error(trace) + with pytest.raises(PDLRuntimeError): + process_prog(state, empty_scope, data) function_call12 = { @@ -429,8 +421,8 @@ def test_function_call11(): def test_function_call12(): state = InterpreterState() data = Program.model_validate(function_call12) - _, _, _, trace = process_prog(state, empty_scope, data) - assert contains_error(trace) + with pytest.raises(PDLRuntimeError): + process_prog(state, empty_scope, data) function_call13 = { @@ -450,8 +442,8 @@ def test_function_call12(): def test_function_call13(): state = InterpreterState() data = Program.model_validate(function_call13) - _, _, _, trace = process_prog(state, empty_scope, data) - assert contains_error(trace) + with pytest.raises(PDLRuntimeError): + process_prog(state, empty_scope, data) function_call14 = { @@ -471,8 +463,8 @@ def test_function_call13(): def test_function_call14(): state = InterpreterState() data = Program.model_validate(function_call14) - _, _, _, trace = process_prog(state, empty_scope, data) - assert contains_error(trace) + with pytest.raises(PDLRuntimeError): + process_prog(state, empty_scope, data) function_call15 = { @@ -493,8 +485,7 @@ def test_function_call14(): def test_function_call15(): state = InterpreterState() data = Program.model_validate(function_call15) - document, _, _, trace = process_prog(state, empty_scope, data) - assert not contains_error(trace) + document, _, _, _ = process_prog(state, empty_scope, data) assert document == "Hello 6.6 street!" @@ -516,8 +507,8 @@ def test_function_call15(): def test_function_call16(): state = InterpreterState() data = Program.model_validate(function_call16) - _, _, _, trace = process_prog(state, empty_scope, data) - assert contains_error(trace) + with pytest.raises(PDLRuntimeError): + process_prog(state, empty_scope, data) function_call17 = { @@ -538,8 +529,7 @@ def test_function_call16(): def test_function_call17(): state = InterpreterState() data = Program.model_validate(function_call17) - document, _, _, trace = process_prog(state, empty_scope, data) - assert not contains_error(trace) + document, _, _, _ = process_prog(state, empty_scope, data) assert document == "[1, 2, 3]" @@ -561,8 +551,8 @@ def test_function_call17(): def test_function_call18(): state = InterpreterState() data = Program.model_validate(function_call18) - _, _, _, trace = process_prog(state, empty_scope, data) - assert contains_error(trace) + with pytest.raises(PDLRuntimeError): + process_prog(state, empty_scope, data) hello = { @@ -575,8 +565,7 @@ def test_function_call18(): def test_hello(): state = InterpreterState() data = Program.model_validate(hello) - document, _, _, trace = process_prog(state, empty_scope, data) - assert not contains_error(trace) + document, _, _, _ = process_prog(state, empty_scope, data) assert document == "Hello, world!" @@ -590,8 +579,7 @@ def test_hello(): def test_hello1(): state = InterpreterState() data = Program.model_validate(hello1) - result, _, _, trace = process_prog(state, empty_scope, data) - assert not contains_error(trace) + result, _, _, _ = process_prog(state, empty_scope, data) assert result == {"a": "Hello", "b": "World"} @@ -605,5 +593,5 @@ def test_hello1(): def test_hello2(): state = InterpreterState() data = Program.model_validate(hello2) - _, _, _, trace = process_prog(state, empty_scope, data) - assert contains_error(trace) + with pytest.raises(PDLRuntimeError): + process_prog(state, empty_scope, data) diff --git a/tests/test_var.py b/tests/test_var.py index bb24c1ed8..9f5fb5a3b 100644 --- a/tests/test_var.py +++ b/tests/test_var.py @@ -1,7 +1,9 @@ +import pytest + from pdl.pdl_ast import Program from pdl.pdl_interpreter import ( InterpreterState, - contains_error, + PDLRuntimeError, empty_scope, process_prog, ) @@ -137,8 +139,8 @@ def test_code_var(): def test_missing_var(): state = InterpreterState() data = Program.model_validate(missing_var) - _, _, _, trace = process_prog(state, empty_scope, data) - assert contains_error(trace) + with pytest.raises(PDLRuntimeError): + process_prog(state, empty_scope, data) missing_call = { @@ -150,5 +152,5 @@ def test_missing_var(): def test_missing_call(): state = InterpreterState() data = Program.model_validate(missing_call) - _, _, _, trace = process_prog(state, empty_scope, data) - assert contains_error(trace) + with pytest.raises(PDLRuntimeError): + process_prog(state, empty_scope, data)