---
# **II. Discover and Infere the Causality Graph**
---

# **A. Initialize the Project**

This top-level section initiates the entire notebook project by establishing the computational environment and configurations required for subsequent analytical workflows. It comprises two primary subsections: (1) loading all necessary libraries and modules, and (2) configuring essential tools and settings for reproducibility, visualization, and environment setup. These foundational steps ensure that all dependencies are correctly initialized and the environment is consistently reproducible across executions.

# 1. Load the Packages and Libraries

This subsection is responsible for importing all required Python packages and modules that will be utilized throughout the notebook. It includes essential standard library modules, scientific computing tools, machine learning frameworks, deep learning utilities, and specialized packages for causal inference and structural equation modeling.

## 1.1. Import Essential Python Standard Library Modules

This sub-subsection imports foundational Python standard library modules that provide base functionalities such as abstract base classes (abc), combinatorics (itertools), JSON processing, OS-level operations, warning handling, and logging utilities. These libraries serve as foundational components supporting utility functions, configuration management, and logging mechanisms used throughout the project.

In [1]:
# Import necessary modules.
from abc import ABC, abstractmethod
import itertools
import json
import logging
import os
import warnings

## 1.2. Import Scientific Computing, Machine Learning, and Utility Libraries

This sub-subsection imports a suite of scientific and machine learning libraries crucial for data manipulation, statistical analysis, visualization, and model building. These include:
- numpy, scipy, and pandas for numerical and statistical computations.
- matplotlib for data visualization.
- networkx for graph-based structures.
- scikit-learn for preprocessing and regression.
- pytorch_lightning and torch for deep learning infrastructure and training utilities.

These libraries establish the computational and analytical core of the notebook.

In [2]:
# Import additional data handling, visualization, machine learning, and deep learning modules.
import fsspec
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
from pytorch_lightning.callbacks import TQDMProgressBar
import pytorch_lightning as pl
from scipy import stats
import scipy.linalg as slin
import scipy.optimize as sopt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import LabelEncoder
import torch

## 1.3. Import Causal Inference and causalnex and Structural Equation Modeling Modules

This sub-subsection loads specialized modules from the causica and lingam libraries, which are essential for performing causal discovery, structural learning, and related inference tasks. It includes dataset handling classes, probabilistic distribution tools, Lightning modules for model training, and algorithmic components for augmented Lagrangian training and linear non-Gaussian acyclic modeling. These imports directly support the notebook’s focus on causal inference methodologies.

In [3]:
# Import modules from the causica and lingam packages.
from causica.datasets.causica_dataset_format import Variable
from causica.distributions import ContinuousNoiseDist
from causica.lightning.data_modules.basic_data_module import BasicDECIDataModule
from causica.lightning.modules.deci_module import DECIModule
from causica.training.auglag import AugLagLRConfig
from causalnex.structure import notears
from causalnex.structure.structuremodel import StructureModel
from causalnex.plots import plot_structure
import lingam

# 2. Configure the Tools

This subsection configures critical operational settings for numerical precision, plotting aesthetics, reproducibility, and environmental setup. It ensures standardized formatting and controlled randomness to support robust, interpretable, and reproducible results.

## 2.1. Configure NumPy, Pandas, and Reproducibility Settings

This sub-subsection sets options for numpy and pandas to ensure readable and consistent output formatting for arrays and dataframes. It also seeds random number generators across numpy, torch, and pytorch_lightning to ensure deterministic behavior, which is crucial for reproducibility in experimental workflows.

In [4]:
# Set NumPy configuration options.
np.set_printoptions(precision=3, suppress=True)
np.random.seed(42)

# Set pandas display options.
pd.set_option("display.max_columns", None)
pd.set_option("display.max_colwidth", None)
pd.set_option("display.precision", 2)

# Seed random number generators for PyTorch and PyLightning.
torch.manual_seed(42)
pl.seed_everything(42)

Seed set to 42


42

## 2.2. Configure Matplotlib Settings

This sub-subsection customizes matplotlib's font family to "Times New Roman" to standardize the visual appearance of all plots generated, enhancing readability and stylistic uniformity in reports or publications.

In [5]:
# Set Matplotlib font family.
plt.rcParams["font.family"] = "Times New Roman"

## 2.3. Environment and Path Configurations

This sub-subsection sets project-specific path variables for datasets and configuration files, loads environment variables using dotenv, configures fallback options for PyTorch hardware acceleration, and initializes logging settings. The configuration facilitates modular control of data access, secure variable management, and robust runtime diagnostics through logging.

In [6]:
# Set configuration variables.
test_run = bool(os.environ.get("TEST_RUN", False))
DATA_PATH = "data/dataset.csv"
VARIABLES_PATH = "data/variables.json"

# Import environment variable loader.
from dotenv import load_dotenv

# Load environment variables from a .env file.
load_dotenv()

# Set PyTorch environment variable for MPS fallback.
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

# Configure logging settings.
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

# **B. Get the Data Prepared**

This section performs comprehensive preprocessing of the input dataset to ensure its readiness for causal analysis and modeling. It includes steps for loading, cleaning, transforming, encoding, and splitting the data. Emphasis is placed on systematically identifying relevant features, removing redundancies and data leaks, handling missing values and outliers, and encoding categorical variables appropriately.

# 3. Explore and Clean the Data

This subsection initiates the data preparation pipeline. It sequentially processes the dataset by loading it, inspecting the content, renaming and selecting pertinent features, handling anomalies, and preparing the data for modeling. It ensures data integrity, feature relevance, and analytical consistency.

## 3.1. Load the Data

Loads the dataset from a CSV file into a Pandas DataFrame. This provides the raw input data necessary for subsequent preprocessing steps.

In [7]:
# Load the dataset xslsx file into a Pandas DataFrame.
data = pd.read_csv(
    DATA_PATH
)

## 3.2. Show the Data

Displays the first 5 rows of the dataset to facilitate initial inspection and validation of the dataset structure and content.

In [8]:
# Display the first 10 rows of the dataset.
data.head(5).style.set_caption(
    "<b>IBM Telco Customer Churn Dataset (First 5 Rows)</b>"
)

Unnamed: 0,Customer ID,Gender,Age,Under 30,Senior Citizen,Married,Dependents,Number of Dependents,Country,State,City,Zip Code,Latitude,Longitude,Population,Quarter,Referred a Friend,Number of Referrals,Tenure in Months,Offer,Phone Service,Avg Monthly Long Distance Charges,Multiple Lines,Internet Service,Internet Type,Avg Monthly GB Download,Online Security,Online Backup,Device Protection Plan,Premium Tech Support,Streaming TV,Streaming Movies,Streaming Music,Unlimited Data,Contract,Paperless Billing,Payment Method,Monthly Charge,Total Charges,Total Refunds,Total Extra Data Charges,Total Long Distance Charges,Total Revenue,Satisfaction Score,Customer Status,Churn Label,Churn Score,CLTV,Churn Category,Churn Reason
0,8779-QRDMV,Male,78,No,Yes,No,No,0,United States,California,Los Angeles,90022,34.02381,-118.156582,68701,Q3,No,0,1,,No,0.0,No,Yes,DSL,8,No,No,Yes,No,No,Yes,No,No,Month-to-Month,Yes,Bank Withdrawal,39.65,39.65,0.0,20,0.0,59.65,3,Churned,Yes,91,5433,Competitor,Competitor offered more data
1,7495-OOKFY,Female,74,No,Yes,Yes,Yes,1,United States,California,Los Angeles,90063,34.044271,-118.185237,55668,Q3,Yes,1,8,Offer E,Yes,48.85,Yes,Yes,Fiber Optic,17,No,Yes,No,No,No,No,No,Yes,Month-to-Month,Yes,Credit Card,80.65,633.3,0.0,0,390.8,1024.1,3,Churned,Yes,69,5302,Competitor,Competitor made better offer
2,1658-BYGOY,Male,71,No,Yes,No,Yes,3,United States,California,Los Angeles,90065,34.108833,-118.229715,47534,Q3,No,0,18,Offer D,Yes,11.33,Yes,Yes,Fiber Optic,52,No,No,No,No,Yes,Yes,Yes,Yes,Month-to-Month,Yes,Bank Withdrawal,95.45,1752.55,45.61,0,203.94,1910.88,2,Churned,Yes,81,3179,Competitor,Competitor made better offer
3,4598-XLKNJ,Female,78,No,Yes,Yes,Yes,1,United States,California,Inglewood,90303,33.936291,-118.332639,27778,Q3,Yes,1,25,Offer C,Yes,19.76,No,Yes,Fiber Optic,12,No,Yes,Yes,No,Yes,Yes,No,Yes,Month-to-Month,Yes,Bank Withdrawal,98.5,2514.5,13.43,0,494.0,2995.07,2,Churned,Yes,88,5337,Dissatisfaction,Limited range of services
4,4846-WHAFZ,Female,80,No,Yes,Yes,Yes,1,United States,California,Whittier,90602,33.972119,-118.020188,26265,Q3,Yes,1,37,Offer C,Yes,6.33,Yes,Yes,Fiber Optic,14,No,No,No,No,No,No,No,Yes,Month-to-Month,Yes,Bank Withdrawal,76.5,2868.15,0.0,0,234.21,3102.36,2,Churned,Yes,67,2793,Price,Extra data charges


## 3.3. Rename and Reorder the Features

Groups, flattens, and reorganizes dataset columns based on domain-specific categories (e.g., demographic, service usage, billing) to improve interpretability and facilitate targeted cleaning and feature selection.

In [9]:
# Define the feature mapping for the dataset.
feature_mapping = {
    "Customer Info": [
        "Customer ID", "Gender", "Age", "Under 30", "Senior Citizen", "Married", "Dependents", "Number of Dependents"
    ],
    "Location Info": [
        "Country", "State", "City", "Zip Code", "Latitude", "Longitude", "Population"
    ],
    "Referral & Tenure": [
        "Quarter", "Referred a Friend", "Number of Referrals", "Tenure in Months", "Offer"
    ],
    "Services Signed Up": [
        "Phone Service", "Multiple Lines", "Internet Service", "Internet Type", "Unlimited Data"
    ],
    "Internet Features": [
        "Online Security", "Online Backup", "Device Protection Plan", "Premium Tech Support",
        "Streaming TV", "Streaming Movies", "Streaming Music"
    ],
    "Billing & Payment": [
        "Avg Monthly Long Distance Charges", "Avg Monthly GB Download", "Monthly Charge",
        "Total Charges", "Total Refunds", "Total Extra Data Charges",
        "Total Long Distance Charges", "Total Revenue", "Paperless Billing", "Payment Method"
    ],
    "Customer Scores": [
        "Satisfaction Score", "CLTV", "Churn Score"
    ],
    "Churn Info": [
        "Customer Status", "Churn Label", "Churn Category", "Churn Reason"
    ]
}

In [10]:
# Flatten the feature_mapping into a single list.
desired_order = [col for group in feature_mapping.values() for col in group]

# Reorder the DataFrame.
data = data[desired_order]

## 3.4. Remove the Unnecessary Features

Identifies and removes features based on redundancy, low variance, potential data leakage, high cardinality, and irrelevance. This ensures the modeling pipeline remains efficient and statistically sound.

In [11]:
# Define the features to remove from the dataset.
features_to_remove = [
    # Identifiers & Redundant Demographics
    "Customer ID",           # High cardinality identifier
    "Under 30",              # Redundant (derivable from Age)
    "Dependents",            # Redundant (derivable from Number of Dependents)

    # Location Info (low variance or low utility)
    "Country",               # Constant (all United States)
    "State",                 # Constant (all California)
    "Zip Code",              # Too granular
    "City",                  # High cardinality, many unique values
    "Latitude",              # Granular
    "Longitude",             # Granular
    "Population",            # Possibly low variation or correlated with city

    # Referral
    "Referred a Friend",     # Redundant (derivable from Number of Referrals)

    # Subscription Redundancy
    "Internet Service",      # Redundant (inferable from Internet Type)

    # Derived or Leaky Features
    "Customer Status",       # Leaks churn label
    "Churn Score",           # Usually post-hoc score, potential leakage
    "Churn Category",        # Sparse & derived from churn
    "Churn Reason",          # Sparse & derived from churn
    "CLTV",                  # Leaks churn label
    "Satisfaction Score",    # Leaks churn label

    # Time Feature
    "Quarter",               # Possibly low relevance unless time modeling is intended

    # Financial features removed in favor of only keeping Total Revenue
    "Avg Monthly Long Distance Charges",    # Usage-level detail removed
    "Avg Monthly GB Download",              # Usage-level detail removed
    "Monthly Charge",                       # Snapshot charge removed
    "Total Charges",                        # Cumulative but derived
    "Total Refunds",                        # Post-hoc financial info
    "Total Extra Data Charges",             # Specific fee detail removed
    "Total Long Distance Charges"           # Specific usage-based revenue removed
]

In [12]:
# Remove the specified columns from the DataFrame.
data = data.drop(
    columns=features_to_remove
)

## 3.5. Distinguish the Categorical and Numeric Features

Defines explicit lists of categorical and numeric features to enable precise and appropriate transformations in later preprocessing steps.

In [13]:
# Define the categorical features in the dataset.
categorical_features = [
    "Gender",
    "Senior Citizen",
    "Married",
    "Offer",
    "Phone Service",
    "Multiple Lines",
    "Internet Type",
    "Unlimited Data",
    "Online Security",
    "Online Backup",
    "Device Protection Plan",
    "Premium Tech Support",
    "Streaming TV",
    "Streaming Movies",
    "Streaming Music",
    "Paperless Billing",
    "Payment Method",
    "Churn Label"
]

# Define the numeric features in the dataset.
numeric_features = [
    "Age",
    "Number of Dependents",
    "Number of Referrals",
    "Tenure in Months",
    "Total Revenue"
]

## 3.6. Handle Outlier Values for Numeric Features

Applies interquartile range (IQR) based clipping to cap outliers in numeric features, mitigating their impact on modeling algorithms while preserving core data distributions.

In [14]:
# Iterate over each numeric feature to compute and apply IQR-based clipping.
for feature in numeric_features:
    if feature != "Number of Dependents":
        q1, q3 = data[feature].quantile(
            [0.25, 0.75]
        )
        
        iqr = q3 - q1

        lower_bound, upper_bound = (
            q1 - 1.5 * iqr,
            q3 + 1.5 * iqr
        )

        data[feature] = data[feature].clip(
            lower=lower_bound,
            upper=upper_bound
        )

## 3.7. Check the Categorical Features' Unique Values

Computes frequency and percentage distributions for unique values in categorical features, aiding in understanding class distributions and guiding encoding strategies.

In [15]:
# Define a vacant list to store the rows.
rows = []

# Iterate through the categorical features.
for feature in categorical_features:

    # Get the unique values and their counts.
    value_counts = data[feature].value_counts()
    first_row = True

    # Iterate through the unique values.
    for value, count in value_counts.items():

        # Calculate the percentage.
        percentage = str(round((count / len(data)) * 100, 2))
        # Set the feature name.
        feature_name = feature if first_row else ""
        # Append the unique value, count, and percentage to the rows.
        rows.append([
            feature_name,
            value,
            count,
            percentage
        ])
        # Set the first row to False.
        first_row = False

In [16]:
# Create a dataFrame of unique values.
unique_values_df = pd.DataFrame(
    rows,
    columns=["Feature", "Unique Value", "Frequency", "Percentage"]
)

# Round the percentage to 2 decimal places.
# unique_values_df["Percentage"] = unique_values_df["Percentage"].round(2)
unique_values_df["Percentage"] = unique_values_df["Percentage"]

# Style the unique values dataFrame.
unique_values_df = (
    unique_values_df.style
    .set_caption("<b>Unique Values in Categorical Features</b>")
    .hide(axis="index")
)

# Display the unique values in the categorical features.
unique_values_df

Feature,Unique Value,Frequency,Percentage
Gender,Male,3555,50.48
,Female,3488,49.52
Senior Citizen,No,5901,83.79
,Yes,1142,16.21
Married,No,3641,51.7
,Yes,3402,48.3
Offer,,3877,55.05
,Offer B,824,11.7
,Offer E,805,11.43
,Offer D,602,8.55


## 3.8. Check the Numeric Features' Statistics

Generates summary statistics (e.g., mean, std, min, max) for numeric features to inspect data ranges, central tendencies, and spread.

In [17]:
data[numeric_features].describe()

Unnamed: 0,Age,Number of Dependents,Number of Referrals,Tenure in Months,Total Revenue
count,7043.0,7043.0,7043.0,7043.0,7043.0
mean,46.51,0.47,1.81,32.39,3033.27
std,16.75,0.96,2.66,24.54,2861.98
min,19.0,0.0,0.0,1.0,21.36
25%,32.0,0.0,0.0,9.0,605.61
50%,46.0,0.0,0.0,29.0,2108.64
75%,60.0,0.0,3.0,55.0,4801.15
max,80.0,9.0,7.5,72.0,11094.45


## 3.9. Check for Missing Values in the Data

Computes the percentage of non-missing entries per feature, presented in a tabular format, to diagnose data completeness and identify missing value patterns.

In [18]:
# Find out the missing values percentage in the dataset.
non_missing_percentage = data.notnull().mean() * 100

In [19]:
# Create a DataFrame from the non-missing percentage series.
non_missing_df = pd.DataFrame(
    non_missing_percentage,
    columns=["Non-Missing Percentage"]
)

# Change the index to a column named "Feature".
non_missing_df = non_missing_df.reset_index().rename(
    columns={
        "index": "Feature"
    }
)

# Increment the DataFrame index to start from 1.
non_missing_df.index = non_missing_df.index + 1

# Display the non-missing percentage table with two decimal places.
non_missing_df.style.set_caption(
    "<b>Non-Missing Percentage of Features</b>"
).format(
    {
        "Non-Missing Percentage": "{:.2f}"
    }
)

Unnamed: 0,Feature,Non-Missing Percentage
1,Gender,100.0
2,Age,100.0
3,Senior Citizen,100.0
4,Married,100.0
5,Number of Dependents,100.0
6,Number of Referrals,100.0
7,Tenure in Months,100.0
8,Offer,100.0
9,Phone Service,100.0
10,Multiple Lines,100.0


## 3.10. Fill the Missing Values

Fills missing values in specific categorical features using domain-appropriate default values (e.g., "No Offer" for offers, "No Internet" for connection type).

In [20]:
# Fill missing values in the dataset.
data["Offer"] = data["Offer"].fillna("No Offer")
data["Internet Type"] = data["Internet Type"].fillna("No Internet")

## 3.11. Categoirze and Display the Features

Creates a side-by-side tabular display of the final sets of categorical and numeric features, providing a transparent overview of feature classification.

In [21]:
# Pad the shorter list with empty strings.
max_length = max(
    len(categorical_features),
    len(numeric_features)
)

categorical_features_feature = categorical_features.copy()
numeric_features_feature = numeric_features.copy()

categorical_features_feature += [""] * (
    max_length - len(categorical_features)
)
numeric_features_feature += [""] * (
    max_length - len(numeric_features)
)

In [22]:
# Create a DataFrame to display feature categorization.
feature_types_df = pd.DataFrame(
    {
        "Categorical Features": categorical_features_feature,
        "Numeric Features": numeric_features_feature
    }
)

# Increment the DataFrame index to start from 1.
feature_types_df.index = feature_types_df.index + 1

# Display the feature categorization table.
feature_types_df.style.set_caption(
    "<b>Categorization of Features by Type</b>"
)

Unnamed: 0,Categorical Features,Numeric Features
1,Gender,Age
2,Senior Citizen,Number of Dependents
3,Married,Number of Referrals
4,Offer,Tenure in Months
5,Phone Service,Total Revenue
6,Multiple Lines,
7,Internet Type,
8,Unlimited Data,
9,Online Security,
10,Online Backup,


## 3.12. Check the Features' Datatypes

Displays data types of both categorical and numeric features to verify type correctness and compatibility for encoding and modeling operations.

In [23]:
# Create a DataFrame to display the data types of categorical features.
categorical_features_data_types = pd.DataFrame(
    data[categorical_features].dtypes,
    columns=["Categorical Features' Data Types"]
)

# Display the data types table with a caption.
categorical_features_data_types.style.set_caption(
    "<b>Categorical Features' Data Types</b>"
)

Unnamed: 0,Categorical Features' Data Types
Gender,object
Senior Citizen,object
Married,object
Offer,object
Phone Service,object
Multiple Lines,object
Internet Type,object
Unlimited Data,object
Online Security,object
Online Backup,object


In [24]:
# Create a DataFrame to display the data types of numeric features.
numeric_features_data_types = pd.DataFrame(
    data[numeric_features].dtypes,
    columns=["Numeric Features' Data Types"]
)

# Display the data types table with a caption.
numeric_features_data_types.style.set_caption(
    "Numeric Features' Data Types"
)

Unnamed: 0,Numeric Features' Data Types
Age,int64
Number of Dependents,int64
Number of Referrals,float64
Tenure in Months,int64
Total Revenue,float64


## 3.13. Encode the Binary Features

Defines and applies a function to map binary categorical values (e.g., "Yes"/"No", "Male"/"Female") to numeric binary format (1/0), ensuring compatibility with ML algorithms.

In [25]:
# Define a function to encode binary features.
def encode_binary_features(datasets, features, mapping):
    """
    Applies binary encoding to specified features across multiple datasets.

    Args:
        datasets (List[pd.DataFrame]): A list of DataFrames to be modified in-place.
        features (List[str]): The names of binary categorical features to encode.
        mapping (dict): A dictionary mapping string categories to binary values.
    """
    for df in datasets:
        for feature in features:
            df[feature] = df[feature].astype(str).map(mapping)

In [26]:
# Define binary categorical features to be encoded.
binary_features = [
    "Gender",
    "Senior Citizen",
    "Married",
    "Phone Service",
    "Multiple Lines",
    "Unlimited Data",
    "Online Security",
    "Online Backup",
    "Device Protection Plan",
    "Premium Tech Support",
    "Streaming TV",
    "Streaming Movies",
    "Streaming Music",
    "Paperless Billing",
    "Churn Label"
]

# Define mapping for binary categories.
binary_mapping = {
    "Yes": 1,
    "No": 0,
    "Male": 1,
    "Female": 0
}

In [27]:
# Apply binary encoding to all datasets.
encode_binary_features(
    datasets=[data],
    features=binary_features,
    mapping=binary_mapping
)

## 3.14. Encode the Ordinal Features

Encodes ordinal categorical features based on domain-informed hierarchical mappings (e.g., levels of service offers or payment methods), preserving their intrinsic order.

In [28]:
# Define a function to encode ordinal features.
def encode_ordinal_features(datasets, mappings):
    """
    Applies ordinal encoding to specified features across multiple datasets.

    Args:
        datasets (List[pd.DataFrame]): A list of DataFrames to be modified in-place.
        mappings (dict): A dictionary where keys are feature names and values are mapping dicts.
    """
    for df in datasets:
        for feature, mapping in mappings.items():
            df[feature] = df[feature].map(mapping)

In [29]:
# Define the mappings for ordinal features.
offer_mapping = {
    "No Offer": 0,
    "Offer A": 1,
    "Offer B": 2,
    "Offer C": 3,
    "Offer D": 4,
    "Offer E": 5
}

internet_type_mapping = {
    "No Internet": 0,
    "DSL": 1,
    "Cable": 2,
    "Fiber Optic": 3
}

payment_method_mapping = {
    "Mailed Check": 1,
    "Bank Withdrawal": 2,
    "Credit Card": 3
}

# Create a dictionary of ordinal mappings.
ordinal_mappings = {
    "Offer": offer_mapping,
    "Internet Type": internet_type_mapping,
    "Payment Method": payment_method_mapping
}

In [30]:
# Apply ordinal encoding to all datasets.
encode_ordinal_features(
    datasets=[data],
    mappings=ordinal_mappings
)

## 3.15. Standard Scale the Numeric Features

Applies standard scaling to numeric features using StandardScaler to normalize distributions and ensure numerical comparability across features.

In [31]:
# Apply the scaler to normalize all numeric features in the dataset.
scaler = StandardScaler()

data[numeric_features] = scaler.fit_transform(data[numeric_features])

## 3.16. Split the Data into Train and Test Sets

Divides the cleaned and transformed dataset into training and testing subsets with a reproducible 80/20 split, preparing for model development and evaluation.

In [32]:
# Split data into train and test sets
train_data, test_data = train_test_split(
    data,
    test_size=0.2,
    random_state=42
)

# **C. Define the Functions to Use in the Pypeline**

This primary section initiates the modular construction of a causal discovery pipeline. It encapsulates function definitions and abstractions critical to utility, validation, visualization, and algorithmic implementation. The ensuing subsections provide the computational foundation for generating constraint-aware causal graphs, analyzing their structural properties, and preparing the environment for comparative algorithm execution.

# 4. Define Utility Functions

This subsection introduces a suite of foundational utility functions that serve to enforce domain-specific constraints, validate structural integrity, visualize causal graphs, extract interpretable relationships, and evaluate algorithmic outputs. Each function is essential for preprocessing, diagnostics, and ensuring consistency throughout the causal modeling workflow.

## 4.1. Define the Constraint Matrix Creator Function

This function establishes a constraint matrix that encodes permissible and forbidden causal relationships among variables based on predefined hierarchical tiers and user-specified restrictions. The matrix (with NaN for allowed and 0.0 for disallowed edges) ensures that the causal discovery algorithms conform to domain knowledge, including treating demographic variables as root nodes and the churn label as a terminal node.

## Methodology

### Purpose and Context

The purpose of the constraint matrix construction procedure is to encode prior knowledge and domain-specific assumptions about the causal structure among a set of variables. This matrix serves as a foundational component for constraint-based and score-based causal discovery algorithms by explicitly delineating which directed edges are permitted or forbidden. The encoded constraints enforce logical tier-based causal ordering, restrict self-contradictory relationships, and preserve interpretability of the inferred graph.

### Inputs and Parameters

Let $\mathcal{N} = \{v_1, \dots, v_d\}$ denote the set of variable names, partitioned into $T$ hierarchical tiers $\mathcal{T}_1, \dots, \mathcal{T}_T$. The function takes the following inputs:

- $\mathcal{N}$: Ordered list of variable names.
- $\mathcal{T} = [\mathcal{T}_1, \dots, \mathcal{T}_T]$: List of tier sets, where $\mathcal{T}_k \subseteq \mathcal{N}$.
- $\mathcal{C}_{\text{spec}} = (\mathcal{C}_{\text{forbid}}, \mathcal{C}_{\text{allow}})$: Optional set of user-defined constraints.

The output is a matrix $\mathcal{C} \in \mathbb{R}^{d \times d}$ such that:

- $\mathcal{C}_{ij} = 0.0$: edge $v_i \to v_j$ is forbidden.
- $\mathcal{C}_{ij} = \text{NaN}$: edge $v_i \to v_j$ is allowed.

### Algorithmic Procedure

1. **Matrix Initialization**:
   - Initialize $\mathcal{C}$ as a $d \times d$ matrix with all entries set to $\text{NaN}$ (permitting all edges by default).

2. **Terminal Node Enforcement**:
   - Identify the index $i^*$ of the "Churn Label" variable (if present).
   - Set $\mathcal{C}_{i^*, j} = 0.0$ for all $j$, prohibiting outgoing edges from "Churn Label".

3. **Root Node Constraints**:
   - For all variables $v_j \in \mathcal{T}_1$ (demographic tier), enforce:
     $$
     \mathcal{C}_{i, j} = 0.0, \quad \forall i
     $$
     thereby restricting any incoming edges to demographic variables.

4. **Intra-Tier Edge Blocking**:
   - For $v_i, v_j \in \mathcal{T}_1$, $i \neq j$, set $\mathcal{C}_{ij} = 0.0$ to forbid causal edges among demographic variables.

5. **Inter-Tier Edge Permissions**:
   - For each tier $\mathcal{T}_k$ and its subsequent tier $\mathcal{T}_{k+1}$:
     - Allow edges from $\mathcal{T}_k$ to $\mathcal{T}_{k+1}$ by maintaining $\mathcal{C}_{ij} = \text{NaN}$ for $v_i \in \mathcal{T}_k$, $v_j \in \mathcal{T}_{k+1}$.
     - For all $v_j \notin \mathcal{T}_{k+1}$, set $\mathcal{C}_{ij} = 0.0$ to restrict edge targets to the immediate downstream tier.

6. **Specific Constraints Incorporation**:
   - Apply custom constraints:
     - For $(v_i, v_j) \in \mathcal{C}_{\text{forbid}}$, set $\mathcal{C}_{ij} = 0.0$.
     - For $(v_i, v_j) \in \mathcal{C}_{\text{allow}}$, set $\mathcal{C}_{ij} = \text{NaN}$.

### Theoretical Justification

This constraint encoding strategy ensures that the causal discovery process adheres to essential prior knowledge, including known causal ordering (tiers), variable roles (e.g., exogenous vs endogenous), and explicit domain rules. By defining the admissible edge set $\mathcal{E}_{\text{valid}} = \{(i,j) : \mathcal{C}_{ij} = \text{NaN}\}$, the matrix $\mathcal{C}$ facilitates efficient search or pruning in downstream algorithms, thereby reducing the hypothesis space and improving estimation accuracy.

### Outcomes and Limitations

The output includes:
- Constraint matrix $\mathcal{C} \in \mathbb{R}^{d \times d}$.
- Mapping $\phi: \mathcal{N} \rightarrow \{1, \dots, d\}$ for node indices.

This approach assumes that tiers and specific constraints are accurate and comprehensive. Errors or omissions in the provided tier structure or specific constraints may lead to incorrect causal exclusion or inclusion. The method does not support probabilistic or soft constraints.

In [33]:
def create_constraint_matrix(node_names, tiers, specific_constraints=None):
    """
    Create a constraint matrix for causal discovery algorithms.
    
    Args:
        node_names (list): List of variable names.
        tiers (list): List of tier lists (e.g., [demographic, customer, ...]).
        specific_constraints (dict): Additional constraints (e.g., {"forbidden": [(src, dst), ...]}).
    
    Returns:
        np.ndarray: Constraint matrix (np.nan for allowed edges, 0.0 for forbidden).
    """
    num_nodes = len(node_names)
    node_name_to_idx = {name: i for i, name in enumerate(node_names)}
    constraint_matrix = np.full((num_nodes, num_nodes), np.nan, dtype=np.float32)
    
    # Set Churn Label as sink node (no outgoing edges)
    churn_idx = node_name_to_idx.get("Churn Label")
    if churn_idx is not None:
        constraint_matrix[churn_idx, :] = 0.0
    
    # Set demographic variables as root nodes (no incoming edges)
    for feature in tiers[0]:  # Tier 1: Demographic
        if feature in node_name_to_idx:
            feature_idx = node_name_to_idx[feature]
            constraint_matrix[:, feature_idx] = 0.0
    
    # Prevent edges within Tier 1
    for src in tiers[0]:
        for dst in tiers[0]:
            if src != dst and src in node_name_to_idx and dst in node_name_to_idx:
                constraint_matrix[node_name_to_idx[src], node_name_to_idx[dst]] = 0.0
    
    # Allow edges only from Tier N to Tier N+1
    for src_tier_idx, src_tier in enumerate(tiers[:-1]):
        dst_tier = tiers[src_tier_idx + 1]
        for src in src_tier:
            for dst in dst_tier:
                if src in node_name_to_idx and dst in node_name_to_idx:
                    constraint_matrix[node_name_to_idx[src], node_name_to_idx[dst]] = np.nan
        # Block edges to other tiers
        for other_tier_idx, other_tier in enumerate(tiers):
            if other_tier_idx != src_tier_idx + 1:
                for src in src_tier:
                    for dst in other_tier:
                        if src in node_name_to_idx and dst in node_name_to_idx:
                            constraint_matrix[node_name_to_idx[src], node_name_to_idx[dst]] = 0.0
    
    # Apply specific constraints (e.g., Gender → Service/Billing forbidden)
    if specific_constraints:
        for src, dst in specific_constraints.get("forbidden", []):
            if src in node_name_to_idx and dst in node_name_to_idx:
                constraint_matrix[node_name_to_idx[src], node_name_to_idx[dst]] = 0.0
        for src, dst in specific_constraints.get("allowed", []):
            if src in node_name_to_idx and dst in node_name_to_idx:
                constraint_matrix[node_name_to_idx[src], node_name_to_idx[dst]] = np.nan
    
    logger.info("Constraint matrix created with shape: %s", constraint_matrix.shape)
    return constraint_matrix, node_name_to_idx

## 4.2. Define the Constraint Matrix Validator Function

This variant of the constraint matrix creator is provided with enhanced commentary and clarity. It ensures that each constraint—root variables, terminal nodes, tier transitions, and explicitly forbidden relationships—is systematically imposed on the adjacency space of causal discovery.

## Methodology

### Purpose and Context

The constraint matrix construction function is a systematic approach for encoding prior structural knowledge into the causal discovery process. It supports tier-based hierarchical modeling and the enforcement of both general and specific domain constraints. The resulting matrix is instrumental for constraining the hypothesis space of causal discovery algorithms, improving both interpretability and computational efficiency.

### Inputs and Parameters

Let $\mathcal{N} = \{v_1, \dots, v_d\}$ denote a set of $d$ variable names and $\mathcal{T} = [\mathcal{T}_1, \dots, \mathcal{T}_K]$ a list of $K$ tiers. The function accepts:

- $\mathcal{N}$: List of variable names.
- $\mathcal{T}$: List of disjoint variable subsets (tiers).
- $\mathcal{C}_{\text{spec}}$: A dictionary of user-specified constraints containing:
  - $\mathcal{C}_{\text{forbid}} = \{(v_i, v_j)\}$: Edges to forbid,
  - $\mathcal{C}_{\text{allow}} = \{(v_i, v_j)\}$: Edges to explicitly permit.

The output is:

- $\mathcal{C} \in \mathbb{R}^{d \times d}$: A constraint matrix with:
  - $\mathcal{C}_{ij} = 0.0$ if edge $v_i \to v_j$ is forbidden,
  - $\mathcal{C}_{ij} = \text{NaN}$ if edge $v_i \to v_j$ is allowed.
- $\phi: \mathcal{N} \rightarrow \{1, \dots, d\}$: Mapping from variable names to indices.

### Algorithmic Procedure

1. **Initialization**:
   - Create a $d \times d$ matrix $\mathcal{C}$ initialized to $\text{NaN}$.

2. **Terminal Variable Constraints**:
   - Identify "Churn Label" as a terminal (sink) node. For its index $i$, set:
     $$
     \mathcal{C}_{ij} = 0.0 \quad \forall j
     $$

3. **Exogenous Variable Constraints**:
   - For each feature $v_j \in \mathcal{T}_1$ (demographics), enforce:
     $$
     \mathcal{C}_{ij} = 0.0 \quad \forall i
     $$
   - Prevent intra-tier edges in $\mathcal{T}_1$: for $v_i, v_j \in \mathcal{T}_1$, $i \neq j$,
     $$
     \mathcal{C}_{ij} = 0.0
     $$

4. **Inter-Tier Causal Flow**:
   - For each tier index $k$, permit causal links only to the immediate downstream tier $\mathcal{T}_{k+1}$:
     $$
     \mathcal{C}_{ij} = \text{NaN} \quad \text{if } v_i \in \mathcal{T}_k, v_j \in \mathcal{T}_{k+1}
     $$
   - Block edges from $\mathcal{T}_k$ to non-adjacent tiers:
     $$
     \mathcal{C}_{ij} = 0.0 \quad \text{if } v_i \in \mathcal{T}_k, v_j \in \mathcal{T}_\ell, \ell \neq k+1
     $$

5. **Application of Specific Constraints**:
   - For $(v_i, v_j) \in \mathcal{C}_{\text{forbid}}$, set $\mathcal{C}_{ij} = 0.0$.
   - For $(v_i, v_j) \in \mathcal{C}_{\text{allow}}$, set $\mathcal{C}_{ij} = \text{NaN}$.

### Theoretical Justification

The constraint matrix $\mathcal{C}$ guides causal discovery algorithms by eliminating implausible edges and encoding domain knowledge. This framework is especially effective for settings with known partial orderings (e.g., time-series or hierarchical data) or where certain causal mechanisms are theoretically or empirically refuted. By reducing the hypothesis space, it improves both accuracy and interpretability of the lea

In [34]:
# Define a function to create a constraint matrix based on node names, tiers, and specific constraints.
def create_constraint_matrix(
    node_names,
    tiers,
    specific_constraints=None
):

    # Determine the number of nodes.
    num_nodes = len(node_names)

    # Map each node name to its index.
    node_name_to_idx = {
        name: i
        for i, name in enumerate(node_names)
    }

    # Initialize the constraint matrix with NaN values.
    constraint_matrix = np.full(
        (num_nodes, num_nodes),
        np.nan,
        dtype=np.float32
    )

    # Set Churn Label as sink node (no outgoing edges)
    churn_idx = node_name_to_idx.get("Churn Label")
    # Check if Churn Label exists and set its row to forbidden edges.
    if churn_idx is not None:
        constraint_matrix[churn_idx, :] = 0.0

    # Set demographic variables as root nodes (no incoming edges)
    # Iterate over tier 0 features to block incoming edges.
    for feature in tiers[0]:
        if feature in node_name_to_idx:
            feature_idx = node_name_to_idx[feature]
            constraint_matrix[:, feature_idx] = 0.0

    # Prevent edges within Tier 1
    # Iterate over tier 0 pairs to block edges between distinct features.
    for src in tiers[0]:
        for dst in tiers[0]:
            if (
                src != dst and
                src in node_name_to_idx and
                dst in node_name_to_idx
            ):
                constraint_matrix[
                    node_name_to_idx[src],
                    node_name_to_idx[dst]
                ] = 0.0

    # Allow edges only from Tier N to Tier N+1
    # Iterate over tier pairs to allow edges only to the immediate next tier.
    for src_tier_idx, src_tier in enumerate(tiers[:-1]):
        dst_tier = tiers[src_tier_idx + 1]

        # Allow edges from current tier to next tier.
        for src in src_tier:
            for dst in dst_tier:
                if (
                    src in node_name_to_idx and
                    dst in node_name_to_idx
                ):
                    constraint_matrix[
                        node_name_to_idx[src],
                        node_name_to_idx[dst]
                    ] = np.nan

        # Block edges to non-adjacent tiers.
        for other_tier_idx, other_tier in enumerate(tiers):
            if other_tier_idx != src_tier_idx + 1:
                for src in src_tier:
                    for dst in other_tier:
                        if (
                            src in node_name_to_idx and
                            dst in node_name_to_idx
                        ):
                            constraint_matrix[
                                node_name_to_idx[src],
                                node_name_to_idx[dst]
                            ] = 0.0

    # Apply specific constraints
    # Check for forbidden constraints and apply them.
    if specific_constraints:
        for src, dst in specific_constraints.get("forbidden", []):
            if (
                src in node_name_to_idx and
                dst in node_name_to_idx
            ):
                constraint_matrix[
                    node_name_to_idx[src],
                    node_name_to_idx[dst]
                ] = 0.0

        # Check for allowed constraints and apply them.
        for src, dst in specific_constraints.get("allowed", []):
            if (
                src in node_name_to_idx and
                dst in node_name_to_idx
            ):
                constraint_matrix[
                    node_name_to_idx[src],
                    node_name_to_idx[dst]
                ] = np.nan

    # Log the shape of the created constraint matrix.
    logger.info(
        "Constraint matrix created with shape: %s",
        constraint_matrix.shape
    )

    return constraint_matrix, node_name_to_idx

## 4.3. Define the Constraints Validator Function

This diagnostic function evaluates a directed acyclic graph (DAG) or adjacency matrix against domain-imposed constraints. It detects violations such as improper directional edges from or into constrained variables, ensuring the integrity of the causal graph with respect to prior knowledge and structural assumptions.

## Methodology

### Purpose and Context

The constraint validation function is a post-processing utility that evaluates whether a directed acyclic graph (DAG) or adjacency matrix adheres to predefined structural constraints. It is primarily designed to assess the logical consistency of inferred causal graphs against prior knowledge, including tier-based acyclicity, variable roles (e.g., sinks or roots), and domain-specific prohibitions.

### Inputs and Parameters

Let $G = (V, E)$ denote a directed graph representing the learned causal structure. The inputs are:

- $G$: Either a `networkx.DiGraph` or a real-valued adjacency matrix $A \in [0,1]^{d \times d}$.
- $\phi: \mathcal{N} \rightarrow \{1, \dots, d\}$: Mapping from node names to indices.
- $\mathcal{T} = [\mathcal{T}_1, \dots, \mathcal{T}_K]$: Tiered structure over $\mathcal{N}$.
- $\tau$: Threshold for edge inclusion (default $\tau = 0.5$).

### Algorithmic Procedure

1. **Graph Construction**:
   - If $G$ is an adjacency matrix $A$, construct a directed graph:
     $$
     (i, j) \in E \iff A_{ij} > \tau
     $$

2. **Sink Node Constraint (Terminal Label)**:
   - Identify index $i^*$ of "Churn Label".
   - If $\exists j$ such that $(i^*, j) \in E$, record violation:
     $$
     \text{"Churn Label has outgoing edges"}
     $$

3. **Root Node Constraints (Tier 1)**:
   - For each $v_i \in \mathcal{T}_1$, if $\exists j$ such that $(j, i) \in E$, record:
     $$
     \text{"$v_i$ has incoming edges"}
     $$

4. **Intra-Tier Constraints**:
   - For $v_i, v_j \in \mathcal{T}_1$, $i \neq j$, if $(i, j) \in E$, record:
     $$
     \text{"T1↛T1 edge: $v_i \to v_j$"}
     $$

5. **Domain-Specific Forbidden Edges**:
   - Identify the index of "Gender", denoted $g$.
   - For all $v_k \in \mathcal{T}_3 \cup \mathcal{T}_4$ (e.g., service and billing), if $(g, k) \in E$, record:
     $$
     \text{"Gender→$v_k$ edge exists"}
     $$

6. **Logging and Return**:
   - If no violations are detected, log success.
   - Otherwise, log and return the list of violations.

### Theoretical Justification

This validation method operationalizes domain constraints through syntactic graph queries, facilitating systematic assessment of learned graph structures. Tier-based constraints approximate temporal or functional hierarchies, while node-specific rules reflect expert knowledge or known exogenous/endogenous roles. By enforcing these checks post hoc, the method supports both hard and soft constraint-based learning paradigms.

### Outcomes and Limitations

The output is a list of textual descriptions of all constraint violations. This approach provides immediate diagnostic feedback but does not quantify the severity or probabilistic deviation of violations. It is most suitable for hard constraint evaluation and assumes accurate mapping and tier definitions.

In [35]:
# Define a function to validate constraints on a DAG or adjacency matrix.
def validate_constraints(
    dag,
    node_name_to_idx,
    tiers,
    threshold=0.5
):

    violations = []
    num_nodes = len(node_name_to_idx)
    churn_idx = node_name_to_idx.get("Churn Label")

    # Convert adjacency matrix to DAG if needed.
    if isinstance(dag, np.ndarray):
        G = nx.DiGraph()
        G.add_nodes_from(range(num_nodes))

        # Add edges to the graph based on threshold.
        for i in range(num_nodes):
            for j in range(num_nodes):
                if dag[i, j] > threshold:
                    G.add_edge(i, j)
    else:
        G = dag

    # Check if Churn Label has outgoing edges.
    if (
        churn_idx is not None and
        any(G.has_edge(churn_idx, j) for j in range(num_nodes))
    ):
        violations.append("Churn Label has outgoing edges")

    # Check if Tier 1 variables have incoming edges.
    for var in tiers[0]:
        var_idx = node_name_to_idx.get(var)

        if (
            var_idx is not None and
            any(G.has_edge(j, var_idx) for j in range(num_nodes))
        ):
            violations.append(f"{var} has incoming edges")

    # Check for edges within Tier 1.
    for src in tiers[0]:
        src_idx = node_name_to_idx.get(src)

        for dst in tiers[0]:
            dst_idx = node_name_to_idx.get(dst)

            if (
                src != dst and
                src_idx is not None and
                dst_idx is not None and
                G.has_edge(src_idx, dst_idx)
            ):
                violations.append(f"T1↛T1 edge: {src}→{dst}")

    # Check for forbidden Gender → Service/Billing edges.
    gender_idx = node_name_to_idx.get("Gender")

    if gender_idx is not None:
        for dst in tiers[2] + tiers[3]:
            dst_idx = node_name_to_idx.get(dst)

            if (
                dst_idx is not None and
                G.has_edge(gender_idx, dst_idx)
            ):
                violations.append(f"Gender→{dst} edge exists")

    # Log the validation result.
    if not violations:
        logger.info("✅ All constraints validated successfully")
    else:
        logger.warning(
            "❌ Constraint violations detected: %s",
            violations
        )

    return violations

## 4.4. Define the Relation Saver Function

This function exports identified causal relationships—either from probabilistic matrices or NetworkX graphs—to a text file. It facilitates result documentation and interpretability by formatting edge weights and directions, ensuring results are both human-readable and persistently stored.

In [36]:
# Define a function to save causal relationships to a text file.
def save_relations_to_text(
    dag,
    node_names,
    filename,
    threshold=0.2
):
    """
    Save causal relationships to a text file and print to console with improved formatting.

    Args:
        dag: NetworkX DiGraph or np.ndarray (adjacency/probability matrix).
        node_names (list): List of node names.
        filename (str): Output text file name.
        threshold (float): Probability/weight threshold for matrix-based DAGs.
    """
    # Handle potential errors.
    try:
        # Initialize list to hold relationship data.
        relations = []

        print("\n=== Causal Relationships ===")

        # Check if the DAG is a NumPy array.
        if isinstance(dag, np.ndarray):
            # Iterate over matrix rows.
            for i in range(dag.shape[0]):
                # Iterate over matrix columns.
                for j in range(dag.shape[1]):
                    # Add relationship if weight exceeds threshold.
                    if dag[i, j] > threshold:
                        relations.append({
                            "source": node_names[i],
                            "destination": node_names[j],
                            "weight": float(dag[i, j])
                        })

                        print(
                            f"{node_names[i]} -> {node_names[j]} (weight: {dag[i, j]:.3f})"
                        )
                        print("---")

        # Handle case where DAG is a graph object.
        else:
            # Iterate over edges in the graph.
            for src, dst in dag.edges():
                weight = dag[src][dst].get("weight", 1.0)
                relations.append({
                    "source": node_names[src],
                    "destination": node_names[dst],
                    "weight": float(weight)
                })

                print(
                    f"{node_names[src]} -> {node_names[dst]} (weight: {weight:.3f})"
                )
                print("---")

        # Convert relationships to DataFrame.
        relations_df = pd.DataFrame(relations)

        # Write relationships to file if any exist.
        if not relations_df.empty:
            # Open file context.
            with open(filename, "w") as f:
                f.write("Source -> Destination (Weight)\n")
                f.write("=" * 30 + "\n")

                # Iterate over DataFrame rows.
                for _, row in relations_df.iterrows():
                    f.write(
                        f"{row['source']:<25} -> {row['destination']:<25} ({row['weight']:.3f})\n"
                    )
                    f.write("-" * 30 + "\n")

            logger.info(
                "Causal relationships saved to %s (%d relations)",
                filename,
                len(relations_df)
            )

        # Handle case where no relationships are found.
        else:
            # Open file context.
            with open(filename, "w") as f:
                f.write("No causal relationships found.\n")

            logger.warning(
                "No causal relationships to save for %s",
                filename
            )

    # Handle any exception that occurs during execution.
    except Exception as e:
        logger.error(
            "Failed to save relations to %s: %s",
            filename,
            str(e)
        )
        raise

## 4.5. Define the Causal Graph Visualizer Function

A visualization utility that renders causal DAGs using uniform spatial distribution (circular layout) and annotated nodes. This function provides graphical representations of causal structures to assist in exploratory analysis and model verification.

In [37]:
# Define function to visualize a causal graph.
def visualize_causal_graph(
    dag,
    node_names,
    filename="causal_graph.png"
):
    """
    Visualize a causal graph with uniform edge lengths, focusing on nodes and relationships.
    
    Args:
        dag: NetworkX DiGraph or np.ndarray (adjacency matrix).
        node_names (list): List of node names.
        filename (str): Output file name.
    
    Returns:
        nx.DiGraph: Visualized graph.
    """

    # Check if input is a NumPy array and convert to DiGraph if needed.
    if isinstance(dag, np.ndarray):
        G = nx.DiGraph(dag)
    else:
        G = dag

    # Relabel nodes with names.
    G = nx.relabel_nodes(
        G,
        dict(enumerate(node_names))
    )

    # Create a matplotlib figure with specified size and resolution.
    plt.figure(
        figsize=(12, 10),
        dpi=300
    )
    # Use circular layout for uniform edge lengths and clear node placement.
    pos = nx.circular_layout(G)

    # Draw graph nodes.
    nx.draw_networkx_nodes(
        G,
        pos,
        node_size=1000,
        node_color="lightblue",
        alpha=0.8
    )

    # Draw graph edges.
    nx.draw_networkx_edges(
        G,
        pos,
        width=1,
        alpha=0.6,
        arrowsize=15
    )

    # Draw graph labels.
    nx.draw_networkx_labels(
        G,
        pos,
        font_size=8,
        font_weight="bold"
    )

    # Set the plot title and turn off axes.
    plt.title(f"Causal Graph ({filename.split('.')[0]})")
    plt.axis("off")

    # Save the figure to file and close the plot.
    plt.savefig(
        filename,
        format="png",
        bbox_inches="tight"
    )
    plt.close()

    # Log the successful save operation.
    logger.info("Causal graph saved as %s", filename)

    # Return the graph.
    return G

## 4.6. Define Structure Learning Analyzer Function

This function extracts descriptive metrics from a learned DAG or adjacency matrix, including graph density, in/out-degree distributions, and most influential/affected nodes. These metrics enable a deeper understanding of the structural characteristics of the causal model and facilitate comparative evaluations across algorithms.

In [38]:
# Define function to analyze the structure of a learned DAG.
def analyze_structure_learning(
    dag,
    node_names,
    threshold=0.5, # This threshold is primarily for when dag is a numpy array
    output_dir="causal_discovery_output", # Added for future use (e.g., plots)
    algo_prefix="" # Added for future use (e.g., unique plot filenames)
):
    """
    Analyze the structure of a learned DAG.

    Args:
        dag: NetworkX DiGraph or np.ndarray (adjacency/probability matrix).
        node_names (list): List of node names.
        threshold (float): Probability threshold for matrix-based DAGs.
        output_dir (str): Directory to save potential output files.
        algo_prefix (str): Prefix for algorithm specific outputs.

    Returns:
        dict: Structure learning metrics.
    """

    G_for_analysis = None
    adj_matrix_for_analysis = None
    num_nodes = len(node_names)

    if isinstance(dag, np.ndarray):
        # If input is a matrix, threshold it
        adj_matrix_for_analysis = (dag > threshold).astype(int)
        # Create a named graph from this matrix for consistent metric calculation
        G_for_analysis = nx.DiGraph()
        G_for_analysis.add_nodes_from(node_names) # Add nodes with their actual names
        for r_idx in range(num_nodes):
            for c_idx in range(num_nodes):
                if adj_matrix_for_analysis[r_idx, c_idx] != 0:
                    G_for_analysis.add_edge(node_names[r_idx], node_names[c_idx])
    elif isinstance(dag, nx.DiGraph):
        G_for_analysis = dag # Assume dag is the graph to be analyzed (e.g., G_viz_primary_named)
        # Ensure the graph has string node names as expected if it came from visualize_causal_graph
        # Convert G_for_analysis to adj_matrix using node_names as the nodelist
        adj_matrix_for_analysis = nx.to_numpy_array(
            G_for_analysis,
            nodelist=node_names # Use the actual string node names for the nodelist
        )
    else:
        logger.error(f"analyze_structure_learning: Unsupported dag type: {type(dag)}")
        return {} # Return empty metrics on error

    num_edges = G_for_analysis.number_of_edges()

    graph_density = (
        num_edges / (num_nodes * (num_nodes - 1))
        if num_nodes > 1 else 0.0
    )

    # Calculate in-degree and out-degree directly from the named graph
    in_degrees_dict = dict(G_for_analysis.in_degree())
    out_degrees_dict = dict(G_for_analysis.out_degree())

    avg_in_degree = sum(in_degrees_dict.values()) / num_nodes if num_nodes > 0 else 0.0
    avg_out_degree = sum(out_degrees_dict.values()) / num_nodes if num_nodes > 0 else 0.0

    most_influential = sorted(
        out_degrees_dict.items(), # Already (node_name, degree)
        key=lambda item: item[1],
        reverse=True
    )[:5]

    most_affected = sorted(
        in_degrees_dict.items(), # Already (node_name, degree)
        key=lambda item: item[1],
        reverse=True
    )[:5]

    metrics = {
        "num_edges": int(num_edges),
        "graph_density": graph_density,
        "avg_in_degree": avg_in_degree,
        "avg_out_degree": avg_out_degree,
        "most_influential": most_influential,
        "most_affected": most_affected,
    }

    # (Optional: Future addition for centrality plots as discussed before)
    # if num_nodes > 0 and algo_prefix: # Check if algo_prefix is provided
    #     try:
    #         # ... (centrality calculation and plotting code would go here) ...
    #         # Example: metrics['degree_centrality'] = nx.degree_centrality(G_for_analysis)
    #         # ... (code to plot and save centrality bar charts using output_dir and algo_prefix) ...
    #         pass 
    #     except Exception as e_metrics_plot:
    #         logger.error(f"Could not compute/plot centrality metrics for {algo_prefix}: {e_metrics_plot}")

    logger.info(
        "Structure Learning Metrics for %s: %s",
        algo_prefix.upper() if algo_prefix else "Algorithm",
        {k: (f"{v:.3f}" if isinstance(v, float) else v) for k, v in metrics.items() if not isinstance(v, list)}
    )

    return metrics

## 4.7. Define the Causal Evaluator Function

This evaluator quantifies the validity of discovered causal models by measuring constraint violations on test data. It handles data cleaning (e.g., missing values, constant columns) and supports tailored validation schemes (e.g., for LiNGAM's continuous-only structure), returning detailed diagnostic summaries per algorithm.

In [39]:
# Define function to evaluate causal discovery algorithms on test data.
def evaluate_causal_discovery(
    results,
    test_data,
    tiers,
    node_name_to_idx,
    threshold=0.5
):
    """
    Evaluate causal discovery algorithms on test data, focusing on constraint violations.

    Args:
        results (dict): Results from run_causal_discovery_pipeline.
        test_data (pd.DataFrame): Test data.
        tiers (list): List of tier lists.
        node_name_to_idx (dict): Mapping of node names to indices.
        threshold (float): Probability threshold for matrix-based DAGs.

    Returns:
        pd.DataFrame: Evaluation metrics for each algorithm.
    """
    
    # Log the start of evaluation.
    logger.info("Evaluating constraint violations on test data...")
    evaluation_metrics = {}

    # Identify and remove constant columns in test data.
    constant_cols = [
        col for col in test_data.columns
        if test_data[col].std() == 0
    ]

    # Handle constant columns if found.
    if constant_cols:
        logger.warning("Constant columns in test data: %s", constant_cols)
        test_data = test_data.drop(columns=constant_cols)

    # Handle missing values in test data.
    if test_data.isna().any().any():
        logger.warning("NaNs in test data, filling with mean")
        test_data = test_data.fillna(test_data.mean())

    # Iterate over results from each algorithm.
    for algo_name, result in results.items():
        # Skip evaluation if there was an error.
        if "error" in result:
            evaluation_metrics[algo_name] = {
                "constraint_violations": "N/A",
                "violation_details": "N/A"
            }
            continue

        try:
            # Extract DAG and adjacency matrix.
            dag = result["dag"]
            adj_matrix = result["adj_matrix"]

            # Set up node index mapping for LiNGAM if applicable.
            algo_node_name_to_idx = node_name_to_idx

            # Filter for continuous features if algorithm is LiNGAM.
            if algo_name == "LiNGAM":
                continuous_features = [
                    f for f in test_data.columns
                    if f in [
                        "Age",
                        "Number of Dependents",
                        "Number of Referrals",
                        "Tenure in Months",
                        "Total Revenue"
                    ]
                ]

                algo_node_name_to_idx = {
                    name: i for i, name in enumerate(continuous_features)
                }

                # Define tier structure for LiNGAM.
                lingam_tiers = [
                    [f for f in tiers[0] if f in continuous_features],
                    [f for f in tiers[1] if f in continuous_features],
                    [f for f in tiers[3] if f in continuous_features]
                ]
            else:
                lingam_tiers = tiers

            # Validate causal constraints.
            violations = validate_constraints(
                dag if algo_name != "LiNGAM" else adj_matrix,
                algo_node_name_to_idx,
                lingam_tiers,
                threshold
            )

            # Record evaluation metrics.
            evaluation_metrics[algo_name] = {
                "constraint_violations": len(violations),
                "violation_details": violations if violations else "None"
            }

            # Log the number of constraint violations.
            logger.info(
                "%s: %d constraint violations on test data: %s",
                algo_name,
                len(violations),
                violations if violations else "None"
            )

        # Handle potential errors during evaluation.
        except Exception as e:
            logger.error("Evaluation failed for %s: %s", algo_name, str(e))
            evaluation_metrics[algo_name] = {
                "constraint_violations": "N/A",
                "violation_details": "N/A"
            }

    # Create summary DataFrame of evaluation metrics.
    summary = pd.DataFrame(evaluation_metrics).T

    # Log the final summary.
    logger.info("Constraint violation summary:\n%s", summary)

    # Return the summary DataFrame.
    return summary

# 5. Define the Causal Discovery and Inference Functions

This section implements a class-based framework for encapsulating causal discovery algorithms. Each subclass of the CausalDiscoveryAlgorithm base class adheres to a unified fit() interface, allowing for standardized training, evaluation, and output handling across multiple algorithmic strategies.

## 5.1. Define the Causal Discovery Base Class

An abstract base class defining the required interface (fit() method) for all causal discovery algorithm implementations. This abstraction ensures structural uniformity and enforces contract adherence for algorithm-specific subclasses.

In [40]:
# Define an abstract base class for causal discovery algorithms.
class CausalDiscoveryAlgorithm(ABC):
    """Base class for causal discovery algorithms."""

    # Define an abstract method to fit the causal discovery algorithm.
    @abstractmethod
    def fit(
        self,
        data,
        constraint_matrix,
        node_names,
        node_name_to_idx,
        tiers,
        output_dir
    ):
        """
        Fit the causal discovery algorithm.

        Args:
            data (pd.DataFrame): Input data.
            constraint_matrix (np.ndarray): Constraint matrix.
            node_names (list): List of node names.
            node_name_to_idx (dict): Mapping of node names to indices.
            tiers (list): List of tier lists.

        Returns:
            dict: Results including DAG, adjacency matrix, metrics, and violations.
        """
        pass

# 5.2. Define the DECI Algorithm

Implements the DECI (Differentiable Equilibrium-based Causal Inference) method. This complex probabilistic model fits a DAG using deep learning-based techniques and acyclicity constraints. The implementation includes data preparation, constraint integration, model training, graph extraction, evaluation, and visualization.

## Methodology

### Purpose and Context

The DECI (Differentiable Equivariant Causal Inference) algorithm is designed to infer the causal structure of a set of observed variables from data by learning the probabilistic adjacency relations that best explain the observed joint distribution. It utilizes a variational inference framework to approximate the posterior distribution over directed acyclic graphs (DAGs), allowing the recovery of causal relationships under both observational and interventional settings. DECI operates under the structural causal model (SCM) formalism, representing causal mechanisms as functions with additive noise.

### Inputs and Parameters

Let $X \in \mathbb{R}^{n \times d}$ denote the observational dataset with $n$ samples and $d$ variables. The following inputs and parameters define the algorithm:

- $X$: Input dataset, where each column corresponds to a variable.
- $\mathcal{C} \in \{0,1\}^{d \times d}$: Binary constraint matrix indicating known causal restrictions.
- $\mathcal{N}$: A list of variable names.
- $\phi: \mathcal{N} \rightarrow \{1, \dots, d\}$: Mapping from variable names to indices.
- $\mathcal{T}$: Tiered structure of variables specifying acyclicity constraints.
- $\lambda$: Prior sparsity coefficient, penalizing edge density in the DAG.
- $\rho_0$: Initial value for augmented Lagrangian penalty.
- $\alpha_0$: Initial Lagrangian multiplier.
- $\text{AugLagLRConfig}$: Hyperparameters for the augmented Lagrangian optimization.

### Algorithmic Procedure

1. **Variable and Data Initialization**:
   - Parse variable metadata to extract expected variables $\mathcal{V} = \{v_1, \dots, v_d\}$.
   - Validate that $\text{columns}(X) = \mathcal{V}$.
   - Normalize $X$ and construct a data module for mini-batch training.

2. **Model Setup**:
   - Define a DECI module consisting of:
     - A generative model with Gaussian noise: $X = f(PA(X)) + \varepsilon$, $\varepsilon \sim \mathcal{N}(0, \sigma^2 I)$.
     - An adjacency distribution with parameters $\theta_{\text{exist}}$ and $\theta_{\text{orient}}$.
   - Set prior sparsity with $\lambda$ and initialize augmented Lagrangian parameters $(\rho_0, \alpha_0)$.

3. **Variational Inference and Optimization**:
   - Optimize the evidence lower bound (ELBO) using a doubly-stochastic gradient descent scheme.
   - Update parameters to maximize:
     $$
     \mathbb{E}_{q(A)}[\log p(X \mid A, \theta)] - \text{KL}(q(A) \| p(A))
     $$
     subject to tiered acyclicity constraints and sparsity prior.

4. **Adjacency Probability Estimation**:
   - Let $\text{logits}_{\text{exist}}, \text{logits}_{\text{orient}}$ denote logits for edge existence and orientation.
   - Construct a skew-symmetric matrix:
     $$
     \Theta = \text{fill}_{\text{upper}}(\text{logits}_{\text{orient}}) - \text{fill}_{\text{lower}}(\text{logits}_{\text{orient}})
     $$
   - Compute logit-based score matrix:
     $$
     S_{ij} = -\log\left(\exp(-\text{logits}_{\text{exist}}) + \exp(\Theta_{ij}) + \exp(\Theta_{ij} - \text{logits}_{\text{exist}})\right)
     $$
   - Apply sigmoid transformation:
     $$
     P_{ij} = \sigma(S_{ij}) = \frac{1}{1 + \exp(-S_{ij})}
     $$

5. **Constraint Enforcement and Graph Extraction**:
   - Enforce structural constraints: $P = P \odot \text{mask}(\mathcal{C})$.
   - Identify constraint violations using $\mathcal{T}$ and $\phi$.
   - Extract DAG $G = (V, E)$ such that $(i, j) \in E$ iff $P_{ij} > 0.5$.

6. **Visualization and Output**:
   - Generate a graph visualization of the inferred DAG.
   - Export adjacency probabilities and causal metrics to the output directory.

### Theoretical Justification

DECI leverages the principles of variational inference, specifically amortized inference over a DAG space, using continuous relaxation and equivariant parametrization of the adjacency distribution. The optimization employs augmented Lagrangian methods to handle acyclicity constraints, ensuring that the inferred graph adheres to a DAG structure while respecting prior structural knowledge.

### Outcomes and Limitations

The output includes the learned DAG $G$, adjacency matrix $P \in [0,1]^{d \times d}$ representing edge probabilities, a set of structural violations, and evaluation metrics such as precision, recall, and SHD (Structural Hamming Distance). The algorithm scales with $O(d^2)$ in terms of parameter estimation, with performance sensitive to initialization and tier definitions. The continuous relaxation may introduce approximation errors in the DAG structure.

In [41]:
class DECIAlgorithm(CausalDiscoveryAlgorithm):
    def fit(
        self,
        data,
        constraint_matrix,
        node_names,
        node_name_to_idx,
        tiers,
        output_dir
    ):
        logger.info("Running DECI algorithm...")
        try:
            # Load variable type specifications for DECI
            # Ensure VARIABLES_PATH is correctly defined (e.g., "data/variables.json")
            with fsspec.open(VARIABLES_PATH, mode="r", encoding="utf-8") as f:
                variables_json = json.load(f)
                if "variables" not in variables_json:
                    raise KeyError("Key 'variables' not found in variables.json.")
                variables = variables_json["variables"]

            # Validate data columns against variable specifications
            expected_columns = [var["name"] for var in variables]
            if set(expected_columns) != set(data.columns):
                missing = set(expected_columns) - set(data.columns)
                extra = set(data.columns) - set(expected_columns)
                logger.error(f"DECI data columns mismatch. Missing: {missing}, Extra: {extra}")
                raise ValueError("Data columns do not match variables.json.")

            # Prepare PyTorch Lightning DataModule for DECI
            data_module = BasicDECIDataModule(
                data,
                variables=[Variable.from_dict(d) for d in variables],
                batch_size=128,
                normalize=True
            )

            # Initialize DECI Lightning Module with specified hyperparameters
            lightning_module = DECIModule(
                noise_dist=ContinuousNoiseDist.GAUSSIAN,
                prior_sparsity_lambda=100.0,
                init_rho=30.0,
                init_alpha=0.20,
                auglag_config=AugLagLRConfig(
                    max_inner_steps=1500,
                    max_outer_steps=8,
                    lr_init_dict={
                        "icgnn": 0.00076,
                        "vardist": 0.0098,
                        "functional_relationships": 3e-4,
                        "noise_dist": 0.0070
                    }
                )
            )
            # Apply domain knowledge constraints to the DECI module
            lightning_module.constraint_matrix = torch.tensor(constraint_matrix, dtype=torch.float32)

            # Configure and run PyTorch Lightning Trainer
            trainer = pl.Trainer(
                accelerator="gpu" if torch.cuda.is_available() else "cpu",
                devices=1,
                max_epochs=10,
                callbacks=[],  # Empty callbacks as progress bar is disabled
                enable_checkpointing=False,
                enable_progress_bar=False, # Disables training progress bar for cleaner logs
                logger=False # Disables default Pytorch Lightning logger
            )
            trainer.fit(lightning_module, datamodule=data_module)

            # Save the trained DECI model
            model_save_path = os.path.join(output_dir, "deci_model.pt")
            torch.save(lightning_module.sem_module, model_save_path)
            logger.info(f"DECI model saved to {model_save_path}")

            # Extract logits for edge existence and orientation
            logits_exist = lightning_module.sem_module.adjacency_module.adjacency_distribution.logits_exist
            logits_orient = lightning_module.sem_module.adjacency_module.adjacency_distribution.logits_orient

            # Helper to reconstruct matrix from vector of triangular elements
            def fill_triangular(vec, upper=False):
                d = int((-1 + np.sqrt(1 + 8 * len(vec))) / 2) + 1 # Number of nodes
                mat = vec.new_zeros(d, d)
                if upper:
                    mat[np.triu_indices(d, k=1)] = vec
                else:
                    mat[np.tril_indices(d, k=-1)] = vec
                return mat

            # Compute probability matrix from learned logits
            neg_theta = fill_triangular(logits_orient, upper=True) - fill_triangular(logits_orient, upper=False)
            logits_matrix = -torch.logsumexp(
                torch.stack([-logits_exist, neg_theta, neg_theta - logits_exist], dim=-1),
                dim=-1
            )
            prob_matrix = 1 / (1 + np.exp(-logits_matrix.cpu().detach().numpy()))
            prob_matrix = prob_matrix * np.isnan(constraint_matrix) # Enforce hard constraints

            # --- 1. Heatmap of the DECI Probability Matrix ---
            plt.figure(figsize=(14, 12), dpi=300)
            sns.heatmap(prob_matrix, xticklabels=node_names, yticklabels=node_names,
                        cmap="viridis", annot=True, fmt=".2f", annot_kws={"size": 6})
            plt.title("DECI - Estimated Probability Matrix")
            heatmap_filename = os.path.join(output_dir, "deci_prob_matrix_heatmap.png")
            plt.savefig(heatmap_filename, bbox_inches='tight')
            plt.close()
            logger.info(f"DECI probability matrix heatmap saved as {heatmap_filename}")

            # --- 2. Generate Causal Graphs at Different Additional Thresholds ---
            thresholds_to_explore_deci = [0.3, 0.6, 0.7, 0.8]
            logger.info(f"DECI: Generating graphs for varying thresholds: {thresholds_to_explore_deci}")

            for thresh_val in thresholds_to_explore_deci:
                adj_matrix_loop = (prob_matrix > thresh_val).astype(int) # Binarize
                G_loop_for_viz = nx.DiGraph(adj_matrix_loop)
                
                loop_graph_filename = os.path.join(output_dir, f"deci_graph_thresh_{thresh_val:.2f}.png")
                visualize_causal_graph(G_loop_for_viz, node_names, loop_graph_filename)
                logger.info(f"DECI graph for threshold {thresh_val:.2f} saved as {loop_graph_filename}")

                loop_relations_filename = os.path.join(output_dir, f"deci_relations_thresh_{thresh_val:.2f}.txt")
                save_relations_to_text(adj_matrix_loop, node_names, loop_relations_filename, threshold=0.01)

            # --- Primary graph generation and analysis (using a standard 0.5 threshold) ---
            primary_deci_threshold = 0.5
            adj_matrix_primary = (prob_matrix > primary_deci_threshold).astype(int)
            
            violations = validate_constraints(
                adj_matrix_primary, 
                node_name_to_idx,
                tiers,
                threshold=0.0 # adj_matrix_primary is already 0/1
            )

            G_primary_int_indexed = nx.DiGraph(adj_matrix_primary)
            G_viz_primary_named = visualize_causal_graph(
                G_primary_int_indexed, 
                node_names,
                os.path.join(output_dir, "deci_graph_primary.png")
            )

            save_relations_to_text(
                prob_matrix, # Use original probability matrix for relation text (uses its own threshold)
                node_names,
                os.path.join(output_dir, "deci_relations_primary.txt"),
                threshold=0.2 # Default threshold in save_relations_to_text
            )

            metrics = analyze_structure_learning(
                G_viz_primary_named, 
                node_names,
                threshold=0.0, # Graph is already binarized
                output_dir=output_dir,
                algo_prefix="deci"
            )

            return {
                "dag": G_viz_primary_named,
                "adj_matrix": prob_matrix, # Full probability matrix from DECI
                "metrics": metrics,
                "violations": violations
            }

        except Exception as e:
            logger.error("DECI failed: %s", str(e))
            import traceback
            logger.error(traceback.format_exc()) # Provides detailed traceback for debugging
            raise

# 5.3. Define the LiNGAM Algorithm

Implements the DirectLiNGAM (Linear Non-Gaussian Acyclic Model) algorithm tailored for continuous variables. It constructs a tier-aware constraint matrix and identifies linear causal relationships, offering precise control over permissible edge directions and robust validation for structural assumptions.

## Methodology

### Purpose and Context

The Linear Non-Gaussian Acyclic Model (LiNGAM) algorithm is designed to recover the causal structure among a set of continuous variables under the assumption of linear relationships and non-Gaussian noise. Unlike traditional structural equation modeling approaches that often require prior knowledge of the causal ordering, LiNGAM exploits non-Gaussianity to identify a unique directed acyclic graph (DAG) structure without such assumptions. The DirectLiNGAM variant leverages ICA-based principles and prior knowledge constraints to optimize the causal discovery process.

### Inputs and Parameters

Let $X \in \mathbb{R}^{n \times d}$ be a data matrix of $n$ observations over $d$ continuous variables. The algorithm requires the following inputs:

- $X$: Filtered data matrix of continuous-valued features.
- $\mathcal{C} \in \{-1, 0\}^{d \times d}$: Prior knowledge matrix where:
  - $-1$ indicates an unknown causal relation,
  - $0$ forbids a causal relation.
- $\mathcal{T}$: A tiered structure on variables indicating causal ordering constraints.
- $\phi: \mathcal{N} \rightarrow \{1, \dots, d\}$: Mapping from feature names to indices.
- $\mathcal{N}$: A list of continuous variable names.

### Algorithmic Procedure

1. **Feature Filtering**:
   - Select a subset $\mathcal{F} \subseteq \mathcal{N}$ of continuous features suitable for linear causal modeling.
   - Remove constant-valued features to ensure identifiability.

2. **Constraint Matrix Construction**:
   - Initialize $\mathcal{C}$ as a $d \times d$ matrix with all entries set to $-1$ (unknown).
   - Enforce domain-specific constraints:
     - For instance, enforce $C_{i,:} = 0$ for variables in tier 0 or fixed exogenous variables.
     - Disallow edges that violate the temporal or tiered structure by setting $C_{ij} = 0$.
     - Maintain acyclicity and remove self-causal loops.

3. **Model Fitting**:
   - Apply the DirectLiNGAM algorithm using the constraint matrix $\mathcal{C}$.
   - DirectLiNGAM assumes the data generation model:
     $$
     X = B X + \varepsilon
     $$
     where $B$ is a strictly lower-triangular matrix representing causal coefficients and $\varepsilon$ are independent non-Gaussian error terms.
   - The model estimates $B$ such that:
     - $B_{ij} \neq 0 \implies$ variable $j$ is a direct cause of variable $i$.
     - $B$ respects the ordering induced by non-Gaussianity and enforced constraints.

4. **Adjacency Matrix Extraction**:
   - Construct the adjacency matrix $A \in \{0,1\}^{d \times d}$ where $A_{ij} = 1$ iff $B_{ij} \neq 0$.

5. **Constraint Validation**:
   - Validate $A$ against the tiered structure $\mathcal{T}$ to identify violations of assumed causal directions.

6. **Graph Construction and Visualization**:
   - Generate a directed graph $G = (V, E)$ with $V = \mathcal{F}$ and $(i, j) \in E$ iff $A_{ij} = 1$.
   - Visualize and save the graph and adjacency matrix to disk for interpretability and further analysis.

### Theoretical Justification

LiNGAM rests on the identifiability of linear non-Gaussian acyclic models. Under the assumption of linear structural equations and statistically independent, non-Gaussian noise, the causal ordering and structure are identifiable from observational data—a result not guaranteed in Gaussian settings. The DirectLiNGAM algorithm employs ICA-based estimation techniques and exploits second-order statistics to determine the causal ordering, then estimates the structural coefficients via regression.

### Outcomes and Limitations

The algorithm outputs a DAG $G$, an adjacency matrix $A$, a list of tier violations, and structural quality metrics (e.g., SHD, precision, recall). Its computational complexity is dominated by ICA and regression steps, typically scaling as $O(d^3)$ for $d$ variables. Limitations include its restriction to linear causal relationships and its reliance on non-Gaussian noise for identifiability. It also does not handle latent confounders or feedback loops.

In [42]:
class LiNGAMAlgorithm(CausalDiscoveryAlgorithm):
    def fit(
        self,
        data,
        constraint_matrix, # Main constraint matrix; LiNGAM will use a filtered version
        node_names,        # Full list of node names
        node_name_to_idx,  # Mapping for full list
        tiers,             # Tiers for full list
        output_dir         # Algorithm-specific sub-directory, e.g., "causal_discovery_output/LiNGAM"
    ):
        logger.info("Running LiNGAM algorithm...")
        try:
            # --- LiNGAM Specific Data Preparation ---
            # LiNGAM operates on continuous (numeric) features only.
            # Define the list of numeric features intended for LiNGAM (e.g., from data exploration phase).
            # Example: numeric_features_for_lingam = ["Age", "Number of Dependents", ..., "Total Revenue"]
            # Here, we'll use a predefined list and filter by columns present in the current dataset.
            potential_lingam_features = ["Age", "Number of Dependents", "Number of Referrals", "Tenure in Months", "Total Revenue"]
            current_numeric_features = [nf for nf in potential_lingam_features if nf in node_names]

            if not current_numeric_features:
                logger.error("LiNGAM: No pre-defined numeric features found in the dataset.")
                return {"dag": nx.DiGraph(), "adj_matrix": np.array([]), "metrics": {}, "violations": ["No numeric data"], "error": "No numeric data"}

            lingam_data = data[current_numeric_features].copy()

            # Remove constant columns from LiNGAM's input data as they provide no variance
            constant_cols = [col for col in lingam_data.columns if lingam_data[col].std() == 0]
            if constant_cols:
                logger.warning(f"LiNGAM: Removing constant columns: {constant_cols}")
                lingam_data = lingam_data.drop(columns=constant_cols)
                current_numeric_features = [f for f in current_numeric_features if f not in constant_cols]

            # Ensure sufficient features remain after filtering
            if lingam_data.empty or lingam_data.shape[1] < 2:
                logger.error("LiNGAM: Not enough valid features to run after preprocessing.")
                return {"dag": nx.DiGraph(), "adj_matrix": np.array([]), "metrics": {}, "violations": ["Insufficient features"], "error": "Insufficient features"}

            # Handle NaNs by mean imputation for LiNGAM
            if lingam_data.isna().any().any():
                logger.warning("LiNGAM: Data contains NaNs. Filling with column means.")
                lingam_data = lingam_data.fillna(lingam_data.mean())

            # --- LiNGAM Specific Constraint Matrix Setup ---
            # Map selected continuous feature names to their new indices for LiNGAM
            lingam_node_to_idx_map = {name: i for i, name in enumerate(current_numeric_features)}
            
            # Initialize LiNGAM's prior_knowledge matrix (-1: unknown, 0: forbidden, 1: required)
            lingam_prior_knowledge = np.full(
                (len(current_numeric_features), len(current_numeric_features)), -1, dtype=np.int32
            )

            # Adapt general tiers to the subset of features used by LiNGAM
            lingam_tiers_filtered = []
            for tier_group in tiers:
                filtered_tier_group = [f for f in tier_group if f in current_numeric_features]
                if filtered_tier_group:
                    lingam_tiers_filtered.append(filtered_tier_group)
            
            # Apply tier-based constraints to LiNGAM's prior_knowledge matrix
            if lingam_tiers_filtered and lingam_tiers_filtered[0]: # Tier 0 (e.g., demographics)
                for feature_in_tier0 in lingam_tiers_filtered[0]:
                    if feature_in_tier0 in lingam_node_to_idx_map:
                        idx = lingam_node_to_idx_map[feature_in_tier0]
                        lingam_prior_knowledge[:, idx] = 0  # No incoming edges to Tier 0 features
                        for other_feature_in_tier0 in lingam_tiers_filtered[0]: # No edges within Tier 0
                            if feature_in_tier0 != other_feature_in_tier0 and other_feature_in_tier0 in lingam_node_to_idx_map:
                                other_idx = lingam_node_to_idx_map[other_feature_in_tier0]
                                lingam_prior_knowledge[idx, other_idx] = 0
            
            # Example: If "Total Revenue" is a known sink among continuous variables for LiNGAM
            if "Total Revenue" in lingam_node_to_idx_map:
                tr_idx = lingam_node_to_idx_map["Total Revenue"]
                lingam_prior_knowledge[tr_idx, :] = 0 # No outgoing edges from Total Revenue

            # Fit the DirectLiNGAM model
            model = lingam.DirectLiNGAM(prior_knowledge=lingam_prior_knowledge)
            model.fit(lingam_data)
            
            # Extract the adjacency matrix containing causal coefficients
            adj_matrix_coeffs = model.adjacency_matrix_

            # --- 1. Heatmap of LiNGAM Coefficients ---
            heatmap_fig_size = (max(8, len(current_numeric_features) + 2), max(6, len(current_numeric_features)))
            plt.figure(figsize=heatmap_fig_size, dpi=300)
            sns.heatmap(adj_matrix_coeffs, xticklabels=current_numeric_features, yticklabels=current_numeric_features,
                        cmap="vlag", center=0, annot=True, fmt=".2f", annot_kws={"size": 8})
            plt.title("LiNGAM - Estimated Adjacency Matrix (Causal Coefficients)")
            heatmap_filename = os.path.join(output_dir, "lingam_adj_matrix_heatmap.png")
            plt.savefig(heatmap_filename, bbox_inches='tight')
            plt.close()
            logger.info(f"LiNGAM coefficient matrix heatmap saved as {heatmap_filename}")

            # --- 2. Generate Causal Graphs at Different Coefficient Thresholds ---
            thresholds_to_explore_lingam = [0.1, 0.3, 0.5, 0.7] # Thresholds for absolute coefficient values
            logger.info(f"LiNGAM: Generating graphs for varying thresholds on absolute coefficients: {thresholds_to_explore_lingam}")

            for thresh_val in thresholds_to_explore_lingam:
                # Threshold based on absolute coefficient strength
                adj_matrix_loop = np.where(np.abs(adj_matrix_coeffs) > thresh_val, adj_matrix_coeffs, 0)
                G_loop_for_viz = nx.DiGraph(adj_matrix_loop) # Integer-indexed for visualization function
                
                loop_graph_filename = os.path.join(output_dir, f"lingam_graph_thresh_{thresh_val:.2f}.png")
                visualize_causal_graph(G_loop_for_viz, current_numeric_features, loop_graph_filename)
                logger.info(f"LiNGAM graph for threshold {thresh_val:.2f} saved as {loop_graph_filename}")

                loop_relations_filename = os.path.join(output_dir, f"lingam_relations_thresh_{thresh_val:.2f}.txt")
                save_relations_to_text(adj_matrix_loop, current_numeric_features, loop_relations_filename, threshold=0.001) # adj_matrix_loop is already 0/1 effectively

            # --- Primary Graph Generation & Analysis (based on non-zero LiNGAM coefficients) ---
            # For LiNGAM, an edge exists if its corresponding coefficient is non-zero.
            violations = validate_constraints(
                adj_matrix_coeffs,        # Validate using the raw coefficient matrix
                lingam_node_to_idx_map,   # LiNGAM-specific node to index map
                lingam_tiers_filtered,    # LiNGAM-specific tiers
                threshold=1e-9            # Consider any non-zero coefficient as an edge
            )

            # Create the primary graph structure (integer-indexed)
            G_primary_int_indexed = nx.DiGraph()
            G_primary_int_indexed.add_nodes_from(range(len(current_numeric_features)))
            for i in range(len(current_numeric_features)):
                for j in range(len(current_numeric_features)):
                    if adj_matrix_coeffs[i, j] != 0:
                        G_primary_int_indexed.add_edge(i, j, weight=adj_matrix_coeffs[i, j])
            
            # Visualize and get the named primary graph
            G_viz_primary_named = visualize_causal_graph(
                G_primary_int_indexed,
                current_numeric_features,
                os.path.join(output_dir, "lingam_graph_primary.png")
            )

            # Save primary relations based on actual coefficients
            save_relations_to_text(
                adj_matrix_coeffs, 
                current_numeric_features,
                os.path.join(output_dir, "lingam_relations_primary.txt"),
                threshold=0.001 # Save any non-negligible coefficient
            )
            
            # Analyze the structure of the primary graph
            metrics = analyze_structure_learning(
                G_viz_primary_named, 
                current_numeric_features,
                threshold=0.0, # Graph structure is already determined by non-zero check
                output_dir=output_dir,
                algo_prefix="lingam"
            )

            return {
                "dag": G_viz_primary_named,       # The named primary graph
                "adj_matrix": adj_matrix_coeffs,  # The raw coefficient matrix from LiNGAM
                "metrics": metrics,
                "violations": violations
            }

        except Exception as e:
            logger.error("LiNGAM failed: %s", str(e))
            import traceback
            logger.error(traceback.format_exc()) # Detailed traceback for debugging
            raise

# 5.4. Define the PC-GIN Algorithm

A constraint-aware extension of the PC algorithm using the GIN (Generalized Independence) test. This variant handles both categorical and numerical data through encoding and residual-based independence testing. It builds skeletons, applies orientation rules, and outputs DAGs respecting prior constraints.

## Methodology

### Purpose and Context

The PC-GIN (Peter-Clark with Generalized Independence with Noise) algorithm is a constraint-based method for causal discovery that extends the classical PC algorithm by integrating a regression-based residual independence test—termed the Generalized Independence with Noise (GIN) test. This test evaluates conditional independence through residual correlations after regressing on conditioning sets, enabling causal structure discovery from mixed-type data, including encoded categorical features.

### Inputs and Parameters

Let $X \in \mathbb{R}^{n \times d}$ denote the dataset of $n$ observations over $d$ variables, where each column may represent either continuous or categorical variables (appropriately encoded). The following inputs are utilized:

- $X$: Input data matrix with preprocessed and encoded features.
- $\mathcal{C} \in \{-1, 0\}^{d \times d}$: Constraint matrix, where:
  - $-1$ denotes unknown or unconstrained relationships,
  - $0$ forbids a causal edge from variable $i$ to $j$.
- $\alpha$: Significance level for statistical independence testing (default $\alpha = 0.05$).
- $\mathcal{T}$: Tier structure defining partial ordering constraints among variables.
- $\phi: \mathcal{N} \rightarrow \{1, \dots, d\}$: Mapping from variable names to indices.
- $\mathcal{N}$: List of variable names.

### Algorithmic Procedure

1. **Data Encoding**:
   - Categorical variables are encoded using integer encoding.
   - The dataset $X$ is transformed accordingly.

2. **GIN Conditional Independence Test**:
   - For variables $X_i$, $X_j$ conditioned on a set $Z$, fit:
     - $X_i = \beta_{Z \rightarrow i} Z + \varepsilon_i$
     - $X_j = \beta_{Z \rightarrow j} Z + \varepsilon_j$
   - Evaluate the correlation $\rho(\varepsilon_i, \varepsilon_j)$.
   - Accept $X_i \perp X_j \mid Z$ if the $p$-value from Pearson correlation exceeds $\alpha$.

3. **Skeleton Construction**:
   - Initialize a complete undirected graph with edges $(i, j)$ iff $\mathcal{C}_{ij} = -1$ or $\mathcal{C}_{ji} = -1$.
   - Iteratively remove edges based on conditional independence tests using conditioning sets of increasing size $d = 0, 1, 2, \dots$:
     $$
     \text{If } p > \alpha \Rightarrow \text{remove edge } (i, j)
     $$
   - Maintain a record of separating sets $\text{Sep}(i, j)$.

4. **Initial DAG Construction**:
   - For each remaining undirected edge $(i, j)$:
     - If $\mathcal{C}_{ij} = -1$ and $\mathcal{C}_{ji} = 0$, orient $i \to j$.
     - If $\mathcal{C}_{ji} = -1$ and $\mathcal{C}_{ij} = 0$, orient $j \to i$.
     - If both directions are unconstrained, include both $i \to j$ and $j \to i$ provisionally.

5. **V-Structure Orientation**:
   - For triplets $(i, j, k)$:
     - If $i$ and $k$ are not adjacent, and both have edges to $j$:
       - If $j \notin \text{Sep}(i, k)$, orient as $i \to j \leftarrow k$.

6. **Conflict Resolution**:
   - Eliminate contradictory edges using constraints $\mathcal{C}$.
   - If both $i \to j$ and $j \to i$ exist, retain only the edge consistent with $\mathcal{C}$ or remove arbitrarily if both directions are allowed.

### Theoretical Justification

PC-GIN integrates the classical PC algorithm, which is asymptotically correct under the causal Markov and faithfulness assumptions, with a residual-based conditional independence test. The GIN test approximates conditional independence by comparing residuals after linear regressions, thereby accommodating some nonlinear and non-Gaussian features in a linear testing framework. The algorithm exploits sparsity in the conditional independence graph to limit the size of conditioning sets, reducing computational burden.

### Outcomes and Limitations

The output includes a DAG $G$, an adjacency matrix $A \in \{0, 1\}^{d \

In [43]:
# Define the PCGINAlgorithm class.
class PCGINAlgorithm(CausalDiscoveryAlgorithm):

    # Define the fit method for the algorithm.
    def fit(
        self,
        data,
        constraint_matrix,
        node_names,
        node_name_to_idx,
        tiers,
        output_dir
    ):
        logger.info("Running PC-GIN algorithm...")

        # Handle potential errors.
        try:
            # Encode categorical columns.
            categorical_cols = [
                col for col in data.columns if col in [
                    "Gender",
                    "Internet Type",
                    "Offer",
                    "Payment Method"
                ]
            ]

            encoded_data = data.copy()

            # Encode each categorical column.
            for col in categorical_cols:
                le = LabelEncoder()
                encoded_data[col] = le.fit_transform(
                    encoded_data[col].astype(str)
                )

            # Define the GIN conditional independence test function.
            def gin_test(X, Y, Z=None, alpha=0.05):
                n = len(X)

                # Check independence without conditioning variables.
                if Z is None or Z.shape[1] == 0:
                    corr, p_value = stats.pearsonr(X, Y)
                    return p_value

                # Fit regression model for X given Z.
                model_x = LinearRegression().fit(
                    Z,
                    X
                )
                residuals_x = X - model_x.predict(Z)

                # Fit regression model for Y given Z.
                model_y = LinearRegression().fit(
                    Z,
                    Y
                )
                residuals_y = Y - model_y.predict(Z)

                # Compute correlation of residuals.
                corr, p_value = stats.pearsonr(
                    residuals_x,
                    residuals_y
                )

                return p_value

            # Define the PC-GIN algorithm function.
            def pc_gin(data, constraint_matrix, alpha=0.01):
                n = data.shape[1]

                # Initialize the skeleton graph.
                skeleton = nx.Graph()
                skeleton.add_nodes_from(range(n))

                separating_sets = {}

                # Add initial edges based on constraint matrix.
                for i in range(n):
                    for j in range(i + 1, n):
                        if (
                            np.isnan(constraint_matrix[i, j])
                            or np.isnan(constraint_matrix[j, i])
                        ):
                            skeleton.add_edge(i, j)

                # Remove edges based on conditional independence tests.
                for d in range(n):
                    edges = list(skeleton.edges())

                    for i, j in edges:
                        if not skeleton.has_edge(i, j):
                            continue

                        adj_i = set(skeleton.neighbors(i)) - {j}

                        if len(adj_i) >= d:
                            for subset in itertools.combinations(adj_i, d):
                                subset_list = list(subset)

                                conditioning_set = (
                                    data[:, subset_list]
                                    if subset_list else None
                                )

                                p_val = gin_test(
                                    data[:, i],
                                    data[:, j],
                                    conditioning_set.reshape(data.shape[0], -1)
                                    if conditioning_set is not None else None,
                                    alpha=alpha
                                )

                                if p_val > alpha:
                                    skeleton.remove_edge(i, j)
                                    separating_sets[(i, j)] = subset
                                    separating_sets[(j, i)] = subset
                                    break

                # Create a directed acyclic graph (DAG) from the skeleton.
                dag = nx.DiGraph()
                dag.add_nodes_from(range(n))

                for i, j in skeleton.edges():
                    if (
                        np.isnan(constraint_matrix[i, j])
                        and not np.isnan(constraint_matrix[j, i])
                    ):
                        dag.add_edge(i, j)

                    elif (
                        np.isnan(constraint_matrix[j, i])
                        and not np.isnan(constraint_matrix[i, j])
                    ):
                        dag.add_edge(j, i)

                    else:
                        dag.add_edge(i, j)
                        dag.add_edge(j, i)

                # Orient edges based on separating sets.
                for i in range(n):
                    for j in range(n):
                        if i == j or not dag.has_edge(i, j):
                            continue

                        for k in range(n):
                            if k == i or k == j:
                                continue

                            if dag.has_edge(k, j) and not skeleton.has_edge(i, k):
                                if (
                                    (i, k) in separating_sets
                                    and j not in separating_sets[(i, k)]
                                ) or (
                                    (k, i) in separating_sets
                                    and j not in separating_sets[(k, i)]
                                ):
                                    if dag.has_edge(j, i):
                                        dag.remove_edge(j, i)

                                    if dag.has_edge(j, k):
                                        dag.remove_edge(j, k)

                # Remove conflicting edges based on constraint matrix.
                for i, j in list(dag.edges()):
                    if dag.has_edge(j, i):
                        if (
                            np.isnan(constraint_matrix[i, j])
                            and not np.isnan(constraint_matrix[j, i])
                        ):
                            dag.remove_edge(j, i)

                        elif (
                            np.isnan(constraint_matrix[j, i])
                            and not np.isnan(constraint_matrix[i, j])
                        ):
                            dag.remove_edge(i, j)

                        else:
                            dag.remove_edge(j, i)

                return dag

            # Run the PC-GIN algorithm.
            dag = pc_gin(
                encoded_data.values,
                constraint_matrix,
                alpha=0.05
            )

            # Create adjacency matrix from DAG.
            adj_matrix = nx.to_numpy_array(
                dag,
                nodelist=range(len(node_names))
            )

            # Validate constraints.
            violations = validate_constraints(
                dag,
                node_name_to_idx,
                tiers
            )

            # Visualize and save results.
            G_viz = visualize_causal_graph(
                dag,
                node_names,
                os.path.join(
                    output_dir,
                    "pcgin_graph.png"
                )
            )

            save_relations_to_text(
                dag,
                node_names,
                os.path.join(
                    output_dir,
                    "pcgin_relations.txt"
                )
            )

            # Analyze structure.
            metrics = analyze_structure_learning(
                dag,
                node_names
            )

            return {
                "dag": dag,
                "adj_matrix": adj_matrix,
                "metrics": metrics,
                "violations": violations
            }

        # Log and raise exception if an error occurs.
        except Exception as e:
            logger.error(
                "PC-GIN failed: %s",
                str(e)
            )
            raise

# 5.5. Define the NOTEARS Algorithm

Integrates NOTEARS (Non-combinatorial Optimization via Trace Exponential and Augmented lagRangian for Structure learning) with soft constraint enforcement. The implementation focuses on constrained matrix optimization to produce acyclic structures, followed by graph conversion, evaluation, and visualization.

## Methodology

### Purpose and Context

The NOTEARS (Non-combinatorial Optimization via Trace Exponential and Augmented lagRangian for Structure learning) algorithm is a continuous optimization approach to learning a directed acyclic graph (DAG) from observational data. It reformulates the combinatorial structure learning problem into a smooth constrained optimization problem, enabling the use of gradient-based solvers. The key innovation lies in enforcing the acyclicity constraint via a differentiable function, allowing efficient optimization over weighted adjacency matrices.

### Inputs and Parameters

Given a standardized data matrix $X \in \mathbb{R}^{n \times d}$ with $n$ samples and $d$ variables, the algorithm utilizes the following:

- $X$: Standardized input data matrix.
- $\mathcal{C} \in \{0, 1, \text{NaN}\}^{d \times d}$: Constraint matrix where:
  - $0$ forbids an edge,
  - $\text{NaN}$ allows edge inference.
- $\lambda_1$: Regularization parameter for L1 penalty.
- $h_{\text{tol}}$: Tolerance for acyclicity constraint.
- $\rho_{\max}$: Maximum penalty for augmented Lagrangian.
- $w_{\text{thresh}}$: Threshold for edge weight post-processing.

### Algorithmic Procedure

1. **Data Standardization**:
   - Normalize each column of $X$ to zero mean and unit variance:
     $$
     X_{\text{standardized}} = \frac{X - \mu}{\sigma}
     $$

2. **Acyclicity Constraint Function**:
   - For weight matrix $W \in \mathbb{R}^{d \times d}$, define:
     $$
     h(W) = \text{tr} \left( \exp(W \circ W / d) \right) - d
     $$
     where $\circ$ denotes the Hadamard product. This function satisfies $h(W) = 0$ iff $W$ corresponds to a DAG.

3. **Objective Function**:
   - Define the loss as:
     $$
     \mathcal{L}(W) = \frac{1}{2n} \| X - XW \|_F^2 + \lambda_1 \| W \|_1
     $$
     subject to $h(W) = 0$ and $W_{ij} = 0$ if $\mathcal{C}_{ij} = 0$.

4. **Optimization via Augmented Lagrangian**:
   - Solve:
     $$
     \min_W \mathcal{L}(W) + \frac{\rho}{2} h(W)^2 + \alpha h(W)
     $$
     using L-BFGS-B method with updates to penalty parameter $\rho$ and multiplier $\alpha$:
     - If $|h(W^{(t)})| > \epsilon$, increase $\rho$.
     - Update $\alpha \leftarrow \alpha + \rho h(W^{(t)})$.

5. **Post-Processing**:
   - Threshold weights to induce sparsity:
     $$
     W_{ij} = 0 \text{ if } |W_{ij}| < w_{\text{thresh}}
     $$
   - Construct graph $G = (V, E)$ with edges from non-zero entries in $W$.

6. **Cycle Removal**:
   - Iteratively remove the weakest edge in detected cycles to ensure acyclicity:
     - Identify a cycle $C \subset G$.
     - Remove edge $(u, v) \in C$ with minimal $|W_{uv}|$.

### Theoretical Justification

NOTEARS is grounded in a novel acyclicity characterization based on matrix exponential properties. By expressing the DAG constraint as a smooth function, the method enables direct optimization over real-valued adjacency matrices without resorting to discrete search. The convergence of the augmented Lagrangian method ensures satisfaction of the acyclicity constraint under mild assumptions. L1 regularization induces sparsity, promoting interpretability of the resulting causal graph.

### Outcomes and Limitations

The algorithm outputs a DAG $G$, adjacency matrix $W \in \mathbb{R}^{d \times d}$, tier violation list, and structural metrics (e.g., SHD, precision, recall). Its computational complexity is dominated by matrix operations and L-BFGS optimization, typically $O(d^3)$ per iteration. NOTEARS assumes linear relationships and may be sensitive to noise or model misspecification. Additionally, soft constraint enforcement may yield approximate DAGs requiring post-hoc cycle corrections.

In [44]:
# Ensure these are imported at the top of your notebook:
# import matplotlib.pyplot as plt # Already in cell execution_count: 149
# import seaborn as sns # You've added this at the start of the class definition
# import os
# import numpy as np
# import pandas as pd
# import networkx as nx
# from sklearn.preprocessing import StandardScaler
# from causalnex.structure import notears # Already imported with the class
# from causalnex.structure.structuremodel import StructureModel # Already imported with the class
# Assuming CausalDiscoveryAlgorithm, visualize_causal_graph, save_relations_to_text,
# analyze_structure_learning, validate_constraints, and logger are defined elsewhere.

class NOTEARSAlgorithm(CausalDiscoveryAlgorithm):
    def fit(self, data, constraint_matrix, node_names, node_name_to_idx, tiers, output_dir):
        logger.info("Running NOTEARS algorithm using CausalNex...")
        try:
            # Prepare DataFrame for NOTEARS
            if isinstance(data, pd.DataFrame):
                df = data.copy()
            else:
                df = pd.DataFrame(data, columns=node_names)

            # Standardize data column-wise, handling zero-variance columns
            scaler = StandardScaler()
            df_standardized = pd.DataFrame(index=df.index, columns=df.columns, dtype=float)

            for col in df.columns:
                if df[col].std() == 0:
                    logger.warning(
                        f"NOTEARS: Column '{col}' has zero std. deviation. Standardized values set to 0."
                    )
                    df_standardized[col] = 0.0
                else:
                    df_standardized[col] = scaler.fit_transform(df[[col]]).flatten()
            
            # Final check to ensure all data is finite for CausalNex
            if not np.all(np.isfinite(df_standardized.values)):
                logger.warning(
                    "NOTEARS: Non-finite values in standardized data. Replacing with 0."
                )
                df_standardized.replace([np.inf, -np.inf, np.nan], 0, inplace=True)

            # Define NOTEARS hyperparameters
            max_iter = 200
            h_tol = 1e-8
            w_threshold = 0.2  # Primary threshold for CausalNex's NOTEARS edge pruning

            # Prepare tabu edges (forbidden edges) from the constraint matrix
            tabu_edges = []
            if constraint_matrix is not None:
                for i in range(constraint_matrix.shape[0]):
                    for j in range(constraint_matrix.shape[1]):
                        if constraint_matrix[i, j] == 0.0: # 0.0 indicates a forbidden edge
                            tabu_edges.append((node_names[i], node_names[j]))

            logger.info(f"NOTEARS: Running with {len(tabu_edges)} tabu edges...")
            
            # Learn causal structure using CausalNex's NOTEARS implementation
            structure_model = notears.from_pandas(
                df_standardized,
                max_iter=max_iter,
                h_tol=h_tol,
                w_threshold=w_threshold, # CausalNex uses this to prune edges in the returned model
                tabu_edges=tabu_edges if tabu_edges else None
            )
            
            # Extract the weighted adjacency matrix from the learned CausalNex model
            # This reflects edges and weights after CausalNex's internal thresholding.
            initial_weighted_adj_matrix_df = nx.to_pandas_adjacency(
                structure_model, nodelist=node_names, weight="weight"
            )
            initial_weighted_adj_matrix = initial_weighted_adj_matrix_df.fillna(0).values

            # --- 1. Heatmap of the Learned Weighted Adjacency Matrix ---
            plt.figure(figsize=(14, 12), dpi=300)
            sns.heatmap(initial_weighted_adj_matrix, xticklabels=node_names, yticklabels=node_names, 
                        cmap="viridis", annot=True, fmt=".2f", annot_kws={"size": 6})
            plt.title("NOTEARS - Learned Weighted Adjacency Matrix (after CausalNex thresholding)")
            heatmap_filename = os.path.join(output_dir, "notears_adj_matrix_heatmap.png")
            plt.savefig(heatmap_filename, bbox_inches='tight')
            plt.close()
            logger.info(f"NOTEARS adjacency matrix heatmap saved as {heatmap_filename}")

            # --- 2. Generate Causal Graphs at Different Additional Thresholds ---
            # These thresholds are applied on top of the 'initial_weighted_adj_matrix'.
            thresholds_to_explore = [0.05, 0.15, 0.25, 0.3] # Customize as needed
            logger.info(f"NOTEARS: Generating graphs for varying additional thresholds: {thresholds_to_explore}")

            for thresh_val in thresholds_to_explore:
                # Apply current additional threshold to the (absolute) weights
                adj_matrix_loop = np.where(
                    np.abs(initial_weighted_adj_matrix) > thresh_val, 
                    initial_weighted_adj_matrix, 
                    0
                )
                G_loop_for_viz = nx.DiGraph(adj_matrix_loop) # Integer-indexed graph
                
                loop_graph_filename = os.path.join(output_dir, f"notears_graph_additional_thresh_{thresh_val:.2f}.png")
                visualize_causal_graph(G_loop_for_viz, node_names, loop_graph_filename)
                logger.info(f"NOTEARS graph for additional threshold {thresh_val:.2f} saved as {loop_graph_filename}")

                loop_relations_filename = os.path.join(output_dir, f"notears_relations_additional_thresh_{thresh_val:.2f}.txt")
                save_relations_to_text(adj_matrix_loop, node_names, loop_relations_filename, threshold=0.01)

            # --- Primary Graph Generation & Analysis (using the main w_threshold) ---
            # This ensures the primary output is based on the main CausalNex threshold consistently.
            adj_matrix_est = np.where(
                np.abs(initial_weighted_adj_matrix) > w_threshold, 
                initial_weighted_adj_matrix, 
                0
            )
            
            violations = validate_constraints(
                adj_matrix_est, 
                node_name_to_idx, 
                tiers, 
                threshold=w_threshold / 2 # Threshold for validate_constraints (if it expects probabilities)
            )
            
            G_primary_int_indexed = nx.DiGraph(adj_matrix_est) # Integer-indexed graph
            
            # Visualize and save the primary graph (visualize_causal_graph returns the named graph)
            G_viz_primary_named = visualize_causal_graph(
                G_primary_int_indexed, 
                node_names, 
                os.path.join(output_dir, "notears_graph_primary.png")
            )
            
            # Save primary relations (save_relations_to_text uses its internal threshold)
            save_relations_to_text(
                adj_matrix_est, 
                node_names, 
                os.path.join(output_dir, "notears_relations_primary.txt"), 
                threshold=w_threshold / 2 
            )
            
            # Analyze the structure of the primary graph
            metrics = analyze_structure_learning(
                G_viz_primary_named, # Pass the named graph
                node_names, 
                threshold=0.0, # Graph is already binarized by w_threshold
                output_dir=output_dir,
                algo_prefix="notears"
            )
            
            logger.info(f"NOTEARS completed. Found {np.sum(adj_matrix_est != 0)} edges for primary graph.")
            
            return {
                "dag": G_viz_primary_named,          # The primary named graph
                "adj_matrix": adj_matrix_est,        # The primary thresholded adjacency matrix
                "structure_model": structure_model,  # Original CausalNex model
                "metrics": metrics,
                "violations": violations
            }

        except Exception as e:
            logger.error("NOTEARS (CausalNex) failed: %s", str(e))
            import traceback
            logger.error(traceback.format_exc()) # Detailed traceback for debugging
            raise

# 5.6. Define the GRaSP Algorithm

Implements the GRaSP (Gradient-based Regularized Structure learning with Penalties) algorithm using a custom acyclicity-constrained optimization scheme. The method identifies sparse DAGs by minimizing loss and penalty terms over standardized data, respecting constraint masks throughout optimization.

## Methodology

### Purpose and Context

The GRaSP (Gradient-based Residual and Structure Penalization) algorithm is a score-based causal discovery method that extends NOTEARS by embedding structure- and residual-based penalties within a constrained optimization framework. It is specifically designed to enforce acyclicity while accommodating domain-specific constraints. Like NOTEARS, GRaSP leverages a continuous optimization formulation but differentiates itself by its flexible modularity and explicit incorporation of regularized loss terms and augmented Lagrangian penalties.

### Inputs and Parameters

Let $X \in \mathbb{R}^{n \times d}$ represent a dataset of $n$ samples across $d$ standardized variables. The algorithm is parameterized as follows:

- $X$: Standardized input data matrix.
- $\mathcal{C} \in \{0, 1, \text{NaN}\}^{d \times d}$: Constraint matrix, where:
  - $0$ indicates forbidden edges,
  - $\text{NaN}$ allows free optimization.
- $\lambda_1$: L1 regularization parameter enforcing sparsity.
- $h_{\text{tol}}$: Tolerance threshold for enforcing acyclicity.
- $\rho_{\max}$: Maximum penalty for the augmented Lagrangian term.
- $w_{\text{thresh}}$: Threshold applied to final edge weights.
- $\mathcal{T}$: Tier structure for constraint validation.
- $\phi: \mathcal{N} \rightarrow \{1, \dots, d\}$: Node-to-index map.

### Algorithmic Procedure

1. **Data Standardization**:
   - For each variable $x_j$, compute:
     $$
     x_j' = \frac{x_j - \mu_j}{\sigma_j}, \quad \text{if } \sigma_j > 0
     $$
     or $x_j' = 0$ if $\sigma_j = 0$.

2. **Acyclicity Characterization**:
   - Define the smooth acyclicity constraint:
     $$
     h(W) = \text{tr}\left(\exp(W \circ W / d)\right) - d
     $$
     where $W \in \mathbb{R}^{d \times d}$ is the weighted adjacency matrix.

3. **Penalized Loss Function**:
   - The optimization objective combines squared loss, L1 regularization, and acyclicity terms:
     $$
     \mathcal{L}(W) = \frac{1}{2n} \|X - XW\|_F^2 + \lambda_1 \|W\|_1 + \frac{\rho}{2} h(W)^2 + \alpha h(W)
     $$
   - Masking is applied such that $W_{ij} = 0$ where $\mathcal{C}_{ij} = 0$.

4. **Gradient Computation**:
   - Gradients are computed with respect to all penalty terms, including:
     - Residual loss: $\nabla \|X - XW\|_F^2$
     - L1 regularization: $\lambda_1 \cdot \text{sign}(W)$
     - Acyclicity: $(\rho h(W) + \alpha) \cdot \nabla h(W)$

5. **Augmented Lagrangian Optimization**:
   - Use L-BFGS-B to iteratively minimize $\mathcal{L}(W)$.
   - Update multipliers:
     $$
     \rho \leftarrow \min(\rho \cdot 10, \rho_{\max}), \quad \alpha \leftarrow \alpha + \rho \cdot h(W)
     $$
     until $|h(W)| \leq h_{\text{tol}}$.

6. **Post-Processing**:
   - Threshold small weights: $W_{ij} = 0$ if $|W_{ij}| < w_{\text{thresh}}$.
   - Construct graph $G = (V, E)$ from non-zero $W_{ij}$.

7. **Cycle Removal**:
   - Ensure $G$ is a DAG by iteratively:
     - Detecting cycles,
     - Removing the edge with the smallest absolute weight within the cycle.

### Theoretical Justification

GRaSP maintains the theoretical foundation of continuous DAG learning by enforcing a differentiable acyclicity constraint. The use of augmented Lagrangian multipliers allows for a flexible enforcement of structural validity. The combination of L1 sparsity, Frobenius norm loss, and cycle penalization ensures interpretable and data-consistent graph recovery. The convexity of the loss with respect to linear models, along with the smoothness of $h(W)$, permits convergence under standard optimization assumptions.

### Outcomes and Limitations

The algorithm outputs a DAG $G$, adjacency matrix $W \in \mathbb{R}^{d \times d}$, constraint violations, and structure learning metrics (e.g., SHD, precision, recall). Its complexity is primarily governed by L-BFGS iterations and matrix exponentials, typically $O(d^3)$ per iteration. The main limitations include sensitivity to regularization tuning, assumption of linearity in structural equations, and approximate enforcement of acyclicity that may necessitate post hoc corrections.

In [45]:
# Define the GRaSPAlgorithm class.
class GRaSPAlgorithm(CausalDiscoveryAlgorithm):

    # Define the fit method for the GRaSPAlgorithm class.
    def fit(
        self,
        data,
        constraint_matrix,
        node_names,
        node_name_to_idx,
        tiers,
        output_dir
    ):
        # Log the start of the GRaSP algorithm.
        logger.info("Running GRaSP algorithm...")

        # Handle potential errors.
        try:
            # Extract values from the data.
            X = data.values

            # Compute standard deviations and means.
            stds = np.std(X, axis=0)
            means = np.mean(X, axis=0)

            # Standardize the data.
            X_standardized = np.where(
                stds != 0,
                (X - means) / stds,
                0
            )

            # Define the GRaSP algorithm with constraints.
            def grasp_with_constraints(
                X,
                constraint_matrix,
                lambda1=0.01,
                max_iter=200,
                h_tol=1e-8,
                rho_max=1e+16,
                w_threshold=0.1
            ):
                # Initialize dimensions and mask.
                n, d = X.shape
                mask = 1.0 - np.isnan(constraint_matrix).astype(float)

                # Define the acyclicity function.
                def _h(w):
                    W = w.reshape((d, d))
                    M = np.eye(d) + W * W / d
                    return np.trace(slin.expm(M)) - d

                # Define the loss function with penalties.
                def _func(w, rho, alpha):
                    W = w.reshape((d, d))
                    W = W * (1.0 - mask)
                    R = X - X @ W

                    loss = 0.5 / n * np.sum(R * R)
                    l1_penalty = lambda1 * np.sum(np.abs(W))
                    h_val = _h(w)

                    return (
                        loss
                        + l1_penalty
                        + 0.5 * rho * h_val ** 2
                        + alpha * h_val
                    )

                # Define the gradient of the loss function.
                def _grad(w, rho, alpha):
                    W = w.reshape((d, d))
                    W = W * (1.0 - mask)
                    R = X - X @ W

                    G_loss = -1.0 / n * X.T @ R
                    G_l1 = lambda1 * np.sign(W)

                    h_val = _h(w)
                    h_gradient = _h_grad(w).reshape((d, d))
                    G_acyclicity = (rho * h_val + alpha) * h_gradient

                    G = (G_loss + G_l1 + G_acyclicity) * (1.0 - mask)
                    return G.flatten()

                # Define the gradient of the acyclicity function.
                def _h_grad(w):
                    W = w.reshape((d, d))
                    M = np.eye(d) + W * W / d
                    E = slin.expm(M)

                    G = E.T * (2 * W / d)
                    G = G * (1.0 - mask)

                    return G.flatten()

                # Initialize optimization parameters.
                w_est = np.zeros(d * d)
                rho, alpha, h = 1.0, 0.0, np.inf

                # Run optimization loop.
                for _ in range(max_iter):
                    w_new = sopt.minimize(
                        lambda w: _func(w, rho, alpha),
                        w_est,
                        method="L-BFGS-B",
                        jac=lambda w: _grad(w, rho, alpha),
                        options={
                            "ftol": 1e-6,
                            "gtol": 1e-6
                        }
                    ).x

                    h_new = _h(w_new)

                    # Break if tolerance or maximum rho is reached.
                    if abs(h_new) <= h_tol or rho >= rho_max:
                        break

                    # Adjust rho based on acyclicity.
                    if abs(h_new) > 0.25 * abs(h):
                        rho *= 10

                    alpha += rho * h_new
                    w_est, h = w_new, h_new

                # Reshape and threshold the estimated weights.
                W_est = w_est.reshape((d, d))
                W_est = W_est * (1.0 - mask)
                W_est[np.abs(W_est) < w_threshold] = 0

                # Create a directed graph from the weight matrix.
                G = nx.DiGraph(W_est)

                # Remove cycles from the graph.
                while not nx.is_directed_acyclic_graph(G):
                    # Handle potential cycles in the graph.
                    try:
                        cycle = nx.find_cycle(G)
                        min_weight = float("inf")
                        min_edge = None

                        # Identify the edge with minimum weight.
                        for u, v in cycle:
                            if abs(W_est[u, v]) < min_weight:
                                min_weight = abs(W_est[u, v])
                                min_edge = (u, v)

                        # Remove the weakest edge in the cycle.
                        if min_edge:
                            G.remove_edge(*min_edge)
                            W_est[min_edge[0], min_edge[1]] = 0

                    except nx.NetworkXNoCycle:
                        break

                return W_est

            # Run GRaSP with constraints.
            adj_matrix = grasp_with_constraints(
                X_standardized,
                constraint_matrix
            )

            # Validate constraints.
            violations = validate_constraints(
                adj_matrix,
                node_name_to_idx,
                tiers
            )

            # Create a directed graph for visualization.
            G = nx.DiGraph(adj_matrix)

            # Visualize and save results.
            G_viz = visualize_causal_graph(
                G,
                node_names,
                os.path.join(output_dir, "grasp_graph.png")
            )

            save_relations_to_text(
                adj_matrix,
                node_names,
                os.path.join(output_dir, "grasp_relations.txt")
            )

            # Analyze structure.
            metrics = analyze_structure_learning(
                adj_matrix,
                node_names
            )

            return {
                "dag": G,
                "adj_matrix": adj_matrix,
                "metrics": metrics,
                "violations": violations
            }

        # Log and raise exceptions.
        except Exception as e:
            logger.error("GRaSP failed: %s", str(e))
            raise

# 6. Prepare the Causal Discovery Pipeline

This section defines the high-level execution pipeline that orchestrates causal discovery across multiple algorithms. It initializes constraints, prepares the modeling environment, invokes each algorithm's fit() method, captures and logs outputs, and compiles a comparative summary of structural metrics and constraint adherence.

## 6.1. Define the Causal Discovery Pipeline Function

The central orchestration function that drives the end-to-end pipeline. It applies preprocessing, initializes the constraint matrix, executes all registered causal discovery algorithms, handles errors, and compiles a results dictionary and summary table, enabling comprehensive comparative analysis.

In [46]:
def run_causal_discovery_pipeline(
    train_data,
    tiers,
    specific_constraints=None,
    output_dir="causal_discovery_output"  # This will be the main parent directory
):
    """
    Run the causal discovery pipeline for all algorithms.
    Each algorithm's output will be saved in a dedicated sub-folder.

    Args:
        train_data (pd.DataFrame): Training data.
        tiers (list): List of tier lists.
        specific_constraints (dict): Additional constraints.
        output_dir (str): Parent directory to save output files.

    Returns:
        dict: Results for each algorithm.
    """
    logger.info("Starting causal discovery pipeline...")

    # Create the main output directory if it doesn't exist
    main_parent_output_dir = output_dir 
    os.makedirs(main_parent_output_dir, exist_ok=True)

    node_names = list(train_data.columns)
    constraint_matrix, node_name_to_idx = create_constraint_matrix(
        node_names,
        tiers,
        specific_constraints
    )

    algorithms = {
        "DECI": DECIAlgorithm(),
        "LiNGAM": LiNGAMAlgorithm(),
        "NOTEARS": NOTEARSAlgorithm(),
        # "PC-GIN": PCGINAlgorithm(), # Uncomment if you want to run these
        # "GRaSP": GRaSPAlgorithm()  # Uncomment if you want to run these
    }

    results = {}

    for algo_name, algo in algorithms.items():
        try:
            logger.info("Executing %s...", algo_name)
            
            # --- Create a specific sub-directory for the current algorithm ---
            algo_specific_output_dir = os.path.join(main_parent_output_dir, algo_name)
            os.makedirs(algo_specific_output_dir, exist_ok=True)
            # --- End of sub-directory creation ---

            # This print statement might become redundant if you silence save_relations_to_text
            # print(f"\n{algo_name} Causal Relationships:") # You can keep or remove this

            result = algo.fit(
                train_data,
                constraint_matrix,
                node_names,
                node_name_to_idx,
                tiers,
                algo_specific_output_dir # Pass the algorithm-specific directory
            )
            results[algo_name] = result

        except Exception as e:
            logger.error("%s failed: %s", algo_name, str(e))
            results[algo_name] = {"error": str(e)}

    summary = pd.DataFrame({
        algo_name: {
            "num_edges": result["metrics"]["num_edges"] if "metrics" in result and result["metrics"] else "N/A",
            "graph_density": result["metrics"]["graph_density"] if "metrics" in result and result["metrics"] else "N/A",
            "violations": len(result["violations"]) if "violations" in result else "N/A"
        } for algo_name, result in results.items()
    }).T

    logger.info("Summary of results:\n%s", summary)
    return results

# **D. Let the Machine Learn the Causal Graphs and Infer Them**

This section transitions from defining preparatory functions to applying them for causal structure discovery. It initiates the actual learning of causal graphs from data by first categorizing features into semantically meaningful tiers, imposing domain-informed constraints, and then running the complete causal discovery pipeline. The output includes both learned causal graphs and a validation of these graphs against the constraints using unseen test data.

# 7. Define the Tiers and Constraints, them Learn the Causal Graphs

This subsection lays the foundation for constraint-aware causal discovery by explicitly defining the hierarchical tiers of variables and specifying inter-variable restrictions. These configurations guide the structure learning process, ensuring the inferred graphs are both statistically plausible and domain-compliant.

## 7.1. Define Tiers

The tier structure segments variables into five conceptual layers: Demographic, Customer, Service, Billing, and Outcome. These layers encode a temporal or logical flow of causality, where earlier tiers can influence later ones but not vice versa. This hierarchy is used to constrain edge directionality in causal discovery, ensuring consistency with prior knowledge.

In [47]:
# Define data tiers with specific constraints for categorization.
tiers = [
    [
        "Gender",
        "Age",
        "Senior Citizen",
        "Married",
        "Number of Dependents"
    ],  # Tier 1: Demographic

    [
        "Number of Referrals",
        "Tenure in Months",
        "Offer"
    ],  # Tier 2: Customer

    [
        "Phone Service",
        "Multiple Lines",
        "Internet Type",
        "Unlimited Data",
        "Online Security",
        "Online Backup",
        "Device Protection Plan",
        "Premium Tech Support",
        "Streaming TV",
        "Streaming Movies",
        "Streaming Music"
    ],  # Tier 3: Service

    [
        "Total Revenue",
        "Paperless Billing",
        "Payment Method"
    ],  # Tier 4: Billing

    [
        "Churn Label"
    ]  # Tier 5: Outcome
]

## 7.2. Define Specific Contsraints

This section outlines manually imposed constraints on individual variable relationships. Specifically, it disallows causal edges from "Gender" to all Service and Billing tier variables, and from "Internet Type" to selected streaming and security-related services. These constraints help eliminate known spurious relationships, enhancing the reliability of the learned causal structure.

In [48]:
# Define specific feature-pair constraints including forbidden relationships.
specific_constraints = {
    "forbidden": [
        ("Gender", dst)
        # Disallow connections from "Gender" to all Tier 3 and Tier 4 features.
        for dst in tiers[2] + tiers[3]
    ] + [
        ("Internet Type", dst)
        # Disallow connections from "Internet Type" to specified streaming and security services.
        for dst in [
            "Unlimited Data",
            "Online Security",
            "Online Backup",
            "Device Protection Plan",
            "Premium Tech Support",
            "Streaming TV",
            "Streaming Movies",
            "Streaming Music"
        ]
    ],
    "allowed": []
}

## 7.3. Discover the Causality

The final step in this block executes the complete causal discovery pipeline across multiple algorithms using the defined tiers and constraints. It evaluates the discovered graphs on test data, measuring adherence to the constraints. This step yields a comparative assessment of each algorithm’s effectiveness in learning valid and interpretable causal structures from the data.

In [49]:
# Run the causal discovery pipeline.
results = run_causal_discovery_pipeline(train_data, tiers, specific_constraints, output_dir="causal_discovery_output")

# Evaluate constraint violations on test data.
node_names = list(train_data.columns)
_, node_name_to_idx = create_constraint_matrix(node_names, tiers, specific_constraints)
evaluation_summary = evaluate_causal_discovery(results, test_data, tiers, node_name_to_idx)

INFO:__main__:Starting causal discovery pipeline...
INFO:__main__:Constraint matrix created with shape: (23, 23)
INFO:__main__:Executing DECI...
INFO:__main__:Running DECI algorithm...
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name        | Type                  | Params
------------------------------------------------------
0 | auglag_loss | AugLagLossCalculator  | 0     
1 | sem_module  | SEMDistributionModule | 36.7 K
------------------------------------------------------
36.2 K    Trainable params
529       Non-trainable params
36.7 K    Total params
0.147     Total estimated model params size (MB)
c:\Users\aafz1\miniconda3\envs\causal_env_shahriyar\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_work


=== Causal Relationships ===
Married -> Number of Referrals (weight: 1.000)
---
Married -> Tenure in Months (weight: 1.000)
---
Number of Dependents -> Number of Referrals (weight: 1.000)
---
Tenure in Months -> Multiple Lines (weight: 1.000)
---
Tenure in Months -> Online Security (weight: 1.000)
---
Tenure in Months -> Online Backup (weight: 1.000)
---
Tenure in Months -> Device Protection Plan (weight: 1.000)
---
Tenure in Months -> Premium Tech Support (weight: 1.000)
---
Tenure in Months -> Streaming TV (weight: 1.000)
---
Tenure in Months -> Streaming Movies (weight: 1.000)
---
Tenure in Months -> Streaming Music (weight: 1.000)
---
Multiple Lines -> Total Revenue (weight: 1.000)
---
Internet Type -> Total Revenue (weight: 1.000)
---
Internet Type -> Paperless Billing (weight: 1.000)
---
Unlimited Data -> Paperless Billing (weight: 1.000)
---
Online Security -> Total Revenue (weight: 1.000)
---
Online Backup -> Total Revenue (weight: 1.000)
---
Device Protection Plan -> Total Re

INFO:__main__:Causal graph saved as causal_discovery_output\DECI\deci_graph_thresh_0.60.png
INFO:__main__:DECI graph for threshold 0.60 saved as causal_discovery_output\DECI\deci_graph_thresh_0.60.png
INFO:__main__:Causal relationships saved to causal_discovery_output\DECI\deci_relations_thresh_0.60.txt (12 relations)



=== Causal Relationships ===
Married -> Number of Referrals (weight: 1.000)
---
Married -> Tenure in Months (weight: 1.000)
---
Tenure in Months -> Multiple Lines (weight: 1.000)
---
Tenure in Months -> Online Security (weight: 1.000)
---
Tenure in Months -> Online Backup (weight: 1.000)
---
Tenure in Months -> Device Protection Plan (weight: 1.000)
---
Tenure in Months -> Premium Tech Support (weight: 1.000)
---
Tenure in Months -> Streaming Movies (weight: 1.000)
---
Multiple Lines -> Total Revenue (weight: 1.000)
---
Internet Type -> Paperless Billing (weight: 1.000)
---
Online Backup -> Total Revenue (weight: 1.000)
---
Device Protection Plan -> Total Revenue (weight: 1.000)
---


INFO:__main__:Causal graph saved as causal_discovery_output\DECI\deci_graph_thresh_0.70.png
INFO:__main__:DECI graph for threshold 0.70 saved as causal_discovery_output\DECI\deci_graph_thresh_0.70.png
INFO:__main__:Causal relationships saved to causal_discovery_output\DECI\deci_relations_thresh_0.70.txt (6 relations)



=== Causal Relationships ===
Married -> Number of Referrals (weight: 1.000)
---
Married -> Tenure in Months (weight: 1.000)
---
Tenure in Months -> Multiple Lines (weight: 1.000)
---
Tenure in Months -> Online Backup (weight: 1.000)
---
Tenure in Months -> Device Protection Plan (weight: 1.000)
---
Tenure in Months -> Streaming Movies (weight: 1.000)
---


INFO:__main__:Causal graph saved as causal_discovery_output\DECI\deci_graph_thresh_0.80.png
INFO:__main__:DECI graph for threshold 0.80 saved as causal_discovery_output\DECI\deci_graph_thresh_0.80.png
INFO:__main__:✅ All constraints validated successfully



=== Causal Relationships ===


INFO:__main__:Causal graph saved as causal_discovery_output\DECI\deci_graph_primary.png
INFO:__main__:Causal relationships saved to causal_discovery_output\DECI\deci_relations_primary.txt (26 relations)
INFO:__main__:Structure Learning Metrics for DECI: {'num_edges': 18, 'graph_density': '0.036', 'avg_in_degree': '0.783', 'avg_out_degree': '0.783'}
INFO:__main__:Executing LiNGAM...
INFO:__main__:Running LiNGAM algorithm...



=== Causal Relationships ===
Married -> Number of Referrals (weight: 0.726)
---
Married -> Tenure in Months (weight: 0.704)
---
Number of Dependents -> Number of Referrals (weight: 0.441)
---
Tenure in Months -> Multiple Lines (weight: 0.709)
---
Tenure in Months -> Online Security (weight: 0.647)
---
Tenure in Months -> Online Backup (weight: 0.728)
---
Tenure in Months -> Device Protection Plan (weight: 0.751)
---
Tenure in Months -> Premium Tech Support (weight: 0.648)
---
Tenure in Months -> Streaming TV (weight: 0.576)
---
Tenure in Months -> Streaming Movies (weight: 0.719)
---
Tenure in Months -> Streaming Music (weight: 0.545)
---
Offer -> Online Security (weight: 0.202)
---
Multiple Lines -> Total Revenue (weight: 0.668)
---
Internet Type -> Total Revenue (weight: 0.321)
---
Internet Type -> Paperless Billing (weight: 0.683)
---
Unlimited Data -> Paperless Billing (weight: 0.316)
---
Online Security -> Total Revenue (weight: 0.538)
---
Online Backup -> Total Revenue (weight: 

INFO:__main__:LiNGAM coefficient matrix heatmap saved as causal_discovery_output\LiNGAM\lingam_adj_matrix_heatmap.png
INFO:__main__:LiNGAM: Generating graphs for varying thresholds on absolute coefficients: [0.1, 0.3, 0.5, 0.7]
INFO:__main__:Causal graph saved as causal_discovery_output\LiNGAM\lingam_graph_thresh_0.10.png
INFO:__main__:LiNGAM graph for threshold 0.10 saved as causal_discovery_output\LiNGAM\lingam_graph_thresh_0.10.png
INFO:__main__:Causal relationships saved to causal_discovery_output\LiNGAM\lingam_relations_thresh_0.10.txt (5 relations)



=== Causal Relationships ===
Age -> Total Revenue (weight: 0.150)
---
Number of Dependents -> Number of Referrals (weight: 0.281)
---
Number of Dependents -> Tenure in Months (weight: 0.170)
---
Number of Referrals -> Tenure in Months (weight: 0.337)
---
Tenure in Months -> Total Revenue (weight: 0.850)
---


INFO:__main__:Causal graph saved as causal_discovery_output\LiNGAM\lingam_graph_thresh_0.30.png
INFO:__main__:LiNGAM graph for threshold 0.30 saved as causal_discovery_output\LiNGAM\lingam_graph_thresh_0.30.png
INFO:__main__:Causal relationships saved to causal_discovery_output\LiNGAM\lingam_relations_thresh_0.30.txt (2 relations)



=== Causal Relationships ===
Number of Referrals -> Tenure in Months (weight: 0.337)
---
Tenure in Months -> Total Revenue (weight: 0.850)
---


INFO:__main__:Causal graph saved as causal_discovery_output\LiNGAM\lingam_graph_thresh_0.50.png
INFO:__main__:LiNGAM graph for threshold 0.50 saved as causal_discovery_output\LiNGAM\lingam_graph_thresh_0.50.png
INFO:__main__:Causal relationships saved to causal_discovery_output\LiNGAM\lingam_relations_thresh_0.50.txt (1 relations)



=== Causal Relationships ===
Tenure in Months -> Total Revenue (weight: 0.850)
---


INFO:__main__:Causal graph saved as causal_discovery_output\LiNGAM\lingam_graph_thresh_0.70.png
INFO:__main__:LiNGAM graph for threshold 0.70 saved as causal_discovery_output\LiNGAM\lingam_graph_thresh_0.70.png
INFO:__main__:Causal relationships saved to causal_discovery_output\LiNGAM\lingam_relations_thresh_0.70.txt (1 relations)
INFO:__main__:✅ All constraints validated successfully



=== Causal Relationships ===
Tenure in Months -> Total Revenue (weight: 0.850)
---


INFO:__main__:Causal graph saved as causal_discovery_output\LiNGAM\lingam_graph_primary.png
INFO:__main__:Causal relationships saved to causal_discovery_output\LiNGAM\lingam_relations_primary.txt (5 relations)
INFO:__main__:Structure Learning Metrics for LINGAM: {'num_edges': 7, 'graph_density': '0.350', 'avg_in_degree': '1.400', 'avg_out_degree': '1.400'}
INFO:__main__:Executing NOTEARS...
INFO:__main__:Running NOTEARS algorithm using CausalNex...
INFO:__main__:NOTEARS: Running with 445 tabu edges...
INFO:root:Learning structure using 'NOTEARS' optimisation.



=== Causal Relationships ===
Age -> Total Revenue (weight: 0.150)
---
Number of Dependents -> Number of Referrals (weight: 0.281)
---
Number of Dependents -> Tenure in Months (weight: 0.170)
---
Number of Referrals -> Tenure in Months (weight: 0.337)
---
Tenure in Months -> Total Revenue (weight: 0.850)
---


INFO:__main__:NOTEARS adjacency matrix heatmap saved as causal_discovery_output\NOTEARS\notears_adj_matrix_heatmap.png
INFO:__main__:NOTEARS: Generating graphs for varying additional thresholds: [0.05, 0.15, 0.25, 0.3]
INFO:__main__:Causal graph saved as causal_discovery_output\NOTEARS\notears_graph_additional_thresh_0.05.png
INFO:__main__:NOTEARS graph for additional threshold 0.05 saved as causal_discovery_output\NOTEARS\notears_graph_additional_thresh_0.05.png
INFO:__main__:Causal relationships saved to causal_discovery_output\NOTEARS\notears_relations_additional_thresh_0.05.txt (13 relations)



=== Causal Relationships ===
Married -> Number of Referrals (weight: 0.681)
---
Married -> Tenure in Months (weight: 0.393)
---
Tenure in Months -> Multiple Lines (weight: 0.319)
---
Tenure in Months -> Online Security (weight: 0.298)
---
Tenure in Months -> Online Backup (weight: 0.356)
---
Tenure in Months -> Device Protection Plan (weight: 0.356)
---
Tenure in Months -> Premium Tech Support (weight: 0.302)
---
Tenure in Months -> Streaming TV (weight: 0.289)
---
Tenure in Months -> Streaming Movies (weight: 0.302)
---
Tenure in Months -> Streaming Music (weight: 0.252)
---
Multiple Lines -> Total Revenue (weight: 0.227)
---
Online Security -> Total Revenue (weight: 0.202)
---
Online Backup -> Total Revenue (weight: 0.236)
---


INFO:__main__:Causal graph saved as causal_discovery_output\NOTEARS\notears_graph_additional_thresh_0.15.png
INFO:__main__:NOTEARS graph for additional threshold 0.15 saved as causal_discovery_output\NOTEARS\notears_graph_additional_thresh_0.15.png
INFO:__main__:Causal relationships saved to causal_discovery_output\NOTEARS\notears_relations_additional_thresh_0.15.txt (13 relations)



=== Causal Relationships ===
Married -> Number of Referrals (weight: 0.681)
---
Married -> Tenure in Months (weight: 0.393)
---
Tenure in Months -> Multiple Lines (weight: 0.319)
---
Tenure in Months -> Online Security (weight: 0.298)
---
Tenure in Months -> Online Backup (weight: 0.356)
---
Tenure in Months -> Device Protection Plan (weight: 0.356)
---
Tenure in Months -> Premium Tech Support (weight: 0.302)
---
Tenure in Months -> Streaming TV (weight: 0.289)
---
Tenure in Months -> Streaming Movies (weight: 0.302)
---
Tenure in Months -> Streaming Music (weight: 0.252)
---
Multiple Lines -> Total Revenue (weight: 0.227)
---
Online Security -> Total Revenue (weight: 0.202)
---
Online Backup -> Total Revenue (weight: 0.236)
---


INFO:__main__:Causal graph saved as causal_discovery_output\NOTEARS\notears_graph_additional_thresh_0.25.png
INFO:__main__:NOTEARS graph for additional threshold 0.25 saved as causal_discovery_output\NOTEARS\notears_graph_additional_thresh_0.25.png
INFO:__main__:Causal relationships saved to causal_discovery_output\NOTEARS\notears_relations_additional_thresh_0.25.txt (10 relations)



=== Causal Relationships ===
Married -> Number of Referrals (weight: 0.681)
---
Married -> Tenure in Months (weight: 0.393)
---
Tenure in Months -> Multiple Lines (weight: 0.319)
---
Tenure in Months -> Online Security (weight: 0.298)
---
Tenure in Months -> Online Backup (weight: 0.356)
---
Tenure in Months -> Device Protection Plan (weight: 0.356)
---
Tenure in Months -> Premium Tech Support (weight: 0.302)
---
Tenure in Months -> Streaming TV (weight: 0.289)
---
Tenure in Months -> Streaming Movies (weight: 0.302)
---
Tenure in Months -> Streaming Music (weight: 0.252)
---


INFO:__main__:Causal graph saved as causal_discovery_output\NOTEARS\notears_graph_additional_thresh_0.30.png
INFO:__main__:NOTEARS graph for additional threshold 0.30 saved as causal_discovery_output\NOTEARS\notears_graph_additional_thresh_0.30.png
INFO:__main__:Causal relationships saved to causal_discovery_output\NOTEARS\notears_relations_additional_thresh_0.30.txt (7 relations)
INFO:__main__:✅ All constraints validated successfully



=== Causal Relationships ===
Married -> Number of Referrals (weight: 0.681)
---
Married -> Tenure in Months (weight: 0.393)
---
Tenure in Months -> Multiple Lines (weight: 0.319)
---
Tenure in Months -> Online Backup (weight: 0.356)
---
Tenure in Months -> Device Protection Plan (weight: 0.356)
---
Tenure in Months -> Premium Tech Support (weight: 0.302)
---
Tenure in Months -> Streaming Movies (weight: 0.302)
---


INFO:__main__:Causal graph saved as causal_discovery_output\NOTEARS\notears_graph_primary.png
INFO:__main__:Causal relationships saved to causal_discovery_output\NOTEARS\notears_relations_primary.txt (13 relations)
INFO:__main__:Structure Learning Metrics for NOTEARS: {'num_edges': 14, 'graph_density': '0.028', 'avg_in_degree': '0.609', 'avg_out_degree': '0.609'}
INFO:__main__:NOTEARS completed. Found 14 edges for primary graph.
INFO:__main__:Summary of results:
         num_edges  graph_density  violations
DECI          18.0           0.04         0.0
LiNGAM         7.0           0.35         0.0
NOTEARS       14.0           0.03         0.0
INFO:__main__:Constraint matrix created with shape: (23, 23)
INFO:__main__:Evaluating constraint violations on test data...
INFO:__main__:✅ All constraints validated successfully
INFO:__main__:DECI: 0 constraint violations on test data: None
INFO:__main__:✅ All constraints validated successfully
INFO:__main__:LiNGAM: 0 constraint violations on tes


=== Causal Relationships ===
Married -> Number of Referrals (weight: 0.681)
---
Married -> Tenure in Months (weight: 0.393)
---
Tenure in Months -> Multiple Lines (weight: 0.319)
---
Tenure in Months -> Online Security (weight: 0.298)
---
Tenure in Months -> Online Backup (weight: 0.356)
---
Tenure in Months -> Device Protection Plan (weight: 0.356)
---
Tenure in Months -> Premium Tech Support (weight: 0.302)
---
Tenure in Months -> Streaming TV (weight: 0.289)
---
Tenure in Months -> Streaming Movies (weight: 0.302)
---
Tenure in Months -> Streaming Music (weight: 0.252)
---
Multiple Lines -> Total Revenue (weight: 0.227)
---
Online Security -> Total Revenue (weight: 0.202)
---
Online Backup -> Total Revenue (weight: 0.236)
---
