In [14]:
import re
import math
import sys
sys.path.append("../../src/")
import uncertainpy.gradual as grad
import time
from statistics import median
import random
random.seed(0)

In [15]:
# Parse a QBAF file to an adjacency matrix
def parse_qbaf(filename):
    num_nodes = 0
    with open(filename, 'r') as file:
        for line in file:
            match = re.match(r'(\w+)\((\d+), ([\d.]+)\)', line)
            if match:
                type, node_id1, weight = match.groups()
                if type == 'arg':
                    num_nodes += 1
            else:
                raise ValueError(f"can't parse this line: {line}")
        adj_matrix = [[0] * num_nodes for _ in range(num_nodes)]

    with open(filename, 'r') as file:
        for line in file:
            match = re.match(r'(\w+)\((\d+), ([\d.]+)\)', line)
            if match:
                type, node_id1, node_id2 = match.groups()
                if type == 'att':
                    adj_matrix[int(node_id1)][int(node_id2)] = -1
                elif type == 'sup':
                    adj_matrix[int(node_id1)][int(node_id2)] = 1
            else:
                raise ValueError(f"can't parse this line: {line}")
    return adj_matrix

In [16]:
# Convert the adjacency matrix to a python dictionary
def adjacency_matrix_to_dict(adj_matrix):
    graph_dict = {}

    for i, row in enumerate(adj_matrix):
        neighbors = set(j for j, weight in enumerate(row) if weight != 0)
        graph_dict[i] = neighbors

    return graph_dict

In [17]:
# Check if a node contains a loop to itself
def find_cycles(graph, start, current, visited, path, cycles):
    visited[current] = True
    path.append(current)
    for neighbor in graph[current]:
        if neighbor == start and len(path) > 0:
            # find a cycle
            cycles.append(path.copy()+[neighbor])
        elif not visited[neighbor]:
            find_cycles(graph, start, neighbor, visited, path, cycles)
    # backtracking
    visited[current] = False
    path.pop()

# Check all the nodes in a QBAF if they contain loops to themselves, if there is any, then the QBAF is cyclic.
def find_all_cycles(graph):
    cycles = []
    num_nodes = len(graph)
    visited = [False] * num_nodes
    for node in range(num_nodes):
        find_cycles(graph, node, node, visited, [], cycles)
    return cycles

# Compute the polarity of a cycle (cycles are also paths).
def compute_cycle_polarity(cycles, adj_matrix):
    cycles_polarity_dict = {}

    for cycle in cycles:
        cycles_polarity_dict[cycle[0]] = 1 # initialise

    for cycle in cycles:
        if cycles_polarity_dict[cycle[0]] != -1:
            temp = 1
            for i in range(len(cycle)-1):
                temp *= adj_matrix[cycle[i]][cycle[i+1]]
            if temp == -1:
                cycles_polarity_dict[cycle[0]] = -1
                continue
    return cycles_polarity_dict

In [18]:
# Compute all the paths between two arguments
def find_all_paths_between_two_args(graph, start, end, path=[]):
    path = path + [start]
    if start == end:
        return [path]
    if start not in graph:
        return []
    paths = []
    for node in graph[start]:
        if node not in path:
            new_paths = find_all_paths_between_two_args(graph, node, end, path)
            for new_path in new_paths:
                paths.append(new_path)
    return paths

def find_all_paths_between_two_args_complete(graph, start, end, path=[]):
    if start == end:
        return []
    else:
        paths = find_all_paths_between_two_args(graph, start, end)
        return paths

In [19]:
# Compute the (direct or indirect) polarity from one node to another, return an integer
def compute_polarity_between_two_args(graph, start_node, end_node, cycles_polarity, adj_matrix):
    paths = find_all_paths_between_two_args_complete(graph, start_node, end_node)
    paths_count = len(paths)
    # print(paths_count)
    # print(paths)
    if start_node == end_node:
        return 1 # positive
    if paths_count == 0:
        return -2 # neutral
    elif paths_count >= 1:
        node_in_paths = {element for sublist in paths for element in sublist}
        for node in node_in_paths:
            if node in cycles_polarity and cycles_polarity[node] == -1:
                return 0 # contains negative cycles, thus unknown

        recodr_polarity = [1] * paths_count
        for j in range(paths_count):
            for i in range(len(paths[j])-1):
                recodr_polarity[j] *= adj_matrix[paths[j][i]][paths[j][i+1]]

        if all(x == 1 for x in recodr_polarity):
            pol = 1 # positive
        elif all(x == -1 for x in recodr_polarity):
            pol = -1 # negative
        else:
            pol = 0 # neutral

        return pol

In [20]:
# Compute the polarity for all the nodes to the topic argument, return a vector
def compute_polarity_vector(graph, topic_arg, cycles_polarity, adj_matrix):
    num_nodes = len(graph)
    polarity_vector = [-2] * num_nodes
    for i in range(num_nodes):
        polarity_vector[i] = compute_polarity_between_two_args(graph, i, topic_arg, cycles_polarity, adj_matrix)
    return polarity_vector

In [21]:
# Compute priority (not polarity) between two arguments
def compute_priority_between_two_args(graph, start, end):
    if start == end:
        return 3 # itself
    paths = find_all_paths_between_two_args(graph, start, end)
    if len(paths) == 0:
        return 0 # disconnected
    min_length = min(len(path) for path in paths) - 1
    priority = 1 / min_length # single/multi-path connected
    return priority

In [22]:
# Compute the priority for all arguments to the topic argument
def compute_priority_vector(graph, topic_arg):
    num_nodes = len(graph)
    priority_vector = [0] * num_nodes
    for i in range(num_nodes):
        priority_vector[i] = compute_priority_between_two_args(graph, i, topic_arg)
    return priority_vector

In [23]:
# Compute different quotiont from one argument to the topic argument
def diff_quotient(bag, arg, topic_arg, h, agg_f, inf_f):

    sigma = bag.arguments[str(topic_arg)].strength

    arg_initial = arg.get_initial_weight()
    arg.reset_initial_weight(arg_initial + h)

    grad.algorithms.computeStrengthValues(bag, agg_f, inf_f)
    sigma_new = bag.arguments[str(topic_arg)].strength

    arg.reset_initial_weight(arg_initial)
    grad.algorithms.computeStrengthValues(bag, agg_f, inf_f)

    return (sigma_new - sigma) / h

In [24]:
# compute distance: the distance between basescore1 and basescore2
def compute_bs_dist(bs1, bs2):

    l1_dist = 0
    l2_dist = 0

    for item in bs1:
        l1_dist += abs(bs1[item] - bs2[item])
        l2_dist += (bs1[item] - bs2[item]) ** 2

    return l1_dist, math.sqrt(l2_dist)

In [25]:
def compute_potential_cause_cyclic(filename, topic_arg, desired_strength, updating_step, Epoch, polarity, priority, model):

    model.approximator = grad.algorithms.RK4(model)
    model.BAG = grad.BAG(filename)
    model.solve(delta=10e-2, epsilon=10e-4, verbose=False, generate_plot=False)

    pre_strength = -1
    update_list = [0] * len(model.BAG.arguments)

    # print(f"desired_strength:{desired_strength}")
    initial_strength = model.BAG.arguments[str(topic_arg)].strength
    cur_strength = initial_strength
    print(f"cur_strength:{initial_strength}")

    # compute update_list for disconnected, pos, neg arguments
    for arg in model.BAG.arguments.values():
        if polarity[int(arg.name)] == -2: # disconnected
            update_list[int(arg.name)] = 0
        elif (desired_strength - cur_strength) * polarity[int(arg.name)] > 0:
            # arg.reset_initial_weight(arg.get_initial_weight() + update)
            update_list[int(arg.name)] = updating_step
        elif (desired_strength - cur_strength) * polarity[int(arg.name)] < 0:
            # arg.reset_initial_weight(arg.get_initial_weight() - update)
            update_list[int(arg.name)] = -updating_step

    for epoch in range(Epoch):
        if pre_strength == cur_strength:
            # print(f"pre_strength:{pre_strength}")
            break
        if (desired_strength - cur_strength) * (desired_strength - initial_strength) > 0:
            print(f"Epoch:{epoch}=====================================================")

            # compute update_list for unknown arguments
            for arg in model.BAG.arguments.values():
                if polarity[int(arg.name)] == 0:
                    initial = arg.get_initial_weight()
                    arg.reset_initial_weight(initial + updating_step)
                    if abs(desired_strength - model.BAG.arguments[str(topic_arg)].strength) < abs(desired_strength - cur_strength):
                        update_list[int(arg.name)] = updating_step
                    elif abs(desired_strength - model.BAG.arguments[str(topic_arg)].strength) > abs(desired_strength - cur_strength):
                        update_list[int(arg.name)] = -updating_step
                    else:
                        update_list[int(arg.name)] = 0
                    arg.reset_initial_weight(initial)

            # finishing computing the update_list
            # print(f"update_list:{update_list}")

            # update all the base scores
            for arg in model.BAG.arguments.values():
                arg.reset_initial_weight(max(0, min(1, arg.get_initial_weight() + update_list[int(arg.get_name())] * priority[int(arg.get_name())])))

            # re-compute the strength
            print(f"desired_strength:{desired_strength}")
            pre_strength = cur_strength
            model.solve(delta=10e-2, epsilon=10e-4, verbose=False, generate_plot=False)
            cur_strength = model.BAG.arguments[str(topic_arg)].strength
            print(f"cur_strength:{cur_strength}")

            # print current bs
            current_bs_dict = {}
            for arg in model.BAG.arguments.values():
                current_bs_dict[arg.get_name()] = arg.get_initial_weight()
            # print(f"current base scores: {current_bs_dict}")
    return current_bs_dict

In [26]:
def main():

    start_time = time.time() # record the current time


    # set parameters
    filename = f'../../bags/approx_0.bag'
    topic_arg = 9
    desired_strength = 0.5
    updating_step = 0.01
    Epoch = 100 # iterate 100 times for each QBAF


    # obtain origin_base_score_dict
    bag = grad.BAG(filename)
    origin_base_score_dict = {}
    for arg in bag.arguments.values():
        origin_base_score_dict[arg.name] = arg.get_initial_weight()
    # print(f"origin_base_score_dict:{origin_base_score_dict}")


    # compute cycles, polarity, priority
    adj_matrix = parse_qbaf(filename)
    graph_dict = adjacency_matrix_to_dict(adj_matrix)
    cycles = find_all_cycles(graph_dict)
    cycles_polarity_dict = compute_cycle_polarity(cycles, adj_matrix)
    polarity = compute_polarity_vector(graph_dict, topic_arg, cycles_polarity_dict, adj_matrix)
    # priority = [1] * len(origin_base_score_dict)
    priority = compute_priority_vector(graph_dict, topic_arg)
    # print(f"cycle polarity:{cycles_polarity_dict}")
    # print(f"polarity:{polarity}")
    # print(f"priority:{priority}")


    # compute a potential cause
    model = grad.semantics.QuadraticEnergyModel()
    potential_cause_dict = compute_potential_cause_cyclic(filename, topic_arg, desired_strength, updating_step, Epoch, polarity, priority, model)


    # compute distance
    l1_dist, l2_dist = compute_bs_dist(origin_base_score_dict, potential_cause_dict)
    print(f"l1_dist:{l1_dist}")
    print(f"l2_dist:{l2_dist}")

    end_time = time.time()
    runtime = end_time - start_time
    print(f"Runtime: {runtime}")


    print(f"================================================")
    print(f"Summary Results:")
    print(f"l1_dist:{l1_dist}")
    print(f"l2_dist:{l2_dist}")
    print(f"runtime:{runtime}")


if __name__ == "__main__":
    main()

cur_strength:0.3329364165950608
desired_strength:0.5
cur_strength:0.3594487548895628
desired_strength:0.5
cur_strength:0.3870045456717425
desired_strength:0.5
cur_strength:0.41425275725019534
desired_strength:0.5
cur_strength:0.4410466132388782
desired_strength:0.5
cur_strength:0.4675605451281033
desired_strength:0.5
cur_strength:0.4937024838001584
desired_strength:0.5
cur_strength:0.5194632455731726
l1_dist:0.5074999999999993
l2_dist:0.23456964140030284
Runtime: 0.20605087280273438
Summary Results:
l1_dist:0.5074999999999993
l2_dist:0.23456964140030284
runtime:0.20605087280273438
cur_strength:0.3329364165950608
desired_strength:0.5
cur_strength:0.3594487548895628
desired_strength:0.5
cur_strength:0.3870045456717425
desired_strength:0.5
cur_strength:0.41425275725019534
desired_strength:0.5
cur_strength:0.4410466132388782
desired_strength:0.5
cur_strength:0.4675605451281033
desired_strength:0.5
cur_strength:0.4937024838001584
desired_strength:0.5
cur_strength:0.5194632455731726
l1_dist: