这个文件用来训练Segmentation模型

In [1]:
import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch import optim
from utils import *

In [2]:
# 加载预训练的Faster R-CNN模型
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False)

# 获取分类器的输入特征数
in_features = model.roi_heads.box_predictor.cls_score.in_features

# 替换预训练的头部为一个新的，只有两个类别（背景和单词）
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes=2)



In [3]:
class ModifiedDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.trans = transforms.Normalize(mean=(0.5), std=(0.5))

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        img -= img.mean()
        img /= img.std()
        # img = self.trans(img)

        # 将 (x, y, w, h) 格式的边界框转换为 (x1, y1, x2, y2) 格式
        label[:, 2] = label[:, 0] + label[:, 2]
        label[:, 3] = label[:, 1] + label[:, 3]

        # 仅保留包含单词的边界框
        indices = label.sum(dim=-1) > 0
        label = label[indices]

        # 制造classifier的标签
        temp = torch.ones(len(label), dtype=torch.long)
        label = {'boxes': label, 'labels': temp}
        
        return img, label


from torch.utils.data.dataloader import default_collate

def collate_fn(batch):
    images = [item[0] for item in batch]
    targets = [item[1] for item in batch]

    # 使用默认的 collate 处理图片（因为图片大小相同）
    images = default_collate(images)
    
    # 不尝试合并 targets，因为它们包含不同数量的边界框
    # 直接作为列表返回
    return images, targets

In [4]:
dataset = ModifiedDataset(SegDataset('IAM', 'train'))
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=0, collate_fn=collate_fn)

In [5]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

# 优化器
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005)

# 训练模式
model.train()

num_epochs = 10

for epoch in range(num_epochs):
    for step, (images, targets) in enumerate(dataloader):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        # print("images:\n", images)
        # print("targets:\n", targets)

        # 计算损失
        loss_dict = model(images, targets)
        # print("loss_dict:\n", loss_dict)
        losses = sum(loss for loss in loss_dict.values())

        # 反向传播
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        print(f"Epoch {epoch} step {step} loss: {losses.item()}")
    print(f"Epoch {epoch} loss: {losses.item()}")