In [1]:
#create an empty structure model.

import warnings
from causalnex.structure import StructureModel

warnings.filterwarnings("ignore")  # silence warnings

sm = StructureModel()

In [2]:
# add relationships into our structure model:
sm.add_edges_from([
    ('health', 'absences'),
    ('health', 'G1')
])

In [3]:
#examine a StructureModel by looking at the output of sm.edges
sm.edges

OutEdgeView([('health', 'absences'), ('health', 'G1')])

In [None]:
# Plot the structure model
from causalnex.plots import plot_structure, NODE_STYLE, EDGE_STYLE

viz = plot_structure(
    sm,
    all_node_attributes=NODE_STYLE.WEAK,
    all_edge_attributes=EDGE_STYLE.WEAK,
)
viz.show("../screenshots/01_simple_plot.html")

In [4]:
# import the dataset
import pandas as pd

data = pd.read_csv('../data/merged_result.csv')
data.head(5)

Unnamed: 0,id,order_id,driver_id,driver_action,lat,lng,trip_origin,trip_destination,weekday,trip_start_date,holiday,rain,distance,speed
0,1,392001,243828,accepted,6.602207,3.270465,"6.6010417,3.2766339","6.4501069,3.3916154",1.0,2021-07-01,0.0,0.0,13.039051,216.316059
1,2,392001,243588,rejected,6.592097,3.287445,"6.6010417,3.2766339","6.4501069,3.3916154",1.0,2021-07-01,0.0,0.0,13.039051,216.316059
2,3,392001,243830,rejected,6.596133,3.281784,"6.6010417,3.2766339","6.4501069,3.3916154",1.0,2021-07-01,0.0,0.0,13.039051,216.316059
3,4,392001,243539,rejected,6.596142,3.280526,"6.6010417,3.2766339","6.4501069,3.3916154",1.0,2021-07-01,0.0,0.0,13.039051,216.316059
4,5,392001,171653,rejected,6.609232,3.2888,"6.6010417,3.2766339","6.4501069,3.3916154",1.0,2021-07-01,0.0,0.0,13.039051,216.316059


In [6]:
#drop unwanted columns

drop_col = ['trip_start_date']
data = data.drop(columns=drop_col)
data.head(5)

Unnamed: 0,id,order_id,driver_id,driver_action,lat,lng,trip_origin,trip_destination,weekday,holiday,rain,distance,speed
0,1,392001,243828,accepted,6.602207,3.270465,"6.6010417,3.2766339","6.4501069,3.3916154",1.0,0.0,0.0,13.039051,216.316059
1,2,392001,243588,rejected,6.592097,3.287445,"6.6010417,3.2766339","6.4501069,3.3916154",1.0,0.0,0.0,13.039051,216.316059
2,3,392001,243830,rejected,6.596133,3.281784,"6.6010417,3.2766339","6.4501069,3.3916154",1.0,0.0,0.0,13.039051,216.316059
3,4,392001,243539,rejected,6.596142,3.280526,"6.6010417,3.2766339","6.4501069,3.3916154",1.0,0.0,0.0,13.039051,216.316059
4,5,392001,171653,rejected,6.609232,3.2888,"6.6010417,3.2766339","6.4501069,3.3916154",1.0,0.0,0.0,13.039051,216.316059


In [7]:
# choose non-numeric columns
import numpy as np

struct_data = data.copy()
non_numeric_columns = list(struct_data.select_dtypes(exclude=[np.number]).columns)

print(non_numeric_columns)

['driver_action', 'trip_origin', 'trip_destination']


In [10]:
from sklearn.preprocessing import LabelEncoder

le = LabelEncoder()

for col in non_numeric_columns:
    struct_data[col] = le.fit_transform(struct_data[col])

struct_data.head(5)

Unnamed: 0,address,famsize,Pstatus,Medu,Fedu,traveltime,studytime,failures,schoolsup,famsup,...,famrel,freetime,goout,Dalc,Walc,health,absences,G1,G2,G3
0,1,0,0,4,4,2,2,0,1,0,...,4,3,4,1,1,3,4,0,11,11
1,1,0,1,1,1,1,2,0,0,1,...,5,3,3,1,1,3,2,9,11,11
2,1,1,1,1,1,1,2,0,1,0,...,4,3,2,2,3,3,6,12,13,12
3,1,0,1,4,2,1,3,0,0,1,...,3,2,2,1,1,5,0,14,14,14
4,1,0,1,3,3,1,2,0,0,1,...,4,3,2,1,2,5,0,11,13,13


In [11]:
#NOTEARS algorithm to learn the structure.
from causalnex.structure.notears import from_pandas
sm = from_pandas(struct_data)

In [None]:
# visualize the structure

viz = plot_structure(
    sm,
    all_node_attributes=NODE_STYLE.WEAK,
    all_edge_attributes=EDGE_STYLE.WEAK,
)

viz.toggle_physics(False)
viz.show("../screenshots/01_fully_connected.html")

In [None]:
#remove the edges by calling the structure model function, remove_edges_below_threshold
sm.remove_edges_below_threshold(0.8)
viz = plot_structure(
    sm,
    all_node_attributes=NODE_STYLE.WEAK,
    all_edge_attributes=EDGE_STYLE.WEAK,
)
viz.show("../screenshots/01_thresholded.html")

In [None]:
#avoid these erroneous relationships
sm = from_pandas(struct_data, tabu_edges=[("higher", "Medu")], w_threshold=0.8)
viz = plot_structure(
    sm,
    all_node_attributes=NODE_STYLE.WEAK,
    all_edge_attributes=EDGE_STYLE.WEAK,
)
viz.show("../screenshots/01_edge_added.html")

In [15]:

sm.add_edge("failures", "G1")
sm.remove_edge("Pstatus", "G1")
sm.remove_edge("address", "G1")

In [None]:

viz = plot_structure(
    sm,
    all_node_attributes=NODE_STYLE.WEAK,
    all_edge_attributes=EDGE_STYLE.WEAK,
)
viz.show("../screenshots/01_modified_structure.html")