In [1]:
import torch
import numpy as np
import wandb
from models.lightning_models import Unet3D
from modules.segmentation_model import Unet, PositionalUnet
from loaders.lazy_loaders import PatchDataloader
from monai.losses.dice import DiceLoss
from losses.jaccard_loss import JaccardLoss # TODO for later
from torch.nn import BCELoss
from torch.utils.data import DataLoader
from torch.optim import Adam
from pytorch_lightning.loggers import WandbLogger
# from losses.dice_loss import DiceLoss

2025-06-30 14:23:02.453239: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-06-30 14:23:02.457518: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-06-30 14:23:02.466441: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1751286182.480892  389517 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1751286182.484475  389517 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been regist

In [None]:
channel_list = [2]
# slices = 80

train = np.load("../MAMAMIA_Challenge/train_ids_v_pojeb.npy", allow_pickle=True)
validation = np.load("../MAMAMIA_Challenge/validation_ids_v_pojeb.npy", allow_pickle=True)[:200]
sets = [train, validation]

# get zipped filepaths depending on channel count
channel_images = [None] * len(sets)
for i, set in enumerate(sets):
    set_images = [None] * len(channel_list)
    for j, channel in enumerate(channel_list):
        collected_channel_list = []
        for set_id in set:
            set_channel = next(set_id.glob(f"*000{channel}.nii.gz"))
            collected_channel_list.append(set_channel)

        set_images[j] = collected_channel_list

    channel_images[i] = set_images

# get labels
labels = [None] * len(sets)
for i, set in enumerate(sets):
    set_labels = []
    for id in set:
        seg = next(id.glob("*segmentation.nii.gz"))
        set_labels.append(seg)
    labels[i] = set_labels

# get dicts
train_dicts = [
    {"image": image_name, "label": label_name}
    for image_name, label_name in zip(list(zip(*channel_images[0])), labels[0])
]
val_dicts = [
    {"image": image_name, "label": label_name}
    for image_name, label_name in zip(list(zip(*channel_images[1])), labels[1])
]

In [3]:
patch_size = (64, 64, 64)
train_dataset = PatchDataloader(images_labels_dict=train_dicts, repeat=4, patch_size=patch_size)
validation_dataset = PatchDataloader(images_labels_dict=val_dicts, repeat=1, patch_size=-1)

In [4]:
train_dataset[0][0].shape, train_dataset[0][1].shape, validation_dataset[0][0].shape

(torch.Size([1, 64, 64, 64]),
 torch.Size([64, 64, 64]),
 torch.Size([1, 154, 146, 161]))

In [5]:
# import matplotlib.pyplot as plt
# dat = train_dataset[12]
# plt.imshow(dat[0][0,30,...], cmap='gray')
# plt.show()
# plt.imshow(dat[1][30,...], cmap='gray')
# plt.show()
# # TODO extract 3D images and check in 3dslicer

In [6]:
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=False, num_workers=16)
# train_loader = DataLoader(train_dataset, batch_size=min(slices, 16), shuffle=False, num_workers=min(slices, 16))
validation_loader = DataLoader(validation_dataset, batch_size=1, shuffle=False, num_workers=1)

In [None]:
stride = (32, 32, 32)
model = Unet3D(
    model=Unet,
    depths=[3, 3, 3, 3],
    # loss=BCELoss(),
    # loss=JaccardLoss(from_logits=True, reduce=True),
    # loss=DiceLoss(convert_logits_to_probs=True),
    loss=DiceLoss(include_background=True, to_onehot_y=True, sigmoid=True),
    patch_size=patch_size,
    strides=stride,
    padding="same",
    classes=1,
    beta=1,
    initial_LR=1e-4
    # final_activation=torch.nn.Sigmoid(),
    # optimizer=Adam
)

In [8]:
# device = torch.device('cuda:0')
# checkpoint_path = (
#     "/home/romanuccio/RomanuccioDiff/Segmentation3D/Gianluca_Mamamia/3mkossiu/checkpoints/best-checkpoint-epoch=11-val_loss=0.91.ckpt"
# )
# model = Unet3D.load_from_checkpoint(
#     checkpoint_path,
#     model=Unet,
#     depths=[3, 3, 3, 3],
#     # loss=BCELoss(),
#     # loss=JaccardLoss(from_logits=True, reduce=True),
#     loss=DiceLoss(convert_logits_to_probs=True),
#     # loss=DiceLoss(include_background=True, to_onehot_y=False, sigmoid=True),
#     patch_size=patch_size,
#     strides=stride,
#     padding="same",
#     classes=1,
#     beta=1,
# )
# # model.eval()  # or model.train() depending on your use case

In [9]:
# data = []
# for i, val_data in enumerate(validation_loader):
#     if i > 1:
#         break
#     # model.validation_step(val_data, i)
#     # image, label = val_data
#     # print(image.shape)
#     # # model.predict_step(
#     # #     image,
#     # #     patch_size=patch_size,
#     # #     strides=stride,
#     # #     padding="same", unpad=True, verbose=True
#     # # )
    
#     data.append(val_data)

In [10]:
# asdf = data[0].to(device)
# model.validation_step(asdf, 0)

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.trainer import Trainer

checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",         # Metric to monitor
    mode="min",                 # 'min' for loss, 'max' for accuracy
    save_top_k=1,               # Save only the best model
    filename="best-checkpoint-{epoch:02d}-{val_loss:.2f}",
    verbose=True
)

wandb_logger = WandbLogger(project="Gianluca_Mamamia")
trainer = Trainer(max_epochs=10, callbacks=[checkpoint_callback], devices=[1], log_every_n_steps=0, logger=wandb_logger)
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=validation_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mromanuccio[0m ([33mromanuccio-brno-university-of-technology[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name             | Type     | Params | Mode 
------------------------------------------------------
0 | final_activation | Identity | 0      | train
1 | model            | Unet     | 2.0 M  | train
2 | loss             | DiceLoss | 0      | train
------------------------------------------------------
2.0 M     Trainable params
0         Non-trainable params
2.0 M     Total params
8.089     Total estimated model params size (MB)
287       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.
single channel prediction, `to_onehot_y=True` ignored.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]