In [1]:
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from prettytable import PrettyTable


def load_data():
    data_d = np.load('data_reduce.npy')
    data_l = np.load('labels_01.npy')
    X_train, X_test, y_train, y_test = train_test_split(
        data_d, data_l, test_size=0.30, random_state=0)
    ss = StandardScaler()
    X_train = ss.fit_transform(X_train)
    X_test = ss.transform(X_test)
    return X_train, X_test, y_train, y_test


def test_LR_multiomaial(X_train, X_test, y_train, y_test):
    cls = LogisticRegression(
        penalty='l1', multi_class='multinomial', solver='saga')
    cls.fit(X_train, y_train)
    out_arr = cls.predict(X_test)

    print('Score: %.4f' % cls.score(X_test, y_test))

    correct, correct1, correct0, false1, false0 = 0, 0, 0, 0, 0
    for i in range(out_arr.shape[0]):
        if out_arr[i] == y_test[i]:
            correct += 1
            if out_arr[i] == 1:
                correct1 += 1
            else:
                correct0 += 1
        else:
            if out_arr[i] == 1:
                false1 += 1
            else:
                false0 += 1

    tab = PrettyTable()
    tab.field_names = ["", "Cancer", "Normal", "Precision"]

    tab.add_row(['Predict Cancer', correct1, false1, "%.4f" %
                 (correct1/(correct1+false1))])
    tab.add_row(['Predict Normal', false0, correct0, "%.4f" %
                 (correct0/(correct0+false0))])
    tab.add_row(['Recall', "%.4f" % (correct1/(correct1+false0)),
                 "%.4f" % (correct0/(false1+correct0)), ""])
    print(tab)


if __name__ == '__main__':
    X_train, X_test, y_train, y_test = load_data()
    test_LR_multiomaial(X_train, X_test, y_train, y_test)

Score: 0.9356
+----------------+--------+--------+-----------+
|                | Cancer | Normal | Precision |
+----------------+--------+--------+-----------+
| Predict Cancer |  1118  |   69   |   0.9419  |
| Predict Normal |   45   |  537   |   0.9227  |
|     Recall     | 0.9613 | 0.8861 |           |
+----------------+--------+--------+-----------+


