<b><p style="font-size: 40px;">Causal Discovery and Inference in Customer Churn</p></b>

<b><p style="font-size: 35px;">I. First Phase: Prepare the Data</p></b>

---
# 1. Import the Relevant Packages and Configuire Them if Needed
---

## 1.1. Import the Libraries and Packages

In [109]:
# Standard library imports
from abc import ABC, abstractmethod
import itertools
import json
import logging
import os
import warnings

# Third-party imports
import fsspec
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
from pytorch_lightning.callbacks import TQDMProgressBar
import pytorch_lightning as pl
from scipy import stats
import scipy.linalg as slin
import scipy.optimize as sopt
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import LabelEncoder
import torch
from causica.datasets.causica_dataset_format import Variable
from causica.distributions import ContinuousNoiseDist
from causica.lightning.data_modules.basic_data_module import BasicDECIDataModule
from causica.lightning.modules.deci_module import DECIModule
from causica.training.auglag import AugLagLRConfig
import lingam



## 1.2. Configure the Tools

In [110]:
np.set_printoptions(precision=3, suppress=True)
np.random.seed(100)
pd.set_option("display.max_columns", None)
pd.set_option("display.max_colwidth", None)
pd.set_option("display.precision", 2)
plt.rcParams["font.family"] = "Times New Roman"

## 1.3. Setup the Environment

In [111]:
test_run = bool(os.environ.get("TEST_RUN", False))
DATA_PATH = "data/dataset.csv"
VARIABLES_PATH = "data/variables.json"
from dotenv import load_dotenv
load_dotenv()
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

## 1.4 Data Loading and Preview

In [112]:
data = pd.read_csv(DATA_PATH)
data.head(5).style.set_caption("<b>IBM Telco Customer Churn Dataset (First 5 Rows)</b>")

Unnamed: 0,Customer ID,Gender,Age,Under 30,Senior Citizen,Married,Dependents,Number of Dependents,Country,State,City,Zip Code,Latitude,Longitude,Population,Quarter,Referred a Friend,Number of Referrals,Tenure in Months,Offer,Phone Service,Avg Monthly Long Distance Charges,Multiple Lines,Internet Service,Internet Type,Avg Monthly GB Download,Online Security,Online Backup,Device Protection Plan,Premium Tech Support,Streaming TV,Streaming Movies,Streaming Music,Unlimited Data,Contract,Paperless Billing,Payment Method,Monthly Charge,Total Charges,Total Refunds,Total Extra Data Charges,Total Long Distance Charges,Total Revenue,Satisfaction Score,Customer Status,Churn Label,Churn Score,CLTV,Churn Category,Churn Reason
0,8779-QRDMV,Male,78,No,Yes,No,No,0,United States,California,Los Angeles,90022,34.02381,-118.156582,68701,Q3,No,0,1,,No,0.0,No,Yes,DSL,8,No,No,Yes,No,No,Yes,No,No,Month-to-Month,Yes,Bank Withdrawal,39.65,39.65,0.0,20,0.0,59.65,3,Churned,Yes,91,5433,Competitor,Competitor offered more data
1,7495-OOKFY,Female,74,No,Yes,Yes,Yes,1,United States,California,Los Angeles,90063,34.044271,-118.185237,55668,Q3,Yes,1,8,Offer E,Yes,48.85,Yes,Yes,Fiber Optic,17,No,Yes,No,No,No,No,No,Yes,Month-to-Month,Yes,Credit Card,80.65,633.3,0.0,0,390.8,1024.1,3,Churned,Yes,69,5302,Competitor,Competitor made better offer
2,1658-BYGOY,Male,71,No,Yes,No,Yes,3,United States,California,Los Angeles,90065,34.108833,-118.229715,47534,Q3,No,0,18,Offer D,Yes,11.33,Yes,Yes,Fiber Optic,52,No,No,No,No,Yes,Yes,Yes,Yes,Month-to-Month,Yes,Bank Withdrawal,95.45,1752.55,45.61,0,203.94,1910.88,2,Churned,Yes,81,3179,Competitor,Competitor made better offer
3,4598-XLKNJ,Female,78,No,Yes,Yes,Yes,1,United States,California,Inglewood,90303,33.936291,-118.332639,27778,Q3,Yes,1,25,Offer C,Yes,19.76,No,Yes,Fiber Optic,12,No,Yes,Yes,No,Yes,Yes,No,Yes,Month-to-Month,Yes,Bank Withdrawal,98.5,2514.5,13.43,0,494.0,2995.07,2,Churned,Yes,88,5337,Dissatisfaction,Limited range of services
4,4846-WHAFZ,Female,80,No,Yes,Yes,Yes,1,United States,California,Whittier,90602,33.972119,-118.020188,26265,Q3,Yes,1,37,Offer C,Yes,6.33,Yes,Yes,Fiber Optic,14,No,No,No,No,No,No,No,Yes,Month-to-Month,Yes,Bank Withdrawal,76.5,2868.15,0.0,0,234.21,3102.36,2,Churned,Yes,67,2793,Price,Extra data charges


## 1.5. Feature Organization and Data Reordering

In [113]:
feature_mapping = {
    "Customer Info": [
        "Customer ID", "Gender", "Age", "Under 30", "Senior Citizen", "Married", "Dependents", "Number of Dependents"
    ],
    "Location Info": [
        "Country", "State", "City", "Zip Code", "Latitude", "Longitude", "Population"
    ],
    "Referral & Tenure": [
        "Quarter", "Referred a Friend", "Number of Referrals", "Tenure in Months", "Offer"
    ],
    "Services Signed Up": [
        "Phone Service", "Multiple Lines", "Internet Service", "Internet Type", "Unlimited Data"
    ],
    "Internet Features": [
        "Online Security", "Online Backup", "Device Protection Plan", "Premium Tech Support",
        "Streaming TV", "Streaming Movies", "Streaming Music"
    ],
    "Billing & Payment": [
        "Avg Monthly Long Distance Charges", "Avg Monthly GB Download", "Monthly Charge",
        "Total Charges", "Total Refunds", "Total Extra Data Charges",
        "Total Long Distance Charges", "Total Revenue", "Paperless Billing", "Payment Method"
    ],
    "Customer Scores": [
        "Satisfaction Score", "CLTV", "Churn Score"
    ],
    "Churn Info": [
        "Customer Status", "Churn Label", "Churn Category", "Churn Reason"
    ]
}
desired_order = [col for group in feature_mapping.values() for col in group]
data = data[desired_order]

## 1.6. Feature Removal

In [114]:
features_to_remove = [
    "Customer ID", "Under 30", "Dependents", "Country", "State", "Zip Code", "City", "Latitude", "Longitude",
    "Population", "Referred a Friend", "Internet Service", "Customer Status", "Churn Score", "Churn Category",
    "Churn Reason", "CLTV", "Satisfaction Score", "Quarter", "Avg Monthly Long Distance Charges",
    "Avg Monthly GB Download", "Monthly Charge", "Total Charges", "Total Refunds", "Total Extra Data Charges",
    "Total Long Distance Charges"
]
data = data.drop(columns=features_to_remove)

## 1.7. Outlier Handling for Numeric Features

In [115]:
# Handle outliers in numeric features
numeric_features = ["Age", "Number of Dependents", "Number of Referrals", "Tenure in Months", "Total Revenue"]
print("Summary statistics before outlier handling:\n", data[numeric_features].describe())
for feature in numeric_features:
    q1, q3 = data[feature].quantile([0.25, 0.75])
    iqr = q3 - q1
    lower_bound, upper_bound = q1 - 1.5 * iqr, q3 + 1.5 * iqr
    data[feature] = data[feature].clip(lower=lower_bound, upper=upper_bound)
print("Summary statistics after outlier handling:\n", data[numeric_features].describe())

Summary statistics before outlier handling:
            Age  Number of Dependents  Number of Referrals  Tenure in Months  \
count  7043.00               7043.00              7043.00           7043.00   
mean     46.51                  0.47                 1.95             32.39   
std      16.75                  0.96                 3.00             24.54   
min      19.00                  0.00                 0.00              1.00   
25%      32.00                  0.00                 0.00              9.00   
50%      46.00                  0.00                 0.00             29.00   
75%      60.00                  0.00                 3.00             55.00   
max      80.00                  9.00                11.00             72.00   

       Total Revenue  
count        7043.00  
mean         3034.38  
std          2865.20  
min            21.36  
25%           605.61  
50%          2108.64  
75%          4801.15  
max         11979.34  
Summary statistics after outlier han

## 1.8. Missing Value Imputation and Verification

In [116]:
# Fill missing values and verify none remain
data["Offer"] = data["Offer"].fillna("None")
data["Internet Type"] = data["Internet Type"].fillna("No Internet")
missing_values = data.isnull().sum()
if missing_values.sum() > 0:
    print("Warning: Missing values found after preprocessing:\n", missing_values[missing_values > 0])
else:
    print("No missing values after preprocessing.")

No missing values after preprocessing.


In [117]:
non_missing_percentage = data.notnull().mean() * 100
non_missing_df = pd.DataFrame(non_missing_percentage, columns=["Non-Missing Percentage"])
non_missing_df = non_missing_df.reset_index().rename(columns={"index": "Feature"})
non_missing_df.index = non_missing_df.index + 1
non_missing_df.style.set_caption("<b>Non-Missing Percentage of Features</b>").format({"Non-Missing Percentage": "{:.2f}"})

Unnamed: 0,Feature,Non-Missing Percentage
1,Gender,100.0
2,Age,100.0
3,Senior Citizen,100.0
4,Married,100.0
5,Number of Dependents,100.0
6,Number of Referrals,100.0
7,Tenure in Months,100.0
8,Offer,100.0
9,Phone Service,100.0
10,Multiple Lines,100.0


## 1.9. Feature Type Categorization and Display

In [118]:
categorical_features = [
    "Gender", "Senior Citizen", "Married", "Offer", "Phone Service", "Multiple Lines", "Internet Type",
    "Unlimited Data", "Online Security", "Online Backup", "Device Protection Plan", "Premium Tech Support",
    "Streaming TV", "Streaming Movies", "Streaming Music", "Paperless Billing", "Payment Method", "Churn Label"
]
numeric_features = ["Age", "Number of Dependents", "Number of Referrals", "Tenure in Months", "Total Revenue"]
max_length = max(len(categorical_features), len(numeric_features))
categorical_features_feature = categorical_features.copy()
numeric_features_feature = numeric_features.copy()
categorical_features_feature += [""] * (max_length - len(categorical_features))
numeric_features_feature += [""] * (max_length - len(numeric_features))
feature_types_df = pd.DataFrame({
    "Categorical Features": categorical_features_feature,
    "Numeric Features": numeric_features_feature
})
feature_types_df.index = feature_types_df.index + 1
feature_types_df.style.set_caption("<b>Categorization of Features by Type</b>")

Unnamed: 0,Categorical Features,Numeric Features
1,Gender,Age
2,Senior Citizen,Number of Dependents
3,Married,Number of Referrals
4,Offer,Tenure in Months
5,Phone Service,Total Revenue
6,Multiple Lines,
7,Internet Type,
8,Unlimited Data,
9,Online Security,
10,Online Backup,


In [119]:
categorical_features_data_types = pd.DataFrame(data[categorical_features].dtypes, columns=["Categorical Features' Data Types"])
categorical_features_data_types.style.set_caption("<b>Categorical Features' Data Types</b>")
numeric_features_data_types = pd.DataFrame(data[numeric_features].dtypes, columns=["Numeric Features' Data Types"])
numeric_features_data_types.style.set_caption("Numeric Features' Data Types")

Unnamed: 0,Numeric Features' Data Types
Age,int64
Number of Dependents,int64
Number of Referrals,float64
Tenure in Months,int64
Total Revenue,float64


## 1.10. Analysis of Categorical Feature Unique Values

In [120]:
rows = []
for feature in categorical_features:
    value_counts = data[feature].value_counts()
    first_row = True
    for value, count in value_counts.items():
        percentage = round((count / len(data)) * 100, 2)
        feature_name = feature if first_row else ""
        rows.append([feature_name, value, count, percentage])
        first_row = False
unique_values_df = pd.DataFrame(rows, columns=["Feature", "Unique Value", "Frequency", "Percentage"])
unique_values_df["Percentage"] = unique_values_df["Percentage"].round(2)
unique_values_df = (
    unique_values_df.style
    .set_caption("<b>Unique Values in Categorical Features</b>")
    .hide(axis="index")
)
unique_values_df

Feature,Unique Value,Frequency,Percentage
Gender,Male,3555,50.48
,Female,3488,49.52
Senior Citizen,No,5901,83.79
,Yes,1142,16.21
Married,No,3641,51.7
,Yes,3402,48.3
Offer,,3877,55.05
,Offer B,824,11.7
,Offer E,805,11.43
,Offer D,602,8.55


In [121]:
# Split data into train and test sets
from sklearn.model_selection import train_test_split
train_data, test_data = train_test_split(data, test_size=0.2, random_state=100)
print(f"Train set shape: {train_data.shape}, Test set shape: {test_data.shape}")

Train set shape: (5634, 23), Test set shape: (1409, 23)


## 1.11. Binary Feature Encoding

In [122]:
# Encode binary features for train and test
def encode_binary_features(datasets, features, mapping):
    for df in datasets:
        for feature in features:
            print(f"Unique values in {feature} before encoding:", df[feature].unique())
            df[feature] = df[feature].astype(str).map(mapping)
            df[feature] = df[feature].fillna(0)
binary_features = [
    "Gender", "Senior Citizen", "Married", "Phone Service", "Multiple Lines", "Unlimited Data",
    "Online Security", "Online Backup", "Device Protection Plan", "Premium Tech Support",
    "Streaming TV", "Streaming Movies", "Streaming Music", "Paperless Billing", "Churn Label"
]
binary_mapping = {"Yes": 1, "No": 0, "Male": 1, "Female": 0}
encode_binary_features(datasets=[train_data, test_data], features=binary_features, mapping=binary_mapping)

Unique values in Gender before encoding: ['Female' 'Male']
Unique values in Senior Citizen before encoding: ['No' 'Yes']
Unique values in Married before encoding: ['Yes' 'No']
Unique values in Phone Service before encoding: ['Yes' 'No']
Unique values in Multiple Lines before encoding: ['Yes' 'No']
Unique values in Unlimited Data before encoding: ['Yes' 'No']
Unique values in Online Security before encoding: ['Yes' 'No']
Unique values in Online Backup before encoding: ['No' 'Yes']
Unique values in Device Protection Plan before encoding: ['Yes' 'No']
Unique values in Premium Tech Support before encoding: ['Yes' 'No']
Unique values in Streaming TV before encoding: ['Yes' 'No']
Unique values in Streaming Movies before encoding: ['Yes' 'No']
Unique values in Streaming Music before encoding: ['Yes' 'No']
Unique values in Paperless Billing before encoding: ['Yes' 'No']
Unique values in Churn Label before encoding: ['No' 'Yes']
Unique values in Gender before encoding: ['Female' 'Male']
Unique 

## 1.12. Ordinal Feature Encoding and Type Conversion

In [123]:
# Encode ordinal features and ensure float types
def encode_ordinal_features(datasets, mappings):
    for df in datasets:
        for feature, mapping in mappings.items():
            print(f"Unique values in {feature} before encoding:", df[feature].unique())
            unmapped = set(df[feature].astype(str)) - set(mapping.keys())
            if unmapped:
                warnings.warn(f"Unmapped values in {feature}: {unmapped}")
            df[feature] = df[feature].astype(str).map(lambda x: mapping.get(x, 0))
            df[feature] = df[feature].fillna(0)
offer_mapping = {
    "None": 0, "Offer A": 1, "Offer B": 2, "Offer C": 3, "Offer D": 4, "Offer E": 5
}
internet_type_mapping = {
    "No Internet": 0, "DSL": 1, "Cable": 2, "Fiber Optic": 3
}
payment_method_mapping = {
    "Mailed Check": 1, "Bank Withdrawal": 2, "Credit Card": 3
}
ordinal_mappings = {
    "Offer": offer_mapping,
    "Internet Type": internet_type_mapping,
    "Payment Method": payment_method_mapping
}
encode_ordinal_features(datasets=[train_data, test_data], mappings=ordinal_mappings)
train_data = train_data.astype(float)
test_data = test_data.astype(float)
print("Train data types:\n", train_data.dtypes)
print("Test data types:\n", test_data.dtypes)

Unique values in Offer before encoding: ['None' 'Offer D' 'Offer A' 'Offer B' 'Offer C' 'Offer E']
Unique values in Internet Type before encoding: ['Fiber Optic' 'None' 'DSL' 'Cable']
Unique values in Payment Method before encoding: ['Bank Withdrawal' 'Credit Card' 'Mailed Check']
Unique values in Offer before encoding: ['Offer B' 'None' 'Offer C' 'Offer D' 'Offer E' 'Offer A']
Unique values in Internet Type before encoding: ['None' 'Cable' 'DSL' 'Fiber Optic']
Unique values in Payment Method before encoding: ['Credit Card' 'Bank Withdrawal' 'Mailed Check']
Train data types:
 Gender                    float64
Age                       float64
Senior Citizen            float64
Married                   float64
Number of Dependents      float64
Number of Referrals       float64
Tenure in Months          float64
Offer                     float64
Phone Service             float64
Multiple Lines            float64
Internet Type             float64
Unlimited Data            float64
Online Se



## 1.13. Numeric Feature Scaling

In [124]:
# Scale numeric features for train and test
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
numeric_features = ["Age", "Number of Dependents", "Number of Referrals", "Tenure in Months", "Total Revenue"]
train_data[numeric_features] = scaler.fit_transform(train_data[numeric_features])
test_data[numeric_features] = scaler.transform(test_data[numeric_features])
print("Train numeric features scaled:\n", train_data[numeric_features].describe())
print("Test numeric features scaled:\n", test_data[numeric_features].describe())

Train numeric features scaled:
             Age  Number of Dependents  Number of Referrals  Tenure in Months  \
count  5.63e+03                5634.0             5.63e+03          5.63e+03   
mean   1.08e-16                   0.0            -3.56e-17         -1.14e-16   
std    1.00e+00                   0.0             1.00e+00          1.00e+00   
min   -1.64e+00                   0.0            -6.88e-01         -1.29e+00   
25%   -8.64e-01                   0.0            -6.88e-01         -9.61e-01   
50%   -3.11e-02                   0.0            -6.88e-01         -1.45e-01   
75%    8.02e-01                   0.0             4.33e-01          9.15e-01   
max    1.99e+00                   0.0             2.12e+00          1.61e+00   

       Total Revenue  
count       5.63e+03  
mean        1.46e-16  
std         1.00e+00  
min        -1.06e+00  
25%        -8.48e-01  
50%        -3.18e-01  
75%         6.14e-01  
max         2.82e+00  
Test numeric features scaled:
          

<b><p style="font-size: 35px;">II. Second Phase: Causal Discovery and Inference</p></b>

## 2.1. Logging and Random Seed Configuration

In [125]:
# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[logging.FileHandler("causal_discovery.log"), logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# Set random seeds for reproducibility
np.random.seed(100)
torch.manual_seed(100)
pl.seed_everything(100)

Seed set to 100


100

## 2.2.1 constraint_matrix

In [126]:
logger = logging.getLogger(__name__)

def create_constraint_matrix(node_names, tiers, specific_constraints=None):
    """
    Create a constraint matrix for causal discovery algorithms.
    
    Args:
        node_names (list): List of variable names.
        tiers (list): List of tier lists (e.g., [demographic, customer, ...]).
        specific_constraints (dict): Additional constraints (e.g., {"forbidden": [(src, dst), ...]}).
    
    Returns:
        np.ndarray: Constraint matrix (np.nan for allowed edges, 0.0 for forbidden).
    """
    num_nodes = len(node_names)
    node_name_to_idx = {name: i for i, name in enumerate(node_names)}
    constraint_matrix = np.full((num_nodes, num_nodes), np.nan, dtype=np.float32)
    
    # Set Churn Label as sink node (no outgoing edges)
    churn_idx = node_name_to_idx.get("Churn Label")
    if churn_idx is not None:
        constraint_matrix[churn_idx, :] = 0.0
    
    # Set demographic variables as root nodes (no incoming edges)
    for feature in tiers[0]:  # Tier 1: Demographic
        if feature in node_name_to_idx:
            feature_idx = node_name_to_idx[feature]
            constraint_matrix[:, feature_idx] = 0.0
    
    # Prevent edges within Tier 1
    for src in tiers[0]:
        for dst in tiers[0]:
            if src != dst and src in node_name_to_idx and dst in node_name_to_idx:
                constraint_matrix[node_name_to_idx[src], node_name_to_idx[dst]] = 0.0
    
    # Allow edges only from Tier N to Tier N+1
    for src_tier_idx, src_tier in enumerate(tiers[:-1]):
        dst_tier = tiers[src_tier_idx + 1]
        for src in src_tier:
            for dst in dst_tier:
                if src in node_name_to_idx and dst in node_name_to_idx:
                    constraint_matrix[node_name_to_idx[src], node_name_to_idx[dst]] = np.nan
        # Block edges to other tiers
        for other_tier_idx, other_tier in enumerate(tiers):
            if other_tier_idx != src_tier_idx + 1:
                for src in src_tier:
                    for dst in other_tier:
                        if src in node_name_to_idx and dst in node_name_to_idx:
                            constraint_matrix[node_name_to_idx[src], node_name_to_idx[dst]] = 0.0
    
    # Apply specific constraints (e.g., Gender → Service/Billing forbidden)
    if specific_constraints:
        for src, dst in specific_constraints.get("forbidden", []):
            if src in node_name_to_idx and dst in node_name_to_idx:
                constraint_matrix[node_name_to_idx[src], node_name_to_idx[dst]] = 0.0
        for src, dst in specific_constraints.get("allowed", []):
            if src in node_name_to_idx and dst in node_name_to_idx:
                constraint_matrix[node_name_to_idx[src], node_name_to_idx[dst]] = np.nan
    
    logger.info("Constraint matrix created with shape: %s", constraint_matrix.shape)
    return constraint_matrix, node_name_to_idx

## 2.2.2 validatation constraints

In [127]:


def validate_constraints(dag, node_name_to_idx, tiers, threshold=0.5):
    """
    Validate constraints on a DAG or adjacency matrix.
    
    Args:
        dag: NetworkX DiGraph or np.ndarray (adjacency/probability matrix).
        node_name_to_idx (dict): Mapping of node names to indices.
        tiers (list): List of tier lists.
        threshold (float): Probability threshold for matrix-based DAGs.
    
    Returns:
        list: List of constraint violation messages (empty if none).
    """
    violations = []
    num_nodes = len(node_name_to_idx)
    churn_idx = node_name_to_idx.get("Churn Label")
    
    # Convert adjacency matrix to DAG if needed
    if isinstance(dag, np.ndarray):
        G = nx.DiGraph()
        G.add_nodes_from(range(num_nodes))
        for i in range(num_nodes):
            for j in range(num_nodes):
                if dag[i, j] > threshold:
                    G.add_edge(i, j)
    else:
        G = dag
    
    # Check if Churn Label has outgoing edges
    if churn_idx is not None and any(G.has_edge(churn_idx, j) for j in range(num_nodes)):
        violations.append("Churn Label has outgoing edges")
    
    # Check if Tier 1 variables have incoming edges
    for var in tiers[0]:
        var_idx = node_name_to_idx.get(var)
        if var_idx is not None and any(G.has_edge(j, var_idx) for j in range(num_nodes)):
            violations.append(f"{var} has incoming edges")
    
    # Check for edges within Tier 1
    for src in tiers[0]:
        src_idx = node_name_to_idx.get(src)
        for dst in tiers[0]:
            dst_idx = node_name_to_idx.get(dst)
            if src != dst and src_idx is not None and dst_idx is not None and G.has_edge(src_idx, dst_idx):
                violations.append(f"T1↛T1 edge: {src}→{dst}")
    
    # Check for forbidden Gender → Service/Billing edges
    gender_idx = node_name_to_idx.get("Gender")
    if gender_idx is not None:
        for dst in tiers[2] + tiers[3]:  # Service + Billing
            dst_idx = node_name_to_idx.get(dst)
            if dst_idx is not None and G.has_edge(gender_idx, dst_idx):
                violations.append(f"Gender→{dst} edge exists")
    
    if not violations:
        logger.info("✅ All constraints validated successfully")
    else:
        logger.warning("❌ Constraint violations detected: %s", violations)
    
    return violations

    
    




## 2.2.3 save relations to text

In [128]:

def save_relations_to_text(dag, node_names, filename, threshold=0.2):
    """
    Save causal relationships to a text file and print to console with improved formatting.
    
    Args:
        dag: NetworkX DiGraph or np.ndarray (adjacency/probability matrix).
        node_names (list): List of node names.
        filename (str): Output text file name.
        threshold (float): Probability/weight threshold for matrix-based DAGs.
    """
    try:
        relations = []
        print("\n=== Causal Relationships ===")  # Header for console output
        if isinstance(dag, np.ndarray):
            for i in range(dag.shape[0]):
                for j in range(dag.shape[1]):
                    if dag[i, j] > threshold:
                        relations.append({
                            "source": node_names[i],
                            "destination": node_names[j],
                            "weight": float(dag[i, j])
                        })
                        # Improved console output with weight and separator
                        print(f"{node_names[i]} -> {node_names[j]} (weight: {dag[i, j]:.3f})")
                        print("---")
        else:
            for src, dst in dag.edges():
                weight = dag[src][dst].get("weight", 1.0)
                relations.append({
                    "source": node_names[src],
                    "destination": node_names[dst],
                    "weight": float(weight)
                })
                # Improved console output with weight and separator
                print(f"{node_names[src]} -> {node_names[dst]} (weight: {weight:.3f})")
                print("---")
        
        # Convert to DataFrame and save with better formatting
        relations_df = pd.DataFrame(relations)
        if not relations_df.empty:
            # Write to file with a header and formatted entries
            with open(filename, "w") as f:
                f.write("Source -> Destination (Weight)\n")
                f.write("=" * 30 + "\n")
                for _, row in relations_df.iterrows():
                    f.write(f"{row['source']:<25} -> {row['destination']:<25} ({row['weight']:.3f})\n")
                    f.write("-" * 30 + "\n")
            logger.info("Causal relationships saved to %s (%d relations)", filename, len(relations_df))
        else:
            with open(filename, "w") as f:
                f.write("No causal relationships found.\n")
            logger.warning("No causal relationships to save for %s", filename)
    
    except Exception as e:
        logger.error("Failed to save relations to %s: %s", filename, str(e))
        raise

## 2.2.4 visualize causal graph

In [129]:
def visualize_causal_graph(dag, node_names, filename="causal_graph.png"):
    """
    Visualize a causal graph with uniform edge lengths, focusing on nodes and relationships.
    
    Args:
        dag: NetworkX DiGraph or np.ndarray (adjacency matrix).
        node_names (list): List of node names.
        filename (str): Output file name.
    
    Returns:
        nx.DiGraph: Visualized graph.
    """
    if isinstance(dag, np.ndarray):
        G = nx.DiGraph(dag)
    else:
        G = dag

    # Relabel nodes with names
    G = nx.relabel_nodes(G, dict(enumerate(node_names)))

    # Use circular layout for uniform edge lengths and clear node placement
    plt.figure(figsize=(12, 10), dpi=300)
    pos = nx.circular_layout(G)  # Changed to circular_layout for uniform edge lengths

    # Draw the graph with uniform edges (no weight consideration)
    nx.draw_networkx_nodes(G, pos, node_size=1000, node_color="lightblue", alpha=0.8)
    nx.draw_networkx_edges(G, pos, width=1, alpha=0.6, arrowsize=15)  # Fixed width for all edges
    nx.draw_networkx_labels(G, pos, font_size=8, font_weight="bold")

    plt.title(f"Causal Graph ({filename.split('.')[0]})")
    plt.axis("off")
    plt.savefig(filename, format="png", bbox_inches="tight")
    plt.close()
    logger.info("Causal graph saved as %s", filename)
    return G

## 2.2.5 evaluating the learned graph

In [130]:
def analyze_structure_learning(dag, node_names, threshold=0.5):
    """
    Analyze the structure of a learned DAG.
    
    Args:
        dag: NetworkX DiGraph or np.ndarray (adjacency/probability matrix).
        node_names (list): List of node names.
        threshold (float): Probability threshold for matrix-based DAGs.
    
    Returns:
        dict: Structure learning metrics.
    """
    if isinstance(dag, np.ndarray):
        adj_matrix = (dag > threshold).astype(int)
        G = nx.DiGraph(adj_matrix)
    else:
        G = dag
        adj_matrix = nx.to_numpy_array(G, nodelist=range(len(node_names)))
    
    num_edges = np.sum(adj_matrix)
    num_nodes = len(node_names)
    graph_density = num_edges / (num_nodes * (num_nodes - 1)) if num_nodes > 1 else 0
    in_degree = np.sum(adj_matrix, axis=0)
    out_degree = np.sum(adj_matrix, axis=1)
    avg_in_degree = np.mean(in_degree)
    avg_out_degree = np.mean(out_degree)
    
    most_influential = [(node_names[i], out_degree[i]) for i in np.argsort(-out_degree)[:5]]
    most_affected = [(node_names[i], in_degree[i]) for i in np.argsort(-in_degree)[:5]]
    
    metrics = {
        "num_edges": num_edges,
        "graph_density": graph_density,
        "avg_in_degree": avg_in_degree,
        "avg_out_degree": avg_out_degree,
        "most_influential": most_influential,
        "most_affected": most_affected
    }
    
    logger.info("Structure Learning Metrics: %s", metrics)
    return metrics



def evaluate_causal_discovery(results, test_data, tiers, node_name_to_idx, threshold=0.5):
    """
    Evaluate causal discovery algorithms on test data, focusing on constraint violations.
    
    Args:
        results (dict): Results from run_causal_discovery_pipeline.
        test_data (pd.DataFrame): Test data.
        tiers (list): List of tier lists.
        node_name_to_idx (dict): Mapping of node names to indices.
        threshold (float): Probability threshold for matrix-based DAGs.
    
    Returns:
        pd.DataFrame: Evaluation metrics for each algorithm.
    """
    logger.info("Evaluating constraint violations on test data...")
    evaluation_metrics = {}
    
    # Preprocess test data to handle issues
    constant_cols = [col for col in test_data.columns if test_data[col].std() == 0]
    if constant_cols:
        logger.warning("Constant columns in test data: %s", constant_cols)
        test_data = test_data.drop(columns=constant_cols)
    if test_data.isna().any().any():
        logger.warning("NaNs in test data, filling with mean")
        test_data = test_data.fillna(test_data.mean())
    
    for algo_name, result in results.items():
        if "error" in result:
            evaluation_metrics[algo_name] = {
                "constraint_violations": "N/A",
                "violation_details": "N/A"
            }
            continue
        
        try:
            dag = result["dag"]
            adj_matrix = result["adj_matrix"]
            
            # Adjust node_name_to_idx for LiNGAM (uses subset of features)
            algo_node_name_to_idx = node_name_to_idx
            if algo_name == "LiNGAM":
                continuous_features = [f for f in test_data.columns if f in ["Age", "Number of Dependents", "Number of Referrals", "Tenure in Months", "Total Revenue"]]
                algo_node_name_to_idx = {name: i for i, name in enumerate(continuous_features)}
                # Update tiers for LiNGAM
                lingam_tiers = [
                    [f for f in tiers[0] if f in continuous_features],
                    [f for f in tiers[1] if f in continuous_features],
                    [f for f in tiers[3] if f in continuous_features]
                ]
            else:
                lingam_tiers = tiers
            
            # Validate constraints
            violations = validate_constraints(dag if algo_name != "LiNGAM" else adj_matrix, 
                                           algo_node_name_to_idx, lingam_tiers, threshold)
            
            evaluation_metrics[algo_name] = {
                "constraint_violations": len(violations),
                "violation_details": violations if violations else "None"
            }
            
            logger.info("%s: %d constraint violations on test data: %s",
                        algo_name, len(violations), violations if violations else "None")
        
        except Exception as e:
            logger.error("Evaluation failed for %s: %s", algo_name, str(e))
            evaluation_metrics[algo_name] = {
                "constraint_violations": "N/A",
                "violation_details": "N/A"
            }
    
    # Create summary DataFrame
    summary = pd.DataFrame(evaluation_metrics).T
    logger.info("Constraint violation summary:\n%s", summary)
    
    return summary

## Algorithms

## 2.3.1. CausalDiscoveryAlgorithm Base Class


In [131]:
class CausalDiscoveryAlgorithm(ABC):
    """Base class for causal discovery algorithms."""
    
    @abstractmethod
    def fit(self, data, constraint_matrix, node_names, node_name_to_idx, tiers, output_dir):
        """
        Fit the causal discovery algorithm.
        
        Args:
            data (pd.DataFrame): Input data.
            constraint_matrix (np.ndarray): Constraint matrix.
            node_names (list): List of node names.
            node_name_to_idx (dict): Mapping of node names to indices.
            tiers (list): List of tier lists.
        
        Returns:
            dict: Results including DAG, adjacency matrix, metrics, and violations.
        """
        pass

## 2.3.2. DECI Algorithm


In [132]:
class DECIAlgorithm(CausalDiscoveryAlgorithm):
    def fit(self, data, constraint_matrix, node_names, node_name_to_idx, tiers, output_dir):
        logger.info("Running DECI algorithm...")
        try:
            # Load variables.json
            with fsspec.open("data/variables.json", mode="r", encoding="utf-8") as f:
                variables = json.load(f)["variables"]
            
            # Validate columns
            expected_columns = [var["name"] for var in variables]
            if set(expected_columns) != set(data.columns):
                raise ValueError(f"Columns mismatch: {set(data.columns)} vs {set(expected_columns)}")
            
            # Prepare data module
            data_module = BasicDECIDataModule(
                data,
                variables=[Variable.from_dict(d) for d in variables],
                batch_size=128,
                normalize=True
            )
            
            # Initialize DECI module
            lightning_module = DECIModule(
                noise_dist=ContinuousNoiseDist.GAUSSIAN,
                prior_sparsity_lambda=100.0,
                init_rho=30.0,
                init_alpha=0.20,
                auglag_config=AugLagLRConfig(
                    max_inner_steps=1500,
                    max_outer_steps=8,
                    lr_init_dict={
                        "icgnn": 0.00076,
                        "vardist": 0.0098,
                        "functional_relationships": 3e-4,
                        "noise_dist": 0.0070,
                    }
                )
            )
            lightning_module.constraint_matrix = torch.tensor(constraint_matrix)
            
            # Train
            trainer = pl.Trainer(
                accelerator="gpu" if torch.cuda.is_available() else "cpu",
                devices=1,
                max_epochs=10,
                callbacks=[TQDMProgressBar(refresh_rate=19)],
                enable_checkpointing=False
            )
            trainer.fit(lightning_module, datamodule=data_module)
            
            # Save model
            torch.save(lightning_module.sem_module, "deci.pt")
            
            # Compute probability matrix
            logits_exist = lightning_module.sem_module.adjacency_module.adjacency_distribution.logits_exist
            logits_orient = lightning_module.sem_module.adjacency_module.adjacency_distribution.logits_orient
            
            def fill_triangular(vec, upper=False):
                n = int(np.sqrt(2 * len(vec))) + 1
                if upper:
                    return vec.new_zeros(n, n).triu(1).masked_scatter_(
                        torch.triu(torch.ones(n, n, device=vec.device), 1).bool(), vec
                    )
                return vec.new_zeros(n, n).tril(-1).masked_scatter_(
                        torch.tril(torch.ones(n, n, device=vec.device), -1).bool(), vec
                    )
            
            neg_theta = fill_triangular(logits_orient, upper=True) - fill_triangular(logits_orient, upper=False)
            logits_matrix = -torch.logsumexp(torch.stack([-logits_exist, neg_theta, neg_theta - logits_exist], dim=-1), dim=-1)
            prob_matrix = 1 / (1 + np.exp(-logits_matrix.cpu().detach().numpy()))
            prob_matrix = prob_matrix * np.isnan(constraint_matrix)
            
            # Validate constraints
            violations = validate_constraints(prob_matrix, node_name_to_idx, tiers)
            
            # Create a directed graph for visualization
            G = nx.DiGraph()
            for i in range(len(node_names)):
                G.add_node(i)
            for i in range(len(node_names)):
                for j in range(len(node_names)):
                    if prob_matrix[i, j] > 0.5:  # Threshold for edge existence
                        G.add_edge(i, j)
            
            # Visualize and save results
            G_viz = visualize_causal_graph(G, node_names, os.path.join(output_dir, "deci_graph.png"))
            save_relations_to_text(prob_matrix, node_names, os.path.join(output_dir, "deci_relations.txt"))
            
            # Analyze structure
            metrics = analyze_structure_learning(prob_matrix, node_names)
            
            return {
                "dag": G,
                "adj_matrix": prob_matrix,
                "metrics": metrics,
                "violations": violations
            }
        
        except Exception as e:
            logger.error("DECI failed: %s", str(e))
            raise


## 2.3.3. LiNGAM Algorithm


In [133]:

class LiNGAMAlgorithm(CausalDiscoveryAlgorithm):
    def fit(self, data, constraint_matrix, node_names, node_name_to_idx, tiers, output_dir):
        logger.info("Running LiNGAM algorithm...")
        try:
            # Filter continuous features
            continuous_features = [f for f in node_names if f in ["Age", "Number of Dependents", "Number of Referrals", "Tenure in Months", "Total Revenue"]]
            lingam_data = data[continuous_features].copy()
            
            # Remove constant columns
            constant_cols = [col for col in lingam_data.columns if lingam_data[col].std() == 0]
            if constant_cols:
                logger.warning("Removing constant columns: %s", constant_cols)
                lingam_data = lingam_data.drop(columns=constant_cols)
                continuous_features = [f for f in continuous_features if f not in constant_cols]
            
            # Validate data
            if lingam_data.isna().any().any():
                raise ValueError("LiNGAM data contains NaNs")
            
            # Create LiNGAM-specific constraint matrix
            lingam_node_to_idx = {name: i for i, name in enumerate(continuous_features)}
            lingam_constraint_matrix = np.full((len(continuous_features), len(continuous_features)), -1, dtype=np.int32)
            
            lingam_tiers = [
                [f for f in tiers[0] if f in continuous_features],  # Demographic
                [f for f in tiers[1] if f in continuous_features],  # Customer
                [f for f in tiers[3] if f in continuous_features]   # Billing
            ]
            
            if "Total Revenue" in lingam_node_to_idx:
                lingam_constraint_matrix[lingam_node_to_idx["Total Revenue"], :] = 0
            for feature in lingam_tiers[0]:
                lingam_constraint_matrix[:, lingam_node_to_idx[feature]] = 0
            for src_tier_idx, src_tier in enumerate(lingam_tiers[:-1]):
                dst_tier = lingam_tiers[src_tier_idx + 1]
                for src in src_tier:
                    for dst in dst_tier:
                        lingam_constraint_matrix[lingam_node_to_idx[src], lingam_node_to_idx[dst]] = -1
                for other_tier_idx, other_tier in enumerate(lingam_tiers):
                    if other_tier_idx != src_tier_idx + 1:
                        for src in src_tier:
                            for dst in other_tier:
                                lingam_constraint_matrix[lingam_node_to_idx[src], lingam_node_to_idx[dst]] = 0
            for src in lingam_tiers[0]:
                for dst in lingam_tiers[0]:
                    if src != dst:
                        lingam_constraint_matrix[lingam_node_to_idx[src], lingam_node_to_idx[dst]] = 0
            
            # Fit LiNGAM
            model = lingam.DirectLiNGAM(prior_knowledge=lingam_constraint_matrix)
            model.fit(lingam_data)
            
            adj_matrix = model.adjacency_matrix_
            
            # Validate constraints
            violations = validate_constraints(adj_matrix, lingam_node_to_idx, lingam_tiers)
            
            # Create a directed graph for visualization
            G = nx.DiGraph()
            for i in range(len(continuous_features)):
                G.add_node(i)
            for i in range(len(continuous_features)):
                for j in range(len(continuous_features)):
                    if adj_matrix[i, j] != 0:
                        G.add_edge(i, j)
                        
            # Visualize and save results
            G_viz = visualize_causal_graph(G, continuous_features, os.path.join(output_dir, "lingam_graph.png"))
            save_relations_to_text(adj_matrix, continuous_features, os.path.join(output_dir, "lingam_relations.txt"))
            
            # Analyze structure
            metrics = analyze_structure_learning(adj_matrix, continuous_features)
            
            return {
                "dag": G,
                "adj_matrix": adj_matrix,
                "metrics": metrics,
                "violations": violations
            }
        
        except Exception as e:
            logger.error("LiNGAM failed: %s", str(e))
            raise

## 2.3.4 NOTEARS Algorithm


In [134]:
class NOTEARSAlgorithm(CausalDiscoveryAlgorithm):
    def fit(self, data, constraint_matrix, node_names, node_name_to_idx, tiers, output_dir):
        logger.info("Running NOTEARS algorithm...")
        try:
            X = data.values
            stds = np.std(X, axis=0)
            means = np.mean(X, axis=0)
            X_standardized = np.where(stds != 0, (X - means) / stds, 0)
            
            def notears_with_constraints(X, constraint_matrix, lambda1=0.01, max_iter=200, h_tol=1e-8, rho_max=1e+16, w_threshold=0.1):
                n, d = X.shape
                mask = 1.0 - np.isnan(constraint_matrix).astype(float)
                
                def _h(w):
                    W = w.reshape((d, d))
                    M = np.eye(d) + W * W / d
                    return np.trace(slin.expm(M)) - d
                
                def _func(w):
                    W = w.reshape((d, d))
                    W = W * (1.0 - mask)
                    R = X - X @ W
                    loss = 0.5 / n * np.sum(R * R)
                    l1_penalty = lambda1 * np.sum(np.abs(W))
                    return loss + l1_penalty
                
                def _grad(w):
                    W = w.reshape((d, d))
                    W = W * (1.0 - mask)
                    R = X - X @ W
                    G = -1.0 / n * X.T @ R
                    G_l1 = lambda1 * np.sign(W)
                    G = (G + G_l1) * (1.0 - mask)
                    return G.flatten()
                
                def _h_grad(w):
                    W = w.reshape((d, d))
                    M = np.eye(d) + W * W / d
                    E = slin.expm(M)
                    G = E.T * (2 * W / d)
                    G = G * (1.0 - mask)
                    return G.flatten()
                
                w_est = np.zeros(d * d)
                rho, alpha, h = 1.0, 0.0, np.inf
                for _ in range(max_iter):
                    w_new = sopt.minimize(
                        lambda w: _func(w) + 0.5 * rho * _h(w) ** 2 + alpha * _h(w),
                        w_est,
                        method='L-BFGS-B',
                        jac=lambda w: _grad(w) + rho * _h(w) * _h_grad(w) + alpha * _h_grad(w),
                        options={'ftol': 1e-6, 'gtol': 1e-6}
                    ).x
                    h_new = _h(w_new)
                    if abs(h_new) <= h_tol or rho >= rho_max:
                        break
                    if abs(h_new) > 0.25 * abs(h):
                        rho *= 10
                    alpha += rho * h_new
                    w_est, h = w_new, h_new
                
                W_est = w_est.reshape((d, d))
                W_est = W_est * (1.0 - mask)
                W_est[np.abs(W_est) < w_threshold] = 0
                
                G = nx.DiGraph(W_est)
                while not nx.is_directed_acyclic_graph(G):
                    try:
                        cycle = nx.find_cycle(G)
                        min_weight = float('inf')
                        min_edge = None
                        for u, v in cycle:
                            if abs(W_est[u, v]) < min_weight:
                                min_weight = abs(W_est[u, v])
                                min_edge = (u, v)
                        if min_edge:
                            G.remove_edge(*min_edge)
                            W_est[min_edge[0], min_edge[1]] = 0
                    except nx.NetworkXNoCycle:
                        break
                
                return W_est
            
            adj_matrix = notears_with_constraints(X_standardized, constraint_matrix, lambda1=0.01)
            
            # Validate constraints
            violations = validate_constraints(adj_matrix, node_name_to_idx, tiers)
            
            # Create a directed graph for visualization
            G = nx.DiGraph(adj_matrix)
            
            # Visualize and save results
            G_viz = visualize_causal_graph(G, node_names, os.path.join(output_dir, "notears_graph.png"))
            save_relations_to_text(adj_matrix, node_names, os.path.join(output_dir, "notears_relations.txt"))
            
            # Analyze structure
            metrics = analyze_structure_learning(adj_matrix, node_names)
            
            return {
                "dag": G,
                "adj_matrix": adj_matrix,
                "metrics": metrics,
                "violations": violations
            }
        
        except Exception as e:
            logger.error("NOTEARS failed: %s", str(e))
            raise



## 2.3.5 PC-Algorithm

In [135]:


class PCGINAlgorithm(CausalDiscoveryAlgorithm):
    def fit(self, data, constraint_matrix, node_names, node_name_to_idx, tiers, output_dir):
        logger.info("Running PC-GIN algorithm...")
        try:
            # Encode categorical columns
            categorical_cols = [col for col in data.columns if col in ["Gender", "Internet Type", "Offer", "Payment Method"]]
            encoded_data = data.copy()
            for col in categorical_cols:
                le = LabelEncoder()
                encoded_data[col] = le.fit_transform(encoded_data[col].astype(str))
            
            def gin_test(X, Y, Z=None, alpha=0.05):
                n = len(X)
                if Z is None or Z.shape[1] == 0:
                    corr, p_value = stats.pearsonr(X, Y)
                    return p_value
                model_x = LinearRegression().fit(Z, X)
                residuals_x = X - model_x.predict(Z)
                model_y = LinearRegression().fit(Z, Y)
                residuals_y = Y - model_y.predict(Z)
                corr, p_value = stats.pearsonr(residuals_x, residuals_y)
                return p_value
            
            def pc_gin(data, constraint_matrix, alpha=0.01):
                n = data.shape[1]
                skeleton = nx.Graph()
                skeleton.add_nodes_from(range(n))
                separating_sets = {}
                
                for i in range(n):
                    for j in range(i + 1, n):
                        if np.isnan(constraint_matrix[i, j]) or np.isnan(constraint_matrix[j, i]):
                            skeleton.add_edge(i, j)
                
                for d in range(n):
                    edges = list(skeleton.edges())
                    for i, j in edges:
                        if not skeleton.has_edge(i, j):
                            continue
                        adj_i = set(skeleton.neighbors(i)) - {j}
                        if len(adj_i) >= d:
                            for subset in itertools.combinations(adj_i, d):
                                subset_list = list(subset)
                                conditioning_set = data[:, subset_list] if subset_list else None
                                p_val = gin_test(
                                    data[:, i], data[:, j],
                                    conditioning_set.reshape(data.shape[0], -1) if conditioning_set is not None else None,
                                    alpha=alpha
                                )
                                if p_val > alpha:
                                    skeleton.remove_edge(i, j)
                                    separating_sets[(i, j)] = subset
                                    separating_sets[(j, i)] = subset
                                    break
                
                dag = nx.DiGraph()
                dag.add_nodes_from(range(n))
                for i, j in skeleton.edges():
                    if np.isnan(constraint_matrix[i, j]) and not np.isnan(constraint_matrix[j, i]):
                        dag.add_edge(i, j)
                    elif np.isnan(constraint_matrix[j, i]) and not np.isnan(constraint_matrix[i, j]):
                        dag.add_edge(j, i)
                    else:
                        dag.add_edge(i, j)
                        dag.add_edge(j, i)
                
                for i in range(n):
                    for j in range(n):
                        if i == j or not dag.has_edge(i, j):
                            continue
                        for k in range(n):
                            if k == i or k == j:
                                continue
                            if dag.has_edge(k, j) and not skeleton.has_edge(i, k):
                                if ((i, k) in separating_sets and j not in separating_sets[(i, k)]) or \
                                   ((k, i) in separating_sets and j not in separating_sets[(k, i)]):
                                    if dag.has_edge(j, i):
                                        dag.remove_edge(j, i)
                                    if dag.has_edge(j, k):
                                        dag.remove_edge(j, k)
                
                for i, j in list(dag.edges()):
                    if dag.has_edge(j, i):
                        if np.isnan(constraint_matrix[i, j]) and not np.isnan(constraint_matrix[j, i]):
                            dag.remove_edge(j, i)
                        elif np.isnan(constraint_matrix[j, i]) and not np.isnan(constraint_matrix[i, j]):
                            dag.remove_edge(i, j)
                        else:
                            dag.remove_edge(j, i)
                
                return dag
            
            dag = pc_gin(encoded_data.values, constraint_matrix, alpha=0.05)
            
            # Create adjacency matrix from DAG
            adj_matrix = nx.to_numpy_array(dag, nodelist=range(len(node_names)))
            
            # Validate constraints
            violations = validate_constraints(dag, node_name_to_idx, tiers)
            
            # Visualize and save results
            G_viz = visualize_causal_graph(dag, node_names, os.path.join(output_dir, "pcgin_graph.png"))
            save_relations_to_text(dag, node_names, os.path.join(output_dir, "pcgin_relations.txt"))
            
            # Analyze structure
            metrics = analyze_structure_learning(dag, node_names)
            
            return {
                "dag": dag,
                "adj_matrix": adj_matrix,
                "metrics": metrics,
                "violations": violations
            }
        
        except Exception as e:
            logger.error("PC-GIN failed: %s", str(e))
            raise



## 2.3.6 GRaSPAlgorithm

In [136]:
class GRaSPAlgorithm(CausalDiscoveryAlgorithm):
    def fit(self, data, constraint_matrix, node_names, node_name_to_idx, tiers, output_dir):
        logger.info("Running GRaSP algorithm...")
        try:
            X = data.values
            stds = np.std(X, axis=0)
            means = np.mean(X, axis=0)
            X_standardized = np.where(stds != 0, (X - means) / stds, 0)
            
            def grasp_with_constraints(X, constraint_matrix, lambda1=0.01, max_iter=200, h_tol=1e-8, rho_max=1e+16, w_threshold=0.1):
                n, d = X.shape
                mask = 1.0 - np.isnan(constraint_matrix).astype(float)
                
                def _h(w):
                    W = w.reshape((d, d))
                    M = np.eye(d) + W * W / d
                    return np.trace(slin.expm(M)) - d
                
                def _func(w, rho, alpha):
                    W = w.reshape((d, d))
                    W = W * (1.0 - mask)
                    R = X - X @ W
                    loss = 0.5 / n * np.sum(R * R)
                    l1_penalty = lambda1 * np.sum(np.abs(W))
                    h_val = _h(w)
                    return loss + l1_penalty + 0.5 * rho * h_val ** 2 + alpha * h_val
                
                def _grad(w, rho, alpha):
                    W = w.reshape((d, d))
                    W = W * (1.0 - mask)
                    R = X - X @ W
                    G_loss = -1.0 / n * X.T @ R
                    G_l1 = lambda1 * np.sign(W)
                    h_val = _h(w)
                    h_gradient = _h_grad(w).reshape((d, d))
                    G_acyclicity = (rho * h_val + alpha) * h_gradient
                    G = (G_loss + G_l1 + G_acyclicity) * (1.0 - mask)
                    return G.flatten()
                
                def _h_grad(w):
                    W = w.reshape((d, d))
                    M = np.eye(d) + W * W / d
                    E = slin.expm(M)
                    G = E.T * (2 * W / d)
                    G = G * (1.0 - mask)
                    return G.flatten()
                
                w_est = np.zeros(d * d)
                rho, alpha, h = 1.0, 0.0, np.inf
                for _ in range(max_iter):
                    w_new = sopt.minimize(
                        lambda w: _func(w, rho, alpha),
                        w_est,
                        method='L-BFGS-B',
                        jac=lambda w: _grad(w, rho, alpha),
                        options={'ftol': 1e-6, 'gtol': 1e-6}
                    ).x
                    h_new = _h(w_new)
                    if abs(h_new) <= h_tol or rho >= rho_max:
                        break
                    if abs(h_new) > 0.25 * abs(h):
                        rho *= 10
                    alpha += rho * h_new
                    w_est, h = w_new, h_new
                
                W_est = w_est.reshape((d, d))
                W_est = W_est * (1.0 - mask)
                W_est[np.abs(W_est) < w_threshold] = 0
                
                G = nx.DiGraph(W_est)
                while not nx.is_directed_acyclic_graph(G):
                    try:
                        cycle = nx.find_cycle(G)
                        min_weight = float('inf')
                        min_edge = None
                        for u, v in cycle:
                            if abs(W_est[u, v]) < min_weight:
                                min_weight = abs(W_est[u, v])
                                min_edge = (u, v)
                        if min_edge:
                            G.remove_edge(*min_edge)
                            W_est[min_edge[0], min_edge[1]] = 0
                    except nx.NetworkXNoCycle:
                        break
                
                return W_est
            
            adj_matrix = grasp_with_constraints(X_standardized, constraint_matrix)
            
            # Validate constraints
            violations = validate_constraints(adj_matrix, node_name_to_idx, tiers)
            
            # Create a directed graph for visualization
            G = nx.DiGraph(adj_matrix)
            
            # Visualize and save results
            G_viz = visualize_causal_graph(G, node_names, os.path.join(output_dir, "grasp_graph.png"))
            save_relations_to_text(adj_matrix, node_names, os.path.join(output_dir, "grasp_relations.txt"))
            
            # Analyze structure
            metrics = analyze_structure_learning(adj_matrix, node_names)
            
            return {
                "dag": G,
                "adj_matrix": adj_matrix,
                "metrics": metrics,
                "violations": violations
            }
        
        except Exception as e:
            logger.error("GRaSP failed: %s", str(e))
            raise

## 2.4.1. run_causal_discovery_pipeline Function


In [137]:
logger = logging.getLogger(__name__)

def run_causal_discovery_pipeline(train_data, tiers, specific_constraints=None, output_dir="causal_discovery_output"):
    """
    Run the causal discovery pipeline for all algorithms.
    
    Args:
        train_data (pd.DataFrame): Training data.
        tiers (list): List of tier lists.
        specific_constraints (dict): Additional constraints.
        output_dir (str): Directory to save output files.
    
    Returns:
        dict: Results for each algorithm.
    """
    logger.info("Starting causal discovery pipeline...")
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Create constraint matrix
    node_names = list(train_data.columns)
    constraint_matrix, node_name_to_idx = create_constraint_matrix(node_names, tiers, specific_constraints)
    
    # Initialize algorithms
    algorithms = {
        "DECI": DECIAlgorithm(),
        "LiNGAM": LiNGAMAlgorithm(),
        "PC-GIN": PCGINAlgorithm(),
        "NOTEARS": NOTEARSAlgorithm(),
        "GRaSP": GRaSPAlgorithm()
    }
    
    # Run algorithms
    results = {}
    for algo_name, algo in algorithms.items():
        try:
            logger.info("Executing %s...", algo_name)
            print(f"\n{algo_name} Causal Relationships:")
            result = algo.fit(train_data, constraint_matrix, node_names, node_name_to_idx, tiers, output_dir)
            results[algo_name] = result
        except Exception as e:
            logger.error("%s failed: %s", algo_name, str(e))
            results[algo_name] = {"error": str(e)}
    
    # Summarize results
    summary = pd.DataFrame({
        algo_name: {
            "num_edges": result["metrics"]["num_edges"] if "metrics" in result else "N/A",
            "graph_density": result["metrics"]["graph_density"] if "metrics" in result else "N/A",
            "violations": len(result["violations"]) if "violations" in result else "N/A"
        } for algo_name, result in results.items()
    }).T
    logger.info("Summary of results:\n%s", summary)
    
    return results

## 2.4.2. Tiers Definition


In [138]:
# Define tiers and specific constraints
tiers = [
    ["Gender", "Age", "Senior Citizen", "Married", "Number of Dependents"],  # Tier 1: Demographic
    ["Number of Referrals", "Tenure in Months", "Offer"],                    # Tier 2: Customer
    ["Phone Service", "Multiple Lines", "Internet Type", "Unlimited Data",
     "Online Security", "Online Backup", "Device Protection Plan",
     "Premium Tech Support", "Streaming TV", "Streaming Movies", "Streaming Music"],  # Tier 3: Service
    ["Total Revenue", "Paperless Billing", "Payment Method"],                # Tier 4: Billing
    ["Churn Label"]                                                         # Tier 5: Outcome
]

## 2.4.3. Specific Constraints


In [139]:
specific_constraints = {
    "forbidden": [
        ("Gender", dst) for dst in tiers[2] + tiers[3]
    ] + [
        ("Internet Type", dst) for dst in ["Unlimited Data", "Online Security", "Online Backup",
                                          "Device Protection Plan", "Premium Tech Support",
                                          "Streaming TV", "Streaming Movies", "Streaming Music"]
    ],
    "allowed": []
}


## 2.4.4. Pipeline Execution
## 2.4.5. Constraint Violations Evaluation

In [140]:
# Run the pipeline (assuming train_data is available from preprocessing)
results = run_causal_discovery_pipeline(train_data, tiers, specific_constraints, output_dir="causal_discovery_output")

# Evaluate constraint violations on test data
node_names = list(train_data.columns)
_, node_name_to_idx = create_constraint_matrix(node_names, tiers, specific_constraints)
evaluation_summary = evaluate_causal_discovery(results, test_data, tiers, node_name_to_idx)

INFO:__main__:Starting causal discovery pipeline...
INFO:__main__:Constraint matrix created with shape: (23, 23)
INFO:__main__:Executing DECI...
INFO:__main__:Running DECI algorithm...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type                  | Params
------------------------------------------------------
0 | auglag_loss | AugLagLossCalculator  | 0     
1 | sem_module  | SEMDistributionModule | 36.7 K
------------------------------------------------------
36.2 K    Trainable params
529       Non-trainable params
36.7 K    Total params
0.147     Total estimated model params size (MB)
c:\Users\aafz1\miniconda3\envs\project-env\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of th


DECI Causal Relationships:
Epoch 9: 100%|██████████| 44/44 [00:02<00:00, 20.48it/s, v_num=11]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 44/44 [00:02<00:00, 20.33it/s, v_num=11]

INFO:__main__:✅ All constraints validated successfully





INFO:__main__:Causal graph saved as causal_discovery_output\deci_graph.png
INFO:__main__:Causal relationships saved to causal_discovery_output\deci_relations.txt (25 relations)
INFO:__main__:Structure Learning Metrics: {'num_edges': 17, 'graph_density': 0.03359683794466403, 'avg_in_degree': 0.7391304347826086, 'avg_out_degree': 0.7391304347826086, 'most_influential': [('Tenure in Months', 8), ('Married', 2), ('Streaming TV', 1), ('Premium Tech Support', 1), ('Multiple Lines', 1)], 'most_affected': [('Total Revenue', 6), ('Paperless Billing', 1), ('Streaming Music', 1), ('Streaming Movies', 1), ('Number of Referrals', 1)]}
INFO:__main__:Executing LiNGAM...
INFO:__main__:Running LiNGAM algorithm...
INFO:__main__:✅ All constraints validated successfully



=== Causal Relationships ===
Married -> Number of Referrals (weight: 0.767)
---
Married -> Tenure in Months (weight: 0.672)
---
Tenure in Months -> Multiple Lines (weight: 0.661)
---
Tenure in Months -> Online Security (weight: 0.748)
---
Tenure in Months -> Online Backup (weight: 0.712)
---
Tenure in Months -> Device Protection Plan (weight: 0.723)
---
Tenure in Months -> Premium Tech Support (weight: 0.618)
---
Tenure in Months -> Streaming TV (weight: 0.683)
---
Tenure in Months -> Streaming Movies (weight: 0.635)
---
Tenure in Months -> Streaming Music (weight: 0.607)
---
Phone Service -> Total Revenue (weight: 0.384)
---
Multiple Lines -> Total Revenue (weight: 0.702)
---
Internet Type -> Total Revenue (weight: 0.370)
---
Internet Type -> Paperless Billing (weight: 0.691)
---
Unlimited Data -> Paperless Billing (weight: 0.396)
---
Online Security -> Total Revenue (weight: 0.623)
---
Online Backup -> Total Revenue (weight: 0.733)
---
Device Protection Plan -> Total Revenue (weight

INFO:__main__:Causal graph saved as causal_discovery_output\lingam_graph.png
INFO:__main__:Causal relationships saved to causal_discovery_output\lingam_relations.txt (2 relations)
INFO:__main__:Structure Learning Metrics: {'num_edges': 1, 'graph_density': 0.08333333333333333, 'avg_in_degree': 0.25, 'avg_out_degree': 0.25, 'most_influential': [('Tenure in Months', 1), ('Age', 0), ('Number of Referrals', 0), ('Total Revenue', 0)], 'most_affected': [('Total Revenue', 1), ('Age', 0), ('Number of Referrals', 0), ('Tenure in Months', 0)]}
INFO:__main__:Executing PC-GIN...
INFO:__main__:Running PC-GIN algorithm...



=== Causal Relationships ===
Number of Referrals -> Total Revenue (weight: 0.269)
---
Tenure in Months -> Total Revenue (weight: 0.852)
---

PC-GIN Causal Relationships:


INFO:__main__:✅ All constraints validated successfully
INFO:__main__:Causal graph saved as causal_discovery_output\pcgin_graph.png
INFO:__main__:Causal relationships saved to causal_discovery_output\pcgin_relations.txt (50 relations)
INFO:__main__:Structure Learning Metrics: {'num_edges': 50.0, 'graph_density': 0.09881422924901186, 'avg_in_degree': 2.1739130434782608, 'avg_out_degree': 2.1739130434782608, 'most_influential': [('Tenure in Months', 6.0), ('Offer', 6.0), ('Unlimited Data', 3.0), ('Number of Dependents', 3.0), ('Streaming Music', 3.0)], 'most_affected': [('Total Revenue', 11.0), ('Paperless Billing', 8.0), ('Payment Method', 7.0), ('Churn Label', 3.0), ('Offer', 3.0)]}
INFO:__main__:Executing NOTEARS...
INFO:__main__:Running NOTEARS algorithm...
  X_standardized = np.where(stds != 0, (X - means) / stds, 0)



=== Causal Relationships ===
Gender -> Offer (weight: 1.000)
---
Senior Citizen -> Offer (weight: 1.000)
---
Married -> Number of Referrals (weight: 1.000)
---
Married -> Tenure in Months (weight: 1.000)
---
Number of Dependents -> Number of Referrals (weight: 1.000)
---
Number of Dependents -> Tenure in Months (weight: 1.000)
---
Number of Dependents -> Offer (weight: 1.000)
---
Number of Referrals -> Internet Type (weight: 1.000)
---
Number of Referrals -> Online Security (weight: 1.000)
---
Tenure in Months -> Multiple Lines (weight: 1.000)
---
Tenure in Months -> Online Security (weight: 1.000)
---
Tenure in Months -> Online Backup (weight: 1.000)
---
Tenure in Months -> Device Protection Plan (weight: 1.000)
---
Tenure in Months -> Premium Tech Support (weight: 1.000)
---
Tenure in Months -> Streaming Movies (weight: 1.000)
---
Offer -> Multiple Lines (weight: 1.000)
---
Offer -> Online Security (weight: 1.000)
---
Offer -> Online Backup (weight: 1.000)
---
Offer -> Device Protec

INFO:__main__:✅ All constraints validated successfully
INFO:__main__:Causal graph saved as causal_discovery_output\notears_graph.png
INFO:__main__:Causal relationships saved to causal_discovery_output\notears_relations.txt (13 relations)
INFO:__main__:Structure Learning Metrics: {'num_edges': 1, 'graph_density': 0.001976284584980237, 'avg_in_degree': 0.043478260869565216, 'avg_out_degree': 0.043478260869565216, 'most_influential': [('Married', 1), ('Gender', 0), ('Paperless Billing', 0), ('Total Revenue', 0), ('Streaming Music', 0)], 'most_affected': [('Number of Referrals', 1), ('Gender', 0), ('Paperless Billing', 0), ('Total Revenue', 0), ('Streaming Music', 0)]}
INFO:__main__:Executing GRaSP...
INFO:__main__:Running GRaSP algorithm...
  X_standardized = np.where(stds != 0, (X - means) / stds, 0)



=== Causal Relationships ===
Married -> Number of Referrals (weight: 0.692)
---
Married -> Tenure in Months (weight: 0.366)
---
Tenure in Months -> Multiple Lines (weight: 0.326)
---
Tenure in Months -> Online Security (weight: 0.308)
---
Tenure in Months -> Online Backup (weight: 0.351)
---
Tenure in Months -> Device Protection Plan (weight: 0.350)
---
Tenure in Months -> Premium Tech Support (weight: 0.307)
---
Tenure in Months -> Streaming TV (weight: 0.272)
---
Tenure in Months -> Streaming Movies (weight: 0.283)
---
Tenure in Months -> Streaming Music (weight: 0.234)
---
Multiple Lines -> Total Revenue (weight: 0.235)
---
Internet Type -> Paperless Billing (weight: 0.285)
---
Online Backup -> Total Revenue (weight: 0.222)
---

GRaSP Causal Relationships:


INFO:__main__:✅ All constraints validated successfully
INFO:__main__:Causal graph saved as causal_discovery_output\grasp_graph.png
INFO:__main__:Causal relationships saved to causal_discovery_output\grasp_relations.txt (13 relations)
INFO:__main__:Structure Learning Metrics: {'num_edges': 1, 'graph_density': 0.001976284584980237, 'avg_in_degree': 0.043478260869565216, 'avg_out_degree': 0.043478260869565216, 'most_influential': [('Married', 1), ('Gender', 0), ('Paperless Billing', 0), ('Total Revenue', 0), ('Streaming Music', 0)], 'most_affected': [('Number of Referrals', 1), ('Gender', 0), ('Paperless Billing', 0), ('Total Revenue', 0), ('Streaming Music', 0)]}
INFO:__main__:Summary of results:
         num_edges  graph_density  violations
DECI          17.0       3.36e-02         0.0
LiNGAM         1.0       8.33e-02         0.0
PC-GIN        50.0       9.88e-02         0.0
NOTEARS        1.0       1.98e-03         0.0
GRaSP          1.0       1.98e-03         0.0
INFO:__main__:Constr


=== Causal Relationships ===
Married -> Number of Referrals (weight: 0.692)
---
Married -> Tenure in Months (weight: 0.366)
---
Tenure in Months -> Multiple Lines (weight: 0.326)
---
Tenure in Months -> Online Security (weight: 0.308)
---
Tenure in Months -> Online Backup (weight: 0.351)
---
Tenure in Months -> Device Protection Plan (weight: 0.350)
---
Tenure in Months -> Premium Tech Support (weight: 0.307)
---
Tenure in Months -> Streaming TV (weight: 0.272)
---
Tenure in Months -> Streaming Movies (weight: 0.283)
---
Tenure in Months -> Streaming Music (weight: 0.234)
---
Multiple Lines -> Total Revenue (weight: 0.235)
---
Internet Type -> Paperless Billing (weight: 0.285)
---
Online Backup -> Total Revenue (weight: 0.222)
---
