### 导入包及数据部分

In [None]:
import collections
from collections import defaultdict
from os.path import join

import matplotlib.pyplot as plt
import numpy as np  # 导入numpy，用于科学计算，如，矩阵运算
from sklearn.neighbors import KNeighborsClassifier  # 包装好的knn算法


def file2matrix(filename):
    fr = open(filename)
    number_of_lines = len(
        fr.readlines()
    )  # get the number of lines in the file
    # prepare matrix to return the number of features
    return_mat = np.zeros((number_of_lines, 5))
    class_label_vector = []  # prepare labels return
    fr = open(filename)
    for index, line in enumerate(fr.readlines()):
        line = line.strip()
        list_from_line = line.split("\t")
        return_mat[index, :] = list_from_line[0:5]  # chose features
        class_label_vector.append(float(list_from_line[-1]))
        # classLabelVector.append(float(0))
    return return_mat, class_label_vector


### 生成KNN分类器

In [None]:
data_x, data_y = file2matrix(join("data", "sample.dat"))

kNN_classifier_3 = KNeighborsClassifier(n_neighbors=3)  # K = 3
kNN_classifier_10 = KNeighborsClassifier(n_neighbors=10)  # K = 10
kNN_classifier_24 = KNeighborsClassifier(n_neighbors=24)  # K = 24

# 拟合
kNN_classifier_3.fit(data_x, data_y)
kNN_classifier_10.fit(data_x, data_y)
kNN_classifier_24.fit(data_x, data_y)


In [None]:
def ModulationClassTest(SNR, method, n, labels, kNN_classifier):
    accuracy = defaultdict(list)
    # 分别代表 BPSK, QPSK, 16QAM, 64QAM
    numbers = [
        defaultdict(list),
        defaultdict(list),
        defaultdict(list),
        defaultdict(list),
    ]
    for snr in SNR:
        filename = (
            "test" + labels[method] + "-" + str(n) + "-" + str(snr) + ".dat"
        )
        testDataMat, _ = file2matrix(join("data", filename))
        numTestVecs = testDataMat.shape[0]
        for i in range(4):
            numbers[i][snr] = 0.0
        for i in range(numTestVecs):

            X_predict = testDataMat[i, :].reshape(1, -1)
            y_predict = kNN_classifier.predict(X_predict)

            if y_predict == 2:
                numbers[0][snr] += 1.0
            elif y_predict == 4:
                numbers[1][snr] += 1.0
            elif y_predict == 16:
                numbers[2][snr] += 1.0
            elif y_predict == 64:
                numbers[3][snr] += 1.0

        accuracy[snr] = numbers[method][snr] / numTestVecs
        print("the total correct rate on %d dB SNR is:" % snr, (accuracy[snr]))
        accuracy = collections.OrderedDict(sorted(accuracy.items()))
        for i in range(4):
            print(
                ("正确" if i == method else "") + "判断为",
                labels[i] + ":",
                numbers[i][snr],
            )
    return accuracy


def ModulationClassTests(SNR, method):
    N = (200, 500)
    labels = ["BPSK", "QPSK", "16QAM", "64QAM"]
    fig, axs = plt.subplots(2, figsize=(20, 16))
    fig.suptitle("SNR vs Accuracy - " + labels[method], fontsize=3224)
    x = SNR

    for i, n in enumerate(N):
        print("K = 3:")
        accuracy_3 = ModulationClassTest(
            SNR, method, n, labels, kNN_classifier_3
        )
        print(accuracy_3)

        print("K = 10:")
        accuracy_10 = ModulationClassTest(
            SNR, method, n, labels, kNN_classifier_10
        )
        print(accuracy_10)

        print("K = 24:")
        accuracy_24 = ModulationClassTest(
            SNR, method, n, labels, kNN_classifier_24
        )
        print(accuracy_24)

        axs[i].plot(
            x,
            list(accuracy_3.values()),
            label="K = 3",
            marker="o",
            linewidth=2.0,
            linestyle="dashed",
            color="red",
        )
        axs[i].plot(
            x,
            list(accuracy_10.values()),
            label="K = 10",
            marker="o",
            linewidth=2.0,
            linestyle="dashed",
            color="green",
        )
        axs[i].plot(
            x,
            list(accuracy_24.values()),
            label="K = 24",
            marker="o",
            linewidth=2.0,
            linestyle="dashed",
            color="blue",
        )
        axs[i].set_title("N =" + str(n), fontsize=24)
        axs[i].legend(loc="upper left", frameon=False, fontsize=18)
        axs[i].grid()

    for ax in axs.flat:
        ax.set(
            xticks=np.arange(min(x), max(x) + 1, 2.0),
            yticks=np.arange(0, 1, 0.10),
            xlabel="SNR (dB)",
            ylabel="Test accuracy",
        )

    plt.show()


### BPSK ModulationClassTest

In [None]:
SNR = [2 * x for x in range(-2, 6)]
ModulationClassTests(SNR, 0)

### QPSK ModulationClassTest

In [None]:
SNR = [2 * x for x in range(-2, 6)]
ModulationClassTests(SNR, 1)

### 16QAM ModulationClassTest

In [None]:
SNR = [5 * x for x in range(0, 9)]
ModulationClassTests(SNR, 2)