In [None]:
# Alperen Erol - 200051583 - Unsupervised Anomaly Detection on Medical Images project notebook

In [None]:
import os
import time
import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib
import seaborn as sns
import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, L1Loss
from tqdm import tqdm
import sklearn

from monai import transforms
from monai.config import print_config
from monai.data import DataLoader, Dataset
from monai.utils import first, set_determinism

from GenerativeModels.generative.inferers import VQVAETransformerInferer
from GenerativeModels.generative.networks.nets import VQVAE, DecoderOnlyTransformer
from GenerativeModels.generative.utils.enums import OrderingType
from GenerativeModels.generative.utils.ordering import Ordering

from skimage.metrics import structural_similarity as ssim

In [None]:
# Min-max normalisation function
def min_max_normalize(tensor):
    tensor_min = tensor.min()
    tensor_max = tensor.max()
    normalized_tensor = (tensor - tensor_min) / (tensor_max - tensor_min)
    return normalized_tensor

In [None]:
# Custom dataset class to load .npz files (image-only)

In [None]:
class NPZDataset(Dataset):
    def __init__(self, root_dir, indices=None, transform=None):
        self.root_dir = root_dir
        self.files = [f for f in os.listdir(root_dir) if f.endswith('.npz')]
        if indices is not None:
            self.files = [self.files[i] for i in indices]
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file_path = os.path.join(self.root_dir, self.files[idx])
        data = np.load(file_path)
        sample = data["image"]
        
        # image data
        # Adds channel dimension
        sample = np.expand_dims(sample, axis=0)

        # Adds batch dimension and resizes to 256x256 using bicubic interpolation
        sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0)  # Add batch dimension
        sample = F.interpolate(sample, size=(128, 128), mode='bicubic', align_corners=False)
        
        # Normalize the sample
        sample = min_max_normalize(sample) 
        
        sample = sample.squeeze(0)  # Removes batch dimension
        
        if self.transform:
            sample = self.transform(sample)
        
        return sample.clone().detach().to(dtype=torch.float32)

In [None]:
# Creates and loads of training set and validation set

In [None]:
root_dir = "normal_slices_training"

train_data = NPZDataset(root_dir, indices=list(range(12000)))
val_data_normal = NPZDataset(root_dir="normal_slices_training", indices=list(range(12000, 14000)))
val_data = val_data_normal

In [None]:
train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_data, batch_size=32, shuffle=True, num_workers=4)

In [None]:
# Check if data normalised

check_data = next(iter(val_loader))[0]

print("This is minimum value of Sample tensor:",torch.min(check_data))
print("This is maximum value of Sample tensor:",torch.max(check_data))

In [None]:
# Plot 3 examples from the training set

In [None]:
check_data = next(iter(train_loader))
fig, ax = plt.subplots(nrows=1, ncols=3)  # Added figsize for larger images
for image_n in range(3):
    ax[image_n].imshow(check_data[image_n].squeeze(), cmap="gray")  # Squeeze out the channel dimension
    ax[image_n].axis("off")
plt.show()

In [None]:
# VQ-VAE : Define network, optimizer and losses

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")

vqvae_model = VQVAE(
    spatial_dims=2,
    in_channels=1,
    out_channels=1,
    num_res_layers=2,
    downsample_parameters=((2, 4, 1, 1), (2, 4, 1, 1)),
    upsample_parameters=((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),
    num_channels=(256, 256),
    num_res_channels=(256, 256),
    num_embeddings=16,
    embedding_dim=64,
)
vqvae_model.to(device)

In [None]:
# VQ-VAE model total parameters
total_params = sum(p.numel() for p in vqvae_model.parameters())
print(f"Number of parameters: {total_params}")

In [None]:
# VQ-VAE Model training

In [None]:
optimizer = torch.optim.Adam(params=vqvae_model.parameters(), lr=5e-4) # Optimiser and learning rate
l1_loss = L1Loss()
n_epochs = 100 # training epoch adjusted here
val_interval = 10 # validation interval
epoch_losses = []
val_epoch_losses = []

#Early stopping params
patience = 10  # Number of epochs with no improvement to wait before stopping
best_loss = float('inf')
epochs_without_improvement = 0

total_start = time.time()
for epoch in range(n_epochs):
    vqvae_model.train()
    epoch_loss = 0
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), ncols=110)
    progress_bar.set_description(f"Epoch {epoch}")
    for step, batch in progress_bar:
        images = batch.to(device)
        optimizer.zero_grad(set_to_none=True)

        # model outputs reconstruction and the quantization error
        reconstruction, quantization_loss = vqvae_model(images=images)
        recons_loss = l1_loss(reconstruction.float(), images.float())
        loss = recons_loss + quantization_loss

        loss.backward()
        optimizer.step()

        epoch_loss += recons_loss.item()

        progress_bar.set_postfix(
            {"recons_loss": epoch_loss / (step + 1), "quantization_loss": quantization_loss.item() / (step + 1)}
        )
    epoch_losses.append(epoch_loss / (step + 1))
    

    # Validation
    vqvae_model.eval()
    val_loss = 0
    with torch.no_grad():
        for val_step, batch_valid in enumerate(val_loader, start=1):
            images_valid = batch_valid.to(device)
            reconstruction, quantization_loss = vqvae_model(images=images_valid)
            recons_loss = l1_loss(reconstruction.float(), images_valid.float())
            val_loss += recons_loss.item()

    val_loss /= val_step
    val_epoch_losses.append(val_loss)
    
    if (epoch + 1) % val_interval == 0:
        torch.save(vqvae_model.state_dict(), f"demo_vqvae_training_epoch_{epoch}.pth")
        
        # At the end of each epoch original/reconstruction image plot from validation set
        with torch.no_grad():
            sample_images = images_valid[:3]  # Taking first 3 images from the last batch
            reconstructions = vqvae_model(sample_images)[0]  # Getting reconstructions

            fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(12, 8))
            for i in range(3):
                axes[0, i].imshow(sample_images[i].squeeze().cpu().numpy(), cmap='gray')
                axes[0, i].set_title(f"Original {i+1}")
                axes[0, i].axis('off')

                axes[1, i].imshow(reconstructions[i].squeeze().cpu().numpy(), cmap='gray')
                axes[1, i].set_title(f"Reconstruction {i+1}")
                axes[1, i].axis('off')
            plt.show()
        
        # After computing validation loss:
        if val_loss < best_loss:
            best_loss = val_loss
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= patience:
            print("Early stopping!")
            break

total_time = time.time() - total_start
print(f"train completed, total time: {total_time}.")

In [None]:
# Plotting Training and Validation Losses

In [None]:
epochs = range(1, len(epoch_losses) + 1)

plt.figure(figsize=(10, 6))
plt.plot(epochs, epoch_losses, 'b', label='Training loss')
plt.plot(epochs, val_epoch_losses, 'r', label='Validation loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
# Load pre-trained VQ-VAE model after training

In [None]:
# Load the saved state dictionary of pre-trained VQ-VAE model
saved_state_dict = torch.load("vqvae_training5_epoch_99.pth")  # Replace with the actual path

# Load the state dictionary into the VQ-VAE model
vqvae_model.load_state_dict(saved_state_dict)

In [None]:
# Calculate Average SSIM of Normal Images Only Validation Set

In [None]:
from skimage.metrics import structural_similarity as ssim

val_loader_ssim = DataLoader(val_data, batch_size=1, shuffle=True, num_workers=4) # Batch-size defined as 1 for easy process
total_ssim = 0
total_images = 0
vqvae_model.eval()

with torch.no_grad():  # No need to track gradients
    for data in val_loader_ssim:
        images = data.to(device)  # Adjust based on your dataset structure
        original_images = images.squeeze().cpu().numpy()
        reconstructed_images = vqvae_model(images)[0].squeeze().cpu().numpy()
        #plt.imshow(reconstructed_images)
        #print(reconstructed_images)
        for original, reconstructed in zip(original_images, reconstructed_images):
            # If images have multiple channels, convert them to grayscale or calculate SSIM per channel
            ssim_value = ssim(original, reconstructed, data_range=1)  # Using the first channel (if multiple channels)
            total_ssim += ssim_value
            total_images += 1

average_ssim = total_ssim / total_images
print(f'Average SSIM: {average_ssim:.2f}')

In [None]:
# Calculate MAE(L1 loss)

In [None]:
l1_loss = L1Loss()
vqvae_model.eval()
val_loss = 0

with torch.no_grad():
    for val_step, batch_valid in enumerate(val_loader, start=1):
        images_valid = batch_valid.to(device)
        reconstruction, quantization_loss = vqvae_model(images=images_valid)
        recons_loss = l1_loss(reconstruction.float(), images_valid.float())
        val_loss += recons_loss.item()

val_loss /= val_step
print(val_loss)

In [None]:
# Projection of 2D latent representation to 1D sequence

In [None]:
# Get spatial dimensions of data

test_data = next(iter(train_loader)).to(device)
spatial_shape = vqvae_model.encode_stage_2_inputs(test_data).shape[2:] # quantizations

# Initialize an Ordering class that projects a 2D image into a 1D sequence.
ordering = Ordering(ordering_type=OrderingType.RASTER_SCAN.value, spatial_dims=2, dimensions=(1,) + spatial_shape)

In [None]:
# Define auto-regressive transformer network, VQ-VAE/Transformer inferer, optimizer and loss function

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transformer_model = DecoderOnlyTransformer(
    num_tokens=16 + 1,
    max_seq_len=spatial_shape[0] * spatial_shape[1],
    attn_layers_dim=128,
    attn_layers_depth=16,
    attn_layers_heads=16,
)
transformer_model.to(device)

inferer = VQVAETransformerInferer()

In [None]:
# Optimizer and Loss function of transformer
optimizer = torch.optim.Adam(params=transformer_model.parameters(), lr=5e-3)
ce_loss = CrossEntropyLoss()

In [None]:
# Transformer Training

In [None]:
n_epochs = 150 #training epoch adjusted here
val_interval = 10 #validation interval
epoch_losses = []
val_epoch_losses = []
vqvae_model.eval()

#Early stopping params
patience = 10  # Number of epochs with no improvement to wait before stopping
best_loss = float('inf')
epochs_without_improvement = 0

total_start = time.time()
for epoch in range(n_epochs):
    transformer_model.train()
    epoch_loss = 0
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), ncols=110)
    progress_bar.set_description(f"Epoch {epoch}")
    for step, batch in progress_bar:

        images = batch.to(device)

        optimizer.zero_grad(set_to_none=True)

        logits, target, _ = inferer(images, vqvae_model, transformer_model, ordering, return_latent=True)
        logits = logits.transpose(1, 2)

        loss = ce_loss(logits, target)

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        progress_bar.set_postfix({"ce_loss": epoch_loss / (step + 1)})
    epoch_losses.append(epoch_loss / (step + 1))
    
    #Validation
    transformer_model.eval()
    val_loss = 0
    with torch.no_grad():
        for val_step, batch in enumerate(val_loader, start=1):

            images = batch.to(device)

            logits, quantizations_target, _ = inferer(
                images, vqvae_model, transformer_model, ordering, return_latent=True
            )
            logits = logits.transpose(1, 2)

            loss = ce_loss(logits[:, :, :-1], quantizations_target[:, 1:])

            val_loss += loss.item()
            
    if (epoch + 1) % val_interval == 0:
        torch.save(transformer_model.state_dict(), f"demo_transformer_model_training_epoch_{epoch}.pth")
        
        # get and show sample images generated by transformer
        sample = inferer.sample(
            vqvae_model=vqvae_model,
            transformer_model=transformer_model,
            ordering=ordering,
            latent_spatial_dim=(spatial_shape[0], spatial_shape[1]),
            starting_tokens=vqvae_model.num_embeddings * torch.ones((1, 1), device=device),
        )
        plt.imshow(sample[0, 0, ...].cpu().detach())
        plt.title(f"Sample epoch {epoch}")
        plt.show()
        val_loss /= val_step
        val_epoch_losses.append(val_loss)
        val_loss /= val_step
        val_epoch_losses.append(val_loss)
        
        # After computing validation loss:
        if val_loss < best_loss:
            best_loss = val_loss
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= patience:
            print("Early stopping!")
            break

total_time = time.time() - total_start
print(f"train completed, total time: {total_time}.")

In [None]:
# Load pre-trained transformer model

In [None]:
# Load the saved state dictionary
saved_state_dict = torch.load("transformer_model_training_25092023_epoch_99.pth")  # Replace with the actual path

# Load the state dictionary into the VQ-VAE model
transformer_model.load_state_dict(saved_state_dict)

In [None]:
# Switch models to evaluation mode
vqvae_model.eval()
transformer_model.eval()

In [None]:
# get sample image from pre-trained Transformer model
sample = inferer.sample(
    vqvae_model=vqvae_model,
    transformer_model=transformer_model,
    ordering=ordering,
    latent_spatial_dim=(spatial_shape[0], spatial_shape[1]),
    starting_tokens=vqvae_model.num_embeddings * torch.ones((1, 1), device=device),
)
plt.imshow(sample[0, 0, ...].cpu().detach())
#plt.title(f"Sample epoch {epoch}")
plt.show()

In [None]:
# Image-wise anomaly detection

In [None]:
# load normal and abnormal test data

In [None]:
vqvae_model.eval()
transformer_model.eval()

test_data_normal = NPZDataset(root_dir="normal_slices_training", indices=list(range(14000, 16000)))
test_data_normal_loader = DataLoader(test_data_normal, batch_size=32, shuffle=True, num_workers=4)

test_data_abnormal = NPZDataset(root_dir="abnormal_slices", indices=list(range(1600)))
test_data_abnormal_loader = DataLoader(test_data_abnormal, batch_size=32, shuffle=True, num_workers=4)

In [None]:
# get normal distribution log likelihood

In [None]:
normal_likelihoods = []

progress_bar = tqdm(enumerate(test_data_normal_loader), total=len(test_data_normal_loader), ncols=110)
progress_bar.set_description(f"Normal-distribution data")
for step, batch in progress_bar:
    images = batch.to(device)

    log_likelihood = inferer.get_likelihood(
        inputs=images, vqvae_model=vqvae_model, transformer_model=transformer_model, ordering=ordering
    )
    normal_likelihoods.append(log_likelihood.sum(dim=(1, 2)).cpu().numpy())

normal_likelihoods = np.concatenate(normal_likelihoods)

In [None]:
# get abnormal distribution log likelihood

In [None]:
abnormal_likelihoods = []

progress_bar = tqdm(enumerate(test_data_abnormal_loader), total=len(test_data_abnormal_loader), ncols=110)
progress_bar.set_description(f"Abnormal-distribution data")
for step, batch in progress_bar:
    images = batch.to(device)

    log_likelihood = inferer.get_likelihood(
        inputs=images, vqvae_model=vqvae_model, transformer_model=transformer_model, ordering=ordering
    )
    abnormal_likelihoods.append(log_likelihood.sum(dim=(1, 2)).cpu().numpy())

abnormal_likelihoods = np.concatenate(abnormal_likelihoods)

In [None]:
# Normal and Abnormal Log-likelihood plot

In [None]:
sns.set_style("whitegrid", {"axes.grid": False})
sns.kdeplot(normal_likelihoods, bw_adjust=1, label="Normal-distribution", fill=True, cut=True)
sns.kdeplot(abnormal_likelihoods, bw_adjust=1, label="Abnormal-distribution", cut=True, fill=True)
plt.legend(loc="upper right")
plt.xlabel("Log-likelihood")

In [None]:
# Calculate AUC-ROC score, True Positive Rate, False Positive Rate

In [None]:
import numpy as np
from sklearn.metrics import roc_curve, auc

# log-likelihood values for two classes
log_likelihood_class0 = abnormal_likelihoods
log_likelihood_class1 = normal_likelihoods

# Determination of threshold
threshold = -1515.0 # intersection point observed from the log-likelihood plot

# Combine the log-likelihood values and labels
log_likelihood = np.concatenate((log_likelihood_class0, log_likelihood_class1))
likelihood_labels = np.concatenate((np.zeros_like(log_likelihood_class0), np.ones_like(log_likelihood_class1)))

# Classify observations based on the threshold
predicted = (log_likelihood > threshold).astype(int)

# Calculate the TPR and FPR
fpr, tpr, _ = roc_curve(likelihood_labels, predicted)

# Compute the AUC-ROC
roc_auc = auc(fpr, tpr)

print(f"True Positive Rate: {tpr}")
print(f"False Positive Rate: {fpr}")
print(f"AUC-ROC: {roc_auc}")


In [None]:
# Visualizing the ROC Curve

In [None]:
import matplotlib.pyplot as plt

plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(loc="lower right")
plt.show()


In [None]:
# Localised anomaly detection,segmentation and healing on a synthetic abnormal image

In [None]:
input_image = first(test_data_normal_loader)
image_clean = input_image[0, ...]
plt.subplot(1, 2, 1)
plt.imshow(image_clean[0, ...], cmap="gray")
plt.axis("off")
plt.title("Clean image")

In [None]:
image_corrupted = image_clean.clone()
image_corrupted[0, 50:80, 40:60] = 1
plt.subplot(1, 2, 2)
plt.imshow(image_corrupted[0, ...], cmap="gray")
plt.axis("off")
plt.title("Corrupted image")
plt.show()

In [None]:
# Get the log-likelihood and convert into a mask

In [None]:
log_likelihood = inferer.get_likelihood(
    inputs=image_corrupted[None, ...].to(device),
    vqvae_model=vqvae_model,
    transformer_model=transformer_model,
    ordering=ordering,
)
likelihood = torch.exp(log_likelihood)
plt.subplot(1, 2, 1)
plt.imshow(likelihood.cpu()[0, ...])
plt.axis("off")
plt.title("Log-likelihood")
plt.subplot(1, 2, 2)
mask = log_likelihood.cpu()[0, ...] < torch.quantile(log_likelihood, 0.04).item()
# Further mask with the healing mask
resizer = torch.nn.Upsample(size=(128, 128), mode="nearest")
mask_upsampled = resizer(mask[None, None, ...].float()).int().squeeze()
plt.imshow(mask_upsampled)
plt.axis("off")
plt.title("Healing mask")
plt.show()

In [None]:
# Use this mask and the trained transformer to 'heal' the sequence

In [None]:
# flatten the mask
mask_flattened = mask.reshape(-1)
mask_flattened = mask_flattened[ordering.get_sequence_ordering()]

latent = vqvae_model.index_quantize(image_corrupted[None, ...].to(device))
latent = latent.reshape(latent.shape[0], -1)
latent = latent[:, ordering.get_sequence_ordering()]
latent = F.pad(latent, (1, 0), "constant", vqvae_model.num_embeddings)
latent = latent.long()
latent_healed = latent.clone()

# heal the sequence
# loop over tokens
for i in range(1, latent.shape[1]):
    if mask_flattened[i - 1]:
        # if token is low probability, replace with tranformer's most likely token
        logits = transformer_model(latent_healed[:, :i])
        probs = F.softmax(logits, dim=-1)
        # don't sample beginning of sequence token
        probs[:, :, vqvae_model.num_embeddings] = 0
        index = torch.argmax(probs[0, -1, :])
        latent_healed[:, i] = index


# reconstruct
latent_healed = latent_healed[:, 1:]
latent_healed = latent_healed[:, ordering.get_revert_sequence_ordering()]
latent_healed = latent_healed.reshape((32, 32))

image_healed = vqvae_model.decode_samples(latent_healed[None, ...]).cpu().detach()
plt.imshow(image_healed[0, 0, ...], cmap="gray")
plt.axis("off")
plt.title("Healed image")
plt.show()

In [None]:
# Create anomaly maps

In [None]:
# Get a naive anomaly map using the difference
difference_map = torch.abs(image_healed[0, 0, ...] - image_corrupted[0, ...])

# Further mask with the healing mask
resizer = torch.nn.Upsample(size=(128, 128), mode="nearest")
mask_upsampled = resizer(mask[None, None, ...].float()).int()

fig, ax = plt.subplots(1, 5, figsize=(14, 8))
plt.subplot(1, 5, 1)
plt.imshow(image_clean[0, ...], cmap="gray")
plt.axis("off")
plt.title("Clean image")
#image_corrupted = image_clean.clone()
#image_corrupted[0, 25:40, 40:50] = 1
plt.subplot(1, 5, 2)
plt.imshow(image_corrupted[0, ...], cmap="gray")
plt.axis("off")
plt.title("Corrupted image")
plt.subplot(1, 5, 3)
plt.imshow(image_corrupted[0, ...] - image_clean[0, ...], cmap="gray")
plt.axis("off")
plt.title("Ground-Truth anomaly mask")
plt.subplot(1, 5, 4)
plt.imshow(mask_upsampled[0, 0, ...] * difference_map, cmap="gray")
plt.axis("off")
plt.title("Predicted anomaly mask")
plt.show()

plt.subplot(1, 5, 5)
plt.imshow(image_healed[0, 0, ...], cmap="gray")
plt.axis("off")
plt.title("Healed image")
plt.show()

In [None]:
# Dice score calculation function
def dice_score(predicted, target, epsilon=1e-7):
    #predicted = predicted.view(-1).float()
    #target = target.view(-1).float()
    intersection = (predicted * target).sum()
    return (2. * intersection + epsilon) / (predicted.sum() + target.sum() + epsilon)

In [None]:
# Synthetic anomalies dice score

In [None]:
# test_loader should yield normal images

dice_scores = []  # List to store dice scores for all test samples

# Create a synthetic anomaly mask
gt_mask = np.zeros((128, 128), dtype=float)
gt_mask[50:80, 40:60] = 1.0
#gt_mask = gt_mask.to(device)
        
# Loop over the test dataset
for batch in test_data_normal_loader:
    images = batch  
    
    # Loop over the batch
    for i in range(len(images)):
        image_clean = images[i].to(device)  # send image to device
        
        # create synthetic anomalies and synthetic gt mask
        image_corrupted = image_clean.clone()
        image_corrupted[0, 50:80, 40:60] = 1
        
        # ... Your code to calculate predicted_mask for the image
        # This will use the code you showed above to calculate the mask based on log-likelihood
        log_likelihood = inferer.get_likelihood(
            inputs=image_corrupted[None, ...].to(device),
            vqvae_model=vqvae_model,
            transformer_model=transformer_model,
            ordering=ordering)
        likelihood = torch.exp(log_likelihood)
        mask = log_likelihood.cpu()[0, ...] < torch.quantile(log_likelihood, 0.04).item()
        # Further mask with the healing mask
        resizer = torch.nn.Upsample(size=(128, 128), mode="nearest")
        predicted_mask = resizer(mask[None, None, ...].float()).squeeze()
        #print(predicted_mask)
        # Now calculate the dice score for this image
        score = dice_score(predicted_mask, gt_mask)
        dice_scores.append(score.item())

# Now calculate the average dice score for the entire test dataset
average_dice_score = sum(dice_scores) / len(dice_scores)

print(f'Average Dice Score on Test Dataset: {average_dice_score:.4f}')

In [None]:
class NPZDataset2(Dataset):
    def __init__(self, root_dir, indices=None, transform=None):
        self.root_dir = root_dir
        self.files = [f for f in os.listdir(root_dir) if f.endswith('.npz')]
        if indices is not None:
            self.files = [self.files[i] for i in indices]
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file_path = os.path.join(self.root_dir, self.files[idx])
        data = np.load(file_path)
        sample = data["image"]
        
        # image data
        # Adds channel dimension
        sample = np.expand_dims(sample, axis=0)

        # Adds batch dimension and resizes to 256x256 using bicubic interpolation
        sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0)  # Add batch dimension
        sample = F.interpolate(sample, size=(128, 128), mode='bicubic', align_corners=False)
        
        # Normalize the sample
        sample = min_max_normalize(sample) 
        
        sample = sample.squeeze(0)  # Removes batch dimension
        
        if self.transform:
            sample = self.transform(sample)
        
        # label data
        sample_label = data["label"] 
        # Add channel dimension
        sample_label = np.expand_dims(sample_label, axis=0)

        # Add batch dimension and resize to 256x256 using bicubic interpolation
        sample_label = torch.tensor(sample_label, dtype=torch.float32).unsqueeze(0)  # Add batch dimension
        sample_label = F.interpolate(sample_label, size=(128, 128), mode='bicubic', align_corners=False)
        
        # Normalize the sample
        sample_label = min_max_normalize(sample_label)  

        sample_label = sample_label.squeeze(0)  # Remove batch dimension
        
        if self.transform:
            sample_label = self.transform(sample_label)
            
        return sample.clone().detach().to(dtype=torch.float32), sample_label.clone().detach().to(dtype=torch.float32)

In [None]:
test_data_abnormal_labels = NPZDataset2(root_dir="abnormal_slices", indices=list(range(1600)))
test_data_abnormal_loader = DataLoader(test_data_abnormal_labels, batch_size=32, shuffle=True, num_workers=4)

In [None]:
# Real anomalies dice score

In [None]:
# test_data_abnormal_loader should yield abnormal images and corresponding ground truth mask 

dice_scores = []  # List to store dice scores for all test samples

# Loop over the test dataset
for batch in test_data_abnormal_loader:
    images, ground_truth_masks = batch  # assuming each batch yields images and corresponding ground truth masks
    
    # Loop over the batch
    for i in range(len(images)):
        image = images[i].to(device)  # send image to device
        ground_truth_mask = ground_truth_masks[i].to(device)  # send ground truth mask to device
        
        # Calculate predicted_mask for the image based on log-likelihood
        log_likelihood = inferer.get_likelihood(
            inputs=image[None, ...].to(device),
            vqvae_model=vqvae_model,
            transformer_model=transformer_model,
            ordering=ordering)
        likelihood = torch.exp(log_likelihood)
        mask = log_likelihood.cpu()[0, ...] < torch.quantile(log_likelihood, 0.04).item()
        # Further mask with the healing mask
        resizer = torch.nn.Upsample(size=(128, 128), mode="nearest")
        predicted_mask = resizer(mask[None, None, ...].float()).int().squeeze().to(device)
        
        # Now calculate the dice score for this image
        score = dice_score(predicted_mask, ground_truth_mask)
        dice_scores.append(score.item())

# Now calculate the average dice score for the entire test dataset
average_dice_score = sum(dice_scores) / len(dice_scores)

print(f'Average Dice Score on Test Dataset: {average_dice_score:.4f}')
