In [None]:
import os, sys
project_dir = os.path.join(os.getcwd(),'../..')
if project_dir not in sys.path:
    sys.path.append(project_dir)

import pandas as pd
import numpy as np
import config
import h5py
from matplotlib import pyplot as plt

data_directory = os.path.join(config.BRATS_DATASET_PATH, 'BraTS2020_training_data/content/data/')

In [None]:
df = pd.read_csv(os.path.join(config.BRATS_DATASET_PATH, 'tumour_labels.csv'))
filenames = df['Filename'].values

no_tumor_filenames = df[df['Label'] == 0]['Filename'].values

sample_file_path = os.path.join(data_directory, no_tumor_filenames[10])
data = {}
with h5py.File(sample_file_path, 'r') as file:
        for key in file.keys():
            data[key] = file[key][()]

In [None]:

for i in range(4):
    plt.subplot(2, 2, i+1)
    plt.imshow(data['image'][:, :, i].T, cmap='gray')
    plt.title(f'Channel {i}')
    plt.axis('off')

plt.show()

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

class AnomalyBrainTumor(Dataset):
    ''' 
        BRATS 2020 dataset adapted for anomaly detection. Since tumors mostly occur in the middle of the brain,
        it is excluded the lowest 80 slices and the uppermost 26 slices from the dataset.
    '''
    def __init__(self, dataset_path:str, transform=None, on_memory=False) -> None:
        super().__init__()

        self.on_memory = on_memory
        self.dataset_path = os.path.join(dataset_path, 'BraTS2020_training_data/content/data/')
        self.transform = transform

        self.df = pd.read_csv(os.path.join(dataset_path, 'tumour_labels.csv'))
        self.filenames = list(map(lambda x: os.path.join(self.dataset_path, x), self.df['Filename'].values))
        self.labels = self.df['Label'].values

        if self.on_memory:
            self.data = self.__loaddata__(self.filenames)

    def __len__(self) -> int:
        return len(self.filenames)
    
    def __getitem__(self, idx:int) -> torch.Tensor:
        if self.on_memory: 
            data = self.data[idx]
        else:
            data = self.__readfile__(self.filenames[idx])

        if self.transform:
            data = self.transform(data)

        return data, self.labels[idx]

    def __loaddata__(self, filenames:list) -> np.ndarray:
        shape = (len(filenames), 240, 240, 4) # BRATS 2020 dataset shape
        dataset = np.zeros(shape, dtype=np.float32)
        for idx, filename in enumerate(filenames):
            with h5py.File(filename, 'r') as file:
                dataset[idx] = file['image'][()].astype(np.float32)
                file.close()

        return dataset

    def __readfile__(self, filename:str)-> np.ndarray:
        with h5py.File(filename, 'r') as file:
            data = file['image'][()].astype(np.float32)
            file.close()
        
        return data

In [None]:
def normalize_brats_tensor(tensor: torch.Tensor) -> torch.Tensor:
    tensor = tensor - torch.min(tensor.view(4,-1), dim=1).values.reshape(4,1,1)
    return tensor / torch.max(tensor.view(4,-1), dim=1).values.reshape(4,1,1)

from torchvision import transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: normalize_brats_tensor(x)),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

dataset = AnomalyBrainTumor(config.BRATS_DATASET_PATH, transform=transform, on_memory=False)

In [None]:
from torch.utils.data import Subset

def generate_dataset(dataset:Dataset, percent_normal_samples=.75, percent_anomalies=.1) -> (Subset, Subset):
    '''
        Generate a dataset with a given percentage of normal samples and anomalies. The percentage
        of anomalies samples is calculated based on the percentage of normal samples.
    
    '''
    normal_samples = np.where(dataset.labels == 0, 1, 0).astype(bool)

    normal_idx = np.argwhere(normal_samples==True).flatten()
    anomaly_idx = np.argwhere(normal_samples==False).flatten()

    #random selection of samples
    normal_selected = np.random.permutation(normal_idx)[:int(len(normal_idx) * percent_normal_samples)]
    anomaly_selected = np.random.permutation(anomaly_idx)[:int(len(normal_selected) * percent_anomalies)]

    training_subset = np.zeros(len(dataset), dtype=bool)
    training_subset[np.concatenate([normal_selected, anomaly_selected])] = True

    train_dataset = Subset(dataset, np.argwhere(training_subset).flatten())
    test_dataset = Subset(dataset, np.argwhere(~training_subset).flatten())

    return train_dataset, test_dataset


In [None]:
train_dataset, test_dataset = generate_dataset(dataset)

In [None]:
from ADeLEn.model import ADeLEn
from torch.nn.functional import mse_loss
from VAE.loss import SGVBL
from experiments.utils.ADeLEn import train
d = 10
model = ADeLEn((240, 240), [4, 8, 24, 32, 48], [1024, 128, 32], bottleneck=d, skip_connection=False)
sgvbl = SGVBL(model, len(dataset), mle=mse_loss)

In [None]:
from utils import generate_multi_df, generate_roc_df

def threshold(sigma, d) -> float:
        score = d * np.log(sigma)
        gauss = d * np.log(2*torch.pi*torch.e)
        return .5 * (gauss + score)

def roc_curve(model, test_dataset):
    '''
        Obtain the ROC curve of the model with the test dataset.

        Returns:
        --------
            fpr: float
                False positive rate.
            tpr: float
                True positive rate.
            roc_auc: float
                Area under the curve.
    '''
    from sklearn.metrics import roc_curve, auc

    X, y = zip(*test_dataset)
    X = torch.stack(X)
    y = torch.tensor(y).flatten()

    scores = model.score_samples(X)
    fpr, tpr, _ = roc_curve(y, scores)
    roc_auc = auc(fpr, tpr)

    return (fpr, tpr, roc_auc)

def classification_metrics(model, test_dataset, sigma=1.2, d=10) -> tuple:
    '''
        Test the model with the test dataset

        Returns:
        --------
            accuracy: float
                The accuracy of the model.
            precision: float
                The precision of the model.
            recall: float
                The recall of the model.
            f1: float
                The f1 score of the model.
    '''
    from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

    X, y = zip(*test_dataset)
    X = torch.stack(X)
    y = torch.tensor(y).flatten()

    scores = model.score_samples(X)

    y_pred = np.where(scores > threshold(sigma, d), 1, 0)
    
    accuracy = accuracy_score(y, y_pred)
    precision = precision_score(y, y_pred)
    recall = recall_score(y, y_pred)
    f1 = f1_score(y, y_pred)
    return (accuracy, precision, recall, f1)

def score_per_label(model, test_dataset):
    X, y = zip(*test_dataset)
    X = torch.stack(X)
    y = torch.tensor(y).flatten()

    scores = model.score_samples(X)
    return (scores[y == 0], scores[y == 1])

In [None]:
n_iter = 25
roc, scores = [], []
metrics = np.empty((n_iter, 5)) # acc, prec, rec, f1, auc

for i in range(n_iter):
    train_dataset, test_dataset = generate_dataset(dataset)
    model = train(model, train_dataset, 100, 1)
    fpr, tpr, roc_auc = roc_curve(model, test_dataset)
    acc, prec, rec, f1 = classification_metrics(model, test_dataset, sigma=1.2, d=d)
    
    roc.append((fpr, tpr))
    scores.append(roc_auc)
    metrics[i] = [acc, prec, rec, f1, roc_auc]