In [None]:
# Install the package from bitbucket.
!pip uninstall arqee -y # remove existing installation, if any
!pip install ../

In [None]:
# Some additional libraries for the toturial
!pip install matplotlib
!pip install gif

In [None]:
# Donwload a small sample of test data
# This will download +-240Mb of data under ./local__data/sample_data
import os
import arqee
current_directory = os.getcwd()
download_loc = os.path.join(current_directory,"local_data")
arqee.download_data_sample(download_loc,verbose=True)

In [None]:
# Load a sample
import os
import numpy as np
import matplotlib.pyplot as plt
sample_alax = np.load(os.path.join(download_loc,"sample_data/alax.npy"))
# sample_alax is an ndarray of shape (nb_frames,height,width) with values in range [0,255]
sample_frame = sample_alax[-1]
plt.imshow(sample_frame, cmap='gray')
plt.show()

In [None]:
'''
Download the end-to-end quality model.
Remarks:
    - The model is trained and tested on apical 4 chamber (A4C), apiacal 2 chamber (A2C) and apiacal long axis (ALAX) views.
'''
model_name = 'mobilenetv2_regional_quality'
arqee.download_and_set_up_model(model_name)

In [None]:
# Once the model is set up in arqee, you can load it as follows:
model_object = arqee.load_model(model_name)

In [None]:
# Run quality inference on a single frame
# model_object.predict_img expects the data to be in the format (nb_channels,height,width)
sample_frame_with_channel = np.expand_dims(sample_frame, axis=0) # add channel dimension
print(sample_frame_with_channel)
res_labels = model_object.predict_img(sample_frame_with_channel,verbose=True)
print(res_labels)
# This gives quality scores in a continuous scale matched fitted on the annotations of the clinicians, with meaning:
# 0: not visible, 1 poor, 3 ok, 4 good, 5 excellent
# If you want the labels in categorical format, set convert_to_labels to True
res_labels_cat = model_object.predict_img(sample_frame_with_channel,convert_to_labels=True,verbose=False)
# The output of the inference is a ndarray with size 1x8 with the quality labels in the following order:
# basal_left,mid_left,apical_left,apical_right,mid_right,basal_right,annulus_left,annulus_right
print(res_labels_cat)

In [None]:
# You can also run inference on a recording
# model_object.predict_recording expects the data to be in the format (nb_frames,nb_channels,height,width)
sample_alax_with_channel = np.expand_dims(sample_alax, axis=1) # add channel dimension
res_labels_rec = model_object.predict_recording(sample_alax_with_channel,convert_to_labels=True,verbose=True)
# The output of the inference is a ndarray with size nb_framesx8 with the quality labels in the following order:
# basal_left,mid_left,apical_left,apical_right,mid_right,basal_right,annulus_left,annulus_right
print(res_labels_rec)
print(res_labels_rec.shape)

In [None]:
'''
Let's visualize the results
For this, we first need a segmentation model to get the regions of interest the model is referring to when making predictions. To download the model, we use the same procedure as the quality model.
There are two segmentation models available:
    - 'nnunet_hunt4_alax' (33M parameters)
    - 'nnunet_hunt4_a2c_a4c'  (33M parameters)
Remarks:
    - The models are trained on the NTNU internal HUNT4 dataset. 
'''
model_name_seg = 'nnunet_hunt4_alax'
arqee.download_and_set_up_model(model_name_seg)

In [None]:
# Once the model is set up in arqee, we can load it as follows:
model_object_seg = arqee.load_model(model_name_seg)

In [None]:
# Now we can run the segmentation
segmentation_sample = model_object_seg.predict_img(sample_frame_with_channel,verbose=False)
plt.imshow(segmentation_sample)
plt.show()

In [None]:
# With the segmentation available, we can visualize the quality predictions
from skimage.transform import resize

sample_frame_resized = resize(sample_frame, (256, 256), preserve_range=True)
_ = arqee.plot_quality_prediction_result(sample_frame_resized,segmentation_sample,res_labels_cat)
plt.show()

In [None]:
# Let's visualize the results for a recording
from tqdm import tqdm
import gif
print('Running inference on the recording')
res_segmentations = model_object_seg.predict_recording(sample_alax_with_channel, verbose=True)

@gif.frame
def plot_frame(sample_frame,sample_seg,quality_labels):
    resized_frame = resize(sample_frame[0], (256, 256), preserve_range=True)
    arqee.plot_quality_prediction_result(resized_frame, sample_seg,quality_labels)

print('Creating gif')
frames = [plot_frame(sample_alax_with_channel[i], res_segmentations[i],res_labels_rec[i]) for i in tqdm(range(len(sample_alax_with_channel)))]
gif.save(frames, "image_quality_prediction.gif", duration=19)

In [None]:
from IPython.display import Image

with open('./image_quality_prediction.gif','rb') as f:
    display(Image(data=f.read(), format='png'))