In [1]:
import torchvision
import numpy as np
import torch
import pydicom as dicom
import nibabel as nib
from torchvision import transforms
from torchvision.io import read_image
from torchvision.datasets import ImageFolder
import os

IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp",
                  ".pgm", ".tif", ".tiff", ".webp", ".dcm", ".nii", ".nii.gz")


def img_normalize(img) -> torch.Tensor:
    """
    Normalize image

    Args:
    img: (Tensor) - raw image

    Return:
    Normalized tensor with 0-1 values
    """
    return (img - torch.min(img)) / (torch.max(img) - torch.min(img))    # (img - torch.mean(img)) / torch.std(img, False)


class ImageFolderDCM(torchvision.datasets.DatasetFolder):
    def __init__(
        self,
        root: str,
        transform=None,
        target_transform=None,
        loader=None,
        is_valid_file=None,
    ):
        super().__init__(
            root,
            loader,
            IMG_EXTENSIONS if is_valid_file is None else None,
            transform=transform,
            target_transform=target_transform,
            is_valid_file=is_valid_file,
        )
        self.imgs = self.samples
        # self.classes = ['segmentations', 'volume']
        
    # def find_classes(self, root):
    #    classes = ['segmentations', 'volume'] 
    #    return classes, {cls_name: i for i, cls_name in enumerate(classes)}


    # def __getitem__(self, index):
    #     original_tuple = super(ImageFolderDCM, self).__getitem__(index)
    #     path = self.imgs[index][0]
    #     tuple_with_path = (original_tuple + (path,))
    #     return tuple_with_path

def conv_1img_3D_totensor(img_path) -> torch.tensor:
    """
    Convert NIFTI or DICOM 3D image file to pytorch tensor

    Args:
        img_path (str): path of the nii or dcm image.

    Returns:
        output (Tensor[image_channels, image_height, image_width])
    """
    if img_path.endswith(".dcm"):
        ds = dicom.dcmread(img_path)
        img = ds.pixel_array
        # imgmax = 2500   # v1-v3 not v4
        # vfun = np.vectorize(lambda x: x if x <= imgmax else imgmax)   # v1-v3 not v4
        # img = vfun(img)   # v1-v3 not v4
        # vfun = np.vectorize(lambda x: x if x >= 0 else 0.0)
        # img = vfun(img) # .astype(float)
        # img = torch.tensor((img / img.max()) * 255, dtype=torch.uint8) # v1-v2
        img = np.clip(img, 0, 2550).astype(np.uint8)
        img = torch.tensor(img, dtype=torch.uint8)  # v3-v4
    # print(image.shape)
    elif img_path.endswith((".nii", ".nii.gz")):
        img3d = nib.load(img_path).get_fdata()
        img = np.clip(img3d, -350, 350).astype(np.uint8)
        img = torch.tensor(img, dtype=torch.uint8)  # v3-v4
        img = img.permute(2, 1, 0)
    else:
        return None
    return img

def conv_1img_totensor(img_path) -> torch.tensor:
    """
    Convert JPEG or DCM image file to pytorch tensor

    Args:
        img_path (str): path of the JPEG or PNG image.

    Returns:
        output (Tensor[image_channels, image_height, image_width])
    """
    if img_path.endswith(".dcm"):
        ds = dicom.dcmread(img_path)
        img = ds.pixel_array
        # imgmax = 2500   # v1-v3 not v4
        # vfun = np.vectorize(lambda x: x if x <= imgmax else imgmax)   # v1-v3 not v4
        # img = vfun(img)   # v1-v3 not v4
        # vfun = np.vectorize(lambda x: x if x >= 0 else 0.0)
        # img = vfun(img) # .astype(float)
        # img = torch.tensor((img / img.max()) * 255, dtype=torch.uint8) # v1-v2
        img = np.clip(img, 0, 2550).astype(np.uint8)
        img = torch.tensor(img, dtype=torch.uint8)  # v3-v4
    elif img_path.endswith((".jpg", ".png", ".jpeg")):
        img = read_image(img_path)
        img = torch.squeeze(img)
    # print(image.shape)
    elif img_path.endswith((".nii", ".nii.gz")):
        img3d = nib.load(img_path).get_fdata()
        img = img3d.permute(2, 1, 0)
    else:
        return None
    if len(img.shape) == 2:
        img = torch.broadcast_to(img, (3,)+tuple(img.shape))

    return img

train_transforms = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.Lambda(img_normalize)
])

val_transforms = transforms.Compose([
    # transforms.CenterCrop(img_size),
    transforms.Resize((512, 512)),
    transforms.Lambda(img_normalize)
])
test_transforms = transforms.Compose([
    # transforms.CenterCrop(img_size),
    transforms.Resize((512, 512)),
    transforms.Lambda(img_normalize)
])


In [14]:
data_path = "/run/media/alex/PartOfDiskWithWin7/LEARN/Longevity/Liver/datasets/kaggle.com_andrewmvd_liver-tumor-segmentation/"
# test_dir = data_path + "test"
# data_path = "/run/media/alex/PartOfDiskWithWin7/LEARN/Longevity/Liver/datasets/wiki.cancerimagingarchive.net/Crowds Cure Cancer: Data collected at the RSNA 2017 annual meeting/manifest-KTt2tScD7164745271364348431/TCGA-LIHC/1-test/"
sample_targets = ['segmentations', 'volume']
samples = sorted(os.listdir(os.path.join(data_path, sample_targets[1])))
targets = sorted(os.listdir(os.path.join(data_path, sample_targets[0])))
targets_path = os.path.join(data_path, sample_targets[0])
samples_path = os.path.join(data_path, sample_targets[1])
st = zip(targets, samples)
batch_size = 4
for ts in st:
    t = conv_1img_3D_totensor(os.path.join(targets_path,ts[0]))
    s = conv_1img_3D_totensor(os.path.join(samples_path, ts[1]))
    # print(t.shape)
    tail = t.shape[0] % batch_size
    for i in range(t.shape[0]//batch_size):
        target = torch.zeros([batch_size, 3, t.shape[1], t.shape[2]])
        sample = torch.zeros([batch_size, 3, s.shape[1], s.shape[2]])
        for j in range(batch_size):
            target[j, ...] = torch.broadcast_to(
                t[i+j, ...], (3,)+tuple(t[i+j, ...].shape))
            sample[j, ...] = torch.broadcast_to(
                s[i+j, ...], (3,)+tuple(s[i+j, ...].shape))
    print(target.shape)
    print(sample.shape)
# print(st, next(st))


torch.Size([4, 3, 512, 512])
torch.Size([4, 3, 512, 512])
torch.Size([4, 3, 512, 512])
torch.Size([4, 3, 512, 512])
torch.Size([4, 3, 512, 512])
torch.Size([4, 3, 512, 512])


KeyboardInterrupt: 

In [25]:
batch_size = 1

test_dataset = ImageFolderDCM(
    data_path, transform=test_transforms, loader=conv_1img_3D_totensor)
    
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
                                              shuffle=False, num_workers=batch_size//2)


In [26]:
print(test_dataloader)


<torch.utils.data.dataloader.DataLoader object at 0x7fc5c934e8e0>


In [27]:
for inputs, labels in test_dataloader:  # tqdm(
    # print(type(inputs), type(labels))
    print(inputs.shape)
    print(labels)
    # break


torch.Size([1, 75, 512, 512])
tensor([0])
torch.Size([1, 123, 512, 512])
tensor([0])
torch.Size([1, 501, 512, 512])
tensor([0])


KeyboardInterrupt: 

In [29]:
x = torch.tensor([1, 2, 3, 4])
print(x.shape)
print(torch.unsqueeze(x, 0))
print(torch.unsqueeze(x, 0).shape)

print(torch.unsqueeze(x, 1))
print(torch.unsqueeze(x, 1).shape)


torch.Size([4])
tensor([[1, 2, 3, 4]])
torch.Size([1, 4])
tensor([[1],
        [2],
        [3],
        [4]])
torch.Size([4, 1])
