# Model analysis

In [None]:
import ast
import os
from datetime import datetime
from pathlib import Path

import matplotlib.pyplot as plt
import mlflow
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from autoencoder.datasets import DiffusionMRIDataset, SphericalTransformer
from autoencoder.logger import logger, set_log_level

In [None]:
set_log_level(10)

In [None]:
sns.set_theme(context="notebook", style="ticks", rc={"figure.figsize": (11.7 / 2, 8.27 / 2)})

In [None]:
logger.info("torch version %s", torch.__version__)

In [None]:
# use gpu if available, else cpu
has_cuda = torch.cuda.is_available()

logger.info("Is the GPU available? %s", has_cuda)
device = torch.device("cuda" if has_cuda else "cpu")

if has_cuda:
    logger.info("Current device: %s", torch.cuda.current_device())
    logger.info("Device count: %s", torch.cuda.device_count())
    torch.cuda.set_device(0)
    logger.info("Using device: %s", torch.cuda.get_device_properties(device))
else:
    logger.warning("No GPU dectected! Training will be extremly slow")

## Loading the models

In [None]:
server_ip = "localhost"

os.environ["AWS_ACCESS_KEY_ID"] = "minio"
os.environ["AWS_SECRET_ACCESS_KEY"] = "minio123"
os.environ["MLFLOW_S3_ENDPOINT_URL"] = f"http://{server_ip}:9000"

mlflow.set_tracking_uri(f"http://{server_ip}:5000")


def get_mlflow_runs(experiment_id: int, filter_tags=None):
    filter_tags = dict() if filter_tags is None else filter_tags
    df_frames = list()
    run_infos = mlflow.list_run_infos(str(experiment_id), run_view_type=mlflow.entities.ViewType.ACTIVE_ONLY)
    for run_info in run_infos:

        run = mlflow.get_run(run_info.run_uuid)

        for filter_key, filter_value in filter_tags.items():
            if filter_key not in run.data.tags.keys() or run.data.tags[filter_key] != filter_value:
                break
        else:
            logger.debug("Loading run info for: %s", run_info.run_uuid)

            metrics = {f"metrics_{key}": val for key, val in run.data.metrics.items()}
            params = {f"params_{key}": val for key, val in run.data.params.items()}
            tags = {f"tags_{key}": val for key, val in run.data.tags.items()}

            features_dict = {**dict(run.info), **metrics, **params, **tags}

            df_tmp = pd.DataFrame.from_records([features_dict])

            df_tmp["end_time"] = pd.to_datetime(df_tmp["end_time"], unit="ms")
            df_tmp["start_time"] = pd.to_datetime(df_tmp["start_time"], unit="ms")

            df_frames.append(df_tmp)

    df_runs = pd.concat(df_frames)
    df_runs = df_runs.set_index("start_time")
    return df_runs

In [None]:
df_runs_fcn = get_mlflow_runs(3, dict(data="MUDI"))

In [None]:
df_runs_fourier_s2 = get_mlflow_runs(4, dict(data="MUDI"))

## Model evalutation

In [None]:
ROOT_PATH = ".."
IMAGES_PATH = Path(ROOT_PATH, "images")
IMAGES_PATH.mkdir(parents=True, exist_ok=True)

### Reconstruction loss (MSE) for each model

In [None]:
def predict(subject, row, tissue="wb"):
    artifact_path = str(row.artifact_uri)

    # load the data transformer if one was used during training.
    transform = None
    if "params_transform" in row.keys() and row.params_transform != "None":
        transform_args = ast.literal_eval(row.params_transform)
        if transform_args["class_path"] == "autoencoder.datasets.SphericalTransformer":
            transform = SphericalTransformer(**transform_args["init_args"])

    # Load the latent features. Replace the file: with / DOES NOT WORK IN WINDOWS
    p = list(Path(artifact_path, "latent_features.txt").parts)
    p[0] = "/"
    features = np.loadtxt(Path(*p), dtype=np.int32)

    # create the data set
    data_set = DiffusionMRIDataset(
        Path("..", "data", "prj_MUDI_parameters.hdf5"),
        Path("..", "data", "prj_MUDI_data.hdf5"),
        np.array([subject]),
        tissue,
        batch_size=256,
        return_target=True,
        include_parameters=features,
        transform=transform,
    )
    data_gen = DataLoader(
        data_set,
        batch_size=None,
        batch_sampler=None,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
        drop_last=False,
    )

    # construct the model
    model = mlflow.pytorch.load_model(str(Path(artifact_path, "scripted_model")))
    model.cuda()
    model.eval()

    # make the prediction
    predictions = list()
    targets = list()
    with torch.inference_mode():
        for batch_idx, batch in tqdm(enumerate(data_gen)):
            if hasattr(data_gen.dataset, "get_subject_id_by_batch_id"):
                subject_id = data_gen.dataset.get_subject_id_by_batch_id(batch_idx)
                metadata = data_gen.dataset.get_metadata_by_subject_id(subject_id)
            else:
                raise Exception("Unknown dataset type. Could not get metadata")

            sample, target = batch["sample"], batch["target"]
            if isinstance(batch["sample"], dict):
                for k in sample:
                    sample[k] = sample[k].to(device)
            else:
                sample = sample.to(device)

            prediction = model(sample)
            prediction = (prediction.T / metadata["lstsq_coefficient"] * metadata["max_data"]).T
            target = (target.T / metadata["lstsq_coefficient"] * metadata["max_data"]).T

            predictions.append(prediction)
            targets.append(target)

    prediction = torch.cat(predictions)
    target = torch.cat(targets)

    # return ground truth and prediction
    return target.cpu(), prediction.cpu()


mse_loss = torch.nn.MSELoss(reduction="mean").to(device)


def calc_loss(row, subject, tissue):

    target, prediction = predict(subject, row, tissue=tissue)
    loss = mse_loss(prediction, target).to("cpu")

    losses = np.full(prediction.shape[1] + 3, np.nan)
    for i in range(3, losses.shape[0]):
        loss_raw = torch.nn.functional.mse_loss(prediction[:, i - 3], target[:, i - 3], reduction="none")
        losses[i] = (torch.sum(loss_raw) / loss_raw.shape[0]).to("cpu")

    losses[0] = np.mean(losses[3:])
    losses[1] = np.median(losses[3:])
    losses[2] = np.percentile(losses[3:], 85)

    return losses


def add_loss_metrics(df):
    tissue_dfs = list()
    for tissue in ["gm", "wm", "csf", "wb"]:
        df_loss = df.apply(calc_loss, axis=1, args=[15, tissue], result_type="expand")
        column_names = ["mean_loss", "median_loss", "85th_percentile_loss"]
        for i in range(3, len(df_loss.columns)):
            column_names.append(str(i - 3))

        df_loss.columns = column_names

        x = pd.merge(df, df_loss[["mean_loss", "median_loss", "85th_percentile_loss"]], on="start_time")
        x["tissue"] = tissue

        df_loss_tmp = df_loss.iloc[:, 3:].T
        df_loss_tmp["feature"] = df_loss_tmp.index

        dfs = list()
        for column in df_loss_tmp.columns[:-1]:
            tmp = pd.DataFrame(df_loss_tmp)
            tmp["start_time"] = column
            tmp = tmp[[column, "start_time", "feature"]]
            tmp.columns = ["loss", "start_time", "feature"]
            dfs.append(tmp)

        df_loss = pd.concat(dfs)
        df_loss = df_loss.set_index("start_time")

        tissue_dfs.append(pd.merge(x, df_loss, on=["start_time"]))

    return pd.concat(tissue_dfs)


def clean_up_df(df):
    df["tags_input_size"] = pd.to_numeric(df["tags_input_size"])
    df = df.sort_values(by=["tags_input_size"])
    return df

In [None]:
df_runs_fourier_s2 = add_loss_metrics(df_runs_fourier_s2)
df_runs_fourier_s2 = clean_up_df(df_runs_fourier_s2)

In [None]:
df_runs_fcn = add_loss_metrics(df_runs_fcn)
df_runs_fcn = clean_up_df(df_runs_fcn)

In [None]:
x = df_runs_fcn.drop_duplicates(subset=["artifact_uri", "tissue"]).sort_values(by=["tags_run_group", "tissue"])

latent_size = 250
# run_group = ("FCN: random samples 1",)
run_group = ("FCN: no regularisation",)


def to_table(row):
    print(f"{row.tissue}: ({row.mean_loss:.2e}, {row.median_loss:.2e}, {row['85th_percentile_loss']:.2e})")


x = x.loc[
    (x.tags_input_size == latent_size) & (x.tags_run_group.isin(run_group)),
    ["tags_run_group", "tissue", "mean_loss", "median_loss", "85th_percentile_loss"],
]
# .apply(to_table, axis=1)
val = ""
for tissue in ["wb", "gm", "wm", "csf"]:
    row = x[x.tissue == tissue].iloc[0]
    val += f"  ({row.mean_loss:.4f}, {row.median_loss:.4f})"
    if tissue != "csf":
        val += " &\n"
    else:
        val += " &"

print(val)

In [None]:
df_runs_fcn = df_runs_fcn.sort_values(by=["tags_run_group", "tissue"])
df_runs_fourier_s2 = df_runs_fourier_s2.sort_values(by=["tags_run_group", "tissue"])

In [None]:
df_runs = pd.concat([df_runs_fcn, df_runs_fourier_s2])

In [None]:
ax = sns.pointplot(
    x="params_input_size",
    y="mean_loss",
    hue="params_hidden_layers",
    data=df_runs_fcn,
    palette="rocket",
    hue_order=["2", "1", "3"],
    markers=["o", "s", "x"],
    dodge=True,
    ci=None,
    # linestyles=["-", "--", "-."],
)
handles, labels = ax.get_legend_handles_labels()

ax.set(
    xlabel="Latent size",
    ylabel="MSE",
)
sns.despine(trim=True, bottom=True)
plt.legend(title="Num. decoder layers", handles=handles, labels=["2", "3", "4"])
plt.tight_layout()

plt.savefig(Path(IMAGES_PATH, "FCN_decoder_sizes.pdf"), bbox_inches="tight")

### Plot average loss

In [None]:
grid = sns.catplot(
    x="tissue",
    y="loss",
    hue="tags_run_group",
    col="tags_input_size",
    col_wrap=2,
    data=df_runs_fcn,
    kind="box",
    palette="rocket",
)
grid.set(yscale="log")
grid.despine(trim=True, bottom=True)
grid.tight_layout()
grid.set_titles(template="Latent size: {col_name}")
grid.set_ylabels("MSE")
grid.set_xlabels("Tissue type")

grid.legend.remove()
handles, labels = grid.axes[1].get_legend_handles_labels()
print(labels)
grid.fig.legend(
    handles,
    ["no regularisation", "regularisation"],
    loc="upper right",
    title="FCN Decoder:",
)
grid.savefig(Path(IMAGES_PATH, "HCP_FCN_loss_boxplot.pdf"), bbox_inches="tight")

In [None]:
grid = sns.catplot(
    x="tags_input_size",
    y="mean_loss",
    hue="tags_run_group",
    col="tissue",
    col_wrap=2,
    data=df_runs_fcn,
    kind="point",
    sharey=False,
    palette="rocket",
    dodge=True,
    markers=["o", "s", "x", "^"],
    ic=None,
)
# grid.set(yscale="log")
grid.despine(trim=True, bottom=True)
grid.tight_layout()
grid.set_titles(template="Tissue type: {col_name}")
grid.set_ylabels("MSE")
grid.set_xlabels("Latent size")

grid.legend.remove()
handles, labels = grid.axes[1].get_legend_handles_labels()
print(labels)
grid.fig.legend(
    handles,
    ["no regularisation", "regularisation"],
    loc="upper right",
    title="FCN Decoder:",
)
grid.savefig(Path(IMAGES_PATH, "FCN_loss_latent_size.pdf"), bbox_inches="tight")

In [None]:
grid = sns.catplot(
    x="tissue",
    y="loss",
    hue="tags_run_group",
    col="tags_input_size",
    col_wrap=2,
    data=df_runs_fourier_s2,
    kind="box",
    palette="rocket",
    sharey=False,
)
grid.set(yscale="log")
grid.despine(trim=True, bottom=True)
grid.tight_layout()
grid.set_titles(template="Latent size: {col_name}")
grid.set_ylabels("MSE")
grid.set_xlabels("Tissue type")

grid.legend.remove()
handles, labels = grid.axes[1].get_legend_handles_labels()
print(labels)
grid.fig.legend(
    handles,
    ["regularisation", "no regularisation"],
    loc="upper right",
    title="Fourier S2 Decoder:",
)
grid.savefig(Path(IMAGES_PATH, "HCP_Fourier_loss_boxplot.pdf"), bbox_inches="tight")

In [None]:
grid = sns.catplot(
    x="tags_input_size",
    y="mean_loss",
    hue="tags_run_group",
    col="tissue",
    col_wrap=2,
    data=df_runs_fourier_s2,
    kind="point",
    sharey=False,
    palette="rocket",
    dodge=True,
    markers=["o", "s", "x", "^"],
)
# grid.set(yscale="log")
grid.despine(trim=True, bottom=True)
grid.tight_layout()
grid.set_titles(template="Tissue type: {col_name}")
grid.set_ylabels("MSE")
grid.set_xlabels("Latent size")

grid.legend.remove()
handles, labels = grid.axes[1].get_legend_handles_labels()
print(labels)
grid.fig.legend(
    handles,
    ["regularisation", "random samples", "no regularisation"],
    loc="upper right",
    title="Fourier S2 Decoder:",
)
grid.savefig(Path(IMAGES_PATH, "Fourier_loss_latent_size.pdf"), bbox_inches="tight")

In [None]:
# df_runs_fcn_f = df_runs_fcn[
#     (df_runs_fcn.tags_run_group != "FCN: random samples 0")
# ]
# df_runs_fourier_s2_f = df_runs_fourier_s2[df_runs_fourier_s2.tags_run_group != "Fourier S2: random samples"]
grid = sns.catplot(
    x="tags_input_size",
    y="mean_loss",
    hue="tags_run_group",
    col="tissue",
    col_wrap=2,
    data=pd.concat([df_runs_fcn, df_runs_fourier_s2]),
    kind="point",
    sharey=False,
    palette="rocket",
    dodge=True,
    markers=["o", "s", "x", "^"],
)
# grid.set(yscale="log")
grid.despine(trim=True, bottom=True)
grid.tight_layout()
grid.set_titles(template="Tissue type: {col_name}")
grid.set_ylabels("MSE")
grid.set_xlabels("Latent size")

grid.legend.remove()
handles, labels = grid.axes[1].get_legend_handles_labels()
print(labels)
grid.fig.legend(
    handles,
    labels,
    loc="upper right",
    title="Decoder:",
)
grid.savefig(Path(IMAGES_PATH, "HCP_loss_latent_size.pdf"), bbox_inches="tight")

### Plot feature occurence count

Some features occure multiple times in a single model. Lets plot the top 20 most occuring features.

In [None]:
def get_feat_count(row):
    model_uri = row.artifact_uri + "/model"
    model = mlflow.pytorch.load_model(model_uri)
    logits = model.encoder.logits
    logits_size = logits.size()
    features = torch.argmax(logits, len(logits_size) - 1)

    counts = np.bincount(features)
    counts_df = pd.DataFrame(counts, columns=["count"])
    return counts_df


def get_feat_counts(data):
    counts_dfs = []
    if type(data) is pd.DataFrame:
        for _, row in data.iterrows():
            counts_df = get_feat_count(row)
            counts_dfs.append(counts_df)
    else:  # assume it is a Series
        counts_df = get_feat_count(data)
        counts_dfs.append(counts_df)
    # sum all the bin counts
    df_counts = pd.concat(counts_dfs).groupby(level=0).sum().reset_index()
    return df_counts


get_feat_count(df.iloc[0]).reset_index()

In [None]:
import torch.nn.functional as F


def feature_count(row):
    print("lambda:", row.params_lambda_reg)
    print("val loss:", row.metrics_val_loss)

    model_uri = row.artifact_uri + "/model"
    model = mlflow.pytorch.load_model(model_uri)
    logits = model.encoder.logits
    logits_size = logits.size()
    features = torch.argmax(logits, len(logits_size) - 1).numpy()

    eps = 1e-10
    threshold = 3.0
    selection = torch.clamp(F.softmax(logits, dim=0), eps, 1)
    print("reg term:", torch.sum(F.relu(torch.norm(selection, 1, dim=1) - threshold)))

    return pd.DataFrame(features, columns=["feature"])


sns.set(rc={"figure.figsize": (40, 4)})
sns.countplot(data=feature_count(df.iloc[2]), x="feature");

In [None]:
sns.countplot(data=feature_count(df.iloc[1]), x="feature");

In [None]:
sns.countplot(data=feature_count(df.iloc[0]), x="feature");

In [None]:
df_counts_total

In [None]:
def show_values_on_bars(axs, h_v="v", space=0.4):
    """Code from https://stackoverflow.com/a/56780852/6131485"""

    def _show_on_single_plot(ax):
        if h_v == "v":
            for p in ax.patches:
                _x = p.get_x() + p.get_width() / 2
                _y = p.get_y() + p.get_height()
                value = int(p.get_height())
                ax.text(_x, _y, value, ha="center")
        elif h_v == "h":
            for p in ax.patches:
                _x = p.get_x() + p.get_width() - float(space)
                _y = p.get_y() + p.get_height() - 0.2
                value = int(p.get_width())
                ax.text(_x, _y, value, ha="right", c="white")

    if isinstance(axs, np.ndarray):
        for idx, ax in np.ndenumerate(axs):
            _show_on_single_plot(ax)
    else:
        _show_on_single_plot(axs)

In [None]:
def plot_counts(df, top_size=10, ax=None, title=None):
    df_counts = df.sort_values(by="count", ascending=False)[:top_size]
    plot = sns.barplot(
        orient="h",
        x="count",
        y="index",
        data=df_counts,
        order=df_counts["index"].values,
        palette="rocket",
        ax=ax,
    )
    plot.set(xlabel=None, ylabel=None, title=title)
    show_values_on_bars(plot, "h")
    if ax is not None:
        ax.grid(True, which="both", ls="-", c="lightgray")


fig, axes = plt.subplots(5, 3, figsize=(5 * 3, 5 * 5))
# fig.suptitle("Volume counts for each model", x=0.5, y=1)
fig.text(0.5, -0.01, "Count", ha="center")
fig.text(-0.01, 0.5, "Volume", va="center", rotation="vertical")

for i, ax in enumerate(axes.flatten()):
    row = df.iloc[i]
    df_counts = get_feat_counts(row)
    plot_counts(df_counts, ax=ax, title=f"feat={row['n_features']} decoder={row['decoder']}")

sns.despine(left=True)

plt.tight_layout()

image_path = Path(IMAGES_PATH, "feature_count.pdf")
plt.savefig(image_path, bbox_inches="tight");

In [None]:
fig, axes = plt.subplots(5, 3, figsize=(5 * 3, 5 * 5))
fig.suptitle("Feature counts for each model", x=0.5, y=0.9)
fig.text(0.5, 0.1, "Count", ha="center")
fig.text(0.07, 0.5, "Feature", va="center", rotation="vertical")

for i, ax in enumerate(axes.flatten()):
    try:
        row = df.iloc[15 + i]
        df_counts = get_feat_counts(row)
        plot_counts(df_counts, ax=ax, title=f"feat={row['n_features']} decoder={row['decoder']}")
    except IndexError:
        continue
image_path = Path(IMAGES_PATH, "feature_count_exclude.pdf")
plt.savefig(image_path, bbox_inches="tight");

### Interactive model plot

In [None]:
from nilearn import image, masking

In [None]:
import matplotlib.gridspec as gridspec
import plotly
import plotly.express as px
import plotly.graph_objects as go
from bokeh.io import output_notebook, show
from bokeh.layouts import column, row
from bokeh.models import (
    ColorBar,
    ColumnDataSource,
    LinearColorMapper,
    LogColorMapper,
    PreText,
    RadioButtonGroup,
    Select,
    Slider,
    Spinner,
)
from bokeh.plotting import figure
from plotly.subplots import make_subplots

output_notebook()

In [None]:
df_runs_fcn_f = df_runs_fcn.drop_duplicates(subset=["artifact_uri"])
df_runs_fourier_f = df_runs_fourier_s2.drop_duplicates(subset=["artifact_uri"])


def get_predict_mse(df, latent_size, run_group):
    row = df[(df.tags_input_size == latent_size) & (df.tags_run_group == run_group)].iloc[0]
    _, _, predict_mse = load_dmri(row, 15, return_mse=True)
    return predict_mse


def plot_imshow_rgb(axes, data, do_rot90: bool):
    data = np.rot90(data) if do_rot90 else data
    masked = np.dstack([data, ~(data[..., 1] == 0)])
    axes.imshow(masked, interpolation="none")
    plt.axis("off")

def plot_heatmap(axes, data, mask, do_rot90: bool, cbar_axes=None):
    data = np.rot90(data) if do_rot90 else data
    mask = np.rot90(mask) if do_rot90 else mask
    sns.heatmap(
        data,
        mask=mask == 0,
        square=True,
        xticklabels=False,
        yticklabels=False,
        ax=axes,
        vmin=0,
        vmax=0.06,
        cbar=True if cbar_axes is not None else False,
        cbar_ax=cbar_axes,
        cbar_kws={"orientation": "horizontal"},
    )


fig = plt.figure(figsize=(19, 15))
gs0 = gridspec.GridSpec(2, 2, height_ratios=[4, 1], figure=fig)
gs00 = gs0[0, 0].subgridspec(4, 3)

latent_sizes = [500, 250, 100, 50]

brain_mask = image.load_img(f"/media/maarten/disk1/MUDI/cdmri0015/brain_mask.nii.gz")
brain_mask = np.asanyarray(brain_mask.dataobj)

for i in range(gs00.nrows):
    data = get_predict_mse(df_runs_fcn_f, str(latent_sizes[i]), "FCN: no regularisation")
    s = data.shape

    plot_heatmap(fig.add_subplot(gs00[i, 0]), data[s[0] // 2], brain_mask[s[0] // 2], True)
    plot_heatmap(fig.add_subplot(gs00[i, 1]), data[:, s[1] // 2], brain_mask[:, s[1] // 2], True)
    plot_heatmap(fig.add_subplot(gs00[i, 2]), data[:, :, 40], brain_mask[:, :, 40], False)


gs01 = gs0[0, 1].subgridspec(4, 3)
gs03 = gs0[1, 1].subgridspec(5, 1)

for i in range(gs01.nrows):
    data = get_predict_mse(df_runs_fourier_f, str(latent_sizes[i]), "Fourier S2: no regularisation")
    s = data.shape

    plot_heatmap(fig.add_subplot(gs01[i, 0]), data[s[0] // 2], brain_mask[s[0] // 2], True)
    plot_heatmap(fig.add_subplot(gs01[i, 1]), data[:, s[1] // 2], brain_mask[:, s[1] // 2], True)
    plot_heatmap(fig.add_subplot(gs01[i, 2]), data[:, :, 40], brain_mask[:, :, 40], False, fig.add_subplot(gs03[-1, 0]))

gs02 = gs0[1, 0].subgridspec(1, 3)

for i in range(gs02.nrows):
    mask_3tt = image.get_data(image.load_img("/media/maarten/disk1/MUDI/cdmri0015/3tt.nii"))
    s = mask_3tt.shape

    plot_imshow_rgb(fig.add_subplot(gs02[i, 0]), mask_3tt[s[0] // 2], True)
    plot_imshow_rgb(fig.add_subplot(gs02[i, 1]), mask_3tt[:, s[1] // 2], True)
    plot_imshow_rgb(fig.add_subplot(gs02[i, 2]), mask_3tt[:, :, 40], False)

plt.tight_layout()
plt.savefig("../images/cdb_mse_voxel.png")

In [None]:
def load_dmri(row, subject=15, return_mse=False):
    target, prediction = predict(subject, row)

    mask_path = f"/media/maarten/disk1/MUDI/cdmri00{subject}/brain_mask.nii.gz"

    target_img = masking.unmask(np.transpose(target.numpy()), mask_path)
    target_img = image.get_data(target_img)
    prediction_img = masking.unmask(np.transpose(prediction.numpy()), mask_path)
    prediction_img = image.get_data(prediction_img)

    if return_mse:
        prediction_mse = torch.nn.functional.mse_loss(target, prediction, reduction="none")
        prediction_mse_img = masking.unmask(np.transpose(prediction_mse.numpy()), mask_path)
        prediction_mse_img = image.get_data(prediction_mse_img)
        prediction_mse_img = np.sum(prediction_mse_img, axis=3) / prediction_mse_img.shape[3]

        return target_img, prediction_img, prediction_mse_img
    else:
        return target_img, prediction_img

In [None]:
df_runs_fcn_f = df_runs_fcn.drop_duplicates(subset=["artifact_uri"])
df_runs_fcn_f = df_runs_fcn_f[df_runs_fcn_f.tags_input_size == "500"]

# df_runs_fcn_f[df_runs_fcn_f.tags_input_size == 500]

df_runs_fcn_f

In [None]:
df_runs_fourier_f = df_runs_fourier_s2.drop_duplicates(subset=["artifact_uri"])
df_runs_fourier_f = df_runs_fourier_f[df_runs_fourier_f.tags_input_size == "500"]
df_runs_fourier_f

In [None]:
_, _, prediction_img_fcn_mse = load_dmri(df_runs_fcn_f.iloc[-2], 15, return_mse=True)
_, _, prediction_img_fourier_mse = load_dmri(df_runs_fourier_f.iloc[1], 15, return_mse=True)

In [None]:
x = 22

target = np.rot90(target_img[x, :, :, 0])
predict_fcn = np.rot90(prediction_img_fcn[x, :, :, 0])
predict_fourier = np.rot90(prediction_img_fourier[x, :, :, 0])

# columns with all zeros
idx = np.argwhere(np.all(target[..., :] == 0, axis=0))

target = np.delete(target, idx, axis=1)
predict_fcn = np.delete(predict_fcn, idx, axis=1)
predict_fourier = np.delete(predict_fourier, idx, axis=1)

fig = px.imshow(
    np.array([target, predict_fcn, predict_fourier]),
    facet_col=0,
    template="seaborn",
)
fig.layout.annotations[0]["text"] = "Ground truth"
fig.layout.annotations[1]["text"] = "Prediction: FCN Decoder<br>(latent size=500, no regularisation)"
fig.layout.annotations[2]["text"] = "Prediction: Fourier S2 Decoder<br>(latent size=500, no regularisation)"
fig.update_layout(width=800, height=350, margin=dict(l=10, r=10, t=60, b=10))
fig.update_xaxes(showticklabels=False)
fig.update_yaxes(showticklabels=False)
fig.show()

In [None]:
# dir(image)

In [None]:
s = prediction_img_fcn_mse.shape

grid_kws = {"height_ratios": (0.45, 0.45, 0.05), "hspace": 0.1}
fig, axes = plt.subplots(3, 3, figsize=(15, 10), gridspec_kw=grid_kws)
# fig, (axes, cbar_ax) = plt.subplots(1, 3, figsize=(10,3), gridspec_kw=grid_kws)

sns.heatmap(
    np.rot90(prediction_img_fcn_mse[s[0] // 2]),
    mask=np.rot90(brain_mask[s[0] // 2] == 0),
    square=True,
    xticklabels=False,
    yticklabels=False,
    ax=axes[0, 0],
    cbar=False,
    vmin=0,
    vmax=0.06,
)
sns.heatmap(
    np.rot90(prediction_img_fcn_mse[:, s[1] // 2]),
    mask=np.rot90(brain_mask[:, s[1] // 2] == 0),
    square=True,
    xticklabels=False,
    yticklabels=False,
    ax=axes[0, 1],
    cbar=False,
    vmin=0,
    vmax=0.06,
).set(title="FCN decoder: no regularisation, latent size=500")
sns.heatmap(
    prediction_img_fcn_mse[:, :, 40],
    mask=brain_mask[:, :, 40] == 0,
    square=True,
    xticklabels=False,
    yticklabels=False,
    ax=axes[0, 2],
    vmin=0,
    vmax=0.06,
    cbar=False,
    # cbar_ax=cbar_ax[0],
    # cbar_kws={"orientation": "horizontal"},
)

sns.heatmap(
    np.rot90(prediction_img_fourier_mse[s[0] // 2]),
    mask=np.rot90(brain_mask[s[0] // 2] == 0),
    square=True,
    xticklabels=False,
    yticklabels=False,
    ax=axes[1, 0],
    cbar=False,
    vmin=0,
    vmax=0.06,
)

sns.heatmap(
    np.rot90(prediction_img_fourier_mse[:, s[1] // 2]),
    mask=np.rot90(brain_mask[:, s[1] // 2] == 0),
    square=True,
    xticklabels=False,
    yticklabels=False,
    ax=axes[1, 1],
    cbar=False,
    vmin=0,
    vmax=0.06,
).set(title="Fourier S2 decoder: no regularisation, latent size=500")

sns.heatmap(
    prediction_img_fourier_mse[:, :, 40],
    mask=brain_mask[:, :, 40] == 0,
    square=True,
    xticklabels=False,
    yticklabels=False,
    ax=axes[1, 2],
    vmin=0,
    vmax=0.06,
    cbar_ax=axes[2, 1],
    cbar_kws={"orientation": "horizontal"},
)
axes[2, 0].remove()
axes[2, 2].remove()

# sns.despine(left=True)

# plt.tight_layout()
plt.savefig("../images/mse_voxel.png")

In [None]:
def bkapp(doc):
    target_img, prediction_img = load_drmi(0)
    source = ColumnDataSource(dict(target=[], prediction=[]))

    x_max = target_img.shape[0] - 1
    y_max = target_img.shape[1] - 1
    z_max = target_img.shape[2] - 1
    max_values = [x_max, y_max, z_max]

    color_map = LogColorMapper(palette="Greys256", low=0.01, high=255)

    target_fig = figure(
        title="Truth",
        tooltips=[("X", "$sx"), ("Y", "$sy"), ("Value", "@target")],
        toolbar_location="below",
        output_backend="webgl",
    )
    target_fig.image(image="target", source=source, x=0, y=0, dw=10, dh=10, color_mapper=color_map)

    prediction_fig = figure(
        title="Prediction",
        tooltips=[("X", "$sx"), ("Y", "$sy"), ("Value", "@prediction")],
        x_range=target_fig.x_range,
        y_range=target_fig.y_range,
        toolbar_location="below",
        output_backend="webgl",
    )
    prediction_fig.image(
        image="prediction",
        source=source,
        x=0,
        y=0,
        dw=10,
        dh=10,
        color_mapper=color_map,
    )
    color_bar = ColorBar(color_mapper=color_map, label_standoff=12)
    prediction_fig.add_layout(color_bar, "right")

    options = list(
        zip(
            np.arange(len(df_runs)).astype(str),
            list(df_runs.run_id),
        )
    )

    model_select = Select(
        title="Model:",
        value="0",
        options=options,
    )
    slice_slider = Slider(start=0, end=x_max, value=0, step=1, title="Slice")
    feature_slider = Spinner(low=0, high=1343, value=0, step=1, title="Feature:")
    axis_radio = RadioButtonGroup(labels=["X", "Y", "Z"], active=0)

    def model_update():
        target_img, prediction_img = load_drmi(int(model_select.value))
        update()

    def update():
        feature_value = feature_slider.value
        axis_value = axis_radio.active

        slice_slider.end = max_values[axis_value]
        if slice_slider.value > slice_slider.end:
            slice_slider.value = slice_slider.end

        slice_value = slice_slider.value

        color_map.high = np.max(target_img)

        if axis_value == 0:  # X
            source.data = dict(
                target=[target_img[slice_value, :, :, feature_value]],
                prediction=[prediction_img[slice_value, :, :, feature_value]],
            )
        elif axis_value == 1:  # Y
            source.data = dict(
                target=[target_img[:, slice_value, :, feature_value]],
                prediction=[prediction_img[:, slice_value, :, feature_value]],
            )
        elif axis_value == 2:  # Z
            source.data = dict(
                target=[target_img[:, :, slice_value, feature_value]],
                prediction=[prediction_img[:, :, slice_value, feature_value]],
            )

    model_select.on_change("value", lambda attr, old, new: model_update())
    slice_slider.on_change("value", lambda attr, old, new: update())
    feature_slider.on_change("value", lambda attr, old, new: update())
    axis_radio.on_change("active", lambda attr, old, new: update())

    layout = row(
        column(model_select, axis_radio, feature_slider, slice_slider),
        target_fig,
        prediction_fig,
    )
    doc.add_root(layout)

In [None]:
os.environ["BOKEH_ALLOW_WS_ORIGIN"] = "127.0.0.1:8888"
show(bkapp)

In [None]:
show(bkapp)