In [1]:
import sys
import time

from freamon import Freamon
from freamon.templates import Output, SourceType


frm = Freamon('freamon-benchmarks-fairness', './mlruns')

captured_pipeline = None
cmd_args = []


with frm.pipeline_from_py_file('pipelines--mlinspect--credit.py', cmd_args=cmd_args) as pipeline:
    captured_pipeline = pipeline

In [24]:
def compute_opt(pipeline, sensitive_attribute, non_protected_class):
    fact_table_index, fact_table_source = [
        (index, test_source) for index, test_source in enumerate(pipeline.test_sources)
                             if test_source.source_type == SourceType.ENTITIES][0]

    fact_table_lineage = pipeline.test_source_lineage[fact_table_index]    

    is_in_non_protected = list(fact_table_source.data[sensitive_attribute] == non_protected_class)
    
    y_pred = pipeline.outputs[Output.Y_PRED]
    lineage_y_pred = pipeline.output_lineage[Output.Y_PRED]

    # Compute the confusion matrix per group
    y_test = pipeline.outputs[Output.Y_TEST]

    non_protected_false_negatives = 0
    non_protected_true_positives = 0
    non_protected_true_negatives = 0
    non_protected_false_positives = 0

    protected_false_negatives = 0
    protected_true_positives = 0
    protected_true_negatives = 0
    protected_false_positives = 0

    for index, polynomial in enumerate(lineage_y_pred):
        for entry in polynomial:
            if entry.operator_id == fact_table_source.operator_id:
                # Positive ground truth label
                if y_test[index] == 1.0:
                    if is_in_non_protected[entry.row_id]:
                    #if is_in_non_protected_by_row_id[entry.row_id]:
                        if y_pred[index] == 1.0:
                            non_protected_true_positives += 1
                        else:
                            non_protected_false_negatives += 1
                    else:
                        if y_pred[index] == 1.0:
                            protected_true_positives += 1
                        else:
                            protected_false_negatives += 1
                # Negative ground truth label
                else:
                    if is_in_non_protected[entry.row_id]:
                    #if is_in_non_protected_by_row_id[entry.row_id]:
                        if y_pred[index] == 1.0:
                            non_protected_false_positives += 1
                        else:
                            non_protected_true_negatives += 1
                    else:
                        if y_pred[index] == 1.0:
                            protected_false_positives += 1
                        else:
                            protected_true_negatives += 1

    return non_protected_true_negatives, non_protected_false_positives, non_protected_false_negatives, \
           non_protected_true_positives, protected_true_negatives, protected_false_positives, \
           protected_false_negatives, protected_true_positives

In [19]:
def compute_naive(pipeline, sensitive_attribute, non_protected_class):
    fact_table_index, fact_table_source = [
        (index, test_source) for index, test_source in enumerate(pipeline.test_sources)
                             if test_source.source_type == SourceType.ENTITIES][0]

    fact_table_lineage = pipeline.test_source_lineage[fact_table_index]    

    y_pred = pipeline.outputs[Output.Y_PRED]
    lineage_y_pred = pipeline.output_lineage[Output.Y_PRED]

    # Compute the confusion matrix per group
    y_test = pipeline.outputs[Output.Y_TEST]

    non_protected_false_negatives = 0
    non_protected_true_positives = 0
    non_protected_true_negatives = 0
    non_protected_false_positives = 0

    protected_false_negatives = 0
    protected_true_positives = 0
    protected_true_negatives = 0
    protected_false_positives = 0    
    
    
    for tuple_index, annotation in enumerate(fact_table_lineage):
        tuple_annotation = list(annotation)[0]
        for index, polynomial in enumerate(lineage_y_pred):
            result = 1.0
            for variable in polynomial:
                if variable.operator_id == tuple_annotation.operator_id \
                        and variable.row_id == tuple_annotation.row_id:
                    result *= 0.0
                else:
                    result *= 1.0
            if result == 0.0:

                row = fact_table_source.data.iloc[tuple_index]
                
                is_non_protected = \
                    row[sensitive_attribute] == non_protected_class
                
                if y_test[index] == 1.0:
                    if is_non_protected:
                        if y_pred[index] == 1.0:
                            non_protected_true_positives += 1
                        else:
                            non_protected_false_negatives += 1
                    else:
                        if y_pred[index] == 1.0:
                            protected_true_positives += 1
                        else:
                            protected_false_negatives += 1
                # Negative ground truth label
                else:
                    if is_non_protected:
                        if y_pred[index] == 1.0:
                            non_protected_false_positives += 1
                        else:
                            non_protected_true_negatives += 1
                    else:
                        if y_pred[index] == 1.0:
                            protected_false_positives += 1
                        else:
                            protected_true_negatives += 1                
                
                break

    return non_protected_true_negatives, non_protected_false_positives, non_protected_false_negatives, \
           non_protected_true_positives, protected_true_negatives, protected_false_positives, \
           protected_false_negatives, protected_true_positives

In [22]:
naive_start = time.time()
compute_naive(captured_pipeline, 'sex', 'Male')
naive_duration = time.time() - naive_start
naive_duration

7.774752140045166

In [26]:
opt_start = time.time()
res = compute_opt(captured_pipeline, 'sex', 'Male')
opt_duration = time.time() - opt_start
opt_duration

0.008813142776489258

In [27]:
res

(828, 95, 187, 173, 713, 49, 55, 28)