In [1]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm
from sklearn.metrics import roc_curve, auc

import torch
import torch.utils as utils
from torchvision import datasets, transforms

In [2]:
dataset_name = 'train_data9'
train_set = datasets.ImageFolder(root='dataset/preproced_data/%s/'%dataset_name,
                       transform=transforms.Compose([
#                              transforms.RandomHorizontalFlip(p=0.5),
##                             transforms.RandomVerticalFlip(p=0.5),
##                             transforms.RandomRotation(90),
                         transforms.Grayscale(),
                         transforms.Resize((64,64)),
                         transforms.ToTensor(),
                         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]))
train_loader = utils.data.DataLoader(dataset=train_set, batch_size=1000,shuffle=True)

for idx, (data, target) in enumerate(train_loader):
    x_train = data.cpu().numpy()

score_num = 1
dir_name = ['abnormal', 'normal']
test_set = datasets.ImageFolder(root='dataset/dongdong2/%s/'%dir_name[score_num],
                           transform=transforms.Compose([
#                              transforms.RandomHorizontalFlip(p=0.5),
                             transforms.Grayscale(),
                             transforms.Resize((64,64)),
                             transforms.ToTensor(),
                             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]))
test_loader = utils.data.DataLoader(dataset=test_set,batch_size=100,shuffle=False)

for idx, (data, target) in enumerate(test_loader):
    x_test = data.cpu().numpy()
    
score_num = 0    
test_set = datasets.ImageFolder(root='dataset/dongdong2/%s/'%dir_name[score_num],
                           transform=transforms.Compose([
#                              transforms.RandomHorizontalFlip(p=0.5),
                             transforms.Grayscale(),
                             transforms.Resize((64,64)),
                             transforms.ToTensor(),
                             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]))
test_loader = utils.data.DataLoader(dataset=test_set,batch_size=100,shuffle=False)

for idx, (data, target) in enumerate(test_loader):
    x_out = data.cpu().numpy()

x_train = np.reshape(x_train, (len(x_train),-1))
x_test = np.reshape(x_test, (len(x_test), -1))
x_out = np.reshape(x_out, (len(x_out),-1))
print(np.shape(x_train), np.shape(x_test), np.shape(x_out))

(631, 4096) (25, 4096) (31, 4096)


In [7]:
def cal_AUC(sensitivity, specificity):
    wid = (1-specificity)[:-1]-(1-specificity)[1:]
    auc = np.sum(sensitivity[1:]*wid)
    return auc

fpr_list = []
tpr_list = []
for k in range(1,21):
    clf = svm.OneClassSVM(nu=0.05*k, kernel='rbf', gamma='auto')
    clf.fit(x_train)
    y_pred_train = clf.predict(x_train)
    y_pred_test = clf.predict(x_test)
    y_pred_out = clf.predict(x_out)

    label = np.append(np.zeros(len(y_pred_test)),np.ones(len(y_pred_out)))
    y_hat = np.append(np.float32(y_pred_test!=1), np.float32(y_pred_out!=1))

    fpr, tpr, threshold = roc_curve(label, y_hat, drop_intermediate=False)
    fpr_list.append(fpr[1])
    tpr_list.append(tpr[1])

auc = cal_AUC(np.array(tpr_list), np.array(fpr_list))
plt.plot(fpr_list, tpr_list, 'ro-')
plt.xlabel('1-Specificity  (FPR)', fontsize=12)
plt.ylabel('Sensitivity  (TPR)', fontsize=12)
plt.plot([0,1],[0,1],'k--')
plt.title('AUC : %.4f'%auc, fontsize=14)
plt.savefig('ROC_plot.png')
plt.show()

(631,) (25,) (31,)
[29.17122338 29.23231893 29.28678663 29.32727464 29.15823093 29.24488894
 29.27506313 29.21697635 29.29733379 29.28047626 29.3332268  29.18416487
 29.28047327 29.29501997 29.34726851 29.2309149  29.30331095 29.3381879
 29.35037389 29.25497877 29.25998913 29.21468286 29.12564626 29.15279672
 29.26954637] [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
(631,) (25,) (31,)
[59.02405302 59.14998599 59.28239266 59.31044581 58.99840828 59.19560473
 59.262406   59.13626645 59.27222376 59.27715642 59.35671049 59.06392532
 59.26716463 59.29712263 59.43025226 59.144433   59.34565805 59.35607962
 59.40087425 59.24471663 59.24378    59.15792894 58.89163904 59.03040411
 59.21208865] [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
(631,) (25,) (31,)
[89.22326906 89.45802973 89.64651101 89.66634467 89.17692037 89.52358923
 89.64587802 89.39231806 89.65365296 89.63870158 89.74786278 89.3596809
 89.64428653 89.68123998 89.88166701 89.40470509 89.77285356 89.76785921
 89.865047

KeyboardInterrupt: 