In [2]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision import models, datasets

# 数据集类
class CustomDataset(Dataset):
    def __init__(self, data_dir):
        self.data = datasets.ImageFolder(root=data_dir, transform=transforms.Compose([
            transforms.Resize((480, 640)),  # 调整图片尺寸
            transforms.ToTensor()
        ]))
        self.classes = self.data.classes

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# 特征金字塔网络模型
class FeaturePyramidNetwork(nn.Module):
    def __init__(self, base_model, num_classes):
        super(FeaturePyramidNetwork, self).__init__()

        # 获取基础模型的特征提取部分
        self.base_model = nn.Sequential(*list(base_model.children())[:-1])

        # 添加用于构建特征金字塔的额外卷积层
        self.additional_conv = nn.Conv2d(2048, 256, kernel_size=1)

        # 金字塔层级上的卷积层
        self.pyramid_conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.pyramid_conv2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.pyramid_conv3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)

        # 分类器
        self.classifier = nn.Linear(256, num_classes)

    def forward(self, x):
        # 特征提取
        features = self.base_model(x)
        features = self.additional_conv(features)

        # 金字塔层级的处理
        pyramid_level1 = self.pyramid_conv1(features)
        pyramid_level2 = self.pyramid_conv2(pyramid_level1)
        pyramid_level3 = self.pyramid_conv3(pyramid_level2)

        # 特征金字塔的池化操作
        pyramid_pooled = torch.mean(pyramid_level3, dim=(2, 3))

        # 分类
        output = self.classifier(pyramid_pooled)

        return output

# 设置数据路径和其他超参数
data_dir = "./MO_106/"
num_classes = 10
batch_size = 16
num_epochs = 10

# 创建数据集和数据加载器
dataset = CustomDataset(data_dir)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 使用预训练的ResNet作为基础模型
base_model = models.resnet50(pretrained=True)

# 创建特征金字塔网络模型
fpn = FeaturePyramidNetwork(base_model, num_classes)

# 将模型转移到GPU（如果可用）
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("mps")
fpn.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(fpn.parameters(), lr=0.001)

# 训练过程
for epoch in range(num_epochs):
    for images, labels in data_loader:
        images = images.to(device)
        labels = labels.to(device)

        # 前向传播
        outputs = fpn(images)
        loss = criterion(outputs, labels)

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # 打印训练信息
    print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

# 保存模型
torch.save(fpn.state_dict(), "fpn_model.pth")

# 在测试集上进行测试
fpn.eval()
test_dataset = CustomDataset(data_dir)
test_data_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_data_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = fpn(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Test Accuracy: {:.2f}%'.format(100 * correct / total))


KeyboardInterrupt: 