In [5]:
from typing import Optional
from dataclasses import dataclass


# Define the Context Attributes Type
# (Variable name, is it deterministic)
ContextAttribute = tuple[str, bool]


# Define the context using NamedTuple
@dataclass
class Context:
    id: int
    attributes: dict[str, ContextAttribute]

    def __init__(self, id: int, attributes: dict[str, str]):
        self.id = id
        self.attributes = {
            name: (value, True) for name, value in attributes.items()
        }

    # Define the __hash__ method
    def __hash__(self) -> int:
        return hash(self.id)


# Define the Varible type
Variable = str

# Define the Assignment type
Assignment = dict[Variable, Context]

# Assignment binding


def bind(assignment: Assignment, variable: Variable, context: Context) -> Assignment:
    assignment = assignment.copy()
    assignment[variable] = context
    return assignment


# Define the ContextSet type
ContextSet = frozenset[Context]

# Define the ContextSet Map type
# (Set name, Set)
ContextSetMap = dict[str, ContextSet]

# Define the ContextPool type
ContextPool = dict[int, Context]


# Define the ContextPool initialization function
def make_context_pool(contexts: list[Context]) -> ContextPool:
    return {context.id: context for context in contexts}


# Define the ContextPool to ContextSetMap function
def make_context_set_map(pool: ContextPool, name_id: dict[str, frozenset[int]]) -> ContextSetMap:
    return {
        name: frozenset(pool[id] for id in ids)
        for name, ids in name_id.items()
    }

In [9]:
# Define the base class for formula nodes
class FormulaNode:
    def evaluate(self, assignment: Assignment, context_set_map: ContextSetMap) -> bool:
        raise NotImplementedError(
            "Evaluate method should be implemented by subclasses")

    # False to True repair suite under given assignment
    def repair_f2t(self, assignment: Assignment, context_set_map: ContextSetMap, lk: bool = False) -> "RepairSuite":
        raise NotImplementedError(
            "Repair_f2t method should be implemented by subclasses")

    # True to False repair suite under given assignment
    def repair_t2f(self, assignment: Assignment, context_set_map: ContextSetMap, lk: bool = False) -> "RepairSuite":
        raise NotImplementedError(
            "Repair_t2f method should be implemented by subclasses")


# Define NotNode class
class NotNode(FormulaNode):
    def __init__(self, node: FormulaNode):
        self.node = node

    def evaluate(self, assignment: Assignment, context_set_map: ContextSetMap) -> bool:
        return not self.node.evaluate(assignment, context_set_map)

    def repair_f2t(self, assignment: Assignment, context_set_map: ContextSetMap, lk: bool = False) -> "RepairSuite":
        return self.node.repair_t2f(assignment, context_set_map)

    def repair_t2f(self, assignment: Assignment, context_set_map: ContextSetMap, lk: bool = False) -> "RepairSuite":
        return self.node.repair_f2t(assignment, context_set_map)


# Define AndNode class, which may have multiple children
class AndNode(FormulaNode):
    def __init__(self, *nodes: FormulaNode):
        if len(nodes) != 2:
            raise ValueError("AndNode should have 2 children")
        self.nodes = nodes

    def evaluate(self, assignment: Assignment, context_set_map: ContextSetMap) -> bool:
        return all(node.evaluate(assignment, context_set_map) for node in self.nodes)

    def repair_f2t(self, assignment: Assignment, context_set_map: ContextSetMap, lk: bool = False) -> "RepairSuite":
        if not lk:
            match (self.nodes[0].evaluate(assignment), self.nodes[1].evaluate(assignment)):
                case (False, False):
                    return self.nodes[0].repair_f2t(assignment, context_set_map) & self.nodes[1].repair_f2t(assignment, context_set_map)
                case (False, True):
                    return self.nodes[0].repair_f2t(assignment, context_set_map)
                case (True, False):
                    return self.nodes[1].repair_f2t(assignment, context_set_map)
                case (True, True):
                    raise ValueError(
                        "AndNode should not be True under given assignment")
        # Todo: Implement the LK repair
        return RepairSuite({RepairCase(frozenset(), 0)})

    def repair_t2f(self, assignment: Assignment, context_set_map: ContextSetMap, lk: bool = False) -> "RepairSuite":
        return self.nodes[0].repair_t2f(assignment, context_set_map, lk) | self.nodes[1].repair_t2f(assignment, context_set_map, lk)


# Define OrNode class
class OrNode(FormulaNode):
    def __init__(self, *nodes: FormulaNode):
        if len(nodes) < 2:
            raise ValueError("OrNode should have at least 2 children")
        self.nodes = nodes

    def evaluate(self, assignment: Assignment, context_set_map: ContextSetMap) -> bool:
        return any(node.evaluate(assignment, context_set_map) for node in self.nodes)


# Define ImpliesNode class
@dataclass
class ImpliesNode(FormulaNode):
    left: FormulaNode
    right: FormulaNode

    def evaluate(self, assignment: Assignment, context_set_map: ContextSetMap) -> bool:
        return not self.left.evaluate(assignment, context_set_map) or self.right.evaluate(assignment, context_set_map)


# Define EqualsNode class
@dataclass
class EqualsNode(FormulaNode):
    var1: Variable
    attr1: str
    var2: Variable
    attr2: str
    weight: float = 1.0

    def evaluate(self, assignment: Assignment, context_set_map: ContextSetMap) -> bool:
        if self.var1 not in assignment or self.var2 not in assignment:
            raise ValueError("Variables not found in assignment")
        if assignment[self.var1].attributes[self.attr1][1] == False or assignment[self.var2].attributes[self.attr2][1] == False:
            return False
        return assignment[self.var1].attributes[self.attr1][0] == assignment[self.var2].attributes[self.attr2][0]

    def repair_f2t(self, assignment: Assignment, context_set_map: ContextSetMap, lk: bool = False) -> "RepairSuite":
        repair_case = RepairCase(frozenset({EqualAttributeAction(
            assignment[self.var1], self.attr1, assignment[self.var2], self.attr2)}), self.weight)
        return RepairSuite({repair_case})

    def repair_t2f(self, assignment: Assignment, context_set_map: ContextSetMap, lk: bool = False) -> "RepairSuite":
        repair_case = RepairCase(frozenset({UnequalAttributeAction(
            assignment[self.var1], self.attr1, assignment[self.var2], self.attr2)}), self.weight)
        return RepairSuite({repair_case})


# Define ForAllNode class
@dataclass
class ForAllNode(FormulaNode):
    variable: Variable
    node: FormulaNode
    context_set: str
    weight: float = 1.0

    def evaluate(self, assignment: Assignment, context_set_map: ContextSetMap) -> bool:
        for context in context_set_map[self.context_set]:
            if not self.node.evaluate(bind(assignment, self.variable, context), context_set_map):
                return False
        return True

    def repair_f2t(self, assignment: Assignment, context_set_map: ContextSetMap, lk: bool = False) -> "RepairSuite":
        def del_case(context: Context) -> RepairCase:
            return RepairCase(frozenset({RemoveContextAction(context, self.context_set)}), self.weight)

        def bind_assignment(context: Context) -> Assignment:
            return bind(assignment, self.variable, context)

        if lk:
            suites = [RepairSuite({del_case(context)})
                      for context in context_set_map[self.context_set] if not self.node.evaluate(bind_assignment(context), context_set_map)]
        else:
            suites = [RepairSuite({del_case(context)}) | self.node.repair_f2t(bind_assignment(context), context_set_map)
                      for context in context_set_map[self.context_set] if not self.node.evaluate(bind_assignment(context), context_set_map)]
        return RepairSuite.and_all(*suites)

    def repair_t2f(self, assignment: Assignment, context_set_map: ContextSetMap, lk: bool = False) -> "RepairSuite":
        RepairSuite({RepairCase(frozenset(), 0)})

# Define ExistsNode class
@dataclass
class ExistsNode(FormulaNode):
    variable: Variable
    node: FormulaNode
    context_set: str

    def evaluate(self, assignment: Assignment, context_set_map: ContextSetMap) -> bool:
        for context in self.context_set:
            if self.node.evaluate(bind(assignment, self.variable, context), context_set_map):
                return True
        return False


# Abstract Repair Action class
class RepairAction:
    def apply(self, set_map: ContextSetMap) -> ContextSetMap:
        raise NotImplementedError("Apply method should be implemented by subclasses")

    def __str__(self) -> str:
        raise NotImplementedError("__str__ method should be implemented by subclasses")


# Define the AddAttributeAction class
@dataclass(unsafe_hash=True)
class AddContextAction(RepairAction):
    context: Context
    context_set: ContextSet

    def apply(self, set_map: ContextSetMap) -> ContextSetMap:
        set_map = set_map.copy()
        set_map[self.context_set] = set_map[self.context_set].union({self.context})
        return set_map

    def __str__(self) -> str:
        return f"{self.context.id}: +{self.attribute}"


# Define the RemoveAttributeAction class
@dataclass(unsafe_hash=True)
class RemoveContextAction(RepairAction):
    context: Context
    context_set: ContextSet

    def apply(self, set_map: ContextSetMap) -> ContextSetMap:
        set_map = set_map.copy()
        set_map[self.context_set] = set_map[self.context_set].difference({self.context})
        return set_map

    def __str__(self) -> str:
        return f"{self.context.id}: -{self.attribute}"


# Define the UnequalAttributeAction class
@dataclass(unsafe_hash=True)
class UnequalAttributeAction(RepairAction):
    context1: Context
    attribute1: str
    context2: Context
    attribute2: str

    def apply(self, set_map: ContextSetMap) -> ContextSetMap:
        pass

    def __str__(self) -> str:
        return f"{self.context1.id}.{self.attribute1} != {self.context2.id}.{self.attribute2}"

    def eq_action(self) -> list["EqualAttributeAction"]:
        return [EqualAttributeAction(self.context1, self.attribute1, self.context2, self.attribute2), EqualAttributeAction(self.context2, self.attribute2, self.context1, self.attribute1)]

    def is_contradictory(self, disjoint_set: "DisjointSet[Attribute]") -> bool:
        return any(self.eq_action(), lambda action: action.is_contradictory(disjoint_set))


# Define the type for context attribute value pairs
Attribute = tuple[Context, str]


# Define the EqualAttributeAction class
@dataclass(unsafe_hash=True)
class EqualAttributeAction(RepairAction):
    context1: Context
    attribute1: str
    context2: Context
    attribute2: str

    def apply(self, set_map: ContextSetMap) -> ContextSetMap:
        pass

    def __str__(self) -> str:
        return f"{self.context1.id}.{self.attribute1} == {self.context2.id}.{self.attribute2}"
    
    def is_contradictory(self, disjoint_set: "DisjointSet[Attribute]") -> bool:
        return disjoint_set.find((self.context1, self.attribute1)) == disjoint_set.find((self.context2, self.attribute2))


# Define the Union-Find data structure for EqualAttributeAction
class DisjointSet[Element]:
    def __init__(self):
        self.parent = {}
        self.rank = {}

    def find(self, element: Element) -> Element:
        if element not in self.parent:
            self.parent[element] = element
            self.rank[element] = 0
        if self.parent[element] != element:
            self.parent[element] = self.find(self.parent[element])
        return self.parent[element]

    def union(self, element1: Element, element2: Element) -> None:
        root1 = self.find(element1)
        root2 = self.find(element2)
        if root1 == root2:
            return
        if self.rank[root1] < self.rank[root2]:
            root1, root2 = root2, root1
        self.parent[root2] = root1
        if self.rank[root1] == self.rank[root2]:
            self.rank[root1] += 1


# Define the Repair Case class
# A repair case is a tuple of a repair action set and a weight
@dataclass
class RepairCase:
    action: frozenset[RepairAction]
    weight: float
    equal_disjoint_set: Optional[DisjointSet[Attribute]] = None

    # Define the __str__ method
    def __str__(self) -> str:
        actions = "\n".join(str(action) for action in self.action)
        return f"Repair Case (Weight: {self.weight}):\n{actions}\n"

    # Define the __and__ method
    def __and__(self, other: "RepairCase") -> "RepairCase":
        return RepairCase(self.action.union(other.action), self.weight + other.weight)

    # 'And' of multiple repair cases
    @staticmethod
    def and_all(*cases: "RepairCase") -> "RepairCase":
        return RepairCase(frozenset.union(*(case.action for case in cases)), sum(case.weight for case in cases))

    # Define the __hash__ method
    def __hash__(self) -> int:
        return hash((self.action, self.weight))

    # Get the disjoint set of equal attribute actions
    def get_equal_disjoint_set(self) -> DisjointSet[Attribute]:
        disjoint_set = DisjointSet()
        for action in self.action:
            if isinstance(action, EqualAttributeAction):
                disjoint_set.union(
                    (action.context1, action.attribute1), (action.context2, action.attribute2))
        self.equal_disjoint_set = disjoint_set
        return disjoint_set

    # Check if the repair case is contradictory
    def is_contradictory(self) -> bool:
        return any(action.is_contradictory(self.get_equal_disjoint_set()) for action in self.action if isinstance(action, UnequalAttributeAction))


# Define the Repair Suite class
# A repair suite is a set of repair cases
class RepairSuite(frozenset[RepairCase]):
    # Define the __str__ method
    def __str__(self) -> str:
        cases = "\n".join(str(repair_case) for repair_case in self)
        return f"Repair Suite:\n{cases}\n"

    # Define the __or__ method
    def __or__(self, other: "RepairSuite") -> "RepairSuite":
        return RepairSuite(self.union(other))

    # 'Or' of multiple repair suites
    @staticmethod
    def or_all(*suites: "RepairSuite") -> "RepairSuite":
        return RepairSuite(frozenset.union(*(suite for suite in suites)))

    # Define the __and__ method
    def __and__(self, other: "RepairSuite") -> "RepairSuite":
        if self.isdisjoint(other):
            return RepairSuite(self | other)
        return RepairSuite((repair_case1 & repair_case2 for repair_case1 in self for repair_case2 in other))

    # 'And' of multiple repair suites
    @staticmethod
    def and_all(*suites: "RepairSuite") -> "RepairSuite":
        return RepairSuite((repair_case1 & repair_case2 for repair_case1 in suites for repair_case2 in suites))

In [13]:
def example_usage():
    contexts = [
        Context(1, {"Value": "A", "Type": "str"}),
        Context(2, {"Value": "B", "Type": "str"}),
        Context(3, {"Value": "A", "Type": "str"}),
        Context(4, {"Value": "B", "Type": "str"}),
        Context(5, {"Value": 1, "Type": "int"}),
        Context(6, {"Value": 1, "Type": "int"}),
        Context(7, {"Value": 2, "Type": "int"}),
        Context(8, {"Value": 1, "Type": "int"}),
    ]
    name_id = {
        "Set_A": frozenset({1, 2, 3, 4}),
        "Set_B": frozenset({5, 6, 7, 8}),
        "Set_C": frozenset({1, 5}),
        "Set_D": frozenset({1, 3}),
    }

    context_pool = make_context_pool(contexts)
    context_set_map = make_context_set_map(context_pool, name_id)

    print(context_set_map)

    # forall x in Set_A, not forall y in Set_B, x.Value == y.Value

    equals_node = EqualsNode("x", "Value", "y", "Value")
    forall_y_node = ForAllNode("y", NotNode(equals_node), "Set_B")
    forall_x_node = ForAllNode("x", forall_y_node, "Set_A")

    print(forall_x_node.evaluate({}, context_set_map))

    repair = forall_x_node.repair_f2t({}, context_set_map)
    print(repair)

In [14]:
example_usage()

{'Set_A': frozenset({Context(id=1, attributes={'Value': ('A', True), 'Type': ('str', True)}), Context(id=2, attributes={'Value': ('B', True), 'Type': ('str', True)}), Context(id=3, attributes={'Value': ('A', True), 'Type': ('str', True)}), Context(id=4, attributes={'Value': ('B', True), 'Type': ('str', True)})}), 'Set_B': frozenset({Context(id=8, attributes={'Value': (1, True), 'Type': ('int', True)}), Context(id=5, attributes={'Value': (1, True), 'Type': ('int', True)}), Context(id=6, attributes={'Value': (1, True), 'Type': ('int', True)}), Context(id=7, attributes={'Value': (2, True), 'Type': ('int', True)})}), 'Set_C': frozenset({Context(id=1, attributes={'Value': ('A', True), 'Type': ('str', True)}), Context(id=5, attributes={'Value': (1, True), 'Type': ('int', True)})}), 'Set_D': frozenset({Context(id=1, attributes={'Value': ('A', True), 'Type': ('str', True)}), Context(id=3, attributes={'Value': ('A', True), 'Type': ('str', True)})})}
True
Repair Suite:




In [20]:
# Create and apply repair actions
context_to_repair = Context(3, {"attribute1": "value3"})
context_set = {context_to_repair}
add_action = AddContextAction(context_to_repair, context_set)
remove_action1 = RemoveContextAction(context_to_repair, context_set)
remove_action2 = RemoveContextAction(context_to_repair, context_set)

# Apply repair actions
add_action.apply()

# Create repair cases and suite
repair_case1 = RepairCase(frozenset([add_action]), 1.0)
repair_case2 = RepairCase(frozenset([add_action, remove_action2]), 1.0)
repair_case3 = repair_case1 & repair_case2

print(repair_case1)
print(repair_case2)
print(repair_case3)

# Create repair suites
repair_suite1 = RepairSuite([repair_case1, repair_case2])
repair_suite2 = RepairSuite([repair_case2])

# Combine repair suites
repair_suite3 = repair_suite1 | repair_suite2
repair_suite4 = repair_suite1 & repair_suite2

print(repair_suite3)
print(repair_suite4)

TypeError: unhashable type: 'set'