In [None]:
pip install segmentation-models-pytorch

In [None]:
import torch
import torch.nn as nn
import timm
import segmentation_models_pytorch as smp
import segmentation_models_pytorch.utils as smp_utils
import numpy as np
import rasterio
from torch.utils.data import Dataset, DataLoader
from scipy.sparse import load_npz
import cv2
from torchvision import transforms
import torch.nn.functional as F
import scipy.sparse as sp
from timm.layers import trunc_normal_
DEVICE = 'cuda'
from torch.optim.lr_scheduler import CosineAnnealingLR
from segmentation_models_pytorch.utils.meter import AverageValueMeter
from tqdm import tqdm as tqdm
import sys
import pickle
import networkx as nx

In [None]:
train_list = np.load("/kaggle/input/data-split-water/train.npy")
test_list = np.load("/kaggle/input/data-split-water/valid.npy")
val_list = np.load("/kaggle/input/data-split-water/test.npy")
class CustomDataset(Dataset):
    def __init__(self, graph_root_dir,label_root_dir,num_list):
        self.features = []
        self.mapping = []
        self.label = []
        self.images = []
        self.adj = []
        # 遍历文件夹，读取 pkl 文件和 npz 文件
        for i in range(len(num_list)):
            self.features.append(graph_root_dir+"/water_"+str(num_list[i])+"_superpixel.npy")
            self.adj.append(graph_root_dir+"/water_"+str(num_list[i])+"_adj.npz")
            self.mapping.append(graph_root_dir+"/"+"water_"+str(num_list[i])+".npz")
            self.label.append(label_root_dir+"/"+"water_"+str(num_list[i])+".png")
            self.images.append(label_root_dir+"/"+"sar_image_"+str(num_list[i])+".tif")
        self.transformer = transforms.Compose([
            transforms.ToTensor(),
        ]) 
    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):

        with rasterio.open(self.images[idx]) as src:
            data = src.read()
        image = data[:3] 
        image = image.transpose(1, 2, 0)
        image = self.transformer(image)
        # 超像素到像素级映射——
        feature = np.load(self.features[idx]).astype(np.float32)
        feature = self.normalize_columns(feature)
        adj_matrix_loaded = load_npz(self.adj[idx])

        row = torch.LongTensor(adj_matrix_loaded.row)
        col = torch.LongTensor(adj_matrix_loaded.col)
        val = torch.FloatTensor(adj_matrix_loaded.data)
        size = torch.Size(adj_matrix_loaded.shape)
        adj = torch.sparse_coo_tensor(indices=torch.stack([row, col]), values=val, size=size)

        mapping_matrix = load_npz(self.mapping[idx])
        
        row = torch.LongTensor(mapping_matrix.row)
        col = torch.LongTensor(mapping_matrix.col)
        val = torch.FloatTensor(mapping_matrix.data)
        size = torch.Size(mapping_matrix.shape)
        mapping = torch.sparse_coo_tensor(indices=torch.stack([row, col]), values=val, size=size)
    
    
        mask = cv2.imread(self.label[idx],0).astype(np.float32)
        mask = mask.reshape(1, mask.shape[0], mask.shape[1])
        return ((image,feature,adj,mapping),mask)
    
    
    def normalize_columns(self,matrix):
        """
        将矩阵的每一列归一化到 [0, 1] 的范围内
        """
        col_max = np.max(matrix, axis=0)   # 获取每一列的最大值
        col_min = np.min(matrix, axis=0)   # 获取每一列的最小值
        denominator = col_max - col_min   # 计算每一列的归一化因子
        denominator[denominator == 0] = 1  # 避免除数为 0 的情况
        return (matrix - col_min) / denominator   # 归一化操作并返回结果

batch_size = 8
train_dataset = CustomDataset("/kaggle/input/train-image-512/train_image_512","/kaggle/input/image-graph-data",train_list)  # 假设所有数据文件都在 "data" 文件夹下
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = CustomDataset("/kaggle/input/train-image-512/train_image_512","/kaggle/input/image-graph-data",test_list)  # 假设所有数据文件都在 "data" 文件夹下
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

val_dataset = CustomDataset("/kaggle/input/train-image-512/train_image_512","/kaggle/input/image-graph-data",val_list)  # 假设所有数据文件都在 "data" 文件夹下
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
# for image,label in test_dataloader:
#     print(image[0].dtype,image[1].dtype,image[2].dtype,image[3].dtype)

In [None]:
#卷积块 ——3*3卷积 +batchnorm + relu
class Conv_block(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Conv_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size = 3, padding=1,bias = False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
        #初始化模型参数
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        
    def forward(self, x):
        x = self.conv(x)
        return x
    
#卷积块 ——1*1卷积 +batchnorm + relu
class Conv_block_1(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Conv_block_1, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size = 1,bias = False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
        #初始化模型参数
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        
    def forward(self, x):
        x = self.conv(x)
        return x
    
#1*1卷积 
class Conv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size = 1,bias = False)
        )
        #初始化模型参数
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)   
    def forward(self, x):
        x = self.conv(x)
        return x
    
#上采样模块（包括上采样+拼接+卷积块）
class Up(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Up, self).__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2)
        self.conv = Conv_block(out_ch*2, out_ch)
 
    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


#多头注意力机制实现；数为qkv及mask；输出强化+残差连接后特征；
class MultiHeadAttention(nn.Module):
    ''' Multi-Head Attention module '''
    def __init__(self, n_head, d_model, qkv_channels, dropout=0.1):
        super().__init__()

        self.n_head = n_head
        self.qkv_channels = qkv_channels
        self.scale = qkv_channels ** 0.5
        
        self.w_qkv = nn.Linear(d_model, 3*n_head * qkv_channels, bias=False)
        self.fc = nn.Linear(n_head * qkv_channels, d_model, bias=False)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

    def forward(self,x, mask=None):

        N,L,C = x.shape
        
        residual = x
        
        x = self.layer_norm(x)
        
        qkv = self.w_qkv(x).reshape(N,L, 3, self.n_head, self.qkv_channels).permute(2, 0, 3, 1, 4)
        
        q, k, v = qkv.unbind(0)
        q = q /self.scale
        attn = q @ k.transpose(-2, -1)
        
        if mask is not None:
            mask = mask.unsqueeze(1).to_dense()
            attn = attn.masked_fill(mask == 0, -1e9)
            
        attn = attn.softmax(dim=-1)
        attn = self.dropout1(attn)
        x = attn @ v
        
        x = x.transpose(1, 2).reshape(N,L,-1)
        x = self.dropout2(self.fc(x))
        residual = x + residual

        return residual
#多层感知机
class PositionwiseFeedForward(nn.Module):
    ''' A two-feed-forward-layer module '''

    def __init__(self, d_in, d_hid, dropout=0.1):
        super().__init__()
        self.w_1 = nn.Linear(d_in, d_hid) # position-wise
        self.w_2 = nn.Linear(d_hid, d_in) # position-wise
        self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
        self.dropout = nn.Dropout(dropout)
#         self.ls1 = LayerScale(d_in, init_values=0.1) 

    def forward(self, x, adj = None):

        residual = x
        
        x = self.layer_norm(x)
        x = self.w_1(x)
        x = F.relu(adj@x)
        x = self.w_2(x)
        x = adj@x
        x = self.dropout(x)
#         x = self.ls1(self.dropout(x))
        x += residual
        
        return x
    
class SuperPixelBlock(nn.Module):
    ''' Compose with two layers '''
    #d_model 代表模型输入特征维度；d_inner代表多层感知机之间映射特征维度；
    #n_head代表头个数；d_k, d_v分别代表key和value的特征维度（每个头下的）
    def __init__(self, d_model, d_inner, n_head, qkv_channels, dropout=0.1):
        super(SuperPixelBlock, self).__init__()
        self.slf_attn = MultiHeadAttention(n_head, d_model, qkv_channels, dropout=dropout)
        self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)

    def forward(self, enc_input, slf_attn_mask=None):
        enc_output = self.slf_attn(
            enc_input)
        enc_output = self.pos_ffn(enc_output,slf_attn_mask)
        return enc_output
    
class SPTUNet(nn.Module):
    def __init__(self, projection_dim, transformer_layers, transformer_heads,qkv_dim, hidden_dim,dropout):
        super(SPTUNet, self).__init__()
        self.projection_dim = projection_dim
        self.transformer_layers = transformer_layers
        self.transformer_heads = transformer_heads
            
        # 像素级模块
        self.encoder = timm.create_model("efficientnet_b4",pretrained = True,in_chans=3,features_only = True,drop_rate = 0.2,drop_path_rate=0.2)
        self.up6 = Up(448, 160)
        self.up7 = Up(160, 56)
        self.up8 = Up(56, 32)
        self.up9 = Up(32, 24)
        self.up10 = nn.ConvTranspose2d(24,24,2,stride = 2)
        
        #超像素级模块
        # 多层 Transformer 编解码器
        self.transformers = nn.ModuleList()
        for _ in range(transformer_layers):
            self.transformers.append(SuperPixelBlock(projection_dim, hidden_dim, transformer_heads, qkv_dim, dropout=dropout))
            
        self.layer_norm = nn.LayerNorm(projection_dim, eps=1e-6)
        self.graph_outputs = nn.Linear(projection_dim, 24, bias=True)
        self.pos_embed = nn.Parameter(torch.randn(1, 512, projection_dim) * .02)
        trunc_normal_(self.pos_embed, std=.02)
        self.graph_inputs = nn.Linear(5, projection_dim)
        self.batch_norm = nn.BatchNorm2d(24)
        
        #最终输出模块
        self.output_conv = Conv_block(24,24)
        self.conv_1 = Conv(24,1)
        
    def forward(self,images,graph_inputs, adj_inputs, mapping_matrix):
        
        #像素级提取水体：：：
        c1,c2,c3,c4,c5 = self.encoder(images)
        up_6 = self.up6(c5, c4)
        up_7 = self.up7(up_6, c3)
        up_8 = self.up8(up_7, c2)
        up_9 = self.up9(up_8, c1)
        pixels = self.up10(up_9)

        #超像素级提取水体：：：
        adj_inputs = adj_inputs.to_dense()
        nbatches = graph_inputs.size(0)
        #特征映射 +LN
        superPixels = self.graph_inputs(graph_inputs) 
        superPixels = self.layer_norm(superPixels)
        #添加位置编码
        superPixels = superPixels + self.pos_embed
        # 分类头
        for transformer in self.transformers:
            superPixels = transformer(superPixels,adj_inputs)
        #特征映射 +尺寸映射
        
        superPixels = self.graph_outputs(superPixels)
        superPixels = torch.bmm(mapping_matrix,superPixels)
        superPixels = torch.reshape(superPixels, [nbatches,24,512, 512])
        superPixels = self.batch_norm(superPixels)
        
        #输出模块：：：
        outputs = torch.add(pixels,superPixels)
        outputs = self.output_conv(outputs)
        outputs = self.conv_1(outputs)
        outputs = torch.sigmoid(outputs)
        
        return outputs
    
model = SPTUNet(projection_dim=128,  transformer_layers=8, transformer_heads=4, qkv_dim = 32, hidden_dim=512,dropout = 0.2)



In [None]:
# from torchinfo import summary

# summary(model, [(2,3,512,512),(2,512,5),(2,512,512),(2,262144,512)])

In [None]:
#定义早退函数
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=10, verbose=False, delta=0):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 10
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss, model):
        score = val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        """Saves model when validation loss decrease."""
        if self.verbose:
            print(f'Validation iou increased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model, './best_model.pth')
        self.val_loss_min = val_loss

In [None]:
class BCE_BinaryDiceLoss(torch.nn.Module):
    def __init__(self, ignore_index=None,w = 0.6, reduction='mean', **kwargs):
        super(BCE_BinaryDiceLoss, self).__init__()
        self.smooth = 1  # suggest set a large number when target area is large,like '10|100'
        self.ignore_index = ignore_index
        self.reduction = reduction
        self.w = w #代表BCE和diceLoss之间的权重比例
        self._name = "BCE_BinaryDiceLoss" # 新增name属性并设置为该类的名称
    @property
    def __name__(self):
        if self._name is None:
            name = self.__class__.__name__
            s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
            return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
        else:
            return self._name
    def forward(self, output, target, use_sigmoid=False):
        assert output.shape[0] == target.shape[0], "output & target batch size don't match"
        
        if use_sigmoid:
            output = torch.sigmoid(output)

        if self.ignore_index is not None:
            validmask = (target != self.ignore_index).float()
            output = output.mul(validmask)  # can not use inplace for bp
            target = target.float().mul(validmask)

        dim0= output.shape[0]

        output = output.contiguous().view(dim0,-1)
        target = target.contiguous().view(dim0,-1).float()

        num = 2 * torch.sum(torch.mul(output, target), dim=1) + self.smooth
        den = torch.sum(output.abs() + target.abs(), dim=1) + self.smooth

        loss = 1 - (num / den)

        if self.reduction == 'mean':
            return self.w * F.binary_cross_entropy(output, target)+(1-self.w)*loss.mean()
        elif self.reduction == 'sum':
            return self.w * F.binary_cross_entropy(output, target)+(1-self.w)*loss.sum()
        elif self.reduction == 'none':
            return self.w * F.binary_cross_entropy(output, target)+(1-self.w)*loss
        else:
            raise Exception('Unexpected reduction {}'.format(self.reduction))
            
loss = BCE_BinaryDiceLoss()


metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
    smp.utils.metrics.Precision(threshold=0.5),
    smp.utils.metrics.Recall(threshold=0.5)
]

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.0003),
])
scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=0)
early_stopping = EarlyStopping(patience=30, verbose=True)

class ValidEpoch_avert(smp.utils.train.Epoch):
    def __init__(self, model, loss, metrics, device="cpu", verbose=True):
        super().__init__(
            model=model,
            loss=loss,
            metrics=metrics,
            stage_name="valid",
            device=device,
            verbose=verbose,
        )
    def batch_update(self, x, y):
        with torch.no_grad():
            prediction = self.model.forward(x[0],x[1],x[2],x[3])
            loss = self.loss(prediction, y)
        return loss, prediction
    def on_epoch_start(self):
        self.model.eval()
    def run(self, dataloader):
        self.on_epoch_start()
        logs = {}
        loss_meter = AverageValueMeter()
        metrics_meters = {metric.__name__: AverageValueMeter() for metric in self.metrics}

        with tqdm(
            dataloader,
            desc=self.stage_name,
            file=sys.stdout,
            disable=not (self.verbose),
        ) as iterator:
            for x, y in iterator:
                for i in range(len(x)):
                    x[i] = x[i].to(self.device)
                y = y.to(self.device)
                loss, y_pred = self.batch_update(x, y)

                # update loss logs
                loss_value = loss.cpu().detach().numpy()
                loss_meter.add(loss_value)
                loss_logs = {self.loss.__name__: loss_meter.mean}
                logs.update(loss_logs)

                # update metrics logs
                for metric_fn in self.metrics:
                    metric_value = metric_fn(y_pred, y).cpu().detach().numpy()
                    metrics_meters[metric_fn.__name__].add(metric_value)
                metrics_logs = {k: v.mean for k, v in metrics_meters.items()}
                logs.update(metrics_logs)

                if self.verbose:
                    s = self._format_logs(logs)
                    iterator.set_postfix_str(s)

        return logs
    
class TrainEpoch_avert(smp.utils.train.Epoch):
    def __init__(self, model, loss, metrics, optimizer,scheduler, device="cpu", verbose=True):
        super().__init__(
            model=model,
            loss=loss,
            metrics=metrics,
            stage_name="train",
            device=device,
            verbose=verbose,
        )
        self.optimizer = optimizer
        self.scheduler = scheduler

    def on_epoch_start(self):
        self.model.train()

    def batch_update(self, x, y):
        self.optimizer.zero_grad()
        prediction = self.model.forward(x[0],x[1],x[2],x[3])
        loss = self.loss(prediction, y)
        loss.backward()
        self.optimizer.step()
        self.scheduler.step()
        return loss, prediction
    
    def run(self, dataloader):
        self.on_epoch_start()
        logs = {}
        loss_meter = AverageValueMeter()
        metrics_meters = {metric.__name__: AverageValueMeter() for metric in self.metrics}

        with tqdm(
            dataloader,
            desc=self.stage_name,
            file=sys.stdout,
            disable=not (self.verbose),
        ) as iterator:
            for x, y in iterator:
                for i in range(len(x)):
                    x[i] = x[i].to(self.device)
                y = y.to(self.device)
                loss, y_pred = self.batch_update(x, y)
                # update loss logs
                loss_value = loss.cpu().detach().numpy()
                loss_meter.add(loss_value)
                loss_logs = {self.loss.__name__: loss_meter.mean}
                logs.update(loss_logs)

                # update metrics logs
                for metric_fn in self.metrics:
                    metric_value = metric_fn(y_pred, y).cpu().detach().numpy()
                    metrics_meters[metric_fn.__name__].add(metric_value)
                metrics_logs = {k: v.mean for k, v in metrics_meters.items()}
                logs.update(metrics_logs)

                if self.verbose:
                    s = self._format_logs(logs)
                    iterator.set_postfix_str(s)

        return logs
    
train_epoch = TrainEpoch_avert(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    scheduler = scheduler,
    device=DEVICE,
    verbose=True,
)

valid_epoch = ValidEpoch_avert(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

for i in range(0, 300):   
    print('\nEpoch: {}'.format(i))
    train_logs = train_epoch.run(train_dataloader)
    valid_logs = valid_epoch.run(val_dataloader)
    val_iou = valid_logs['iou_score']
    early_stopping(val_iou, model)
    if early_stopping.early_stop:
        print("Early stopping")
        break

In [None]:
model = torch.load('/kaggle/input/spt-unet-512/spt_Unet_512_80.pth')
test_epoch = ValidEpoch_avert(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)
test_logs = test_epoch.run(test_dataloader)