In [46]:
from pprint import pprint

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import statsmodels.api as sm
# import seaborn as sns
from patsy import dmatrices
from sklearn import metrics

np.random.seed(1)

In [60]:
def get_lin_reg_model(model_formula, df_in, print_MSE=False, return_MSE=False):
    """
    Function returns the summary for fitted linear model.

    Parameter "model_formula" should be a patsy formula describing the model.
    Parameter "df" is a dataframe.
    """

    # Split the data into training (80%) and validation set (20%)
    mask = np.random.rand(len(df_in)) < 0.8
    train = df_in[mask]
    valid = df_in[~mask]

    # Prepare the data (dmatrices is from patsy library)
    y_train, X_train = dmatrices(model_formula, data=train, return_type='dataframe')
    y_valid, X_valid = dmatrices(model_formula, data=valid, return_type='dataframe')

    # Train the model
    model = sm.OLS(y_train, X_train)
    fitted_model = model.fit()
    y_train_pred = fitted_model.predict(X_train)
    train_MSE = metrics.mean_squared_error(y_train, y_train_pred)
    y_valid_pred = fitted_model.predict(X_valid)
    test_MSE = metrics.mean_squared_error(y_valid, y_valid_pred)

    if print_MSE is True:
        # Show MSE for training set
        print(f'{train_MSE=}')

        # Show MSE for validation set
        print(f'{test_MSE=}\n')

    # Return fitted model
    if return_MSE is True:
        return fitted_model, train_MSE, test_MSE
    else:
        return fitted_model

In [61]:
PATH = "lab/data/"
df_raw = pd.read_csv(f'{PATH}Auto.csv')

In [62]:
df_raw.isnull().sum()


mpg             0
cylinders       0
displacement    0
horsepower      0
weight          0
acceleration    0
year            0
origin          0
name            0
dtype: int64

In [63]:
# horsepower has some missing ('?') values
bad_rows = []
for index, row in df_raw.iterrows():
    if row['horsepower'] == "?":
        bad_rows.append(index)
df_raw = df_raw.drop(bad_rows)
print(f"dropped: {len(bad_rows)}")

dropped: 5


In [64]:
df_cont = df_raw.astype({"horsepower": 'int'})
df_cat = df_raw.astype({"horsepower": 'int', "year": 'category'})

In [65]:
year_continuous_model = get_lin_reg_model("mpg ~ year", df_cont, print_MSE=True)
year_categorical_model = get_lin_reg_model("mpg ~ year", df_cat, print_MSE=True)

train_MSE=39.6751221984536
test_MSE=43.28152633337354

train_MSE=33.12762922434946
test_MSE=43.60249855609495



In [66]:
year_continuous_model.summary()

0,1,2,3
Dep. Variable:,mpg,R-squared:,0.349
Model:,OLS,Adj. R-squared:,0.347
Method:,Least Squares,F-statistic:,172.8
Date:,"Tue, 06 Dec 2022",Prob (F-statistic):,6.75e-32
Time:,21:55:32,Log-Likelihood:,-1056.0
No. Observations:,324,AIC:,2116.0
Df Residuals:,322,BIC:,2124.0
Df Model:,1,,
Covariance Type:,nonrobust,,

0,1,2,3,4,5,6
,coef,std err,t,P>|t|,[0.025,0.975]
Intercept,-72.6835,7.307,-9.947,0.000,-87.059,-58.308
year,1.2651,0.096,13.147,0.000,1.076,1.454

0,1,2,3
Omnibus:,19.746,Durbin-Watson:,0.866
Prob(Omnibus):,0.0,Jarque-Bera (JB):,15.898
Skew:,0.452,Prob(JB):,0.000353
Kurtosis:,2.4,Cond. No.,1580.0


In [67]:
year_categorical_model.summary()

0,1,2,3
Dep. Variable:,mpg,R-squared:,0.463
Model:,OLS,Adj. R-squared:,0.442
Method:,Least Squares,F-statistic:,21.7
Date:,"Tue, 06 Dec 2022",Prob (F-statistic):,2.67e-34
Time:,21:55:48,Log-Likelihood:,-998.27
No. Observations:,315,AIC:,2023.0
Df Residuals:,302,BIC:,2071.0
Df Model:,12,,
Covariance Type:,nonrobust,,

0,1,2,3,4,5,6
,coef,std err,t,P>|t|,[0.025,0.975]
Intercept,17.8333,1.200,14.862,0.000,15.472,20.195
year[T.71],3.6884,1.715,2.150,0.032,0.313,7.064
year[T.72],0.2619,1.756,0.149,0.882,-3.195,3.718
year[T.73],-1.7667,1.610,-1.097,0.273,-4.935,1.401
year[T.74],5.1232,1.715,2.987,0.003,1.748,8.499
year[T.75],2.2037,1.649,1.336,0.182,-1.041,5.449
year[T.76],3.5000,1.610,2.174,0.030,0.332,6.668
year[T.77],6.3867,1.680,3.802,0.000,3.081,9.692
year[T.78],5.6148,1.649,3.405,0.001,2.370,8.860

0,1,2,3
Omnibus:,12.859,Durbin-Watson:,1.016
Prob(Omnibus):,0.002,Jarque-Bera (JB):,11.234
Skew:,0.391,Prob(JB):,0.00364
Kurtosis:,2.505,Cond. No.,14.0


In [68]:
print(df_cat['year'])

0      70
1      70
2      70
3      70
4      70
       ..
392    82
393    82
394    82
395    82
396    82
Name: year, Length: 392, dtype: category
Categories (13, int64): [70, 71, 72, 73, ..., 79, 80, 81, 82]
