In [5]:
import os

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import nrrd
from IPython.display import Image as show_gif

import matplotlib.pyplot as plt
import matplotlib.animation as anim
import pydicom as pdm

from monai.networks.nets import UNet
from torch.optim.lr_scheduler import ReduceLROnPlateau
from monai.transforms import Compose, ToTensor, Transform, NormalizeIntensity, RandGaussianNoise
from scipy.spatial.distance import directed_hausdorff
from monai.losses import DiceFocalLoss
from sklearn.metrics import jaccard_score
from torchvision.ops import box_iou
from tqdm import tqdm

import yaml 

try:
    from yaml import CLoader as Loader
except ImportError:
    from yaml import Loader


In [None]:
class ImageToGIF:
    """Create GIF without saving image files."""
    def __init__(self,
                 size=(600, 400), 
                 xy_text=(80, 10),
                 dpi=100, 
                 cmap='CMRmap'):

        self.fig = plt.figure()
        self.fig.set_size_inches(size[0] / dpi, size[1] / dpi)
        self.xy_text = xy_text
        self.cmap = cmap
        
        self.ax = self.fig.add_axes([0, 0, 1, 1])
        self.ax.set_xticks([])
        self.ax.set_yticks([])
        self.images = []
 
    def add(self, *args, label, with_mask=True):
        
        image = args[0]
        mask = args[-1]
        plt.set_cmap(self.cmap)
        plt_img = self.ax.imshow(image, animated=True)
        if with_mask:
            plt_mask = self.ax.imshow(np.ma.masked_where(mask == False, mask),
                                      alpha=0.7, animated=True)

        plt_text = self.ax.text(*self.xy_text, label, color='red')
        to_plot = [plt_img, plt_mask, plt_text] if with_mask else [plt_img, plt_text]
        self.images.append(to_plot)
        plt.close()
 
    def save(self, filename, fps):
        animation = anim.ArtistAnimation(self.fig, self.images)
        animation.save(filename, writer='ffmpeg', fps=fps)
        

In [None]:
def read_nrrd_file(path: str, 
                   tensor_shape: tuple ) -> np.ndarray:
    if os.path.exists(path):
        tensor = nrrd.read(path)[0]                             
        tensor = np.flip(tensor, -1)                   # Warning! slice order of images and masks does not match.
    else: 
        tensor = np.zeros(tensor_shape, dtype=np.float32)
    return tensor


def nrrd_to_numpy(id_: str, tensor_shape: tuple):
    '''
    Returns:  all id masks in single numpy tensor.
    '''
    #lung_file_path = 'data/nrrd_lung/nrrd_lung/' + id_ + '_lung.nrrd'
    heart_file_path  = 'data/nrrd_heart/nrrd_heart/' + id_ + '_heart.nrrd'         # Here path hardcoded
    #trachea_file_path = 'data/nrrd_trachea/nrrd_trachea/' + id_ + '_trachea.nrrd'
    #lung_tensor = read_nrrd_file(lung_file_path, tensor_shape)
    heart_tensor = read_nrrd_file(heart_file_path, tensor_shape)
    #trachea_tensor = read_nrrd_file(trachea_file_path, tensor_shape)
    
        
    # now each tensor channel is a mask with a unique label
    #full_mask = np.stack([lung_tensor, heart_tensor, trachea_tensor])

    # reorient the axes from CHWB to BWHC
    heart_tensor = np.moveaxis(heart_tensor,
                            [0, 1, 2],
                            [2, 1, 0]).astype(np.float32)
    

    return heart_tensor

In [None]:
id_ = 'ID00015637202177877247924'

sample_masks = nrrd_to_numpy(id_, (768, 768))
sample_masks.shape

In [None]:
'''
sample_data_gif = ImageToGIF(size=(768, 768),
                             xy_text=(250, 15))

label = 'ID00015637202177877247924'


for i in range(sample_masks.shape[0]):
    sample_data_gif.add(sample_masks[i],label=f'{label}_{str(i)}', with_mask=False)
 
sample_data_gif.save(f'{label}.gif', fps=15)
show_gif(f'{label}.gif', format='png')
'''

In [None]:
label = 'ID00015637202177877247924'

sample_path = 'data/train/' + label
sample_path_files = sorted(os.listdir(sample_path), key=lambda x: int(x[:-4]))

In [None]:
'''
sample_data_gif = ImageToGIF()
labelOUT = label + '_with_masks'
for i in range(sample_masks.shape[0]):
    path = os.path.join(sample_path, sample_path_files[i])
    image = pdm.dcmread(path).pixel_array
    mask = sample_masks[i]
    sample_data_gif.add(image, mask, label=f'{labelOUT}_{str(i)}',)
 
sample_data_gif.save(f'{labelOUT}.gif', fps=15)
show_gif(f'{labelOUT}.gif', format='png')

'''



In [None]:
'''
train_data_paths = pd.read_csv('data/train.csv')
test_data_paths = pd.read_csv('data/test.csv')
train_data_paths.head()
'''



        

In [None]:
segment_ids = os.listdir('data/nrrd_heart/nrrd_heart/')
segment_ids = map(lambda x: x.split('_')[0], segment_ids)
segment_ids = pd.DataFrame(segment_ids, columns=['Patient'])
segment_ids.head()

In [None]:
class HeartDataset(Dataset):
    def __init__(self, 
                 imgs_dir: str,
                 masks_dir:str,
                 df: pd.DataFrame,
                 transform = None,
                 mask_transform = None):
        """Initialization."""
        self.root_imgs_dir = imgs_dir
        self.root_masks_dir = masks_dir
        self.df = df
        
        self.transform = transform
        self.mask_transform = mask_transform
    
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        id = self.df.loc[idx, "Patient"]
        full_image = []
        for image_id in sorted(os.listdir(self.root_imgs_dir + id), key=lambda x: int(x[:-4])):
            image = pdm.dcmread(self.root_imgs_dir + id + '/' + image_id)
            #image.PhotometricInterpretation  = 'YBR_FULL'
            image = image.pixel_array
            full_image.append(image)
            
        full_image = np.array(full_image, dtype=np.float32)
        mask_shape = full_image.shape[1:]
        mask = nrrd_to_numpy(id, mask_shape)
        

        
        if full_image is None or mask is None:
            raise FileNotFoundError(f"Image with id {id} not found.")
        
        if self.transform:
            full_image = self.transform(full_image)
            mask = self.mask_transform(mask)
        else:
            full_image = torch.tensor(full_image, dtype=torch.float32).unsqueeze(0)
            mask = torch.tensor(mask, dtype=torch.float32).unsqueeze(0)
        return full_image, mask

In [None]:
dataset = HeartDataset('data/train/', 'data/nrrd_heart/nrrd_heart/', segment_ids)
item, label = dataset.__getitem__(0)

print(item.shape, label.shape)

del dataset

In [1]:
def get_dataloader(
    imgs_dir: str,
    masks_dir: str,
    batch_size: int = 8,
    test_size: float = 0.2,
    df: pd.DataFrame = segment_ids,
    train_transforms=None,
    test_transforms=None,
    num_workers = 1,
    mask_transforms=None
):
    '''Returns: dataloader for the model training'''
    

    train_df, val_df = train_test_split(df, 
                                          test_size=test_size, 
                                          random_state=69)
    train_df, val_df = train_df.reset_index(drop=True), val_df.reset_index(drop=True)
    
    train_data_set = HeartDataset(imgs_dir, masks_dir, train_df, transform=train_transforms, mask_transform=mask_transforms)
    train_loader = DataLoader(
        train_data_set,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=True,   
        num_workers=num_workers
    )
    
    test_data_set = HeartDataset(imgs_dir, masks_dir, val_df, transform=test_transforms, mask_transform=mask_transforms)
    test_loader = DataLoader(
        test_data_set,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=False,
        num_workers=num_workers
    )

    return train_loader, test_loader

NameError: name 'segment_ids' is not defined

In [None]:
class Downsample3d(Transform):
    def __init__(self, out_size: tuple):
        super().__init__()
        self.out_size = out_size
        
    def __call__(self, inpt):
        return torch.nn.functional.interpolate(inpt.unsqueeze(0).unsqueeze(0), size=self.out_size, mode='trilinear').squeeze(0)
    
class CenterCrop3d(Transform):
    def __init__(self, roi_size: tuple):
        super().__init__()
        self.roi_size = roi_size
        
    def __call__(self, inpt):
        return inpt[inpt.shape[0]//2 - self.roi_size[0]//2: inpt.shape[0]//2 + self.roi_size[0]//2,
                    inpt.shape[1]//2 - self.roi_size[1]//2: inpt.shape[1]//2 + self.roi_size[1]//2,
                    inpt.shape[2]//2 - self.roi_size[2]//2: inpt.shape[2]//2 + self.roi_size[2]//2]

class RandomCrop3d(Transform):
    def __init__(self, roi_size: tuple):
        super().__init__()
        self.roi_size = roi_size
        
    def __call__(self, inpt):
        x = np.random.randint(0, inpt.shape[0] - self.roi_size[0])
        y = np.random.randint(0, inpt.shape[1] - self.roi_size[1])
        z = np.random.randint(0, inpt.shape[2] - self.roi_size[2])
        return inpt[x: x + self.roi_size[0],
                    y: y + self.roi_size[1],
                    z: z + self.roi_size[2]]


In [None]:

monai_transform = Compose([
    ToTensor(),
    CenterCrop3d((88, 400, 400)),
    RandomCrop3d((80, 350, 350)),
    Downsample3d((80, 256, 256)),
    NormalizeIntensity(),
    RandGaussianNoise(prob=0.3)
])
dataset = HeartDataset('data/train/', 'data/nrrd_heart/nrrd_heart/', segment_ids, transform=monai_transform)
item, label = dataset.__getitem__(2)
item = np.array(item, dtype=np.float32)
print(item.shape, label.shape)
print(np.max(item), np.min(item))

del dataset

In [None]:

train_dataloader, test_dataloader = get_dataloader('data/train/', 'data/nrrd_heart/nrrd_heart/', train_transforms=monai_transform, test_transforms=None)
print(len(train_dataloader), len(test_dataloader))

print(np.array(iter(train_dataloader).__next__()).shape)

del train_dataloader, test_dataloader


In [None]:
index = 2


sample_data_gif = ImageToGIF()
labelOUT = segment_ids.loc[index, "Patient"] + '_with_masks'
print(item.shape, label.shape)
item = item[0]
label = label[0]

print(item.shape, label.shape)
for i in range(item.shape[0]):
    sample_data_gif.add(item[i], label[i], label=f'{labelOUT}_{str(i)}',)
 
print(labelOUT)
sample_data_gif.save(f'{labelOUT}.gif', fps=15)
show_gif(f'{labelOUT}.gif', format='png')

In [4]:


def dice_coef_metric(probabilities: torch.Tensor,
                     truth: torch.Tensor,
                     treshold: float = 0.5,
                     eps: float = 1e-9) -> np.ndarray:
    """
    Calculate Dice score for data batch.
    Params:
        probobilities: model outputs after activation function.
        truth: truth values.
        threshold: threshold for probabilities.
        eps: additive to refine the estimate.
        Returns: dice score aka f1.
    """
    scores = []
    num = probabilities.shape[0]
    predictions = (probabilities >= treshold).float()
    assert(predictions.shape == truth.shape)
    for i in range(num):
        prediction = predictions[i]
        truth_ = truth[i]
        intersection = 2.0 * (truth_ * prediction).sum()
        union = truth_.sum() + prediction.sum()
        if truth_.sum() == 0 and prediction.sum() == 0:
            scores.append(1.0)
        else:
            scores.append((intersection + eps) / union)
    return np.mean(scores)

def hausdorff_distance_metric (probabilities: torch.Tensor,
                                truth: torch.Tensor) -> np.ndarray:
    metric = directed_hausdorff(probabilities, truth)
    return metric

def IOU_metric(probabilities: torch.Tensor,
               truth: torch.Tensor,
               treshold: float = 0.5,
               eps: float = 1e-9) -> np.ndarray:
    box_iou_score = box_iou(probabilities, truth)
    # jaccard_score(truth.flatten(), (probabilities.flatten() > treshold).astype(int))
    return box_iou_score 

class Meter:
    '''factory for storing and updating iou and dice scores.'''
    def __init__(self, treshold: float = 0.5):
        self.threshold: float = treshold
        self.scores: dict = {
            'dice': [],
            'iou': [],
            'hausdorff': []
        }
    
    def update(self, logits: torch.Tensor, targets: torch.Tensor):
        """
        Takes: logits from output model and targets,
        calculates dice and iou scores, and stores them in lists.
        """
        probs = torch.sigmoid(logits)
        dice = dice_coef_metric(probs, targets, self.threshold)
        IOU = IOU_metric(probs, targets, self.threshold)
        Hausdorff = hausdorff_distance_metric(probs, targets)
        self.scores['dice'].append(dice)
        self.scores['iou'].append(IOU)
        self.scores['hausdorff'].append(Hausdorff)
    
    def get_metrics(self):
        """
        Returns: the average of the accumulated dice and iou scores.
        """
        dice = np.mean(self.scores['dice'])
        iou = np.mean(self.scores['iou'])
        hausdorff = np.mean(self.scores['hausdorff'])
        return dice, iou, hausdorff
    

In [8]:
class Trainer:
    def __init__(self,
                 model: nn.Module,
                 optimizer: torch.optim.Optimizer,
                 criterion: nn.Module,
                 train_loader: DataLoader,
                 test_loader: DataLoader,
                 epochs: int = 100,
                 device: str = 'cuda',
                 plot=False,
                 log_dir: str = 'logs',
                 checkpoint_dir: str = 'checkpoints',
                 output_dir: str = 'outputs',
                 checkpoint_interval: int = 10):
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.epochs = epochs
        self.device = device
        self.model.to(self.device)
        
        self.plot = plot
        
        self.losses = {'train': [], 'test': []}
        self.dice_scores = {'train': [], 'test': []}
        self.iou_scores = {'train': [], 'test': []}
        self.hausdorff_scores = {'train': [], 'test': []}
        
        self.log_dir = log_dir
        self.checkpoint_dir = checkpoint_dir
        self.output_dir = output_dir
        
        self.checkpoint_interval = checkpoint_interval
        
        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)
        
        if not os.path.exists(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)
            
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)
        
        self.best_test_loss = np.inf
        
    def loss_and_logits(self, images: torch.Tensor, masks: torch.Tensor):
        images = images.to(self.device)
        masks = masks.to(self.device)
        
        logits = self.model(images)
        loss = self.criterion(logits, masks)
        return loss, logits
    
    def next_epoch(self, epoch, test=False):
        self.model.train() if not test else self.model.eval()
        meter = Meter()
        running_loss = 0.0
        
        if not test:
            self.optimizer.zero_grad()
        
        for i, (images, masks) in enumerate(self.train_loader if not test else self.test_loader):
            loss, logits = self.loss_and_logits(images, masks)
            
            if not test:
                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()
            
            running_loss += loss.item()
            meter.update(logits.detach().cpu(), masks.detach().cpu())
            
        epoch_loss = running_loss / len(self.train_loader if not test else self.test_loader)
        dice, iou, hausdorff = meter.get_metrics()
        
        self.losses['train' if not test else 'test'].append(epoch_loss)
        self.dice_scores['train' if not test else 'test'].append(dice)
        self.iou_scores['train' if not test else 'test'].append(iou)
        self.hausdorff_scores['train' if not test else 'test'].append(hausdorff)
        
        return epoch_loss, (dice, iou, hausdorff)
    
    def train(self):
        for epoch in tqdm(range(self.epochs)):
            self.next_epoch(epoch, test=False)
            
            with torch.no_grad():
                test_loss, metrics = self.next_epoch(epoch, test=True)
                lr_scheduler.step(test_loss)
            
            if self.plot:
                self.plot_metrics()
            
            if test_loss < self.best_test_loss:
                print(f"Saving best model with test loss: {test_loss:.4f} and metrics: dice - {metrics[0]:.4f} iou - {metrics[1]:.4f} hassdorf - {metrics[2]:.4f} at epoch: {epoch + 1}")
                self.best_test_loss = test_loss
                torch.save(self.model.state_dict(), self.output_dir + "/" + 'best_model.pth')
                
            if (epoch + 1) % self.checkpoint_interval == 0:
                print(f"Saving checkpoint at epoch: {epoch + 1}")
                torch.save(self.model.state_dict(), self.checkpoint_dir + "/" + f'epoch_{epoch + 1}.pth')

        self.save_log() 
        
    def plot_metrics(self):
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.plot(self.losses['train'], label='train')
        plt.plot(self.losses['test'], label='test')
        plt.title('Loss')
        plt.legend()
        
        plt.subplot(1, 2, 2)
        plt.plot(self.dice_scores['train'], label='train')
        plt.plot(self.dice_scores['test'], label='test')
        plt.title('Dice')
        plt.legend()
        
        
        plt.subplot(2, 2, 1)
        plt.plot(self.iou_scores['train'], label='train')
        plt.plot(self.iou_scores['test'], label='test')
        plt.title('IOU')
        plt.legend()
        
        plt.subplot(2, 2, 2)
        plt.plot(self.hausdorff_scores['train'], label='train')
        plt.plot(self.hausdorff_scores['test'], label='test')
        plt.title('Hausdorff')
        plt.legend()
        
        plt.show()
    
    def save_log(self):
        torch.save(self.model.state_dict(), self.output_dir + "/" + 'last_epoch.pth')
        
        log = pd.DataFrame({
            'train_loss': self.losses['train'],
            'test_loss': self.losses['test'],
            'train_dice': self.dice_scores['train'],
            'test_dice': self.dice_scores['test'],
            'train_iou': self.iou_scores['train'],
            'test_iou': self.iou_scores['test'],
            'train_hausdorff': self.hausdorff_scores['train'],
            'test_hausdorff': self.hausdorff_scores['test']
        })
        log.to_csv(self.log_dir + "/" + 'log.csv', index=False)
    
    
    def load_model(self, path: str):
        print(f"Loading model from {path}")
        self.model.load_state_dict(torch.load(path))
        

In [6]:
config_dict = {}
try:
    with open("configs.yaml", 'r') as stream:
        config_dict = yaml.load(stream, Loader)
except FileNotFoundError:
    print("Config file not found.")

model_dict = config_dict['model']
trainer_dict = config_dict['trainer']
dataset_dict = config_dict['dataset']
optimizer_dict = config_dict['optimizer']
del config_dict

In [None]:
model = UNet(
    spatial_dims= model_dict['spatial_dims'],
    in_channels= model_dict['in_channels'],
    out_channels= model_dict['out_channels'],
    channels= model_dict['channels'],
    strides= model_dict['strides'],
    num_res_units= model_dict['num_res_units'],
)

train_transforms = Compose([
    ToTensor(),
    CenterCrop3d((88, 400, 400)),
    RandomCrop3d((80, 350, 350)),
    Downsample3d((80, 256, 256)),
    NormalizeIntensity(),
    RandGaussianNoise(prob=0.3)
])

test_transforms = Compose([
    ToTensor(),
    CenterCrop3d((88, 400, 400)),
    RandomCrop3d((80, 350, 350)),
    Downsample3d((80, 256, 256)),
    NormalizeIntensity(),
])

mask_transforms = Compose([
    ToTensor(),
    CenterCrop3d((88, 400, 400)),
    RandomCrop3d((80, 350, 350)),
    Downsample3d((80, 256, 256)),
])

train_dataloader, test_dataloader = get_dataloader(dataset_dict['image_path'], dataset_dict['mask_path'], train_transforms=train_transforms, test_transforms=test_transforms, mask_transforms=mask_transforms, batch_size=trainer_dict['batch_size'], num_workers=trainer_dict['num_workers'])

if optimizer_dict['name'] == 'Adam':
    optimizer = torch.optim.Adam(model.parameters(), optimizer_dict['params']['lr'])
else:
    optimizer = torch.optim.SGD(model.parameters(), optimizer_dict['params']['lr'], momentum=optimizer_dict['momentum'])


lr_scheduler = ReduceLROnPlateau(optimizer, optimizer_dict['scheduler']['params']['mode'], factor=optimizer_dict['scheduler']['params']['factor'], patience=optimizer_dict['scheduler']['params']['patience'])

criterion = DiceFocalLoss()

device = 'cuda' if torch.cuda.is_available() else 'cpu'

print("running on device: ", device)

trainer = Trainer(model, optimizer, criterion, train_dataloader, test_dataloader, epochs=trainer_dict['epochs'], plot=trainer_dict['plot'], device=device, log_dir=trainer_dict['log_dir'], checkpoint_dir=trainer_dict['checkpoint_dir'], output_dir=trainer_dict['output_dir'], checkpoint_interval=trainer_dict['checkpoint_interval'])


In [None]:
%%time
trainer.train()