In [157]:
import pydicom
import numpy as np
import glob
from pathlib import Path
import matplotlib.pylab as plt
from PIL import Image
from io import BytesIO

from src import sqlite_functions
from src import gen_hash


In [158]:
database = Path('../Feb2024_Test_Extraction/dicom_processing/MRI_Segmentation-4_Participants_Data.sqlite3')
#database = Path('../Feb2024_Test_Extraction/dicom_processing/MRI_Segmentation-Selected_Daphne_Data.sqlite3')

sqlite_connection = sqlite_functions.connect_to_database(database, timeout=30.0)

Trying to connect to ../Feb2024_Test_Extraction/dicom_processing/MRI_Segmentation-4_Participants_Data.sqlite3...

Connected to ../Feb2024_Test_Extraction/dicom_processing/MRI_Segmentation-4_Participants_Data.sqlite3!
SQLite Database Version is: [('3.44.2',)]



In [159]:
cursor = sqlite_connection.cursor()
version_query = "select sqlite_version()"
cursor.execute(version_query)
version_record = cursor.fetchall()

In [160]:
def get_column_values(connection, table, column):
    cursor = connection.cursor()
    query = f"SELECT {column} FROM {table}"
    cursor.execute(query)
    column_values = cursor.fetchall()
    return column_values

def get_column_a_for_column_b(connection, table, column_a, column_b, value_b):
    cursor = connection.cursor()
    query = f"SELECT {column_a} FROM {table} WHERE {column_b} = '{value_b}'"
    cursor.execute(query)
    column_values = cursor.fetchall()
    return column_values

def get_column_a_for_column_b_and_column_c(connection, table, column_a, column_b, value_b, column_c, value_c):
    cursor = connection.cursor()
    query = f"SELECT {column_a} FROM {table} WHERE {column_b} = '{value_b}' AND {column_c} = '{value_c}'"
    cursor.execute(query)
    column_values = cursor.fetchall()
    return column_values

def convert_image_bytes_to_numpy(image_data, image_dims):
    image_bytes = BytesIO(image_data)
    image_np = np.load(image_bytes, allow_pickle=True).astype(np.uint8)
    image_np = image_np.reshape(image_dims)
    return image_np

def convert_timeseries_bytes_to_numpy(timeseries_data):
    timeseries_bytes = BytesIO(timeseries_data)
    timeseries_np = np.load(timeseries_bytes, allow_pickle=True).astype(np.float32)
    return timeseries_np

def create_image_file(connection, table, patient_id, dataofexam, image_hash):
    """
    Create an image file with the scan data for a given image_hash.
    The file name will be a combination of the patient_id, dataofexam, and image_hash.
    """
    cursor = connection.cursor()
    query = f"SELECT image, image_dims FROM {table} WHERE image_hash = '{image_hash}'"
    cursor.execute(query)
    image_collection = cursor.fetchall()
    image_data = image_collection[0][0]
    image_dims = eval(image_collection[0][1])
    image_np = convert_image_bytes_to_numpy(image_data, image_dims)
    image = Image.fromarray(image_np)
    image.save(f"{patient_id}_{dataofexam}_{image_hash[:8]}.png")

def create_image_mask(image_np):
    img_mask = np.ones(image_np.shape).astype(np.uint8)
    img_mask[0:31, 0:600, :] = 0
    return img_mask

def plot_image_file(connection, table, image_hash):
    """
    Create an image file with the scan data for a given image_hash.
    The file name will be a combination of the patient_id, dataofexam, and image_hash.
    """
    cursor = connection.cursor()
    query = f"SELECT image, image_dims FROM {table} WHERE image_hash = '{image_hash}'"
    cursor.execute(query)
    image_collection = cursor.fetchall()
    image_data = image_collection[0][0]
    image_dims = eval(image_collection[0][1])
    image_np = convert_image_bytes_to_numpy(image_data, image_dims)
    img_mask = create_image_mask(image_np)
    image_np = image_np * img_mask
    image = Image.fromarray(image_np)
    image.show()


def extract_all_images_for_patient(connection, table, patient_id):
    """
    Extract all images for a given patient_id.
    """
    cursor = connection.cursor()
    query = f"SELECT image, image_dims, dateofexam, image_hash FROM {table} WHERE patient_id = '{patient_id}'"
    cursor.execute(query)
    image_collection = cursor.fetchall()
    for image_data, image_dims, dateofexam, image_hash in image_collection:
        image_dims = eval(image_dims)
        image_np = convert_image_bytes_to_numpy(image_data, image_dims)
        image = Image.fromarray(image_np)
        image.save(f"example_data/{patient_id}/{patient_id}_{dateofexam}_{image_hash[:8]}.png")

def extract_and_mask_all_images_for_patient(connection, table, patient_id):
    """
    Extract all images for a given patient_id.
    """
    cursor = connection.cursor()
    query = f"SELECT image, image_dims, dateofexam, image_hash FROM {table} WHERE patient_id = '{patient_id}'"
    cursor.execute(query)
    image_collection = cursor.fetchall()
    for image_data, image_dims, dateofexam, image_hash in image_collection:
        image_dims = eval(image_dims)
        image_np = convert_image_bytes_to_numpy(image_data, image_dims)
        img_mask = create_image_mask(image_np)
        image_np = image_np * img_mask
        image = Image.fromarray(image_np)
        image.save(f"{patient_id}_{dateofexam}_{image_hash[:8]}.png")


def clean_up_and_replace_images(connection, table):
    """
    This will extract a list of all image hashes from the database and then pull each image in turn,
    clean it (using the larger mask), and then replace it in the database.
    """
    image_hashes = get_column_values(connection, table, 'image_hash')
    for image_hash in image_hashes:
        image_hash = image_hash[0]
        cursor = connection.cursor()
        query = f"SELECT image, image_dims FROM {table} WHERE image_hash = '{image_hash}'"
        cursor.execute(query)
        image_collection = cursor.fetchall()
        image_data = image_collection[0][0]
        image_dims = eval(image_collection[0][1])
        image_bytes = BytesIO(image_data)
        image_np = np.load(image_bytes, allow_pickle=True).astype(np.uint8)
        image_np = image_np.reshape(image_dims)
        img_mask = create_image_mask(image_np)
        image_np = image_np * img_mask
        image_bytes = BytesIO()
        np.save(image_bytes, image_np)
        image_bytes.seek(0)
        image_data = image_bytes.getvalue()
        query = f"UPDATE {table} SET image = ? WHERE image_hash = '{image_hash}'"
        cursor.execute(query, [image_data])
        connection.commit()


def extract_numpy_data(connection, table, column, image_hash):
    cursor = connection.cursor()
    query = f"SELECT {column} FROM {table} WHERE image_hash = '{image_hash}'"
    cursor.execute(query)
    column_data = cursor.fetchall()
    try:
        column_np = convert_timeseries_bytes_to_numpy(column_data[0][0])
    except:
        column_np = None
    return column_np

def extract_and_save_signal_and_taxis(connection, table):
    image_hashes = get_column_values(connection, table, 'image_hash')
    for image_hash in image_hashes:
        image_hash = image_hash[0]
        signal_np = extract_numpy_data(connection, table, 'signal', image_hash)
        taxis_np = extract_numpy_data(connection, table, 't_axis', image_hash)
        query = f"SELECT patient_id, dateofexam FROM {table} WHERE image_hash = '{image_hash}'"
        cursor = connection.cursor()
        cursor.execute(query)
        patient_id, dateofexam = cursor.fetchall()[0]
        if signal_np is not None:
            np.savetxt(f"example_data/{patient_id}/signal_{patient_id}_{dateofexam}_{image_hash[:8]}.csv", signal_np)
        if taxis_np is not None:
            np.savetxt(f"example_data/{patient_id}/t_axis_{patient_id}_{dateofexam}_{image_hash[:8]}.csv", taxis_np)


In [161]:
image_hashes = get_column_a_for_column_b(sqlite_connection, table_name, 'image_hash', 'patient_id', 'DAPHNE-14')

In [162]:
create_image_file(connection=sqlite_connection, table='images', patient_id='DAPHNE-14', dataofexam='2024-02-01', image_hash=image_hashes[0][0])
#extract_and_mask_all_images_for_patient(connection=sqlite_connection, table='images', patient_id='DAPHNE-43')

In [116]:
extract_all_images_for_patient(connection=sqlite_connection, table='images', patient_id='DAPHNE-1')

In [137]:
#plot_image_file(connection=sqlite_connection, table='images', image_hash=image_hashes[0][0])

signal_test = extract_signal(sqlite_connection, 'images', image_hashes[4][0])

In [142]:
signal_bytes = BytesIO(signal_test[0][0])
signal_np = np.load(signal_bytes, allow_pickle=True).astype(np.float32)

In [156]:
extract_and_save_signal_and_taxis(connection=sqlite_connection, table='images')

In [114]:
clean_up_and_replace_images(connection=sqlite_connection, table='images')

In [155]:
patients = get_column_values(sqlite_connection, 'images', 'patient_id')
for patient in np.unique(patients):
    if not Path(f"example_data/{patient}").exists():
        Path(f"example_data/{patient}").mkdir()
    extract_all_images_for_patient(connection=sqlite_connection, table='images', patient_id=patient)

In [9]:
image_data = data_out[0][-2]

In [10]:
np.shape(image_data)

(852, 1136, 3)

In [None]:
plt.imshow(image_data[:,:,:])
plt.axis('off')
plt.show()

#data_out[0][7]

In [None]:
data_out[0]