# 采用RBF kernel 的SVM

数据集是UCI的CTGs。根据胎儿心率等特征，帮助产科医生对胎儿3种状态分类。

In [1]:
import pandas as pd
df = pd.read_excel('CTG.xls', 'Raw Data') # 里面有3各sheet，取‘Raw Data‘

In [2]:
df.shape

(2130, 40)

In [3]:
df.head(3)

Unnamed: 0,FileName,Date,SegFile,b,e,LBE,LB,AC,FM,UC,...,C,D,E,AD,DE,LD,FS,SUSP,CLASS,NSP
0,,NaT,,,,,,,,,...,,,,,,,,,,
1,Variab10.txt,1996-12-01,CTG0001.txt,240.0,357.0,120.0,120.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,9.0,2.0
2,Fmcs_1.txt,1996-05-03,CTG0002.txt,5.0,632.0,132.0,132.0,4.0,0.0,4.0,...,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,6.0,1.0


In [4]:
X = df.iloc[1:2126, 3:-2].values
Y = df.iloc[1:2126,-1].values 

确认类别比例

In [5]:
from collections import Counter
Counter(Y)

Counter({1.0: 1654, 2.0: 295, 3.0: 176})

分出20%作testing set

In [6]:
from sklearn.model_selection import train_test_split
X_train, X_test, Y_train, Y_test = train_test_split(X,Y,test_size=0.2, random_state=42)

采用‘RBF‘kernel，对参数C、gamma 进行调优

In [7]:
from sklearn.svm import SVC
svc = SVC(kernel='rbf')

In [8]:
parameters = {'C': (100, 1e3, 1e4, 1e5),
              'gamma': (1e-08, 1e-7, 1e-6, 1e-5)
             }

In [9]:
from sklearn.model_selection import GridSearchCV
grid_search = GridSearchCV(svc,parameters, n_jobs=-1, cv=3)

In [10]:
import timeit
start_time = timeit.default_timer()
grid_search.fit(X_train, Y_train)
print("--- %0.3fs seconds ---" % (timeit.default_timer() - start_time))

--- 9.051s seconds ---


In [11]:
print(grid_search.best_params_)
print(grid_search.best_score_)

{'C': 100000.0, 'gamma': 1e-07}
0.944705882353


In [12]:
svc_best = grid_search.best_estimator_
accuracy = svc_best.score(X_test, Y_test)
print("The accuracy on testing set is : {0:.1f}%".format(accuracy*100))

The accuracy on testing set is : 95.5%


In [13]:
prediction = svc_best.predict(X_test)
from sklearn.metrics import classification_report
report = classification_report(Y_test, prediction)
print(report)

             precision    recall  f1-score   support

        1.0       0.96      0.98      0.97       324
        2.0       0.89      0.91      0.90        65
        3.0       1.00      0.78      0.88        36

avg / total       0.96      0.96      0.95       425

