<a href="https://colab.research.google.com/github/arejimon/Artifact-Removal/blob/main/flairseg_2channel_pytorch_unet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### This notebook is optionally accelerated with a GPU runtime.
### If you would like to use this acceleration, please select the menu option "Runtime" -> "Change runtime type", select "Hardware Accelerator" -> "GPU" and click "SAVE"

----------------------------------------------------------------------

# U-Net for brain MRI

*Author: mateuszbuda*

**U-Net with batch normalization for biomedical image segmentation with pretrained weights for abnormality segmentation in brain MRI**

<img src="https://pytorch.org/assets/images/unet_brain_mri.png" alt="alt" width="50%"/>

In [1]:
!pip install dash
!pip install itk
!pip install SimpleITK

Collecting dash
  Downloading dash-2.14.2-py3-none-any.whl (10.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.2/10.2 MB[0m [31m28.8 MB/s[0m eta [36m0:00:00[0m
Collecting dash-html-components==2.0.0 (from dash)
  Downloading dash_html_components-2.0.0-py3-none-any.whl (4.1 kB)
Collecting dash-core-components==2.0.0 (from dash)
  Downloading dash_core_components-2.0.0-py3-none-any.whl (3.8 kB)
Collecting dash-table==5.0.0 (from dash)
  Downloading dash_table-5.0.0-py3-none-any.whl (3.9 kB)
Collecting retrying (from dash)
  Downloading retrying-1.3.4-py3-none-any.whl (11 kB)
Collecting ansi2html (from dash)
  Downloading ansi2html-1.9.1-py3-none-any.whl (17 kB)
Installing collected packages: dash-table, dash-html-components, dash-core-components, retrying, ansi2html, dash
Successfully installed ansi2html-1.9.1 dash-2.14.2 dash-core-components-2.0.0 dash-html-components-2.0.0 dash-table-5.0.0 retrying-1.3.4
Collecting itk
  Downloading itk-5.3.0-cp310-cp310-m

In [2]:
import dash
import itk
import SimpleITK as sitk
import matplotlib.pyplot as plt
import numpy as np
import os
from pathlib import Path
import torch

# from nnfit.utils.image import itk_to_sitk
# from nnfit.data.midas import *
# from nnfit.xtra.dash_slicer import VolumeSlicer

In [3]:
model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
    in_channels=3, out_channels=1, init_features=32, pretrained=True)


You are about to download and run code from an untrusted repository. In a future release, this won't be allowed. To add the repository to your trusted list, change the command to {calling_fn}(..., trust_repo=False) and a command prompt will appear asking for an explicit confirmation of trust, or load(..., trust_repo=True), which will assume that the prompt is to be answered with 'yes'. You can also use load(..., trust_repo='check') which will only prompt for confirmation if the repo is not already trusted. This will eventually be the default behaviour
Downloading: "https://github.com/mateuszbuda/brain-segmentation-pytorch/zipball/master" to /root/.cache/torch/hub/master.zip
Downloading: "https://github.com/mateuszbuda/brain-segmentation-pytorch/releases/download/v1.0/unet-e012d006.pt" to /root/.cache/torch/hub/checkpoints/unet-e012d006.pt


In [None]:
t1_img = sitk.ReadImage(f"/workspace/Dropbox/nn/artifact/t1_EM004_0_strip.nii.gz", imageIO="NiftiImageIO")
print(t1_img.GetSize())

mask_img = sitk.ReadImage(f"/workspace/Dropbox/nn/artifact/t1_EM004_0_strip_mask.nii.gz", imageIO="NiftiImageIO")
print(mask_img.GetSize())

flair_img = sitk.ReadImage(f"/workspace/Dropbox/nn/artifact/flair_EM004_0_strip.nii.gz", imageIO="NiftiImageIO")
print(flair_img.GetSize())

flair_img, _ = register_routine(t1_img, flair_img, learn_rate=4.0, stop=0.01, max_steps=50, rotate=False, log=True)
flair_img, _ = register_routine(t1_img, flair_img, learn_rate=0.01, stop=0.001, max_steps=50, rotate=True, log=True)


In [None]:
os.makedirs("/workspace/Dropbox/nn/artifact/data/", exist_ok=True)

In [None]:
for i in range(flair_img.GetSize()[2]):
    writer = sitk.ImageFileWriter()
    writer.SetFileName(f"/workspace/Dropbox/nn/artifact/data/flair_EM004_0_slice_{i}.tif")
    writer.Execute(flair_vol[:,:,:,i])

In [None]:
flair_vol = sitk.PermuteAxes(sitk.JoinSeries([flair_img, flair_img, flair_img]), [3,0,1,2])

In [None]:
flair_vol.GetSize()

In [None]:
plt.imshow(sitk.GetArrayFromImage(t1_img)[70, ...])
plt.figure()
plt.imshow(sitk.GetArrayFromImage(flair_img)[70, ...])
#plt.imshow(sitk.GetArrayFromImage(t1_img)[10, ...])
#plt.figure()
#plt.imshow(sitk.GetArrayFromImage(flair_img)[10, ...])

Loads a U-Net model pre-trained for abnormality segmentation on a dataset of brain MRI volumes [kaggle.com/mateuszbuda/lgg-mri-segmentation](https://www.kaggle.com/mateuszbuda/lgg-mri-segmentation)
The pre-trained model requires 3 input channels, 1 output channel, and 32 features in the first layer.

### Model Description

This U-Net model comprises four levels of blocks containing two convolutional layers with batch normalization and ReLU activation function, and one max pooling layer in the encoding part and up-convolutional layers instead in the decoding part.
The number of convolutional filters in each block is 32, 64, 128, and 256.
The bottleneck layer has 512 convolutional filters.
From the encoding layers, skip connections are used to the corresponding layers in the decoding part.
Input image is a 3-channel brain MRI slice from pre-contrast, FLAIR, and post-contrast sequences, respectively.
Output is a one-channel probability map of abnormality regions with the same size as the input image.
It can be transformed to a binary segmentation mask by thresholding as shown in the example below.

### Example

Input images for pre-trained model should have 3 channels and be resized to 256x256 pixels and z-score normalized per volume.

In [None]:
# Download an example image
import urllib
url, filename = ("https://github.com/mateuszbuda/brain-segmentation-pytorch/raw/master/assets/TCGA_CS_4944.png", "TCGA_CS_4944.png")
try: urllib.URLopener().retrieve(url, filename)
except: urllib.request.urlretrieve(url, filename)

In [None]:
import numpy as np
from PIL import Image
from torchvision import transforms

input_image = Image.open(filename)
m, s = np.mean(input_image, axis=(0, 1)), np.std(input_image, axis=(0, 1))
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=m, std=s),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)



In [None]:
flair_img = preprocess_img(flair_img)
t1_img = preprocess_img(t1_img)

In [None]:
plt.imshow(sitk.GetArrayFromImage(t1_img)[70, ...])
plt.figure()
plt.imshow(sitk.GetArrayFromImage(flair_img)[70, ...])
#plt.imshow(sitk.GetArrayFromImage(t1_img)[10, ...])
#plt.figure()
#plt.imshow(sitk.GetArrayFromImage(flair_img)[10, ...])

In [None]:
arr_t1 = sitk.GetArrayFromImage(t1_img)
arr_flair = sitk.GetArrayFromImage(flair_img)
arr_mask = sitk.GetArrayFromImage(mask_img)
plt.hist(arr_t1[np.where(arr_mask)])
plt.figure()
plt.hist(arr_t1.flatten())

In [None]:
def load_img(t1, flair, mask):
    """"""
    t1 = sitk.GetArrayFromImage(t1)
    flair = sitk.GetArrayFromImage(flair)
    mask = sitk.GetArrayFromImage(mask)

    arg_mask = np.where(mask == 1)
    anti_mask = np.where(mask == 0)

    t1_mean = np.mean(t1[arg_mask])
    t1_std = np.std(t1[arg_mask])

    flair_mean = np.mean(flair[arg_mask])
    flair_std = np.std(flair[arg_mask])

    t1 = (t1 - t1_mean) / t1_std
    flair = (flair - flair_mean) / flair_std

    #t1[anti_mask] = 0.0
    #flair[anti_mask] = 0.0

    t1 = t1[:, None, ...].astype(np.float32)
    flair = flair[:, None, ...].astype(np.float32)

    t1 = torch.from_numpy(t1)
    flair = torch.from_numpy(flair)

    image = torch.concatenate([flair, flair, flair], axis=1)
    #image = (image - torch.mean(image)) / torch.std(image)

    return image

In [None]:
image = load_img(t1_img, flair_img, mask_img)

In [None]:
with torch.no_grad():
    result = model(image)

In [None]:
fig, ax = plt.subplots(1,3,figsize=(18,6))
z = 70
ax[0].imshow(result[z,0,...].detach().numpy())
ax[1].imshow(image[z,0,...].detach().numpy(), cmap='gray')
ax[2].imshow(image[z,1,...].detach().numpy(), cmap='gray')

In [None]:
# if torch.cuda.is_available():
#     input_batch = input_batch.to('cuda')
#     model = model.to('cuda')

# with torch.no_grad():
#     output = model(input_batch)

# print(torch.round(output[0]))

### References

- [Association of genomic subtypes of lower-grade gliomas with shape features automatically extracted by a deep learning algorithm](http://arxiv.org/abs/1906.03720)
- [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/abs/1505.04597)
- [Brain MRI segmentation dataset](https://www.kaggle.com/mateuszbuda/lgg-mri-segmentation)