In [None]:
import torch
from torchvision import transforms
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.datasets import VOCDetection
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np

# 定义数据转换
transform = transforms.Compose([
    transforms.ToTensor(),
])

# 定义设备
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
def custom_collate_fn(batch):
    images = [item[0] for item in batch]
    targets = []
    for item in batch:

        img, annotation = item
        num_objs = len(annotation["annotation"]["object"])
        boxes = []
        labels = []
        annotations = []
        for i in range(num_objs):
            annotations.append(annotation["annotation"])
            xmin = float(annotation["annotation"]["object"][i]["bndbox"]["xmin"])
            ymin = float(annotation["annotation"]["object"][i]["bndbox"]["ymin"])
            xmax = float(annotation["annotation"]["object"][i]["bndbox"]["xmax"])
            ymax = float(annotation["annotation"]["object"][i]["bndbox"]["ymax"])
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(annotation["annotation"]["object"][i]["name"])
        boxes = torch.as_tensor(boxes, dtype=torch.float32).to(device)
        labels = [class_to_idx[label] for label in labels]  # 将标签转换为对应的索引
        labels = torch.as_tensor(labels, dtype=torch.int64).to(device)

        target = {"boxes": boxes, "labels": labels}
        targets.append(target)
    return images, targets


data_path = 'VOCtrainval_06-Nov-2007'  # 你的 VOCdevkit 文件夹路径

# 加载VOC数据集
train_dataset = VOCDetection(root=data_path, year='2007', image_set='train', download=False, transform=transform)
test_dataset = VOCDetection(root=data_path, year='2007', image_set='val', download=False, transform=transform)

# DataLoader用于训练和测试
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=0, collate_fn=custom_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=custom_collate_fn)

VOC_CLASSES = (    # always index 0
    'aeroplane', 'bicycle', 'bird', 'boat',
    'bottle', 'bus', 'car', 'cat', 'chair',
    'cow', 'diningtable', 'dog', 'horse',
    'motorbike', 'person', 'pottedplant',
    'sheep', 'sofa', 'train', 'tvmonitor')
# 创建类别到索引的映射
class_to_idx = {class_name: idx for idx, class_name in enumerate(VOC_CLASSES)}


# 加载预训练的Faster R-CNN模型
model = fasterrcnn_resnet50_fpn(pretrained=True)
num_classes = len(train_loader)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)


model.to(device)

# 定义优化器和学习率调度器
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

# 训练模型
num_epochs = 10
loss_all = []
for epoch in range(num_epochs):
    model.train()
    all_loss = 0.0
    i = 0
    for images, targets in train_loader:
        images = list(image.to(device) for image in images)
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        all_loss += losses.item()
        if i % 100 == 0:
            print(f"小epoch:  [{i}/{len(train_loader)}]", )
        i += 1
    loss_all.append(all_loss)
    # 更新学习率
    lr_scheduler.step()
    print('====================================================')
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {all_loss/len(train_loader):.4f}")
plt.plot(range(len(loss_all)), loss_all)
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()
# 保存模型
torch.save(model.state_dict(), 'fastercnn_model.pth')


# 使用测试集进行推理
model.eval()
for images, targets in test_loader:
    images = list(image.to(device) for image in images)
    targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

    with torch.no_grad():
        predictions = model(images)

    for i, image in enumerate(images):
        img = image.permute(1, 2, 0).cpu().numpy()
        boxes = predictions[i]['boxes'].cpu().numpy()
        scores = predictions[i]['scores'].cpu().numpy()
        labels = predictions[i]['labels'].cpu().numpy()

        # 显示图像
        fig, ax = plt.subplots(1)
        ax.imshow(img)
        if i == 40:
            break
        # 可视化检测结果
        for box, score, label in zip(boxes, scores, labels):
            if score > 0.5:  # 设置阈值
                box = box.astype(np.int32)
                rect = patches.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1],
                                         linewidth=1, edgecolor='r', facecolor='none')
                ax.add_patch(rect)

                # 使用class_to_idx字典获取类别名称
                class_name = VOC_CLASSES[label]
                plt.text(box[0], box[1], class_name, color='red')

        plt.axis('off')
        plt.show()

小epoch:  [0/1251]
小epoch:  [100/1251]
小epoch:  [200/1251]
小epoch:  [300/1251]
小epoch:  [400/1251]
小epoch:  [500/1251]
小epoch:  [600/1251]
小epoch:  [700/1251]
小epoch:  [800/1251]
小epoch:  [900/1251]
小epoch:  [1000/1251]
小epoch:  [1100/1251]
小epoch:  [1200/1251]
Epoch [1/10], Loss: 0.5372
小epoch:  [0/1251]
小epoch:  [100/1251]
小epoch:  [200/1251]
小epoch:  [300/1251]
小epoch:  [400/1251]
小epoch:  [500/1251]
小epoch:  [600/1251]
小epoch:  [700/1251]
小epoch:  [800/1251]
小epoch:  [900/1251]
小epoch:  [1000/1251]
小epoch:  [1100/1251]
小epoch:  [1200/1251]
Epoch [2/10], Loss: 0.3697
小epoch:  [0/1251]
小epoch:  [100/1251]
小epoch:  [200/1251]
小epoch:  [300/1251]
小epoch:  [400/1251]
小epoch:  [500/1251]
小epoch:  [600/1251]
小epoch:  [700/1251]
小epoch:  [800/1251]
小epoch:  [900/1251]
小epoch:  [1000/1251]
小epoch:  [1100/1251]
小epoch:  [1200/1251]
Epoch [3/10], Loss: 0.3174
小epoch:  [0/1251]
小epoch:  [100/1251]
小epoch:  [200/1251]
小epoch:  [300/1251]
小epoch:  [400/1251]
小epoch:  [500/1251]
小epoch:  [600/1251