# DSC180A Causal Discovery of Remote Work and Mental Health
group members: Evelyn, Vivan, Jason, Yishan

## Setup

In [2]:
from causallearn.utils.TXT2GeneralGraph import txt2generalgraph
from causallearn.graph.ArrowConfusion import ArrowConfusion
from causallearn.graph.AdjacencyConfusion import AdjacencyConfusion
from causallearn.graph.SHD import SHD

In [3]:
def display_causal_evaluation(truth_cpdag, estimated_graph, show_matrices=False):
    """
    Display evaluation metrics for causal graph comparison in a clear, organized format.
    
    Parameters:
    -----------
    truth_cpdag : GeneralGraph
        The ground truth CPDAG
    estimated_graph : GeneralGraph
        The estimated graph to evaluate
    show_matrices : bool, optional
        Whether to show the confusion matrices (default: False)
    """
    # Calculate all metrics
    arrow = ArrowConfusion(truth_cpdag, estimated_graph)
    adj = AdjacencyConfusion(truth_cpdag, estimated_graph)
    shd = SHD(truth_cpdag, estimated_graph).get_shd()
    
    # Format precision and recall as percentages
    def format_metric(value):
        if value is None or value != value:  # Check for None or NaN
            return "N/A"
        return f"{value * 100:.2f}%"
    
    print("\n=== Causal Graph Evaluation Results ===\n")
    
    # Arrow Metrics
    print("Arrow Metrics:")
    print("-" * 40)
    print(f"True Positives (TP):  {arrow.get_arrows_tp()}")
    print(f"False Positives (FP): {arrow.get_arrows_fp()}")
    print(f"False Negatives (FN): {arrow.get_arrows_fn()}")
    print(f"True Negatives (TN):  {arrow.get_arrows_tn()}")
    print(f"Precision:            {format_metric(arrow.get_arrows_precision())}")
    print(f"Recall:              {format_metric(arrow.get_arrows_recall())}")
    
    if show_matrices:
        print("\nArrow Confusion Matrix:")
        print("-" * 40)
        print("      Predicted")
        print("Actual  Arrow  No Arrow")
        print(f"Arrow    {arrow.get_arrows_tp()}       {arrow.get_arrows_fn()}")
        print(f"No Arrow {arrow.get_arrows_fp()}       {arrow.get_arrows_tn()}")
    
    # Adjacency Metrics
    print("\nAdjacency Metrics:")
    print("-" * 40)
    print(f"True Positives (TP):  {adj.get_adj_tp()}")
    print(f"False Positives (FP): {adj.get_adj_fp()}")
    print(f"False Negatives (FN): {adj.get_adj_fn()}")
    print(f"True Negatives (TN):  {adj.get_adj_tn()}")
    print(f"Precision:            {format_metric(adj.get_adj_precision())}")
    print(f"Recall:              {format_metric(adj.get_adj_recall())}")
    
    if show_matrices:
        print("\nAdjacency Confusion Matrix:")
        print("-" * 40)
        print("      Predicted")
        print("Actual  Edge  No Edge")
        print(f"Edge     {adj.get_adj_tp()}       {adj.get_adj_fn()}")
        print(f"No Edge  {adj.get_adj_fp()}       {adj.get_adj_tn()}")
    
    # Structural Hamming Distance
    print("\nStructural Hamming Distance (SHD):")
    print("-" * 40)
    print(f"SHD: {shd}")
    
    # Calculate F1 scores
    def calculate_f1(precision, recall):
        if precision is None or recall is None or precision != precision or recall != recall:
            return None
        if precision + recall == 0:
            return 0
        return 2 * (precision * recall) / (precision + recall)
    
    arrow_f1 = calculate_f1(arrow.get_arrows_precision(), arrow.get_arrows_recall())
    adj_f1 = calculate_f1(adj.get_adj_precision(), adj.get_adj_recall())
    
    # Summary metrics
    print("\nSummary Metrics:")
    print("-" * 40)
    print(f"Arrow F1 Score:     {format_metric(arrow_f1) if arrow_f1 is not None else 'N/A'}")
    print(f"Adjacency F1 Score: {format_metric(adj_f1) if adj_f1 is not None else 'N/A'}")

## Evaluation

In [4]:
# ground truth graph with general graph format
G = txt2generalgraph('file/groundtruth.txt')

In [5]:
#result to compare
pc_5_fisherz = txt2generalgraph('file/pc_0.05_fisherz.txt')
fci_1_fisherz = txt2generalgraph('file/fci_0.01_fisherz.txt')
fci_1_kci = txt2generalgraph('file/fci_0.01_kci.txt')
fci_1_fastkci = txt2generalgraph('file/fci_0.01_fastkci.txt')
fci_5_kci = txt2generalgraph('file/fci_0.05_kci.txt')
fci_5_fisherz = txt2generalgraph('file/fci_0.05_fisherz.txt')
fci_5_fastkci = txt2generalgraph('file/fci_0.05_fastkci.txt')

In [6]:
# Basic usage
display_causal_evaluation(G, pc_5_fisherz)

# Or with confusion matrices
display_causal_evaluation(G, pc_5_fisherz, show_matrices=True)


=== Causal Graph Evaluation Results ===

Arrow Metrics:
----------------------------------------
True Positives (TP):  7.0
False Positives (FP): 2.0
False Negatives (FN): 0.0
True Negatives (TN):  55.0
Precision:            77.78%
Recall:              100.00%

Adjacency Metrics:
----------------------------------------
True Positives (TP):  7
False Positives (FP): 2
False Negatives (FN): 0
True Negatives (TN):  19
Precision:            77.78%
Recall:              100.00%

Structural Hamming Distance (SHD):
----------------------------------------
SHD: 2

Summary Metrics:
----------------------------------------
Arrow F1 Score:     87.50%
Adjacency F1 Score: 87.50%

=== Causal Graph Evaluation Results ===

Arrow Metrics:
----------------------------------------
True Positives (TP):  7.0
False Positives (FP): 2.0
False Negatives (FN): 0.0
True Negatives (TN):  55.0
Precision:            77.78%
Recall:              100.00%

Arrow Confusion Matrix:
----------------------------------------

In [7]:
display_causal_evaluation(G, fci_1_fisherz)
display_causal_evaluation(G, fci_1_fisherz, show_matrices=True)


=== Causal Graph Evaluation Results ===

Arrow Metrics:
----------------------------------------
True Positives (TP):  7.0
False Positives (FP): 4.0
False Negatives (FN): 0.0
True Negatives (TN):  53.0
Precision:            63.64%
Recall:              100.00%

Adjacency Metrics:
----------------------------------------
True Positives (TP):  7
False Positives (FP): 2
False Negatives (FN): 0
True Negatives (TN):  19
Precision:            77.78%
Recall:              100.00%

Structural Hamming Distance (SHD):
----------------------------------------
SHD: 3

Summary Metrics:
----------------------------------------
Arrow F1 Score:     77.78%
Adjacency F1 Score: 87.50%

=== Causal Graph Evaluation Results ===

Arrow Metrics:
----------------------------------------
True Positives (TP):  7.0
False Positives (FP): 4.0
False Negatives (FN): 0.0
True Negatives (TN):  53.0
Precision:            63.64%
Recall:              100.00%

Arrow Confusion Matrix:
----------------------------------------

In [8]:
display_causal_evaluation(G, fci_1_kci)
display_causal_evaluation(G, fci_1_kci, show_matrices=True)


=== Causal Graph Evaluation Results ===

Arrow Metrics:
----------------------------------------
True Positives (TP):  6.0
False Positives (FP): 2.0
False Negatives (FN): 1.0
True Negatives (TN):  55.0
Precision:            75.00%
Recall:              85.71%

Adjacency Metrics:
----------------------------------------
True Positives (TP):  6
False Positives (FP): 1
False Negatives (FN): 1
True Negatives (TN):  20
Precision:            85.71%
Recall:              85.71%

Structural Hamming Distance (SHD):
----------------------------------------
SHD: 2

Summary Metrics:
----------------------------------------
Arrow F1 Score:     80.00%
Adjacency F1 Score: 85.71%

=== Causal Graph Evaluation Results ===

Arrow Metrics:
----------------------------------------
True Positives (TP):  6.0
False Positives (FP): 2.0
False Negatives (FN): 1.0
True Negatives (TN):  55.0
Precision:            75.00%
Recall:              85.71%

Arrow Confusion Matrix:
----------------------------------------
  

In [9]:
display_causal_evaluation(G, fci_1_fastkci)
display_causal_evaluation(G, fci_1_fastkci, show_matrices=True)


=== Causal Graph Evaluation Results ===

Arrow Metrics:
----------------------------------------
True Positives (TP):  7.0
False Positives (FP): 4.0
False Negatives (FN): 0.0
True Negatives (TN):  53.0
Precision:            63.64%
Recall:              100.00%

Adjacency Metrics:
----------------------------------------
True Positives (TP):  7
False Positives (FP): 2
False Negatives (FN): 0
True Negatives (TN):  19
Precision:            77.78%
Recall:              100.00%

Structural Hamming Distance (SHD):
----------------------------------------
SHD: 3

Summary Metrics:
----------------------------------------
Arrow F1 Score:     77.78%
Adjacency F1 Score: 87.50%

=== Causal Graph Evaluation Results ===

Arrow Metrics:
----------------------------------------
True Positives (TP):  7.0
False Positives (FP): 4.0
False Negatives (FN): 0.0
True Negatives (TN):  53.0
Precision:            63.64%
Recall:              100.00%

Arrow Confusion Matrix:
----------------------------------------

In [10]:
display_causal_evaluation(G, fci_5_kci)
display_causal_evaluation(G, fci_5_kci, show_matrices=True)


=== Causal Graph Evaluation Results ===

Arrow Metrics:
----------------------------------------
True Positives (TP):  6.0
False Positives (FP): 4.0
False Negatives (FN): 1.0
True Negatives (TN):  53.0
Precision:            60.00%
Recall:              85.71%

Adjacency Metrics:
----------------------------------------
True Positives (TP):  6
False Positives (FP): 2
False Negatives (FN): 1
True Negatives (TN):  19
Precision:            75.00%
Recall:              85.71%

Structural Hamming Distance (SHD):
----------------------------------------
SHD: 3

Summary Metrics:
----------------------------------------
Arrow F1 Score:     70.59%
Adjacency F1 Score: 80.00%

=== Causal Graph Evaluation Results ===

Arrow Metrics:
----------------------------------------
True Positives (TP):  6.0
False Positives (FP): 4.0
False Negatives (FN): 1.0
True Negatives (TN):  53.0
Precision:            60.00%
Recall:              85.71%

Arrow Confusion Matrix:
----------------------------------------
  

In [11]:
display_causal_evaluation(G, fci_5_fisherz)
display_causal_evaluation(G, fci_5_fisherz, show_matrices=True)


=== Causal Graph Evaluation Results ===

Arrow Metrics:
----------------------------------------
True Positives (TP):  7.0
False Positives (FP): 5.0
False Negatives (FN): 0.0
True Negatives (TN):  52.0
Precision:            58.33%
Recall:              100.00%

Adjacency Metrics:
----------------------------------------
True Positives (TP):  7
False Positives (FP): 3
False Negatives (FN): 0
True Negatives (TN):  18
Precision:            70.00%
Recall:              100.00%

Structural Hamming Distance (SHD):
----------------------------------------
SHD: 4

Summary Metrics:
----------------------------------------
Arrow F1 Score:     73.68%
Adjacency F1 Score: 82.35%

=== Causal Graph Evaluation Results ===

Arrow Metrics:
----------------------------------------
True Positives (TP):  7.0
False Positives (FP): 5.0
False Negatives (FN): 0.0
True Negatives (TN):  52.0
Precision:            58.33%
Recall:              100.00%

Arrow Confusion Matrix:
----------------------------------------

In [12]:
display_causal_evaluation(G, fci_5_fastkci)
display_causal_evaluation(G, fci_5_fastkci, show_matrices=True)


=== Causal Graph Evaluation Results ===

Arrow Metrics:
----------------------------------------
True Positives (TP):  7.0
False Positives (FP): 5.0
False Negatives (FN): 0.0
True Negatives (TN):  52.0
Precision:            58.33%
Recall:              100.00%

Adjacency Metrics:
----------------------------------------
True Positives (TP):  7
False Positives (FP): 3
False Negatives (FN): 0
True Negatives (TN):  18
Precision:            70.00%
Recall:              100.00%

Structural Hamming Distance (SHD):
----------------------------------------
SHD: 4

Summary Metrics:
----------------------------------------
Arrow F1 Score:     73.68%
Adjacency F1 Score: 82.35%

=== Causal Graph Evaluation Results ===

Arrow Metrics:
----------------------------------------
True Positives (TP):  7.0
False Positives (FP): 5.0
False Negatives (FN): 0.0
True Negatives (TN):  52.0
Precision:            58.33%
Recall:              100.00%

Arrow Confusion Matrix:
----------------------------------------