In [None]:
# Install the package from bitbucket.
!pip uninstall arqee -y # remove existing installation, if any
!pip install ../

In [None]:
# Some additional libraries for the toturial
!pip install matplotlib
!pip install tqdm
!pip install gif

In [None]:
# Donwload a small sample of test data
# This will download +-240Mb of data under ./local__data/sample_data
import os
import arqee
current_directory = os.getcwd()
download_loc = os.path.join(current_directory,"local_data")
arqee.download_data_sample(download_loc,verbose=True)

In [None]:
# Load a sample
import os
import numpy as np
import matplotlib.pyplot as plt
sample_alax = np.load(os.path.join(download_loc,"sample_data/a2c.npy"))
# sample_alax is an ndarray of shape (nb_frames,height,width) with values in range [0,255]
sample_frame = sample_alax[-1]
plt.imshow(sample_frame, cmap='gray')
plt.show()

In [None]:
'''
we apply classic left ventricle and myocardium segmentation and divide the segmentation into 8 regions:
basal_left,mid_left,apical_left,apical_right,mid_right,basal_right,annulus_left,annulus_right
There are two segmentation models available:
    - 'nnunet_hunt4_alax' (33M parameters)
    - 'nnunet_hunt4_a2c_a4c'  (33M parameters)
'''
model_name_seg = 'nnunet_hunt4_a2c_a4c'
arqee.download_and_set_up_model(model_name_seg)

In [None]:
# Once the model is set up in arqee, we can load it as follows:
model_object_seg = arqee.load_model(model_name_seg)

In [None]:
# Now we can run the segmentation and divide the result into the 8 regions specified above
sample_frame_with_channel = np.expand_dims(sample_frame, axis=0) # add channel dimension
segmentation_sample = model_object_seg.predict_img(sample_frame_with_channel,verbose=False)
divided_indices=[0,1,2,3,4,5,6,7]
segmentation_regions = arqee.divide_segmentation(segmentation_sample,new_labels=divided_indices)
# this is an array of shape (8,256,256) with the 8 masks of the 8 regions. 
# Note some regions might be overlapping
visualization=arqee.create_visualization(sample_frame,segmentation_regions,labels=divided_indices)
plt.imshow(visualization)
plt.show()

In [None]:
# We calculate the gcnr quality metric for each region by using the myocardium region as region of interest and the lv lumen as background
# Let's set up the pixel based method
# Available pixel-based models are:
# - 'pixel_based_gcnr'
# - 'pixel_based_cnr'
# - 'pixel_based_cr'
# - 'pixel_based_intensity'
pixel_based_model=arqee.set_up_pixel_based_method("pixel_based_gcnr")

In [None]:
gcnr_values = pixel_based_model.predict_img(sample_frame_with_channel,segmentation=segmentation_sample,
                                          apply_linear_model=False)
print(gcnr_values)

# We can also apply a linear model to the gcnr values to get the quality labels. 
# This linear model is fitted on the validation set to map from quality metrics to quality labels.
quality_labels = pixel_based_model.predict_img(sample_frame_with_channel,segmentation=segmentation_sample,
                                             apply_linear_model=True, convert_to_labels=True)
print(quality_labels)

In [None]:
# Let's visualize the quality predictions from the pixel-based method
from skimage.transform import resize

sample_frame_resized = resize(sample_frame, (256, 256), preserve_range=True)
_ = arqee.plot_quality_prediction_result(sample_frame_resized,segmentation_sample,quality_labels)
plt.show()