Skip to content
44 changes: 30 additions & 14 deletions src/causal_validation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,36 @@ def __post_init__(self):

def to_df(
self, index_start: str = dt.date(year=2023, month=1, day=1)
) -> pd.DataFrame:
inputs = np.vstack([self.Xtr, self.Xte])
outputs = np.vstack([self.ytr, self.yte])
data = np.hstack([outputs, inputs])
) -> tp.Tuple[pd.DataFrame, tp.Optional[pd.DataFrame]]:
control_outputs = np.vstack([self.Xtr, self.Xte])
treated_outputs = np.vstack([self.ytr, self.yte])
data = np.hstack([treated_outputs, control_outputs])
index = self._get_index(index_start)
colnames = self._get_columns()
indicator = self._get_indicator()
df = pd.DataFrame(data, index=index, columns=colnames)
df = df.assign(treated=indicator)
return df
df_outputs = pd.DataFrame(data, index=index, columns=colnames)
df_outputs = df_outputs.assign(treated=indicator)

if not self.has_covariates:
cov_df = None
else:
control_covs = np.concatenate([self.Ptr, self.Pte], axis=0)
treated_covs = np.concatenate([self.Rtr, self.Rte], axis=0)

all_covs = np.concatenate([treated_covs, control_covs], axis=1)

unit_cols = self._get_columns()
covariate_cols = [f"F{i}" for i in range(self.n_covariates)]

cov_data = all_covs.reshape(self.n_timepoints, -1)

col_tuples = [(unit, cov) for unit in unit_cols for cov in covariate_cols]
multi_cols = pd.MultiIndex.from_tuples(col_tuples)

cov_df = pd.DataFrame(cov_data, index=index, columns=multi_cols)
cov_df = cov_df.assign(treated=indicator)

return df_outputs, cov_df

@property
def n_post_intervention(self) -> int:
Expand Down Expand Up @@ -180,12 +200,7 @@ def get_index(self, period: InterventionTypes) -> DatetimeIndex:
return self.full_index

def _get_columns(self) -> tp.List[str]:
if self.has_covariates:
colnames = ["T"] + [f"C{i}" for i in range(self.n_units)] + [
f"F{i}" for i in range(self.n_covariates)
]
else:
colnames = ["T"] + [f"C{i}" for i in range(self.n_units)]
colnames = ["T"] + [f"C{i}" for i in range(self.n_units)]
return colnames

def _get_index(self, start_date: dt.date) -> DatetimeIndex:
Expand Down Expand Up @@ -224,7 +239,8 @@ def __eq__(self, other: Dataset) -> bool:

def to_azcausal(self):
time_index = np.arange(self.n_timepoints)
data = self.to_df().assign(time=time_index).melt(id_vars=["time", "treated"])
data_df, _ = self.to_df()
data = data_df.assign(time=time_index).melt(id_vars=["time", "treated"])
data.loc[:, "treated"] = np.where(
(data["variable"] == "T") & (data["treated"] == 1.0), 1, 0
)
Expand Down
2 changes: 2 additions & 0 deletions src/causal_validation/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class TestConstants:
N_CONTROL: int = 10
N_PRE_TREATMENT: int = 500
N_POST_TREATMENT: int = 500
N_COVARIATES: tp.Optional[int] = None
DATA_SLOTS: tp.Tuple[str, str, str, str] = ("Xtr", "Xte", "ytr", "yte")
ZERO_DIVISION_ERROR: float = 1e-6
GLOBAL_SCALE: float = 1.0
Expand All @@ -26,6 +27,7 @@ def simulate_data(
n_control_units=constants.N_CONTROL,
n_pre_intervention_timepoints=constants.N_PRE_TREATMENT,
n_post_intervention_timepoints=constants.N_POST_TREATMENT,
n_covariates=constants.N_COVARIATES,
global_mean=global_mean,
global_scale=constants.GLOBAL_SCALE,
seed=seed,
Expand Down
4 changes: 2 additions & 2 deletions src/causal_validation/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from causal_validation.transforms.noise import Noise
from causal_validation.transforms.noise import Noise, CovariateNoise
from causal_validation.transforms.periodic import Periodic
from causal_validation.transforms.trends import Trend

__all__ = ["Trend", "Periodic", "Noise"]
__all__ = ["Trend", "Periodic", "Noise", "CovariateNoise"]
32 changes: 28 additions & 4 deletions src/causal_validation/transforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _get_parameter_values(self, data: Dataset) -> tp.Dict[str, np.ndarray]:


@dataclass(kw_only=True)
class AdditiveTransform(AbstractTransform):
class AdditiveOutputTransform(AbstractTransform):
def apply_values(
self,
pre_intervention_vals: np.ndarray,
Expand All @@ -75,12 +75,14 @@ def apply_values(
Xte = Xte + post_intervention_vals[:, 1:]
yte = yte + post_intervention_vals[:, :1]
return Dataset(
Xtr, Xte, ytr, yte, data._start_date, data.counterfactual, data.synthetic
Xtr, Xte, ytr, yte, data._start_date,
data.Ptr, data.Pte, data.Rtr, data.Rte,
data.counterfactual, data.synthetic
)


@dataclass(kw_only=True)
class MultiplicativeTransform(AbstractTransform):
class MultiplicativeOutputTransform(AbstractTransform):
def apply_values(
self,
pre_intervention_vals: np.ndarray,
Expand All @@ -94,5 +96,27 @@ def apply_values(
Xte = Xte * post_intervention_vals
yte = yte * post_intervention_vals
return Dataset(
Xtr, Xte, ytr, yte, data._start_date, data.counterfactual, data.synthetic
Xtr, Xte, ytr, yte, data._start_date,
data.Ptr, data.Pte, data.Rtr, data.Rte,
data.counterfactual, data.synthetic
)

@dataclass(kw_only=True)
class AdditiveCovariateTransform(AbstractTransform):
def apply_values(
self,
pre_intervention_vals: np.ndarray,
post_intervention_vals: np.ndarray,
data: Dataset,
) -> Dataset:
Ptr, Rtr = [deepcopy(i) for i in data.pre_intervention_covariates]
Pte, Rte = [deepcopy(i) for i in data.post_intervention_covariates]
Ptr = Ptr + pre_intervention_vals[:, 1:, :]
Rtr = Rtr + pre_intervention_vals[:, :1, :]
Pte = Pte + post_intervention_vals[:, 1:, :]
Rte = Rte + post_intervention_vals[:, :1, :]
return Dataset(
data.Xtr, data.Xte, data.ytr, data.yte,
data._start_date, Ptr, Pte, Rtr, Rte,
data.counterfactual, data.synthetic
)
34 changes: 31 additions & 3 deletions src/causal_validation/transforms/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,18 @@
from scipy.stats import norm

from causal_validation.data import Dataset
from causal_validation.transforms.base import AdditiveTransform
from causal_validation.transforms.parameter import TimeVaryingParameter
from causal_validation.transforms.base import (
AdditiveOutputTransform,
AdditiveCovariateTransform
)
from causal_validation.transforms.parameter import (
TimeVaryingParameter,
CovariateNoiseParameter
)


@dataclass(kw_only=True)
class Noise(AdditiveTransform):
class Noise(AdditiveOutputTransform):
"""
Transform the treatment by adding TimeVaryingParameter noise terms sampled from
a specified sampling distribution. By default, the sampling distribution is
Expand All @@ -30,3 +36,25 @@ def get_values(self, data: Dataset) -> Float[np.ndarray, "N D"]:
).reshape(-1)
noise[:, 0] = noise_treatment
return noise


@dataclass(kw_only=True)
class CovariateNoise(AdditiveCovariateTransform):
"""
Transform the covariates by adding CovariateNoiseParameter noise terms sampled from
a specified sampling distribution. By default, the sampling distribution is
Normal with 0 loc and 0.1 scale.
"""

noise_dist: CovariateNoiseParameter = field(
default_factory=lambda: CovariateNoiseParameter(sampling_dist=norm(0, 0.1))
)
_slots: Tuple[str] = ("noise_dist",)

def get_values(self, data: Dataset) -> Float[np.ndarray, "N D"]:
noise = self.noise_dist.get_value(
n_units=data.n_units+1,
n_timepoints=data.n_timepoints,
n_covariates=data.n_covariates
)
return noise
12 changes: 12 additions & 0 deletions src/causal_validation/transforms/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,18 @@ def get_value(
return np.tile(time_param, reps=n_units)


@dataclass
class CovariateNoiseParameter(RandomParameter):
def get_value(
self, n_units: int, n_timepoints: int, n_covariates: int
) -> Float[np.ndarray, "{n_timepoints} {n_units} {n_covariates}"]:
covariate_noise = self.sampling_dist.rvs(
size=(n_timepoints, n_units, n_covariates),
random_state=self.random_state
)
return covariate_noise


ParameterOrFloat = tp.Union[Parameter, float]


Expand Down
4 changes: 2 additions & 2 deletions src/causal_validation/transforms/periodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import numpy as np

from causal_validation.data import Dataset
from causal_validation.transforms.base import AdditiveTransform
from causal_validation.transforms.base import AdditiveOutputTransform
from causal_validation.transforms.parameter import ParameterOrFloat


@dataclass(kw_only=True)
class Periodic(AdditiveTransform):
class Periodic(AdditiveOutputTransform):
amplitude: ParameterOrFloat = 1.0
frequency: ParameterOrFloat = 1.0
shift: ParameterOrFloat = 0.0
Expand Down
4 changes: 2 additions & 2 deletions src/causal_validation/transforms/trends.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import numpy as np

from causal_validation.data import Dataset
from causal_validation.transforms.base import AdditiveTransform
from causal_validation.transforms.base import AdditiveOutputTransform
from causal_validation.transforms.parameter import ParameterOrFloat


@dataclass(kw_only=True)
class Trend(AdditiveTransform):
class Trend(AdditiveOutputTransform):
degree: int = 1
coefficient: ParameterOrFloat = 1.0
intercept: ParameterOrFloat = 0.0
Expand Down
46 changes: 44 additions & 2 deletions tests/test_causal_validation/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
MAX_STRING_LENGTH = 20
DEFAULT_SEED = 123
NUM_NON_CONTROL_COLS = 2
NUM_TREATED = 1
LARGE_N_POST = 5000
LARGE_N_PRE = 5000

Expand Down Expand Up @@ -109,15 +110,15 @@ def test_indicator(n_pre_treatment: int, n_post_treatment: int):
n_pre_treatment=st.integers(min_value=1, max_value=50),
n_post_treatment=st.integers(min_value=1, max_value=50),
)
def test_to_df(n_control: int, n_pre_treatment: int, n_post_treatment: int):
def test_to_df_no_cov(n_control: int, n_pre_treatment: int, n_post_treatment: int):
constants = TestConstants(
N_POST_TREATMENT=n_post_treatment,
N_PRE_TREATMENT=n_pre_treatment,
N_CONTROL=n_control,
)
data = simulate_data(0.0, DEFAULT_SEED, constants=constants)

df = data.to_df()
df, _ = data.to_df()
assert isinstance(df, pd.DataFrame)
assert df.shape == (
n_pre_treatment + n_post_treatment,
Expand All @@ -133,6 +134,47 @@ def test_to_df(n_control: int, n_pre_treatment: int, n_post_treatment: int):
assert isinstance(index, DatetimeIndex)
assert index[0].strftime("%Y-%m-%d") == data._start_date.strftime("%Y-%m-%d")

@given(
n_control=st.integers(min_value=1, max_value=50),
n_pre_treatment=st.integers(min_value=1, max_value=50),
n_post_treatment=st.integers(min_value=1, max_value=50),
n_covariates=st.integers(min_value=1, max_value=50),
)
def test_to_df_with_cov(n_control: int,
n_pre_treatment: int,
n_post_treatment: int,
n_covariates:int):
constants = TestConstants(
N_POST_TREATMENT=n_post_treatment,
N_PRE_TREATMENT=n_pre_treatment,
N_CONTROL=n_control,
N_COVARIATES=n_covariates,
)
data = simulate_data(0.0, DEFAULT_SEED, constants=constants)

df_outs, df_covs = data.to_df()
assert isinstance(df_outs, pd.DataFrame)
assert df_outs.shape == (
n_pre_treatment + n_post_treatment,
n_control + NUM_NON_CONTROL_COLS,
)

assert isinstance(df_covs, pd.DataFrame)
assert df_covs.shape == (
n_pre_treatment + n_post_treatment,
n_covariates * (n_control + NUM_TREATED)
+ NUM_NON_CONTROL_COLS - NUM_TREATED,
)

colnames = data._get_columns()
assert isinstance(colnames, list)
assert colnames[0] == "T"
assert len(colnames) == n_control + 1

index = data.full_index
assert isinstance(index, DatetimeIndex)
assert index[0].strftime("%Y-%m-%d") == data._start_date.strftime("%Y-%m-%d")


@given(
n_control=st.integers(min_value=2, max_value=50),
Expand Down
Loading