<b><p style="font-size: 40px;">Causal Discovery and Inference in Customer Churn</p></b>

<b><p style="font-size: 35px;">I. First Phase: Prepare the Data</p></b>

---
# 1. Import the Relevant Packages and Configuire Them if Needed
---

## 1.1. Import the Libraries and Packages

In [1]:
# Import core Python utilities for iteration, serialization, logging, OS operations, warnings, and abstract base classes.
import itertools
import json
import logging
import os
import warnings
from abc import ABC, abstractmethod
from operator import itemgetter

In [2]:
# Import filesystem, plotting, graph, array, data handling, ML frameworks, and utility libraries.
import fsspec
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import scipy.linalg as slin
import scipy.optimize as sopt
from scipy import stats
import torch
import pytorch_lightning as pl
from dotenv import load_dotenv
from sklearn.linear_model import LinearRegression
from sklearn.metrics import classification_report
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split
from tensordict import TensorDict

In [3]:
# Import Causica and LiNGAM modules for causal inference, SEMs, and optimization routines.
import causica.distributions as _continuous_noise
from causica.distributions import ContinuousNoiseDist
from causica.datasets.causica_dataset_format import CAUSICA_DATASETS_PATH, Variable
from causica.lightning.data_modules.basic_data_module import BasicDECIDataModule
from causica.lightning.modules.deci_module import DECIModule
from causica.sem.sem_distribution import SEMDistributionModule
from causica.sem.structural_equation_model import ite
from causica.training.auglag import AugLagLRConfig
import lingam
from lingam.utils import make_dot

## 1.2. Configure the Tools

In [4]:
# Configure NumPy, Pandas, and Matplotlib display settings for clear and consistent outputs.
np.set_printoptions(precision=3, suppress=True)
pd.set_option("display.max_columns", None)
pd.set_option("display.max_colwidth", None)
pd.set_option("display.precision", 2)
plt.rcParams["font.family"] = "Times New Roman"

## 1.3. Setup the Environment

In [5]:
# Enable PyTorch MPS fallback for macOS GPU support.
PYTORCH_ENABLE_MPS_FALLBACK = 1

In [6]:
# Determine if running in test mode and set file paths for data and variables.
test_run = bool(os.getenv("TEST_RUN", False))
DATA_PATH = "data/dataset.csv"
VARIABLES_PATH = "data/variables.json"

In [7]:
# Set global random seed for reproducible experiments across NumPy, PyTorch, and PyTorch Lightning.
SEED = 100
np.random.seed(SEED)
torch.manual_seed(SEED)
pl.seed_everything(SEED)

Seed set to 100


100

## 1.4. Setup the Logs

In [8]:
# Initialize application-wide logging to file and console at INFO level.
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[
        logging.FileHandler("causal_discovery.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

---
# 2. Read and Show the Dataset
---

## 2.1. Read the Dataset

In [9]:
# Load the dataset xslsx file into a Pandas DataFrame.
data = pd.read_csv(
    DATA_PATH
)

## 2.2. Display the Dataset

In [10]:
# Display the first 10 rows of the dataset.
data.head(5).style.set_caption(
    "<b>IBM Telco Customer Churn Dataset (First 5 Rows)</b>"
)

Unnamed: 0,Customer ID,Gender,Age,Under 30,Senior Citizen,Married,Dependents,Number of Dependents,Country,State,City,Zip Code,Latitude,Longitude,Population,Quarter,Referred a Friend,Number of Referrals,Tenure in Months,Offer,Phone Service,Avg Monthly Long Distance Charges,Multiple Lines,Internet Service,Internet Type,Avg Monthly GB Download,Online Security,Online Backup,Device Protection Plan,Premium Tech Support,Streaming TV,Streaming Movies,Streaming Music,Unlimited Data,Contract,Paperless Billing,Payment Method,Monthly Charge,Total Charges,Total Refunds,Total Extra Data Charges,Total Long Distance Charges,Total Revenue,Satisfaction Score,Customer Status,Churn Label,Churn Score,CLTV,Churn Category,Churn Reason
0,8779-QRDMV,Male,78,No,Yes,No,No,0,United States,California,Los Angeles,90022,34.02381,-118.156582,68701,Q3,No,0,1,,No,0.0,No,Yes,DSL,8,No,No,Yes,No,No,Yes,No,No,Month-to-Month,Yes,Bank Withdrawal,39.65,39.65,0.0,20,0.0,59.65,3,Churned,Yes,91,5433,Competitor,Competitor offered more data
1,7495-OOKFY,Female,74,No,Yes,Yes,Yes,1,United States,California,Los Angeles,90063,34.044271,-118.185237,55668,Q3,Yes,1,8,Offer E,Yes,48.85,Yes,Yes,Fiber Optic,17,No,Yes,No,No,No,No,No,Yes,Month-to-Month,Yes,Credit Card,80.65,633.3,0.0,0,390.8,1024.1,3,Churned,Yes,69,5302,Competitor,Competitor made better offer
2,1658-BYGOY,Male,71,No,Yes,No,Yes,3,United States,California,Los Angeles,90065,34.108833,-118.229715,47534,Q3,No,0,18,Offer D,Yes,11.33,Yes,Yes,Fiber Optic,52,No,No,No,No,Yes,Yes,Yes,Yes,Month-to-Month,Yes,Bank Withdrawal,95.45,1752.55,45.61,0,203.94,1910.88,2,Churned,Yes,81,3179,Competitor,Competitor made better offer
3,4598-XLKNJ,Female,78,No,Yes,Yes,Yes,1,United States,California,Inglewood,90303,33.936291,-118.332639,27778,Q3,Yes,1,25,Offer C,Yes,19.76,No,Yes,Fiber Optic,12,No,Yes,Yes,No,Yes,Yes,No,Yes,Month-to-Month,Yes,Bank Withdrawal,98.5,2514.5,13.43,0,494.0,2995.07,2,Churned,Yes,88,5337,Dissatisfaction,Limited range of services
4,4846-WHAFZ,Female,80,No,Yes,Yes,Yes,1,United States,California,Whittier,90602,33.972119,-118.020188,26265,Q3,Yes,1,37,Offer C,Yes,6.33,Yes,Yes,Fiber Optic,14,No,No,No,No,No,No,No,Yes,Month-to-Month,Yes,Bank Withdrawal,76.5,2868.15,0.0,0,234.21,3102.36,2,Churned,Yes,67,2793,Price,Extra data charges


---
# 3. Clean and Prepare the Dataset
---

## 3.1. Rename and Reorder the Features

In [11]:
# Define the feature mapping for the dataset.
feature_mapping = {
    "Customer Info": [
        "Customer ID", "Gender", "Age", "Under 30", "Senior Citizen", "Married", "Dependents", "Number of Dependents"
    ],
    "Location Info": [
        "Country", "State", "City", "Zip Code", "Latitude", "Longitude", "Population"
    ],
    "Referral & Tenure": [
        "Quarter", "Referred a Friend", "Number of Referrals", "Tenure in Months", "Offer"
    ],
    "Services Signed Up": [
        "Phone Service", "Multiple Lines", "Internet Service", "Internet Type", "Unlimited Data"
    ],
    "Internet Features": [
        "Online Security", "Online Backup", "Device Protection Plan", "Premium Tech Support",
        "Streaming TV", "Streaming Movies", "Streaming Music"
    ],
    "Billing & Payment": [
        "Avg Monthly Long Distance Charges", "Avg Monthly GB Download", "Monthly Charge",
        "Total Charges", "Total Refunds", "Total Extra Data Charges",
        "Total Long Distance Charges", "Total Revenue", "Paperless Billing", "Payment Method"
    ],
    "Customer Scores": [
        "Satisfaction Score", "CLTV", "Churn Score"
    ],
    "Churn Info": [
        "Customer Status", "Churn Label", "Churn Category", "Churn Reason"
    ]
}

In [12]:
# Flatten the feature_mapping into a single list.
desired_order = [col for group in feature_mapping.values() for col in group]

# Reorder the DataFrame.
data = data[desired_order]

## 3.2. Remove the Unnecessary Features

In [13]:
# Define the features to remove from the dataset.
features_to_remove = [
    # Identifiers & Redundant Demographics
    "Customer ID",           # High cardinality identifier
    "Under 30",              # Redundant (derivable from Age)
    "Dependents",            # Redundant (derivable from Number of Dependents)

    # Location Info (low variance or low utility)
    "Country",               # Constant (all United States)
    "State",                 # Constant (all California)
    "Zip Code",              # Too granular
    "City",                  # High cardinality, many unique values
    "Latitude",              # Granular
    "Longitude",             # Granular
    "Population",            # Possibly low variation or correlated with city

    # Referral
    "Referred a Friend",     # Redundant (derivable from Number of Referrals)

    # Subscription Redundancy
    "Internet Service",      # Redundant (inferable from Internet Type)

    # Derived or Leaky Features
    "Customer Status",       # Leaks churn label
    "Churn Score",           # Usually post-hoc score, potential leakage
    "Churn Category",        # Sparse & derived from churn
    "Churn Reason",          # Sparse & derived from churn
    "CLTV",                  # Leaks churn label
    "Satisfaction Score",    # Leaks churn label

    # Time Feature
    "Quarter",               # Possibly low relevance unless time modeling is intended

    # Financial features removed in favor of only keeping Total Revenue
    "Avg Monthly Long Distance Charges",    # Usage-level detail removed
    "Avg Monthly GB Download",              # Usage-level detail removed
    "Monthly Charge",                       # Snapshot charge removed
    "Total Charges",                        # Cumulative but derived
    "Total Refunds",                        # Post-hoc financial info
    "Total Extra Data Charges",             # Specific fee detail removed
    "Total Long Distance Charges"           # Specific usage-based revenue removed
]

In [14]:
# Remove the specified columns from the DataFrame.
data = data.drop(
    columns=features_to_remove
)

## 3.3. Distinguish the Categorical and Numeric Features

In [15]:
# Define the categorical features in the dataset.
categorical_features = [
    "Gender",
    "Senior Citizen",
    "Married",
    "Offer",
    "Phone Service",
    "Multiple Lines",
    "Internet Type",
    "Unlimited Data",
    "Online Security",
    "Online Backup",
    "Device Protection Plan",
    "Premium Tech Support",
    "Streaming TV",
    "Streaming Movies",
    "Streaming Music",
    "Paperless Billing",
    "Payment Method",
    "Churn Label"
]

# Define the numeric features in the dataset.
numeric_features = [
    "Age",
    "Number of Dependents",
    "Number of Referrals",
    "Tenure in Months",
    "Total Revenue"
]

## 3.4. Check the Categorical Features' Unique Values

In [16]:
# Define a vacant list to store the rows.
rows = []

# Iterate through the categorical features.
for feature in categorical_features:

    # Get the unique values and their counts.
    value_counts = data[feature].value_counts()
    first_row = True

    # Iterate through the unique values.
    for value, count in value_counts.items():

        # Calculate the percentage.
        percentage = str(round((count / len(data)) * 100, 2))
        # Set the feature name.
        feature_name = feature if first_row else ""
        # Append the unique value, count, and percentage to the rows.
        rows.append([
            feature_name,
            value,
            count,
            percentage
        ])
        # Set the first row to False.
        first_row = False

In [17]:
# Create a dataFrame of unique values.
unique_values_df = pd.DataFrame(
    rows,
    columns=["Feature", "Unique Value", "Frequency", "Percentage"]
)

# Round the percentage to 2 decimal places.
unique_values_df["Percentage"] = unique_values_df["Percentage"].round(2)

# Style the unique values dataFrame.
unique_values_df = (
    unique_values_df.style
    .set_caption("<b>Unique Values in Categorical Features</b>")
    .hide(axis="index")
)

# Display the unique values in the categorical features.
unique_values_df

TypeError: can't multiply sequence by non-int of type 'float'

## 3.5. Check the Numeric Features' Statistics

In [None]:
data[numeric_features].describe()

## 3.6. Implement IQR-Based Clipping of Numeric Outliers

In [None]:
# Handle outliers in numeric features.
for feature in numeric_features:

    q1, q3 = data[feature].quantile([0.25, 0.75])
    iqr = q3 - q1
    lower_bound, upper_bound = q1 - 1.5 * iqr, q3 + 1.5 * iqr

    data[feature] = data[feature].clip(lower=lower_bound, upper=upper_bound)

## 3.6. Check for Missing Values in the Data

In [None]:
# Find out the missing values percentage in the dataset.
non_missing_percentage = data.notnull().mean() * 100

In [None]:
# Create a DataFrame from the non-missing percentage series.
non_missing_df = pd.DataFrame(
    non_missing_percentage,
    columns=["Non-Missing Percentage"]
)

# Change the index to a column named "Feature".
non_missing_df = non_missing_df.reset_index().rename(
    columns={
        "index": "Feature"
    }
)

# Increment the DataFrame index to start from 1.
non_missing_df.index = non_missing_df.index + 1

# Display the non-missing percentage table with two decimal places.
non_missing_df.style.set_caption(
    "<b>Non-Missing Percentage of Features</b>"
).format(
    {
        "Non-Missing Percentage": "{:.2f}"
    }
)

## 3.7. Fill the Missing Values

In [None]:
# Fill missing values in the dataset.
data["Offer"] = data["Offer"].fillna("No Offer")
data["Internet Type"] = data["Internet Type"].fillna("No Internet")

## 3.8. Fix Data Types

In [None]:
# Pad the shorter list with empty strings.
max_length = max(
    len(categorical_features),
    len(numeric_features)
)

categorical_features_feature = categorical_features.copy()
numeric_features_feature = numeric_features.copy()

categorical_features_feature += [""] * (
    max_length - len(categorical_features)
)
numeric_features_feature += [""] * (
    max_length - len(numeric_features)
)

In [None]:
# Create a DataFrame to display feature categorization.
feature_types_df = pd.DataFrame(
    {
        "Categorical Features": categorical_features_feature,
        "Numeric Features": numeric_features_feature
    }
)

# Increment the DataFrame index to start from 1.
feature_types_df.index = feature_types_df.index + 1

# Display the feature categorization table.
feature_types_df.style.set_caption(
    "<b>Categorization of Features by Type</b>"
)

## 3.9. Check the Features Data Types

In [None]:
# Create a DataFrame to display the data types of categorical features.
categorical_features_data_types = pd.DataFrame(
    data[categorical_features].dtypes,
    columns=["Categorical Features' Data Types"]
)

# Display the data types table with a caption.
categorical_features_data_types.style.set_caption(
    "<b>Categorical Features' Data Types</b>"
)

In [None]:
# Create a DataFrame to display the data types of numeric features.
numeric_features_data_types = pd.DataFrame(
    data[numeric_features].dtypes,
    columns=["Numeric Features' Data Types"]
)

# Display the data types table with a caption.
numeric_features_data_types.style.set_caption(
    "Numeric Features' Data Types"
)

## 3.10. Encode the Binary Features

In [None]:
# Define a function to encode binary features.
def encode_binary_features(datasets, features, mapping):
    """
    Applies binary encoding to specified features across multiple datasets.

    Args:
        datasets (List[pd.DataFrame]): A list of DataFrames to be modified in-place.
        features (List[str]): The names of binary categorical features to encode.
        mapping (dict): A dictionary mapping string categories to binary values.
    """
    for df in datasets:
        for feature in features:
            df[feature] = df[feature].astype(str).map(mapping)

In [None]:
# Define binary categorical features to be encoded.
binary_features = [
    "Gender",
    "Senior Citizen",
    "Married",
    "Phone Service",
    "Multiple Lines",
    "Unlimited Data",
    "Online Security",
    "Online Backup",
    "Device Protection Plan",
    "Premium Tech Support",
    "Streaming TV",
    "Streaming Movies",
    "Streaming Music",
    "Paperless Billing",
    "Churn Label"
]

# Define mapping for binary categories.
binary_mapping = {
    "Yes": 1,
    "No": 0,
    "Male": 1,
    "Female": 0
}

In [None]:
# Apply binary encoding to all datasets.
encode_binary_features(
    datasets=[data],
    features=binary_features,
    mapping=binary_mapping
)

## 3.11. Encode the Ordinal Features

In [None]:
# Define a function to encode ordinal features.
def encode_ordinal_features(datasets, mappings):
    """
    Applies ordinal encoding to specified features across multiple datasets.

    Args:
        datasets (List[pd.DataFrame]): A list of DataFrames to be modified in-place.
        mappings (dict): A dictionary where keys are feature names and values are mapping dicts.
    """
    for df in datasets:
        for feature, mapping in mappings.items():
            df[feature] = df[feature].map(mapping)

In [None]:
# Define the mappings for ordinal features.
offer_mapping = {
    "No Offer": 0,
    "Offer A": 1,
    "Offer B": 2,
    "Offer C": 3,
    "Offer D": 4,
    "Offer E": 5
}

internet_type_mapping = {
    "No Internet": 0,
    "DSL": 1,
    "Cable": 2,
    "Fiber Optic": 3
}

payment_method_mapping = {
    "Mailed Check": 1,
    "Bank Withdrawal": 2,
    "Credit Card": 3
}

# Create a dictionary of ordinal mappings.
ordinal_mappings = {
    "Offer": offer_mapping,
    "Internet Type": internet_type_mapping,
    "Payment Method": payment_method_mapping
}

In [None]:
# Apply ordinal encoding to all datasets.
encode_ordinal_features(
    datasets=[data],
    mappings=ordinal_mappings
)

## 3.12. Standard Scale the Numeric Features

In [None]:
# Apply the scaler to normalize all numeric features in the dataset.
scaler = StandardScaler()

data[numeric_features] = scaler.fit_transform(data[numeric_features])

<b><p style="font-size: 35px;">II. Second Phase: Causal Discovery and Inference</p></b>

# 1. Define Utility Functions

In [None]:
def create_constraint_matrix(node_names, tiers, specific_constraints=None):
    """
    Create a constraint matrix for causal discovery algorithms.

    Args:
        node_names (list): List of variable names.
        tiers (list): List of tier lists (e.g., [demographic, customer, ...]).
        specific_constraints (dict): Additional constraints (e.g., {"forbidden": [(src, dst), ...]}).

    Returns:
        np.ndarray: Constraint matrix (np.nan for allowed edges, 0.0 for forbidden).
    """
    num_nodes = len(node_names)
    node_name_to_idx = {name: i for i, name in enumerate(node_names)}
    constraint_matrix = np.full((num_nodes, num_nodes), np.nan, dtype=np.float32)

    # Set Churn Label as sink node (no outgoing edges)
    churn_idx = node_name_to_idx.get("Churn Label")
    if churn_idx is not None:
        constraint_matrix[churn_idx, :] = 0.0

    # Set demographic variables as root nodes (no incoming edges)
    for feature in tiers[0]:  # Tier 1: Demographic
        if feature in node_name_to_idx:
            feature_idx = node_name_to_idx[feature]
            constraint_matrix[:, feature_idx] = 0.0

    # Prevent edges within Tier 1
    for src in tiers[0]:
        for dst in tiers[0]:
            if src != dst and src in node_name_to_idx and dst in node_name_to_idx:
                constraint_matrix[node_name_to_idx[src], node_name_to_idx[dst]] = 0.0

    # Allow edges only from Tier N to Tier N+1
    for src_tier_idx, src_tier in enumerate(tiers[:-1]):
        dst_tier = tiers[src_tier_idx + 1]
        for src in src_tier:
            for dst in dst_tier:
                if src in node_name_to_idx and dst in node_name_to_idx:
                    constraint_matrix[node_name_to_idx[src], node_name_to_idx[dst]] = np.nan
        # Block edges to other tiers
        for other_tier_idx, other_tier in enumerate(tiers):
            if other_tier_idx != src_tier_idx + 1:
                for src in src_tier:
                    for dst in other_tier:
                        if src in node_name_to_idx and dst in node_name_to_idx:
                            constraint_matrix[node_name_to_idx[src], node_name_to_idx[dst]] = 0.0

    # Apply specific constraints (e.g., Gender → Service/Billing forbidden)
    if specific_constraints:
        for src, dst in specific_constraints.get("forbidden", []):
            if src in node_name_to_idx and dst in node_name_to_idx:
                constraint_matrix[node_name_to_idx[src], node_name_to_idx[dst]] = 0.0
        for src, dst in specific_constraints.get("allowed", []):
            if src in node_name_to_idx and dst in node_name_to_idx:
                constraint_matrix[node_name_to_idx[src], node_name_to_idx[dst]] = np.nan

    logger.info("Constraint matrix created with shape: %s", constraint_matrix.shape)
    return constraint_matrix, node_name_to_idx

In [None]:
def validate_constraints(dag, node_name_to_idx, tiers, threshold=0.5):
    """
    Validate constraints on a DAG or adjacency matrix.

    Args:
        dag: NetworkX DiGraph or np.ndarray (adjacency/probability matrix).
        node_name_to_idx (dict): Mapping of node names to indices.
        tiers (list): List of tier lists.
        threshold (float): Probability threshold for matrix-based DAGs.

    Returns:
        list: List of constraint violation messages (empty if none).
    """
    violations = []
    num_nodes = len(node_name_to_idx)
    churn_idx = node_name_to_idx.get("Churn Label")

    # Convert adjacency matrix to DAG if needed
    if isinstance(dag, np.ndarray):
        G = nx.DiGraph()
        G.add_nodes_from(range(num_nodes))
        for i in range(num_nodes):
            for j in range(num_nodes):
                if dag[i, j] > threshold:
                    G.add_edge(i, j)
    else:
        G = dag

    # Check if Churn Label has outgoing edges
    if churn_idx is not None and any(G.has_edge(churn_idx, j) for j in range(num_nodes)):
        violations.append("Churn Label has outgoing edges")

    # Check if Tier 1 variables have incoming edges
    for var in tiers[0]:
        var_idx = node_name_to_idx.get(var)
        if var_idx is not None and any(G.has_edge(j, var_idx) for j in range(num_nodes)):
            violations.append(f"{var} has incoming edges")

    # Check for edges within Tier 1
    for src in tiers[0]:
        src_idx = node_name_to_idx.get(src)
        for dst in tiers[0]:
            dst_idx = node_name_to_idx.get(dst)
            if src != dst and src_idx is not None and dst_idx is not None and G.has_edge(src_idx, dst_idx):
                violations.append(f"T1↛T1 edge: {src}→{dst}")

    # Check for forbidden Gender → Service/Billing edges
    gender_idx = node_name_to_idx.get("Gender")
    if gender_idx is not None:
        for dst in tiers[2] + tiers[3]:  # Service + Billing
            dst_idx = node_name_to_idx.get(dst)
            if dst_idx is not None and G.has_edge(gender_idx, dst_idx):
                violations.append(f"Gender→{dst} edge exists")

    if not violations:
        logger.info("✅ All constraints validated successfully")
    else:
        logger.warning("❌ Constraint violations detected: %s", violations)

    return violations

In [None]:
def save_relations_to_text(dag, node_names, filename, threshold=0.2):
    """
    Save causal relationships to a text file.

    Args:
        dag: NetworkX DiGraph or np.ndarray (adjacency/probability matrix).
        node_names (list): List of node names.
        filename (str): Output text file name.
        threshold (float): Probability/weight threshold for matrix-based DAGs.
    """
    try:
        relations = []
        if isinstance(dag, np.ndarray):
            for i in range(dag.shape[0]):
                for j in range(dag.shape[1]):
                    if dag[i, j] > threshold:
                        relations.append({
                            "source": node_names[i],
                            "destination": node_names[j],
                            "weight": float(dag[i, j])
                        })
        else:
            for src, dst in dag.edges():
                weight = dag[src][dst].get("weight", 1.0)  # Default to 1.0 for unweighted edges
                relations.append({
                    "source": node_names[src],
                    "destination": node_names[dst],
                    "weight": float(weight)
                })

        # Convert to DataFrame and save as text
        relations_df = pd.DataFrame(relations)
        if not relations_df.empty:
            relations_df.to_csv(filename, sep="\t", index=False, columns=["source", "destination", "weight"])
            logger.info("Causal relationships saved to %s (%d relations)", filename, len(relations_df))
        else:
            with open(filename, "w") as f:
                f.write("No causal relationships found.")
            logger.warning("No causal relationships to save for %s", filename)

    except Exception as e:
        logger.error("Failed to save relations to %s: %s", filename, str(e))
        raise

In [None]:
def visualize_causal_graph(dag, node_names, filename="causal_graph.png"):
    """
    Visualize a causal graph and save it to a file.

    Args:
        dag: NetworkX DiGraph or np.ndarray (adjacency matrix).
        node_names (list): List of node names.
        filename (str): Output file name.

    Returns:
        nx.DiGraph: Visualized graph.
    """
    if isinstance(dag, np.ndarray):
        G = nx.DiGraph(dag)
    else:
        G = dag

    plt.figure(figsize=(10, 8))
    pos = nx.spring_layout(G, seed=42)
    nx.draw(G, pos, with_labels=True, labels={i: node_names[i] for i in G.nodes()},
            node_color='lightblue', node_size=2000, font_size=10, font_weight='bold', arrowsize=20)

    if isinstance(dag, np.ndarray):
        edge_labels = {(i, j): f"{dag[i, j]:.3f}" for i, j in G.edges()}
        nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)

    plt.title(f"Causal Graph ({filename.split('.')[0]})")
    plt.savefig(filename, format="png", dpi=300, bbox_inches="tight")
    plt.close()
    logger.info("Causal graph saved as %s", filename)
    return G

In [None]:
def analyze_structure_learning(dag, node_names, threshold=0.5):
    """
    Analyze the structure of a learned DAG.

    Args:
        dag: NetworkX DiGraph or np.ndarray (adjacency/probability matrix).
        node_names (list): List of node names.
        threshold (float): Probability threshold for matrix-based DAGs.

    Returns:
        dict: Structure learning metrics.
    """
    if isinstance(dag, np.ndarray):
        adj_matrix = (dag > threshold).astype(int)
        G = nx.DiGraph(adj_matrix)
    else:
        G = dag
        adj_matrix = nx.to_numpy_array(G, nodelist=range(len(node_names)))

    num_edges = np.sum(adj_matrix)
    num_nodes = len(node_names)
    graph_density = num_edges / (num_nodes * (num_nodes - 1)) if num_nodes > 1 else 0
    in_degree = np.sum(adj_matrix, axis=0)
    out_degree = np.sum(adj_matrix, axis=1)
    avg_in_degree = np.mean(in_degree)
    avg_out_degree = np.mean(out_degree)

    most_influential = [(node_names[i], out_degree[i]) for i in np.argsort(-out_degree)[:5]]
    most_affected = [(node_names[i], in_degree[i]) for i in np.argsort(-in_degree)[:5]]

    metrics = {
        "num_edges": num_edges,
        "graph_density": graph_density,
        "avg_in_degree": avg_in_degree,
        "avg_out_degree": avg_out_degree,
        "most_influential": most_influential,
        "most_affected": most_affected
    }

    logger.info("Structure Learning Metrics: %s", metrics)
    return metrics

In [None]:
def evaluate_causal_discovery(results, test_data, tiers, node_name_to_idx, threshold=0.5):
    """
    Evaluate causal discovery algorithms on test data, focusing on constraint violations.

    Args:
        results (dict): Results from run_causal_discovery_pipeline.
        test_data (pd.DataFrame): Test data.
        tiers (list): List of tier lists.
        node_name_to_idx (dict): Mapping of node names to indices.
        threshold (float): Probability threshold for matrix-based DAGs.

    Returns:
        pd.DataFrame: Evaluation metrics for each algorithm.
    """
    logger.info("Evaluating constraint violations on test data...")
    evaluation_metrics = {}

    # Preprocess test data to handle issues
    constant_cols = [col for col in test_data.columns if test_data[col].std() == 0]
    if constant_cols:
        logger.warning("Constant columns in test data: %s", constant_cols)
        test_data = test_data.drop(columns=constant_cols)
    if test_data.isna().any().any():
        logger.warning("NaNs in test data, filling with mean")
        test_data = test_data.fillna(test_data.mean())

    for algo_name, result in results.items():
        if "error" in result:
            evaluation_metrics[algo_name] = {
                "constraint_violations": "N/A",
                "violation_details": "N/A"
            }
            continue

        try:
            dag = result["dag"]
            adj_matrix = result["adj_matrix"]

            # Adjust node_name_to_idx for LiNGAM (uses subset of features)
            algo_node_name_to_idx = node_name_to_idx
            if algo_name == "LiNGAM":
                continuous_features = [f for f in test_data.columns if f in ["Age", "Number of Dependents", "Number of Referrals", "Tenure in Months", "Total Revenue"]]
                algo_node_name_to_idx = {name: i for i, name in enumerate(continuous_features)}
                # Update tiers for LiNGAM
                lingam_tiers = [
                    [f for f in tiers[0] if f in continuous_features],
                    [f for f in tiers[1] if f in continuous_features],
                    [f for f in tiers[3] if f in continuous_features]
                ]
            else:
                lingam_tiers = tiers

            # Validate constraints
            violations = validate_constraints(dag if algo_name != "LiNGAM" else adj_matrix,
                                           algo_node_name_to_idx, lingam_tiers, threshold)

            evaluation_metrics[algo_name] = {
                "constraint_violations": len(violations),
                "violation_details": violations if violations else "None"
            }

            logger.info("%s: %d constraint violations on test data: %s",
                        algo_name, len(violations), violations if violations else "None")

        except Exception as e:
            logger.error("Evaluation failed for %s: %s", algo_name, str(e))
            evaluation_metrics[algo_name] = {
                "constraint_violations": "N/A",
                "violation_details": "N/A"
            }

    # Create summary DataFrame
    summary = pd.DataFrame(evaluation_metrics).T
    logger.info("Constraint violation summary:\n%s", summary)

    return summary

# 2. Define the Causal Discovery and Inference Functions

# 2.1. DECI Algorithm

In [None]:
def run_deci_algorithm(data, constraint_matrix, node_names, node_name_to_idx, tiers):
    logger.info("Running DECI algorithm...")
    try:
        # Load variables.json
        with fsspec.open("data/variables.json", mode="r", encoding="utf-8") as f:
            variables = json.load(f)["variables"]

        # Validate columns
        expected_columns = [var["name"] for var in variables]
        if set(expected_columns) != set(data.columns):
            raise ValueError(f"Columns mismatch: {set(data.columns)} vs {set(expected_columns)}")

        # Prepare data module
        data_module = BasicDECIDataModule(
            data,
            variables=[Variable.from_dict(d) for d in variables],
            batch_size=128,
            normalize=True
        )

        # Initialize DECI module
        lightning_module = DECIModule(
            noise_dist=ContinuousNoiseDist.GAUSSIAN,
            prior_sparsity_lambda=200.0,
            init_rho=30.0,
            init_alpha=0.20,
            auglag_config=AugLagLRConfig(
                max_inner_steps=1500,
                max_outer_steps=8,
                lr_init_dict={
                    "icgnn": 0.00076,
                    "vardist": 0.0098,
                    "functional_relationships": 3e-4,
                    "noise_dist": 0.0070,
                }
            )
        )
        lightning_module.constraint_matrix = torch.tensor(constraint_matrix)

        # Train
        trainer = pl.Trainer(
            accelerator="gpu" if torch.cuda.is_available() else "cpu",
            devices=1,
            max_epochs=10,
            callbacks=[TQDMProgressBar(refresh_rate=19)],
            enable_checkpointing=False
        )
        trainer.fit(lightning_module, datamodule=data_module)

        # Save model
        torch.save(lightning_module.sem_module, "deci.pt")

        # Compute probability matrix
        logits_exist = lightning_module.sem_module.adjacency_module.adjacency_distribution.logits_exist
        logits_orient = lightning_module.sem_module.adjacency_module.adjacency_distribution.logits_orient

        def fill_triangular(vec, upper=False):
            n = int(np.sqrt(2 * len(vec))) + 1
            if upper:
                return vec.new_zeros(n, n).triu(1).masked_scatter_(
                    torch.triu(torch.ones(n, n, device=vec.device), 1).bool(), vec
                )
            return vec.new_zeros(n, n).tril(-1).masked_scatter_(
                    torch.tril(torch.ones(n, n, device=vec.device), -1).bool(), vec
                )

        neg_theta = fill_triangular(logits_orient, upper=True) - fill_triangular(logits_orient, upper=False)
        logits_matrix = -torch.logsumexp(torch.stack([-logits_exist, neg_theta, neg_theta - logits_exist], dim=-1), dim=-1)
        prob_matrix = 1 / (1 + np.exp(-logits_matrix.cpu().detach().numpy()))
        prob_matrix = prob_matrix * np.isnan(constraint_matrix)

        # Validate constraints
        violations = validate_constraints(prob_matrix, node_name_to_idx, tiers)

        # Visualize
        G = visualize_causal_graph(prob_matrix, node_names, "deci_graph.png")

        # Save relations to text
        save_relations_to_text(prob_matrix, node_names, "deci_relations.txt")

        # Analyze structure
        metrics = analyze_structure_learning(prob_matrix, node_names)

        return {
            "dag": G,
            "adj_matrix": prob_matrix,
            "metrics": metrics,
            "violations": violations
        }

    except Exception as e:
        logger.error("DECI failed: %s", str(e))
        raise

# 2.2. LiNGAM Algorithm

In [None]:
def run_lingam_algorithm(data, constraint_matrix, node_names, node_name_to_idx, tiers):
    logger.info("Running LiNGAM algorithm...")
    try:
        # Filter continuous features
        continuous_features = [f for f in node_names if f in ["Age", "Number of Dependents", "Number of Referrals", "Tenure in Months", "Total Revenue"]]
        lingam_data = data[continuous_features].copy()

        # Remove constant columns
        constant_cols = [col for col in lingam_data.columns if lingam_data[col].std() == 0]
        if constant_cols:
            logger.warning("Removing constant columns: %s", constant_cols)
            lingam_data = lingam_data.drop(columns=constant_cols)
            continuous_features = [f for f in continuous_features if f not in constant_cols]

        # Validate data
        if lingam_data.isna().any().any():
            raise ValueError("LiNGAM data contains NaNs")

        # Create LiNGAM-specific constraint matrix
        lingam_node_to_idx = {name: i for i, name in enumerate(continuous_features)}
        lingam_constraint_matrix = np.full((len(continuous_features), len(continuous_features)), -1, dtype=np.int32)

        lingam_tiers = [
            [f for f in tiers[0] if f in continuous_features],  # Demographic
            [f for f in tiers[1] if f in continuous_features],  # Customer
            [f for f in tiers[3] if f in continuous_features]   # Billing
        ]

        if "Total Revenue" in lingam_node_to_idx:
            lingam_constraint_matrix[lingam_node_to_idx["Total Revenue"], :] = 0
        for feature in lingam_tiers[0]:
            lingam_constraint_matrix[:, lingam_node_to_idx[feature]] = 0
        for src_tier_idx, src_tier in enumerate(lingam_tiers[:-1]):
            dst_tier = lingam_tiers[src_tier_idx + 1]
            for src in src_tier:
                for dst in dst_tier:
                    lingam_constraint_matrix[lingam_node_to_idx[src], lingam_node_to_idx[dst]] = -1
            for other_tier_idx, other_tier in enumerate(lingam_tiers):
                if other_tier_idx != src_tier_idx + 1:
                    for src in src_tier:
                        for dst in other_tier:
                            lingam_constraint_matrix[lingam_node_to_idx[src], lingam_node_to_idx[dst]] = 0
        for src in lingam_tiers[0]:
            for dst in lingam_tiers[0]:
                if src != dst:
                    lingam_constraint_matrix[lingam_node_to_idx[src], lingam_node_to_idx[dst]] = 0

        # Fit LiNGAM
        model = lingam.DirectLiNGAM(prior_knowledge=lingam_constraint_matrix)
        model.fit(lingam_data)

        adj_matrix = model.adjacency_matrix_

        # Validate constraints
        violations = validate_constraints(adj_matrix, lingam_node_to_idx, lingam_tiers)

        # Visualize
        G = visualize_causal_graph(adj_matrix, continuous_features, "lingam_graph.png")

        # Save relations to text
        save_relations_to_text(adj_matrix, continuous_features, "lingam_relations.txt")

        # Analyze structure
        metrics = analyze_structure_learning(adj_matrix, continuous_features)

        return {
            "dag": G,
            "adj_matrix": adj_matrix,
            "metrics": metrics,
            "violations": violations
        }

    except Exception as e:
        logger.error("LiNGAM failed: %s", str(e))
        raise

# 2.3. PC-GIN Algorithm

In [None]:
def run_pcgins_algorithm(data, constraint_matrix, node_names, node_name_to_idx, tiers):
    logger.info("Running PC-GIN algorithm...")
    try:
        # Encode categorical columns
        categorical_cols = [col for col in data.columns if col in ["Gender", "Internet Type", "Offer", "Payment Method"]]
        encoded_data = data.copy()
        for col in categorical_cols:
            le = LabelEncoder()
            encoded_data[col] = le.fit_transform(encoded_data[col].astype(str))

        def gin_test(X, Y, Z=None, alpha=0.01):
            n = len(X)
            if Z is None or Z.shape[1] == 0:
                corr, p_value = stats.pearsonr(X, Y)
                return p_value
            model_x = LinearRegression().fit(Z, X)
            residuals_x = X - model_x.predict(Z)
            model_y = LinearRegression().fit(Z, Y)
            residuals_y = Y - model_y.predict(Z)
            corr, p_value = stats.pearsonr(residuals_x, residuals_y)
            return p_value

        def pc_gin(data, constraint_matrix, alpha=0.01):
            n = data.shape[1]
            skeleton = nx.Graph()
            skeleton.add_nodes_from(range(n))
            separating_sets = {}

            for i in range(n):
                for j in range(i + 1, n):
                    if np.isnan(constraint_matrix[i, j]) or np.isnan(constraint_matrix[j, i]):
                        skeleton.add_edge(i, j)

            for d in range(n):
                edges = list(skeleton.edges())
                for i, j in edges:
                    if not skeleton.has_edge(i, j):
                        continue
                    adj_i = set(skeleton.neighbors(i)) - {j}
                    if len(adj_i) >= d:
                        for subset in itertools.combinations(adj_i, d):
                            subset_list = list(subset)
                            conditioning_set = data[:, subset_list] if subset_list else None
                            p_val = gin_test(
                                data[:, i], data[:, j],
                                conditioning_set.reshape(data.shape[0], -1) if conditioning_set is not None else None,
                                alpha=alpha
                            )
                            if p_val > alpha:
                                skeleton.remove_edge(i, j)
                                separating_sets[(i, j)] = subset
                                separating_sets[(j, i)] = subset
                                break

            dag = nx.DiGraph()
            dag.add_nodes_from(range(n))
            for i, j in skeleton.edges():
                if np.isnan(constraint_matrix[i, j]) and not np.isnan(constraint_matrix[j, i]):
                    dag.add_edge(i, j)
                elif np.isnan(constraint_matrix[j, i]) and not np.isnan(constraint_matrix[i, j]):
                    dag.add_edge(j, i)
                else:
                    dag.add_edge(i, j)
                    dag.add_edge(j, i)

            for i in range(n):
                for j in range(n):
                    if i == j or not dag.has_edge(i, j):
                        continue
                    for k in range(n):
                        if k == i or k == j:
                            continue
                        if dag.has_edge(k, j) and not skeleton.has_edge(i, k):
                            if ((i, k) in separating_sets and j not in separating_sets[(i, k)]) or \
                               ((k, i) in separating_sets and j not in separating_sets[(k, i)]):
                                if dag.has_edge(j, i):
                                    dag.remove_edge(j, i)
                                if dag.has_edge(j, k):
                                    dag.remove_edge(j, k)

            for i, j in list(dag.edges()):
                if dag.has_edge(j, i):
                    if np.isnan(constraint_matrix[i, j]) and not np.isnan(constraint_matrix[j, i]):
                        dag.remove_edge(j, i)
                    elif np.isnan(constraint_matrix[j, i]) and not np.isnan(constraint_matrix[i, j]):
                        dag.remove_edge(i, j)
                    else:
                        dag.remove_edge(j, i)

            return dag

        dag = pc_gin(encoded_data.values, constraint_matrix)

        # Validate constraints
        violations = validate_constraints(dag, node_name_to_idx, tiers)

        # Visualize
        G = visualize_causal_graph(dag, node_names, "pcgin_graph.png")

        # Save relations to text
        save_relations_to_text(dag, node_names, "pcgin_relations.txt")

        # Analyze structure
        metrics = analyze_structure_learning(dag, node_names)

        return {
            "dag": G,
            "adj_matrix": nx.to_numpy_array(dag, nodelist=range(len(node_names))),
            "metrics": metrics,
            "violations": violations
        }

    except Exception as e:
        logger.error("PC-GIN failed: %s", str(e))
        raise

# 2.4. NOTEARS Algorithm

In [None]:
def run_notears_algorithm(data, constraint_matrix, node_names, node_name_to_idx, tiers):
    logger.info("Running NOTEARS algorithm...")
    try:
        X = data.values
        stds = np.std(X, axis=0)
        means = np.mean(X, axis=0)
        X_standardized = np.where(stds != 0, (X - means) / stds, 0)

        def notears_with_constraints(X, constraint_matrix, lambda1=0.1, max_iter=200, h_tol=1e-8, rho_max=1e+16, w_threshold=0.1):
            n, d = X.shape
            mask = 1.0 - np.isnan(constraint_matrix).astype(float)

            def _h(w):
                W = w.reshape((d, d))
                M = np.eye(d) + W * W / d
                return np.trace(slin.expm(M)) - d

            def _func(w):
                W = w.reshape((d, d))
                W = W * (1.0 - mask)
                R = X - X @ W
                loss = 0.5 / n * np.sum(R * R)
                l1_penalty = lambda1 * np.sum(np.abs(W))
                return loss + l1_penalty

            def _grad(w):
                W = w.reshape((d, d))
                W = W * (1.0 - mask)
                R = X - X @ W
                G = -1.0 / n * X.T @ R
                G_l1 = lambda1 * np.sign(W)
                G = (G + G_l1) * (1.0 - mask)
                return G.flatten()

            def _h_grad(w):
                W = w.reshape((d, d))
                M = np.eye(d) + W * W / d
                E = slin.expm(M)
                G = E.T * (2 * W / d)
                G = G * (1.0 - mask)
                return G.flatten()

            w_est = np.zeros(d * d)
            rho, alpha, h = 1.0, 0.0, np.inf
            for _ in range(max_iter):
                w_new = sopt.minimize(
                    lambda w: _func(w) + 0.5 * rho * _h(w) ** 2 + alpha * _h(w),
                    w_est,
                    method='L-BFGS-B',
                    jac=lambda w: _grad(w) + rho * _h(w) * _h_grad(w) + alpha * _h_grad(w),
                    options={'ftol': 1e-6, 'gtol': 1e-6}
                ).x
                h_new = _h(w_new)
                if abs(h_new) <= h_tol or rho >= rho_max:
                    break
                if abs(h_new) > 0.25 * abs(h):
                    rho *= 10
                alpha += rho * h_new
                w_est, h = w_new, h_new

            W_est = w_est.reshape((d, d))
            W_est = W_est * (1.0 - mask)
            W_est[np.abs(W_est) < w_threshold] = 0

            G = nx.DiGraph(W_est)
            while not nx.is_directed_acyclic_graph(G):
                try:
                    cycle = nx.find_cycle(G)
                    min_weight = float('inf')
                    min_edge = None
                    for u, v in cycle:
                        if abs(W_est[u, v]) < min_weight:
                            min_weight = abs(W_est[u, v])
                            min_edge = (u, v)
                    if min_edge:
                        G.remove_edge(*min_edge)
                        W_est[min_edge[0], min_edge[1]] = 0
                except nx.NetworkXNoCycle:
                    break

            return W_est

        adj_matrix = notears_with_constraints(X_standardized, constraint_matrix)

        # Validate constraints
        violations = validate_constraints(adj_matrix, node_name_to_idx, tiers)

        # Visualize
        G = visualize_causal_graph(adj_matrix, node_names, "notears_graph.png")

        # Save relations to text
        save_relations_to_text(adj_matrix, node_names, "notears_relations.txt")

        # Analyze structure
        metrics = analyze_structure_learning(adj_matrix, node_names)

        return {
            "dag": G,
            "adj_matrix": adj_matrix,
            "metrics": metrics,
            "violations": violations
        }

    except Exception as e:
        logger.error("NOTEARS failed: %s", str(e))
        raise

# 2.5. GRaSP Algorithm

In [None]:
def run_grasp_algorithm(data, constraint_matrix, node_names, node_name_to_idx, tiers):
    logger.info("Running GRaSP algorithm...")
    try:
        X = data.values
        stds = np.std(X, axis=0)
        means = np.mean(X, axis=0)
        X_standardized = np.where(stds != 0, (X - means) / stds, 0)

        def grasp_with_constraints(X, constraint_matrix, lambda1=0.1, max_iter=200, h_tol=1e-8, rho_max=1e+16, w_threshold=0.1):
            n, d = X.shape
            mask = 1.0 - np.isnan(constraint_matrix).astype(float)

            def _h(w):
                W = w.reshape((d, d))
                M = np.eye(d) + W * W / d
                return np.trace(slin.expm(M)) - d

            def _func(w, rho, alpha):
                W = w.reshape((d, d))
                W = W * (1.0 - mask)
                R = X - X @ W
                loss = 0.5 / n * np.sum(R * R)
                l1_penalty = lambda1 * np.sum(np.abs(W))
                h_val = _h(w)
                return loss + l1_penalty + 0.5 * rho * h_val ** 2 + alpha * h_val

            def _grad(w, rho, alpha):
                W = w.reshape((d, d))
                W = W * (1.0 - mask)
                R = X - X @ W
                G_loss = -1.0 / n * X.T @ R
                G_l1 = lambda1 * np.sign(W)
                h_val = _h(w)
                h_gradient = _h_grad(w).reshape((d, d))
                G_acyclicity = (rho * h_val + alpha) * h_gradient
                G = (G_loss + G_l1 + G_acyclicity) * (1.0 - mask)
                return G.flatten()

            def _h_grad(w):
                W = w.reshape((d, d))
                M = np.eye(d) + W * W / d
                E = slin.expm(M)
                G = E.T * (2 * W / d)
                G = G * (1.0 - mask)
                return G.flatten()

            w_est = np.zeros(d * d)
            rho, alpha, h = 1.0, 0.0, np.inf
            for _ in range(max_iter):
                w_new = sopt.minimize(
                    lambda w: _func(w, rho, alpha),
                    w_est,
                    method='L-BFGS-B',
                    jac=lambda w: _grad(w, rho, alpha),
                    options={'ftol': 1e-6, 'gtol': 1e-6}
                ).x
                h_new = _h(w_new)
                if abs(h_new) <= h_tol or rho >= rho_max:
                    break
                if abs(h_new) > 0.25 * abs(h):
                    rho *= 10
                alpha += rho * h_new
                w_est, h = w_new, h_new

            W_est = w_est.reshape((d, d))
            W_est = W_est * (1.0 - mask)
            W_est[np.abs(W_est) < w_threshold] = 0

            G = nx.DiGraph(W_est)
            while not nx.is_directed_acyclic_graph(G):
                try:
                    cycle = nx.find_cycle(G)
                    min_weight = float('inf')
                    min_edge = None
                    for u, v in cycle:
                        if abs(W_est[u, v]) < min_weight:
                            min_weight = abs(W_est[u, v])
                            min_edge = (u, v)
                    if min_edge:
                        G.remove_edge(*min_edge)
                        W_est[min_edge[0], min_edge[1]] = 0
                except nx.NetworkXNoCycle:
                    break

            return W_est

        adj_matrix = grasp_with_constraints(X_standardized, constraint_matrix)

        # Validate constraints
        violations = validate_constraints(adj_matrix, node_name_to_idx, tiers)

        # Visualize
        G = visualize_causal_graph(adj_matrix, node_names, "grasp_graph.png")

        # Save relations to text
        save_relations_to_text(adj_matrix, node_names, "grasp_relations.txt")

        # Analyze structure
        metrics = analyze_structure_learning(adj_matrix, node_names)

        return {
            "dag": G,
            "adj_matrix": adj_matrix,
            "metrics": metrics,
            "violations": violations
        }

    except Exception as e:
        logger.error("GRaSP failed: %s", str(e))
        raise

# 3. Define the Causal Discovery Pipeline

In [None]:
def run_causal_discovery_pipeline(train_data, tiers, specific_constraints=None):
    """
    Run the causal discovery pipeline for all algorithms.

    Args:
        train_data (pd.DataFrame): Training data.
        tiers (list): List of tier lists (e.g., [demographic, customer, billing]).
        specific_constraints (dict): Additional constraints (e.g., {"forbidden": [(src, dst), ...]}).

    Returns:
        dict: Results for each algorithm.
    """
    logger.info("Starting causal discovery pipeline...")

    # Validate input data
    if train_data.empty or train_data.isna().all().all():
        logger.error("Training data is empty or contains only NaNs")
        raise ValueError("Invalid training data")

    # Create constraint matrix
    node_names = list(train_data.columns)
    constraint_matrix, node_name_to_idx = create_constraint_matrix(node_names, tiers, specific_constraints)

    # Define algorithm functions
    algorithms = {
        "DECI": run_deci_algorithm,
        "LiNGAM": run_lingam_algorithm,
        "PC-GIN": run_pcgins_algorithm,
        "NOTEARS": run_notears_algorithm,
        "GRaSP": run_grasp_algorithm
    }

    # Run algorithms
    results = {}
    for algo_name, algo_func in algorithms.items():
        try:
            logger.info("Executing %s...", algo_name)
            # Call the algorithm function with the required arguments
            result = algo_func(train_data, constraint_matrix, node_names, node_name_to_idx, tiers)
            results[algo_name] = result
        except Exception as e:
            logger.error("%s failed: %s", algo_name, str(e))
            results[algo_name] = {"error": str(e)}

    # Summarize results
    summary = pd.DataFrame({
        algo_name: {
            "num_edges": result["metrics"]["num_edges"] if "metrics" in result else "N/A",
            "graph_density": result["metrics"]["graph_density"] if "metrics" in result else "N/A",
            "violations": len(result["violations"]) if "violations" in result else "N/A"
        } for algo_name, result in results.items()
    }).T
    logger.info("Summary of results:\n%s", summary)

    return results

# 4. Split Data into Train and Test Sets

In [None]:
train_data, test_data = train_test_split(data, test_size=0.2, random_state=100)

# 5. Discover the Causality

In [None]:
# Define tiers and specific constraints
tiers = [
    ["Gender", "Age", "Senior Citizen", "Married", "Number of Dependents"],  # Tier 1: Demographic
    ["Number of Referrals", "Tenure in Months", "Offer"],                    # Tier 2: Customer
    ["Phone Service", "Multiple Lines", "Internet Type", "Unlimited Data",
     "Online Security", "Online Backup", "Device Protection Plan",
     "Premium Tech Support", "Streaming TV", "Streaming Movies", "Streaming Music"],  # Tier 3: Service
    ["Total Revenue", "Paperless Billing", "Payment Method"],                # Tier 4: Billing
    ["Churn Label"]                                                         # Tier 5: Outcome
]

specific_constraints = {
    "forbidden": [
        ("Gender", dst) for dst in tiers[2] + tiers[3]
    ] + [
        ("Internet Type", dst) for dst in ["Unlimited Data", "Online Security", "Online Backup",
                                          "Device Protection Plan", "Premium Tech Support",
                                          "Streaming TV", "Streaming Movies", "Streaming Music"]
    ],
    "allowed": []
}

In [None]:
# Run the pipeline (assuming train_data is available from preprocessing)
results = run_causal_discovery_pipeline(train_data, tiers, specific_constraints)


# Evaluate constraint violations on test data
node_names = list(train_data.columns)
_, node_name_to_idx = create_constraint_matrix(node_names, tiers, specific_constraints)
evaluation_summary = evaluate_causal_discovery(results, test_data, tiers, node_name_to_idx)