# Setup

## Imports

In [None]:
import pyplastimatch as pypla
from monailabel.datastore.utils.convert import nifti_to_dicom_seg
import nibabel as nib
from glob import glob
import numpy as np 
import pydicom
import tempfile
import pydicom
import SimpleITK as sitk
import os

import logging
logging.basicConfig(level=logging.DEBUG)

## Config

In [None]:
INPUT_DIR  = '~/data/nifti'
OUTPUT_DIR = './output'

INPUT_SERIES        = os.path.expanduser(os.path.join(INPUT_DIR , 'sample_mri.nii.gz'))
INPUT_SEGMENTATION  = os.path.expanduser(os.path.join(INPUT_DIR , 'sample_segmentation.nii.gz' ))
OUTPUT_SERIES       = os.path.expanduser(os.path.join(OUTPUT_DIR, 'series/'    ))
OUTPUT_SEGMENTATION = os.path.expanduser(os.path.join(OUTPUT_DIR, 'segmentation/seg.dcm'))

# Conversion

## Series

In [None]:
logging.info(f'🗑️ Deleting any existing files in {OUTPUT_SERIES}...')
existing_files = glob(OUTPUT_SERIES + '/*')
logging.debug(f'Found {len(existing_files)} existing files: {existing_files[:5]}...')
for file_path in existing_files:
    logging.debug(f'Deleting {file_path}...')
    os.remove(file_path)
logging.info(f'✅ Deleted {len(existing_files)} files from {OUTPUT_SERIES}.')

In [None]:
logging.info(f'🔄 Loading the series nifti file into memory...')
input_series_nifti = nib.load(INPUT_SERIES)
logging.info(f'🧠 Series nifti file loaded into memory.')

logging.info('🔄 Saving the series to a temporary file...')
with tempfile.NamedTemporaryFile(suffix='.nii.gz') as temp_series_file:
    nib.save(input_series_nifti, temp_series_file.name)
    logging.info(f'📁 Temporary series file created at: {temp_series_file.name}')

    logging.info('🔄 Converting flipped series to DICOM with plastimatch...')
    !plastimatch convert --input {temp_series_file.name} --output-dicom {OUTPUT_SERIES}

logging.info(f'📂 Created files = {list(glob(OUTPUT_SERIES + "/*"))[:5]}\n... ({len(glob(OUTPUT_SERIES + "/*"))} files in total)')

## Segmentation

In [None]:
logging.info(f'📊 Input series: {INPUT_SERIES = }')
logging.info(f'📊 Input segmentation: {INPUT_SEGMENTATION = }')

logging.info('🔄 Loading the nifti file into memory...')
input_segmentation_nifti = nib.load(INPUT_SEGMENTATION)
logging.info(f'🧠 Nifti file loaded into memory.')

logging.debug(f'{str(input_segmentation_nifti.header) = }')
logging.debug(f'{input_segmentation_nifti.header.get_data_shape() = }')

logging.info('🔄 Creating a temporary copy of the nifti file using nibabel...')
with tempfile.NamedTemporaryFile(suffix='.nii.gz') as temp_nifti_file:
    
    logging.info(f'🔄 Processing segmentation data...')
    internal_type = np.int16
    segmentation_data = input_segmentation_nifti.get_fdata().astype(internal_type)
    logging.debug(f'{segmentation_data.shape = }')
    logging.debug(f'{segmentation_data.dtype = }')
    
    # Isolate label 2 and convert to binary format (0 or 1)
    # Update the nifti file with the processed data
    input_segmentation_nifti = nib.Nifti1Image(
        segmentation_data,
        input_segmentation_nifti.affine,
        input_segmentation_nifti.header
    )
    logging.info('✅ Updated nifti file with binary segmentation data.')
    
    logging.info('🔄 Saving the nifti file to the temporary file...')
    nib.save(input_segmentation_nifti, temp_nifti_file.name)
    logging.info(f'📁 Temporary nifti file created at: {temp_nifti_file.name}')

    logging.info('🔄 Converting nifti to DICOM...')
    temp_segmentation_dicom_path = nifti_to_dicom_seg(
        series_dir=OUTPUT_SERIES,
        label=temp_nifti_file.name,
        #label_info={0: {'name': SEGMENTATION_LABEL}},
        label_info={},
        use_itk=False,
    )
    logging.debug(f'📊 temp_segmentation_dicom_path: {temp_segmentation_dicom_path}')

logging.info(f'📂 Loading DICOM SEG from {temp_segmentation_dicom_path}...')
dcm_seg = pydicom.dcmread(temp_segmentation_dicom_path)
logging.debug(f'📊 dcm_seg: {dcm_seg}')

logging.info(f'🗑️ Deleting initial nifti to DICOM temporary file...')
os.unlink(temp_segmentation_dicom_path)

logging.info('🔍 Finding a sample DICOM series to get FrameOfReferenceUID...')
dcm_series_sample = glob(OUTPUT_SERIES + '/*')[0]
logging.info(f'📂 Found DICOM series: {dcm_series_sample}')
logging.debug(f'📊 dcm_series_sample: {dcm_series_sample}')

logging.info('📝 Copying FrameOfReferenceUID from DICOM series to DICOM SEG...')
dcm_seg.FrameOfReferenceUID = pydicom.dcmread(dcm_series_sample).FrameOfReferenceUID
logging.debug(f'📊 dcm_seg.FrameOfReferenceUID: {dcm_seg.FrameOfReferenceUID}')

logging.info('💾 Saving the final DICOM SEG...')
os.makedirs(os.path.dirname(OUTPUT_SEGMENTATION))
dcm_seg.save_as(OUTPUT_SEGMENTATION)
logging.info(f'📊 Saved to: {OUTPUT_SEGMENTATION}')

logging.info('✅ Done.')