In [1]:
import sys
import os
sys.path.append(os.path.join(os.getcwd(), "../src"))

import pandas as pd
import statsmodels.api as sm
from statsmodels.sandbox.regression.gmm import IV2SLS
import plotly.graph_objects as go

from helper_fct import *
import kiv
import boundiv
from plot_fct import colours, img_path, update_layout, plot_diversity_methods

In [2]:
%load_ext autoreload
%autoreload 2

# Simpson and Shannon Diversity

## Read in data

<div class="alert alert-block alert-info">
<b>Note:</b> 
Unfortunately, we personally cannot share the data. We refer to the paper for further references. However, the methods are implemented and work in any one-dimensional instrumental variable setting.
</div>

In [10]:
data = pd.read_csv("data_simpson.csv")  # data
X = data["estimate"]
n = X.shape[0]
Y = data["weight"]
Z = np.array([data["Treatment_x"][i].find("control") for i in range(n)])*(-1)

We whiten the data. Please make sure, to also whiten $Z$, if this is not a binary indicator.

In [12]:
Z = Z
X = whiten_data(X)
Y = whiten_data(Y)

# Run Methods on Real Microbiome Dataset

In [8]:
z = Z.squeeze()
y = Y.values.squeeze()
x = X.values.squeeze()

zz = sm.add_constant(z)
xx = sm.add_constant(x)
xstar = np.linspace(x.min(), x.max(), 30)

### First Stage Regression for F Test

In [10]:
ols1 = sm.OLS(x, zz).fit()
xhat = ols1.predict(zz)
print(ols1.summary())

                            OLS Regression Results                            
Dep. Variable:                      y   R-squared:                       0.128
Model:                            OLS   Adj. R-squared:                  0.113
Method:                 Least Squares   F-statistic:                     8.104
Date:                 Fr, 11 Jun 2021   Prob (F-statistic):            0.00620
Time:                        10:16:23   Log-Likelihood:                -76.458
No. Observations:                  57   AIC:                             156.9
Df Residuals:                      55   BIC:                             161.0
Df Model:                           1                                         
Covariance Type:            nonrobust                                         
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
const          0.4480      0.201      2.231      0.0

### Ordinary Least Squares

In [11]:
ols = sm.OLS(y, xx).fit()
coeff_ols = ols.params
ystar_ols = coeff_ols[0] + coeff_ols[1]*xstar

### Two Stage Least Squares

In [12]:
iv2sls = IV2SLS(y, xx, zz).fit()
coeff_2sls = iv2sls.params
ystar_2sls = coeff_2sls[0] + coeff_2sls[1]*xstar

### KIV

In [13]:
import kiv
xstar, ystar_kiv = kiv.fit_kiv(z, x, y, xstar=xstar)

### Bounding Method

In [14]:
import boundiv

In [15]:
xstar_bound = np.linspace(np.quantile(x, 0.1), np.quantile(x, 0.9), 10)

In [16]:
import boundiv
satis, results = boundiv.fit_bounds(z, x, y, xstar_bound, slack=0.6)

2
<function api_boundary.<locals>.reraise_with_filtered_traceback at 0x7fdd1fb92b70>
0
1
0
1


100%|██████████| 150/150 [00:13<00:00, 11.15it/s]
  1%|▏         | 2/150 [00:00<00:09, 15.85it/s]

{'indices': DeviceArray([[150,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[0.20934077, 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.27916303, 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.

100%|██████████| 150/150 [00:10<00:00, 14.60it/s]
  1%|▏         | 2/150 [00:00<00:12, 11.96it/s]

{'indices': DeviceArray([[150, 150],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[0.20934077, 2.0305836 ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.27916303, 0.3768974 ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.

100%|██████████| 150/150 [00:09<00:00, 15.11it/s]
  1%|▏         | 2/150 [00:00<00:10, 14.46it/s]

{'indices': DeviceArray([[150, 150],
             [150,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[0.20934077, 2.0305836 ],
             [0.08623122, 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.27916303, 0.3768974 ],
             [0.2834019 , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.

100%|██████████| 150/150 [00:10<00:00, 14.58it/s]
  1%|▏         | 2/150 [00:00<00:09, 14.87it/s]

{'indices': DeviceArray([[150, 150],
             [150, 150],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[0.20934077, 2.0305836 ],
             [0.08623122, 1.6260333 ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.27916303, 0.3768974 ],
             [0.2834019 , 0.3759325 ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.

100%|██████████| 150/150 [00:09<00:00, 15.18it/s]
  1%|▏         | 2/150 [00:00<00:11, 12.46it/s]

{'indices': DeviceArray([[150, 150],
             [150, 150],
             [150,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[ 0.20934077,  2.0305836 ],
             [ 0.08623122,  1.6260333 ],
             [-0.07718299,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.27916303, 0.3768974 ],
             [0.2834019 , 0.3759325 ],
             [0.3538366 , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.       

100%|██████████| 150/150 [00:09<00:00, 15.52it/s]
  1%|▏         | 2/150 [00:00<00:10, 14.29it/s]

{'indices': DeviceArray([[150, 150],
             [150, 150],
             [150, 150],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[ 0.20934077,  2.0305836 ],
             [ 0.08623122,  1.6260333 ],
             [-0.07718299,  1.2878089 ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.27916303, 0.3768974 ],
             [0.2834019 , 0.3759325 ],
             [0.3538366 , 0.33831406],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.       

100%|██████████| 150/150 [00:09<00:00, 15.62it/s]
  1%|▏         | 2/150 [00:00<00:09, 15.91it/s]

{'indices': DeviceArray([[150, 150],
             [150, 150],
             [150, 150],
             [150,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[ 0.20934077,  2.0305836 ],
             [ 0.08623122,  1.6260333 ],
             [-0.07718299,  1.2878089 ],
             [-0.27837473,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.27916303, 0.3768974 ],
             [0.2834019 , 0.3759325 ],
             [0.3538366 , 0.33831406],
             [0.34781808, 0.        ],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.       

100%|██████████| 150/150 [00:10<00:00, 14.97it/s]
  1%|▏         | 2/150 [00:00<00:11, 12.99it/s]

{'indices': DeviceArray([[150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[ 0.20934077,  2.0305836 ],
             [ 0.08623122,  1.6260333 ],
             [-0.07718299,  1.2878089 ],
             [-0.27837473,  0.95241207],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.27916303, 0.3768974 ],
             [0.2834019 , 0.3759325 ],
             [0.3538366 , 0.33831406],
             [0.34781808, 0.34442526],
             [0.        , 0.        ],
             [0.        , 0.        ],
             [0.        , 0.       

100%|██████████| 150/150 [00:09<00:00, 15.49it/s]
  1%|▏         | 2/150 [00:00<00:09, 15.01it/s]

{'indices': DeviceArray([[150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[ 0.20934077,  2.0305836 ],
             [ 0.08623122,  1.6260333 ],
             [-0.07718299,  1.2878089 ],
             [-0.27837473,  0.95241207],
             [-0.4408887 ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.27916303, 0.3768974 ],
             [0.2834019 , 0.3759325 ],
             [0.3538366 , 0.33831406],
             [0.34781808, 0.34442526],
             [0.33759212, 0.        ],
             [0.        , 0.        ],
             [0.        , 0.       

100%|██████████| 150/150 [00:09<00:00, 15.37it/s]
  1%|▏         | 2/150 [00:00<00:09, 15.18it/s]

{'indices': DeviceArray([[150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[ 0.20934077,  2.0305836 ],
             [ 0.08623122,  1.6260333 ],
             [-0.07718299,  1.2878089 ],
             [-0.27837473,  0.95241207],
             [-0.4408887 ,  0.6720197 ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.27916303, 0.3768974 ],
             [0.2834019 , 0.3759325 ],
             [0.3538366 , 0.33831406],
             [0.34781808, 0.34442526],
             [0.33759212, 0.3374936 ],
             [0.        , 0.        ],
             [0.        , 0.       

100%|██████████| 150/150 [00:09<00:00, 15.19it/s]
  1%|▏         | 2/150 [00:00<00:09, 15.03it/s]

{'indices': DeviceArray([[150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[ 0.20934077,  2.0305836 ],
             [ 0.08623122,  1.6260333 ],
             [-0.07718299,  1.2878089 ],
             [-0.27837473,  0.95241207],
             [-0.4408887 ,  0.6720197 ],
             [-0.6471962 ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.27916303, 0.3768974 ],
             [0.2834019 , 0.3759325 ],
             [0.3538366 , 0.33831406],
             [0.34781808, 0.34442526],
             [0.33759212, 0.3374936 ],
             [0.34903783, 0.        ],
             [0.        , 0.       

100%|██████████| 150/150 [00:09<00:00, 15.35it/s]
  1%|▏         | 2/150 [00:00<00:09, 15.83it/s]

{'indices': DeviceArray([[150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [  0,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[ 0.20934077,  2.0305836 ],
             [ 0.08623122,  1.6260333 ],
             [-0.07718299,  1.2878089 ],
             [-0.27837473,  0.95241207],
             [-0.4408887 ,  0.6720197 ],
             [-0.6471962 ,  0.45772836],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.27916303, 0.3768974 ],
             [0.2834019 , 0.3759325 ],
             [0.3538366 , 0.33831406],
             [0.34781808, 0.34442526],
             [0.33759212, 0.3374936 ],
             [0.34903783, 0.33908075],
             [0.        , 0.       

100%|██████████| 150/150 [00:09<00:00, 15.68it/s]
  1%|▏         | 2/150 [00:00<00:09, 15.71it/s]

{'indices': DeviceArray([[150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150,   0],
             [  0,   0],
             [  0,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[ 0.20934077,  2.0305836 ],
             [ 0.08623122,  1.6260333 ],
             [-0.07718299,  1.2878089 ],
             [-0.27837473,  0.95241207],
             [-0.4408887 ,  0.6720197 ],
             [-0.6471962 ,  0.45772836],
             [-0.9357966 ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.27916303, 0.3768974 ],
             [0.2834019 , 0.3759325 ],
             [0.3538366 , 0.33831406],
             [0.34781808, 0.34442526],
             [0.33759212, 0.3374936 ],
             [0.34903783, 0.33908075],
             [0.35973734, 0.       

100%|██████████| 150/150 [00:13<00:00, 11.49it/s]
  1%|▏         | 2/150 [00:00<00:11, 12.80it/s]

{'indices': DeviceArray([[150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [  0,   0],
             [  0,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[ 0.20934077,  2.0305836 ],
             [ 0.08623122,  1.6260333 ],
             [-0.07718299,  1.2878089 ],
             [-0.27837473,  0.95241207],
             [-0.4408887 ,  0.6720197 ],
             [-0.6471962 ,  0.45772836],
             [-0.9357966 ,  0.3015917 ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.27916303, 0.3768974 ],
             [0.2834019 , 0.3759325 ],
             [0.3538366 , 0.33831406],
             [0.34781808, 0.34442526],
             [0.33759212, 0.3374936 ],
             [0.34903783, 0.33908075],
             [0.35973734, 0.3276164

100%|██████████| 150/150 [00:12<00:00, 11.84it/s]
  1%|          | 1/150 [00:00<00:15,  9.89it/s]

{'indices': DeviceArray([[150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150,   0],
             [  0,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[ 0.20934077,  2.0305836 ],
             [ 0.08623122,  1.6260333 ],
             [-0.07718299,  1.2878089 ],
             [-0.27837473,  0.95241207],
             [-0.4408887 ,  0.6720197 ],
             [-0.6471962 ,  0.45772836],
             [-0.9357966 ,  0.3015917 ],
             [-1.2716264 ,  0.        ],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.27916303, 0.3768974 ],
             [0.2834019 , 0.3759325 ],
             [0.3538366 , 0.33831406],
             [0.34781808, 0.34442526],
             [0.33759212, 0.3374936 ],
             [0.34903783, 0.33908075],
             [0.35973734, 0.3276164

100%|██████████| 150/150 [00:13<00:00, 11.46it/s]
  1%|▏         | 2/150 [00:00<00:10, 14.67it/s]

{'indices': DeviceArray([[150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [  0,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[ 0.20934077,  2.0305836 ],
             [ 0.08623122,  1.6260333 ],
             [-0.07718299,  1.2878089 ],
             [-0.27837473,  0.95241207],
             [-0.4408887 ,  0.6720197 ],
             [-0.6471962 ,  0.45772836],
             [-0.9357966 ,  0.3015917 ],
             [-1.2716264 ,  0.12118108],
             [ 0.        ,  0.        ],
             [ 0.        ,  0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.27916303, 0.3768974 ],
             [0.2834019 , 0.3759325 ],
             [0.3538366 , 0.33831406],
             [0.34781808, 0.34442526],
             [0.33759212, 0.3374936 ],
             [0.34903783, 0.33908075],
             [0.35973734, 0.3276164

100%|██████████| 150/150 [00:10<00:00, 13.71it/s]
  1%|▏         | 2/150 [00:00<00:10, 14.61it/s]

{'indices': DeviceArray([[150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150,   0],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[ 0.20934077,  2.0305836 ],
             [ 0.08623122,  1.6260333 ],
             [-0.07718299,  1.2878089 ],
             [-0.27837473,  0.95241207],
             [-0.4408887 ,  0.6720197 ],
             [-0.6471962 ,  0.45772836],
             [-0.9357966 ,  0.3015917 ],
             [-1.2716264 ,  0.12118108],
             [-1.7137612 ,  0.        ],
             [ 0.        ,  0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.27916303, 0.3768974 ],
             [0.2834019 , 0.3759325 ],
             [0.3538366 , 0.33831406],
             [0.34781808, 0.34442526],
             [0.33759212, 0.3374936 ],
             [0.34903783, 0.33908075],
             [0.35973734, 0.3276164

100%|██████████| 150/150 [00:10<00:00, 14.33it/s]
  1%|▏         | 2/150 [00:00<00:11, 13.39it/s]

{'indices': DeviceArray([[150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [  0,   0]], dtype=int32), 'objective': DeviceArray([[ 0.20934077,  2.0305836 ],
             [ 0.08623122,  1.6260333 ],
             [-0.07718299,  1.2878089 ],
             [-0.27837473,  0.95241207],
             [-0.4408887 ,  0.6720197 ],
             [-0.6471962 ,  0.45772836],
             [-0.9357966 ,  0.3015917 ],
             [-1.2716264 ,  0.12118108],
             [-1.7137612 , -0.04830782],
             [ 0.        ,  0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.27916303, 0.3768974 ],
             [0.2834019 , 0.3759325 ],
             [0.3538366 , 0.33831406],
             [0.34781808, 0.34442526],
             [0.33759212, 0.3374936 ],
             [0.34903783, 0.33908075],
             [0.35973734, 0.3276164

100%|██████████| 150/150 [00:10<00:00, 13.89it/s]
  1%|▏         | 2/150 [00:00<00:10, 14.66it/s]

{'indices': DeviceArray([[150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150,   0]], dtype=int32), 'objective': DeviceArray([[ 0.20934077,  2.0305836 ],
             [ 0.08623122,  1.6260333 ],
             [-0.07718299,  1.2878089 ],
             [-0.27837473,  0.95241207],
             [-0.4408887 ,  0.6720197 ],
             [-0.6471962 ,  0.45772836],
             [-0.9357966 ,  0.3015917 ],
             [-1.2716264 ,  0.12118108],
             [-1.7137612 , -0.04830782],
             [-2.2402294 ,  0.        ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.27916303, 0.3768974 ],
             [0.2834019 , 0.3759325 ],
             [0.3538366 , 0.33831406],
             [0.34781808, 0.34442526],
             [0.33759212, 0.3374936 ],
             [0.34903783, 0.33908075],
             [0.35973734, 0.3276164

100%|██████████| 150/150 [00:11<00:00, 13.54it/s]


{'indices': DeviceArray([[150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150],
             [150, 150]], dtype=int32), 'objective': DeviceArray([[ 0.20934077,  2.0305836 ],
             [ 0.08623122,  1.6260333 ],
             [-0.07718299,  1.2878089 ],
             [-0.27837473,  0.95241207],
             [-0.4408887 ,  0.6720197 ],
             [-0.6471962 ,  0.45772836],
             [-0.9357966 ,  0.3015917 ],
             [-1.2716264 ,  0.12118108],
             [-1.7137612 , -0.04830782],
             [-2.2402294 , -0.1784195 ]], dtype=float32), 'maxabsdiff': DeviceArray([[0.27916303, 0.3768974 ],
             [0.2834019 , 0.3759325 ],
             [0.3538366 , 0.33831406],
             [0.34781808, 0.34442526],
             [0.33759212, 0.3374936 ],
             [0.34903783, 0.33908075],
             [0.35973734, 0.3276164

# Plot approaches

In [17]:
from plot_fct import colours, img_path, update_layout, plot_diversity_methods

In [37]:
fig = plot_diversity_methods(x, y, xstar, xstar_bound, ystar_ols, ystar_2sls, ystar_kiv, results)
fig.update_layout(xaxis=dict(title="Simpson Diversity (standardized)"))
fig.show()
fig.write_image(os.path.join(img_path, "SimpsonDiversity.pdf"))

In [52]:
fig = plot_diversity_methods(x, y, xstar, xstar_bound, ystar_ols, ystar_2sls, ystar_kiv, results)
fig.update_layout(xaxis=dict(title="Simpson Diversity (standardized)"))
fig.update_layout(legend=dict(orientation="h",
                             yanchor="bottom",
                             xanchor="right",
                             y=1.02,
                             x=1))

fig.update_layout(width=2000, height=500)
fig.show()
fig.write_image(os.path.join(img_path, "SimpsonDiversity.pdf"))

In [51]:
fig = plot_diversity_methods(x, y, xstar, xstar_bound, ystar_ols, ystar_2sls, ystar_kiv, results)
fig.update_layout(xaxis=dict(title="Simpson Diversity (standardized)"))
fig.update_layout(showlegend=False)
fig.show()
fig.write_image(os.path.join(img_path, "SimpsonDiversity_withoutLegend.pdf"))