In [61]:
import numpy as np
import pandas as pd
import pathlib, sys, os, random, time
import cv2, gc
from tqdm.notebook import tqdm

import matplotlib.pyplot as plt
%matplotlib inline

import warnings
warnings.filterwarnings('ignore')

import albumentations as A
from albumentations.pytorch import ToTensorV2

import rasterio
from rasterio.windows import Window

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as D
import torchvision
from torchvision import transforms as T

In [62]:
# RLE编码和解码函数
def rle_encode(im):
    '''
    im: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = im.flatten(order = 'F')
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

def rle_decode(mask_rle, shape=(512, 512)):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background
    '''
    if mask_rle == '' or pd.isna(mask_rle):
        return np.zeros(shape, dtype=np.uint8)
    
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape, order='F')

In [63]:
# 配置参数
SEED = 42
EPOCHS = 30
BATCH_SIZE = 16
IMAGE_SIZE = 384  # 增加图像尺寸以获取更多细节
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
THRESHOLD = 0.5  # 二值化阈值，可以通过验证集调整

In [64]:
# 设置随机种子
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(SEED)

In [65]:
# 增强的数据增强策略
train_transform = A.Compose([
    A.Resize(IMAGE_SIZE, IMAGE_SIZE),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ShiftScaleRotate(p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    A.GaussNoise(p=0.3),
    A.Normalize(
        mean=[0.625, 0.448, 0.688],
        std=[0.131, 0.177, 0.101],
    ),
    ToTensorV2(),
])

valid_transform = A.Compose([
    A.Resize(IMAGE_SIZE, IMAGE_SIZE),
    A.Normalize(
        mean=[0.625, 0.448, 0.688],
        std=[0.131, 0.177, 0.101],
    ),
    ToTensorV2(),
])

In [66]:
class BuildingSegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, img_paths, mask_paths=None, transform=None):
        self.img_paths = img_paths
        self.mask_paths = mask_paths  # 这里可能存储的是掩码数据而不是路径
        self.transform = transform
        
    def __len__(self):
        return len(self.img_paths)
    
    def __getitem__(self, idx):
        # 加载图像
        img = cv2.imread(self.img_paths[idx])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # 处理掩码数据
        if self.mask_paths is not None:
            # 检查 mask_paths 是否是数据而不是路径
            mask_data = self.mask_paths[idx]
            
            # 如果 mask_data 是字符串，且看起来是空格分隔的数字序列
            if isinstance(mask_data, str) and ' ' in mask_data:
                try:
                    # 将空格分隔的数字转换为数组
                    mask_values = list(map(int, mask_data.split()))
                    
                    # 假设这是一个扁平化的二值掩码，将其重构为方形矩阵
                    # 估计尺寸（取近似平方根）
                    size = int(np.sqrt(len(mask_values) / 2)) * 2  # 确保是偶数
                    
                    # 将扁平数组转为掩码矩阵（如果长度不匹配则填充0）
                    if len(mask_values) >= size * size:
                        mask = np.array(mask_values[:size*size]).reshape(size, size)
                    else:
                        # 填充
                        padded = mask_values + [0] * (size * size - len(mask_values))
                        mask = np.array(padded).reshape(size, size)
                    
                except Exception as e:
                    # 如果转换失败，创建空掩码
                    print(f"警告：掩码转换失败 - {str(e)}")
                    mask = np.zeros((384, 384), dtype=np.float32)
            
            # 如果是路径，尝试加载文件（但很可能不是这种情况）
            elif isinstance(mask_data, str) and (mask_data.endswith('.npy') or 
                                              mask_data.endswith('.png') or 
                                              mask_data.endswith('.jpg')):
                try:
                    if mask_data.endswith('.npy'):
                        mask = np.load(mask_data)
                    else:
                        mask = cv2.imread(mask_data, cv2.IMREAD_GRAYSCALE)
                except Exception as e:
                    print(f"警告：掩码文件加载失败 - {str(e)}")
                    mask = np.zeros((384, 384), dtype=np.float32)
            
            # 如果是NumPy数组，直接使用
            elif isinstance(mask_data, np.ndarray):
                mask = mask_data
            
            # 如果是torch.Tensor，转为NumPy
            elif isinstance(mask_data, torch.Tensor):
                mask = mask_data.cpu().numpy()
            
            # 其他情况，创建空掩码
            else:
                print(f"警告：未知的掩码数据类型 - {type(mask_data)}")
                mask = np.zeros((384, 384), dtype=np.float32)
            
            # 应用变换（如果有）
            if self.transform is not None:
                # 如果使用 albumentations
                try:
                    transformed = self.transform(image=img, mask=mask)
                    img = transformed['image']
                    mask = transformed['mask']
                # 如果使用 torchvision.transforms
                except:
                    img = self.transform(img)
                    mask = cv2.resize(mask, (384, 384), interpolation=cv2.INTER_NEAREST)
            else:
                # 确保mask格式正确
                if not isinstance(mask, np.ndarray):
                    mask = np.array(mask)
                
                # 调整尺寸
                mask = cv2.resize(mask, (384, 384), interpolation=cv2.INTER_NEAREST)
            
            # 转换为tensor
            if not isinstance(img, torch.Tensor):
                img = torch.from_numpy(img.transpose(2, 0, 1)).float() / 255.0
            
            if isinstance(mask, np.ndarray):
                mask = torch.from_numpy(mask).float()
                
            # 确保mask有通道维度
            if mask.dim() == 2:
                mask = mask.unsqueeze(0)
                
            return img, mask
        else:
            # 只返回图像
            if self.transform is not None:
                try:  # albumentations
                    transformed = self.transform(image=img)
                    img = transformed['image']
                except:  # torchvision
                    img = self.transform(img)
            else:
                img = torch.from_numpy(img.transpose(2, 0, 1)).float() / 255.0
                
            return img

In [67]:
# U-Net模型定义 - 基础模块
class DoubleConv(nn.Module):
    """(Conv2D -> BN -> ReLU) * 2"""
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

In [68]:
class Up(nn.Module):
    """Upscaling then double conv"""
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [69]:
class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=1, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        factor = 2 if bilinear else 1
        
        # 使用更多滤波器以增加模型容量
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

In [70]:
# 优化的损失函数 - Dice Loss
class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
        
    def forward(self, pred, target):
        pred = torch.sigmoid(pred)
        
        # 平滑处理以避免0/0的情况
        intersection = (pred * target).sum(dim=(2,3))
        union = pred.sum(dim=(2,3)) + target.sum(dim=(2,3))
        
        dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
        return 1.0 - dice.mean()

In [71]:
# 优化的损失函数 - Focal Loss
class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        
    def forward(self, pred, target):
        pred = torch.sigmoid(pred)
        
        # 二元交叉熵损失
        bce = F.binary_cross_entropy(pred, target, reduction='none')
        
        # 应用focal loss公式
        pt = torch.exp(-bce)
        focal_loss = (1-pt)**self.gamma * bce
        
        return focal_loss.mean()

In [72]:
# 优化的组合损失函数
class CombinedLoss(nn.Module):
    def __init__(self, dice_weight=0.5, focal_weight=0.5):
        super(CombinedLoss, self).__init__()
        self.dice_weight = dice_weight
        self.focal_weight = focal_weight
        self.dice_loss = DiceLoss()
        self.focal_loss = FocalLoss()
        
    def forward(self, pred, target):
        dice = self.dice_loss(pred, target)
        focal = self.focal_loss(pred, target)
        return self.dice_weight * dice + self.focal_weight * focal

In [73]:
# 训练函数
def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    
    for images, masks in tqdm(dataloader):
        images = images.to(device)
        masks = masks.to(device)
        
        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

In [74]:
# 验证函数
@torch.no_grad()
def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    dice_scores = []
    
    for images, masks in tqdm(dataloader):
        images = images.to(device)
        masks = masks.to(device)
        
        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        # 计算Dice分数
        preds = (torch.sigmoid(outputs) > THRESHOLD).float()
        dice = (2 * (preds * masks).sum()) / (preds.sum() + masks.sum() + 1e-8)
        dice_scores.append(dice.item())
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader), np.mean(dice_scores)

In [75]:
# 预测函数
@torch.no_grad()
def predict(model, dataloader, device, threshold=THRESHOLD):
    model.eval()
    results = []
    
    for images, filenames in tqdm(dataloader):
        images = images.to(device)
        outputs = model(images)
        preds = torch.sigmoid(outputs)
        
        # 处理每个批次的预测
        for pred, filename in zip(preds, filenames):
            pred = pred.cpu().numpy().squeeze()
            pred = cv2.resize(pred, (512, 512))  # 调整为原始大小
            mask = (pred > threshold).astype(np.uint8)
            rle = rle_encode(mask)
            results.append([filename, rle])
    
    return results

In [76]:
# 主训练循环
def main():
    # 加载数据
    # 确保这个路径正确指向您的train_mask.csv文件位置
    train_mask = pd.read_csv('数据集/train_mask.csv', sep='\t', names=['name', 'mask'])
    # 确保这个路径前缀与您的训练图像目录一致
    train_mask['name'] = train_mask['name'].apply(lambda x: '数据集/train/' + x)
    
    # 分割训练集和验证集
    train_idx, valid_idx = [], []
    for i in range(len(train_mask)):
        if i % 7 == 0:
            valid_idx.append(i)
        else:
            train_idx.append(i)
    
    train_df = train_mask.iloc[train_idx].reset_index(drop=True)
    valid_df = train_mask.iloc[valid_idx].reset_index(drop=True)
    
    # 创建数据集和数据加载器
    train_ds = BuildingSegmentationDataset(
        train_df['name'].values,
        train_df['mask'].fillna('').values,
        transform=train_transform
    )
    
    valid_ds = BuildingSegmentationDataset(
        valid_df['name'].values,
        valid_df['mask'].fillna('').values,
        transform=valid_transform
    )
    
    train_loader = D.DataLoader(
        train_ds, batch_size=BATCH_SIZE, shuffle=True, 
        num_workers=4, pin_memory=True
    )
    
    valid_loader = D.DataLoader(
        valid_ds, batch_size=BATCH_SIZE, shuffle=False, 
        num_workers=4, pin_memory=True
    )
    
    # 初始化模型
    model = UNet(n_channels=3, n_classes=1, bilinear=False)
    model.to(DEVICE)
    
    # 优化器和学习率调度
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=3, verbose=True
    )
    
    # 损失函数
    criterion = CombinedLoss(dice_weight=0.7, focal_weight=0.3)
    
    # 训练循环
    best_dice = 0
    best_epoch = 0
    
    for epoch in range(1, EPOCHS + 1):
        print(f"Epoch {epoch}/{EPOCHS}")
        
        # 训练
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, DEVICE)
        
        # 验证
        val_loss, val_dice = validate(model, valid_loader, criterion, DEVICE)
        
        # 调整学习率
        scheduler.step(val_loss)
        
        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Dice: {val_dice:.4f}")
        
        # 保存最佳模型
        if val_dice > best_dice:
            print(f"Dice improved from {best_dice:.4f} to {val_dice:.4f}. Saving model...")
            best_dice = val_dice
            best_epoch = epoch
            torch.save(model.state_dict(), 'best_building_segmentation_model.pth')
    
    print(f"Training complete! Best Dice: {best_dice:.4f} at epoch {best_epoch}")

In [77]:
def predict_test_set():
    # 加载最佳模型进行预测
    model = UNet(n_channels=3, n_classes=1, bilinear=False)
    model.load_state_dict(torch.load('best_building_segmentation_model.pth'))
    model.to(DEVICE)
    
    # 创建测试数据集
    test_paths = []
    # 确保这个路径指向您的测试图像目录
    for file in os.listdir('数据集/test_a'):
        if file.endswith('.jpg') or file.endswith('.tif'):
            test_paths.append(os.path.join('数据集/test_a', file))
    
    test_ds = BuildingSegmentationDataset(
        test_paths,
        [""] * len(test_paths),
        transform=valid_transform,
        test_mode=True
    )
    
    test_loader = D.DataLoader(
        test_ds, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=4, pin_memory=True
    )
    
    # 预测并保存结果
    results = predict(model, test_loader, DEVICE)
    submission = pd.DataFrame(results, columns=['name', 'mask'])
    submission.to_csv('submission.csv', index=False, header=False, sep='\t')
    print("Prediction complete! Results saved to submission.csv")

In [78]:
if __name__ == "__main__":
    main()
    predict_test_set()

Epoch 1/30


  0%|          | 0/1608 [00:00<?, ?it/s]

警告：未知的掩码数据类型 - <class 'str'>
警告：未知的掩码数据类型 - <class 'str'>


KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/tmp/ipykernel_10028/2992724474.py", line 73, in __getitem__
    transformed = self.transform(image=img, mask=mask)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/albumentations/core/composition.py", line 493, in __call__
    self.preprocess(data)
  File "/root/miniconda3/lib/python3.12/site-packages/albumentations/core/composition.py", line 527, in preprocess
    self._check_shape_consistency(shapes, volume_shapes)
  File "/root/miniconda3/lib/python3.12/site-packages/albumentations/core/composition.py", line 778, in _check_shape_consistency
    self._check_shapes(shapes, self.is_check_shapes)
  File "/root/miniconda3/lib/python3.12/site-packages/albumentations/core/composition.py", line 705, in _check_shapes
    raise ValueError(
ValueError: Height and Width of image, mask or masks should be equal. You can disable shapes check by setting a parameter is_check_shapes=False of Compose class (do it only if you are sure about your data consistency).

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/data/_utils/worker.py", line 351, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
            ~~~~~~~~~~~~^^^^^
  File "/tmp/ipykernel_10028/2992724474.py", line 78, in __getitem__
    img = self.transform(img)
          ^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/albumentations/core/composition.py", line 479, in __call__
    raise KeyError(msg)
KeyError: 'You have to pass data to augmentations as named arguments, for example: aug(image=image)'


警告：未知的掩码数据类型 - <class 'str'>
