In [1]:
import torch
import torchinfo
from models.lightning_models import Unet3D
from modules.segmentation_model import Unet, PositionalUnet

In [2]:
model = Unet3D(model=Unet, loss=torch.nn.BCELoss(), patch_size=(64, 64, 64), strides=(32, 32, 32), padding="same", final_activation=torch.nn.Sigmoid())

In [3]:
pos_model = Unet3D(
    model=PositionalUnet,
    loss=torch.nn.BCELoss(),
    patch_size=(64, 64, 64),
    strides=(32, 32, 32),
    padding="same",
    final_activation=torch.nn.Sigmoid(),
    positional=True,
)

In [4]:
test_tensor = torch.rand((1, 1, 128, 128, 128))

In [5]:
# pred_mod = model.predict_step(test_tensor, patch_size=(64, 64, 64), strides=(32, 32, 32), padding="same", unpad=True, verbose=False)

In [11]:
pred = pos_model.predict_step(test_tensor, patch_size=(64, 64, 64), strides=(32, 32, 32), padding="same", unpad=True, verbose=True, positional=True)

prediction took 3.78 seconds


In [7]:
pos_model.training_step((test_tensor, test_tensor, torch.rand((1, 3, 1))), 0)

/home/romanuccio/RomanuccioDiff/.venv/lib/python3.10/site-packages/pytorch_lightning/core/module.py:441: You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`


tensor(100.0452, grad_fn=<MulBackward0>)

In [8]:
model = PositionalUnet(
    in_channels=1,
    positional_channels=3,
    classes=1,
    depths=[2, 2, 2, 2],
    channel_multipliers=[1, 2, 4, 8],
    embed_dim=48,
    positional_embed_dim=32,
    final_activation=torch.nn.Sigmoid(),
)
model = model.to("cuda")


In [9]:
x = torch.rand((1, 1, 128, 128, 128)).to("cuda")
y = torch.rand((1, 3, 1)).to("cuda")

In [10]:
torchinfo.summary(model, input_data=(x, y), device='cuda', depth=5)

Layer (type:depth-idx)                             Output Shape              Param #
PositionalUnet                                     [1, 1, 128, 128, 128]     --
├─Conv3d: 1-1                                      [1, 48, 64, 64, 64]       432
├─SinusoidalEmbedding: 1-2                         [1, 3, 32]                --
├─ModuleList: 1-3                                  --                        --
│    └─ModuleList: 2-1                             --                        --
│    │    └─ResidualLayer: 3-1                     [1, 48, 64, 64, 64]       --
│    │    │    └─ModuleList: 4-1                   --                        --
│    │    │    │    └─ResidualBlock: 5-1           [1, 48, 64, 64, 64]       5,256
│    │    │    │    └─ResidualBlock: 5-2           [1, 48, 64, 64, 64]       5,256
│    │    └─Downsample: 3-2                        [1, 96, 32, 32, 32]       --
│    │    │    └─Conv3d: 4-2                       [1, 96, 32, 32, 32]       36,960
│    └─ModuleList: 2-2  