### Utilis Function

In [8]:
import time
import torch
import glob
import numpy as np
import os
# flake8: noqa: E501


class AverageMeter(object):
    """
    Computes and stores the average and current value
    Adapted from https://github.com/pytorch/examples/blob/master/imagenet/main.py    # noqa: E501
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def to_device(data, device):
    if device.startswith('npu'):
        return {k: v.to(device, non_blocking=True) for k, v in data.items()}
    else:
        Warning("NPU Device is not supported")
    return data


def print_epoch_stats(phase,
                      epoch,
                      loss_meter,
                      batch_time,
                      data_time,
                      start_time,
                      iou=None):
    if iou:
        print(
            f"Epoch: [{epoch}] {phase} - TotalT: {(time.time() - start_time) / 60:.1f} min, "
            f"Loss: {loss_meter.avg:.4f}, Global IoU: {iou:.4f}")
    else:
        print(
            f"Epoch: [{epoch}] {phase} - TotalT: {(time.time() - start_time) / 60:.1f} min, "
            f"BatchT: {batch_time.avg:.3f}s, DataT: {data_time.avg:.3f}s, Loss: {loss_meter.avg:.4f}"
        )


def save_model(epoch, iou, model, optimizer, experiment_name, log_path,
               best_metric, best_metric_epoch):
    save_dict = {
        "model_name": experiment_name,
        "epoch_num": epoch,
        "model_state_dict": model.state_dict(),
        "optim_state_dict": optimizer.state_dict(),
        "iou": iou
    }
    old_model_path = glob.glob(f"{log_path}/best_iou*")
    if old_model_path:
        os.remove(old_model_path[0])
    torch.save(save_dict, f"{log_path}/best_iou_{epoch}_{iou:.4f}.pt")

    if iou > best_metric:
        best_metric, best_metric_epoch = iou, epoch

    return best_metric, best_metric_epoch


def create_data_loader(dataset, batch_size, shuffle, num_workers, pin_memory):
    return torch.utils.data.DataLoader(dataset,
                                       batch_size=batch_size,
                                       shuffle=shuffle,
                                       num_workers=num_workers,
                                       pin_memory=pin_memory)


def generate_paths(base_path, pattern, replacements):
    paths = sorted(glob.glob(f"{base_path}/{pattern}"))
    return [
        list(map(lambda r: path.replace(*r), replacements)) for path in paths
    ]


def split_data(paths, train_perc, val_perc):
    np.random.seed(17)
    np.random.shuffle(paths)
    n = len(paths)
    return {
        "train": paths[:int(train_perc * n)],
        "val": paths[int(train_perc * n):int((train_perc + val_perc) * n)],
        "test": paths[int((train_perc + val_perc) * n):]
    }


### Data Processing

In [9]:
import tifffile
import numpy as np
import torch.utils.data
import matplotlib.pyplot as plt
# flake8: noqa: E501


def visualize_s2_img(s2_channel_paths):
    """
    Calculates and returns a 'SWIR' S2 false color composite ready for visualization.

    Args:
        s2_channel_paths (list of str): Paths to the ['B12', 'B7', 'B4'] bands of a S2 image.

    Returns:
        np.array: Image (H, W, 3) ready for visualization with e.g. plt.imshow(..)
    """
    img = []
    for path in s2_channel_paths:
        img.append(tifffile.imread(path))
    img = np.stack(img, axis=-1)
    return scale_S1_S2_img(img, sentinel=2)


def visualize_s1_img(path_vv, path_vh):
    """
    Calculates and returns a S1 false color composite ready for visualization.

    Args:
        path_vv (str): Path to the VV band.
        path_vh (str): Path to the VH band.

    Returns:
        np.array: Image (H, W, 3) ready for visualization with e.g. plt.imshow(..)
    """
    s1_img = np.stack((tifffile.imread(path_vv), tifffile.imread(path_vh)),
                      axis=-1)
    img = np.zeros((s1_img.shape[0], s1_img.shape[1], 3), dtype=np.float32)
    img[:, :, :2] = s1_img.copy()
    img[:, :, 2] = s1_img[:, :, 0] / s1_img[:, :, 1]
    s1_img = np.nan_to_num(s1_img)
    return scale_S1_S2_img(img, sentinel=1)


def scale_S1_S2_img(matrix, sentinel=2):
    """
    Returns a scaled (H,W,D) image which is more easily visually inspectable. Image is linearly scaled between
    min and max_value of by channel.
    Args:
        matrix (np.array): (H,W,D) image to be scaled
        sentinel (int, optional): Sentinel 1 or Sentinel 2 image? Determines the min and max values for scalin, defaults to 2.

    Returns:
        np.array: Image (H, W, 3) ready for visualization with e.g. plt.imshow(..)
    """
    w, h, d = matrix.shape
    min_values = np.array([100, 100, 100]) if sentinel == 2 else np.array(
        [-23, -28, 0.2])
    max_values = np.array([3500, 3500, 3500]) if sentinel == 2 else np.array(
        [0, -5, 1])

    matrix = np.reshape(matrix, [w * h, d]).astype(np.float64)
    matrix = (matrix - min_values[None, :]) / (max_values[None, :] -
                                               min_values[None, :])
    matrix = np.reshape(matrix, [w, h, d])

    matrix = matrix.clip(0, 1)
    return matrix


def mask_to_img(label, color_dict):
    """Recodes a (H,W) mask to a (H,W,3) RGB image according to color_dict"""
    mutually_exclusive = np.zeros(label.shape + (3, ), dtype=np.uint8)
    for key in range(1, len(color_dict.keys()) + 1):
        mutually_exclusive[label == key] = color_dict[key]
    return mutually_exclusive


class Sentinel1_Dataset(torch.utils.data.Dataset):

    def __init__(
        self,
        img_paths,
        mask_paths,
        transforms=None,
        min_normalize=-77,
        max_normalize=26,
    ):
        self.img_paths = img_paths
        self.mask_paths = mask_paths
        self.transforms = transforms
        self.min_normalize = min_normalize
        self.max_normalize = max_normalize

    def __getitem__(self, idx):
        # Load in image
        arr_x = []
        for path in self.img_paths[idx]:
            arr_x.append(tifffile.imread(path))
        arr_x = np.stack(arr_x, axis=-1)
        # Min-Max Normalization
        arr_x = np.clip(arr_x, self.min_normalize, self.max_normalize)
        arr_x = (arr_x - self.min_normalize) / (self.max_normalize -
                                                self.min_normalize)

        sample = {"image": arr_x}

        # Load in label mask
        sample["mask"] = tifffile.imread(self.mask_paths[idx])

        # Apply Data Augmentation
        if self.transforms:
            sample = self.transforms(image=sample["image"],
                                     mask=sample["mask"])
        if sample["image"].shape[-1] < 20:
            sample["image"] = sample["image"].transpose((2, 0, 1))

        return sample

    def __len__(self):
        return len(self.img_paths)

    def visualize(self, how_many=1, show_specific_index=None):
        for _ in range(how_many):
            rand_int = np.random.randint(len(self.img_paths))
            if show_specific_index is not None:
                rand_int = show_specific_index
            print(self.img_paths[rand_int][0])
            f, axarr = plt.subplots(1, 3, figsize=(30, 9))
            axarr[0].imshow(
                visualize_s1_img(self.img_paths[rand_int][0],
                                 self.img_paths[rand_int][1]))
            sample = self.__getitem__(rand_int)

            img = sample["image"]
            axarr[0].set_title("FCC of original S1 image", fontsize=15)  # noqa
            axarr[1].imshow(img[0])  # Just visualize the VV band here
            axarr[1].set_title(
                f"VV band returned from the dataset, Min: {img.min():.4f}, Max: {img.max():.4f}",
                fontsize=15)
            if "mask" in sample.keys():
                axarr[2].set_title(
                    f"Corresponding water mask: {(sample['mask'] == 1).sum()} px",
                    fontsize=15)
                mask = mask_to_img(sample["mask"], {1: (0, 0, 255)})
                axarr[2].imshow(mask)
            plt.tight_layout()
            plt.show()


### Loss Metrics

In [10]:
import torch
import torch.nn as nn


class XEDiceLoss(nn.Module):

    def __init__(self, alpha=0.5, EPS=1e-7):
        super().__init__()
        self.cross_entropy = nn.CrossEntropyLoss()
        self.alpha = alpha
        self.EPS = EPS

    def forward(self, preds, targets):
        xe_loss = self.cross_entropy(preds, targets)
        dice_loss = self.calculate_dice_loss(preds, targets)
        return self.alpha * xe_loss + (1 - self.alpha) * dice_loss

    def calculate_dice_loss(self, preds, targets):
        targets = targets.float()
        preds = torch.softmax(preds, dim=1)[:, 1]
        intersection = torch.sum(preds * targets)
        union = torch.sum(preds + targets)
        return 1 - (2.0 * intersection) / (union + self.EPS)


# Metric Calculation
def tp_fp_fn(preds, targets):
    tp = torch.sum(preds * targets)
    fp = torch.sum(preds) - tp
    fn = torch.sum(targets) - tp
    return tp.item(), fp.item(), fn.item()


### Training Function

In [11]:
import time
import torch


def train_epochs(model, train_loader, val_loader, optimizer, loss_func,
                 experiment_name, log_path, scheduler, N_EPOCHS,
                 EARLY_STOP_THRESHOLD, EARLY_STOP_PATIENCE, DEVICE, EPS):
    best_metric, best_metric_epoch = 0, 0
    early_stop_counter, best_val_early_stopping_metric = 0, 0
    start_time = time.time()

    for curr_epoch_num in range(1, N_EPOCHS + 1):
        # 训练
        model.train()
        train_loss, batch_time, data_time = AverageMeter(), AverageMeter(
        ), AverageMeter()
        end = time.time()

        for data in train_loader:
            data_time.update(time.time() - end)
            data = to_device(data, DEVICE)
            optimizer.zero_grad()
            preds = model(data["image"])
            loss = loss_func(preds, data["mask"].long())
            loss.backward()
            optimizer.step()
            train_loss.update(loss.item(), data["image"].size(0))
            batch_time.update(time.time() - end)
            end = time.time()

        print_epoch_stats('Train', curr_epoch_num, train_loss, batch_time,
                          data_time, start_time)

        # 验证
        model.eval()
        val_loss, tps, fps, fns = AverageMeter(), 0, 0, 0
        with torch.no_grad():
            for data in val_loader:
                data = to_device(data, DEVICE)
                preds = model(data["image"])
                loss = loss_func(preds, data["mask"].long())
                val_loss.update(loss.item(), data["image"].size(0))
                preds_binary = (torch.softmax(preds, dim=1)[:, 1] > 0.5).long()
                tp, fp, fn = tp_fp_fn(preds_binary, data["mask"])
                tps += tp
                fps += fp
                fns += fn

        iou_global = tps / (tps + fps + fns + EPS)
        print_epoch_stats('Val', curr_epoch_num, val_loss, batch_time,
                          data_time, start_time, iou_global)

        # 保存Scheduler和Model
        if curr_epoch_num > EARLY_STOP_THRESHOLD:
            scheduler.step(iou_global)
        if iou_global > best_metric:
            best_metric, best_metric_epoch = save_model(
                curr_epoch_num, iou_global, model, optimizer, experiment_name,
                log_path, best_metric, best_metric_epoch)
        if iou_global < best_val_early_stopping_metric:
            early_stop_counter += 1
        else:
            best_val_early_stopping_metric, early_stop_counter = iou_global, 0
        if early_stop_counter > EARLY_STOP_PATIENCE:
            print("Early Stopping")
            break
    # 返回有用的信息
    return {
        "model": model,
        "optimizer": optimizer,
        "scheduler": scheduler,
        "best_metric": best_metric,
        "best_metric_epoch": best_metric_epoch,
        "total_epochs": curr_epoch_num,
        "training_time": time.time() - start_time
    }


### Model

In [12]:
import torch.nn as nn
import segmentation_models_pytorch as smp


class FloodwaterSegmentationModel(nn.Module):

    def __init__(self,
                 encoder_name="resnet34",
                 encoder_weights="imagenet",
                 in_channels=2,
                 classes=2):
        """
        初始化 FloodwaterSegmentationModel 类。

        Args:
            encoder_name (str): 使用的编码器名称。默认为 'resnet34'。
            encoder_weights (str): 编码器的预训练权重。默认为 'imagenet'。
            in_channels (int): 输入通道数。对于 Sentinel-1 数据，默认为 2（VV 和 VH）。
            classes (int): 输出类别数。对于二分类问题，默认为 2。
        """
        super(FloodwaterSegmentationModel, self).__init__()
        self.model = smp.Unet(encoder_name=encoder_name,
                              encoder_weights=encoder_weights,
                              in_channels=in_channels,
                              classes=classes)

    def forward(self, x):
        """
        定义模型的前向传播。

        Args:
            x (torch.Tensor): 输入数据。

        Returns:
            torch.Tensor: 模型的输出。
        """
        return self.model(x)


## Main

In [44]:
import torch_npu  # noqa:F401
import torch
import albumentations
import torch.optim
import torch.utils.data
import sys
import os
import segmentation_models_pytorch as smp
import time
# flake8: noqa: E402


#
BASE_PATH = "/home/HW/Pein/Floodwater/dataset"
TRAIN_PERCENTAGE = 0.1
VAL_PERCENTAGE = 0.1
TEST_PERCENTAGE = round(1 - TRAIN_PERCENTAGE - VAL_PERCENTAGE, 2)
MIN_NORMALIZE = -77
MAX_NORMALIZE = 26
TRAIN_CROP_SIZE = 256

label_replacements = [("LabelWater.tif", "LabelWater.tif")]
image_replacements = [("LabelWater.tif", "VV.tif"),
                      ("LabelWater.tif", "VH.tif")]

label_paths = generate_paths(BASE_PATH, "chips/*/s1/*/LabelWater.tif",
                             label_replacements)
image_paths = generate_paths(BASE_PATH, "chips/*/s1/*/LabelWater.tif",
                             image_replacements)

label_splits = split_data(label_paths, TRAIN_PERCENTAGE, VAL_PERCENTAGE)
image_splits = split_data(image_paths, TRAIN_PERCENTAGE, VAL_PERCENTAGE)


In [48]:
# 数据增强
train_transforms = albumentations.Compose([
    albumentations.RandomCrop(TRAIN_CROP_SIZE, TRAIN_CROP_SIZE),
    albumentations.RandomRotate90(),
    albumentations.HorizontalFlip(),
    albumentations.VerticalFlip()
])



In [49]:

# Sentinel1_Dataset初始化
train_dataset = Sentinel1_Dataset(image_splits['train'],
                                  label_splits['train'],
                                  transforms=train_transforms,
                                  min_normalize=MIN_NORMALIZE,
                                  max_normalize=MAX_NORMALIZE)
val_dataset = Sentinel1_Dataset(image_splits['val'],
                                label_splits['val'],
                                transforms=None,
                                min_normalize=MIN_NORMALIZE,
                                max_normalize=MAX_NORMALIZE)
test_dataset = Sentinel1_Dataset(image_splits['test'],
                                 label_splits['test'],
                                 transforms=None,
                                 min_normalize=MIN_NORMALIZE,
                                 max_normalize=MAX_NORMALIZE)


### Original Input shape:(2，512，512). Cropped to (2，256，256) for training. 2 channels represents VH and VV.
### Original Output shape:(512，512). Cropped to (256，256) for training.

In [50]:
for i in range(0, 10):
    sample = train_dataset[i]
    image = sample["image"]
    label = sample["mask"]
    dim = label.shape[0]
    if i == 0:
        print(f"Sample {i} dimensions: {image.shape}")
        print(f"Sample {i} label dimensions: {label.shape}")
    print(f'ratio is {np.round( np.count_nonzero(label) / (dim**2),3)}')


Sample 0 dimensions: (2, 512, 512)
Sample 0 label dimensions: (512, 512)
ratio is 1.0
ratio is 0.523
ratio is 0.008
ratio is 0.107
ratio is 0.0
ratio is 0.023
ratio is 0.113
ratio is 0.851
ratio is 0.994
ratio is 0.264


In [None]:

# 常量
NUM_WORKERS = 4
PIN_MEMORY = False
BATCH_SIZE = 64
EPS = 1e-7
EXPERIMENT_NAME = "Exp1_res34"
BACKBONE = "resnet34"
PATIENCE = 5
N_EPOCHS = 200
LEARNING_RATE = 1e-2
LOG_PATH = f"/home/HW/Pein/Floodwater/trained_models/{EXPERIMENT_NAME}"
os.makedirs(LOG_PATH, exist_ok=True)

EARLY_STOP_THRESHOLD = 5
EARLY_STOP_PATIENCE = PATIENCE * 5

DEVICE = 'npu:7' if torch.npu.is_available() else 'cpu'
if torch.npu.is_available() and DEVICE.startswith('npu'):
    print(f'Using NPU: {DEVICE} ...')
else:
    print('Using CPU ...')


TRAIN_PERCENTAGE = 0.7
VAL_PERCENTAGE = 0.1
TEST_PERCENTAGE = round(1 - TRAIN_PERCENTAGE - VAL_PERCENTAGE, 2)
# 打印基本信息
print(f"""
TRAIN_PERCENTAGE = {TRAIN_PERCENTAGE}
VAL_PERCENTAGE = {VAL_PERCENTAGE}
TEST_PERCENTAGE = {TEST_PERCENTAGE}
NUM_WORKERS = {NUM_WORKERS}
PIN_MEMORY = {PIN_MEMORY}
BATCH_SIZE = {BATCH_SIZE}
EPS = {EPS}
EXPERIMENT_NAME = '{EXPERIMENT_NAME}'
BACKBONE = '{BACKBONE}'
PATIENCE = {PATIENCE}
N_EPOCHS = {N_EPOCHS}
LEARNING_RATE = {LEARNING_RATE}
EARLY_STOP_THRESHOLD = {EARLY_STOP_THRESHOLD}
EARLY_STOP_PATIENCE = {EARLY_STOP_PATIENCE}
""")

# 加载训练集、验证集和测试集
train_loader = create_data_loader(train_dataset,
                                  BATCH_SIZE,
                                  shuffle=True,
                                  num_workers=NUM_WORKERS,
                                  pin_memory=PIN_MEMORY)
val_loader = create_data_loader(val_dataset,
                                BATCH_SIZE,
                                shuffle=False,
                                num_workers=NUM_WORKERS,
                                pin_memory=PIN_MEMORY)
test_loader = create_data_loader(test_dataset,
                                 BATCH_SIZE,
                                 shuffle=False,
                                 num_workers=NUM_WORKERS,
                                 pin_memory=PIN_MEMORY)

# 模型初始化
model = smp.Unet(encoder_name=BACKBONE,
                 encoder_weights="imagenet",
                 in_channels=2,
                 classes=2).to(DEVICE)

loss_func = XEDiceLoss().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       mode="max",
                                                       factor=0.5,
                                                       patience=PATIENCE,
                                                       verbose=True)
# 训练指标初始化
best_metric, best_metric_epoch = 0, 0
early_stop_counter, best_val_early_stopping_metric = 0, 0


In [None]:

# 开始训练
train_result = train_epochs(model, train_loader, val_loader, optimizer,
                            loss_func, EXPERIMENT_NAME, LOG_PATH, scheduler,
                            N_EPOCHS, EARLY_STOP_THRESHOLD,
                            EARLY_STOP_PATIENCE, DEVICE, EPS)

# 打印或处理返回的信息
print("Training Completed:")
print(
    f"Best IOU: {train_result['best_metric']} at Epoch: {train_result['best_metric_epoch']}"
)
print(f"Total Training Time: {train_result['training_time']} seconds")