In [1]:
import data_setup, engine, u_net, utils
from pathlib import Path
from monai import losses
import torch

## 1. Create the train and test_dataloader

In [2]:
in_dir = (Path.cwd().parent) / 'data'
train_dataloader, test_dataloader =data_setup.prepare_train_eval_data(in_dir = in_dir,
                                                      a_max = 300)



### 1.1 Check the length of the dataloader

In [3]:
print(f'Length of the Train dataloader: {len(train_dataloader)}\nLength of the Test dataloader: {len(test_dataloader)}')

Length of the Train dataloader: 6
Length of the Test dataloader: 3


## 2. Creating a model, the loss function and the optimizer

In [4]:
num_classes = utils.number_of_classes(in_dir = in_dir / 'train_segmentations')
model, device = u_net.unet(num_classes = num_classes)
loss_fn = losses.DiceLoss(to_onehot_y = True, sigmoid = True)
optimizer = torch.optim.Adam(params = model.parameters(),
                             lr = 0.001)
target_dir = (Path.cwd().parent) / 'models'
target_dir.is_dir()

[INFO] Number of classes: 8
Layer (type (var_name))                                                               Input Shape               Output Shape              Param #                   Trainable
UNet (UNet)                                                                           [1, 1, 128, 128, 64]      [1, 8, 128, 128, 64]      --                        True
├─Sequential (model)                                                                  [1, 1, 128, 128, 64]      [1, 8, 128, 128, 64]      --                        True
│    └─ResidualUnit (0)                                                               [1, 1, 128, 128, 64]      [1, 16, 64, 64, 32]       --                        True
│    │    └─Conv3d (residual)                                                         [1, 1, 128, 128, 64]      [1, 16, 64, 64, 32]       448                       True
│    │    └─Sequential (conv)                                                         [1, 1, 128, 128, 64]      [1, 16, 64

False

## 3. Start the training loop

In [5]:
if __name__ == '__main__':
    engine.train(model = model,
                 train_dataloader = train_dataloader,
                 test_dataloader = test_dataloader,
                 loss_fn = loss_fn,
                 optimizer = optimizer,
                 device = device,
                 target_dir = target_dir,
                 model_name = 'test.pth',
                 epochs = 2,
                 writer = utils.create_writer(model_name = 'U-net',
                                              extra = 'DiceLoss short test'))

[INFO Created SummaryWriter saving to c:\Users\graumnitz\Desktop\Heart_segmentation\runs\2024-03-13\U-net\DiceLoss short test]


  0%|          | 0/2 [00:00<?, ?it/s]

Step 1 of 6 | train loss: 0.8743 | train metric: 17.03%
Step 2 of 6 | train loss: 0.8638 | train metric: 18.64%
Step 3 of 6 | train loss: 0.8496 | train metric: 20.87%
Step 4 of 6 | train loss: 0.8575 | train metric: 19.50%
Step 5 of 6 | train loss: 0.8434 | train metric: 21.57%
Step 6 of 6 | train loss: 0.8387 | train metric: 22.06%

[INFO] E: 0 | Epoch train loss: 0.8546 | Epoch train metric: 19.94%
--------------------------------------------------

Step: 1 of 3 | test loss: 0.8646 | test metric: 16.16%
Step: 2 of 3 | test loss: 0.8669 | test metric: 15.77%
Step: 3 of 3 | test loss: 0.8689 | test metric: 15.54%

[INFO] E: 0 | Epoch test loss: 0.8668 | Epoch test metric: 15.82%
--------------------------------------------------

[INFO] Saving model to: c:\Users\graumnitz\Desktop\Heart_segmentation\models\test.pth
Step 1 of 6 | train loss: 0.8568 | train metric: 19.16%
Step 2 of 6 | train loss: 0.8564 | train metric: 19.28%
Step 3 of 6 | train loss: 0.8421 | train metric: 21.20%
Step 