In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import pandas as pd

unprocessed_data_root_dir = Path("/home/haim/code/tumors/data")
processed_data_root_dir = Path("/home/haim/code/tumors/data/processed")
processed_images_dir: Path = processed_data_root_dir / "volumes"
processed_segments_dir: Path = processed_data_root_dir / "segmentations"

df = pd.read_csv("/home/haim/code/tumors/liver_tumors/image_and_segment_paths.csv")
vol = 0
idx = 60

In [None]:
img = nib.load(unprocessed_data_root_dir / df["image_path"][vol])
img_data = img.get_fdata()
print(img_data.shape)
plt.imshow(img_data[:, :, idx].T, cmap="bone")
plt.axis('off')
plt.show()

In [None]:
%%writefile /home/haim/code/tumors/liver_tumors/src/liver_tk/transforms/transforms.py
from typing import Tuple
import torch


class PadOrTrim:
    """
    A torch transform that Adjusts the depth of the input image and segment to match the target
    depth by either padding or trimming.

    Args:
        image (torch.Tensor): The input image tensor.
        segment (torch.Tensor): The input segmentation tensor.
        target_depth (int): The target depth to which the image and segment should be adjusted.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: The adjusted image and segmentation tensors.
    """
    def __init__(self, target_depth: int, sparse_result: bool = False, depth_dim: int = 2, pad: float = 0):
        self.target_depth = target_depth
        self.depth_dim = depth_dim
        self.pad = pad
        self.sparse_result = sparse_result

    def __call__(self, sample: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        current_depth = sample.shape[self.depth_dim]

        if current_depth > self.target_depth:
            sample = sample[:, :, :self.target_depth]
        elif current_depth < self.target_depth:
            pad_after = self.target_depth - current_depth
            sample = torch.nn.functional.pad(sample, (self.pad, pad_after))

        return sample

class WindowImage:
    """
    A torch transform tha applies Hounsfiled windowing to a CT image for better visualization.

    Args:
        image (torch.Tensor): The input CT image as a PyTorch tensor.
        window_level (float): The center of the windowing range.
        window_width (float): The width of the windowing range.

    Returns:
        torch.Tensor: The windowed image normalized to the range [0, 255] as a PyTorch tensor.
    """
    def __init__(self, window_level: float, window_width: float):
        self.window_level = window_level
        self.window_width = window_width

    def __call__(self, image: torch.Tensor) -> torch.Tensor:
        min_intensity = self.window_level - (self.window_width / 2)
        max_intensity = self.window_level + (self.window_width / 2)
        
        windowed_image = torch.clamp(image, min_intensity, max_intensity)
        windowed_image = ((windowed_image - min_intensity) / (max_intensity - min_intensity) * 255).to(torch.uint8)
        
        return windowed_image

In [None]:
import torch
from torchvision import transforms
from liver_tk.transforms import PadOrTrim, WindowImage

In [None]:
window_level: int = 30
window_width: int = 150
target_depth: int = 851
image_transforms = transforms.Compose([
        WindowImage(window_level=window_level, window_width=window_width),
        PadOrTrim(target_depth=target_depth),
        # transforms.Lambda(lambda x: x.to_sparse(2)),
    ])
mask_transforms = transforms.Compose([PadOrTrim(target_depth=target_depth)])

In [None]:
%%writefile /home/haim/code/tumors/liver_tumors/src/liver_tk/datamodule/segmentation_dataset.py
from pathlib import Path

import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import nibabel as nib
from typing import Tuple, Optional

class SegmentationDataset(Dataset):
    def __init__(self, data_root_path: Path, data_frame: pd.DataFrame, image_transform: Optional[torch.nn.Module] = None, mask_transform: Optional[torch.nn.Module] = None):
        """
        Args:
            data_root_path (Path): Root directory path for the dataset.
            data_frame (pd.DataFrame): DataFrame with volume, image paths, and segmentation paths.
            image_transform (Optional[torch.nn.Module], optional): Optional transform to be applied on an image.
            mask_transform (Optional[torch.nn.Module], optional): Optional transform to be applied on a mask.
        """
        self.data_root_path = data_root_path
        self.data_frame = data_frame
        self.image_transform = image_transform
        self.mask_transform = mask_transform

    def __len__(self) -> int:
        return len(self.data_frame)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_path = self.data_root_path / self.data_frame.iloc[idx, 1]
        segment_path = self.data_root_path / self.data_frame.iloc[idx, 2]

        img = nib.load(img_path).get_fdata()
        segment = nib.load(segment_path).get_fdata()

        # Convert to torch tensor
        img = torch.tensor(img, dtype=torch.float16)
        segment = torch.tensor(segment, dtype=torch.uint8)

        if self.image_transform:
            img = self.image_transform(img)

        if self.mask_transform:
            segment = self.mask_transform(segment)

        return img, segment


In [None]:
# Usage example
data_frame = pd.read_csv("/home/haim/code/tumors/liver_tumors/image_and_segment_paths.csv")  # Load your DataFrame

dataset = SegmentationDataset(
    data_root_path=Path("/home/haim/code/tumors/data"),
    data_frame=data_frame,
    image_transform=image_transforms,
    mask_transform=mask_transforms,
)

dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

In [None]:
X, y = next(iter(dataloader))

In [None]:
plt.imshow(X.numpy()[0, :, :, 85].T, cmap="bone")
plt.imshow(y.numpy()[0, :, :, 85].T, cmap="viridis", alpha=0.3)
plt.axis('off')
plt.title("Segmented Liver")
plt.show()

In [None]:
y.numpy()[0, :, :, 85].T

In [None]:
%%writefile /home/haim/code/tumors/liver_tumors/src/liver_tk/datamodule/segmentation_datamodule.py
import lightning as L
from sklearn.model_selection import train_test_split

class SegmentationDataModule(L.LightningDataModule):
    def __init__(self, data_root_path: str, csv_file_path: str, batch_size: int, num_workers: int = 4, window_level: int = 30, window_width: int = 150, target_depth: int = 851):
        super().__init__()
        self.data_root_path = Path(data_root_path)
        self.csv_file_path = Path(csv_file_path)
        self.batch_size: int = batch_size
        self.num_workers: int = num_workers

        self.df_train: pd.DataFrame = None
        self.df_val: pd.DataFrame = None
        self.df_test: pd.DataFrame = None

        self.image_transform = transforms.Compose([
            WindowImage(window_level=window_level, window_width=window_width),
            PadOrTrim(target_depth=target_depth),
        ])

        self.mask_transform = transforms.Compose([PadOrTrim(target_depth=target_depth)])

    def setup(self, stage: str = None):
        train_size: int = 0.7
        val_size: int = 0.15
        test_size: int = 0.15
        df = pd.read_csv(self.csv_file)
        self.df_train, df_temp = train_test_split(df, test_size=(1.0 - train_size), random_state=42)
        test_frac = test_size / (val_size + test_size)
        self.df_val, self.df_test = train_test_split(df_temp, test_size=test_frac, random_state=42)


        self.train_set = SegmentationDataset(
            data_root_path=self.data_root_path,
            data_frame=self.df_train,
            image_transform=image_transforms,
            mask_transform=mask_transforms,
        )

        self.val_set = SegmentationDataset(
            data_root_path=self.data_root_path,
            data_frame=self.df_val,
            image_transform=image_transforms,
            mask_transform=mask_transforms,
        )

        self.test_set = SegmentationDataset(
            data_root_path=self.data_root_path,
            data_frame=self.df_test,
            image_transform=image_transforms,
            mask_transform=mask_transforms,
        )


    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)

In [None]:
# from sklearn.model_selection import train_test_split
# train_size = 0.7
# val_size = 0.15
# test_size = 0.15
# df = pd.read_csv("/home/haim/code/tumors/liver_tumors/image_and_segment_paths.csv")
# df_train, df_temp = train_test_split(df, test_size=(1.0 - train_size), random_state=42)
# val_frac = val_size / (val_size + test_size)
# test_frac = test_size / (val_size + test_size)
# df_val, df_test = train_test_split(df_temp, test_size=test_frac, random_state=42)


In [1]:
from liver_tk.datamodule.segmentation_datamodule import SegmentationDataModule
import torch
torch.set_float32_matmul_precision('medium')

In [2]:
datamodule = SegmentationDataModule(
        batch_size=1,
        data_root_path="/home/haim/code/tumors/data",
        csv_file_path="/home/haim/code/tumors/liver_tumors/image_and_segment_paths.csv",
        num_workers=1,
)
datamodule.setup()

In [3]:
train_dataloader = datamodule.train_dataloader()

In [4]:
X, y = next(iter(train_dataloader))

In [7]:
X = X.to("cuda")
y = y.to("cuda")

In [None]:
from icecream import ic
ic(X.shape)
ic(y.shape)

In [8]:
import lightning as L
from liver_tk.nets.unet import SegmentationModel
from liver_tk.datamodule.segmentation_datamodule import SegmentationDataModule
import torch
torch.set_float32_matmul_precision('medium')
model = SegmentationModel().to("cuda")

In [9]:
y_hat = model(X)

In [None]:
# N, D, H, W
X.shape

In [None]:
# N, Cin, H, D
X.unsqueeze(1).shape

In [10]:
y_hat.shape

torch.Size([1, 1, 512, 512, 852])