# Matching pipeline

The comparison method is used in statistical analysis to eliminate distortions caused by differences in the basic characteristics of the studied groups. Simply put, matching helps to make sure that the results of the experiment are really caused by the studied effect, and not by external factors.

Matching is most often performed in cases where the use of a standard AB test is impossible.


[Wiki Matching](https://github.com/sb-ai-lab/HypEx/wiki/Matching) with more detailed description of terms for Matching.

In [25]:
from hypex import Matching
from hypex.dataset import Dataset, FeatureRole, InfoRole, TargetRole, TreatmentRole

## Data preparation 

It is important to mark the data fields by assigning the appropriate roles:

* **FeatureRole**: columns with features or predictor variables. Matching is based on these. Applied by default if the role is not specified for the column.
* **TreatmentRole**: column indicating the treatment or intervention (should be binary: 0/1 or True/False).
* **TargetRole**: column with the target or outcome variable (numeric, e.g., spend, conversion).
* **InfoRole**: columns with information about the data, such as user IDs (should be unique identifiers).


In [26]:
data = Dataset(
    roles={
        "user_id": InfoRole(int),
        "treat": TreatmentRole(int),
        "post_spends": TargetRole(float)
    },
    data="data.csv",
    default_role=FeatureRole(),
)
data

Unnamed: 0,user_id,signup_month,treat,pre_spends,post_spends,age,gender,industry
0,0,0,0,488.0,414.444444,,M,E-commerce
1,1,8,1,512.5,462.222222,26.0,,E-commerce
2,2,7,1,483.0,479.444444,25.0,M,Logistics
3,3,0,0,501.5,424.333333,39.0,M,E-commerce
4,4,1,1,543.0,514.555556,18.0,F,E-commerce
...,...,...,...,...,...,...,...,...
9995,9995,10,1,538.5,450.444444,42.0,M,Logistics
9996,9996,0,0,500.5,430.888889,26.0,F,Logistics
9997,9997,3,1,473.0,534.111111,22.0,F,E-commerce
9998,9998,2,1,495.0,523.222222,67.0,F,E-commerce


In [27]:
data.roles

{'user_id': Info(<class 'int'>),
 'treat': Treatment(<class 'int'>),
 'post_spends': Target(<class 'float'>),
 'signup_month': Feature(<class 'int'>),
 'pre_spends': Feature(<class 'float'>),
 'age': Feature(<class 'float'>),
 'gender': Feature(<class 'str'>),
 'industry': Feature(<class 'str'>)}

## Simple Matching  
Matching consists of 4 main steps: 
1. **Dummy Encoder**: Converts categorical features to numeric (one-hot encoding).
2. **Process Mahalanobis distance**: Calculates distances between units using all features (default is Mahalanobis, can be changed).
3. **Two sides pairs searching by faiss**: Finds the best matches between treated and control units using fast nearest neighbor search.
4. **Metrics (ATT, ATC, ATE) estimation**: Calculates the effect based on matched pairs.

> **Common issues:**
- If you get errors about categorical features, check that all non-numeric columns are intended as features and will be encoded.
- If matching quality is poor, try changing the distance metric or reviewing your feature selection.

In [28]:
data = data.fillna(method="bfill")

In [29]:
test = Matching()
result = test.execute(data)

  return list(groups)
  return list(groups)
  return list(groups)
  return list(groups)
  return list(groups)
  return list(groups)


**ATT** (Average Treatment effect on the Treated): the estimated effect of the treatment for those who actually received it.

**ATC** (Average Treatment effect on the Controls): the estimated effect if the control group had received the treatment.

**ATE** (Average Treatment Effect): the overall average effect, combining ATT and ATC, weighted by group sizes.


In [30]:
result.resume

Unnamed: 0,Effect Size,Standard Error,P-value,CI Lower,CI Upper,outcome
ATT,63.37,2.45,0.0,58.57,68.16,post_spends
ATC,96.47,1.57,0.0,93.4,99.55,post_spends
ATE,80.13,1.44,0.0,77.31,82.95,post_spends


In [31]:
result.full_data

Unnamed: 0,user_id,signup_month,treat,pre_spends,post_spends,age,gender,industry,user_id_matched,signup_month_matched,treat_matched,pre_spends_matched,post_spends_matched,age_matched,gender_matched,industry_matched
0,0,0,0,488.0,414.444444,26.0,M,E-commerce,9433,1,1,488.5,518.444444,37.0,F,Logistics
1,1,8,1,512.5,462.222222,26.0,M,E-commerce,5438,0,0,529.0,417.111111,23.0,F,E-commerce
2,2,7,1,483.0,479.444444,25.0,M,Logistics,5165,0,0,498.5,412.222222,25.0,F,Logistics
3,3,0,0,501.5,424.333333,39.0,M,E-commerce,1735,1,1,504.0,516.333333,33.0,M,Logistics
4,4,1,1,543.0,514.555556,18.0,F,E-commerce,539,0,0,531.0,414.000000,20.0,F,E-commerce
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,9995,10,1,538.5,450.444444,42.0,M,Logistics,5893,0,0,535.0,414.555556,40.0,M,E-commerce
9996,9996,0,0,500.5,430.888889,26.0,F,Logistics,7731,1,1,500.0,515.888889,25.0,M,Logistics
9997,9997,3,1,473.0,534.111111,22.0,F,E-commerce,7066,0,0,480.0,423.222222,22.0,F,Logistics
9998,9998,2,1,495.0,523.222222,67.0,F,E-commerce,1885,0,0,499.0,423.000000,67.0,F,Logistics


In [32]:
result.indexes

Unnamed: 0,indexes
0,9433
1,5438
2,5165
3,1735
4,539
...,...
9995,5893
9996,7731
9997,7066
9998,1885


In [33]:
result.full_data.roles

{'user_id': Info(<class 'int'>),
 'treat': Treatment(<class 'int'>),
 'post_spends': Target(<class 'float'>),
 'signup_month': Feature(<class 'int'>),
 'pre_spends': Feature(<class 'float'>),
 'age': Feature(<class 'float'>),
 'gender': Feature(<class 'str'>),
 'industry': Feature(<class 'str'>),
 'user_id_matched': Info(<class 'int'>),
 'treat_matched': Treatment(<class 'int'>),
 'post_spends_matched': Target(<class 'float'>),
 'signup_month_matched': Feature(<class 'int'>),
 'pre_spends_matched': Feature(<class 'float'>),
 'age_matched': Feature(<class 'float'>),
 'gender_matched': Feature(<class 'str'>),
 'industry_matched': Feature(<class 'str'>)}

We can add **quality_tests** to evaluate balance of features after matching.
- **t-test** checks if feature means are similar across treatment and control groups.
- **ks-test** (Kolmogorov-Smirnov) checks if feature distributions are similar.

In [34]:
test = Matching(quality_tests=['t-test', 'ks-test'])
result = test.execute(data)

  return list(groups)
  return list(groups)
  return list(groups)
  return list(groups)
  return list(groups)
  return list(groups)


In [35]:
result.quality_results

Unnamed: 0,feature,group,TTest pass,TTest p-value,KSTest pass,KSTest p-value
0,signup_month,1┆signup_month,NOT OK,0.0,NOT OK,0.0
1,pre_spends,1┆pre_spends,NOT OK,1.8024199999999998e-212,NOT OK,3.2847500000000003e-231
2,age,1┆age,OK,0.9602563,OK,0.7186624


We can change the **metric** parameter to estimate different effects:
- `'att'`: effect for treated group (default)
- `'atc'`: effect for control group
- `'ate'`: average effect for all
- `'auto'`: automatically selects based on data

In [36]:
test = Matching(metric="atc")
result = test.execute(data)

  return list(groups)
  return list(groups)
  return list(groups)
  return list(groups)
  return list(groups)
  return list(groups)


In [37]:
result.resume

Unnamed: 0,Effect Size,Standard Error,P-value,CI Lower,CI Upper,outcome
ATC,96.47,0.14,0.0,96.21,96.74,post_spends


In [38]:
result.indexes

Unnamed: 0,indexes
0,9433
1,-1
2,-1
3,1735
4,-1
...,...
9995,-1
9996,7731
9997,-1
9998,-1


Also it is possible to search pairs only in **test group**. This way we have metric "auto" and **ATT** will be estimated. 

In [39]:
test = Matching(metric='att')
result = test.execute(data)

  return list(groups)
  return list(groups)
  return list(groups)
  return list(groups)
  return list(groups)
  return list(groups)


In [40]:
result.resume

Unnamed: 0,Effect Size,Standard Error,P-value,CI Lower,CI Upper,outcome
ATT,63.37,0.46,0.0,62.46,64.28,post_spends


In [41]:
result.indexes

Unnamed: 0,indexes
0,-1
1,5438
2,5165
3,-1
4,539
...,...
9995,5893
9996,-1
9997,7066
9998,1885


In [42]:
result.full_data

Unnamed: 0,user_id,signup_month,treat,pre_spends,post_spends,age,gender,industry,user_id_matched,signup_month_matched,treat_matched,pre_spends_matched,post_spends_matched,age_matched,gender_matched,industry_matched
0,0,0,0,488.0,414.444444,26.0,M,E-commerce,,,,,,,,
1,1,8,1,512.5,462.222222,26.0,M,E-commerce,5438.0,0.0,0.0,529.0,417.111111,23.0,F,E-commerce
2,2,7,1,483.0,479.444444,25.0,M,Logistics,5165.0,0.0,0.0,498.5,412.222222,25.0,F,Logistics
3,3,0,0,501.5,424.333333,39.0,M,E-commerce,,,,,,,,
4,4,1,1,543.0,514.555556,18.0,F,E-commerce,539.0,0.0,0.0,531.0,414.000000,20.0,F,E-commerce
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,9995,10,1,538.5,450.444444,42.0,M,Logistics,5893.0,0.0,0.0,535.0,414.555556,40.0,M,E-commerce
9996,9996,0,0,500.5,430.888889,26.0,F,Logistics,,,,,,,,
9997,9997,3,1,473.0,534.111111,22.0,F,E-commerce,7066.0,0.0,0.0,480.0,423.222222,22.0,F,Logistics
9998,9998,2,1,495.0,523.222222,67.0,F,E-commerce,1885.0,0.0,0.0,499.0,423.000000,67.0,F,Logistics


Finally, you can change the distance metric used for matching. By default, Mahalanobis distance is used, but you can also use L2 (Euclidean) distance.

- **Mahalanobis**: Takes into account correlations between features; recommended for most cases.
- **L2 (Euclidean)**: Simpler, may work well if features are uncorrelated and similarly scaled.

> **Tip:** If matching quality is poor or you get warnings about singular matrices, try switching the distance metric.

In [43]:
test = Matching(distance="l2", metric='att')
result = test.execute(data)

  return list(groups)
  return list(groups)
  return list(groups)
  return list(groups)
  return list(groups)


In [44]:
result.resume

Unnamed: 0,Effect Size,Standard Error,P-value,CI Lower,CI Upper,outcome
ATT,63.37,0.46,0.0,62.46,64.27,post_spends


In [45]:
result.indexes

Unnamed: 0,indexes
0,-1
1,2490
2,5493
3,-1
4,321
...,...
9995,5893
9996,-1
9997,8670
9998,507


In [46]:
result.full_data

Unnamed: 0,user_id,signup_month,treat,pre_spends,post_spends,age,gender,industry,user_id_matched,signup_month_matched,treat_matched,pre_spends_matched,post_spends_matched,age_matched,gender_matched,industry_matched
0,0,0,0,488.0,414.444444,26.0,M,E-commerce,,,,,,,,
1,1,8,1,512.5,462.222222,26.0,M,E-commerce,2490.0,0.0,0.0,511.5,417.444444,27.0,F,E-commerce
2,2,7,1,483.0,479.444444,25.0,M,Logistics,5493.0,0.0,0.0,483.0,408.000000,25.0,M,E-commerce
3,3,0,0,501.5,424.333333,39.0,M,E-commerce,,,,,,,,
4,4,1,1,543.0,514.555556,18.0,F,E-commerce,321.0,0.0,0.0,538.0,421.444444,29.0,M,E-commerce
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,9995,10,1,538.5,450.444444,42.0,M,Logistics,5893.0,0.0,0.0,535.0,414.555556,40.0,M,E-commerce
9996,9996,0,0,500.5,430.888889,26.0,F,Logistics,,,,,,,,
9997,9997,3,1,473.0,534.111111,22.0,F,E-commerce,8670.0,0.0,0.0,473.0,415.777778,22.0,F,Logistics
9998,9998,2,1,495.0,523.222222,67.0,F,E-commerce,507.0,0.0,0.0,495.0,429.777778,67.0,F,Logistics
