<a href="https://colab.research.google.com/github/alim98/Thesis/blob/main/h_vit_working.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install the required packages
# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
!pip install einops timm lightning wandb monai gitpython




#Correct Code

In [None]:
import logging
import os

class Logger:
    def __init__(self, save_dir):
        self.logger = logging.getLogger(__name__)
        self.logger.setLevel(logging.INFO)

        # Create handlers
        console_handler = logging.StreamHandler()

        # Create the directory if it doesn't exist
        os.makedirs(save_dir, exist_ok=True)
        file_handler = logging.FileHandler(os.path.join(save_dir, "logfile.log"))

        # Create formatters and add it to handlers
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        console_handler.setFormatter(formatter)
        file_handler.setFormatter(formatter)

        # Add handlers to the logger
        self.logger.addHandler(console_handler)
        self.logger.addHandler(file_handler)

    def info(self, message):
        self.logger.info(message)

    def warning(self, message):
        self.logger.warning(message)

    def error(self, message):
        self.logger.error(message)

    def debug(self, message):
        self.logger.debug(message)


In [None]:
import os
import sys
import math
import yaml
import glob
import pickle
import random
import logging
from functools import reduce
from typing import Tuple, Dict, Any, List, Set, Optional, Union, Callable, Type

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import lightning as L
from lightning import LightningModule, Trainer
from lightning.pytorch.loggers import WandbLogger

from einops import rearrange
import timm
import wandb
import monai


In [None]:
# Utility Functions and Classes

def read_yaml_file(file_path):
    """
    Reads a YAML file and returns the content as a dictionary.
    """
    with open(file_path, 'r') as file:
        try:
            content = yaml.safe_load(file)
            return content
        except yaml.YAMLError as e:
            print(f"Error reading YAML file: {e}")
            return None

class Logger:
    def __init__(self, save_dir):
        self.logger = logging.getLogger(__name__)
        self.logger.setLevel(logging.INFO)

        # Create handlers
        console_handler = logging.StreamHandler()

        # Create the directory if it doesn't exist
        os.makedirs(save_dir, exist_ok=True)
        file_handler = logging.FileHandler(os.path.join(save_dir, "logfile.log"))

        # Create formatters and add it to handlers
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        console_handler.setFormatter(formatter)
        file_handler.setFormatter(formatter)

        # Add handlers to the logger
        self.logger.addHandler(console_handler)
        self.logger.addHandler(file_handler)

    def info(self, message):
        self.logger.info(message)

    def warning(self, message):
        self.logger.warning(message)

    def error(self, message):
        self.logger.error(message)

    def debug(self, message):
        self.logger.debug(message)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def get_one_hot(inp_seg, num_labels):
    B, C, H, W, D = inp_seg.shape
    inp_onehot = nn.functional.one_hot(inp_seg.long(), num_classes=num_labels)
    inp_onehot = inp_onehot.squeeze(dim=1)
    inp_onehot = inp_onehot.permute(0, 4, 1, 2, 3).contiguous()
    return inp_onehot


In [None]:
# Loss Functions

class Grad3D(torch.nn.Module):
    """
    N-D gradient loss.
    """
    def __init__(self, penalty='l1', loss_mult=None):
        super().__init__()
        self.penalty = penalty
        self.loss_mult = loss_mult

    def forward(self, y_pred):
        dy = torch.abs(y_pred[:, :, 1:, :, :] - y_pred[:, :, :-1, :, :])
        dx = torch.abs(y_pred[:, :, :, 1:, :] - y_pred[:, :, :, :-1, :])
        dz = torch.abs(y_pred[:, :, :, :, 1:] - y_pred[:, :, :, :, :-1])

        if self.penalty == 'l2':
            dy = dy * dy
            dx = dx * dx
            dz = dz * dz

        d = torch.mean(dx) + torch.mean(dy) + torch.mean(dz)
        grad = d / 3.0

        if self.loss_mult is not None:
            grad *= self.loss_mult
        return grad

class DiceLoss(nn.Module):
    """Dice loss"""
    def __init__(self, num_class=36):
        super().__init__()
        self.num_class = num_class

    def forward(self, y_pred, y_true):
        y_true = nn.functional.one_hot(y_true, num_classes=self.num_class)
        y_true = torch.squeeze(y_true, 1)
        y_true = y_true.permute(0, 4, 1, 2, 3).contiguous()
        intersection = y_pred * y_true
        intersection = intersection.sum(dim=[2, 3, 4])
        union = torch.pow(y_pred, 2).sum(dim=[2, 3, 4]) + torch.pow(y_true, 2).sum(dim=[2, 3, 4])
        dsc = (2.*intersection) / (union + 1e-5)
        dsc_loss = (1-torch.mean(dsc))
        return dsc_loss

def DiceScore(y_pred, y_true, num_class):
    y_true = nn.functional.one_hot(y_true, num_classes=num_class)
    y_true = torch.squeeze(y_true, 1)
    y_true = y_true.permute(0, 4, 1, 2, 3).contiguous()
    intersection = y_pred * y_true
    intersection = intersection.sum(dim=[2, 3, 4])
    union = torch.pow(y_pred, 2).sum(dim=[2, 3, 4]) + torch.pow(y_true, 2).sum(dim=[2, 3, 4])
    dsc = (2.*intersection) / (union + 1e-5)
    return dsc

loss_functions = {
    "mse": nn.MSELoss(),
    "dice": DiceLoss(),
    "grad": Grad3D(penalty='l2')
}


In [None]:
# Data Loading

class OASIS_Dataset(Dataset):
    def __init__(self, input_dim, data_path, num_steps=1000, is_pair: bool = False, ext="pkl"):
        self.paths = glob.glob(os.path.join(data_path, f"*.{ext}"))
        self.num_steps = num_steps
        self.input_dim = input_dim
        self.is_pair = is_pair

        self.transforms_mask = monai.transforms.Compose([
            monai.transforms.Resize(spatial_size=input_dim, mode="nearest")
        ])
        self.transforms_image = monai.transforms.Compose([
            monai.transforms.Resize(spatial_size=input_dim)
        ])

    def _pkload(self, filename: str) -> tuple:
        """
        Load a pickled file and return its contents.
        """
        try:
            with open(filename, 'rb') as file:
                return pickle.load(file)
        except FileNotFoundError:
            raise FileNotFoundError(f"The file {filename} was not found.")
        except pickle.UnpicklingError:
            raise pickle.UnpicklingError(f"Error unpickling the file {filename}.")

    def __getitem__(self, index):
        if self.is_pair:
            src, tgt, src_lbl, tgt_lbl = self._pkload(self.paths[index])
        else:
            selected_items = random.sample(list(self.paths), 2)
            src, src_lbl = self._pkload(selected_items[0])
            tgt, tgt_lbl = self._pkload(selected_items[1])

        src = torch.from_numpy(src).float().unsqueeze(0)
        src_lbl = torch.from_numpy(src_lbl).long().unsqueeze(0)
        tgt = torch.from_numpy(tgt).float().unsqueeze(0)
        tgt_lbl = torch.from_numpy(tgt_lbl).long().unsqueeze(0)

        src = self.transforms_image(src)
        tgt = self.transforms_image(tgt)
        src_lbl = self.transforms_mask(src_lbl)
        tgt_lbl = self.transforms_mask(tgt_lbl)

        return src, tgt, src_lbl, tgt_lbl

    def __len__(self):
        return self.num_steps if not self.is_pair else len(self.paths)

def get_dataloader(data_path, input_dim, batch_size, shuffle: bool = True, is_pair: bool = False):
    ds = OASIS_Dataset(input_dim=input_dim, data_path=data_path, is_pair=is_pair)
    dataloader = DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=4)
    return dataloader


In [None]:
# Model Components: Blocks and Transformer Layers

# Drop Path (Stochastic Depth) Implementation
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0 and scale_by_keep:
        random_tensor.div_(keep_prob)
    return x * random_tensor

class timm_DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks)."""
    def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
        super(timm_DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)

    def extra_repr(self):
        return f'drop_prob={round(self.drop_prob,3):0.3f}'

# Truncated Normal Initialization
def _trunc_normal_(tensor, mean, std, a, b):
    def norm_cdf(x):
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",
                      stacklevel=2)

    l = norm_cdf((a - mean) / std)
    u = norm_cdf((b - mean) / std)

    tensor.uniform_(2 * l - 1, 2 * u - 1)
    tensor.erfinv_()
    tensor.mul_(std * math.sqrt(2.))
    tensor.add_(mean)
    tensor.clamp_(min=a, max=b)
    return tensor

def timm_trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    with torch.no_grad():
        return _trunc_normal_(tensor, mean, std, a, b)

# Convolutional Block with ReLU and Normalization
class Conv3dReLU(nn.Sequential):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            padding=0,
            stride=1,
            use_batchnorm=True,
    ):
        conv = nn.Conv3d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            bias=False,
        )
        relu = nn.LeakyReLU(inplace=True)
        if use_batchnorm:
            nm = nn.BatchNorm3d(out_channels)
        else:
            nm = nn.InstanceNorm3d(out_channels)
        super(Conv3dReLU, self).__init__(conv, nm, relu)

# Normalization and Activation Getters
def get_norm(name, **kwargs):
    if name.lower() == 'batchnorm2d'.lower():
        BatchNorm = getattr(nn, f'BatchNorm{ndims}d')
        return BatchNorm(**kwargs)
    elif name.lower() == 'instance':
        InstanceNorm = getattr(nn, f'InstanceNorm{ndims}d')
        return InstanceNorm(**kwargs)
    elif name.lower() == 'none'.lower():
        return nn.Identity()
    else:
        raise NotImplementedError(f"Normalization '{name}' not implemented.")

def get_activation(name, **kwargs):
    if name.lower() == 'relu':
        return nn.ReLU()
    elif name.lower() == 'gelu':
        return nn.GELU()
    elif name.lower() == 'none':
        return nn.Identity()
    else:
        raise NotImplementedError(f"Activation '{name}' not implemented.")

def prod_func(Vec):
    return reduce(lambda x, y: x*y, Vec)

def downsampler_fn(data, out_size):
    """
    Trilinear downsampling
    """
    return nn.functional.interpolate(data,
                                     size=out_size,
                                     mode='trilinear',
                                     align_corners=False)

# Spatial Transformer
class SpatialTransformer(nn.Module):
    """
    N-D Spatial Transformer
    Obtained from https://github.com/voxelmorph/voxelmorph
    """
    def __init__(self, size, mode='bilinear'):
        super().__init__()

        self.mode = mode

        # create sampling grid
        vectors = [torch.arange(0, s) for s in size]
        grids = torch.meshgrid(vectors)
        grid = torch.stack(grids)
        grid = torch.unsqueeze(grid, 0)
        grid = grid.type(torch.FloatTensor)

        # Register the grid as a buffer
        self.register_buffer('grid', grid)

    def forward(self, src, flow):
        # new locations
        new_locs = self.grid + flow
        shape = flow.shape[2:]

        # normalize grid values to [-1, 1]
        for i in range(len(shape)):
            new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)

        # move channels dim to last position and reverse if necessary
        if len(shape) == 2:
            new_locs = new_locs.permute(0, 2, 3, 1)
            new_locs = new_locs[..., [1, 0]]
        elif len(shape) == 3:
            new_locs = new_locs.permute(0, 2, 3, 4, 1)
            new_locs = new_locs[..., [2, 1, 0]]

        return F.grid_sample(src, new_locs, align_corners=False, mode=self.mode)


# Model

In [None]:
# Model Components: HViT and Related Classes
from torch import Tensor

ndims = 3  # Spatial dimensions

class Attention(nn.Module):
    """
    Attention module for hierarchical vision transformer.
    Implements both local and global attention mechanisms.
    """
    def __init__(
        self,
        dim: int,
        num_heads: int,
        patch_size: Union[int, List[int]],
        attention_type: str = "local",
        qkv_bias: bool = True,
        qk_scale: Optional[float] = None,
        attn_drop: float = 0.,
        proj_drop: float = 0.
    ) -> None:
        super().__init__()

        self.dim = dim
        self.num_heads = num_heads
        self.patch_size = [patch_size] * ndims if isinstance(patch_size, int) else patch_size
        self.attention_type = attention_type

        assert dim % num_heads == 0, "Dimension must be divisible by number of heads"
        self.head_dim = dim // num_heads
        self.scale = qk_scale or self.head_dim ** -0.5

        if self.attention_type == "local":
            self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        elif self.attention_type == "global":
            self.qkv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        else:
            raise NotImplementedError(f"Attention type '{self.attention_type}' not implemented.")

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x: Tensor, q_ms: Optional[Tensor] = None) -> Tensor:
        B_, N, C = x.size()

        if self.attention_type == "local":
            qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
            q, k, v = qkv[0], qkv[1], qkv[2]
            q = q * self.scale
        else:
            B = q_ms.size()[0]
            kv = self.qkv(x).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
            k, v = kv[0], kv[1]
            q = self._process_global_query(q_ms, B, B_, N, C)

        attn = (q @ k.transpose(-2, -1))
        attn = F.softmax(attn, dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)

        return x

    def _process_global_query(self, q_ms: Tensor, B: int, B_: int, N: int, C: int) -> Tensor:
        q_tmp = q_ms.reshape(B, self.num_heads, N, C // self.num_heads)
        div_, rem_ = divmod(B_, B)
        q_tmp = q_tmp.repeat(div_, 1, 1, 1)
        q_tmp = q_tmp.reshape(B * div_, self.num_heads, N, C // self.num_heads)

        q = torch.zeros(B_, self.num_heads, N, C // self.num_heads, device=q_ms.device)
        q[:B*div_] = q_tmp
        if rem_ > 0:
            q[B*div_:] = q_tmp[:rem_]

        return q * self.scale

def get_patches(x: Tensor, patch_size: int) -> Tuple[Tensor, int, int, int]:
    """
    Divide the input tensor into patches and reshape them for processing.
    """
    B, H, W, D, C = x.size()
    nh = H / patch_size
    nw = W / patch_size
    nd = D / patch_size

    down_req = (nh - int(nh)) + (nw - int(nw)) + (nd - int(nd))
    if down_req > 0:
        new_dims = [int(nh) * patch_size, int(nw) * patch_size, int(nd) * patch_size]
        x = downsampler_fn(x.permute(0, 4, 1, 2, 3), new_dims).permute(0, 2, 3, 4, 1)
        B, H, W, D, C = x.size()

    x = x.view(B, H // patch_size, patch_size,
               W // patch_size, patch_size,
               D // patch_size, patch_size,
               C)

    windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, patch_size, patch_size, patch_size, C)

    return windows, H, W, D

def get_image(windows: Tensor, patch_size: int, Hatt: int, Watt: int, Datt: int, H: int, W: int, D: int) -> Tensor:
    """
    Reconstruct the image from windows (patches).
    """
    B = int(windows.size(0) / ((Hatt * Watt * Datt) // (patch_size ** 3)))

    x = windows.view(B,
                    Hatt // patch_size,
                    Watt // patch_size,
                    Datt // patch_size,
                    patch_size, patch_size, patch_size, -1)
    x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, Hatt, Watt, Datt, -1)

    if H != Hatt or W != Watt or D != Datt:
        x = downsampler_fn(x.permute(0, 4, 1, 2, 3), [H, W, D]).permute(0, 2, 3, 4, 1)
    return x

class ViTBlock(nn.Module):
    """
    Vision Transformer Block.
    """
    def __init__(self,
                 embed_dim: int,
                 input_dims: List[int],
                 num_heads: int,
                 mlp_type: str,
                 patch_size: int,
                 mlp_ratio: float,
                 qkv_bias: bool,
                 qk_scale: Optional[float],
                 drop: float,
                 attn_drop: float,
                 drop_path: float,
                 act_layer: str,
                 attention_type: str,
                 norm_layer: Callable[..., nn.Module],
                 layer_scale: Optional[float]):
        super().__init__()
        self.patch_size = patch_size
        self.num_windows = prod_func([d // patch_size for d in input_dims])

        self.norm1 = norm_layer(embed_dim)
        self.attn = Attention(
            dim=embed_dim,
            num_heads=num_heads,
            patch_size=patch_size,
            attention_type=attention_type,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
        )

        self.drop_path = timm_DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(embed_dim)

        self.mlp = Conv3dReLU(
            in_channels=embed_dim,
            out_channels=int(embed_dim * mlp_ratio),
            kernel_size=3,  # Assuming kernel_size=3 for MLP
            padding=1,
            stride=1,
            use_batchnorm=True,
        )

        # Add projection layer to ensure output channels match embed_dim
        self.proj = nn.Conv3d(
            in_channels=int(embed_dim * mlp_ratio),
            out_channels=embed_dim,
            kernel_size=1
        )
        self.layer_scale = layer_scale is not None and isinstance(layer_scale, (int, float))
        if self.layer_scale:
            self.gamma1 = nn.Parameter(layer_scale * torch.ones(embed_dim), requires_grad=True)
            self.gamma2 = nn.Parameter(layer_scale * torch.ones(embed_dim), requires_grad=True)
        else:
            self.gamma1 = 1.0
            self.gamma2 = 1.0


    def forward(self, x: Tensor, q_ms: Optional[Tensor]) -> Tensor:
        B, H, W, D, C = x.size()
        shortcut = x

        # Normalize and compute attention
        x = self.norm1(x)
        x_windows, Hatt, Watt, Datt = get_patches(x, self.patch_size)
        x_windows = x_windows.view(-1, self.patch_size ** 3, C)

        # Compute attention and reconstruct image
        attn_windows = self.attn(x_windows, q_ms)
        x = get_image(attn_windows, self.patch_size, Hatt, Watt, Datt, H, W, D)

        # Apply shortcut and drop path
        x = shortcut + self.drop_path(self.gamma1 * x)

        # Apply MLP
        x_mlp_input = self.norm2(x).permute(0, 4, 1, 2, 3)
        print(f"MLP input shape after permute: {x_mlp_input.shape}")  # Debug print
        x_mlp_output = self.mlp(x_mlp_input)
        x_mlp_output = self.proj(x_mlp_output).permute(0, 2, 3, 4, 1)

        # Add MLP output with drop path and gamma scaling
        x = x + self.drop_path(self.gamma2 * x_mlp_output)
        return x

class PatchEmbed(nn.Module):
    """
    Patch Embedding layer.
    """
    def __init__(self, in_chans: int = 3, out_chans: int = 32,
                 drop_rate: float = 0,
                 kernel_size: int = 3,
                 stride: int = 1, padding: int = 1,
                 dilation: int = 1, groups: int = 1, bias: bool = False) -> None:
        super().__init__()

        Convnd = getattr(nn, f"Conv{ndims}d")
        self.proj = Convnd(in_channels=in_chans, out_channels=out_chans,
                           kernel_size=kernel_size,
                           stride=stride, padding=padding,
                           dilation=dilation, groups=groups, bias=bias)

        self.drop = nn.Dropout(p=drop_rate)

    def forward(self, x: Tensor) -> Tensor:
        x = self.drop(self.proj(x))
        return x

class ViTLayer(nn.Module):
    """
    Vision Transformer Layer.
    """
    def __init__(
        self,
        attention_type: str,
        dim: int,
        dim_out: int,
        depth: int,
        input_dims: List[int],
        num_heads: int,
        patch_size: int,
        mlp_type: str,
        mlp_ratio: float,
        qkv_bias: bool,
        qk_scale: Optional[float],
        drop: float,
        attn_drop: float,
        drop_path: Union[float, List[float]],
        norm_layer: Callable[..., nn.Module],
        norm_type: str,
        layer_scale: Optional[float],
        act_layer: str
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = dim
        self.input_dims = input_dims
        self.blocks = nn.ModuleList([
            ViTBlock(
                embed_dim=dim,
                input_dims=input_dims,
                num_heads=num_heads,
                mlp_type=mlp_type,
                patch_size=patch_size,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                attention_type=attention_type,
                drop=drop,
                attn_drop=attn_drop,
                drop_path=drop_path[k] if isinstance(drop_path, list) else drop_path,
                act_layer=act_layer,
                norm_layer=norm_layer,
                layer_scale=layer_scale
            )
            for k in range(depth)
        ])

    def forward(self, inp: Tensor, q_ms: Optional[Tensor], CONCAT_ok: bool) -> Tensor:
        x = inp.clone()
        x = rearrange(x, 'b c h w d -> b h w d c')

        if q_ms is not None:
            q_ms = rearrange(q_ms, 'b c h w d -> b h w d c')

        for blk in self.blocks:
            if q_ms is None:
                x = blk(x, None)
            else:
                q_ms_patches, _, _, _ = get_patches(q_ms, self.patch_size)
                q_ms_patches = q_ms_patches.view(-1, self.patch_size ** ndims, x.size()[-1])
                x = blk(x, q_ms_patches)

        x = rearrange(x, 'b h w d c -> b c h w d')

        if CONCAT_ok:
            x = torch.cat((inp, x), dim=-1)
        else:
            x = inp + x
        return x

class ViT(nn.Module):
    """
    Vision Transformer (ViT) module for hierarchical feature processing.
    """
    def __init__(self,
                 PYR_SCALES=None,
                 feats_num=None,
                 hid_dim=None,
                 depths=None,
                 patch_size=None,
                 mlp_ratio=None,
                 num_heads=None,
                 mlp_type=None,
                 norm_type=None,
                 act_layer=None,
                 drop_path_rate: float = 0.2,
                 qkv_bias: bool = True,
                 qk_scale: bool = None,
                 drop_rate: float = 0.,
                 attn_drop_rate: float = 0.,
                 norm_layer=nn.LayerNorm,
                 layer_scale=None,
                 img_size=None,
                 NUM_CROSS_ATT=-1):
        super().__init__()

        # Determine the number of levels for processing
        num_levels = len(feats_num)
        num_levels = min(num_levels, NUM_CROSS_ATT) if NUM_CROSS_ATT > 0 else num_levels
        # WO_SELF_ATT is defined globally; set to False as per code
        global WO_SELF_ATT
        if WO_SELF_ATT:
            num_levels -= 1

        # Ensure patch_size is a list
        patch_size = patch_size if isinstance(patch_size, list) else [patch_size for _ in range(num_levels)]
        hwd = img_size[-1]

        # Create patch embedding layers
        self.patch_embed = nn.ModuleList([
            PatchEmbed(
                in_chans=feats_num[i],
                out_chans=hid_dim,
                drop_rate=drop_rate
            ) for i in range(num_levels)
        ])

        # Generate drop path rate for each layer
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]

        # Create ViT layers
        self.levels = nn.ModuleList()
        for i in range(num_levels):
            level = ViTLayer(
                dim=hid_dim,
                dim_out=hid_dim,
                depth=depths[i],
                num_heads=num_heads[i],
                patch_size=patch_size[i],
                mlp_type=mlp_type,
                attention_type="local" if i == 0 else "global",
                drop_path=dpr[sum(depths[:i]):sum(depths[:i+1])],
                input_dims=img_size[i],
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                norm_layer=norm_layer,
                layer_scale=layer_scale,
                norm_type=norm_type,
                act_layer=act_layer
            )
            self.levels.append(level)

    def _init_weights(self, m):
        """Initialize the weights of the module."""
        if isinstance(m, nn.Linear):
            timm_trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        """Return keywords for no weight decay."""
        return {'rpb'}

    def forward(self, KQs, CONCAT_ok: bool = False):
        """
        Forward pass of the ViT module.
        """
        for i, (patch_embed_, level) in enumerate(zip(self.patch_embed, self.levels)):
            if i == 0:
                # First level: process input without cross-attention
                Q = patch_embed_(KQs[i])
                x = level(Q, None, CONCAT_ok=CONCAT_ok)
                Q = patch_embed_(x)
            else:
                # Subsequent levels: process with cross-attention
                K = patch_embed_(KQs[i])
                x = level(Q, K, CONCAT_ok=CONCAT_ok)
                Q = x.clone()

        return x

class EncoderCnnBlock(nn.Module):
    """
    Convolutional block for the encoder part of the network.
    """
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding=1,
        bias=False,
        affine=True,
        eps=1e-05
    ):
        super().__init__()

        # First convolutional block
        conv_block_1 = [
            nn.Conv3d(
                in_channels=in_channels, out_channels=out_channels,
                kernel_size=kernel_size, stride=stride, padding=padding,
                bias=bias
            ),
            nn.InstanceNorm3d(num_features=out_channels, affine=affine, eps=eps),
            nn.ReLU(inplace=True)
        ]

        # Second convolutional block
        conv_block_2 = [
            nn.Conv3d(
                in_channels=out_channels, out_channels=out_channels,
                kernel_size=kernel_size, stride=1, padding=padding,
                bias=bias
            ),
            nn.InstanceNorm3d(num_features=out_channels, affine=affine, eps=eps),
            nn.ReLU(inplace=True)
        ]

        # Combine both blocks
        self._block = nn.Sequential(
            *conv_block_1,
            *conv_block_2
        )

    def forward(self, x):
        """Forward pass of the EncoderCnnBlock."""
        return self._block(x)

class ViT(nn.Module):
    """
    Vision Transformer (ViT) module for hierarchical feature processing.
    """
    def __init__(self,
                 PYR_SCALES=None,
                 feats_num=None,
                 hid_dim=None,
                 depths=None,
                 patch_size=None,
                 mlp_ratio=None,
                 num_heads=None,
                 mlp_type=None,
                 norm_type=None,
                 act_layer=None,
                 drop_path_rate: float = 0.2,
                 qkv_bias: bool = True,
                 qk_scale: bool = None,
                 drop_rate: float = 0.,
                 attn_drop_rate: float = 0.,
                 norm_layer=nn.LayerNorm,
                 layer_scale=None,
                 img_size=None,
                 WO_SELF_ATT=False,  # Added WO_SELF_ATT parameter

                 NUM_CROSS_ATT=-1):
        super().__init__()

        # Determine the number of levels for processing
        num_levels = len(feats_num)
        num_levels = min(num_levels, NUM_CROSS_ATT) if NUM_CROSS_ATT > 0 else num_levels
        if WO_SELF_ATT:
            num_levels -= 1

        # Ensure patch_size is a list
        patch_size = patch_size if isinstance(patch_size, list) else [patch_size for _ in range(num_levels)]
        hwd = img_size[-1]

        # Create patch embedding layers
        self.patch_embed = nn.ModuleList([
            PatchEmbed(
                in_chans=feats_num[i],
                out_chans=hid_dim,
                drop_rate=drop_rate
            ) for i in range(num_levels)
        ])

        # Generate drop path rate for each layer
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]

        # Create ViT layers
        self.levels = nn.ModuleList()
        for i in range(num_levels):
            level = ViTLayer(
                dim=hid_dim,
                dim_out=hid_dim,
                depth=depths[i],
                num_heads=num_heads[i],
                patch_size=patch_size[i],
                mlp_type=mlp_type,
                attention_type="local" if i == 0 else "global",
                drop_path=dpr[sum(depths[:i]):sum(depths[:i+1])],
                input_dims=img_size[i],
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                norm_layer=norm_layer,
                layer_scale=layer_scale,
                norm_type=norm_type,
                act_layer=act_layer
            )
            self.levels.append(level)

    def _init_weights(self, m):
        """Initialize the weights of the module."""
        if isinstance(m, nn.Linear):
            timm_trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        """Return keywords for no weight decay."""
        return {'rpb'}

    def forward(self, KQs, CONCAT_ok: bool = False):
        """
        Forward pass of the ViT module.
        """
        for i, (patch_embed_, level) in enumerate(zip(self.patch_embed, self.levels)):
            if i == 0:
                # First level: process input without cross-attention
                Q = patch_embed_(KQs[i])
                x = level(Q, None, CONCAT_ok=CONCAT_ok)
                Q = patch_embed_(x)
            else:
                # Subsequent levels: process with cross-attention
                K = patch_embed_(KQs[i])
                x = level(Q, K, CONCAT_ok=CONCAT_ok)
                Q = x.clone()

        return x



class Decoder(nn.Module):
    """
    Decoder module for the hierarchical vision transformer.
    """
    def __init__(self, config: Dict[str, Any]):
        super().__init__()
        self._num_stages: int = config['num_stages']
        self.use_seg: bool = config['use_seg_loss']

        # Determine channels of encoder feature maps
        encoder_out_channels: torch.Tensor = torch.tensor([config['start_channels'] * 2**stage for stage in range(self._num_stages)])
        self._NUM_CROSS_ATT=config.get('NUM_CROSS_ATT', -1)
        # Estimate required stages
        required_stages: Set[int] = set(int(fmap[-1]) for fmap in config['out_fmaps'])
        self._required_stages: Set[int] = required_stages

        earliest_required_stage: int = min(required_stages)

        # Lateral connections
        lateral_in_channels: torch.Tensor = encoder_out_channels[earliest_required_stage:]
        lateral_out_channels: torch.Tensor = lateral_in_channels.clip(max=config['fpn_channels'])

        self._lateral: nn.ModuleList = nn.ModuleList([
            nn.Conv3d(in_channels=in_ch, out_channels=out_ch, kernel_size=1)
            for in_ch, out_ch in zip(lateral_in_channels, lateral_out_channels)
        ])
        self._lateral_levels: int = len(self._lateral)

        # Output layers
        out_in_channels: List[int] = [lateral_out_channels[-self._num_stages + required_stage].item() for required_stage in required_stages]
        out_out_channels: List[int] = [int(config['fpn_channels'])] * len(out_in_channels)
        out_out_channels[0] = int(config['fpn_channels'])

        self._out: nn.ModuleList = nn.ModuleList([
            nn.Conv3d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, padding=1)
            for in_ch, out_ch in zip(out_in_channels, out_out_channels)
        ])

        # Upsampling layers
        self._up: nn.ModuleList = nn.ModuleList([
            nn.ConvTranspose3d(
                in_channels=list(reversed(lateral_out_channels))[level],
                out_channels=list(reversed(lateral_out_channels))[level+1],
                kernel_size=list(reversed(config['strides']))[level],
                stride=list(reversed(config['strides']))[level]
            )
            for level in range(len(lateral_out_channels)-1)
        ])

        # Multi-scale attention
        self.hierarchical_dec: nn.ModuleList = self._create_hierarchical_layers(config, out_out_channels)

        if self.use_seg:
            self._seg_head: nn.ModuleList = nn.ModuleList([
                nn.Conv3d(out_ch, config['num_organs'] + 1, kernel_size=1, stride=1)
                for out_ch in out_out_channels
            ])
    def _create_hierarchical_layers(self, config: Dict[str, Any], out_out_channels: List[int]) -> nn.ModuleList:
        """Create hierarchical layers for multi-scale attention."""
        out: nn.ModuleList = nn.ModuleList()
        img_size: List[List[int]] = []
        feats_num: List[int] = []

        num_levels = len(out_out_channels)  # Ensure `num_levels` matches the length of `out_out_channels`

        for k, out_ch in enumerate(out_out_channels):
            img_size.append([int(item / (2 ** (self._num_stages - k - 1))) for item in config['data_size']])
            feats_num.append(out_ch)
            n: int = len(feats_num)

            if k == 0:
                out.append(nn.Identity())
            else:
                # Ensure depths and num_heads have enough entries
                depths = config.get('depths', [1] * num_levels)
                num_heads = config.get('num_heads', [32] * num_levels)

                # Use k or level-based indexing
                out.append(
                    ViT(
                        NUM_CROSS_ATT=config.get('NUM_CROSS_ATT', self._NUM_CROSS_ATT),
                        PYR_SCALES=[1.],
                        feats_num=feats_num,
                        hid_dim=int(config.get('fpn_channels', 64)),
                        depths=depths,  # Use the list directly
                        patch_size=config.get('patch_size', [2] * n),  # Fixed line
                        mlp_ratio=int(config.get('mlp_ratio', 2)),
                        num_heads=num_heads,  # Use the list directly
                        mlp_type='basic',
                        norm_type='BatchNorm2d',
                        act_layer='gelu',
                        drop_path_rate=config.get('drop_path_rate', 0.2),
                        qkv_bias=config.get('qkv_bias', True),
                        qk_scale=None,
                        drop_rate=config.get('drop_rate', 0.),
                        attn_drop_rate=config.get('attn_drop_rate', 0.),
                        norm_layer=nn.LayerNorm,
                        layer_scale=1e-5,
                        img_size=img_size
                    )
                )
        return out


    def forward(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]:
        """Forward pass of the Decoder."""
        lateral_out: List[Tensor] = [lateral(fmap) for lateral, fmap in zip(self._lateral, list(x.values())[-self._lateral_levels:])]

        up_out: List[Tensor] = []
        for idx, x in enumerate(reversed(lateral_out)):
            if idx != 0:
                x = x + up

            if idx < self._lateral_levels - 1:
                up = self._up[idx](x)

            up_out.append(x)

        cnn_outputs: Dict[int, Tensor] = {stage: self._out[idx](fmap) for idx, (fmap, stage) in enumerate(zip(reversed(up_out), self._required_stages))}
        return self._forward_hierarchical(cnn_outputs)

    def _forward_hierarchical(self, cnn_outputs: Dict[int, Tensor]) -> Dict[str, Tensor]:
        """Forward pass through the hierarchical decoder."""
        xs: List[Tensor] = [cnn_outputs[key].clone() for key in range(max(cnn_outputs.keys()), min(cnn_outputs.keys())-1, -1)]

        out_dict: Dict[str, Tensor] = {}
        QK: List[Tensor] = []
        for i, key in enumerate(range(max(cnn_outputs.keys()), min(cnn_outputs.keys())-1, -1)):
            QK = [xs[i]] + QK
            if i == 0:
                Pi = QK[0]
            else:
                Pi = self.hierarchical_dec[i](QK)
            QK[0] = Pi
            out_dict[f'P{key}'] = Pi

            if self.use_seg:
                Pi_seg = self._seg_head[i](Pi)
                out_dict[f'S{key}'] = Pi_seg

        return out_dict



class HierarchicalViT(nn.Module):
    """
    Hierarchical Vision Transformer (HViT) for image processing tasks.
    """
    def __init__(self, config: Dict[str, Any]):
        super().__init__()

        # Configuration parameters
        self.backbone = config['backbone_net']
        in_channels = 2 * config.get('in_channels', 1)  # source + target
        kernel_size = config.get('kernel_size', 3)
        emb_dim = config.get('start_channels', 32)
        data_size = config.get('data_size', [160, 192, 224])
        self.out_fmaps = config.get('out_fmaps', ['P4', 'P3', 'P2', 'P1'])

        # Calculate number of stages
        num_stages = min(int(math.log2(min(data_size))) - 1,
                         max(int(fmap[-1]) for fmap in self.out_fmaps) + 1)

        strides = [1] + [2] * (num_stages - 1)
        kernel_sizes = [kernel_size] * num_stages

        config['num_stages'] = num_stages
        config['strides'] = strides

        # Build encoder
        self._encoder = nn.ModuleList()
        if self.backbone.lower() in ['fpn', 'fpn']:
            for k in range(num_stages):
                blk = EncoderCnnBlock(
                    in_channels=in_channels,
                    out_channels=emb_dim,
                    kernel_size=kernel_sizes[k],
                    stride=strides[k]
                )
                self._encoder.append(blk)

                in_channels = emb_dim
                emb_dim *= 2

        # Build decoder
        if self.backbone.lower() in ['fpn', 'fpn']:
            self._decoder = Decoder(config)

    def init_weights(self) -> None:
        """Initialize model weights."""
        for m in self.modules():
            self._init_weights(m)

    def _init_weights(self, m):
        """Initialize the weights of the module."""
        if isinstance(m, nn.Linear):
            timm_trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x: Tensor, verbose: bool = False) -> Dict[str, Tensor]:
        """
        Forward pass of the HierarchicalViT model.
        """
        down = {}
        if self.backbone.lower() in ['fpn', 'fpn']:
            for stage_id, module in enumerate(self._encoder):
                x = module(x)
                down[f'C{stage_id}'] = x
            up = self._decoder(down)

        if verbose:
            for key, item in down.items():
                print(f'down {key}', item.shape)
            for key, item in up.items():
                print(f'up {key}', item.shape)
        return up

class RegistrationHead(nn.Sequential):
    """
    Registration head for generating displacement fields.
    """
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3):
        super().__init__()
        conv3d = nn.Conv3d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            padding=kernel_size // 2
        )
        # Initialize weights with small random values
        conv3d.weight = nn.Parameter(torch.zeros_like(conv3d.weight).normal_(0, 1e-5))
        conv3d.bias = nn.Parameter(torch.zeros(conv3d.bias.shape))
        self.add_module('conv3d', conv3d)

class HierarchicalViT_Light(nn.Module):
    """
    Light Hierarchical Vision Transformer (HViT) model for image registration.
    """
    def __init__(self, config: dict):
        super(HierarchicalViT_Light, self).__init__()
        self.upsample_df = config.get('upsample_df', False)
        self.upsample_scale_factor = config.get('upsample_scale_factor', 2)
        self.scale_level_df = config.get('scale_level_df', 'P1')
        self.ndims = config.get('ndims', 3)
        self._NUM_CROSS_ATT = config.get('NUM_CROSS_ATT', -1)
        self.deformable = HierarchicalViT(config)
        self.avg_pool = nn.AvgPool3d(3, stride=2, padding=1)
        self.spatial_trans = SpatialTransformer(config['data_size'])
        self.reg_head = RegistrationHead(
            in_channels=config.get('fpn_channels', 64),
            out_channels=ndims,
            kernel_size=ndims,
        )

    def forward(self, source: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
        """
        Forward pass of the HViT model.
        """
        x = torch.cat((source, target), dim=1)
        x_dec = self.deformable(x)

        # Extract features at the specified scale level
        x_dec = x_dec[self.scale_level_df]
        flow = self.reg_head(x_dec)

        if self.upsample_df:
            flow = nn.Upsample(scale_factor=self.upsample_scale_factor,
                               mode='trilinear',
                               align_corners=False)(flow)

        moved = self.spatial_trans(source, flow)
        return moved, flow
# Trainer: Lightning Module

class LiTHViT(LightningModule):
    def __init__(self, args, config, wandb_logger=None, save_model_every_n_epochs=10):
        super().__init__()
        self.automatic_optimization = False
        self.args = args
        self.config = config
        self.best_val_loss = 1e8
        self.save_model_every_n_epochs = save_model_every_n_epochs
        self.lr = args.lr
        self.last_epoch = 0
        self.tgt2src_reg = args.tgt2src_reg
        self.hvit_light = args.hvit_light
        self.precision = args.precision

        # Initialize logger
        self.custom_logger = Logger(save_dir="./logs")

        self.hvit = HierarchicalViT_Light(config) if self.hvit_light else HierarchicalViT(config)

        self.loss_weights = {
            "mse": self.args.mse_weights,
            "dice": self.args.dice_weights,
            "grad": self.args.grad_weights
        }
        self.wandb_logger = wandb_logger
        self.test_step_outputs = []

    def _forward(self, batch, calc_score: bool = False, tgt2src_reg: bool = False):
        _loss = {}
        _score = 0.

        dtype_map = {
            'bf16': torch.bfloat16,
            'fp32': torch.float32,
            'fp16': torch.float16
        }
        dtype_ = dtype_map.get(self.precision, torch.float32)

        with torch.amp.autocast(device_type="cuda", dtype=dtype_):
            if tgt2src_reg:
                target, source = batch[0].to(dtype=dtype_), batch[1].to(dtype=dtype_)
                tgt_seg, src_seg = batch[2], batch[3]
            else:
                source, target = batch[0].to(dtype=dtype_), batch[1].to(dtype=dtype_)
                src_seg, tgt_seg = batch[2], batch[3]

            moved, flow = self.hvit(source, target)

            if calc_score:
                moved_seg = self._get_one_hot_from_src(src_seg, flow, self.args.num_labels)
                _score = DiceScore(moved_seg, tgt_seg.long(), self.args.num_labels)

            _loss = {}
            for key, weight in self.loss_weights.items():
                if key == "mse":
                    _loss[key] = weight * loss_functions[key](moved, target)
                elif key == "dice":
                    moved_seg = self._get_one_hot_from_src(src_seg, flow, self.args.num_labels)
                    _loss[key] = weight * loss_functions[key](moved_seg, tgt_seg.long())
                elif key == "grad":
                    _loss[key] = weight * loss_functions[key](flow)

            _loss["avg_loss"] = sum(_loss.values()) / len(_loss)
        return _loss, _score

    def training_step(self, batch, batch_idx):
        self.hvit.train()
        opt = self.optimizers()

        loss1, _ = self._forward(batch, calc_score=False)
        self.manual_backward(loss1["avg_loss"])
        opt.step()
        opt.zero_grad()

        if self.tgt2src_reg:
            loss2, _ = self._forward(batch, tgt2src_reg=True, calc_score=False)
            self.manual_backward(loss2["avg_loss"])
            opt.step()
            opt.zero_grad()

        total_loss = {
            key: (loss1[key].item() + loss2[key].item()) / 2 if self.tgt2src_reg and key in loss2 else loss1[key].item()
            for key in loss1.keys()
        }

        if self.wandb_logger:
            self.wandb_logger.log_metrics(total_loss, step=self.global_step)
        return total_loss

    def on_train_epoch_end(self):
        if self.current_epoch % self.save_model_every_n_epochs == 0:
            checkpoints_dir = f"./checkpoints/{self.current_epoch}"
            os.makedirs(checkpoints_dir, exist_ok=True)
            checkpoint_path = f"{checkpoints_dir}/model_epoch_{self.current_epoch}.ckpt"
            self.trainer.save_checkpoint(checkpoint_path)
            self.custom_logger.info(f"Saved model at epoch {self.current_epoch}")  # Use custom_logger

        current_lr = self.optimizers().param_groups[0]['lr']
        if self.wandb_logger:
            self.wandb_logger.log_metrics({"learning_rate": current_lr}, step=self.global_step)

    def validation_step(self, batch, batch_idx):
        with torch.no_grad():
            self.hvit.eval()
            _loss, _score = self._forward(batch, calc_score=True)

        # Log each component of the validation loss
        for loss_name, loss_value in _loss.items():
            self.log(f"val_{loss_name}", loss_value, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)

        # Log the mean validation score if available
        if _score is not None:
            self.log("val_score", _score.mean(), on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)

        # Log to wandb
        if self.wandb_logger:
            log_dict = {f"val_{k}": v.item() for k, v in _loss.items()}
            log_dict.update({
                "val_score_mean": _score.mean().item() if _score is not None else None,
            })
            self.wandb_logger.log_metrics({k: v for k, v in log_dict.items() if v is not None}, step=self.global_step)

        return {"val_loss": _loss["avg_loss"], "val_score": _score.mean().item()}

    def on_validation_epoch_end(self):
        """
        Callback method called at the end of the validation epoch.
        Saves the best model based on validation loss and logs metrics.
        """
        val_loss = self.trainer.callback_metrics.get("val_loss")
        checkpoints_dir = f"./checkpoints/{self.current_epoch}"
        if val_loss is not None and self.current_epoch > 0:
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                best_model_path = f"{checkpoints_dir}/best_model.ckpt"
                self.trainer.save_checkpoint(best_model_path)
                if self.wandb_logger:
                    self.wandb_logger.experiment.log({
                        "best_model_saved": best_model_path,
                        "best_val_loss": self.best_val_loss.item()
                    })
                logger.info(f"New best model saved with validation loss: {self.best_val_loss:.4f}")

    def test_step(self, batch, batch_idx):
        """
        Performs a single test step on a batch of data.
        """
        with torch.no_grad():
            self.hvit.eval()
            _, _score = self._forward(batch, calc_score=True)

        _score = _score.mean() if isinstance(_score, torch.Tensor) else torch.tensor(_score).mean()

        self.test_step_outputs.append(_score)

        # Log to wandb only if the logger is available
        if self.wandb_logger:
            self.wandb_logger.log_metrics({"test_dice": _score.item()}, step=self.global_step)

        # Return as a dict with tensor values
        return {"test_dice": _score}

    def on_test_epoch_end(self):
        """
        Callback method called at the end of the test epoch.
        Computes and logs the average test Dice score.
        """
        # Calculate the average Dice score across all test steps
        avg_test_dice = torch.stack(self.test_step_outputs).mean()

        # Log the average test Dice score
        self.log("avg_test_dice", avg_test_dice, prog_bar=True)

        # Log to wandb if available
        if self.wandb_logger:
            self.wandb_logger.log_metrics({"total_test_dice_avg": avg_test_dice.item()})

        # Clear the test step outputs list for the next test epoch
        self.test_step_outputs.clear()

    def configure_optimizers(self):
        """
        Configures the optimizer and learning rate scheduler for the model.
        """
        optimizer = torch.optim.Adam(self.hvit.parameters(), lr=self.lr, weight_decay=0, amsgrad=True)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=self.lr_lambda)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",
                "frequency": 1,
            },
        }

    def lr_lambda(self, epoch):
        """
        Defines the learning rate schedule.
        """
        return math.pow(1 - epoch / self.trainer.max_epochs, 0.9)

    @classmethod
    def load_from_checkpoint(cls, checkpoint_path, args=None, wandb_logger=None):
        """
        Loads a model from a checkpoint file.
        """
        checkpoint = torch.load(checkpoint_path, map_location='cpu')

        args = args or checkpoint.get('hyper_parameters', {}).get('args')
        config = checkpoint.get('hyper_parameters', {}).get('config')

        model = cls(args, config, wandb_logger)
        model.load_state_dict(checkpoint['state_dict'])

        if 'hyper_parameters' in checkpoint:
            hyper_params = checkpoint['hyper_parameters']
            for attr in ['lr', 'best_val_loss', 'last_epoch']:
                setattr(model, attr, hyper_params.get(attr, getattr(model, attr)))

        return model

    def on_save_checkpoint(self, checkpoint):
        """
        Callback to save additional information in the checkpoint.
        """
        checkpoint['hyper_parameters'] = {
            'config': self.config,
            'lr': self.lr,
            'best_val_loss': self.best_val_loss,
            'last_epoch': self.current_epoch
        }

    def _get_one_hot_from_src(self, src_seg, flow, num_labels):
        """
        Converts source segmentation to one-hot encoding and applies deformation.
        """
        src_seg_onehot = get_one_hot(src_seg, self.args.num_labels)
        deformed_segs = [
            self.hvit.spatial_trans(src_seg_onehot[:, i:i+1, ...].float(), flow.float())
            for i in range(num_labels)
        ]
        return torch.cat(deformed_segs, dim=1)
# Instantiate the Model and Run with Dummy Data


In [None]:
config = {
    'WO_SELF_ATT': False,
    '_NUM_CROSS_ATT': -1,
    'out_fmaps': ['P4', 'P3', 'P2', 'P1'],  # Number of levels = 4
    'scale_level_df': 'P1',
    'upsample_df': True,
    'upsample_scale_factor': 2,
    'fpn_channels': 64,
    'start_channels': 32,
    'patch_size': [2, 2, 2, 2],  # Matches number of levels
    'backbone_net': 'fpn',
    'in_channels': 1,
    'data_size': [40, 48, 56],
    'bias': True,
    'norm_type': 'instance',
    'kernel_size': 3,
    'depths': [1, 1, 1, 1],  # Matches number of levels
    'mlp_ratio': 2,
    'num_heads': [4, 8, 16, 32],  # Matches number of levels
    'drop_path_rate': 0.,
    'qkv_bias': True,
    'drop_rate': 0.,
    'attn_drop_rate': 0.,
    'use_seg_loss': False,
    'use_seg_proxy_loss': False,
    'num_organs': -1,
}


# Initialize the model# Initialize the model
model = HierarchicalViT_Light(config)

# Move model to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define dummy inputs
B, C, H, W, D = 1, 1, 40, 48, 56
source = torch.rand([B, C, H, W, D]).to(device)
target = torch.rand([B, C, H, W, D]).to(device)

# Perform a forward pass
moved, flow = model(source, target)
print(f'moved shape: {moved.shape}, flow shape: {flow.shape}')



  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


MLP input shape after permute: torch.Size([1, 64, 10, 12, 14])
MLP input shape after permute: torch.Size([1, 64, 10, 12, 14])
MLP input shape after permute: torch.Size([1, 64, 20, 24, 28])
MLP input shape after permute: torch.Size([1, 64, 20, 24, 28])
MLP input shape after permute: torch.Size([1, 64, 20, 24, 28])
moved shape: torch.Size([1, 1, 40, 48, 56]), flow shape: torch.Size([1, 3, 40, 48, 56])


In [None]:
# Training arguments
class Args:
    def __init__(self):
        self.lr = 0.001
        self.mse_weights = 1.0
        self.dice_weights = 1.0
        self.grad_weights = 1.0
        self.tgt2src_reg = False
        self.hvit_light = True
        self.precision = 'fp32'  # Training precision (e.g., 'bf16', 'fp16', 'fp32')
        self.num_labels = 36

args = Args()

# Define a dummy dataset loader
class DummyDataset(Dataset):
    def __init__(self, num_samples=10, input_dim=(1, 40, 48, 56), num_classes=36):
        self.num_samples = num_samples
        self.input_dim = input_dim
        self.num_classes = num_classes

    def __getitem__(self, index):
        # Create dummy source, target, and their segmentations
        source = torch.rand(self.input_dim)
        target = torch.rand(self.input_dim)
        source_seg = torch.randint(0, self.num_classes, self.input_dim)
        target_seg = torch.randint(0, self.num_classes, self.input_dim)
        return source, target, source_seg, target_seg

    def __len__(self):
        return self.num_samples

# Instantiate dummy dataloader
dummy_dataloader = DataLoader(DummyDataset(), batch_size=2, shuffle=True, num_workers=0)

# Initialize WandB logger (optional, replace with None if not using WandB)
wandb_logger = WandbLogger(project="hvit_dummy_test")

# Instantiate the Lightning module
lit_model = LiTHViT(args, config, wandb_logger=wandb_logger)

# Define the PyTorch Lightning Trainer
trainer = Trainer(
    max_epochs=5,  # Number of epochs
    logger=wandb_logger,  # Log training metrics
    enable_checkpointing=False,  # Disable checkpointing for testing
    devices=1,  # Number of GPUs (set to 0 for CPU)
    accelerator="gpu" if torch.cuda.is_available() else "cpu",  # Use GPU if available
)

# Train the model
trainer.fit(lit_model, train_dataloaders=dummy_dataloader)


INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33malim98barnet[0m ([33malim98barnet-university-of-tehran[0m). Use [1m`wandb login --relogin`[0m to force relogin


INFO: 
  | Name | Type                  | Params | Mode 
-------------------------------------------------------
0 | hvit | HierarchicalViT_Light | 7.2 M  | train
-------------------------------------------------------
7.2 M     Trainable params
0         Non-trainable params
7.2 M     Total params
28.993    Total estimated model params size (MB)
234       Modules in train mode
0         Modules in eval mode
INFO:lightning.pytorch.callbacks.model_summary:
  | Name | Type                  | Params | Mode 
-------------------------------------------------------
0 | hvit | HierarchicalViT_Light | 7.2 M  | train
-------------------------------------------------------
7.2 M     Trainable params
0         Non-trainable params
7.2 M     Total params
28.993    Total estimated model params size (MB)
234       Modules in train mode
0         Modules in eval mode
/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (5) is smaller than the

Training: |          | 0/? [00:00<?, ?it/s]



MLP input shape after permute: torch.Size([2, 64, 10, 12, 14])
MLP input shape after permute: torch.Size([2, 64, 10, 12, 14])
MLP input shape after permute: torch.Size([2, 64, 20, 24, 28])
MLP input shape after permute: torch.Size([2, 64, 20, 24, 28])
MLP input shape after permute: torch.Size([2, 64, 20, 24, 28])
MLP input shape after permute: torch.Size([2, 64, 10, 12, 14])
MLP input shape after permute: torch.Size([2, 64, 10, 12, 14])
MLP input shape after permute: torch.Size([2, 64, 20, 24, 28])
MLP input shape after permute: torch.Size([2, 64, 20, 24, 28])
MLP input shape after permute: torch.Size([2, 64, 20, 24, 28])
MLP input shape after permute: torch.Size([2, 64, 10, 12, 14])
MLP input shape after permute: torch.Size([2, 64, 10, 12, 14])
MLP input shape after permute: torch.Size([2, 64, 20, 24, 28])
MLP input shape after permute: torch.Size([2, 64, 20, 24, 28])
MLP input shape after permute: torch.Size([2, 64, 20, 24, 28])
MLP input shape after permute: torch.Size([2, 64, 10, 1

2025-01-10 08:02:59,619 - __main__ - INFO - Saved model at epoch 0
2025-01-10 08:02:59,619 - __main__ - INFO - Saved model at epoch 0
INFO:__main__:Saved model at epoch 0


MLP input shape after permute: torch.Size([2, 64, 10, 12, 14])
MLP input shape after permute: torch.Size([2, 64, 10, 12, 14])
MLP input shape after permute: torch.Size([2, 64, 20, 24, 28])
MLP input shape after permute: torch.Size([2, 64, 20, 24, 28])
MLP input shape after permute: torch.Size([2, 64, 20, 24, 28])
MLP input shape after permute: torch.Size([2, 64, 10, 12, 14])
MLP input shape after permute: torch.Size([2, 64, 10, 12, 14])
MLP input shape after permute: torch.Size([2, 64, 20, 24, 28])
MLP input shape after permute: torch.Size([2, 64, 20, 24, 28])
MLP input shape after permute: torch.Size([2, 64, 20, 24, 28])
MLP input shape after permute: torch.Size([2, 64, 10, 12, 14])
MLP input shape after permute: torch.Size([2, 64, 10, 12, 14])
MLP input shape after permute: torch.Size([2, 64, 20, 24, 28])
MLP input shape after permute: torch.Size([2, 64, 20, 24, 28])
MLP input shape after permute: torch.Size([2, 64, 20, 24, 28])
MLP input shape after permute: torch.Size([2, 64, 10, 1

INFO: `Trainer.fit` stopped: `max_epochs=5` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.


# Tests

In [None]:

# # Define a dummy configuration
# config = {
#     'WO_SELF_ATT': False,  # Add this parameter

#     '_NUM_CROSS_ATT': -1,
#     'out_fmaps': ['P4', 'P3', 'P2', 'P1'],
#     'scale_level_df': 'P1',
#     'upsample_df': True,
#     'upsample_scale_factor': 2,
#     'fpn_channels': 64,
#     'start_channels': 32,
#     'patch_size': 2,
#     'bspl': False,

#     'backbone_net': 'fpn',
#     'in_channels': 1,
#     'data_size': [40, 48, 56],  # Adjusted to match the dummy data dimensions below
#     'bias': True,
#     'norm_type': 'instance',
#     'cuda': 0,
#     'kernel_size': 3,
#     'depths': [1],
#     'mlp_ratio': 2,

#     'num_heads': [32],
#     'drop_path_rate': 0.,
#     'qkv_bias': True,
#     'drop_rate': 0.,
#     'attn_drop_rate': 0.,

#     'use_seg_loss': False,
#     'use_seg_proxy_loss': False,
#     'num_organs': -1
# }

# # Instantiate the model
# model = HierarchicalViT_Light(config)

# # Move model to appropriate device
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.to(device)

# # Create dummy data
# # Assuming input dimensions as per config['data_size'] = [40, 48, 56]
# B, C, H, W, D = 1, 1, 40, 48, 56
# source = torch.rand([B, C, H, W, D]).to(device)
# target = torch.rand([B, C, H, W, D]).to(device)

# # Run a forward pass
# with torch.no_grad():
#     moved, flow = model(source, target)
#     print(f'moved shape: {moved.shape}, flow shape: {flow.shape}')

# # Check trainable parameters
# total_params = count_parameters(model)
# print(f"Total trainable parameters: {total_params / 1e6:.5f} million")


In [None]:
import torch
from torch import nn
device='cpu'
# Initialize Attention
attention = Attention(dim=64, num_heads=8, patch_size=2, attention_type="local").to(device)
x = torch.rand(4, 16, 64).to(device)  # B_, N, C

# Forward pass
output = attention(x)
print(f'Attention output shape: {output.shape}')
assert output.shape == (4, 16, 64), "Attention output shape mismatch"
x = torch.rand(2, 40, 48, 56, 16).to(device)  # B, H, W, D, C
patch_size = 4

# Extract patches
patches, H, W, D = get_patches(x, patch_size)
print(f'Patches shape: {patches.shape}')

# Reconstruct image
reconstructed = get_image(patches, patch_size, H, W, D, x.shape[1], x.shape[2], x.shape[3])
print(f'Reconstructed shape: {reconstructed.shape}')
assert torch.allclose(x, reconstructed, atol=1e-5), "Reconstruction mismatch"


Attention output shape: torch.Size([4, 16, 64])


In [None]:

class ViTBlock(nn.Module):
    """
    Vision Transformer Block.
    """
    def __init__(self,
                 embed_dim: int,
                 input_dims: List[int],
                 num_heads: int,
                 mlp_type: str,
                 patch_size: int,
                 mlp_ratio: float,
                 qkv_bias: bool,
                 qk_scale: Optional[float],
                 drop: float,
                 attn_drop: float,
                 drop_path: float,
                 act_layer: str,
                 attention_type: str,
                 norm_layer: Callable[..., nn.Module],
                 layer_scale: Optional[float]):
        super().__init__()
        self.patch_size = patch_size
        self.num_windows = prod_func([d // patch_size for d in input_dims])

        self.norm1 = norm_layer(embed_dim)
        self.attn = Attention(
            dim=embed_dim,
            num_heads=num_heads,
            patch_size=patch_size,
            attention_type=attention_type,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
        )

        self.drop_path = timm_DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(embed_dim)

        self.mlp = Conv3dReLU(
            in_channels=embed_dim,
            out_channels=int(embed_dim * mlp_ratio),
            kernel_size=3,  # Assuming kernel_size=3 for MLP
            padding=1,
            stride=1,
            use_batchnorm=True,
        )

        # Add projection layer to ensure output channels match embed_dim
        self.proj = nn.Conv3d(
            in_channels=int(embed_dim * mlp_ratio),
            out_channels=embed_dim,
            kernel_size=1
        )
        self.layer_scale = layer_scale is not None and isinstance(layer_scale, (int, float))
        if self.layer_scale:
            self.gamma1 = nn.Parameter(layer_scale * torch.ones(embed_dim), requires_grad=True)
            self.gamma2 = nn.Parameter(layer_scale * torch.ones(embed_dim), requires_grad=True)
        else:
            self.gamma1 = 1.0
            self.gamma2 = 1.0


    def forward(self, x: Tensor, q_ms: Optional[Tensor]) -> Tensor:
        B, H, W, D, C = x.size()
        shortcut = x

        # Normalize and compute attention
        x = self.norm1(x)
        x_windows, Hatt, Watt, Datt = get_patches(x, self.patch_size)
        x_windows = x_windows.view(-1, self.patch_size ** 3, C)

        # Compute attention and reconstruct image
        attn_windows = self.attn(x_windows, q_ms)
        x = get_image(attn_windows, self.patch_size, Hatt, Watt, Datt, H, W, D)

        # Apply shortcut and drop path
        x = shortcut + self.drop_path(self.gamma1 * x)

        # Apply MLP
        x_mlp_input = self.norm2(x).permute(0, 4, 1, 2, 3)
        print(f"MLP input shape after permute: {x_mlp_input.shape}")  # Debug print
        x_mlp_output = self.mlp(x_mlp_input)
        x_mlp_output = self.proj(x_mlp_output).permute(0, 2, 3, 4, 1)

        # Add MLP output with drop path and gamma scaling
        x = x + self.drop_path(self.gamma2 * x_mlp_output)
        return x


vit_block = ViTBlock(embed_dim=64, input_dims=[40, 48, 56], num_heads=4, mlp_type="basic",
                     patch_size=4, mlp_ratio=2, qkv_bias=True, qk_scale=None, drop=0.1,
                     attn_drop=0.1, drop_path=0.1, act_layer="relu", attention_type="local",
                     norm_layer=nn.LayerNorm, layer_scale=1e-5).to(device)

x = torch.rand(2, 40, 48, 56, 64).to(device)  # Correct input shape
output = vit_block(x, q_ms=None)
print(f'ViTBlock output shape: {output.shape}')
assert output.shape == (2, 40, 48, 56, 64), "ViTBlock output shape mismatch"


MLP input shape after permute: torch.Size([2, 64, 40, 48, 56])
ViTBlock output shape: torch.Size([2, 40, 48, 56, 64])


In [None]:
patch_embed = PatchEmbed(in_chans=1, out_chans=32, kernel_size=3, stride=1).to(device)
x = torch.rand(2, 1, 40, 48, 56).to(device)  # B, C, H, W, D

output = patch_embed(x)
print(f'PatchEmbed output shape: {output.shape}')


PatchEmbed output shape: torch.Size([2, 32, 40, 48, 56])


In [None]:
vit_layer = ViTLayer(attention_type="local", dim=64, dim_out=64, depth=2, input_dims=[40, 48, 56],
                     num_heads=4, patch_size=4, mlp_type="basic", mlp_ratio=2, qkv_bias=True,
                     qk_scale=None, drop=0.1, attn_drop=0.1, drop_path=0.1, norm_layer=nn.LayerNorm,
                     norm_type="instance", layer_scale=1e-5, act_layer="relu").to(device)

x = torch.rand(2, 64, 40, 48, 56).to(device)
output = vit_layer(x, q_ms=None, CONCAT_ok=False)
print(f'ViTLayer output shape: {output.shape}')


MLP input shape after permute: torch.Size([2, 64, 40, 48, 56])
MLP input shape after permute: torch.Size([2, 64, 40, 48, 56])
ViTLayer output shape: torch.Size([2, 64, 40, 48, 56])


In [None]:

class Decoder(nn.Module):
    """
    Decoder module for the hierarchical vision transformer.
    """
    def __init__(self, config: Dict[str, Any]):
        super().__init__()
        self._num_stages: int = config['num_stages']
        self.use_seg: bool = config['use_seg_loss']

        # Determine channels of encoder feature maps
        encoder_out_channels: torch.Tensor = torch.tensor([config['start_channels'] * 2**stage for stage in range(self._num_stages)])
        self._NUM_CROSS_ATT=config.get('NUM_CROSS_ATT', -1)
        # Estimate required stages
        required_stages: Set[int] = set(int(fmap[-1]) for fmap in config['out_fmaps'])
        self._required_stages: Set[int] = required_stages

        earliest_required_stage: int = min(required_stages)

        # Lateral connections
        lateral_in_channels: torch.Tensor = encoder_out_channels[earliest_required_stage:]
        lateral_out_channels: torch.Tensor = lateral_in_channels.clip(max=config['fpn_channels'])

        self._lateral: nn.ModuleList = nn.ModuleList([
            nn.Conv3d(in_channels=in_ch, out_channels=out_ch, kernel_size=1)
            for in_ch, out_ch in zip(lateral_in_channels, lateral_out_channels)
        ])
        self._lateral_levels: int = len(self._lateral)

        # Output layers
        out_in_channels: List[int] = [lateral_out_channels[-self._num_stages + required_stage].item() for required_stage in required_stages]
        out_out_channels: List[int] = [int(config['fpn_channels'])] * len(out_in_channels)
        out_out_channels[0] = int(config['fpn_channels'])

        self._out: nn.ModuleList = nn.ModuleList([
            nn.Conv3d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, padding=1)
            for in_ch, out_ch in zip(out_in_channels, out_out_channels)
        ])

        # Upsampling layers
        self._up: nn.ModuleList = nn.ModuleList([
            nn.ConvTranspose3d(
                in_channels=list(reversed(lateral_out_channels))[level],
                out_channels=list(reversed(lateral_out_channels))[level+1],
                kernel_size=list(reversed(config['strides']))[level],
                stride=list(reversed(config['strides']))[level]
            )
            for level in range(len(lateral_out_channels)-1)
        ])

        # Multi-scale attention
        self.hierarchical_dec: nn.ModuleList = self._create_hierarchical_layers(config, out_out_channels)

        if self.use_seg:
            self._seg_head: nn.ModuleList = nn.ModuleList([
                nn.Conv3d(out_ch, config['num_organs'] + 1, kernel_size=1, stride=1)
                for out_ch in out_out_channels
            ])
    def _create_hierarchical_layers(self, config: Dict[str, Any], out_out_channels: List[int]) -> nn.ModuleList:
        """Create hierarchical layers for multi-scale attention."""
        out: nn.ModuleList = nn.ModuleList()
        img_size: List[List[int]] = []
        feats_num: List[int] = []

        num_levels = len(out_out_channels)  # Ensure `num_levels` matches the length of `out_out_channels`

        for k, out_ch in enumerate(out_out_channels):
            img_size.append([int(item / (2 ** (self._num_stages - k - 1))) for item in config['data_size']])
            feats_num.append(out_ch)
            n: int = len(feats_num)

            if k == 0:
                out.append(nn.Identity())
            else:
                # Ensure depths and num_heads have enough entries
                depths = config.get('depths', [1] * num_levels)
                num_heads = config.get('num_heads', [32] * num_levels)

                # Use k or level-based indexing
                out.append(
                    ViT(
                        NUM_CROSS_ATT=config.get('NUM_CROSS_ATT', self._NUM_CROSS_ATT),
                        PYR_SCALES=[1.],
                        feats_num=feats_num,
                        hid_dim=int(config.get('fpn_channels', 64)),
                        depths=depths,  # Use the list directly
                        patch_size=config.get('patch_size', [2] * n),  # Fixed line
                        mlp_ratio=int(config.get('mlp_ratio', 2)),
                        num_heads=num_heads,  # Use the list directly
                        mlp_type='basic',
                        norm_type='BatchNorm2d',
                        act_layer='gelu',
                        drop_path_rate=config.get('drop_path_rate', 0.2),
                        qkv_bias=config.get('qkv_bias', True),
                        qk_scale=None,
                        drop_rate=config.get('drop_rate', 0.),
                        attn_drop_rate=config.get('attn_drop_rate', 0.),
                        norm_layer=nn.LayerNorm,
                        layer_scale=1e-5,
                        img_size=img_size
                    )
                )
        return out


    def forward(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]:
        """Forward pass of the Decoder."""
        lateral_out: List[Tensor] = [lateral(fmap) for lateral, fmap in zip(self._lateral, list(x.values())[-self._lateral_levels:])]

        up_out: List[Tensor] = []
        for idx, x in enumerate(reversed(lateral_out)):
            if idx != 0:
                x = x + up

            if idx < self._lateral_levels - 1:
                up = self._up[idx](x)

            up_out.append(x)

        cnn_outputs: Dict[int, Tensor] = {stage: self._out[idx](fmap) for idx, (fmap, stage) in enumerate(zip(reversed(up_out), self._required_stages))}
        return self._forward_hierarchical(cnn_outputs)

    def _forward_hierarchical(self, cnn_outputs: Dict[int, Tensor]) -> Dict[str, Tensor]:
        """Forward pass through the hierarchical decoder."""
        xs: List[Tensor] = [cnn_outputs[key].clone() for key in range(max(cnn_outputs.keys()), min(cnn_outputs.keys())-1, -1)]

        out_dict: Dict[str, Tensor] = {}
        QK: List[Tensor] = []
        for i, key in enumerate(range(max(cnn_outputs.keys()), min(cnn_outputs.keys())-1, -1)):
            QK = [xs[i]] + QK
            if i == 0:
                Pi = QK[0]
            else:
                Pi = self.hierarchical_dec[i](QK)
            QK[0] = Pi
            out_dict[f'P{key}'] = Pi

            if self.use_seg:
                Pi_seg = self._seg_head[i](Pi)
                out_dict[f'S{key}'] = Pi_seg

        return out_dict




In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

total_params = count_parameters(model)
print(f"Total trainable parameters: {total_params / 1e6:.5f} million")


Total trainable parameters: 7.24813 million
