In [None]:
import torch
from data_module import DataModule
from utility.utility import get_args
import numpy as np
from models.nodf import NODF
import nibabel as nib
from models.posterior import FVRF
from utility.utility import get_mask

%load_ext autoreload
%autoreload 2

## Args

In [None]:
args = get_args(cmd=False)
# TODO: modify arguments here if needed
# args.ckpt_path = "..."

## Data

In [None]:
data_module = DataModule(args)
data_module.setup("fit")

In [None]:
dataset = data_module.dataset
dataloader = data_module.train_dataloader()
coords = dataloader.dataset.coords
coords.shape

In [None]:
batch = next(iter(dataloader))
batch

## Model

In [None]:
if args.ckpt_path:
    print("Loading model from checkpoint")
    model = NODF.load_from_checkpoint(args.ckpt_path).cpu()
else:
    model = NODF(args)

In [None]:
model

In [None]:
model.count_parameters()

## Forward Pass

In [None]:
# ODF coefficients
chat = model(batch)
chat.shape

## Posterior

In [None]:
# posterior = FVRF(args)

In [None]:
# get roi
mask = get_mask(args)
# axial
mask[:168] = False
mask[169:] = False
# sagittal
mask[:, :74] = False
mask[:, 88:] = False
# coronal
mask[:, :, :67] = False
mask[:, :, 85:] = False

In [None]:
# generate posterior samples
# post_samples_chat = posterior.sample_posterior(mask)

## Bayesian Optimizatoin

In [None]:
from ax.service.ax_client import AxClient
from ax.plot.contour import plot_contour
from ax.plot.trace import optimization_trace_single_method
from ax.service.ax_client import AxClient
from ax.utils.notebook.plotting import render, init_notebook_plotting
from ax.utils.tutorials.cnn_utils import load_mnist, train, evaluate

init_notebook_plotting()

In [None]:
# TODO: set path
bo_client_path = "XXX"
map_ax_client = AxClient.load_from_json_file(filepath=bo_client_path)

In [None]:
# show best trial
map_ax_client.get_best_trial()

In [None]:
# fit model that predicts the hypter parameters
map_ax_client.fit_model()

In [None]:
# plot hyperparameter distributions
render(map_ax_client.get_contour_plot(param_x='n_levels', param_y='base_resolution'))

In [None]:
# plot objective line plot
best_objectives = np.array([[trial.objective_mean for trial in map_ax_client.experiment.trials.values()]])
best_objective_plot = optimization_trace_single_method(
    y=best_objectives,
    title="Model performance vs. # of iterations",
    ylabel="Accuracy",
)
render(best_objective_plot)