In [146]:
import tensorflow as tf
import numpy as np
import scipy
import pydicom
import glob
import os
import re
import nibabel

In [81]:
PATIENTS_SRC_FOLDER = "/docs/src/kt/datasets/ct-150/data/"
LABELS_SRC_FOLDER = "/docs/src/kt/datasets/ct-150/labels/"
TFRECORD_FOLDER = "/docs/src/kt/datasets/ct-150/tfrecords/"
INPUT_DIMS = np.array([256, 256, 256])
AUGMENT_SCALE_FACTOR = 0.1
AUGMENT_SCALED_DIMS = tf.cast(tf.constant(input_dims, dtype=tf.float32) * (1 + scale_factor), dtype=tf.int32).numpy()

In [160]:
class CT150:
    def __init__(self, patients_src_folder, labels_src_folder, tfrecord_folder):
        self.patients_src_folder = patients_src_folder
        self.labels_src_folder = labels_src_folder
        self.tfrecord_folder = tfrecord_folder

    def GetPatientIDFromFolder(self, folder):
        result = None
        m = re.search("PANCREAS_(\d+)", folder)
        if(m):
            result = m.group(1)
        return result
        
    def ReadDICOMDataFromFiles(self, files):
        result = np.array([])
        
        slices = [pydicom.dcmread(file) for file in files]
        slices.sort(key = lambda x: float(x.ImagePositionPatient[2]) )

        if(len(slices)):
            result = np.stack([_.pixel_array for _ in slices], axis=-1)
        else:
            print("ERROR: can't dcmrad from files:", files)
        
        return result
        
    def GetDICOMData(self, folder):
        result = np.array([])
        file_list = glob.glob(os.path.join(folder, "*.dcm"))
        if(len(file_list)):
            result = self.ReadDICOMDataFromFiles(file_list)
            if(result.shape[0]):
                # --- nothing to do
                a = 1
            else:
                print("ERROR: reading DICOM data from files:", file_list)
        else:
            folders = glob.glob(os.path.join(folder, "*"))
            if(len(folders)):
                result = self.GetDICOMData(folders[0])
                
        return result
        
    def GetNiftiData(self, patient_id):
        nifti_data = nibabel.load(os.path.join(self.labels_src_folder, "label" + patient_id + ".nii"))
        return nifti_data.get_fdata()
        
    def ConsistencyCheck(self, data, label):
        result = False
    
        if(data.shape[0]):
            if(label.shape[0]):
                if(data.shape == label.shape):
                    result = True
                else:
                    print("ERROR: data shape is not equal to label shape")                
            else:
                print("ERROR: label shape is incorrect (", label.shape, ")")
        else:
            print("ERROR: data shape is incorrect (", data.shape, ")")
    
        return result
    
    def SaveTFRecords(self, patient_id, data, label):
        result = True
        
        with tf.io.TFRecordWriter(os.path.join(self.TFRECORD_FOLDER, patient_id + ".tfrecord")) as f:
            feature = {
                "original_shape": Feature(int64_list=tf.train.Int64List(value=data.shape())),
                "data": Feature(float_list=tf.train.FloatList(value=tf.random.normal([15]))),
            }

            example_proto = tf.train.Example(features = tf.train.Features(feature=feature))
            f.write(example_proto.SerializeToString())
        
        
        return result
    
    def PreprocessData(self, data, label):
        data_processed = (data - data.min()) / (data.max() - data.min())
        return data_processed, label
    
    def ReadSrcDataAndLabels_SaveAsTFRecords(self):
        folder_list = glob.glob(os.path.join(self.patients_src_folder, "*"))
        
        for folder in folder_list:
            patient_id = self.GetPatientIDFromFolder(folder)
            if(patient_id):
                # --- read data about patient
                print("Read data about patient", patient_id)
                
                src_data = self.GetDICOMData(folder)
                if(src_data.shape[0]):
                    # print("\tread from dicom files:", src_data.shape)
                    pass
                else:
                    print("ERROR: can't read DICOM data from folder:", folder)
                    
                label_data = self.GetNiftiData(patient_id)
                if(label_data.shape[0]):
                    # print("label mean:", label_data.get_fdata().mean())
                    pass
                else:
                    print("ERROR: can't find nifti labels:", patient_id)
                    
                if(self.ConsistencyCheck(src_data, label_data)):
                    # print("\tdata & labels are consistent")

                    src_data, label_data = self.PreprocessData(src_data, label_data)

                    if(self.SaveTFRecords(patient_id, src_data, label_data)):
                        pass
                    else:
                        print("ERROR: can't save TFRecord patient id:", patient_id)

                else:
                    print("ERROR: data & labels are not consistent")
            else:
                print("ERROR: identifying patient_id from folder:", folder)

ct150 = CT150(PATIENTS_SRC_FOLDER, LABELS_SRC_FOLDER, TFRECORD_FOLDER)
ct150.ReadSrcDataAndLabels_SaveAsTFRecords()

Read data about patient 0001


AttributeError: 'CT150' object has no attribute 'TFRECORD_FOLDER'

In [104]:
m = re.search("PANCREAS_(\d+)", "folder\\PANCREAS_00001")
if(m):
    print("yeah", m.group())
else:
    print("no")

yeah 00001
