In [1]:
from data.dataset import REFLACXWithClinicalDataset
from model.xami import  XAMIMultiCocatModal

from utils.gradcam import get_df_label_pred_img_input_loss, show_gradCAMpp_result

import torch
import os
import pandas as pd
from utils.train import get_aus_loss
import matplotlib.pylab as plt
import pandas as pd
from utils.plot import relabel_ellipse_df, get_ellipses_patch
from utils.transform import TransformFuncs
from copy import copy
from tqdm.notebook import trange
import PIL
from data.data_loader import MIMICDataloader
from eye_tracking.plot import draw_fixations, draw_scanpath, draw_heatmap, draw_raw, get_fixations_dict_from_reflacx_eye_tracking

In [2]:
pd.options.mode.chained_assignment = None
plt.ioff()

<matplotlib.pyplot._IoffContext at 0x224422d7880>

In [3]:
# checking if the GPU is available
use_gpu = torch.cuda.is_available()
device = 'cuda' if use_gpu else 'cpu'
print(f"Will be using {device}")

Will be using cuda


In [4]:
# load the dataset
reflacx_dataset = REFLACXWithClinicalDataset(image_size=256)

Positive Loss weight:
[0.768546   0.78635013 0.7937685  0.8338279  0.87388724]
Negative Loss weight:
[0.231454   0.21364985 0.20623146 0.1661721  0.12611276]
Random Loss:
0.21026036153991162


In [5]:
# model_name = 'test_0_8162_epoch300_WithoutClincal_dim32_2022-02-07 21_43_31_353207'

model_name = 'test_0_8260__epoch300_WithClincal_dim32_2022-02-08 10_03_56_953198'

In [6]:
# get the trained model
xami_mutlimodal = XAMIMultiCocatModal(
    reflacx_dataset,
    device,
    use_clinical=True,
    model_dim=32, # was 64
    embeding_dim=64,
    dropout=.2,  # increase the dropout rate did improve the regularization.
    pretrained=True
)
xami_mutlimodal.load_state_dict(torch.load(os.path.join('saved_models', model_name ), map_location=device))

xami_mutlimodal = xami_mutlimodal.to(device)
xami_mutlimodal.eval()


XAMIMultiCocatModal(
  (image_net): ImageDenseNet(
    (model_ft): DenseNet(
      (features): Sequential(
        (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu0): ReLU(inplace=True)
        (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        (denseblock1): _DenseBlock(
          (denselayer1): _DenseLayer(
            (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (relu1): ReLU(inplace=True)
            (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (relu2): ReLU(inplace=True)
            (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          )
          (denselayer

In [7]:
loss_fn = get_aus_loss(reflacx_dataset)

In [8]:
saving_folder = f'{model_name}_result'
os.makedirs(saving_folder, exist_ok=True)

In [9]:
XAMI_MIMIC_PATH = "D:\XAMI-MIMIC"
mimic_dataloader = MIMICDataloader(XAMI_MIMIC_PATH)

In [10]:
model_input_image_size = reflacx_dataset.image_size
transform_fun = TransformFuncs(image_size=model_input_image_size)

for i in trange(len(reflacx_dataset)):

    df, labels_df, pred_df, img, loss, model_input = get_df_label_pred_img_input_loss(
        xami_mutlimodal, loss_fn, reflacx_dataset, i, device)

    image_size_x = df.iloc[0]['image_size_x']
    image_size_y = df.iloc[0]['image_size_y']

    image_path = df.iloc[0]['image_path']
    subject_id = df.iloc[0]['subject_id']
    reflacx_id = df.iloc[0]['id']

    instance_save_dir = os.path.join(saving_folder, df.iloc[0]['id'])
    os.makedirs(instance_save_dir, exist_ok=True)

    ellipse_df = pd.read_csv(df['anomaly_location_ellipses_path'].iloc[0])
    relabeled_ellipse_df = relabel_ellipse_df(ellipse_df)

    eye_tracking_path = mimic_dataloader.get_reflacx_eye_tracking_path(
        subject_id, reflacx_id)

    if os.path.isfile(eye_tracking_path):
        fix = get_fixations_dict_from_reflacx_eye_tracking(
            pd.read_csv(eye_tracking_path))

        eye_tracking_heatmap_fig = draw_heatmap(fix, imagefile=image_path)
        eye_tracking_heatmap_fig.canvas.draw()
        eye_tracking_heatmap_img = PIL.Image.frombytes('RGB',
                                                    eye_tracking_heatmap_fig.canvas.get_width_height(
                                                    ), eye_tracking_heatmap_fig.canvas.tostring_rgb()
                                                    )

        fig, ax = plt.subplots(1, figsize=(10, 10), dpi=80, sharex=True)
        ax.imshow(transform_fun.display_transform(eye_tracking_heatmap_img))
        fig.savefig(os.path.join(instance_save_dir, 'eye_tracking_heatmap.png'))

    fig, ax = plt.subplots(1, figsize=(10, 10), dpi=80, sharex=True)
    ax.imshow(transform_fun.display_transform(img))
    fig.savefig(os.path.join(instance_save_dir, f'cxr.png'))

    for d in reflacx_dataset.labels_cols:
        ellipes = get_ellipses_patch(
            relabeled_ellipse_df, d, image_size_x, image_size_y, model_input_image_size)

        if len(ellipes) > 0:
            # plot the original image with bouding boxes
            fig, ax = plt.subplots(1, figsize=(10, 10), dpi=80, sharex=True)
            ax.imshow(transform_fun.display_transform(img))

            for e in ellipes:
                ax.add_patch(copy(e))

            fig.savefig(os.path.join(instance_save_dir,
                        f'{d}_cxr_with_ellipse_{len(ellipes)}.png'))

            # plot eye tracking data with bouding boxes.
            if os.path.isfile(eye_tracking_path):
                fig, ax = plt.subplots(1, figsize=(10, 10), dpi=80, sharex=True)
                ax.imshow(transform_fun.display_transform(
                    eye_tracking_heatmap_img))
                for e in ellipes:
                    ax.add_patch(copy(e))
                fig.savefig(os.path.join(instance_save_dir,
                            f'{d}_eye_tracking_heatmap_with_ellipse_{len(ellipes)}.png'))

        # plot the gradcam images with bouding boxes.
        gradcam_img = show_gradCAMpp_result(
            reflacx_dataset, xami_mutlimodal, d, img, model_input, use_full_features=True)

        fig, ax = plt.subplots(1, figsize=(10, 10), dpi=80, sharex=True)
        ax.imshow(gradcam_img)
        fig.savefig(os.path.join(instance_save_dir, f'{d}_cxr_gradcampp.png'))

        if len(ellipes) > 0:
            for e in ellipes:
                ax.add_patch(copy(e))
            fig.savefig(os.path.join(instance_save_dir,
                        f'{d}_cxr_gradcampp_with_ellipse_{len(ellipes)}.png'))

        plt.close('all')

    # store prediction data and loss as well.
    for l_col in reflacx_dataset.labels_cols:
        df[f"pred_{l_col}"] = None
        df.at[0, f"pred_{l_col}"] = pred_df.at[0, l_col]

    df.at[0, 'loss'] = loss

    df.to_csv(os.path.join(instance_save_dir, "df.csv"))


  0%|          | 0/674 [00:00<?, ?it/s]