In [2]:
#| default_exp backbones
#| export
from typing import Tuple, Dict, Literal, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from loguru import logger
import math
from qct_3d_nod_detect.SimpleFPN import BackboneFPN, SimpleFPN

class DropPath(nn.Module):
    def __init__(self, drop_prob: float = 0.0) -> None:
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore[override]
        if self.drop_prob == 0.0 or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()
        return x.div(keep_prob) * random_tensor


class LayerScale(nn.Module):
    def __init__(self, dim: int, init_values: float = 1e-5) -> None:
        super().__init__()
        self.gamma = nn.Parameter(init_values * torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore[override]
        return x * self.gamma


class TubePatchEmbed3D(nn.Module):
    def __init__(
        self,
        img_size: Tuple[int, int, int],
        patch_size: Tuple[int, int, int],
        in_channels: int,
        embed_dim: int,
    ) -> None:
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (
            img_size[0] // patch_size[0],
            img_size[1] // patch_size[1],
            img_size[2] // patch_size[2],
        )
        self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
        self.proj = nn.Conv3d(
            in_channels,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore[override]
        x = self.proj(x)
        b, d, z, y, xdim = x.shape
        return x.permute(0, 2, 3, 4, 1).reshape(b, z * y * xdim, d)


def get_3d_sincos_pos_embed(embed_dim: int, grid_size: Tuple[int, int, int]) -> torch.Tensor:
    gz, gy, gx = grid_size
    grid_z = torch.arange(gz, dtype=torch.float32)
    grid_y = torch.arange(gy, dtype=torch.float32)
    grid_x = torch.arange(gx, dtype=torch.float32)
    grid = torch.stack(torch.meshgrid(grid_z, grid_y, grid_x, indexing="ij"), dim=-1).reshape(-1, 3)

    def _emb(dim_val: int, coord: torch.Tensor) -> torch.Tensor:
        omega = torch.arange(dim_val // 2, dtype=torch.float32)
        if len(omega) > 1:
            omega = omega / (len(omega) - 1)
        omega = 1.0 / (10000**omega)
        out = torch.einsum("n,d->nd", coord, omega)
        return torch.cat([out.sin(), out.cos()], dim=1)

    dim_z = embed_dim // 3
    dim_y = embed_dim // 3
    dim_x = embed_dim - dim_z - dim_y
    pos = torch.cat(
        [
            _emb(dim_z, grid[:, 0]),
            _emb(dim_y, grid[:, 1]),
            _emb(dim_x, grid[:, 2]),
        ],
        dim=1,
    )
    if pos.shape[1] < embed_dim:
        pad = embed_dim - pos.shape[1]
        pos = torch.nn.functional.pad(pos, (0, pad))
    return pos


class MLP(nn.Module):
    def __init__(self, dim: int, mlp_ratio: float, drop: float) -> None:
        super().__init__()
        hidden = int(dim * mlp_ratio)
        self.fc1 = nn.Linear(dim, hidden)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden, dim)
        self.drop = nn.Dropout(drop)

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore[override]
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        qkv_bias: bool,
        attn_drop: float,
        proj_drop: float,
        use_sdpa: bool,
    ) -> None:
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim**-0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.use_sdpa = use_sdpa and hasattr(F, "scaled_dot_product_attention")

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore[override]
        b, n, c = x.shape
        qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        if self.use_sdpa:
            x = F.scaled_dot_product_attention(q, k, v)
        else:
            attn = (q @ k.transpose(-2, -1)) * self.scale
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = attn @ v
        x = x.transpose(1, 2).reshape(b, n, c)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Block(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        drop: float = 0.0,
        attn_drop: float = 0.0,
        drop_path: float = 0.0,
        use_layer_scale: bool = True,
        layer_scale_init: float = 1e-5,
        use_sdpa: bool = True,
    ) -> None:
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, num_heads, qkv_bias, attn_drop, drop, use_sdpa)
        self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, mlp_ratio, drop)
        self.use_layer_scale = use_layer_scale
        if use_layer_scale:
            self.ls1 = LayerScale(dim, layer_scale_init)
            self.ls2 = LayerScale(dim, layer_scale_init)

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore[override]
        if self.use_layer_scale:
            x = x + self.drop_path(self.ls1(self.attn(self.norm1(x))))
            x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))
        else:
            x = x + self.drop_path(self.attn(self.norm1(x)))
            x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class ViT3D(nn.Module):
    def __init__(
        self,
        img_size: Tuple[int, int, int],
        patch_size: Tuple[int, int, int],
        in_channels: int,
        embed_dim: int,
        depth: int,
        num_heads: int,
        mlp_ratio: float,
        qkv_bias: bool,
        drop_rate: float,
        attn_drop_rate: float,
        drop_path_rate: float,
        use_cls_token: bool,
        use_sdpa: bool,
        use_layer_scale: bool = True,
        layer_scale_init: float = 1e-5,
        num_register_tokens: int = 0,
    ) -> None:
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.use_cls_token = use_cls_token
        self.num_register_tokens = num_register_tokens

        self.patch_embed = TubePatchEmbed3D(img_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.num_patches
        self.prefix_tokens = (1 if use_cls_token else 0) + num_register_tokens
        self.grid_size = (
            img_size[0] // patch_size[0],
            img_size[1] // patch_size[1],
            img_size[2] // patch_size[2],
        )

        if use_cls_token:
            self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        if num_register_tokens > 0:
            self.register_tokens = nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim))

        pos_embed = get_3d_sincos_pos_embed(embed_dim, self.grid_size)
        seq = []
        if use_cls_token:
            seq.append(torch.zeros(1, embed_dim))
        if num_register_tokens > 0:
            seq.append(torch.zeros(num_register_tokens, embed_dim))
        seq.append(pos_embed)
        full_pos = torch.cat(seq, dim=0).unsqueeze(0)
        self.pos_embed = nn.Parameter(full_pos, requires_grad=False)
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = torch.linspace(0, drop_path_rate, depth).tolist()
        self.blocks = nn.ModuleList(
            [
                Block(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[i],
                    use_layer_scale=use_layer_scale,
                    layer_scale_init=layer_scale_init,
                    use_sdpa=use_sdpa,
                )
                for i in range(depth)
            ]
        )
        self.norm = nn.LayerNorm(embed_dim)
        self.output_dim = embed_dim
        self.num_patches = num_patches

    def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:  # type: ignore[override]
        b = x.shape[0]
        tokens = self.patch_embed(x)  # [B, N, D]
        if self.use_cls_token:
            cls_tokens = self.cls_token.expand(b, -1, -1)
            tokens = torch.cat([cls_tokens, tokens], dim=1)
        if self.num_register_tokens > 0:
            regs = self.register_tokens.expand(b, -1, -1)
            tokens = (
                torch.cat([tokens[:, :1, :], regs, tokens[:, 1:, :]], dim=1)
                if self.use_cls_token
                else torch.cat([regs, tokens], dim=1)
            )
        tokens = tokens + self.pos_embed.to(tokens.dtype).to(tokens.device)
        tokens = self.pos_drop(tokens)
        for blk in self.blocks:
            tokens = blk(tokens)
        tokens = self.norm(tokens)
        patch_tokens = tokens[:, self.prefix_tokens :, :]
        global_token = tokens[:, 0] if self.use_cls_token else patch_tokens.mean(dim=1)
        return {
            "feat_tokens": tokens,
            "patch_tokens": patch_tokens,
            "global": global_token,
        }

In [1]:
#| export
def _clean_key(key: str) -> str:
    for prefix in ("model.", "module."):
        if key.startswith(prefix):
            key = key[len(prefix) :]
    if key.startswith("encoder."):
        key = key[len("encoder.") :]
    return key


def _adapt_patch_embed_channels(
    weight: torch.Tensor, target_shape: torch.Size, in_chans: int
) -> torch.Tensor | None:
    if weight.ndim != 5 or weight.shape[2:] != target_shape[2:]:
        return None
    out_channels, old_in, _, _, _ = weight.shape
    if out_channels != target_shape[0]:
        return None
    if old_in == in_chans:
        return weight
    if in_chans == 1:
        return weight.mean(dim=1, keepdim=True)
    if in_chans < old_in:
        trimmed = weight[:, :in_chans, :, :, :]
        return trimmed * (old_in / in_chans)
    repeat = math.ceil(in_chans / old_in)
    expanded = weight.repeat(1, repeat, 1, 1, 1)[:, :in_chans, :, :, :]
    expanded *= old_in / in_chans
    return expanded

def load_state_dict_into_vit(vit: ViT3D, ckpt_path: Path) -> tuple[list[str], list[str]]:

    raw = torch.load(str(ckpt_path), map_location="cpu", weights_only=False)
    sd = raw.get("state_dict", raw.get("model", raw))
    if not isinstance(sd, dict):
        if hasattr(sd, "state_dict"):
            sd = sd.state_dict()
        else:
            raise RuntimeError("Unsupported checkpoint format.")
    target = vit.state_dict()
    cleaned: Dict[str, torch.Tensor] = {}
    for k, v in sd.items():
        key = _clean_key(k)
        if key == "pos_embed" and v.shape != target.get("pos_embed", v).shape:
            continue
        if (
            key == "patch_embed.proj.weight"
            and key in target
            and v.shape[1] != target[key].shape[1]
        ):
            adapted = _adapt_patch_embed_channels(v, target[key].shape, target[key].shape[1])
            if adapted is not None:
                cleaned[key] = adapted
                continue
        if key in target and getattr(v, "shape", None) == target[key].shape:
            cleaned[key] = v
    missing, unexpected = vit.load_state_dict(cleaned, strict=False)
    if missing:
        logger.info("Missing MAE keys during load: %s", missing)
    if unexpected:
        logger.info("Unexpected MAE keys during load: %s", unexpected)
    return list(missing), list(unexpected)

def build_vit_backbone_with_fpn(
    variant: Literal["S", "L"],
    ckpt_path: Union[Path, str],
    scales: Tuple[float, ...] = (1, 2, 0.5, 0.25),
    out_channels: int = 256,
) -> nn.Module:

    if variant == "S":
        model_config = {
            "img_size": (128, 128, 128),
            "patch_size": (16, 16, 16),
            "in_channels": 1,
            "embed_dim": 384,
            "depth": 12,
            "num_heads": 6,
            "mlp_ratio": 4.0,
            "qkv_bias": True,
            "drop_rate": 0.0,
            "attn_drop_rate": 0.0,
            "drop_path_rate": 0.1,
            "use_cls_token": True,
            "use_sdpa": True, # TODO Populated with sample, will change later
        }

    elif variant == "L":
        model_config = {
            "img_size": (128, 128, 128),
            "patch_size": (8, 8, 8),
            "in_channels": 1,
            "embed_dim": 1024,
            "depth": 24,
            "num_heads": 16,
            "mlp_ratio": 4.0,
            "qkv_bias": True,
            "drop_rate": 0.0,
            "attn_drop_rate": 0.0,
            "drop_path_rate": 0.0,
            "use_cls_token": True,
            "use_sdpa": True,
        }


    # Load backbone
    backbone = ViT3D(**model_config)
    if ckpt_path is not None:
        if not isinstance(ckpt_path, Path):
            ckpt_path = Path(ckpt_path)
            print(f"Loading from pretrained checkpoint - {ckpt_path}")

        missing, unexpected = load_state_dict_into_vit(backbone, ckpt_path)
        print(f"Missing keys: {missing}, Unexpected keys: {unexpected}")
    
    else:
        print("loading random weights")

    # Build with FPN
    fpn = SimpleFPN(dim=model_config["embed_dim"], out_channels=out_channels, scales=scales)
    backbone_fpn = BackboneFPN(
        backbone=backbone,
        fpn=fpn,
        patch_grid_size=(
            model_config["img_size"][0] // model_config["patch_size"][0],
            model_config["img_size"][1] // model_config["patch_size"][1],
            model_config["img_size"][2] // model_config["patch_size"][2],
        ),
    )

    return backbone_fpn

NameError: name 'torch' is not defined

In [4]:
#| hide
backbone_fpn = build_vit_backbone_with_fpn(
    variant="S",
    ckpt_path=None,
    scales=[1, 2, 0.5, 0.25],
    out_channels=256
)



loading random weights


In [5]:
x = torch.randn(1, 1, 128, 128, 128)
out = backbone_fpn(x)

RuntimeError: The size of tensor a (513) must match the size of tensor b (385) at non-singleton dimension 1

device(type='cpu')