In [1]:
'''
使用KNN做车型分析
'''
# 导入包
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report
import numpy as np
from tqdm import tqdm

#### 加载数据集

In [2]:
# 数据预处理
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Prepare the dataset and dataloader
dataset = ImageFolder('../TypeData', transform=transform)
train_size = int(0.75 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

#### 定义网络结构

In [3]:
# 展平图像 + 提取标签
def flatten_data(dataloader):
    X, y = [], []
    for images, labels in tqdm(dataloader, desc="Flattening"):
        images = images.view(images.size(0), -1)  # 展平 (B, 3*64*64)
        X.append(images.numpy())
        y.append(labels.numpy())
    return np.vstack(X), np.concatenate(y)


X_train, y_train = flatten_data(train_loader)
X_test, y_test = flatten_data(test_loader)

Flattening: 100%|██████████| 23/23 [00:01<00:00, 16.68it/s]
Flattening: 100%|██████████| 8/8 [00:00<00:00, 17.99it/s]


#### 训练模型

In [4]:
knn = KNeighborsClassifier(n_neighbors=3)  # 可尝试改 k 值
knn.fit(X_train, y_train)

In [5]:
y_pred = knn.predict(X_test)
print("\nKNN 分类报告：")
print(classification_report(y_test, y_pred, target_names=dataset.classes))


KNN 分类报告：
              precision    recall  f1-score   support

         bus       0.97      0.94      0.95        31
         car       1.00      0.97      0.98        29
     minibus       0.94      1.00      0.97        31
       truck       0.94      0.94      0.94        31

    accuracy                           0.96       122
   macro avg       0.96      0.96      0.96       122
weighted avg       0.96      0.96      0.96       122

