https://scikit-learn.org/stable/modules/cross_validation.html#cross-validation

In [1]:
import pandas as pd
import numpy as np
import seaborn as sns
import statsmodels.formula.api as smf
import statsmodels.api as sm

from sklearn.metrics import r2_score, confusion_matrix
from sklearn.model_selection import cross_val_score, cross_validate, train_test_split
from sklearn.linear_model import LinearRegression


In [2]:
#version 3
def statsmodels_train_test_split(df, stratify=None, **kwargs):

    if stratify is None:
        y, X = df.iloc[:,0], df.drop(columns=df.columns[0])
        X_train, X_test, y_train, y_test = train_test_split(X,y, **kwargs)
    else:
        y, X = stratify, df.drop(columns = stratify.name)
        X_train, X_test, y_train, y_test = train_test_split(X,y,stratify=y, **kwargs)
    
    return pd.concat([X_train, y_train], axis=1), pd.concat([X_test, y_test], axis=1)

In [3]:
from sklearn.base import BaseEstimator, RegressorMixin
class SMWrapper(BaseEstimator, RegressorMixin):
    """ A universal sklearn-style wrapper for statsmodels regressors """
    def __init__(self, model_class, fit_intercept=True):
        self.model_class = model_class
        self.fit_intercept = fit_intercept
    def fit(self, X, y):
        if self.fit_intercept:
            X = sm.add_constant(X)
        self.model_ = self.model_class(y, X)
        self.results_ = self.model_.fit()
    def predict(self, X):
        if self.fit_intercept:
            X = sm.add_constant(X)
        return self.results_.predict(X)

In [4]:
df = sns.load_dataset('iris')
df.head(10)

Unnamed: 0,sepal_length,sepal_width,petal_length,petal_width,species
0,5.1,3.5,1.4,0.2,setosa
1,4.9,3.0,1.4,0.2,setosa
2,4.7,3.2,1.3,0.2,setosa
3,4.6,3.1,1.5,0.2,setosa
4,5.0,3.6,1.4,0.2,setosa
5,5.4,3.9,1.7,0.4,setosa
6,4.6,3.4,1.4,0.3,setosa
7,5.0,3.4,1.5,0.2,setosa
8,4.4,2.9,1.4,0.2,setosa
9,4.9,3.1,1.5,0.1,setosa


In [None]:
#let's do linear regression with statsmodels first

In [5]:
train, test = statsmodels_train_test_split(df)

In [6]:
formula = 'sepal_length ~ sepal_width + petal_length + petal_width + C(species)'
model = smf.ols(formula=formula, data=train)
fitted_model = model.fit()
fitted_model.summary()

0,1,2,3
Dep. Variable:,sepal_length,R-squared:,0.876
Model:,OLS,Adj. R-squared:,0.87
Method:,Least Squares,F-statistic:,149.7
Date:,"Thu, 17 Sep 2020",Prob (F-statistic):,2.29e-46
Time:,08:27:20,Log-Likelihood:,-23.989
No. Observations:,112,AIC:,59.98
Df Residuals:,106,BIC:,76.29
Df Model:,5,,
Covariance Type:,nonrobust,,

0,1,2,3,4,5,6
,coef,std err,t,P>|t|,[0.025,0.975]
Intercept,2.1358,0.334,6.386,0.000,1.473,2.799
C(species)[T.versicolor],-0.7010,0.280,-2.504,0.014,-1.256,-0.146
C(species)[T.virginica],-0.9599,0.390,-2.463,0.015,-1.733,-0.187
sepal_width,0.5096,0.102,5.003,0.000,0.308,0.711
petal_length,0.8380,0.080,10.535,0.000,0.680,0.996
petal_width,-0.3607,0.186,-1.940,0.055,-0.729,0.008

0,1,2,3
Omnibus:,0.741,Durbin-Watson:,2.206
Prob(Omnibus):,0.69,Jarque-Bera (JB):,0.508
Skew:,-0.162,Prob(JB):,0.776
Kurtosis:,3.061,Cond. No.,96.9


In [8]:
r2_score(test['sepal_length'],fitted_model.predict(test))  #0.8275418768476765 with random_state=3

0.8275418768476765

In [None]:
# Do it with statsmodels

In [None]:
X = df.drop(columns='sepal_length')
y = df['sepal_length']
X_train, X_test, y_train, y_test = train_test_split(X,y, random_state=3)

In [25]:
df = sns.load_dataset('iris')

In [26]:
# Change our training dataset to dummies!

df_dummies = pd.get_dummies(df, drop_first=True)

In [27]:
formula = 'sepal_length ~ sepal_width + petal_length + petal_width + C(species)'

In [32]:
formula.split('~')[0].strip() # this gives the y

'sepal_length'

In [33]:
y = df_dummies[formula.split('~')[0].strip()]
X = df_dummies.drop(columns=formula.split('~')[0].strip())

In [34]:
X.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 150 entries, 0 to 149
Data columns (total 5 columns):
 #   Column              Non-Null Count  Dtype  
---  ------              --------------  -----  
 0   sepal_width         150 non-null    float64
 1   petal_length        150 non-null    float64
 2   petal_width         150 non-null    float64
 3   species_versicolor  150 non-null    uint8  
 4   species_virginica   150 non-null    uint8  
dtypes: float64(3), uint8(2)
memory usage: 3.9 KB
