In [125]:
from typing import SupportsIndex, Optional, Union

import clingo
import clingo.ast
from clingo.ast import ProgramBuilder

In [126]:
program = """

head1(X,Y,Z) :- body1(X), body2(Y), body3(Z).

head2(X) :- body1(X), body2(X-1).

head3(X) :- body4(X).

"""

In [127]:
def get_parsed_program(program: str):
    nodes = []
    clingo.ast.parse_string(program, lambda stm: nodes.append(stm))
    return nodes

In [128]:
nodes = get_parsed_program(program)
nodes

[ast.Program(Location(begin=Position(filename='<string>', line=1, column=1), end=Position(filename='<string>', line=1, column=1)), 'base', []),
 ast.Rule(Location(begin=Position(filename='<string>', line=3, column=1), end=Position(filename='<string>', line=3, column=46)), ast.Literal(Location(begin=Position(filename='<string>', line=3, column=1), end=Position(filename='<string>', line=3, column=13)), 0, ast.SymbolicAtom(ast.Function(Location(begin=Position(filename='<string>', line=3, column=1), end=Position(filename='<string>', line=3, column=13)), 'head1', [ast.Variable(Location(begin=Position(filename='<string>', line=3, column=7), end=Position(filename='<string>', line=3, column=8)), 'X'), ast.Variable(Location(begin=Position(filename='<string>', line=3, column=9), end=Position(filename='<string>', line=3, column=10)), 'Y'), ast.Variable(Location(begin=Position(filename='<string>', line=3, column=11), end=Position(filename='<string>', line=3, column=12)), 'Z')], 0))), [ast.Literal(

In [129]:
pos = clingo.ast.Position('<string>', 0, 0)
loc = clingo.ast.Location(pos, pos)

In [130]:
def _head_as_aggregate(rule: clingo.ast.AST, left_guard: Union[str, int, None] = None, right_guard: Union[str, int, None] = None):
    head = rule.head
    body = rule.body
    if head.ast_type is clingo.ast.ASTType.Literal:
        elements = [clingo.ast.ConditionalLiteral(loc, head, [])]
    else:
        elements = head.elements
    if isinstance(left_guard, str):
        left_guard = clingo.ast.AggregateGuard(clingo.ast.ComparisonOperator.LessEqual, clingo.ast.Variable(loc, left_guard))
    elif isinstance(left_guard, int):
        left_guard = clingo.ast.AggregateGuard(clingo.ast.ComparisonOperator.LessEqual, clingo.ast.SymbolicTerm(loc, clingo.Number(left_guard)))
    if isinstance(right_guard, str):
        right_guard = clingo.ast.AggregateGuard(clingo.ast.ComparisonOperator.LessEqual, clingo.ast.Variable(loc, right_guard))
    elif isinstance(right_guard, int):
        right_guard = clingo.ast.AggregateGuard(clingo.ast.ComparisonOperator.LessEqual, clingo.ast.SymbolicTerm(loc, clingo.Number(right_guard)))

    head_aggregate = clingo.ast.Aggregate(loc, left_guard, elements, right_guard)
    return clingo.ast.Rule(loc, head_aggregate, body)

In [131]:
rule = nodes[1]
print(rule)

head1(X,Y,Z) :- body1(X); body2(Y); body3(Z).


In [132]:
new_rule_1 = _head_as_aggregate(rule)
print(new_rule_1)

{ head1(X,Y,Z) } :- body1(X); body2(Y); body3(Z).


In [133]:
new_rule_2 = _head_as_aggregate(new_rule_1, 2)
print(new_rule_2)

2 <= { head1(X,Y,Z) } :- body1(X); body2(Y); body3(Z).


In [134]:
new_rule_3 = _head_as_aggregate(new_rule_2, right_guard=4)
print(new_rule_3)

{ head1(X,Y,Z) } <= 4 :- body1(X); body2(Y); body3(Z).


In [135]:
new_rule_4 = _head_as_aggregate(new_rule_3, left_guard=2, right_guard=4)
print(new_rule_4)

2 <= { head1(X,Y,Z) } <= 4 :- body1(X); body2(Y); body3(Z).


In [136]:
new_rule_5 = _head_as_aggregate(new_rule_4, left_guard='X', right_guard='Y')
print(new_rule_5)

X <= { head1(X,Y,Z) } <= Y :- body1(X); body2(Y); body3(Z).


In [137]:
ctl = clingo.Control()
with ProgramBuilder(ctl) as pb:
    pb.add(new_rule_1)
    pb.add(new_rule_2)
    pb.add(new_rule_3)
    pb.add(new_rule_4)
    pb.add(new_rule_5)

In [138]:
ctl.ground([('base', ())])

<string>:3:17-25: info: atom does not occur in any rule head:
  body1(X)

<string>:3:27-35: info: atom does not occur in any rule head:
  body2(Y)

<string>:3:37-45: info: atom does not occur in any rule head:
  body3(Z)



In [None]:
def _head_insert_symbol(rule: clingo.ast.AST):
    pass