In [14]:
import pandas as pd
import numpy as np
from patsy import dmatrices
from statsmodels.stats.outliers_influence import variance_inflation_factor
import altair as alt

def vif_bar_plot(x, y, df, thresh):
    """
    Returns a list containing a dataframe that includes Variance Inflation Factor (VIF) score and 
    a bar chart for the VIF scores alongside the specified threshold for each explanatory variable
    in a linear regression model.
   
    Parameters
    ----------
    x : list
        A list of the names of the explanatory variables.
    y : str
        The response variable.
    df : pandas.DataFrame
        A dataframe containing the data.
    thresh : int
        An integer specifying the threshold.

    Returns
    -------
    list
        A list containing a dataframe for VIFs and a bar chart of the VIFs for each explanatory variable alongside the threshold.
    
    Examples
    --------
    >>> from collinearity_tool.collinearity_tool vif_bar_plot
    >>> vif_bar_plot(["exp1", "exp2", "exp3"], "response", data, 5)
    """
    if type(x) is not list:
        raise ValueError("x must be a list of explanatory variables!")
    if type(y) is not str:
        raise ValueError("y must be a string!")
    if type(df) is not pd.DataFrame:
        raise ValueError("df must be a pandas data frame!")
    if type(thresh) is not int:
        raise ValueError("thresh must be an integer!")
    
    # Data frame containing VIF scores
    explanatory_var = "+".join(set(x))
    
    y, X = dmatrices(y + " ~" + explanatory_var, df, return_type = "dataframe")
    
    vif_list = []
    for i in range(X.shape[1]):
        vif_list.append(variance_inflation_factor(X.values, i))
        
    vif_df = pd.DataFrame(vif_list, 
                          columns = ["vif_score"])
    vif_df["explanatory_var"] = X.columns
    
    
    # Plotting the VIF scores
    hbar_plot = alt.Chart(vif_df).mark_bar(
        ).encode(
            x = alt.X("vif_score", 
              title = "VIF Score"),
            y = alt.Y("explanatory_var",
              title = "Explanatory Variable")
    ).properties(
        width = 400,
        height = 300,
        title = "VIF Scores for Each Explanatory Variable in Linear Regression"
    )
    thresh_plot = alt.Chart(pd.DataFrame({"x": [thresh]})).mark_rule(
        color = "red"
    ).encode(
        x = "x")
    vif_plot = hbar_plot + thresh_plot
    
    return [vif_df, vif_plot]


## Tests

In [2]:
def test_vif_bar_plot():
    """
    Tests the vif_bar_plot function. 
    The tests cover the output of the function including the VIF score dataframe and the bar plot for each explanatory variable.
    
    Examples
    --------
    >>> test_vif_bar_plot()
    """
    
    iris_df = pd.read_csv('https://raw.githubusercontent.com/mwaskom/seaborn-data/master/iris.csv')
    vif_scores_and_plot = vif_bar_plot(["sepal_length",  "sepal_width"], "petal_width", iris_df, 5)

    # data types of outputs
    assert type(vif_scores_and_plot) == list, "The output should be a list."
    assert type(vif_scores_and_plot[0]) == pd.DataFrame, "The first element of the list should be a data frame."
    assert type(vif_scores_and_plot[1]) == alt.LayerChart, "The second element of the list should be a layered Altair object."

    # dataframe tests
    pd.DataFrame.equals(round(vif_scores_and_plot[0], 3), 
                        pd.DataFrame([[113.940, "Intercept"], 
                                     [1.014, "sepal_length"], 
                                     [1.014, "sepal_width"]],
                                     columns = ["vif_score", "explanatory_var"]
                                    ) 
                       )
    assert vif_scores_and_plot[0].columns.tolist() == ['vif_score', 'explanatory_var'], "Wrong column names."
    assert vif_scores_and_plot[0].dtypes[0] == "float64", "Wrong data type for the VIF scores column."
    assert vif_scores_and_plot[0].dtypes[1] == "object", "Wrong data type for the explantory variable column."
    assert vif_scores_and_plot[0].shape == (len(["sepal_length",  "sepal_width"]) + 1, 2)

    # plot
    assert len(vif_scores_and_plot[1].layer) == 2, "The altair plot must have two layers."
    assert vif_scores_and_plot[1].layer[0].mark == "bar", "Mark should be a bar."
    assert vif_scores_and_plot[1].layer[0].encoding.x.shorthand == "vif_score", "x-axis should be mapped to vif_score."
    assert vif_scores_and_plot[1].layer[0].encoding.y.shorthand == "explanatory_var", "y-axis should be mapped to explanatory_var."
    assert vif_scores_and_plot[1].layer[1].mark.color == "red", "The threshold should be a red line."
    assert vif_scores_and_plot[1].layer[1].mark.type == "rule", "The threshold should be a line spanning the axis."
    assert vif_scores_and_plot[1] == vif_scores_and_plot[1].layer[0] + vif_scores_and_plot[1].layer[1], "The plot should be a layered plot (bar + line)."
