In [None]:
import os, sys

cwd = os.getcwd()
parent_dir = os.path.dirname(cwd)
base_dir = os.path.dirname(parent_dir)
main_src_dir = os.path.join(base_dir, "src")
file_src_dir = os.path.join(cwd, "src")


for p in (main_src_dir, file_src_dir):
    if p not in sys.path:
        sys.path.insert(0, p)

from evaluate import ModelPhantom, EvaluationPackage, metrics # From main src dir 
import eval_grad_plots

## Define paths

In [None]:
#############################
######## Which eMNS? ########
#############################

emns = "octomag" # "octomag" or "navion"

if emns == "octomag":
    package_dir = cwd + "/evaluation_packages/"
elif emns == "navion":
    package_dir = cwd + "/evaluation_packages/navion/"

test_package_name = "test_eval_pack.pkl"
training_package_name = "train_eval_pack.pkl"

## Load packages

In [None]:
test_package, training_package = None, None

try:
    print("Loading test set evaluation packages...")
    test_package = EvaluationPackage.load_from(package_dir + test_package_name)
    print("Successfully loaded test set evaluation packages.\n")
except Exception as e:
    print(f"Error loading test set evaluation package: {e}\n")

try:
    print("Loading training set evaluation packages...")
    training_package = EvaluationPackage.load_from(package_dir + training_package_name)
    print("Successfully loaded training set evaluation packages.\n")
except Exception as e:
    print(f"Error loading training set evaluation package: {e}\n")

## Define gradient metrics to compute, and models to focus on

In [None]:
###############################
####### Which metrics? ########
###############################

metrics_list = [
    metrics.grad_curl_div,
]

##############################
####### Which models? ########
##############################

models = [
    ModelPhantom(name="ActuationNet", dataset_percentage=100, structure=(512, 512, 512)),
    ModelPhantom(name="ActuationNet", dataset_percentage=100, structure=(256, 256, 256)),
    ModelPhantom(name="ActuationNet", dataset_percentage=100, structure=(256, 256)),
    ModelPhantom(name="ActuationNet", dataset_percentage=50, structure=(512, 512, 512)),
    ModelPhantom(name="ActuationNet", dataset_percentage=20, structure=(512, 512, 512)),
    ModelPhantom(name="ActuationNet", dataset_percentage=5, structure=(512, 512, 512)),
    ModelPhantom(name="ActuationNet", dataset_percentage=1, structure=(512, 512, 512)),
    ModelPhantom(name="PotentialNet", dataset_percentage=100, structure=(512, 512, 512)),
    ModelPhantom(name="PotentialNet", dataset_percentage=100, structure=(256, 256, 256)),
    ModelPhantom(name="PotentialNet", dataset_percentage=100, structure=(256, 256)),
    ModelPhantom(name="PotentialNet", dataset_percentage=50, structure=(512, 512, 512)),
    ModelPhantom(name="PotentialNet", dataset_percentage=20, structure=(512, 512, 512)),
    ModelPhantom(name="PotentialNet", dataset_percentage=5, structure=(512, 512, 512)),
    ModelPhantom(name="PotentialNet", dataset_percentage=1, structure=(512, 512, 512)),
    ModelPhantom(name="DirectNet", dataset_percentage=100, structure=(512, 512, 512)),
    ModelPhantom(name="DirectNet", dataset_percentage=100, structure=(256, 256, 256)),
    ModelPhantom(name="DirectNet", dataset_percentage=100, structure=(256, 256)),
    ModelPhantom(name="DirectNet", dataset_percentage=50, structure=(512, 512, 512)),
    ModelPhantom(name="DirectNet", dataset_percentage=20, structure=(512, 512, 512)),
    ModelPhantom(name="DirectNet", dataset_percentage=5, structure=(512, 512, 512)),
    ModelPhantom(name="DirectNet", dataset_percentage=1, structure=(512, 512, 512)),
    ModelPhantom(name="DirectGBT", dataset_percentage=100, structure=(128,)),
    ModelPhantom(name="DirectGBT", dataset_percentage=100, structure=(64,)),
    ModelPhantom(name="DirectGBT", dataset_percentage=100, structure=(32,)),
    ModelPhantom(name="DirectGBT", dataset_percentage=50, structure=(128,)),
    ModelPhantom(name="DirectGBT", dataset_percentage=20, structure=(128,)),
    ModelPhantom(name="DirectGBT", dataset_percentage=5, structure=(128,)),
    ModelPhantom(name="DirectGBT", dataset_percentage=1, structure=(128,)),
]

full_models = [model for model in models if model.dataset_percentage == 100]
large_models = [model for model in models if model.structure == (512, 512, 512) or model.structure == (128,)]
large_full_models = [model for model in large_models if model.dataset_percentage == 100]

## Evaluate test set performance

In [None]:
# Compute metrics
for metric_function in metrics_list:
    print(f"Computing metric function: {metric_function.__name__}...\n")
    test_package.apply_gradient_metric(metric_function)

## Plot divergence and curl

### Full

In [None]:
eval_grad_plots.div_curl_plot(
    test_package,
    large_full_models,
    title="",
    verbose = False
)