In [1]:
import torch
import monai

from monai.transforms import (
    Compose,
    NormalizeIntensityd,
    RandSpatialCropd,
    RandFlipd,
    RandScaleIntensityd,
    RandShiftIntensityd,
    Activations, AsDiscrete,Resized
)
from torch.utils.tensorboard import SummaryWriter

from dataset import HNTSDataset
from trainer import MedSegTrainer
from monai.metrics import DiceMetric

In [2]:
class Config:
    device = "cuda"
    batch_size = 4
    epochs = 400
    lr_init = 1e-4
    lr_min = 1e-10
    cpu_cores = 8
    weight_decay = 1e-5
    smooth_nr = 0
    smooth_dr = 1e-5


config = Config()

In [3]:
train_transforms = Compose(
    [
        # Normalization and cropping
        RandSpatialCropd(
            keys=["image", "mask"], roi_size=[224, 224, 96], random_size=False
        ),
        RandFlipd(keys=["image", "mask"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image", "mask"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image", "mask"], prob=0.5, spatial_axis=2),
        NormalizeIntensityd(keys="image", nonzero=True),
        RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
        RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
    ]
)
train_dataset = HNTSDataset("data/train", transform=train_transforms)
train_loader = monai.data.DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.cpu_cores,
)

Loading dataset: 100%|██████████| 130/130 [01:10<00:00,  1.84it/s]


In [44]:
# model = monai.networks.nets.SwinUNETR(
#     in_channels=1,
#     out_channels=2,
#     img_size=(224, 224, 96),
#     spatial_dims=3,
#     use_checkpoint=False,
#     use_v2=True,
# ).to(config.device)

In [45]:
# model = monai.networks.nets.SegResNet(
#     blocks_down=[2, 4, 4, 4],
#     blocks_up=[2, 4, 4],
#     init_filters=16,
#     in_channels=1,
#     out_channels=2,
#     dropout_prob=0.2,
# ).to(config.device)

In [3]:
model = monai.networks.nets.UNETR(
    in_channels=1,      
    out_channels=2,     
    img_size=(224, 224, 96),  
    feature_size=16,  
    hidden_size=768,  
    num_heads=16,  
    proj_type="conv",
    norm_name="instance",
).to(config.device)

In [47]:
# model = monai.networks.nets.UNet(
#     spatial_dims=3,
#     in_channels=1,
#     out_channels=2,
#     channels=(32, 64, 128, 256),
#     strides=(2, 2, 2),
#     num_res_units=2,
# ).to(config.device)

In [4]:
optimizer = torch.optim.Adam(model.parameters(), config.lr_init, weight_decay=config.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=config.epochs, eta_min=config.lr_min
)
loss_function = monai.losses.DiceCELoss(
    smooth_nr=config.smooth_nr,
    smooth_dr=config.smooth_dr,
    squared_pred=True,
    to_onehot_y=False,  # labels are already separated by channel
    sigmoid=True,  # 0 is background, 1 is label
    weight=torch.tensor([1.1698134, 0.8732383]).to(config.device)
)
experiment_name = "debug-UNETR"
writer = SummaryWriter(f"logs/writer/{experiment_name}")
trainer = MedSegTrainer(
    experiment_name=experiment_name,
    model=model,
    epochs=config.epochs,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    loss_f=loss_function,
    writer=writer,
    save_every=50
)

In [5]:
trainer.load_checkpoint("logs/debug-UNETR/checkpoints/4950")

  self.model.load_state_dict(torch.load(f"{checkpoint_path}/model.pth"))
  self.optimizer.load_state_dict(torch.load(f"{checkpoint_path}/optimizer.pth"))
  torch.load(f"{checkpoint_path}/lr_scheduler.pth")
  trainer_state_dict: dict = torch.load(f"{checkpoint_path}/trainer_state.pth")


In [50]:
trainer.fit(train_loader=train_loader)

Epoch 1/400


Training step: 100%|██████████| 33/33 [00:40<00:00,  1.23s/it]


Loss epoch: 33.573
Epoch 2/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 33.473
Epoch 3/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 33.430
Epoch 4/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 33.390
Epoch 5/400


Training step: 100%|██████████| 33/33 [00:36<00:00,  1.11s/it]


Loss epoch: 33.365
Epoch 6/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.09s/it]


Loss epoch: 33.312
Epoch 7/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 33.286
Epoch 8/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 33.271
Epoch 9/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 33.246
Epoch 10/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 33.226
Epoch 11/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 33.178
Epoch 12/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 33.170
Epoch 13/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 33.183
Epoch 14/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 33.115
Epoch 15/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 33.101
Epoch 16/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 33.091
Epoch 17/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 33.092
Epoch 18/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 32.968
Epoch 19/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.09s/it]


Loss epoch: 33.017
Epoch 20/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 32.961
Epoch 21/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 32.925
Epoch 22/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 32.790
Epoch 23/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 32.769
Epoch 24/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 32.773
Epoch 25/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 32.711
Epoch 26/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 32.689
Epoch 27/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 32.594
Epoch 28/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 32.612
Epoch 29/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 32.571
Epoch 30/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 32.490
Epoch 31/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 32.435
Epoch 32/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 32.375
Epoch 33/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 32.361
Epoch 34/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 32.114
Epoch 35/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 32.072
Epoch 36/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 32.031
Epoch 37/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 31.909
Epoch 38/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.09s/it]


Loss epoch: 31.948
Epoch 39/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 31.857
Epoch 40/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 31.702
Epoch 41/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 31.627
Epoch 42/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 31.297
Epoch 43/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 31.100
Epoch 44/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 31.178
Epoch 45/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 31.144
Epoch 46/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 30.855
Epoch 47/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 30.712
Epoch 48/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 30.456
Epoch 49/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 30.242
Epoch 50/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 29.968
Epoch 51/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 29.711
Epoch 52/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 29.670
Epoch 53/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 29.451
Epoch 54/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 29.136
Epoch 55/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 29.082
Epoch 56/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 29.237
Epoch 57/400


Training step: 100%|██████████| 33/33 [00:36<00:00,  1.09s/it]


Loss epoch: 28.874
Epoch 58/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 29.022
Epoch 59/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 28.583
Epoch 60/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 28.324
Epoch 61/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 28.376
Epoch 62/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 27.925
Epoch 63/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 27.719
Epoch 64/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 27.554
Epoch 65/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 27.513
Epoch 66/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 26.741
Epoch 67/400


Training step: 100%|██████████| 33/33 [00:36<00:00,  1.12s/it]


Loss epoch: 27.023
Epoch 68/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 26.265
Epoch 69/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 27.026
Epoch 70/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 26.443
Epoch 71/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 26.160
Epoch 72/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 26.131
Epoch 73/400


Training step: 100%|██████████| 33/33 [00:36<00:00,  1.11s/it]


Loss epoch: 26.036
Epoch 74/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 25.508
Epoch 75/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 24.959
Epoch 76/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 24.833
Epoch 77/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 24.551
Epoch 78/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 24.496
Epoch 79/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 24.149
Epoch 80/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 23.553
Epoch 81/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 23.454
Epoch 82/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 23.929
Epoch 83/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 23.999
Epoch 84/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 23.487
Epoch 85/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 23.223
Epoch 86/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 23.038
Epoch 87/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 22.881
Epoch 88/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.09s/it]


Loss epoch: 22.558
Epoch 89/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 22.436
Epoch 90/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 22.458
Epoch 91/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 22.447
Epoch 92/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 22.491
Epoch 93/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 22.281
Epoch 94/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 21.833
Epoch 95/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 22.271
Epoch 96/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 21.716
Epoch 97/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 21.892
Epoch 98/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 21.160
Epoch 99/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 22.104
Epoch 100/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 22.291
Epoch 101/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 21.242
Epoch 102/400


Training step: 100%|██████████| 33/33 [00:36<00:00,  1.10s/it]


Loss epoch: 21.512
Epoch 103/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 20.906
Epoch 104/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 20.339
Epoch 105/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 20.321
Epoch 106/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 20.841
Epoch 107/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 20.087
Epoch 108/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 20.389
Epoch 109/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 20.448
Epoch 110/400


Training step: 100%|██████████| 33/33 [00:36<00:00,  1.10s/it]


Loss epoch: 20.876
Epoch 111/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 20.237
Epoch 112/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 20.695
Epoch 113/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 20.442
Epoch 114/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 19.869
Epoch 115/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 19.470
Epoch 116/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 19.802
Epoch 117/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 20.048
Epoch 118/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 20.002
Epoch 119/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 19.307
Epoch 120/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 19.513
Epoch 121/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 19.807
Epoch 122/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 19.597
Epoch 123/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 19.982
Epoch 124/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 19.877
Epoch 125/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 19.387
Epoch 126/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 19.002
Epoch 127/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 19.017
Epoch 128/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 19.503
Epoch 129/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 19.624
Epoch 130/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 19.699
Epoch 131/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 19.143
Epoch 132/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 19.277
Epoch 133/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 19.258
Epoch 134/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 19.477
Epoch 135/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 19.197
Epoch 136/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 18.986
Epoch 137/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 18.785
Epoch 138/400


Training step: 100%|██████████| 33/33 [00:36<00:00,  1.12s/it]


Loss epoch: 19.033
Epoch 139/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 19.354
Epoch 140/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 18.547
Epoch 141/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 18.195
Epoch 142/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 18.340
Epoch 143/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 18.166
Epoch 144/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 18.519
Epoch 145/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 17.763
Epoch 146/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 18.230
Epoch 147/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 19.102
Epoch 148/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 18.310
Epoch 149/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]


Loss epoch: 18.240
Epoch 150/400


Training step: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]


Loss epoch: 18.464
Epoch 151/400


Training step:  12%|█▏        | 4/33 [00:08<01:01,  2.13s/it]


KeyboardInterrupt: 

In [None]:
# model.load_state_dict(torch.load("logs/exp1-UNETR/model.pth"))

In [6]:
torch.clear_autocast_cache()
test_dataset = HNTSDataset("data/test", transform=NormalizeIntensityd(keys="image"))
test_loader = monai.data.DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=config.cpu_cores,
)

Loading dataset:   0%|          | 0/20 [00:00<?, ?it/s]

Loading dataset: 100%|██████████| 20/20 [00:13<00:00,  1.53it/s]


In [13]:
metrics_dict = {
    "IoU": monai.metrics.MeanIoU(include_background=True, reduction="mean"),
    "Dice": monai.metrics.DiceMetric(include_background=True, reduction="mean"),
    "DiceBatched": monai.metrics.DiceMetric(
        include_background=True, reduction="mean_batch"
    ),
}

score = trainer.test(
    test_loader=test_loader,
    metrics=metrics_dict,
)

print(f"Metrics: {score}")

TypeError: test() missing 1 required positional argument: 'self'

In [None]:
from monai.inferers import sliding_window_inference


def inference(model, input_):
    def _compute(input_):
        return sliding_window_inference(
            inputs=input_,
            roi_size=(224, 224, 96),
            sw_batch_size=1,
            predictor=model,
            overlap=0.2,
        )

    # with torch.amp.autocast('cuda'):
    post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
    return post_trans(_compute(input_.to(config.device)))


In [11]:
from matplotlib import pyplot as plt
import ipywidgets as widgets
from ipywidgets import interact

sample_idx = 14
sample = test_dataset[sample_idx]
image = sample['image']
mask_pred = inference(model, image.unsqueeze(0).to(config.device))[0].cpu()

image = monai.transforms.Orientation(axcodes="SPL")(sample['image'])
mask_target = monai.transforms.Orientation(axcodes="SPL")(sample['mask'])
mask_pred = monai.transforms.Orientation(axcodes="SPL")(mask_pred)

def plot_slice(slice_idx):
    image_slice =  image[0,slice_idx]
    label0 = [mask_pred[0,slice_idx], mask_target[0,slice_idx]]
    label1 = [mask_pred[1,slice_idx], mask_target[1,slice_idx]]
    title = ['Prediction', 'Target']
    
    fig, axs = plt.subplots(1,2,figsize=(16,8))
    for i in range(2):
        axs[i].imshow(image_slice, cmap="gray", alpha=1.0)
        axs[i].imshow(label0[i], cmap="Reds", alpha=0.3)
        axs[i].imshow(label1[i], cmap="plasma", alpha=0.3)
        axs[i].set_title(title[i])
        axs[i].axis('off')
    plt.tight_layout()
    plt.show()
slider = widgets.IntSlider(value=mask_pred.shape[1]//2, min=0, max=mask_pred.shape[1]-1, step=1, description='Slice Index')
interact(plot_slice, slice_idx=slider);

interactive(children=(IntSlider(value=74, description='Slice Index', max=148), Output()), _dom_classes=('widge…

In [None]:
image = sample["image"]
mask_target = sample["mask"]
mask_pred = inference(trainer.model, image.unsqueeze(0))[0].cpu()

image = monai.transforms.Orientation(axcodes="SPL")(image)
mask_target = monai.transforms.Orientation(axcodes="SPL")(mask_target)
mask_pred = monai.transforms.Orientation(axcodes="SPL")(mask_pred)

title = ["Prediction", "Target"]
for slice_idx in range(mask_target.shape[1]):
    image_slice = image[0, slice_idx]
    label0 = [mask_pred[0, slice_idx], mask_target[0, slice_idx]]
    label1 = [mask_pred[1, slice_idx], mask_target[1, slice_idx]]
    fig, axs = plt.subplots(1, 2, figsize=(16, 8))
    for i in range(2):
        axs[i].imshow(image_slice, cmap="gray", alpha=1.0)
        axs[i].imshow(label0[i], cmap="Reds", alpha=0.3)
        axs[i].imshow(label1[i], cmap="plasma", alpha=0.3)
        axs[i].set_title(title[i])
        axs[i].axis("off")
    plt.tight_layout()
    writer.add_figure("prediction_vs_target", fig, global_step=slice_idx)