From 495873fd0998ac338cd180cb8b98f6bcf07e9164 Mon Sep 17 00:00:00 2001 From: Semih Akbayrak Date: Tue, 23 Sep 2025 13:38:47 +0000 Subject: [PATCH 1/7] Covariate support for Dataset class --- pyproject.toml | 2 +- src/causal_validation/data.py | 106 +++++++++- src/causal_validation/validation/placebo.py | 9 +- src/causal_validation/validation/rmspe.py | 6 +- tests/test_causal_validation/test_data.py | 222 +++++++++++++++++--- 5 files changed, 295 insertions(+), 50 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index accaaef..2cf98ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -171,7 +171,7 @@ select = [ "TID", "ISC", ] -ignore = ["F722"] +ignore = ["F722", "PLW1641"] [tool.ruff.format] quote-style = "double" diff --git a/src/causal_validation/data.py b/src/causal_validation/data.py index 057acfa..a81f13a 100644 --- a/src/causal_validation/data.py +++ b/src/causal_validation/data.py @@ -21,15 +21,45 @@ @dataclass class Dataset: + """A causal inference dataset containing pre/post intervention observations + and optional associated covariates. + + Attributes: + Xtr: Pre-intervention control unit observations (N x D) + Xte: Post-intervention control unit observations (M x D) + ytr: Pre-intervention treated unit observations (N x 1) + yte: Post-intervention treated unit observations (M x 1) + _start_date: Start date for time indexing + Ptr: Pre-intervention control unit covariates (N x D x F) + Pte: Post-intervention control unit covariates (M x D x F) + Rtr: Pre-intervention treated unit covariates (N x 1 x F) + Rte: Post-intervention treated unit covariates (M x 1 x F) + counterfactual: Optional counterfactual outcomes (M x 1) + synthetic: Optional synthetic control outcomes (M x 1). + This is weighted combination of control units + minimizing a distance-based error w.r.t. the + treated in pre-intervention period. + _name: Optional name identifier for the dataset + """ Xtr: Float[np.ndarray, "N D"] Xte: Float[np.ndarray, "M D"] ytr: Float[np.ndarray, "N 1"] yte: Float[np.ndarray, "M 1"] _start_date: dt.date + Ptr: tp.Optional[Float[np.ndarray, "N D F"]] = None + Pte: tp.Optional[Float[np.ndarray, "M D F"]] = None + Rtr: tp.Optional[Float[np.ndarray, "N 1 F"]] = None + Rte: tp.Optional[Float[np.ndarray, "M 1 F"]] = None counterfactual: tp.Optional[Float[np.ndarray, "M 1"]] = None synthetic: tp.Optional[Float[np.ndarray, "M 1"]] = None _name: str = None + def __post_init__(self): + covariates = [self.Ptr, self.Pte, self.Rtr, self.Rte] + self.has_covariates = all(cov is not None for cov in covariates) + if not self.has_covariates: + assert all(cov is None for cov in covariates) + def to_df( self, index_start: str = dt.date(year=2023, month=1, day=1) ) -> pd.DataFrame: @@ -59,6 +89,13 @@ def n_units(self) -> int: def n_timepoints(self) -> int: return self.n_post_intervention + self.n_pre_intervention + @property + def n_covariates(self) -> int: + if self.has_covariates: + return self.Ptr.shape[2] + else: + return 0 + @property def control_units(self) -> Float[np.ndarray, "{self.n_timepoints} {self.n_units}"]: return np.vstack([self.Xtr, self.Xte]) @@ -67,6 +104,26 @@ def control_units(self) -> Float[np.ndarray, "{self.n_timepoints} {self.n_units} def treated_units(self) -> Float[np.ndarray, "{self.n_timepoints} 1"]: return np.vstack([self.ytr, self.yte]) + @property + def control_covariates( + self, + ) -> tp.Optional[ + Float[np.ndarray, "{self.n_timepoints} {self.n_units} {self.n_covariates}"] + ]: + if self.has_covariates: + return np.vstack([self.Ptr, self.Pte]) + else: + return None + + @property + def treated_covariates( + self, + ) -> tp.Optional[Float[np.ndarray, "{self.n_timepoints} 1 {self.n_covariates}"]]: + if self.has_covariates: + return np.vstack([self.Rtr, self.Rte]) + else: + return None + @property def pre_intervention_obs( self, @@ -79,6 +136,32 @@ def post_intervention_obs( ) -> tp.Tuple[Float[np.ndarray, "M D"], Float[np.ndarray, "M 1"]]: return self.Xte, self.yte + @property + def pre_intervention_covariates( + self, + ) -> tp.Optional[ + tp.Tuple[ + Float[np.ndarray, "N D F"], Float[np.ndarray, "N 1 F"], + ] + ]: + if self.has_covariates: + return self.Ptr, self.Rtr + else: + return None + + @property + def post_intervention_covariates( + self, + ) -> tp.Optional[ + tp.Tuple[ + Float[np.ndarray, "M D F"], Float[np.ndarray, "M 1 F"], + ] + ]: + if self.has_covariates: + return self.Pte, self.Rte + else: + return None + @property def full_index(self) -> DatetimeIndex: return self._get_index(self._start_date) @@ -97,7 +180,12 @@ def get_index(self, period: InterventionTypes) -> DatetimeIndex: return self.full_index def _get_columns(self) -> tp.List[str]: - colnames = ["T"] + [f"C{i}" for i in range(self.n_units)] + if self.has_covariates: + colnames = ["T"] + [f"C{i}" for i in range(self.n_units)] + [ + f"F{i}" for i in range(self.n_covariates) + ] + else: + colnames = ["T"] + [f"C{i}" for i in range(self.n_units)] return colnames def _get_index(self, start_date: dt.date) -> DatetimeIndex: @@ -116,7 +204,10 @@ def inflate(self, inflation_vals: Float[np.ndarray, "M 1"]) -> Dataset: Xtr, ytr = [deepcopy(i) for i in self.pre_intervention_obs] Xte, yte = [deepcopy(i) for i in self.post_intervention_obs] inflated_yte = yte * inflation_vals - return Dataset(Xtr, Xte, ytr, inflated_yte, self._start_date, yte) + return Dataset( + Xtr, Xte, ytr, inflated_yte, self._start_date, + self.Ptr, self.Pte, self.Rtr, self.Rte, yte, self.synthetic, self._name + ) def __eq__(self, other: Dataset) -> bool: ytr = np.allclose(self.ytr, other.ytr) @@ -151,14 +242,21 @@ def _slots(self) -> tp.Dict[str, int]: def drop_unit(self, idx: int) -> Dataset: Xtr = np.delete(self.Xtr, [idx], axis=1) Xte = np.delete(self.Xte, [idx], axis=1) + Ptr = np.delete(self.Ptr, [idx], axis=1) if self.Ptr is not None else None + Pte = np.delete(self.Pte, [idx], axis=1) if self.Pte is not None else None return Dataset( Xtr, Xte, self.ytr, self.yte, self._start_date, + Ptr, + Pte, + self.Rtr, + self.Rte, self.counterfactual, self.synthetic, + self._name, ) def to_placebo_data(self, to_treat_idx: int) -> Dataset: @@ -212,5 +310,7 @@ def reassign_treatment( Xtr = data.Xtr Xte = data.Xte return Dataset( - Xtr, Xte, ytr, yte, data._start_date, data.counterfactual, data.synthetic + Xtr, Xte, ytr, yte, data._start_date, + data.Ptr, data.Pte, data.Rtr, data.Rte, + data.counterfactual, data.synthetic, data._name ) diff --git a/src/causal_validation/validation/placebo.py b/src/causal_validation/validation/placebo.py index b8f7c36..b8143fb 100644 --- a/src/causal_validation/validation/placebo.py +++ b/src/causal_validation/validation/placebo.py @@ -1,7 +1,6 @@ from dataclasses import dataclass import typing as tp -from azcausal.core.effect import Effect import numpy as np import pandas as pd from pandera import ( @@ -11,14 +10,8 @@ ) from rich.progress import ( Progress, - ProgressBar, - track, ) from scipy.stats import ttest_1samp -from tqdm import ( - tqdm, - trange, -) from causal_validation.data import ( Dataset, @@ -108,7 +101,7 @@ def execute(self, verbose: bool = True) -> PlaceboTestResult: "[blue]Datasets", total=n_datasets, visible=verbose ) unit_task = progress.add_task( - f"[green]Control Units", + "[green]Control Units", total=n_control, visible=verbose, ) diff --git a/src/causal_validation/validation/rmspe.py b/src/causal_validation/validation/rmspe.py index 6b541ff..b606722 100644 --- a/src/causal_validation/validation/rmspe.py +++ b/src/causal_validation/validation/rmspe.py @@ -2,18 +2,14 @@ import typing as tp from jaxtyping import Float -import numpy as np import pandas as pd from pandera import ( Check, Column, DataFrameSchema, ) -from rich import box from rich.progress import ( Progress, - ProgressBar, - track, ) from causal_validation.validation.placebo import PlaceboTest @@ -87,7 +83,7 @@ def execute(self, verbose: bool = True) -> RMSPETestResult: "[blue]Datasets", total=n_datasets, visible=verbose ) unit_task = progress.add_task( - f"[green]Treatment and Control Units", + "[green]Treatment and Control Units", total=n_control + 1, visible=verbose, ) diff --git a/tests/test_causal_validation/test_data.py b/tests/test_causal_validation/test_data.py index 07554a4..7434bf1 100644 --- a/tests/test_causal_validation/test_data.py +++ b/tests/test_causal_validation/test_data.py @@ -23,6 +23,7 @@ simulate_data, ) from causal_validation.types import InterventionTypes +import datetime as dt MIN_STRING_LENGTH = 1 MAX_STRING_LENGTH = 20 @@ -198,6 +199,10 @@ def test_drop_unit(n_pre: int, n_post: int, n_control: int): assert reduced_data.Xte.shape == desired_shape_Xte assert reduced_data.ytr.shape == desired_shape_ytr assert reduced_data.yte.shape == desired_shape_yte + + assert reduced_data.counterfactual == data.counterfactual + assert reduced_data.synthetic == data.synthetic + assert reduced_data._name == data._name @pytest.mark.parametrize("n_pre, n_post, n_control", [(60, 30, 10), (60, 30, 20)]) @@ -288,37 +293,188 @@ def test_naming_setter(name: str, extra_chars: str): @given( - seeds=st.lists( - elements=st.integers(min_value=1, max_value=1000), min_size=1, max_size=10 - ), - to_name=st.booleans(), + n_pre=st.integers(min_value=10, max_value=100), + n_post=st.integers(min_value=10, max_value=100), + n_control=st.integers(min_value=2, max_value=20), +) +@settings(max_examples=5) +def test_counterfactual_synthetic_attributes(n_pre: int, n_post: int, n_control: int): + constants = TestConstants( + N_POST_TREATMENT=n_post, + N_PRE_TREATMENT=n_pre, + N_CONTROL=n_control, + ) + data = simulate_data(0.0, DEFAULT_SEED, constants=constants) + + assert data.counterfactual is None + assert data.synthetic is None + + counterfactual_vals = np.random.randn(n_post, 1) + synthetic_vals = np.random.randn(n_post, 1) + + data_with_attrs = Dataset( + data.Xtr, data.Xte, data.ytr, data.yte, data._start_date, + data.Ptr, data.Pte, data.Rtr, data.Rte, + counterfactual_vals, synthetic_vals, "test_dataset" + ) + + np.testing.assert_array_equal(data_with_attrs.counterfactual, counterfactual_vals) + np.testing.assert_array_equal(data_with_attrs.synthetic, synthetic_vals) + assert data_with_attrs.name == "test_dataset" + + +@given( + n_pre=st.integers(min_value=10, max_value=100), + n_post=st.integers(min_value=10, max_value=100), + n_control=st.integers(min_value=2, max_value=20), +) +@settings(max_examples=5) +def test_inflate_method(n_pre: int, n_post: int, n_control: int): + constants = TestConstants( + N_POST_TREATMENT=n_post, + N_PRE_TREATMENT=n_pre, + N_CONTROL=n_control, + ) + data = simulate_data(0.0, DEFAULT_SEED, constants=constants) + + inflation_vals = np.ones((n_post, 1)) * 1.1 + inflated_data = data.inflate(inflation_vals) + + np.testing.assert_array_equal(inflated_data.Xtr, data.Xtr) + np.testing.assert_array_equal(inflated_data.ytr, data.ytr) + np.testing.assert_array_equal(inflated_data.Xte, data.Xte) + + expected_yte = data.yte * inflation_vals + np.testing.assert_array_equal(inflated_data.yte, expected_yte) + + np.testing.assert_array_equal(inflated_data.counterfactual, data.yte) + + +@given( + n_pre=st.integers(min_value=10, max_value=100), + n_post=st.integers(min_value=10, max_value=100), + n_control=st.integers(min_value=2, max_value=20), +) +@settings(max_examples=5) +def test_control_treated_properties(n_pre: int, n_post: int, n_control: int): + constants = TestConstants( + N_POST_TREATMENT=n_post, + N_PRE_TREATMENT=n_pre, + N_CONTROL=n_control, + ) + data = simulate_data(0.0, DEFAULT_SEED, constants=constants) + + control_units = data.control_units + expected_control = np.vstack([data.Xtr, data.Xte]) + np.testing.assert_array_equal(control_units, expected_control) + assert control_units.shape == (n_pre + n_post, n_control) + + treated_units = data.treated_units + expected_treated = np.vstack([data.ytr, data.yte]) + np.testing.assert_array_equal(treated_units, expected_treated) + assert treated_units.shape == (n_pre + n_post, 1) + + +@given( + n_pre=st.integers(min_value=10, max_value=100), + n_post=st.integers(min_value=10, max_value=100), + n_control=st.integers(min_value=2, max_value=20), ) -def test_dataset_container(seeds: tp.List[int], to_name: bool): - datasets = [simulate_data(0.0, s) for s in seeds] - if to_name: - names = [f"D_{idx}" for idx in range(len(datasets))] - else: - names = None - container = DatasetContainer(datasets, names) - - # Test names were correctly assigned - if to_name: - assert container.names == names - else: - assert container.names == [f"Dataset {idx}" for idx in range(len(datasets))] - - # Assert ordering - for idx, dataset in enumerate(container): - assert dataset == datasets[idx] - - # Assert no data was dropped/added - assert len(container) == len(datasets) - - # Test `as_dict()` method preserves order - container_dict = container.as_dict() - for idx, (k, v) in enumerate(container_dict.items()): - if to_name: - assert k == names[idx] - else: - assert k == f"Dataset {idx}" - assert v == datasets[idx] +@settings(max_examples=5) +def test_covariate_properties_without_covariates(n_pre: int, n_post: int, n_control: int): + constants = TestConstants( + N_POST_TREATMENT=n_post, + N_PRE_TREATMENT=n_pre, + N_CONTROL=n_control, + ) + data = simulate_data(0.0, DEFAULT_SEED, constants=constants) + + assert data.has_covariates is False + assert data.control_covariates is None + assert data.treated_covariates is None + assert data.pre_intervention_covariates is None + assert data.post_intervention_covariates is None + assert data.n_covariates == 0 + + +@given( + n_pre=st.integers(min_value=10, max_value=50), + n_post=st.integers(min_value=10, max_value=50), + n_control=st.integers(min_value=2, max_value=10), + n_covariates=st.integers(min_value=1, max_value=5), + Xtr=st.data(), + Xte=st.data(), + ytr=st.data(), + yte=st.data(), + Ptr=st.data(), + Pte=st.data(), + Rtr=st.data(), + Rte=st.data(), +) +@settings(max_examples=5) +def test_covariate_properties_with_covariates(n_pre: int, + n_post: int, + n_control: int, + n_covariates: int, + Xtr, + Xte, + ytr, + yte, + Ptr, + Pte, + Rtr, + Rte): + + Xtr = Xtr.draw(st.lists(st.floats(min_value=-10, max_value=10), + min_size=n_pre*n_control, max_size=n_pre*n_control)) + Xtr = np.array(Xtr).reshape(n_pre, n_control) + + Xte = Xte.draw(st.lists(st.floats(min_value=-10, max_value=10), + min_size=n_post*n_control, max_size=n_post*n_control)) + Xte = np.array(Xte).reshape(n_post, n_control) + + ytr = ytr.draw(st.lists(st.floats(min_value=-10, max_value=10), + min_size=n_pre, max_size=n_pre)) + ytr = np.array(ytr).reshape(n_pre, 1) + + yte = yte.draw(st.lists(st.floats(min_value=-10, max_value=10), + min_size=n_post, max_size=n_post)) + yte = np.array(yte).reshape(n_post, 1) + + Ptr = Ptr.draw(st.lists(st.floats(min_value=-10, max_value=10), + min_size=n_pre*n_control*n_covariates, max_size=n_pre*n_control*n_covariates)) + Ptr = np.array(Ptr).reshape(n_pre, n_control, n_covariates) + + Pte = Pte.draw(st.lists(st.floats(min_value=-10, max_value=10), + min_size=n_post*n_control*n_covariates, max_size=n_post*n_control*n_covariates)) + Pte = np.array(Pte).reshape(n_post, n_control, n_covariates) + + Rtr = Rtr.draw(st.lists(st.floats(min_value=-10, max_value=10), + min_size=n_pre*n_covariates, max_size=n_pre*n_covariates)) + Rtr = np.array(Rtr).reshape(n_pre, 1, n_covariates) + + Rte = Rte.draw(st.lists(st.floats(min_value=-10, max_value=10), + min_size=n_post*n_covariates, max_size=n_post*n_covariates)) + Rte = np.array(Rte).reshape(n_post, 1, n_covariates) + + data = Dataset(Xtr, Xte, ytr, yte, dt.date(2023, 1, 1), Ptr, Pte, Rtr, Rte) + + assert data.n_covariates == n_covariates + assert data.has_covariates is True + + control_covariates = data.control_covariates + expected_control_cov = np.vstack([Ptr, Pte]) + np.testing.assert_array_equal(control_covariates, expected_control_cov) + assert control_covariates.shape == (n_pre + n_post, n_control, n_covariates) + + treated_covariates = data.treated_covariates + expected_treated_cov = np.vstack([Rtr, Rte]) + np.testing.assert_array_equal(treated_covariates, expected_treated_cov) + assert treated_covariates.shape == (n_pre + n_post, 1, n_covariates) + + pre_cov = data.pre_intervention_covariates + assert pre_cov == (Ptr, Rtr) + + post_cov = data.post_intervention_covariates + assert post_cov == (Pte, Rte) + From 6682af4fbe5c7925d23e0aea230cdb16e293b0f3 Mon Sep 17 00:00:00 2001 From: Semih Akbayrak Date: Tue, 23 Sep 2025 13:47:10 +0000 Subject: [PATCH 2/7] Add dataset container test back --- tests/test_causal_validation/test_data.py | 35 +++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/test_causal_validation/test_data.py b/tests/test_causal_validation/test_data.py index 7434bf1..70befc4 100644 --- a/tests/test_causal_validation/test_data.py +++ b/tests/test_causal_validation/test_data.py @@ -374,6 +374,41 @@ def test_control_treated_properties(n_pre: int, n_post: int, n_control: int): np.testing.assert_array_equal(treated_units, expected_treated) assert treated_units.shape == (n_pre + n_post, 1) +@given( + seeds=st.lists( + elements=st.integers(min_value=1, max_value=1000), min_size=1, max_size=10 + ), + to_name=st.booleans(), +) +def test_dataset_container(seeds: tp.List[int], to_name: bool): + datasets = [simulate_data(0.0, s) for s in seeds] + if to_name: + names = [f"D_{idx}" for idx in range(len(datasets))] + else: + names = None + container = DatasetContainer(datasets, names) + + # Test names were correctly assigned + if to_name: + assert container.names == names + else: + assert container.names == [f"Dataset {idx}" for idx in range(len(datasets))] + + # Assert ordering + for idx, dataset in enumerate(container): + assert dataset == datasets[idx] + + # Assert no data was dropped/added + assert len(container) == len(datasets) + + # Test `as_dict()` method preserves order + container_dict = container.as_dict() + for idx, (k, v) in enumerate(container_dict.items()): + if to_name: + assert k == names[idx] + else: + assert k == f"Dataset {idx}" + assert v == datasets[idx] @given( n_pre=st.integers(min_value=10, max_value=100), From 0dc4c3d44df393dcb11a751da04561705f5df428 Mon Sep 17 00:00:00 2001 From: Semih Akbayrak Date: Tue, 23 Sep 2025 14:15:35 +0000 Subject: [PATCH 3/7] Fix linting errors in tests. --- tests/test_causal_validation/test_data.py | 112 +++++++----------- .../test_validation/test_placebo.py | 2 +- .../test_validation/test_rmspe.py | 2 +- 3 files changed, 42 insertions(+), 74 deletions(-) diff --git a/tests/test_causal_validation/test_data.py b/tests/test_causal_validation/test_data.py index 70befc4..1b7705c 100644 --- a/tests/test_causal_validation/test_data.py +++ b/tests/test_causal_validation/test_data.py @@ -199,7 +199,7 @@ def test_drop_unit(n_pre: int, n_post: int, n_control: int): assert reduced_data.Xte.shape == desired_shape_Xte assert reduced_data.ytr.shape == desired_shape_ytr assert reduced_data.yte.shape == desired_shape_yte - + assert reduced_data.counterfactual == data.counterfactual assert reduced_data.synthetic == data.synthetic assert reduced_data._name == data._name @@ -305,19 +305,19 @@ def test_counterfactual_synthetic_attributes(n_pre: int, n_post: int, n_control: N_CONTROL=n_control, ) data = simulate_data(0.0, DEFAULT_SEED, constants=constants) - + assert data.counterfactual is None assert data.synthetic is None - + counterfactual_vals = np.random.randn(n_post, 1) synthetic_vals = np.random.randn(n_post, 1) - + data_with_attrs = Dataset( data.Xtr, data.Xte, data.ytr, data.yte, data._start_date, data.Ptr, data.Pte, data.Rtr, data.Rte, counterfactual_vals, synthetic_vals, "test_dataset" ) - + np.testing.assert_array_equal(data_with_attrs.counterfactual, counterfactual_vals) np.testing.assert_array_equal(data_with_attrs.synthetic, synthetic_vals) assert data_with_attrs.name == "test_dataset" @@ -336,17 +336,17 @@ def test_inflate_method(n_pre: int, n_post: int, n_control: int): N_CONTROL=n_control, ) data = simulate_data(0.0, DEFAULT_SEED, constants=constants) - - inflation_vals = np.ones((n_post, 1)) * 1.1 + + inflation_vals = np.ones((n_post, 1)) * 1.1 inflated_data = data.inflate(inflation_vals) - + np.testing.assert_array_equal(inflated_data.Xtr, data.Xtr) np.testing.assert_array_equal(inflated_data.ytr, data.ytr) np.testing.assert_array_equal(inflated_data.Xte, data.Xte) - + expected_yte = data.yte * inflation_vals np.testing.assert_array_equal(inflated_data.yte, expected_yte) - + np.testing.assert_array_equal(inflated_data.counterfactual, data.yte) @@ -363,12 +363,12 @@ def test_control_treated_properties(n_pre: int, n_post: int, n_control: int): N_CONTROL=n_control, ) data = simulate_data(0.0, DEFAULT_SEED, constants=constants) - + control_units = data.control_units expected_control = np.vstack([data.Xtr, data.Xte]) np.testing.assert_array_equal(control_units, expected_control) assert control_units.shape == (n_pre + n_post, n_control) - + treated_units = data.treated_units expected_treated = np.vstack([data.ytr, data.yte]) np.testing.assert_array_equal(treated_units, expected_treated) @@ -416,14 +416,16 @@ def test_dataset_container(seeds: tp.List[int], to_name: bool): n_control=st.integers(min_value=2, max_value=20), ) @settings(max_examples=5) -def test_covariate_properties_without_covariates(n_pre: int, n_post: int, n_control: int): +def test_covariate_properties_without_covariates( + n_pre: int, n_post: int, n_control: int +): constants = TestConstants( N_POST_TREATMENT=n_post, N_PRE_TREATMENT=n_pre, N_CONTROL=n_control, ) data = simulate_data(0.0, DEFAULT_SEED, constants=constants) - + assert data.has_covariates is False assert data.control_covariates is None assert data.treated_covariates is None @@ -437,79 +439,45 @@ def test_covariate_properties_without_covariates(n_pre: int, n_post: int, n_cont n_post=st.integers(min_value=10, max_value=50), n_control=st.integers(min_value=2, max_value=10), n_covariates=st.integers(min_value=1, max_value=5), - Xtr=st.data(), - Xte=st.data(), - ytr=st.data(), - yte=st.data(), - Ptr=st.data(), - Pte=st.data(), - Rtr=st.data(), - Rte=st.data(), + seed=st.integers(min_value=1, max_value=10000), ) @settings(max_examples=5) -def test_covariate_properties_with_covariates(n_pre: int, - n_post: int, - n_control: int, - n_covariates: int, - Xtr, - Xte, - ytr, - yte, - Ptr, - Pte, - Rtr, - Rte): - - Xtr = Xtr.draw(st.lists(st.floats(min_value=-10, max_value=10), - min_size=n_pre*n_control, max_size=n_pre*n_control)) - Xtr = np.array(Xtr).reshape(n_pre, n_control) - - Xte = Xte.draw(st.lists(st.floats(min_value=-10, max_value=10), - min_size=n_post*n_control, max_size=n_post*n_control)) - Xte = np.array(Xte).reshape(n_post, n_control) - - ytr = ytr.draw(st.lists(st.floats(min_value=-10, max_value=10), - min_size=n_pre, max_size=n_pre)) - ytr = np.array(ytr).reshape(n_pre, 1) - - yte = yte.draw(st.lists(st.floats(min_value=-10, max_value=10), - min_size=n_post, max_size=n_post)) - yte = np.array(yte).reshape(n_post, 1) - - Ptr = Ptr.draw(st.lists(st.floats(min_value=-10, max_value=10), - min_size=n_pre*n_control*n_covariates, max_size=n_pre*n_control*n_covariates)) - Ptr = np.array(Ptr).reshape(n_pre, n_control, n_covariates) - - Pte = Pte.draw(st.lists(st.floats(min_value=-10, max_value=10), - min_size=n_post*n_control*n_covariates, max_size=n_post*n_control*n_covariates)) - Pte = np.array(Pte).reshape(n_post, n_control, n_covariates) - - Rtr = Rtr.draw(st.lists(st.floats(min_value=-10, max_value=10), - min_size=n_pre*n_covariates, max_size=n_pre*n_covariates)) - Rtr = np.array(Rtr).reshape(n_pre, 1, n_covariates) - - Rte = Rte.draw(st.lists(st.floats(min_value=-10, max_value=10), - min_size=n_post*n_covariates, max_size=n_post*n_covariates)) - Rte = np.array(Rte).reshape(n_post, 1, n_covariates) - +def test_covariate_properties_with_covariates( + n_pre: int, + n_post: int, + n_control: int, + n_covariates: int, + seed: int, +): + rng = np.random.RandomState(seed) + + Xtr = rng.uniform(-10, 10, (n_pre, n_control)) + Xte = rng.uniform(-10, 10, (n_post, n_control)) + ytr = rng.uniform(-10, 10, (n_pre, 1)) + yte = rng.uniform(-10, 10, (n_post, 1)) + Ptr = rng.uniform(-10, 10, (n_pre, n_control, n_covariates)) + Pte = rng.uniform(-10, 10, (n_post, n_control, n_covariates)) + Rtr = rng.uniform(-10, 10, (n_pre, 1, n_covariates)) + Rte = rng.uniform(-10, 10, (n_post, 1, n_covariates)) + data = Dataset(Xtr, Xte, ytr, yte, dt.date(2023, 1, 1), Ptr, Pte, Rtr, Rte) - + assert data.n_covariates == n_covariates assert data.has_covariates is True - + control_covariates = data.control_covariates expected_control_cov = np.vstack([Ptr, Pte]) np.testing.assert_array_equal(control_covariates, expected_control_cov) assert control_covariates.shape == (n_pre + n_post, n_control, n_covariates) - + treated_covariates = data.treated_covariates expected_treated_cov = np.vstack([Rtr, Rte]) np.testing.assert_array_equal(treated_covariates, expected_treated_cov) assert treated_covariates.shape == (n_pre + n_post, 1, n_covariates) - + pre_cov = data.pre_intervention_covariates assert pre_cov == (Ptr, Rtr) - + post_cov = data.post_intervention_covariates assert post_cov == (Pte, Rte) diff --git a/tests/test_causal_validation/test_validation/test_placebo.py b/tests/test_causal_validation/test_validation/test_placebo.py index 858c5f3..172de83 100644 --- a/tests/test_causal_validation/test_validation/test_placebo.py +++ b/tests/test_causal_validation/test_validation/test_placebo.py @@ -30,7 +30,7 @@ def test_schema_coerce(): df = PlaceboSchema.example() cols = df.columns for col in cols: - if not col in ["Model", "Dataset"]: + if col not in ["Model", "Dataset"]: df[col] = np.ceil((df[col])) PlaceboSchema.validate(df) diff --git a/tests/test_causal_validation/test_validation/test_rmspe.py b/tests/test_causal_validation/test_validation/test_rmspe.py index 1bc6b37..ead1dfa 100644 --- a/tests/test_causal_validation/test_validation/test_rmspe.py +++ b/tests/test_causal_validation/test_validation/test_rmspe.py @@ -35,7 +35,7 @@ def test_schema_coerce(): df = RMSPESchema.example() cols = df.columns for col in cols: - if not col in ["Model", "Dataset"]: + if col not in ["Model", "Dataset"]: df[col] = np.ceil((df[col])) RMSPESchema.validate(df) From 51def3b8565fe0bc2be51fa0ab37a58d7b9a500b Mon Sep 17 00:00:00 2001 From: Semih Akbayrak Date: Wed, 24 Sep 2025 15:08:48 +0000 Subject: [PATCH 4/7] Covariates can be included in simulations. --- src/causal_validation/config.py | 59 +++++++ src/causal_validation/simulate.py | 41 ++++- src/causal_validation/weights.py | 32 +++- tests/test_causal_validation/test_config.py | 80 ++++++++++ tests/test_causal_validation/test_simulate.py | 145 ++++++++++++++++++ tests/test_causal_validation/test_weights.py | 32 +++- 6 files changed, 376 insertions(+), 13 deletions(-) create mode 100644 tests/test_causal_validation/test_config.py create mode 100644 tests/test_causal_validation/test_simulate.py diff --git a/src/causal_validation/config.py b/src/causal_validation/config.py index 8040be1..ca7928a 100644 --- a/src/causal_validation/config.py +++ b/src/causal_validation/config.py @@ -4,7 +4,11 @@ ) import datetime as dt +from jaxtyping import Float +import typing as tp + import numpy as np +from scipy.stats import halfcauchy from causal_validation.types import ( Number, @@ -20,9 +24,38 @@ class WeightConfig: @dataclass(kw_only=True) class Config: + """Configuration for causal data generation. + + Args: + n_control_units (int): Number of control units in the synthetic dataset. + n_pre_intervention_timepoints (int): Number of time points before intervention. + n_post_intervention_timepoints (int): Number of time points after intervention. + n_covariates (Optional[int]): Number of covariates. Defaults to None. + covariate_means (Optional[Float[np.ndarray, "D K"]]): Mean values for covariates + D is n_control_units and K is n_covariates. Defaults to None. If it is set + to None while n_covariates is provided, covariate_means will be generated + randomly from Normal distribution. + covariate_stds (Optional[Float[np.ndarray, "D K"]]): Standard deviations for + covariates. D is n_control_units and K is n_covariates. Defaults to None. + If it is set to None while n_covariates is provided, covariate_stds + will be generated randomly from Half-Cauchy distribution. + covariate_coeffs (Optional[np.ndarray]): Linear regression + coefficients to map covariates to output observations. K is n_covariates. + Defaults to None. + global_mean (Number): Global mean for data generation. Defaults to 20.0. + global_scale (Number): Global scale for data generation. Defaults to 0.2. + start_date (dt.date): Start date for time series. Defaults to 2023-01-01. + seed (int): Random seed for reproducibility. Defaults to 123. + weights_cfg (WeightConfig): Configuration for unit weights. Defaults to + UniformWeights. + """ n_control_units: int n_pre_intervention_timepoints: int n_post_intervention_timepoints: int + n_covariates: tp.Optional[int] = None + covariate_means: tp.Optional[Float[np.ndarray, "D K"]] = None + covariate_stds: tp.Optional[Float[np.ndarray, "D K"]] = None + covariate_coeffs: tp.Optional[np.ndarray] = None global_mean: Number = 20.0 global_scale: Number = 0.2 start_date: dt.date = dt.date(year=2023, month=1, day=1) @@ -31,3 +64,29 @@ class Config: def __post_init__(self): self.rng = np.random.RandomState(self.seed) + if self.covariate_means is not None: + assert self.covariate_means.shape == (self.n_control_units, + self.n_covariates) + + if self.covariate_stds is not None: + assert self.covariate_stds.shape == (self.n_control_units, + self.n_covariates) + + if (self.n_covariates is not None) & (self.covariate_means is None): + self.covariate_means = self.rng.normal( + loc=0.0, scale=5.0, size=(self.n_control_units, + self.n_covariates) + ) + + if (self.n_covariates is not None) & (self.covariate_stds is None): + self.covariate_stds = ( + halfcauchy.rvs(scale=0.5, + size=(self.n_control_units, + self.n_covariates), + random_state=self.rng) + ) + + if (self.n_covariates is not None) & (self.covariate_coeffs is None): + self.covariate_coeffs = self.rng.normal( + loc=0.0, scale=5.0, size=self.n_covariates + ) diff --git a/src/causal_validation/simulate.py b/src/causal_validation/simulate.py index 2f02c10..e85e80f 100644 --- a/src/causal_validation/simulate.py +++ b/src/causal_validation/simulate.py @@ -29,9 +29,40 @@ def _simulate_base_obs( obs = key.normal( loc=config.global_mean, scale=config.global_scale, size=(n_timepoints, n_units) ) - Xtr = obs[: config.n_pre_intervention_timepoints, :] - Xte = obs[config.n_pre_intervention_timepoints :, :] - ytr = weights.weight_obs(Xtr) - yte = weights.weight_obs(Xte) - data = Dataset(Xtr, Xte, ytr, yte, _start_date=config.start_date) + + if config.n_covariates is not None: + Xtr_ = obs[: config.n_pre_intervention_timepoints, :] + Xte_ = obs[config.n_pre_intervention_timepoints :, :] + + covariates = key.normal( + loc=config.covariate_means, + scale=config.covariate_stds, + size=(n_timepoints, n_units, config.n_covariates) + ) + + Ptr = covariates[:config.n_pre_intervention_timepoints, :, :] + Pte = covariates[config.n_pre_intervention_timepoints:, :, :] + + Xtr = Xtr_ + Ptr @ config.covariate_coeffs + Xte = Xte_ + Pte @ config.covariate_coeffs + + ytr = weights.weight_contr(Xtr) + yte = weights.weight_contr(Xte) + + Rtr = weights.weight_contr(Ptr) + Rte = weights.weight_contr(Pte) + + data = Dataset( + Xtr, Xte, ytr, yte, _start_date=config.start_date, + Ptr=Ptr, Pte=Pte, Rtr=Rtr, Rte=Rte + ) + else: + Xtr = obs[: config.n_pre_intervention_timepoints, :] + Xte = obs[config.n_pre_intervention_timepoints :, :] + + ytr = weights.weight_contr(Xtr) + yte = weights.weight_contr(Xte) + + data = Dataset(Xtr, Xte, ytr, yte, _start_date=config.start_date) + return data diff --git a/src/causal_validation/weights.py b/src/causal_validation/weights.py index 42234f9..8d6c6f2 100644 --- a/src/causal_validation/weights.py +++ b/src/causal_validation/weights.py @@ -11,15 +11,23 @@ if tp.TYPE_CHECKING: from causal_validation.config import WeightConfig +# Constants for array dimensions +_NDIM_2D = 2 +_NDIM_3D = 3 + @dataclass class AbstractWeights(BaseObject): name: str = "Abstract Weights" - def _get_weights(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "D 1"]: + def _get_weights( + self, obs: Float[np.ndarray, "N D"] | Float[np.ndarray, "N D K"] + ) -> Float[np.ndarray, "D 1"]: raise NotImplementedError("Please implement `_get_weights` in all subclasses.") - def get_weights(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "D 1"]: + def get_weights( + self, obs: Float[np.ndarray, "N D"] | Float[np.ndarray, "N D K"] + ) -> Float[np.ndarray, "D 1"]: weights = self._get_weights(obs) np.testing.assert_almost_equal( @@ -28,13 +36,21 @@ def get_weights(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "D 1"] assert min(weights >= 0), "Weights should be non-negative" return weights - def __call__(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "N 1"]: - return self.weight_obs(obs) + def __call__( + self, obs: Float[np.ndarray, "N D"] | Float[np.ndarray, "N D K"] + ) -> Float[np.ndarray, "N 1"] | Float[np.ndarray, "N 1 K"]: + return self.weight_contr(obs) - def weight_obs(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "N 1"]: + def weight_contr( + self, obs: Float[np.ndarray, "N D"] | Float[np.ndarray, "N D K"] + ) -> Float[np.ndarray, "N 1"] | Float[np.ndarray, "N 1 K"]: weights = self.get_weights(obs) - weighted_obs = obs @ weights + if obs.ndim == _NDIM_2D: + weighted_obs = obs @ weights + elif obs.ndim == _NDIM_3D: + weighted_obs = np.einsum("n d k, d i -> n i k", obs, weights) + return weighted_obs @@ -42,7 +58,9 @@ def weight_obs(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "N 1"]: class UniformWeights(AbstractWeights): name: str = "Uniform Weights" - def _get_weights(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "D 1"]: + def _get_weights( + self, obs: Float[np.ndarray, "N D"] | Float[np.ndarray, "N D K"] + ) -> Float[np.ndarray, "D 1"]: n_units = obs.shape[1] return np.repeat(1.0 / n_units, repeats=n_units).reshape(-1, 1) diff --git a/tests/test_causal_validation/test_config.py b/tests/test_causal_validation/test_config.py new file mode 100644 index 0000000..afca1b4 --- /dev/null +++ b/tests/test_causal_validation/test_config.py @@ -0,0 +1,80 @@ +import numpy as np +from hypothesis import given, strategies as st + +from causal_validation.config import Config + + +@given( + n_units=st.integers(min_value=1, max_value=10), + n_pre=st.integers(min_value=1, max_value=20), + n_post=st.integers(min_value=1, max_value=20) +) +def test_config_basic_initialization(n_units, n_pre, n_post): + cfg = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=n_pre, + n_post_intervention_timepoints=n_post + ) + assert cfg.n_control_units == n_units + assert cfg.n_pre_intervention_timepoints == n_pre + assert cfg.n_post_intervention_timepoints == n_post + assert cfg.n_covariates is None + assert cfg.covariate_means is None + assert cfg.covariate_stds is None + assert cfg.covariate_coeffs is None + + +@given( + n_units=st.integers(min_value=1, max_value=5), + n_pre=st.integers(min_value=1, max_value=10), + n_post=st.integers(min_value=1, max_value=10), + n_covariates=st.integers(min_value=1, max_value=3), + seed=st.integers(min_value=1, max_value=1000) +) +def test_config_with_covariates_auto_generation( + n_units, n_pre, n_post, n_covariates, seed +): + cfg = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=n_pre, + n_post_intervention_timepoints=n_post, + n_covariates=n_covariates, + seed=seed + ) + assert cfg.n_covariates == n_covariates + assert cfg.covariate_means.shape == (n_units, n_covariates) + assert cfg.covariate_stds.shape == (n_units, n_covariates) + assert cfg.covariate_coeffs.shape == (n_covariates,) + assert np.all(cfg.covariate_stds >= 0) + + +@given( + n_units=st.integers(min_value=1, max_value=3), + n_covariates=st.integers(min_value=1, max_value=3) +) +def test_config_with_explicit_covariate_means(n_units, n_covariates): + means = np.random.random((n_units, n_covariates)) + cfg = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=10, + n_post_intervention_timepoints=5, + n_covariates=n_covariates, + covariate_means=means + ) + np.testing.assert_array_equal(cfg.covariate_means, means) + + +@given( + n_units=st.integers(min_value=1, max_value=3), + n_covariates=st.integers(min_value=1, max_value=3) +) +def test_config_with_explicit_covariate_stds(n_units, n_covariates): + stds = np.random.random((n_units, n_covariates)) + 0.1 + cfg = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=10, + n_post_intervention_timepoints=5, + n_covariates=n_covariates, + covariate_stds=stds + ) + np.testing.assert_array_equal(cfg.covariate_stds, stds) diff --git a/tests/test_causal_validation/test_simulate.py b/tests/test_causal_validation/test_simulate.py new file mode 100644 index 0000000..52213dc --- /dev/null +++ b/tests/test_causal_validation/test_simulate.py @@ -0,0 +1,145 @@ +import numpy as np +from hypothesis import given, strategies as st + +from causal_validation.config import Config +from causal_validation.simulate import simulate + + +@given( + n_units=st.integers(min_value=1, max_value=5), + n_pre=st.integers(min_value=1, max_value=10), + n_post=st.integers(min_value=1, max_value=10), + seed=st.integers(min_value=1, max_value=1000) +) +def test_simulate_basic(n_units, n_pre, n_post, seed): + cfg = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=n_pre, + n_post_intervention_timepoints=n_post, + seed=seed + ) + data = simulate(cfg) + + assert data.Xtr.shape == (n_pre, n_units) + assert data.Xte.shape == (n_post, n_units) + assert data.ytr.shape == (n_pre, 1) + assert data.yte.shape == (n_post, 1) + assert not data.has_covariates + + +@given( + n_units=st.integers(min_value=1, max_value=5), + n_pre=st.integers(min_value=1, max_value=10), + n_post=st.integers(min_value=1, max_value=10), + n_covariates=st.integers(min_value=1, max_value=3), + seed=st.integers(min_value=1, max_value=1000) +) +def test_simulate_with_covariates(n_units, n_pre, n_post, n_covariates, seed): + cfg = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=n_pre, + n_post_intervention_timepoints=n_post, + n_covariates=n_covariates, + seed=seed + ) + data = simulate(cfg) + + assert data.Xtr.shape == (n_pre, n_units) + assert data.Xte.shape == (n_post, n_units) + assert data.ytr.shape == (n_pre, 1) + assert data.yte.shape == (n_post, 1) + assert data.has_covariates + assert data.Ptr.shape == (n_pre, n_units, n_covariates) + assert data.Pte.shape == (n_post, n_units, n_covariates) + assert data.Rtr.shape == (n_pre, 1, n_covariates) + assert data.Rte.shape == (n_post, 1, n_covariates) + + +@given( + n_units=st.integers(min_value=1, max_value=5), + n_pre=st.integers(min_value=1, max_value=10), + n_post=st.integers(min_value=1, max_value=10), + seed=st.integers(min_value=1, max_value=1000) +) +def test_simulate_reproducible(n_units, n_pre, n_post, seed): + cfg1 = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=n_pre, + n_post_intervention_timepoints=n_post, + seed=seed + ) + data1 = simulate(cfg1) + + cfg2 = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=n_pre, + n_post_intervention_timepoints=n_post, + seed=seed + ) + data2 = simulate(cfg2) + + np.testing.assert_array_equal(data1.Xtr, data2.Xtr) + np.testing.assert_array_equal(data1.Xte, data2.Xte) + np.testing.assert_array_equal(data1.ytr, data2.ytr) + np.testing.assert_array_equal(data1.yte, data2.yte) + + +@given( + n_units=st.integers(min_value=1, max_value=3), + n_pre=st.integers(min_value=3, max_value=10), + n_post=st.integers(min_value=1, max_value=5), + seed=st.integers(min_value=1, max_value=1000) +) +def test_simulate_covariate_effects(n_units, n_pre, n_post, seed): + cfg = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=n_pre, + n_post_intervention_timepoints=n_post, + n_covariates=1, + covariate_coeffs=np.array([10.0]), + seed=seed + ) + data_with_cov = simulate(cfg) + + cfg_no_cov = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=n_pre, + n_post_intervention_timepoints=n_post, + seed=seed + ) + data_no_cov = simulate(cfg_no_cov) + + assert not np.allclose(data_with_cov.ytr, data_no_cov.ytr) + assert not np.allclose(data_with_cov.yte, data_no_cov.yte) + + +@given( + n_units=st.integers(min_value=1, max_value=3), + n_pre=st.integers(min_value=3, max_value=10), + n_post=st.integers(min_value=1, max_value=5), + seed=st.integers(min_value=1, max_value=1000) +) +def test_simulate_exact_covariate_effects(n_units, n_pre, n_post, seed): + cfg = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=n_pre, + n_post_intervention_timepoints=n_post, + n_covariates=2, + covariate_means=np.ones((n_units,2)), + covariate_stds= 1e-12*np.ones((n_units,2)), + covariate_coeffs=np.array([10.0, 5.0]), + seed=seed + ) + data_with_cov = simulate(cfg) + + cfg_no_cov = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=n_pre, + n_post_intervention_timepoints=n_post, + seed=seed + ) + data_no_cov = simulate(cfg_no_cov) + + assert np.allclose(data_with_cov.Xtr-15, data_no_cov.Xtr) + assert np.allclose(data_with_cov.Xte-15, data_no_cov.Xte) + diff --git a/tests/test_causal_validation/test_weights.py b/tests/test_causal_validation/test_weights.py index 1b5ed6d..543d711 100644 --- a/tests/test_causal_validation/test_weights.py +++ b/tests/test_causal_validation/test_weights.py @@ -23,10 +23,40 @@ def test_uniform_weights(n_units: int, n_time: int): n_units=st.integers(min_value=1, max_value=100), n_time=st.integers(min_value=1, max_value=100), ) -def test_weight_obs(n_units: int, n_time: int): +def test_weight_contr(n_units: int, n_time: int): obs = np.ones(shape=(n_time, n_units)) weighted_obs = UniformWeights()(obs) np.testing.assert_almost_equal(np.mean(weighted_obs), weighted_obs, decimal=6) np.testing.assert_almost_equal( obs @ UniformWeights().get_weights(obs), weighted_obs, decimal=6 ) + + +@given( + n_units=st.integers(min_value=1, max_value=10), + n_time=st.integers(min_value=1, max_value=10), + n_covariates=st.integers(min_value=1, max_value=5), +) +def test_weight_contr_3d(n_units: int, n_time: int, n_covariates: int): + covariates = np.ones(shape=(n_time, n_units, n_covariates)) + weights = UniformWeights() + weighted_covs = weights.weight_contr(covariates) + + assert weighted_covs.shape == (n_time, 1, n_covariates) + expected = np.einsum("n d k, d i -> n i k", + covariates, weights.get_weights(covariates)) + np.testing.assert_almost_equal(weighted_covs, expected, decimal=6) + + +def test_weights_sum_to_one(): + obs = np.random.random((10, 5)) + weights = UniformWeights() + weight_vals = weights.get_weights(obs) + np.testing.assert_almost_equal(weight_vals.sum(), 1.0, decimal=6) + + +def test_weights_non_negative(): + obs = np.random.random((10, 5)) + weights = UniformWeights() + weight_vals = weights.get_weights(obs) + assert np.all(weight_vals >= 0) From 3675aebf7f3edcdedb62cea138684fbaae6a23ec Mon Sep 17 00:00:00 2001 From: Semih Akbayrak Date: Fri, 26 Sep 2025 10:24:42 +0000 Subject: [PATCH 5/7] Dataset to_df revision to support covariates properly. --- src/causal_validation/data.py | 44 +++++++++++++++------- src/causal_validation/testing.py | 2 + tests/test_causal_validation/test_data.py | 46 ++++++++++++++++++++++- 3 files changed, 76 insertions(+), 16 deletions(-) diff --git a/src/causal_validation/data.py b/src/causal_validation/data.py index a81f13a..e7922a8 100644 --- a/src/causal_validation/data.py +++ b/src/causal_validation/data.py @@ -62,16 +62,36 @@ def __post_init__(self): def to_df( self, index_start: str = dt.date(year=2023, month=1, day=1) - ) -> pd.DataFrame: - inputs = np.vstack([self.Xtr, self.Xte]) - outputs = np.vstack([self.ytr, self.yte]) - data = np.hstack([outputs, inputs]) + ) -> tp.Tuple[pd.DataFrame, tp.Optional[pd.DataFrame]]: + control_outputs = np.vstack([self.Xtr, self.Xte]) + treated_outputs = np.vstack([self.ytr, self.yte]) + data = np.hstack([treated_outputs, control_outputs]) index = self._get_index(index_start) colnames = self._get_columns() indicator = self._get_indicator() - df = pd.DataFrame(data, index=index, columns=colnames) - df = df.assign(treated=indicator) - return df + df_outputs = pd.DataFrame(data, index=index, columns=colnames) + df_outputs = df_outputs.assign(treated=indicator) + + if not self.has_covariates: + cov_df = None + else: + control_covs = np.concatenate([self.Ptr, self.Pte], axis=0) + treated_covs = np.concatenate([self.Rtr, self.Rte], axis=0) + + all_covs = np.concatenate([treated_covs, control_covs], axis=1) + + unit_cols = self._get_columns() + covariate_cols = [f"F{i}" for i in range(self.n_covariates)] + + cov_data = all_covs.reshape(self.n_timepoints, -1) + + col_tuples = [(unit, cov) for unit in unit_cols for cov in covariate_cols] + multi_cols = pd.MultiIndex.from_tuples(col_tuples) + + cov_df = pd.DataFrame(cov_data, index=index, columns=multi_cols) + cov_df = cov_df.assign(treated=indicator) + + return df_outputs, cov_df @property def n_post_intervention(self) -> int: @@ -180,12 +200,7 @@ def get_index(self, period: InterventionTypes) -> DatetimeIndex: return self.full_index def _get_columns(self) -> tp.List[str]: - if self.has_covariates: - colnames = ["T"] + [f"C{i}" for i in range(self.n_units)] + [ - f"F{i}" for i in range(self.n_covariates) - ] - else: - colnames = ["T"] + [f"C{i}" for i in range(self.n_units)] + colnames = ["T"] + [f"C{i}" for i in range(self.n_units)] return colnames def _get_index(self, start_date: dt.date) -> DatetimeIndex: @@ -224,7 +239,8 @@ def __eq__(self, other: Dataset) -> bool: def to_azcausal(self): time_index = np.arange(self.n_timepoints) - data = self.to_df().assign(time=time_index).melt(id_vars=["time", "treated"]) + data_df, _ = self.to_df() + data = data_df.assign(time=time_index).melt(id_vars=["time", "treated"]) data.loc[:, "treated"] = np.where( (data["variable"] == "T") & (data["treated"] == 1.0), 1, 0 ) diff --git a/src/causal_validation/testing.py b/src/causal_validation/testing.py index a1bcb22..56a1561 100644 --- a/src/causal_validation/testing.py +++ b/src/causal_validation/testing.py @@ -11,6 +11,7 @@ class TestConstants: N_CONTROL: int = 10 N_PRE_TREATMENT: int = 500 N_POST_TREATMENT: int = 500 + N_COVARIATES: tp.Optional[int] = None DATA_SLOTS: tp.Tuple[str, str, str, str] = ("Xtr", "Xte", "ytr", "yte") ZERO_DIVISION_ERROR: float = 1e-6 GLOBAL_SCALE: float = 1.0 @@ -26,6 +27,7 @@ def simulate_data( n_control_units=constants.N_CONTROL, n_pre_intervention_timepoints=constants.N_PRE_TREATMENT, n_post_intervention_timepoints=constants.N_POST_TREATMENT, + n_covariates=constants.N_COVARIATES, global_mean=global_mean, global_scale=constants.GLOBAL_SCALE, seed=seed, diff --git a/tests/test_causal_validation/test_data.py b/tests/test_causal_validation/test_data.py index 1b7705c..28b4638 100644 --- a/tests/test_causal_validation/test_data.py +++ b/tests/test_causal_validation/test_data.py @@ -29,6 +29,7 @@ MAX_STRING_LENGTH = 20 DEFAULT_SEED = 123 NUM_NON_CONTROL_COLS = 2 +NUM_TREATED = 1 LARGE_N_POST = 5000 LARGE_N_PRE = 5000 @@ -109,7 +110,7 @@ def test_indicator(n_pre_treatment: int, n_post_treatment: int): n_pre_treatment=st.integers(min_value=1, max_value=50), n_post_treatment=st.integers(min_value=1, max_value=50), ) -def test_to_df(n_control: int, n_pre_treatment: int, n_post_treatment: int): +def test_to_df_no_cov(n_control: int, n_pre_treatment: int, n_post_treatment: int): constants = TestConstants( N_POST_TREATMENT=n_post_treatment, N_PRE_TREATMENT=n_pre_treatment, @@ -117,7 +118,7 @@ def test_to_df(n_control: int, n_pre_treatment: int, n_post_treatment: int): ) data = simulate_data(0.0, DEFAULT_SEED, constants=constants) - df = data.to_df() + df, _ = data.to_df() assert isinstance(df, pd.DataFrame) assert df.shape == ( n_pre_treatment + n_post_treatment, @@ -133,6 +134,47 @@ def test_to_df(n_control: int, n_pre_treatment: int, n_post_treatment: int): assert isinstance(index, DatetimeIndex) assert index[0].strftime("%Y-%m-%d") == data._start_date.strftime("%Y-%m-%d") +@given( + n_control=st.integers(min_value=1, max_value=50), + n_pre_treatment=st.integers(min_value=1, max_value=50), + n_post_treatment=st.integers(min_value=1, max_value=50), + n_covariates=st.integers(min_value=1, max_value=50), +) +def test_to_df_with_cov(n_control: int, + n_pre_treatment: int, + n_post_treatment: int, + n_covariates:int): + constants = TestConstants( + N_POST_TREATMENT=n_post_treatment, + N_PRE_TREATMENT=n_pre_treatment, + N_CONTROL=n_control, + N_COVARIATES=n_covariates, + ) + data = simulate_data(0.0, DEFAULT_SEED, constants=constants) + + df_outs, df_covs = data.to_df() + assert isinstance(df_outs, pd.DataFrame) + assert df_outs.shape == ( + n_pre_treatment + n_post_treatment, + n_control + NUM_NON_CONTROL_COLS, + ) + + assert isinstance(df_covs, pd.DataFrame) + assert df_covs.shape == ( + n_pre_treatment + n_post_treatment, + n_covariates * (n_control + NUM_TREATED) + + NUM_NON_CONTROL_COLS - NUM_TREATED, + ) + + colnames = data._get_columns() + assert isinstance(colnames, list) + assert colnames[0] == "T" + assert len(colnames) == n_control + 1 + + index = data.full_index + assert isinstance(index, DatetimeIndex) + assert index[0].strftime("%Y-%m-%d") == data._start_date.strftime("%Y-%m-%d") + @given( n_control=st.integers(min_value=2, max_value=50), From 94cfe183ee363d525eefc1996b28133880ee29c8 Mon Sep 17 00:00:00 2001 From: Semih Akbayrak Date: Fri, 26 Sep 2025 11:35:47 +0000 Subject: [PATCH 6/7] Preserve covariates after transformation --- src/causal_validation/transforms/base.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/causal_validation/transforms/base.py b/src/causal_validation/transforms/base.py index 6ef7a97..03bd62e 100644 --- a/src/causal_validation/transforms/base.py +++ b/src/causal_validation/transforms/base.py @@ -75,7 +75,9 @@ def apply_values( Xte = Xte + post_intervention_vals[:, 1:] yte = yte + post_intervention_vals[:, :1] return Dataset( - Xtr, Xte, ytr, yte, data._start_date, data.counterfactual, data.synthetic + Xtr, Xte, ytr, yte, data._start_date, + data.Ptr, data.Pte, data.Rtr, data.Rte, + data.counterfactual, data.synthetic ) @@ -94,5 +96,7 @@ def apply_values( Xte = Xte * post_intervention_vals yte = yte * post_intervention_vals return Dataset( - Xtr, Xte, ytr, yte, data._start_date, data.counterfactual, data.synthetic + Xtr, Xte, ytr, yte, data._start_date, + data.Ptr, data.Pte, data.Rtr, data.Rte, + data.counterfactual, data.synthetic ) From 13ffdb539797d24f8e2015f1008fb563e9a3f2f0 Mon Sep 17 00:00:00 2001 From: Semih Akbayrak Date: Fri, 26 Sep 2025 15:33:23 +0000 Subject: [PATCH 7/7] Add noise transformation for covariates --- src/causal_validation/transforms/__init__.py | 4 +- src/causal_validation/transforms/base.py | 24 +++- src/causal_validation/transforms/noise.py | 34 ++++- src/causal_validation/transforms/parameter.py | 12 ++ src/causal_validation/transforms/periodic.py | 4 +- src/causal_validation/transforms/trends.py | 4 +- .../test_transforms/test_noise.py | 133 +++++++++++++++++- 7 files changed, 203 insertions(+), 12 deletions(-) diff --git a/src/causal_validation/transforms/__init__.py b/src/causal_validation/transforms/__init__.py index c5cf07f..44ac584 100644 --- a/src/causal_validation/transforms/__init__.py +++ b/src/causal_validation/transforms/__init__.py @@ -1,5 +1,5 @@ -from causal_validation.transforms.noise import Noise +from causal_validation.transforms.noise import Noise, CovariateNoise from causal_validation.transforms.periodic import Periodic from causal_validation.transforms.trends import Trend -__all__ = ["Trend", "Periodic", "Noise"] +__all__ = ["Trend", "Periodic", "Noise", "CovariateNoise"] diff --git a/src/causal_validation/transforms/base.py b/src/causal_validation/transforms/base.py index 03bd62e..c4258fe 100644 --- a/src/causal_validation/transforms/base.py +++ b/src/causal_validation/transforms/base.py @@ -61,7 +61,7 @@ def _get_parameter_values(self, data: Dataset) -> tp.Dict[str, np.ndarray]: @dataclass(kw_only=True) -class AdditiveTransform(AbstractTransform): +class AdditiveOutputTransform(AbstractTransform): def apply_values( self, pre_intervention_vals: np.ndarray, @@ -82,7 +82,7 @@ def apply_values( @dataclass(kw_only=True) -class MultiplicativeTransform(AbstractTransform): +class MultiplicativeOutputTransform(AbstractTransform): def apply_values( self, pre_intervention_vals: np.ndarray, @@ -100,3 +100,23 @@ def apply_values( data.Ptr, data.Pte, data.Rtr, data.Rte, data.counterfactual, data.synthetic ) + +@dataclass(kw_only=True) +class AdditiveCovariateTransform(AbstractTransform): + def apply_values( + self, + pre_intervention_vals: np.ndarray, + post_intervention_vals: np.ndarray, + data: Dataset, + ) -> Dataset: + Ptr, Rtr = [deepcopy(i) for i in data.pre_intervention_covariates] + Pte, Rte = [deepcopy(i) for i in data.post_intervention_covariates] + Ptr = Ptr + pre_intervention_vals[:, 1:, :] + Rtr = Rtr + pre_intervention_vals[:, :1, :] + Pte = Pte + post_intervention_vals[:, 1:, :] + Rte = Rte + post_intervention_vals[:, :1, :] + return Dataset( + data.Xtr, data.Xte, data.ytr, data.yte, + data._start_date, Ptr, Pte, Rtr, Rte, + data.counterfactual, data.synthetic + ) diff --git a/src/causal_validation/transforms/noise.py b/src/causal_validation/transforms/noise.py index 1116471..e251225 100644 --- a/src/causal_validation/transforms/noise.py +++ b/src/causal_validation/transforms/noise.py @@ -6,12 +6,18 @@ from scipy.stats import norm from causal_validation.data import Dataset -from causal_validation.transforms.base import AdditiveTransform -from causal_validation.transforms.parameter import TimeVaryingParameter +from causal_validation.transforms.base import ( + AdditiveOutputTransform, + AdditiveCovariateTransform +) +from causal_validation.transforms.parameter import ( + TimeVaryingParameter, + CovariateNoiseParameter +) @dataclass(kw_only=True) -class Noise(AdditiveTransform): +class Noise(AdditiveOutputTransform): """ Transform the treatment by adding TimeVaryingParameter noise terms sampled from a specified sampling distribution. By default, the sampling distribution is @@ -30,3 +36,25 @@ def get_values(self, data: Dataset) -> Float[np.ndarray, "N D"]: ).reshape(-1) noise[:, 0] = noise_treatment return noise + + +@dataclass(kw_only=True) +class CovariateNoise(AdditiveCovariateTransform): + """ + Transform the covariates by adding CovariateNoiseParameter noise terms sampled from + a specified sampling distribution. By default, the sampling distribution is + Normal with 0 loc and 0.1 scale. + """ + + noise_dist: CovariateNoiseParameter = field( + default_factory=lambda: CovariateNoiseParameter(sampling_dist=norm(0, 0.1)) + ) + _slots: Tuple[str] = ("noise_dist",) + + def get_values(self, data: Dataset) -> Float[np.ndarray, "N D"]: + noise = self.noise_dist.get_value( + n_units=data.n_units+1, + n_timepoints=data.n_timepoints, + n_covariates=data.n_covariates + ) + return noise diff --git a/src/causal_validation/transforms/parameter.py b/src/causal_validation/transforms/parameter.py index 5ff9362..5806b80 100644 --- a/src/causal_validation/transforms/parameter.py +++ b/src/causal_validation/transforms/parameter.py @@ -51,6 +51,18 @@ def get_value( return np.tile(time_param, reps=n_units) +@dataclass +class CovariateNoiseParameter(RandomParameter): + def get_value( + self, n_units: int, n_timepoints: int, n_covariates: int + ) -> Float[np.ndarray, "{n_timepoints} {n_units} {n_covariates}"]: + covariate_noise = self.sampling_dist.rvs( + size=(n_timepoints, n_units, n_covariates), + random_state=self.random_state + ) + return covariate_noise + + ParameterOrFloat = tp.Union[Parameter, float] diff --git a/src/causal_validation/transforms/periodic.py b/src/causal_validation/transforms/periodic.py index d9a1376..9283cb0 100644 --- a/src/causal_validation/transforms/periodic.py +++ b/src/causal_validation/transforms/periodic.py @@ -5,12 +5,12 @@ import numpy as np from causal_validation.data import Dataset -from causal_validation.transforms.base import AdditiveTransform +from causal_validation.transforms.base import AdditiveOutputTransform from causal_validation.transforms.parameter import ParameterOrFloat @dataclass(kw_only=True) -class Periodic(AdditiveTransform): +class Periodic(AdditiveOutputTransform): amplitude: ParameterOrFloat = 1.0 frequency: ParameterOrFloat = 1.0 shift: ParameterOrFloat = 0.0 diff --git a/src/causal_validation/transforms/trends.py b/src/causal_validation/transforms/trends.py index 56018d5..6143f83 100644 --- a/src/causal_validation/transforms/trends.py +++ b/src/causal_validation/transforms/trends.py @@ -5,12 +5,12 @@ import numpy as np from causal_validation.data import Dataset -from causal_validation.transforms.base import AdditiveTransform +from causal_validation.transforms.base import AdditiveOutputTransform from causal_validation.transforms.parameter import ParameterOrFloat @dataclass(kw_only=True) -class Trend(AdditiveTransform): +class Trend(AdditiveOutputTransform): degree: int = 1 coefficient: ParameterOrFloat = 1.0 intercept: ParameterOrFloat = 0.0 diff --git a/tests/test_causal_validation/test_transforms/test_noise.py b/tests/test_causal_validation/test_transforms/test_noise.py index 94608c3..8ffef8c 100644 --- a/tests/test_causal_validation/test_transforms/test_noise.py +++ b/tests/test_causal_validation/test_transforms/test_noise.py @@ -11,10 +11,14 @@ simulate_data, ) from causal_validation.transforms import ( + CovariateNoise, Noise, Trend, ) -from causal_validation.transforms.parameter import TimeVaryingParameter +from causal_validation.transforms.parameter import ( + CovariateNoiseParameter, + TimeVaryingParameter, +) CONSTANTS = TestConstants() DEFAULT_SEED = 123 @@ -125,3 +129,130 @@ def test_perturbation_impact( assert np.min(diff_te_list[0]) > np.min(diff_te_list[1]) assert np.max(diff_te_list[0]) < np.max(diff_te_list[2]) assert np.min(diff_te_list[0]) < np.min(diff_te_list[2]) + +# Covariate Noise Test +def test_cov_slot_type(): + noise_transform = CovariateNoise() + assert isinstance(noise_transform.noise_dist, CovariateNoiseParameter) + + +@given(n_covariates=st.integers(min_value=1, max_value=50)) +@settings(max_examples=5) +def test_output_covariate_transform(n_covariates:int): + CONSTANTS2 = TestConstants(N_COVARIATES=n_covariates) + base_data = simulate_data(GLOBAL_MEAN, DEFAULT_SEED, CONSTANTS2) + + covariate_noise_transform = CovariateNoise() + noisy_data = covariate_noise_transform(base_data) + + assert np.all(noisy_data.ytr == base_data.ytr) + assert np.all(noisy_data.yte == base_data.yte) + assert np.all(noisy_data.Xtr == base_data.Xtr) + assert np.all(noisy_data.Xte == base_data.Xte) + + diff_Ptr = (noisy_data.Ptr - base_data.Ptr).reshape(-1) + diff_Pte = (noisy_data.Pte - base_data.Pte).reshape(-1) + + assert np.all(diff_Ptr != diff_Pte) + + diff_Rtr = (noisy_data.Rtr - base_data.Rtr).reshape(-1) + diff_Rte = (noisy_data.Rte - base_data.Rte).reshape(-1) + + assert np.all(diff_Rtr != diff_Rte) + + diff_Ptr_permute = np.random.permutation(diff_Ptr) + diff_Pte_permute = np.random.permutation(diff_Pte) + diff_Rtr_permute = np.random.permutation(diff_Rtr) + diff_Rte_permute = np.random.permutation(diff_Rte) + + assert not np.all(diff_Ptr == diff_Ptr_permute) + assert not np.all(diff_Pte == diff_Pte_permute) + assert not np.all(diff_Rtr == diff_Rtr_permute) + assert not np.all(diff_Rte == diff_Rte_permute) + + +@given(n_covariates=st.integers(min_value=1, max_value=50)) +@settings(max_examples=5) +def test_cov_composite_transform(n_covariates:int): + CONSTANTS2 = TestConstants(N_COVARIATES=n_covariates) + base_data = simulate_data(GLOBAL_MEAN, DEFAULT_SEED, CONSTANTS2) + + covariate_noise_transform = CovariateNoise() + cov_noisy_data = covariate_noise_transform(base_data) + + noise_transform = Noise() + noisy_data = noise_transform(cov_noisy_data) + + assert np.all(noisy_data.Xtr == cov_noisy_data.Xtr) + assert np.all(noisy_data.Xte == cov_noisy_data.Xte) + assert np.all(noisy_data.Ptr == cov_noisy_data.Ptr) + assert np.all(noisy_data.Pte == cov_noisy_data.Pte) + assert np.all(noisy_data.Rtr == cov_noisy_data.Rtr) + assert np.all(noisy_data.Rte == cov_noisy_data.Rte) + assert np.all(noisy_data.ytr != cov_noisy_data.ytr) + assert np.all(noisy_data.yte != cov_noisy_data.yte) + + +@given( + loc_large=st.floats(min_value=10.0, max_value=15.0), + loc_small=st.floats(min_value=-2.5, max_value=2.5), + scale_large=st.floats(min_value=10.0, max_value=15.0), + scale_small=st.floats(min_value=0.1, max_value=1.0), + n_covariates=st.integers(min_value=1, max_value=50), +) +@settings(max_examples=5) +def test_cov_perturbation_impact( + loc_large: float, + loc_small: float, + scale_large: float, + scale_small: float, + n_covariates:int +): + CONSTANTS2 = TestConstants(N_COVARIATES=n_covariates) + base_data = simulate_data(GLOBAL_MEAN, DEFAULT_SEED, CONSTANTS2) + + noise_transform1 = CovariateNoise( + noise_dist=CovariateNoiseParameter(sampling_dist=norm(loc_small, scale_small)) + ) + noise_transform2 = CovariateNoise( + noise_dist=CovariateNoiseParameter(sampling_dist=norm(loc_small, scale_large)) + ) + noise_transform3 = CovariateNoise( + noise_dist=CovariateNoiseParameter(sampling_dist=norm(loc_large, scale_small)) + ) + + noise_transforms = [noise_transform1, noise_transform2, noise_transform3] + + diff_Rtr_list, diff_Rte_list = [], [] + diff_Ptr_list, diff_Pte_list = [], [] + + for noise_transform in noise_transforms: + noisy_data = noise_transform(base_data) + diff_Rtr = noisy_data.Rtr - base_data.Rtr + diff_Rte = noisy_data.Rte - base_data.Rte + diff_Ptr = noisy_data.Ptr - base_data.Ptr + diff_Pte = noisy_data.Pte - base_data.Pte + diff_Rtr_list.append(diff_Rtr) + diff_Rte_list.append(diff_Rte) + diff_Ptr_list.append(diff_Ptr) + diff_Pte_list.append(diff_Pte) + + assert np.max(diff_Rtr_list[0]) < np.max(diff_Rtr_list[1]) + assert np.min(diff_Rtr_list[0]) > np.min(diff_Rtr_list[1]) + assert np.max(diff_Rtr_list[0]) < np.max(diff_Rtr_list[2]) + assert np.min(diff_Rtr_list[0]) < np.min(diff_Rtr_list[2]) + + assert np.max(diff_Rte_list[0]) < np.max(diff_Rte_list[1]) + assert np.min(diff_Rte_list[0]) > np.min(diff_Rte_list[1]) + assert np.max(diff_Rte_list[0]) < np.max(diff_Rte_list[2]) + assert np.min(diff_Rte_list[0]) < np.min(diff_Rte_list[2]) + + assert np.max(diff_Ptr_list[0]) < np.max(diff_Ptr_list[1]) + assert np.min(diff_Ptr_list[0]) > np.min(diff_Ptr_list[1]) + assert np.max(diff_Ptr_list[0]) < np.max(diff_Ptr_list[2]) + assert np.min(diff_Ptr_list[0]) < np.min(diff_Ptr_list[2]) + + assert np.max(diff_Pte_list[0]) < np.max(diff_Pte_list[1]) + assert np.min(diff_Pte_list[0]) > np.min(diff_Pte_list[1]) + assert np.max(diff_Pte_list[0]) < np.max(diff_Pte_list[2]) + assert np.min(diff_Pte_list[0]) < np.min(diff_Pte_list[2])