<a href="https://colab.research.google.com/github/MichalGrzebyk/Medical_segmentation/blob/master/Medical_segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive

drive.mount('/content/drive')

In [None]:
#preparing directories for .png files
!mkdir /content/train
!mkdir /content/train/data
!mkdir /content/train/data/1
!mkdir /content/train/mask
!mkdir /content/train/mask/1
!mkdir /content/val
!mkdir /content/val/data
!mkdir /content/val/data/1
!mkdir /content/val/mask
!mkdir /content/val/mask/1

In [None]:
!pip install --upgrade nibabel

In [None]:
#!wget "DATASET LINK"  -O public.zip
#!unzip -q public.zip
#!rm public.zip

In [None]:
!pip install segmentation_models
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
import requests
import zlib
import tensorflow as tf
from keras.callbacks import ModelCheckpoint
from sklearn.preprocessing import MinMaxScaler
from segmentation_models.losses import dice_loss
from segmentation_models import Unet
from segmentation_models import get_preprocessing
from segmentation_models.metrics import iou_score, f1_score
from typing import Tuple, List
from pathlib import Path
import cv2


In [None]:
def load_raw_volume(path: Path) -> Tuple[np.ndarray, np.ndarray]:
    data: nib.Nifti1Image = nib.load(str(path))
    data = nib.as_closest_canonical(data)
    raw_data = data.get_fdata(caching='unchanged', dtype=np.float32)
    return raw_data, data.affine


def load_labels_volume(path: Path) -> np.ndarray:
    return load_raw_volume(path)[0].astype(np.uint8)


def save_labels(data: np.ndarray, affine: np.ndarray, path: Path):
    nib.save(nib.Nifti1Image(data, affine), str(path))


def show_slices(slices: List[np.ndarray]):
    fig, axes = plt.subplots(1, len(slices))
    for i, data_slice in enumerate(slices):
        axes[i].imshow(data_slice.T, cmap="gray", origin="lower")

In [None]:
#getting .png slices from 3D scans (there was two different datasets)
def predata():
    inputs = ['FirstDataset/train/', 'SecondDataset/train/']
    outputs = ['/content/train/data/1/', '/content/train/mask/1/', '/content/val/data/1/', '/content/val/mask/1/']
    data_path = [name for name in sorted(Path(inputs[1]).iterdir())]
    size = len(data_path)
    for num, file in enumerate(data_path):
        data_path2 = [name_ for name_ in sorted(Path(file).iterdir())]
        d_p = data_path2[0]
        m_p = data_path2[1]
        tmp_img, aff = load_raw_volume(d_p)
        tmp_musk = load_labels_volume(m_p)
        x_size, y_size, z_size =tmp_img.shape
        for y_index in range(y_size):
            data_slice = tmp_img[:, y_index]
            data_slice = cv2.resize(data_slice, (256, 256))
            mask_slice = tmp_musk[:, y_index]
            mask_slice = cv2.resize(mask_slice, (256, 256))
            if num / size < 0.9:
                name = outputs[0] + 'x%04d%04d.png' % (num, y_index)
                name_mask = outputs[1] + 'x%04d%04d.png' % (num, y_index)
            else:
                name = outputs[2] + 'x%04d%04d.png' % (num, y_index)
                name_mask = outputs[3] + 'x%04d%04d.png' % (num, y_index)
            plt.imsave(name, data_slice, format='png', cmap='gray', origin='lower')
            plt.imsave(name_mask, mask_slice, format='png', cmap='gray', origin='lower')

    data_path = [name for name in sorted(Path(inputs[0]).iterdir()) if not name.name.endswith('mask.nii.gz')]
    size = len(data_path)
    for num, file in enumerate(data_path):
        tmp_img, aff = load_raw_volume(file)
        tmp_musk = load_labels_volume(str(file).replace(".nii.gz", "_mask.nii.gz"))
        x_size, y_size, z_size =tmp_img.shape
        for y_index in range(y_size):
            data_slice = tmp_img[:, y_index]
            data_slice = cv2.resize(data_slice, (256, 256))
            mask_slice = tmp_musk[:, y_index]
            mask_slice = cv2.resize(mask_slice, (256, 256))
            if num / size < 0.9:
                name = outputs[0] + 'x%04d%04d.png' % (num, y_index)
                name_mask = outputs[1] + 'x%04d%04d.png' % (num, y_index)
            else:
                name = outputs[2] + 'x%04d%04d.png' % (num, y_index)
                name_mask = outputs[3] + 'x%04d%04d.png' % (num, y_index)
            plt.imsave(name, data_slice, format='png', cmap='gray', origin='lower')
            plt.imsave(name_mask, mask_slice, format='png', cmap='gray', origin='lower')
predata()

In [None]:
def from_directory_datagen():
    flow_params = {'target_size': (256, 256),
                   'class_mode': None,
                   'color_mode': 'rgb'
                   }

    images_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
        rescale=1 / 255,
    )

    mask_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
        rescale=1 / 255,
    )

    tr_im = images_datagen.flow_from_directory(
        '/content/train/data/',
        batch_size=32,
        seed=42,
        **flow_params
    )

    tr_mask = mask_datagen.flow_from_directory(
        '/content/train/mask/',
        batch_size=32,
        seed=42,
        **flow_params
    )
    val_im = images_datagen.flow_from_directory(
        '/content/val/data/',
        batch_size=32,
        seed=42,
        **flow_params
    )

    val_mask = mask_datagen.flow_from_directory(
        '/content/val/mask/',
        batch_size=32,
        seed=42,
        **flow_params
    )

    return tr_im, tr_mask, val_im, val_mask

In [None]:
def train_net():
  with tf.device("/gpu:0"):
    backbone = 'resnet50'
    preprocess_input = get_preprocessing(backbone)

    # load your data
    x_train, y_train, x_val, y_val = from_directory_datagen()

    # preprocess input
    x_train = preprocess_input(x_train)
    x_val = preprocess_input(x_val)

    # define model
    model = Unet(backbone, encoder_weights='imagenet', input_shape=(256, 256, 3))
    model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-3), loss=dice_loss,
                  metrics=[f1_score, iou_score])
    
    check_point = [ModelCheckpoint('/content/drive/My Drive/model-{epoch:03d}-{val_f1-score:03f}.h5', verbose=1,
                             monitor='val_f1-score',
                             save_best_only=True, mode='max')]

    # fit model
    model.fit(
        x=(pair for pair in zip(x_train, y_train)),
        epochs=10,
        steps_per_epoch=x_train.n // x_train.batch_size,
        validation_data=(pair for pair in zip(x_val, y_val)),
        validation_steps=x_val.n // x_val.batch_size,
        verbose=1,
        shuffle=True,
        callbacks=check_point,
    )
    model.save('/content/drive/My Drive/unet2.h5')
train_net()

In [None]:
def predict_3D():
    first_dataset_path = Path('FirstDataset/test')
    second_dataset_path = Path('SecondDataset/test')

    backbone = 'resnet50'
    preprocess_input = get_preprocessing(backbone)

    # define model
    model = Unet(backbone, encoder_weights='imagenet', input_shape=(None, None, 3))
    model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-3), loss=dice_loss,
                  metrics=[f1_score, iou_score])
    model.load_weights('/content/drive/My Drive/model-final.h5')

    for scan_path in first_dataset_path.iterdir():
        if scan_path.name.endswith('mask.nii.gz'):
            print(nib.load(str(scan_path)).header.get_zooms())

    print()

    for scan_path in second_dataset_path.iterdir():
        print(nib.load(str(scan_path / 'T1w.nii.gz')).header.get_zooms())

    predictions_base_path = Path('/content/drive/My Drive/Predictions')
    first_dataset_predictions_path = predictions_base_path / 'first'
    second_dataset_predictions_path = predictions_base_path / 'second'

    first_dataset_predictions_path.mkdir(exist_ok=True, parents=True)
    second_dataset_predictions_path.mkdir(exist_ok=True, parents=True)

    first_dataset_test_path = Path('FirstDataset/test')
    second_dataset_test_path = Path('SecondDataset/test')
    fit = MinMaxScaler()

    for scan_path in first_dataset_test_path.iterdir():
        data, affine = load_raw_volume(scan_path)
        labels = np.zeros(data.shape, dtype=np.uint8)

        x_size, y_size, z_size = data.shape
        for y_index in range(y_size):
            data_slice = data[:, y_index, :]
            data_slice = cv2.resize(data_slice, (256, 256))
            data_slice = fit.fit_transform(data_slice)
            data_slice = cv2.cvtColor(data_slice, cv2.COLOR_GRAY2RGB)
            prediction = model.predict(data_slice[None, :])
            prediction[prediction < 0.5] = 0
            prediction[prediction >= 0.5] = 1
            prediction = prediction.squeeze()
            labels[:, y_index, :] = cv2.resize(prediction, (z_size, x_size))

        save_labels(labels, affine, first_dataset_predictions_path / scan_path.name)

    for scan_path in second_dataset_test_path.iterdir():
        data, affine = load_raw_volume(scan_path / 'T1w.nii.gz')
        labels = np.zeros(data.shape, dtype=np.uint8)

        x_size, y_size, z_size = data.shape
        for y_index in range(y_size):
            data_slice = data[:, y_index, :]
            data_slice = cv2.resize(data_slice, (256, 256))
            data_slice = fit.fit_transform(data_slice)
            data_slice = cv2.cvtColor(data_slice, cv2.COLOR_GRAY2RGB)
            prediction = model.predict(data_slice[None, :])
            prediction[prediction < 0.5] = 0
            prediction[prediction >= 0.5] = 1
            prediction = prediction.squeeze()
            labels[:, y_index, :] = cv2.resize(prediction, (z_size, x_size))
        save_labels(labels, affine, second_dataset_predictions_path / f'{scan_path.name}.nii.gz')

predict_3D()

In [None]:
def check_3D():
    mean = 0
    i = 0
    first_dataset_predictions_path = Path('/content/drive/My Drive/Predictions/first')
    second_dataset_predictions_path = Path('/content/drive/My Drive/Predictions/second')
    for dataset_predictions_path in (first_dataset_predictions_path, second_dataset_predictions_path):
        for prediction_path in dataset_predictions_path.iterdir():
            prediction_name = prediction_path.name[:-7]  # deleting '.nii.gz' from filename
            prediction = nib.load(str(prediction_path))

            response = requests.post(f'link to prediction checker{prediction_name}',
                                     data=zlib.compress(prediction.to_bytes()))
            if response.status_code == 200:
                print(dataset_predictions_path.name, prediction_path.name, response.json())
                mean += response.json()['dice']
                i += 1
            else:
                print(f'Error processing prediction {dataset_predictions_path.name}/{prediction_name}: {response.text}')

    print(mean/i)
check_3D()
