# "Adult" dataset
Predict whether income exceeds $50K/yr based on census data. Also known as "Census Income" dataset.
https://archive.ics.uci.edu/dataset/2/adult

In [1]:
import dice_ml
from dice_ml.utils import helpers
import pandas as pd
dataset = helpers.load_adult_income_dataset()
dataset.head()

Unnamed: 0,age,workclass,education,marital_status,occupation,race,gender,hours_per_week,income
0,28,Private,Bachelors,Single,White-Collar,White,Female,60,0
1,30,Self-Employed,Assoc,Married,Professional,White,Male,65,1
2,32,Private,Some-college,Married,White-Collar,White,Male,50,0
3,20,Private,Some-college,Single,Service,White,Female,35,0
4,41,Self-Employed,Some-college,Married,White-Collar,White,Male,50,0


In [2]:
# description of transformed features
adult_info = helpers.get_adult_data_info()
adult_info

{'age': 'age',
 'workclass': 'type of industry (Government, Other/Unknown, Private, Self-Employed)',
 'education': 'education level (Assoc, Bachelors, Doctorate, HS-grad, Masters, Prof-school, School, Some-college)',
 'marital_status': 'marital status (Divorced, Married, Separated, Single, Widowed)',
 'occupation': 'occupation (Blue-Collar, Other/Unknown, Professional, Sales, Service, White-Collar)',
 'race': 'white or other race?',
 'gender': 'male or female?',
 'hours_per_week': 'total work hours per week',
 'income': '0 (<=50K) vs 1 (>50K)'}

# Utils

## Creating a prompt collection for different purposes

### Discovery of Causal relationships

In [3]:
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain_community.chat_models import ChatOpenAI

llm = ChatOpenAI(model_name="gpt-4")
template= """
First, lets define some variables:
X: {A}
Y:{B}
Which cause-and-effect relationship is more likely?
1. changing X causes a change in Y.
-1. changing Y causes a change in X.
0.there is no direct causal relationship between X and Y.
Let’s work this out in a step by step way to be sure that we have the right answer. Then provide your final within the tags <Answer>1/-1/0</Answer>.
"""



prompt = PromptTemplate(
    input_variables=["A", "B"],
    template=template,
)

reasoning_chain = LLMChain(llm=llm, prompt=prompt)

In [4]:
import re
def extract_answer(text):
    # Regular expression pattern to find text between <Answer> and </Answer>
    pattern = "<Answer>(.+)</Answer>"
    
    # Use re.findall to find all occurrences that match the pattern
    matches = re.findall(pattern, text)
    
    return matches[0]

def extract_reasoning(text):
    # Split the string into a list of lines
    lines = text.split('\n')
    #delete the last line
    lines = lines[:-1]
    # Join the list back into a string
    return '\n'.join(lines)



In [5]:
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain_community.chat_models import ChatOpenAI

def discover_relationship(event1, event2):
    llm = ChatOpenAI(model_name="gpt-4")
    template= """
    First, lets define some variables:
    X: {A}
    Y:{B}
    Which cause-and-effect relationship is more likely?
    1. changing X causes a change in Y.
    -1. changing Y causes a change in X.
    0.there is no direct causal relationship between X and Y.
    Let’s work this out in a step by step way to be sure that we have the right answer. Then provide your final within the tags <Answer>1/-1/0</Answer>.
    """



    prompt = PromptTemplate(
        input_variables=["A", "B"],
        template=template,
    )

    reasoning_chain = LLMChain(llm=llm, prompt=prompt)

    response = reasoning_chain.run({
    'A': event1,
    'B': event2
    })
    return int(extract_answer(response)), extract_reasoning(response)

### Validation of reasoning consistency

In [6]:
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain_community.chat_models import ChatOpenAI

def check_relationship_consistency(event1, event2,llm_reasoning, llm_direction):
    llm = ChatOpenAI(model_name="gpt-4")
    #First we need to recovewr the original question and response
    template= """
    First, lets define some variables:
    X: {A}
    Y:{B}
    Which cause-and-effect relationship is more likely?
    1. changing X causes a change in Y.
    -1. changing Y causes a change in X.
    0.there is no direct causal relationship between X and Y.
    Let’s work this out in a step by step way to be sure that we have the right answer. Then provide your final within the tags <Answer>1/-1/0</Answer>.
    """
    question = template.format(A=event1, B=event2)

    previous_answer = llm_reasoning + '\n<Answer>' + str(llm_direction) + '<\Answer>'
    template2 = """Analyze the output from an AI assistant. Is the final answer consistent with the reasoning provided by the assistant? Give a final answer after reasoning this issue.
    Question:
    {question}
    AI assistant: {previous_answer}
    Explain your decision and then provide your final within the tags 
    <Answer>1/0/-1</Answer>."""

    prompt2 = PromptTemplate(
        input_variables=["question", "previous_answer"],
        template=template2,
    )

    autoconsistency_chain = LLMChain(llm=llm, prompt=prompt2)
    
    response = autoconsistency_chain.run({
    "question": question,
    "previous_answer": previous_answer
    })
    
    return int(extract_answer(response)), extract_reasoning(response)


### Mediators thought experiment

In [7]:
llm = ChatOpenAI(model_name="gpt-4")
template3 = """
A: {A}
B: {B}
MEDIATORS:
{MEDIATORS}

Given that there exists a causal relationship from A to B, and acknowledging that there are also causal pathways from A to MEDIATORS, and from these variables (MEDIATORS) to B, we aim to discern whether the causal influence of A on B is direct or mediated through these other variables.

Causal Framework Establishment: We start by affirming the existence of a causal relationship from A to B, acknowledging that A also influences intermediary variables, which in turn have their own effects on B.

Hypothesis of Direct vs. Indirect Causation: The core question is whether A's impact on B is direct, not merely transmitted through its effects on MEDIATORS. In other words, does A influence B independently of the pathways through MEDIATORS?

Intervention Thought Experiment:

Control for Mediators: Imagine we can control or "fix" the values of MEDIATORS, such that any change in these variables does not affect B. This setup mimics a scenario where the only pathway from A to B that can manifest is the direct one, if it exists.
Observation after Intervention: After fixing MEDIATORS, we then assess the effect of a change in A on B. If altering A still changes B even when MEDIATORS, are held constant, this indicates a direct causal relationship between A and B. Conversely, if changes in A no longer influence B once we control for MEDIATORS, the causal path from A to B is likely indirect, fully mediated by these variables.
Conclusion Drawing: Based on the observations from the intervention thought experiment, we can conclude whether the causal path from A to B is direct 1 (A directly influences B, independent of other variables) or indirect -1 (A's influence on B is mediated through MEDIATORS). 

Reason step by step andat the end answer 1 (direct), -1 (indirect) or 0 (not sure).

Reasoning:

Answer: <Answer>1/0</Answer>"""




prompt3 = PromptTemplate(
    input_variables=["A", "B","MEDIATORS"],
    template=template3,
)

direct_chain = LLMChain(llm=llm, prompt=prompt3)


In [8]:
A = 'income'


mediators = """
 'marital_status': 'marital status (Divorced, Married, Separated, Single, Widowed)',
 'occupation': 'occupation (Blue-Collar, Other/Unknown, Professional, Sales, Service, White-Collar)',
 'hours_per_week': 'total work hours per week'
 """

B = 'age'


response = direct_chain.run({
"A": A,
"B": B,
"MEDIATORS":mediators
})

print(response)

In order to determine if the causal relationship from income (A) to age (B) is direct or indirect, we need to assess the potential mediating effects of marital status, occupation, and hours worked per week.

1. Income can influence marital status – people with higher incomes may be more likely to get married or remain married. Marital status can also influence age, as individuals tend to get married at a certain age and can potentially live longer due to the benefits of companionship.

2. Income can influence occupation – individuals with higher incomes are more likely to have professional or white-collar jobs. Occupation can also influence age as certain jobs might have age restrictions or certain jobs might be more prevalent among different age groups.

3. Income can influence hours worked per week – individuals with higher incomes may work more hours per week. The number of hours worked per week can also influence age because certain age groups might have more or less available time

In [9]:
A = """ 'workclass': 'type of industry (Government, Other/Unknown, Private, Self-Employed)'"""

mediators = """
 'occupation': 'occupation (Blue-Collar, Other/Unknown, Professional, Sales, Service, White-Collar)',
 'hours_per_week': 'total work hours per week'
 """

B = 'income'


response = direct_chain.run({
"A": A,
"B": B,
"MEDIATORS":mediators
})

print(response)

Given the described scenario, it's hard to conclude definitively without empirical data. However, we can hypothesize possible outcomes based on general knowledge.

Workclass (A) could directly influence income (B) because typically different industries have different pay scales. For example, a job in the private sector might generally pay more than a job in the government sector. This would suggest a direct causal relationship.

On the other hand, the type of work (workclass) could influence one's occupation and hours per week (MEDIATORS), which in turn could affect income. For instance, certain industries might have more opportunities for high-paying occupations or may require more working hours per week, leading to higher income. In this case, the causal path would be indirect.

Given these potential scenarios, it's plausible that the causal relationship could be either direct or indirect. However, without specific data or more information, it is impossible to declare with certainty 

## Creating classes

In [10]:
class Event:
    """Represents an event with a name and a description."""
    def __init__(self, event_name: str, event_description: str) -> None:
        self.event_name = event_name
        self.event_description = event_description

    def __str__(self) -> str:
        return f"{self.event_name}: {self.event_description}"

    def to_dict(self):
        return {"event_name": self.event_name, "event_description": self.event_description}

    @classmethod
    def from_dict(cls, data):
        return cls(data["event_name"], data["event_description"])

In [11]:
class CausalRelation:
    def __init__(self, event1: Event, event2: Event):
        self.event1 = event1
        self.event2 = event2
        self.direction = None  
        self.reasoning = ""
        self.consistency_check = None
        self.consistency_reasoning = ""

    def discover(self) -> (str, int):
        try:
            # Assuming discover_relationship expects string representations
            self.direction, self.reasoning = discover_relationship(str(self.event1), str(self.event2))
        except Exception as e:
            self.reasoning = f"Error discovering relationship: {e}"
            self.direction = None
        return self.reasoning, self.direction

    def check_consistency(self) -> (str, bool):
        if self.direction is not None:
            try:
                # Assuming check_relationship_consistency is adjusted to handle Event objects or their string representations
                self.consistency_check, self.consistency_reasoning = check_relationship_consistency(str(self.event1), str(self.event2), self.reasoning, self.direction)
            except Exception as e:
                self.consistency_reasoning = f"Error checking consistency: {e}"
                self.consistency_check = None
        else:
            self.consistency_reasoning = "Discovery must be performed before checking consistency."
            self.consistency_check = None
        return self.consistency_reasoning, self.consistency_check
    
    def to_dict(self):
        return {
            "event1": self.event1.event_name,
            "event2": self.event2.event_name,
            "direction": self.direction,
            "reasoning": self.reasoning,
            "consistency_check": self.consistency_check,
            "consistency_reasoning": self.consistency_reasoning,
        }

    @staticmethod
    def from_dict(data, events):
        event1 = events[data["event1"]]
        event2 = events[data["event2"]]
        relation = CausalRelation(event1, event2)
        relation.direction = data["direction"]
        relation.reasoning = data["reasoning"]
        relation.consistency_check = data.get("consistency_check")
        relation.consistency_reasoning = data.get("consistency_reasoning")
        return relation


In [15]:
import json
from graphviz import Digraph
class CausalGraph:
    def __init__(self):
        self.events = {}  # Key: event name, Value: Event object
        self.relations = []  # List of CausalRelation objects
    def add_events(self, events: [Event]):
        """Add a list of events to the graph."""
        for event in events:
            if event.event_name in self.events:
                print(f"Warning: Event named '{event.event_name}' already exists. Skipping.")
            else:
                self.events[event.event_name] = event

    def discover_graph(self, check_consistency = False):
        """Automatically discover and add all possible causal relationships between events in the graph."""
        event_names = list(self.events.keys())
        for i, event_name1 in enumerate(event_names):
            for event_name2 in event_names[i + 1:]:
                #we create a CausalRelation class and discover if there are or not direct causal relations between the variables
                relation = CausalRelation(self.events[event_name1], self.events[event_name2])
                relation.discover()
                if check_consistency:
                    relation.check_consistency()
                self.relations.append(relation)

    def consistency(self):
        for relation in self.relations:
            relation.check_consistency()
            

    def save_to_json(self, file_path):
        graph_data = {
            "events": [event.to_dict() for event in self.events.values()],
            "relations": [relation.to_dict() for relation in self.relations],
        }
        with open(file_path, 'w') as file:
            json.dump(graph_data, file, indent=4)

    def load_from_json(self, file_path):
        with open(file_path, 'r') as file:
            graph_data = json.load(file)
        self.events = {data["event_name"]: Event.from_dict(data) for data in graph_data["events"]}
        self.relations = [CausalRelation.from_dict(data, self.events) for data in graph_data["relations"]]

    def plot_graph(self, filename='causal_graph', format='png', after_consistency = False):
        dot = Digraph(comment='The Causal Graph')
        
        # Add nodes (events)
        for event_name, event in self.events.items():
            dot.node(event_name, f'{event_name}')
        
        # Add edges (relations)
        for relation in self.relations:
            if after_consistency:
                direction = relation.consistency_check
            else:
                direction = relation.direction

            if direction >= 1:
                start = relation.event1.event_name
                end = relation.event2.event_name
                # Determine edge color based on consistency
                dot.edge(start, end, color = 'black')

            elif direction <= -1:
                end = relation.event1.event_name
                start = relation.event2.event_name
                # Determine edge color based on consistency
                dot.edge(start, end, color = 'black')

            elif direction == 0:
                continue
            else:
                start = relation.event1.event_name
                end = relation.event2.event_name
                dot.edge(start, end, color ='red', style = 'dashed',dir='none')
                continue

        # Render the graph to a file and optionally display it
        dot.render(filename, format=format, view=True)




# Causal Discovery

## Discovering the graph

In [13]:
# We will start by adding all the events and their descriptions from the adults dataset
adult_events = []
for i in range(len(dataset.columns)):
    adult_events.append(Event(dataset.columns[i],adult_info[dataset.columns[i]]))
print(adult_events)

[<__main__.Event object at 0x000001584171E2C0>, <__main__.Event object at 0x000001584173FB50>, <__main__.Event object at 0x000001584173F9A0>, <__main__.Event object at 0x000001584173D8D0>, <__main__.Event object at 0x000001584173E0E0>, <__main__.Event object at 0x000001584173F6A0>, <__main__.Event object at 0x000001584173F670>, <__main__.Event object at 0x000001584173F4F0>, <__main__.Event object at 0x000001584173F970>]


In [None]:
#We initialize the object
graph=CausalGraph()

#We add the list of events we created
graph.add_events(adult_events)

#We will check all the possible causal relationships
graph.discover_graph(check_consistency=True)

#plotting the DAG
graph.plot_graph('Adult_DAG3')
graph.plot_graph('Adult_DAG_checked3',after_consistency=True)
graph.save_to_json('adult_DAG3.json')

DAG after discovering relationships


![Alt text](./DAGs/Adult_DAG.png "DAG")

Dag after checking the consistency of the reasonings



![Alt text](./DAGs/Adult_DAG_checked.png "DAG")

## Loading the graph

In [69]:
graph=CausalGraph()
graph.load_from_json('adult_DAG.json')


In [70]:
graph.plot_graph('Adult_DAG')
graph.plot_graph('Adult_DAG_checked', after_consistency=True)

## Leverage 15 graphs

In [14]:
for i in range (4,16):
    graph=CausalGraph()

    #We add the list of events we created
    graph.add_events(adult_events)

    #We will check all the possible causal relationships
    graph.discover_graph(check_consistency=True)

    #plotting the DAG
    graph.plot_graph('Adult_DAG'+str(i))
    graph.plot_graph('Adult_DAG_checked'+str(i),after_consistency=True)
    graph.save_to_json('adult_DAG'+str(i)+'.json')

In [21]:
total_graph = CausalGraph()
total_graph.load_from_json('./DAGs/adult_DAG1.json')
for i in range (2,16):
    graph=CausalGraph()
    graph.load_from_json(f'./DAGs/adult_DAG{i}.json')
    for j in range(len(graph.relations)):
        relation = graph.relations[j]
        direction = relation.direction
        direction_c = relation.consistency_check

        total_graph.relations[j].direction = total_graph.relations[j].direction + direction
        total_graph.relations[j].consistency_check = total_graph.relations[j].consistency_check + direction_c


total_graph.save_to_json('adult_DAG_sum.json')

In [26]:
import numpy as np
final_graph = CausalGraph()
final_graph.load_from_json('./DAGs/adult_DAG_sum.json')

threshold = 7.5
for relation in final_graph.relations:

    if relation.direction >= threshold:
        relation.direction = 1
    elif relation.direction <= -threshold:
        relation.direction = -1
    else:
        relation.direction = 0

    if relation.consistency_check >= threshold:
        relation.consistency_check = 1
    elif relation.consistency_check  <= -threshold:
        relation.consistency_check  = -1
    else:
        relation.consistency_check  = 0



final_graph.plot_graph(f'Adult_DAG_final_{threshold}')
final_graph.plot_graph(f'Adult_DAG_final_checked_{threshold}',after_consistency=True)


DAG with threshold 7.5


![Alt text](./DAGs/Adult_DAG_final_7.5.png "DAG")



DAG with threshold5


![Alt text](./DAGs/Adult_DAG_final_5.png "DAG")

In [27]:
final_graph.save_to_json('adult_DAG_final.json')

## Comaparing with ground truth DAG

Ground truth found on the internet

![Alt text](https://www.researchgate.net/publication/352436215/figure/fig3/AS:1035182521991168@1623818154645/Ground-truth-DAG-for-the-simulated-adult-dataset-Gray-edges-indicate-parent-edges-for.png "DAG")


https://www.researchgate.net/publication/352436215/figure/fig3/AS:1035182521991168@1623818154645/Ground-truth-DAG-for-the-simulated-adult-dataset-Gray-edges-indicate-parent-edges-for.png

In [30]:
gt_graph = CausalGraph()
gt_graph.load_from_json('./DAGs/adult_DAG_groundtruth.json')


In [31]:
gt_graph.plot_graph('Adult_DAG_GT')

In [35]:
#Lets check how many edges were predicted corectly
count = 0
for j in range(len(gt_graph.relations)):
    relation = final_graph.relations[j].direction
    gt_relation = gt_graph.relations[j].direction 
    if relation == gt_relation:
        count = count + 1
    else:
        continue

acc = count/len(gt_graph.relations)

print(f'{count} edges predicted correctly out of {len(gt_graph.relations)}')
print('Accuracy: ',acc*100, '%')

22 edges predicted correctly out of 36
Accuracy:  61.111111111111114 %


In [32]:
final_graph

<__main__.CausalGraph at 0x158b3325360>