# Fulmine LABS mini-PACs

## Overview

Fulmine Labs will use medical images for various quality/testing related, machine learning (ML) initiatives. 
The best practice for managing this data is to use Digital Imaging and Communications in Medicine (DICOM) standard compliant images with a PACS-like system.

The code in this project implements and tests a basic PACS with the following architecture:

```

[ Orthanc Repository (Open Source component) ]
       |
       | (DICOM Images) <----------------------------------------->  [ OHIF Viewer (Open Source component) ]
       v
[ Fulmine-Labs-Mini-PACS - Data Setup Script ]      
       |								
       | (Metadata and generated images)   
       |                                          
[ SQLite Database ]							
       |								
       | (API Requests)						 
       v
[ Flask Application ]
       |
       | (HTTP Requests for Data)
       v
[ Client (Pytest, Browser) ]
       |
       | (Model Training Data)  [ Data Enhancements ]
       v
[ Fulmine-Labs-Mini-PACS - Data Setup Script ]
       |
       | (Data Enhancements)
       v
[ Anomaly Detection Model Training ]


```

The data setup script will traverse all folders in a specified location, identify DICOM images and if they have appropriate Window Center and Width DICOM header information, will convert them to PNG files at another specified location and add the related metadata to an SQLite database.

The database maintains the Patient -> Study -> Series -> Image relationship, as well as tracking the output image file names and parameters used in their creation, allowing PACS-like SQL queries to be constructed.

Currently supported endpoints (usually at http://127.0.0.1:5000) are:

'/' - welcome message
'/patients/<patient_id>' - get patient information
'/studies/<study_id>' - get study information
'/series/<series_id>' - get series information
'/images/<image_id>' - get image information
'/patients/<patient_id>/studies' - get studies for a patient
'/patients/<patient_id>/studycount' - get study count for a patient
'/patients/<patient_id>/seriescount' - get series count for a patient
'/patients/<patient_id>/imagecount' - get image count for a patient
'/patients/<patient_id>/counts' - get all counts for a patient
'/patients/count' - get total patient count
'/studies/count' - get total studies count
'/series/count' - get total series count
'/images/count' - get total images count
'/imageinfo/' - get image info by providing the file name
Once PNG images have been generated from the DICOM images, these are used as the basis of the 'valid' class in an ML image classifier. To reduce overfitting, additional images are generated and added to the valid class. These include:

The same images with random window centers and widths
The same images with light random blurring to simulate pixel interpolation or compression
The same images with flips and rotations
The same images zoomed in and out, also with random window centers and widths
The 'invalid' class will be comprised of, for example:

Non-medical images selected from the Kaggle 'real and fake' dataset
The same valid images as above with simulated error/message boxes in order to help to detect anomalous conditions
Some custom anomalous images, including AI-generated medical images
All of the images above will be distributed randomly between training, validation and testing folders in order to train and test the model. In addition, in order to test how well the model recognizes previously unseen medical images of the same type, some custom images will be selected from Google searches and used only for testing.

Overall the

``` folder structure looks like this:

Orthanc (DICOM images)
  ├── subfolders

training (PNG)
  ├──train
  ├──validate
  ├──test

Kaggle_real_and_fake_images (PNG)
  ├── subfolders

Custom_invalid (MIX)
  ├── subfolders

Custom_test_valid

training_images
  ├──train
  │		├── valid
  │		│  original training
  │		│   ├── blurred
  │		│   └── window_leveled
  │		│   └── rotate_and_flip
  │		│   └── zoomed
  |             |     └── window_leveled
  │		└── invalid/
  │   		 	├── Kaggle_real_and_fake_images
  │    	 		├── copied from Custom_invalid
  ├──validate
  │		├── valid
  │		│  original validate
  │		│   ├── blurred
  │		│   └── window_leveled
  │		│   └── rotate_and_flip
  │		│   └── zoomed
  |             |     └── window_leveled
  │		└── invalid/
  │   		 	├── Kaggle_real_and_fake_images
  │    	 		├── copied from Custom_invalid
  ├──test
 	├── valid
 	│  original test
		├── copied from Custom_test_valid
 	└── invalid/

```
     	 	├── Kaggle_real_and_fake_images
      	 	├── copied from Custom_invalid

Once the data is prepared, the classifier model training is initiated. The model is saved and reloaded and used to test those images seected for testing, producing metrics on Accuracy, Precision, Recall and F1 score.


## Author
Duncan Henderson
Fulmine Labs LLC

In [1]:
# For the image generation pipeline
import os
import random
from random import sample
import pydicom
import sqlite3
import shutil 
import logging
from datetime import datetime
import numpy as np
from PIL import Image, ImageDraw, ImageFont, ImageOps
import textwrap
from scipy.ndimage import gaussian_filter, rotate
from PIL import ImageFilter

In [2]:
# For training and testing the model
from tensorflow import keras
from tensorflow.keras.layers import Conv2D, MaxPooling2D, UpSampling2D, Input
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import Flatten, Dense
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense
from keras.utils import to_categorical
import scipy
from matplotlib import pyplot as plt
from keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import img_to_array, load_img
from tensorflow.keras.layers import Dropout
from keras.callbacks import EarlyStopping




In [3]:
# Test run variables

# Define a verbose flag (set it to True for verbose output)
#verbose = True
verbose = False

source_dir = r'D:\\Orthanc'
target_dir = r'D:\\training' # The output PNG files will be written to the same folder name with _images appended
training_ratio, validation_ratio = 0.7, 0.15

# Variables to control database and image deletion
delete_db = True
delete_images = True
db_path = 'medical_imaging.db'

# Image dimensions
img_width, img_height = 152, 152

# Training parameters
batch_size = 32
epochs = 20 # Can increase the epochs since early stopping will handle overfitting
threshold = 0.5

# File name for saved model
model_name = 'lung_ct_classification_model.h5'

# Percentage of images to apply message_boxes to
message_box_percentage = 100 

# Maximum number of images to process (needs to approximately balance the number of valid images)
max_invalid_images = 15000 
# Maximum number of custom invalid images to process
max_custom_invalid_images = 100 
# Maximum number of custom valid images to process
max_custom_valid_images = 100 


In [4]:
# Append '_images' to the end of target_dir
training_images_dir = f"{target_dir}_images"

In [5]:
# Log to a log file that is specific for the test run and also to the screen if verbose is set

class CustomLogger:
    def __init__(self, verbose=False):
        self.verbose = verbose

        # Define log format to include date and time
        log_format = '%(asctime)s - %(levelname)s - %(message)s'
        log_filename = f'log_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'

        # Configure logging with UTF-8 encoding and specified log format
        logging.basicConfig(filename=log_filename, format=log_format, level=logging.INFO, encoding='utf-8', filemode='w', datefmt='%Y-%m-%d %H:%M:%S')

    def iprint(self, message):
        if self.verbose:
            print(message)
        logging.info(message)
        
    def eprint(self, message):
        print(message)
        logging.error(message)

In [6]:
def setup_database(db_path):

    logger.iprint ("In function setup_database")
    
    # Connect to SQLite database (this will create the database if it doesn't exist)
    
    try:
        conn = sqlite3.connect(db_path)
        cursor = conn.cursor()

        # Create tables
        cursor.execute('''CREATE TABLE IF NOT EXISTS Patients (
                            PatientID TEXT PRIMARY KEY,
                            PatientInfo TEXT);''')

        cursor.execute('''CREATE TABLE IF NOT EXISTS Studies (
                            StudyID TEXT PRIMARY KEY,
                            PatientID TEXT,
                            StudyDate TEXT,
                            StudyDescription TEXT,
                            BodyPartExamined TEXT,
                            FOREIGN KEY (PatientID) REFERENCES Patients (PatientID));''')

        cursor.execute('''CREATE TABLE IF NOT EXISTS Series (
                            SeriesID TEXT PRIMARY KEY,
                            StudyID TEXT,
                            SeriesDate TEXT,
                            SeriesDescription TEXT,
                            Modality TEXT,
                            FOREIGN KEY (StudyID) REFERENCES Studies (StudyID));''')
        
        cursor.execute('''CREATE TABLE IF NOT EXISTS Images (
                            ImageID TEXT PRIMARY KEY,
                            SeriesID TEXT,
                            FilePath TEXT,
                            State TEXT,
                            InstanceNumber TEXT,
                            FOREIGN KEY (SeriesID) REFERENCES Series (SeriesID));''')

        cursor.execute('''CREATE TABLE IF NOT EXISTS AugmentedImages (
                            AugmentedImageID INTEGER PRIMARY KEY AUTOINCREMENT,
                            ImageID TEXT,
                            PngFilePath TEXT,
                            Transformation TEXT,
                            WindowCenter TEXT,
                            WindowWidth TEXT,
                            RescaleIntercept TEXT,
                            RescaleSlope TEXT,
                            FOREIGN KEY (ImageID) REFERENCES Images (ImageID));''')
        conn.commit()
    except sqlite3.DatabaseError as e:
        print(f"Database error: {e}")
    finally:
        conn.close()

In [7]:
def is_dicom_file(file_path):

    logger.iprint ("In function is_dicom_file")
    
    try:
        pydicom.dcmread(file_path, stop_before_pixels=True)
        return True
    except Exception as e:
        logger.eprint(f"Error reading DICOM file {file_path}: {e}")
        return False

In [8]:
def clear_and_create_directory(directory):

    logger.iprint ("In function clear_and_create_directory")
    
    if os.path.exists(directory):
        shutil.rmtree(directory)
    os.makedirs(directory)

In [9]:
def copy_files(file_list, destination):

    logger.iprint ("In function copy_files ")
    
    try:
        os.makedirs(destination, exist_ok=True)
        for file_path in file_list:
            shutil.copy(file_path, os.path.join(destination, os.path.basename(file_path)))
    except (FileNotFoundError, PermissionError) as e:
        logger.eprint(f"File I/O error: {e}")

In [10]:
def extract_metadata(dicom_file_path):
    
    logger.iprint ("In function extract_metadata ")
    
    # Extract metadata like PatientID, StudyID, SeriesID, ImageID, Modality, and BodyPart
    ds = pydicom.dcmread(dicom_file_path, stop_before_pixels=True)
    patient_id = ds.PatientID
    study_id = ds.StudyInstanceUID
    series_id = ds.SeriesInstanceUID
    image_id = ds.SOPInstanceUID
    study_description = ds.StudyDescription if 'StudyDescription' in ds else 'N/A'
    series_description = ds.SeriesDescription if 'SeriesDescription' in ds else 'N/A'
    instance_number = ds.InstanceNumber if 'InstanceNumber' in ds else 'N/A'
    modality = ds.Modality if 'Modality' in ds else 'N/A'
    body_part_examined = ds.BodyPartExamined if 'BodyPartExamined' in ds else 'N/A'
    study_date = ds.StudyDate if 'StudyDate' in ds else 'N/A'
    series_date = ds.SeriesDate if 'SeriesDate' in ds else 'N/A'
 
    return patient_id, study_id, series_id, image_id, modality, body_part_examined, instance_number, study_description, series_description, study_date, series_date

In [11]:
def insert_metadata_into_db(cursor, patient_id, study_id, series_id, image_id, modality, body_part_examined, instance_number, study_description, series_description, study_date, series_date, file_path, state):
        
    logger.iprint ("In function insert_metadata_into_db")
    
    # Insert data into the Patients table
    cursor.execute("INSERT OR IGNORE INTO Patients (PatientID) VALUES (?)", (patient_id,))

    # Insert data into the Studies table
    cursor.execute("INSERT OR IGNORE INTO Studies (StudyID, PatientID, StudyDate, StudyDescription, BodyPartExamined) VALUES (?, ?, ?, ?, ?)",
                   (study_id, patient_id, study_date, study_description, body_part_examined))  # Replace "StudyDate" with actual values if needed

    # Insert data into the Series table
    cursor.execute("INSERT OR IGNORE INTO Series (SeriesID, StudyID, SeriesDate, SeriesDescription, Modality) VALUES (?, ?, ?, ?, ?)",
                   (series_id, study_id, series_date, series_description, modality))  # Replace "SeriesDate" with actual values if needed

    # Insert data into the Images table
    cursor.execute("INSERT OR IGNORE INTO Images (ImageID, SeriesID, FilePath, State, InstanceNumber) VALUES (?, ?, ?, ?, ?)", 
                   (image_id, series_id, file_path, state, instance_number))

In [12]:
def apply_rescale_and_window_level(dcm):
    """
    Apply the rescale slope and intercept, and window center and width to the DICOM image data.

    Parameters:
    - dcm: DICOM dataset.

    Returns:
    - rescaled_and_windowed_image: The image after applying rescale and window level.
    - window_center: Window center used for windowing.
    - window_width: Window width used for windowing.
    - rescale_slope: Rescale slope used for rescaling.
    - rescale_intercept: Rescale intercept used for rescaling.
    """
    
    logger.iprint ("In function apply_rescale_and_window_level")    
    
    # Apply rescale slope and intercept if available
    rescale_slope = getattr(dcm, 'RescaleSlope', 1)  # Default to 1 if not present
    rescale_intercept = getattr(dcm, 'RescaleIntercept', 0)  # Default to 0 if not present
    rescaled_image = dcm.pixel_array.astype(np.float64) * rescale_slope + rescale_intercept

    # Apply window center and width if available
    if hasattr(dcm, 'WindowCenter') and hasattr(dcm, 'WindowWidth'):
        window_center = dcm.WindowCenter
        window_width = dcm.WindowWidth
        if isinstance(window_center, pydicom.multival.MultiValue):
            window_center = window_center[0]
        if isinstance(window_width, pydicom.multival.MultiValue):
            window_width = window_width[0]
        window_center = float(window_center)
        window_width = float(window_width)

        low = window_center - window_width / 2
        high = window_center + window_width / 2
        rescaled_and_windowed_image = np.clip(rescaled_image, low, high)
    else:
        # If window level is not specified, use the rescaled image
        rescaled_and_windowed_image = rescaled_image
        window_center = None
        window_width = None

    return rescaled_and_windowed_image, window_center, window_width, rescale_slope, rescale_intercept

In [13]:
def normalize_image(image):
    
    logger.iprint ("In function normalize_image")    
    
    """Normalize the image data to 0-255 and convert to uint8."""
    image = image - np.min(image)
    image = image / np.max(image)
    image = (image * 255).astype(np.uint8)
    return image

In [14]:
def random_rotate_and_flip(image):
  
    logger.iprint ("In function random_rotate_and_flip")    
  
    try:
        # Convert PIL Image to numpy array for processing
        image_np = np.array(image)

        # Random rotation by 90, 180, or 270 degrees
        rotations = [0, 90, 180, 270]
        rotation_choice = random.choice(rotations)
        if rotation_choice != 0:
            image_np = np.rot90(image_np, rotation_choice // 90)  # np.rot90 expects k=1,2,3 for 90,180,270 degrees

        # Random flip
        if random.choice([True, False]):
            image_np = np.fliplr(image_np)
        if random.choice([True, False]):
            image_np = np.flipud(image_np)

        # Convert numpy array back to PIL Image
        rotated_and_flipped_image = Image.fromarray(image_np)

        return rotated_and_flipped_image
    except Exception as e:
        # Log the error or print it out
        logger.eprint("Error during rotation and flip: " + e)
        # Return the original image in case of error
        return image

In [15]:
def random_zoom(image):
     
    logger.iprint ("In function random_zoom")    
 
    # Randomly choose a zoom factor from 80% to 120%
    zoom_factor = random.uniform(0.8, 1.2)
    width, height = image.size
    
    # Calculate the new width and height based on the zoom factor
    new_width = int(width * zoom_factor)
    new_height = int(height * zoom_factor)
    
    # Resize the image
    resized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
    
    # If zooming out, pad the image with black pixels
    if zoom_factor < 1.0:
        delta_w = width - new_width
        delta_h = height - new_height
        padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
        return ImageOps.expand(resized_image, padding, fill=0)
    elif zoom_factor > 1.0:
        # If zooming in, crop the image to the original size
        return resized_image.crop(((new_width - width) // 2,
                                   (new_height - height) // 2,
                                   (new_width + width) // 2,
                                   (new_height + height) // 2))
    else:
        # If zoom_factor is 1, return the original image
        return image

In [16]:
def random_blur(image):
     
    logger.iprint ("In function random_blur")    
 
    # Randomly choose a blur radius from 0 to 5
    blur_radius = random.uniform(0, 5)
    
    # Apply Gaussian blur with the chosen radius
    blurred_image = image.filter(ImageFilter.GaussianBlur(blur_radius))
    
    return blurred_image

In [17]:
def construct_transformation_description(randomized_wl, blurred, rotate_and_flip, zoom, invert):
  
    logger.iprint ("In function construct_transformation_description")    
    
    transformations = []

    if randomized_wl:
        transformations.append("randomized_window_level")
    if blurred:
        transformations.append("blurred")
    if rotate_and_flip:
        transformations.append("rotated_flipped")
    if zoom:
        transformations.append("zoomed")
    if invert:
        transformations.append("invert")

    # Join all transformations with a plus sign or another separator that makes sense in your context
    transformation_description = "+".join(transformations) or "original"

    return transformation_description

In [18]:
def random_window(image, min_level=-700, max_level=100, min_width=1200, max_width=2000):

    logger.iprint ("In function random_window")    
    
    window_level = random.randint(min_level, max_level)
    window_width = random.randint(min_width, max_width)

    img_min = window_level - window_width / 2
    img_max = window_level + window_width / 2
    windowed_img = np.clip(image, img_min, img_max)

    windowed_img -= windowed_img.min()
    windowed_img /= windowed_img.max()
    return windowed_img

In [19]:
from PIL import ImageOps

def convert_dicom_to_png(dicom_dir, output_base_dir, state, cursor, randomized_wl=False, blurred=False, rotate_and_flip=False, zoom=False, invert=False):
    logger.iprint("In function convert_dicom_to_png")

    try:
        output_dir = os.path.join(output_base_dir, 
                                  "zoomed" if zoom else "",
                                  "blurred" if blurred else "", 
                                  "randomized_wl" if randomized_wl else "",
                                  "rotate_and_flip" if rotate_and_flip else "",
                                  "inverted" if invert else "")
        os.makedirs(output_dir, exist_ok=True)
        
        transformation_description = construct_transformation_description(randomized_wl, blurred, rotate_and_flip, zoom, invert)
  
        for entry in os.listdir(dicom_dir):
            logger.iprint("Processing DICOM file: " + entry)
            dicom_path = os.path.join(dicom_dir, entry)

            if os.path.isfile(dicom_path):
                try:
                    dcm = pydicom.dcmread(dicom_path)
                    if not hasattr(dcm, 'PixelData') and not hasattr(dcm, 'FloatPixelData') and not hasattr(dcm, 'DoubleFloatPixelData'):
                        raise ValueError("DICOM file does not contain image pixel data")
                    
                    rescaled_and_windowed_image, window_center, window_width, rescale_slope, rescale_intercept = apply_rescale_and_window_level(dcm)

                    normalized_image = normalize_image(rescaled_and_windowed_image)

                    pil_image = Image.fromarray(normalized_image.astype(np.uint8))

                    if zoom:
                        pil_image = random_zoom(pil_image)
                    if blurred:
                        pil_image = random_blur(pil_image)
                    if rotate_and_flip:
                        pil_image = random_rotate_and_flip(pil_image)
                    if invert:
                        pil_image = ImageOps.invert(pil_image)

                    png_filename = entry + '.png'
                    png_path = os.path.join(output_dir, png_filename)
                    pil_image.save(png_path)
                    
                    cursor.execute("INSERT INTO AugmentedImages (ImageID, PngFilePath, Transformation, WindowCenter, WindowWidth, RescaleIntercept, RescaleSlope) VALUES (?, ?, ?, ?, ?, ?, ?)", 
                                   (dcm.SOPInstanceUID, png_path, transformation_description, window_center, window_width, rescale_intercept, rescale_slope))
                    
                except Exception as e:
                    logger.iprint(f"Failed to convert: {dicom_path}, Error: {e}")

    except Exception as e:
        logger.eprint(f"Unexpected error: {e}")

In [20]:
def generate_error_message():

    logger.iprint ("In function generate_error_message")    

    base_messages = [
        "Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
        "Error: The operation could not be completed.",
        "Warning: System memory is running low.",
        "Alert: Unrecognized device detected.",        
        "System failure: Please contact support.",
        "Update failed: Please retry or check your connection."
    ]
    # Select a base message randomly
    message = random.choice(base_messages)
    
    # Optionally append a random error code or numerical detail
    if random.choice([True, False]):  # Decide randomly whether to add a number
        number = random.randint(100, 9999)
        # Randomly choose how to add the number (end as a code or within the text)
        if random.choice([True, False]):
            message += f" Code: {number}."
        else:
            parts = message.split()
            insert_at = random.randint(1, len(parts) - 1)
            parts.insert(insert_at, str(number))
            message = ' '.join(parts)
    
    return message

In [21]:
def apply_error_message_box(image_path, output_path, font_paths, message_box_percentage):

    logger.iprint ("In function apply_error_message_box")    
    
    if random.randint(1, 100) > message_box_percentage:
        # Skip this image; do not apply message_box
        return

    image = Image.open(image_path).convert('L')  # Ensure image is in grayscale
    draw = ImageDraw.Draw(image)

    # Random font and size
    font_path = random.choice(font_paths)
    font_size = random.randint(15, 25)  # Adjust range as needed
    font = ImageFont.truetype(font_path, font_size)

    # Generate message
    message = generate_error_message()

    # Wrap the message
    wrapped_text = "\n".join(textwrap.wrap(message, width=40))

    # Calculate the bounding box for the wrapped text
    text_bbox = draw.textbbox((0, 0), wrapped_text, font=font)
    text_width = text_bbox[2] - text_bbox[0]
    text_height = text_bbox[3] - text_bbox[1]

    # Determine box size based on text dimensions, making box_height randomly larger
    box_width = max(text_width + 20, image.width // 4)  # Ensure minimum width and add padding
    # Randomize the height a bit more significantly
    box_height = random.randint(text_height + 20, int(1.5 * image.height // 4))

    # Randomly calculate the box position to fit within the image
    max_x = max(image.width - box_width, 0)
    max_y = max(image.height - box_height, 0)
    box_x = random.randint(0, max_x)
    box_y = random.randint(0, max_y)

    # Choose a random grayscale value for the box fill
    box_fill = random.randint(0, 255)
    # Use white text if the box is dark, black otherwise
    text_color = 255 if box_fill < 128 else 0

    # Draw the message box
    draw.rectangle([box_x, box_y, box_x + box_width, box_y + box_height], fill=box_fill)

    # Draw the text, left and top justified within the box
    # By starting the text at the top left corner of the box (with a little padding)
    text_start_x = box_x + 10  # Add some padding from the left edge
    text_start_y = box_y + 10  # Add some padding from the top edge
    draw.text((text_start_x, text_start_y), wrapped_text, fill=text_color, font=font, align='left')

    # Save the modified image
    image.save(output_path)

In [22]:
def augment_images_with_message_boxes(source_dir, output_dir, message_box_percentage, font_path="arial.ttf"):

    logger.iprint ("In function augment_images_with_message_boxes")    
    
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    all_images = os.listdir(source_dir)
    num_images_to_augment = int(len(all_images) * (message_box_percentage / 100.0))
    images_to_augment = random.sample(all_images, num_images_to_augment)
    logger.iprint ("Num images to add message boxes to: " + num_images_to_augment)

    for image_file in images_to_augment:
        logger.iprint (image_file)
        if image_file.endswith(('.png', '.jpg', '.jpeg')):
            image_path = os.path.join(source_dir, image_file)
            output_path = os.path.join(output_dir, "message_box_" + image_file)
            apply_error_message_box(image_path, output_path)

In [23]:
def copy_random_images(source_dir, target_dir, class_name, max_images=None, train_ratio=0.7, validate_ratio=0.15, test_ratio=0.15, subfolder_name=None):

    logger.iprint ("In function copy_random_images")    
  
    """
    Copies a random subset of images from the source directory to the train, validate, and test directories within the target directory.
    Each set will be placed in a subdirectory named after the specified class, and optionally, within a further specified subfolder.

    Parameters:
    - source_dir: The directory containing the original images.
    - target_dir: The base directory where the train, validate, and test directories will be created.
    - class_name: The class name (e.g., 'valid' or 'invalid') specifying the subdirectory within train, validate, and test where images should be copied.
    - max_images: The maximum number of images to copy. If None, all images in source_dir are considered.
    - train_ratio: The fraction of images to copy to the training set.
    - validate_ratio: The fraction of images to copy to the validation set.
    - test_ratio: The fraction of images to copy to the testing set.
    - subfolder_name: Optional name of a subfolder within each set directory where images will be copied.
    """
    
    # Ensure the ratios sum to 1
    assert train_ratio + validate_ratio + test_ratio == 1, "Ratios must sum to 1"
    
    # Recursively get all image filenames in the source directory and its subdirectories
    all_images = []
    for root, dirs, files in os.walk(source_dir):
        for file in files:
            
            if os.path.isfile(os.path.join(root, file)):
                all_images.append(os.path.join(root, file))
    
    if max_images is not None and max_images < len(all_images):
        all_images = sample(all_images, max_images)
    
    # Calculate the number of images for each set
    num_train = int(len(all_images) * train_ratio)
    num_validate = int(len(all_images) * validate_ratio)
    num_test = len(all_images) - num_train - num_validate
    
    # Randomly select images for each set
    train_images = sample(all_images, num_train)
    validate_images = sample([img for img in all_images if img not in train_images], num_validate)
    test_images = [img for img in all_images if img not in train_images + validate_images]
    
    def copy_images(images, target_subdir):
        if subfolder_name:
            target_subdir = os.path.join(target_subdir, subfolder_name)
        os.makedirs(target_subdir, exist_ok=True)
        for img_path in images:
            filename = os.path.basename(img_path)
            target_file_path = os.path.join(target_subdir, filename)
            if os.path.abspath(img_path) != os.path.abspath(target_file_path):
                shutil.copy(img_path, target_file_path)
    
    # Copy images to their respective directories
    copy_images(train_images, os.path.join(target_dir, f'train\\{class_name}'))
    copy_images(validate_images, os.path.join(target_dir, f'validate\\{class_name}'))
    copy_images(test_images, os.path.join(target_dir, f'test\\{class_name}'))

In [24]:
def apply_message_boxes_to_dataset(source_base_dir, output_base_dir, subset, font_paths, message_box_percentage):

    logger.iprint ("In function apply_message_boxes_to_dataset")    
    
    source_dir = os.path.join(source_base_dir, subset, 'valid')
    output_dir = os.path.join(output_base_dir, subset, 'invalid', 'message_boxes')
    
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    for image_file in os.listdir(source_dir):
        if image_file.endswith(('.png', '.jpg', '.jpeg')):
            image_path = os.path.join(source_dir, image_file)
            output_path = os.path.join(output_dir, "message_box_" + image_file)
            apply_error_message_box(image_path, output_path, font_paths, message_box_percentage)


In [25]:
def is_dicom_file(file_path):
    
    logger.iprint ("In function is_dicom_file")    
    
    try:
        pydicom.dcmread(file_path, stop_before_pixels=True)
        return True
    except:
        return False

In [26]:
# Function to display the false negative image
def display_fn_image(image_path):
 
    logger.iprint ("In function display_fn_image") 
    
    image = load_img(image_path, color_mode='grayscale')
    plt.figure(figsize=(4, 4))
    plt.imshow(image, cmap='gray')
    plt.title("False Negative")
    plt.axis('off')
    plt.show()

In [27]:
# Function to display the false positive image
def display_fp_image(image_path):
     
    logger.iprint ("In function display_fp_image") 
    
    image = plt.imread(image_path)
    plt.figure(figsize=(5, 5))
    plt.imshow(image, cmap='gray')
    plt.title("False Positive")
    plt.axis('off')
    plt.show()

In [28]:
def preprocess_image(image_path, img_width, img_height):

    logger.iprint ("In function preprocess_image") 
   
    image = load_img(image_path, target_size=(img_width, img_height), color_mode='grayscale')
    image = img_to_array(image)
    image = np.expand_dims(image, axis=0)  # Add batch dimension
    image /= 255.0  # Normalize to [0, 1]
    return image

In [29]:
# Main code execution for image pipeline

logger = CustomLogger(verbose)

# Check if the database exists and delete it if delete_db is True
if delete_db and os.path.exists(db_path):
    os.remove(db_path)
    logger.iprint("Existing database removed.")
else:
    logger.iprint("Skipping database removal.")

In [30]:
setup_database(db_path)

In [31]:
setup_database(db_path)
# In your main function or processing script
conn = sqlite3.connect(db_path)
cursor = conn.cursor()

In [32]:
if delete_images and (os.path.exists(target_dir) or os.path.exists(training_images_dir)):
    clear_and_create_directory(target_dir)
    clear_and_create_directory(training_images_dir)
    logger.iprint("Existing images removed.")
else:
    logger.iprint("Skipping image removal.")

In [33]:
# Copy the DIDOM files from the image archive and add them to the database
dicom_files = []
for root, dirs, files in os.walk(source_dir):
    for file in files:
        file_path = os.path.join(root, file)
        if is_dicom_file(file_path):
            dicom_files.append(file_path)

random.shuffle(dicom_files)
total_files = len(dicom_files)
training_count = int(total_files * training_ratio)
validation_count = int(total_files * validation_ratio)

training_files = dicom_files[:training_count]
validation_files = dicom_files[training_count:training_count + validation_count]
test_files = dicom_files[training_count + validation_count:]

for file_path in training_files:
    patient_id, study_id, series_id, image_id, modality, body_part_examined, instance_number, study_description, series_description, study_date, series_date = extract_metadata(file_path)
    insert_metadata_into_db(cursor, patient_id, study_id, series_id, image_id, modality, body_part_examined, instance_number, study_description, series_description, study_date, series_date, file_path, 'train')

for file_path in validation_files:
    patient_id, study_id, series_id, image_id, modality, body_part_examined, instance_number, study_description, series_description, study_date, series_date = extract_metadata(file_path)
    insert_metadata_into_db(cursor, patient_id, study_id, series_id, image_id, modality, body_part_examined, instance_number, study_description, series_description, study_date, series_date, file_path, 'validate')

for file_path in test_files:
    patient_id, study_id, series_id, image_id, modality, body_part_examined, instance_number, study_description, series_description, study_date, series_date = extract_metadata(file_path)
    insert_metadata_into_db(cursor, patient_id, study_id, series_id, image_id, modality, body_part_examined, instance_number, study_description, series_description, study_date, series_date, file_path, 'test')

conn.commit()

copy_files(training_files, os.path.join(target_dir, 'train\\valid'))
copy_files(validation_files, os.path.join(target_dir, 'validate\\valid'))
copy_files(test_files, os.path.join(target_dir, 'test\\valid'))

logger.iprint(f"Total DICOM files: " + str(total_files))
logger.iprint(f"Training files: " + str(len(training_files)))
logger.iprint(f"Validation files: " + str(len(validation_files)))
logger.iprint(f"Test files: " + str(len(test_files)))


In [None]:
# After inserting DICOM metadata...
# Create the PNG files for training based on the DICOM tags
randomized_wl=False
blurred=False
rotate_and_flip=False
zoomed=False
invert=False

convert_dicom_to_png(os.path.join(target_dir, 'train\\valid'), os.path.join(target_dir, '..', 'training_images', 'train', 'valid'), 'train', cursor, randomized_wl, blurred, rotate_and_flip, zoomed, invert)
convert_dicom_to_png(os.path.join(target_dir, 'validate\\valid'), os.path.join(target_dir, '..', 'training_images', 'validate', 'valid'), 'validate', cursor, randomized_wl, blurred, rotate_and_flip, zoomed, invert)
convert_dicom_to_png(os.path.join(target_dir, 'test\\valid'),os.path.join(target_dir, '..', 'training_images', 'test', 'valid'), 'test', cursor, randomized_wl, blurred, rotate_and_flip, zoomed, invert)

conn.commit()

In [None]:
# The PNGs have been created from the original DICOM images 
# Now create copies of the same files with randomized window levels to reduce overfitting
randomized_wl=True
blurred=False
rotate_and_flip=False
zoomed=False
invert=False

convert_dicom_to_png(os.path.join(target_dir, 'train\\valid'), os.path.join(target_dir, '..', 'training_images', 'train', 'valid'), 'train', cursor, randomized_wl, blurred, rotate_and_flip, zoomed, invert)
convert_dicom_to_png(os.path.join(target_dir, 'validate\\valid'), os.path.join(target_dir, '..', 'training_images', 'validate', 'valid'), 'validate', cursor, randomized_wl, blurred, rotate_and_flip, zoomed, invert)
convert_dicom_to_png(os.path.join(target_dir, 'test\\valid'),os.path.join(target_dir, '..', 'training_images', 'test', 'valid'), 'test', cursor, randomized_wl, blurred, rotate_and_flip, zoomed, invert)

conn.commit()

In [None]:
# Now create copies of the same files with randomized window levels to reduce overfitting
from scipy.ndimage import gaussian_filter

randomized_wl=False
blurred=True
rotate_and_flip=False
zoomed=False
invert=False

convert_dicom_to_png(os.path.join(target_dir, 'train\\valid'), os.path.join(target_dir, '..', 'training_images', 'train', 'valid'), 'train', cursor, randomized_wl, blurred, rotate_and_flip, zoomed, invert)
convert_dicom_to_png(os.path.join(target_dir, 'validate\\valid'), os.path.join(target_dir, '..', 'training_images', 'validate', 'valid'), 'validate', cursor, randomized_wl, blurred, rotate_and_flip, zoomed, invert)
convert_dicom_to_png(os.path.join(target_dir, 'test\\valid'),os.path.join(target_dir, '..', 'training_images', 'test', 'valid'), 'test', cursor, randomized_wl, blurred, rotate_and_flip, zoomed, invert)

conn.commit()


In [None]:
# Now create copies of the same files with randomized rotation and flipping

randomized_wl=False
blurred=False
rotate_and_flip=True
zoomed=False
invert=False

convert_dicom_to_png(os.path.join(target_dir, 'train\\valid'), os.path.join(target_dir, '..', 'training_images', 'train', 'valid'), 'train', cursor, randomized_wl, blurred, rotate_and_flip, zoomed, invert)
convert_dicom_to_png(os.path.join(target_dir, 'validate\\valid'), os.path.join(target_dir, '..', 'training_images', 'validate', 'valid'), 'validate', cursor, randomized_wl, blurred, rotate_and_flip, zoomed, invert)
convert_dicom_to_png(os.path.join(target_dir, 'test\\valid'),os.path.join(target_dir, '..', 'training_images', 'test', 'valid'), 'test', cursor, randomized_wl, blurred, rotate_and_flip, zoomed, invert)

conn.commit()

In [None]:
# Now create copies of the same files with randomized zoom as well as window level to simulate some internet studies
from scipy.ndimage import gaussian_filter

randomized_wl=True
blurred=False
rotate_and_flip=False
zoomed=True
invert=False

convert_dicom_to_png(os.path.join(target_dir, 'train\\valid'), os.path.join(target_dir, '..', 'training_images', 'train', 'valid'), 'train', cursor, randomized_wl, blurred, rotate_and_flip, zoomed, invert)
convert_dicom_to_png(os.path.join(target_dir, 'validate\\valid'), os.path.join(target_dir, '..', 'training_images', 'validate', 'valid'), 'validate', cursor, randomized_wl, blurred, rotate_and_flip, zoomed, invert)
convert_dicom_to_png(os.path.join(target_dir, 'test\\valid'),os.path.join(target_dir, '..', 'training_images', 'test', 'valid'), 'test', cursor, randomized_wl, blurred, rotate_and_flip, zoomed, invert)

conn.commit()

In [None]:
# Database updates complete
conn.close()

In [None]:
# Copy some random invalid Kaggle real and fake images
source_dir = 'D:\\Kaggle_real_and_fake_images'
copy_random_images(source_dir, training_images_dir, 'invalid', max_invalid_images)

In [None]:
# Copy some random custom invalid images
source_dir = 'D:\\Custom_invalid'
copy_random_images(source_dir, training_images_dir, 'invalid', max_custom_invalid_images)

In [None]:
# Copy some random custom valid images (from internet sources) for testing only
source_dir = 'D:\\Custom_test_valid'
copy_random_images(source_dir, training_images_dir, 'valid', max_custom_valid_images, 0, 0, 1, "InternetTest")

In [None]:
# Add message_boxes to images
# Note: In this case we are reusing images already selected for training, validation and test
# a better approach might be to use images that were not selected

font_paths = ['arial.ttf', 'times.ttf']  # Add paths to any other fonts you'd like to use

# Apply message_boxes to each dataset subset
for subset in ['train', 'validate', 'test']:
    apply_message_boxes_to_dataset(training_images_dir, training_images_dir, subset, font_paths, message_box_percentage)

In [None]:
# increase the number of message box images as with 100% get some failures

# Apply message_boxes to each dataset subset
for subset in ['train', 'validate', 'test']:
    apply_message_boxes_to_dataset(training_images_dir, training_images_dir, subset, font_paths, message_box_percentage)

In [None]:
# The following code creates and tests the model based on the image pipeline 

# Rescale images from [0, 255] to [0, 1]
datagen = ImageDataGenerator(rescale=1. / 255)

In [None]:
train_data_dir = f'{target_dir}_images\\train'

train_generator = datagen.flow_from_directory(
    directory=train_data_dir,
    target_size=(img_width, img_height),  # Ensure this matches your image dimensions
    batch_size=batch_size,
    class_mode='binary',  # Assuming a binary classification setup
    color_mode='grayscale',  # Assuming grayscale images
    shuffle=True
)

# This will print out the image counts, confirming that images from 'valid' and 'invalid' are included
logger.iprint("Training found " + str(train_generator.samples) + " images belonging to " + str(train_generator.num_classes) + " classes.")

In [None]:
val_data_dir = f'{target_dir}_images\\validate'

val_generator = datagen.flow_from_directory(
    directory=val_data_dir,
    target_size=(img_width, img_height),  # Ensure this matches your image dimensions
    batch_size=batch_size,
    class_mode='binary',  # Assuming a binary classification setup
    color_mode='grayscale',  # Assuming grayscale images
    shuffle=True
)

# This will print out the counts, confirming that images from 'valid' and 'invalid' are included
logger.iprint("Validation found " + str(val_generator.samples) + " images belonging to " + str(val_generator.num_classes) + " classes.")

In [None]:
# Define the model architecture
input_img = Input(shape=(img_width, img_height, 1))  # 1 channel for grayscale images
x = Conv2D(32, (3, 3), activation='relu', padding='same')(input_img)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Dropout(0.25)(x)  # Dropout layer after MaxPooling
x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Dropout(0.25)(x)  # Dropout layer after MaxPooling
x = Flatten()(x)
x = Dense(64, activation='relu')(x)
x = Dropout(0.5)(x)  # Dropout layer before the output layer
classifier_output = Dense(1, activation='sigmoid', name='classifier_output')(x)  # Use sigmoid for binary classification

# Create and compile the model
model = Model(inputs=input_img, outputs=classifier_output)
model.compile(optimizer='adam',
              loss='binary_crossentropy',
              metrics=['accuracy'])

In [None]:
# EarlyStopping callback
early_stopping = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)

# Train the model with EarlyStopping
model.fit(
    train_generator,
    steps_per_epoch=len(train_generator),
    epochs=epochs,
    validation_data=val_generator,
    validation_steps=len(val_generator),
    callbacks=[early_stopping]  # Add the EarlyStopping callback here
)

In [None]:
# Save the model

model.save(model_name)

In [None]:
# Reload the model

autoencoder_classifier = keras.models.load_model(model_name)

In [None]:
# Test the model

true_positives = 0  # Correctly identified as normal (valid images)
false_negatives = 0  # Incorrectly identified as anomalous (valid images)
true_negatives = 0  # Correctly identified as anomalous (invalid images)
false_positives = 0  # Incorrectly identified as normal (invalid images)

In [None]:
valid_image_dir = f'{target_dir}_images\\test\\valid'

logger.iprint("Testing valid images")
# Recursively test with valid images (normal medical images)
for root, dirs, files in os.walk(valid_image_dir):
    for filename in files:
        if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
            image_path = os.path.join(root, filename)
            image = preprocess_image(image_path, img_width, img_height)

            classifier_pred = autoencoder_classifier.predict(image)
            logger.iprint("classifier_pred " + str(classifier_pred))
            predicted_class = (classifier_pred > threshold).astype(int)[0]

            if predicted_class == 1:  # Correctly identified as normal
                true_positives += 1
                logger.iprint("True positive")
            else:  # Incorrectly identified as anomalous
                false_negatives += 1
                logger.eprint("Image is:" + filename) 
                logger.eprint("False negative")
                display_fn_image(image_path)  # Display the image that was incorrectly classified as anomalous

In [None]:
logger.iprint("Testing invalid images")

invalid_image_dir = f'{target_dir}_images\\test\\invalid'
# Recursively test with invalid images (anomalous or non-medical images)
for root, dirs, files in os.walk(invalid_image_dir):
    for filename in files:
        if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
            image_path = os.path.join(root, filename)
            image = preprocess_image(image_path, img_width, img_height)

            # Predict the class of the image
            classifier_pred = autoencoder_classifier.predict(image)
            logger.iprint("classifier_pred" + str(classifier_pred))
            predicted_class = (classifier_pred > threshold).astype(int)[0]
            
            if predicted_class == 0:  # Correctly identified as anomalous
                true_negatives += 1
                logger.iprint("True negative")
            else:  # Incorrectly identified as normal
                logger.eprint("Image is:" + filename) 
                false_positives += 1
                logger.eprint("False positive")
                # Optionally, display the image that was incorrectly classified as invalid
                display_fp_image(image_path)

In [None]:
# Calculate performance metrics
accuracy = (true_positives + true_negatives) / (true_positives + true_negatives + false_positives + false_negatives)
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

logger.iprint(f"True Positives: {true_positives}")
logger.iprint(f"False Positives: {false_positives}")
logger.iprint(f"True Negatives: {true_negatives}")
logger.iprint(f"False Negatives: {false_negatives}")
logger.iprint(f"Accuracy: {accuracy:.4f}")
logger.iprint(f"Precision: {precision:.4f}")
logger.iprint(f"Recall: {recall:.4f}")
logger.iprint(f"F1 Score: {f1_score:.4f}")

In [None]:
print(f"True Positives: {true_positives}")
print(f"False Positives: {false_positives}")
print(f"True Negatives: {true_negatives}")
print(f"False Negatives: {false_negatives}")
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1_score:.4f}")