Skip to content

Commit

Permalink
ttest rel core
Browse files Browse the repository at this point in the history
  • Loading branch information
Artem Vasin committed Jun 6, 2023
1 parent 89a4194 commit 564edd8
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 40 deletions.
15 changes: 15 additions & 0 deletions ambrosia/spark_tools/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2022 MTS (Mobile Telesystems)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

EMPTY_VALUE_PARTITION: int = 0
8 changes: 5 additions & 3 deletions ambrosia/spark_tools/split_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import ambrosia.spark_tools.stratification as strat_pkg
from ambrosia import types
from ambrosia.spark_tools.constants import EMPTY_VALUE_PARTITION
from ambrosia.tools import split_tools
from ambrosia.tools.import_tools import spark_installed

Expand All @@ -26,7 +27,6 @@
HASH_COLUMN_NAME: str = "__hashed_ambrosia_column"
GROUPS_COLUMN: str = "group"
ROW_NUMBER: str = "__row_number"
EMPTY_VALUE: int = 0


def unite_spark_tables(*dataframes: types.SparkDataFrame) -> types.SparkDataFrame:
Expand Down Expand Up @@ -90,7 +90,7 @@ def udf_make_labels(row_number: int) -> str:
label_ind = (row_number - 1) // groups_size
return labels[label_ind]

window = Window.orderBy(HASH_COLUMN_NAME).partitionBy(spark_funcs.lit(EMPTY_VALUE))
window = Window.orderBy(HASH_COLUMN_NAME).partitionBy(spark_funcs.lit(EMPTY_VALUE_PARTITION))
result = hashed_dataframe.withColumn(ROW_NUMBER, spark_funcs.row_number().over(window)).withColumn(
GROUPS_COLUMN, spark_funcs.udf(udf_make_labels)(spark_funcs.col(ROW_NUMBER))
)
Expand Down Expand Up @@ -128,7 +128,9 @@ def udf_make_labels_with_find(row_number: int):
not_used_ids.withColumn(
ROW_NUMBER,
spark_funcs.row_number().over(
Window.orderBy(spark_funcs.lit(EMPTY_VALUE)).partitionBy(spark_funcs.lit(EMPTY_VALUE))
Window.orderBy(spark_funcs.lit(EMPTY_VALUE_PARTITION)).partitionBy(
spark_funcs.lit(EMPTY_VALUE_PARTITION)
)
),
)
.withColumn(GROUPS_COLUMN, spark_funcs.udf(udf_make_labels_with_find)(ROW_NUMBER))
Expand Down
114 changes: 91 additions & 23 deletions ambrosia/spark_tools/stat_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@
import ambrosia.tools.pvalue_tools as pvalue_pkg
import ambrosia.tools.theoretical_tools as theory_pkg
from ambrosia import types
from ambrosia.spark_tools.constants import EMPTY_VALUE_PARTITION
from ambrosia.spark_tools.theory import get_stats_from_table
from ambrosia.tools.ab_abstract_component import ABStatCriterion
from ambrosia.tools.configs import Effects
from ambrosia.tools.import_tools import spark_installed
from ambrosia.tools.stat_criteria import TtestRelHelpful

if spark_installed():
import pyspark.sql.functions as F
from pyspark.sql.functions import col, row_number
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.functions import col, mean, row_number, variance
from pyspark.sql.window import Window


Expand Down Expand Up @@ -88,8 +91,7 @@ class TtestIndCriterionSpark(ABSparkCriterion):
Unit for pyspark independent T-test.
"""

__implemented_effect_types: List = ["absolute", "relative"]
__type_error_msg: str = f"Choose effect type from {__implemented_effect_types}"
implemented_effect_types: List = ["absolute", "relative"]
__data_parameters = ["mean_group_a", "mean_group_b", "std_group_a", "std_group_b", "nobs_group_a", "nobs_group_b"]

def __calc_and_cache_data_parameters(
Expand Down Expand Up @@ -127,8 +129,8 @@ def calculate_pvalue(
effect_type: str = "absolute",
**kwargs,
):
if effect_type not in TtestIndCriterionSpark.__implemented_effect_types:
raise ValueError(TtestIndCriterionSpark.__type_error_msg)
if effect_type not in self.implemented_effect_types:
raise ValueError(self._send_type_error_msg())

Check warning on line 133 in ambrosia/spark_tools/stat_criteria.py

View check run for this annotation

Codecov / codecov/patch

ambrosia/spark_tools/stat_criteria.py#L133

Added line #L133 was not covered by tests
if not self.parameters_are_cached:
self.__calc_and_cache_data_parameters(group_a, group_b, column)
if effect_type == "absolute":
Expand Down Expand Up @@ -163,7 +165,7 @@ def calculate_effect(
"mean_group_a"
]
else:
raise ValueError(TtestIndCriterionSpark.__type_error_msg)
raise ValueError(self._send_type_error_msg())

Check warning on line 168 in ambrosia/spark_tools/stat_criteria.py

View check run for this annotation

Codecov / codecov/patch

ambrosia/spark_tools/stat_criteria.py#L168

Added line #L168 was not covered by tests
return effect

def calculate_conf_interval(
Expand Down Expand Up @@ -194,10 +196,10 @@ def calculate_conf_interval(
conf_interval = self._apply_delta_method(alpha, **kwargs)[0]
return conf_interval
else:
raise ValueError(TtestIndCriterionSpark.__type_error_msg)
raise ValueError(self._send_type_error_msg())

Check warning on line 199 in ambrosia/spark_tools/stat_criteria.py

View check run for this annotation

Codecov / codecov/patch

ambrosia/spark_tools/stat_criteria.py#L199

Added line #L199 was not covered by tests


class TtestRelativeCriterionSpark(ABSparkCriterion):
class TtestRelativeCriterionSpark(ABSparkCriterion, TtestRelHelpful):
"""
Relative ttest for spark
"""
Expand All @@ -213,15 +215,23 @@ def _rename_col(column: str, group: str) -> str:
def _calc_and_cache_data_parameters(
self, group_a: types.SparkDataFrame, group_b: types.SparkDataFrame, column: types.ColumnNameType
) -> None:
a_ = (
col_a: str = self._rename_col(column, "a")
col_b: str = self._rename_col(column, "b")
a_: DataFrame = (

Check warning on line 220 in ambrosia/spark_tools/stat_criteria.py

View check run for this annotation

Codecov / codecov/patch

ambrosia/spark_tools/stat_criteria.py#L218-L220

Added lines #L218 - L220 were not covered by tests
group_a.withColumn(self.__ord_col, F.lit(1))
.withColumn(self.__add_index_name, row_number().over(Window().orderBy(self.__ord_col)))
.withColumnRenamed(column, self._rename_col(column, "a"))
.withColumn(
self.__add_index_name,
row_number().over(Window().orderBy(self.__ord_col).partitionBy(F.lit(EMPTY_VALUE_PARTITION))),
)
.withColumnRenamed(column, col_a)
)
b_ = (
b_: DataFrame = (

Check warning on line 228 in ambrosia/spark_tools/stat_criteria.py

View check run for this annotation

Codecov / codecov/patch

ambrosia/spark_tools/stat_criteria.py#L228

Added line #L228 was not covered by tests
group_b.withColumn(self.__ord_col, F.lit(1))
.withColumn(self.__add_index_name, row_number().over(Window().orderBy(self.__ord_col)))
.withColumnRenamed(column, self._rename_col(column, "b"))
.withColumn(
self.__add_index_name,
row_number().over(Window().orderBy(self.__ord_col).partitionBy(F.lit(EMPTY_VALUE_PARTITION))),
)
.withColumnRenamed(column, col_b)
)

n_a_obs: int = group_a.count()
Expand All @@ -230,11 +240,25 @@ def _calc_and_cache_data_parameters(
if n_a_obs != n_b_obs:
raise ValueError("Size of group A and B must be equal")

both = a_.join(b_, self.__add_index_name, "inner").withColumn(
self.__diff, col(self._rename_col(column, "b")) - col(self._rename_col(column, "a"))
)
both: DataFrame = a_.join(b_, self.__add_index_name, "inner").withColumn(self.__diff, col(col_b) - col(col_a))

Check warning on line 243 in ambrosia/spark_tools/stat_criteria.py

View check run for this annotation

Codecov / codecov/patch

ambrosia/spark_tools/stat_criteria.py#L243

Added line #L243 was not covered by tests

cov: float = both.stat.cov(col_a, col_b)
stats = both.select(

Check warning on line 246 in ambrosia/spark_tools/stat_criteria.py

View check run for this annotation

Codecov / codecov/patch

ambrosia/spark_tools/stat_criteria.py#L245-L246

Added lines #L245 - L246 were not covered by tests
variance(col_a).alias("__var_a"),
variance(col_b).alias("__var_b"),
mean(col_a).alias("__mean_a"),
mean(col_b).alias("__mean_b"),
).first()
var_a: float = theory_pkg.unbiased_to_sufficient(stats["__var_a"], n_a_obs, is_std=False)
var_b: float = theory_pkg.unbiased_to_sufficient(stats["__var_b"], n_a_obs, is_std=False)

Check warning on line 253 in ambrosia/spark_tools/stat_criteria.py

View check run for this annotation

Codecov / codecov/patch

ambrosia/spark_tools/stat_criteria.py#L252-L253

Added lines #L252 - L253 were not covered by tests

self.data_stats["mean"], self.data_stats["std"] = get_stats_from_table(both, self.__diff)
self.data_stats["n_obs"] = n_a_obs
self.data_stats["cov"] = cov
self.data_stats["var_a"] = var_a
self.data_stats["var_b"] = var_b
self.data_stats["mean_a"] = stats["__mean_a"]
self.data_stats["mean_b"] = stats["__mean_b"]

Check warning on line 261 in ambrosia/spark_tools/stat_criteria.py

View check run for this annotation

Codecov / codecov/patch

ambrosia/spark_tools/stat_criteria.py#L257-L261

Added lines #L257 - L261 were not covered by tests
self.parameters_are_cached = True

def calculate_pvalue(
Expand All @@ -247,30 +271,74 @@ def calculate_pvalue(
):
self._recalc_cache(group_a, group_b, column)
if effect_type == Effects.abs.value:
if "alternative" in kwargs:
kwargs["alternative"] = theory_pkg.switch_alternative(kwargs["alternative"])

Check warning on line 275 in ambrosia/spark_tools/stat_criteria.py

View check run for this annotation

Codecov / codecov/patch

ambrosia/spark_tools/stat_criteria.py#L274-L275

Added lines #L274 - L275 were not covered by tests
p_value = theory_pkg.ttest_1samp_from_stats(
mean=self.data_stats["mean"], std=self.data_stats["std"], n_obs=self.data_stats["n_obs"], **kwargs
)
)[
1
] # (stat, pvalue)
elif effect_type == Effects.rel.value:
raise NotImplementedError("Will be implemented later")
_, p_value = theory_pkg.apply_delta_method_by_stats(

Check warning on line 282 in ambrosia/spark_tools/stat_criteria.py

View check run for this annotation

Codecov / codecov/patch

ambrosia/spark_tools/stat_criteria.py#L282

Added line #L282 was not covered by tests
size=self.data_stats["n_obs"],
mean_group_a=self.data_stats["mean_a"],
mean_group_b=self.data_stats["mean_b"],
var_group_a=self.data_stats["var_a"],
var_group_b=self.data_stats["var_b"],
cov_groups=self.data_stats["cov"],
transformation="fraction",
)
else:
raise ValueError(self._send_type_error_msg())

Check warning on line 292 in ambrosia/spark_tools/stat_criteria.py

View check run for this annotation

Codecov / codecov/patch

ambrosia/spark_tools/stat_criteria.py#L292

Added line #L292 was not covered by tests
self._check_clear_cache()
return p_value

def calculate_conf_interval(
self,
group_a: types.SparkDataFrame,
group_b: types.SparkDataFrame,
column: str,
alpha: types.StatErrorType,
effect_type: str,
effect_type: str = Effects.abs.value,
**kwargs,
) -> List[Tuple]:
raise NotImplementedError("Will be implemented later")
self._recalc_cache(group_a, group_b, column)
if effect_type == Effects.abs.value:
confidence_intervals = self._build_intervals_absolute_from_stats(

Check warning on line 307 in ambrosia/spark_tools/stat_criteria.py

View check run for this annotation

Codecov / codecov/patch

ambrosia/spark_tools/stat_criteria.py#L305-L307

Added lines #L305 - L307 were not covered by tests
center=self.data_stats["mean"],
sd_1=self.data_stats["std"],
n_obs=self.data_stats["n_obs"],
alpha=alpha,
**kwargs,
)
elif effect_type == Effects.rel.value:
confidence_intervals, _ = theory_pkg.apply_delta_method_by_stats(

Check warning on line 315 in ambrosia/spark_tools/stat_criteria.py

View check run for this annotation

Codecov / codecov/patch

ambrosia/spark_tools/stat_criteria.py#L314-L315

Added lines #L314 - L315 were not covered by tests
size=self.data_stats["n_obs"],
mean_group_a=self.data_stats["mean_a"],
mean_group_b=self.data_stats["mean_group_b"],
var_group_a=self.data_stats["var_a"],
var_group_b=self.data_stats["var_b"],
cov_groups=self.data_stats["cov"],
alpha=alpha,
transformation="fraction",
**kwargs,
)
else:
raise ValueError(self._send_type_error_msg())
return confidence_intervals

Check warning on line 328 in ambrosia/spark_tools/stat_criteria.py

View check run for this annotation

Codecov / codecov/patch

ambrosia/spark_tools/stat_criteria.py#L327-L328

Added lines #L327 - L328 were not covered by tests

def calculate_effect(
self, group_a: types.SparkDataFrame, group_b: types.SparkDataFrame, column: str, effect_type: str
self,
group_a: types.SparkDataFrame,
group_b: types.SparkDataFrame,
column: str,
effect_type: str = Effects.abs.value,
) -> float:
self._recalc_cache(group_a, group_b, column)
if effect_type == Effects.abs.value:
effect: float = self.data_stats["mean"]
elif effect_type == Effects.rel.value:
effect: float = (self.data_stats["mean_b"] - self.data_stats["mean_a"]) / self.data_stats["mean_a"]

Check warning on line 341 in ambrosia/spark_tools/stat_criteria.py

View check run for this annotation

Codecov / codecov/patch

ambrosia/spark_tools/stat_criteria.py#L340-L341

Added lines #L340 - L341 were not covered by tests
else:
raise NotImplementedError("Will be implemented later")
raise ValueError(self._send_type_error_msg())

Check warning on line 343 in ambrosia/spark_tools/stat_criteria.py

View check run for this annotation

Codecov / codecov/patch

ambrosia/spark_tools/stat_criteria.py#L343

Added line #L343 was not covered by tests
return effect
4 changes: 2 additions & 2 deletions ambrosia/spark_tools/stratification.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@

import ambrosia.tools.ab_abstract_component as ab_abstract
from ambrosia import types
from ambrosia.spark_tools.constants import EMPTY_VALUE_PARTITION
from ambrosia.tools.import_tools import spark_installed

if spark_installed():
import pyspark.sql.functions as spark_funcs
from pyspark.sql import Window


EMPTY_VALUE: int = 0
STRAT_GROUPS: str = "__ambrosia_strat"


Expand All @@ -38,7 +38,7 @@ def fit(self, dataframe: types.SparkDataFrame, columns: Optional[Iterable[types.
self.strats = {ab_abstract.EmptyStratValue.NO_STRATIFICATION: dataframe}
return

window = Window.orderBy(*columns).partitionBy(spark_funcs.lit(EMPTY_VALUE))
window = Window.orderBy(*columns).partitionBy(spark_funcs.lit(EMPTY_VALUE_PARTITION))
with_groups = dataframe.withColumn(STRAT_GROUPS, spark_funcs.dense_rank().over(window))
amount_of_strats: int = with_groups.select(spark_funcs.max(STRAT_GROUPS)).collect()[0][0]

Expand Down
2 changes: 1 addition & 1 deletion ambrosia/tools/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_all_enum_values(cls) -> tp.List[str]:
@classmethod
def raise_if_value_incorrect_enum(cls, value: tp.Any) -> None:
if not cls.check_value_in_enum(value):
msg: str = f"Choose value from " + ", ".join(cls.get_all_enum_values())
msg: str = f"Choose value from {', '.join(cls.get_all_enum_values())}, your value - {value}"

Check warning on line 42 in ambrosia/tools/configs.py

View check run for this annotation

Codecov / codecov/patch

ambrosia/tools/configs.py#L42

Added line #L42 was not covered by tests
raise ValueError(msg)


Expand Down
37 changes: 28 additions & 9 deletions ambrosia/tools/stat_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,30 @@ def get_results(
return super().get_results(group_a, group_b, alpha, effect_type, **kwargs)


class TtestRelCriterion(ABStatCriterion):
class TtestRelHelpful:
def _build_intervals_absolute_from_stats(
self,
center: float,
sd_1: float,
n_obs: int,
alpha: types.StatErrorType = np.array([0.05]),
alternative: str = "two-sided",
):
"""
Helps handle different alternatives and build confidence interval
for related sampels
"""
alpha_corrected: float = pvalue_pkg.corrected_alpha(alpha, alternative)
std_error = sd_1 / np.sqrt(n_obs)
quantiles = sps.t.ppf(1 - alpha_corrected / 2, df=n_obs - 1)
left_ci: float = center - quantiles * std_error
right_ci: float = center + quantiles * std_error
left_ci, right_ci = pvalue_pkg.choose_from_bounds(left_ci, right_ci, alternative)
conf_intervals = list(zip(left_ci, right_ci))
return conf_intervals


class TtestRelCriterion(ABStatCriterion, TtestRelHelpful):
"""
Unit for relative paired T-test.
"""
Expand Down Expand Up @@ -149,14 +172,10 @@ def _build_intervals_absolute(
Helps handle different alternatives and build confidence interval
for related sampels
"""
alpha_corrected: float = pvalue_pkg.corrected_alpha(alpha, alternative)
std_error = np.sqrt(np.var(group_b - group_a, ddof=1) / len(group_a))
quantiles = sps.t.ppf(1 - alpha_corrected / 2, df=len(group_a) - 1)
left_ci: float = center - quantiles * std_error
right_ci: float = center + quantiles * std_error
left_ci, right_ci = pvalue_pkg.choose_from_bounds(left_ci, right_ci, alternative)
conf_intervals = list(zip(left_ci, right_ci))
return conf_intervals
sd_1: float = np.sqrt(np.var(group_b - group_a, ddof=1))
return self._build_intervals_absolute_from_stats(
center=center, sd_1=sd_1, n_obs=len(group_a), alpha=alpha, alternative=alternative
)

def calculate_conf_interval(
self,
Expand Down
14 changes: 12 additions & 2 deletions ambrosia/tools/theoretical_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@
ROUND_DIGITS_PERCENT: int = 1


def switch_alternative(alternative: str) -> str:
Alternatives.raise_if_value_incorrect_enum(alternative)
if alternative == Alternatives.ts.value:
return alternative
if alternative == Alternatives.less.value:
return Alternatives.gr.value
return Alternatives.less.value

Check warning on line 43 in ambrosia/tools/theoretical_tools.py

View check run for this annotation

Codecov / codecov/patch

ambrosia/tools/theoretical_tools.py#L38-L43

Added lines #L38 - L43 were not covered by tests


def get_stats(values: Iterable[float], ddof: int = 1) -> Tuple[float, float]:
"""
Calculate the mean and standard value for a list of values.
Expand All @@ -60,12 +69,13 @@ def check_encode_alternative(alternative: str) -> str:
return statsmodels_alternatives_encoding[alternative]


def unbiased_to_sufficient(std: float, size: int) -> float:
def unbiased_to_sufficient(std: float, size: int, is_std: bool = True) -> float:
"""
Transforms unbiased estimation of standard deviation to sufficient
(ddof = 1) => (ddof = 0)
If is_std = True, then transform std, else variance
"""
return std * np.sqrt((size - 1) / size)
return std * np.sqrt((size - 1) / size) if is_std else std * (size - 1) / size


def check_target_type(
Expand Down

0 comments on commit 564edd8

Please sign in to comment.