In [1]:
import torch
import numpy as np
import torch.nn.functional as F
from scipy.spatial.distance import directed_hausdorff

def compute_hausdorff_distance(mask1, mask2):
    """
    计算两个掩码之间的豪斯多夫距离。
    
    参数:
    mask1 (torch.Tensor): 第一个掩码，形状为 (H, W) 或 (D, H, W)。
    mask2 (torch.Tensor): 第二个掩码，形状为 (H, W) 或 (D, H, W)。
    
    返回:
    float: 豪斯多夫距离。
    """
    # 将掩码转换为点集
    points1 = torch.nonzero(mask1).float()  # 获取非零元素的坐标
    points2 = torch.nonzero(mask2).float()  # 获取非零元素的坐标
    
    # 将点集转换为 NumPy 数组，因为 scipy 的 directed_hausdorff 需要 NumPy 数组
    points1_np = points1.cpu().numpy()
    points2_np = points2.cpu().numpy()
    
    # 计算双向豪斯多夫距离
    hd1 = directed_hausdorff(points1_np, points2_np)[0]
    hd2 = directed_hausdorff(points2_np, points1_np)[0]
    
    # 取最大值作为豪斯多夫距离
    hd = max(hd1, hd2)
    
    return hd

def compute_h95(mask1, mask2):
    """
    计算两个掩码之间的 H95 指标。
    
    参数:
    mask1 (torch.Tensor): 第一个掩码，形状为 (H, W) 或 (D, H, W)。
    mask2 (torch.Tensor): 第二个掩码，形状为 (H, W) 或 (D, H, W)。
    
    返回:
    float: H95 指标。
    """
    # 计算豪斯多夫距离
    hd = compute_hausdorff_distance(mask1, mask2)
    
    # 计算 95 百分位数
    h95 = np.percentile(hd, 95)
    
    return h95

In [2]:
# from tqdm import tqdm
# # 测试数据
# batch_size = 2
# num_classes = 4
# spatial_dims = (128, 128, 128)

# # 生成随机标签和预测张量
# labels = torch.randint(0, 2, (batch_size, num_classes, *spatial_dims)).float()  # 标签
# predictions = torch.randint(0, 2, (batch_size, num_classes, *spatial_dims)).float()  # 预测

# # 计算每个样本和每个类别的 H95 指标
# h95_results = torch.zeros(batch_size, num_classes)  # 存储 H95 结果

# for i in tqdm(range(batch_size)):  # 遍历每个样本
#     for j in tqdm(range(num_classes)):  # 遍历每个类别
#         label_mask = labels[i, j]  # 获取当前样本和类别的标签掩码
#         pred_mask = predictions[i, j]  # 获取当前样本和类别的预测掩码
        
#         # 计算 H95 指标
#         h95 = compute_h95(label_mask, pred_mask)
#         h95_results[i, j] = h95

# print("H95 Results for each sample and class:")
# print(h95_results)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pred = torch.rand([2, 4, 128, 128, 128]).float().to(device=device)
mask = torch.randint(0, 4, [2, 128, 128, 128])
mask = F.one_hot(mask).permute(0, 4, 1, 2, 3).float().to(device=device)

pred.shape, mask.shape


(torch.Size([2, 4, 128, 128, 128]), torch.Size([2, 4, 128, 128, 128]))

In [4]:
# compute_h95(pred, mask)

In [11]:
distance = torch.cdist(pred.view(2, -1, 4), mask.reshape(2, -1, 4).shape)
# min_distance_pred_to_mask = torch.min(distance, dim=1)[0]


RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

In [16]:
mask.reshape(2, -1, 4).shape

torch.Size([2, 2097152, 4])

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Deformable3DConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        
        # 常规3D卷积用于特征提取
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding)
        
        # 偏移量预测分支（输出3*K^3个偏移量，对应三维坐标偏移）
        self.offset_conv = nn.Conv3d(in_channels, 3 * kernel_size**3, 
                                   kernel_size, stride, padding)
        
        # 初始化偏移量卷积权重为零
        nn.init.constant_(self.offset_conv.weight, 0)
        nn.init.constant_(self.offset_conv.bias, 0)

    def forward(self, x):
        # 生成偏移量场 [B, 3*K^3, D, H, W]
        offset = self.offset_conv(x)
        
        # 生成常规卷积的输出特征
        base_feature = self.conv(x)
        
        # 获取输入特征图尺寸
        B, C, D, H, W = x.shape
        K = self.kernel_size
        
        # 生成采样网格
        grid = self._get_grid(offset)
        
        # 可变形采样
        deformed_feature = F.grid_sample(
            x, 
            grid, 
            mode='bilinear',
            padding_mode='zeros',
            align_corners=False
        )
        
        # 重塑采样结果并进行卷积融合
        deformed_feature = deformed_feature.view(B, C, K, K, K, D, H, W)
        output = torch.einsum('bcklmnop,qcklmn->bqp', deformed_feature, self.conv.weight)
        
        return output + base_feature  # 残差连接

    def _get_grid(self, offset):
        B, _, D, H, W = offset.shape
        K = self.kernel_size
        
        # 生成基础网格坐标
        x_grid, y_grid, z_grid = torch.meshgrid(
            torch.linspace(-1, 1, K),
            torch.linspace(-1, 1, K),
            torch.linspace(-1, 1, K)
        )
        base_grid = torch.stack((z_grid, y_grid, x_grid), 3)  # [K, K, K, 3]
        base_grid = base_grid.unsqueeze(0).repeat(B, 1, 1, 1, 1)  # [B, K, K, K, 3]
        
        # 添加偏移量并归一化
        offset = offset.permute(0, 2, 3, 4, 1).view(B, D, H, W, K, K, K, 3)
        deformed_grid = base_grid + offset * 0.1  # 控制偏移量幅度
        
        return deformed_grid.reshape(B, D*K, H*K, W*K, 3)


In [6]:
# BraTS输入维度：(batch_size, 4, 128, 128, 128) 对应4个模态
input_tensor = torch.randn(1, 4, 128, 128, 128)
conv_layer = Deformable3DConv(in_channels=4, out_channels=64)
output = conv_layer(input_tensor)  # 输出维度：(batch_size, 64, 128, 128, 128)


RuntimeError: einsum(): the number of subscripts in the equation (6) does not match the number of dimensions (5) for operand 1 and no ellipsis was given

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Deformable3DConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
        super().__init__()
        self.kernel_size = kernel_size
        self.padding = padding
        
        # 偏移量生成卷积层（输出通道数为3倍卷积核体积）
        self.offset_conv = nn.Conv3d(
            in_channels,
            3 * kernel_size**3,  # 每个采样点3个方向偏移
            kernel_size=kernel_size,
            padding=padding
        )
        
        # 常规卷积层
        self.conv = nn.Conv3d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            padding=padding
        )
        
    def _get_base_grid(self, x):
        """生成归一化基础网格坐标[-1, 1]"""
        batch_size, _, D, H, W = x.size()
        grid_d, grid_h, grid_w = torch.meshgrid(
            torch.linspace(-1, 1, D, device=x.device),
            torch.linspace(-1, 1, H, device=x.device),
            torch.linspace(-1, 1, W, device=x.device),
            indexing='ij'
        )
        grid = torch.stack((grid_w, grid_h, grid_d), dim=-1)  # [D, H, W, 3]
        return grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)  # [B, D, H, W, 3]

    def forward(self, x):
        # 生成偏移量 [B, 3*k^3, D, H, W]
        offsets = self.offset_conv(x)
        batch_size, _, D, H, W = offsets.size()
        
        # 调整偏移量形状 [B, D, H, W, k^3, 3]
        offsets = offsets.view(
            batch_size, 
            self.kernel_size**3, 
            3, 
            D, H, W
        ).permute(0, 3, 4, 5, 1, 2)
        
        # 生成基础网格 [B, D, H, W, 3]
        grid = self._get_base_grid(x)
        
        # 生成卷积核相对偏移量
        kernel_offset = torch.stack(
            torch.meshgrid(
                torch.linspace(-1, 1, self.kernel_size, device=x.device),
                torch.linspace(-1, 1, self.kernel_size, device=x.device),
                torch.linspace(-1, 1, self.kernel_size, device=x.device),
                indexing='ij'
            ), dim=-1
        ).view(-1, 3)  # [k^3, 3]
        
        # 应用相对偏移和可学习偏移
        grid = grid.unsqueeze(4) + kernel_offset.view(1, 1, 1, 1, -1, 3)  # [B, D, H, W, k^3, 3]
        grid = grid + offsets
        
        # 执行可变形采样
        deformed_x = F.grid_sample(
            input=x,
            grid=grid,
            mode='bilinear',
            padding_mode='zeros',
            align_corners=True
        )
        
        # 应用常规卷积
        return self.conv(deformed_x)


In [None]:
# BraTS输入维度：(batch_size, 4, 128, 128, 128) 对应4个模态
input_tensor = torch.randn(1, 4, 128, 128, 128)
conv_layer = Deformable3DConv(in_channels=4, out_channels=64)
output = conv_layer(input_tensor)  # 输出维度：(batch_size, 64, 128, 128, 128)

RuntimeError: grid_sampler(): expected 5D input and grid with same number of dimensions, but got input with sizes [1, 4, 128, 128, 128] and grid with sizes [1, 128, 128, 128, 27, 3]

: 