In [1]:
import numpy as np
import pandas as pd
import torch 
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import sys
import os
from encoders.base import EncoderBase
from monai.networks.nets import SegResNet
from typing import Dict, List, Optional


In [2]:
from building_blocks import UAFS
from building_blocks import MAFS


In [3]:
class MonaiSegResNetEncoder(EncoderBase):
    """
    Wraps MONAI's SegResNet to function as a pure encoder.
    
    Why SegResNet?
    - Uses GroupNorm by default (better for small batch sizes in 3D).
    - Highly optimized for GPU memory.
    - Native .encode() method returns hierarchical features.
    """
    def __init__(
        self, 
        in_channels: int = 1, 
        feature_channels: List[int] = [16, 32, 64, 128], 
        spatial_dims: int = 3
    ):
        # SegResNet expects an `init_filters` arg (the first stage width)
        # and `dropout_prob` etc.
        # We assume feature_channels follows a doubling pattern like [16, 32, 64, 128]
        
        super().__init__(in_channels, feature_channels)
        
        self.net = SegResNet(
            spatial_dims=spatial_dims,
            init_filters=feature_channels[0], # e.g. 16
            in_channels=in_channels,
            out_channels=1, # Dummy value, we won't use the decoder
            dropout_prob=0.2,
            # blocks_down defines how many ResNet blocks per stage. 
            # [1, 2, 2, 4] is a common default configuration.
            blocks_down=(1, 2, 2, 4), 
        )

    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        SegResNet.encode(x) returns:
          - x_final (Tensor): The bottleneck feature (lowest resolution)
          - down_x (List[Tensor]): A list of intermediate features (high to low res)
                                   BUT usually in reverse order of generation? 
                                   Let's verify the standard behavior.
        """
        # MONAI SegResNet encode returns:
        # x (bottleneck), layers (list of skip connections)
        bottleneck, skips = self.net.encode(x)
        
        # skips contains features from [Resolution 1, Resolution 1/2, Resolution 1/4 ...]
        # bottleneck is Resolution 1/8 (if 4 stages)
        
        features = {}
        
        # 1. Add the high-res stages from the skip connections
        # Note: skips[0] is usually the input convolution output (Resolution 1)
        for i, skip in enumerate(skips):
            features[f"stage{i+1}"] = skip
            
        # 2. Add the bottleneck as the final stage
        final_stage_idx = len(skips) + 1
        features[f"stage{final_stage_idx}"] = bottleneck
        
        return features
    

In [None]:
class AFN3D(nn.Module): ## I assume symetrical input 

    def __init__(self):
        super().__init__()
        self.encoder = MonaiSegResNetEncoder(in_channels=1,feature_channels=[16,32,64,128],spatial_dims = 3)
        self.uafs1 = UAFS(128,64)
        self.uafs2 = UAFS(64, 32)
        self.uafs3 = UAFS(32,16)
        self.mafs = MAFS(16)

    def forward(self,x:torch.Tensor):

        features = self.encoder(x) ## stage1-16 stage2-32 stage3-64 
        x_s = x_t = features["stage4"]
        x_s, x_t = self.uafs1(x_s,x_t,features["stage3"])
        x_s, x_t = self.uafs2(x_s,x_t,features["stage2"])
        x_s, x_t = self.uafs3(x_s,x_t,features["stage1"])
        pred,affs = self.mafs(x_s,x_t)
        print(pred.shape,affs.shape)
        return pred, affs

        



In [11]:
model = AFN3D()

In [6]:
input = torch.rand([3,1,64,64,64])

In [7]:
input.shape

torch.Size([3, 1, 64, 64, 64])

In [8]:
output = model(input)

torch.Size([3, 128, 8, 8, 8])
torch.Size([3, 64, 16, 16, 16])
torch.Size([3, 32, 32, 32, 32])
torch.Size([3, 1, 64, 64, 64]) torch.Size([3, 78, 64, 64, 64])
