In [1]:
import torch
import torch.nn.functional as F


class MultiScaleLoss(torch.nn.Module):
    def __init__(self, scales=[1, 0.5, 0.25], weight_mse=1.0, weight_sam=1.0):
        """
        Multi-Scale Loss combining MSE and SAM.

        Args:
            scales (list): Downscaling factors for multi-scale loss.
            weight_mse (float): Weight for the MSE loss component.
            weight_sam (float): Weight for the SAM loss component.
        """
        super(MultiScaleLoss, self).__init__()
        self.scales = scales
        self.weight_mse = weight_mse
        self.weight_sam = weight_sam

    def forward(self, y_true, y_pred):
        """
        Compute the combined loss.

        Args:
            y_true (torch.Tensor): Ground truth tensor of shape (N, C, H, W).
            y_pred (torch.Tensor): Predicted tensor of shape (N, C, H, W).

        Returns:
            torch.Tensor: The combined loss.
        """
        total_loss = 0.0

        for scale in self.scales:
            if scale < 1:
                y_true_scaled = F.interpolate(
                    y_true, scale_factor=scale, mode="bilinear", align_corners=False
                )
                y_pred_scaled = F.interpolate(
                    y_pred, scale_factor=scale, mode="bilinear", align_corners=False
                )
            else:
                y_true_scaled = y_true
                y_pred_scaled = y_pred

            # MSE Loss
            mse_loss = F.mse_loss(y_pred_scaled, y_true_scaled)

            # SAM Loss
            sam_loss = self.sam_loss(y_true_scaled, y_pred_scaled)

            # Combine losses
            total_loss += self.weight_mse * mse_loss + self.weight_sam * sam_loss

        return total_loss

    @staticmethod
    def sam_loss(y_true, y_pred):
        """
        Spectral Angle Mapper (SAM) loss.

        Args:
            y_true (torch.Tensor): Ground truth tensor of shape (N, C, H, W).
            y_pred (torch.Tensor): Predicted tensor of shape (N, C, H, W).

        Returns:
            torch.Tensor: The SAM loss.
        """
        # Flatten spatial dimensions for vectorized computation
        y_true_flat = y_true.view(y_true.size(0), y_true.size(1), -1)  # (N, C, H*W)
        y_pred_flat = y_pred.view(y_pred.size(0), y_pred.size(1), -1)  # (N, C, H*W)

        # Compute dot product and norms
        dot_product = torch.sum(y_true_flat * y_pred_flat, dim=1)  # (N, H*W)
        norm_true = torch.norm(y_true_flat, dim=1) + 1e-8  # Avoid division by zero
        norm_pred = torch.norm(y_pred_flat, dim=1) + 1e-8

        # Compute SAM
        sam = torch.acos(
            torch.clamp(dot_product / (norm_true * norm_pred), -1.0, 1.0)
        )  # (N, H*W)
        return torch.mean(sam)  # Mean over all pixels

In [2]:
y_true = torch.rand((4, 10, 128, 128))  # Example ground truth with 10 channels
y_pred = torch.rand((4, 10, 128, 128))  # Example prediction with 10 channels

loss_fn = MultiScaleLoss(scales=[1, 0.5, 0.25], weight_mse=1.0, weight_sam=1.0)

loss = loss_fn(y_true, y_pred)
print("Loss:", loss.item())

NameError: name 'loss_fn' is not defined

In [17]:
import torch

# 输入张量 (B=1, C=10, H=2, W=2)
x = torch.arange(72).view(1, 18, 2, 2)  # 数据从 0 到 31
print("原始张量:\n", x)

# 参数设置
time_span = 3
conv_input_channel = 18 // time_span  # 每个时间步的通道数

# 通过 view 调整形状
test = x.view(1, time_span, conv_input_channel, 2, 2)

# 原始与切割后的结果比较
print("\n切割后的张量形状:", test.shape)
print("切割后的张量:\n", test)

# 验证切割后每个时间步的内容
for t in range(time_span):
    print(f"\n时间步 {t}: 通道数据")
    print(test[0, :, t, :, :])  # 提取第 t 个时间步的通道数据

原始张量:
 tensor([[[[ 0,  1],
          [ 2,  3]],

         [[ 4,  5],
          [ 6,  7]],

         [[ 8,  9],
          [10, 11]],

         [[12, 13],
          [14, 15]],

         [[16, 17],
          [18, 19]],

         [[20, 21],
          [22, 23]],

         [[24, 25],
          [26, 27]],

         [[28, 29],
          [30, 31]],

         [[32, 33],
          [34, 35]],

         [[36, 37],
          [38, 39]],

         [[40, 41],
          [42, 43]],

         [[44, 45],
          [46, 47]],

         [[48, 49],
          [50, 51]],

         [[52, 53],
          [54, 55]],

         [[56, 57],
          [58, 59]],

         [[60, 61],
          [62, 63]],

         [[64, 65],
          [66, 67]],

         [[68, 69],
          [70, 71]]]])

切割后的张量形状: torch.Size([1, 3, 6, 2, 2])
切割后的张量:
 tensor([[[[[ 0,  1],
           [ 2,  3]],

          [[ 4,  5],
           [ 6,  7]],

          [[ 8,  9],
           [10, 11]],

          [[12, 13],
           [14, 15]],

          [[

In [18]:
print(x.shape)

torch.Size([1, 18, 2, 2])


In [19]:
test[:, 0, :, :, :]

tensor([[[[ 0,  1],
          [ 2,  3]],

         [[ 4,  5],
          [ 6,  7]],

         [[ 8,  9],
          [10, 11]],

         [[12, 13],
          [14, 15]],

         [[16, 17],
          [18, 19]],

         [[20, 21],
          [22, 23]]]])

In [20]:
test = test.view(1, time_span * conv_input_channel, 2, 2)
test[:, 2, :, :]

tensor([[[ 8,  9],
         [10, 11]]])