<a href="https://colab.research.google.com/github/alim98/Thesis/blob/main/H_vit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/your-username/hvit.git
!cd hvit

Cloning into 'hvit'...
fatal: could not read Username for 'https://github.com': No such device or address
/bin/bash: line 1: cd: hvit: No such file or directory


# Dataset

In [None]:
import os, glob
import monai
import torch
import pickle
import random
import numpy as np
from torch.utils.data import Dataset


class OASIS_Dataset(Dataset):
    def __init__(self, input_dim, data_path, num_steps=1000, is_pair: bool = False, ext="pkl"):
        self.paths = glob.glob(os.path.join(data_path, f"*.{ext}"))
        self.num_steps = num_steps
        self.input_dim = input_dim
        self.is_pair = is_pair

        self.transforms_mask = monai.transforms.Compose([
            monai.transforms.Resize(spatial_size=input_dim, mode="nearest")
        ])
        self.transforms_image = monai.transforms.Compose([
            monai.transforms.Resize(spatial_size=input_dim)
        ])

    def _pkload(self, filename: str) -> tuple:
        """
        Load a pickled file and return its contents.

        Args:
            filename (str): The path to the pickled file.

        Returns:
            tuple: The unpickled contents of the file.

        Raises:
            FileNotFoundError: If the file does not exist.
            pickle.UnpicklingError: If there's an error during unpickling.
        """
        try:
            with open(filename, 'rb') as file:
                return pickle.load(file) #np.ascontiguousarray(pickle.load(file))
        except FileNotFoundError:
            raise FileNotFoundError(f"The file {filename} was not found.")
        except pickle.UnpicklingError:
            raise pickle.UnpicklingError(f"Error unpickling the file {filename}.")

    def __getitem__(self, index):
        if self.is_pair:
            src, tgt, src_lbl, tgt_lbl = self._pkload(self.paths[index])
        else:
            selected_items = random.sample(list(self.paths), 2)
            src, src_lbl = self._pkload(selected_items[0])
            tgt, tgt_lbl = self._pkload(selected_items[1])

        src = torch.from_numpy(src).float().unsqueeze(0)
        src_lbl = torch.from_numpy(src_lbl).long().unsqueeze(0)
        tgt = torch.from_numpy(tgt).float().unsqueeze(0)
        tgt_lbl = torch.from_numpy(tgt_lbl).long().unsqueeze(0)

        src = self.transforms_image(src)
        tgt = self.transforms_image(tgt)
        src_lbl = self.transforms_mask(src_lbl)
        tgt_lbl = self.transforms_mask(tgt_lbl)

        return src, tgt, src_lbl, tgt_lbl

    def __len__(self):
        return self.num_steps if not self.is_pair else len(self.paths)

def get_dataloader(data_path, input_dim, batch_size, shuffle: bool = True, is_pair: bool = False):
    ds = OASIS_Dataset(input_dim = input_dim, data_path = data_path, is_pair=is_pair)
    dataloader = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=4)
    return dataloader

# Blocks


In [None]:
import math
import warnings

import torch
import torch.nn as nn
from torch.nn import functional as F

from torch import Tensor
from einops import rearrange
from functools import reduce

ndims = 3 # H,W,D


# these functions are adopted from timm.
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.
    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0 and scale_by_keep:
        random_tensor.div_(keep_prob)
    return x * random_tensor


class timm_DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
        super(timm_DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)

    def extra_repr(self):
        return f'drop_prob={round(self.drop_prob,3):0.3f}'


def _trunc_normal_(tensor, mean, std, a, b):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",
                      stacklevel=2)

    # Values are generated by using a truncated uniform distribution and
    # then using the inverse CDF for the normal distribution.
    # Get upper and lower cdf values
    l = norm_cdf((a - mean) / std)
    u = norm_cdf((b - mean) / std)

    # Uniformly fill tensor with values from [l, u], then translate to
    # [2l-1, 2u-1].
    tensor.uniform_(2 * l - 1, 2 * u - 1)

    # Use inverse cdf transform for normal distribution to get truncated
    # standard normal
    tensor.erfinv_()

    # Transform to proper mean, std
    tensor.mul_(std * math.sqrt(2.))
    tensor.add_(mean)

    # Clamp to ensure it's in the proper range
    tensor.clamp_(min=a, max=b)
    return tensor



def timm_trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    # type: (Tensor, float, float, float, float) -> Tensor
    r"""Fills the input Tensor with values drawn from a truncated
    normal distribution. The values are effectively drawn from the
    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
    with values outside :math:`[a, b]` redrawn until they are within
    the bounds. The method used for generating the random values works
    best when :math:`a \leq \text{mean} \leq b`.
    NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
    applied while sampling the normal with mean/std applied, therefore a, b args
    should be adjusted to match the range of mean, std args.
    Args:
        tensor: an n-dimensional `torch.Tensor`
        mean: the mean of the normal distribution
        std: the standard deviation of the normal distribution
        a: the minimum cutoff value
        b: the maximum cutoff value
    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.trunc_normal_(w)
    """
    with torch.no_grad():
        return _trunc_normal_(tensor, mean, std, a, b)



class Conv3dReLU(nn.Sequential):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            padding=0,
            stride=1,
            use_batchnorm=True,
    ):
        conv = nn.Conv3d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            bias=False,
        )
        relu = nn.LeakyReLU(inplace=True)
        if not use_batchnorm:
            nm = nn.InstanceNorm3d(out_channels)
        else:
            nm = nn.BatchNorm3d(out_channels)
        super(Conv3dReLU, self).__init__(conv, nm, relu)


def get_norm(name, **kwargs):
    if name.lower() == 'BatchNorm2d'.lower():
        BatchNorm = getattr(nn, 'BatchNorm%dd' % ndims)
        return BatchNorm(**kwargs)
    elif name.lower() == 'instance':
        InstanceNorm = getattr(nn, 'InstanceNorm%dd' % ndims)
        return InstanceNorm(**kwargs)
    elif name.lower() == 'None'.lower():
        return nn.Identity()
    else:
        return NotImplementedError


def get_activation(name, **kwargs):
    if name.lower() == 'ReLU'.lower():
        return nn.ReLU()
    elif name.lower() == 'GELU'.lower():
        return nn.GELU()
    elif name.lower() == 'None'.lower():
        return nn.Identity()
    else:
        return NotImplementedError

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def prod_func(Vec):
  return reduce( lambda x, y: x*y, Vec ) #  math.prod()

def downsampler_fn(img, out_size):
    """
    input sahep: B,C,H,W,D
    output sahep: B,C,H,W,D
    """
    return nn.functional.interpolate(img,
                                     size=out_size,
                                     mode='nearest',
                                     align_corners=None,
                                     recompute_scale_factor=None,
                                     #antialias=False
                                     )


def get_norm(name, **kwargs):
    if name.lower() == 'BatchNorm'.lower():
        BatchNorm = getattr(nn, 'BatchNorm%dd' % ndims)
        return BatchNorm(**kwargs)
    elif name.lower() in ['instance', 'InstanceNorm'.lower()]:
        InstanceNorm = getattr(nn, 'InstanceNorm%dd' % ndims)
        return InstanceNorm(**kwargs)
    elif name.lower() == 'None'.lower():
        return nn.Identity()
    else:
        return NotImplementedError


def get_activation(name, **kwargs):
    if name.lower() == 'ReLU'.lower():
        return nn.ReLU()
    elif name.lower() == 'GELU'.lower():
        return nn.GELU()
    elif name.lower() == 'None'.lower():
        return nn.Identity()
    else:
        return NotImplementedError



def downsampler_fn(data, out_size):
    """
    input sahep: B,Ci,Hi,Wi,Di
    output sahep: B,C,H,W,D

    """
    out = nn.functional.interpolate(data,
                                     size=out_size,
                                     mode='trilinear',
                                     align_corners=None,
                                     recompute_scale_factor=None,
                                     #antialias=False
    )
    return out.to(data.device)


class MLP(nn.Module):
    def __init__(self,
                in_feats,
                MLP_type="basic", # scmlp conv basic
                hid_feats=None,
                out_feats=None,
                kernel_size=3,
                act_name="GELU",
                drop=0.,
                bias=False,
            )->None:
        super(MLP, self).__init__()

        out_feats = out_feats or in_feats
        hid_feats = hid_feats or in_feats

        if MLP_type.lower()=="scmlp":
            # improved MLP : 3x3conv (spatial) -> eca (channel) -> mlp
            self.net = nn.Sequential(*[
                rearrange('B h w c -> B c h w'),
                nn.Conv3d(in_feats, out_feats, kernel_size=1, bias=bias),
                nn.BatchNorm3d(out_feats),
                get_activation(act_name),
                rearrange('B c h w d -> B h w d c'),
                nn.Linear(out_feats, out_feats),
                get_activation(act_name),
                nn.Dropout(drop),
                nn.Linear(out_feats, out_feats),
                get_activation(act_name),
                nn.Dropout(drop),
            ])


        elif MLP_type.lower()=="conv":
            # improved MLP # RVT cvpr2022
            self.net = nn.Sequential(*[
                rearrange('B h w c -> B c h w'),
                nn.Conv3d(in_feats, hid_feats, kernel_size=1, bias=bias),
                nn.BatchNorm2d(hid_feats),
                get_activation(act_name),
                nn.Dropout(drop),
                nn.Conv3d(hid_feats, hid_feats, kernel_size=kernel_size,
                    padding=int(kernel_size//2), groups=hid_feats, bias=bias),
                nn.BatchNorm2d(hid_feats),
                get_activation(act_name),
                nn.Conv3d(hid_feats, out_feats, kernel_size=1, bias=bias),
                nn.BatchNorm2d(out_feats),
                nn.Dropout(drop),
                rearrange('B c h w-> B h w c'),
            ])


        elif MLP_type.lower()=="basic":
            self.net = nn.Sequential(*[
                    nn.Linear(in_feats, hid_feats),
                    get_activation(act_name),
                    nn.Dropout(drop),
                    nn.Linear(hid_feats, out_feats),
                    nn.Dropout(drop)
            ])

    def forward(self, x)-> torch.Tensor:
        x = self.net(x)
        return x

# Transformation

In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from typing import Union



class SpatialTransformer(nn.Module):
    """
    N-D Spatial Transformer
    Obtained from https://github.com/voxelmorph/voxelmorph
    """
    def __init__(self, size, mode='bilinear'):
        super().__init__()

        self.mode = mode

        # create sampling grid
        vectors = [torch.arange(0, s) for s in size]
        grids = torch.meshgrid(vectors)
        grid = torch.stack(grids)
        grid = torch.unsqueeze(grid, 0)
        grid = grid.type(torch.FloatTensor)

        # registering the grid as a buffer cleanly moves it to the GPU, but it also
        # adds it to the state dict. this is annoying since everything in the state dict
        # is included when saving weights to disk, so the model files are way bigger
        # than they need to be. so far, there does not appear to be an elegant solution.
        # see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict
        self.register_buffer('grid', grid)

    def forward(self, src, flow):
        # new locations
        new_locs = self.grid + flow
        shape = flow.shape[2:]

        # need to normalize grid values to [-1, 1] for resampler
        for i in range(len(shape)):
            new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)

        # move channels dim to last position
        # also not sure why, but the channels need to be reversed
        if len(shape) == 2:
            new_locs = new_locs.permute(0, 2, 3, 1)
            new_locs = new_locs[..., [1, 0]]
        elif len(shape) == 3:
            new_locs = new_locs.permute(0, 2, 3, 4, 1)
            new_locs = new_locs[..., [2, 1, 0]]

        return F.grid_sample(src, new_locs, align_corners=False, mode=self.mode)


def normalize_displacement(displacement: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
    """
    Spatially normalize the displacement vector field to the [-1, 1] coordinate system utilized by PyTorch's `grid_sample()` function.

    This function assumes that the displacement field size is identical to the corresponding image dimensions.

    Args:
        displacement (Union[np.ndarray, torch.Tensor]): The displacement field with shape (N, ndim, *size).

    Returns:
        Union[np.ndarray, torch.Tensor]: The normalized displacement field.

    Raises:
        TypeError: If the input type is neither numpy.ndarray nor torch.Tensor.
    """
    number_of_dimensions = displacement.ndim - 2

    if isinstance(displacement, np.ndarray):
        normalization_factors = 2.0 / np.array(displacement.shape[2:])
        normalization_factors = normalization_factors.reshape(1, number_of_dimensions, *(1,) * number_of_dimensions)

    elif isinstance(displacement, torch.Tensor):
        normalization_factors = torch.tensor(2.0) / torch.tensor(
            displacement.size()[2:], dtype=displacement.dtype, device=displacement.device)
        normalization_factors = normalization_factors.view(1, number_of_dimensions, *(1,) * number_of_dimensions)

    else:
        raise TypeError("Input data type not recognized. Expected numpy.ndarray or torch.Tensor.")

    return displacement * normalization_factors


# src/model/hvit_light.py

In [None]:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Tuple, Dict, Any, List, Set, Optional, Union, Callable, Type
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import Tensor
from einops import rearrange
import lightning as L

from src.model.blocks import *
from src.model.transformation import *


WO_SELF_ATT = False # without self attention
_NUM_CROSS_ATT = -1
ndims = 3 # H,W,D

class Attention(nn.Module):
    """
    Attention module for hierarchical vision transformer.

    This module implements both local and global attention mechanisms.
    """

    def __init__(
        self,
        dim: int,
        num_heads: int,
        patch_size: Union[int, List[int]],
        attention_type: str = "local",
        qkv_bias: bool = True,
        qk_scale: Optional[float] = None,
        attn_drop: float = 0.,
        proj_drop: float = 0.
    ) -> None:
        """
        Initialize the Attention module.

        Args:
            dim (int): Input dimension.
            num_heads (int): Number of attention heads.
            patch_size (Union[int, List[int]]): Size of the patches.
            attention_type (str): Type of attention mechanism ("local" or "global").
            qkv_bias (bool): Whether to use bias in query, key, value projections.
            qk_scale (Optional[float]): Scale factor for query-key dot product.
            attn_drop (float): Dropout rate for attention matrix.
            proj_drop (float): Dropout rate for output projection.
        """
        super().__init__()

        self.dim: int = dim
        self.num_heads: int = num_heads
        self.patch_size: List[int] = [patch_size] * ndims if isinstance(patch_size, int) else patch_size
        self.attention_type: str = attention_type

        assert dim % num_heads == 0, "Dimension must be divisible by number of heads"
        self.head_dim: int = dim // num_heads
        self.scale: float = qk_scale or self.head_dim ** -0.5

        # Skip initialization if using local attention without self-attention
        if self.attention_type == "local" and WO_SELF_ATT:
            return

        # Initialize query, key, value projections based on attention type
        if attention_type == "local":
            self.qkv: nn.Linear = nn.Linear(dim, dim * 3, bias=qkv_bias)
        elif attention_type == "global":
            self.qkv: nn.Linear = nn.Linear(dim, dim * 2, bias=qkv_bias)

        self.attn_drop: nn.Dropout = nn.Dropout(attn_drop)
        self.proj: nn.Linear = nn.Linear(dim, dim)
        self.proj_drop: nn.Dropout = nn.Dropout(proj_drop)

    def forward(self, x: Tensor, q_ms: Optional[Tensor] = None) -> Tensor:
        """
        Forward pass of the Attention module.

        Args:
            x (Tensor): Input tensor.
            q_ms (Optional[Tensor]): Query tensor for global attention.

        Returns:
            Tensor: Output tensor after applying attention.
        """
        B_, N, C = x.size()

        # Return input if using local attention without self-attention
        if self.attention_type == "local" and WO_SELF_ATT:
            return x

        if self.attention_type == "local":
            qkv: Tensor = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
            q, k, v = qkv[0], qkv[1], qkv[2]
            q = q * self.scale
        else:
            B: int = q_ms.size()[0]
            kv: Tensor = self.qkv(x).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
            k, v = kv[0], kv[1]
            q: Tensor = self._process_global_query(q_ms, B, B_, N, C)

        # Compute attention scores and apply attention
        attn: Tensor = (q @ k.transpose(-2, -1))
        attn = F.softmax(attn, dim=-1)
        attn = self.attn_drop(attn)

        x: Tensor = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)

        return x

    def _process_global_query(self, q_ms: Tensor, B: int, B_: int, N: int, C: int) -> Tensor:
        """
        Process the global query tensor.

        Args:
            q_ms (Tensor): Global query tensor.
            B (int): Batch size of q_ms.
            B_ (int): Batch size of input tensor.
            N (int): Number of patches.
            C (int): Channel dimension.

        Returns:
            Tensor: Processed global query tensor.
        """
        q_tmp: Tensor = q_ms.reshape(B, self.num_heads, N, C // self.num_heads)
        div_, rem_ = divmod(B_, B)
        q_tmp = q_tmp.repeat(div_, 1, 1, 1)
        q_tmp = q_tmp.reshape(B * div_, self.num_heads, N, C // self.num_heads)

        q: Tensor = torch.zeros(B_, self.num_heads, N, C // self.num_heads, device=q_ms.device)
        q[:B*div_] = q_tmp
        if rem_ > 0:
            q[B*div_:] = q_tmp[:rem_]

        return q * self.scale


def get_patches(x: Tensor, patch_size: int) -> Tuple[Tensor, int, int, int]:
    """
    Divide the input tensor into patches and reshape them for processing.

    Args:
        x (Tensor): Input tensor of shape (B, H, W, D, C).
        patch_size (int): Size of each patch.

    Returns:
        Tuple[Tensor, int, int, int]: A tuple containing:
            - windows: Reshaped tensor of patches.
            - H, W, D: Updated dimensions of the input tensor.
    """
    B, H, W, D, C = x.size()
    nh: float = H / patch_size
    nw: float = W / patch_size
    nd: float = D / patch_size

    # Check if downsampling is required
    down_req: float = (nh - int(nh)) + (nw - int(nw)) + (nd - int(nd))
    if down_req > 0:
        # Downsample the input tensor to fit patch size
        new_dims: List[int] = [int(nh) * patch_size, int(nw) * patch_size, int(nd) * patch_size]
        x = downsampler_fn(x.permute(0, 4, 1, 2, 3), new_dims).permute(0, 2, 3, 4, 1)
        B, H, W, D, C = x.size()

    # Reshape the tensor into patches
    x = x.view(B, H // patch_size, patch_size,
               W // patch_size, patch_size,
               D // patch_size, patch_size,
               C)

    # Rearrange dimensions and flatten patches
    windows: Tensor = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, patch_size, patch_size, patch_size, C)

    return windows, H, W, D


def get_image(windows: Tensor, patch_size: int, Hatt: int, Watt: int, Datt: int, H: int, W: int, D: int) -> Tensor:
    """
    Reconstruct the image from windows (patches).

    Args:
        windows (Tensor): Input tensor containing the windows.
        patch_size (int): Size of each patch.
        Hatt, Watt, Datt (int): Dimensions of the attention space.
        H, W, D (int): Original dimensions of the image.

    Returns:
        Tensor: Reconstructed image.
    """
    # Calculate batch size
    B: int = int(windows.size(0) / ((Hatt * Watt * Datt) // (patch_size ** 3)))

    # Reshape windows into image
    x: Tensor = windows.view(B,
                    Hatt // patch_size,
                    Watt // patch_size,
                    Datt // patch_size,
                    patch_size, patch_size, patch_size, -1)
    x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, Hatt, Watt, Datt, -1)

    # Downsample if necessary
    if H != Hatt or W != Watt or D != Datt:
        x = downsampler_fn(x.permute(0, 4, 1, 2, 3), [H, W, D]).permute(0, 2, 3, 4, 1)
    return x

class ViTBlock(nn.Module):
    """
    Vision Transformer Block.
    """
    def __init__(self,
                 embed_dim: int,
                 input_dims: List[int],
                 num_heads: int,
                 mlp_type: str,
                 patch_size: int,
                 mlp_ratio: float,
                 qkv_bias: bool,
                 qk_scale: Optional[float],
                 drop: float,
                 attn_drop: float,
                 drop_path: float,
                 act_layer: str,
                 attention_type: str,
                 norm_layer: Callable[..., nn.Module],
                 layer_scale: Optional[float]):
        super().__init__()
        self.patch_size: int = patch_size
        self.num_windows: int = prod_func([d // patch_size for d in input_dims])

        self.norm1: nn.Module = norm_layer(embed_dim)
        self.attn: nn.Module = Attention(
            embed_dim,
            attention_type=attention_type,
            num_heads=num_heads,
            patch_size=patch_size,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
        )

        self.drop_path: nn.Module = timm_DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2: nn.Module = norm_layer(embed_dim)
        self.mlp: nn.Module = MLP(
            in_feats=embed_dim,
            hid_feats=int(embed_dim * mlp_ratio),
            act_name=act_layer,
            drop=drop,
            MLP_type=mlp_type
        )

        self.layer_scale: bool = layer_scale is not None and isinstance(layer_scale, (int, float))
        if self.layer_scale:
            self.gamma1: nn.Parameter = nn.Parameter(layer_scale * torch.ones(embed_dim), requires_grad=True)
            self.gamma2: nn.Parameter = nn.Parameter(layer_scale * torch.ones(embed_dim), requires_grad=True)
        else:
            self.gamma1: float = 1.0
            self.gamma2: float = 1.0

    def forward(self, x: Tensor, q_ms: Optional[Tensor]) -> Tensor:
        B, H, W, D, C = x.size()
        shortcut: Tensor = x

        x = self.norm1(x)
        x_windows, Hatt, Watt, Datt = get_patches(x, self.patch_size)
        x_windows = x_windows.view(-1, self.patch_size ** 3, C)

        attn_windows: Tensor = self.attn(x_windows, q_ms)
        x = get_image(attn_windows, self.patch_size, Hatt, Watt, Datt, H, W, D)
        x = shortcut + self.drop_path(self.gamma1 * x)
        x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x)))
        return x

class PatchEmbed(nn.Module):
    """
    Patch Embedding layer.
    """
    def __init__(self, in_chans: int = 3, out_chans: int = 32,
                 drop_rate: float = 0,
                 kernel_size: int = 3,
                 stride: int = 1, padding: int = 1,
                 dilation: int = 1, groups: int = 1, bias: bool = False) -> None:
        super().__init__()

        Convnd: Type[nn.Module] = getattr(nn, f"Conv{ndims}d")
        self.proj: nn.Module = Convnd(in_channels=in_chans, out_channels=out_chans,
                              kernel_size=kernel_size,
                              stride=stride, padding=padding,
                              dilation=dilation, groups=groups, bias=bias)

        self.drop: nn.Module = nn.Dropout(p=drop_rate)

    def forward(self, x: Tensor) -> Tensor:
        x = self.drop(self.proj(x))
        return x

class ViTLayer(nn.Module):
    """
    Vision Transformer Layer.
    """
    def __init__(
        self,
        attention_type: str,
        dim: int,
        dim_out: int,
        depth: int,
        input_dims: List[int],
        num_heads: int,
        patch_size: int,
        mlp_type: str,
        mlp_ratio: float,
        qkv_bias: bool,
        qk_scale: Optional[float],
        drop: float,
        attn_drop: float,
        drop_path: Union[float, List[float]],
        norm_layer: Callable[..., nn.Module],
        norm_type: str,
        layer_scale: Optional[float],
        act_layer: str
    ) -> None:
        super().__init__()
        self.patch_size: int = patch_size
        self.embed_dim: int = dim
        self.input_dims: List[int] = input_dims
        self.blocks: nn.ModuleList = nn.ModuleList([
            ViTBlock(
                embed_dim=dim,
                num_heads=num_heads,
                mlp_type=mlp_type,
                patch_size=patch_size,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                attention_type=attention_type,
                drop=drop,
                attn_drop=attn_drop,
                drop_path=drop_path[k] if isinstance(drop_path, list) else drop_path,
                act_layer=act_layer,
                norm_layer=norm_layer,
                layer_scale=layer_scale,
                input_dims=input_dims
            )
            for k in range(depth)
        ])

    def forward(self, inp: Tensor, q_ms: Optional[Tensor], CONCAT_ok: bool) -> Tensor:
        x: Tensor = inp.clone()
        x = rearrange(x, 'b c h w d -> b h w d c')

        if q_ms is not None:
            q_ms = rearrange(q_ms, 'b c h w d -> b h w d c')

        for blk in self.blocks:
            if q_ms is None:
                x = blk(x, None)
            else:
                q_ms_patches, _, _, _ = get_patches(q_ms, self.patch_size)
                q_ms_patches = q_ms_patches.view(-1, self.patch_size ** ndims, x.size()[-1])
                x = blk(x, q_ms_patches)

        x = rearrange(x, 'b h w d c -> b c h w d')

        if CONCAT_ok:
            x = torch.cat((inp, x), dim=-1)
        else:
            x = inp + x
        return x


class ViT(nn.Module):
    """
    Vision Transformer (ViT) module for hierarchical feature processing.
    """
    def __init__(self,
                 PYR_SCALES=None,
                 feats_num=None,
                 hid_dim=None,
                 depths=None,
                 patch_size=None,
                 mlp_ratio=None,
                 num_heads=None,
                 mlp_type=None,
                 norm_type=None,
                 act_layer=None,
                 drop_path_rate: float = 0.2,
                 qkv_bias: bool = True,
                 qk_scale: bool = None,
                 drop_rate: float = 0.,
                 attn_drop_rate: float = 0.,
                 norm_layer=nn.LayerNorm,
                 layer_scale=None,
                 img_size=None,
                 NUM_CROSS_ATT=-1):
        super().__init__()

        # Determine the number of levels for processing
        num_levels = len(feats_num)
        num_levels = min(num_levels, NUM_CROSS_ATT) if NUM_CROSS_ATT > 0 else num_levels
        if WO_SELF_ATT:
            num_levels -= 1

        # Ensure patch_size is a list
        patch_size = patch_size if isinstance(patch_size, list) else [patch_size for _ in range(num_levels)]
        hwd = img_size[-1]

        # Create patch embedding layers
        self.patch_embed = nn.ModuleList([
            PatchEmbed(
                in_chans=feats_num[i],
                out_chans=hid_dim,
                drop_rate=drop_rate
            ) for i in range(num_levels)
        ])

        # Generate drop path rate for each layer
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]

        # Create ViT layers
        self.levels = nn.ModuleList()
        for i in range(num_levels):
            level = ViTLayer(
                dim=hid_dim,
                dim_out=hid_dim,
                depth=depths[i],
                num_heads=num_heads[i],
                patch_size=patch_size[i],
                mlp_type=mlp_type,
                attention_type="local" if i == 0 else "global",
                drop_path=dpr[sum(depths[:i]):sum(depths[:i+1])],
                input_dims=img_size[i],
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                norm_layer=norm_layer,
                layer_scale=layer_scale,
                norm_type=norm_type,
                act_layer=act_layer
            )
            self.levels.append(level)

    def _init_weights(self, m):
        """Initialize the weights of the module."""
        if isinstance(m, nn.Linear):
            timm_trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        """Return keywords for no weight decay."""
        return {'rpb'}

    def forward(self, KQs, CONCAT_ok: bool = False):
        """
        Forward pass of the ViT module.

        Args:
            KQs (List[Tensor]): List of input tensors for each level.
            CONCAT_ok (bool): Flag to determine if concatenation is allowed.

        Returns:
            Tensor: Processed output tensor.
        """
        for i, (patch_embed_, level) in enumerate(zip(self.patch_embed, self.levels)):
            if i == 0:
                # First level: process input without cross-attention
                Q = patch_embed_(KQs[i])
                x = level(Q, None, CONCAT_ok=CONCAT_ok)
                Q = patch_embed_(x)
            else:
                # Subsequent levels: process with cross-attention
                K = patch_embed_(KQs[i])
                x = level(Q, K, CONCAT_ok=CONCAT_ok)
                Q = x.clone()

        return x


class EncoderCnnBlock(nn.Module):
    """
    Convolutional block for the encoder part of the network.
    """
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding=1,
        bias=False,
        affine=True,
        eps=1e-05
    ):
        super().__init__()

        # First convolutional block
        conv_block_1 = [
            nn.Conv3d(
                in_channels=in_channels, out_channels=out_channels,
                kernel_size=kernel_size, stride=stride, padding=padding,
                bias=bias
            ),
            nn.InstanceNorm3d(num_features=out_channels, affine=affine, eps=eps),
            nn.ReLU(inplace=True)
        ]

        # Second convolutional block
        conv_block_2 = [
            nn.Conv3d(
                in_channels=out_channels, out_channels=out_channels,
                kernel_size=kernel_size, stride=1, padding=padding,
                bias=bias
            ),
            nn.InstanceNorm3d(num_features=out_channels, affine=affine, eps=eps),
            nn.ReLU(inplace=True)
        ]

        # Combine both blocks
        self._block = nn.Sequential(
            *conv_block_1,
            *conv_block_2
        )

    def forward(self, x):
        """Forward pass of the EncoderCnnBlock."""
        return self._block(x)




class Decoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self._num_stages: int = config['num_stages']
        self.use_seg: bool = config['use_seg_loss']


        # Determine channels of encoder feature maps
        encoder_out_channels: torch.Tensor = torch.tensor([config['start_channels'] * 2**stage for stage in range(self._num_stages)])

        # Estimate required stages
        required_stages: Set[int] = set(int(fmap[-1]) for fmap in config['out_fmaps'])
        self._required_stages: Set[int] = required_stages

        earliest_required_stage: int = min(required_stages)

        # Lateral connections
        lateral_in_channels: torch.Tensor = encoder_out_channels[earliest_required_stage:]
        lateral_out_channels: torch.Tensor = lateral_in_channels.clip(max=config['fpn_channels'])

        self._lateral: nn.ModuleList = nn.ModuleList([
            nn.Conv3d(in_channels=in_ch, out_channels=out_ch, kernel_size=1)
            for in_ch, out_ch in zip(lateral_in_channels, lateral_out_channels)
        ])
        self._lateral_levels: int = len(self._lateral)

        # Output layers
        out_in_channels: List[int] = [lateral_out_channels[-self._num_stages + required_stage].item() for required_stage in required_stages]
        out_out_channels: List[int] = [int(config['fpn_channels'])] * len(out_in_channels)
        out_out_channels[0] = int(config['fpn_channels'])

        self._out: nn.ModuleList = nn.ModuleList([
            nn.Conv3d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, padding=1)
            for in_ch, out_ch in zip(out_in_channels, out_out_channels)
        ])

        # Upsampling layers
        self._up: nn.ModuleList = nn.ModuleList([
            nn.ConvTranspose3d(
                in_channels=list(reversed(lateral_out_channels))[level],
                out_channels=list(reversed(lateral_out_channels))[level+1],
                kernel_size=list(reversed(config['strides']))[level],
                stride=list(reversed(config['strides']))[level]
            )
            for level in range(len(lateral_out_channels)-1)
        ])

        # Multi-scale attention
        self.hierarchical_dec: nn.ModuleList = self._create_hierarchical_layers(config, out_out_channels)

        if self.use_seg:
            self._seg_head: nn.ModuleList = nn.ModuleList([
                nn.Conv3d(out_ch, config['num_organs'] + 1, kernel_size=1, stride=1)
                for out_ch in out_out_channels
            ])

    def _create_hierarchical_layers(self, config: Dict[str, Any], out_out_channels: List[int]) -> nn.ModuleList:
        out: nn.ModuleList = nn.ModuleList()
        img_size: List[List[int]] = []
        feats_num: List[int] = []

        for k, out_ch in enumerate(out_out_channels):
            img_size.append([int(item/(2**(self._num_stages-k-1))) for item in config['data_size']])
            feats_num.append(out_ch)
            n: int = len(feats_num)

            if k == 0:
                out.append(nn.Identity())
            else:
                out.append(
                    ViT(
                        NUM_CROSS_ATT=config.get('NUM_CROSS_ATT', _NUM_CROSS_ATT),
                        PYR_SCALES=[1.],
                        feats_num=feats_num,
                        hid_dim=int(config.get('fpn_channels', 64)),
                        depths=[int(config.get('depths', 1))]*n,
                        patch_size=[int(config.get('patch_size', 2))]*n,
                        mlp_ratio=int(config.get('mlp_ratio', 2)),
                        num_heads=[int(config.get('num_heads', 32))]*n,
                        mlp_type='basic',
                        norm_type='BatchNorm2d',
                        act_layer='gelu',
                        drop_path_rate=config.get('drop_path_rate', 0.2),
                        qkv_bias=config.get('qkv_bias', True),
                        qk_scale=None,
                        drop_rate=config.get('drop_rate', 0.),
                        attn_drop_rate=config.get('attn_drop_rate', 0.),
                        norm_layer=nn.LayerNorm,
                        layer_scale=1e-5,
                        img_size=img_size
                    )
                )
        return out

    def forward(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]:
        lateral_out: List[Tensor] = [lateral(fmap) for lateral, fmap in zip(self._lateral, list(x.values())[-self._lateral_levels:])]

        up_out: List[Tensor] = []
        for idx, x in enumerate(reversed(lateral_out)):
            if idx != 0:
                x = x + up

            if idx < self._lateral_levels - 1:
                up = self._up[idx](x)

            up_out.append(x)

        cnn_outputs: Dict[int, Tensor] = {stage: self._out[idx](fmap) for idx, (fmap, stage) in enumerate(zip(reversed(up_out), self._required_stages))}
        return self._forward_hierarchical(cnn_outputs)

    def _forward_hierarchical(self, cnn_outputs: Dict[int, Tensor]) -> Dict[str, Tensor]:
        xs: List[Tensor] = [cnn_outputs[key].clone() for key in range(max(cnn_outputs.keys()), min(cnn_outputs.keys())-1, -1)]

        out_dict: Dict[str, Tensor] = {}
        for i, key in enumerate(range(max(cnn_outputs.keys()), min(cnn_outputs.keys())-1, -1)):
            QK = xs[0:i+1]
            QK.reverse()

            if i == 0:
                Pi = QK[0]
            else:
                Pi = self.hierarchical_dec[i](QK)
            out_dict[f'P{key}'] = Pi

            if self.use_seg:
                Pi_seg = self._seg_head[i](Pi)
                out_dict[f'S{key}'] = Pi_seg
        return out_dict








class HierarchicalViT(nn.Module):
    """
    Hierarchical Vision Transformer (HViT) for image processing tasks.
    """
    def __init__(self, config: Dict[str, Any]):
        super().__init__()

        # Configuration parameters
        self.backbone: str = config['backbone_net']
        in_channels: int = 2 * config.get('in_channels', 1)  # source + target
        kernel_size: int = config.get('kernel_size', 3)
        emb_dim: int = config.get('start_channels', 32)
        data_size: Tuple[int, ...] = config.get('data_size', [160, 192, 224])
        self.out_fmaps: List[str] = config.get('out_fmaps', ['P4', 'P3', 'P2', 'P1'])

        # Calculate number of stages
        num_stages: int = min(int(math.log2(min(data_size))) - 1,
                              max(int(fmap[-1]) for fmap in self.out_fmaps) + 1)

        strides: List[int] = [1] + [2] * (num_stages - 1)
        kernel_sizes: List[int] = [kernel_size] * num_stages

        config['num_stages'] = num_stages
        config['strides'] = strides

        # Build encoder
        self._encoder: nn.ModuleList = nn.ModuleList()
        if self.backbone in ['fpn', 'FPN']:
            for k in range(num_stages):
                blk = EncoderCnnBlock(
                    in_channels=in_channels,
                    out_channels=emb_dim,
                    kernel_size=kernel_sizes[k],
                    stride=strides[k]
                )
                self._encoder.append(blk)

                in_channels = emb_dim
                emb_dim *= 2

        # Build decoder
        if self.backbone in ['fpn', 'FPN']:
            self._decoder: Decoder = Decoder(config)

    def init_weights(self) -> None:
        """
        Initialize model weights.
        """
        # TODO: Implement weight initialization

    def forward(self, x: Tensor, verbose: bool = False) -> Dict[str, Tensor]:
        """
        Forward pass of the HierarchicalViT model.

        Args:
            x (Tensor): Input tensor.
            verbose (bool): If True, print shape information.

        Returns:
            Dict[str, Tensor]: Output feature maps.
        """
        down: Dict[str, Tensor] = {}
        if self.backbone in ['fpn', 'FPN']:
            for stage_id, module in enumerate(self._encoder):
                x = module(x)
                down[f'C{stage_id}'] = x
            up = self._decoder(down)

        if verbose:
            for key, item in down.items():
                print(f'down {key}', item.shape)
            for key, item in up.items():
                print(f'up {key}', item.shape)
        return up


class RegistrationHead(nn.Sequential):
    """
    Registration head for generating displacement fields.
    """
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3):
        super().__init__()
        conv3d = nn.Conv3d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            padding=kernel_size // 2
        )
        # Initialize weights with small random values
        conv3d.weight = nn.Parameter(torch.zeros_like(conv3d.weight).normal_(0, 1e-5))
        conv3d.bias = nn.Parameter(torch.zeros(conv3d.bias.shape))
        self.add_module('conv3d', conv3d)


class HViT_Light(nn.Module):
    """
    Light Hierarchical Vision Transformer (HViT) model for image registration.
    """
    def __init__(self, config: dict):
        super(HViT_Light, self).__init__()
        self.upsample_df: bool = config.get('upsample_df', False)
        self.upsample_scale_factor: int = config.get('upsample_scale_factor', 2)
        self.scale_level_df: str = config.get('scale_level_df', 'P1')

        self.deformable: HierarchicalViT = HierarchicalViT(config)
        self.avg_pool: nn.AvgPool3d = nn.AvgPool3d(3, stride=2, padding=1)
        self.spatial_trans: SpatialTransformer = SpatialTransformer(config['data_size'])
        self.reg_head: RegistrationHead = RegistrationHead(
            in_channels=config.get('fpn_channels', 64),
            out_channels=ndims,
            kernel_size=ndims,
        )

    def forward(self, source: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
        """
        Forward pass of the HViT model.

        Args:
            source (Tensor): Source image tensor.
            target (Tensor): Target image tensor.

        Returns:
            Tuple[Tensor, Tensor]: Moved image and displacement field.
        """
        x: Tensor = torch.cat((source, target), dim=1)
        x_dec: Dict[str, Tensor] = self.deformable(x)

        # Extract features at the specified scale level
        x_dec: Tensor = x_dec[self.scale_level_df]
        flow: Tensor = self.reg_head(x_dec)

        if self.upsample_df:
            flow = nn.Upsample(scale_factor=self.upsample_scale_factor,
                               mode='trilinear',
                               align_corners=False)(flow)

        moved: Tensor = self.spatial_trans(source, flow)
        return moved, flow


if __name__ == "__main__":
    # Test the HViT model
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    B = 1
    H, W, D = 160//2, 192//2, 224//2

    for fpn_channels in [64]:
        config = {
            'NUM_CROSS_ATT': _NUM_CROSS_ATT,
            'out_fmaps': ['P4', 'P3', 'P2', 'P1'],
            'scale_level_df': 'P1',
            'upsample_df': True,
            'upsample_scale_factor': 2,
            'fpn_channels': fpn_channels,
            'start_channels': 32,
            'patch_size': 2,
            'bspl': False,

            'backbone_net': 'fpn',
            'in_channels': 1,
            'data_size': [H, W, D],
            'bias': True,
            'norm_type': 'instance',
            'cuda': 0,
            'kernel_size': 3,
            'depths': 1,
            'mlp_ratio': 2,

            'num_heads': 32,
            'drop_path_rate': 0.,
            'qkv_bias': True,
            'drop_rate': 0.,
            'attn_drop_rate': 0.,

            'use_seg_loss': False,
            'use_seg_proxy_loss': False,
            'num_organs': -1
        }

        source = torch.rand([1, 1, H, W, D])
        tgt = torch.rand([1, 1, H, W, D])
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        model = HViT_Light(config)
        model.to(device)

        with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
            source = source.to(dtype=torch.float16).to(device)
            tgt = tgt.to(dtype=torch.float16).to(device)

            moved, flow = model(source, tgt)
            print('\n\nmoved {} flow {}'.format(moved.shape, flow.shape))

            max_mem_mb = torch.cuda.max_memory_allocated() / 1024**3
            print("[+] Maximum memory:\t{:.2f}GB: >>> \t{:.0f} feats".format(max_mem_mb, config['fpn_channels']) if max_mem_mb is not None else "")
            print("[+] Required Total memory:\t{:.2f}GB".format(torch.cuda.get_device_properties(0).total_memory/1024**3))
            print("[+] Trainable params:\t{:.5f} m".format(count_parameters(model)/1e6))


# Hvit

In [None]:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Tuple, Dict, Any, List, Set, Optional, Union, Callable, Type
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import Tensor
from einops import rearrange
import lightning as L

from src.model.blocks import *
from src.model.transformation import *


WO_SELF_ATT = False # without self attention
_NUM_CROSS_ATT = -1
ndims = 3 # H,W,D

class Attention(nn.Module):
    """
    Attention module for hierarchical vision transformer.

    This module implements both local and global attention mechanisms.
    """

    def __init__(
        self,
        dim: int,
        num_heads: int,
        patch_size: Union[int, List[int]],
        attention_type: str = "local",
        qkv_bias: bool = True,
        qk_scale: Optional[float] = None,
        attn_drop: float = 0.,
        proj_drop: float = 0.
    ) -> None:
        """
        Initialize the Attention module.

        Args:
            dim (int): Input dimension.
            num_heads (int): Number of attention heads.
            patch_size (Union[int, List[int]]): Size of the patches.
            attention_type (str): Type of attention mechanism ("local" or "global").
            qkv_bias (bool): Whether to use bias in query, key, value projections.
            qk_scale (Optional[float]): Scale factor for query-key dot product.
            attn_drop (float): Dropout rate for attention matrix.
            proj_drop (float): Dropout rate for output projection.
        """
        super().__init__()

        self.dim: int = dim
        self.num_heads: int = num_heads
        self.patch_size: List[int] = [patch_size] * ndims if isinstance(patch_size, int) else patch_size
        self.attention_type: str = attention_type

        assert dim % num_heads == 0, "Dimension must be divisible by number of heads"
        self.head_dim: int = dim // num_heads
        self.scale: float = qk_scale or self.head_dim ** -0.5

        # Skip initialization if using local attention without self-attention
        if self.attention_type == "local" and WO_SELF_ATT:
            return

        # Initialize query, key, value projections based on attention type
        if attention_type == "local":
            self.qkv: nn.Linear = nn.Linear(dim, dim * 3, bias=qkv_bias)
        elif attention_type == "global":
            self.qkv: nn.Linear = nn.Linear(dim, dim * 2, bias=qkv_bias)

        self.attn_drop: nn.Dropout = nn.Dropout(attn_drop)
        self.proj: nn.Linear = nn.Linear(dim, dim)
        self.proj_drop: nn.Dropout = nn.Dropout(proj_drop)

    def forward(self, x: Tensor, q_ms: Optional[Tensor] = None) -> Tensor:
        """
        Forward pass of the Attention module.

        Args:
            x (Tensor): Input tensor.
            q_ms (Optional[Tensor]): Query tensor for global attention.

        Returns:
            Tensor: Output tensor after applying attention.
        """
        B_, N, C = x.size()

        # Return input if using local attention without self-attention
        if self.attention_type == "local" and WO_SELF_ATT:
            return x

        if self.attention_type == "local":
            qkv: Tensor = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
            q, k, v = qkv[0], qkv[1], qkv[2]
            q = q * self.scale
        else:
            B: int = q_ms.size()[0]
            kv: Tensor = self.qkv(x).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
            k, v = kv[0], kv[1]
            q: Tensor = self._process_global_query(q_ms, B, B_, N, C)

        # Compute attention scores and apply attention
        attn: Tensor = (q @ k.transpose(-2, -1))
        attn = F.softmax(attn, dim=-1)
        attn = self.attn_drop(attn)

        x: Tensor = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)

        return x

    def _process_global_query(self, q_ms: Tensor, B: int, B_: int, N: int, C: int) -> Tensor:
        """
        Process the global query tensor.

        Args:
            q_ms (Tensor): Global query tensor.
            B (int): Batch size of q_ms.
            B_ (int): Batch size of input tensor.
            N (int): Number of patches.
            C (int): Channel dimension.

        Returns:
            Tensor: Processed global query tensor.
        """
        q_tmp: Tensor = q_ms.reshape(B, self.num_heads, N, C // self.num_heads)
        div_, rem_ = divmod(B_, B)
        q_tmp = q_tmp.repeat(div_, 1, 1, 1)
        q_tmp = q_tmp.reshape(B * div_, self.num_heads, N, C // self.num_heads)

        q: Tensor = torch.zeros(B_, self.num_heads, N, C // self.num_heads, device=q_ms.device)
        q[:B*div_] = q_tmp
        if rem_ > 0:
            q[B*div_:] = q_tmp[:rem_]

        return q * self.scale


def get_patches(x: Tensor, patch_size: int) -> Tuple[Tensor, int, int, int]:
    """
    Divide the input tensor into patches and reshape them for processing.

    Args:
        x (Tensor): Input tensor of shape (B, H, W, D, C).
        patch_size (int): Size of each patch.

    Returns:
        Tuple[Tensor, int, int, int]: A tuple containing:
            - windows: Reshaped tensor of patches.
            - H, W, D: Updated dimensions of the input tensor.
    """
    B, H, W, D, C = x.size()
    nh: float = H / patch_size
    nw: float = W / patch_size
    nd: float = D / patch_size

    # Check if downsampling is required
    down_req: float = (nh - int(nh)) + (nw - int(nw)) + (nd - int(nd))
    if down_req > 0:
        # Downsample the input tensor to fit patch size
        new_dims: List[int] = [int(nh) * patch_size, int(nw) * patch_size, int(nd) * patch_size]
        x = downsampler_fn(x.permute(0, 4, 1, 2, 3), new_dims).permute(0, 2, 3, 4, 1)
        B, H, W, D, C = x.size()

    # Reshape the tensor into patches
    x = x.view(B, H // patch_size, patch_size,
               W // patch_size, patch_size,
               D // patch_size, patch_size,
               C)

    # Rearrange dimensions and flatten patches
    windows: Tensor = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, patch_size, patch_size, patch_size, C)

    return windows, H, W, D


def get_image(windows: Tensor, patch_size: int, Hatt: int, Watt: int, Datt: int, H: int, W: int, D: int) -> Tensor:
    """
    Reconstruct the image from windows (patches).

    Args:
        windows (Tensor): Input tensor containing the windows.
        patch_size (int): Size of each patch.
        Hatt, Watt, Datt (int): Dimensions of the attention space.
        H, W, D (int): Original dimensions of the image.

    Returns:
        Tensor: Reconstructed image.
    """
    # Calculate batch size
    B: int = int(windows.size(0) / ((Hatt * Watt * Datt) // (patch_size ** 3)))

    # Reshape windows into image
    x: Tensor = windows.view(B,
                    Hatt // patch_size,
                    Watt // patch_size,
                    Datt // patch_size,
                    patch_size, patch_size, patch_size, -1)
    x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, Hatt, Watt, Datt, -1)

    # Downsample if necessary
    if H != Hatt or W != Watt or D != Datt:
        x = downsampler_fn(x.permute(0, 4, 1, 2, 3), [H, W, D]).permute(0, 2, 3, 4, 1)
    return x

class ViTBlock(nn.Module):
    """
    Vision Transformer Block.
    """
    def __init__(self,
                 embed_dim: int,
                 input_dims: List[int],
                 num_heads: int,
                 mlp_type: str,
                 patch_size: int,
                 mlp_ratio: float,
                 qkv_bias: bool,
                 qk_scale: Optional[float],
                 drop: float,
                 attn_drop: float,
                 drop_path: float,
                 act_layer: str,
                 attention_type: str,
                 norm_layer: Callable[..., nn.Module],
                 layer_scale: Optional[float]):
        super().__init__()
        self.patch_size: int = patch_size
        self.num_windows: int = prod_func([d // patch_size for d in input_dims])

        self.norm1: nn.Module = norm_layer(embed_dim)
        self.attn: nn.Module = Attention(
            embed_dim,
            attention_type=attention_type,
            num_heads=num_heads,
            patch_size=patch_size,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
        )

        self.drop_path: nn.Module = timm_DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2: nn.Module = norm_layer(embed_dim)
        self.mlp: nn.Module = MLP(
            in_feats=embed_dim,
            hid_feats=int(embed_dim * mlp_ratio),
            act_name=act_layer,
            drop=drop,
            MLP_type=mlp_type
        )

        self.layer_scale: bool = layer_scale is not None and isinstance(layer_scale, (int, float))
        if self.layer_scale:
            self.gamma1: nn.Parameter = nn.Parameter(layer_scale * torch.ones(embed_dim), requires_grad=True)
            self.gamma2: nn.Parameter = nn.Parameter(layer_scale * torch.ones(embed_dim), requires_grad=True)
        else:
            self.gamma1: float = 1.0
            self.gamma2: float = 1.0

    def forward(self, x: Tensor, q_ms: Optional[Tensor]) -> Tensor:
        B, H, W, D, C = x.size()
        shortcut: Tensor = x

        x = self.norm1(x)
        x_windows, Hatt, Watt, Datt = get_patches(x, self.patch_size)
        x_windows = x_windows.view(-1, self.patch_size ** 3, C)

        attn_windows: Tensor = self.attn(x_windows, q_ms)
        x = get_image(attn_windows, self.patch_size, Hatt, Watt, Datt, H, W, D)
        x = shortcut + self.drop_path(self.gamma1 * x)
        x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x)))
        return x

class PatchEmbed(nn.Module):
    """
    Patch Embedding layer.
    """
    def __init__(self, in_chans: int = 3, out_chans: int = 32,
                 drop_rate: float = 0,
                 kernel_size: int = 3,
                 stride: int = 1, padding: int = 1,
                 dilation: int = 1, groups: int = 1, bias: bool = False) -> None:
        super().__init__()

        Convnd: Type[nn.Module] = getattr(nn, f"Conv{ndims}d")
        self.proj: nn.Module = Convnd(in_channels=in_chans, out_channels=out_chans,
                              kernel_size=kernel_size,
                              stride=stride, padding=padding,
                              dilation=dilation, groups=groups, bias=bias)

        self.drop: nn.Module = nn.Dropout(p=drop_rate)

    def forward(self, x: Tensor) -> Tensor:
        x = self.drop(self.proj(x))
        return x

class ViTLayer(nn.Module):
    """
    Vision Transformer Layer.
    """
    def __init__(
        self,
        attention_type: str,
        dim: int,
        dim_out: int,
        depth: int,
        input_dims: List[int],
        num_heads: int,
        patch_size: int,
        mlp_type: str,
        mlp_ratio: float,
        qkv_bias: bool,
        qk_scale: Optional[float],
        drop: float,
        attn_drop: float,
        drop_path: Union[float, List[float]],
        norm_layer: Callable[..., nn.Module],
        norm_type: str,
        layer_scale: Optional[float],
        act_layer: str
    ) -> None:
        super().__init__()
        self.patch_size: int = patch_size
        self.embed_dim: int = dim
        self.input_dims: List[int] = input_dims
        self.blocks: nn.ModuleList = nn.ModuleList([
            ViTBlock(
                embed_dim=dim,
                num_heads=num_heads,
                mlp_type=mlp_type,
                patch_size=patch_size,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                attention_type=attention_type,
                drop=drop,
                attn_drop=attn_drop,
                drop_path=drop_path[k] if isinstance(drop_path, list) else drop_path,
                act_layer=act_layer,
                norm_layer=norm_layer,
                layer_scale=layer_scale,
                input_dims=input_dims
            )
            for k in range(depth)
        ])

    def forward(self, inp: Tensor, q_ms: Optional[Tensor], CONCAT_ok: bool) -> Tensor:
        x: Tensor = inp.clone()
        x = rearrange(x, 'b c h w d -> b h w d c')

        if q_ms is not None:
            q_ms = rearrange(q_ms, 'b c h w d -> b h w d c')

        for blk in self.blocks:
            if q_ms is None:
                x = blk(x, None)
            else:
                q_ms_patches, _, _, _ = get_patches(q_ms, self.patch_size)
                q_ms_patches = q_ms_patches.view(-1, self.patch_size ** ndims, x.size()[-1])
                x = blk(x, q_ms_patches)

        x = rearrange(x, 'b h w d c -> b c h w d')

        if CONCAT_ok:
            x = torch.cat((inp, x), dim=-1)
        else:
            x = inp + x
        return x


class ViT(nn.Module):
    """
    Vision Transformer (ViT) module for hierarchical feature processing.
    """
    def __init__(self,
                 PYR_SCALES=None,
                 feats_num=None,
                 hid_dim=None,
                 depths=None,
                 patch_size=None,
                 mlp_ratio=None,
                 num_heads=None,
                 mlp_type=None,
                 norm_type=None,
                 act_layer=None,
                 drop_path_rate: float = 0.2,
                 qkv_bias: bool = True,
                 qk_scale: bool = None,
                 drop_rate: float = 0.,
                 attn_drop_rate: float = 0.,
                 norm_layer=nn.LayerNorm,
                 layer_scale=None,
                 img_size=None,
                 NUM_CROSS_ATT=-1):
        super().__init__()

        # Determine the number of levels for processing
        num_levels = len(feats_num)
        num_levels = min(num_levels, NUM_CROSS_ATT) if NUM_CROSS_ATT > 0 else num_levels
        if WO_SELF_ATT:
            num_levels -= 1

        # Ensure patch_size is a list
        patch_size = patch_size if isinstance(patch_size, list) else [patch_size for _ in range(num_levels)]
        hwd = img_size[-1]

        # Create patch embedding layers
        self.patch_embed = nn.ModuleList([
            PatchEmbed(
                in_chans=feats_num[i],
                out_chans=hid_dim,
                drop_rate=drop_rate
            ) for i in range(num_levels)
        ])

        # Generate drop path rate for each layer
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]

        # Create ViT layers
        self.levels = nn.ModuleList()
        for i in range(num_levels):
            level = ViTLayer(
                dim=hid_dim,
                dim_out=hid_dim,
                depth=depths[i],
                num_heads=num_heads[i],
                patch_size=patch_size[i],
                mlp_type=mlp_type,
                attention_type="local" if i == 0 else "global",
                drop_path=dpr[sum(depths[:i]):sum(depths[:i+1])],
                input_dims=img_size[i],
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                norm_layer=norm_layer,
                layer_scale=layer_scale,
                norm_type=norm_type,
                act_layer=act_layer
            )
            self.levels.append(level)

    def _init_weights(self, m):
        """Initialize the weights of the module."""
        if isinstance(m, nn.Linear):
            timm_trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        """Return keywords for no weight decay."""
        return {'rpb'}

    def forward(self, KQs, CONCAT_ok: bool = False):
        """
        Forward pass of the ViT module.

        Args:
            KQs (List[Tensor]): List of input tensors for each level.
            CONCAT_ok (bool): Flag to determine if concatenation is allowed.

        Returns:
            Tensor: Processed output tensor.
        """
        for i, (patch_embed_, level) in enumerate(zip(self.patch_embed, self.levels)):
            if i == 0:
                # First level: process input without cross-attention
                Q = patch_embed_(KQs[i])
                x = level(Q, None, CONCAT_ok=CONCAT_ok)
                Q = patch_embed_(x)
            else:
                # Subsequent levels: process with cross-attention
                K = patch_embed_(KQs[i])
                x = level(Q, K, CONCAT_ok=CONCAT_ok)
                Q = x.clone()

        return x


class EncoderCnnBlock(nn.Module):
    """
    Convolutional block for the encoder part of the network.
    """
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding=1,
        bias=False,
        affine=True,
        eps=1e-05
    ):
        super().__init__()

        # First convolutional block
        conv_block_1 = [
            nn.Conv3d(
                in_channels=in_channels, out_channels=out_channels,
                kernel_size=kernel_size, stride=stride, padding=padding,
                bias=bias
            ),
            nn.InstanceNorm3d(num_features=out_channels, affine=affine, eps=eps),
            nn.ReLU(inplace=True)
        ]

        # Second convolutional block
        conv_block_2 = [
            nn.Conv3d(
                in_channels=out_channels, out_channels=out_channels,
                kernel_size=kernel_size, stride=1, padding=padding,
                bias=bias
            ),
            nn.InstanceNorm3d(num_features=out_channels, affine=affine, eps=eps),
            nn.ReLU(inplace=True)
        ]

        # Combine both blocks
        self._block = nn.Sequential(
            *conv_block_1,
            *conv_block_2
        )

    def forward(self, x):
        """Forward pass of the EncoderCnnBlock."""
        return self._block(x)


class Decoder(nn.Module):
    """
    Decoder module for the hierarchical vision transformer.
    """
    def __init__(self, config: Dict[str, Any]):
        super().__init__()
        self._num_stages: int = config['num_stages']
        self.use_seg: bool = config['use_seg_loss']

        # Determine channels of encoder feature maps
        encoder_out_channels: torch.Tensor = torch.tensor([config['start_channels'] * 2**stage for stage in range(self._num_stages)])

        # Estimate required stages
        required_stages: Set[int] = set(int(fmap[-1]) for fmap in config['out_fmaps'])
        self._required_stages: Set[int] = required_stages

        earliest_required_stage: int = min(required_stages)

        # Lateral connections
        lateral_in_channels: torch.Tensor = encoder_out_channels[earliest_required_stage:]
        lateral_out_channels: torch.Tensor = lateral_in_channels.clip(max=config['fpn_channels'])

        self._lateral: nn.ModuleList = nn.ModuleList([
            nn.Conv3d(in_channels=in_ch, out_channels=out_ch, kernel_size=1)
            for in_ch, out_ch in zip(lateral_in_channels, lateral_out_channels)
        ])
        self._lateral_levels: int = len(self._lateral)

        # Output layers
        out_in_channels: List[int] = [lateral_out_channels[-self._num_stages + required_stage].item() for required_stage in required_stages]
        out_out_channels: List[int] = [int(config['fpn_channels'])] * len(out_in_channels)
        out_out_channels[0] = int(config['fpn_channels'])

        self._out: nn.ModuleList = nn.ModuleList([
            nn.Conv3d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, padding=1)
            for in_ch, out_ch in zip(out_in_channels, out_out_channels)
        ])

        # Upsampling layers
        self._up: nn.ModuleList = nn.ModuleList([
            nn.ConvTranspose3d(
                in_channels=list(reversed(lateral_out_channels))[level],
                out_channels=list(reversed(lateral_out_channels))[level+1],
                kernel_size=list(reversed(config['strides']))[level],
                stride=list(reversed(config['strides']))[level]
            )
            for level in range(len(lateral_out_channels)-1)
        ])

        # Multi-scale attention
        self.hierarchical_dec: nn.ModuleList = self._create_hierarchical_layers(config, out_out_channels)

        if self.use_seg:
            self._seg_head: nn.ModuleList = nn.ModuleList([
                nn.Conv3d(out_ch, config['num_organs'] + 1, kernel_size=1, stride=1)
                for out_ch in out_out_channels
            ])

    def _create_hierarchical_layers(self, config: Dict[str, Any], out_out_channels: List[int]) -> nn.ModuleList:
        """Create hierarchical layers for multi-scale attention."""
        out: nn.ModuleList = nn.ModuleList()
        img_size: List[List[int]] = []
        feats_num: List[int] = []

        for k, out_ch in enumerate(out_out_channels):
            img_size.append([int(item/(2**(self._num_stages-k-1))) for item in config['data_size']])
            feats_num.append(out_ch)
            n: int = len(feats_num)

            if k == 0:
                out.append(nn.Identity())
            else:
                out.append(
                    ViT(
                        NUM_CROSS_ATT=config.get('NUM_CROSS_ATT', _NUM_CROSS_ATT),
                        PYR_SCALES=[1.],
                        feats_num=feats_num,
                        hid_dim=int(config.get('fpn_channels', 64)),
                        depths=[int(config.get('depths', 1))]*n,
                        patch_size=[int(config.get('patch_size', 2))]*n,
                        mlp_ratio=int(config.get('mlp_ratio', 2)),
                        num_heads=[int(config.get('num_heads', 32))]*n,
                        mlp_type='basic',
                        norm_type='BatchNorm2d',
                        act_layer='gelu',
                        drop_path_rate=config.get('drop_path_rate', 0.2),
                        qkv_bias=config.get('qkv_bias', True),
                        qk_scale=None,
                        drop_rate=config.get('drop_rate', 0.),
                        attn_drop_rate=config.get('attn_drop_rate', 0.),
                        norm_layer=nn.LayerNorm,
                        layer_scale=1e-5,
                        img_size=img_size
                    )
                )
        return out

    def forward(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]:
        """Forward pass of the Decoder."""
        lateral_out: List[Tensor] = [lateral(fmap) for lateral, fmap in zip(self._lateral, list(x.values())[-self._lateral_levels:])]

        up_out: List[Tensor] = []
        for idx, x in enumerate(reversed(lateral_out)):
            if idx != 0:
                x = x + up

            if idx < self._lateral_levels - 1:
                up = self._up[idx](x)

            up_out.append(x)

        cnn_outputs: Dict[int, Tensor] = {stage: self._out[idx](fmap) for idx, (fmap, stage) in enumerate(zip(reversed(up_out), self._required_stages))}
        return self._forward_hierarchical(cnn_outputs)

    def _forward_hierarchical(self, cnn_outputs: Dict[int, Tensor]) -> Dict[str, Tensor]:
        """Forward pass through the hierarchical decoder."""
        xs: List[Tensor] = [cnn_outputs[key].clone() for key in range(max(cnn_outputs.keys()), min(cnn_outputs.keys())-1, -1)]

        out_dict: Dict[str, Tensor] = {}
        QK: List[Tensor] = []
        for i, key in enumerate(range(max(cnn_outputs.keys()), min(cnn_outputs.keys())-1, -1)):
            QK = [xs[i]] + QK
            if i == 0:
                Pi = QK[0]
            else:
                Pi = self.hierarchical_dec[i](QK)
            QK[0] = Pi
            out_dict[f'P{key}'] = Pi

            if self.use_seg:
                Pi_seg = self._seg_head[i](Pi)
                out_dict[f'S{key}'] = Pi_seg

        return out_dict




class HierarchicalViT(nn.Module):
    """
    Hierarchical Vision Transformer (HViT) for image processing tasks.
    """
    def __init__(self, config: Dict[str, Any]):
        super().__init__()

        # Configuration parameters
        self.backbone: str = config['backbone_net']
        in_channels: int = 2 * config.get('in_channels', 1)  # source + target
        kernel_size: int = config.get('kernel_size', 3)
        emb_dim: int = config.get('start_channels', 32)
        data_size: Tuple[int, ...] = config.get('data_size', [160, 192, 224])
        self.out_fmaps: List[str] = config.get('out_fmaps', ['P4', 'P3', 'P2', 'P1'])

        # Calculate number of stages
        num_stages: int = min(int(math.log2(min(data_size))) - 1,
                              max(int(fmap[-1]) for fmap in self.out_fmaps) + 1)

        strides: List[int] = [1] + [2] * (num_stages - 1)
        kernel_sizes: List[int] = [kernel_size] * num_stages

        config['num_stages'] = num_stages
        config['strides'] = strides

        # Build encoder
        self._encoder: nn.ModuleList = nn.ModuleList()
        if self.backbone in ['fpn', 'FPN']:
            for k in range(num_stages):
                blk = EncoderCnnBlock(
                    in_channels=in_channels,
                    out_channels=emb_dim,
                    kernel_size=kernel_sizes[k],
                    stride=strides[k]
                )
                self._encoder.append(blk)

                in_channels = emb_dim
                emb_dim *= 2

        # Build decoder
        if self.backbone in ['fpn', 'FPN']:
            self._decoder: Decoder = Decoder(config)

    def init_weights(self) -> None:
        """
        Initialize model weights.
        """
        # TODO: Implement weight initialization

    def forward(self, x: Tensor, verbose: bool = False) -> Dict[str, Tensor]:
        """
        Forward pass of the HierarchicalViT model.

        Args:
            x (Tensor): Input tensor.
            verbose (bool): If True, print shape information.

        Returns:
            Dict[str, Tensor]: Output feature maps.
        """
        down: Dict[str, Tensor] = {}
        if self.backbone in ['fpn', 'FPN']:
            for stage_id, module in enumerate(self._encoder):
                x = module(x)
                down[f'C{stage_id}'] = x
            up = self._decoder(down)

        if verbose:
            for key, item in down.items():
                print(f'down {key}', item.shape)
            for key, item in up.items():
                print(f'up {key}', item.shape)
        return up


class RegistrationHead(nn.Sequential):
    """
    Registration head for generating displacement fields.
    """
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3):
        super().__init__()
        conv3d = nn.Conv3d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            padding=kernel_size // 2
        )
        # Initialize weights with small random values
        conv3d.weight = nn.Parameter(torch.zeros_like(conv3d.weight).normal_(0, 1e-5))
        conv3d.bias = nn.Parameter(torch.zeros(conv3d.bias.shape))
        self.add_module('conv3d', conv3d)


class HViT(nn.Module):
    """
    Hierarchical Vision Transformer (HViT) model for image registration.
    """
    def __init__(self, config: dict):
        super(HViT, self).__init__()
        self.upsample_df: bool = config.get('upsample_df', False)
        self.upsample_scale_factor: int = config.get('upsample_scale_factor', 2)
        self.scale_level_df: str = config.get('scale_level_df', 'P1')

        self.deformable: HierarchicalViT = HierarchicalViT(config)
        self.avg_pool: nn.AvgPool3d = nn.AvgPool3d(3, stride=2, padding=1)
        self.spatial_trans: SpatialTransformer = SpatialTransformer(config['data_size'])
        self.reg_head: RegistrationHead = RegistrationHead(
            in_channels=config.get('fpn_channels', 64),
            out_channels=ndims,
            kernel_size=ndims,
        )

    def forward(self, source: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
        """
        Forward pass of the HViT model.

        Args:
            source (Tensor): Source image tensor.
            target (Tensor): Target image tensor.

        Returns:
            Tuple[Tensor, Tensor]: Moved image and displacement field.
        """
        x: Tensor = torch.cat((source, target), dim=1)
        x_dec: Dict[str, Tensor] = self.deformable(x)

        # Extract features at the specified scale level
        x_dec: Tensor = x_dec[self.scale_level_df]
        flow: Tensor = self.reg_head(x_dec)

        if self.upsample_df:
            flow = nn.Upsample(scale_factor=self.upsample_scale_factor,
                               mode='trilinear',
                               align_corners=False)(flow)

        moved: Tensor = self.spatial_trans(source, flow)
        return moved, flow


if __name__ == "__main__":
    # Test the HViT model
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    B = 1
    H, W, D = 160//2, 192//2, 224//2

    for fpn_channels in [64]:
        config = {
            'NUM_CROSS_ATT': _NUM_CROSS_ATT,
            'out_fmaps': ['P4', 'P3', 'P2', 'P1'],
            'scale_level_df': 'P1',
            'upsample_df': True,
            'upsample_scale_factor': 2,
            'fpn_channels': fpn_channels,
            'start_channels': 32,
            'patch_size': 2,
            'bspl': False,

            'backbone_net': 'fpn',
            'in_channels': 1,
            'data_size': [H, W, D],
            'bias': True,
            'norm_type': 'instance',
            'cuda': 0,
            'kernel_size': 3,
            'depths': 1,
            'mlp_ratio': 2,

            'num_heads': 32,
            'drop_path_rate': 0.,
            'qkv_bias': True,
            'drop_rate': 0.,
            'attn_drop_rate': 0.,

            'use_seg_loss': False,
            'use_seg_proxy_loss': False,
            'num_organs': -1
        }

        source = torch.rand([1, 1, H, W, D])
        tgt = torch.rand([1, 1, H, W, D])
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        model = HViT(config)
        model.to(device)

        with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
            source = source.to(dtype=torch.float16).to(device)
            tgt = tgt.to(dtype=torch.float16).to(device)

            moved, flow = model(source, tgt)
            print('\n\nmoved {} flow {}'.format(moved.shape, flow.shape))

            max_mem_mb = torch.cuda.max_memory_allocated() / 1024**3
            print("[+] Maximum memory:\t{:.2f}GB: >>> \t{:.0f} feats".format(max_mem_mb, config['fpn_channels']) if max_mem_mb is not None else "")
            print("[+] Required Total memory:\t{:.2f}GB".format(torch.cuda.get_device_properties(0).total_memory/1024**3))
            print("[+] Trainable params:\t{:.5f} m".format(count_parameters(model)/1e6))

# Loss

In [2]:
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from math import exp
import math
import torch.nn as nn


class Grad3D(torch.nn.Module):
    """
    N-D gradient loss.
    """

    def __init__(self, penalty='l1', loss_mult=None):
        super().__init__()
        self.penalty = penalty
        self.loss_mult = loss_mult

    def forward(self, y_pred):
        dy = torch.abs(y_pred[:, :, 1:, :, :] - y_pred[:, :, :-1, :, :])
        dx = torch.abs(y_pred[:, :, :, 1:, :] - y_pred[:, :, :, :-1, :])
        dz = torch.abs(y_pred[:, :, :, :, 1:] - y_pred[:, :, :, :, :-1])

        if self.penalty == 'l2':
            dy = dy * dy
            dx = dx * dx
            dz = dz * dz

        d = torch.mean(dx) + torch.mean(dy) + torch.mean(dz)
        grad = d / 3.0

        if self.loss_mult is not None:
            grad *= self.loss_mult
        return grad



class DiceLoss(nn.Module):
    """Dice loss"""

    def __init__(self, num_class=36):
        super().__init__()
        self.num_class = num_class

    def forward(self, y_pred, y_true):
        y_true = nn.functional.one_hot(y_true, num_classes=self.num_class)
        y_true = torch.squeeze(y_true, 1)
        y_true = y_true.permute(0, 4, 1, 2, 3).contiguous()
        intersection = y_pred * y_true
        intersection = intersection.sum(dim=[2, 3, 4])
        union = torch.pow(y_pred, 2).sum(dim=[2, 3, 4]) + torch.pow(y_true, 2).sum(dim=[2, 3, 4])
        dsc = (2.*intersection) / (union + 1e-5)
        dsc_loss = (1-torch.mean(dsc))
        return dsc_loss

def DiceScore(y_pred, y_true, num_class):
    y_true = nn.functional.one_hot(y_true, num_classes=num_class)
    y_true = torch.squeeze(y_true, 1)
    y_true = y_true.permute(0, 4, 1, 2, 3).contiguous()
    intersection = y_pred * y_true
    intersection = intersection.sum(dim=[2, 3, 4])
    union = torch.pow(y_pred, 2).sum(dim=[2, 3, 4]) + torch.pow(y_true, 2).sum(dim=[2, 3, 4])
    dsc = (2.*intersection) / (union + 1e-5)
    return dsc


loss_functions = {
    "mse": nn.MSELoss(),
    "dice": DiceLoss(),
    "grad": Grad3D(penalty='l2')
}


# Utils

In [None]:
import logging
import os
import yaml
from torch import nn

def read_yaml_file(file_path):
    """
    Reads a YAML file and returns the content as a dictionary.

    Parameters:
    file_path (str): The path to the YAML file to read.

    Returns:
    dict: The content of the YAML file as a dictionary.
    """
    with open(file_path, 'r') as file:
        try:
            content = yaml.safe_load(file)  # Load the YAML file content
            return content
        except yaml.YAMLError as e:
            print(f"Error reading YAML file: {e}")
            return None


class Logger:
    def __init__(self, save_dir):
        self.logger = logging.getLogger(__name__)
    def __init__(self, save_dir):
        self.logger = logging.getLogger(__name__)
        self.logger.setLevel(logging.INFO)

        # Create handlers
        console_handler = logging.StreamHandler()

        # Create the directory if it doesn't exist
        os.makedirs(save_dir, exist_ok=True)
        file_handler = logging.FileHandler(os.path.join(save_dir, "logfile.log"))

        # Create formatters and add it to handlers
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        console_handler.setFormatter(formatter)
        file_handler.setFormatter(formatter)

        # Add handlers to the logger
        self.logger.addHandler(console_handler)
        self.logger.addHandler(file_handler)
        self.logger.addHandler(file_handler)

    def info(self, message):
        self.logger.info(message)

    def warning(self, message):
        self.logger.warning(message)

    def error(self, message):
        self.logger.error(message)

    def debug(self, message):
        self.logger.debug(message)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)



def get_one_hot(inp_seg, num_labels):
    B, C, H, W, D = inp_seg.shape
    inp_onehot = nn.functional.one_hot(inp_seg.long(), num_classes=num_labels)
    inp_onehot = inp_onehot.squeeze(dim=1)
    inp_onehot = inp_onehot.permute(0, 4, 1, 2, 3).contiguous()
    return inp_onehot


# Trainer

In [None]:
import math
import torch
from lightning import LightningModule

from src import logger, checkpoints_dir
from src.model.hvit import HViT
from src.model.hvit_light import HViT_Light
from src.loss import loss_functions, DiceScore
from src.utils import get_one_hot

dtype_map = {
    'bf16': torch.bfloat16,
    'fp32': torch.float32,
    'fp16': torch.float16
}

class LiTHViT(LightningModule):
    def __init__(self, args, config, wandb_logger=None, save_model_every_n_epochs=10):
        super().__init__()
        self.automatic_optimization = False
        self.args = args
        self.config = config
        self.best_val_loss = 1e8
        self.save_model_every_n_epochs = save_model_every_n_epochs
        self.lr = args.lr
        self.last_epoch = 0
        self.tgt2src_reg = args.tgt2src_reg
        self.hvit_light = args.hvit_light
        self.precision = args.precision

        self.hvit = HViT_Light(config) if self.hvit_light else HViT(config)

        self.loss_weights = {
            "mse": self.args.mse_weights,
            "dice": self.args.dice_weights,
            "grad": self.args.grad_weights
        }
        self.wandb_logger = wandb_logger
        self.test_step_outputs = []

    def _forward(self, batch, calc_score: bool = False, tgt2src_reg: bool = False):
        _loss = {}
        _score = 0.


        dtype_ = dtype_map.get(self.precision, torch.float32)

        with torch.amp.autocast(device_type="cuda", dtype=dtype_):
            if tgt2src_reg:
                target, source = batch[0].to(dtype=dtype_), batch[1].to(dtype=dtype_)
                tgt_seg, src_seg = batch[2], batch[3]
            else:
                source, target = batch[0].to(dtype=dtype_), batch[1].to(dtype=dtype_)
                src_seg, tgt_seg = batch[2], batch[3]

            moved, flow = self.hvit(source, target)

            if calc_score:
                moved_seg = self._get_one_hot_from_src(src_seg, flow, self.args.num_labels)
                _score = DiceScore(moved_seg, tgt_seg.long(), self.args.num_labels)

            _loss = {}
            for key, weight in self.loss_weights.items():
                if key == "mse":
                    _loss[key] = weight * loss_functions[key](moved, target)
                elif key == "dice":
                    moved_seg = self._get_one_hot_from_src(src_seg, flow, self.args.num_labels)
                    _loss[key] = weight * loss_functions[key](moved_seg, tgt_seg.long())
                elif key == "grad":
                    _loss[key] = weight * loss_functions[key](flow)

            _loss["avg_loss"] = sum(_loss.values()) / len(_loss)
        return _loss, _score

    def training_step(self, batch, batch_idx):
        self.hvit.train()
        opt = self.optimizers()

        loss1, _ = self._forward(batch, calc_score=False)
        self.manual_backward(loss1["avg_loss"])
        opt.step()
        opt.zero_grad()

        if self.tgt2src_reg:
            loss2, _ = self._forward(batch, tgt2src_reg=True, calc_score=False)
            self.manual_backward(loss2["avg_loss"])
            opt.step()
            opt.zero_grad()

        total_loss = {
            key: (loss1[key].item() + loss2[key].item()) / 2 if self.tgt2src_reg and key in loss2 else loss1[key].item()
            for key in loss1.keys()
        }

        self.wandb_logger.log_metrics(total_loss, step=self.global_step)
        return total_loss

    def on_train_epoch_end(self):
        if self.current_epoch % self.save_model_every_n_epochs == 0:
            checkpoint_path = f"{checkpoints_dir}/model_epoch_{self.current_epoch}.ckpt"
            self.trainer.save_checkpoint(checkpoint_path)
            logger.info(f"Saved model at epoch {self.current_epoch}")

        current_lr = self.optimizers().param_groups[0]['lr']
        self.wandb_logger.log_metrics({"learning_rate": current_lr}, step=self.global_step)


    def validation_step(self, batch, batch_idx):
        with torch.no_grad():
            self.hvit.eval()
            _loss, _score = self._forward(batch, calc_score=True)

        # Log each component of the validation loss
        for loss_name, loss_value in _loss.items():
            self.log(f"val_{loss_name}", loss_value, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)

        # Log the mean validation score if available
        if _score is not None:
            self.log("val_score", _score.mean(), on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)

        # Log to wandb
        log_dict = {f"val_{k}": v.item() for k, v in _loss.items()}
        log_dict.update({
            "val_score_mean": _score.mean().item() if _score is not None else None,
        })
        self.wandb_logger.log_metrics({k: v for k, v in log_dict.items() if v is not None}, step=self.global_step)

        return {"val_loss": _loss["avg_loss"], "val_score": _score.mean().item()}

    def on_validation_epoch_end(self):
        """
        Callback method called at the end of the validation epoch.
        Saves the best model based on validation loss and logs metrics.
        """
        val_loss = self.trainer.callback_metrics.get("val_loss")

        if val_loss is not None and self.current_epoch > 0:
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                best_model_path = f"{checkpoints_dir}/best_model.ckpt"
                self.trainer.save_checkpoint(best_model_path)
                self.wandb_logger.experiment.log({
                    "best_model_saved": best_model_path,
                    "best_val_loss": self.best_val_loss.item()
                })
                logger.info(f"New best model saved with validation loss: {self.best_val_loss:.4f}")

    def test_step(self, batch, batch_idx):
        """
        Performs a single test step on a batch of data.

        Args:
            batch: The input batch of data.
            batch_idx: The index of the current batch.

        Returns:
            A dictionary containing the test Dice score.
        """
        with torch.no_grad():
            self.hvit.eval()
            _, _score = self._forward(batch, calc_score=True)

        # Ensure _score is a tensor and take the mean
        _score = _score.mean() if isinstance(_score, torch.Tensor) else torch.tensor(_score).mean()

        self.test_step_outputs.append(_score)

        # Log to wandb only if the logger is available
        if self.wandb_logger:
            self.wandb_logger.log_metrics({"test_dice": _score.item()}, step=self.global_step)

        # Return as a dict with tensor values
        return {"test_dice": _score}

    def on_test_epoch_end(self):
        """
        Callback method called at the end of the test epoch.
        Computes and logs the average test Dice score.
        """
        # Calculate the average Dice score across all test steps
        avg_test_dice = torch.stack(self.test_step_outputs).mean()

        # Log the average test Dice score
        self.log("avg_test_dice", avg_test_dice, prog_bar=True)

        # Log to wandb if available
        if self.wandb_logger:
            self.wandb_logger.log_metrics({"total_test_dice_avg": avg_test_dice.item()})

        # Clear the test step outputs list for the next test epoch
        self.test_step_outputs.clear()


    def configure_optimizers(self):
        """
        Configures the optimizer and learning rate scheduler for the model.

        Returns:
            A dictionary containing the optimizer and learning rate scheduler configuration.
        """
        optimizer = torch.optim.Adam(self.hvit.parameters(), lr=self.lr, weight_decay=0, amsgrad=True)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=self.lr_lambda)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",
                "frequency": 1,
            },
        }

    def lr_lambda(self, epoch):
        """
        Defines the learning rate schedule.

        Args:
            epoch: The current epoch number.

        Returns:
            The learning rate multiplier for the given epoch.
        """
        max_epochs = self.trainer.max_epochs
        return math.pow(1 - epoch / max_epochs, 0.9)

    @classmethod
    def load_from_checkpoint(cls, checkpoint_path, args=None, wandb_logger=None):
        """
        Loads a model from a checkpoint file.

        Args:
            checkpoint_path: Path to the checkpoint file.
            args: Optional arguments to override saved ones.
            wandb_logger: Optional WandB logger instance.

        Returns:
            An instance of the model loaded from the checkpoint.
        """
        checkpoint = torch.load(checkpoint_path, map_location='cpu')

        args = args or checkpoint.get('hyper_parameters', {}).get('args')
        config = checkpoint.get('hyper_parameters', {}).get('config')

        model = cls(args, config, wandb_logger)
        model.load_state_dict(checkpoint['state_dict'])

        if 'hyper_parameters' in checkpoint:
            hyper_params = checkpoint['hyper_parameters']
            for attr in ['lr', 'best_val_loss', 'last_epoch']:
                setattr(model, attr, hyper_params.get(attr, getattr(model, attr)))

        return model

    def on_save_checkpoint(self, checkpoint):
        """
        Callback to save additional information in the checkpoint.

        Args:
            checkpoint: The checkpoint dictionary to be saved.
        """
        checkpoint['hyper_parameters'] = {
            'config': self.config,
            'lr': self.lr,
            'best_val_loss': self.best_val_loss,
            'last_epoch': self.current_epoch
        }

    def _get_one_hot_from_src(self, src_seg, flow, num_labels):
        """
        Converts source segmentation to one-hot encoding and applies deformation.

        Args:
            src_seg: Source segmentation.
            flow: Deformation flow.
            num_labels: Number of segmentation labels.

        Returns:
            Deformed one-hot encoded segmentation.
        """
        src_seg_onehot = get_one_hot(src_seg, self.args.num_labels)
        deformed_segs = [
            self.hvit.spatial_trans(src_seg_onehot[:, i:i+1, ...].float(), flow.float())
            for i in range(num_labels)
        ]
        return torch.cat(deformed_segs, dim=1)


# Main

In [None]:
import sys, os
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))


import argparse
import wandb
import torch
from lightning import Trainer
from lightning.pytorch.loggers import WandbLogger

# Add this line after the imports
torch.set_float32_matmul_precision('medium')

from src import logger
from src.trainer import LiTHViT
from src.utils import read_yaml_file
from src.data.datasets import get_dataloader


def parse_arguments():
    parser = argparse.ArgumentParser(description="Run training or inference")
    parser.add_argument("--num_gpus", type=int, default='1', help="Number of GPUs to use. Use '-1' for all available GPUs.")
    parser.add_argument("--experiment_name", type=str, default="OASIS", help="Experiment name")
    parser.add_argument("--mode", choices=["train", "inference"], default="train", help="Mode to run: train or inference")
    parser.add_argument("--train_data_path", type=str, default="/dss/dssmcmlfs01/pr62la/pr62la-dss-0002/Mori/DATA/OASIS/OASIS_L2R_2021_task03/train", help="Path to the train set")
    parser.add_argument("--val_data_path", type=str, default="/dss/dssmcmlfs01/pr62la/pr62la-dss-0002/Mori/DATA/OASIS/OASIS_L2R_2021_task03/test", help="Path to the validation set")
    parser.add_argument("--test_data_path", type=str, default="/home/mori/HViT/OASIS_small/test", help="Path to the test set")
    parser.add_argument("--checkpoint_path", type=str, default=None, help="Path to the model/checkpoint_path to load")
    parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Path to the best model")
    parser.add_argument("--mse_weights",type=float, default=1, help="MSE Loss weights")
    parser.add_argument("--dice_weights", type=float, default=1, help="Dice Loss weights")
    parser.add_argument("--grad_weights", type=float, default=0.02, help="Grad Loss weights")
    parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
    parser.add_argument("--batch_size", type=int, default=2, help="Batch size")
    parser.add_argument("--tgt2src_reg", type=bool, default=True, help="target to source registration during training")
    parser.add_argument("--num_workers", type=int, default=4, help="Number of workers")
    parser.add_argument("--max_epochs", type=int, default=1000, help="Maximum number of epochs")
    parser.add_argument("--num_labels", type=int, default=36, help="Number of labels")
    parser.add_argument("--precision", type=str, default='bf16', help="Precision")
    parser.add_argument("--hvit_light", type=bool, default=True, help="Use HViT-Light")
    args = parser.parse_args()
    return args

def main():
    args = parse_arguments()
    config = read_yaml_file("./config/config.yaml")

    # Initialize a single WandbLogger instance
    wandb_logger = WandbLogger(project="wandb_HViT", name=args.experiment_name)

    # get dataloaders
    train_dataloader = get_dataloader(data_path = args.train_data_path,
                                      input_dim=config["data_size"],
                                      batch_size=args.batch_size,
                                      is_pair=False)

    val_dataloader = get_dataloader(data_path = args.val_data_path,
                                    input_dim=config["data_size"],
                                    batch_size=args.batch_size,
                                    shuffle = False,
                                    is_pair=True)

    # Determine number of GPUs to use
    devices = min(int(args.num_gpus), torch.cuda.device_count()) if args.num_gpus > 0 else -1
    print(f"Using {devices} GPUs ...")

    # setup trainer for specified number of GPUs
    trainer = Trainer(max_epochs=args.max_epochs,
                      logger=[wandb_logger],
                      precision=args.precision,
                      accelerator="gpu",
                      devices=devices,
                      strategy="ddp" if devices > 1 else "auto")  # Use "auto" for single GPU

    # train/test
    if args.mode == "train":
        if args.resume_from_checkpoint:
            model = LiTHViT.load_from_checkpoint(args.resume_from_checkpoint, args=args, wandb_logger=wandb_logger)
            print(f"Resuming training from epoch {model.last_epoch + 1}")
        else:
            model = LiTHViT(args, config, wandb_logger=wandb_logger)
            print("Starting new training run")
        logger.info("Starting training")
        trainer.fit(model,
                    train_dataloaders=train_dataloader,
                    val_dataloaders=val_dataloader,
                    datamodule=None,
                    ckpt_path=args.resume_from_checkpoint)

    elif args.mode == "inference":
        logger.info("Starting inference")

        test_dataloader = get_dataloader(data_path = args.test_data_path,
                                input_dim=config["data_size"],
                                is_pair=True,
                                batch_size=args.batch_size,
                                shuffle = False)


        # # Get the latest checkpoint folder
        # checkpoints_dir = Path("checkpoints")
        # checkpoints = sorted([d for d in checkpoints_dir.iterdir() if d.is_dir()], key=os.path.getctime, reverse=True)
        # latest_checkpoint = checkpoints[1] if checkpoints else None

        # if os.path.exists(latest_checkpoint):
        #     logger.info(f"Using latest checkpoint: {latest_checkpoint}")

        #     if args.checkpoint_path:
        #         best_model_path = f"{args.checkpoint_path}/best_model.ckpt"
        #     else:
        #         best_model_path = f"{latest_checkpoint}/best_model.ckpt"

        if args.checkpoint_path:
            model = LiTHViT.load_from_checkpoint(args.checkpoint_path, args=args, wandb_logger=wandb_logger)
            print(f"Checkpoint loaded. Resuming from epoch {model.last_epoch + 1}")
        else:
            raise Exception("No checkpoint found")
        trainer.test(model, dataloaders=test_dataloader)

    # Finish the wandb run
    wandb.finish()

if __name__ == "__main__":
    main()