<a href="https://colab.research.google.com/github/alim98/Thesis/blob/main/UTSRMorph.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
################################################################################
# 1) INSTALL DEPENDENCIES
################################################################################

!pip install einops timm ml_collections

################################################################################
# 2) IMPORTS
################################################################################

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import math
import numpy as np

from timm.models.layers import DropPath, trunc_normal_, to_3tuple
from einops import rearrange
from torch.distributions.normal import Normal

# needed for the config dictionary
import ml_collections
from typing import Dict, List, Optional, Sequence, Tuple, Union

Collecting ml_collections
  Downloading ml_collections-1.0.0-py3-none-any.whl.metadata (22 kB)
Downloading ml_collections-1.0.0-py3-none-any.whl (76 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.5/76.5 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ml_collections
Successfully installed ml_collections-1.0.0




In [None]:


################################################################################
# 3) CONFIGS (merged from configs_UTSRMorph_dice.py as an example)
################################################################################

def get_UTSRMorph_config():
    """
    Example config used for building the UTSRMorph model.
    Customize as needed.
    """
    config = ml_collections.ConfigDict()
    config.if_transskip = True   # whether to use Transformer skip connections
    config.if_convskip = True    # whether to use Convolution skip connections
    config.patch_size = 4
    config.in_chans = 2
    config.embed_dim = 96
    config.depths = (2, 2, 2, 2)
    config.num_heads = (4, 4, 4, 4)
    config.window_size = (5, 6, 7)
    config.mlp_ratio = 4
    config.pat_merg_rf = 4
    config.qkv_bias = False
    config.drop_rate = 0
    config.drop_path_rate = 0.3
    config.ape = False
    config.spe = False
    config.rpe = True
    config.patch_norm = True
    config.use_checkpoint = False
    config.out_indices = (0, 1, 2, 3)
    config.reg_head_chan = 16
    config.img_size = (160, 192, 224)
    return config


################################################################################
# 4) MODEL CODE (merged from UTSRMorph_dice.py and internal references)
################################################################################

############################
# Channel Attention (CA)
############################
class CA(nn.Module):
    """Channel attention used in RCAN."""
    def __init__(self, num_feat, squeeze_factor=16):
        super(CA, self).__init__()
        self.attention = nn.Sequential(
            nn.AdaptiveAvgPool3d(1),
            nn.Conv3d(num_feat, num_feat // squeeze_factor, 1, padding=0),
            nn.ReLU(inplace=True),
            nn.Conv3d(num_feat // squeeze_factor, num_feat, 1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, x):
        y = self.attention(x)
        return x * y

############################
# CAB: small Conv+CA block
############################
class CAB(nn.Module):
    def __init__(self, num_feat, compress_ratio=3, squeeze_factor=30):
        super(CAB, self).__init__()
        self.cab = nn.Sequential(
            nn.Conv3d(num_feat, num_feat // compress_ratio, 3, 1, 1),
            nn.GELU(),
            nn.Conv3d(num_feat // compress_ratio, num_feat, 3, 1, 1),
            CA(num_feat, squeeze_factor)
        )

    def forward(self, x):
        return self.cab(x)

############################
# Basic MLP
############################
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None,
                 act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

############################
# Patch partition & reverse
############################
def window_partition(x, window_size):
    """
    x: (B, H, W, L, C)
    window_size: (Wx, Wy, Wz)
    """
    B, H, W, L, C = x.shape
    Wx, Wy, Wz = window_size
    # reshape
    x = x.view(B, H // Wx, Wx, W // Wy, Wy, L // Wz, Wz, C)
    # permute + flatten
    windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous()
    windows = windows.view(-1, Wx, Wy, Wz, C)
    return windows

def window_reverse(windows, window_size, H, W, L):
    """
    windows: (num_windows*B, Wx, Wy, Wz, C)
    window_size: (Wx, Wy, Wz)
    """
    Wx, Wy, Wz = window_size
    B = int(windows.shape[0] / (H * W * L / Wx / Wy / Wz))
    x = windows.view(B, H // Wx, W // Wy, L // Wz, Wx, Wy, Wz, -1)
    x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous()
    x = x.view(B, H, W, L, -1)
    return x

############################
# 3D Unfold Helpers (for OAB)
############################
def filter_dilated_rows(tensor: torch.Tensor,
                        dilation: Tuple[int, int, int],
                        dilated_kernel_size: Tuple[int, int, int],
                        kernel_size: Tuple[int, int, int]) -> torch.Tensor:
    """
    Helper to remove extra rows from dilation.
    """
    kernel_rank = len(kernel_size)
    indices_to_keep = [
        list(range(0, dilated_kernel_size[i], dilation[i])) for i in range(kernel_rank)
    ]
    tensor_np = tensor.cpu().numpy()  # to numpy
    axis_offset = len(tensor.shape) - kernel_rank

    for dim in range(kernel_rank):
        tensor_np = np.take(tensor_np, indices_to_keep[dim], axis=axis_offset + dim)

    return torch.from_numpy(tensor_np).to(tensor.device)

def unfold3d(tensor: torch.Tensor,
             kernel_size: Union[int, Tuple[int,int,int]],
             padding: Union[int, Tuple[int,int,int]]=0,
             stride: Union[int, Tuple[int,int,int]]=1,
             dilation: Union[int, Tuple[int,int,int]]=1):
    """
    3D version of Torch's unfold operation.
    """
    if len(tensor.shape) != 5:
        raise ValueError(f"Input must be 5D [B, C, D, H, W]. Got {tensor.shape}")

    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size, kernel_size)
    if isinstance(padding, int):
        padding = (padding, padding, padding)
    if isinstance(stride, int):
        stride = (stride, stride, stride)
    if isinstance(dilation, int):
        dilation = (dilation, dilation, dilation)

    # Pad first
    B, C, D, H, W = tensor.shape
    pad_d, pad_h, pad_w = padding
    tensor = F.pad(tensor, (pad_w, pad_w, pad_h, pad_h, pad_d, pad_d))

    # Effective kernel size with dilation
    dilated_kernel = (
        kernel_size[0] + (kernel_size[0]-1)*(dilation[0]-1),
        kernel_size[1] + (kernel_size[1]-1)*(dilation[1]-1),
        kernel_size[2] + (kernel_size[2]-1)*(dilation[2]-1),
    )

    # unfold
    tensor = tensor.unfold(2, dilated_kernel[0], stride[0])
    tensor = tensor.unfold(3, dilated_kernel[1], stride[1])
    tensor = tensor.unfold(4, dilated_kernel[2], stride[2])

    # remove extraneous rows if dilation > 1
    if dilation != (1,1,1):
        tensor = filter_dilated_rows(tensor, dilation, dilated_kernel, kernel_size)

    # rearrange
    tensor = tensor.permute(0,2,3,4,1,5,6,7)
    # shape: (B, D_out, H_out, W_out, C, kD, kH, kW)
    tensor = tensor.reshape(B, -1, C*np.prod(kernel_size)).transpose(1,2)
    # shape: (B, D_out*H_out*W_out, C*kD*kH*kW)
    return tensor

############################
# OAB: Overlapping Attention
############################
class OAB(nn.Module):
    """
    Overlapping cross-attention block:
      - Queries from standard local windows
      - Keys/Values from larger overlapping window
    """
    def __init__(self, dim, window_size, overlap_ratio, num_heads,
                 qkv_bias=True, qk_scale=None, mlp_ratio=2,
                 norm_layer=nn.LayerNorm, rpe=True):
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim**-0.5
        self.overlap_win_size = int(window_size * overlap_ratio) + window_size

        self.norm1 = norm_layer(dim)
        self.qkv = nn.Linear(dim, dim*3, bias=qkv_bias)

        # relative position bias table
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((window_size + self.overlap_win_size - 1)**3, num_heads)
        )
        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.register_buffer("relative_position_index_OAB",
            self._create_rpe_index(window_size, self.overlap_win_size))

        self.softmax = nn.Softmax(dim=-1)
        self.proj = nn.Linear(dim, dim)

        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(dim, mlp_hidden_dim, act_layer=nn.GELU)
        self.H = self.W = self.T = None
        self.rpe = rpe

    def _create_rpe_index(self, window_size, window_size_ext):
        """
        Build relative position indices for standard window vs extended (overlapping) window.
        """
        ws = window_size
        wse = window_size_ext
        coords_h = torch.arange(ws)
        coords_w = torch.arange(ws)
        coords_t = torch.arange(ws)
        coords_ori = torch.stack(torch.meshgrid([coords_h, coords_w, coords_t]))
        coords_ori_flat = coords_ori.flatten(1)  # shape (3, ws^3)

        coords_he = torch.arange(wse)
        coords_we = torch.arange(wse)
        coords_te = torch.arange(wse)
        coords_ext = torch.stack(torch.meshgrid([coords_he, coords_we, coords_te]))
        coords_ext_flat = coords_ext.flatten(1)  # shape (3, wse^3)

        rel_coords = coords_ext_flat[:,None,:] - coords_ori_flat[:,:,None]  # (3, ws^3, wse^3)
        rel_coords = rel_coords.permute(1,2,0).contiguous()  # (ws^3, wse^3, 3)

        # shift start to 0
        rel_coords[...,0] += ws - wse + 1
        rel_coords[...,1] += ws - wse + 1
        rel_coords[...,2] += ws - wse + 1

        # flatten to a single index
        factor = (ws + wse - 1)*(ws + wse - 1)
        rel_coords[...,0] *= factor
        rel_coords[...,1] *= (ws + wse - 1)
        rel_index = rel_coords.sum(-1)  # (ws^3, wse^3)
        return rel_index.view(-1)

    def forward(self, x, mask=None):
        """
        x shape: (B, H*W*T, C)
        """
        B, L, C = x.shape
        H, W, T = self.H, self.W, self.T
        assert L == H*W*T

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, T, C)

        # pad to multiples of self.window_size
        pad_r = (self.window_size - H % self.window_size) % self.window_size
        pad_b = (self.window_size - W % self.window_size) % self.window_size
        pad_h = (self.window_size - T % self.window_size) % self.window_size
        x = F.pad(x, (0,0, 0,pad_h, 0,pad_b, 0,pad_r))
        _, Hp, Wp, Tp, _ = x.shape

        # Q,K,V
        qkv = self.qkv(x).reshape(B, Hp, Wp, Tp, 3, C).permute(4,0,5,1,2,3)
        # shape: (3, B, C, Hp, Wp, Tp)
        q = qkv[0].permute(0,2,3,4,1)  # (B, Hp, Wp, Tp, C)
        kv = torch.cat([qkv[1], qkv[2]], dim=1)  # (B, 2*C, Hp, Wp, Tp)

        # partition Q into standard windows
        q_windows = window_partition(q, (self.window_size,)*3)  # (nW*B, ws, ws, ws, C)
        q_windows = q_windows.view(-1, self.window_size**3, C)

        # unfold K,V in bigger overlapping windows
        kv_windows = unfold3d(kv, kernel_size=self.overlap_win_size,
                              stride=self.window_size,
                              padding=(self.window_size)//2)
        # kv_windows shape: (B, D_out*H_out*W_out, 2*C*(ow^3)) ???

        # rearrange into separate (k, v)
        # note: we must reshape carefully
        # let's do: (b (nW)) in one dimension, and (ow^3, c) in another
        nc=2; ow=self.overlap_win_size; ch=C
        kv_windows = rearrange(
            kv_windows,
            'b (nc ch owh oww owt) nw -> nc (b nw) (owh oww owt) ch',
            nc=nc, ch=ch, owh=ow, oww=ow, owt=ow
        )
        k_windows, v_windows = kv_windows[0], kv_windows[1]

        # Multi-head attention
        b_, nq, _ = q_windows.shape
        _, n_, _ = k_windows.shape
        d = self.dim // self.num_heads

        q = q_windows.reshape(b_, nq, self.num_heads, d).permute(0, 2, 1, 3)
        k = k_windows.reshape(b_, n_, self.num_heads, d).permute(0, 2, 1, 3)
        v = v_windows.reshape(b_, n_, self.num_heads, d).permute(0, 2, 1, 3)

        q = q*self.scale
        attn = q @ k.transpose(-2, -1)

        # add relative position bias
        if self.rpe:
            rpb = self.relative_position_bias_table[self.relative_position_index_OAB.view(-1)]
            # rpb shape: (ws^3 * wse^3, num_heads)
            # but we have actual shapes: ws^3 -> self.window_size**3
            # wse^3 -> self.overlap_win_size**3
            size_q = self.window_size**3
            size_k = self.overlap_win_size**3
            rpb = rpb.view(size_q, size_k, self.num_heads).permute(2,0,1)
            # shape: (num_heads, ws^3, wse^3)
            attn = attn + rpb.unsqueeze(0)

        attn = self.softmax(attn)
        out = (attn @ v).transpose(1,2).reshape(b_, nq, self.dim)

        # merge small windows back
        out = out.view(-1, self.window_size, self.window_size, self.window_size, C)
        out = window_reverse(out, (self.window_size,)*3, Hp, Wp, Tp)

        # unpad
        if pad_r>0 or pad_b>0 or pad_h>0:
            out = out[:, :H, :W, :T, :].contiguous()

        out = out.view(B, H*W*T, C)
        out = self.proj(out) + shortcut

        # final MLP
        out = out + self.mlp(self.norm2(out))
        return out


############################
# WindowAttention: W-MSA
############################
class WindowAttention(nn.Module):
    """
    Window-based multi-head self attention (W-MSA) with optional relative position bias.
    """
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None,
                 rpe=True, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim**-0.5
        self.rpe = rpe

        # relative pos bias table
        table_size = (2*window_size[0]-1)*(2*window_size[1]-1)*(2*window_size[2]-1)
        self.relative_position_bias_table = nn.Parameter(torch.zeros(table_size, num_heads))
        trunc_normal_(self.relative_position_bias_table, std=.02)

        # build relative index
        coords_h = torch.arange(window_size[0])
        coords_w = torch.arange(window_size[1])
        coords_t = torch.arange(window_size[2])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w, coords_t])) # (3, wH, wW, wT)
        coords_flat = coords.flatten(1) # shape (3, wH*wW*wT)

        if rpe:
            rel = coords_flat[:,:,None] - coords_flat[:,None,:] # (3, n, n)
            rel = rel.permute(1,2,0).contiguous() # (n, n, 3)
            # shift
            rel[...,0] += window_size[0]-1
            rel[...,1] += window_size[1]-1
            rel[...,2] += window_size[2]-1
            # 3D flatten
            pos_factor = (2*window_size[1]-1)*(2*window_size[2]-1)
            rel[...,0] *= pos_factor
            rel[...,1] *= (2*window_size[2]-1)
            rel_index = rel.sum(-1)
            self.register_buffer("relative_position_index", rel_index)

        # qkv
        self.qkv = nn.Linear(dim, dim*3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """
        x: (nW*B, window_size^3, C)
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(
            B_, N, 3, self.num_heads, C//self.num_heads
        ).permute(2,0,3,1,4)  # shape: (3, B_, num_heads, N, d)
        q, k, v = qkv[0], qkv[1], qkv[2]

        q = q*self.scale
        attn = q @ k.transpose(-2, -1)  # (B_, num_heads, N, N)

        if self.rpe:
            # add rpb
            table_size = self.window_size[0]*self.window_size[1]*self.window_size[2]
            rpb = self.relative_position_bias_table[
                self.relative_position_index.view(-1)
            ].view(table_size, table_size, self.num_heads)
            rpb = rpb.permute(2,0,1) # (num_heads, N, N)
            attn = attn + rpb.unsqueeze(0)  # shape broadcast

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_//nW, nW, self.num_heads, N, N)
            attn = attn + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)
        x = (attn@v).transpose(1,2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

############################
# SwinTransformerBlock -> FAB
############################
class FAB(nn.Module):
    """
    Fusion Attention Block = W-MSA (or SW-MSA) + local convolution (CAB).
    """
    def __init__(self, dim, num_heads, window_size=(7,7,7), shift_size=(0,0,0),
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, rpe=True,
                 drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(dim, window_size, num_heads, qkv_bias, qk_scale,
                                    rpe, attn_drop, drop)
        self.drop_path = DropPath(drop_path) if drop_path>0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden = int(dim*mlp_ratio)
        self.mlp = Mlp(dim, mlp_hidden, act_layer=act_layer, drop=drop)

        # local conv
        self.conv_block = CAB(num_feat=dim, compress_ratio=3, squeeze_factor=30)
        self.H = self.W = self.T = None

    def forward(self, x, mask_matrix):
        """
        x: (B, H*W*T, C)
        """
        H, W, T = self.H, self.W, self.T
        B, L, C = x.shape
        assert L == H*W*T

        shortcut = x
        x = self.norm1(x)
        x_3d = x.view(B, H, W, T, C)

        # local conv
        conv_x = self.conv_block(x_3d.permute(0,4,1,2,3))  # (B,C,H,W,T)
        conv_x = conv_x.permute(0,2,3,4,1).contiguous().view(B,H*W*T,C)

        # pad for shifting
        pad_r = (self.window_size[0]-H%self.window_size[0])%self.window_size[0]
        pad_b = (self.window_size[1]-W%self.window_size[1])%self.window_size[1]
        pad_h = (self.window_size[2]-T%self.window_size[2])%self.window_size[2]
        x_3d = F.pad(x_3d, (0,0, 0,pad_h, 0,pad_b, 0,pad_r))
        Hp, Wp, Tp = x_3d.shape[1], x_3d.shape[2], x_3d.shape[3]

        # shift
        if any(self.shift_size):
            shifted_x = torch.roll(x_3d, shifts=(-self.shift_size[0],
                                                 -self.shift_size[1],
                                                 -self.shift_size[2]),
                                   dims=(1,2,3))
            attn_mask = mask_matrix
        else:
            shifted_x = x_3d
            attn_mask = None

        # partition
        x_windows = window_partition(shifted_x, self.window_size)
        x_windows = x_windows.view(-1, self.window_size[0]*self.window_size[1]*self.window_size[2], C)

        # W-MSA or SW-MSA
        attn_windows = self.attn(x_windows, attn_mask)

        # reverse windows
        attn_windows = attn_windows.view(-1, *self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp, Tp)

        # reverse shift
        if any(self.shift_size):
            x_3d = torch.roll(shifted_x, shifts=(self.shift_size[0],
                                                 self.shift_size[1],
                                                 self.shift_size[2]),
                              dims=(1,2,3))
        else:
            x_3d = shifted_x

        if pad_r>0 or pad_b>0 or pad_h>0:
            x_3d = x_3d[:,:H,:W,:T,:].contiguous()

        # reshape back
        x_attn = x_3d.view(B, H*W*T, C)

        # sum up
        x = shortcut + self.drop_path(x_attn) + conv_x*0.01
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

############################
# Patch Merging
############################
class PatchMerging(nn.Module):
    def __init__(self, dim, norm_layer=nn.LayerNorm, reduce_factor=2):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(8*dim, (8//reduce_factor)*dim, bias=False)
        self.norm = norm_layer(8*dim)

    def forward(self, x, H, W, T):
        B, L, C = x.shape
        assert L==H*W*T
        assert H%2==0 and W%2==0 and T%2==0

        x_3d = x.view(B,H,W,T,C)
        x0 = x_3d[:, 0::2, 0::2, 0::2, :]
        x1 = x_3d[:, 1::2, 0::2, 0::2, :]
        x2 = x_3d[:, 0::2, 1::2, 0::2, :]
        x3 = x_3d[:, 0::2, 0::2, 1::2, :]
        x4 = x_3d[:, 1::2, 1::2, 0::2, :]
        x5 = x_3d[:, 0::2, 1::2, 1::2, :]
        x6 = x_3d[:, 1::2, 0::2, 1::2, :]
        x7 = x_3d[:, 1::2, 1::2, 1::2, :]

        x_cat = torch.cat([x0,x1,x2,x3,x4,x5,x6,x7], dim=-1)  # (B,H/2,W/2,T/2, 8*C)
        x_cat = x_cat.view(B, -1, 8*C)
        x_cat = self.norm(x_cat)
        x_cat = self.reduction(x_cat)
        return x_cat

############################
# BasicLayer: stack of FAB + OAB
############################
class BasicLayer(nn.Module):
    """
    A stage in the Swin-like architecture.
    Repeats FAB blocks and then an Overlapping Attention Block (OAB).
    Optionally, does PatchMerging at the end.
    """
    def __init__(self, dim, depth, num_heads, window_size=(7,7,7),
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, rpe=True,
                 drop=0., attn_drop=0., drop_path=0.,
                 norm_layer=nn.LayerNorm,
                 downsample=None,
                 use_checkpoint=False,
                 pat_merg_rf=2):
        super().__init__()
        self.window_size = window_size
        self.shift_size = (window_size[0]//2, window_size[1]//2, window_size[2]//2)
        self.depth = depth
        self.use_checkpoint = use_checkpoint
        self.blocks = nn.ModuleList([
            FAB(dim, num_heads, window_size,
                (0,0,0) if (i%2==0) else self.shift_size,
                mlp_ratio, qkv_bias, qk_scale, rpe, drop, attn_drop,
                drop_path[i] if isinstance(drop_path, list) else drop_path,
                norm_layer=norm_layer)
            for i in range(depth)
        ])

        self.overlap_attn = OAB(dim, window_size=4, overlap_ratio=0.5,
                                num_heads=4, qkv_bias=qkv_bias, qk_scale=qk_scale,
                                mlp_ratio=mlp_ratio, norm_layer=norm_layer)

        if downsample is not None:
            self.downsample = downsample(dim, norm_layer, reduce_factor=pat_merg_rf)
        else:
            self.downsample = None

    def forward(self, x, H, W, T):
        """
        x: (B, H*W*T, C)
        """
        # create the SW-MSA mask
        Hp = int(math.ceil(H/self.window_size[0]))*self.window_size[0]
        Wp = int(math.ceil(W/self.window_size[1]))*self.window_size[1]
        Tp = int(math.ceil(T/self.window_size[2]))*self.window_size[2]
        img_mask = torch.zeros((1, Hp, Wp, Tp, 1), device=x.device)

        h_slices = (slice(0,-self.window_size[0]),
                    slice(-self.window_size[0], -self.shift_size[0]),
                    slice(-self.shift_size[0], None))
        w_slices = (slice(0,-self.window_size[1]),
                    slice(-self.window_size[1], -self.shift_size[1]),
                    slice(-self.shift_size[1], None))
        t_slices = (slice(0,-self.window_size[2]),
                    slice(-self.window_size[2], -self.shift_size[2]),
                    slice(-self.shift_size[2], None))
        cnt=0
        for hh in h_slices:
            for ww in w_slices:
                for tt in t_slices:
                    img_mask[:, hh, ww, tt, :] = cnt
                    cnt+=1

        mask_windows = window_partition(img_mask, self.window_size)
        mask_windows = mask_windows.view(-1, self.window_size[0]*self.window_size[1]*self.window_size[2])
        attn_mask = mask_windows.unsqueeze(1)-mask_windows.unsqueeze(2)
        attn_mask = attn_mask.masked_fill(attn_mask!=0, float(-100.0)).masked_fill(attn_mask==0, float(0.0))

        # pass blocks
        for blk in self.blocks:
            blk.H, blk.W, blk.T = H, W, T
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x, attn_mask)
            else:
                x = blk(x, attn_mask)

        # OverlapAttn
        self.overlap_attn.H, self.overlap_attn.W, self.overlap_attn.T = H, W, T
        if self.use_checkpoint:
            x = checkpoint.checkpoint(self.overlap_attn, x, None)
        else:
            x = self.overlap_attn(x, None)

        # Patch Merging?
        if self.downsample is not None:
            x_down = self.downsample(x, H, W, T)
            Wh, Ww, Wt = (H+1)//2, (W+1)//2, (T+1)//2
            return x, H, W, T, x_down, Wh, Ww, Wt
        else:
            return x, H, W, T, x, H, W, T

############################
# PatchEmbed
############################
class PatchEmbed(nn.Module):
    def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        patch_size = to_3tuple(patch_size)
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.embed_dim = embed_dim
        self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer is not None else None

    def forward(self, x):
        B, C, H, W, T = x.shape
        ps = self.patch_size
        # pad if needed
        if T%ps[2]!=0:
            x = F.pad(x, (0, ps[2]-T%ps[2]))
        if W%ps[1]!=0:
            x = F.pad(x, (0,0, 0, ps[1]-W%ps[1]))
        if H%ps[0]!=0:
            x = F.pad(x, (0,0, 0,0, 0, ps[0]-H%ps[0]))
        x = self.proj(x)
        # x shape: (B, embed_dim, H/ps, W/ps, T/ps)
        if self.norm is not None:
            Hp, Wp, Tp = x.shape[2], x.shape[3], x.shape[4]
            x = x.flatten(2).transpose(1,2)  # (B, n, C)
            x = self.norm(x)
            x = x.transpose(1,2).view(B, self.embed_dim, Hp, Wp, Tp)
        return x

############################
# The main SwinTransformer
############################
class SwinTransformer(nn.Module):
    def __init__(self, pretrain_img_size=224, patch_size=4, in_chans=3,
                 embed_dim=96, depths=(2,2,6,2), num_heads=(3,6,12,24),
                 window_size=(7,7,7), mlp_ratio=4.,
                 qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
                 drop_path_rate=0.2, norm_layer=nn.LayerNorm, ape=False, spe=False,
                 rpe=True, patch_norm=True, out_indices=(0,1,2,3),
                 frozen_stages=-1, use_checkpoint=False, pat_merg_rf=2):
        super().__init__()
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.ape = ape
        self.spe = spe
        self.rpe = rpe
        self.patch_norm = patch_norm
        self.out_indices = out_indices
        self.frozen_stages = frozen_stages

        # patch embed
        self.patch_embed = PatchEmbed(patch_size, in_chans, embed_dim,
                                      norm_layer if patch_norm else None)
        self.pos_drop = nn.Dropout(p=drop_rate)

        # absolute pos embedding
        if self.ape:
            pass  # not used in these configs
        elif self.spe:
            # sinusoidal 3D pos embed
            self.pos_embd = SinPositionalEncoding3D(embed_dim).cuda()

        # drop path schedule
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]

        # build layers
        self.layers = nn.ModuleList()
        for i in range(self.num_layers):
            layer = BasicLayer(dim=int(embed_dim*(2**i)),
                               depth=depths[i],
                               num_heads=num_heads[i],
                               window_size=window_size,
                               mlp_ratio=mlp_ratio,
                               qkv_bias=qkv_bias,
                               qk_scale=qk_scale,
                               rpe=rpe,
                               drop=drop_rate,
                               attn_drop=attn_drop_rate,
                               drop_path=dpr[sum(depths[:i]):sum(depths[:i+1])],
                               norm_layer=norm_layer,
                               downsample=PatchMerging if i<self.num_layers-1 else None,
                               use_checkpoint=use_checkpoint,
                               pat_merg_rf=pat_merg_rf)
            self.layers.append(layer)

        # final out dims
        num_features = [int(embed_dim*(2**i)) for i in range(self.num_layers)]
        self.num_features = num_features

        # norms for outputs
        for i_layer in out_indices:
            layer_ = norm_layer(num_features[i_layer])
            self.add_module(f"norm{i_layer}", layer_)

        self._freeze_stages()

    def _freeze_stages(self):
        if self.frozen_stages>=0:
            self.patch_embed.eval()
            for p in self.patch_embed.parameters():
                p.requires_grad=False

    def forward(self, x):
        # patch embed
        x = self.patch_embed(x)
        B, C, Hp, Wp, Tp = x.shape

        if self.ape:
            pass  # skipping usage
        elif self.spe:
            x = (x + self.pos_embd(x)).flatten(2).transpose(1,2)
        else:
            x = x.flatten(2).transpose(1,2)

        x = self.pos_drop(x)  # (B, n, C)
        outs = []
        # forward stages
        for i in range(self.num_layers):
            layer = self.layers[i]
            x_out, H, W, T, x, Wh, Ww, Wt = layer(x, Hp, Wp, Tp)
            # store
            if i in self.out_indices:
                norm_layer = getattr(self, f"norm{i}")
                x_out = norm_layer(x_out)
                out_f = x_out.view(B, H, W, T, self.num_features[i]).permute(0,4,1,2,3)
                outs.append(out_f)
            # update resolution
            Hp, Wp, Tp = Wh, Ww, Wt
        return outs

    def train(self, mode=True):
        super().train(mode)
        self._freeze_stages()

############################
# Sinusoidal (optionally)
############################
class SinPositionalEncoding3D(nn.Module):
    def __init__(self, channels):
        super().__init__()
        channels = int(np.ceil(channels/6)*2)
        if channels%2:
            channels+=1
        self.channels = channels
        self.inv_freq = 1. / (10000**(torch.arange(0, channels, 2).float()/channels))

    def forward(self, tensor):
        # expecting (B, C, X, Y, Z)
        # we transpose to (B, X, Y, Z, C)
        tensor = tensor.permute(0,2,3,4,1)
        B, X, Y, Z, C = tensor.shape
        pos_x = torch.arange(X, device=tensor.device).type(self.inv_freq.type())
        pos_y = torch.arange(Y, device=tensor.device).type(self.inv_freq.type())
        pos_z = torch.arange(Z, device=tensor.device).type(self.inv_freq.type())

        sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
        sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
        sin_inp_z = torch.einsum("i,j->ij", pos_z, self.inv_freq)

        emb_x = torch.cat([sin_inp_x.sin(), sin_inp_x.cos()], dim=-1).unsqueeze(1).unsqueeze(1)
        emb_y = torch.cat([sin_inp_y.sin(), sin_inp_y.cos()], dim=-1).unsqueeze(1)
        emb_z = torch.cat([sin_inp_z.sin(), sin_inp_z.cos()], dim=-1)

        emb = torch.zeros((X,Y,Z,self.channels*3), device=tensor.device, dtype=tensor.dtype)
        emb[:,:,:, :self.channels] = emb_x
        emb[:,:,:, self.channels:2*self.channels] = emb_y
        emb[:,:,:, 2*self.channels:] = emb_z
        emb = emb[None, :,:,:,:C].repeat(B,1,1,1,1)
        return emb.permute(0,4,1,2,3)

############################
# PixelShuffle3D
############################
class PixelShuffle3d(nn.Module):
    def __init__(self, scale):
        super().__init__()
        self.scale = scale

    def forward(self, inp):
        b, c, d, h, w = inp.size()
        oc = c // (self.scale**3)
        od, oh, ow = d*self.scale, h*self.scale, w*self.scale
        out = inp.view(b, oc, self.scale, self.scale, self.scale, d, h, w)
        out = out.permute(0,1,5,2,6,3,7,4).contiguous()
        out = out.view(b, oc, od, oh, ow)
        return out

############################
# ConvergeHead (used for up)
############################
class ConvergeHead(nn.Module):
    def __init__(self, in_dim, up_ratio, kernel_size, padding):
        super().__init__()
        self.in_dim = in_dim
        self.up_ratio = up_ratio
        self.conv = nn.Conv3d(in_dim, (up_ratio**3)*in_dim, kernel_size, 1, padding, 1, in_dim)
        self.apply(self._init_weights)

    def forward(self, x):
        hp = self.conv(x)
        poxel = PixelShuffle3d(self.up_ratio)
        hp = poxel(hp)
        return hp

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv3d, nn.Linear)):
            nn.init.normal_(m.weight, std=0.001)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm3d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)

############################
# SR block for upsampling
############################
class Conv3dReLU(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size, padding=0, stride=1, use_batchnorm=True):
        conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
        relu = nn.LeakyReLU(inplace=True)
        nm = nn.BatchNorm3d(out_channels) if use_batchnorm else nn.InstanceNorm3d(out_channels)
        super().__init__(conv, nm, relu)

class SR(nn.Module):
    """
    Super-resolution upsampling block (PixelShuffle3D).
    """
    def __init__(self, in_channels, out_channels, skip_channels=0, use_batchnorm=True):
        super().__init__()
        self.up = ConvergeHead(in_channels, 2, 3, 1)  # 2x up, kernel=3
        self.conv1 = Conv3dReLU(in_channels+skip_channels, out_channels, 3, 1, use_batchnorm=use_batchnorm)
        self.conv2 = Conv3dReLU(out_channels, out_channels, 3, 1, use_batchnorm=use_batchnorm)

    def forward(self, x, skip=None):
        x = self.up(x)
        if skip is not None:
            x = torch.cat([x, skip], dim=1)
        x = self.conv1(x)
        x = self.conv2(x)
        return x

############################
# Registration Head
############################
class RegistrationHead(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
        conv3d = nn.Conv3d(in_channels, out_channels, kernel_size, padding=kernel_size//2)
        conv3d.weight = nn.Parameter(Normal(0,1e-5).sample(conv3d.weight.shape))
        conv3d.bias = nn.Parameter(torch.zeros(conv3d.bias.shape))
        super().__init__(conv3d)

############################
# Spatial Transformer
############################
class SpatialTransformer(nn.Module):
    """
    N-D Spatial Transformer (via grid_sample).
    """
    def __init__(self, size, mode='bilinear'):
        super().__init__()
        self.mode = mode
        vectors = [torch.arange(0, s) for s in size]
        grids = torch.meshgrid(vectors, indexing='ij' if hasattr(torch, 'meshgrid') else None)
        grid = torch.stack(grids)
        grid = grid.unsqueeze(0).float()  # (1, ndim, D, H, W)
        self.register_buffer('grid', grid)

    def forward(self, src, flow):
        """
        src: (B, C, D, H, W)
        flow: (B, 3, D, H, W)
        """
        new_locs = self.grid + flow
        shape = flow.shape[2:]

        # normalize to [-1,1]
        for i in range(len(shape)):
            new_locs[:, i, ...] = 2*(new_locs[:, i, ...]/(shape[i]-1) - 0.5)

        if len(shape)==3:
            new_locs = new_locs.permute(0,2,3,4,1)
            new_locs = new_locs[..., [2,1,0]]
        return F.grid_sample(src, new_locs, align_corners=False, mode=self.mode)

############################
# UTSRMorph Model
############################
class UTSRMorph(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.if_convskip = config.if_convskip
        self.if_transskip = config.if_transskip
        embed_dim = config.embed_dim

        # 1) Swin Transformer as encoder
        self.transformer = SwinTransformer(
            patch_size=config.patch_size,
            in_chans=config.in_chans,
            embed_dim=config.embed_dim,
            depths=config.depths,
            num_heads=config.num_heads,
            window_size=config.window_size,
            mlp_ratio=config.mlp_ratio,
            qkv_bias=config.qkv_bias,
            drop_rate=config.drop_rate,
            drop_path_rate=config.drop_path_rate,
            ape=config.ape,
            spe=config.spe,
            rpe=config.rpe,
            patch_norm=config.patch_norm,
            use_checkpoint=config.use_checkpoint,
            out_indices=config.out_indices,
            pat_merg_rf=config.pat_merg_rf
        )

        # 2) Decoder (SR blocks)
        self.up0 = SR(embed_dim*8, embed_dim*4,
                      skip_channels=(embed_dim*4 if self.if_transskip else 0),
                      use_batchnorm=False)
        self.up1 = SR(embed_dim*4, embed_dim*2,
                      skip_channels=(embed_dim*2 if self.if_transskip else 0),
                      use_batchnorm=False)
        self.up2 = SR(embed_dim*2, embed_dim,
                      skip_channels=(embed_dim if self.if_transskip else 0),
                      use_batchnorm=False)
        self.up3 = SR(embed_dim,
                      config.reg_head_chan,
                      skip_channels=(embed_dim//2 if self.if_convskip else 0),
                      use_batchnorm=False)

        # optional conv skip
        self.c1 = Conv3dReLU(2, embed_dim//2, 3, 1, use_batchnorm=False)

        # 3) Head
        self.reg_head = RegistrationHead(config.reg_head_chan, 3, 3)
        self.spatial_trans = SpatialTransformer(config.img_size)
        self.avg_pool = nn.AvgPool3d(3, stride=2, padding=1)
        self.up = ConvergeHead(3, 2, 3, 1)

    def forward(self, x):
        """
        x shape: (B, 2, D, H, W)  # 2 -> [moving, fixed]
        """
        B, C, D, H, W = x.shape
        # separate the 'moving' image for final warp
        source = x[:, :1, ...]  # shape: (B,1,D,H,W)

        # if conv skip
        if self.if_convskip:
            x_s1 = self.avg_pool(x)         # downsample
            f4 = self.c1(x_s1)              # (B, embed_dim//2, D/2, H/2, W/2)
        else:
            f4 = None

        # run Swin Transformer
        out_feats = self.transformer(x)  # a list of 4 feature maps (lowest to highest resolution)

        # handle skip from the transformer's multi-scale features
        if self.if_transskip:
            f1 = out_feats[-2]  # second last scale
            f2 = out_feats[-3]
            f3 = out_feats[-4]
        else:
            f1 = f2 = f3 = None

        # decode / upsample
        x = self.up0(out_feats[-1], f1)
        x = self.up1(x, f2)
        x = self.up2(x, f3)
        x = self.up3(x, f4)

        # final reg head
        flow = self.reg_head(x)   # (B,3,D,H,W) at final resolution
        flow = self.up(flow)      # 2x up if needed
        out = self.spatial_trans(source, flow)  # warp the moving image
        return out, flow

################################################################################
# 5) QUICK TEST
################################################################################

if __name__ == "__main__":
    # Create config
    config = get_UTSRMorph_config()
    # Build model
    model = UTSRMorph(config).cuda()
    model.eval()

    # Create a dummy input with shape (B=1, 2 channels, D,H,W)
    # matching config.img_size = (160,192,224)
    # so the input is (1,2,160,192,224)
    dummy = torch.randn(1, 2, 160, 192, 224).cuda()

    with torch.no_grad():
        out, flow = model(dummy)

    print("Output shape:", out.shape)  # expect (1, 1, 160,192,224)
    print("Flow shape:  ", flow.shape) # expect (1, 3, 160,192,224)


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Output shape: torch.Size([1, 1, 160, 192, 224])
Flow shape:   torch.Size([1, 3, 160, 192, 224])


In [None]:
################################################################################
# 0) INSTALL DEPENDENCIES
################################################################################

!pip install einops timm ml_collections surface-distance

################################################################################
# 1) IMPORTS
################################################################################

import os, sys, math, glob, random, pickle, datetime
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import matplotlib.pyplot as plt
from natsort import natsorted
from typing import Dict, List, Optional, Sequence, Tuple, Union
from math import exp
import scipy.ndimage
import torch.optim as optim

# third-party
import ml_collections
from einops import rearrange
from timm.models.layers import DropPath, trunc_normal_, to_3tuple
from surface_distance import compute_surface_distances, compute_robust_hausdorff, compute_dice_coefficient


################################################################################
# 2) UTILS: pkload, register_model, AverageMeter, dice_val, etc.
#    (Merged from 'utils.py')
################################################################################

def pkload(fname):
    with open(fname, 'rb') as f:
        return pickle.load(f)

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.vals = []
        self.std = 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        self.vals.append(val)
        self.std = np.std(self.vals)

def pad_image(img, target_size):
    rows_to_pad = max(target_size[0] - img.shape[2], 0)
    cols_to_pad = max(target_size[1] - img.shape[3], 0)
    slcs_to_pad = max(target_size[2] - img.shape[4], 0)
    padded_img = F.pad(img, (0, slcs_to_pad, 0, cols_to_pad, 0, rows_to_pad), "constant", 0)
    return padded_img

class SpatialTransformer(nn.Module):
    """
    N-D Spatial Transformer, originally from Voxelmorph.
    """
    def __init__(self, size, mode='bilinear'):
        super().__init__()
        self.mode = mode
        vectors = [torch.arange(0, s) for s in size]
        grids = torch.meshgrid(vectors)
        grid = torch.stack(grids)
        grid = torch.unsqueeze(grid, 0)
        grid = grid.type(torch.FloatTensor).cuda()
        self.register_buffer('grid', grid)
    def forward(self, src, flow):
        new_locs = self.grid + flow
        shape = flow.shape[2:]
        for i in range(len(shape)):
            new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)
        if len(shape) == 2:
            new_locs = new_locs.permute(0, 2, 3, 1)
            new_locs = new_locs[..., [1, 0]]
        elif len(shape) == 3:
            new_locs = new_locs.permute(0, 2, 3, 4, 1)
            new_locs = new_locs[..., [2, 1, 0]]
        return F.grid_sample(src, new_locs, align_corners=True, mode=self.mode)

class register_model(nn.Module):
    """
    Simple wrapper to warp an image with a displacement field.
    """
    def __init__(self, img_size=(64, 256, 256), mode='bilinear'):
        super(register_model, self).__init__()
        self.spatial_trans = SpatialTransformer(img_size, mode)
    def forward(self, x):
        img = x[0].cuda()
        flow = x[1].cuda()
        out = self.spatial_trans(img, flow)
        return out

def dice_val_VOI(y_pred, y_true):
    """
    DSC for labels [1..35], ignoring 0
    """
    VOI_lbls = list(range(1, 36))
    pred = y_pred.detach().cpu().numpy()[0, 0, ...]
    true = y_true.detach().cpu().numpy()[0, 0, ...]
    DSCs = np.zeros((len(VOI_lbls)))
    idx = 0
    for i in VOI_lbls:
        pred_i = (pred == i)
        true_i = (true == i)
        inter = np.sum(pred_i * true_i)
        union = np.sum(pred_i) + np.sum(true_i)
        dsc = 2.0*inter/(union + 1e-5)
        DSCs[idx] = dsc
        idx += 1
    return np.mean(DSCs)

def dice_val(y_pred, y_true, num_clus):
    """
    Standard DSC, ignoring background?
    """
    y_pred = nn.functional.one_hot(y_pred, num_classes=num_clus)
    y_pred = torch.squeeze(y_pred, 1).permute(0,4,1,2,3).contiguous()
    y_true = nn.functional.one_hot(y_true, num_classes=num_clus)
    y_true = torch.squeeze(y_true, 1).permute(0,4,1,2,3).contiguous()
    intersection = y_pred * y_true
    intersection = intersection.sum(dim=[2,3,4])
    union = y_pred.sum(dim=[2,3,4]) + y_true.sum(dim=[2,3,4])
    dsc = (2.*intersection)/(union+1e-5)
    return torch.mean(torch.mean(dsc, dim=1))

################################################################################
# 3) LOSSES (Merged from 'losses.py')
################################################################################

class Grad3d(torch.nn.Module):
    """
    N-D gradient loss, from VoxelMorph
    """
    def __init__(self, penalty='l1', loss_mult=None):
        super(Grad3d, self).__init__()
        self.penalty = penalty
        self.loss_mult = loss_mult
    def forward(self, y_pred, y_true):
        dy = torch.abs(y_pred[:, :, 1:, :, :] - y_pred[:, :, :-1, :, :])
        dx = torch.abs(y_pred[:, :, :, 1:, :] - y_pred[:, :, :, :-1, :])
        dz = torch.abs(y_pred[:, :, :, :, 1:] - y_pred[:, :, :, :, :-1])
        if self.penalty == 'l2':
            dy = dy*dy
            dx = dx*dx
            dz = dz*dz
        d = torch.mean(dx) + torch.mean(dy) + torch.mean(dz)
        grad = d/3.0
        if self.loss_mult is not None:
            grad *= self.loss_mult
        return grad

class DiceLoss(nn.Module):
    """
    Simple Dice loss if needed
    """
    def __init__(self, num_class=36):
        super().__init__()
        self.num_class = num_class
    def forward(self, y_pred, y_true):
        y_true = nn.functional.one_hot(y_true, num_classes=self.num_class)
        y_true = torch.squeeze(y_true, 1).permute(0,4,1,2,3).contiguous()
        intersection = y_pred*y_true
        intersection = intersection.sum(dim=[2,3,4])
        union = y_pred.pow(2).sum(dim=[2,3,4]) + y_true.pow(2).sum(dim=[2,3,4])
        dsc = 2.*intersection/(union+1e-5)
        dsc = (1-torch.mean(dsc))
        return dsc

class NCC_vxm(torch.nn.Module):
    """
    Local normalized cross correlation (VoxelMorph)
    """
    def __init__(self, win=None):
        super(NCC_vxm, self).__init__()
        self.win = win
    def forward(self, y_true, y_pred):
        Ii = y_true
        Ji = y_pred
        ndims = len(list(Ii.size())) - 2
        assert ndims in [1,2,3]
        win = [9]*ndims if self.win is None else self.win
        sum_filt = torch.ones([1,1,*win]).to("cuda")
        pad_no = math.floor(win[0]/2)
        if ndims==1:
            stride=(1,)
            padding=(pad_no,)
        elif ndims==2:
            stride=(1,1)
            padding=(pad_no,pad_no)
        else:
            stride=(1,1,1)
            padding=(pad_no,pad_no,pad_no)

        conv_fn = getattr(F, 'conv%dd' % ndims)
        I2 = Ii*Ii
        J2 = Ji*Ji
        IJ = Ii*Ji
        I_sum = conv_fn(Ii, sum_filt, stride=stride, padding=padding)
        J_sum = conv_fn(Ji, sum_filt, stride=stride, padding=padding)
        I2_sum = conv_fn(I2, sum_filt, stride=stride, padding=padding)
        J2_sum = conv_fn(J2, sum_filt, stride=stride, padding=padding)
        IJ_sum = conv_fn(IJ, sum_filt, stride=stride, padding=padding)
        win_size = np.prod(win)
        u_I = I_sum/win_size
        u_J = J_sum/win_size
        cross = IJ_sum - u_J*I_sum - u_I*J_sum + u_I*u_J*win_size
        I_var = I2_sum - 2*u_I*I_sum + u_I*u_I*win_size
        J_var = J2_sum - 2*u_J*J_sum + u_J*u_J*win_size
        cc = cross*cross/(I_var*J_var+1e-5)
        return -torch.mean(cc)

################################################################################
# 4) SWIN TRANSFORMER + UTSRMorph MODEL (Single-Notebook Version)
################################################################################

#
# We copy the single-notebook code from your previous instructions,
# then paste here. For brevity, we rename them as "SwinTransformer" and "UTSRMorph".
#

# ---- config example (like get_UTSRMorph_config)
def get_UTSRMorph_config():
    config = ml_collections.ConfigDict()
    config.if_transskip = True
    config.if_convskip = True
    config.patch_size = 4
    config.in_chans = 2
    config.embed_dim = 96
    config.depths = (2,2,2,2)
    config.num_heads = (4,4,4,4)
    config.window_size = (5,6,7)
    config.mlp_ratio = 4
    config.pat_merg_rf = 4
    config.qkv_bias = False
    config.drop_rate = 0
    config.drop_path_rate = 0.3
    config.ape = False
    config.spe = False
    config.rpe = True
    config.patch_norm = True
    config.use_checkpoint = False
    config.out_indices = (0,1,2,3)
    config.reg_head_chan = 16
    config.img_size = (160,192,224)
    return config

# --- basic building blocks for the model (Swin parts) ---
class CA(nn.Module):
    def __init__(self, num_feat, squeeze_factor=16):
        super(CA, self).__init__()
        self.attention = nn.Sequential(
            nn.AdaptiveAvgPool3d(1),
            nn.Conv3d(num_feat, num_feat//squeeze_factor, 1, padding=0),
            nn.ReLU(inplace=True),
            nn.Conv3d(num_feat//squeeze_factor, num_feat, 1, padding=0),
            nn.Sigmoid()
        )
    def forward(self, x):
        y = self.attention(x)
        return x*y

class CAB(nn.Module):
    def __init__(self, num_feat, compress_ratio=3, squeeze_factor=30):
        super(CAB, self).__init__()
        self.cab = nn.Sequential(
            nn.Conv3d(num_feat, num_feat//compress_ratio, 3,1,1),
            nn.GELU(),
            nn.Conv3d(num_feat//compress_ratio, num_feat, 3,1,1),
            CA(num_feat, squeeze_factor)
        )
    def forward(self, x):
        return self.cab(x)

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None,
                 act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

def window_partition(x, window_size):
    B, H, W, L, C = x.shape
    x = x.view(B, H//window_size[0], window_size[0], W//window_size[1], window_size[1], L//window_size[2], window_size[2], C)
    windows = x.permute(0,1,3,5,2,4,6,7).contiguous().view(-1, window_size[0], window_size[1], window_size[2], C)
    return windows

def window_reverse(windows, window_size, H, W, L):
    B = int(windows.shape[0]/(H*W*L/window_size[0]/window_size[1]/window_size[2]))
    x = windows.view(B, H//window_size[0], W//window_size[1], L//window_size[2],
                     window_size[0], window_size[1], window_size[2], -1)
    x = x.permute(0,1,4,2,5,3,6,7).contiguous().view(B,H,W,L,-1)
    return x

class WindowAttention(nn.Module):
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None,
                 rpe=True, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.dim = dim
        self.window_size=window_size
        self.num_heads=num_heads
        head_dim = dim//num_heads
        self.scale = qk_scale or head_dim**-0.5
        self.rpe = rpe
        table_size=(2*window_size[0]-1)*(2*window_size[1]-1)*(2*window_size[2]-1)
        self.relative_position_bias_table = nn.Parameter(torch.zeros(table_size,num_heads))
        trunc_normal_(self.relative_position_bias_table, std=.02)
        coords_h = torch.arange(window_size[0])
        coords_w = torch.arange(window_size[1])
        coords_t = torch.arange(window_size[2])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w, coords_t]))
        coords_flat = coords.flatten(1)
        if rpe:
            rel = coords_flat[:,:,None] - coords_flat[:,None,:]
            rel = rel.permute(1,2,0).contiguous()
            rel[...,0]+=window_size[0]-1
            rel[...,1]+=window_size[1]-1
            rel[...,2]+=window_size[2]-1
            pos_factor=(2*window_size[1]-1)*(2*window_size[2]-1)
            rel[...,0]*=pos_factor
            rel[...,1]*=(2*window_size[2]-1)
            rel_index=rel.sum(-1)
            self.register_buffer("relative_position_index", rel_index)
        self.qkv=nn.Linear(dim, dim*3, bias=qkv_bias)
        self.attn_drop=nn.Dropout(attn_drop)
        self.proj=nn.Linear(dim, dim)
        self.proj_drop=nn.Dropout(proj_drop)
        self.softmax=nn.Softmax(dim=-1)
    def forward(self, x, mask=None):
        B_,N,C=x.shape
        qkv=self.qkv(x).reshape(B_,N,3,self.num_heads,C//self.num_heads).permute(2,0,3,1,4)
        q, k, v=qkv[0], qkv[1], qkv[2]
        q=q*self.scale
        attn=(q @ k.transpose(-2,-1))
        if self.rpe:
            table_size=self.window_size[0]*self.window_size[1]*self.window_size[2]
            rpb=self.relative_position_bias_table[self.relative_position_index.view(-1)]
            rpb=rpb.view(table_size, table_size, self.num_heads).permute(2,0,1)
            attn=attn + rpb.unsqueeze(0)
        if mask is not None:
            nW=mask.shape[0]
            attn=attn.view(B_//nW, nW,self.num_heads,N,N)
            attn=attn+mask.unsqueeze(1).unsqueeze(0)
            attn=attn.view(-1,self.num_heads,N,N)
            attn=self.softmax(attn)
        else:
            attn=self.softmax(attn)
        attn=self.attn_drop(attn)
        x=(attn@v).transpose(1,2).reshape(B_,N,C)
        x=self.proj(x)
        x=self.proj_drop(x)
        return x

class FAB(nn.Module):
    """
    Fusion Attention Block = (W-MSA or SW-MSA) + local conv
    """
    def __init__(self, dim, num_heads, window_size=(7,7,7), shift_size=(0,0,0),
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, rpe=True,
                 drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim=dim
        self.num_heads=num_heads
        self.window_size=window_size
        self.shift_size=shift_size
        self.mlp_ratio=mlp_ratio
        self.norm1=norm_layer(dim)
        self.attn=WindowAttention(dim, window_size, num_heads, qkv_bias, qk_scale, rpe, attn_drop, drop)
        self.drop_path=DropPath(drop_path) if drop_path>0 else nn.Identity()
        self.norm2=norm_layer(dim)
        mlp_hidden=int(dim*mlp_ratio)
        self.mlp=Mlp(dim, mlp_hidden, act_layer=act_layer, drop=drop)
        self.conv_block = CAB(num_feat=dim, compress_ratio=3, squeeze_factor=30)
        self.H=self.W=self.T=None
    def forward(self, x, mask_matrix):
        H,W,T=self.H,self.W,self.T
        B,L,C=x.shape
        shortcut=x
        x=self.norm1(x)
        x_3d=x.view(B,H,W,T,C)
        conv_x=self.conv_block(x_3d.permute(0,4,1,2,3))
        conv_x=conv_x.permute(0,2,3,4,1).contiguous().view(B,H*W*T,C)
        pad_r=(self.window_size[0]-H%self.window_size[0])%self.window_size[0]
        pad_b=(self.window_size[1]-W%self.window_size[1])%self.window_size[1]
        pad_h=(self.window_size[2]-T%self.window_size[2])%self.window_size[2]
        x_3d=F.pad(x_3d,(0,0,0,pad_h,0,pad_b,0,pad_r))
        Hp, Wp, Tp=x_3d.shape[1], x_3d.shape[2], x_3d.shape[3]
        if any(self.shift_size):
            shifted_x=torch.roll(x_3d, shifts=(-self.shift_size[0],-self.shift_size[1],-self.shift_size[2]), dims=(1,2,3))
            attn_mask=mask_matrix
        else:
            shifted_x=x_3d
            attn_mask=None
        x_windows=window_partition(shifted_x,self.window_size)
        x_windows=x_windows.view(-1,self.window_size[0]*self.window_size[1]*self.window_size[2],C)
        attn_windows=self.attn(x_windows, attn_mask)
        attn_windows=attn_windows.view(-1,*self.window_size,C)
        shifted_x=window_reverse(attn_windows,self.window_size,Hp,Wp,Tp)
        if any(self.shift_size):
            x_3d=torch.roll(shifted_x,shifts=(self.shift_size[0],self.shift_size[1],self.shift_size[2]),dims=(1,2,3))
        else:
            x_3d=shifted_x
        if pad_r>0 or pad_b>0 or pad_h>0:
            x_3d=x_3d[:,:H,:W,:T,:].contiguous()
        x_attn=x_3d.view(B,H*W*T,C)
        x=shortcut + self.drop_path(x_attn) + conv_x*0.01
        x=x + self.drop_path(self.mlp(self.norm2(x)))
        return x

class PatchMerging(nn.Module):
    """
    3D Patch Merging
    """
    def __init__(self, dim, norm_layer=nn.LayerNorm, reduce_factor=2):
        super().__init__()
        self.dim=dim
        self.reduction=nn.Linear(8*dim,(8//reduce_factor)*dim,bias=False)
        self.norm=norm_layer(8*dim)
    def forward(self, x, H, W, T):
        B,L,C=x.shape
        x_3d=x.view(B,H,W,T,C)
        x0=x_3d[:,0::2,0::2,0::2,:]
        x1=x_3d[:,1::2,0::2,0::2,:]
        x2=x_3d[:,0::2,1::2,0::2,:]
        x3=x_3d[:,0::2,0::2,1::2,:]
        x4=x_3d[:,1::2,1::2,0::2,:]
        x5=x_3d[:,0::2,1::2,1::2,:]
        x6=x_3d[:,1::2,0::2,1::2,:]
        x7=x_3d[:,1::2,1::2,1::2,:]
        x_cat=torch.cat([x0,x1,x2,x3,x4,x5,x6,x7],dim=-1)
        x_cat=x_cat.view(B,-1,8*C)
        x_cat=self.norm(x_cat)
        x_cat=self.reduction(x_cat)
        return x_cat

class BasicLayer(nn.Module):
    def __init__(self, dim, depth, num_heads, window_size=(7,7,7),
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, rpe=True,
                 drop=0., attn_drop=0., drop_path=0.,
                 norm_layer=nn.LayerNorm,
                 downsample=None, use_checkpoint=False,
                 pat_merg_rf=2):
        super().__init__()
        self.window_size=window_size
        self.shift_size=(window_size[0]//2, window_size[1]//2, window_size[2]//2)
        self.depth=depth
        self.use_checkpoint=use_checkpoint
        self.blocks=nn.ModuleList([
            FAB(dim,num_heads,window_size,(0,0,0) if (i%2==0) else self.shift_size,
                mlp_ratio,qkv_bias,qk_scale,rpe,drop,attn_drop,
                drop_path[i] if isinstance(drop_path,list) else drop_path,
                norm_layer=norm_layer)
            for i in range(depth)
        ])
        self.overlap_attn=None
        # Overlapping block can be added here if needed
        self.downsample=downsample(dim=dim, norm_layer=norm_layer, reduce_factor=pat_merg_rf) if downsample else None
    def forward(self,x,H,W,T):
        Hp=int(math.ceil(H/self.window_size[0]))*self.window_size[0]
        Wp=int(math.ceil(W/self.window_size[1]))*self.window_size[1]
        Tp=int(math.ceil(T/self.window_size[2]))*self.window_size[2]
        img_mask=torch.zeros((1,Hp,Wp,Tp,1), device=x.device)
        h_slices=(slice(0,-self.window_size[0]),
                  slice(-self.window_size[0], -self.shift_size[0]),
                  slice(-self.shift_size[0], None))
        w_slices=(slice(0,-self.window_size[1]),
                  slice(-self.window_size[1], -self.shift_size[1]),
                  slice(-self.shift_size[1], None))
        t_slices=(slice(0,-self.window_size[2]),
                  slice(-self.window_size[2], -self.shift_size[2]),
                  slice(-self.shift_size[2], None))
        cnt=0
        for hh in h_slices:
            for ww in w_slices:
                for tt in t_slices:
                    img_mask[:,hh,ww,tt,:]=cnt
                    cnt+=1
        mask_windows=window_partition(img_mask,self.window_size)
        mask_windows=mask_windows.view(-1,self.window_size[0]*self.window_size[1]*self.window_size[2])
        attn_mask=mask_windows.unsqueeze(1)-mask_windows.unsqueeze(2)
        attn_mask=attn_mask.masked_fill(attn_mask!=0,float(-100.0)).masked_fill(attn_mask==0,float(0.0))
        for blk in self.blocks:
            blk.H, blk.W, blk.T=H,W,T
            if self.use_checkpoint:
                x=checkpoint.checkpoint(blk,x,attn_mask)
            else:
                x=blk(x,attn_mask)
        if self.downsample is not None:
            x_down=self.downsample(x,H,W,T)
            Wh, Ww, Wt=(H+1)//2,(W+1)//2,(T+1)//2
            return x,H,W,T, x_down,Wh,Ww,Wt
        else:
            return x,H,W,T, x,H,W,T

class PatchEmbed(nn.Module):
    def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        patch_size=to_3tuple(patch_size)
        self.patch_size=patch_size
        self.in_chans=in_chans
        self.embed_dim=embed_dim
        self.proj=nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm=norm_layer(embed_dim) if norm_layer else None
    def forward(self, x):
        B,C,H,W,T=x.shape
        ps=self.patch_size
        if T%ps[2]!=0:
            x=F.pad(x,(0,ps[2]-T%ps[2]))
        if W%ps[1]!=0:
            x=F.pad(x,(0,0,0,ps[1]-W%ps[1]))
        if H%ps[0]!=0:
            x=F.pad(x,(0,0,0,0,0,ps[0]-H%ps[0]))
        x=self.proj(x)
        if self.norm is not None:
            Hp,Wp,Tp=x.shape[2], x.shape[3], x.shape[4]
            x=x.flatten(2).transpose(1,2)
            x=self.norm(x)
            x=x.transpose(1,2).view(B,self.embed_dim,Hp,Wp,Tp)
        return x

class SinPositionalEncoding3D(nn.Module):
    # optional
    def __init__(self, channels):
        super().__init__()
        channels=int(np.ceil(channels/6)*2)
        if channels%2: channels+=1
        self.channels=channels
        self.inv_freq=1./(10000**(torch.arange(0,channels,2).float()/channels))
    def forward(self, tensor):
        tensor=tensor.permute(0,2,3,4,1)
        B,X,Y,Z,C=tensor.shape
        pos_x=torch.arange(X, device=tensor.device).type(self.inv_freq.type())
        pos_y=torch.arange(Y, device=tensor.device).type(self.inv_freq.type())
        pos_z=torch.arange(Z, device=tensor.device).type(self.inv_freq.type())
        sin_inp_x=torch.einsum("i,j->ij", pos_x, self.inv_freq)
        sin_inp_y=torch.einsum("i,j->ij", pos_y, self.inv_freq)
        sin_inp_z=torch.einsum("i,j->ij", pos_z, self.inv_freq)
        emb_x=torch.cat([sin_inp_x.sin(), sin_inp_x.cos()], dim=-1).unsqueeze(1).unsqueeze(1)
        emb_y=torch.cat([sin_inp_y.sin(), sin_inp_y.cos()], dim=-1).unsqueeze(1)
        emb_z=torch.cat([sin_inp_z.sin(), sin_inp_z.cos()], dim=-1)
        emb=torch.zeros((X,Y,Z,self.channels*3), device=tensor.device, dtype=tensor.dtype)
        emb[:,:,:, :self.channels]=emb_x
        emb[:,:,:, self.channels:2*self.channels]=emb_y
        emb[:,:,:,2*self.channels:]=emb_z
        emb=emb[None,:,:, :,:C].repeat(B,1,1,1,1)
        return emb.permute(0,4,1,2,3)

class SwinTransformer(nn.Module):
    def __init__(self, pretrain_img_size=224,
                 patch_size=4, in_chans=3, embed_dim=96,
                 depths=[2,2,6,2], num_heads=[3,6,12,24], window_size=(7,7,7),
                 mlp_ratio=4., qkv_bias=True,qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.2,
                 norm_layer=nn.LayerNorm, ape=False, spe=False, rpe=True,
                 patch_norm=True,out_indices=(0,1,2,3),frozen_stages=-1,
                 use_checkpoint=False, pat_merg_rf=2):
        super().__init__()
        self.num_layers=len(depths)
        self.embed_dim=embed_dim
        self.ape=ape
        self.spe=spe
        self.rpe=rpe
        self.patch_norm=patch_norm
        self.out_indices=out_indices
        self.frozen_stages=frozen_stages
        self.patch_embed=PatchEmbed(patch_size, in_chans, embed_dim, norm_layer if patch_norm else None)
        self.pos_drop=nn.Dropout(p=drop_rate)
        dpr=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        self.layers=nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer=BasicLayer(dim=int(embed_dim*2**i_layer),
                             depth=depths[i_layer],
                             num_heads=num_heads[i_layer],
                             window_size=window_size,
                             mlp_ratio=mlp_ratio,
                             qkv_bias=qkv_bias,
                             qk_scale=qk_scale,
                             rpe=rpe,
                             drop=drop_rate,
                             attn_drop=attn_drop_rate,
                             drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer+1])],
                             norm_layer=norm_layer,
                             downsample=PatchMerging if (i_layer<self.num_layers-1) else None,
                             use_checkpoint=use_checkpoint,
                             pat_merg_rf=pat_merg_rf)
            self.layers.append(layer)
        num_features=[int(embed_dim*2**i) for i in range(self.num_layers)]
        self.num_features=num_features
        for i_layer in out_indices:
            layer_=norm_layer(num_features[i_layer])
            self.add_module(f"norm{i_layer}", layer_)
    def forward(self, x):
        x=self.patch_embed(x)
        B,C,Hp,Wp,Tp=x.shape
        x=x.flatten(2).transpose(1,2)
        x=self.pos_drop(x)
        outs=[]
        for i in range(self.num_layers):
            layer=self.layers[i]
            x_out,H,W,T, x, Wh,Ww,Wt=layer(x, Hp,Wp,Tp)
            if i in self.out_indices:
                norm_layer=getattr(self, f"norm{i}")
                x_out=norm_layer(x_out)
                out_f=x_out.view(B,H,W,T,self.num_features[i]).permute(0,4,1,2,3).contiguous()
                outs.append(out_f)
            Hp,Wp,Tp=Wh,Ww,Wt
        return outs

class PixelShuffle3d(nn.Module):
    def __init__(self, scale):
        super().__init__()
        self.scale=scale
    def forward(self, inp):
        b,c,d,h,w=inp.size()
        oc=c//(self.scale**3)
        od, oh, ow=d*self.scale,h*self.scale,w*self.scale
        out=inp.view(b, oc, self.scale, self.scale, self.scale, d,h,w)
        out=out.permute(0,1,5,2,6,3,7,4).contiguous()
        out=out.view(b, oc, od, oh, ow)
        return out

class ConvergeHead(nn.Module):
    def __init__(self, in_dim, up_ratio, kernel_size, padding):
        super().__init__()
        self.in_dim=in_dim
        self.up_ratio=up_ratio
        self.conv=nn.Conv3d(in_dim, (up_ratio**3)*in_dim, kernel_size,1,padding,1,in_dim)
        self.apply(self._init_weights)
    def forward(self, x):
        hp=self.conv(x)
        poxel=PixelShuffle3d(self.up_ratio)
        hp=poxel(hp)
        return hp
    def _init_weights(self,m):
        if isinstance(m,(nn.Conv3d,nn.Linear)):
            nn.init.normal_(m.weight,std=0.001)
            nn.init.constant_(m.bias,0)
        elif isinstance(m,nn.BatchNorm3d):
            nn.init.constant_(m.weight,1)
            nn.init.constant_(m.bias,0)

class Conv3dReLU(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size, padding=0, stride=1, use_batchnorm=True):
        conv=nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False)
        relu=nn.LeakyReLU(inplace=True)
        nm=nn.BatchNorm3d(out_channels) if use_batchnorm else nn.InstanceNorm3d(out_channels)
        super(Conv3dReLU,self).__init__(conv, nm, relu)

class SR(nn.Module):
    def __init__(self, in_channels, out_channels, skip_channels=0, use_batchnorm=True):
        super().__init__()
        self.up=ConvergeHead(in_channels,2,3,1)
        self.conv1=Conv3dReLU(in_channels+skip_channels, out_channels,3,1,use_batchnorm=use_batchnorm)
        self.conv2=Conv3dReLU(out_channels,out_channels,3,1,use_batchnorm=use_batchnorm)
    def forward(self,x, skip=None):
        x=self.up(x)
        if skip is not None:
            x=torch.cat([x, skip], dim=1)
        x=self.conv1(x)
        x=self.conv2(x)
        return x

class RegistrationHead(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
        conv3d=nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size//2)
        conv3d.weight=nn.Parameter(Normal(0,1e-5).sample(conv3d.weight.shape))
        conv3d.bias=nn.Parameter(torch.zeros(conv3d.bias.shape))
        super().__init__(conv3d)

class SpatialTransformer_forUTSR(nn.Module):
    def __init__(self,size,mode='bilinear'):
        super().__init__()
        self.mode=mode
        vectors=[torch.arange(0,s) for s in size]
        grids=torch.meshgrid(vectors)
        grid=torch.stack(grids)
        grid=torch.unsqueeze(grid,0).float()
        self.register_buffer('grid',grid)
    def forward(self, src, flow):
        new_locs=self.grid + flow
        shape=flow.shape[2:]
        for i in range(len(shape)):
            new_locs[:,i,...]=2*(new_locs[:,i,...]/(shape[i]-1)-0.5)
        if len(shape)==3:
            new_locs=new_locs.permute(0,2,3,4,1)
            new_locs=new_locs[..., [2,1,0]]
        return F.grid_sample(src,new_locs,align_corners=False,mode=self.mode)

class UTSRMorph(nn.Module):
    def __init__(self, config):
        super(UTSRMorph,self).__init__()
        self.if_convskip=config.if_convskip
        self.if_transskip=config.if_transskip
        embed_dim=config.embed_dim
        self.transformer=SwinTransformer(
            patch_size=config.patch_size,
            in_chans=config.in_chans,
            embed_dim=config.embed_dim,
            depths=config.depths,
            num_heads=config.num_heads,
            window_size=config.window_size,
            mlp_ratio=config.mlp_ratio,
            qkv_bias=config.qkv_bias,
            drop_rate=config.drop_rate,
            drop_path_rate=config.drop_path_rate,
            ape=config.ape,
            spe=config.spe,
            rpe=config.rpe,
            patch_norm=config.patch_norm,
            use_checkpoint=config.use_checkpoint,
            out_indices=config.out_indices,
            pat_merg_rf=config.pat_merg_rf
        )
        self.up0=SR(embed_dim*8, embed_dim*4, skip_channels=(embed_dim*4 if self.if_transskip else 0), use_batchnorm=False)
        self.up1=SR(embed_dim*4, embed_dim*2, skip_channels=(embed_dim*2 if self.if_transskip else 0), use_batchnorm=False)
        self.up2=SR(embed_dim*2, embed_dim, skip_channels=(embed_dim if self.if_transskip else 0), use_batchnorm=False)
        self.up3=SR(embed_dim, config.reg_head_chan, skip_channels=(embed_dim//2 if self.if_convskip else 0),use_batchnorm=False)
        self.c1=Conv3dReLU(2, embed_dim//2,3,1,use_batchnorm=False)
        self.reg_head=RegistrationHead(config.reg_head_chan,3,3)
        self.spatial_trans=SpatialTransformer_forUTSR(config.img_size)
        self.avg_pool=nn.AvgPool3d(3,stride=2,padding=1)
        self.up=ConvergeHead(3,2,3,1)
    def forward(self, x):
        # x: shape (B,2,H,W,D)
        source=x[:,0:1,:,:,:]
        if self.if_convskip:
            x_s1=self.avg_pool(x)
            f4=self.c1(x_s1)
        else:
            f4=None
        out_feats=self.transformer(x)
        if self.if_transskip:
            f1=out_feats[-2]
            f2=out_feats[-3]
            f3=out_feats[-4]
        else:
            f1=None; f2=None; f3=None
        x=self.up0(out_feats[-1], f1)
        x=self.up1(x,f2)
        x=self.up2(x,f3)
        x=self.up3(x,f4)
        flow=self.reg_head(x)
        flow=self.up(flow)
        out=self.spatial_trans(source, flow)
        return out, flow

################################################################################
# 5) THE TRAINING SCRIPT (Merged from 'train_UTSRMorph_oasis.py')
################################################################################

class Logger(object):
    def __init__(self, save_dir):
        self.terminal=sys.stdout
        self.log=open(save_dir+"logfile.log","a")
    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)
    def flush(self):
        pass

def mk_grid_img(grid_step, line_thickness=1, grid_sz=(160,192,224)):
    grid_img=np.zeros(grid_sz)
    for j in range(0, grid_img.shape[1], grid_step):
        grid_img[:, j+line_thickness-1, :]=1
    for i in range(0, grid_img.shape[2], grid_step):
        grid_img[:, :, i+line_thickness-1]=1
    grid_img=grid_img[None,None,...]
    grid_img=torch.from_numpy(grid_img).cuda().float()
    return grid_img

def comput_fig(img):
    img=img.detach().cpu().numpy()[0,0,48:64,:,:]
    fig=plt.figure(figsize=(12,12),dpi=180)
    for i in range(img.shape[0]):
        plt.subplot(4,4,i+1)
        plt.axis('off')
        plt.imshow(img[i,:,:], cmap='gray')
    fig.subplots_adjust(wspace=0,hspace=0)
    return fig

def save_checkpoint(state, save_dir='models', filename='checkpoint.pth.tar', max_model_num=8):
    torch.save(state, save_dir+filename)
    model_lists=natsorted(glob.glob(save_dir+'*'))
    while len(model_lists)>max_model_num:
        os.remove(model_lists[0])
        model_lists=natsorted(glob.glob(save_dir+'*'))

def adjust_learning_rate(optimizer, epoch, MAX_EPOCHES, INIT_LR, power=0.9):
    for param_group in optimizer.param_groups:
        param_group['lr']= round(INIT_LR * np.power(1 - (epoch)/MAX_EPOCHES , power),8)

# Dummy data placeholders: data/datasets modules
# (In your actual code, these come from 'data' folder.)
class DummyOASISBrainDataset(torch.utils.data.Dataset):
    def __init__(self, files, transforms=None):
        self.files=files
        self.transforms=transforms
    def __len__(self):
        return len(self.files)
    def __getitem__(self, idx):
        # returning random data for demonstration
        x=np.random.rand(160,192,224).astype(np.float32)
        y=np.random.rand(160,192,224).astype(np.float32)
        x_seg=(x*10).astype(np.int16)
        y_seg=(y*10).astype(np.int16)
        return x[None,:,:,:], y[None,:,:,:], x_seg[None,:,:,:], y_seg[None,:,:,:]

class DummyOASISBrainInferDataset(torch.utils.data.Dataset):
    def __init__(self, files, transforms=None):
        self.files=files
        self.transforms=transforms
    def __len__(self):
        return len(self.files)
    def __getitem__(self, idx):
        x=np.random.rand(160,192,224).astype(np.float32)
        y=np.random.rand(160,192,224).astype(np.float32)
        x_seg=(x*10).astype(np.int16)
        y_seg=(y*10).astype(np.int16)
        return x[None,:,:,:], y[None,:,:,:], x_seg[None,:,:,:], y_seg[None,:,:,:]

def train_UTSRMorph():
    batch_size=1
    train_dir='(example path)'
    val_dir='(example val path)'
    weights=[1,1]
    save_dir='UTSRMorph_ncc_{}_diffusion_{}/'.format(weights[0], weights[1])
    if not os.path.exists('experiments/'+save_dir):
        os.makedirs('experiments/'+save_dir)
    if not os.path.exists('logs/'+save_dir):
        os.makedirs('logs/'+save_dir)
    # sys.stdout=Logger('logs/'+save_dir) # optional
    lr=1e-4
    epoch_start=0
    max_epoch=5
    cont_training=False
    config=get_UTSRMorph_config()
    model=UTSRMorph(config).cuda()
    reg_model=register_model(config.img_size, 'nearest').cuda()
    reg_model_bilin=register_model(config.img_size, 'bilinear').cuda()
    if cont_training:
        pass
    updated_lr=lr
    train_set=DummyOASISBrainDataset(glob.glob(train_dir+'*.pkl'))
    val_set=DummyOASISBrainInferDataset(glob.glob(val_dir+'*.pkl'))
    train_loader=DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
    val_loader=DataLoader(val_set, batch_size=1, shuffle=False, num_workers=0, pin_memory=True, drop_last=True)
    optimizer=optim.Adam(model.parameters(), lr=updated_lr, weight_decay=0, amsgrad=True)
    criterion_ncc=NCC_vxm()
    criterion_reg=Grad3d(penalty='l2')
    best_dsc=0
    writer=SummaryWriter(log_dir='logs/'+save_dir)
    for epoch in range(epoch_start,max_epoch):
        print(f'--- Epoch {epoch} / {max_epoch} ---')
        loss_all=AverageMeter()
        idx=0
        time_start=time.time()
        for data in train_loader:
            idx+=1
            model.train()
            adjust_learning_rate(optimizer, epoch, max_epoch, lr)
            data=[t.cuda() for t in data]
            x=data[0]
            y=data[1]
            x_seg=data[2]
            y_seg=data[3]
            if random.random()<=0.5:
                x_in=torch.cat((x,y),dim=1)
                output, flow=model(x_in)
                loss_ncc=criterion_ncc(output,y)*weights[0]
                loss_reg=criterion_reg(flow,y)*weights[1]
                loss=loss_ncc+loss_reg
                loss_all.update(loss.item(), y.numel())
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                del x_in, output, flow
            else:
                y_in=torch.cat((y,x),dim=1)
                output, flow=model(y_in)
                loss_ncc=criterion_ncc(output,x)*weights[0]
                loss_reg=criterion_reg(flow,x)*weights[1]
                loss=loss_ncc+loss_reg
                loss_all.update(loss.item(), x.numel())
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                del y_in,output, flow
            print('Iter {} of {} loss {:.4f}, ImgSim {:.6f}, Reg {:.6f}'.format(idx,len(train_loader),loss.item(),loss_ncc.item(),loss_reg.item()))
        writer.add_scalar('Loss/train',loss_all.avg,epoch)
        print('Epoch {} loss {:.4f}'.format(epoch, loss_all.avg))
        eval_dsc=AverageMeter()
        with torch.no_grad():
            for data in val_loader:
                model.eval()
                data=[t.cuda() for t in data]
                x=data[0]
                y=data[1]
                x_seg=data[2]
                y_seg=data[3]
                x_in=torch.cat((x,y),dim=1)
                grid_img=mk_grid_img(8,1,config.img_size)
                output=model(x_in)
                def_out=reg_model([x_seg.float(), output[1]])
                def_grid=reg_model_bilin([grid_img.float(), output[1]])
                dsc_val=dice_val_VOI(def_out.long(), y_seg.long())
                eval_dsc.update(dsc_val.item(), x.size(0))
                print(eval_dsc.avg)
        best_dsc=max(eval_dsc.avg, best_dsc)
        save_checkpoint({
            'epoch':epoch+1,
            'state_dict':model.state_dict(),
            'best_dsc':best_dsc,
            'optimizer':optimizer.state_dict()
        }, save_dir='experiments/'+save_dir, filename='dsc{:.4f}.pth.tar'.format(eval_dsc.avg))
        time_end=time.time()
        alltime=(time_end-time_start)*(max_epoch-1-epoch)
        timeresult=str(datetime.timedelta(seconds=alltime))
        print("time:"+timeresult)
        writer.add_scalar('DSC/validate', eval_dsc.avg, epoch)
        # maybe skip plotting if in colab
        loss_all.reset()
    writer.close()

################################################################################
# 6) INFERENCE SCRIPT (Merged from 'infer_UTSRMorph_oasis.py')
################################################################################

def infer_UTSRMorph():
    test_dir='(example test path)'
    save_dir='(example submit path)'
    model_idx=-1
    weights=[1,1,1]
    model_folder='UTSRMorph_ncc_{}_diffusion_{}/'.format(weights[0],weights[1])
    model_dir='experiments/'+model_folder
    config=get_UTSRMorph_config()
    model=UTSRMorph(config)
    # load some best model
    # best_model=torch.load(model_dir + natsorted(os.listdir(model_dir))[model_idx])['state_dict']
    # model.load_state_dict(best_model)
    model.cuda()
    img_size=(160,192,224)
    reg_model=register_model(img_size, 'nearest').cuda()
    # build dummy data loader
    test_set=DummyOASISBrainInferDataset(glob.glob(test_dir+'*.pkl'))
    test_loader=DataLoader(test_set, batch_size=1, shuffle=False, num_workers=0, pin_memory=True, drop_last=True)
    file_names=glob.glob(test_dir+'*.pkl')
    dice_all=[]
    with torch.no_grad():
        for i, data in enumerate(file_names):
            # your actual code loads data differently
            x,y,x_seg,y_seg=pkload(data)
            x, y=torch.from_numpy(x).cuda(), torch.from_numpy(y).cuda()
            x_seg,y_seg=torch.from_numpy(x_seg).cuda(), torch.from_numpy(y_seg).cuda()
            model.eval()
            x_in=torch.cat((x,y),dim=1)
            x_def, flow=model(x_in)
            def_out=reg_model([x_seg.float(), flow])
            dsc_val=dice_val_VOI(def_out.long(), y_seg.long())
            dice_all.append(dsc_val)
    dice_all=np.array(dice_all)
    print("Mean DSC:",dice_all.mean())

################################################################################
# 7) analysis_oasis.py (Optional - for advanced metrics)
################################################################################

def jacobian_determinant_np(disp):
    _,_,H,W,D=disp.shape
    gradx=np.array([-0.5,0,0.5]).reshape(1,3,1,1)
    grady=np.array([-0.5,0,0.5]).reshape(1,1,3,1)
    gradz=np.array([-0.5,0,0.5]).reshape(1,1,1,3)
    gradx_disp=np.stack([scipy.ndimage.correlate(disp[:,0,:,:,:],gradx,mode='constant',cval=0.0),
                          scipy.ndimage.correlate(disp[:,1,:,:,:],gradx,mode='constant',cval=0.0),
                          scipy.ndimage.correlate(disp[:,2,:,:,:],gradx,mode='constant',cval=0.0)], axis=1)
    grady_disp=np.stack([scipy.ndimage.correlate(disp[:,0,:,:,:],grady,mode='constant',cval=0.0),
                          scipy.ndimage.correlate(disp[:,1,:,:,:],grady,mode='constant',cval=0.0),
                          scipy.ndimage.correlate(disp[:,2,:,:,:],grady,mode='constant',cval=0.0)], axis=1)
    gradz_disp=np.stack([scipy.ndimage.correlate(disp[:,0,:,:,:],gradz,mode='constant',cval=0.0),
                          scipy.ndimage.correlate(disp[:,1,:,:,:],gradz,mode='constant',cval=0.0),
                          scipy.ndimage.correlate(disp[:,2,:,:,:],gradz,mode='constant',cval=0.0)], axis=1)
    grad_disp=np.concatenate([gradx_disp,grady_disp,gradz_disp], 0)
    jacobian=grad_disp+np.eye(3,3).reshape(3,3,1,1,1)
    jacobian=jacobian[:,:,2:-2,2:-2,2:-2]
    jacdet= jacobian[0,0,:,:,:]* (jacobian[1,1,:,:,:]*jacobian[2,2,:,:,:]- jacobian[1,2,:,:,:]*jacobian[2,1,:,:,:]) \
            -jacobian[1,0,:,:,:]* (jacobian[0,1,:,:,:]*jacobian[2,2,:,:,:]- jacobian[0,2,:,:,:]*jacobian[2,1,:,:,:]) \
            +jacobian[2,0,:,:,:]* (jacobian[0,1,:,:,:]*jacobian[1,2,:,:,:]- jacobian[0,2,:,:,:]*jacobian[1,1,:,:,:])
    return jacdet

################################################################################
# 8) MAIN LAUNCH (uncomment as needed)
################################################################################

if __name__ == "__main__":
    print("All code loaded. Now you can train or infer as desired.\n")
    print("To train, call train_UTSRMorph()")
    print("To infer, call infer_UTSRMorph()")

    # Example usage:
    # train_UTSRMorph()
    # infer_UTSRMorph()


Collecting surface-distance
  Downloading surface_distance-0.1-py3-none-any.whl.metadata (1.7 kB)
Downloading surface_distance-0.1-py3-none-any.whl (14 kB)
Installing collected packages: surface-distance
Successfully installed surface-distance-0.1
All code loaded. Now you can train or infer as desired.

To train, call train_UTSRMorph()
To infer, call infer_UTSRMorph()


In [None]:
################################################################################
# 0) INSTALL DEPENDENCIES
################################################################################

!pip install einops timm ml_collections surface-distance

################################################################################
# 1) IMPORTS
################################################################################

import os, sys, math, glob, random, pickle, datetime, collections
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import matplotlib.pyplot as plt
from natsort import natsorted
from typing import Dict, List, Optional, Sequence, Tuple, Union
from math import exp
import scipy.ndimage
import torch.optim as optim

# third-party
import ml_collections
from einops import rearrange
from timm.models.layers import DropPath, trunc_normal_, to_3tuple
from surface_distance import compute_surface_distances, compute_robust_hausdorff, compute_dice_coefficient



In [None]:


################################################################################
# 2) DATA UTILS: from data/data_utils.py
################################################################################

M = 2**32 - 1

def init_fn(worker):
    import random, torch, numpy as np
    seed = torch.LongTensor(1).random_().item()
    seed = (seed + worker) % M
    np.random.seed(seed)
    random.seed(seed)

def pkload(fname):
    with open(fname, 'rb') as f:
        return pickle.load(f)

################################################################################
# 3) DATASET CLASSES: OASISBrainDataset and OASISBrainInferDataset
#    from data/datasets.py
################################################################################

class OASISBrainDataset(Dataset):
    def __init__(self, data_path, transforms):
        """
        data_path: list of pkl file paths
        transforms: a Compose of transformations
        """
        self.paths = data_path
        self.transforms = transforms

    def __getitem__(self, index):
        path = self.paths[index]
        # pick a random other file as 'target'
        tar_list = self.paths.copy()
        tar_list.remove(path)
        random.shuffle(tar_list)
        tar_file = tar_list[0]

        x, x_seg = pkload(path)
        y, y_seg = pkload(tar_file)

        # add channel dims
        x, y = x[None, ...], y[None, ...]
        x_seg, y_seg = x_seg[None, ...], y_seg[None, ...]

        # apply transforms
        x, x_seg = self.transforms([x, x_seg])
        y, y_seg = self.transforms([y, y_seg])

        # convert to contiguous array -> then torch
        x, y = np.ascontiguousarray(x), np.ascontiguousarray(y)
        x_seg, y_seg = np.ascontiguousarray(x_seg), np.ascontiguousarray(y_seg)

        x, y = torch.from_numpy(x), torch.from_numpy(y)
        x_seg, y_seg = torch.from_numpy(x_seg), torch.from_numpy(y_seg)

        return x, y, x_seg, y_seg

    def __len__(self):
        return len(self.paths)

class OASISBrainInferDataset(Dataset):
    def __init__(self, data_path, transforms):
        """
        data_path: list of pkl file paths
        transforms: a Compose of transformations
        """
        self.paths = data_path
        self.transforms = transforms

    def __getitem__(self, index):
        path = self.paths[index]
        x, y, x_seg, y_seg = pkload(path)

        # add channel dims
        x, y = x[None, ...], y[None, ...]
        x_seg, y_seg = x_seg[None, ...], y_seg[None, ...]

        # apply transforms
        x, x_seg = self.transforms([x, x_seg])
        y, y_seg = self.transforms([y, y_seg])

        # convert to contiguous array -> then torch
        x, y = np.ascontiguousarray(x), np.ascontiguousarray(y)
        x_seg, y_seg = np.ascontiguousarray(x_seg), np.ascontiguousarray(y_seg)

        x, y = torch.from_numpy(x), torch.from_numpy(y)
        x_seg, y_seg = torch.from_numpy(x_seg), torch.from_numpy(y_seg)

        return x, y, x_seg, y_seg

    def __len__(self):
        return len(self.paths)

################################################################################
# 4) TRANSFORMATIONS: from data/trans.py (plus data/rand.py embedded)
################################################################################

class Uniform(object):
    def __init__(self, a, b):
        self.a=a
        self.b=b
    def sample(self):
        import random
        return random.uniform(self.a, self.b)

class Gaussian(object):
    def __init__(self, mean, std):
        self.mean=mean
        self.std=std
    def sample(self):
        import random
        return random.gauss(self.mean, self.std)

class Constant(object):
    def __init__(self, val):
        self.val=val
    def sample(self):
        return self.val

# Base class for transforms
class Base(object):
    def sample(self,*shape):
        return shape
    def tf(self,img,k=0):
        return img
    def __call__(self, img, dim=3, reuse=False):
        if not reuse:
            im = img if isinstance(img,np.ndarray) else img[0]
            shape = im.shape[1:dim+1]
            self.sample(*shape)
        if isinstance(img, (list, tuple)):
            return [self.tf(x, k_idx) for k_idx, x in enumerate(img)]
        return self.tf(img)

class Identity(Base):
    def __str__(self):
        return 'Identity()'

class Compose(Base):
    def __init__(self, ops):
        if not isinstance(ops, (list, tuple)):
            ops=[ops]
        self.ops=ops
    def sample(self,*shape):
        for op in self.ops:
            shape=op.sample(*shape)
    def tf(self,img,k=0):
        for op in self.ops:
            img=op.tf(img,k)
        return img
    def __str__(self):
        return f"Compose({self.ops})"

class NumpyType(Base):
    """
    Convert to specific numpy dtype per item
    e.g. NumpyType((np.float32, np.int16))
    """
    def __init__(self, types, num=-1):
        self.types=types
        self.num=num
    def tf(self,img,k=0):
        if self.num>0 and k>=self.num:
            return img
        return img.astype(self.types[k])
    def __str__(self):
        return f"NumpyType({self.types})"

# Add more transforms as needed from trans.py
# e.g. RandomRotion, RandomFlip, etc.

################################################################################
# 5) TRAIN/INFERENCE UTILS
################################################################################

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()
    def reset(self):
        self.val=0
        self.avg=0
        self.sum=0
        self.count=0
        self.vals=[]
        self.std=0
    def update(self,val,n=1):
        self.val=val
        self.sum+=val*n
        self.count+=n
        self.avg=self.sum/self.count
        self.vals.append(val)
        self.std=np.std(self.vals)

def dice_val_VOI(y_pred, y_true):
    """
    Simple DSC across labels 1..35
    """
    pred=y_pred.detach().cpu().numpy()[0,0,...]
    true=y_true.detach().cpu().numpy()[0,0,...]
    DSCs=[]
    for i in range(1,36):
        p=(pred==i)
        t=(true==i)
        inter=(p & t).sum()
        union=p.sum()+t.sum()
        DSCs.append(2.*inter/(union+1e-5))
    return np.mean(DSCs)

################################################################################
# 6) LOSSES
################################################################################

class Grad3d(nn.Module):
    def __init__(self, penalty='l2', loss_mult=None):
        super().__init__()
        self.penalty=penalty
        self.loss_mult=loss_mult
    def forward(self, y_pred, y_true):
        dy=torch.abs(y_pred[:,:,1:,:,:]-y_pred[:,:,:-1,:,:])
        dx=torch.abs(y_pred[:,:,:,1:,:]-y_pred[:,:,:,:-1,:])
        dz=torch.abs(y_pred[:,:,:,:,1:]-y_pred[:,:,:,:,:-1])
        if self.penalty=='l2':
            dy=dy*dy
            dx=dx*dx
            dz=dz*dz
        d=torch.mean(dx)+torch.mean(dy)+torch.mean(dz)
        grad=d/3.0
        if self.loss_mult is not None:
            grad*=self.loss_mult
        return grad

class NCC_vxm(nn.Module):
    def __init__(self, win=None):
        super().__init__()
        self.win=win
    def forward(self, y_true, y_pred):
        Ii=y_true
        Ji=y_pred
        ndims=len(list(Ii.size()))-2
        assert ndims in [1,2,3]
        win=[9]*ndims if self.win is None else self.win
        sum_filt=torch.ones([1,1,*win]).to(Ii.device)
        pad_no=win[0]//2
        if ndims==1:
            stride=(1,)
            padding=(pad_no,)
        elif ndims==2:
            stride=(1,1)
            padding=(pad_no,pad_no)
        else:
            stride=(1,1,1)
            padding=(pad_no,pad_no,pad_no)
        conv_fn=getattr(F, f'conv{ndims}d')
        I2=Ii*Ii
        J2=Ji*Ji
        IJ=Ii*Ji
        I_sum=conv_fn(Ii, sum_filt, stride=stride, padding=padding)
        J_sum=conv_fn(Ji, sum_filt, stride=stride, padding=padding)
        I2_sum=conv_fn(I2, sum_filt, stride=stride, padding=padding)
        J2_sum=conv_fn(J2, sum_filt, stride=stride, padding=padding)
        IJ_sum=conv_fn(IJ, sum_filt, stride=stride, padding=padding)
        win_size=np.prod(win)
        u_I=I_sum/win_size
        u_J=J_sum/win_size
        cross=IJ_sum - u_J*I_sum - u_I*J_sum + u_I*u_J*win_size
        I_var=I2_sum-2*u_I*I_sum+u_I*u_I*win_size
        J_var=J2_sum-2*u_J*J_sum+u_J*u_J*win_size
        cc=cross*cross/(I_var*J_var+1e-5)
        return -torch.mean(cc)

################################################################################
# 7) SWIN TRANSFORMER + UTSRMorph MODEL
################################################################################

class CA(nn.Module):
    """Channel attention used in RCAN."""
    def __init__(self,num_feat,squeeze_factor=16):
        super().__init__()
        self.attention=nn.Sequential(
            nn.AdaptiveAvgPool3d(1),
            nn.Conv3d(num_feat,num_feat//squeeze_factor,1,padding=0),
            nn.ReLU(inplace=True),
            nn.Conv3d(num_feat//squeeze_factor,num_feat,1,padding=0),
            nn.Sigmoid()
        )
    def forward(self,x):
        y=self.attention(x)
        return x*y

class CAB(nn.Module):
    def __init__(self,num_feat,compress_ratio=3,squeeze_factor=30):
        super().__init__()
        self.cab=nn.Sequential(
            nn.Conv3d(num_feat,num_feat//compress_ratio,3,1,1),
            nn.GELU(),
            nn.Conv3d(num_feat//compress_ratio,num_feat,3,1,1),
            CA(num_feat,squeeze_factor)
        )
    def forward(self,x):
        return self.cab(x)

class Mlp(nn.Module):
    def __init__(self,in_features,hidden_features=None,out_features=None,act_layer=nn.GELU,drop=0.):
        super().__init__()
        out_features=out_features or in_features
        hidden_features=hidden_features or in_features
        self.fc1=nn.Linear(in_features,hidden_features)
        self.act=act_layer()
        self.fc2=nn.Linear(hidden_features,out_features)
        self.drop=nn.Dropout(drop)
    def forward(self,x):
        x=self.fc1(x)
        x=self.act(x)
        x=self.drop(x)
        x=self.fc2(x)
        x=self.drop(x)
        return x

def window_partition(x,window_size):
    B,H,W,L,C=x.shape
    x=x.view(B,H//window_size[0],window_size[0],W//window_size[1],window_size[1],L//window_size[2],window_size[2],C)
    windows=x.permute(0,1,3,5,2,4,6,7).contiguous().view(-1,window_size[0],window_size[1],window_size[2],C)
    return windows

def window_reverse(windows,window_size,H,W,L):
    B=int(windows.shape[0]/(H*W*L/window_size[0]/window_size[1]/window_size[2]))
    x=windows.view(B,H//window_size[0],W//window_size[1],L//window_size[2],window_size[0],window_size[1],window_size[2],-1)
    x=x.permute(0,1,4,2,5,3,6,7).contiguous().view(B,H,W,L,-1)
    return x

class WindowAttention(nn.Module):
    def __init__(self,dim,window_size,num_heads,qkv_bias=True,qk_scale=None,rpe=True,attn_drop=0.,proj_drop=0.):
        super().__init__()
        self.dim=dim
        self.window_size=window_size
        self.num_heads=num_heads
        head_dim=dim//num_heads
        self.scale=qk_scale or head_dim**-0.5
        self.rpe=rpe
        # relative pos
        table_size=(2*window_size[0]-1)*(2*window_size[1]-1)*(2*window_size[2]-1)
        self.relative_position_bias_table=nn.Parameter(torch.zeros(table_size,num_heads))
        trunc_normal_(self.relative_position_bias_table,std=.02)
        coords_h=torch.arange(window_size[0])
        coords_w=torch.arange(window_size[1])
        coords_t=torch.arange(window_size[2])
        coords=torch.stack(torch.meshgrid([coords_h,coords_w,coords_t]))
        coords_flat=coords.flatten(1)
        if rpe:
            rel=coords_flat[:,:,None]-coords_flat[:,None,:]
            rel=rel.permute(1,2,0).contiguous()
            rel[...,0]+=window_size[0]-1
            rel[...,1]+=window_size[1]-1
            rel[...,2]+=window_size[2]-1
            pos_factor=(2*window_size[1]-1)*(2*window_size[2]-1)
            rel[...,0]*=pos_factor
            rel[...,1]*=(2*window_size[2]-1)
            rel_index=rel.sum(-1)
            self.register_buffer("relative_position_index",rel_index)
        self.qkv=nn.Linear(dim,dim*3,bias=qkv_bias)
        self.attn_drop=nn.Dropout(attn_drop)
        self.proj=nn.Linear(dim,dim)
        self.proj_drop=nn.Dropout(proj_drop)
        self.softmax=nn.Softmax(dim=-1)
    def forward(self,x,mask=None):
        B_,N,C=x.shape
        qkv=self.qkv(x).reshape(B_,N,3,self.num_heads,C//self.num_heads).permute(2,0,3,1,4)
        q,k,v=qkv[0],qkv[1],qkv[2]
        q=q*self.scale
        attn=q@k.transpose(-2,-1)
        if self.rpe:
            table_size=self.window_size[0]*self.window_size[1]*self.window_size[2]
            rpb=self.relative_position_bias_table[self.relative_position_index.view(-1)]
            rpb=rpb.view(table_size,table_size,self.num_heads).permute(2,0,1)
            attn=attn+rpb.unsqueeze(0)
        if mask is not None:
            nW=mask.shape[0]
            attn=attn.view(B_//nW,nW,self.num_heads,N,N)
            attn=attn+mask.unsqueeze(1).unsqueeze(0)
            attn=attn.view(-1,self.num_heads,N,N)
            attn=self.softmax(attn)
        else:
            attn=self.softmax(attn)
        attn=self.attn_drop(attn)
        x=(attn@v).transpose(1,2).reshape(B_,N,C)
        x=self.proj(x)
        x=self.proj_drop(x)
        return x

class FAB(nn.Module):
    """
    Fusion Attention Block = W-MSA + local conv
    """
    def __init__(self,dim,num_heads,window_size=(7,7,7),shift_size=(0,0,0),
                 mlp_ratio=4.,qkv_bias=True,qk_scale=None,rpe=True,drop=0.,
                 attn_drop=0.,drop_path=0.,act_layer=nn.GELU,norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim=dim
        self.num_heads=num_heads
        self.window_size=window_size
        self.shift_size=shift_size
        self.norm1=norm_layer(dim)
        self.attn=WindowAttention(dim,window_size,num_heads,qkv_bias,qk_scale,rpe,attn_drop,drop)
        self.drop_path=DropPath(drop_path) if drop_path>0 else nn.Identity()
        self.norm2=norm_layer(dim)
        mlp_hidden=int(dim*mlp_ratio)
        self.mlp=Mlp(dim,mlp_hidden,act_layer=act_layer,drop=drop)
        self.conv_block=CAB(dim,3,30)
        self.H=self.W=self.T=None
    def forward(self,x,mask_matrix):
        H,W,T=self.H,self.W,self.T
        B,L,C=x.shape
        shortcut=x
        x=self.norm1(x)
        x_3d=x.view(B,H,W,T,C)
        conv_x=self.conv_block(x_3d.permute(0,4,1,2,3))
        conv_x=conv_x.permute(0,2,3,4,1).contiguous().view(B,H*W*T,C)
        pad_r=(self.window_size[0]-H%self.window_size[0])%self.window_size[0]
        pad_b=(self.window_size[1]-W%self.window_size[1])%self.window_size[1]
        pad_h=(self.window_size[2]-T%self.window_size[2])%self.window_size[2]
        x_3d=F.pad(x_3d,(0,0,0,pad_h,0,pad_b,0,pad_r))
        Hp,Wp,Tp=x_3d.shape[1],x_3d.shape[2],x_3d.shape[3]
        if any(self.shift_size):
            shifted_x=torch.roll(x_3d,shifts=(-self.shift_size[0],-self.shift_size[1],-self.shift_size[2]),dims=(1,2,3))
            attn_mask=mask_matrix
        else:
            shifted_x=x_3d
            attn_mask=None
        x_windows=window_partition(shifted_x,self.window_size)
        x_windows=x_windows.view(-1,self.window_size[0]*self.window_size[1]*self.window_size[2],C)
        attn_windows=self.attn(x_windows,attn_mask)
        attn_windows=attn_windows.view(-1,self.window_size[0],self.window_size[1],self.window_size[2],C)
        shifted_x=window_reverse(attn_windows,self.window_size,Hp,Wp,Tp)
        if any(self.shift_size):
            x_3d=torch.roll(shifted_x,shifts=(self.shift_size[0],self.shift_size[1],self.shift_size[2]),dims=(1,2,3))
        else:
            x_3d=shifted_x
        if pad_r>0 or pad_b>0 or pad_h>0:
            x_3d=x_3d[:,:H,:W,:T,:].contiguous()
        x_attn=x_3d.view(B,H*W*T,C)
        x=shortcut+self.drop_path(x_attn)+conv_x*0.01
        x=x+self.drop_path(self.mlp(self.norm2(x)))
        return x

class PatchMerging(nn.Module):
    def __init__(self,dim,norm_layer=nn.LayerNorm,reduce_factor=2):
        super().__init__()
        self.dim=dim
        self.reduction=nn.Linear(8*dim,(8//reduce_factor)*dim,bias=False)
        self.norm=norm_layer(8*dim)
    def forward(self,x,H,W,T):
        B,L,C=x.shape
        x_3d=x.view(B,H,W,T,C)
        x0=x_3d[:,0::2,0::2,0::2,:]
        x1=x_3d[:,1::2,0::2,0::2,:]
        x2=x_3d[:,0::2,1::2,0::2,:]
        x3=x_3d[:,0::2,0::2,1::2,:]
        x4=x_3d[:,1::2,1::2,0::2,:]
        x5=x_3d[:,0::2,1::2,1::2,:]
        x6=x_3d[:,1::2,0::2,1::2,:]
        x7=x_3d[:,1::2,1::2,1::2,:]
        x_cat=torch.cat([x0,x1,x2,x3,x4,x5,x6,x7],-1)
        x_cat=x_cat.view(B,-1,8*C)
        x_cat=self.norm(x_cat)
        x_cat=self.reduction(x_cat)
        return x_cat

class BasicLayer(nn.Module):
    def __init__(self,dim,depth,num_heads,window_size=(7,7,7),mlp_ratio=4.,qkv_bias=True,
                 qk_scale=None,rpe=True,drop=0.,attn_drop=0.,drop_path=0.,norm_layer=nn.LayerNorm,
                 downsample=None,use_checkpoint=False,pat_merg_rf=2):
        super().__init__()
        self.window_size=window_size
        self.shift_size=(window_size[0]//2,window_size[1]//2,window_size[2]//2)
        self.depth=depth
        self.use_checkpoint=use_checkpoint
        dpr=drop_path if isinstance(drop_path,list) else [drop_path]*depth
        self.blocks=nn.ModuleList([
            FAB(dim,num_heads,window_size,(0,0,0) if (i%2==0) else self.shift_size,
                mlp_ratio,qkv_bias,qk_scale,rpe,drop,attn_drop,
                dpr[i],norm_layer=norm_layer)
            for i in range(depth)
        ])
        self.downsample=downsample(dim=dim,norm_layer=norm_layer,reduce_factor=pat_merg_rf) if downsample else None
    def forward(self,x,H,W,T):
        import numpy as np
        Hp=int(np.ceil(H/self.window_size[0]))*self.window_size[0]
        Wp=int(np.ceil(W/self.window_size[1]))*self.window_size[1]
        Tp=int(np.ceil(T/self.window_size[2]))*self.window_size[2]
        img_mask=torch.zeros((1,Hp,Wp,Tp,1),device=x.device)
        h_slices=(slice(0,-self.window_size[0]),slice(-self.window_size[0],-self.shift_size[0]),slice(-self.shift_size[0],None))
        w_slices=(slice(0,-self.window_size[1]),slice(-self.window_size[1],-self.shift_size[1]),slice(-self.shift_size[1],None))
        t_slices=(slice(0,-self.window_size[2]),slice(-self.window_size[2],-self.shift_size[2]),slice(-self.shift_size[2],None))
        cnt=0
        for hh in h_slices:
            for ww in w_slices:
                for tt in t_slices:
                    img_mask[:,hh,ww,tt,:]=cnt
                    cnt+=1
        mask_windows=window_partition(img_mask,self.window_size)
        mask_windows=mask_windows.view(-1,self.window_size[0]*self.window_size[1]*self.window_size[2])
        attn_mask=mask_windows.unsqueeze(1)-mask_windows.unsqueeze(2)
        attn_mask=attn_mask.masked_fill(attn_mask!=0,float(-100.0)).masked_fill(attn_mask==0,float(0.0))
        for blk in self.blocks:
            blk.H,blk.W,blk.T=H,W,T
            if self.use_checkpoint:
                x=checkpoint.checkpoint(blk,x,attn_mask)
            else:
                x=blk(x,attn_mask)
        if self.downsample is not None:
            x_down=self.downsample(x,H,W,T)
            Wh,Ww,Wt=(H+1)//2,(W+1)//2,(T+1)//2
            return x,H,W,T,x_down,Wh,Ww,Wt
        else:
            return x,H,W,T,x,H,W,T

class PatchEmbed(nn.Module):
    def __init__(self,patch_size=4,in_chans=3,embed_dim=96,norm_layer=None):
        super().__init__()
        patch_size=to_3tuple(patch_size)
        self.patch_size=patch_size
        self.in_chans=in_chans
        self.embed_dim=embed_dim
        self.proj=nn.Conv3d(in_chans,embed_dim,kernel_size=patch_size,stride=patch_size)
        self.norm=norm_layer(embed_dim) if norm_layer else None
    def forward(self,x):
        B,C,H,W,D=x.shape
        ps=self.patch_size
        # pad if needed
        if D%ps[2]!=0:
            x=F.pad(x,(0,ps[2]-D%ps[2]))
        if W%ps[1]!=0:
            x=F.pad(x,(0,0,0,ps[1]-W%ps[1]))
        if H%ps[0]!=0:
            x=F.pad(x,(0,0,0,0,0,ps[0]-H%ps[0]))
        x=self.proj(x)
        if self.norm is not None:
            Hp,Wp,Dp=x.shape[2],x.shape[3],x.shape[4]
            x=x.flatten(2).transpose(1,2)
            x=self.norm(x)
            x=x.transpose(1,2).view(B,self.embed_dim,Hp,Wp,Dp)
        return x

class SwinTransformer(nn.Module):
    def __init__(self,pretrain_img_size=224,patch_size=4,in_chans=3,embed_dim=96,
                 depths=[2,2,2,2],num_heads=[3,3,6,6],window_size=(7,7,7),mlp_ratio=4.,
                 qkv_bias=True,qk_scale=None,drop_rate=0.,attn_drop_rate=0.,drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm,ape=False,spe=False,rpe=True,patch_norm=True,
                 out_indices=(0,1,2,3),frozen_stages=-1,use_checkpoint=False,pat_merg_rf=2):
        super().__init__()
        self.patch_embed=PatchEmbed(patch_size,in_chans,embed_dim,norm_layer if patch_norm else None)
        dpr=[x.item() for x in torch.linspace(0,drop_path_rate,sum(depths))]
        self.layers=nn.ModuleList()
        cur=0
        self.num_layers=len(depths)
        self.out_indices=out_indices
        for i in range(self.num_layers):
            layer=BasicLayer(dim=int(embed_dim*2**i),
                             depth=depths[i],
                             num_heads=num_heads[i],
                             window_size=window_size,
                             mlp_ratio=mlp_ratio,
                             qkv_bias=qkv_bias,
                             qk_scale=qk_scale,
                             rpe=rpe,
                             drop=drop_rate,
                             attn_drop=attn_drop_rate,
                             drop_path=dpr[cur:cur+depths[i]],
                             norm_layer=norm_layer,
                             downsample=PatchMerging if (i<self.num_layers-1) else None,
                             use_checkpoint=use_checkpoint,
                             pat_merg_rf=pat_merg_rf)
            self.layers.append(layer)
            cur+=depths[i]
        num_features=[int(embed_dim*2**i) for i in range(self.num_layers)]
        self.num_features=num_features
        for i_layer in out_indices:
            layer=norm_layer(num_features[i_layer])
            self.add_module(f"norm{i_layer}", layer)
    def forward(self,x):
        x=self.patch_embed(x)
        B,C,Hp,Wp,Dp=x.shape
        x=x.flatten(2).transpose(1,2)
        outs=[]
        for i,layer in enumerate(self.layers):
            x_out,H,W,D,x,Hp,Wp,Dp=layer(x,Hp,Wp,Dp)
            if i in self.out_indices:
                norm_layer=getattr(self,f"norm{i}")
                x_out=norm_layer(x_out)
                out_f=x_out.view(B,H,W,D,self.num_features[i]).permute(0,4,1,2,3).contiguous()
                outs.append(out_f)
        return outs

class PixelShuffle3d(nn.Module):
    def __init__(self,scale):
        super().__init__()
        self.scale=scale
    def forward(self,inp):
        b,c,d,h,w=inp.size()
        oc=c//(self.scale**3)
        od,oh,ow=d*self.scale,h*self.scale,w*self.scale
        out=inp.view(b,oc,self.scale,self.scale,self.scale,d,h,w)
        out=out.permute(0,1,5,2,6,3,7,4).contiguous()
        out=out.view(b,oc,od,oh,ow)
        return out

class ConvergeHead(nn.Module):
    def __init__(self,in_dim,up_ratio,kernel_size,padding):
        super().__init__()
        self.in_dim=in_dim
        self.up_ratio=up_ratio
        self.conv=nn.Conv3d(in_dim,(up_ratio**3)*in_dim,kernel_size,1,padding,1,in_dim)
        self.apply(self._init_weights)
    def forward(self,x):
        hp=self.conv(x)
        poxel=PixelShuffle3d(self.up_ratio)
        hp=poxel(hp)
        return hp
    def _init_weights(self,m):
        if isinstance(m,(nn.Conv3d,nn.Linear)):
            nn.init.normal_(m.weight,std=0.001)
            nn.init.constant_(m.bias,0)
        elif isinstance(m,nn.BatchNorm3d):
            nn.init.constant_(m.weight,1)
            nn.init.constant_(m.bias,0)

class Conv3dReLU(nn.Sequential):
    def __init__(self,in_channels,out_channels,kernel_size,padding=0,stride=1,use_batchnorm=True):
        conv=nn.Conv3d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,bias=False)
        relu=nn.LeakyReLU(inplace=True)
        nm=nn.BatchNorm3d(out_channels) if use_batchnorm else nn.InstanceNorm3d(out_channels)
        super().__init__(conv,nm,relu)

class SR(nn.Module):
    def __init__(self,in_channels,out_channels,skip_channels=0,use_batchnorm=True):
        super().__init__()
        self.up=ConvergeHead(in_channels,2,3,1)
        self.conv1=Conv3dReLU(in_channels+skip_channels,out_channels,3,1,use_batchnorm=use_batchnorm)
        self.conv2=Conv3dReLU(out_channels,out_channels,3,1,use_batchnorm=use_batchnorm)
    def forward(self,x,skip=None):
        x=self.up(x)
        if skip is not None:
            x=torch.cat([x,skip],dim=1)
        x=self.conv1(x)
        x=self.conv2(x)
        return x

class RegistrationHead(nn.Sequential):
    def __init__(self,in_channels,out_channels,kernel_size=3,upsampling=1):
        conv3d=nn.Conv3d(in_channels,out_channels,kernel_size,padding=kernel_size//2)
        conv3d.weight=nn.Parameter(torch.normal(0,1e-5,size=conv3d.weight.shape))
        conv3d.bias=nn.Parameter(torch.zeros(conv3d.bias.shape))
        super().__init__(conv3d)

class SpatialTransformer(nn.Module):
    def __init__(self,size,mode='bilinear'):
        super().__init__()
        self.mode=mode
        vectors=[torch.arange(0,s) for s in size]
        grids=torch.meshgrid(vectors)
        grid=torch.stack(grids)
        grid=torch.unsqueeze(grid,0).float()
        self.register_buffer('grid',grid)
    def forward(self,src,flow):
        new_locs=self.grid+flow
        shape=flow.shape[2:]
        for i in range(len(shape)):
            new_locs[:,i,...]=2*(new_locs[:,i,...]/(shape[i]-1)-0.5)
        if len(shape)==3:
            new_locs=new_locs.permute(0,2,3,4,1)
            new_locs=new_locs[..., [2,1,0]]
        return F.grid_sample(src,new_locs,align_corners=False,mode=self.mode)

def get_UTSRMorph_config():
    config=ml_collections.ConfigDict()
    config.if_transskip=True
    config.if_convskip=True
    config.patch_size=4
    config.in_chans=2
    config.embed_dim=96
    config.depths=(2,2,2,2)
    config.num_heads=(4,4,4,4)
    config.window_size=(5,6,7)
    config.mlp_ratio=4
    config.pat_merg_rf=4
    config.qkv_bias=False
    config.drop_rate=0
    config.drop_path_rate=0.3
    config.ape=False
    config.spe=False
    config.rpe=True
    config.patch_norm=True
    config.use_checkpoint=False
    config.out_indices=(0,1,2,3)
    config.reg_head_chan=16
    config.img_size=(160,192,224)
    return config

class UTSRMorph(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.if_convskip=config.if_convskip
        self.if_transskip=config.if_transskip
        embed_dim=config.embed_dim
        self.transformer=SwinTransformer(
            patch_size=config.patch_size,
            in_chans=config.in_chans,
            embed_dim=config.embed_dim,
            depths=config.depths,
            num_heads=config.num_heads,
            window_size=config.window_size,
            mlp_ratio=config.mlp_ratio,
            qkv_bias=config.qkv_bias,
            drop_rate=config.drop_rate,
            drop_path_rate=config.drop_path_rate,
            ape=config.ape,
            spe=config.spe,
            rpe=config.rpe,
            patch_norm=config.patch_norm,
            use_checkpoint=config.use_checkpoint,
            out_indices=config.out_indices,
            pat_merg_rf=config.pat_merg_rf
        )
        self.up0=SR(embed_dim*8,embed_dim*4,skip_channels=(embed_dim*4 if self.if_transskip else 0),use_batchnorm=False)
        self.up1=SR(embed_dim*4,embed_dim*2,skip_channels=(embed_dim*2 if self.if_transskip else 0),use_batchnorm=False)
        self.up2=SR(embed_dim*2,embed_dim,skip_channels=(embed_dim if self.if_transskip else 0),use_batchnorm=False)
        self.up3=SR(embed_dim,config.reg_head_chan,skip_channels=(embed_dim//2 if self.if_convskip else 0),use_batchnorm=False)
        self.c1=Conv3dReLU(2,embed_dim//2,3,1,use_batchnorm=False)
        self.reg_head=RegistrationHead(config.reg_head_chan,3,3)
        self.spatial_trans=SpatialTransformer(config.img_size)
        self.avg_pool=nn.AvgPool3d(3,stride=2,padding=1)
        self.up=ConvergeHead(3,2,3,1)
    def forward(self,x):
        source=x[:,0:1,:,:,:]
        if self.if_convskip:
            x_s1=self.avg_pool(x)
            f4=self.c1(x_s1)
        else:
            f4=None
        out_feats=self.transformer(x)
        if self.if_transskip:
            f1=out_feats[-2]
            f2=out_feats[-3]
            f3=out_feats[-4]
        else:
            f1=None; f2=None; f3=None
        x=self.up0(out_feats[-1],f1)
        x=self.up1(x,f2)
        x=self.up2(x,f3)
        x=self.up3(x,f4)
        flow=self.reg_head(x)
        flow=self.up(flow)
        out=self.spatial_trans(source,flow)
        return out,flow

################################################################################
# 8) TRAIN SCRIPT / INFERENCE SCRIPT
################################################################################

def train_UTSRMorph():
    print("Train script: you can implement your training logic here.")
    # Example minimal usage
    config=get_UTSRMorph_config()
    model=UTSRMorph(config).cuda()
    # create dataset, loader, etc.
    print("Model built. Ready to train...")

def infer_UTSRMorph():
    print("Inference script: implement your inference logic here.")
    config=get_UTSRMorph_config()
    model=UTSRMorph(config).cuda()
    print("Model built. Ready to infer...")

################################################################################
# 9) MAIN (Optional)
################################################################################

if __name__ == "__main__":
    print("All code loaded in one place!")
    # Example: train_UTSRMorph()
    # Example: infer_UTSRMorph()


All code loaded in one place!
