In [2]:
##优化过拟合版本
import glob
import os

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image, ImageFile
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
from torchvision import models
from torchvision import transforms as T
from torchvision.transforms.functional import InterpolationMode, resize
from torchvision.models import ResNet50_Weights
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path


# 防止 PIL 出错,尽量读取jpg
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [3]:
class Net(nn.Module):
    """ 修改后模型架构，增加 Dropout 防止过拟合 """

    def __init__(self):
        super().__init__()
        # self.net = models.resnet50(weights=ResNet50_Weights.DEFAULT)
        self.net = models.resnet50()
        
        # 修改 ResNet50 的全连接层（fc）
        self.net.fc = nn.Sequential(
            nn.Linear(2048, 1024),  # 先缩小维度
            nn.ReLU(),
            nn.Dropout(0.5),  # 添加 Dropout
            nn.Linear(1024, 58)  # 输出类别数
        )

    def forward(self, X):
        return self.net(X)


In [4]:
class MyDataset(Dataset):
    classes = [
        '5kilometer',
        '15kilometer',
        '30kilometer',
        '40kilometer',
        '50kilometer',
        '60kilometer',
        '70kilometer',
        '90kilometer',
        'No Left Turn or Straight Ahead',
        'No Right Turn or Straight Ahead',
        'No Straight Ahead',
        'No Left Turn',
        'No Left or Right Turn',
        'No Right Turn',
        'No Overtaking',
        'No U-turn',
        'No Entry for Motor Vehicles',
        'No Horn',
        'End Speed Limit 40',
        'End Speed Limit 50',
        'Turn Right or Go Straight Ahead',
        'Ahead Only',
        'Left Turn Only',
        'Left or Right Turn Only',
        'Right Turn Only',
        'Keep Left',
        'Keep Right',
        'Roundabout',
        'Motor Vehicles Only',
        'Sound Horn',
        'Bicycles Only',
        'U-turn Only',
        'Divided Road Ahead',
        'Traffic Signals Ahead',
        'General Warning',
        'Pedestrian Crossing Ahead',
        'Cyclists Ahead',
        'Children Crossing Ahead',
        'Right Curve Ahead',
        'Left Curve Ahead',
        'Steep Descent',
        'Steep Ascent',
        'SLOW',
        'Side Road Junction Ahead',
        'Side Road Junction (left) Ahead',
        'Built-up Area Warning',
        'Winding Road Ahead',
        'train ahead',
        'Road Works Ahead',
        'Continuous sharp turn sign',
        'Railway level crossing',
        'Rear End Collision',
        'STOP',
        'No Entry for Vehicles',
        'No Stopping',
        'No Entry',
        'Give Way',
        'Stop - Police'
    ]

    label = {
        c: index
        for index, c in enumerate(classes)
    }

    def __init__(
        self,
        root=r'./data',
        is_train=True
    ):
        self.path = os.path.join(root, 'train' if is_train else 'validation')
        self.images = []
        self.merge_images()
        self.t = T.Compose([
            T.RandomResizedCrop(224, scale=(0.8, 1.0)),  # 随机裁剪
            T.RandomHorizontalFlip(p=0.5),  # 随机水平翻转
            T.RandomRotation(15),  # 轻微旋转
            T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # 颜色抖动
            T.RandomAffine(degrees=0, translate=(0.1, 0.1)),  # 轻微平移
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 归一化
        ])


    def merge_images(self):
        valid_exts = [".jpg", ".jpeg", ".png", ".bmp"]
        for c in self.classes:
            p = Path(self.path) / c
            if not p.exists():
                continue
            for img in p.iterdir():
                if img.suffix.lower() in valid_exts:
                    self.images.append((str(img), c))


    def __getitem__(self, index):
        image_path, image_label = self.images[index]
        image_label = self.label[image_label]
        try:
            image = Image.open(image_path).convert("RGB")  # 确保是 RGB
            image = self.t(image)
            return image, torch.LongTensor([image_label])
        except Exception as e:
            print(f"Warning: Skipping corrupted image {image_path}. Error: {e}")
            return self.__getitem__((index + 1) % len(self.images))


    def __len__(self):
        return len(self.images)

In [5]:
def dataload(
    is_train=True,
    batch_size=64
):
    dataset = MyDataset(is_train=is_train)
    return DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=is_train
    )


def train_batch(
    batch,
    optimizer,
    loss,
    net
):
    optimizer.zero_grad()
    if torch.cuda.is_available():
        batch = [i.cuda() for i in batch]
    X, Y = batch
    Y_hat = net(X)
    #Y_hat = F.softmax(Y_hat, -1)
    l = loss(Y_hat, Y.flatten())
    l.sum().backward()
    optimizer.step()
    with torch.no_grad():
        return l.sum().cpu()
    
def test_batch(
    batch,
    net,
    loss
):
    if torch.cuda.is_available():
        batch = [i.cuda() for i in batch]
    X, Y = batch
    Y_hat = net(X)
    #Y_hat = F.softmax(Y_hat, -1)
    l = loss(Y_hat, Y.flatten())
    return l.sum().cpu()


def save_checkpoint(net, path, epoch):
    torch.save(net.state_dict(), path)
    print(f'已保存模型 (Epoch {epoch})')  #  添加 epoch 信息

In [6]:
def train(
    epoch,
    lr=1e-4, #更改
    train_batch_size=32,
    test_batch_size=32,
    path='parameters_Resnet50.cpt'
):
    net = Net()
    if torch.cuda.is_available():
        net.cuda()
    loss = nn.CrossEntropyLoss()
    optimizer = Adam(net.parameters(), lr=lr, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epoch) #更改
    dataset = dataload(is_train=True, batch_size=train_batch_size)
    valid_dataset = dataload(is_train=False, batch_size=test_batch_size)
    
    writer = SummaryWriter(log_dir='logs')  # 指定日志目录

    min_mean_loss = None
    min_valid_loss = None

    for i in range(epoch):
        print(f'epoch: {i+1}/{epoch}')
        # 打印当前学习率，便于监控
        current_lr = optimizer.param_groups[0]['lr']
        print(f'当前学习率: {current_lr:.6f}')
        
        mean_loss = []
        mean_valid_loss = []
        net.train()
        for batch in dataset:
            l = train_batch(
                batch=batch,
                optimizer=optimizer,
                loss=loss,
                net=net
            )
            mean_loss.append(l)

        with torch.no_grad(): #关闭梯度计算
            net.eval()
            for batch in valid_dataset:
                l = test_batch(batch=batch, net=net, loss=loss)
                mean_valid_loss.append(l)

        writer.add_scalar('Resnet50_Loss/train', np.mean(mean_loss), i+1)
        writer.add_scalar('Resnet50_Loss/test', np.mean(mean_valid_loss), i+1)
        
        print(f'train loss: {np.mean(mean_loss):.5f}')
        print(f'test loss: {np.mean(mean_valid_loss):.5f}')

        # if min_mean_loss is None or min_mean_loss > np.mean(mean_loss):
        #     save_checkpoint(net, path, i+1)  #  传递当前 epoch
        #     min_mean_loss = np.mean(mean_loss)
            
        if min_valid_loss is None or min_valid_loss > np.mean(mean_valid_loss): #更改
            save_checkpoint(net, path, i+1)  # 传递当前 epoch
            min_valid_loss = np.mean(mean_valid_loss)
            
        #让学习率衰减
        scheduler.step()
            
    writer.close()

In [7]:
def load_net_from_hdf(path):
    net = Net()  # 先实例化网络结构
    net.load_state_dict(torch.load(path, map_location=torch.device('cuda:0')))  # 只加载权重
    net.eval()
    return net

def predict(X, Y, net=None):
    return net(X).argmax(-1).flatten(), Y.flatten()

In [15]:
if __name__ == '__main__':
    # 训练
    train(
        epoch=100,
        train_batch_size=16,
        test_batch_size=16
    )


epoch: 1/100
当前学习率: 0.000100
train loss: 3.40104
test loss: 3.55096
已保存模型 (Epoch 1)
epoch: 2/100
当前学习率: 0.000100
train loss: 2.77887
test loss: 3.14914
已保存模型 (Epoch 2)
epoch: 3/100
当前学习率: 0.000100
train loss: 2.36102
test loss: 2.70288
已保存模型 (Epoch 3)
epoch: 4/100
当前学习率: 0.000100
train loss: 1.94401
test loss: 2.43751
已保存模型 (Epoch 4)
epoch: 5/100
当前学习率: 0.000100
train loss: 1.60678
test loss: 2.49246
epoch: 6/100
当前学习率: 0.000099
train loss: 1.33995
test loss: 2.60530
epoch: 7/100
当前学习率: 0.000099
train loss: 1.11906
test loss: 2.07296
已保存模型 (Epoch 7)
epoch: 8/100
当前学习率: 0.000099
train loss: 0.93252
test loss: 1.67700
已保存模型 (Epoch 8)
epoch: 9/100
当前学习率: 0.000098
train loss: 0.77976
test loss: 1.55289
已保存模型 (Epoch 9)
epoch: 10/100
当前学习率: 0.000098
train loss: 0.62543
test loss: 1.49299
已保存模型 (Epoch 10)
epoch: 11/100
当前学习率: 0.000098
train loss: 0.54161
test loss: 1.17674
已保存模型 (Epoch 11)
epoch: 12/100
当前学习率: 0.000097
train loss: 0.44220
test loss: 1.35473
epoch: 13/100
当前学习率: 0.000096
train

KeyboardInterrupt: 

In [8]:
path = 'parameters_Resnet50.cpt' 
net = load_net_from_hdf(path)

classes = [
    '5kilometer',
    '15kilometer',
    '30kilometer',
    '40kilometer',
    '50kilometer',
    '60kilometer',
    '70kilometer',
    '90kilometer',
    'No Left Turn or Straight Ahead',
    'No Right Turn or Straight Ahead',
    'No Straight Ahead',
    'No Left Turn',
    'No Left or Right Turn',
    'No Right Turn',
    'No Overtaking',
    'No U-turn',
    'No Entry for Motor Vehicles',
    'No Horn',
    'End Speed Limit 40',
    'End Speed Limit 50',
    'Turn Right or Go Straight Ahead',
    'Ahead Only',
    'Left Turn Only',
    'Left or Right Turn Only',
    'Right Turn Only',
    'Keep Left',
    'Keep Right',
    'Roundabout',
    'Motor Vehicles Only',
    'Sound Horn',
    'Bicycles Only',
    'U-turn Only',
    'Divided Road Ahead',
    'Traffic Signals Ahead',
    'General Warning',
    'Pedestrian Crossing Ahead',
    'Cyclists Ahead',
    'Children Crossing Ahead',
    'Right Curve Ahead',
    'Left Curve Ahead',
    'Steep Descent',
    'Steep Ascent',
    'SLOW',
    'Side Road Junction Ahead',
    'Side Road Junction (left) Ahead',
    'Built-up Area Warning',
    'Winding Road Ahead',
    'train ahead',
    'Road Works Ahead',
    'Continuous sharp turn sign',
    'Railway level crossing',
    'Rear End Collision',
    'STOP',
    'No Entry for Vehicles',
    'No Stopping',
    'No Entry',
    'Give Way',
    'Stop - Police'
]

# **计算最终的验证集准确率**
total_correct = 0
total_samples = 0

# **计算每个类别的正确率**
correct_counts = {c: 0 for c in classes}  # 统计每个类别预测正确的样本数
total_counts = {c: 0 for c in classes}    # 统计每个类别的总样本数

for X, Y in dataload(False, batch_size=16):
    Y_hat, Y = predict(X, Y, net)
    
    total_correct += (Y_hat == Y).sum().item()  # 总正确数
    total_samples += Y.shape[0]  # 总样本数
    
    for i in range(Y.shape[0]):
        true_class = classes[Y[i].item()]
        pred_class = classes[Y_hat[i].item()]

        total_counts[true_class] += 1  # 该类别总样本数+1
        if true_class == pred_class:
            correct_counts[true_class] += 1  # 该类别预测正确数+1

# **打印最终整体准确率**
accuracy = total_correct / total_samples * 100
print(f'Final Validation Accuracy: {accuracy:.2f}%')

# **打印每个类别的准确率**
print("\n分类正确率：")
for c in classes:
    class_acc = (correct_counts[c] / total_counts[c]) * 100 if total_counts[c] > 0 else 0
    print(f"{c}: {class_acc:.2f}% ({correct_counts[c]}/{total_counts[c]})")



Final Validation Accuracy: 85.96%

分类正确率：
5kilometer: 85.71% (12/14)
15kilometer: 83.33% (10/12)
30kilometer: 98.33% (59/60)
40kilometer: 100.00% (84/84)
50kilometer: 91.38% (53/58)
60kilometer: 68.00% (34/50)
70kilometer: 100.00% (30/30)
90kilometer: 84.00% (42/50)
No Left Turn or Straight Ahead: 92.86% (13/14)
No Right Turn or Straight Ahead: 0.00% (0/2)
No Straight Ahead: 96.67% (58/60)
No Left Turn: 93.08% (121/130)
No Left or Right Turn: 100.00% (22/22)
No Right Turn: 78.26% (72/92)
No Overtaking: 100.00% (12/12)
No U-turn: 50.00% (18/36)
No Entry for Motor Vehicles: 97.37% (74/76)
No Horn: 78.57% (66/84)
End Speed Limit 40: 0.00% (0/0)
End Speed Limit 50: 0.00% (0/0)
Turn Right or Go Straight Ahead: 100.00% (2/2)
Ahead Only: 75.00% (9/12)
Left Turn Only: 62.50% (5/8)
Left or Right Turn Only: 60.00% (6/10)
Right Turn Only: 84.62% (22/26)
Keep Left: 100.00% (36/36)
Keep Right: 90.30% (121/134)
Roundabout: 100.00% (24/24)
Motor Vehicles Only: 89.71% (61/68)
Sound Horn: 100.00% (26/2