In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt

from tqdm import tqdm
from sklearn.metrics import auc

from Exp3_VAE_utils import *

In [None]:
DATA_PATH = "data"
MODEL_PATH = "models/VAE.pt"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
config = {
    'input_size': 256,
    'first_ch': 32,
    'latent_channels': 40,
    'conv1': {'kernel_size': 4, 'stride': 2, 'padding': 1}, # 256 -> 128
    'conv2': {'kernel_size': 2, 'stride': 2, 'padding': 0}, # 128 -> 64
    'conv3': {'kernel_size': 4, 'stride': 2, 'padding': 1}, # 64 -> 32
    'conv4': {'kernel_size': 3, 'stride': 1, 'padding': 1}, # 32 -> 16
    'conv5': {'kernel_size': 4, 'stride': 2, 'padding': 1}, # 16 -> 8
    'deconv1': {'kernel_size': 4, 'stride': 2, 'padding': 1}, # 8 -> 16
    'deconv2': {'kernel_size': 3, 'stride': 1, 'padding': 1}, # 16 -> 32
    'deconv3': {'kernel_size': 4, 'stride': 2, 'padding': 1}, # 32 -> 64
    'deconv4': {'kernel_size': 4, 'stride': 2, 'padding': 1}, # 64 -> 128
    'deconv5': {'kernel_size': 2, 'stride': 2, 'padding': 0}, # 128 -> 256
    'out': {'kernel_size': 3, 'stride': 1, 'padding': 1}, # 256 -> 256
}
model = Exp3VariationalAutoEncoder(config)
model.load_state_dict(torch.load(MODEL_PATH)['model_state_dict'])

In [None]:
checkpoint_path_idx = 0
test_data_loader = load_dataset(DATA_PATH, BraTS2020Dataset, train=False, healthy=False, batch_size=1, shuffle=False)

("Model loaded from checkpoint.")

image_index_to_save = [0, 1, 6, 16, 18, 27, 28]
# Test the model on one picture
model.eval()
model.to(DEVICE)
for i, sample in enumerate(test_data_loader):
    if i == 30:
        break
    with torch.no_grad():
        test_data = sample['image'].to(DEVICE)
        (q1, q2), mu, logvar = model(test_data)
        # Plot the original image
        plt.figure(figsize=(10, 6))
        plt.subplot(2, 4, 1)
        # Add a title for all the plots
        plt.suptitle("Model Evaluation on Sample {}".format(i))
        plt.title("Original Channel 1")
        # show the first channel of the image
        first_channel = test_data.cpu().numpy()[0, 0]
        plt.imshow(first_channel, cmap='gray')
        # Plot the ground truth mask
        plt.subplot(2, 4, 2)
        plt.title("Ground Truth Mask")
        mask = sample['mask'].numpy()[0, 0]
        plt.imshow(mask, cmap='gray')
        # Plot the reconstruction
        plt.subplot(2, 4, 3)
        plt.title("Reconstruction")
        plt.imshow(q2.cpu()[0, 0], cmap='gray')
        # Plot the rejection mask
        plt.subplot(2, 4, 4)
        plt.title("Rejection Mask")
        reject_mask = calculate_rejection_mask(test_data, q2, q2 - q1, threshold=0.05)[0].squeeze()
        plt.imshow(reject_mask, cmap='gray')
        # Plot the original image
        plt.subplot(2, 4, 5)
        plt.title("Original Channel 2")
        # show the second channel of the image
        second_channel = test_data.cpu().numpy()[0, 1]
        plt.imshow(second_channel, cmap='gray')
        # Plot the original image
        plt.subplot(2, 4, 6)
        plt.title("Original Channel 3")
        # show the third channel of the image
        third_channel = test_data.cpu().numpy()[0, 2]
        plt.imshow(third_channel, cmap='gray')
        # Plot the reconstruction
        plt.subplot(2, 4, 7)
        plt.imshow(q1.cpu()[0, 0], cmap='gray')

        # Save the image if it is in the list
        if i in image_index_to_save:
            pic_to_save = sample["image"]
            print(pic_to_save.shape)
            pic_to_save = pic_to_save.cpu().squeeze().numpy().transpose(1, 2, 0)
            # Save each channel as a separate image losslessly
            for j in range(3):
                plt.imsave(f"pictures/sample_{i}_channel_{j}.png", pic_to_save[:, :, j], cmap='gray')
        # Save the ground truth mask
        if i in image_index_to_save:
            mask_to_save = sample["mask"]
            mask_to_save = mask_to_save.cpu().squeeze().numpy()
            plt.imsave(f"pictures/mask_{i}.png", mask_to_save, cmap='gray')

            
        

In [None]:
# Load the saved images and test the model on them
for i in image_index_to_save:
    image = np.zeros((3, 256, 256))
    for j in range(3):
        image_to_be_loaded = plt.imread(f"pictures/sample_{i}_channel_{j}.png")
        # Turn the image from 4 channels to 1
        image[j] = image_to_be_loaded[:, :, 0]
    image = torch.tensor(image).unsqueeze(0).float().to(DEVICE)
    with torch.no_grad():
        (q1, q2), mu, logvar = model(image)
        # Plot the original image
        plt.figure(figsize=(10, 6))
        plt.subplot(2, 4, 1)
        # Add a title for all the plots
        plt.suptitle("Model Evaluation on Sample {}".format(i))
        plt.title("Original Channel 1")
        # show the first channel of the image
        first_channel = image.cpu().numpy()[0, 0]
        plt.imshow(first_channel, cmap='gray')
        # Plot the ground truth mask
        plt.subplot(2, 4, 2)
        plt.title("Ground Truth Mask")
        mask = plt.imread(f"pictures/mask_{i}.png")
        plt.imshow(mask, cmap='gray')
        # Plot the reconstruction
        plt.subplot(2, 4, 3)
        plt.title("Reconstruction")
        plt.imshow(q2.cpu()[0, 0], cmap='gray')
        # Plot the rejection mask
        plt.subplot(2, 4, 4)
        plt.title("Rejection Mask")
        reject_mask = calculate_rejection_mask(image, q2, q2 - q1, threshold=0.02)[0].squeeze()
        plt.imshow(reject_mask, cmap='gray')
        # Plot the original image
        plt.subplot(2, 4, 5)
        plt.title("Original Channel 2")
        # show the second channel of the image
        second_channel = image.cpu().numpy()[0, 1]
        plt.imshow(second_channel, cmap='gray')
        # Plot the original image
        plt.subplot(2, 4, 6)
        plt.title("Original Channel 3")
        # show the third channel of the image
        third_channel = image.cpu().numpy()[0, 2]
        plt.imshow(third_channel, cmap='gray')
        # Plot the reconstruction
        plt.subplot(2, 4, 7)
        plt.imshow(q1.cpu()[0, 0], cmap='gray')
        plt.show()


In [None]:
def test_qr_vae(device, model, h_test_loader, uh_test_loader, num_thresholds=20):
    model.eval()
    model.to(device)

    # Generate a list of thresholds from 0 to 1
    thresholds = np.linspace(0, 0.985, num=num_thresholds)

    def cal_tpr_fpr(mask, image, mean_recon, std_recon, thresholds):
        tprs = np.zeros_like(thresholds)
        fprs = np.zeros_like(thresholds)
        for i, threshold in enumerate(tqdm(thresholds)):
            label = calculate_rejection_mask(image, mean_recon, std_recon, threshold=threshold)
            # print(f"Label: {label.shape}")
            label = label.flatten()
            tp = np.sum(mask & label)
            fp = np.sum(~mask & label)
            tn = np.sum(~mask & ~label)
            fn = np.sum(mask & ~label)
            if tp + fn == 0:
                tpr = 1
            else:
                tpr = tp / (tp + fn)
            if fp + tn == 0:
                fpr = 1
            else:
                fpr = fp / (fp + tn)
            tprs[i] += tpr
            fprs[i] += fpr
            # print(f"TP: {tp}, FP: {fp}, TN: {tn}, FN: {fn}, TPR: {tpr}, FPR: {fpr}")
        return tprs, fprs

    data = next(iter(h_test_loader))
    images = data['image'].to(device)
    masks = np.zeros(images.shape[0]*256*256, dtype=bool)
    length = images.shape[0]
    print(f"Length: {length}")
    with torch.no_grad():
        recon, _, _ = model(images)
        mean_recon = recon[1]
        std_recon = mean_recon - recon[0]

        tprs_1, fprs_1 = cal_tpr_fpr(masks, images, mean_recon, std_recon, thresholds)
        

    data = next(iter(uh_test_loader))
    images = data['image'].to(device)
    
    # Assuming 'mask' indicates unhealthy regions
    masks = data['mask'].numpy().flatten()
    length += images.shape[0]
    print(f"Length: {length}")

    with torch.no_grad():
        recon, _, _ = model(images)
        mean_recon = recon[1]
        std_recon = mean_recon - recon[0]

        tprs_2, fprs_2 = cal_tpr_fpr(masks, images, mean_recon, std_recon, thresholds)

    return (tprs_1 + tprs_2) / 2, (fprs_1 + fprs_2) / 2


In [None]:
h_test_loader = load_dataset(DATA_PATH, BraTS2020Dataset, train=False, healthy=True, batch_size=128)
uh_test_loader = load_dataset(DATA_PATH, BraTS2020Dataset, train=False, healthy=False, batch_size=128)
tprs, fprs = test_qr_vae(DEVICE, model, h_test_loader, uh_test_loader)
auc_val = auc(fprs, tprs)

In [None]:
# print(f"fprs: {fprs}, tprs: {tprs}, auc: {auc_val}")
# Plotting the ROC Curve
plt.figure()
lw = 2
plt.plot(fprs, tprs, color='darkorange',
            lw=lw, label='ROC curve (area = %0.2f)' % auc_val)
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic across all data')
plt.legend(loc="lower right")
plt.show()
