In [None]:
import os
os.environ['CASTLE_BACKEND'] = 'pytorch'

from collections import OrderedDict
import warnings

import numpy as np
import networkx as nx

import ges
from castle.common import GraphDAG
from castle.metrics import MetricsDAG
from castle.datasets import IIDSimulation, DAG
from castle.algorithms import PC, ICALiNGAM, GOLEM

import matplotlib.pyplot as plt

In [None]:
# Mute warnings - for the sake of presentation clarity
# Should be removed for real-life applications
warnings.simplefilter('ignore')

# Causal Discovery in Python


Over the last decade, causal inference gained a lot of traction in academia and in the industry. Causal models can be immensely helpful in various areas – from marketing to medicine and from finance to cybersecurity. To make these models work, we need not only data as in traditional machine learning, but also a causal structure. Traditional way to obtain the latter is through well-designed experiments. Unfortunately, experiments can be tricky – difficult to design, expensive or unethical. Causal discovery (also known as structure learning) is an umbrella term that describes several families of methods aiming at discovering causal structure from observational data. During the talk, we will review the basics of causal inference and introduce the concept of causal discovery. Next, we will discuss differences between various approaches to causal discovery. Finally, we will see a series of practical examples of causal discovery using Python.

## Installing the environment

* Using **Conda**:

    `conda env create --file econml-dowhy-py38.yml`


* Installing `gcastle` only:

    `pip install gcastle==1.0.3rc3`

In [None]:
def get_n_undirected(g):
    
    total = 0
    
    for i in range(g.shape[0]):
        for j in range(g.shape[0]):
            if (g[i, j] == 1) and (g[i, j] == g[j, i]):
                total += .5
    
    return total

## PC algorithm

**PC algorithm** starts with a **fully connected** graph and then performs a series of steps to remove edges, based on graph independence structure. Finally, it tries to orient as many edges as possible.

Figure 1 presents a visual representatrion of these steps.

<br><br>

<img src="img/glymour_et_al_pc.jpg">

<br>

<figcaption><center><b>Figure 1. </b>Original graph and PC algorithm steps. (Gylmour et al., 2019)</center></figcaption>

<br>


Interested in more details? 
[Gylmour et al. - Review of Causal Discovery Methods Based on Graphical Models (2019)](https://www.frontiersin.org/articles/10.3389/fgene.2019.00524/full)

In [None]:
# Let's implement this structure
x = np.random.randn(1000)
y = np.random.randn(1000)

z = x + y + .1 * np.random.randn(1000)
w = .7 * z + .1 * np.random.randn(1000)

In [None]:
# To matrix
pc_dataset = np.vstack([x, y, z, w]).T

In [None]:
# Sanity check
pc_dataset, pc_dataset.shape

In [None]:
# Build the model
pc = PC()
pc.learn(pc_dataset)

In [None]:
pc.causal_matrix

In [None]:
# Get learned graph
learned_graph = nx.DiGraph(pc.causal_matrix)

# Relabel the nodes
MAPPING = {k: v for k, v in zip(range(4), ['X', 'Y', 'Z', 'W'])}
learned_graph = nx.relabel_nodes(learned_graph, MAPPING, copy=True)

# Plot the graph
nx.draw(
    learned_graph, 
    with_labels=True,
    node_size=1800,
    font_size=18,
    font_color='white'
)

## Let's do some more discovery!

### Generate datasets

We'll use a [scale-free](https://en.wikipedia.org/wiki/Scale-free_network) model to generate graphs.

Then we'll use three different causal models on this graph:

* linear Gaussian
* linear exp
* non-linear quadratic

In [None]:
# Data simulation, simulate true causal dag and train_data.
true_dag = DAG.scale_free(n_nodes=10, n_edges=15, seed=18)


DATA_PARAMS = {
    'linearity': ['linear', 'nonlinear'], 
    'distribution': {
        'linear': ['gauss', 'exp'],
        'nonlinear': ['quadratic']
    }
}

datasets = {}

for linearity in DATA_PARAMS['linearity']:
    for distr in DATA_PARAMS['distribution'][linearity]:
        
        datasets[f'{linearity}_{distr}'] = IIDSimulation(
            W=true_dag, 
            n=2000, 
            method=linearity, 
            sem_type=distr)


In [None]:
# Sanity check
datasets

In [None]:
plt.figure(figsize=(16, 8))
for i, dataset in enumerate(datasets):
    X = datasets[dataset].X
    
    plt.subplot(4, 2, i + 1)
    plt.hist(X[:, 0], bins=100)
    plt.title(dataset)
    plt.axis('off')
    
    plt.subplot(4, 2, i + 5)
    plt.scatter(X[:, 8], X[:, 4], alpha=.3)
    plt.title(dataset)
    plt.axis('off')
    
plt.subplots_adjust(hspace=.7)
plt.show()

### Visualize the true graph

In [None]:
nx.draw(
    nx.DiGraph(true_dag), 
    node_size=1800,
    alpha=.7,
    pos=nx.circular_layout(nx.DiGraph(true_dag))
)

In [None]:
GraphDAG(true_dag)
plt.show()

## Method comparison 

In [None]:
methods = OrderedDict({
    'PC': PC,
    'GES': ges,
    'LiNGAM': ICALiNGAM,
    'GOLEM': GOLEM
})

In [None]:
%%time

results = {}

for k, dataset in datasets.items():
    print(f'************* Current dataset: {k}\n')
    X = dataset.X
    
    results[dataset] = {}
    
    for method in methods:
        
        if method not in ['GES', 'CORL']:
            print(f'Method: {method}')
            
            # Fit the model
            if method == 'GOLEM':
                model = methods[method](num_iter=2.5e4)
            else:
                model = methods[method]()
            
            model.learn(X)
            
            pred_dag = model.causal_matrix

        elif method == 'GES':
            print(f'Method: {method}')
            
            # Fit the model
            pred_dag, _ = methods[method].fit_bic(X)
              
        # Get n undir edges
        n_undir = get_n_undirected(pred_dag)

        # Plot results
        GraphDAG(pred_dag, true_dag, 'result')

        mt = MetricsDAG(pred_dag, true_dag)
        print(f'FDR: {mt.metrics["fdr"]}')
        print(f'Recall: {mt.metrics["recall"]}')
        print(f'Precision: {mt.metrics["precision"]}')
        print(f'F1 score: {mt.metrics["F1"]}')
        print(f'No. of undir. edges: {n_undir}\n')
        print('-' * 50, '\n')

        results[dataset][method] = pred_dag      
            
    print('\n')         