# Causal effect estimation of diversity on weight

The notebook reproduces the real data diversity analysis in the paper [__A causal view on compositional data__](https://arxiv.org/abs/2106.11234) by E. Ailer, C. L. Müller and N. Kilbertus.

In this section we explain the figure on the one-dimensional causal effect estimation of diversity on the weight. We apply the different methods on the Shannon diversity and the Simpson diversity to show that diversity might produce misleading results. This mostly serves as our motivation to then look deeper into the higher dimensional approach of using the whole compositional vector for the microbiome.


<div class="alert alert-block alert-info">
<b>Note:</b>    
    
We perform the analysis for __Shannon diversity__ in full length. For __Simpson diversity__, the analysis can be performed in the exact same way, except that $X$ has to exchanged by the Simpson diversity estimate instead of the Shannon diversity estimate.
</div>

In [1]:
# standard libraries
import sys
import os
sys.path.append(os.path.join(os.getcwd(), "../src"))
import pandas as pd
import statsmodels.api as sm
from statsmodels.sandbox.regression.gmm import IV2SLS
import plotly.graph_objects as go


# paper relevant function files
from helper_fct import *
import boundiv
import kiv
from plot_fct import colours, update_layout, plot_diversity_methods

# Shannon Diversity

##  Read in data

<div class="alert alert-block alert-info">
<b>Note:</b>    
    
Unfortunately, we personally cannot share the data. We refer to the paper for further references. However, the methods are implemented and work in any one-dimensional instrumental variable setting.
</div>


Both the Shannon as well as the Simpson Diversity have been estimated in R via the package [DivNet](https://github.com/adw96/DivNet) and [breakaway](https://github.com/adw96/breakaway) by Amy Willis.



In [4]:
data = pd.read_csv("data_shannon.csv")

X = data["estimate"]
n = X.shape[0]
Y = data["weight"]
Z = np.array([data["Treatment_x"][i].find("control") for i in range(n)])*(-1)



We whiten the data. Please make sure, to also whiten $Z$, if this is not a binary indicator.

In [5]:
Z = Z  # whiten_data(Z)
X = whiten_data(X)
Y = whiten_data(Y)

In [36]:
z = Z.squeeze()
y = Y.values.squeeze()
x = X.values.squeeze()

# Add constants to the feature matrices
zz = sm.add_constant(z)
xx = sm.add_constant(x)

# xstar for evaluation purpose
xstar = np.linspace(x.min(), x.max(), 30)

## First Stage Regression for F Test


<div class="alert alert-block alert-info">
<b>Note:</b> 
The first stage F Test is used for evaluating the strength of the instrument. In general a F-value over 10 is considered to be a strong enough instrument. However, note that this is only applicable in the one-dimensional setting.
</div>

In [10]:
ols1 = sm.OLS(x, zz).fit()
xhat = ols1.predict(zz)
print(ols1.summary())

                            OLS Regression Results                            
Dep. Variable:                      y   R-squared:                       0.129
Model:                            OLS   Adj. R-squared:                  0.113
Method:                 Least Squares   F-statistic:                     8.147
Date:                 Mo, 12 Jul 2021   Prob (F-statistic):            0.00607
Time:                        11:12:51   Log-Likelihood:                -76.438
No. Observations:                  57   AIC:                             156.9
Df Residuals:                      55   BIC:                             161.0
Df Model:                           1                                         
Covariance Type:            nonrobust                                         
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
const         -0.4490      0.201     -2.237      0.0

## Run Methods on Real Microbiome Dataset

In the following we try different methods for the dataset. We use __OLS__, __2SLS__, __KIV__ and the __Bounding Method__ to get a whole picture for the causal effect of the diversity on the weight. All three methods have different assumptions on the underlying data.

### Ordinary Least Squares


__OLS__ is used as a benchmark. This is the regression not making use of the instrument but just predicting $Y$ entirely from $X$.

In [22]:
ols = sm.OLS(y, xx).fit()
coeff_ols = ols.params
ystar_ols = coeff_ols[0] + coeff_ols[1]*xstar

### Two Stage Least Squares

__TSLS__ assumes linearity and additive noise between $Z$, $X$, $Y$ and the confounding term.

In [23]:
iv2sls = IV2SLS(y, xx, zz).fit()
coeff_2sls = iv2sls.params
ystar_2sls = coeff_2sls[0] + coeff_2sls[1]*xstar

### KIV

__KIV__ allows for non-linear modelling of the relationships between $Z$, $X$ and $Y$ while still assuming additive noise.

In [24]:
xstar, ystar_kiv = kiv.fit_kiv(z, x, y, xstar=xstar)

### Bounding Method

For the __Bounding Method__ we have a large function class for the relationships between $Z$, $X$ and $Y$ and do no longer need to assume additive noise. This, however, means that we can estimate only upper and lower bounds for the causal effect.

In [12]:
xstar_bound = np.linspace(np.quantile(x, 0.1), np.quantile(x, 0.9), 10)

In [13]:
satis, results = boundiv.fit_bounds(z, x, y, xstar_bound, slack=0.6)

2
<function api_boundary.<locals>.reraise_with_filtered_traceback at 0x7fbb3d15d400>
0
1
0
1


100%|██████████| 150/150 [00:13<00:00, 10.73it/s]
  1%|▏         | 2/150 [00:00<00:10, 14.71it/s]

{'indices': DeviceArray([[150,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[-2.3275762,  0.       ],
             [ 0.       ,  0.       ],
             [ 0.       ,  0.       ],
             [ 0.       ,  0.       ],
             [ 0.       ,  0.       ],
             [ 0.       ,  0.       ],
             [ 0.       ,  0.       ],
             [ 0.       ,  0.       ],
             [ 0.       ,  0.       ],
             [ 0.       ,  0.       ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.42314655, 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.

100%|██████████| 150/150 [00:09<00:00, 15.34it/s]
  1%|▏         | 2/150 [00:00<00:12, 11.46it/s]

{'indices': DeviceArray([[150, 150],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[-2.3275762 , -0.13899913],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.42314655, 0.2819839 ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.       

100%|██████████| 150/150 [00:09<00:00, 15.05it/s]
  1%|▏         | 2/150 [00:00<00:09, 15.90it/s]

{'indices': DeviceArray([[150, 150],
             [150,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[-2.3275762 , -0.13899913],
             [-1.7910798 ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.42314655, 0.2819839 ],
             [0.42729062, 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.       

100%|██████████| 150/150 [00:18<00:00,  8.08it/s]
  1%|▏         | 2/150 [00:00<00:10, 14.16it/s]

{'indices': DeviceArray([[150, 150],
             [150, 150],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[-2.3275762 , -0.13899913],
             [-1.7910798 , -0.0131385 ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.42314655, 0.2819839 ],
             [0.42729062, 0.3739583 ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.       

100%|██████████| 150/150 [00:09<00:00, 15.24it/s]
  1%|▏         | 2/150 [00:00<00:10, 14.18it/s]

{'indices': DeviceArray([[150, 150],
             [150, 150],
             [150,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[-2.3275762 , -0.13899913],
             [-1.7910798 , -0.0131385 ],
             [-1.3313822 ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.42314655, 0.2819839 ],
             [0.42729062, 0.3739583 ],
             [0.43405968, 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.       

100%|██████████| 150/150 [00:09<00:00, 16.01it/s]
  1%|▏         | 2/150 [00:00<00:09, 15.54it/s]

{'indices': DeviceArray([[150, 150],
             [150, 150],
             [150, 150],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[-2.3275762 , -0.13899913],
             [-1.7910798 , -0.0131385 ],
             [-1.3313822 ,  0.1547686 ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.42314655, 0.2819839 ],
             [0.42729062, 0.3739583 ],
             [0.43405968, 0.3785016 ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.       

100%|██████████| 150/150 [00:09<00:00, 15.81it/s]
  1%|▏         | 2/150 [00:00<00:10, 14.43it/s]

{'indices': DeviceArray([[150, 150],
             [150, 150],
             [150, 150],
             [150,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[-2.3275762 , -0.13899913],
             [-1.7910798 , -0.0131385 ],
             [-1.3313822 ,  0.1547686 ],
             [-0.9799459 ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.42314655, 0.2819839 ],
             [0.42729062, 0.3739583 ],
             [0.43405968, 0.3785016 ],
             [0.45560998, 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.       

100%|██████████| 150/150 [00:09<00:00, 15.54it/s]
  1%|▏         | 2/150 [00:00<00:09, 15.22it/s]

{'indices': DeviceArray([[150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[-2.3275762 , -0.13899913],
             [-1.7910798 , -0.0131385 ],
             [-1.3313822 ,  0.1547686 ],
             [-0.9799459 ,  0.30470592],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.42314655, 0.2819839 ],
             [0.42729062, 0.3739583 ],
             [0.43405968, 0.3785016 ],
             [0.45560998, 0.3506425 ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.       

100%|██████████| 150/150 [00:09<00:00, 15.95it/s]
  1%|▏         | 2/150 [00:00<00:09, 15.66it/s]

{'indices': DeviceArray([[150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[-2.3275762 , -0.13899913],
             [-1.7910798 , -0.0131385 ],
             [-1.3313822 ,  0.1547686 ],
             [-0.9799459 ,  0.30470592],
             [-0.6756184 ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.42314655, 0.2819839 ],
             [0.42729062, 0.3739583 ],
             [0.43405968, 0.3785016 ],
             [0.45560998, 0.3506425 ],
             [0.43326217, 0.        ],
             [0.        , 0.        ],
             [0.        , 0.       

100%|██████████| 150/150 [00:09<00:00, 15.77it/s]
  1%|▏         | 2/150 [00:00<00:09, 15.54it/s]

{'indices': DeviceArray([[150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[-2.3275762 , -0.13899913],
             [-1.7910798 , -0.0131385 ],
             [-1.3313822 ,  0.1547686 ],
             [-0.9799459 ,  0.30470592],
             [-0.6756184 ,  0.45809567],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.42314655, 0.2819839 ],
             [0.42729062, 0.3739583 ],
             [0.43405968, 0.3785016 ],
             [0.45560998, 0.3506425 ],
             [0.43326217, 0.38278085],
             [0.        , 0.        ],
             [0.        , 0.       

100%|██████████| 150/150 [00:10<00:00, 14.64it/s]
  1%|▏         | 2/150 [00:00<00:10, 14.29it/s]

{'indices': DeviceArray([[150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[-2.3275762 , -0.13899913],
             [-1.7910798 , -0.0131385 ],
             [-1.3313822 ,  0.1547686 ],
             [-0.9799459 ,  0.30470592],
             [-0.6756184 ,  0.45809567],
             [-0.4607192 ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.42314655, 0.2819839 ],
             [0.42729062, 0.3739583 ],
             [0.43405968, 0.3785016 ],
             [0.45560998, 0.3506425 ],
             [0.43326217, 0.38278085],
             [0.41741925, 0.        ],
             [0.        , 0.       

100%|██████████| 150/150 [00:09<00:00, 15.81it/s]
  1%|▏         | 2/150 [00:00<00:10, 13.65it/s]

{'indices': DeviceArray([[150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[-2.3275762 , -0.13899913],
             [-1.7910798 , -0.0131385 ],
             [-1.3313822 ,  0.1547686 ],
             [-0.9799459 ,  0.30470592],
             [-0.6756184 ,  0.45809567],
             [-0.4607192 ,  0.6803309 ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.42314655, 0.2819839 ],
             [0.42729062, 0.3739583 ],
             [0.43405968, 0.3785016 ],
             [0.45560998, 0.3506425 ],
             [0.43326217, 0.38278085],
             [0.41741925, 0.39086348],
             [0.        , 0.       

100%|██████████| 150/150 [00:11<00:00, 12.97it/s]
  1%|▏         | 2/150 [00:00<00:09, 14.83it/s]

{'indices': DeviceArray([[150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[-2.3275762 , -0.13899913],
             [-1.7910798 , -0.0131385 ],
             [-1.3313822 ,  0.1547686 ],
             [-0.9799459 ,  0.30470592],
             [-0.6756184 ,  0.45809567],
             [-0.4607192 ,  0.6803309 ],
             [-0.29215193,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.42314655, 0.2819839 ],
             [0.42729062, 0.3739583 ],
             [0.43405968, 0.3785016 ],
             [0.45560998, 0.3506425 ],
             [0.43326217, 0.38278085],
             [0.41741925, 0.39086348],
             [0.4055093 , 0.       

100%|██████████| 150/150 [00:10<00:00, 14.88it/s]
  1%|▏         | 2/150 [00:00<00:10, 14.39it/s]

{'indices': DeviceArray([[150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [  0,   0],
             [  0,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[-2.3275762 , -0.13899913],
             [-1.7910798 , -0.0131385 ],
             [-1.3313822 ,  0.1547686 ],
             [-0.9799459 ,  0.30470592],
             [-0.6756184 ,  0.45809567],
             [-0.4607192 ,  0.6803309 ],
             [-0.29215193,  0.96470225],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.42314655, 0.2819839 ],
             [0.42729062, 0.3739583 ],
             [0.43405968, 0.3785016 ],
             [0.45560998, 0.3506425 ],
             [0.43326217, 0.38278085],
             [0.41741925, 0.39086348],
             [0.4055093 , 0.4102782

100%|██████████| 150/150 [00:10<00:00, 14.73it/s]
  1%|▏         | 2/150 [00:00<00:09, 15.03it/s]

{'indices': DeviceArray([[150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150,   0],
             [  0,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[-2.3275762 , -0.13899913],
             [-1.7910798 , -0.0131385 ],
             [-1.3313822 ,  0.1547686 ],
             [-0.9799459 ,  0.30470592],
             [-0.6756184 ,  0.45809567],
             [-0.4607192 ,  0.6803309 ],
             [-0.29215193,  0.96470225],
             [-0.08326203,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.42314655, 0.2819839 ],
             [0.42729062, 0.3739583 ],
             [0.43405968, 0.3785016 ],
             [0.45560998, 0.3506425 ],
             [0.43326217, 0.38278085],
             [0.41741925, 0.39086348],
             [0.4055093 , 0.4102782

100%|██████████| 150/150 [00:09<00:00, 15.98it/s]
  1%|▏         | 2/150 [00:00<00:11, 13.10it/s]

{'indices': DeviceArray([[150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [  0,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[-2.3275762 , -0.13899913],
             [-1.7910798 , -0.0131385 ],
             [-1.3313822 ,  0.1547686 ],
             [-0.9799459 ,  0.30470592],
             [-0.6756184 ,  0.45809567],
             [-0.4607192 ,  0.6803309 ],
             [-0.29215193,  0.96470225],
             [-0.08326203,  1.2855396 ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.42314655, 0.2819839 ],
             [0.42729062, 0.3739583 ],
             [0.43405968, 0.3785016 ],
             [0.45560998, 0.3506425 ],
             [0.43326217, 0.38278085],
             [0.41741925, 0.39086348],
             [0.4055093 , 0.4102782

100%|██████████| 150/150 [00:09<00:00, 15.50it/s]
  1%|▏         | 2/150 [00:00<00:12, 12.23it/s]

{'indices': DeviceArray([[150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[-2.3275762 , -0.13899913],
             [-1.7910798 , -0.0131385 ],
             [-1.3313822 ,  0.1547686 ],
             [-0.9799459 ,  0.30470592],
             [-0.6756184 ,  0.45809567],
             [-0.4607192 ,  0.6803309 ],
             [-0.29215193,  0.96470225],
             [-0.08326203,  1.2855396 ],
             [ 0.07979719,  0.        ],
             [ 0.        ,  0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.42314655, 0.2819839 ],
             [0.42729062, 0.3739583 ],
             [0.43405968, 0.3785016 ],
             [0.45560998, 0.3506425 ],
             [0.43326217, 0.38278085],
             [0.41741925, 0.39086348],
             [0.4055093 , 0.4102782

100%|██████████| 150/150 [00:09<00:00, 15.15it/s]
  1%|▏         | 2/150 [00:00<00:10, 14.48it/s]

{'indices': DeviceArray([[150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[-2.3275762 , -0.13899913],
             [-1.7910798 , -0.0131385 ],
             [-1.3313822 ,  0.1547686 ],
             [-0.9799459 ,  0.30470592],
             [-0.6756184 ,  0.45809567],
             [-0.4607192 ,  0.6803309 ],
             [-0.29215193,  0.96470225],
             [-0.08326203,  1.2855396 ],
             [ 0.07979719,  1.6231467 ],
             [ 0.        ,  0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.42314655, 0.2819839 ],
             [0.42729062, 0.3739583 ],
             [0.43405968, 0.3785016 ],
             [0.45560998, 0.3506425 ],
             [0.43326217, 0.38278085],
             [0.41741925, 0.39086348],
             [0.4055093 , 0.4102782

100%|██████████| 150/150 [00:09<00:00, 15.53it/s]
  1%|▏         | 2/150 [00:00<00:10, 13.85it/s]

{'indices': DeviceArray([[150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150,   0]], dtype=int32), 'objective': DeviceArray([[-2.3275762 , -0.13899913],
             [-1.7910798 , -0.0131385 ],
             [-1.3313822 ,  0.1547686 ],
             [-0.9799459 ,  0.30470592],
             [-0.6756184 ,  0.45809567],
             [-0.4607192 ,  0.6803309 ],
             [-0.29215193,  0.96470225],
             [-0.08326203,  1.2855396 ],
             [ 0.07979719,  1.6231467 ],
             [ 0.19556758,  0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.42314655, 0.2819839 ],
             [0.42729062, 0.3739583 ],
             [0.43405968, 0.3785016 ],
             [0.45560998, 0.3506425 ],
             [0.43326217, 0.38278085],
             [0.41741925, 0.39086348],
             [0.4055093 , 0.4102782

100%|██████████| 150/150 [00:09<00:00, 15.28it/s]


{'indices': DeviceArray([[150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150]], dtype=int32), 'objective': DeviceArray([[-2.3275762 , -0.13899913],
             [-1.7910798 , -0.0131385 ],
             [-1.3313822 ,  0.1547686 ],
             [-0.9799459 ,  0.30470592],
             [-0.6756184 ,  0.45809567],
             [-0.4607192 ,  0.6803309 ],
             [-0.29215193,  0.96470225],
             [-0.08326203,  1.2855396 ],
             [ 0.07979719,  1.6231467 ],
             [ 0.19556758,  2.017192  ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.42314655, 0.2819839 ],
             [0.42729062, 0.3739583 ],
             [0.43405968, 0.3785016 ],
             [0.45560998, 0.3506425 ],
             [0.43326217, 0.38278085],
             [0.41741925, 0.39086348],
             [0.4055093 , 0.4102782

## Plot approaches for Shannon Diversity

In [31]:
# insert here the personal path to some temporary or fig folder
path = os.getcwd()
img_path = os.path.join(path, "temp")

In [29]:
fig = plot_diversity_methods(x, y, xstar, xstar_bound, ystar_ols, ystar_2sls, ystar_kiv, results)
fig.update_layout(xaxis=dict(title="Shannon Diversity (standardized)"))
fig.show()
fig.write_image(os.path.join(img_path, "ShannonDiversity.pdf"))

# Simpson Diversity

For the __Simpson Diversity__ we can use entirely the same steps as before. The thing we have to exchange is the underlying $X$ values, which are now computed by the Simpson Diversity estimate.

In [30]:
data = pd.read_csv("data_simpson.csv")  # data
X = data["estimate"]
n = X.shape[0]
Y = data["weight"]
Z = np.array([data["Treatment_x"][i].find("control") for i in range(n)])*(-1)

... continue ...