In [1]:
!pip install monai

Collecting monai
  Downloading monai-1.5.1-py3-none-any.whl.metadata (13 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.4.1->monai)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.4.1->monai)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.4.1->monai)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.4.1->monai)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=2.4.1->monai)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=2.4.1->monai)
  Downloading nvidia_cufft

In [2]:
import torch
import torch.nn as nn
from monai.networks.nets.swin_unetr import SwinTransformer
import torch.nn.functional as F
from typing import List, Optional, Tuple

2025-11-05 17:11:35.402812: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1762362695.576149      37 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1762362695.626166      37 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [13]:
class AttentionFaultFormerEncoder(nn.Module):
    """
    Обёртка над MONAI SwinTransformer, возвращающая ровно 3 cascaded stages.
    Заменяет encoder.patch_embed.proj на Conv3d(kernel=5, stride=2, padding=2).
    Важно: используемую версию MONAI должна содержать SwinTransformer и
    атрибут encoder.patch_embed.proj.
    """
    def __init__(
        self,
        in_chans: int = 1,
        embed_dim: int = 48,
        window_size=(7,7,7),
        patch_size=(2,2,2),
        depths=(2,2,2,1),           
        num_heads=(3,6,12,12),
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.1,
        norm_layer=nn.LayerNorm,
        patch_norm: bool = True,
        use_checkpoint: bool = True,
        spatial_dims: int = 3,
        patch_kernel: int = 5,
        patch_stride: int = 2,
        patch_padding: int = 2,
    ):
        super().__init__()
        if SwinTransformer is None:
            raise ImportError("MONAI SwinTransformer не найден. Установите monai>=0.x и убедитесь, что модуль swin_unetr доступен.")

        # создаём стандартный SwinTransformer (MONAI)
        self.encoder = SwinTransformer(
            in_chans=in_chans,
            embed_dim=embed_dim,
            window_size=window_size,
            patch_size=patch_size,
            depths=depths,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            qkv_bias=qkv_bias,
            drop_rate=drop_rate,
            attn_drop_rate=attn_drop_rate,
            drop_path_rate=drop_path_rate,
            norm_layer=norm_layer,
            patch_norm=patch_norm,
            use_checkpoint=use_checkpoint,
            spatial_dims=spatial_dims,
        )

        # заменяем внутр. проекцию patch-embed на Conv3d(kernel=5, stride=2, padding=2)
        self.encoder.patch_embed.proj = nn.Conv3d(
            in_channels=in_chans,
            out_channels=embed_dim,
            kernel_size=patch_kernel,
            stride=patch_stride,
            padding=patch_padding,
            bias=False,
        )
        nn.init.kaiming_normal_(self.encoder.patch_embed.proj.weight, nonlinearity="relu")
  

    def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
        """
        Возвращает список фичей [stage1, stage2, stage3]
        (берём из возвращаемого MONAI списка элементы 1..3).
        """
        feats = self.encoder(x)
        # В статье используются 3 стадии; берем feats[1], feats[2], feats[3]
        return [feats[1], feats[2], feats[3]]


In [14]:
class ChannelAttention3D(nn.Module):
    """
    Канальное внимание для 3D фич (часть CBAM).
    Считает маску каналов через MLP, применяемый к avg/max pooled дескрипторам.
    """
    def __init__(self, channels: int, reduction_ratio: int = 4):
        super().__init__()
        hidden = max(1, channels // reduction_ratio)
        self.mlp = nn.Sequential(
            nn.Conv3d(channels, hidden, kernel_size=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv3d(hidden, channels, kernel_size=1, bias=False)
        )
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.max_pool = nn.AdaptiveMaxPool3d(1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        avg = self.avg_pool(x)
        max_ = self.max_pool(x)
        avg_out = self.mlp(avg)
        max_out = self.mlp(max_)
        return self.sigmoid(avg_out + max_out)

In [15]:
class SpatialAttention3D(nn.Module):
    """
    Пространственное внимание для 3D: объединяет среднее и максимум по каналам
    и применяет свёртку.
    """
    def __init__(self, kernel_size: int = 7):
        super().__init__()
        padding = (kernel_size - 1) // 2
        self.conv = nn.Conv3d(2, 1, kernel_size=kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        avg = torch.mean(x, dim=1, keepdim=True)
        max_ = torch.max(x, dim=1, keepdim=True)[0]
        cat = torch.cat([avg, max_], dim=1)
        return self.sigmoid(self.conv(cat))

In [16]:
class CBAM3D(nn.Module):
    """
    CBAM для 3D: каналное внимание, затем пространственное внимание.
    """
    def __init__(self, channels: int, reduction_ratio: int = 4, spatial_kernel: int = 7, apply_spatial: bool = True):
        super().__init__()
        self.channel_attn = ChannelAttention3D(channels, reduction_ratio)
        self.apply_spatial = apply_spatial
        self.spatial_attn = SpatialAttention3D(kernel_size=spatial_kernel)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x * self.channel_attn(x)
        if self.apply_spatial:
            x = x * self.spatial_attn(x)
        return x

In [17]:
class AttentionSkipBlock(nn.Module):
    """
    Conv3d(5x5x5) -> CBAM -> BatchNorm3d -> ReLU.
    """
    def __init__(self, in_channels: int, out_channels: int, reduction_ratio: int = 4, apply_spatial: bool = True):
        super().__init__()
        self.proj = nn.Conv3d(in_channels, out_channels, kernel_size=5, padding=2, bias=False)
        self.cbam = CBAM3D(out_channels, reduction_ratio=reduction_ratio, apply_spatial=apply_spatial)
        self.bn = nn.BatchNorm3d(out_channels)
        self.act = nn.ReLU(inplace=True)
        nn.init.kaiming_normal_(self.proj.weight, nonlinearity="relu")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.proj(x)
        x = self.cbam(x)
        x = self.bn(x)
        return self.act(x)

In [18]:
class MASCA3D(nn.Module):
    """
    MASCA: полосатые свёртки разной длины с суммированием и 1x1 свёрткой,
    затем сигмоида для получения воксельной маски.
    """
    def __init__(self, channels: int, N_list: Tuple[int, ...] = (7, 11)):
        super().__init__()
        self.f1_conv = nn.Conv3d(channels, channels, kernel_size=5, padding=2, bias=False)
        self.N_list = list(N_list)
        self.strip_blocks = nn.ModuleList()
        for N in self.N_list:
            pad = (N - 1) // 2
            block = nn.Sequential(
                nn.Conv3d(channels, channels, kernel_size=(1,1,N), padding=(0,0,pad), bias=False),
                nn.ReLU(inplace=True),
                nn.Conv3d(channels, channels, kernel_size=(1,N,1), padding=(0,pad,0), bias=False),
                nn.ReLU(inplace=True),
                nn.Conv3d(channels, channels, kernel_size=(N,1,1), padding=(pad,0,0), bias=False),
                nn.ReLU(inplace=True),
            )
            self.strip_blocks.append(block)
        self.out_conv = nn.Conv3d(channels, channels, kernel_size=1, bias=False)
        self.sigmoid = nn.Sigmoid()
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, nonlinearity="relu")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        f1 = self.f1_conv(x)
        summ = f1
        for block in self.strip_blocks:
            summ = summ + block(f1)
        w = self.out_conv(summ)
        w = self.sigmoid(w)
        return x * w

In [19]:
class ResBlock3D(nn.Module):
    """
    Резидуальный блок: три свёртки с InstanceNorm и LeakyReLU.
    """
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, negative_slope: float = 0.01):
        super().__init__()
        pad = (kernel_size - 1) // 2
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=pad, bias=False)
        self.in1 = nn.InstanceNorm3d(out_channels, affine=True)
        self.act1 = nn.LeakyReLU(negative_slope=negative_slope, inplace=True)

        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, padding=pad, bias=False)
        self.in2 = nn.InstanceNorm3d(out_channels, affine=True)
        self.act2 = nn.LeakyReLU(negative_slope=negative_slope, inplace=True)

        self.conv3 = nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, padding=pad, bias=False)
        self.in3 = nn.InstanceNorm3d(out_channels, affine=True)

        self.skip_proj = None
        if in_channels != out_channels:
            self.skip_proj = nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=False)
            nn.init.kaiming_normal_(self.skip_proj.weight, nonlinearity="linear")

        self.final_act = nn.LeakyReLU(negative_slope=negative_slope, inplace=True)
        for m in (self.conv1, self.conv2, self.conv3):
            nn.init.kaiming_normal_(m.weight, nonlinearity="leaky_relu")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x
        out = self.conv1(x); out = self.in1(out); out = self.act1(out)
        out = self.conv2(out); out = self.in2(out); out = self.act2(out)
        out = self.conv3(out); out = self.in3(out)
        if self.skip_proj is not None:
            identity = self.skip_proj(identity)
        out = out + identity
        return self.final_act(out)

In [20]:
class AttentionUNetDecoderRes(nn.Module):
    """
    Декодер с attention skip и residual блоками.
    """
    def __init__(self,
                 encoder_channels: List[int],
                 decoder_channels: List[int],
                 reduction_ratio: int = 4,
                 use_deconv: bool = False,
                 use_masca: bool = True,
                 masca_N_list: Tuple[int, ...] = (7, 11),
                 masca_on_shallow: bool = False):
        super().__init__()
        assert len(encoder_channels) == len(decoder_channels)
        self.n_stages = len(encoder_channels)
        self.use_deconv = use_deconv
        self.masca_on_shallow = masca_on_shallow

        self.deep_proj = nn.Conv3d(encoder_channels[-1], decoder_channels[0], kernel_size=1, bias=False)
        self.deep_bn = nn.BatchNorm3d(decoder_channels[0])
        self.deep_act = nn.ReLU(inplace=True)

        self.attn_skips = nn.ModuleList()
        for i in range(self.n_stages - 1):
            enc_idx = self.n_stages - 2 - i
            c_enc = encoder_channels[enc_idx]
            c_dec = decoder_channels[i + 1]
            self.attn_skips.append(
                AttentionSkipBlock(in_channels=c_enc, out_channels=c_dec,
                                   reduction_ratio=reduction_ratio)
            )

        if use_masca:
            self.masca = MASCA3D(decoder_channels[-1], N_list=masca_N_list)
        else:
            self.masca = None

        self.upsamplers = nn.ModuleList()
        for i in range(self.n_stages - 1):
            ch = decoder_channels[i]
            # в оригинале могли быть ConvTranspose3d; для подсчёта параметров можно оставить Identity
            self.upsamplers.append(nn.Identity())

        self.decode_resblocks = nn.ModuleList()
        for i in range(self.n_stages - 1):
            in_ch = decoder_channels[i] + decoder_channels[i + 1]
            out_ch = decoder_channels[i + 1]
            self.decode_resblocks.append(ResBlock3D(in_ch, out_ch))

    def forward(self, feats: List[torch.Tensor]) -> torch.Tensor:
        assert len(feats) == self.n_stages, f"Ожидал {self.n_stages} фич, получил {len(feats)}"
        x = feats[-1]
        x = self.deep_proj(x); x = self.deep_bn(x); x = self.deep_act(x)

        for i in range(self.n_stages - 1):
            skip = feats[self.n_stages - 2 - i]
            if isinstance(self.upsamplers[i], nn.ConvTranspose3d):
                x = self.upsamplers[i](x)
            else:
                target_size = skip.shape[2:]
                x = F.interpolate(x, size=target_size, mode='nearest')
            skip_proj = self.attn_skips[i](skip)
            if (self.masca is not None) and (i == self.n_stages - 2):
                skip_proj = self.masca(skip_proj)
            if x.shape[2:] != skip_proj.shape[2:]:
                skip_proj = F.interpolate(skip_proj, size=x.shape[2:], mode='trilinear', align_corners=False)
            x = torch.cat([x, skip_proj], dim=1)
            x = self.decode_resblocks[i](x)
        return x

In [21]:
class AttentionFaultFormerNet(nn.Module):

    def __init__(self, encoder_callable, num_classes: int = 1, use_deconv: bool = False, final_sigmoid: bool = True):
        super().__init__()
        # encoder_callable может быть nn.Module или любой callable, возвращающий list[Tensor]
        self.encoder = encoder_callable
        self.num_classes = num_classes
        self.use_deconv = use_deconv
        self.final_sigmoid = final_sigmoid
        self.decoder = None
        self.final_resblock = None
        self.final_conv = None

    def _lazy_init(self, feats: List[torch.Tensor]):
        enc_ch = [f.shape[1] for f in feats]
        dec_ch = enc_ch[::-1]
        self.decoder = AttentionUNetDecoderRes(encoder_channels=enc_ch, decoder_channels=dec_ch, use_deconv=self.use_deconv)
        self.final_resblock = ResBlock3D(dec_ch[-1], dec_ch[-1])
        self.final_conv = nn.Conv3d(dec_ch[-1], self.num_classes, kernel_size=1, bias=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        feats = self.encoder(x)
        if self.decoder is None:
            self._lazy_init(feats)
        dec = self.decoder(feats)
        out = self.final_resblock(dec)
        logits = self.final_conv(out)
        if self.final_sigmoid:
            logits = torch.sigmoid(logits)
        return logits

In [22]:
encoder = AttentionFaultFormerEncoder(in_chans=1, embed_dim=48, use_checkpoint=False)
model = AttentionFaultFormerNet(encoder_callable=encoder, num_classes=1, use_deconv=False, final_sigmoid=True)