In [1]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn.model_selection import train_test_split
sns.set()
colors = sns.color_palette("husl")

## 基本使用

In [29]:
X = np.random.random(size=(10,4))
y = np.random.randint(0,2,size=10)

In [30]:
from sklearn.linear_model import LogisticRegression

In [31]:
lr = LogisticRegression()
lr.fit(X,y)

LogisticRegression()

### 特征系数

In [32]:
lr.coef_

array([[-0.26357219,  0.37842564,  0.33051234, -0.31007345]])

In [33]:
lr.predict(X)

array([1, 0, 0, 1, 0, 1, 1, 1, 0, 1])

### 预测结果的概率

In [34]:
lr.predict_proba(X)

array([[0.48461664, 0.51538336],
       [0.52282813, 0.47717187],
       [0.50390781, 0.49609219],
       [0.46635565, 0.53364435],
       [0.56330619, 0.43669381],
       [0.49693808, 0.50306192],
       [0.49924916, 0.50075084],
       [0.44361584, 0.55638416],
       [0.57852686, 0.42147314],
       [0.4406548 , 0.5593452 ]])

In [35]:
threshold = 0.5
(lr.predict_proba(X)[:,1] > threshold).astype(np.int8)

array([1, 0, 0, 1, 0, 1, 1, 1, 0, 1], dtype=int8)

## 手写数字识别

In [38]:
from sklearn.datasets import load_digits
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import cross_val_score

In [37]:
digits = load_digits()

data = digits.data
target = digits.target

In [61]:
X_train,X_test,y_train,y_test = train_test_split(data,target,test_size=0.2,random_state=1)

In [65]:
lr = LogisticRegression(max_iter=10000)
knn = KNeighborsClassifier(weights="distance")
lr_y_ = cross_val_score(lr,data,target,cv=5)
knn_y_ = cross_val_score(knn,data,target,cv=5)
lr.fit(X_train,y_train)

LogisticRegression(max_iter=10000)

In [67]:
print(f"lr:{lr_y_}\nknn:{knn_y_}\ntype:{type(lr_y_)}")
pd.DataFrame(data={
    "LR":lr,
    "KNN":knn
},index = [0,1,2,3,4]).plot()

lr:[0.925      0.875      0.93871866 0.93314763 0.89693593]
knn:[0.95277778 0.95555556 0.96657382 0.98050139 0.96100279]
type:<class 'numpy.ndarray'>


TypeError: no numeric data to plot

### 调优

In [66]:
# C
lr1 = LogisticRegression(penalty="l2",C=0.1,max_iter=10000)
lr1.fit(X_train,y_train)
print(f"lr:{lr.score(X_test,y_test)}\nlr1:{lr1.score(X_test,y_test)}")

lr:0.9722222222222222
lr1:0.975


### 网格搜索调参

In [69]:
from sklearn.model_selection import GridSearchCV

In [71]:
lr = LogisticRegression(max_iter=1000)

In [73]:
# 要搜索的参数的集合
parm_grid = {
    "penalty":["l1","l2"],
    "C":[0.1,0.5,1,5,10]
}

# 构造网格搜索对象
gscv = GridSearchCV(estimator=lr,param_grid=parm_grid,cv=5)

# 开始搜索
gscv.fit(data,target)

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

GridSearchCV(cv=5, estimator=LogisticRegression(max_iter=1000),
             param_grid={'C': [0.1, 0.5, 1, 5, 10], 'penalty': ['l1', 'l2']})

In [74]:
# 获取最best_params_(gscv.best_params_)
print(gbest_estimator_t_estimator_)
print(gscv.predict(X_test))

{'C': 0.1, 'penalty': 'l2'}
LogisticRegression(C=0.1, max_iter=1000)
[1 5 0 7 1 0 6 1 5 4 9 2 7 8 4 6 9 3 7 4 7 1 8 6 0 9 6 1 3 7 5 9 8 3 2 8 8
 1 1 0 7 9 0 0 8 7 2 7 4 3 4 3 4 0 4 7 0 5 5 5 2 1 7 0 5 1 8 3 3 4 0 3 7 4
 3 4 2 9 7 3 2 5 3 4 1 5 5 2 5 2 2 2 2 7 0 8 1 7 4 2 3 8 2 3 3 0 2 9 9 2 3
 2 8 1 1 9 1 2 0 4 8 5 4 4 7 6 7 6 6 1 7 5 6 3 8 3 7 1 8 5 3 4 7 8 5 0 6 0
 6 3 7 6 5 6 2 2 2 3 0 7 6 5 6 4 1 0 6 0 6 4 0 9 3 8 1 2 3 1 9 0 7 6 2 9 3
 5 3 4 6 3 3 7 4 9 2 7 6 1 6 8 4 0 3 1 0 9 9 9 0 1 8 6 8 0 9 5 9 8 2 3 5 3
 0 8 7 4 0 3 3 3 6 3 3 2 9 1 6 9 0 4 2 2 7 9 1 6 7 6 3 7 1 9 3 4 0 6 4 8 5
 3 6 3 1 4 0 4 4 8 7 9 1 5 2 7 0 9 0 4 4 0 1 0 6 4 2 8 5 0 2 6 0 1 8 2 0 9
 5 6 2 0 5 0 9 1 4 7 1 7 0 6 6 8 0 2 2 6 9 9 7 5 1 7 6 4 6 1 9 4 7 1 3 7 8
 1 6 9 8 3 2 4 8 7 5 5 6 9 9 8 5 0 0 4 9 3 0 4 9 4 2 5]


In [76]:
best_model = gscv.best_estimator_
best_model.fit(X_train,y_train)
best_model.score(X_test,y_test)

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


0.975