In [1]:
import torch

from models_torch.utils_3d import ResizeLayer3D, DropPath3D

In [2]:
def test_resize_layer():
    # Define the target size
    target_depth = 64
    target_height = 64
    target_width = 64
    
    # Create an instance of ResizeLayer
    resize_layer = ResizeLayer3D(target_depth, target_height, target_width)
    
    # Create a sample input tensor with a different size
    input_tensor = torch.randn(1, 3, 32, 32, 32)  # Batch size of 1, 3 channels, 32x32 image
    
    # Perform the resizing operation
    output_tensor = resize_layer(input_tensor)
    
    # Check the output size
    assert output_tensor.shape == (1, 3, target_depth, target_height, target_width), "Output shape should match the target size."
    print(output_tensor.shape)
    
    print("ResizeLayer test passed.")

# Run the test
test_resize_layer()

torch.Size([1, 3, 64, 64, 64])
ResizeLayer test passed.


In [3]:
def test_drop_path():
    drop_path_rate = 0.2
    module = DropPath3D(drop_path_rate)
    x = torch.ones(5, 3, 32, 32, 32)  # Example input tensor

    # Test during training
    output_training = module(x)
    
    zero_elements = 0
    while zero_elements == 0:
        # Test during training
        output_training = module(x)
        zero_elements = (output_training == 0).sum().item()
        
    print(f"Number of zero elements during training: {zero_elements}")
 
    assert (output_training == 0).any().item(), "Some elements should be zeroed out during training."
    assert output_training.shape == x.shape, "Output shape should match input shape during training."
    print(output_training.shape)

    print("All tests passed.")

# Run the test
test_drop_path()

Number of zero elements during training: 97538
torch.Size([5, 3, 32, 32, 32])
All tests passed.


In [4]:
import torch

from models_torch.attention_3d import Attention3D

In [5]:
def test_attention():
    # Define parameters
    dim = 64
    num_heads = 8
    sr_ratio = 1
    batch_size = 2
    depth = 8
    height = 8
    width = 8

    # Create an instance of Attention
    attention_layer = Attention3D(dim, num_heads, sr_ratio)

    # Create a sample input tensor
    input_tensor = torch.randn(batch_size, depth * height * width, dim)  # B, N, C

    # Perform the attention operation
    output_tensor = attention_layer(input_tensor, depth, height, width)

    # Check the output size
    assert output_tensor.shape == (batch_size, depth * height * width, dim), "Output shape should match the input shape."
    print(output_tensor.shape)

    print("Attention test passed.")

# Run the test
test_attention()

torch.Size([2, 512, 64])
Attention test passed.


In [6]:
import torch

from models_torch.head_3d import MLP3D, ConvModule3D, SegFormerHead3D

In [7]:
def test_mlp():
    # Define the dimension
    batch_size = 4
    input_dim = 3
    decode_dim = 2
    additional_dims = (5, 5)  # Example additional dimensions
    
    # Create an instance of MLP
    mlp_layer = MLP3D(input_dim, decode_dim)
    
    # Create a sample input tensor
    input_tensor = torch.randn(batch_size, *additional_dims, input_dim)  # B, decode_dim
    print("shape of input: ", input_tensor.shape)

    
    # Perform the MLP operation
    output_tensor = mlp_layer(input_tensor)
    
    # Check the output size
    expected_shape = (batch_size, *additional_dims, decode_dim)
    assert output_tensor.shape == expected_shape, "Output shape should match the input shape."
    print(output_tensor.shape)
    
    print("MLP test passed.")

# Run the test
test_mlp()

shape of input:  torch.Size([4, 5, 5, 3])
torch.Size([4, 5, 5, 2])
MLP test passed.


In [8]:
def test_conv_module():
    # Define the dimension
    decode_dim_in = 4*64
    decode_dim_out = 64
    
    # Create an instance of ConvModule
    conv_module = ConvModule3D(decode_dim_in, decode_dim_out)
    
    # Create a sample input tensor
    batch_size = 4
    depth = 32
    height = 32
    width = 32
    input_tensor = torch.randn(batch_size, decode_dim_in, depth, height, width)  # B, C, H, W
    
    # Perform the ConvModule operation in training mode
    output_tensor_training = conv_module(input_tensor)
    
    # Check the output size in training mode
    assert output_tensor_training.shape == (batch_size, decode_dim_out, depth, height, width), "Output shape should match the input shape in training mode."
    print(output_tensor_training.shape)
    
    print("ConvModule test passed.")
# Run the test
test_conv_module()

torch.Size([4, 64, 32, 32, 32])
ConvModule test passed.


In [9]:
def test_segformer_head():
    # Define parameters
    input_dims = [64, 128, 256, 512]
    decode_dim = 768
    num_classes = 19
    batch_size = 2
    depth = 32
    height = 32
    width = 32

    # Create an instance of SegFormerHead
    segformer_head = SegFormerHead3D(input_dims, decode_dim, num_classes)

    # Create sample input tensors
    inputs = [torch.randn(batch_size, dim, depth, height, width) for dim in input_dims]

    # Perform the SegFormerHead operation in training mode
    output_tensor_training = segformer_head(inputs)

    # Check the output size in training mode
    assert output_tensor_training.shape == (batch_size, num_classes, depth, height, width), "Output shape should match the expected shape in training mode."
    print(output_tensor_training.shape)
    
    print("SegFormerHead test passed.")

# Run the test
test_segformer_head()

torch.Size([2, 19, 32, 32, 32])
SegFormerHead test passed.


In [10]:
import torch
import torch.nn as nn

from models_torch.modules_3d import DWConv3D, Mlp3D, Block3D, OverlapPatchEmbed3D, MixVisionTransformer3D

In [11]:
def test_dwconv():
    # Initialize the module
    hidden_features = 768
    dwconv = DWConv3D(hidden_features)
    
    # Create mock input data
    batch_size = 2
    depth, height, width = 32, 32, 32
    input_tensor = torch.randn(batch_size, depth * height * width, hidden_features)
    
    # Test forward pass
    output = dwconv(input_tensor, depth, height, width)
    print("Output shape:", output.shape)
    
    # Check output shape
    assert output.shape == (batch_size, depth * height * width, hidden_features), "Output shape mismatch"
    print("Test passed successfully!")
    
# Run the test function
test_dwconv()

Output shape: torch.Size([2, 32768, 768])
Test passed successfully!


In [12]:
def test_mlp():
    # Initialize the module
    in_features = 256
    hidden_features = 128
    out_features = 256
    drop_rate = 0.1
    mlp = Mlp3D(in_features, hidden_features, out_features, drop_rate)
    
    # Create mock input data
    batch_size = 1
    depth, height, width = 32, 32, 32
    input_tensor = torch.randn(batch_size, depth * height * width, in_features)
    
    # Test forward pass in training mode
    output_training = mlp(input_tensor, depth, height, width)
    print("Output shape in training mode:", output_training.shape)
    
    # Check output shape
    assert output_training.shape == (batch_size, depth * height * width, out_features), "Output shape mismatch in training mode"
    print("Tests passed successfully!")
    
# Run the test function
test_mlp()

Output shape in training mode: torch.Size([1, 32768, 256])
Tests passed successfully!


In [13]:
def test_block():
    # Initialize the module
    dim = 64
    num_heads = 8
    mlp_ratio = 4.0
    qkv_bias = True
    drop = 0.1
    attn_drop = 0.1
    drop_path = 0.1
    sr_ratio = 1.0
    block = Block3D(dim, num_heads, mlp_ratio, qkv_bias, drop, attn_drop, drop_path, sr_ratio)

    # Create mock input data
    batch_size = 1
    depth, height, width = 8, 8, 8
    input_tensor = torch.randn(batch_size, depth * height * width, dim)

    # Test forward pass in training mode
    output_training = block(input_tensor, depth, height, width)
    print("Output shape in training mode:", output_training.shape)

    # Check output shape
    assert output_training.shape == (batch_size, depth * height * width, dim), "Output shape mismatch in training mode"

    print("Tests passed successfully!")

# Run the test function
test_block()

Output shape in training mode: torch.Size([1, 512, 64])
Tests passed successfully!


In [14]:
def test_overlap_patch_embed():
    # Initialize the module
    img_size = 224
    img_channels = 3
    patch_size = 7
    stride = 4
    filters = 768
    overlap_patch_embed = OverlapPatchEmbed3D(img_size, img_channels, patch_size, stride, filters)
    
    # Create mock input data
    batch_size = 1
    input_tensor = torch.randn(batch_size, img_channels, img_size, img_size, img_size)  # Simulate a batch of images
    
    # Test forward pass
    output, D, H, W = overlap_patch_embed(input_tensor)
    print("Output shape:", output.shape)
    print("Depth:", D)
    print("Height:", H)
    print("Width:", W)
    
    # Calculate expected dimensions
    expected_H = (img_size + patch_size // 2 * 2 - patch_size) // stride + 1
    expected_W = expected_H  # Assuming square input
    expected_D = expected_H  # Assuming square input
    
    # Check output shape
    assert output.shape == (batch_size, expected_D * expected_H * expected_W, filters), "Output shape mismatch"
    assert H == expected_H, "Height mismatch"
    assert W == expected_W, "Width mismatch"
    print("Tests passed successfully!")
    
# Run the test function
test_overlap_patch_embed()

Output shape: torch.Size([1, 175616, 768])
Depth: 56
Height: 56
Width: 56
Tests passed successfully!


In [15]:
def test_mix_vision_transformer():
    # Initialize the MixVisionTransformer with default parameters
    model = MixVisionTransformer3D()

    # Test initialization
    assert isinstance(model, nn.Module), "Model is not an instance of nn.Module"
    assert len(model.patch_embeds) == 4, "Incorrect number of patch embeddings"
    assert len(model.blocks) == 4, "Incorrect number of blocks"
    assert len(model.norms) == 4, "Incorrect number of norms"

    print("Initialization test passed.")

    # Create a dummy input tensor with the shape (batch_size, channels, depth, height, width)
    dummy_input = torch.randn(1, 3, 224, 224, 224)  # Batch size of 1, 3 channels, 224x224x224 image

    # Perform a forward pass
    output = model(dummy_input)

    # Check if the output is a list and has the expected number of feature maps
    assert isinstance(output, list), "Output is not a list"
    assert len(output) == 4, "Output does not have 4 feature maps"

    # Check the shape of each feature map
    expected_shapes = [(1, 64, 56, 56, 56), (1, 128, 28, 28, 28), (1, 256, 14, 14, 14), (1, 512, 7, 7, 7)]
    for out, expected_shape in zip(output, expected_shapes):
        print(out.shape)
        assert out.shape == expected_shape, f"Feature map shape {out.shape} does not match expected {expected_shape}"

    print("Forward pass test passed.")
    
    
test_mix_vision_transformer()

Initialization test passed.
torch.Size([1, 64, 56, 56, 56])
torch.Size([1, 128, 28, 28, 28])
torch.Size([1, 256, 14, 14, 14])
torch.Size([1, 512, 7, 7, 7])
Forward pass test passed.


In [16]:
import torch
import torch.nn as nn

from models_torch.segformer_3d import SegFormer3D, SegFormer3D_SHViT

In [17]:
def test_segformer_b0():
    input_shape = (3, 224, 224, 224)  # Example input shape (channels, depth, height, width)
    num_classes = 10
    model = SegFormer3D(model_type="B0", input_shape=input_shape, num_classes=num_classes, use_resize=True)

    # Create a dummy input tensor
    dummy_input = torch.rand(1, *input_shape)  # Batch size of 1

    # Perform a forward pass
    output = model(dummy_input)

    # Check the output shape
    expected_output_shape = (1, num_classes, input_shape[1], input_shape[1], input_shape[1])
    print(output.shape, expected_output_shape)
    assert output.shape == expected_output_shape, f"Expected output shape {expected_output_shape}, but got {output.shape}"

    print("Test passed!")

# Run the test
test_segformer_b0()

torch.Size([1, 10, 224, 224, 224]) (1, 10, 224, 224, 224)
Test passed!


In [18]:
def test_segformer_shvit_b0_s4():
    input_shape = (1, 272, 272, 272)  # Example input shape (channels, depth, height, width)
    num_classes = 10
    use_resize = True
    model = SegFormer3D_SHViT(model_type="B0", shvit_type="S4", input_shape=input_shape, num_stages=3, num_classes=num_classes, use_resize=use_resize)

    # Create a dummy input tensor
    dummy_input = torch.rand(1, *input_shape)  # Batch size of 1

    # Perform a forward pass
    output = model(dummy_input)

    # Check the output shape
    factor = 1
    if not use_resize:
        factor = 16
    expected_output_shape = (1, num_classes, input_shape[1]//factor, input_shape[2]//factor, input_shape[3]//factor)
    print(output.shape, expected_output_shape)
    assert output.shape == expected_output_shape, f"Expected output shape {expected_output_shape}, but got {output.shape}"

    print("Test passed!")

# Run the test
test_segformer_shvit_b0_s4()

torch.Size([1, 10, 272, 272, 272]) (1, 10, 272, 272, 272)
Test passed!


In [19]:
def test_segformer_shvit_b0_s4():
    input_shape = (1, 272, 272, 272)  # Example input shape (channels, depth, height, width)
    num_classes = 10
    use_resize = False
    model = SegFormer3D_SHViT(model_type="B0", shvit_type="S4", input_shape=input_shape, num_stages=3, num_classes=num_classes, use_resize=use_resize)

    # Create a dummy input tensor
    dummy_input = torch.rand(1, *input_shape)  # Batch size of 1

    # Perform a forward pass
    output = model(dummy_input)

    # Check the output shape
    factor = 1
    if not use_resize:
        factor = 16
    expected_output_shape = (1, num_classes, input_shape[1]//factor, input_shape[2]//factor, input_shape[3]//factor)
    print(output.shape, expected_output_shape)
    assert output.shape == expected_output_shape, f"Expected output shape {expected_output_shape}, but got {output.shape}"

    print("Test passed!")

# Run the test
test_segformer_shvit_b0_s4()

torch.Size([1, 10, 17, 17, 17]) (1, 10, 17, 17, 17)
Test passed!
