In [1]:
import numpy as np
import pandas as pd
import pathlib, sys, os, random, time
import cv2, gc
from tqdm.notebook import tqdm

import matplotlib.pyplot as plt
%matplotlib inline

import warnings
warnings.filterwarnings('ignore')

import albumentations as A
from albumentations.pytorch import ToTensorV2

import rasterio
from rasterio.windows import Window

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as D
import torchvision
from torchvision import transforms as T

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

[HAMI-core Msg(3352:140406248606592:libvgpu.c:840)]: Initializing.....
[HAMI-core Warn(3352:140406248606592:libvgpu.c:96)]: recursive dlsym : ompt_start_tool



In [2]:
# RLE编码和解码函数
def rle_encode(im):
    '''
    im: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = im.flatten(order = 'F')
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

def rle_decode(mask_rle, shape=(512, 512)):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background
    '''
    if mask_rle == '' or pd.isna(mask_rle):
        return np.zeros(shape, dtype=np.uint8)
    
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape, order='F')

In [3]:
# 配置参数
SEED = 42
EPOCHS = 30
BATCH_SIZE = 1
IMAGE_SIZE = 256  # 增加图像尺寸以获取更多细节
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
THRESHOLD = 0.5  # 二值化阈值，可以通过验证集调整

[HAMI-core Msg(3352:140406248606592:libvgpu.c:859)]: Initialized


In [4]:
# 设置随机种子
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(SEED)

In [5]:
# 增强的数据增强策略
train_transform = A.Compose([
    A.Resize(IMAGE_SIZE, IMAGE_SIZE),
    A.HorizontalFlip(p=0.5),
    # 移除了一些可能导致问题的增强
    A.Normalize(
        mean=[0.625, 0.448, 0.688],
        std=[0.131, 0.177, 0.101],
    ),
    ToTensorV2(),
])

# 验证集变换 - 加在此处
valid_transform = A.Compose([
    A.Resize(IMAGE_SIZE, IMAGE_SIZE),
    A.Normalize(
        mean=[0.625, 0.448, 0.688],
        std=[0.131, 0.177, 0.101],
    ),
    ToTensorV2(),
])

In [6]:
class BuildingSegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, img_paths, mask_paths=None, transform=None):
        self.img_paths = img_paths
        self.mask_paths = mask_paths
        self.transform = transform
        
    def __len__(self):
        return len(self.img_paths)
    
    def __getitem__(self, idx):
        # 加载图像
        img = cv2.imread(self.img_paths[idx])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # 处理掩码数据
        if self.mask_paths is not None:
            mask_data = self.mask_paths[idx]
            
            # 使用已有的rle_decode函数处理RLE格式的掩码
            if isinstance(mask_data, str):
                mask = rle_decode(mask_data, shape=(512, 512))
            elif isinstance(mask_data, np.ndarray):
                mask = mask_data
            elif isinstance(mask_data, torch.Tensor):
                mask = mask_data.cpu().numpy()
            else:
                print(f"警告：未知的掩码数据类型 - {type(mask_data)}")
                mask = np.zeros((512, 512), dtype=np.uint8)
            
            # 应用变换
            if self.transform is not None:
                transformed = self.transform(image=img, mask=mask)
                img = transformed["image"]
                mask = transformed["mask"]
            
            # 关键修复: 始终确保mask是float类型的tensor
            if isinstance(mask, np.ndarray):
                mask = torch.from_numpy(mask).float()
            else:
                # 确保tensor是float类型
                mask = mask.float()
            
            if mask.dim() == 2:
                mask = mask.unsqueeze(0)
                
            return img, mask
        else:
            # 只返回图像（用于测试集）
            if self.transform is not None:
                transformed = self.transform(image=img)
                img = transformed["image"]
                
            return img

In [7]:
# U-Net模型定义 - 基础模块
class DoubleConv(nn.Module):
    """(Conv2D -> BN -> ReLU) * 2"""
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

In [8]:
class Up(nn.Module):
    """Upscaling then double conv"""
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [9]:
class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=1, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        factor = 2 if bilinear else 1
        
        # 使用更多滤波器以增加模型容量
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

In [10]:
# 优化的损失函数 - Dice Loss
class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
        
    def forward(self, pred, target):
        pred = torch.sigmoid(pred)
        
        # 平滑处理以避免0/0的情况
        intersection = (pred * target).sum(dim=(2,3))
        union = pred.sum(dim=(2,3)) + target.sum(dim=(2,3))
        
        dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
        return 1.0 - dice.mean()

In [11]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.eps = 1e-7  # 添加极小值避免数值不稳定
        
    def forward(self, pred, target):
        # 确保输入是float类型
        pred = torch.sigmoid(pred)
        target = target.float()  # 确保target也是float
        
        # 对预测值进行剪裁，确保在有效范围内
        pred = torch.clamp(pred, self.eps, 1.0 - self.eps)
        
        # 二元交叉熵损失
        bce = F.binary_cross_entropy(pred, target, reduction='none')
        
        # 应用focal loss公式
        pt = torch.exp(-bce)
        focal_loss = (1-pt)**self.gamma * bce
        
        return focal_loss.mean()

In [12]:
# 优化的组合损失函数
class CombinedLoss(nn.Module):
    def __init__(self, dice_weight=0.5, focal_weight=0.5):
        super(CombinedLoss, self).__init__()
        self.dice_weight = dice_weight
        self.focal_weight = focal_weight
        self.dice_loss = DiceLoss()
        self.focal_loss = FocalLoss()
        
    def forward(self, pred, target):
        dice = self.dice_loss(pred, target)
        focal = self.focal_loss(pred, target)
        return self.dice_weight * dice + self.focal_weight * focal

In [13]:
class EarlyStopping:
    """当验证集性能不再提升时提前停止训练"""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
        """
        Args:
            patience (int): 验证集性能不提升后等待多少轮停止训练
            verbose (bool): 是否打印详细信息
            delta (float): 性能变化的最小阈值
            path (str): 保存检查点路径
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.inf
        self.delta = delta
        self.path = path
        
    def __call__(self, val_loss, model):
        score = -val_loss
        
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'早停计数: {self.counter}/{self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0
            
    def save_checkpoint(self, val_loss, model):
        '''当验证损失减小时保存模型'''
        if self.verbose:
            print(f'验证损失从 ({self.val_loss_min:.6f} 降至 {val_loss:.6f})。保存模型...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [14]:
def debug_tensors(images, masks, outputs=None):
    """检查张量的形状和类型以及值的范围"""
    print(f"Images: shape={images.shape}, type={images.dtype}, device={images.device}")
    print(f"Masks: shape={masks.shape}, type={masks.dtype}, device={masks.device}")
    
    if outputs is not None:
        print(f"Outputs: shape={outputs.shape}, type={outputs.dtype}, device={outputs.device}")
        
        # 检查输出值的范围
        with torch.no_grad():
            outputs_sigmoid = torch.sigmoid(outputs)
            min_val = outputs_sigmoid.min().item()
            max_val = outputs_sigmoid.max().item()
            print(f"输出sigmoid后的值范围: [{min_val:.6f}, {max_val:.6f}]")
            
            # 检查是否有极端值
            if min_val < 0 or max_val > 1:
                print("警告: sigmoid后的输出值超出[0,1]范围!")
    
    # 检查是否包含NaN或Inf
    if torch.isnan(images).any():
        print("警告: 图像包含NaN值!")
    if torch.isnan(masks).any():
        print("警告: 掩码包含NaN值!")
    if outputs is not None and torch.isnan(outputs).any():
        print("警告: 输出包含NaN值!")
    if outputs is not None and torch.isinf(outputs).any():
        print("警告: 输出包含Inf值!")

In [15]:
def train_one_epoch(model, dataloader, optimizer, criterion, device, accumulation_steps=4):
    model.train()
    total_loss = 0
    
    optimizer.zero_grad()
    
    for i, (images, masks) in enumerate(tqdm(dataloader)):
        images = images.to(device)
        masks = masks.to(device).float()
        
        # 仅在特定间隔打印调试信息，比如每500个批次
        if i % 500 == 0:
            debug_tensors(images, masks)
        
        # 前向传播
        outputs = model(images)
        
        # 同样，限制输出频率
        if i % 500 == 0:
            debug_tensors(images, masks, outputs)
        
        loss = criterion(outputs, masks)
        loss = loss / accumulation_steps
        
        # 检查损失是否有问题
        if torch.isnan(loss) or torch.isinf(loss):
            print(f"警告: 损失值异常: {loss.item()}")
            continue  # 跳过这个批次
            
        loss.backward()
        
        if (i + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        
        total_loss += loss.item() * accumulation_steps
    
    if (i + 1) % accumulation_steps != 0:
        optimizer.step()
        optimizer.zero_grad()
    
    return total_loss / len(dataloader)

In [16]:
def clear_memory():
    """清理可能的内存泄漏"""
    torch.cuda.empty_cache()
    gc.collect()

# 在train_with_checkpoints函数中每个epoch后调用
clear_memory()

In [17]:
# 验证函数
@torch.no_grad()
def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    dice_scores = []
    
    for images, masks in tqdm(dataloader):
        images = images.to(device)
        masks = masks.to(device)
        
        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        # 计算Dice分数
        preds = (torch.sigmoid(outputs) > THRESHOLD).float()
        dice = (2 * (preds * masks).sum()) / (preds.sum() + masks.sum() + 1e-8)
        dice_scores.append(dice.item())
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader), np.mean(dice_scores)

In [18]:
# 预测函数
@torch.no_grad()
def predict(model, dataloader, device, threshold=THRESHOLD):
    model.eval()
    results = []
    
    for images, filenames in tqdm(dataloader):
        images = images.to(device)
        outputs = model(images)
        preds = torch.sigmoid(outputs)
        
        # 处理每个批次的预测
        for pred, filename in zip(preds, filenames):
            pred = pred.cpu().numpy().squeeze()
            pred = cv2.resize(pred, (512, 512))  # 调整为原始大小
            mask = (pred > threshold).astype(np.uint8)
            rle = rle_encode(mask)
            results.append([filename, rle])
    
    return results

In [19]:
@torch.no_grad()
def predict_with_tta(model, image, device, threshold=THRESHOLD, tta_transforms=None):
    """测试时增强提高预测质量"""
    model.eval()
    
    # 如果没有提供TTA变换，则使用基本变换
    if tta_transforms is None:
        tta_transforms = [
            A.Compose([A.Normalize(mean=[0.625, 0.448, 0.688], std=[0.131, 0.177, 0.101]), ToTensorV2()]),
            A.Compose([A.HorizontalFlip(p=1.0), A.Normalize(mean=[0.625, 0.448, 0.688], std=[0.131, 0.177, 0.101]), ToTensorV2()]),
            A.Compose([A.VerticalFlip(p=1.0), A.Normalize(mean=[0.625, 0.448, 0.688], std=[0.131, 0.177, 0.101]), ToTensorV2()]),
            A.Compose([A.Transpose(p=1.0), A.Normalize(mean=[0.625, 0.448, 0.688], std=[0.131, 0.177, 0.101]), ToTensorV2()])
        ]
    
    # 应用所有变换并预测
    preds = []
    for transform in tta_transforms:
        augmented = transform(image=image)
        img_tensor = augmented['image'].unsqueeze(0).to(device)
        output = model(img_tensor)
        pred = torch.sigmoid(output).cpu().numpy().squeeze()
        
        # 还原变换
        if 'HorizontalFlip' in str(transform):
            pred = np.fliplr(pred)
        if 'VerticalFlip' in str(transform):
            pred = np.flipud(pred)
        if 'Transpose' in str(transform):
            pred = np.transpose(pred)
            
        preds.append(pred)
    
    # 平均所有预测结果
    final_pred = np.mean(preds, axis=0)
    return (final_pred > threshold).astype(np.uint8)

In [20]:
def train_with_checkpoints(model, train_loader, valid_loader, optimizer, 
                          criterion, scheduler, device, num_epochs, 
                          checkpoint_dir='checkpoints', accumulation_steps=4):
    os.makedirs(checkpoint_dir, exist_ok=True)
    best_dice = 0
    best_epoch = 0
    
    # 创建日志
    log_file = open(f"{checkpoint_dir}/training_log.csv", "w")
    log_file.write("epoch,train_loss,val_loss,val_dice,learning_rate\n")
    
    # 初始化早停
    early_stopping = EarlyStopping(patience=7, verbose=True, 
                                   path=f"{checkpoint_dir}/early_stop_model.pth")
    
    for epoch in range(1, num_epochs + 1):
        print(f"第 {epoch}/{num_epochs} 轮")
        
        # 训练 - 使用梯度累积
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device, accumulation_steps)
        
        # 清理内存
        clear_memory()

        print("\n" + "="*50)
        print(f"完成第 {epoch} 轮训练，开始验证...")

        # 验证
        val_loss, val_dice = validate(model, valid_loader, criterion, device)

        print(f"验证完成! 损失: {val_loss:.4f}, Dice: {val_dice:.4f}")
        print("="*50 + "\n")

        # 清理内存
        clear_memory()
        
        # 记录学习率
        current_lr = optimizer.param_groups[0]['lr']
        
        # 写入日志
        log_file.write(f"{epoch},{train_loss:.4f},{val_loss:.4f},{val_dice:.4f},{current_lr:.8f}\n")
        log_file.flush()
        
        # 调整学习率
        scheduler.step(val_loss)
        
        print(f"训练损失: {train_loss:.4f} | 验证损失: {val_loss:.4f} | Dice分数: {val_dice:.4f} | 学习率: {current_lr:.8f}")
        
        # 保存检查点 - 只保存必要信息以节省空间
        if epoch % 5 == 0 or epoch == num_epochs:  # 每5个epoch保存一次完整检查点
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_dice': val_dice,
            }, f"{checkpoint_dir}/checkpoint_epoch_{epoch}.pth")
        
        # 保存最佳模型
        if val_dice > best_dice:
            print(f"Dice分数从 {best_dice:.4f} 提高到 {val_dice:.4f}. 正在保存模型...")
            best_dice = val_dice
            best_epoch = epoch
            torch.save(model.state_dict(), f"{checkpoint_dir}/best_model.pth")
        
        # 检查早停条件
        early_stopping(val_loss, model)
        if early_stopping.early_stop:
            print("触发早停! 训练停止。")
            break
    
    log_file.close()
    print(f"训练完成! 最佳Dice分数: {best_dice:.4f} (第{best_epoch}轮)")
    return best_dice, best_epoch

In [21]:
def main():
    # 设置随机种子
    seed_everything(SEED)
    
    # 加载数据
    try:
        print("正在加载训练数据...")
        train_mask = pd.read_csv('数据集/train_mask.csv', sep='\t', names=['name', 'mask'])
        train_mask['name'] = train_mask['name'].apply(lambda x: '数据集/train/' + x)
    except Exception as e:
        print(f"加载数据时出错: {e}")
        print("请确保'数据集/train_mask.csv'文件存在并且格式正确!")
        return
    
    print(f"已加载 {len(train_mask)} 条训练数据")
    
    # 分割训练集和验证集
    train_idx, valid_idx = [], []
    for i in range(len(train_mask)):
        if i % 7 == 0:
            valid_idx.append(i)
        else:
            train_idx.append(i)
    
    train_df = train_mask.iloc[train_idx].reset_index(drop=True)
    valid_df = train_mask.iloc[valid_idx].reset_index(drop=True)
    
    print(f"训练集: {len(train_df)} 样本, 验证集: {len(valid_df)} 样本")
    
    # 创建数据集和数据加载器
    train_ds = BuildingSegmentationDataset(
        train_df['name'].values,
        train_df['mask'].fillna('').values,
        transform=train_transform
    )
    
    valid_ds = BuildingSegmentationDataset(
        valid_df['name'].values,
        valid_df['mask'].fillna('').values,
        transform=valid_transform
    )
    
    train_loader = D.DataLoader(
        train_ds, batch_size=BATCH_SIZE, shuffle=True, 
        num_workers=2, pin_memory=True
    )
    
    valid_loader = D.DataLoader(
        valid_ds, batch_size=BATCH_SIZE, shuffle=False, 
        num_workers=2, pin_memory=True
    )
    
    # 初始化模型
    model = UNet(n_channels=3, n_classes=1, bilinear=False)
    model.to(DEVICE)
    
    # 打印模型摘要
    print(f"模型已创建并加载到设备: {DEVICE}")
    total_params = sum(p.numel() for p in model.parameters())
    print(f"模型总参数量: {total_params:,}")
    
    # 优化器和学习率调度
    # 在main函数中替换优化器定义
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4, eps=1e-8)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=3, verbose=True
    )
    
    # 损失函数
    criterion = CombinedLoss(dice_weight=0.7, focal_weight=0.3)
    
    # 使用改进的训练循环
    best_dice, best_epoch = train_with_checkpoints(
        model, train_loader, valid_loader, optimizer, 
        criterion, scheduler, DEVICE, EPOCHS, checkpoint_dir='model_checkpoints'
    )
    
    print(f"训练完成! 最佳Dice分数: {best_dice:.4f} 在第{best_epoch}轮")
    return model

In [22]:
def predict_test_set():
    # 加载最佳模型进行预测
    model = UNet(n_channels=3, n_classes=1, bilinear=False)
    try:
        model.load_state_dict(torch.load('model_checkpoints/best_model.pth'))
    except:
        model.load_state_dict(torch.load('best_building_segmentation_model.pth'))
    
    model.to(DEVICE)
    model.eval()
    
    # 创建测试数据集
    test_paths = []
    # 确保这个路径指向您的测试图像目录
    for file in os.listdir('数据集/test_a'):
        if file.endswith('.jpg') or file.endswith('.tif'):
            test_paths.append(os.path.join('数据集/test_a', file))
    
    # 使用更小的批次
    test_batch_size = 1  # 单张预测避免内存问题
    
    test_ds = BuildingSegmentationDataset(
        test_paths,
        None,  # 测试集没有掩码数据
        transform=valid_transform
    )
    
    test_loader = D.DataLoader(
        test_ds, batch_size=test_batch_size, shuffle=False,
        num_workers=1, pin_memory=True  # 使用更少的worker和更小的批量
    )
    
    results = []
    test_files = [os.path.basename(p) for p in test_paths]
    
    with torch.no_grad():
        i = 0
        for images in tqdm(test_loader):
            images = images.to(DEVICE)
            outputs = model(images)
            preds = torch.sigmoid(outputs)
            
            # 处理每个批次的预测
            for pred in preds:
                if i >= len(test_files):  # 安全检查
                    break
                    
                pred = pred.cpu().numpy().squeeze()
                # 及时清理GPU内存
                clear_memory()
                
                pred = cv2.resize(pred, (512, 512))  # 调整为原始大小
                mask = (pred > THRESHOLD).astype(np.uint8)
                rle = rle_encode(mask)
                results.append([test_files[i], rle])
                i += 1
    
    submission = pd.DataFrame(results, columns=['name', 'mask'])
    submission.to_csv('submission.csv', index=False, header=False, sep='\t')
    print("预测完成! 结果已保存到 submission.csv")