In [2]:
import numpy as np
import matplotlib.pyplot as plt
from itertools import cycle
from sklearn import svm, datasets
from sklearn.metrics import roc_curve, auc
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier
from scipy import interp

%matplotlib inline

In [3]:
iris = datasets.load_iris()
X = iris.data
y = iris.target

In [4]:
# 查看原来标签数据格式
print(y.shape)
print(y)
# 标签转化
y = label_binarize(y, classes=[0, 1, 2])
print(y[:3])

(150,)
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2]
[[1 0 0]
 [1 0 0]
 [1 0 0]]


In [8]:
# 设置种类
n_classes = y.shape[1]

# 训练模型并预测
random_state = np.random.RandomState(0)
n_samples, n_features = X.shape
# 随机化数据，并划分训练数据和测试数据
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.5,random_state=0)

# Learn to predict each class against the other
model = OneVsRestClassifier(svm.SVC(kernel='linear', probability=True,random_state=random_state))
clt = model.fit(X_train, y_train)

In [10]:
# 分别在训练集和测试集上查看得分
clt.score(X_train, y_train)

0.8133333333333334

In [11]:
clt.score(X_test,y_test)

0.6533333333333333

In [12]:
# 利用SVM的方法decision_function给每个样本中的每个类一个评分
y_preds_scores=clt.decision_function(X_test)
y_preds_scores[:5]

array([[-3.58459897, -0.31176426,  1.78242707],
       [-2.15411929,  1.11402775, -2.393737  ],
       [ 1.89199335, -3.89624382, -6.29685764],
       [-4.52609987, -0.63389114,  1.96065819],
       [ 1.39684192, -1.77742447, -6.26300472]])

In [13]:
# 根据评分将其转化为原始标签格式
np.argmax(clt.decision_function(X_test), axis=1)[:5]

array([2, 1, 0, 2, 0])

In [14]:
# 利用predict_proba查看每一类的预测概率
clt.predict_proba(X_test)[:4]

array([[3.37637164e-03, 3.50100247e-01, 9.11691552e-01],
       [4.16381095e-02, 7.00996181e-01, 3.02284397e-02],
       [9.81254695e-01, 1.28566941e-02, 1.30907603e-04],
       [6.42843364e-04, 2.76107276e-01, 9.29762069e-01]])

In [15]:
np.argmax(clt.predict_proba(X_test),axis=1)[:5]

array([2, 1, 0, 2, 0])