In [10]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

data = pd.read_csv('data/student-mat.csv', sep=';')
data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 395 entries, 0 to 394
Data columns (total 33 columns):
 #   Column      Non-Null Count  Dtype 
---  ------      --------------  ----- 
 0   school      395 non-null    object
 1   sex         395 non-null    object
 2   age         395 non-null    int64 
 3   address     395 non-null    object
 4   famsize     395 non-null    object
 5   Pstatus     395 non-null    object
 6   Medu        395 non-null    int64 
 7   Fedu        395 non-null    int64 
 8   Mjob        395 non-null    object
 9   Fjob        395 non-null    object
 10  reason      395 non-null    object
 11  guardian    395 non-null    object
 12  traveltime  395 non-null    int64 
 13  studytime   395 non-null    int64 
 14  failures    395 non-null    int64 
 15  schoolsup   395 non-null    object
 16  famsup      395 non-null    object
 17  paid        395 non-null    object
 18  activities  395 non-null    object
 19  nursery     395 non-null    object
 20  higher    

In [11]:
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import make_pipeline
from sklearn.metrics import r2_score
num_pipeline = make_pipeline(
    SimpleImputer(strategy='median'),
    StandardScaler()
)
str_pipeline = make_pipeline(
    SimpleImputer(strategy='most_frequent'),
    OneHotEncoder(handle_unknown='ignore')
)

nums = ['age', 'Medu', 'Fedu', 'traveltime', 'studytime', 'failures', 'famrel', 'freetime', 'goout', 'Dalc', 'Walc', 'health', 'absences']
strs = ['school', 'sex', 'address', 'famsize', 'Pstatus', 'Mjob', 'Fjob', 'reason', 'guardian', 'schoolsup', 'famsup', 'paid', 'activities', 'nursery', 'higher', 'internet', 'romantic']

preprocessing = ColumnTransformer([
    ('num', num_pipeline, nums),
    ('str', str_pipeline, strs)
])
train_model = make_pipeline(preprocessing, LogisticRegression())

data_train, data_test = train_test_split(data, test_size=0.2, random_state=42)
y_train = data_train['G3'].copy()
X_train = data_train.drop(['G1', 'G2', 'G3'], axis=1)
y_test = data_test['G3'].copy()
X_test = data_test.drop(['G1', 'G2', 'G3'], axis=1)

train_model.fit(X_train, y_train)
# print(train_model.score(X_train, y_train))
y_pred = train_model.predict(X_test)
# print(np.sum(np.abs(y_pred - y_test)) / len(y_test))
print(f"mse: {mean_squared_error(y_test, y_pred):.4f}")
print(f"r2: {r2_score(y_test, y_pred)}")

mse: 23.6456
r2: -0.1531585034226235


STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


In [12]:
# 决策树
from sklearn.tree import DecisionTreeRegressor

train_model1 = make_pipeline(preprocessing, DecisionTreeRegressor(random_state=42))
train_model1.fit(X_train, y_train)
# print(train_model1.score(X_train, y_train))
y_pred = train_model1.predict(X_test)
# print(np.sum(np.abs(y_pred - y_test)) / len(y_test))
print(f"mse: {mean_squared_error(y_test, y_pred):.4f}")
print(f"r2: {r2_score(y_test, y_pred)}")

mse: 23.4051
r2: -0.14142937517581955


In [13]:
# 多项式回归
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures

train_model2 = make_pipeline(preprocessing, PolynomialFeatures(degree=2), LinearRegression())
train_model2.fit(X_train, y_train)
# print(train_model2.score(X_train, y_train))
y_pred = train_model2.predict(X_test)
print(f"mse: {mean_squared_error(y_test, y_pred):.4f}")
print(f"r2: {r2_score(y_test, y_pred)}")

mse: 45.6595
r2: -1.226744584556911
