In [51]:
import os

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import nrrd
from IPython.display import Image as show_gif

import matplotlib.pyplot as plt
import matplotlib.animation as anim
import pydicom as pdm


In [2]:
class ImageToGIF:
    """Create GIF without saving image files."""
    def __init__(self,
                 size=(600, 400), 
                 xy_text=(80, 10),
                 dpi=100, 
                 cmap='CMRmap'):

        self.fig = plt.figure()
        self.fig.set_size_inches(size[0] / dpi, size[1] / dpi)
        self.xy_text = xy_text
        self.cmap = cmap
        
        self.ax = self.fig.add_axes([0, 0, 1, 1])
        self.ax.set_xticks([])
        self.ax.set_yticks([])
        self.images = []
 
    def add(self, *args, label, with_mask=True):
        
        image = args[0]
        mask = args[-1]
        plt.set_cmap(self.cmap)
        plt_img = self.ax.imshow(image, animated=True)
        if with_mask:
            plt_mask = self.ax.imshow(np.ma.masked_where(mask == False, mask),
                                      alpha=0.7, animated=True)

        plt_text = self.ax.text(*self.xy_text, label, color='red')
        to_plot = [plt_img, plt_mask, plt_text] if with_mask else [plt_img, plt_text]
        self.images.append(to_plot)
        plt.close()
 
    def save(self, filename, fps):
        animation = anim.ArtistAnimation(self.fig, self.images)
        animation.save(filename, writer='ffmpeg', fps=fps)
        

In [3]:
def read_nrrd_file(path: str, 
                   tensor_shape: tuple ) -> np.ndarray:
    if os.path.exists(path):
        tensor = nrrd.read(path)[0]                             
        tensor = np.flip(tensor, -1)                   # Warning! slice order of images and masks does not match.
    else: 
        tensor = np.zeros(tensor_shape, dtype=np.float32)
    return tensor


def nrrd_to_numpy(id_: str, tensor_shape: tuple):
    '''
    Returns:  all id masks in single numpy tensor.
    '''
    #lung_file_path = 'data/nrrd_lung/nrrd_lung/' + id_ + '_lung.nrrd'
    heart_file_path  = 'data/nrrd_heart/nrrd_heart/' + id_ + '_heart.nrrd'         # Here path hardcoded
    #trachea_file_path = 'data/nrrd_trachea/nrrd_trachea/' + id_ + '_trachea.nrrd'
    #lung_tensor = read_nrrd_file(lung_file_path, tensor_shape)
    heart_tensor = read_nrrd_file(heart_file_path, tensor_shape)
    #trachea_tensor = read_nrrd_file(trachea_file_path, tensor_shape)
    
        
    # now each tensor channel is a mask with a unique label
    #full_mask = np.stack([lung_tensor, heart_tensor, trachea_tensor])

    # reorient the axes from CHWB to BWHC
    heart_tensor = np.moveaxis(heart_tensor,
                            [0, 1, 2],
                            [2, 1, 0]).astype(np.float32)
    

    return heart_tensor

In [4]:
id_ = 'ID00015637202177877247924'

sample_masks = nrrd_to_numpy(id_, (768, 768))
sample_masks.shape

(295, 768, 768)

In [5]:
'''
sample_data_gif = ImageToGIF(size=(768, 768),
                             xy_text=(250, 15))

label = 'ID00015637202177877247924'


for i in range(sample_masks.shape[0]):
    sample_data_gif.add(sample_masks[i],label=f'{label}_{str(i)}', with_mask=False)
 
sample_data_gif.save(f'{label}.gif', fps=15)
show_gif(f'{label}.gif', format='png')
'''

"\nsample_data_gif = ImageToGIF(size=(768, 768),\n                             xy_text=(250, 15))\n\nlabel = 'ID00015637202177877247924'\n\n\nfor i in range(sample_masks.shape[0]):\n    sample_data_gif.add(sample_masks[i],label=f'{label}_{str(i)}', with_mask=False)\n \nsample_data_gif.save(f'{label}.gif', fps=15)\nshow_gif(f'{label}.gif', format='png')\n"

In [6]:
label = 'ID00015637202177877247924'

sample_path = 'data/train/' + label
sample_path_files = sorted(os.listdir(sample_path), key=lambda x: int(x[:-4]))

In [7]:
'''
sample_data_gif = ImageToGIF()
labelOUT = label + '_with_masks'
for i in range(sample_masks.shape[0]):
    path = os.path.join(sample_path, sample_path_files[i])
    image = pdm.dcmread(path).pixel_array
    mask = sample_masks[i]
    sample_data_gif.add(image, mask, label=f'{labelOUT}_{str(i)}',)
 
sample_data_gif.save(f'{labelOUT}.gif', fps=15)
show_gif(f'{labelOUT}.gif', format='png')

'''



"\nsample_data_gif = ImageToGIF()\nlabelOUT = label + '_with_masks'\nfor i in range(sample_masks.shape[0]):\n    path = os.path.join(sample_path, sample_path_files[i])\n    image = pdm.dcmread(path).pixel_array\n    mask = sample_masks[i]\n    sample_data_gif.add(image, mask, label=f'{labelOUT}_{str(i)}',)\n \nsample_data_gif.save(f'{labelOUT}.gif', fps=15)\nshow_gif(f'{labelOUT}.gif', format='png')\n\n"

In [8]:
'''
train_data_paths = pd.read_csv('data/train.csv')
test_data_paths = pd.read_csv('data/test.csv')
train_data_paths.head()
'''



        

"\ntrain_data_paths = pd.read_csv('data/train.csv')\ntest_data_paths = pd.read_csv('data/test.csv')\ntrain_data_paths.head()\n"

In [9]:
segment_ids = os.listdir('data/nrrd_heart/nrrd_heart/')
segment_ids = map(lambda x: x.split('_')[0], segment_ids)
segment_ids = pd.DataFrame(segment_ids, columns=['Patient'])
segment_ids.head()

Unnamed: 0,Patient
0,ID00007637202177411956430
1,ID00009637202177434476278
2,ID00010637202177584971671
3,ID00012637202177665765362
4,ID00014637202177757139317


In [48]:
class HeartDataset(Dataset):
    def __init__(self, 
                 imgs_dir: str,
                 masks_dir:str,
                 df: pd.DataFrame,
                 transform = None):
        """Initialization."""
        self.root_imgs_dir = imgs_dir
        self.root_masks_dir = masks_dir
        self.df = df
        
        self.transform = transform
    
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        id = self.df.loc[idx, "Patient"]
        full_image = []
        for image_id in sorted(os.listdir(self.root_imgs_dir + id), key=lambda x: int(x[:-4])):
            image = pdm.dcmread(self.root_imgs_dir + id + '/' + image_id)
            #image.PhotometricInterpretation  = 'YBR_FULL'
            image = image.pixel_array
            full_image.append(image)
            
        full_image = np.array(full_image, dtype=np.float32)
        mask_shape = full_image.shape[1:]
        mask = nrrd_to_numpy(id, mask_shape)
        

        
        if full_image is None or mask is None:
            raise FileNotFoundError(f"Image with id {id} not found.")
        
        
        if full_image.shape[0] < 200:
            full_image = np.pad(full_image, ((0, 100 - full_image.shape[0]), (0, 0), (0, 0)), mode='constant')
            mask = np.pad(mask, ((0, 100 - mask.shape[0]), (0, 0), (0, 0)), mode='constant')
        
        if full_image.shape[0] > 200:
            full_image = full_image[:100]
            mask = mask[:100]

        if full_image.shape[1] < 512:
            full_image = np.pad(full_image, ((0, 0), (0, 512 - full_image.shape[1]), (0, 0)), mode='constant')
            mask = np.pad(mask, ((0, 0), (0, 512 - mask.shape[1]), (0, 0)), mode='constant')
        
        if full_image.shape[1] > 512:
            full_image = full_image[:, :512, :512]
            mask = mask[:, :512, :512]
        
        full_image = self.transform(full_image)
        mask = self.transform(mask)
            
        
        return full_image, mask

In [33]:
dataset = HeartDataset('data/train/', 'data/nrrd_heart/nrrd_heart/', segment_ids)
item, label = dataset.__getitem__(0)

print(item.shape, label.shape)

del dataset

(512, 512, 30) (512, 512, 30)


TypeError: 'NoneType' object is not callable

In [17]:
def get_dataloader(
    imgs_dir: str,
    masks_dir: str,
    batch_size: int = 8,
    test_size: float = 0.2,
    df: pd.DataFrame = segment_ids,
    train_transforms=None,
    test_transforms=None,
):
    '''Returns: dataloader for the model training'''
    

    train_df, val_df = train_test_split(df, 
                                          test_size=test_size, 
                                          random_state=69)
    train_df, val_df = train_df.reset_index(drop=True), val_df.reset_index(drop=True)
    
    print(f"Train: {len(train_df)}")
    print(f"Val: {len(val_df)}")
    print(train_df.head())
    print(val_df.head())
    train_data_set = HeartDataset(imgs_dir, masks_dir, train_df, transform=train_transforms)
    train_loader = DataLoader(
        train_data_set,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=True,   
    )
    
    test_data_set = HeartDataset(imgs_dir, masks_dir, val_df, transform=test_transforms)
    test_loader = DataLoader(
        test_data_set,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=False,
    )

    return train_loader, test_loader

In [49]:

from torchvision.transforms.v2 import Compose, Normalize, ToDtype, RandomVerticalFlip

train_transforms = Compose([
    Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)),
    ToDtype(torch.float32, scale=True),
])

test_transforms = Compose([
    Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)),
    ToDtype(torch.float32, scale=True),
])
train_dataloader, test_dataloader = get_dataloader('data/train/', 'data/nrrd_heart/nrrd_heart/', train_transforms=train_transforms, test_transforms=test_transforms)
item, label = next(iter(train_dataloader))
print(item.shape, label.shape)


Train: 69
Val: 18
                     Patient
0  ID00398637202303897337979
1  ID00009637202177434476278
2  ID00411637202309374271828
3  ID00417637202310901214011
4  ID00123637202217151272140
                     Patient
0  ID00072637202198161894406
1  ID00360637202295712204040
2  ID00323637202285211956970
3  ID00358637202295388077032
4  ID00032637202181710233084
(240, 768, 768) (240, 768, 768)
(64, 512, 512) (64, 512, 512)
(278, 512, 512) (278, 512, 512)
(233, 512, 512) (233, 512, 512)
(355, 512, 512) (355, 512, 512)
(33, 512, 512) (33, 512, 512)
(302, 512, 512) (302, 512, 512)
(32, 512, 512) (32, 512, 512)
torch.Size([8, 100, 512, 512]) torch.Size([8, 100, 512, 512])


In [None]:
index = 2


sample_data_gif = ImageToGIF()
labelOUT = segment_ids.loc[index, "Patient"] + '_with_masks'
for i in range(item.shape[0]):
    sample_data_gif.add(item[i], label[i], label=f'{labelOUT}_{str(i)}',)
 
print(labelOUT)
sample_data_gif.save(f'{labelOUT}.gif', fps=15)
show_gif(f'{labelOUT}.gif', format='png')

In [53]:
def dice_coef_metric(probabilities: torch.Tensor,
                     truth: torch.Tensor,
                     treshold: float = 0.5,
                     eps: float = 1e-9) -> np.ndarray:
    """
    Calculate Dice score for data batch.
    Params:
        probobilities: model outputs after activation function.
        truth: truth values.
        threshold: threshold for probabilities.
        eps: additive to refine the estimate.
        Returns: dice score aka f1.
    """
    scores = []
    num = probabilities.shape[0]
    predictions = (probabilities >= treshold).float()
    assert(predictions.shape == truth.shape)
    for i in range(num):
        prediction = predictions[i]
        truth_ = truth[i]
        intersection = 2.0 * (truth_ * prediction).sum()
        union = truth_.sum() + prediction.sum()
        if truth_.sum() == 0 and prediction.sum() == 0:
            scores.append(1.0)
        else:
            scores.append((intersection + eps) / union)
    return np.mean(scores)


def jaccard_coef_metric(probabilities: torch.Tensor,
               truth: torch.Tensor,
               treshold: float = 0.5,
               eps: float = 1e-9) -> np.ndarray:
    """
    Calculate Jaccard index for data batch.
    Params:
        probobilities: model outputs after activation function.
        truth: truth values.
        threshold: threshold for probabilities.
        eps: additive to refine the estimate.
        Returns: jaccard score aka iou."
    """
    scores = []
    num = probabilities.shape[0]
    predictions = (probabilities >= treshold).float()
    assert(predictions.shape == truth.shape)

    for i in range(num):
        prediction = predictions[i]
        truth_ = truth[i]
        intersection = (prediction * truth_).sum()
        union = (prediction.sum() + truth_.sum()) - intersection + eps
        if truth_.sum() == 0 and prediction.sum() == 0:
            scores.append(1.0)
        else:
            scores.append((intersection + eps) / union)
    return np.mean(scores)


class Meter:
    '''factory for storing and updating iou and dice scores.'''
    def __init__(self, treshold: float = 0.5):
        self.threshold: float = treshold
        self.dice_scores: list = []
        self.iou_scores: list = []
    
    def update(self, logits: torch.Tensor, targets: torch.Tensor):
        """
        Takes: logits from output model and targets,
        calculates dice and iou scores, and stores them in lists.
        """
        probs = torch.sigmoid(logits)
        dice = dice_coef_metric(probs, targets, self.threshold)
        iou = jaccard_coef_metric(probs, targets, self.threshold)
        
        self.dice_scores.append(dice)
        self.iou_scores.append(iou)
    
    def get_metrics(self):
        """
        Returns: the average of the accumulated dice and iou scores.
        """
        dice = np.mean(self.dice_scores)
        iou = np.mean(self.iou_scores)
        return dice, iou
    

class DiceLoss(nn.Module):
    """Calculate dice loss."""
    def __init__(self, eps: float = 1e-9):
        super(DiceLoss, self).__init__()
        self.eps = eps
        
    def forward(self,
                logits: torch.Tensor,
                targets: torch.Tensor) -> torch.Tensor:
        
        num = targets.size(0)
        probability = torch.sigmoid(logits)
        probability = probability.view(num, -1)
        targets = targets.view(num, -1)
        assert(probability.shape == targets.shape)
        
        intersection = 2.0 * (probability * targets).sum()
        union = probability.sum() + targets.sum()
        dice_score = (intersection + self.eps) / union
        #print("intersection", intersection, union, dice_score)
        return 1.0 - dice_score
        
        
class BCEDiceLoss(nn.Module):
    """Compute objective loss: BCE loss + DICE loss."""
    def __init__(self):
        super(BCEDiceLoss, self).__init__()
        self.bce = nn.BCEWithLogitsLoss()
        self.dice = DiceLoss()
        
    def forward(self, 
                logits: torch.Tensor,
                targets: torch.Tensor) -> torch.Tensor:
        assert(logits.shape == targets.shape)
        dice_loss = self.dice(logits, targets)
        bce_loss = self.bce(logits, targets)
        
        return bce_loss + dice_loss

In [ ]:
model = Unet('efficientnet-b2', encoder_weights="imagenet", classes=3, activation=None)
