#### <u>Another approach to find the bifactual causes</u>.

In this notebook, we will implement a novel algorithm to identify bifactual causes, as presented by Tim Miller in his article 'Contrastive Explanation: A Structural Model Approach'. This algorithm is based on the counterfactual explanations proposed by Ryma Boumazouza in her thesis.<br><br>

We will work within a very specific framework, that of the DAC software project "CausaLytics", realised by Faten Racha Said, Ahmed Abdelaziz Mokeddem and Yacine B.D.Chettab.

In [2]:
import sympy as sp
from util import *

##### <u>1. The skeleton of the algorithm</u>:
In this section, we will construct the classifier and implement most of the functions necessary to subsequently develop our algorithm.

In [3]:
# Building the classifier:

L, B, P, M1, M2 = sp.symbols('L B P M1 M2')

admit = sp.And(L, B, P, M2)
waiti = sp.And(L, B, P, M1)

# Symbols map:
mapping = {M1: 1, M2: 2, L: 3, B: 4, P: 5}

# Classes:
classes: dict[int, int] = {1: 1, 2: 1, 3: 2, 4: 3, 5: 4}
inv_class: dict[int, str] = {1: 'M', 2: 'L', 3: 'B', 4: 'P'}

In [4]:
# Building the student's formula:

def numeric_to_symbolic(student: list[int])-> sp.core.expr.Expr:
    """
    Converts the student's numerical marks into a symbolic logic formula.

    Parameters:
    ----------
    student: list[int]
        A list of integers representing the student's marks in the semesters, and the marks in the modules 
        "Logique", "BDD", and "PS".

    Precondition:
        len(student) = 8
        The order of the grades is as follows: S1, S2, S3, S4, S5, Logique, BDD, PS.

    Returns:
    -------
    sympy.core.expr.Expr
        A sympy expression representing the logical formula based on the student's marks.
    """

    l = sp.sympify('L') if (student[5] >= 12) else sp.sympify('~L')
    b = sp.sympify('B') if (student[6] >= 12) else sp.sympify('~B')
    p = sp.sympify('P') if (student[7] >= 12) else sp.sympify('~P')

    res = sp.And(l, b, p)

    moy = (student[0] + student[1] + student[2] + student[3] + student[4]) / 5
    if(moy >= 14):
        res = sp.And(sp.sympify('M1 & M2'), res)
    elif(moy >= 12):
        res = sp.And(sp.sympify('M1 & ~M2'), res)
    else:
        res = sp.And(sp.sympify('~M1 & ~M2'), res)

    return res

In [5]:
def extract_events(formula):
    """
    Helper function to extract events from a sympy logical formula.
    
    Parameters:
    ----------
    formula: sympy.core.expr.Expr
        A sympy logical formula.
        
    Returns:
    -------
    set[int]
        A set of integers representing the events associated with student1.
    """
    events = set()
    
    def _extract_symbols(expr, sign=1):
        if isinstance(expr, sp.Symbol):
            events.add(sign * mapping[expr])
        elif isinstance(expr, sp.Not):
            _extract_symbols(expr.args[0], -sign)
        elif isinstance(expr, sp.And) or isinstance(expr, sp.Or):
            for arg in expr.args:
                _extract_symbols(arg, sign)
        else:
            for arg in expr.args:
                _extract_symbols(arg, sign)
    
    _extract_symbols(formula)
    
    return events

def commun_events(student1, student2) -> set[int]:
    """
    Returns a set of integers representing the common events between student1 and student2.

    Parameters:
    ----------
    student1 : sympy.core.expr.Expr
        A sympy logical formula representing the events associated with student1.
        
    student2 : sympy.core.expr.Expr
        A sympy logical formula representing the events associated with student2.

    Returns:
    -------
    set[int]
        A set of integers representing the common events between student1 and student2 based on their logical formulas.
    """
    events1 = extract_events(student1)
    events2 = extract_events(student2)
    
    common_events = events1 & events2  # Intersection of two sets
    
    return common_events

In [6]:
def filtering(cFx1: list[set[int]], cFx2: list[set[int]], communEvents: set[int])-> tuple[list[set[int]], list[set[int]]]:
    """
    Filters the counterfactual explanations related to the decision token regarding a student by the classifier.

    Parameters:
    ----------
    cFx1 : list[set[int]]
        A list of sets representing counterfactual explanations for the classifier's decision regarding student 1.

    cFx2 : list[set[int]]
        A list of sets representing counterfactual explanations for the classifier's decision regarding student 2.

    communEvents : set[int]
        A set of integers representing the common events that should be removed from cFx1 and cFx2.

    Returns:
    -------
    tuple
        A tuple containing two lists:
        - The first list contains cFx1 with communEvents removed.
        - The second list contains cFx2 with communEvents removed.
    """

    _cFx1 = list(filter(lambda z: z != set(), map(lambda x: set(filter(lambda y: y not in communEvents, x)), cFx1)))
    _cFx2 = list(filter(lambda z: z != set(), map(lambda x: set(filter(lambda y: y not in communEvents, x)), cFx2)))

    return _cFx1, _cFx2

In [10]:
def removeContainedSubsets(r: list[set[int]])-> list[set[int]]:
    """
    Remove subsets sets from the collection of sets.
    
    Parameters:
    ----------
    r (list[set[int]]): The list of sets

    Returns:
    List of sets: A list containing only the non Contained Subsets.
    """
    res = [s1 for s1 in r if not any(s2 > s1 for s2 in r)]
    return [res[i] for i in range (len(res)) if (res[i] not in res[i + 1: ])]

def scs(set1: list[set[int]], set2: list[set[int]])-> list[set[int]]:
    """
    Returns the intersection of sets from set1 and set2, removing those that are subsets of others.

    Parameters:
    ----------
    set1 : list[set[int]]
        List of sets of integers.
        
    set2 : list[set[int]]
        List of sets of integers.
    """

    res = []

    set1 = list(map(lambda x: set(map(lambda y: classes[abs(y)], x)), set1))
    set2 = list(map(lambda x: set(map(lambda y: classes[abs(y)], x)), set2))

    for x in set1:
        for y in set2:
            tmp = x & y
            if (tmp != set()):
                res.append(tmp)
            
    return removeContainedSubsets(res)

##### <u>2. At the core of the algorithm</u>:
In this section, we will implement an algorithm to find bi-factual contrastive causes for two instances with **different predictions**.

In [8]:
def causes(student1, student2)-> list[set[str]]:
    """
    Returns the bi-factuals contrastives causes.

    Parameters:
    ----------
    student1 : Tuple[sympy.core.expr.Expr, int]
        A sympy logical formula representing the events associated with student1 and its prediction Δ(student1).
        
    student2 : Tuple[sympy.core.expr.Expr, int]
        A sympy logical formula representing the events associated with student2 and its prediction Δ(student2).

    Precondition:
    ------------
    Δ(student1) != Δ(student2)
    """

    st1, pred1 = student1
    st2, pred2 = student2

    st1 = numeric_to_symbolic(st1)
    print(st1)
    st2 = numeric_to_symbolic(st2)
    print(st2)
    # Computing the conterfactuals explanations.
    if(max(pred1, pred2) == 2):
        _, _, cFx1 = xPred(admit, st1, mapping)
        _, _, cFx2 = xPred(admit, st2, mapping)
    
    else:
        _, _, cFx1 = xPred(waiti, st1, mapping)
        _, _, cFx2 = xPred(waiti, st2, mapping)

    # Computing the common events between student1 and student2.
    communEvents = commun_events(st1, st2)
    # Filtering.
    _cFx1, _cFx2 = filtering(cFx1, cFx2, communEvents)

    # Finding the largest set having same classes.
    return list( map( lambda x: set(map(lambda y: inv_class[y], x)), scs(_cFx1, _cFx2) ))
