# SpineNet Lumbar Grading Example

Last edited: 05/07/2022

This notebook shows an example of using SpineNet to grade a typical T2 lumbar scan
on CUDA-enabled hardware. Note that it is also possible to run SpineNet on a CPU
only (however, this will be slower).


## 01. Loading In SpineNet + Data

In [1]:
import sys
import os
from pathlib import Path
sys.path.insert(0, str(Path(os.getcwd()).parent)) # Add parent directory to path. This shouldn't be needed if SpineNet is pip-installed



import spinenet
from spinenet import SpineNet, download_example_scan
from spinenet.io import load_dicoms_from_folder

# download example scans 
os.makedirs('example_scans', exist_ok=True)

# download_example_scan('t2_lumbar_scan_1', file_path='example_scans')
# or use the other example scan available
scan_name = 't2_lumbar_scan_2'

example_scan_folder = f'./example_scans'
download_example_scan(scan_name, file_path=example_scan_folder)


# download weights from server. This may take a minute or two.
# You do not need to run this line if weights have already been downloaded.
spinenet.download_weights(verbose=True, force=False)

# load in spinenet. Replace device with 'cpu' if you are not using a CUDA-enabled machine.
spnt = SpineNet(device='cuda:0', verbose=True)




# metadata to be overwritten in the scan - useful if certain important values are missing from some/all dicom files
# in this case, slice thickness and image orientation are missing from the dicom files and so we add false values
# of 2mm and a sagittal orientation code ([0,1,0,0,0,1]). 
# Do not overwrite this metadata if it already exists in the dicom files being used.
overwrite_dict = {'SliceThickness': [2], 'ImageOrientationPatient': [0, 1, 0, 0, 0, -1]}


# loads in a dicom from the example scan folder.
# if set, the `require_extensions` flag requires that files end with `.dcm`
scan = load_dicoms_from_folder(f"{example_scan_folder}/{scan_name}", require_extensions=False, metadata_overwrites=overwrite_dict)


ModuleNotFoundError: No module named 'spinenet'

## 02. Visualize Scan Slices

Show each slice from the sagittal T2 lumbar scan and print out metadata information.

In [None]:
print(f'Scan has {scan.volume.shape[-1]} sagittal slices, of dimension {scan.volume.shape[0]}x{scan.volume.shape[1]} ({scan.pixel_spacing} mm pixel spacing) and {scan.slice_thickness} mm slice thickness.')


import matplotlib.pyplot as plt
fig = plt.figure(figsize=(8,8))

# show each sagittal slice
for slice_idx in range(scan.volume.shape[-1]):
    ax = fig.add_subplot(4,3,slice_idx+1)
    ax.imshow(scan.volume[:,:,slice_idx], cmap='gray')
    ax.set_title(f'Slice {slice_idx+1}')
    ax.axis('off')


## 03. Detect Vertebrae
Use SpineNet to detect vertebrae and then show the detections in slices.

In [None]:
# detect and identify vertebrae in scan. Note that pixel spacing information is required 
# so SpineNet knows what size to split patches into.
vert_dicts = spnt.detect_vb(scan.volume, scan.pixel_spacing)

print(f'{len(vert_dicts)} vertebrae detected; {[vert_dict["predicted_label"] for vert_dict in vert_dicts]}')


# visualize vertebrae detections in slices 
from matplotlib.patches import Polygon 
import numpy as np
fig = plt.figure(figsize=(8,8))
for slice_idx in range(scan.volume.shape[-1]):
    ax = fig.add_subplot(4,3,slice_idx+1)
    ax.imshow(scan.volume[:,:,slice_idx], cmap='gray')
    ax.set_title(f'Slice {slice_idx+1}')
    ax.axis('off')
    for vert_dict in vert_dicts:
        if slice_idx in vert_dict['slice_nos']:
            poly_idx = int(vert_dict['slice_nos'].index(slice_idx))
            poly = np.array(vert_dict['polys'][poly_idx])
            ax.add_patch(Polygon(poly, ec='y',fc='none'))
            ax.text(np.mean(poly[:,0]), np.mean(poly[:,1]), vert_dict['predicted_label'],c='y', ha='center',va='center')

fig.suptitle('Detected Vertebrae (all slices)')
plt.show()

## 04. Show vertebrae detections in the mid sagittal slice

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
# plot mid sagittal slice and detected vertebrae
plt.figure(figsize=(5,5))
plt.imshow(scan.volume[:,:,scan.volume.shape[-1]//2], cmap='gray')
for idx, vert_dict in enumerate(vert_dicts):
    if scan.volume.shape[-1]//2 in vert_dict['slice_nos']:
        poly = np.array(vert_dict['polys'][vert_dict['slice_nos'].index(scan.volume.shape[-1]//2)])
        plt.gca().add_patch(Polygon(poly, fc='none', ec='y'))
        plt.text(np.mean(poly[:,0]), np.mean(poly[:,1]), vert_dict['predicted_label'], color='y',fontsize=20, va='center', ha='center')
    else:
        continue

plt.axis('off')
plt.title('Detected Vertebrae (Mid Sagittal Slice)')
plt.show()


## 05. Perform grading of T2 scans for common radiological conditions

Since this is a T2 sagittal lumbar scan, SpineNet can be used to perform radiological grading for several common spinal degenerative conditions. Note that this is trained on IVDs from T12/L5 to L5/S1 so may not be accurate for vertebrae outside this range.

In [None]:
# extract IVDs using the detections from the previous stage.
ivd_dicts = spnt.get_ivds_from_vert_dicts(vert_dicts, scan.volume)

# grade IVDs - note that this is only validated on IVDs from L5/S1 to T12/L5 vertebrae.
# IVDs gradings are output as a pandas dictionary. For information on the grading schemes used, see http://zeus.robots.ox.ac.uk/spinenet2/
ivd_grades = spnt.grade_ivds(ivd_dicts)
ivd_grades.head(len(ivd_dicts))



In [None]:
results_folder = './results'
os.makedirs(results_folder, exist_ok=True)

ivd_grades.to_csv(f"{results_folder}/{scan_name}_ivd_grades.csv")


In [6]:
!pip install keras-core

Collecting keras-core
  Downloading keras_core-0.1.7-py3-none-any.whl.metadata (4.3 kB)
Downloading keras_core-0.1.7-py3-none-any.whl (950 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m950.8/950.8 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0mta [36m0:00:01[0m
[?25hInstalling collected packages: keras-core
Successfully installed keras-core-0.1.7


In [10]:
!pip install keras-core
!pip install wurlitzer
!pip install tensorflow==2.15.0  # Downgrade TensorFlow if necessary
!pip install tensorflow-decision-forests
!pip check  # Optional: Check for any remaining conflicts

Collecting tensorflow==2.15.0
  Downloading tensorflow-2.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.4 kB)
Collecting ml-dtypes~=0.2.0 (from tensorflow==2.15.0)
  Downloading ml_dtypes-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Collecting tensorboard<2.16,>=2.15 (from tensorflow==2.15.0)
  Downloading tensorboard-2.15.2-py3-none-any.whl.metadata (1.7 kB)
Collecting keras<2.16,>=2.15.0 (from tensorflow==2.15.0)
  Using cached keras-2.15.0-py3-none-any.whl.metadata (2.4 kB)
Downloading tensorflow-2.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (475.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m475.2/475.2 MB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hUsing cached keras-2.15.0-py3-none-any.whl (1.7 MB)
Downloading ml_dtypes-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [

In [11]:
!git clone --depth 1 https://github.com/tensorflow/models

Cloning into 'models'...
remote: Enumerating objects: 4245, done.[K
remote: Counting objects: 100% (4245/4245), done.[K
remote: Compressing objects: 100% (3235/3235), done.[K
remote: Total 4245 (delta 1193), reused 2123 (delta 939), pack-reused 0[K
Receiving objects: 100% (4245/4245), 48.86 MiB | 32.64 MiB/s, done.
Resolving deltas: 100% (1193/1193), done.


In [13]:
import os

In [14]:
os.getcwd()

'/kaggle/working'

In [15]:
os.chdir("/kaggle/working/models/research")

In [21]:
!cd models/research && protoc object_detection/protos/*.proto --python_out=.

In [18]:
os.getcwd()

'/kaggle/working/models/research'

In [20]:
%cd ..

/kaggle/working


In [22]:
os.environ['PYTHONPATH'] += ":/kaggle/working/models:/kaggle/working/models/research:/kaggle/working/models/research/slim"

In [23]:
!ls models/research/object_detection/protos

__init__.py			       input_reader_pb2.py
__pycache__			       keypoint_box_coder.proto
anchor_generator.proto		       keypoint_box_coder_pb2.py
anchor_generator_pb2.py		       losses.proto
argmax_matcher.proto		       losses_pb2.py
argmax_matcher_pb2.py		       matcher.proto
bipartite_matcher.proto		       matcher_pb2.py
bipartite_matcher_pb2.py	       mean_stddev_box_coder.proto
box_coder.proto			       mean_stddev_box_coder_pb2.py
box_coder_pb2.py		       model.proto
box_predictor.proto		       model_pb2.py
box_predictor_pb2.py		       multiscale_anchor_generator.proto
calibration.proto		       multiscale_anchor_generator_pb2.py
calibration_pb2.py		       optimizer.proto
center_net.proto		       optimizer_pb2.py
center_net_pb2.py		       pipeline.proto
eval.proto			       pipeline_pb2.py
eval_pb2.py			       post_processing.proto
faster_rcnn.proto		       post_processing_pb2.py
faster_rcnn_box_coder.proto	       preprocessor.proto
faster_rcnn_box_coder_pb2.py	       preprocessor_pb2.

In [24]:
from object_detection.builders import model_builder