In [None]:
'''Kmeans'''

In [3]:
from sklearn.cluster import KMeans
import numpy as np
from scipy import stats
import pickle
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, label_binarize
from sklearn.metrics import accuracy_score
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import confusion_matrix, precision_score, accuracy_score, recall_score
from sklearn.metrics import f1_score, roc_auc_score, roc_curve, auc
import matplotlib.pyplot as plt

In [4]:
'''加载数据集&数据预处理'''
def loadDataSet():
    # 加载CIFAR-10数据集
    def unpickle(file):
        with open(file, 'rb') as fo:
            dict = pickle.load(fo, encoding='bytes')
        return dict

    def load_cifar10_batch(file):
        batch = unpickle(file)
        images = batch[b'data']
        labels = batch[b'labels']
        images = images.reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1).astype("float32")
        labels = np.array(labels)
        return images, labels

    # 加载训练集和测试集
    train_images = []
    train_labels = []
    for i in range(1, 6):
        file = f'C:/Users/ClairDeLune/Desktop/机器学习大作业/data-CIFAR-10/cifar-10-batches-py/data_batch_{i}'
        images, labels = load_cifar10_batch(file)
        train_images.append(images)
        train_labels.append(labels)

    train_images = np.concatenate(train_images)
    train_labels = np.concatenate(train_labels)

    test_images, test_labels = load_cifar10_batch('C:/Users/ClairDeLune/Desktop/机器学习大作业/data-CIFAR-10/cifar-10-batches-py/test_batch')

    # 展平图像
    train_images = train_images.reshape(-1, 3 * 32 * 32)
    test_images = test_images.reshape(-1, 3 * 32 * 32)

    # 标准化特征
    scaler = StandardScaler()
    train_images = scaler.fit_transform(train_images)
    test_images = scaler.transform(test_images)

    return train_images, test_images, train_labels, test_labels

In [5]:
'''定义训练函数、预测函数及ROC曲线绘制函数'''

def test(model, x_test, y_test):
    y_pred = model.predict(x_test)
    con_matrix = confusion_matrix(y_test, y_pred)
    print('confusion_matrix:\n', con_matrix)
    print('accuracy:{}'.format(accuracy_score(y_test, y_pred)))
    print('precision:{}'.format(precision_score(y_test, y_pred, average='micro')))
    print('recall:{}'.format(recall_score(y_test, y_pred, average='micro')))
    print('f1-score:{}'.format(f1_score(y_test, y_pred, average='micro')))

def plot_roc_curve(model, x_test, y_test, n_classes):
    # 计算测试集上的概率预测
    y_prob = model.predict_proba(x_test)
    
    # 计算每个类别的ROC曲线和AUC
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(y_test == i, y_prob[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
    
    # 绘制所有类别的ROC曲线
    plt.figure(figsize=(8, 6))
    for i in range(n_classes):
        plt.plot(fpr[i], tpr[i], label=f'Class {i + 1} (area = {roc_auc[i]:.2f})')
    
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic')
    plt.legend(loc="lower right")
    plt.show()

In [6]:
X_train, X_test, y_train, y_test = loadDataSet()

In [None]:
n_classes = 10
model = KMeans(n_clusters=n_classes, random_state=42)
model.fit(X_train)

In [11]:
def cluster_indices(cluster_num, labels):
    return np.where(labels == cluster_num)[0]

def classification_accuracy(predicted_labels, true_labels):
    return np.mean(predicted_labels == true_labels)

def visualize(cluster_centroids):
    plt.figure(figsize=(8, 8))
    for i, centroid in enumerate(cluster_centroids):
        plt.subplot(5, 5, i + 1)
        plt.imshow(centroid.reshape(8, 8), cmap='gray')
        plt.axis('off')
    plt.show()

In [17]:
test(model, X_test, y_test)

confusion_matrix:
 [[184  33  81  63 109 105  42  33 211 139]
 [ 25 106 157  92 185  98 141  59  42  95]
 [109 119  45 251  69  54 161  66  84  42]
 [ 91 119  59 143 119  82 155 152  62  18]
 [ 73 120  51 243  49  39 213 172  20  20]
 [134  88  37 162 181  71 102 155  39  31]
 [ 31 210  27 210  93  86 207  86  39  11]
 [ 60  58 136 202  89 111 131 137  23  53]
 [ 94  33 193  45 168  32  31  39  39 326]
 [ 18  18 329  82  82 132 124  38  37 140]]
accuracy:0.1121
precision:0.1121
recall:0.1121
f1-score:0.1121
