In [1]:
import numpy as np
from numpy import exp
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.metrics import recall_score
from sklearn.metrics import precision_score
from sklearn.metrics import classification_report
from sklearn.metrics import accuracy_score

In [2]:
def loadTrainData():
    cancer = load_breast_cancer()  # 加载乳腺癌数据
    X = cancer.data  # 加载乳腺癌判别特征
    y = cancer.target  # 两个TAG，y = 0时为阴性，y = 1时为阳性
    # 将数据集划分为训练集和测试集，测试集占比为0.2
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
    X_train = X_train.T
    X_test = X_test.T
    return X_train, X_test, y_train, y_test


In [3]:

def sigmoid(inx):
    print(inx)
    return 1.0 / (1 + exp(-inx))

In [4]:
# 初始化参数
def initialize_para(dim):
    mu = 0
    sigma = 0.1
    np.random.seed(0)
    w = np.random.normal(mu, sigma, dim)
    w = np.reshape(w, (dim, 1))
    b = 0
    return w, b

In [5]:
# 前向传播
def propagate(w, b, X, Y):
    # eps防止log运算遇到0
    eps = 1e-5
    m = X.shape[1]
    # 计算初步运算结果
    A = sigmoid(np.dot(w.T, X) + b)
    # 计算损失函数值大小
    cost = -1 / m * np.sum(np.multiply(Y, np.log(A + eps)) +
                           np.multiply(1 - Y, np.log(1 - A + eps)))
    # 计算梯度值
    dw = 1 / m * np.dot(X, (A - Y).T)
    db = 1 / m * np.sum(A-Y)
    cost = np.squeeze(cost)

    grads = {"dw": dw,
             "db": db}
    # 返回损失函数大小以及反向传播的梯度值
    return grads, cost, A


In [6]:
# num_iterations 梯度下降次数
# learning_rate 学习率
def optimize(w, b, X, Y, num_iterations, learning_rate):
    costs = []  # 记录损失函数值

    # 循环进行梯度下降
    for i in range(num_iterations):
        # print(i)
        grads, cost, pre_Y = propagate(w, b, X, Y)
        dw = grads["dw"]
        db = grads["db"]

        w = w - learning_rate * dw
        b = b - learning_rate * db

        # 每100次循环记录一次损失函数大小并打印
        if i % 100 == 0:
            costs.append(cost)

        if i % 100 == 0:
            pre_Y[pre_Y >= 0.5] = 1
            pre_Y[pre_Y < 0.5] = 0
            pre_Y = pre_Y.astype(np.int64)
            acc = 1 - np.sum(pre_Y ^ Y) / len(Y)
            print("Iteration:{} Loss = {}, Acc = {}".format(i, cost, acc))

    # 最终参数值
    params = {"w": w,
              "b": b}

    return params, costs

In [7]:
def predict(w, b, X):
    # 样本个数
    m = X.shape[1]
    # 初始化预测输出
    Y_prediction = np.zeros((1, m))
    # 转置参数向量w
    w = w.reshape(X.shape[0], 1)

    # 预测结果
    Y_hat = sigmoid(np.dot(w.T, X) + b)

    # 将结果按照0.5的阈值转化为0/1
    for i in range(Y_hat.shape[1]):
        if Y_hat[:, i] > 0.5:
            Y_prediction[:, i] = 1
        else:
            Y_prediction[:, i] = 0

    return Y_prediction


初始化参数
对模型进行训练
在训练过程中进行前向传播，损失函数计算，梯度下降
找到最优解
测试集预测
模型评价


In [8]:
# 训练以及预测
def Logisticmodel(X_train, Y_train, X_test, Y_test, num_iterations=1000, learning_rate=0.1):
    # 初始化参数w，b
    w, b = initialize_para(X_train.shape[0])
    # 梯度下降找到最优参数
    parameters, costs = optimize(
        w, b, X_train, Y_train, num_iterations, learning_rate)

    w = parameters["w"]
    b = parameters["b"]

    # 训练集测试集的预测结果
    Y_prediction_train = predict(w, b, X_train)
    Y_prediction_test = predict(w, b, X_test)
    Y_prediction_test = Y_prediction_test.T

    # 模型评价
    accuracy_score_value = accuracy_score(Y_test, Y_prediction_test)
    recall_score_value = recall_score(Y_test, Y_prediction_test)
    precision_score_value = precision_score(Y_test, Y_prediction_test)
    classification_report_value = classification_report(
        Y_test, Y_prediction_test)

    print("准确率:", accuracy_score_value)
    print("召回率:", recall_score_value)
    print("精确率:", precision_score_value)
    print(classification_report_value)

    d = {"costs": costs,
         "Y_prediction_test": Y_prediction_test,
         "Y_prediction_train": Y_prediction_train,
         "w": w,
         "b": b,
         "learning_rate": learning_rate,
         "num_iterations": num_iterations}

    return d

In [9]:
if __name__ == '__main__':
    X_train, X_test, y_train, y_test = loadTrainData()
    Logisticmodel(X_train, y_train, X_test, y_test)


[[ 53.01708631  85.82838916  41.1237452   66.74038708  89.12143669
  263.63181697  64.61841388  75.05352001  66.87210744  68.95912898
   79.09328344  74.39297594 122.64451636 132.33986946 116.18778711
  182.69416595  74.86522853  44.27406699 127.11510027  54.58061857
  185.30955086  98.08419179  85.90034038  92.64676235  95.91715539
  135.17478581  81.27733665 167.50492361  97.55375144  71.46605862
   82.60839103  94.80369034  83.93374367  71.53414171  71.88842311
  142.62278122  93.03377304  92.30660011  81.56203856 139.53260535
   95.48829884  88.02505569  70.96294115 177.63992107 109.12929583
  135.35573691  98.45173637  70.70887843 180.88827145 107.04121975
  104.38980331 197.56045818 124.01452598 114.72272059  65.37677122
  106.55301875 116.38179528  71.91923065  72.17824841  95.32330211
   95.25399435 145.86870858 102.74919683  62.332325   152.5549306
   80.53972859 135.93605729  76.51547422  73.76782004 113.69430114
   78.2909623   65.4270962  104.09122215  28.63151493  77.01532

  return 1.0 / (1 + exp(-inx))


[[ 13770.10901927  23322.15743596  11663.81951239  17737.07319636
   24344.43413502  67451.3303009   17456.73944986  20744.19671212
   18294.50510804  18922.7929662   21606.5406781   20059.90212255
   30132.55542054  35696.64714195  27677.00197083  46557.78202292
   20209.85709841  12634.46427443  20676.84985536  15206.42862661
   48513.15301652  -3412.80254349  23387.05416449  24196.08091514
   25833.39872107  22610.98215495  20452.37993267  40216.86939814
   26707.40990994  19289.06061859  22675.72962375  19688.72731021
   22436.39205234  18772.11951087  19922.36136835  35665.41318939
   24935.89370614  24351.39175286  23023.30966591  31641.21769573
   25259.91134725  23074.45108068  19470.31991132  43041.78675797
   23104.93244538  35390.85695886  24462.9002733   19453.94099941
   42087.88190102  28918.73793727  28658.84543135  48729.24082227
   33754.7937722   30989.29610802   -253.06799596  26316.83790977
   31654.34627916  18719.34182481  19051.71276164  25572.77903069
   22415.3