Skip to content

Commit

Permalink
Added plot function, which automatically computes 90 percent confiden…
Browse files Browse the repository at this point in the history
…ce bands and adds them to rslt dictionary, and minor code improvements.
  • Loading branch information
Sebastian Gsell authored and Sebastian Gsell committed Mar 23, 2020
1 parent 8b6301c commit cedd8ea
Show file tree
Hide file tree
Showing 9 changed files with 301 additions and 29 deletions.
15 changes: 15 additions & 0 deletions grmpy/check/auxiliary.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import pandas as pd

from grmpy.read.read import read
from grmpy.simulate.simulate_auxiliary import construct_covariance_matrix


Expand Down Expand Up @@ -41,3 +42,17 @@ def check_special_conf(dict_):
return True, msg

return False, " "


def check_append_constant(init_file, dict_, data, semipar=False):
"""Check if constant already provided by user.
If not, add auto-generated constant.
"""
if "const" not in data:
dict_ = read(init_file, semipar, include_constant=True)
data.insert(0, "const", 1.0)

else:
pass

return dict_, data
16 changes: 16 additions & 0 deletions grmpy/check/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from grmpy.check.custom_exceptions import UserError
from grmpy.check.auxiliary import is_pos_def

from grmpy.read.read import read


def check_presence_init(fname):
"""This function checks whether the model initialization file does in fact exist."""
Expand Down Expand Up @@ -109,3 +111,17 @@ def check_start_values(x0):
"start values for the estimation process."
)
raise UserError(msg)


def check_append_constant(init_file, dict_, data, semipar=False):
"""Check if constant already provided by user.
If not, add auto-generated constant.
"""
if "const" not in data:
dict_ = read(init_file, semipar, include_constant=True)
data.insert(0, "const", 1.0)

else:
pass

return dict_, data
27 changes: 10 additions & 17 deletions grmpy/estimate/estimate.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
"""The module provides an estimation process given the simulated data set and the
initialization file.
"""
import numpy as np

from grmpy.check.check import check_presence_estimation_dataset
from grmpy.check.check import check_basic_init_basic
from grmpy.check.check import check_par_init_dict
from grmpy.check.check import check_presence_init
from grmpy.check.check import check_basic_init_basic
from grmpy.check.check import check_par_init_file
from grmpy.check.auxiliary import read_data
from grmpy.read.read import read

from grmpy.read.read import read, check_append_constant

from grmpy.estimate.estimate_semipar import semipar_fit
from grmpy.estimate.estimate_par import par_fit


def fit(init_file, semipar=False):
"""This function estimates the MTE based on a parametric normal model or,
alternatively, via the semiparametric method of local instrumental variables (LIV)"""
"""This function estimates the MTE based on a parametric normal model
or, alternatively, via the semiparametric method of
local instrumental variables (LIV)"""

# Load the estimation file
check_presence_init(init_file)
Expand All @@ -32,13 +33,9 @@ def fit(init_file, semipar=False):

# Distribute initialization information.
data = read_data(dict_["ESTIMATION"]["file"])
dict_, data = check_append_constant(init_file, dict_, data, semipar=True)

# Check if constant already provided by user, but with name
# other than 'const'. If so, drop auto-generated constant.
if np.array_equal(np.asarray(data.iloc[:, 0]), np.ones(len(data))) is False:
dict_ = read(init_file, semipar, include_constant=True)

rslt = semipar_fit(dict_)
rslt = semipar_fit(dict_, data)

# Parametric Normal Model
else:
Expand All @@ -47,12 +44,8 @@ def fit(init_file, semipar=False):

# Distribute initialization information.
data = read_data(dict_["ESTIMATION"]["file"])
dict_, data = check_append_constant(init_file, dict_, data, semipar=False)

# Check if constant already provided by user, but with name
# other than 'const'. If so, drop auto-generated constant.
if np.array_equal(np.asarray(data.iloc[:, 0]), np.ones(len(data))) is False:
dict_ = read(init_file, semipar=False, include_constant=True)

rslt = par_fit(dict_)
rslt = par_fit(dict_, data)

return rslt
6 changes: 1 addition & 5 deletions grmpy/estimate/estimate_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,12 @@
from grmpy.estimate.estimate_output import print_logfile

from grmpy.check.check import UserError, check_start_values
from grmpy.check.auxiliary import read_data


def par_fit(dict_):
def par_fit(dict_, data):
"""The function estimates the coefficients of the simulated data set."""
np.random.seed(dict_["SIMULATION"]["seed"])

# Distribute initialization information.
data = read_data(dict_["ESTIMATION"]["file"])

_, X1, X0, Z1, Z0, Y1, Y0 = process_data(data, dict_)

num_treated = dict_["AUX"]["num_covars_treated"]
Expand Down
11 changes: 5 additions & 6 deletions grmpy/estimate/estimate_semipar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@
import statsmodels.api as sm
import matplotlib.pyplot as plt

from grmpy.check.auxiliary import read_data
from grmpy.KernReg.locpoly import locpoly

from skmisc.loess import loess

lowess = sm.nonparametric.lowess


def semipar_fit(dict_):
def semipar_fit(dict_, data):
""""This function runs the semiparametric estimation via
local instrumental variables"""
# Process the information specified in the initialization file
Expand All @@ -23,9 +22,6 @@ def semipar_fit(dict_):

show_output = dict_["ESTIMATION"]["show_output"]

# Load data
data = read_data(dict_["ESTIMATION"]["file"])

# Prepare the sample for the estimation process
# Compute propensity score, define common support and trim the sample
data, ps = process_mte_data(
Expand Down Expand Up @@ -154,6 +150,9 @@ def trim_support(
# Re-estimate propensity score P(z)
ps = estimate_treatment_propensity(D, Z, logit, show_output)

else:
pass

data = data.sort_values(by="ps", ascending=True)
ps = np.sort(ps)

Expand Down Expand Up @@ -256,7 +255,7 @@ def estimate_treatment_propensity(D, Z, logit, show_output):
def plot_common_support(
ps, indicator, data, nbins, show_output, figsize, fontsize, plot_title
):
data["ps"] = ps
data.loc[:, "ps"] = ps

treated = data[[indicator, "ps"]][data[indicator] == 1].values
untreated = data[[indicator, "ps"]][data[indicator] == 0].values
Expand Down
64 changes: 64 additions & 0 deletions grmpy/plot/plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import matplotlib.pyplot as plt

from grmpy.plot.plot_auxiliary import mte_and_cof_int_semipar
from grmpy.plot.plot_auxiliary import mte_and_cof_int_par

from grmpy.check.check import check_append_constant
from grmpy.check.auxiliary import read_data
from grmpy.read.read import read


def plot_mte(
rslt,
init_file,
college_years=4,
font_size=20,
label_size=14,
color="blue",
semipar=False,
nboot=250,
):
"""This function calculates the marginal treatment effect for
different quantiles u_D of the unobservables.
Depending on the model specification, either the parametric or
semiparametric MTE is plotted along with the corresponding
90 percent confindence bands.
"""
# Read init dict and data
init_dict = read(init_file)
data = read_data(init_dict["ESTIMATION"]["file"])

dict_, data = check_append_constant(init_file, init_dict, data, semipar)

if semipar is True:
quantiles, mte, mte_up, mte_d = mte_and_cof_int_semipar(
rslt, init_file, college_years, nboot
)

else:
quantiles, mte, mte_up, mte_d = mte_and_cof_int_par(rslt, init_dict, data)

# Plot curve
ax = plt.figure(figsize=(17.5, 10)).add_subplot(111)

ax.set_ylabel(r"$MTE$", fontsize=font_size)
ax.set_xlabel("$u_D$", fontsize=font_size)
ax.tick_params(
axis="both",
direction="in",
length=5,
width=1,
grid_alpha=0.25,
labelsize=label_size,
)
ax.xaxis.set_ticks_position("both")
ax.yaxis.set_ticks_position("both")

ax.plot(quantiles, mte, color=color, linewidth=4)
ax.plot(quantiles, mte_up, color=color, linestyle=":", linewidth=3)
ax.plot(quantiles, mte_d, color=color, linestyle=":", linewidth=3)

plt.show()

return mte, quantiles

0 comments on commit cedd8ea

Please sign in to comment.