# Forward Selection with statsmodels

In [6]:
import statsmodels.formula.api as smf

In [7]:
def forward_selected(data, response):
     """Linear model designed by forward selection.
     Parameters:
     -----------
     data : pandas DataFrame with all possible predictors and response
     response: string, name of response column in data
     Returns:
     --------
     model: an "optimal" fitted statsmodels linear model
            with an intercept
            selected by forward selection
            evaluated by adjusted R-squared
"""
     remaining = set(data.columns)
     remaining.remove(response)
     selected = []
     current_score, best_new_score = 0.0, 0.0
     while remaining and current_score == best_new_score:
         scores_with_candidates = []
         for candidate in remaining:
             formula = "{} ~ {} + 1".format(response,
                                            ' + '.join(selected + [candidate]))
             score = smf.ols(formula, data).fit().rsquared_adj
             scores_with_candidates.append((score, candidate))
         scores_with_candidates.sort()
         best_new_score, best_candidate = scores_with_candidates.pop()
         if current_score < best_new_score:
             remaining.remove(best_candidate)
             selected.append(best_candidate)
             current_score = best_new_score
     formula = "{} ~ {} + 1".format(response,
                                    ' + '.join(selected))
     model = smf.ols(formula, data).fit()
     return model

In [8]:
import pandas as pd

In [9]:
url = "salary.dat"
data = pd.read_csv(url, sep='\\s+')

In [19]:
data.head(5)

Unnamed: 0,sx,rk,yr,dg,yd,sl
0,male,full,25,doctorate,35,36350
1,male,full,13,doctorate,22,35350
2,male,full,10,doctorate,23,28200
3,female,full,7,doctorate,27,26775
4,male,full,19,masters,30,33696


In [11]:
model = forward_selected(data, 'sl')

In [12]:
model

<statsmodels.regression.linear_model.RegressionResultsWrapper at 0x1136dc240>

In [16]:
print(model.model.formula)

sl ~ rk + yr + 1


In [18]:
print(model.rsquared_adj)

0.8351907605379858
