-
Notifications
You must be signed in to change notification settings - Fork 5
Test adequacy #215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Test adequacy #215
Changes from all commits
64cd901
2014b41
80f1342
192d6e4
8a4d42f
7d6ec17
4a470c9
ff78e40
a01190a
5069c7f
5a2d4fb
221a3cc
3d51429
7620f96
d53fef2
886b911
1604d7a
8830175
d8e2a40
aaf49ed
f0ca6a5
046b3a5
933a168
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
""" | ||
This module contains code to measure various aspects of causal test adequacy. | ||
""" | ||
from itertools import combinations | ||
from copy import deepcopy | ||
import pandas as pd | ||
|
||
from causal_testing.testing.causal_test_suite import CausalTestSuite | ||
from causal_testing.data_collection.data_collector import DataCollector | ||
from causal_testing.specification.causal_dag import CausalDAG | ||
from causal_testing.testing.estimators import Estimator | ||
from causal_testing.testing.causal_test_case import CausalTestCase | ||
|
||
|
||
class DAGAdequacy: | ||
""" | ||
Measures the adequacy of a given DAG by hos many edges and independences are tested. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
causal_dag: CausalDAG, | ||
test_suite: CausalTestSuite, | ||
): | ||
self.causal_dag = causal_dag | ||
self.test_suite = test_suite | ||
self.tested_pairs = None | ||
self.pairs_to_test = None | ||
self.untested_edges = None | ||
self.dag_adequacy = None | ||
|
||
def measure_adequacy(self): | ||
""" | ||
Calculate the adequacy measurement, and populate the `dat_adequacy` field. | ||
""" | ||
self.tested_pairs = {(t.treatment_variable, t.outcome_variable) for t in self.test_suite} | ||
self.pairs_to_test = set(combinations(self.causal_dag.graph.nodes, 2)) | ||
self.untested_edges = self.pairs_to_test.difference(self.tested_pairs) | ||
self.dag_adequacy = len(self.tested_pairs) / len(self.pairs_to_test) | ||
|
||
def to_dict(self): | ||
"Returns the adequacy object as a dictionary." | ||
return { | ||
"causal_dag": self.causal_dag, | ||
"test_suite": self.test_suite, | ||
"tested_pairs": self.tested_pairs, | ||
"pairs_to_test": self.pairs_to_test, | ||
"untested_edges": self.untested_edges, | ||
"dag_adequacy": self.dag_adequacy, | ||
} | ||
|
||
|
||
class DataAdequacy: | ||
""" | ||
Measures the adequacy of a given test according to the Fisher kurtosis of the bootstrapped result. | ||
- Positive kurtoses indicate the model doesn't have enough data so is unstable. | ||
- Negative kurtoses indicate the model doesn't have enough data, but is too stable, indicating that the spread of | ||
inputs is insufficient. | ||
- Zero kurtosis is optimal. | ||
""" | ||
|
||
def __init__( | ||
self, test_case: CausalTestCase, estimator: Estimator, data_collector: DataCollector, bootstrap_size: int = 100 | ||
): | ||
self.test_case = test_case | ||
self.estimator = estimator | ||
self.data_collector = data_collector | ||
self.kurtosis = None | ||
self.outcomes = None | ||
self.bootstrap_size = bootstrap_size | ||
|
||
def measure_adequacy(self): | ||
""" | ||
Calculate the adequacy measurement, and populate the data_adequacy field. | ||
""" | ||
results = [] | ||
for i in range(self.bootstrap_size): | ||
estimator = deepcopy(self.estimator) | ||
estimator.df = estimator.df.sample(len(estimator.df), replace=True, random_state=i) | ||
# try: | ||
results.append(self.test_case.execute_test(estimator, self.data_collector)) | ||
# except np.LinAlgError: | ||
# continue | ||
outcomes = [self.test_case.expected_causal_effect.apply(c) for c in results] | ||
results = pd.DataFrame(c.to_dict() for c in results)[["effect_estimate", "ci_low", "ci_high"]] | ||
|
||
def convert_to_df(field): | ||
converted = [] | ||
for r in results[field]: | ||
if isinstance(r, float): | ||
converted.append( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this leave you with a list of dataframes of length 1? If so it seems quite inefficient and convoluted. Would making There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did this so it type checks. If you have a categorical variable, statsmodels handles this with a dummy encoding, so you get a dataframe of coefficients with one for each category. It's easier to turn everything to a df than to handle the two different datatypes separately, although that's probably just me thinking like a functional programmer again... |
||
pd.DataFrame({self.test_case.base_test_case.treatment_variable.name: [r]}).transpose() | ||
) | ||
else: | ||
converted.append(r) | ||
return converted | ||
|
||
for field in ["effect_estimate", "ci_low", "ci_high"]: | ||
results[field] = convert_to_df(field) | ||
|
||
effect_estimate = pd.concat(results["effect_estimate"].tolist(), axis=1).transpose().reset_index(drop=True) | ||
self.kurtosis = effect_estimate.kurtosis() | ||
self.outcomes = sum(outcomes) | ||
|
||
def to_dict(self): | ||
"Returns the adequacy object as a dictionary." | ||
return {"kurtosis": self.kurtosis.to_dict(), "bootstrap_size": self.bootstrap_size, "passing": self.outcomes} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't see this data parameter used in any
setup
calls, is this for some future use?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am using it as part of my case study so I can pass in the data directly rather than having to pass in filepaths