#### <center> Forward Selection Regression Function
https://planspace.org/20150423-forward_selection_with_statsmodels/

In [1]:
import statsmodels.formula.api as smf

def forward_selected(df, target):
    """Linear model designed by forward selection.

    Parameters:
    -----------
    df : pandas dfFrame with all possible predictors and target

    target: string, name of target column in df

    Returns:
    --------
    model: an "optimal" fitted statsmodels linear model
           with an intercept
           selected by forward selection
           evaluated by adjusted R-squared
    """
    remaining = set(df.columns)
    remaining.remove(target)
    selected = []
    current_score, best_new_score = 0.0, 0.0
    while remaining and current_score == best_new_score:
        scores_with_candidates = []
        for candidate in remaining:
            formula = "{} ~ {} + 1".format(target,
                                           ' + '.join(selected + [candidate]))
            score = smf.ols(formula, df).fit().rsquared_adj
            scores_with_candidates.append((score, candidate))
        scores_with_candidates.sort()
        best_new_score, best_candidate = scores_with_candidates.pop()
        if current_score < best_new_score:
            remaining.remove(best_candidate)
            selected.append(best_candidate)
            current_score = best_new_score
    formula = "{} ~ {} + 1".format(target,
                                   ' + '.join(selected))
    model = smf.ols(formula, df).fit()
    return model

In [4]:
import pandas as pd

url = "http://data.princeton.edu/wws509/datasets/salary.dat"
data = pd.read_csv(url, sep='\\s+')

model = forward_selected(data, 'sl')

print (model.model.formula)
# sl ~ rk + yr + 1

print (model.rsquared_adj)
# 0.835190760538

sl ~ rk + yr + 1
0.8351907605379858


In [5]:
data.head()

Unnamed: 0,sx,rk,yr,dg,yd,sl
0,male,full,25,doctorate,35,36350
1,male,full,13,doctorate,22,35350
2,male,full,10,doctorate,23,28200
3,female,full,7,doctorate,27,26775
4,male,full,19,masters,30,33696


In [2]:
from ggplot import *

  from pandas.core import datetools
