# Inference


### Imports


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
import logging
import numpy as np
import pandas as pd

from pydicom import dcmread
from typing import Tuple, Dict, Optional, Union
from pathlib import Path
from torch import Tensor
from rich import traceback
from torchvision.models import MobileNet_V3_Large_Weights

%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
src_path: str = "../src"
sys.path.append(src_path)
_ = traceback.install()

In [None]:
from utils import dictify_dicom
from datasets.pneumonia_dicom_dataset import PneumoniaDicomDataset
from models.pneumonia_classifier import PneumoniaClassifier

In [None]:
logging.basicConfig(force=True)
logging.getLogger().setLevel(logging.INFO)

In [None]:
REQUIRED_TRANSFORMS = MobileNet_V3_Large_Weights.DEFAULT.transforms()
DATA_ROOT: Path = Path("../data")
DICOM_ROOT: Path = DATA_ROOT.joinpath("test_dicom")
OUTPUTS_DIR: Path = DATA_ROOT.joinpath("model_outputs")
XRAY_IMAGES_ROOT: Path = Path("/home/uziel/Downloads/nih_chest_x_rays")
LOGS_PATH: Path = OUTPUTS_DIR.joinpath("mobilenet_v3_large")
BEST_TH_PATH: Path = OUTPUTS_DIR.joinpath("best_th.txt")

## 1. Load and pre-process test DICOM images


In [None]:
def gather_dicoms(dicom_root: Path) -> pd.DataFrame:
    """Gather dicom files, register metadata and dicom file path.

    Args:
        dicom_root: Directory under which all dicom files are located.

    Returns:
        A dataframe including metadata of each DICOM file.
    """
    dicom_meta = {}
    for dicom_file in dicom_root.glob("**/*.dcm"):
        dicom_meta[dicom_file.stem] = {
            "file_path": str(dicom_file.resolve()),
            **dictify_dicom(dcmread(dicom_file)),
        }

    dicom_meta_df = pd.DataFrame(dicom_meta).transpose()
    dicom_meta_df.columns = [c.lower().replace(" ", "_") for c in dicom_meta_df.columns]

    return dicom_meta_df.drop(columns="pixel_data").sort_index().rename_axis("id")

In [None]:
dicoms_data = gather_dicoms(DICOM_ROOT)
dicoms_data

In [None]:
dicoms_dataset = PneumoniaDicomDataset(dicoms_data, transform=REQUIRED_TRANSFORMS)

## 2. Load model and make predictions


In [None]:
def load_model(checkpoint_path: Path):
    """Load model, its weights from a checkpoint and compile it.

    Args:
        model_path: Path to model.
        checkpoint_path: Path to model checkpoint containing model weights.

    Returns:
        A compiled model.
    """
    model = PneumoniaClassifier.load_from_checkpoint(checkpoint_path)
    model.eval()

    return model


def predict_image(model: PneumoniaClassifier, img: Tensor, th: float) -> bool:
    """Use model to predict whether an image shows the presence of pneumonia.

    Args:
        model: A trained model to detect the presence of pneumonia on chest x-rays.
        img: A chest x-ray image.
        th: The threshold to determine the presence of pneumonia from the model
            probability output.

    Returns:
        A bool indicating whether the image is predicted to show presence of penumonia
            or not.
    """
    return (model(img.unsqueeze(0)) > th)[0][0]

In [None]:
model = load_model(list(LOGS_PATH.glob("**/*.ckpt"))[0])
best_th = float(BEST_TH_PATH.read_text())

In [None]:
y_true_pred = [
    (label.bool(), predict_image(model, img, best_th)) for img, label in iter(dicoms_dataset)
]

## 3. Evaluate predictions


In [None]:
pd.DataFrame(y_true_pred, columns=["Ground truth", "Predicted"]).astype(bool)