In [4]:
import torch
from pytorch_model_summary import summary
from fvcore.nn import FlopCountAnalysis
from network.dsunetr import DS_UNETR

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


DS-UNETR Model

In [6]:
model = DS_UNETR(in_channels=1,
                out_channels=8,
                img_size=[64, 128, 128],
                feature_size=16,
                num_heads=[4,4,4,4],
                depths=[3, 3, 3, 3],
                dims=[32, 64, 128, 256],
                do_ds=False,
                ).to(device)

In [7]:
# forward
x = torch.ones([1, 1, 64, 128, 128]).cuda()
y = model(x)

# summary
print(summary(model, x))
print('input:', x.shape)
print('output:', y.shape)

----------------------------------------------------------------------------------------------------------------------------------------------------
           Layer (type)                                                                                Output Shape         Param #     Tr. Param #
       SpatialEncoder-1     [1, 64, 256], [1, 32, 32, 32, 32], [1, 64, 16, 16, 16], [1, 128, 8, 8, 8], [1, 64, 256]       6,683,296       6,683,296
       ChannelEncoder-2     [1, 64, 256], [1, 32, 32, 32, 32], [1, 64, 16, 16, 16], [1, 128, 8, 8, 8], [1, 64, 256]       7,671,856       7,671,856
         UnetResBlock-3                                                                       [1, 16, 64, 128, 128]           7,360           7,360
           ConcatConv-4                                                                         [1, 32, 32, 32, 32]         111,904         111,904
           ConcatConv-5                                                                         [1, 64, 16, 16,

In [8]:
# n_parameters & flops
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
flops = FlopCountAnalysis(model, x)
model_flops = flops.total()

print(f"Total trainable parameters: {round(n_parameters * 1e-6, 2)} M")
print(f"MAdds: {round(model_flops * 1e-9, 2)} G")

Unsupported operator aten::mul encountered 543 time(s)
Unsupported operator aten::fill_ encountered 486 time(s)
Unsupported operator aten::sub encountered 18 time(s)
Unsupported operator aten::ne encountered 18 time(s)
Unsupported operator aten::clone encountered 42 time(s)
Unsupported operator aten::add encountered 226 time(s)
Unsupported operator aten::softmax encountered 54 time(s)
Unsupported operator aten::gelu encountered 54 time(s)
Unsupported operator aten::mul_ encountered 13 time(s)
Unsupported operator aten::leaky_relu_ encountered 20 time(s)
Unsupported operator aten::add_ encountered 44 time(s)
Unsupported operator aten::sum encountered 18 time(s)
Unsupported operator aten::div encountered 18 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will stil

Total trainable parameters: 29.0 M
MAdds: 44.99 G


In [9]:
# parameters
print(model.parameters)

<bound method Module.parameters of DS_UNETR(
  (spatial_encoder): SpatialEncoder(
    (downsample_layers): ModuleList(
      (0): Sequential(
        (0): Convolution(
          (conv): Conv3d(1, 32, kernel_size=(2, 4, 4), stride=(2, 4, 4), bias=False)
        )
        (1): GroupNorm(1, 32, eps=1e-05, affine=True)
      )
      (1): Sequential(
        (0): PatchMerging3D(
          (reduction): Linear(in_features=256, out_features=64, bias=False)
        )
        (1): GroupNorm(32, 64, eps=1e-05, affine=True)
      )
      (2): Sequential(
        (0): PatchMerging3D(
          (reduction): Linear(in_features=512, out_features=128, bias=False)
        )
        (1): GroupNorm(64, 128, eps=1e-05, affine=True)
      )
      (3): Sequential(
        (0): PatchMerging3D(
          (reduction): Linear(in_features=1024, out_features=256, bias=False)
        )
        (1): GroupNorm(128, 256, eps=1e-05, affine=True)
      )
    )
    (stages): ModuleList(
      (0): Sequential(
        (0)

Verify training feasibility

In [11]:
# Create a dummy input tensor
dummy_input = torch.randn(1, 1, 64, 128, 128).to(device)  # Batch size of 1, in_channels of 1, and image size of 64x128x128

# Set the model to training mode
model.train()

# Define an optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Perform a forward pass
output = model(dummy_input)

# Create a dummy target tensor with the same shape as the output
dummy_target = torch.randn_like(output)

# Define a loss function
criterion = torch.nn.MSELoss()

# Compute the loss
loss = criterion(output, dummy_target)

# Perform a backward pass to check if gradients can be computed
try:
    loss.backward()
    print("Gradients computed successfully.")

    # Update the model parameters
    optimizer.step()
    print("Model parameters updated successfully.")
except RuntimeError as e:
    print(f"An error occurred: {e}")

Gradients computed successfully.
Model parameters updated successfully.
