diff --git a/src/causal_validation/data.py b/src/causal_validation/data.py index a81f13a..e7922a8 100644 --- a/src/causal_validation/data.py +++ b/src/causal_validation/data.py @@ -62,16 +62,36 @@ def __post_init__(self): def to_df( self, index_start: str = dt.date(year=2023, month=1, day=1) - ) -> pd.DataFrame: - inputs = np.vstack([self.Xtr, self.Xte]) - outputs = np.vstack([self.ytr, self.yte]) - data = np.hstack([outputs, inputs]) + ) -> tp.Tuple[pd.DataFrame, tp.Optional[pd.DataFrame]]: + control_outputs = np.vstack([self.Xtr, self.Xte]) + treated_outputs = np.vstack([self.ytr, self.yte]) + data = np.hstack([treated_outputs, control_outputs]) index = self._get_index(index_start) colnames = self._get_columns() indicator = self._get_indicator() - df = pd.DataFrame(data, index=index, columns=colnames) - df = df.assign(treated=indicator) - return df + df_outputs = pd.DataFrame(data, index=index, columns=colnames) + df_outputs = df_outputs.assign(treated=indicator) + + if not self.has_covariates: + cov_df = None + else: + control_covs = np.concatenate([self.Ptr, self.Pte], axis=0) + treated_covs = np.concatenate([self.Rtr, self.Rte], axis=0) + + all_covs = np.concatenate([treated_covs, control_covs], axis=1) + + unit_cols = self._get_columns() + covariate_cols = [f"F{i}" for i in range(self.n_covariates)] + + cov_data = all_covs.reshape(self.n_timepoints, -1) + + col_tuples = [(unit, cov) for unit in unit_cols for cov in covariate_cols] + multi_cols = pd.MultiIndex.from_tuples(col_tuples) + + cov_df = pd.DataFrame(cov_data, index=index, columns=multi_cols) + cov_df = cov_df.assign(treated=indicator) + + return df_outputs, cov_df @property def n_post_intervention(self) -> int: @@ -180,12 +200,7 @@ def get_index(self, period: InterventionTypes) -> DatetimeIndex: return self.full_index def _get_columns(self) -> tp.List[str]: - if self.has_covariates: - colnames = ["T"] + [f"C{i}" for i in range(self.n_units)] + [ - f"F{i}" for i in range(self.n_covariates) - ] - else: - colnames = ["T"] + [f"C{i}" for i in range(self.n_units)] + colnames = ["T"] + [f"C{i}" for i in range(self.n_units)] return colnames def _get_index(self, start_date: dt.date) -> DatetimeIndex: @@ -224,7 +239,8 @@ def __eq__(self, other: Dataset) -> bool: def to_azcausal(self): time_index = np.arange(self.n_timepoints) - data = self.to_df().assign(time=time_index).melt(id_vars=["time", "treated"]) + data_df, _ = self.to_df() + data = data_df.assign(time=time_index).melt(id_vars=["time", "treated"]) data.loc[:, "treated"] = np.where( (data["variable"] == "T") & (data["treated"] == 1.0), 1, 0 ) diff --git a/src/causal_validation/testing.py b/src/causal_validation/testing.py index a1bcb22..56a1561 100644 --- a/src/causal_validation/testing.py +++ b/src/causal_validation/testing.py @@ -11,6 +11,7 @@ class TestConstants: N_CONTROL: int = 10 N_PRE_TREATMENT: int = 500 N_POST_TREATMENT: int = 500 + N_COVARIATES: tp.Optional[int] = None DATA_SLOTS: tp.Tuple[str, str, str, str] = ("Xtr", "Xte", "ytr", "yte") ZERO_DIVISION_ERROR: float = 1e-6 GLOBAL_SCALE: float = 1.0 @@ -26,6 +27,7 @@ def simulate_data( n_control_units=constants.N_CONTROL, n_pre_intervention_timepoints=constants.N_PRE_TREATMENT, n_post_intervention_timepoints=constants.N_POST_TREATMENT, + n_covariates=constants.N_COVARIATES, global_mean=global_mean, global_scale=constants.GLOBAL_SCALE, seed=seed, diff --git a/src/causal_validation/transforms/__init__.py b/src/causal_validation/transforms/__init__.py index c5cf07f..44ac584 100644 --- a/src/causal_validation/transforms/__init__.py +++ b/src/causal_validation/transforms/__init__.py @@ -1,5 +1,5 @@ -from causal_validation.transforms.noise import Noise +from causal_validation.transforms.noise import Noise, CovariateNoise from causal_validation.transforms.periodic import Periodic from causal_validation.transforms.trends import Trend -__all__ = ["Trend", "Periodic", "Noise"] +__all__ = ["Trend", "Periodic", "Noise", "CovariateNoise"] diff --git a/src/causal_validation/transforms/base.py b/src/causal_validation/transforms/base.py index 6ef7a97..c4258fe 100644 --- a/src/causal_validation/transforms/base.py +++ b/src/causal_validation/transforms/base.py @@ -61,7 +61,7 @@ def _get_parameter_values(self, data: Dataset) -> tp.Dict[str, np.ndarray]: @dataclass(kw_only=True) -class AdditiveTransform(AbstractTransform): +class AdditiveOutputTransform(AbstractTransform): def apply_values( self, pre_intervention_vals: np.ndarray, @@ -75,12 +75,14 @@ def apply_values( Xte = Xte + post_intervention_vals[:, 1:] yte = yte + post_intervention_vals[:, :1] return Dataset( - Xtr, Xte, ytr, yte, data._start_date, data.counterfactual, data.synthetic + Xtr, Xte, ytr, yte, data._start_date, + data.Ptr, data.Pte, data.Rtr, data.Rte, + data.counterfactual, data.synthetic ) @dataclass(kw_only=True) -class MultiplicativeTransform(AbstractTransform): +class MultiplicativeOutputTransform(AbstractTransform): def apply_values( self, pre_intervention_vals: np.ndarray, @@ -94,5 +96,27 @@ def apply_values( Xte = Xte * post_intervention_vals yte = yte * post_intervention_vals return Dataset( - Xtr, Xte, ytr, yte, data._start_date, data.counterfactual, data.synthetic + Xtr, Xte, ytr, yte, data._start_date, + data.Ptr, data.Pte, data.Rtr, data.Rte, + data.counterfactual, data.synthetic + ) + +@dataclass(kw_only=True) +class AdditiveCovariateTransform(AbstractTransform): + def apply_values( + self, + pre_intervention_vals: np.ndarray, + post_intervention_vals: np.ndarray, + data: Dataset, + ) -> Dataset: + Ptr, Rtr = [deepcopy(i) for i in data.pre_intervention_covariates] + Pte, Rte = [deepcopy(i) for i in data.post_intervention_covariates] + Ptr = Ptr + pre_intervention_vals[:, 1:, :] + Rtr = Rtr + pre_intervention_vals[:, :1, :] + Pte = Pte + post_intervention_vals[:, 1:, :] + Rte = Rte + post_intervention_vals[:, :1, :] + return Dataset( + data.Xtr, data.Xte, data.ytr, data.yte, + data._start_date, Ptr, Pte, Rtr, Rte, + data.counterfactual, data.synthetic ) diff --git a/src/causal_validation/transforms/noise.py b/src/causal_validation/transforms/noise.py index 1116471..e251225 100644 --- a/src/causal_validation/transforms/noise.py +++ b/src/causal_validation/transforms/noise.py @@ -6,12 +6,18 @@ from scipy.stats import norm from causal_validation.data import Dataset -from causal_validation.transforms.base import AdditiveTransform -from causal_validation.transforms.parameter import TimeVaryingParameter +from causal_validation.transforms.base import ( + AdditiveOutputTransform, + AdditiveCovariateTransform +) +from causal_validation.transforms.parameter import ( + TimeVaryingParameter, + CovariateNoiseParameter +) @dataclass(kw_only=True) -class Noise(AdditiveTransform): +class Noise(AdditiveOutputTransform): """ Transform the treatment by adding TimeVaryingParameter noise terms sampled from a specified sampling distribution. By default, the sampling distribution is @@ -30,3 +36,25 @@ def get_values(self, data: Dataset) -> Float[np.ndarray, "N D"]: ).reshape(-1) noise[:, 0] = noise_treatment return noise + + +@dataclass(kw_only=True) +class CovariateNoise(AdditiveCovariateTransform): + """ + Transform the covariates by adding CovariateNoiseParameter noise terms sampled from + a specified sampling distribution. By default, the sampling distribution is + Normal with 0 loc and 0.1 scale. + """ + + noise_dist: CovariateNoiseParameter = field( + default_factory=lambda: CovariateNoiseParameter(sampling_dist=norm(0, 0.1)) + ) + _slots: Tuple[str] = ("noise_dist",) + + def get_values(self, data: Dataset) -> Float[np.ndarray, "N D"]: + noise = self.noise_dist.get_value( + n_units=data.n_units+1, + n_timepoints=data.n_timepoints, + n_covariates=data.n_covariates + ) + return noise diff --git a/src/causal_validation/transforms/parameter.py b/src/causal_validation/transforms/parameter.py index 5ff9362..5806b80 100644 --- a/src/causal_validation/transforms/parameter.py +++ b/src/causal_validation/transforms/parameter.py @@ -51,6 +51,18 @@ def get_value( return np.tile(time_param, reps=n_units) +@dataclass +class CovariateNoiseParameter(RandomParameter): + def get_value( + self, n_units: int, n_timepoints: int, n_covariates: int + ) -> Float[np.ndarray, "{n_timepoints} {n_units} {n_covariates}"]: + covariate_noise = self.sampling_dist.rvs( + size=(n_timepoints, n_units, n_covariates), + random_state=self.random_state + ) + return covariate_noise + + ParameterOrFloat = tp.Union[Parameter, float] diff --git a/src/causal_validation/transforms/periodic.py b/src/causal_validation/transforms/periodic.py index d9a1376..9283cb0 100644 --- a/src/causal_validation/transforms/periodic.py +++ b/src/causal_validation/transforms/periodic.py @@ -5,12 +5,12 @@ import numpy as np from causal_validation.data import Dataset -from causal_validation.transforms.base import AdditiveTransform +from causal_validation.transforms.base import AdditiveOutputTransform from causal_validation.transforms.parameter import ParameterOrFloat @dataclass(kw_only=True) -class Periodic(AdditiveTransform): +class Periodic(AdditiveOutputTransform): amplitude: ParameterOrFloat = 1.0 frequency: ParameterOrFloat = 1.0 shift: ParameterOrFloat = 0.0 diff --git a/src/causal_validation/transforms/trends.py b/src/causal_validation/transforms/trends.py index 56018d5..6143f83 100644 --- a/src/causal_validation/transforms/trends.py +++ b/src/causal_validation/transforms/trends.py @@ -5,12 +5,12 @@ import numpy as np from causal_validation.data import Dataset -from causal_validation.transforms.base import AdditiveTransform +from causal_validation.transforms.base import AdditiveOutputTransform from causal_validation.transforms.parameter import ParameterOrFloat @dataclass(kw_only=True) -class Trend(AdditiveTransform): +class Trend(AdditiveOutputTransform): degree: int = 1 coefficient: ParameterOrFloat = 1.0 intercept: ParameterOrFloat = 0.0 diff --git a/tests/test_causal_validation/test_data.py b/tests/test_causal_validation/test_data.py index 1b7705c..28b4638 100644 --- a/tests/test_causal_validation/test_data.py +++ b/tests/test_causal_validation/test_data.py @@ -29,6 +29,7 @@ MAX_STRING_LENGTH = 20 DEFAULT_SEED = 123 NUM_NON_CONTROL_COLS = 2 +NUM_TREATED = 1 LARGE_N_POST = 5000 LARGE_N_PRE = 5000 @@ -109,7 +110,7 @@ def test_indicator(n_pre_treatment: int, n_post_treatment: int): n_pre_treatment=st.integers(min_value=1, max_value=50), n_post_treatment=st.integers(min_value=1, max_value=50), ) -def test_to_df(n_control: int, n_pre_treatment: int, n_post_treatment: int): +def test_to_df_no_cov(n_control: int, n_pre_treatment: int, n_post_treatment: int): constants = TestConstants( N_POST_TREATMENT=n_post_treatment, N_PRE_TREATMENT=n_pre_treatment, @@ -117,7 +118,7 @@ def test_to_df(n_control: int, n_pre_treatment: int, n_post_treatment: int): ) data = simulate_data(0.0, DEFAULT_SEED, constants=constants) - df = data.to_df() + df, _ = data.to_df() assert isinstance(df, pd.DataFrame) assert df.shape == ( n_pre_treatment + n_post_treatment, @@ -133,6 +134,47 @@ def test_to_df(n_control: int, n_pre_treatment: int, n_post_treatment: int): assert isinstance(index, DatetimeIndex) assert index[0].strftime("%Y-%m-%d") == data._start_date.strftime("%Y-%m-%d") +@given( + n_control=st.integers(min_value=1, max_value=50), + n_pre_treatment=st.integers(min_value=1, max_value=50), + n_post_treatment=st.integers(min_value=1, max_value=50), + n_covariates=st.integers(min_value=1, max_value=50), +) +def test_to_df_with_cov(n_control: int, + n_pre_treatment: int, + n_post_treatment: int, + n_covariates:int): + constants = TestConstants( + N_POST_TREATMENT=n_post_treatment, + N_PRE_TREATMENT=n_pre_treatment, + N_CONTROL=n_control, + N_COVARIATES=n_covariates, + ) + data = simulate_data(0.0, DEFAULT_SEED, constants=constants) + + df_outs, df_covs = data.to_df() + assert isinstance(df_outs, pd.DataFrame) + assert df_outs.shape == ( + n_pre_treatment + n_post_treatment, + n_control + NUM_NON_CONTROL_COLS, + ) + + assert isinstance(df_covs, pd.DataFrame) + assert df_covs.shape == ( + n_pre_treatment + n_post_treatment, + n_covariates * (n_control + NUM_TREATED) + + NUM_NON_CONTROL_COLS - NUM_TREATED, + ) + + colnames = data._get_columns() + assert isinstance(colnames, list) + assert colnames[0] == "T" + assert len(colnames) == n_control + 1 + + index = data.full_index + assert isinstance(index, DatetimeIndex) + assert index[0].strftime("%Y-%m-%d") == data._start_date.strftime("%Y-%m-%d") + @given( n_control=st.integers(min_value=2, max_value=50), diff --git a/tests/test_causal_validation/test_transforms/test_noise.py b/tests/test_causal_validation/test_transforms/test_noise.py index 94608c3..8ffef8c 100644 --- a/tests/test_causal_validation/test_transforms/test_noise.py +++ b/tests/test_causal_validation/test_transforms/test_noise.py @@ -11,10 +11,14 @@ simulate_data, ) from causal_validation.transforms import ( + CovariateNoise, Noise, Trend, ) -from causal_validation.transforms.parameter import TimeVaryingParameter +from causal_validation.transforms.parameter import ( + CovariateNoiseParameter, + TimeVaryingParameter, +) CONSTANTS = TestConstants() DEFAULT_SEED = 123 @@ -125,3 +129,130 @@ def test_perturbation_impact( assert np.min(diff_te_list[0]) > np.min(diff_te_list[1]) assert np.max(diff_te_list[0]) < np.max(diff_te_list[2]) assert np.min(diff_te_list[0]) < np.min(diff_te_list[2]) + +# Covariate Noise Test +def test_cov_slot_type(): + noise_transform = CovariateNoise() + assert isinstance(noise_transform.noise_dist, CovariateNoiseParameter) + + +@given(n_covariates=st.integers(min_value=1, max_value=50)) +@settings(max_examples=5) +def test_output_covariate_transform(n_covariates:int): + CONSTANTS2 = TestConstants(N_COVARIATES=n_covariates) + base_data = simulate_data(GLOBAL_MEAN, DEFAULT_SEED, CONSTANTS2) + + covariate_noise_transform = CovariateNoise() + noisy_data = covariate_noise_transform(base_data) + + assert np.all(noisy_data.ytr == base_data.ytr) + assert np.all(noisy_data.yte == base_data.yte) + assert np.all(noisy_data.Xtr == base_data.Xtr) + assert np.all(noisy_data.Xte == base_data.Xte) + + diff_Ptr = (noisy_data.Ptr - base_data.Ptr).reshape(-1) + diff_Pte = (noisy_data.Pte - base_data.Pte).reshape(-1) + + assert np.all(diff_Ptr != diff_Pte) + + diff_Rtr = (noisy_data.Rtr - base_data.Rtr).reshape(-1) + diff_Rte = (noisy_data.Rte - base_data.Rte).reshape(-1) + + assert np.all(diff_Rtr != diff_Rte) + + diff_Ptr_permute = np.random.permutation(diff_Ptr) + diff_Pte_permute = np.random.permutation(diff_Pte) + diff_Rtr_permute = np.random.permutation(diff_Rtr) + diff_Rte_permute = np.random.permutation(diff_Rte) + + assert not np.all(diff_Ptr == diff_Ptr_permute) + assert not np.all(diff_Pte == diff_Pte_permute) + assert not np.all(diff_Rtr == diff_Rtr_permute) + assert not np.all(diff_Rte == diff_Rte_permute) + + +@given(n_covariates=st.integers(min_value=1, max_value=50)) +@settings(max_examples=5) +def test_cov_composite_transform(n_covariates:int): + CONSTANTS2 = TestConstants(N_COVARIATES=n_covariates) + base_data = simulate_data(GLOBAL_MEAN, DEFAULT_SEED, CONSTANTS2) + + covariate_noise_transform = CovariateNoise() + cov_noisy_data = covariate_noise_transform(base_data) + + noise_transform = Noise() + noisy_data = noise_transform(cov_noisy_data) + + assert np.all(noisy_data.Xtr == cov_noisy_data.Xtr) + assert np.all(noisy_data.Xte == cov_noisy_data.Xte) + assert np.all(noisy_data.Ptr == cov_noisy_data.Ptr) + assert np.all(noisy_data.Pte == cov_noisy_data.Pte) + assert np.all(noisy_data.Rtr == cov_noisy_data.Rtr) + assert np.all(noisy_data.Rte == cov_noisy_data.Rte) + assert np.all(noisy_data.ytr != cov_noisy_data.ytr) + assert np.all(noisy_data.yte != cov_noisy_data.yte) + + +@given( + loc_large=st.floats(min_value=10.0, max_value=15.0), + loc_small=st.floats(min_value=-2.5, max_value=2.5), + scale_large=st.floats(min_value=10.0, max_value=15.0), + scale_small=st.floats(min_value=0.1, max_value=1.0), + n_covariates=st.integers(min_value=1, max_value=50), +) +@settings(max_examples=5) +def test_cov_perturbation_impact( + loc_large: float, + loc_small: float, + scale_large: float, + scale_small: float, + n_covariates:int +): + CONSTANTS2 = TestConstants(N_COVARIATES=n_covariates) + base_data = simulate_data(GLOBAL_MEAN, DEFAULT_SEED, CONSTANTS2) + + noise_transform1 = CovariateNoise( + noise_dist=CovariateNoiseParameter(sampling_dist=norm(loc_small, scale_small)) + ) + noise_transform2 = CovariateNoise( + noise_dist=CovariateNoiseParameter(sampling_dist=norm(loc_small, scale_large)) + ) + noise_transform3 = CovariateNoise( + noise_dist=CovariateNoiseParameter(sampling_dist=norm(loc_large, scale_small)) + ) + + noise_transforms = [noise_transform1, noise_transform2, noise_transform3] + + diff_Rtr_list, diff_Rte_list = [], [] + diff_Ptr_list, diff_Pte_list = [], [] + + for noise_transform in noise_transforms: + noisy_data = noise_transform(base_data) + diff_Rtr = noisy_data.Rtr - base_data.Rtr + diff_Rte = noisy_data.Rte - base_data.Rte + diff_Ptr = noisy_data.Ptr - base_data.Ptr + diff_Pte = noisy_data.Pte - base_data.Pte + diff_Rtr_list.append(diff_Rtr) + diff_Rte_list.append(diff_Rte) + diff_Ptr_list.append(diff_Ptr) + diff_Pte_list.append(diff_Pte) + + assert np.max(diff_Rtr_list[0]) < np.max(diff_Rtr_list[1]) + assert np.min(diff_Rtr_list[0]) > np.min(diff_Rtr_list[1]) + assert np.max(diff_Rtr_list[0]) < np.max(diff_Rtr_list[2]) + assert np.min(diff_Rtr_list[0]) < np.min(diff_Rtr_list[2]) + + assert np.max(diff_Rte_list[0]) < np.max(diff_Rte_list[1]) + assert np.min(diff_Rte_list[0]) > np.min(diff_Rte_list[1]) + assert np.max(diff_Rte_list[0]) < np.max(diff_Rte_list[2]) + assert np.min(diff_Rte_list[0]) < np.min(diff_Rte_list[2]) + + assert np.max(diff_Ptr_list[0]) < np.max(diff_Ptr_list[1]) + assert np.min(diff_Ptr_list[0]) > np.min(diff_Ptr_list[1]) + assert np.max(diff_Ptr_list[0]) < np.max(diff_Ptr_list[2]) + assert np.min(diff_Ptr_list[0]) < np.min(diff_Ptr_list[2]) + + assert np.max(diff_Pte_list[0]) < np.max(diff_Pte_list[1]) + assert np.min(diff_Pte_list[0]) > np.min(diff_Pte_list[1]) + assert np.max(diff_Pte_list[0]) < np.max(diff_Pte_list[2]) + assert np.min(diff_Pte_list[0]) < np.min(diff_Pte_list[2])