In [None]:
import numpy as np
import os
import h5py

# Directory containing .h5 files
directory = '/media/abian/Extreme SSD/WorkSpace/Dataset/BRATS/archive/BraTS2020_training_data/content/data/'

# Create a list of all .h5 files in the directory
h5_files = [f for f in os.listdir(directory) if f.endswith('.h5')]
print(f"Found {len(h5_files)} .h5 files:\nExample file names:{h5_files[:3]}")

# Open the first .h5 file in the list to inspect its contents
if h5_files:
    file_path = os.path.join(directory, h5_files[25070])
    with h5py.File(file_path, 'r') as file:
        print("\nKeys for each file:", list(file.keys()))
        for key in file.keys():
            print(f"\nData type of {key}:", type(file[key][()]))
            print(f"Shape of {key}:", file[key].shape)
            print(f"Array dtype: {file[key].dtype}")
            print(f"Array max val: {np.max(file[key])}")
            print(f"Array min val: {np.min(file[key])}")
else:
    print("No .h5 files found in the directory.")

In [None]:
import matplotlib.pyplot as plt
plt.style.use('ggplot')
plt.rcParams['figure.facecolor'] = '#171717'
plt.rcParams['text.color']       = '#DDDDDD'

def display_image_channels(image, title='Image Channels'):
    channel_names = ['T1-weighted (T1)', 'T1-weighted post contrast (T1c)', 'T2-weighted (T2)', 'Fluid Attenuated Inversion Recovery (FLAIR)']
    fig, axes = plt.subplots(2, 2, figsize=(10, 10))
    for idx, ax in enumerate(axes.flatten()):
        channel_image = image[idx, :, :]  # Transpose the array to display the channel
        ax.imshow(channel_image, cmap='magma')
        ax.axis('off')
        ax.set_title(channel_names[idx])
    plt.tight_layout()
    plt.suptitle(title, fontsize=20, y=1.03)
    plt.show()

def display_mask_channels_as_rgb(mask, title='Mask Channels as RGB'):
    channel_names = ['Necrotic (NEC)', 'Edema (ED)', 'Tumour (ET)']
    fig, axes = plt.subplots(1, 3, figsize=(9.75, 5))
    for idx, ax in enumerate(axes):
        rgb_mask = np.zeros((mask.shape[1], mask.shape[2], 3), dtype=np.uint8)
        rgb_mask[..., idx] = mask[idx, :, :] * 255  # Transpose the array to display the channel
        ax.imshow(rgb_mask)
        ax.axis('off')
        ax.set_title(channel_names[idx])
    plt.suptitle(title, fontsize=20, y=0.93)
    plt.tight_layout()
    plt.show()

def overlay_masks_on_image(image, mask, title='Brain MRI with Tumour Masks Overlay'):
    t1_image = image[0, :, :]  # Use the first channel of the image
    t1_image_normalized = (t1_image - t1_image.min()) / (t1_image.max() - t1_image.min())

    rgb_image = np.stack([t1_image_normalized] * 3, axis=-1)
    color_mask = np.stack([mask[0, :, :], mask[1, :, :], mask[2, :, :]], axis=-1)
    rgb_image = np.where(color_mask, color_mask, rgb_image)
    
    plt.figure(figsize=(8, 8))
    plt.imshow(rgb_image)
    plt.title(title, fontsize=18, y=1.02)
    plt.axis('off')
    plt.show()
    
    
# Sample image to view
sample_file_path = os.path.join(directory, h5_files[25070])
data = {}
with h5py.File(sample_file_path, 'r') as file:
    for key in file.keys():
        data[key] = file[key][()]

# Transpose the image and mask to have channels first
image = data['image'].transpose(2, 0, 1)
mask = data['mask'].transpose(2, 0, 1)

# View images using plotting functions
display_image_channels(image)
display_mask_channels_as_rgb(mask)
overlay_masks_on_image(image, mask)

In [None]:
mask.shape

"Since tumors mostly occur in the middle of the brain, we exclude the lowest 80 slices and the uppermost 26 slices." Reference: Diffusion Models for Medical Anomaly Detection

In [None]:
# Regular expression to extract the patient ID from the file name:
# filename: volume_{id}_slice_{slice}.h5

import re
patient_slices = [re.search(r'volume_(\d+)_slice_(\d+)', f).groups() for f in h5_files]
patient_slices

filtered_patient_slices = list(filter(lambda p: 80 <= int(p[1]) < 128, patient_slices))
len(filtered_patient_slices)

In [None]:
filenames = [f'volume_{p[0]}_slice_{p[1]}.h5' for p in filtered_patient_slices]
filenames[:5]

In [None]:
from tqdm import tqdm

labels = np.zeros(len(filenames))

# Iterate through the files and check if the mask contains tumour
for i in tqdm(range(len(filenames))):
    sample_file_path = os.path.join(directory, filenames[i])
    data = {}
    with h5py.File(sample_file_path, 'r') as file:
        for key in file.keys():
            data[key] = file[key][()]

    if len(np.unique(data['mask'])) > 1:
        labels[i] = 1

    

In [None]:
import pandas as pd
df = pd.DataFrame({'Filename': filenames, 'Label': labels.astype(int)})
df.to_csv('tumour_labels.csv', index=False)

In [None]:
dataset_path = '/media/abian/Extreme SSD/WorkSpace/Dataset/BRATS/archive/'
df.to_csv(os.path.join(dataset_path, 'tumour_labels.csv'), index=False)

In [None]:
directory

In [None]:
idx_0 = 1209
print(filenames[idx_0])
sample_file_path = os.path.join(directory, filenames[idx_0])

data_0 = {}
with h5py.File(sample_file_path, 'r') as file:
        for key in file.keys():
            data_0[key] = file[key][()]

idx_1 = 2512
print(filenames[idx_1])
sample_file_path = os.path.join(directory, filenames[idx_1])

data_1 = {}
with h5py.File(sample_file_path, 'r') as file:
        for key in file.keys():
            data_1[key] = file[key][()]

In [None]:
plt.subplot(1, 2, 1)
plt.imshow(data_0['image'][:, :, 0].T, cmap='gray')
plt.title('Patient 0')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(data_1['image'][:, :, 0].T, cmap='gray')
plt.title('Patient 1')
plt.axis('off')

plt.show()

In [None]:
data_0['image'].shape

# Images

In [None]:
dataset_path = '/media/abian/Extreme SSD/WorkSpace/Dataset/BRATS/archive/'

import pandas as pd
df = pd.read_csv(os.path.join(dataset_path, 'tumor_labels.csv'))
filenames = df['Filename'].values

no_tumor_filenames = df[df['Label'] == 0]['Filename'].values

sample_file_path = os.path.join(directory, no_tumor_filenames[10])
data = {}
with h5py.File(sample_file_path, 'r') as file:
        for key in file.keys():
            data[key] = file[key][()]

In [None]:
np.bincount(df['Label'].values)

In [None]:
# return the filenames that has label as 0
df[df['Label'] == 0]

In [None]:
from matplotlib import pyplot as plt
for i in range(4):
    plt.subplot(2, 2, i+1)
    plt.imshow(data['image'][:, :, i].T, cmap='gray')
    plt.title(f'Channel {i}')
    plt.axis('off')

plt.show()


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

class AnoamlyBrainTumor(Dataset):
    ''' 
        BRATS 2020 dataset adapted for anomaly detection. Since tumors mostly occur in the middle of the brain,
        it is excluded the lowest 80 slices and the uppermost 26 slices from the dataset.
    '''
    def __init__(self, dataset_path:str, transform=None) -> None:
        super().__init__()

        self.dataset_path = os.path.join(dataset_path, 'BraTS2020_training_data/content/data/')
        self.transform = transform

        self.df = pd.read_csv(os.path.join(dataset_path, 'tumor_labels.csv'))
        self.filenames = self.df['Filename'].values
        self.labels = self.df['Label'].values

    def __len__(self) -> int:
        return len(self.filenames)
    
    def __getitem__(self, idx:int) -> torch.Tensor:
        sample_file_path = os.path.join(self.dataset_path, self.filenames[idx])
        with h5py.File(sample_file_path, 'r') as file:
            data = file['image'][()].astype(np.float32)
            file.close()

        if self.transform:
            data = self.transform(data)

        return data, self.labels[idx]

In [None]:
def normalize_brats_tensor(tensor: torch.Tensor) -> torch.Tensor:
    tensor = tensor - torch.min(tensor.view(4,-1), dim=1).values.reshape(4,1,1)
    return tensor / torch.max(tensor.view(4,-1), dim=1).values.reshape(4,1,1)

In [None]:
from torchvision import transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: normalize_brats_tensor(x)),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

dataset = AnoamlyBrainTumor(dataset_path, transform=transform)
# loader = DataLoader(AnoamlyBrainTumor(dataset_path), batch_size=128, shuffle=True)
# for data, label in loader:
#     print(data.shape, label.shape)

In [None]:
from torch.utils.data import Subset
no_tumor_idx = np.where(dataset.labels == 0)[0]
tumor_idx = np.where(dataset.labels == 1)[0]

# Select a subset of the dataset for test
test_set = Subset(dataset, np.concatenate([no_tumor_idx[:100], tumor_idx[:100]]))
print(f"Number of samples in test subset: {len(test_set)}")

train_set = Subset(dataset, np.concatenate([no_tumor_idx[100:], tumor_idx[100:100+512]]))
print(f"Number of samples in training subset: {len(train_set)}")

# Check the distribution of labels in the training set
np.bincount([dataset.labels[i] for i in train_set.indices])

train_loader = DataLoader(train_set, batch_size=128, shuffle=True)

In [None]:
import os, sys
project_dir = os.path.join(os.getcwd(),'..')
if project_dir not in sys.path:
    sys.path.append(project_dir)

import torch
from torch import nn
from VAE.AnomalyDetector import AnomalyDetector
from VAE.utils import SGVBL, cosine_scheduler

class VAEModel(nn.Module):
    def __init__(self, input_size, latent_space):
        super(VAEModel, self).__init__()
        conv_out_size = input_size // (2*2*2)
        self.encoder = nn.Sequential(
            nn.Conv2d(4, 16, 3, stride=1, padding=1),
            nn.BatchNorm2d(16),
            nn.Dropout2d(0.2),
            nn.ReLU(),
            nn.MaxPool2d(3, stride=2, padding=1),
            nn.Conv2d(16, 32, 3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.Dropout2d(0.2),
            nn.ReLU(),
            nn.MaxPool2d(3, stride=2, padding=1),
            nn.Conv2d(32, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.Dropout2d(0.2),
            nn.ReLU(),
            nn.MaxPool2d(3, stride=2, padding=1),
            nn.Flatten(),
            nn.Linear(64*(conv_out_size**2), 128),
            nn.BatchNorm1d(128),
            nn.Dropout(0.5),
            nn.ReLU(),
        )

        self.botleneck = AnomalyDetector(128, latent_space)

        self.decoder = nn.Sequential(
            nn.Linear(latent_space, 128),
            nn.BatchNorm1d(128),
            # nn.Dropout(0.1),
            nn.ReLU(),
            nn.Linear(128, 64*(conv_out_size**2)),
            nn.BatchNorm1d(64*(conv_out_size**2)),
            # nn.Dropout(0.1),
            nn.ReLU(),
            nn.Unflatten(1, (64, conv_out_size, conv_out_size)),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(64, 32, 3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            # nn.Dropout2d(0.1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(32, 16, 3, stride=1, padding=1),
            nn.BatchNorm2d(16),
            # nn.Dropout2d(0.1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(16, 4, 3, stride=1, padding=1),
        )


    def forward(self, x:torch.Tensor) -> torch.Tensor:
        x = self.encoder(x)
        x = self.botleneck(x)
        x = self.decoder(x)
        return x

model = VAEModel(240, 10)
from torch.nn.functional import mse_loss
sgvbl = SGVBL(model, len(train_set), mle=mse_loss)

In [None]:
model

In [None]:
240 // (2*2)

In [None]:
30*30*128

x = torch.rand(2,4,240,240)
z = model.encoder[:20](x)
b = model.botleneck(z)
d = model.decoder(b)



In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.train()
n_epochs = 100
from tqdm import tqdm
# kl_weight = 0.02

epoch_iterator = tqdm(
        range(n_epochs),
        leave=True,
        unit="epoch",
        postfix={"tls": "%.4f" % -1},
    )

kl_weight = cosine_scheduler(n_epochs)

opt = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in epoch_iterator:
    epoch_loss = 0.
    # kl_weight = min(kl_weight+0.012, .9)
    for x, y in train_loader:
        # check if there are a target with 1
        # if torch.any(y == 1):
            # print("Anomaly detected")
            # break
        x = x.to(device) # GPU
        opt.zero_grad()
        x_hat = torch.tanh(model(x))
        # loss = sgvbl(x, x_hat, y, kl_weight[epoch])
        loss = sgvbl(x, x_hat, y, 1)
        epoch_loss += loss.detach().item()

        loss.backward()
        opt.step()
    
    epoch_iterator.set_postfix(tls="%.3f" % (epoch_loss/len(train_loader)))

In [None]:
x_hat = torch.tanh(model(x))[0,0]

plt.imshow(x_hat.cpu().detach().numpy(), cmap='gray')

In [None]:
import numpy as np
def plot_reconstructed(autoencoder, r0=(-10, 10), r1=(-10, 10), n=12):
    w = 240
    img = np.zeros((n*w, n*w))
    for i, y in enumerate(np.linspace(*r1, n)):
        for j, x in enumerate(np.linspace(*r0, n)):
            z = torch.Tensor([[x, y]]).to(device)
            x_hat = torch.tanh(autoencoder.decoder(z))
            x_hat = x_hat.reshape(4, 240, 240)[1].to('cpu').detach().numpy()
            img[(n-1-i)*w:(n-1-i+1)*w, j*w:(j+1)*w] = x_hat
    
    plt.xlabel('$\mathcal{N}(0, \sigma_1)$', fontsize='x-large')
    plt.ylabel('$\mathcal{N}(0, \sigma_2)$', fontsize='x-large')
    plt.imshow(img, extent=[*r0, *r1], cmap='viridis')

model.eval()
plot_reconstructed(model, r0=(-6, 6), r1=(-6, 6), n=16)

In [None]:
x, y = dataset[0]
model.encoder[0:11](x.unsqueeze(0).to(device)).shape

In [None]:
x_0 = dataset[no_tumor_idx[0]][0].unsqueeze(0).to(device)
x_1 = dataset[tumor_idx[0]][0].unsqueeze(0).to(device)

x = torch.vstack([x_0, x_1])
x_hat = torch.tanh(model(x))

plt.subplot(2, 2, 1)
plt.imshow(x[0,0].cpu().detach().numpy(), cmap='gray')
plt.title('Original Image')
plt.axis('off')

plt.subplot(2, 2, 2)
plt.imshow(x_hat[0,0].cpu().detach().numpy(), cmap='gray')
plt.title('Reconstructed Image')
plt.axis('off')

plt.subplot(2, 2, 3)
plt.imshow(x[1,0].cpu().detach().numpy(), cmap='gray')
plt.title('Original Image')
plt.axis('off')

plt.subplot(2, 2, 4)
plt.imshow(x_hat[1,0].cpu().detach().numpy(), cmap='gray')
plt.title('Reconstructed Image')
plt.axis('off')

plt.show()


model.botleneck.mu, model.botleneck.sigma


In [None]:
model.encoder

In [None]:
240//2//2