# Matching pipeline

The comparison method is used in statistical analysis to eliminate distortions caused by differences in the basic characteristics of the studied groups. Simply put, matching helps to make sure that the results of the experiment are really caused by the studied effect, and not by external factors.

Matching is most often performed in cases where the use of a standard AB test is impossible.

In [117]:

from hypex import Matching
from hypex.dataset import (
    Dataset,
    FeatureRole,
    GroupingRole,
    InfoRole,
    TargetRole,
    TreatmentRole,
)

## Data preparation 

It is important to mark the data fields by assigning the appropriate roles:

* FeatureRole: a role for columns that contain features or predictor variables. Our split will be based on them. Applied by default if the role is not specified for the column.
* TreatmentRole: a role for columns that show the treatment or intervention.
* TargetRole: a role for columns that show the target or outcome variable.
* InfoRole: a role for columns that contain information about the data, such as user IDs.

In [118]:
data = Dataset(
    roles={
        "user_id": InfoRole(int),
        "treat": TreatmentRole(int),
        "post_spends": TargetRole(float),
        # "gender": FeatureRole(str),
        # "pre_spends": FeatureRole(float),
        # "industry": FeatureRole(str),
    },
    data="data.csv",
    default_role=FeatureRole(),
)
data

Unnamed: 0,user_id,signup_month,treat,pre_spends,post_spends,age,gender,industry
0,0,0,0,488.0,414.444444,,M,E-commerce
1,1,8,1,512.5,462.222222,26.0,,E-commerce
2,2,7,1,483.0,479.444444,25.0,M,Logistics
3,3,0,0,501.5,424.333333,39.0,M,E-commerce
4,4,1,1,543.0,514.555556,18.0,F,E-commerce
...,...,...,...,...,...,...,...,...
9995,9995,10,1,538.5,450.444444,42.0,M,Logistics
9996,9996,0,0,500.5,430.888889,26.0,F,Logistics
9997,9997,3,1,473.0,534.111111,22.0,F,E-commerce
9998,9998,2,1,495.0,523.222222,67.0,F,E-commerce


In [119]:
data.roles

{'user_id': Info(<class 'int'>),
 'treat': Treatment(<class 'int'>),
 'post_spends': Target(<class 'float'>),
 'signup_month': Feature(<class 'int'>),
 'pre_spends': Feature(<class 'float'>),
 'age': Feature(<class 'float'>),
 'gender': Feature(<class 'str'>),
 'industry': Feature(<class 'str'>)}

## Simple Matching  
Now matching has 4 steps: 
1. Dummy Encoder 
2. Process Mahalanobis distance 
3. Two sides pairs searching by faiss 
4. Metrics (ATT, ATC, ATE) estimation depends on your data 

In [120]:
data = data.fillna(method="bfill")

In [121]:
test = Matching(quality_tests=["t-test", "ks-test"])
result = test.execute(data)

**ATT** shows the difference in treated group.   
**ATC** shows the difference in untreated group.   
**ATE** shows the weighted average difference between ATT and ATC.  

In [88]:
result.resume

Unnamed: 0,Effect Size,Standard Error,P-value,CI Lower,CI Upper,outcome
ATT,63.48,1.08,0.0,61.36,65.6,post_spends
ATC,97.71,1.23,0.0,95.3,100.13,post_spends
ATE,80.82,0.78,0.0,79.29,82.35,post_spends


In [89]:
result.full_data

Unnamed: 0,user_id,signup_month,treat,pre_spends,post_spends,age,gender,industry,user_id_matched_0,signup_month_matched_0,treat_matched_0,pre_spends_matched_0,post_spends_matched_0,age_matched_0,gender_matched_0,industry_matched_0
0,0,0,0,488.0,414.444444,26.0,M,E-commerce,4297,2,1,482.0,519.666667,26.0,M,E-commerce
1,1,8,1,512.5,462.222222,26.0,M,E-commerce,4961,0,0,526.5,416.666667,23.0,M,E-commerce
2,2,7,1,483.0,479.444444,25.0,M,Logistics,3594,0,0,499.5,424.666667,25.0,M,Logistics
3,3,0,0,501.5,424.333333,39.0,M,E-commerce,3987,1,1,507.5,522.000000,34.0,M,E-commerce
4,4,1,1,543.0,514.555556,18.0,F,E-commerce,539,0,0,531.0,414.000000,20.0,F,E-commerce
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,9995,10,1,538.5,450.444444,42.0,M,Logistics,5517,0,0,529.5,427.555556,44.0,M,Logistics
9996,9996,0,0,500.5,430.888889,26.0,F,Logistics,924,1,1,503.0,531.555556,27.0,F,Logistics
9997,9997,3,1,473.0,534.111111,22.0,F,E-commerce,4237,0,0,479.5,430.333333,23.0,F,E-commerce
9998,9998,2,1,495.0,523.222222,67.0,F,E-commerce,754,0,0,499.0,397.666667,66.0,F,E-commerce


In [90]:
result.quality_results

Unnamed: 0,feature,group,TTest pass,TTest p-value,KSTest pass,KSTest p-value
0,signup_month,0┆signup_month,OK,0.0,OK,0.0
1,signup_month,1┆signup_month,OK,0.0,OK,0.0
2,pre_spends,0┆pre_spends,OK,5.073003e-07,OK,2.938172e-24
3,pre_spends,1┆pre_spends,OK,7.154501000000001e-204,OK,8.668437e-232
4,age,0┆age,OK,0.8563897,OK,1.546155e-05
5,age,1┆age,OK,0.7210192,OK,0.8134795


In [91]:
result.indexes

Unnamed: 0,indexes_0
0,4297
1,4961
2,3594
3,3987
4,539
...,...
9995,5517
9996,924
9997,4237
9998,754


In [92]:
result.full_data.roles

{'user_id': Info(<class 'int'>),
 'treat': Treatment(<class 'int'>),
 'post_spends': Target(<class 'float'>),
 'signup_month': Feature(<class 'int'>),
 'pre_spends': Feature(<class 'float'>),
 'age': Feature(<class 'float'>),
 'gender': Feature(<class 'str'>),
 'industry': Feature(<class 'str'>),
 'user_id_matched_0': Info(<class 'int'>),
 'treat_matched_0': Treatment(<class 'int'>),
 'post_spends_matched_0': Target(<class 'float'>),
 'signup_month_matched_0': Feature(<class 'int'>),
 'pre_spends_matched_0': Feature(<class 'float'>),
 'age_matched_0': Feature(<class 'float'>),
 'gender_matched_0': Feature(<class 'str'>),
 'industry_matched_0': Feature(<class 'str'>)}

We can change **metric** and do estimation again.

In [93]:
test = Matching(metric="atc")
result = test.execute(data)

  res = hypotest_fun_out(*samples, **kwds)
  res = hypotest_fun_out(*samples, **kwds)
  res = hypotest_fun_out(*samples, **kwds)
  new_data = pd.concat([self.data] + [d.data for d in other], axis=axis)
  new_data = pd.concat([self.data] + [d.data for d in other], axis=axis)


In [94]:
result.resume

Unnamed: 0,Effect Size,Standard Error,P-value,CI Lower,CI Upper,outcome
ATC,97.71,0.15,0.0,97.42,98.01,post_spends


In [95]:
result.quality_results

Unnamed: 0,feature,group,TTest pass,TTest p-value,KSTest pass,KSTest p-value,Chi2Test pass,Chi2Test p-value
0,signup_month,0┆signup_month,OK,0.0,OK,0.0,,
1,signup_month,1┆signup_month,OK,3.403901e-108,OK,0.0,,
2,pre_spends,0┆pre_spends,OK,5.073003e-07,OK,2.938172e-24,,
3,pre_spends,1┆pre_spends,OK,0.0,OK,0.0,,
4,age,0┆age,OK,0.8563897,OK,1.546155e-05,,
5,age,1┆age,OK,4.893333e-145,OK,0.0,,
6,gender,0┆gender,,,,,OK,1.0
7,gender,1┆gender,,,,,OK,
8,industry,0┆industry,,,,,OK,1.0
9,industry,1┆industry,,,,,OK,


In [96]:
result.indexes

Unnamed: 0,indexes_0
0,4297
1,-1
2,-1
3,3987
4,-1
...,...
9995,-1
9996,924
9997,-1
9998,-1


Also it is possible to search pairs only in **test group**. This way we have metric "auto" and **ATT** will be estimated. 

In [97]:
test = Matching(metric='att')
result = test.execute(data)

  res = hypotest_fun_out(*samples, **kwds)
  res = hypotest_fun_out(*samples, **kwds)
  res = hypotest_fun_out(*samples, **kwds)
  new_data = pd.concat([self.data] + [d.data for d in other], axis=axis)
  new_data = pd.concat([self.data] + [d.data for d in other], axis=axis)


In [98]:
result.resume

Unnamed: 0,Effect Size,Standard Error,P-value,CI Lower,CI Upper,outcome
ATT,63.48,0.47,0.0,62.57,64.4,post_spends


In [99]:
result.indexes

Unnamed: 0,indexes_0
0,-1
1,4961
2,3594
3,-1
4,539
...,...
9995,5517
9996,-1
9997,4237
9998,754


In [100]:
result.full_data

Unnamed: 0,user_id,signup_month,treat,pre_spends,post_spends,age,gender,industry,user_id_matched_0,signup_month_matched_0,treat_matched_0,pre_spends_matched_0,post_spends_matched_0,age_matched_0,gender_matched_0,industry_matched_0
0,0,0,0,488.0,414.444444,26.0,M,E-commerce,,,,,,,,
1,1,8,1,512.5,462.222222,26.0,M,E-commerce,4961.0,0.0,0.0,526.5,416.666667,23.0,M,E-commerce
2,2,7,1,483.0,479.444444,25.0,M,Logistics,3594.0,0.0,0.0,499.5,424.666667,25.0,M,Logistics
3,3,0,0,501.5,424.333333,39.0,M,E-commerce,,,,,,,,
4,4,1,1,543.0,514.555556,18.0,F,E-commerce,539.0,0.0,0.0,531.0,414.000000,20.0,F,E-commerce
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,9995,10,1,538.5,450.444444,42.0,M,Logistics,5517.0,0.0,0.0,529.5,427.555556,44.0,M,Logistics
9996,9996,0,0,500.5,430.888889,26.0,F,Logistics,,,,,,,,
9997,9997,3,1,473.0,534.111111,22.0,F,E-commerce,4237.0,0.0,0.0,479.5,430.333333,23.0,F,E-commerce
9998,9998,2,1,495.0,523.222222,67.0,F,E-commerce,754.0,0.0,0.0,499.0,397.666667,66.0,F,E-commerce


Finally, we may search pairs in L2 distance. 

In [101]:
test = Matching(distance="l2", metric='att')
result = test.execute(data)

  res = hypotest_fun_out(*samples, **kwds)
  res = hypotest_fun_out(*samples, **kwds)
  res = hypotest_fun_out(*samples, **kwds)
  new_data = pd.concat([self.data] + [d.data for d in other], axis=axis)
  new_data = pd.concat([self.data] + [d.data for d in other], axis=axis)


In [102]:
result.resume

Unnamed: 0,Effect Size,Standard Error,P-value,CI Lower,CI Upper,outcome
ATT,63.37,0.46,0.0,62.46,64.27,post_spends


In [103]:
result.indexes

Unnamed: 0,indexes_0
0,-1
1,2490
2,5493
3,-1
4,321
...,...
9995,5893
9996,-1
9997,8670
9998,507


In [104]:
result.full_data

Unnamed: 0,user_id,signup_month,treat,pre_spends,post_spends,age,gender,industry,user_id_matched_0,signup_month_matched_0,treat_matched_0,pre_spends_matched_0,post_spends_matched_0,age_matched_0,gender_matched_0,industry_matched_0
0,0,0,0,488.0,414.444444,26.0,M,E-commerce,,,,,,,,
1,1,8,1,512.5,462.222222,26.0,M,E-commerce,2490.0,0.0,0.0,511.5,417.444444,27.0,F,E-commerce
2,2,7,1,483.0,479.444444,25.0,M,Logistics,5493.0,0.0,0.0,483.0,408.000000,25.0,M,E-commerce
3,3,0,0,501.5,424.333333,39.0,M,E-commerce,,,,,,,,
4,4,1,1,543.0,514.555556,18.0,F,E-commerce,321.0,0.0,0.0,538.0,421.444444,29.0,M,E-commerce
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,9995,10,1,538.5,450.444444,42.0,M,Logistics,5893.0,0.0,0.0,535.0,414.555556,40.0,M,E-commerce
9996,9996,0,0,500.5,430.888889,26.0,F,Logistics,,,,,,,,
9997,9997,3,1,473.0,534.111111,22.0,F,E-commerce,8670.0,0.0,0.0,473.0,415.777778,22.0,F,Logistics
9998,9998,2,1,495.0,523.222222,67.0,F,E-commerce,507.0,0.0,0.0,495.0,429.777778,67.0,F,Logistics


## Group Matching

Finds the matches strictly within the groups defined by GroupRole.

In [105]:
data = Dataset(
    roles={
        "user_id": InfoRole(int),
        "treat": TreatmentRole(int),
        "post_spends": TargetRole(float),
        "gender": GroupingRole(str),
        "industry": FeatureRole(),
    },
    data="data.csv",
    # default_role=FeatureRole(),
)
data

Unnamed: 0,user_id,signup_month,treat,pre_spends,post_spends,age,gender,industry
0,0,0,0,488.0,414.444444,,M,E-commerce
1,1,8,1,512.5,462.222222,26.0,,E-commerce
2,2,7,1,483.0,479.444444,25.0,M,Logistics
3,3,0,0,501.5,424.333333,39.0,M,E-commerce
4,4,1,1,543.0,514.555556,18.0,F,E-commerce
...,...,...,...,...,...,...,...,...
9995,9995,10,1,538.5,450.444444,42.0,M,Logistics
9996,9996,0,0,500.5,430.888889,26.0,F,Logistics
9997,9997,3,1,473.0,534.111111,22.0,F,E-commerce
9998,9998,2,1,495.0,523.222222,67.0,F,E-commerce


In [106]:
data = data.fillna(method="bfill")
test = Matching(group_match=True, n_neighbors=2)
result = test.execute(data)

  new_data = pd.concat([self.data] + [d.data for d in other], axis=axis)
  new_data = pd.concat([self.data] + [d.data for d in other], axis=axis)
100%|██████████| 2/2 [00:01<00:00,  1.63it/s]


In [107]:
result.resume

Unnamed: 0,F Effect Size,M Effect Size,F Standard Error,M Standard Error,F P-value,M P-value,F CI Lower,M CI Lower,F CI Upper,M CI Upper,outcome
ATT,59.8,64.36,26.86,8.39,0.03,0.0,7.15,47.92,112.45,80.81,post_spends
ATC,54.0,50.94,24.31,22.47,0.03,0.02,6.34,6.91,101.66,94.98,post_spends
ATE,56.79,57.75,18.19,12.12,0.0,0.0,21.13,34.0,92.44,81.5,post_spends


In [108]:
result.indexes

Unnamed: 0,indexes_0,indexes_1
0,1,5
1,0,3
2,12,16
3,1,5
4,25,27
...,...,...
9995,12,16
9996,6,19
9997,25,27
9998,25,27


In [109]:
result.quality_results

In [110]:
result.full_data

Unnamed: 0,user_id,signup_month,treat,pre_spends,post_spends,age,gender,industry,user_id_matched_0,signup_month_matched_0,...,gender_matched_0,industry_matched_0,user_id_matched_1,signup_month_matched_1,treat_matched_1,pre_spends_matched_1,post_spends_matched_1,age_matched_1,gender_matched_1,industry_matched_1
0,0,0,0,488.0,414.444444,26.0,M,E-commerce,1,8,...,M,E-commerce,1,8,1,512.5,462.222222,26.0,M,E-commerce
1,1,8,1,512.5,462.222222,26.0,M,E-commerce,0,0,...,M,E-commerce,0,0,0,488.0,414.444444,26.0,M,E-commerce
2,2,7,1,483.0,479.444444,25.0,M,Logistics,12,0,...,M,Logistics,12,0,0,472.0,423.777778,42.0,M,Logistics
3,3,0,0,501.5,424.333333,39.0,M,E-commerce,1,8,...,M,E-commerce,1,8,1,512.5,462.222222,26.0,M,E-commerce
4,4,1,1,543.0,514.555556,18.0,F,E-commerce,25,0,...,F,E-commerce,25,0,0,499.5,425.777778,24.0,F,E-commerce
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,9995,10,1,538.5,450.444444,42.0,M,Logistics,12,0,...,M,Logistics,12,0,0,472.0,423.777778,42.0,M,Logistics
9996,9996,0,0,500.5,430.888889,26.0,F,Logistics,6,11,...,F,Logistics,6,11,1,483.5,433.888889,28.0,F,Logistics
9997,9997,3,1,473.0,534.111111,22.0,F,E-commerce,25,0,...,F,E-commerce,25,0,0,499.5,425.777778,24.0,F,E-commerce
9998,9998,2,1,495.0,523.222222,67.0,F,E-commerce,25,0,...,F,E-commerce,25,0,0,499.5,425.777778,24.0,F,E-commerce


## Custom features weights

You can assign custom weights to features, enhancing the matching precision to suit your specific research needs

In [111]:
data = Dataset(
    roles={
        "user_id": InfoRole(int),
        "treat": TreatmentRole(int),
        "post_spends": TargetRole(float),
        "gender": FeatureRole(str),
        "industry": FeatureRole(str),
        "age": FeatureRole(float),
        "signup_month": FeatureRole(int),
        "pre_spends": FeatureRole(float),
    },
    data="data.csv",
)
data

Unnamed: 0,user_id,signup_month,treat,pre_spends,post_spends,age,gender,industry
0,0,0,0,488.0,414.444444,,M,E-commerce
1,1,8,1,512.5,462.222222,26.0,,E-commerce
2,2,7,1,483.0,479.444444,25.0,M,Logistics
3,3,0,0,501.5,424.333333,39.0,M,E-commerce
4,4,1,1,543.0,514.555556,18.0,F,E-commerce
...,...,...,...,...,...,...,...,...
9995,9995,10,1,538.5,450.444444,42.0,M,Logistics
9996,9996,0,0,500.5,430.888889,26.0,F,Logistics
9997,9997,3,1,473.0,534.111111,22.0,F,E-commerce
9998,9998,2,1,495.0,523.222222,67.0,F,E-commerce


You've got to make sure you assign the weights to every feature and that the sum of the weights is equal to 1

In [112]:
data = data.fillna(method="bfill")
test = Matching(weights={"gender": 0.2, "industry": 0.3, "age": 0.1, "signup_month": 0.1, "pre_spends": 0.3})
result = test.execute(data)

In [113]:
result.full_data

Unnamed: 0,user_id,signup_month,treat,pre_spends,post_spends,age,gender,industry,user_id_matched_0,signup_month_matched_0,treat_matched_0,pre_spends_matched_0,post_spends_matched_0,age_matched_0,gender_matched_0,industry_matched_0
0,0,0,0,488.0,414.444444,26.0,M,E-commerce,7507,2,1,487.0,510.666667,28.0,M,E-commerce
1,1,8,1,512.5,462.222222,26.0,M,E-commerce,260,0,0,523.0,414.333333,26.0,M,E-commerce
2,2,7,1,483.0,479.444444,25.0,M,Logistics,9769,0,0,492.5,417.000000,25.0,M,Logistics
3,3,0,0,501.5,424.333333,39.0,M,E-commerce,3987,1,1,507.5,522.000000,34.0,M,E-commerce
4,4,1,1,543.0,514.555556,18.0,F,E-commerce,539,0,0,531.0,414.000000,20.0,F,E-commerce
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,9995,10,1,538.5,450.444444,42.0,M,Logistics,5517,0,0,529.5,427.555556,44.0,M,Logistics
9996,9996,0,0,500.5,430.888889,26.0,F,Logistics,924,1,1,503.0,531.555556,27.0,F,Logistics
9997,9997,3,1,473.0,534.111111,22.0,F,E-commerce,4937,0,0,476.5,425.888889,23.0,F,E-commerce
9998,9998,2,1,495.0,523.222222,67.0,F,E-commerce,4968,0,0,497.5,426.444444,68.0,F,E-commerce


## Bias estimation

Bias estimation can be disabled by setting "bias_estimation" argument to False

In [114]:
data.data.head()

Unnamed: 0,user_id,signup_month,treat,pre_spends,post_spends,age,gender,industry
0,0,0,0,488.0,414.444444,26.0,M,E-commerce
1,1,8,1,512.5,462.222222,26.0,M,E-commerce
2,2,7,1,483.0,479.444444,25.0,M,Logistics
3,3,0,0,501.5,424.333333,39.0,M,E-commerce
4,4,1,1,543.0,514.555556,18.0,F,E-commerce


In [115]:
test = Matching(bias_estimation=False)
result = test.execute(data)

In [116]:
result.resume

Unnamed: 0,Effect Size,Standard Error,P-value,CI Lower,CI Upper,outcome
ATT,63.56,1.08,0.0,61.44,65.68,post_spends
ATC,99.74,1.23,0.0,97.32,102.15,post_spends
ATE,81.88,0.78,0.0,80.35,83.41,post_spends
