In [None]:
from __future__ import absolute_import, division, print_function
import logging, os, sys

# Enable logging
logging.basicConfig(format='[%(levelname)s] %(message)s', level=logging.INFO, stream=sys.stdout)

# Re-import packages if they change
%load_ext memory_profiler
%load_ext autoreload
%autoreload 2

# Recursion Depth
sys.setrecursionlimit(1000000000)

# Intialize tqdm to always use the notebook progress bar
import tqdm

tqdm.tqdm = tqdm.tqdm_notebook

# Third-party libraries
import comet_ml

import numpy as np
import pandas as pd
import nilearn.plotting as nip
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import torch
import collections
%matplotlib inline
plt.rcParams["figure.figsize"] = (12,6)
%config InlineBackend.figure_format='retina'  # adapt plots for retina displays
import git
import aneurysm_utils.evaluation as evaluation

# Project utils

import aneurysm_utils
from aneurysm_utils import evaluation, training,preprocessing


In [None]:
if "workspace" in os.getcwd():
    ROOT = "/workspace" # local 
elif "/group/cake" in os.getcwd(): 
    ROOT = "/group/cake" # Jupyter Lab


In [None]:
env = aneurysm_utils.Environment(project="our-git-project", root_folder=ROOT)
env.cached_data["comet_key"] = "EGrR4luSis87yhHbs2rEaqAWs" 
env.print_info()

In [None]:
## this notebook needs the same seed as the training notebook to get disjoint test samples 
### the dataset and preprocessing params should be the same as in the training notebook


dataset_params = {
    "prediction": "mask",
    "mri_data_selection": "", 
    "balance_data": False,
    "seed": 1,
    "resample_voxel_dim": (1.2,1.2,1.2)
}

preprocessing_params = {
    'min_max_normalize': True,
    'mean_std_normalize': False,
    'smooth_img': False, # can contain a number: smoothing factor
    'intensity_segmentation': False
}


In [None]:
from aneurysm_utils.data_collection import load_aneurysm_dataset

df = load_aneurysm_dataset(
    env,
    mri_data_selection=dataset_params["mri_data_selection"],
    random_state=dataset_params["seed"]
)
df.head()

In [None]:
# Load MRI images and split into train, test, and validation
from aneurysm_utils.data_collection import split_mri_images
#case_list = [ "A123", "A121", "A124"] # "A003","A005","A006","A008", "A010", "A012","A009", "A120",
#df = df.loc[df["Case"].isin(case_list)]

train_data, test_data, val_data, _ = split_mri_images(
    env, 
    df, 
    prediction=dataset_params["prediction"], 
    encode_labels=False,
    random_state=dataset_params["seed"],
    balance_data=dataset_params["balance_data"],
    resample_voxel_dim=dataset_params["resample_voxel_dim"]
)

mri_imgs_train, labels_train,train_participants = train_data
mri_imgs_test, labels_test,test_participants = test_data
mri_imgs_val, labels_val,val_participants = val_data

In [None]:
from aneurysm_utils import preprocessing

most_common_shape = preprocessing.check_mri_shapes(mri_imgs_train)

In [None]:
size = most_common_shape  #(139, 139, 120)
train_index = [i for i, e in enumerate(mri_imgs_train) if e.shape != size]
mri_imgs_train = [i for j, i in enumerate(mri_imgs_train) if j not in train_index]
labels_train = [i for j, i in enumerate(labels_train) if j not in train_index]

test_index = [i for i, e in enumerate(mri_imgs_test) if e.shape != size]
mri_imgs_test = [i for j, i in enumerate(mri_imgs_test) if j not in test_index]
labels_test = [i for j, i in enumerate(labels_test) if j not in test_index]

val_index = [i for i, e in enumerate(mri_imgs_val) if e.shape != size]
mri_imgs_val = [i for j, i in enumerate(mri_imgs_val) if j not in val_index]
labels_val = [i for j, i in enumerate(labels_val) if j not in val_index]

mri_imgs_train[0].shape
preprocessing.check_mri_shapes(mri_imgs_train)
print(np.unique(labels_val[0], return_counts=True))

In [None]:
from aneurysm_utils import preprocessing
patch_size = 64
size_of_train = len(mri_imgs_train)
size_of_test = len(mri_imgs_test)
size_of_val = len(mri_imgs_val)

# preprocess all lists as one to have a working mean_std_normalization
mri_imgs = mri_imgs_train + mri_imgs_test + mri_imgs_val
mri_imgs = preprocessing.preprocess(env, mri_imgs, preprocessing_params)
###creating patches
#mri_imgs_train = np.asarray(mri_imgs[:size_of_train])
#mri_imgs_train = patch_list(mri_imgs_train,patch_size)
mri_imgs_test = np.asarray(mri_imgs[size_of_train : size_of_train + size_of_test])
mri_imgs_test = preprocessing.patch_list(mri_imgs_test,patch_size)
#mri_imgs_val = np.asarray(mri_imgs[size_of_train + size_of_test :])
#mri_imgs_val = patch_list(mri_imgs_val,patch_size)

# preprocess mask
x, y, h = labels_train[0].shape
#labels_train = patch_list(labels_train,patch_size)
labels_test = np.asarray(labels_test)
labels_test = patch_list(labels_test,patch_size)
#labels_val = patch_list(labels_val,patch_size)

In [None]:
## to save RAM 
del mri_imgs_train
del mri_imgs_val

In [None]:
from aneurysm_utils.utils.pytorch_utils import predict
from aneurysm_utils.models.unet_3d_oktay import unet_3D
from aneurysm_utils.models.attention_unet import unet_grid_attention_3D

In [None]:
#model = unet_grid_attention_3D(self, feature_scale=4, n_classes=2, is_deconv=True, in_channels=1,
                 #nonlocal_mode='concatenation', attention_dsample=(2,2,2), is_batchnorm=True)
    
model = unet_3D(feature_scale=2, n_classes=2, is_deconv=True, in_channels=1, is_batchnorm=True)
PATH = "group/cake/our-git-project/models/insert_your_model_either U-Net or Attention U-net.pt"
device = torch.device('cpu')


model.load_state_dict(torch.load(PATH, map_location=device))

In [None]:
## save some RAM
import gc
gc.collect()

In [None]:
from aneurysm_utils.utils import pytorch_utils
from torch.utils.data.dataloader import DataLoader
test_dataset = pytorch_utils.PytorchDataset(
                mri_imgs_test,
                labels_test,
                dtype=np.float64,
                
            )



In [None]:
test_loader = DataLoader(
            test_dataset,
            batch_size=1,  # TODO: use fixed batch size of 5
            shuffle=False,
            num_workers=0,
            pin_memory=True,
        )
predictions = predict(model, test_loader, apply_softmax=False )

In [None]:

###############   Name und Ort gegebenfalls ändern
#### predictions is saved as numpy array
## The shape of the array is (number_of_patches x length_test_set,tuple(masks,probabilities),h,w,d)
## PATH choose your patch to save the predictions
PATH = "/group/cake/our-git-project/predictions/preds.npy"
np.save(PATH,predictions)