In [1]:
!pip install monai

Collecting monai
  Downloading monai-1.5.1-py3-none-any.whl.metadata (13 kB)
Downloading monai-1.5.1-py3-none-any.whl (2.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.7/2.7 MB[0m [31m46.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: monai
Successfully installed monai-1.5.1


In [1]:
import torch
import torch.nn as nn
from monai.networks.nets.swin_unetr import SwinTransformer
import torch.nn.functional as F
from typing import List, Optional, Tuple
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
# или попробовать:
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'




In [2]:
class AttentionFaultFormerEncoder(nn.Module):
    """
    Обёртка над MONAI SwinTransformer, возвращающая ровно 4 cascaded stages:
      [C D/2 H/2 W/2, 2C D/4 H/4 W/4, 4C D/8 H/8 W/8, 8C D/16 H/16 W/16]
    Заменяет encoder.patch_embed.proj на Conv3d(kernel=5, stride=2).
    варианты:
    1) embed_dim = 96, heads = [3, 6, 12]
    2) embed_dim = 192, heads = [3, 6, 12] или [6, 12, 24]
    3) embed_dim = 256, heads = [4, 8, 16]
    """
    def __init__(
        self,
        in_chans: int = 1,
        embed_dim: int = 48,
        window_size=(7,7,7),
        patch_size=(2,2,2),   # <- важно: patch_size обычно (2,2,2) в swin
        depths=(2,2,2,1),     # MONAI требует >=4 stages — оставляем 4
        num_heads=(3,6,12,12),
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.1,
        norm_layer=nn.LayerNorm,
        patch_norm: bool = True,
        use_checkpoint: bool = True,
        spatial_dims: int = 3,
        patch_kernel: int = 5,
        patch_stride: int = 2,
        patch_padding: int = 2,
    ):
        super().__init__()
        if SwinTransformer is None:
            raise ImportError("MONAI SwinTransformer не найден. Установите monai и убедитесь, что swin_unetr доступен.")

        # создаём SwinTransformer с 4 стадиями (MONAI)
        self.encoder = SwinTransformer(
            in_chans=in_chans,
            embed_dim=embed_dim,
            window_size=window_size,
            patch_size=patch_size,
            depths=depths,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            qkv_bias=qkv_bias,
            drop_rate=drop_rate,
            attn_drop_rate=attn_drop_rate,
            drop_path_rate=drop_path_rate,
            norm_layer=norm_layer,
            patch_norm=patch_norm,
            use_checkpoint=use_checkpoint,
            spatial_dims=spatial_dims,
        )

        # заменяем внутреннюю проекцию (PatchEmbed.proj) на Conv3d 5x5x5 stride 2 padding 2
        # это соответствует статье: kernel=5, stride=2 (overlap)
        self.encoder.patch_embed.proj = nn.Conv3d(
            in_channels=in_chans,
            out_channels=embed_dim,
            kernel_size=patch_kernel,
            stride=patch_stride,
            padding=patch_padding,
            bias=False,
        )
        nn.init.kaiming_normal_(self.encoder.patch_embed.proj.weight, nonlinearity="relu")

    def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
        """
        Возвращаем 4 feature maps:
         - feat0: spatial output сразу после patch_embed.proj  -> (B, C_e, D', H', W')
         - feat1-feat3: уровни из SwinTransformer (после BasicLayer/PatchMerging)
        """
        # 1) вычислим spatial-выход patch embed (Conv3d)
        #    это (B, embed_dim, D', H', W')
        feat0_spatial = self.encoder.patch_embed.proj(x)

        # не обязательно: можно применить норм (если нужно)
        # if hasattr(self.encoder.patch_embed, "norm") and self.encoder.patch_embed.norm is not None:
        #     # norm в PatchEmbed обычно применяют к flattened tokens, поэтому применять сюда — опционально
        #     pass

        # 2) получим стандартные выходы Swin (включая первый уровень)
        feats = self.encoder(x)

        if not isinstance(feats, (list, tuple)) or len(feats) < 4:
            raise RuntimeError("SwinTransformer вернул неожидаемый формат/количество фичей")

        # 3) собираем: первый — наш spatial от proj, остальные — уровни 1..3 из swin
        #    предполагаем, что feats[i] имеют формат (B, C_i, D_i, H_i, W_i)
        encoder_feats = [feat0_spatial, feats[1], feats[2], feats[3]]
        return encoder_feats


In [3]:
class ChannelAttention3D(nn.Module):
    """
    Канальное внимание для 3D фич (часть CBAM).
    Считает маску каналов через MLP, применяемый к avg/max pooled дескрипторам.
    в статье не сказано ни про reduction_ratio
    r = 16 (8 или 32)
    """
    def __init__(self, channels: int, reduction_ratio: int = 16):
        super().__init__()
        hidden = max(1, channels // reduction_ratio)
        self.mlp = nn.Sequential(
            #nn.Conv3d(channels, hidden, kernel_size=1, bias=False),
            nn.Linear(channels, hidden),
            nn.ReLU(inplace=True),
            nn.Linear(hidden, channels)
            #nn.Conv3d(hidden, channels, kernel_size=1, bias=False)
        )
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.max_pool = nn.AdaptiveMaxPool3d(1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        b, c, d, h, w = x.shape
        avg = self.avg_pool(x).view(b, c)
        max_ = self.max_pool(x).view(b, c)
        avg_out = self.mlp(avg)
        max_out = self.mlp(max_)
        return self.sigmoid(avg_out + max_out).view(b, c, 1, 1, 1)

In [4]:
class SpatialAttention3D(nn.Module):
    """
    Пространственное внимание для 3D: объединяет среднее и максимум по каналам
    и применяет свёртку.
    """
    def __init__(self, kernel_size: int = 7):
        super().__init__()
        padding = (kernel_size - 1) // 2
        self.conv = nn.Conv3d(2, 1, kernel_size=kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        avg = torch.mean(x, dim=1, keepdim=True)
        max_ = torch.max(x, dim=1, keepdim=True)[0]
        cat = torch.cat([avg, max_], dim=1)
        return self.sigmoid(self.conv(cat))

In [5]:
class CBAM3D(nn.Module):
    """
    CBAM для 3D: каналное внимание, затем пространственное внимание.
    """
    def __init__(self, channels: int, reduction_ratio: int = 16, spatial_kernel: int = 7):
        super().__init__()
        self.channel_attn = ChannelAttention3D(channels, reduction_ratio)
        self.spatial_attn = SpatialAttention3D(kernel_size=spatial_kernel)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x * self.channel_attn(x)
        x = x * self.spatial_attn(x)
        return x

In [6]:
class AttentionSkipBlock(nn.Module):
    """
    Conv3d(5x5x5) -> CBAM -> BatchNorm3d -> ReLU.
    не понятны параметры Conv3d и есть ли активация в конце
    """
    def __init__(self, in_channels: int, out_channels: int, reduction_ratio: int = 16, activation: bool = False, negative_slope=0.01):
        super().__init__()
        self.in_channels = in_channels
        self.proj = nn.Conv3d(in_channels, out_channels, kernel_size=5, stride = 1, padding=2, bias=False)
        self.cbam = CBAM3D(out_channels, reduction_ratio=reduction_ratio)
        self.bn = nn.BatchNorm3d(out_channels)
        self.act = nn.LeakyReLU(negative_slope=negative_slope, inplace=True)
        self.activation = activation
        nn.init.kaiming_normal_(self.proj.weight, nonlinearity="relu")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.shape[1] != self.in_channels:
            raise RuntimeError(
                f"AttentionSkipBlock expected input with {self.in_channels} channels, "
                f"but got {x.shape[1]} channels. This indicates mismatch between encoder "
                f"feature channels and decoder skip-block configuration."
            )
        x = self.proj(x)
        x = self.cbam(x)
        out = self.bn(x)
        if self.activation:
            out  = self.act(out)
        return out

In [7]:
class MASCA3D(nn.Module):
    """
    MASCA: полосатые свёртки разной длины с суммированием и 1x1 свёрткой,
    затем сигмоида для получения воксельной маски.
    """
    def __init__(self, channels: int, N_list: Tuple[int, ...] = (7, 11)):
        super().__init__()
        self.first_common_block = nn.Sequential (
            nn.Conv3d(channels, channels, kernel_size=5, padding=2, bias=False),
            nn.ReLU(inplace=True)
        )
        def _get_triple_conv_block(N):
            pad = (N - 1) // 2
            block = nn.Sequential(
                nn.Conv3d(channels, channels, kernel_size=(1,1,N), padding=(0,0,pad), bias=False),
                nn.ReLU(inplace=True),
                nn.Conv3d(channels, channels, kernel_size=(1,N,1), padding=(0,pad,0), bias=False),
                nn.ReLU(inplace=True),
                nn.Conv3d(channels, channels, kernel_size=(N,1,1), padding=(pad,0,0), bias=False),
                nn.ReLU(inplace=True),
            )
            return block

        self.f2_triple_conv_block = _get_triple_conv_block(7)
        self.f3_triple_conv_block = _get_triple_conv_block(11)

        self.out_conv = nn.Conv3d(channels, channels, kernel_size=1, bias=False)
        self.sigmoid = nn.Sigmoid()
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, nonlinearity="relu")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        f0 = x
        f1 = self.first_common_block(x)
        f2 = self.f2_triple_conv_block(f1)
        f3 = self.f3_triple_conv_block(f1)
        sum_f = f1 + f2 + f3
        out_f = self.sigmoid(self.out_conv(sum_f))
        return f0 * out_f

In [8]:
# ResBlock как в статье: in_channels == out_channels
class ResBlock3D(nn.Module):
    """ResBlock: три Conv3D(3x3x3) + InstanceNorm + LeakyReLU,
       затем residual add и финальный LeakyReLU.
       Предполагается channels == входные и выходные каналы.
    """
    def __init__(self, channels: int, kernel_size: int = 3, negative_slope: float = 0.01):
        super().__init__()
        pad = (kernel_size - 1) // 2
        C = channels
        self.body = nn.Sequential(
            nn.Conv3d(C, C, kernel_size=kernel_size, padding=pad, bias=False),
            nn.InstanceNorm3d(C, affine=True),
            nn.LeakyReLU(negative_slope=negative_slope, inplace=True),

            nn.Conv3d(C, C, kernel_size=kernel_size, padding=pad, bias=False),
            nn.InstanceNorm3d(C, affine=True),
            nn.LeakyReLU(negative_slope=negative_slope, inplace=True),

            nn.Conv3d(C, C, kernel_size=kernel_size, padding=pad, bias=False),
            nn.InstanceNorm3d(C, affine=True)
        )
        self.final_act = nn.LeakyReLU(negative_slope=negative_slope, inplace=True)
        self.apply(self._init_module)

    @staticmethod
    def _init_module(m):
        if isinstance(m, nn.Conv3d):
            nn.init.kaiming_normal_(m.weight, a=0.01, nonlinearity='leaky_relu')
            if getattr(m, "bias", None) is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, (nn.BatchNorm3d, nn.InstanceNorm3d)):
            if getattr(m, "weight", None) is not None:
                nn.init.ones_(m.weight)
            if getattr(m, "bias", None) is not None:
                nn.init.zeros_(m.bias)

    def forward(self, x):
        identity = x
        out = self.body(x)
        out = out + identity
        out = self.final_act(out)
        return out

In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn

class AttentionDecoder(nn.Module):
    def __init__(self, C, use_masca: bool = True, final_sigmoid: bool = False):
        """
        C: базовый канал
        """
        super().__init__()
        self.final_sigmoid = final_sigmoid

        # AttentionSkip и MASCA (как у вас)
        self.skip1 = nn.Sequential(
            AttentionSkipBlock(C, out_channels=C),
            MASCA3D(C)
        )
        self.skip2 = AttentionSkipBlock(2*C, out_channels=2*C)
        self.skip3 = AttentionSkipBlock(4*C, out_channels=4*C)

        # Явные проекции (1x1x1) после concat и перед ResBlock (самый эффективный способ уменьшить каналы)
        self.proj_12_to_4 = nn.Sequential(
            nn.Conv3d(12*C, 4*C, kernel_size=1, bias=False),
            nn.InstanceNorm3d(4*C, affine=True),
            nn.LeakyReLU(0.01, inplace=True)
        )
        self.proj_6_to_2 = nn.Sequential(
            nn.Conv3d(6*C, 2*C, kernel_size=1, bias=False),
            nn.InstanceNorm3d(2*C, affine=True),
            nn.LeakyReLU(0.01, inplace=True)
        )
        self.proj_3_to_1 = nn.Sequential(
            nn.Conv3d(3*C, C, kernel_size=1, bias=False),
            nn.InstanceNorm3d(C, affine=True),
            nn.LeakyReLU(0.01, inplace=True)
        )

        # ResBlocks с равным числом каналов внутри (как требует статья для корректного residual add)
        self.dec1 = ResBlock3D(channels=4*C)
        self.dec2 = ResBlock3D(channels=2*C)
        self.dec3 = ResBlock3D(channels=C)

        # Финал
        self.final_conv = nn.Conv3d(C, 1, kernel_size=1)

    def forward(self, feats):
        feat1, feat2, feat3, bottom = feats  # (B, C, ...), (B,2C,...), (B,4C,...), (B,8C,...)

        # Step 1: upsample bottom (8C -> spatial*2), concat with feat3 (4C) -> 12C
        x = F.interpolate(bottom, scale_factor=2, mode='nearest')   # (B,8C,...)
        skip3 = self.skip3(feat3)                                   # (B,4C,...)
        merged = torch.cat([x, skip3], dim=1)                       # (B,12C,...)
        merged = self.proj_12_to_4(merged)                          # (B,4C,...)
        out1 = self.dec1(merged)                                    # (B,4C,...)

        # Step 2
        x = F.interpolate(out1, scale_factor=2, mode='nearest')     # (B,4C,...)
        skip2 = self.skip2(feat2)                                   # (B,2C,...)
        merged = torch.cat([x, skip2], dim=1)                       # (B,6C,...)
        merged = self.proj_6_to_2(merged)                           # (B,2C,...)
        out2 = self.dec2(merged)                                    # (B,2C,...)

        # Step 3
        x = F.interpolate(out2, scale_factor=2, mode='nearest')     # (B,2C,...)
        skip1 = self.skip1(feat1)                                   # (B,C,...)
        merged = torch.cat([x, skip1], dim=1)                       # (B,3C,...)
        merged = self.proj_3_to_1(merged)                           # (B,C,...)
        out3 = self.dec3(merged)                                    # (B,C,...)

        # Дополнительно: вернуть к исходному пространственному разрешению,
        # если нужно (в статье это не явно, но часто делают).
        out4 = F.interpolate(out3, scale_factor=2, mode='nearest')  # spatial x2, каналы C

        logits = self.final_conv(out4)  # (B,1,D,H,W)
        if self.final_sigmoid:
            return torch.sigmoid(logits)
        return logits


In [10]:
class AttentionFaultFormerNet(nn.Module):

    def __init__(self, encoder_callable, num_classes: int = 1, use_deconv: bool = False, final_sigmoid: bool = False):
        super().__init__()
        # encoder_callable может быть nn.Module или любой callable, возвращающий list[Tensor]
        self.encoder = encoder_callable
        self.num_classes = num_classes
        self.decoder = None
        self.final_resblock = None
        self.final_conv = None

    def _lazy_init(self, C, device):
        self.decoder = AttentionDecoder(C)
        # переносим на device
        self.to(device)
        self._decoder_initialized = True


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # ожидаем вход B,C,D,H,W
        assert isinstance(x, torch.Tensor)
        assert x.ndim == 5
        feats = self.encoder(x)
        # после feats = self.encoder(x) в encoder.forward (или в _lazy_init где вы берете feats)
        C = feats[0].shape[1]
        # ожидаем [C, 2C, 4C, 8C] (например [48,96,192,384])
        # простая assert (чтобы падало ранним и понятным сообщением)
        device = feats[0].device
        if self.decoder is None:
            self._lazy_init(C, device)

        out = self.decoder(feats)

        return out


In [None]:
enc = AttentionFaultFormerEncoder(
    in_chans=1,
    embed_dim=48,
    # обязательно 4 значения для MONAI Swin UNETR
    depths=(2, 2, 2, 1),
    num_heads=(3, 6, 12, 12),
    use_checkpoint=False
)

model = AttentionFaultFormerNet(enc, num_classes=1, final_sigmoid=False)