# Inference

### Imports

In [None]:
import sys
import numpy as np
import pandas as pd
import pydicom
import logging
import tensorflow as tf
from typing import Tuple, Dict, Optional, Union
from pathlib import Path

%matplotlib inline
import matplotlib.pyplot as plt
import keras

In [None]:
src_path: str = "../src"
sys.path.append(src_path)

In [None]:
logging.basicConfig(force=True)
logging.getLogger().setLevel(logging.INFO)

In [None]:
# optinally randomly sample this amount of images for training
IMG_SIZE: int = 260  # for EfficientNetB2
DATA_ROOT: Path = Path("../data")
DICOM_ROOT: Path = DATA_ROOT.joinpath("test_dicom")
OUTPUTS_DIR: Path = DATA_ROOT.joinpath("model_outputs")
XRAY_IMAGES_ROOT: Path = Path("/home/uziel/Downloads/nih_chest_x_rays")
CHECKPOINT_PATH: Path = OUTPUTS_DIR.joinpath("model_checkpoint")
MODEL_PATH: Path = OUTPUTS_DIR.joinpath("pneumonia_xray_classifier")
BEST_TH_PATH: Path = OUTPUTS_DIR.joinpath("best_th.txt")

## 1. Load and pre-process test DICOM images

In [None]:
# This function reads in a .dcm file, checks the important fields for our device,
# and returns a numpy array of just the imaging data
def check_dicom(dicom_file: Path) -> Union[None, Tuple[np.array, bool]]:
    """Load and check the validity of a DICOM file.

    Args:
        dicom_file: Path to a DICOM file.

    Return:
        Pixel data of DICOM image.
        A bool indicating if the image contains pneumonia.
    """
    ds = pydicom.dcmread(dicom_file)
    modality, body_part, label = (
        ds["Modality"].value,
        ds["BodyPartExamined"].value,
        ds["StudyDescription"].value,
    )

    if modality != "DX" or body_part != "CHEST":
        logging.error(f"Modality ({modality}) or body part ({body_part}) are invalid.")
        return None

    logging.info(
        "Image loaded successfully. "
        f"Modality: {modality}. Body part: {body_part}. Label: {label}"
    )

    return ds.pixel_array, "pneumonia" in label.lower()


def preprocess_image(img: np.array) -> tf.Tensor:
    """Ensure the dicom image has three channels, as expected by the network.

        NOTE: Further pre-processing steps are built-in model layers.

    Args:
        img: chest x-ray image as a 2D numpy array.

    Returns:
        Image with three channels.
    """
    img = tf.expand_dims(img, -1)
    return tf.reshape(tf.broadcast_to(img, (*img.shape[:2], 3)), (*img.shape[:2], 3))

In [None]:
loaded_dicoms = [check_dicom(dicom_file) for dicom_file in DICOM_ROOT.glob("*.dcm")]
dicom_dataset = [
    (preprocess_image(dicom_data[0]), dicom_data[1])
    for dicom_data in loaded_dicoms
    if dicom_data is not None
]

## 2. Load model and make predictions

In [None]:
def load_model(model_path: Path, checkpoint_path: Path):
    """Load model, its weights from a checkpoint and compile it.

    Args:
        model_path: Path to model.
        checkpoint_path: Path to model checkpoint containing model weights.

    Returns:
        A compiled model.
    """
    model = tf.keras.models.load_model(str(model_path))
    model.load_weights(checkpoint_path)
    model.compile()

    return model


def predict_image(model: tf.keras.Sequential, img: tf.Tensor, th: float) -> bool:
    """Use model to predict whether an image shows the presence of pneumonia.

    Args:
        model: A trained model to detect the presence of pneumonia on chest x-rays.
        img: A chest x-ray image.
        th: The threshold to determine the presence of pneumonia from the model
            probability output.

    Returns:
        A bool indicating whether the image is predicted to show presence of penumonia
            or not.
    """
    return (model.predict(tf.expand_dims(img, 0)) > th)[0][0]

In [None]:
model = load_model(MODEL_PATH, CHECKPOINT_PATH)
best_th = float(BEST_TH_PATH.read_text())

In [None]:
y_true_pred = [
    (label, predict_image(model, img, best_th)) for img, label in dicom_dataset
]

## 3. Evaluate predictions

In [None]:
pd.DataFrame(y_true_pred, columns=["Ground truth", "Predicted"])