<a href="https://colab.research.google.com/github/Achillesy/Fetal_Functional_MRI_Segmentation/blob/master/fmri_vnet_interface.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Automated Brain Masking of Fetal Functional MRI with Open Data

![MONAI](https://miro.medium.com/v2/resize:fit:4800/format:webp/1*2bl-7kOoc3sONYgm5InOUQ.png)

## Please upload your Fetal Functional MRI files

In [1]:
import tempfile
import io
import ipywidgets as widgets

def upload_and_continue():
    upload_button = widgets.FileUpload()
    display(upload_button)

    def handle_upload_button(sender):
        for filename, content in upload_button.value.items():
            with tempfile.NamedTemporaryFile(delete=False, prefix=filename) as temp_file:
                temp_file.write(content['content'])
                temp_file_path = temp_file.name
                print("fMRI file:", temp_file_path)

    upload_button.observe(handle_upload_button, names='value')

upload_and_continue()


FileUpload(value={}, description='Upload')

fMRI file: /tmp/sub-2013_ses-T1_task-rest_bold.nii.gz5zdt6u0a


By pressing the **Enter** key in the input box below, the fMRI mask will be automatically generated and downloaded in a short time. 

In [2]:
input()





''

In [3]:
!rm /tmp/*.*

rm: cannot remove '/tmp/initgoogle_syslog_dir.0': Is a directory


In [16]:
!apt-get install -qq -y git
!git clone https://github.com/Achillesy/Fetal_Functional_MRI_Segmentation.git


fatal: destination path 'Fetal_Functional_MRI_Segmentation' already exists and is not an empty directory.


In [5]:
import subprocess

!cat Fetal_Functional_MRI_Segmentation/models/fold4_train_metric_vnet_part_* > fold4_train_metric_vnet.pth
file_pth = "fold4_train_metric_vnet.pth"
output = subprocess.check_output(["md5sum", file_pth])
md5 = output.split()[0].decode()
expected_md5 = "cd8284f0e56f21a422b277f3be79ae10"
assert md5 == expected_md5, "MD5 value does not match"


In [6]:
!pip install monai

from monai.config import print_config
print_config()


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
MONAI version: 1.1.0
Numpy version: 1.22.4
Pytorch version: 2.0.1+cu118
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: a2ec3752f54bfc3b40e7952234fbeb5452ed63e3
MONAI __file__: /usr/local/lib/python3.10/dist-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 3.0.2
scikit-image version: 0.19.3
Pillow version: 8.4.0
Tensorboard version: 2.12.2
gdown version: 4.6.6
TorchVision version: 0.15.2+cu118
tqdm version: 4.65.0
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.5
pandas version: 1.5.3
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visi

In [7]:
import os
import numpy as np
import nibabel as nib
from glob import glob

import torch
from types import SimpleNamespace
from google.colab import files

cfg = SimpleNamespace(**{})
cfg.pixdim = (3.5, 3.5, 3.5)
cfg.roi_size = [64, 64, 64]
cfg.sw_batch_size = 4

cfg.mri_dir = "mri"
cfg.mask_dir = "mask"
os.makedirs(cfg.mri_dir, exist_ok=True)
os.makedirs(cfg.mask_dir, exist_ok=True)

cfg.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [8]:
from monai.networks.nets import VNet

model = VNet(
  spatial_dims=3,
  in_channels=1,
  out_channels=2,
  act=("elu", {"inplace": True}),
  dropout_dim=3,
  bias=False,
).to(cfg.device)
model.load_state_dict(torch.load(file_pth, cfg.device))


<All keys matched successfully>

In [9]:
from monai.transforms import (
  AsDiscreted,
  Compose,
  # CropForegroundd,
  EnsureChannelFirstd,
  Invertd,
  # Lambda,
  LoadImaged,
  NormalizeIntensityd,
  Orientationd,
  SaveImaged,
  Spacingd,
)

test_transforms = Compose(
  [
    LoadImaged(keys=["image"]),
    EnsureChannelFirstd(keys=["image"]),
    Orientationd(keys=["image"], axcodes="RAS"),
    Spacingd(keys=["image"], pixdim=cfg.pixdim, mode="bilinear"),
    NormalizeIntensityd(keys="image", nonzero=True),
  ]
)

post_transforms = Compose(
  [
    Invertd(
      keys="pred",
      transform=test_transforms,
      orig_keys="image",
      meta_keys="pred_meta_dict",
      orig_meta_keys="image_meta_dict",
      meta_key_postfix="meta_dict",
      nearest_interp=False,
      to_tensor=True,
    ),
    AsDiscreted(keys="pred", argmax=True),
    SaveImaged(
      keys="pred",
      meta_keys="pred_meta_dict",
      output_dir=cfg.mask_dir,
      output_postfix="vnet",
      resample=False,
    ),
  ]
)

In [17]:
frmi_files = glob("/tmp.*.nii.gz*")
for fmri_data in frmi_files:
  fmri_data_name = os.path.basename(fmri_data).replace(".nii.gz*", "")
  image = nib.load(fmri_data)
  data = image.get_fdata()
  if len(data.shape) != 4:
    raise ValueError("Invalid shape of fMRI file format. Expected 4D shape: [x, y, z, t]")
  channel_list = np.split(data, data.shape[-1], axis=-1)
  for i, channel in enumerate(channel_list):
    channel_image = nib.Nifti1Image(channel, image.affine)
    channel_file_name = os.path.join(cfg.mri_dir, f"{fmri_data_name}_{i+1}.nii.gz")
    nib.save(channel_image, channel_file_name)


In [18]:
test_files = []
rmi_files = glob(os.path.join(cfg.mri_dir, "*.nii.gz"))
for f_file in rmi_files:
  test_files.append({"image": f_file})
print(test_files)


[]


In [12]:
from monai.inferers import sliding_window_inference
from monai.data import DataLoader, Dataset, decollate_batch

test_ds = Dataset(data=test_files, transform=test_transforms)
test_loader = DataLoader(test_ds, batch_size=1)


In [13]:
with torch.no_grad():
  for test_data in test_loader:
    test_inputs = test_data["image"].to(cfg.device)
    test_data["pred"] = sliding_window_inference(
      test_inputs, cfg.roi_size, cfg.sw_batch_size, model
    )
    test_data = [post_transforms(i) for i in decollate_batch(test_data)]

In [14]:
for fmri_data in frmi_files:
  image = nib.load(fmri_data)
  data = image.get_fdata()

  fmri_data_name = os.path.basename(fmri_data).replace(".nii.gz", "")
  mask_data = np.zeros_like(data)
  for i in range(data.shape[-1]):
    i_mask_file = os.path.join(cfg.mask_dir,  f"{fmri_data_name}_{i+1}", f"{fmri_data_name}_{i+1}_vnet.nii.gz")
    i_mask_data = nib.load(i_mask_file).get_fdata()
    mask_data[:,:,:,i] = i_mask_data
  fmri_mask = nib.Nifti1Image(mask_data, affine=image.affine, header=image.header)
  mask_data_name = f"{fmri_data_name}_vnet.nii.gz"
  fmri_mask.to_filename(mask_data_name)
  files.download(mask_data_name)

In [15]:
!rm -rf {cfg.mri_dir}
!rm -rf {cfg.mask_dir}
