In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import Image
from scipy import stats
from sklearn import linear_model
import statsmodels.api as sm
import numpy as np
from matplotlib.offsetbox import AnchoredText
import statsmodels.formula.api as smf

In [5]:
def regplot(X_column, y_column, df, order):
    
    order = order
    
    X = df[X_column]
    y = df[y_column]
    
    model = np.poly1d(np.polyfit(X, y, order))
    results = smf.ols(formula='y ~ model(X)', data=df).fit()
    R2 = results.rsquared
    p_value = results.f_pvalue

    sns.regplot(x=X, y=y, order=order)
    plt.figtext(0.72, 0.78, f' R2: {R2} \n p: {p_value}', fontsize=10)
    plt.xlabel(X_column)
    plt.ylabel(y_column)
    plt.show()

In [3]:
def z_scored_df(df, X_column, cut_off, print_zscores):
    
    z_scored = np.abs(stats.zscore(df[X_column]))
    if print_zscores == 1:
        print(z_scored)
    filtered = (z_scored < cut_off)
    df_new = df[filtered]
    
    nr_excluded = df.shape[0] - df_new.shape[0]
    
    return df_new, nr_excluded

In [4]:
from sklearn.metrics import mutual_info_score

def calc_MI(x, y, bins):
    c_xy = np.histogram2d(x,y,bins)[0]
    mi = mutual_info_score(None,None,contingency=c_xy)
    return mi

def mutual_information_matrix(df, to_compare):

    corr_mat = np.zeros((len(to_compare),len(to_compare)))
    
    for i in range(corr_mat.shape[0]):
        for j in range(corr_mat.shape[0]):
            if i!=j:
                value1 = df_no_nan[to_compare[i]].tolist()
                value2 = df_no_nan[to_compare[j]].tolist()
                mi = calc_MI(value1, value2, 10)
                corr_mat[i,j] = mi
                corr_mat[j,i] = mi
    
    fis, ax = plt.subplots(figsize=(20,12))
    sns.heatmap(corr_mat, annot=True, cmap='rocket_r')
    ax.set_xticklabels(to_compare, rotation=45, ha='right')
    ax.set_yticklabels(to_compare, rotation=0)
    ax.hlines((ax.get_xlim()[1] - 8), *ax.get_xlim(), colors='w')
    ax.vlines((ax.get_xlim()[1] - 8), *ax.get_xlim(), colors='w')
    plt.tight_layout()

In [6]:
def plot_regressions(X_column, y_column, df, zscore):
    
    order = 2
    
    if zscore == 0:
        df = df
    else:
        df, nr_excluded = z_scored_df(df, X_column, cut_off=zscore, print_zscores=0)
        print(f'Nr exlcuded: {nr_excluded}')
    
    X = df[X_column]
    y = df[y_column]
    
    # Quadratic regression
    model = np.poly1d(np.polyfit(X, y, order))
    results = smf.ols(formula='y ~ model(X)', data=df).fit()
    R2_quad = results.rsquared
    p_quad = results.f_pvalue
    
    # Linear regression
    X2 = sm.add_constant(X)
    est = sm.OLS(y, X2)
    est2 = est.fit()
    p_lin = est2.pvalues[1]
    R2_lin = est2.rsquared

    fig, (ax1, ax2) = plt.subplots(1,2, figsize=(14,4))
    fig.suptitle(X_column)

    # Plot linear regression
    sns.regplot(x=X, y=y, ax=ax1)
    anchored_text = AnchoredText(f' R2: {R2_lin:.5f} \n p: {p_lin:.5f}', loc=1)
    ax1.add_artist(anchored_text)

    # Plot polynomial regression
    sns.regplot(x=X, y=y, order=order, ax=ax2)
    anchored_text = AnchoredText(f' R2: {R2_quad:.5f} \n p: {p_quad:.5f}', loc=1)
    ax2.add_artist(anchored_text)
    
    plt.show()

In [8]:
def regplots(X_column, y_column, df, zscore, select):

    order = 2
    significant = False
    
    if zscore == 0:
        df = df
    else:
        df, nr_excluded = z_scored_df(df, X_column, cut_off=zscore, print_zscores=0)
    X = df[X_column]
    y = df[y_column]
    
    # Quadratic regression
    model = np.poly1d(np.polyfit(X, y, order))
    results = smf.ols(formula='y ~ model(X)', data=df).fit()
    R2_quad = results.rsquared
    p_quad = results.f_pvalue
        
    # Linear regression
    X2 = sm.add_constant(X)
    est = sm.OLS(y, X2)
    est2 = est.fit()
    p_lin = est2.pvalues[1]
    R2_lin = est2.rsquared
    
    # Choose selection method
    if select == 'linear':
        if p_lin < 0.05:
            significant = True
    if select == 'quadratic':
        if p_quad < 0.05: 
            significant = True
    if select == 'only_linear':
        if p_lin < 0.05 and p_quad > 0.05:
            significant = True
    if select == 'only_quadratic':
        if p_quad < 0.05 and p_lin > 0.05:
            significant = True        
    if select == 'both': 
        if p_lin < 0.05 and p_quad < 0.05: 
            significant = True
    
    # If significant, make plots
    if significant == True:
        plot_regressions(X_column, y_column, df, zscore)