In [1]:
import numpy as np
import pandas as pd
import pathlib, sys, os, random, time
import cv2, gc
from tqdm.notebook import tqdm

import matplotlib.pyplot as plt
%matplotlib inline

import warnings
warnings.filterwarnings('ignore')

import albumentations as A
from albumentations.pytorch import ToTensorV2

import rasterio
from rasterio.windows import Window

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as D
import torchvision
from torchvision import transforms as T

In [2]:
# RLE编码和解码函数
def rle_encode(im):
    '''
    im: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = im.flatten(order = 'F')
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

def rle_decode(mask_rle, shape=(512, 512)):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background
    '''
    if mask_rle == '' or pd.isna(mask_rle):
        return np.zeros(shape, dtype=np.uint8)
    
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape, order='F')

In [3]:
# 配置参数
SEED = 42
EPOCHS = 30
BATCH_SIZE = 4
IMAGE_SIZE = 384  # 增加图像尺寸以获取更多细节
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
THRESHOLD = 0.5  # 二值化阈值，可以通过验证集调整

In [4]:
# 设置随机种子
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(SEED)

In [None]:
# 增强的数据增强策略
train_transform = A.Compose([
    A.Resize(IMAGE_SIZE, IMAGE_SIZE),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ShiftScaleRotate(p=0.5, shift_limit=0.1, scale_limit=0.2, rotate_limit=30),
    A.OneOf([
        A.RandomBrightnessContrast(p=1.0),
        A.RandomGamma(p=1.0),
        A.HueSaturationValue(p=1.0)
    ], p=0.5),
    A.OneOf([
        A.ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03, p=1.0),
        A.GridDistortion(p=1.0),
        A.OpticalDistortion(distort_limit=1.0, shift_limit=0.5, p=1.0),
    ], p=0.3),
    A.CoarseDropout(max_holes=8, max_height=32, max_width=32, p=0.3),
    A.GaussNoise(p=0.3),
    A.Normalize(
        mean=[0.625, 0.448, 0.688],
        std=[0.131, 0.177, 0.101],
    ),
    ToTensorV2(),
])

In [None]:
class BuildingSegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, img_paths, mask_paths=None, transform=None):
        self.img_paths = img_paths
        self.mask_paths = mask_paths
        self.transform = transform
        
    def __len__(self):
        return len(self.img_paths)
    
    def __getitem__(self, idx):
        # 加载图像
        img = cv2.imread(self.img_paths[idx])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # 处理掩码数据
        if self.mask_paths is not None:
            mask_data = self.mask_paths[idx]
            
            # 使用已有的rle_decode函数处理RLE格式的掩码
            if isinstance(mask_data, str):
                mask = rle_decode(mask_data, shape=(512, 512))
            elif isinstance(mask_data, np.ndarray):
                mask = mask_data
            elif isinstance(mask_data, torch.Tensor):
                mask = mask_data.cpu().numpy()
            else:
                print(f"警告：未知的掩码数据类型 - {type(mask_data)}")
                mask = np.zeros((512, 512), dtype=np.uint8)
            
            # 应用变换
            if self.transform is not None:
                transformed = self.transform(image=img, mask=mask)
                img = transformed["image"]
                mask = transformed["mask"]
            
            # 关键修复: 始终确保mask是float类型的tensor
            if isinstance(mask, np.ndarray):
                mask = torch.from_numpy(mask).float()
            else:
                # 确保tensor是float类型
                mask = mask.float()
            
            if mask.dim() == 2:
                mask = mask.unsqueeze(0)
                
            return img, mask
        else:
            # 只返回图像（用于测试集）
            if self.transform is not None:
                transformed = self.transform(image=img)
                img = transformed["image"]
                
            return img

In [7]:
# U-Net模型定义 - 基础模块
class DoubleConv(nn.Module):
    """(Conv2D -> BN -> ReLU) * 2"""
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

In [8]:
class Up(nn.Module):
    """Upscaling then double conv"""
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [9]:
class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=1, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        factor = 2 if bilinear else 1
        
        # 使用更多滤波器以增加模型容量
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

In [10]:
# 优化的损失函数 - Dice Loss
class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
        
    def forward(self, pred, target):
        pred = torch.sigmoid(pred)
        
        # 平滑处理以避免0/0的情况
        intersection = (pred * target).sum(dim=(2,3))
        union = pred.sum(dim=(2,3)) + target.sum(dim=(2,3))
        
        dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
        return 1.0 - dice.mean()

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.eps = 1e-7  # 添加极小值避免数值不稳定
        
    def forward(self, pred, target):
        # 确保输入是float类型
        pred = torch.sigmoid(pred)
        target = target.float()  # 确保target也是float
        
        # 二元交叉熵损失
        bce = F.binary_cross_entropy(pred + self.eps, target, reduction='none')
        
        # 应用focal loss公式
        pt = torch.exp(-bce)
        focal_loss = (1-pt)**self.gamma * bce
        
        return focal_loss.mean()

In [12]:
# 优化的组合损失函数
class CombinedLoss(nn.Module):
    def __init__(self, dice_weight=0.5, focal_weight=0.5):
        super(CombinedLoss, self).__init__()
        self.dice_weight = dice_weight
        self.focal_weight = focal_weight
        self.dice_loss = DiceLoss()
        self.focal_loss = FocalLoss()
        
    def forward(self, pred, target):
        dice = self.dice_loss(pred, target)
        focal = self.focal_loss(pred, target)
        return self.dice_weight * dice + self.focal_weight * focal

In [None]:
class EarlyStopping:
    """当验证集性能不再提升时提前停止训练"""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
        """
        Args:
            patience (int): 验证集性能不提升后等待多少轮停止训练
            verbose (bool): 是否打印详细信息
            delta (float): 性能变化的最小阈值
            path (str): 保存检查点路径
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        
    def __call__(self, val_loss, model):
        score = -val_loss
        
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'早停计数: {self.counter}/{self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0
            
    def save_checkpoint(self, val_loss, model):
        '''当验证损失减小时保存模型'''
        if self.verbose:
            print(f'验证损失从 ({self.val_loss_min:.6f} 降至 {val_loss:.6f})。保存模型...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [13]:
# 训练函数
def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    
    for images, masks in tqdm(dataloader):
        images = images.to(device)
        masks = masks.to(device)
        
        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

In [14]:
# 验证函数
@torch.no_grad()
def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    dice_scores = []
    
    for images, masks in tqdm(dataloader):
        images = images.to(device)
        masks = masks.to(device)
        
        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        # 计算Dice分数
        preds = (torch.sigmoid(outputs) > THRESHOLD).float()
        dice = (2 * (preds * masks).sum()) / (preds.sum() + masks.sum() + 1e-8)
        dice_scores.append(dice.item())
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader), np.mean(dice_scores)

In [15]:
# 预测函数
@torch.no_grad()
def predict(model, dataloader, device, threshold=THRESHOLD):
    model.eval()
    results = []
    
    for images, filenames in tqdm(dataloader):
        images = images.to(device)
        outputs = model(images)
        preds = torch.sigmoid(outputs)
        
        # 处理每个批次的预测
        for pred, filename in zip(preds, filenames):
            pred = pred.cpu().numpy().squeeze()
            pred = cv2.resize(pred, (512, 512))  # 调整为原始大小
            mask = (pred > threshold).astype(np.uint8)
            rle = rle_encode(mask)
            results.append([filename, rle])
    
    return results

In [None]:
@torch.no_grad()
def predict_with_tta(model, image, device, threshold=THRESHOLD, tta_transforms=None):
    """测试时增强提高预测质量"""
    model.eval()
    
    # 如果没有提供TTA变换，则使用基本变换
    if tta_transforms is None:
        tta_transforms = [
            A.Compose([A.Normalize(mean=[0.625, 0.448, 0.688], std=[0.131, 0.177, 0.101]), ToTensorV2()]),
            A.Compose([A.HorizontalFlip(p=1.0), A.Normalize(mean=[0.625, 0.448, 0.688], std=[0.131, 0.177, 0.101]), ToTensorV2()]),
            A.Compose([A.VerticalFlip(p=1.0), A.Normalize(mean=[0.625, 0.448, 0.688], std=[0.131, 0.177, 0.101]), ToTensorV2()]),
            A.Compose([A.Transpose(p=1.0), A.Normalize(mean=[0.625, 0.448, 0.688], std=[0.131, 0.177, 0.101]), ToTensorV2()])
        ]
    
    # 应用所有变换并预测
    preds = []
    for transform in tta_transforms:
        augmented = transform(image=image)
        img_tensor = augmented['image'].unsqueeze(0).to(device)
        output = model(img_tensor)
        pred = torch.sigmoid(output).cpu().numpy().squeeze()
        
        # 还原变换
        if 'HorizontalFlip' in str(transform):
            pred = np.fliplr(pred)
        if 'VerticalFlip' in str(transform):
            pred = np.flipud(pred)
        if 'Transpose' in str(transform):
            pred = np.transpose(pred)
            
        preds.append(pred)
    
    # 平均所有预测结果
    final_pred = np.mean(preds, axis=0)
    return (final_pred > threshold).astype(np.uint8)

In [None]:
def main():
    # 设置随机种子
    seed_everything(SEED)
    
    # 加载数据
    try:
        print("正在加载训练数据...")
        train_mask = pd.read_csv('数据集/train_mask.csv', sep='\t', names=['name', 'mask'])
        train_mask['name'] = train_mask['name'].apply(lambda x: '数据集/train/' + x)
    except Exception as e:
        print(f"加载数据时出错: {e}")
        print("请确保'数据集/train_mask.csv'文件存在并且格式正确!")
        return
    
    print(f"已加载 {len(train_mask)} 条训练数据")
    
    # 分割训练集和验证集
    train_idx, valid_idx = [], []
    for i in range(len(train_mask)):
        if i % 7 == 0:
            valid_idx.append(i)
        else:
            train_idx.append(i)
    
    train_df = train_mask.iloc[train_idx].reset_index(drop=True)
    valid_df = train_mask.iloc[valid_idx].reset_index(drop=True)
    
    print(f"训练集: {len(train_df)} 样本, 验证集: {len(valid_df)} 样本")
    
    # 创建数据集和数据加载器
    train_ds = BuildingSegmentationDataset(
        train_df['name'].values,
        train_df['mask'].fillna('').values,
        transform=train_transform
    )
    
    valid_ds = BuildingSegmentationDataset(
        valid_df['name'].values,
        valid_df['mask'].fillna('').values,
        transform=valid_transform
    )
    
    train_loader = D.DataLoader(
        train_ds, batch_size=BATCH_SIZE, shuffle=True, 
        num_workers=4, pin_memory=True
    )
    
    valid_loader = D.DataLoader(
        valid_ds, batch_size=BATCH_SIZE, shuffle=False, 
        num_workers=4, pin_memory=True
    )
    
    # 初始化模型
    model = UNet(n_channels=3, n_classes=1, bilinear=False)
    model.to(DEVICE)
    
    # 打印模型摘要
    print(f"模型已创建并加载到设备: {DEVICE}")
    total_params = sum(p.numel() for p in model.parameters())
    print(f"模型总参数量: {total_params:,}")
    
    # 优化器和学习率调度
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=3, verbose=True
    )
    
    # 损失函数
    criterion = CombinedLoss(dice_weight=0.7, focal_weight=0.3)
    
    # 使用改进的训练循环
    best_dice, best_epoch = train_with_checkpoints(
        model, train_loader, valid_loader, optimizer, 
        criterion, scheduler, DEVICE, EPOCHS, checkpoint_dir='model_checkpoints'
    )
    
    print(f"训练完成! 最佳Dice分数: {best_dice:.4f} 在第{best_epoch}轮")
    return model

In [17]:
def predict_test_set():
    # 加载最佳模型进行预测
    model = UNet(n_channels=3, n_classes=1, bilinear=False)
    model.load_state_dict(torch.load('best_building_segmentation_model.pth'))
    model.to(DEVICE)
    
    # 创建测试数据集
    test_paths = []
    # 确保这个路径指向您的测试图像目录
    for file in os.listdir('数据集/test_a'):
        if file.endswith('.jpg') or file.endswith('.tif'):
            test_paths.append(os.path.join('数据集/test_a', file))
    
    # 确保使用的是正确的BuildingSegmentationDataset类
    test_ds = BuildingSegmentationDataset(
        test_paths,
        None,  # 测试集没有掩码数据
        transform=valid_transform
    )
    
    test_loader = D.DataLoader(
        test_ds, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=4, pin_memory=True
    )
    
    # 修改predict函数的调用，因为测试集没有文件名
    model.eval()
    results = []
    test_files = [os.path.basename(p) for p in test_paths]
    
    with torch.no_grad():
        i = 0
        for images in tqdm(test_loader):
            images = images.to(DEVICE)
            outputs = model(images)
            preds = torch.sigmoid(outputs)
            
            # 处理每个批次的预测
            for pred in preds:
                if i >= len(test_files):  # 安全检查
                    break
                    
                pred = pred.cpu().numpy().squeeze()
                pred = cv2.resize(pred, (512, 512))  # 调整为原始大小
                mask = (pred > THRESHOLD).astype(np.uint8)
                rle = rle_encode(mask)
                results.append([test_files[i], rle])
                i += 1
    
    submission = pd.DataFrame(results, columns=['name', 'mask'])
    submission.to_csv('submission.csv', index=False, header=False, sep='\t')
    print("Prediction complete! Results saved to submission.csv")

In [18]:
if __name__ == "__main__":
    main()
    predict_test_set()

Epoch 1/30


  0%|          | 0/6429 [00:00<?, ?it/s]

RuntimeError: Found dtype Byte but expected Float