In [2]:
import doubleml as dml
import numpy as np
import pandas as pd
from _utils_pt_manual import fit_policytree
import copy
import pytest
from sklearn.exceptions import NotFittedError

In [3]:
def dml_policytree_fixture(depth):
    n = 50
    np.random.seed(42)
    random_x_var = pd.DataFrame(np.random.normal(0, 1, size=(n, 3)))
    random_signal = np.random.normal(0, 1, size=(n, ))

    policy_tree = dml.DoubleMLPolicyTree(random_signal, random_x_var, depth=depth)

    policy_tree_obj = copy.copy(policy_tree)
    np.random.seed(42)
    policy_tree.fit()
    np.random.seed(42)
    policy_tree_manual = fit_policytree(random_signal, random_x_var, depth=depth)

    res_dict = {'tree': policy_tree.policy_tree.tree_,
                'tree_manual': policy_tree_manual.tree_,
                'x_vars': policy_tree.x_vars,
                'signal': policy_tree.orth_signal,
                'policytree_model': policy_tree,
                'unfitted_policytree_model': policy_tree_obj}

    return res_dict


In [4]:
def compare_nodes(node1, node2):
    if node1.children_left is None:  # Leaf node
        return node2.children_left is None
    else:
        return (
            (node1.feature == node2.feature) and
            (node1.threshold == node2.threshold) and
            compare_nodes(tree1.children_left[node1.children_left], tree2.children_left[node2.children_left]) and
            compare_nodes(tree1.children_right[node1.children_right], tree2.children_right[node2.children_right])
        )

In [5]:
hallo = dml_policytree_fixture(2)

In [6]:
np.allclose(hallo['tree'].threshold,
                       hallo['tree_manual'].threshold,
                       rtol=1e-9, atol=1e-4)

True

In [7]:
from sklearn.tree import DecisionTreeClassifier
assert isinstance(hallo['policytree_model'].__str__(), str)
assert isinstance(hallo['policytree_model'].summary, pd.DataFrame)
assert isinstance(hallo['policytree_model'].policy_tree, DecisionTreeClassifier)

In [8]:
msg = "The features must have the keys Index(['a', 'b', 'c'], dtype='object'). Features with keys Index(['d'], dtype='object') were passed."

In [13]:
random_x_vars = pd.DataFrame(np.random.normal(0, 1, size=(2, 3)), columns=['a', 'b', 'c'])
signal = np.array([1, 2])

msg = "The signal must be of np.ndarray type. Signal of type <class 'int'> was passed."
with pytest.raises(TypeError, match=msg):
    dml.DoubleMLPolicyTree(orth_signal=1, x_vars=random_x_vars)
msg = 'The signal must be of one dimensional. Signal of dimensions 2 was passed.'
with pytest.raises(ValueError, match=msg):
    dml.DoubleMLPolicyTree(orth_signal=np.array([[1], [2]]), x_vars=random_x_vars)
msg = "The features must be of DataFrame type. Features of type <class 'int'> was passed."
with pytest.raises(TypeError, match=msg):
    dml.DoubleMLPolicyTree(orth_signal=signal, x_vars=1)
msg = 'Invalid pd.DataFrame: Contains duplicate column names.'
with pytest.raises(ValueError, match=msg):
    dml.DoubleMLPolicyTree(orth_signal=signal, x_vars=pd.DataFrame(np.array([[1, 2], [4, 5]]),
                                                            columns=['a_1', 'a_1']))

dml_policytree_predict = dml.DoubleMLPolicyTree(orth_signal=signal, x_vars=random_x_vars)
msg = 'Policy Tree not yet fitted. Call fit before predict.'
with pytest.raises(NotFittedError, match=msg):
    dml_policytree_predict.predict(random_x_vars)

dml_policytree_predict.fit()
msg = "The features must be of DataFrame type. Features of type <class 'int'> was passed."
with pytest.raises(TypeError, match=msg):
    dml_policytree_predict.predict(x_vars=1)
msg = r'The features must have the keys Index\(\[\'a\', \'b\', \'c\'\], dtype\=\'object\'\). Features with keys Index\(\[\'d\'\], dtype=\'object\'\) were passed.'
with pytest.raises(KeyError, match=msg):
    dml_policytree_predict.predict(x_vars=pd.DataFrame({"d": [3,4]}))

dml_policytree_plot = dml.DoubleMLPolicyTree(orth_signal=signal, x_vars=random_x_vars)
msg = 'Policy Tree not yet fitted. Call fit before plot_tree.'
with pytest.raises(NotFittedError, match=msg):
    dml_policytree_plot.plot_tree()

In [21]:
depth=1
n = 50
np.random.seed(42)
random_x_var = pd.DataFrame(np.random.normal(0, 1, size=(n, 3)))
random_signal = np.random.normal(0, 1, size=(n, ))

policy_tree = dml.DoubleMLPolicyTree(random_signal, random_x_var, depth)

policy_tree_obj = copy.copy(policy_tree)
np.random.seed(42)
policy_tree.fit()
np.random.seed(42)
policy_tree_manual = fit_policytree(random_signal, random_x_var, depth)

res_dict = {'tree': policy_tree.policy_tree.tree_,
            'tree_manual': policy_tree_manual.tree_,
            'x_vars': policy_tree.x_vars,
            'signal': policy_tree.orth_signal,
            'policytree_model': policy_tree,
            'unfitted_policytree_model': policy_tree_obj}

dml_policytree_fixture = res_dict

In [37]:
list(random_x_var.keys())

[0, 1, 2]