diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 0000000..4996f57 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,11 @@ +## Checklist + +- [ ] I've formatted the new code by running `hatch run dev:format` before committing. +- [ ] I've added tests for new code. +- [ ] I've added docstrings for the new code. + +## Description + +Please describe your changes here. If this fixes a bug, please link to the issue, if possible. + +Issue Number: N/A \ No newline at end of file diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml new file mode 100644 index 0000000..2539eeb --- /dev/null +++ b/.github/workflows/ruff.yml @@ -0,0 +1,12 @@ +name: Check linting +on: + pull_request: + push: + branches: + - main +jobs: + ruff: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3.5.2 + - uses: chartboost/ruff-action@v1 \ No newline at end of file diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..a15ee52 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,34 @@ +name: Run Tests +on: + pull_request: + push: + branches: + - main + +jobs: + unit-tests: + name: Run Tests + runs-on: ubuntu-latest + strategy: + matrix: + # Select the Python versions to test against + os: ["ubuntu-latest", "macos-latest"] + python-version: ["3.10", "3.11"] + fail-fast: true + steps: + - name: Check out the code + uses: actions/checkout@v3.5.2 + with: + fetch-depth: 1 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + # Install Hatch + - name: Install Hatch + uses: pypa/hatch@install + + # Run the unit tests and build the coverage report + - name: Run Tests + run: hatch run dev:test \ No newline at end of file diff --git a/README.md b/README.md index 847260c..c6ff01a 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,64 @@ -## My Project +# SyntheticCausalDataGen -TODO: Fill this README out! +This package provides functionality to define your own causal data generation process and then simulate data from the process. Within the package, there is functionality to include complex components to your process, such as periodic and temporal trends, and all of these operations are fully composable with one another. -Be sure to: +A short example is given below +```python +from causal_validation import Config, simulate +from causal_validation.effects import StaticEffect +from causal_validation.plotters import plot +from causal_validation.transforms import Trend, Periodic +from causal_validation.transforms.parameter import UnitVaryingParameter +from scipy.stats import norm -* Change the title in this README -* Edit your repository description on GitHub +cfg = Config( + n_control_units=10, + n_pre_intervention_timepoints=60, + n_post_intervention_timepoints=30, +) -## Security +# Simulate the base observation +base_data = simulate(cfg) -See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. +# Apply a linear trend with unit-varying intercept +intercept = UnitVaryingParameter(sampling_dist = norm(0, 1)) +trend_component = Trend(degree=1, coefficient=0.1, intercept=intercept) +trended_data = trend_component(base_data) -## License +# Simulate a 5% lift in the treated unit's post-intervention data +effect = StaticEffect(0.05) +inflated_data = effect(trended_data) -This project is licensed under the Apache-2.0 License. +# Plot your data +plot(inflated_data) +``` + +## Examples + +To supplement the above example, we have two more detailed notebooks which exhaustively present and explain the functionalty in this package, along with how the generated data may be integrated with [AZCausal](https://github.com/amazon-science/azcausal). +1. [Basic notebook](): We here show the full range of available functions for data generation +2. [AZCausal notebook](): We here show how the generated data may be used within an AZCausal model. + +## Installation + +In this section we guide the user through the installation of this package. We distinguish here between _users_ of the package who seek to define their own data generating processes, and _developers_ who wish to extend the existing functionality of the package. + +### Prerequisites + +- Python 3.10 or higher +- [Poetry](https://python-poetry.org/) (optional, but recommended) + +### For Users + +1. It's strongly recommended to use a virtual environment. Create and activate one using your preferred method before proceeding with the installation. +2. Clone the package `git clone git@github.com:amazon-science/causal-validation.git` +3. Enter the package's root directory `cd SyntheticCausalDataGen` +4. Install the package `pip install -e .` + +### For Developers + +1. Follow steps 1-3 from `For Users` +2. Create a hatch environment `hatch env create` +3. Open a hatch shell `hatch shell` +4. Validate your installation by running `hatch run tests:test` diff --git a/examples/azcausal.pct.py b/examples/azcausal.pct.py new file mode 100644 index 0000000..1ac2bbb --- /dev/null +++ b/examples/azcausal.pct.py @@ -0,0 +1,113 @@ +# %% +from azcausal.estimators.panel.sdid import SDID +import scipy.stats as st + +from causal_validation import ( + Config, + simulate, +) +from causal_validation.effects import StaticEffect +from causal_validation.plotters import plot +from causal_validation.transforms import ( + Periodic, + Trend, +) +from causal_validation.transforms.parameter import UnitVaryingParameter + +# %% [markdown] +# ## AZCausal Integration +# +# Amazon's [AZCausal](https://github.com/amazon-science/azcausal) library provides the +# functionality to fit synthetic control and difference-in-difference models to your +# data. Integrating the synthetic data generating process of `causal_validation` with +# AZCausal is trivial, as we show in this notebook. To start, we'll simulate a toy +# dataset. + +# %% +cfg = Config( + n_control_units=10, + n_pre_intervention_timepoints=60, + n_post_intervention_timepoints=30, + seed=123, +) + +linear_trend = Trend(degree=1, coefficient=0.05) +data = linear_trend(simulate(cfg)) +plot(data) + +# %% We'll now simulate a 5% lift in the treatment group's observations. This [markdown] +# will inflate the treated group's observations in the post-intervention window. + +# %% +TRUE_EFFECT = 0.05 +effect = StaticEffect(effect=TRUE_EFFECT) +inflated_data = effect(data) +plot(inflated_data) + +# %% [markdown] +# ### Fitting a model +# +# We now have some very toy data on which we may apply a model. For this demonstration +# we shall use the Synthetic Difference-in-Differences model implemented in AZCausal; +# however, the approach shown here will work for any model implemented in AZCausal. To +# achieve this, we must first coerce the data into a format that is digestible for +# AZCausal. Through the `.to_azcausal()` method implemented here, this is +# straightforward to achieve. Once we have a AZCausal compatible dataset, the modelling +# is very simple by virtue of the clean design of AZCausal. + +# %% +panel = inflated_data.to_azcausal() +model = SDID() +result = model.fit(panel) +print(f"Delta: {TRUE_EFFECT - result.effect.percentage().value / 100}") +print(result.summary(title="Synthetic Data Experiment")) + +# %% We see that SDID has done an excellent job of estimating the treatment [markdown] +# effect. However, given the simplicity of the data, this is not surprising. With the +# functionality within this package though we can easily construct more complex datasets +# in effort to fully stress-test any new model and identify its limitations. +# +# To achieve this, we'll simulate 10 control units, 60 pre-intervention time points, and +# 30 post-intervention time points according to the following process: $$ \begin{align} +# \mu_{n, t} & \sim\mathcal{N}(20, 0.5^2)\\ +# \alpha_{n} & \sim \mathcal{N}(0, 1^2)\\ +# \beta_{n} & \sim \mathcal{N}(0.05, 0.01^2)\\ +# \nu_n & \sim \mathcal{N}(1, 1^2)\\ +# \gamma_n & \sim \operatorname{Student-t}_{10}(1, 1^2)\\ +# \mathbf{Y}_{n, t} & = \mu_{n, t} + \alpha_{n} + \beta_{n}t + \nu_n\sin\left(3\times +# 2\pi t + \gamma\right) + \delta_{t, n} \end{align} $$ where the true treatment effect +# $\delta_{t, n}$ is 5% when $n=1$ and $t\geq 60$ and 0 otherwise. Meanwhile, +# $\mathbf{Y}$ is the matrix of observations, long in the number of time points and wide +# in the number of units. + +# %% +cfg = Config( + n_control_units=10, + n_pre_intervention_timepoints=60, + n_post_intervention_timepoints=30, + global_mean=20, + global_scale=1, + seed=123, +) + +intercept = UnitVaryingParameter(sampling_dist=st.norm(loc=0.0, scale=1)) +coefficient = UnitVaryingParameter(sampling_dist=st.norm(loc=0.05, scale=0.01)) +linear_trend = Trend(degree=1, coefficient=coefficient, intercept=intercept) + +amplitude = UnitVaryingParameter(sampling_dist=st.norm(loc=1.0, scale=2)) +shift = UnitVaryingParameter(sampling_dist=st.t(df=10)) +periodic = Periodic(amplitude=amplitude, shift=shift, frequency=3) + +data = effect(periodic(linear_trend(simulate(cfg)))) +plot(data) + +# %% As before, we may now go about estimating the treatment. However, this [markdown] +# time we see that the delta between the estaimted and true effect is much larger than +# before. + +# %% +panel = data.to_azcausal() +model = SDID() +result = model.fit(panel) +print(f"Delta: {100*(TRUE_EFFECT - result.effect.percentage().value / 100): .2f}%") +print(result.summary(title="Synthetic Data Experiment")) diff --git a/examples/basic.pct.py b/examples/basic.pct.py new file mode 100644 index 0000000..60c5217 --- /dev/null +++ b/examples/basic.pct.py @@ -0,0 +1,169 @@ +# %% +from itertools import product + +import matplotlib.pyplot as plt +import numpy as np +from scipy.stats import ( + norm, + poisson, +) + +from causal_validation import ( + Config, + simulate, +) +from causal_validation.effects import StaticEffect +from causal_validation.plotters import plot +from causal_validation.transforms import ( + Periodic, + Trend, +) +from causal_validation.transforms.parameter import UnitVaryingParameter + +# %% [markdown] +# ## Simulating a Dataset + +# %% Simulating a dataset is as simple as specifying a `Config` object and [markdown] +# then invoking the `simulate` function. Once simulated, we may visualise the data +# through the `plot` function. + +# %% +cfg = Config( + n_control_units=10, + n_pre_intervention_timepoints=60, + n_post_intervention_timepoints=30, + seed=123, +) + +data = simulate(cfg) +plot(data) + +# %% [markdown] +# ### Controlling baseline behaviour +# +# We observe that we have 10 control units, each of which were sampled from a Gaussian +# distribution with mean 20 and scale 0.2. Had we wished for our underlying observations +# to have more or less noise, or to have a different global mean, then we can simply +# specify that through the config file. + +# %% +means = [10, 50] +scales = [0.1, 0.5] + +fig, axes = plt.subplots(ncols=2, nrows=2, figsize=(10, 6), tight_layout=True) +for (m, s), ax in zip(product(means, scales), axes.ravel()): + cfg = Config( + n_control_units=10, + n_pre_intervention_timepoints=60, + n_post_intervention_timepoints=30, + global_mean=m, + global_scale=s, + ) + data = simulate(cfg) + plot(data, ax=ax, title=f"Mean: {m}, Scale: {s}") + +# %% [markdown] +# ### Reproducibility +# +# In the above four panels, we can see that whilst the mean and scale of the underlying +# data generating process is varying, the functional form of the data is the same. This +# is by design to ensure that data sampling is reproducible. To sample a new dataset, +# you may either change the underlying seed in the config file. + +# %% +cfg = Config( + n_control_units=10, + n_pre_intervention_timepoints=60, + n_post_intervention_timepoints=30, + seed=42, +) + +# %% [markdown] +# Reusing the same config file across simulations + +# %% +fig, axes = plt.subplots(ncols=2, figsize=(10, 3)) +for ax in axes: + data = simulate(cfg) + plot(data, ax=ax) + +# %% [markdown] +# Or manually specifying and passing your own pseudorandom number generator key + +# %% + +rng = np.random.RandomState(42) + +fig, axes = plt.subplots(ncols=2, figsize=(10, 3)) +for ax in axes: + data = simulate(cfg, key=rng) + plot(data, ax=ax) + +# %% [markdown] +# ### Simulating an effect +# +# In the data we have seen up until now, the treated unit has been drawn from the same +# data generating process as the control units. However, it can be helpful to also +# inflate the treated unit to observe how well our model can recover the the true +# treatment effect. To do this, we simply compose our dataset with an `Effect` object. +# In the below, we shall inflate our data by 2%. + +# %% +effect = StaticEffect(effect=0.02) +inflated_data = effect(data) +fig, (ax0, ax1) = plt.subplots(ncols=2, figsize=(10, 3)) +plot(data, ax=ax0, title="Original data") +plot(inflated_data, ax=ax1, title="Inflated data") + +# %% [markdown] +# ### More complex generation processes +# +# The example presented above shows a very simple stationary data generation process. +# However, we may make our example more complex by including a non-stationary trend to +# the data. + +# %% +trend_term = Trend(degree=1, coefficient=0.1) +data_with_trend = effect(trend_term(data)) +plot(data_with_trend) + +# %% +trend_term = Trend(degree=2, coefficient=0.0025) +data_with_trend = effect(trend_term(data)) +plot(data_with_trend) + +# %% [markdown] +# We may also include periodic components in our data + +# %% +periodicity = Periodic(amplitude=2, frequency=6) +perioidic_data = effect(periodicity(trend_term(data))) +plot(perioidic_data) + +# %% [markdown] +# ### Unit-level parameterisation + +# %% +sampling_dist = norm(0.0, 1.0) +intercept = UnitVaryingParameter(sampling_dist=sampling_dist) +trend_term = Trend(degree=1, intercept=intercept, coefficient=0.1) +data_with_trend = effect(trend_term(data)) +plot(data_with_trend) + +# %% +sampling_dist = poisson(2) +frequency = UnitVaryingParameter(sampling_dist=sampling_dist) + +p = Periodic(frequency=frequency) +plot(p(data)) + +# %% [markdown] +# ## Conclusions +# +# In this notebook we have shown how one can define their model's true underlying data +# generating process, starting from simple white-noise samples through to more complex +# example with periodic and temporal components, perhaps containing unit-level +# variation. In a follow-up notebook, we show how these datasets may be integrated with +# Amazon's own AZCausal library to compare the effect estimated by a model with the true +# effect of the underlying data generating process. A link to this notebook is +# [here](PLACEHOLDER). diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..4ea4544 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,154 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "causal_validation" +dynamic = ["version"] +description = 'A validation framework for causal models.' +readme = "README.md" +requires-python = ">=3.10,<4.0" +license = "MIT" +keywords = [ + "synthetic data", "causal model", "machine learning" +] +authors = [ + { name = "Thomas Pinder", email = "pinthoma@amazon.nl" }, +] +classifiers = [ + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dependencies = [ + "azcausal", + "beartype", + "jaxtyping", + "matplotlib", + "numpy", + "pandas", +] + +[tool.hatch.build] +include = ["src/causal_validation"] +packages = ["src/causal_validation"] + +[tool.hatch.envs.default] +installer = "uv" +python = "3.10" + +[tool.hatch.envs.dev] +dependencies = [ + "mypy", + "black", + "isort", + "pytest", + "pytest-xdist", + "pytest-cov", + "pytest-sugar", + "coverage", + "autoflake", + "ruff", + "hypothesis", + "pre-commit", + "absolufy-imports", + "ipykernel", + "ipython", + "jupytext", + ] + +[tool.hatch.envs.dev.scripts] +test = "pytest --hypothesis-profile causal_validation" +ptest = "pytest -n auto . --hypothesis-profile causal_validation" +black-format = ["black src tests", "jupytext --pipe black examples/*.py"] +imports-format = [ + "isort src tests", + "isort examples/*.py --treat-comment-as-code '# %%' --float-to-top", +] +lint-format = ['ruff format src tests examples'] +format = ["black-format", "imports-format", "lint-format"] +build_nbs = [ + "jupytext --to notebook examples/*.pct.py", + "mv examples/*.ipynb nbs" +] + +[tool.hatch.version] +path = "src/causal_validation/__about__.py" + +[tool.coverage.run] +source_pkgs = ["causal_validation", "tests"] +branch = true +parallel = true +omit = [ + "src/causal_validation/__about__.py", +] + +[tool.coverage.paths] +causal_validation = ["src/causal_validation", "*/causal_validation/src/causal_validation"] +tests = ["tests", "*/causal_validation/tests"] + +[tool.coverage.report] +exclude_lines = [ + "no cov", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] + +[tool.black] +line-length = 88 +target-version = ["py310"] + + +[tool.isort] +profile = "black" +line_length = 88 +known_first_party = [ "causal_validation" ] +combine_as_imports = true +force_sort_within_sections = true +force_grid_wrap = 2 + +[tool.pytest.ini_options] +addopts = [ + "--durations=5", + "--color=yes", + "--cov=causal_validation" +] +testpaths = [ "test" ] +looponfailroots = [ + "src", + "test", +] + +[tool.ruff] +fix = true +cache-dir = "~/.cache/ruff" +line-length = 88 +src = ["src", "test"] +target-version = "py310" + +[tool.ruff.lint] +dummy-variable-rgx = "^_$" +select = [ + "F", + "E", + "W", + "YTT", + "B", + "Q", + "PLE", + "PLR", + "PLW", + "PIE", + "PYI", + "TID", + "ISC", +] +ignore = ["F722"] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" diff --git a/src/causal_validation/__about__.py b/src/causal_validation/__about__.py new file mode 100644 index 0000000..8d885d2 --- /dev/null +++ b/src/causal_validation/__about__.py @@ -0,0 +1,3 @@ +__version__ = "0.0.1" + +__all__ = ["__version__"] diff --git a/src/causal_validation/__init__.py b/src/causal_validation/__init__.py new file mode 100644 index 0000000..9ae3c8c --- /dev/null +++ b/src/causal_validation/__init__.py @@ -0,0 +1,4 @@ +from causal_validation.config import Config +from causal_validation.simulate import simulate + +__all__ = ["Config", "simulate"] diff --git a/src/causal_validation/base.py b/src/causal_validation/base.py new file mode 100644 index 0000000..68031dd --- /dev/null +++ b/src/causal_validation/base.py @@ -0,0 +1,6 @@ +from dataclasses import dataclass + + +@dataclass +class BaseObject: + name: str = "Abstract Object" diff --git a/src/causal_validation/config.py b/src/causal_validation/config.py new file mode 100644 index 0000000..901b0f2 --- /dev/null +++ b/src/causal_validation/config.py @@ -0,0 +1,33 @@ +from dataclasses import ( + dataclass, + field, +) +import datetime as dt +import typing as tp + +import numpy as np + +from causal_validation.weights import UniformWeights + +if tp.TYPE_CHECKING: + from causal_validation.types import WeightTypes + + +@dataclass(kw_only=True, frozen=True) +class WeightConfig: + weight_type: "WeightTypes" = field(default_factory=UniformWeights) + + +@dataclass(kw_only=True) +class Config: + n_control_units: int + n_pre_intervention_timepoints: int + n_post_intervention_timepoints: int + global_mean: float = 20.0 + global_scale: float = 0.2 + start_date: dt.date = dt.date(year=2023, month=1, day=1) + seed: int = 123 + weights_cfg: WeightConfig = field(default_factory=WeightConfig) + + def __post_init__(self): + self.rng = np.random.RandomState(self.seed) diff --git a/src/causal_validation/data.py b/src/causal_validation/data.py new file mode 100644 index 0000000..202cb63 --- /dev/null +++ b/src/causal_validation/data.py @@ -0,0 +1,130 @@ +from copy import deepcopy +from dataclasses import dataclass +import datetime as dt +import typing as tp + +from azcausal.core.panel import CausalPanel +from azcausal.util import to_panels +from jaxtyping import ( + Float, + Integer, +) +import numpy as np +import pandas as pd +from pandas._libs.tslibs.timestamps import Timestamp +from pandas.core.indexes.datetimes import DatetimeIndex + +from causal_validation.types import InterventionTypes + + +@dataclass(frozen=True) +class Dataset: + Xtr: Float[np.ndarray, "N D"] + Xte: Float[np.ndarray, "M D"] + ytr: Float[np.ndarray, "N 1"] + yte: Float[np.ndarray, "M 1"] + _start_date: dt.date + counterfactual: tp.Optional[Float[np.ndarray, "M 1"]] = None + + def to_df(self, index_start: str = "2023-01-01") -> pd.DataFrame: + inputs = np.vstack([self.Xtr, self.Xte]) + outputs = np.vstack([self.ytr, self.yte]) + data = np.hstack([outputs, inputs]) + index = self._get_index(index_start) + colnames = self._get_columns() + indicator = self._get_indicator() + df = pd.DataFrame(data, index=index, columns=colnames) + df = df.assign(treated=indicator) + return df + + @property + def n_post_intervention(self) -> int: + return self.Xte.shape[0] + + @property + def n_pre_intervention(self) -> int: + return self.Xtr.shape[0] + + @property + def n_units(self) -> int: + return self.Xtr.shape[1] + + @property + def n_timepoints(self) -> int: + return self.n_post_intervention + self.n_pre_intervention + + @property + def control_units(self) -> Float[np.ndarray, "N+M 1"]: + return np.vstack([self.Xtr, self.Xte]) + + @property + def treated_units(self) -> Float[np.ndarray, "N+M 1"]: + return np.vstack([self.ytr, self.yte]) + + @property + def pre_intervention_obs( + self, + ) -> tp.Tuple[Float[np.ndarray, "N D"], Float[np.ndarray, "N 1"]]: + return self.Xtr, self.ytr + + @property + def post_intervention_obs( + self, + ) -> tp.Tuple[Float[np.ndarray, "M D"], Float[np.ndarray, "M 1"]]: + return self.Xte, self.yte + + @property + def full_index(self) -> DatetimeIndex: + return self._get_index(self._start_date) + + @property + def treatment_date(self) -> Timestamp: + idxs = self.full_index + return idxs[self.n_pre_intervention] + + def get_index(self, period: InterventionTypes) -> DatetimeIndex: + if period == "pre-intervention": + return self.full_index[: self.n_pre_intervention] + elif period == "post-intervention": + return self.full_index[self.n_pre_intervention :] + else: + return self.full_index + + def _get_columns(self) -> tp.List[str]: + colnames = ["T"] + [f"C{i}" for i in range(self.n_units)] + return colnames + + def _get_index(self, start_date: str) -> pd.Series: + return pd.date_range(start=start_date, freq="D", periods=self.n_timepoints) + + def _get_indicator(self) -> Integer[np.ndarray, "N 1"]: + indicator = np.vstack( + [ + np.zeros(shape=(self.n_pre_intervention, 1)), + np.ones(shape=(self.n_post_intervention, 1)), + ] + ) + return indicator + + def inflate(self, inflation_vals: Float[np.ndarray, "M 1"]) -> "Dataset": + Xtr, ytr = [deepcopy(i) for i in self.pre_intervention_obs] + Xte, yte = [deepcopy(i) for i in self.post_intervention_obs] + inflated_yte = yte * inflation_vals + return Dataset(Xtr, Xte, ytr, inflated_yte, self._start_date, yte) + + def to_azcausal(self): + time_index = np.arange(self.n_timepoints) + data = self.to_df().assign(time=time_index).melt(id_vars=["time", "treated"]) + data.loc[:, "treated"] = np.where( + (data["variable"] == "T") & (data["treated"] == 1.0), 1, 0 + ) + panels = to_panels(data, "time", "variable", ["value", "treated"]) + ctypes = dict( + outcome="value", time="time", unit="variable", intervention="treated" + ) + panel = CausalPanel(panels).setup(**ctypes) + return panel + + @property + def _slots(self) -> tp.Dict[str, int]: + return {"n_units": self.n_units + 1, "n_timepoints": self.n_timepoints} diff --git a/src/causal_validation/effects.py b/src/causal_validation/effects.py new file mode 100644 index 0000000..accf3b5 --- /dev/null +++ b/src/causal_validation/effects.py @@ -0,0 +1,67 @@ +from dataclasses import dataclass +import typing as tp + +from jaxtyping import Float +import numpy as np + +from causal_validation.base import BaseObject +from causal_validation.data import Dataset + +if tp.TYPE_CHECKING: + from causal_validation.config import EffectConfig + + +@dataclass +class AbstractEffect(BaseObject): + name: str = "Abstract Effect" + + def get_effect(self, data: Dataset, **kwargs) -> Float[np.ndarray, "N 1"]: + raise NotImplementedError("Please implement `get_effect` in all subclasses.") + + def __call__(self, data: Dataset, **kwargs) -> Dataset: + inflation_vals = self.get_effect(data) + return data.inflate(inflation_vals) + + +@dataclass +class _StaticEffect: + effect: float + + +@dataclass +class _RandomEffect: + mean_effect: float + stddev_effect: float + + +@dataclass +class StaticEffect(AbstractEffect, _StaticEffect): + effect: float + name = "Static Effect" + + def get_effect(self, data: Dataset, **kwargs) -> Float[np.ndarray, "N 1"]: + n_post_intervention = data.n_post_intervention + return np.repeat(1.0 + self.effect, repeats=n_post_intervention)[:, None] + + +@dataclass +class RandomEffect(AbstractEffect, _RandomEffect): + mean_effect: float + stddev_effect: float + name: str = "Random Effect" + + def get_effect( + self, data: Dataset, key: np.random.RandomState + ) -> Float[np.ndarray, "N 1"]: + n_post_intervention = data.n_post_intervention + effect_sample = key.normal( + loc=1.0 + self.mean_effect, + scale=self.stddev_effect, + size=(n_post_intervention, 1), + ) + return effect_sample + + +# Placeholder for now. +def resolve_effect(cfg: "EffectConfig") -> AbstractEffect: + return StaticEffect(effect=cfg.effect) diff --git a/src/causal_validation/plotters.py b/src/causal_validation/plotters.py new file mode 100644 index 0000000..be5d91b --- /dev/null +++ b/src/causal_validation/plotters.py @@ -0,0 +1,49 @@ +import typing as tp + +import matplotlib as mpl +from matplotlib.axes._axes import Axes +import matplotlib.dates as mdates +import matplotlib.pyplot as plt + +from causal_validation.data import Dataset + + +def clean_legend(ax: Axes) -> Axes: + """Remove duplicate legend entries from a plot. + + Args: + ax (Axes): The matplotlib axes containing the legend to be formatted. + + Returns: + Axes: The cleaned matplotlib axes. + """ + handles, labels = ax.get_legend_handles_labels() + by_label = dict(zip(labels, handles, strict=False)) + ax.legend(by_label.values(), by_label.keys(), loc="best") + return ax + + +def plot( + data: Dataset, + ax: tp.Optional[Axes] = None, + title: tp.Optional[str] = None, +) -> Axes: + cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] + X = data.control_units + y = data.treated_units + idx = data.full_index + treatment_date = data.treatment_date + + if ax is None: + _, ax = plt.subplots(figsize=(6, 3), tight_layout=True) + ax.plot(idx, X, color=cols[0], label="Control", alpha=0.5) + ax.plot(idx, y, color=cols[1], label="Treated") + ax.axvline(x=treatment_date, color=cols[2], label="Intervention", linestyle="--") + ax.xaxis.set_major_formatter( + mdates.ConciseDateFormatter(ax.xaxis.get_major_locator()) + ) + clean_legend(ax) + ax.spines["right"].set_visible(False) + ax.spines["top"].set_visible(False) + ax.set(xlabel="Time", ylabel="Observed", title=title) + return ax diff --git a/src/causal_validation/py.typed b/src/causal_validation/py.typed new file mode 100644 index 0000000..7ef2116 --- /dev/null +++ b/src/causal_validation/py.typed @@ -0,0 +1 @@ +# Marker file that indicates this package supports typing diff --git a/src/causal_validation/simulate.py b/src/causal_validation/simulate.py new file mode 100644 index 0000000..2f02c10 --- /dev/null +++ b/src/causal_validation/simulate.py @@ -0,0 +1,37 @@ +import typing as tp + +import numpy as np + +from causal_validation.config import Config +from causal_validation.data import Dataset +from causal_validation.weights import ( + AbstractWeights, + UniformWeights, +) + + +def simulate(config: Config, key: tp.Optional[np.random.RandomState] = None) -> Dataset: + if key is None: + key = config.rng + weights = UniformWeights() + + base_data = _simulate_base_obs(config, weights, key) + return base_data + + +def _simulate_base_obs( + config: Config, weights: AbstractWeights, key: np.random.RandomState +) -> Dataset: + n_timepoints = ( + config.n_pre_intervention_timepoints + config.n_post_intervention_timepoints + ) + n_units = config.n_control_units + obs = key.normal( + loc=config.global_mean, scale=config.global_scale, size=(n_timepoints, n_units) + ) + Xtr = obs[: config.n_pre_intervention_timepoints, :] + Xte = obs[config.n_pre_intervention_timepoints :, :] + ytr = weights.weight_obs(Xtr) + yte = weights.weight_obs(Xte) + data = Dataset(Xtr, Xte, ytr, yte, _start_date=config.start_date) + return data diff --git a/src/causal_validation/testing.py b/src/causal_validation/testing.py new file mode 100644 index 0000000..a1bcb22 --- /dev/null +++ b/src/causal_validation/testing.py @@ -0,0 +1,33 @@ +from dataclasses import dataclass +import typing as tp + +from causal_validation.config import Config +from causal_validation.data import Dataset +from causal_validation.simulate import simulate + + +@dataclass(frozen=True, kw_only=True) +class TestConstants: + N_CONTROL: int = 10 + N_PRE_TREATMENT: int = 500 + N_POST_TREATMENT: int = 500 + DATA_SLOTS: tp.Tuple[str, str, str, str] = ("Xtr", "Xte", "ytr", "yte") + ZERO_DIVISION_ERROR: float = 1e-6 + GLOBAL_SCALE: float = 1.0 + __test__: bool = False + + +def simulate_data( + global_mean: float, seed: int, constants: tp.Optional[TestConstants] = None +) -> Dataset: + if not constants: + constants = TestConstants() + cfg = Config( + n_control_units=constants.N_CONTROL, + n_pre_intervention_timepoints=constants.N_PRE_TREATMENT, + n_post_intervention_timepoints=constants.N_POST_TREATMENT, + global_mean=global_mean, + global_scale=constants.GLOBAL_SCALE, + seed=seed, + ) + return simulate(config=cfg) diff --git a/src/causal_validation/transforms/__init__.py b/src/causal_validation/transforms/__init__.py new file mode 100644 index 0000000..2707bfc --- /dev/null +++ b/src/causal_validation/transforms/__init__.py @@ -0,0 +1,4 @@ +from causal_validation.transforms.periodic import Periodic +from causal_validation.transforms.trends import Trend + +__all__ = ["Trend", "Periodic"] diff --git a/src/causal_validation/transforms/base.py b/src/causal_validation/transforms/base.py new file mode 100644 index 0000000..ea15109 --- /dev/null +++ b/src/causal_validation/transforms/base.py @@ -0,0 +1,94 @@ +from copy import deepcopy +from dataclasses import dataclass +import typing as tp + +from jaxtyping import Float +import numpy as np + +from causal_validation.data import Dataset +from causal_validation.transforms.parameter import resolve_parameter + +if tp.TYPE_CHECKING: + from causal_validation.transforms.parameter import ( + Parameter, + resolve_parameter, + ) + + +@dataclass(kw_only=True) +class AbstractTransform: + _slots: tp.Optional[tp.Tuple[str]] = None + + def __post_init__(self): + if self._slots: + for slot in self._slots: + coerced_param = resolve_parameter(getattr(self, slot)) + setattr(self, slot, coerced_param) + + def __call__(self, data: Dataset) -> Dataset: + vals = self.get_values(data) + pre_intervention_trend = vals[: data.n_pre_intervention] + post_intervention_trend = vals[data.n_pre_intervention :] + return self.apply_values( + pre_intervention_trend, post_intervention_trend, data=data + ) + + def get_values(self, data: Dataset) -> Float[np.ndarray, "N D"]: + raise NotImplementedError + + def apply_values( + self, + pre_intervention_vals: np.ndarray, + post_intervention_vals: np.ndarray, + data: Dataset, + ) -> Dataset: + raise NotImplementedError + + @staticmethod + def _resolve_parameter( + data: Dataset, parameter: "Parameter" + ) -> Float[np.ndarray, "..."]: + data_params = data._slots + return parameter.get_value(**data_params) + + def _get_parameter_values(self, data: Dataset) -> tp.Dict[str, np.ndarray]: + param_vals = {} + if self._slots: + for slot in self._slots: + param = getattr(self, slot) + param_vals[slot] = self._resolve_parameter(data, param) + return param_vals + + +@dataclass(kw_only=True) +class AdditiveTransform(AbstractTransform): + def apply_values( + self, + pre_intervention_vals: np.ndarray, + post_intervention_vals: np.ndarray, + data: Dataset, + ) -> Dataset: + Xtr, ytr = [deepcopy(i) for i in data.pre_intervention_obs] + Xte, yte = [deepcopy(i) for i in data.post_intervention_obs] + Xtr = Xtr + pre_intervention_vals[:, 1:] + ytr = ytr + pre_intervention_vals[:, :1] + Xte = Xte + post_intervention_vals[:, 1:] + yte = yte + post_intervention_vals[:, :1] + return Dataset(Xtr, Xte, ytr, yte, data._start_date, data.counterfactual) + + +@dataclass(kw_only=True) +class MultiplicativeTransform(AbstractTransform): + def apply_values( + self, + pre_intervention_vals: np.ndarray, + post_intervention_vals: np.ndarray, + data: Dataset, + ) -> Dataset: + Xtr, ytr = [deepcopy(i) for i in data.pre_intervention_obs] + Xte, yte = [deepcopy(i) for i in data.post_intervention_obs] + Xtr = Xtr * pre_intervention_vals + ytr = ytr * pre_intervention_vals + Xte = Xte * post_intervention_vals + yte = yte * post_intervention_vals + return Dataset(Xtr, Xte, ytr, yte, data._start_date, data.counterfactual) diff --git a/src/causal_validation/transforms/parameter.py b/src/causal_validation/transforms/parameter.py new file mode 100644 index 0000000..5ff9362 --- /dev/null +++ b/src/causal_validation/transforms/parameter.py @@ -0,0 +1,63 @@ +from dataclasses import dataclass +import typing as tp + +from jaxtyping import Float +import numpy as np + +from causal_validation.types import RandomVariable + + +@dataclass +class Parameter: + def get_value(self, **kwargs) -> Float[np.ndarray, "..."]: + raise NotImplementedError + + +@dataclass +class FixedParameter(Parameter): + value: float + + def get_value( + self, n_units: int, n_timepoints: int + ) -> Float[np.ndarray, "{n_timepoints} {n_units}"]: + return np.ones(shape=(n_timepoints, n_units)) * self.value + + +@dataclass +class RandomParameter(Parameter): + sampling_dist: RandomVariable + random_state: int = 123 + + +@dataclass +class UnitVaryingParameter(RandomParameter): + def get_value( + self, n_units: int, n_timepoints: int + ) -> Float[np.ndarray, "{n_timepoints} {n_units}"]: + unit_param = self.sampling_dist.rvs( + size=(n_units,), random_state=self.random_state + ) + return np.stack([unit_param] * n_timepoints) + + +@dataclass +class TimeVaryingParameter(RandomParameter): + def get_value( + self, n_units: int, n_timepoints: int + ) -> Float[np.ndarray, "{n_timepoints} {n_units}"]: + time_param = self.sampling_dist.rvs( + size=(n_timepoints, 1), random_state=self.random_state + ) + return np.tile(time_param, reps=n_units) + + +ParameterOrFloat = tp.Union[Parameter, float] + + +def resolve_parameter(value: ParameterOrFloat) -> Parameter: + if isinstance(value, tp.Union[int, float]): + return FixedParameter(value=value) + elif isinstance(value, Parameter): + return value + else: + raise TypeError("`value` argument must be either a `Parameter` or `float`.") diff --git a/src/causal_validation/transforms/periodic.py b/src/causal_validation/transforms/periodic.py new file mode 100644 index 0000000..d9a1376 --- /dev/null +++ b/src/causal_validation/transforms/periodic.py @@ -0,0 +1,35 @@ +from dataclasses import dataclass +from typing import Tuple + +from jaxtyping import Float +import numpy as np + +from causal_validation.data import Dataset +from causal_validation.transforms.base import AdditiveTransform +from causal_validation.transforms.parameter import ParameterOrFloat + + +@dataclass(kw_only=True) +class Periodic(AdditiveTransform): + amplitude: ParameterOrFloat = 1.0 + frequency: ParameterOrFloat = 1.0 + shift: ParameterOrFloat = 0.0 + offset: ParameterOrFloat = 0.0 + _slots: Tuple[str, str, str, str] = ( + "amplitude", + "frequency", + "shift", + "offset", + ) + + def get_values(self, data: Dataset) -> Float[np.ndarray, "N D"]: + amplitude = self.amplitude.get_value(**data._slots) + frequency = self.frequency.get_value(**data._slots) + shift = self.shift.get_value(**data._slots) + offset = self.offset.get_value(**data._slots) + x_vals = np.tile( + np.linspace(0, 2 * np.pi, num=data.n_timepoints).reshape(-1, 1), + reps=data.n_units + 1, + ) + sine_curve = amplitude * np.sin((x_vals * np.abs(frequency)) + shift) + offset + return sine_curve diff --git a/src/causal_validation/transforms/trends.py b/src/causal_validation/transforms/trends.py new file mode 100644 index 0000000..56018d5 --- /dev/null +++ b/src/causal_validation/transforms/trends.py @@ -0,0 +1,26 @@ +from dataclasses import dataclass +from typing import Tuple + +from jaxtyping import Float +import numpy as np + +from causal_validation.data import Dataset +from causal_validation.transforms.base import AdditiveTransform +from causal_validation.transforms.parameter import ParameterOrFloat + + +@dataclass(kw_only=True) +class Trend(AdditiveTransform): + degree: int = 1 + coefficient: ParameterOrFloat = 1.0 + intercept: ParameterOrFloat = 0.0 + _slots: Tuple[str, str] = ("coefficient", "intercept") + + def get_values(self, data: Dataset) -> Float[np.ndarray, "N D"]: + coefficient = self._resolve_parameter(data, self.coefficient) + intercept = self._resolve_parameter(data, self.intercept) + trend = np.tile( + np.arange(data.n_timepoints)[:, None] ** self.degree, data.n_units + 1 + ) + scaled_trend = intercept + coefficient * trend + return scaled_trend diff --git a/src/causal_validation/types.py b/src/causal_validation/types.py new file mode 100644 index 0000000..34f5104 --- /dev/null +++ b/src/causal_validation/types.py @@ -0,0 +1,11 @@ +import typing as tp + +from scipy.stats._distn_infrastructure import ( + rv_continuous, + rv_discrete, +) + +EffectTypes = tp.Literal["fixed", "random"] +WeightTypes = tp.Literal["uniform", "non-uniform"] +InterventionTypes = tp.Literal["pre-intervention", "post-intervention", "both"] +RandomVariable = tp.Union[rv_continuous, rv_discrete] diff --git a/src/causal_validation/weights.py b/src/causal_validation/weights.py new file mode 100644 index 0000000..f108a7d --- /dev/null +++ b/src/causal_validation/weights.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from dataclasses import dataclass +import typing as tp + +from jaxtyping import Float +import numpy as np + +from causal_validation.base import BaseObject + +if tp.TYPE_CHECKING: + from causal_validation.config import WeightConfig + + +@dataclass +class AbstractWeights(BaseObject): + name: str = "Abstract Weights" + + def _get_weights(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, " D"]: + raise NotImplementedError("Please implement `_get_weights` in all subclasses.") + + def get_weights(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, " D"]: + weights = self._get_weights(obs) + + np.testing.assert_almost_equal( + weights.sum(), 1.0, decimal=1.0, err_msg="Weights must sum to 1." + ) + assert min(weights >= 0), "Weights should be non-negative" + return weights + + def __call__(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "N 1"]: + return self.weight_obs(obs) + + def weight_obs(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "N 1"]: + weights = self.get_weights(obs) + + weighted_obs = obs @ weights + return weighted_obs + + +@dataclass +class UniformWeights(AbstractWeights): + name: str = "Uniform Weights" + + def _get_weights(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, " D"]: + n_units = obs.shape[1] + return np.repeat(1.0 / n_units, repeats=n_units).reshape(-1, 1) + + +def resolve_weights(config: "WeightConfig") -> AbstractWeights: + if config.weight_type == "uniform": + return UniformWeights() diff --git a/static/fig_creation.py b/static/fig_creation.py new file mode 100644 index 0000000..83ff916 --- /dev/null +++ b/static/fig_creation.py @@ -0,0 +1,38 @@ +import matplotlib.pyplot as plt +from scipy.stats import norm + +from causal_validation import ( + Config, + simulate, +) +from causal_validation.effects import StaticEffect +from causal_validation.plotters import plot +from causal_validation.transforms import ( + Periodic, + Trend, +) +from causal_validation.transforms.parameter import UnitVaryingParameter + +plt.style.use("style.mplstyle") + +if __name__ == "__main__": + + cfg = Config( + n_control_units=10, + n_pre_intervention_timepoints=60, + n_post_intervention_timepoints=30, + ) + + # Simulate the base observation + base_data = simulate(cfg) + + # Apply a linear trend with unit-varying intercept + intercept = UnitVaryingParameter(sampling_dist=norm(0, 1)) + trend_component = Trend(degree=1, coefficient=0.1, intercept=intercept) + trended_data = trend_component(base_data) + + # Simulate a 5% lift in the treated unit's post-intervention data + effect = StaticEffect(0.05) + inflated_data = effect(trended_data) + plot(inflated_data) + plt.savefig("readme_fig.png", dpi=150) diff --git a/static/readme_fig.png b/static/readme_fig.png new file mode 100644 index 0000000..06a78f0 Binary files /dev/null and b/static/readme_fig.png differ diff --git a/static/style.mplstyle b/static/style.mplstyle new file mode 100644 index 0000000..4f9421d --- /dev/null +++ b/static/style.mplstyle @@ -0,0 +1,52 @@ +figure.figsize: 5.5, 2.5 +figure.constrained_layout.use: True +figure.autolayout: False +savefig.bbox: tight +figure.dpi: 120 + +# Axes +axes.spines.left: True # display axis spines +axes.spines.bottom: True +axes.spines.top: False +axes.spines.right: False +axes.grid: true +axes.axisbelow: true + +### Fonts +mathtext.fontset: cm +font.serif: Computer Modern Roman +font.size: 10 + +# Axes ticks +ytick.left: True +xtick.bottom: True +xtick.direction: out +ytick.direction: out + +# Colour palettes +axes.prop_cycle: cycler('color', ['2F83B4','B5121B', '0B6E4F','F77F00', '7A68A6', 'C5BB36', '8c564b', 'e377c2']) +lines.color: B5121B +scatter.marker: x +image.cmap: inferno + +### Grids +grid.linestyle: - +grid.linewidth: 0.2 +grid.color: cbcbcb + +### Legend +legend.frameon: True +legend.loc: best +legend.fontsize: 8 +legend.fancybox: True +legend.scatterpoints: 1 +legend.numpoints: 1 + +patch.antialiased: True + +# set text objects edidable in Adobe Illustrator +pdf.fonttype: 42 +ps.fonttype: 42 + +# no background +savefig.transparent: True \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..3a3512e --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,6 @@ +from hypothesis import settings + +settings.register_profile( + "causal_validation", database=None, max_examples=10, deadline=None +) +settings.load_profile("causal_validation") diff --git a/tests/test_causal_validation/README.md b/tests/test_causal_validation/README.md new file mode 100644 index 0000000..9f99dbe --- /dev/null +++ b/tests/test_causal_validation/README.md @@ -0,0 +1,29 @@ +By default, this package is configured to run PyTest tests +(http://pytest.org/). + +## Writing tests + +Place test files in this directory, using file names that start with `test_`. + +## Running tests + +To run the full suite against all interpreters, run: +``` +$ brazil-build test [] +``` + +By default, the package is set up to automatically pass any unknown flags forwards to pytest. Check the tox and pytest documentation for more information. + +Code coverage is automatically reported for causal_validation; +to add other packages, modify `pyproject.toml` in the package root directory. + +To debug failing tests, use the helpful `guard` command which runs the testing on a watch looponfailroots + +``` +$ brazil-build guard +``` + +Or if you want to debug the tests with a debugger open to the failed test, use pytest's pdb option: +``` +$ brazil-build pytest --pdb +``` diff --git a/tests/test_causal_validation/__init__.py b/tests/test_causal_validation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_causal_validation/test_amzn_synthetic_causal_data_gen.py b/tests/test_causal_validation/test_amzn_synthetic_causal_data_gen.py new file mode 100644 index 0000000..a244514 --- /dev/null +++ b/tests/test_causal_validation/test_amzn_synthetic_causal_data_gen.py @@ -0,0 +1,2 @@ +def test_causal_validation_importable(): + assert True diff --git a/tests/test_causal_validation/test_base.py b/tests/test_causal_validation/test_base.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_causal_validation/test_data.py b/tests/test_causal_validation/test_data.py new file mode 100644 index 0000000..fd79430 --- /dev/null +++ b/tests/test_causal_validation/test_data.py @@ -0,0 +1,168 @@ +from azcausal.estimators.panel.did import DID +from hypothesis import ( + given, + settings, + strategies as st, +) +import numpy as np +import pandas as pd +from pandas.core.indexes.datetimes import DatetimeIndex + +from causal_validation.data import Dataset +from causal_validation.testing import ( + TestConstants, + simulate_data, +) +from causal_validation.types import InterventionTypes + +DEFAULT_SEED = 123 +NUM_NON_CONTROL_COLS = 2 +LARGE_N_POST = 5000 +LARGE_N_PRE = 5000 + + +@given( + seed=st.integers(min_value=1, max_value=30), + global_mean=st.floats( + min_value=-5.0, max_value=5.0, allow_infinity=False, allow_nan=False + ), +) +def test_global_mean(seed: int, global_mean: float): + constants = TestConstants( + N_POST_TREATMENT=LARGE_N_POST, N_PRE_TREATMENT=LARGE_N_PRE, GLOBAL_SCALE=0.01 + ) + data = simulate_data(global_mean, seed, constants=constants) + assert isinstance(data, Dataset) + + control_units = data.control_units + treated_units = data.treated_units + + np.testing.assert_almost_equal( + np.mean(control_units, axis=0), global_mean, decimal=0 + ) + np.testing.assert_almost_equal( + np.mean(treated_units, axis=0), global_mean, decimal=0 + ) + + +@given( + n_control=st.integers(min_value=1, max_value=50), + n_pre_treatment=st.integers(min_value=1, max_value=50), + n_post_treatment=st.integers(min_value=1, max_value=50), +) +def test_array_shapes(n_control: int, n_pre_treatment: int, n_post_treatment: int): + constants = TestConstants( + N_POST_TREATMENT=n_post_treatment, + N_PRE_TREATMENT=n_pre_treatment, + N_CONTROL=n_control, + ) + data = simulate_data(0.0, DEFAULT_SEED, constants=constants) + + # Test high-level property values + assert data.n_units == n_control + assert data.n_timepoints == n_pre_treatment + n_post_treatment + assert data.n_pre_intervention == n_pre_treatment + assert data.n_post_intervention == n_post_treatment + + # Test field shapes + assert data.Xtr.shape == (n_pre_treatment, n_control) + assert data.Xte.shape == (n_post_treatment, n_control) + assert data.ytr.shape == (n_pre_treatment, 1) + assert data.yte.shape == (n_post_treatment, 1) + + # Test property shapes + Xtr, ytr = data.pre_intervention_obs + Xte, yte = data.post_intervention_obs + assert Xtr.shape == (n_pre_treatment, n_control) + assert ytr.shape == (n_pre_treatment, 1) + assert Xte.shape == (n_post_treatment, n_control) + assert yte.shape == (n_post_treatment, 1) + + +@given( + n_pre_treatment=st.integers(min_value=1, max_value=50), + n_post_treatment=st.integers(min_value=1, max_value=50), +) +def test_indicator(n_pre_treatment: int, n_post_treatment: int): + constants = TestConstants( + N_POST_TREATMENT=n_post_treatment, + N_PRE_TREATMENT=n_pre_treatment, + ) + data = simulate_data(0.0, DEFAULT_SEED, constants=constants) + assert data._get_indicator().sum() == n_post_treatment + + +@given( + n_control=st.integers(min_value=1, max_value=50), + n_pre_treatment=st.integers(min_value=1, max_value=50), + n_post_treatment=st.integers(min_value=1, max_value=50), +) +def test_to_df(n_control: int, n_pre_treatment: int, n_post_treatment: int): + constants = TestConstants( + N_POST_TREATMENT=n_post_treatment, + N_PRE_TREATMENT=n_pre_treatment, + N_CONTROL=n_control, + ) + data = simulate_data(0.0, DEFAULT_SEED, constants=constants) + + df = data.to_df() + assert isinstance(df, pd.DataFrame) + assert df.shape == ( + n_pre_treatment + n_post_treatment, + n_control + NUM_NON_CONTROL_COLS, + ) + + colnames = data._get_columns() + assert isinstance(colnames, list) + assert colnames[0] == "T" + assert len(colnames) == n_control + 1 + + index = data.full_index + assert isinstance(index, DatetimeIndex) + assert index[0].strftime("%Y-%m-%d") == data._start_date.strftime("%Y-%m-%d") + + +@given( + n_control=st.integers(min_value=2, max_value=50), + n_pre_treatment=st.integers(min_value=10, max_value=50), + n_post_treatment=st.integers(min_value=10, max_value=50), + global_mean=st.floats( + min_value=-5.0, max_value=5.0, allow_infinity=False, allow_nan=False + ), +) +@settings(max_examples=5) +def test_to_azcausal( + n_control: int, n_pre_treatment: int, n_post_treatment: int, global_mean: float +): + constants = TestConstants( + N_POST_TREATMENT=n_post_treatment, + N_PRE_TREATMENT=n_pre_treatment, + N_CONTROL=n_control, + ) + data = simulate_data(global_mean, DEFAULT_SEED, constants=constants) + + panel = data.to_azcausal() + model = DID() + result = model.fit(panel) + assert not np.isnan(result.effect.value) + + +@given( + n_post_treatment=st.integers(min_value=10, max_value=50), + n_pre_treatment=st.integers(min_value=10, max_value=50), + idx=st.sampled_from(["pre-intervention", "post-intervention", "both"]), +) +def test_get_index(n_post_treatment: int, n_pre_treatment: int, idx: InterventionTypes): + constants = TestConstants( + N_POST_TREATMENT=n_post_treatment, + N_PRE_TREATMENT=n_pre_treatment, + ) + data = simulate_data(0.0, DEFAULT_SEED, constants=constants) + idx_vals = data.get_index(idx) + assert isinstance(idx_vals, DatetimeIndex) + if idx == "both": + assert len(idx_vals) == n_pre_treatment + n_post_treatment + elif idx == "post-intervention": + assert len(idx_vals) == n_post_treatment + elif idx == "pre-intervention": + assert len(idx_vals) == n_pre_treatment diff --git a/tests/test_causal_validation/test_effect.py b/tests/test_causal_validation/test_effect.py new file mode 100644 index 0000000..f959315 --- /dev/null +++ b/tests/test_causal_validation/test_effect.py @@ -0,0 +1,58 @@ +from hypothesis import ( + given, + strategies as st, +) + +from causal_validation.effects import StaticEffect +from causal_validation.testing import ( + TestConstants, + simulate_data, +) + +EFFECT_LOWER_BOUND = 1e-3 + + +@st.composite +def effect_strategy(draw): + lower_range = st.floats( + min_value=-0.1, + max_value=-1e-4, + exclude_max=True, + allow_infinity=False, + allow_nan=False, + ) + upper_range = st.floats( + min_value=1e-4, + max_value=0.1, + exclude_min=True, + allow_infinity=False, + allow_nan=False, + ) + combined_strategy = st.one_of(lower_range, upper_range) + return draw(combined_strategy) + + +@given( + global_mean=st.floats( + min_value=20.0, max_value=50.0, allow_nan=False, allow_infinity=False + ), + effect_val=effect_strategy(), + seed=st.integers(min_value=1, max_value=10), +) +def test_array_shapes(global_mean: float, effect_val: float, seed: int): + constants = TestConstants(GLOBAL_SCALE=0.01) + data = simulate_data(global_mean, seed, constants=constants) + effect = StaticEffect(effect=effect_val) + + inflated_data = effect(data) + if effect_val == 0: + assert inflated_data.yte.sum() == data.yte.sum() + elif effect_val < 0: + assert inflated_data.yte.sum() < data.yte.sum() + elif effect_val > 0: + assert inflated_data.yte.sum() > data.yte.sum() + + assert inflated_data.counterfactual.sum() == data.yte.sum() + + _effects = effect.get_effect(data) + assert _effects.shape == (data.n_post_intervention, 1) diff --git a/tests/test_causal_validation/test_integration.py b/tests/test_causal_validation/test_integration.py new file mode 100644 index 0000000..0525bec --- /dev/null +++ b/tests/test_causal_validation/test_integration.py @@ -0,0 +1,37 @@ +import numpy as np +import pytest + +from causal_validation import ( + Config, + simulate, +) +from causal_validation.data import Dataset +from causal_validation.transforms import ( + Periodic, + Trend, +) + + +def _sum_data(data: Dataset) -> float: + return data.Xtr.sum() + data.ytr.sum() + data.Xte.sum() + data.yte.sum() + + +@pytest.mark.parametrize( + "seed,e1,e2", [(123, 19794.92, 63849.93), (42, 19803.64, 63858.64)] +) +def test_end_to_end(seed: int, e1: float, e2: float): + cfg = Config( + n_control_units=10, + n_pre_intervention_timepoints=60, + n_post_intervention_timepoints=30, + seed=seed, + ) + + data = simulate(cfg) + np.testing.assert_approx_equal(_sum_data(data), e1, significant=2) + + t = Trend() + np.testing.assert_approx_equal(_sum_data(t(data)), e2, significant=2) + + p = Periodic() + np.testing.assert_approx_equal(_sum_data(p(t(data))), e2, significant=2) diff --git a/tests/test_causal_validation/test_plotters.py b/tests/test_causal_validation/test_plotters.py new file mode 100644 index 0000000..fcdd286 --- /dev/null +++ b/tests/test_causal_validation/test_plotters.py @@ -0,0 +1,55 @@ +from hypothesis import ( + given, + settings, + strategies as st, +) +from matplotlib.axes._axes import Axes +import matplotlib.pyplot as plt + +from causal_validation.plotters import plot +from causal_validation.testing import ( + TestConstants, + simulate_data, +) + +DEFAULT_SEED = 123 +NUM_AUX_LINES = 2 +LARGE_N_POST = 5000 +LARGE_N_PRE = 5000 +N_LEGEND_ENTRIES = 3 + + +# Define a strategy for generating titles +title_strategy = st.text( + alphabet=st.characters( + whitelist_categories=("Lu", "Ll", "Nd"), whitelist_characters=" " + ), + min_size=1, + max_size=50, +) + + +@given( + n_control=st.integers(min_value=1, max_value=50), + n_pre_treatment=st.integers(min_value=1, max_value=50), + n_post_treatment=st.integers(min_value=1, max_value=50), + ax_bool=st.booleans(), +) +@settings(max_examples=5) +def test_plot( + n_control: int, n_pre_treatment: int, n_post_treatment: int, ax_bool: bool +): + constants = TestConstants( + N_POST_TREATMENT=n_post_treatment, + N_PRE_TREATMENT=n_pre_treatment, + N_CONTROL=n_control, + ) + data = simulate_data(0.0, DEFAULT_SEED, constants=constants) + if ax_bool: + _, ax = plt.subplots() + ax = plot(data) + assert isinstance(ax, Axes) + assert len(ax.lines) == n_control + 2 + assert ax.get_legend() is not None + assert len(ax.get_legend().get_texts()) == N_LEGEND_ENTRIES + plt.close() diff --git a/tests/test_causal_validation/test_transforms/test_periodic.py b/tests/test_causal_validation/test_transforms/test_periodic.py new file mode 100644 index 0000000..07f845f --- /dev/null +++ b/tests/test_causal_validation/test_transforms/test_periodic.py @@ -0,0 +1,178 @@ +from hypothesis import ( + given, + settings, + strategies as st, +) +import numpy as np +from scipy.stats import norm + +from causal_validation.data import Dataset +from causal_validation.testing import ( + TestConstants, + simulate_data, +) +from causal_validation.transforms import Periodic +from causal_validation.transforms.parameter import UnitVaryingParameter + +CONSTANTS = TestConstants() +DEFAULT_SEED = 123 +GLOBAL_MEAN = 20 +GLOBAL_SCALE = 0.5 + + +@given( + frequency=st.integers(min_value=1, max_value=20), + amplitude=st.floats( + min_value=-100, max_value=100, allow_infinity=False, allow_nan=False + ), + shift=st.floats( + min_value=-100, max_value=100, allow_infinity=False, allow_nan=False + ), + offset=st.floats( + min_value=-100, max_value=100, allow_infinity=False, allow_nan=False + ), + global_mean=st.floats( + min_value=-5.0, max_value=5.0, allow_infinity=False, allow_nan=False + ), +) +@settings(max_examples=5) +def test_periodic_initialisation( + frequency: int, + amplitude: float, + shift: float, + offset: float, + global_mean: float, +): + periodic_transform = Periodic( + amplitude=amplitude, frequency=frequency, shift=shift, offset=offset + ) + base_data = simulate_data(global_mean, DEFAULT_SEED) + data = periodic_transform(base_data) + assert isinstance(data, Dataset) + for slot in CONSTANTS.DATA_SLOTS: + _base_data_array = getattr(base_data, slot) + _data_array = getattr(data, slot) + assert _base_data_array.shape == _data_array.shape + assert np.sum(np.isnan(_base_data_array)) == 0 + assert np.sum(np.isnan(_data_array)) == 0 + + +@given( + frequency=st.integers(min_value=1, max_value=20), + seed=st.integers(min_value=1, max_value=30), + global_mean=st.floats( + min_value=-5.0, max_value=5.0, allow_infinity=False, allow_nan=False + ), +) +def test_frequency_param(frequency: int, seed: int, global_mean: float): + periodic_transform = Periodic(amplitude=1, frequency=frequency, shift=0, offset=0) + base_data = simulate_data(global_mean, seed) + data = periodic_transform(base_data) + np.testing.assert_array_almost_equal( + np.mean(data.control_units, axis=0), np.mean(base_data.control_units, axis=0) + ) + np.testing.assert_array_almost_equal( + np.mean(data.treated_units, axis=0), np.mean(base_data.treated_units, axis=0) + ) + + +@st.composite +def amplitude_strategy(draw): + lower_range = st.floats( + min_value=-100, + max_value=-1e-6, + exclude_max=True, + allow_infinity=False, + allow_nan=False, + ) + upper_range = st.floats( + min_value=1e-6, + max_value=100, + exclude_min=True, + allow_infinity=False, + allow_nan=False, + ) + combined_strategy = st.one_of(lower_range, upper_range) + return draw(combined_strategy) + + +@given( + amplitude=amplitude_strategy(), + seed=st.integers(min_value=1, max_value=30), + global_mean=st.floats( + min_value=-5.0, max_value=5.0, allow_infinity=False, allow_nan=False + ), +) +def test_amplitude_param(amplitude: float, seed: int, global_mean: float): + periodic_transform = Periodic(frequency=1, amplitude=amplitude, shift=0, offset=0) + base_data = simulate_data(global_mean, seed) + data = periodic_transform(base_data) + + assert np.isclose( + np.max(data.control_units - base_data.control_units), np.abs(amplitude), rtol=1 + ) + assert np.isclose( + np.max(data.treated_units - base_data.treated_units), np.abs(amplitude), rtol=1 + ) + + +@given( + frequency=st.integers(min_value=1, max_value=20), + seed=st.integers(min_value=1, max_value=30), + global_mean=st.floats( + min_value=-5.0, max_value=5.0, allow_infinity=False, allow_nan=False + ), +) +def test_num_frequencies(frequency: int, seed: int, global_mean: float): + periodic_transform = Periodic(frequency=frequency, amplitude=1, shift=0, offset=0) + base_data = simulate_data(global_mean, seed) + data = periodic_transform(base_data) + control_units = data.control_units + treated_units = data.treated_units + for d in [control_units, treated_units]: + num_samples = d.shape[0] + fft_vals = np.fft.fft(d, axis=0) + peak_frequency = np.argmax(np.abs(fft_vals[1 : num_samples // 2]), axis=0) + 1 + np.testing.assert_equal(peak_frequency, frequency) + + +@given( + offset=st.floats( + min_value=-100, max_value=100, allow_infinity=False, allow_nan=False + ), + seed=st.integers(min_value=1, max_value=30), + global_mean=st.floats( + min_value=-5.0, max_value=5.0, allow_infinity=False, allow_nan=False + ), +) +def test_offset(offset: float, seed: int, global_mean: float): + periodic_transform = Periodic(frequency=1, amplitude=5, shift=0, offset=offset) + base_data = simulate_data(global_mean, seed) + data = periodic_transform(base_data) + original_array = base_data.treated_units.squeeze() + offset_array = data.treated_units.squeeze() + + normal_mean = np.mean(original_array) + offset_mean = np.mean(offset_array) + assert np.isclose(offset_mean - normal_mean, offset, atol=0.1) + + +def test_varying_parameters(): + periodic_transform = Periodic() + param_slots = periodic_transform._slots + constants = TestConstants(N_CONTROL=2) + data_slots = constants.DATA_SLOTS + base_data = simulate_data(GLOBAL_MEAN, DEFAULT_SEED, constants=constants) + base_data_transform = periodic_transform(base_data) + for slot in param_slots: + setattr( + periodic_transform, + slot, + UnitVaryingParameter(sampling_dist=norm(GLOBAL_MEAN, GLOBAL_SCALE)), + ) + data = periodic_transform(base_data) + for dslot in data_slots: + assert not np.any(np.isnan(getattr(data, dslot))) + assert not np.array_equal( + getattr(data, dslot), getattr(base_data_transform, dslot) + ) diff --git a/tests/test_causal_validation/test_transforms/test_trends.py b/tests/test_causal_validation/test_transforms/test_trends.py new file mode 100644 index 0000000..1c9e3da --- /dev/null +++ b/tests/test_causal_validation/test_transforms/test_trends.py @@ -0,0 +1,122 @@ +from hypothesis import ( + given, + settings, + strategies as st, +) +import numpy as np +from scipy.stats import norm + +from causal_validation.testing import ( + TestConstants, + simulate_data, +) +from causal_validation.transforms import Trend +from causal_validation.transforms.parameter import UnitVaryingParameter + +CONSTANTS = TestConstants() +DEFAULT_SEED = 123 +GLOBAL_MEAN = 20 +STATES = [42, 123] + + +@st.composite +def coefficient_strategy(draw): + lower_range = st.floats( + min_value=-1, + max_value=-1e-6, + exclude_max=True, + allow_infinity=False, + allow_nan=False, + ) + upper_range = st.floats( + min_value=1e-6, + max_value=1, + exclude_min=True, + allow_infinity=False, + allow_nan=False, + ) + combined_strategy = st.one_of(lower_range, upper_range) + return draw(combined_strategy) + + +@given(degree=st.integers(min_value=1, max_value=3), coefficient=coefficient_strategy()) +@settings(max_examples=5) +def test_trend_coefficient(degree: int, coefficient: float): + trend_transform = Trend(degree=degree, coefficient=coefficient, intercept=0) + base_data = simulate_data(GLOBAL_MEAN, DEFAULT_SEED) + data = trend_transform(base_data) + + if coefficient > 1: + assert np.all(data.Xtr[-1, :] > base_data.Xtr[-1, :]) + elif coefficient < 0: + assert np.all(data.Xtr[-1, :] < base_data.Xtr[-1, :]) + + +@given(intercept=coefficient_strategy()) +@settings(max_examples=5) +def test_trend_intercept(intercept: float): + trend_transform = Trend(degree=1, coefficient=0, intercept=intercept) + base_data = simulate_data(GLOBAL_MEAN, DEFAULT_SEED) + data = trend_transform(base_data) + + if intercept > 0: + assert np.all(data.Xtr > base_data.Xtr) + elif intercept < 0: + assert np.all(data.Xtr < base_data.Xtr) + + +@given( + loc=st.floats( + min_value=-10.0, max_value=10.0, allow_nan=False, allow_infinity=False + ), + scale=st.floats( + min_value=1e-3, max_value=10, allow_infinity=False, allow_nan=False + ), +) +def test_varying_trend(loc: float, scale: float): + constants = TestConstants( + N_CONTROL=2, + ) + data = simulate_data(GLOBAL_MEAN, DEFAULT_SEED, constants=constants) + sampling_dist = norm(loc, scale) + param = UnitVaryingParameter(sampling_dist=sampling_dist) + trend = Trend(degree=1, coefficient=0.0, intercept=param) + transformed_data = trend(data) + assert not np.array_equal(transformed_data.Xtr[:, 0], transformed_data.Xtr[:, 1]) + + trend = Trend(degree=1, coefficient=param, intercept=0.0) + transformed_data = trend(data) + assert not np.array_equal(transformed_data.Xtr[:, 0], transformed_data.Xtr[:, 1]) + + +@given( + loc=st.floats( + min_value=-10.0, max_value=10.0, allow_nan=False, allow_infinity=False + ), + scale=st.floats( + min_value=1e-3, max_value=10, allow_infinity=False, allow_nan=False + ), +) +@settings(max_examples=5) +def test_randomness(loc: float, scale: float): + constants = TestConstants( + N_CONTROL=2, + ) + data = simulate_data(GLOBAL_MEAN, DEFAULT_SEED, constants=constants) + SLOTS = TestConstants().DATA_SLOTS + + for slot in SLOTS: + transformed_datas = [] + for random_state in STATES: + sampling_dist = norm(loc, scale) + param = UnitVaryingParameter( + sampling_dist=sampling_dist, random_state=random_state + ) + trend = Trend(degree=1, coefficient=0.0, intercept=param) + transformed_datas.append(trend(data)) + assert not np.array_equal( + getattr(transformed_datas[0], slot), getattr(transformed_datas[1], slot) + ) + assert not np.array_equal( + getattr(transformed_datas[0], slot), getattr(data, slot) + ) diff --git a/tests/test_causal_validation/test_weights.py b/tests/test_causal_validation/test_weights.py new file mode 100644 index 0000000..1b5ed6d --- /dev/null +++ b/tests/test_causal_validation/test_weights.py @@ -0,0 +1,32 @@ +from hypothesis import ( + given, + strategies as st, +) +import numpy as np + +from causal_validation.weights import UniformWeights + + +@given( + n_units=st.integers(min_value=1, max_value=100), + n_time=st.integers(min_value=1, max_value=100), +) +def test_uniform_weights(n_units: int, n_time: int): + weights = UniformWeights() + data = np.random.random(size=(n_time, n_units)) + weight_vals = weights.get_weights(data) + np.testing.assert_almost_equal(np.mean(weight_vals), weight_vals, decimal=6) + assert weight_vals.shape == (n_units, 1) + + +@given( + n_units=st.integers(min_value=1, max_value=100), + n_time=st.integers(min_value=1, max_value=100), +) +def test_weight_obs(n_units: int, n_time: int): + obs = np.ones(shape=(n_time, n_units)) + weighted_obs = UniformWeights()(obs) + np.testing.assert_almost_equal(np.mean(weighted_obs), weighted_obs, decimal=6) + np.testing.assert_almost_equal( + obs @ UniformWeights().get_weights(obs), weighted_obs, decimal=6 + )