In [1]:
# Import libraries
import sys

sys.path.append("..")
from src.preprocessing import get_transforms, get_datasets, get_dataloaders

import torch
import matplotlib.pyplot as plt
from src.config import config
from src.model import SegFormer3D

# Set the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: ldeben20 (luciano-deben). Use `wandb login --relogin` to force relogin


cuda:0


In [2]:
# Get the transforms
transform = get_transforms(
    resize_shape=config["resize_shape"], contrast_value=config["contrast_value"]
)

# Get the datasets
train_dataset, val_dataset = get_datasets(
    root_dir="../data",
    collection="HCC-TACE-Seg",
    transform=transform,
    download=True,
    download_len=5,
    seg_type="SEG",
    val_frac=config["val_frac"],
    seed=config["seed"],
)

# Get the dataloaders
train_loader, val_loader = get_dataloaders(
    train_dataset,
    val_dataset,
    batch_size=config["batch_size"],
    num_workers=config["num_workers"],
)

2024-05-12 21:22:37,578 - INFO - Expected md5 is None, skip md5 check for file ..\data\HCC-TACE-Seg\1.2.276.0.7230010.3.1.3.8323329.41.1604860085.518229.zip.
2024-05-12 21:22:37,578 - INFO - File exists: ..\data\HCC-TACE-Seg\1.2.276.0.7230010.3.1.3.8323329.41.1604860085.518229.zip, skipped downloading.
2024-05-12 21:22:37,578 - INFO - Writing into directory: ..\data\HCC-TACE-Seg\raw\1.2.276.0.7230010.3.1.3.8323329.41.1604860085.518229.
2024-05-12 21:22:38,067 - INFO - Expected md5 is None, skip md5 check for file ..\data\HCC-TACE-Seg\1.3.6.1.4.1.14519.5.2.1.1706.8374.172517341095680731665822868712.zip.
2024-05-12 21:22:38,067 - INFO - File exists: ..\data\HCC-TACE-Seg\1.3.6.1.4.1.14519.5.2.1.1706.8374.172517341095680731665822868712.zip, skipped downloading.
2024-05-12 21:22:38,067 - INFO - Writing into directory: ..\data\HCC-TACE-Seg\HCC_017\300\image.
2024-05-12 21:22:38,370 - INFO - Expected md5 is None, skip md5 check for file ..\data\HCC-TACE-Seg\1.2.276.0.7230010.3.1.3.8323329.208

In [10]:
# Sample a batch of data from the dataloader
batch = next(iter(train_loader))

# Separate the image and segmentation from the batch
image, seg = batch["image"], batch["seg"][:, 2]

print(image.shape, seg.shape)

torch.Size([1, 1, 256, 256, 48]) torch.Size([1, 256, 256, 48])


In [11]:
# Evaluate the model
model = SegFormer3D(num_classes=config["num_classes"])
model.load_state_dict(torch.load("../models/model.pth"))
model.eval()

# Move the model to the device
model = model.to(device)

In [12]:
# Set a evaluation script
outputs = model(image.to(device))
prediction = (torch.sigmoid(outputs) > 0.5).float()

print(prediction.shape)
print(prediction.unique())

torch.Size([1, 1, 256, 256, 48])
metatensor([0., 1.], device='cuda:0')


In [13]:
import nibabel as nib
import numpy as np

print(prediction.shape, seg.shape)

# Check if the first dimension is 1 before squeezing
prediction = prediction.squeeze((0, 1))

# Convert to numpy array and change dtype to float32
prediction_numpy = prediction.cpu().numpy().astype(np.float32)

# Save the segmentation target
if seg.size(0) == 1:
    seg = seg.squeeze(0)

seg_numpy = seg.cpu().numpy().astype(np.float32)


torch.Size([1, 1, 256, 256, 48]) torch.Size([1, 256, 256, 48])


In [14]:
# Check the unique values and their count
print(np.unique(prediction_numpy, return_counts=True))

print(np.unique(seg_numpy, return_counts=True))

difference = prediction_numpy - seg_numpy
num_differing_voxels = np.count_nonzero(difference)
print(f"Number of differing voxels: {num_differing_voxels}")

(array([0., 1.], dtype=float32), array([3015732,  129996], dtype=int64))
(array([0., 1.], dtype=float32), array([3137964,    7764], dtype=int64))
Number of differing voxels: 123870


In [9]:
# Save prediction and target volumes as nifti files
new_image = nib.Nifti1Image(prediction_numpy, np.eye(4))
nib.save(new_image, "prediction.nii")

new_image = nib.Nifti1Image(seg_numpy, np.eye(4))
nib.save(new_image, "target.nii")