In [None]:
import os
import random
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from svgpathtools import svg2paths
from tqdm import tqdm

# 1. 数据预处理和数据集准备
class FloorPlanDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None, target_size=(256, 256)):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.target_size = target_size
        
        # 获取所有png文件
        self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.png')]
        
        # 创建图像-mask映射表
        self.file_pairs = []
        for img_file in self.image_files:
            base_name = os.path.splitext(img_file)[0]
            possible_masks = [
                f for f in os.listdir(mask_dir) 
                if f.startswith(base_name) and f.endswith('.svg')
            ]
            
            if len(possible_masks) == 1:
                self.file_pairs.append((img_file, possible_masks[0]))
            elif len(possible_masks) > 1:
                print(f"⚠️ 多个匹配的SVG文件: {img_file} -> {possible_masks}")
        
        if not self.file_pairs:
            raise ValueError("没有找到有效的图像-mask对！")
        print(f"成功匹配 {len(self.file_pairs)} 对图像-mask文件")

    def __len__(self):
        return len(self.file_pairs)

    def __getitem__(self, idx):
        img_file, mask_file = self.file_pairs[idx]
        img_path = os.path.join(self.image_dir, img_file)
        mask_path = os.path.join(self.mask_dir, mask_file)
        
        # 加载图像
        image = Image.open(img_path).convert('RGB')
        image = image.resize(self.target_size)
        image = np.array(image)
        
        # 加载并处理SVG mask
        mask = self.svg_to_mask(mask_path)
        mask = cv2.resize(mask, self.target_size, interpolation=cv2.INTER_NEAREST)
        
        # 应用变换
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']
        
        # 转换为PyTorch张量
        image = transforms.ToTensor()(image)
        mask = torch.from_numpy(mask).long()
        
        return image, mask
    
    def svg_to_mask(self, svg_path):
        paths, attributes = svg2paths(svg_path)
        width, height = 512, 512
        mask = np.zeros((height, width), dtype=np.uint8)
        
        for path, attr in zip(paths, attributes):
            class_id = 1  # 默认前景
            if "wall" in str(attr):  # 根据属性区分类别
                class_id = 2
                
            for contour in path.continuous_subpaths():
                points = []
                for segment in contour:
                    if segment.start is not None:
                        x, y = segment.start.real, segment.start.imag
                        points.append([x, y])
                
                if len(points) > 2:
                    points = np.array(points, dtype=np.int32)
                    cv2.fillPoly(mask, [points], class_id)
        return mask

# 2. 数据集划分
def prepare_datasets(data_dir, test_size=0.2):
    image_dir = os.path.join(data_dir, 'images')
    mask_dir = os.path.join(data_dir, 'masks')
    
    # 创建文件夹（如果不存在）
    os.makedirs(image_dir, exist_ok=True)
    os.makedirs(mask_dir, exist_ok=True)
    
    full_dataset = FloorPlanDataset(image_dir, mask_dir)
    train_size = int((1 - test_size) * len(full_dataset))
    test_size = len(full_dataset) - train_size
    return random_split(full_dataset, [train_size, test_size])

# 3. U-Net模型定义
class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=2):
        super(UNet, self).__init__()
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)
        self.up1 = Up(1024, 512)
        self.up2 = Up(512, 256)
        self.up3 = Up(256, 128)
        self.up4 = Up(128, 64)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        return self.outc(x)

# 4. 训练和评估函数
def train_and_evaluate():
    # 初始化
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    data_dir = r"C:\Users\TroubleMo\Desktop\CVC-FP"
    
    # 准备数据
    train_dataset, test_dataset = prepare_datasets(data_dir)
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=8)
    
    # 初始化模型
    model = UNet().to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # 计算类别权重
    def calculate_weights():
        class_counts = torch.zeros(2)
        for _, mask in train_dataset:
            unique, counts = torch.unique(mask, return_counts=True)
            for u, c in zip(unique, counts):
                class_counts[u] += c
        return (1 / (class_counts / class_counts.sum())).to(device)
    
    criterion = nn.CrossEntropyLoss(weight=calculate_weights())
    
    # 训练循环
    for epoch in range(50):
        model.train()
        for images, masks in tqdm(train_loader):
            images, masks = images.to(device), masks.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
        
        # 评估
        model.eval()
        with torch.no_grad():
            ious = []
            for images, masks in test_loader:
                images, masks = images.to(device), masks.to(device)
                outputs = model(images)
                _, preds = torch.max(outputs, 1)
                
                for cls in range(1, 2):
                    pred = (preds == cls)
                    true = (masks == cls)
                    intersection = (pred & true).sum().item()
                    union = (pred | true).sum().item()
                    if union > 0:
                        ious.append(intersection / union)
        
        print(f"Epoch {epoch+1}, IoU: {np.mean(ious) if ious else 0:.4f}")
    
    # 保存模型
    torch.save(model.state_dict(), 'unet_model.pth')
    print("训练完成！")

if __name__ == '__main__':
    train_and_evaluate()

NameError: name 'train_dataset' is not defined