# Evaluation

In [2]:
import os
import re
import glob
import numpy as np
import pandas as pd
import json
import matplotlib.pyplot as plt
from matplotlib import cm, rc
import torch

from gliomagrowth.experiment.continuous_image import ContinuousTumorGrowth
from gliomagrowth.data import glioma
glioma.data_dir = "/media/jens/SSD/bovarec/multi_128/"

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

# rc('font',**{'family':'serif','serif':['Computer Modern Roman']})
# rc('text', usetex=True)

In [5]:
os.chdir("/media/jens/Data4TB/experiments/mlflow/mlruns/")

### Utilities

In [8]:
def load_data_module(path):
    dm = glioma.GliomaModule.load_from_checkpoint(path, data_dir=glioma.data_dir)
    dm.setup()
    return dm

def load_experiment(path, device="cuda"):
    exp = ContinuousTumorGrowth.load_from_checkpoint(path)
    exp.model.eval()
    exp.to(device=device)
    return exp

def get_experiment_from_name(dir_, name):
    
    for exp in os.listdir(dir_):
        try:
            with open(os.path.join(dir_, exp, "tags", "mlflow.runName"), "r") as infile:
                if name == infile.readline():
                    break
        except (FileNotFoundError, NotADirectoryError):
            continue
    else:
        raise FileNotFoundError("Could not find model with name {}.".format(name))
        
    return exp

def load_experiment_from_name(dir_, name, device="cuda"):
    
    exp = get_experiment_from_name(dir_, name)
        
    checkpoints = glob.glob(os.path.join(dir_, exp, "artifacts", "checkpoint*"))
    if len(checkpoints) == 0:
        raise FileNotFoundError("Couldn't find any checkpoints in {}.".format(os.path.join(dir_, exp, "artifacts")))
    
    epochs = []
    for cp in checkpoints:
        cp = os.path.basename(cp)
        e = int(cp.split("_")[1][6:])
        epochs.append(e)
    max_epoch = np.argmax(epochs)
    
    checkpoint = checkpoints[max_epoch]
    print("Loading checkpoint from {}.".format(checkpoint))
    return load_experiment(checkpoint, device)

### Look at scores

In [9]:
def load_test(exp):
    df = pd.read_csv(os.path.join(exp, "artifacts", "test.csv"))
    df = df.set_index("Subject and Timestep")
    for name in df.columns:
        if name.startswith("Unnamed"):
            df = df.drop(name, axis=1)
    return df

In [234]:
exp = get_experiment_from_name("10", "att2_temp0_128-64_split0_lw0.00001_convup")
scores = load_test("10/" + exp).sort_index()

Add column for true Dice if necessary

In [235]:
if "True Dice" not in scores.columns:
    storage = "dice_2d.csv"
    if os.path.exists(storage):
        dice_list = pd.read_csv(storage)
        dice_list.columns = ["Subject and Timestep", "True Dice"]
        dice_list = dice_list.set_index("Subject and Timestep")
    else:
        data = glioma.load("r")
        gen = glioma.FutureContextGenerator2D(data, 1, (2, 3, 4, 5), ddir=glioma.data_dir)
        info = []
        dice_list = []
        for (context, target) in gen:
            subject_info = "_".join(
                [
                    context["subjects"][0],
                    str(context["timesteps"][0] + context["data"].shape[1]),
                    str(context["slices"][0]),
                    "it" + str(context["data"].shape[1]),
                ]
            )
            seg_in = context["seg"][0, -1, 0] > 0
            seg_out = target["seg"][0, 0, 0] > 0
            gt_volume = np.sum(seg_out)
            if gt_volume == 0:
                continue
            else:
                info.append(subject_info)
                dice_list.append(dice(seg_in, seg_out))
        dice_list = pd.DataFrame(dice_list, columns=["True Dice"], index=info)
        dice_list.to_csv(storage)
    scores.insert(0, "True Dice", dice_list.loc[scores.index])

Remove cases where there is no true overlap

In [236]:
scores = scores[scores["True Dice"] > 0]

Print scores at different true overlap thresholds (1 = all)

In [237]:
true_overlap_low10 = np.percentile(scores["True Dice"], 10)
true_overlap_mean = scores["True Dice"].mean()

thresh = 1.

print(scores[scores["True Dice"] < thresh].mean())
# print(scores[scores["True Dice"] < thresh].std() / np.sqrt(scores[scores["True Dice"] < thresh].shape[0]))

True Dice                        0.670820
GT Volume                      410.448786
Loss Task                        0.356302
Loss Latent                    589.061122
Loss                             0.362192
Dice Class 0                     0.966361
Dice Class 1                     0.583129
Dice Class 2                     0.273732
Dice Foreground                  0.620503
Best Volume Dice Class 0         0.967554
Best Volume Dice Class 1         0.600307
Best Volume Dice Class 2         0.277348
Best Volume Dice Foreground      0.639785
Best Dice Class 0                0.968904
Best Dice Class 1                0.621220
Best Dice Class 2                0.182551
Best Dice Foreground             0.658738
dtype: float64
