In [1]:
%config Completer.use_jedi = False

In [2]:
import os
from tqdm.notebook import tqdm
from skimage import io
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import albumentations
import torch

from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

In [3]:
import sys
sys.path.append('../')

from utils.general import *
import utils.dataload as d
from models import model_selector
from utils.data_augmentation import data_augmentation_selector
from medpy.metric.binary import hd, dc, jc, assd

from utils.neural import *
from utils.datasets import *
from utils.metrics import *

In [4]:
def find_path(directory, filename):
    for path in Path(directory).rglob(filename):
        return path

In [5]:
def save_pred(image, pred_mask, case):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(17, 10))
    fig.tight_layout(pad=3)  # Set spacing between plots

    ax1.imshow(image, cmap="gray")
    ax1.axis("off")
    ax1.set_title("Input Image")

    masked_lv = np.ma.masked_where(pred_mask == 0, pred_mask)
    ax2.imshow(image, cmap="gray")
    ax2.imshow(masked_lv, 'hsv', interpolation='bilinear', alpha=0.33)
    ax2.axis("off")
    ax2.set_title("Automatic Segmentation")

    fig.suptitle(case, y=0.9)
    parent_dir = "acdc2mca_preds_overlays"
    os.makedirs(parent_dir, exist_ok=True)
    plt.savefig(os.path.join(parent_dir, f"{case}.jpg"), dpi=300)
    plt.close()

In [6]:
model = model_selector(
    "segmentation", "resnet34_unet_imagenet_encoder", num_classes=4, from_swa=True,
    in_channels=3, devices=[0], checkpoint="../checks/acdc_model_resnet34_unet_imagenet_encoder_40epochs_swalr0.00256.pt"
)


--- Frosted pretrained backbone! ---
Model total number of parameters: 36290058
Loaded model from checkpoint: ../checks/acdc_model_resnet34_unet_imagenet_encoder_40epochs_swalr0.00256.pt


In [7]:
_, _, val_aug = data_augmentation_selector(
    "acdc172d", 224, 224, "padd"
)

Using LVSC 2D Segmentation Data Augmentation Combinations
Padding masks!
Padding masks!


In [8]:
mca_path = "../data/MCA/"

In [9]:
mca_volume_paths = [] 
for root, dirs, files in os.walk(mca_path):
    for file in files:
        if file.endswith(".nii.gz"):
            filename = os.path.join(root, file)
            mca_volume_paths.append(filename)

In [10]:
add_depth = True
preds_dir = "acdc2mca_preds"
os.makedirs(preds_dir, exist_ok=True)

plot_preds = 100

model.eval()
with torch.no_grad():
    for volume_path in tqdm(mca_volume_paths):
        volume, affine, header = d.load_nii(volume_path)  # [height, width, slices, phases]
        original_volume = copy.deepcopy(volume)
        h,w,_,_ = original_volume.shape
        patient = volume_path.split("/")[-1]
        
        full_mask = []
        for c_phase in range(volume.shape[3]):
            c_volume = volume[..., c_phase].transpose(2, 0, 1)

            c_volume, _ = d.apply_volume_2Daugmentations(c_volume, val_aug, [])
            c_volume = d.apply_normalization(c_volume, "standardize")

            # We have to stack volume as batch
            c_volume = np.expand_dims(c_volume, axis=0) if not add_depth else c_volume
            c_volume = torch.from_numpy(c_volume)

            if add_depth:
                c_volume = d.add_volume_depth_channels(c_volume.unsqueeze(1))

            vol_preds = model(c_volume)
            
            #pred_mask = reshape_volume(
            #    torch.sigmoid(vol_preds).squeeze(1).data.cpu().numpy(), (h, w), "padd"
            #)
            #pred_mask = np.where(pred_mask > 0.5, 1, 0).astype(np.int32)
            
            pred_mask = convert_multiclass_mask(vol_preds).data.cpu().numpy()
            pred_mask = reshape_volume(pred_mask, (h, w), "padd")
            pred_mask = pred_mask.astype(np.uint8)
            
            full_mask.append(pred_mask)

            for pred_indx, pred_slice in enumerate(pred_mask):
                if plot_preds > 0:
                    save_pred(
                        original_volume[...,pred_indx, c_phase], pred_slice, 
                        f"{patient[:-7]}_phase{c_phase}_slice{pred_indx}"
                    )
                    plot_preds -= 1
                    
        full_mask = np.array(full_mask).transpose(2,3,1,0)
        pred_name = f"pred_{patient}"
        d.save_nii(os.path.join(preds_dir, pred_name), full_mask, affine, header)

  0%|          | 0/82 [00:00<?, ?it/s]