<a href="https://colab.research.google.com/github/cerr/pycerr-notebooks/blob/main/autosegment_CT_HeadAndNeck_OARs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction

In this tutorial, we will demonstrate how to apply a pre-trained AI model to segment the OARs on head and neck CT scans.

## Requirements
* Python>=3.7
* This model can be run on a CPU.   
  To use a GPU on Colab:  Select `Runtime` > `Change runtime type > Select GPU `
* Data I/O, pre- and post-processing are performed using [***pyCERR***](https://github.com/cerr/pyCERR) .

## I/O
* **Input**: DICOM-format CT scan(s) of the prostate.  
  
* **Output**: DICOM RTStruct-format segmentations.  
  
  
  Input data should be organized as: one directory of DICOM images per patient.      
  
    
    Input dir
            |------Pat1  
                      |------img1.dcm  
                             img2.dcm  
                             ....  
                             ....  
            |-----Pat2  
                     |------img1.dcm  
                            img2.dcm  
                            ....  
                            ....  


## AI model
* The segmentation model used here was trained and validated on CT scans used for RT planning. Its performance on diagnostic scans is expected to be sub-optimal.
* The trained model is packaged as a Conda environment archive containing  python libraries and other dependencies.

### Installing the model and its dependencies

* Installation is performed using CERR's [***model installer***]( https://github.com/cerr/model_installer).  

* A Conda archive containing dependencies is downloaded to the `conda-pack`   
  sub-directory of a configurable `scriptInstallDir`.  
  By default `condaEnvPath = '/content/CT_HeadAndNeck_SelfAttn/conda-pack'`
  
* The inference script is located at   
  `scriptPath = os.path.join(condaEnvPath,'model_wrapper', run_inference_nii.py')`  

### Running the model
```python
!python {scriptPath} {input_nii_directory} {output_nii_directory}
```

## License

By downloading the software you are agreeing to the following terms and conditions as well as to the Terms of Use of CERR software.

**`THE SOFTWARE IS PROVIDED "AS IS" AND CERR DEVELOPMENT TEAM AND ITS COLLABORATORS DO NOT MAKE ANY WARRANTY, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE, NOR DO THEY ASSUME ANY LIABILITY OR RESPONSIBILITY FOR THE USE OF THIS SOFTWARE.`**

`This software is for research purposes only and has not been approved for clinical use.`

`Software has not been reviewed or approved by the Food and Drug Administration, and is for non-clinical, IRB-approved Research Use Only. In no event shall data or images generated through the use of the Software be used in the provision of patient care.`
  
`YOU MAY NOT DISTRIBUTE COPIES of this software, or copies of software derived from this software, to others outside your organization without specific prior written permission from the CERR development team except where noted for specific software products.`

`All Technology and technical data delivered under this Agreement are subject to US export control laws and may be subject to export or import regulations in other countries. You agree to comply strictly with all such laws and regulations and acknowledge that you have the responsibility to obtain
such licenses to export, re-export, or import as may be required after delivery to you.`

**`You may publish papers and books using results produced using software provided you cite the following`**:
  
  * **AI model**: https://doi.org/10.48550/arXiv.1909.05054
  * **CERR model library**: https://doi.org/10.1016/j.ejmp.2020.04.011



# Downloads

### Install ***pyCERR***

pyCERR handles data import/export, pre- and post-processing transformations for this auto-segmentation model.

In [None]:
%%capture
!pip install "pyCERR[napari] @ git+https://github.com/cerr/pyCERR.git@testing"


## Prepare planning CTs (DICOM)

In [None]:
import os
workDir = r'/content' # For Colab

 ### We will use the sample head & neck CT dataset distributed with pyCERR for this demo.

In [None]:
inputDicomPath = '/usr/local/lib/python3.10/dist-packages/cerr/datasets/sample_ct/' #head_and_neck

### Alternatives

 Uncomment the following cell to download data from user-specified ***dataUrl*** to ***dataDownloadDir***.

See also: [demo notebook](github.com/cerr/pyCERR-Notebooks/download_data_from_xnat.ipynb) for downloading data from XNAT.

In [None]:
# dataUrl = 'http://path.to/data'
# dataDownloadDir = os.path.join(workDir, 'sampleData')
# os.makedirs(dataDownloadDir, exist_ok=True)
# ! wget -O sampleData.gz -L {dataUrl}
# ! tar xf sampleData.gz -C {dataDownloadDir}
# ! rm sampleData.gz

# #Paths to input data
# inputDicomPath = os.path.join(dataDownloadDir,'your_dir_name_here')  # Replace with apropriate path to your dataset

### Define paths to input DICOM directory, desired output directory, and a session directory to store temporary files during model execution

In [None]:
# Paths to conda env with pre-trained models
outputDicomPath = os.path.join(workDir, 'AIoutput')
sessionPath = os.path.join(workDir, 'temp')

if not os.path.exists(outputDicomPath):
  os.mkdir(outputDicomPath)

if not os.path.exists(sessionPath):
  os.mkdir(sessionPath)

## Download pre-trained model, inference script, and packaged conda evnironment to ***scriptInstallDir***


In [None]:
%%capture

# Download model installer
os.chdir(workDir)
!git clone https://github.com/cerr/model_installer.git
os.chdir(os.path.join(workDir,'model_installer'))

# Install CT HN OAR model
!./installer.sh
modelOpt = '6'  # CT_HeadAndNeck_OARs
pythonOpt = 'C' # Download packaged Conda environment

! source ./installer.sh -m {modelOpt} -d {workDir} -p {pythonOpt}

In [None]:
# Location of inference script
scriptInstallDir = os.path.join(workDir, 'CT_HeadAndNeck_OARs')
scriptPath = os.path.join(scriptInstallDir,
                         'model_wrapper',
                         'run_inference_nii.py')

# Location of Conda archive
condaEnvPath = os.path.join(scriptInstallDir, 'conda-pack')

# Location of activation script for Conda environment
activateScript = os.path.join(condaEnvPath,'bin','activate')

# Data processing

## Pre-processing

### `processInputData`: Identify input scan, resample in-plane to 1mm x 1mm. Next,  crop to patient outline and resize the resulting region in-plane by to 256 x 256 voxels by padding with zeros

In [None]:
import numpy as np

import cerr
from cerr import plan_container as pc
from cerr.utils.ai_pipeline import getScanNumFromIdentifier
from cerr.dataclasses import scan
from cerr.radiomics.preprocess import getResampledGrid, imgResample3D
from cerr.utils.mask import computeBoundingBox, getPatientOutline
from cerr.utils.image_proc import resizeScanAndMask

def processInputData(planC):

  # Processing parameters
  modality = 'CT'
  identifier = {'imageType':'CT SCAN'}

  gridType = 'center'
  resampMethod = 'sitkBSpline'
  outputResV = [0.1, 0.1, 0]  #Output res: 1mm x 1mm in-plane

  resizeMethod = 'pad2d'
  outSizeV = [256,256]
  inputMask3M = None

  outlineStructName = 'outline'
  intensityThreshold = -400   #Air intensity for outline detection

  #--------------------------------------------------
  #          Extract input scan
  #---------------------------------------------------
  # Get scan array
  scanNum = getScanNumFromIdentifier(identifier, planC)[0]
  xValsV, yValsV, zValsV = planC.scan[scanNum].getScanXYZVals()
  scan3M = planC.scan[scanNum].getScanArray()

  #--------------------------------------------------
  #          Process input scan
  #---------------------------------------------------
  # 1. Resample
  [xResampleV,yResampleV,zResampleV] = getResampledGrid(outputResV,
                                                        xValsV, yValsV, zValsV,\
                                                        gridType)
  resampScan3M = imgResample3D(scan3M,
                               xValsV, yValsV, zValsV,\
                               xResampleV, yResampleV, zResampleV,\
                               resampMethod, inPlane=True)

  resampleGridS = [xResampleV, yResampleV, zResampleV]
  planC = pc.importScanArray(resampScan3M,
                             resampleGridS[0], resampleGridS[1], resampleGridS[2],\
                             modality, scanNum, planC)
  resampleScanNum = len(planC.scan) - 1

  # 2. Extract patient outline
  replaceStrNum = None
  outline3M = getPatientOutline(resampScan3M, intensityThreshold)
  resampSizeV = outline3M.shape
  planC = pc.importStructureMask(outline3M, scanNum,
                                 outlineStructName,
                                 planC, replaceStrNum)

  # 3. Crop to patient outline on each slice
  sumSlices = np.sum(outline3M, axis=(0, 1))
  validSlicesV = np.where(sumSlices > 0)[0]
  numSlcs = len(validSlicesV)
  limitsM = np.zeros((numSlcs,4))

  for slc in range(numSlcs):
    minr, maxr, minc, maxc, _, _, _ = computeBoundingBox(\
                                               outline3M[:,:,validSlicesV[slc]],
                                               is2DFlag=True)
    limitsM[slc,:] = [minr, maxr, minc, maxc]

  # 4. Resize to 256 x 256 in-plane
  resampSlc3M = resampScan3M[:,:,validSlicesV]
  slcGridS = (resampleGridS[0], resampleGridS[1], resampleGridS[2][validSlicesV])
  procScan3M, maskOut4M, resizeGridS = resizeScanAndMask(resampSlc3M,
                                                         inputMask3M,
                                                         slcGridS,
                                                         outSizeV,
                                                         resizeMethod,
                                                         limitsM=limitsM)
  planC = pc.importScanArray(procScan3M,
                             resizeGridS[0], resizeGridS[1], resizeGridS[2],\
                             modality, scanNum, planC)
  procScanNum = len(planC.scan) - 1

  scanList = [scanNum, resampleScanNum, procScanNum]

  return scanList, validSlicesV, resizeGridS, limitsM, planC


## Post-processing

### `postProcAndImportSeg`: Import label maps to planC and filter to retain only largest connected component.

In [None]:
# Map output labels to structure names
strToLabelMap = {1:"Left Parotid", 2:"Right Parotid", 3:"Left Submandible",\
                 4:"Right Submandible", 7:"Mandible", 8:"Spinal cord",\
                 9:"Brain stem", 10:"Oral cavity"}
outputStrLabels = list(strToLabelMap.keys())
outputStrNames = list(strToLabelMap.values())

In [None]:
#Import label map to CERR
import glob
import SimpleITK as sitk
from cerr.dataclasses import structure
from cerr.contour import rasterseg as rs

def postProcAndImportSeg(outputDir, scanList, outSlicesV, resizeGridS,
                         limitsM, planC):


  origScanNum = scanList[0]
  resampScanNum = scanList[1]
  procScanNum = scanList[2]
  numStrOrig = len(planC.structure)

  #--------------------------------------------------
  #              Read AI-generated mask
  #---------------------------------------------------
  niiGlob = glob.glob(os.path.join(outputDir,'*.nii.gz'))
  print('Importing ' + niiGlob[0]+'...')
  outputMask4M = sitk.ReadImage(os.path.join(outputDir,niiGlob[0]))
  outputMask4M = np.moveaxis(sitk.GetArrayFromImage(outputMask4M),0,2)

  #--------------------------------------------------
  #      Undo pre-processing transformations
  #---------------------------------------------------
  #1. Undo resizing
  outputScan3M = None
  method = 'unpad2d'
  resampSizeV = planC.scan[resampScanNum].getScanSize()
  resampOutSizeV = [resampSizeV[0], resampSizeV[1], len(outSlicesV)]
  print('O/p mask shape')
  print(outputMask4M.shape)

  _, unPadMask4M, unPadGridS = resizeScanAndMask(outputScan3M,
                                                outputMask4M,
                                                resizeGridS,
                                                resampOutSizeV,
                                                method,
                                                limitsM=limitsM)
  #2. Undo outline crop
  print('Undo crop')
  resampMask4M = np.full((resampSizeV[0], resampSizeV[1],
                          resampSizeV[2], unPadMask4M.shape[3]), 0)
  resampMask4M[:,:,outSlicesV,:] = unPadMask4M

  # Extract binary masks and retain largest connected component
  numComponents = 1
  replaceStrNum = None
  procStrV = []
  # Loop over labels
  for labelIdx in range(len(strToLabelMap)):

    # Import mask to processed scan
    strName = outputStrNames[labelIdx]
    maskIdx = outputStrLabels[labelIdx]-1
    outputMask = resampMask4M[:, :, :, maskIdx]
    planC = pc.importStructureMask(outputMask, resampScanNum,
                                   strName, planC, replaceStrNum)
    procStr = len(planC.structure)-1
    procStrV.append(procStr)

    # Copy to original scan
    planC = structure.copyToScan(procStr, origScanNum, planC)
    scanStr = len(planC.structure) - 1

    # Post-process and replace input structure in planC
    procMask3M = structure.getLargestConnComps(scanStr, numComponents,
                                               planC, saveFlag=True,
                                               replaceFlag=True,
                                               procSructName=strName)
    del planC.structure[procStr]


  return planC, procStrV

# Segment OARs

## Apply AI model  to all MR scans

### located in ***inputDicomPath*** and write auto-segmentation results to ***outputDicomPath***

In [None]:
%%capture
import subprocess
import numpy as np

from cerr.dataclasses import scan as cerrScn
from cerr.dcm_export import rtstruct_iod
from cerr.utils.ai_pipeline import createSessionDir


# Loop over DICOM directories
fileList = os.listdir(inputDicomPath)
numFiles = len(fileList)
modality = 'CT'

for iFile in range(numFiles):

    inputFilename = fileList[iFile]
    dcmDir = os.path.join(inputDicomPath, inputFilename)

    # Create session dir to store temporary data
    modInputPath, modOutputPath = createSessionDir(sessionPath,
                                                   inputDicomPath)

    # Import DICOM scan to planC
    planC = pc.loadDcmDir(dcmDir)
    scanList, validSlicesV, procGridS, limitsM, planC = processInputData(planC)

    # Export processed scan to NIfTI
    origScanNum = scanList[0]
    procScanNum = scanList[2]
    numExistingStructs = len(planC.structure)
    scanFilename = os.path.join(modInputPath,
                                f"{inputFilename}_scan_3D.nii.gz")
    planC.scan[procScanNum].saveNii(scanFilename)

    # Apply model
    subprocess.run(f"source {activateScript} && python {scriptPath} \
                  {modInputPath} {modOutputPath}", \
                  capture_output=False, shell=True, executable="/bin/bash")

    # Import results to planC
    planC, procStrV = postProcAndImportSeg(modOutputPath, scanList,
                                           validSlicesV, procGridS,
                                           limitsM, planC)

    # Export segmentations to DICOM
    newNumStructs = len(planC.structure)
    structFileName = inputFilename + '_AI_seg.dcm'
    structFilePath = os.path.join(outputDicomPath, structFileName)
    seriesDescription = "AI Generated"
    exportOpts = {'seriesDescription': seriesDescription}

    structNumV = np.arange(numExistingStructs, newNumStructs)
    indOrigV = np.array([cerrScn.getScanNumFromUID(planC.structure[structNum].assocScanUID,\
                        planC) for structNum in structNumV], dtype=int)
    structsToExportV = structNumV[indOrigV == origScanNum]
    rtstruct_iod.create(structsToExportV, structFilePath, planC, exportOpts)

## **Optional**: Uncomment the following to download the output segmentations to your workspace bucket.

In [None]:
# workspaceBucket = os.environ['WORKSPACE_BUCKET']
# !gcloud storage cp -r {outputDicomPath} {workspaceBucket}

# Display results

## Overlay AI segmentations on scan for visualization using ***Matplotlib***

Note: This example displays the last segmented dataset by default.    
Load the appropriate pyCERR archive to `planC` to view results for desired dataset.

In [None]:
from cerr.viewer import showMplNb

showMplNb(planC=planC, scan_nums=origScanNum,
          struct_nums=structsToExportV,
          windowCenter=50, windowWidth=450)

HBox(children=(Dropdown(description='view', options=('Axial', 'Sagittal', 'Coronal'), value='Axial'), IntSlide…

Output()