In [1]:
!pip install torch torchvision torchsummary einops transformers nibabel tqdm



In [2]:
import nibabel as nib
import torch
from torch.utils.data import Dataset
import numpy as np
import glob

class CTDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None):
        self.images = sorted(glob.glob(f"{image_dir}/*.nii.gz"))
        self.labels = sorted(glob.glob(f"{label_dir}/*.nii.gz"))
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = nib.load(self.images[idx]).get_fdata().astype(np.float32)
        lbl = nib.load(self.labels[idx]).get_fdata().astype(np.int64)

        img = np.expand_dims(img, axis=0)  # Add channel dimension
        if self.transform:
            img, lbl = self.transform(img, lbl)
        return torch.tensor(img), torch.tensor(lbl)


In [4]:
import os
import nibabel as nib
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class KiTS23Dataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        root_dir: "/content/kits23/dataset/"
        Each case folder should have 'imaging.nii.gz' and 'segmentation.nii.gz'
        """
        self.root_dir = root_dir
        self.case_folders = sorted(
            [os.path.join(root_dir, d) for d in os.listdir(root_dir)
             if os.path.isdir(os.path.join(root_dir, d))]
        )
        self.transform = transform

    def __len__(self):
        return len(self.case_folders)

    def __getitem__(self, idx):
        case_path = self.case_folders[idx]
        img_path = os.path.join(case_path, "imaging.nii.gz")
        lbl_path = os.path.join(case_path, "segmentation.nii.gz")

        # Load NIfTI
        img = nib.load(img_path).get_fdata().astype(np.float32)
        lbl = nib.load(lbl_path).get_fdata().astype(np.int64)

        # Add channel dimension for PyTorch (C, H, W, D)
        img = np.expand_dims(img, axis=0)

        if self.transform:
            img, lbl = self.transform(img, lbl)

        return torch.tensor(img), torch.tensor(lbl)

# Example usage
dataset = KiTS23Dataset("/content/kits23/dataset/")
train_loader = DataLoader(dataset, batch_size=1, shuffle=True)

# Check one batch
for imgs, lbls in train_loader:
    print("Image shape:", imgs.shape)
    print("Label shape:", lbls.shape)
    break


Image shape: torch.Size([1, 1, 553, 512, 512])
Label shape: torch.Size([1, 553, 512, 512])


In [None]:
import os
import nibabel as nib
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class KiTS23SliceDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Convert each 3D case to 2D slices along axial plane
        """
        self.slice_data = []  # list of tuples: (image_slice, label_slice)
        self.transform = transform

        case_folders = sorted([
            os.path.join(root_dir, d) for d in os.listdir(root_dir)
            if os.path.isdir(os.path.join(root_dir, d))
        ])

        for case in case_folders:
            img_nii = os.path.join(case, "imaging.nii.gz")
            lbl_nii = os.path.join(case, "segmentation.nii.gz")

            img = nib.load(img_nii).get_fdata().astype(np.float32)
            lbl = nib.load(lbl_nii).get_fdata().astype(np.int64)

            # normalize CT values [-1000, 400] -> [-1, 1]
            img = np.clip(img, -1000, 400)
            img = (img + 1000) / 1400  # 0-1
            img = img * 2 - 1  # -1 to 1

            # slice along z-axis
            for i in range(img.shape[2]):
                img_slice = img[:, :, i]
                lbl_slice = lbl[:, :, i]

                # skip empty slices
                if np.sum(lbl_slice) == 0:
                    continue

                self.slice_data.append((img_slice, lbl_slice))

    def __len__(self):
        return len(self.slice_data)

    def __getitem__(self, idx):
        img, lbl = self.slice_data[idx]
        img = np.expand_dims(img, 0)  # add channel dimension
        if self.transform:
            img, lbl = self.transform(img, lbl)
        return torch.tensor(img, dtype=torch.float32), torch.tensor(lbl, dtype=torch.long)

# Example usage
dataset = KiTS23SliceDataset("/content/kits23/dataset/")
train_loader = DataLoader(dataset, batch_size=4, shuffle=True)

print("Total slices:", len(dataset))
imgs, lbls = next(iter(train_loader))
print("Batch shapes:", imgs.shape, lbls.shape)
