In [1]:
from pathlib import Path

import cv2
import h5py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from torchvision.io import ImageReadMode, read_image, write_png
from torchvision.transforms import Normalize
from torchvision.transforms.functional import InterpolationMode, resize
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks

from model.model import BackboneType, MultiNet

In [2]:
def rle_to_mask(rle: str, height: int, width: int):
    runs = torch.tensor([int(x) for x in rle.split()])
    starts = runs[::2]
    lengths = runs[1::2]
    mask = torch.zeros([height * width], dtype=torch.uint8)

    for start, lengths in zip(starts, lengths):
        start -= 1
        end = start + lengths
        mask[start:end] = 255
    return mask.reshape((height, width))


def generate_mask(rle_lung_left, rle_lung_right, rle_heart, height, width):
    mask_lung_left = rle_to_mask(rle_lung_left, height=height, width=width)
    mask_lung_right = rle_to_mask(rle_lung_right, height=height, width=width)
    mask_heart = rle_to_mask(rle_heart, height=height, width=width)
    mask_lung = mask_lung_left + mask_lung_right
    foreground = (mask_lung + mask_heart) - (mask_lung * mask_heart)
    background = torch.abs(255 - foreground)
    return torch.stack([background, mask_lung, mask_heart])

In [3]:
# with h5py.File('/Volumes/storage/train_image.hdf5', 'r') as file:
#     dataset_image = file['image']
#     # dataset_label = file['label']
#     image = dataset_image[3]
#     # label = dataset_label[0]
#     plt.figure()
#     plt.imshow(np.transpose(image, [1, 2, 0]).astype(np.uint8))
#     plt.figure()
#     # plt.imshow(np.transpose(label, [1, 2, 0]).astype(np.uint8))

In [4]:
# torch.tensor(image, dtype=torch.uint8)

In [63]:
def get_ctr(output_tensor: torch.Tensor, dilation=False):
    heart_mask = (
        (output_tensor[0].sigmoid() > 0.5).permute([1, 2, 0])[..., 2].to(torch.uint8)
    )
    contours, _ = cv2.findContours(
        heart_mask.numpy(),
        cv2.RETR_TREE,
        cv2.CHAIN_APPROX_SIMPLE,
    )
    center_x_y, width_height, angle = cv2.minAreaRect(contours[0])
    w, h = width_height
    heart_width = w

    lung_mask = (
        (output_tensor[0].sigmoid() > 0.5).permute([1, 2, 0])[..., 1].to(torch.uint8)
    )
    contours, _ = cv2.findContours(
        lung_mask.numpy(),
        cv2.RETR_TREE,
        cv2.CHAIN_APPROX_SIMPLE,
    )
    center_x_y, width_height, angle = cv2.minAreaRect(contours[0])
    w, h = width_height
    lung_width = w
    return (heart_width + 0.0001) / (lung_width + 0.0001)


def generate_visualization(original_image, output_tensor, dilation=False):
    heart_mask = output_tensor[0].softmax(0).permute([1, 2, 0])[..., 2].to(torch.uint8)
    if dilation:
        heart_mask = cv2.dilate(heart_mask.numpy(), np.ones([5, 5]), iterations=5)
        x, y, w, h = cv2.boundingRect(heart_mask)
    else:
        x, y, w, h = cv2.boundingRect(heart_mask.numpy())

    image_bounding_box = draw_bounding_boxes(
        (original_image[0] * 255).to(torch.uint8),
        torch.tensor([[x, y, x + w, y + h]]),
        colors=(0, 0, 128),
        width=2,
    )
    lung_mask = output_tensor[0].softmax(0).permute([1, 2, 0])[..., 1].to(torch.uint8)
    x, y, w, h = cv2.boundingRect(lung_mask.numpy())
    image_bounding_box = draw_bounding_boxes(
        image_bounding_box,
        torch.tensor([[x, y, x + w, y + h]]),
        colors=(128, 0, 0),
        width=2,
    )
    image_bounding_box = draw_segmentation_masks(
        image_bounding_box,
        torch.Tensor(heart_mask).to(torch.bool),
        alpha=0.5,
        colors=(0, 0, 128),
    )
    image_bounding_box = draw_segmentation_masks(
        image_bounding_box,
        lung_mask.to(torch.bool),
        alpha=0.5,
        colors=(128, 64, 128),
    )
    return image_bounding_box

In [16]:
mask_pd = pd.read_csv(
    "data/cardiac/chestxray.csv",
    engine="pyarrow",
    index_col=0,
)
mask_pd.head(5)

Unnamed: 0_level_0,Dice RCA (Mean),Dice RCA (Max),Landmarks,Left Lung,Right Lung,Heart,Height,Width
Image Index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
00025787_047.png,0.826465,0.870102,"488,115,449,118,411,132,374,153,344,177,324,20...",200402 3 201422 11 202442 18 203463 24 204483 ...,118243 8 119254 24 120265 40 121281 51 122302 ...,448036 9 449048 27 450064 41 451086 48 452107 ...,1024,1024
00026251_000.png,0.833988,0.928886,"426,124,386,128,345,146,303,171,270,203,246,23...",136822 22 137826 46 138848 52 139870 58 140892...,127399 7 128413 21 129427 35 130441 49 131458 ...,445964 26 446985 33 448007 40 449028 47 450049...,1024,1024
00026194_002.png,0.759308,0.794598,"439,217,402,226,371,245,345,271,324,297,311,32...",189085 9 190096 27 191113 39 192135 45 193158 ...,221645 20 222646 45 223666 50 224686 56 225706...,395796 5 396816 13 397836 21 398856 29 399877 ...,1024,1024
00025227_012.png,0.903819,0.93099,"397,79,357,82,314,98,274,128,236,165,212,205,1...",100976 22 101998 45 103020 49 104042 53 105064...,81288 8 82299 23 83309 40 84325 50 85346 55 86...,498141 13 499161 29 500181 36 501201 43 502221...,1024,1024
00028166_003.png,0.850303,0.869093,"392,180,374,184,353,201,327,225,304,250,284,27...",163435 4 164455 12 165475 20 166496 27 167516 ...,184711 4 185731 12 186750 20 187770 27 188791 ...,464455 11 465455 37 466458 59 467469 74 468485...,1024,1024


In [15]:
pd_data = pd.read_csv("data/cardiac/chestxray/Data_Entry_2017.csv", index_col=0)
pd_data.head(5)

Unnamed: 0_level_0,Finding Labels,Follow-up #,Patient ID,Patient Age,Patient Gender,View Position,OriginalImage[Width,Height],OriginalImagePixelSpacing[x,y],Unnamed: 11
Image Index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
00000001_000.png,Cardiomegaly,0,1,58,M,PA,2682,2749,0.143,0.143,
00000001_001.png,Cardiomegaly|Emphysema,1,1,58,M,PA,2894,2729,0.143,0.143,
00000001_002.png,Cardiomegaly|Effusion,2,1,58,M,PA,2500,2048,0.168,0.168,
00000002_000.png,No Finding,0,2,81,M,PA,2500,2048,0.171,0.171,
00000003_000.png,Hernia,0,3,81,F,PA,2582,2991,0.143,0.143,


In [65]:
if torch.backends.mps.is_available():
    print("Using MPS engine")
    device = "mps"
elif torch.cuda.is_available():
    print("Using CUDA engine")
    device = "cuda"
else:
    print("Using CPU engine")
    device = "cpu"

device = "cpu"

model = MultiNet(numberClass=3, backboneType=BackboneType.RESNET50)
preprocessor = Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]).to(device)

# Load weights into model
model.load_state_dict(
    torch.load(
        "data/model/cardiac_model_new_new_new_new.pt",
        map_location=device,
    )
)

CARDIOMEGALY_THRESHOLD = 0.5
bunch_of_images = [
    x for x in Path("data/cardiac/chestxray/").glob("images_*/images/*.png")
]
true_positive, true_negative, false_positive, false_negative = 0, 0, 0, 0
for img_path in bunch_of_images[-200:]:
    # Load image
    file_path = str(img_path)
    image = read_image(file_path, ImageReadMode.RGB)
    image = (image / 255).float().unsqueeze(0)

    label_ctr_result = (
        pd_data.loc[str(img_path.name)]["Finding Labels"] == "Cardiomegaly"
    )
    model.eval()
    with torch.no_grad():
        original_image = resize(
            image, [512, 512], interpolation=InterpolationMode.NEAREST
        )
        image = preprocessor(original_image)
        output_tensor = model(image)
        pred_ctr = get_ctr(output_tensor, dilation=False)
        pred_is_cardiomegaly = pred_ctr > CARDIOMEGALY_THRESHOLD

        row_of_interest = mask_pd.loc[str(img_path.name)]
        mask_roi = generate_mask(
            rle_heart=row_of_interest["Heart"],
            rle_lung_left=row_of_interest["Left Lung"],
            rle_lung_right=row_of_interest["Right Lung"],
            height=row_of_interest["Height"],
            width=row_of_interest["Width"],
        )
        mask_roi = resize(mask_roi.float().unsqueeze(0), [512, 512])
        label_ctr = get_ctr(mask_roi)
        label_is_cardiomegaly = label_ctr > CARDIOMEGALY_THRESHOLD

        if pred_is_cardiomegaly and label_is_cardiomegaly:
            true_positive += 1
        elif pred_is_cardiomegaly and not label_is_cardiomegaly:
            false_positive += 1
        elif not pred_is_cardiomegaly and label_is_cardiomegaly:
            false_negative += 1
            # visualization_tensor = generate_visualization(
            #     original_image,
            #     output_tensor,
            #     dilation=True,
            # )
            # write_png(visualization_tensor, f"visualization_{str(img_path.name)}")
            # print(label_ctr, pred_ctr)
        else:
            true_negative += 1

        # if pred_is_cardiomegaly and label_ctr_result:
        #     true_positive += 1
        # elif pred_is_cardiomegaly and not label_ctr_result:
        #     false_positive += 1
        # elif not pred_is_cardiomegaly and label_ctr_result:
        #     false_negative += 1
        # else:
        #     true_negative += 1
        # print(pred_ctr, label_ctr, pred_is_cardiomegaly)

precision = (true_positive + 0.00001) / (true_positive + false_positive + 0.00001)
recall = (true_positive + 0.00001) / (true_positive + false_negative + 0.00001)
accuracy = (true_positive + true_negative) / (
    true_positive + true_negative + false_negative + false_positive
)

print("TruePositive TrueNegative FalsePositive FalseNegative")
print(true_positive, true_negative, false_positive, false_negative)
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(
    f"F1: {((2 * true_positive + 0.00001) / ((2 * true_positive) + false_positive + false_negative + 0.00001)):.4f}"
)

Using MPS engine
TruePositive TrueNegative FalsePositive FalseNegative
161 33 5 1
Accuracy: 0.9700
Precision: 0.9699
Recall: 0.9938
F1: 0.9817
