In [20]:
from typing import Union, Iterable

import numpy as np
import pandas as pd

from hypex import AATest
from hypex.utils.tutorial_data_creation import create_test_data

pd.options.display.float_format = "{:,.2f}".format

np.random.seed(42)  # needed to create example data

In [51]:
class RandomModel:
    def __init__(
        self,
        target_field: str,
        group_fields: Union[str, Iterable[str]] = None,
        diff_percentage: float = 0.01,
        sd_percentage: float = 0.1,
    ):
        self.target_field = target_field
        self.group_fields = group_fields
        self.diff_percentage = diff_percentage
        self.sd_percentage = sd_percentage

        self.params = {None: {}}

    def fill_na(self, data: pd.DataFrame):
        t_data = data.copy()
        t_data[self.group_fields] = t_data[self.group_fields].fillna("None")
        return t_data

    def fit(self, data: pd.DataFrame):
        t_data = self.fill_na(data)
        self.params["None"] = {
            "mean": t_data[self.target_field].mean(),
            "sd": t_data[self.target_field].std(),
        }
        if self.group_fields is not None:
            for group, group_data in t_data.groupby(self.group_fields):
                self.params[group] = {
                    "mean": group_data[self.target_field].mean(),
                    "sd": group_data[self.target_field].std(),
                }

        return self

    def predict(self, data: pd.DataFrame):
        t_data = self.fill_na(data)

        result = pd.concat(
            [
                pd.Series(
                    index=group_data.index,
                    data=np.random.normal(
                        self.params.get(group, self.params["None"])["mean"]
                        * self.diff_percentage,
                        self.params.get(group, self.params["None"])["sd"]
                        * self.sd_percentage,
                        len(group_data),
                    ),
                )
                for group, group_data in t_data.groupby(self.group_fields)
            ]
        )
        return t_data[self.target_field] + result

In [52]:
data = create_test_data(rs=52, na_step=10, nan_cols=['age', 'gender'])
data

In [53]:
model = RandomModel(
    target_field="post_spends", group_fields=["age", "gender", "industry"]
)
model = model.fit(data)
data["corrected_spends"] = model.predict(data)
data

Unnamed: 0,user_id,signup_month,treat,pre_spends,post_spends,age,gender,industry,corrected_spends
0,0,0,0,488.00,414.44,,M,E-commerce,413.60
1,1,8,1,512.50,462.22,26.00,,E-commerce,462.97
2,2,7,1,483.00,479.44,25.00,M,Logistics,481.38
3,3,0,0,501.50,424.33,39.00,M,E-commerce,433.31
4,4,1,1,543.00,514.56,18.00,F,E-commerce,515.90
...,...,...,...,...,...,...,...,...,...
9995,9995,10,1,538.50,450.44,42.00,M,Logistics,455.44
9996,9996,0,0,500.50,430.89,26.00,F,Logistics,430.69
9997,9997,3,1,473.00,534.11,22.00,F,E-commerce,540.99
9998,9998,2,1,495.00,523.22,67.00,F,E-commerce,524.65
