In [10]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '5'

In [11]:
import torch
import ast
import numpy as onp
import seaborn as sns
from matplotlib import pyplot as plt

from Problems.CNOBenchmarks import Darcy, Airfoil, DiscContTranslation, ContTranslation, AllenCahn, SinFrequency, WaveEquation, ShearLayer
from Problems.PDEArenaBenchmarks import StandardNavierStokes

In [12]:
plt.rcParams.update(plt.rcParamsDefault)
plt.rc('font', family='serif')
plt.rcParams.update({
                      "text.usetex": True,
                      "font.family": "serif",
                      'text.latex.preamble': r'\usepackage{amsmath}'})

In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [14]:
def load_model(root_folder, which_example, device, batch_size, in_dist, prefix = 'CNO_', postfix = "_1"):
    
    folder = root_folder + "/" + prefix + which_example + postfix

    model_architecture_ = dict()
    with open(folder + "/net_architecture.txt") as f:
        for line in f:
            key_values = line.replace("\n", "").split(",")
            key = key_values[0]
            if key_values[0] == 'activation':
                value = key_values[1]
                model_architecture_[key] = value
            else:
                value = ast.literal_eval(key_values[1])
                model_architecture_[key] = int(value)

    with open(folder + "/training_properties.txt") as f:
        for line in f:
            key, value = line.replace("\n", "").split(",")
            if key == 'training_samples':
                training_samples = int(ast.literal_eval(value))

    


    if which_example == "shear_layer":
        example = ShearLayer(model_architecture_, device, batch_size, training_samples, size = model_architecture_['in_size'], in_dist = in_dist)
    elif which_example == "poisson":
        example = SinFrequency(model_architecture_, device, batch_size, training_samples, in_dist = in_dist)
    elif which_example == "wave_0_5":
        example = WaveEquation(model_architecture_, device, batch_size, training_samples, in_dist = in_dist)
    elif which_example == "allen":
        example = AllenCahn(model_architecture_, device, batch_size, training_samples, in_dist = in_dist)
    elif which_example == "cont_tran":
        example = ContTranslation(model_architecture_, device, batch_size, training_samples, in_dist = in_dist)
    elif which_example == "disc_tran":
        example = DiscContTranslation(model_architecture_, device, batch_size, training_samples, in_dist = in_dist)
    elif which_example == "airfoil":
        example = Airfoil(model_architecture_, device, batch_size, training_samples, in_dist = in_dist)
    elif which_example == "darcy":
        example = Darcy(model_architecture_, device, batch_size, training_samples, in_dist = in_dist)
    elif which_example == "ns":
        example = StandardNavierStokes(model_architecture_, device, batch_size, training_samples, size = None, in_dist = in_dist)
    else:
        raise ValueError()
    
    example.model = torch.load(folder + "/model.pkl", map_location=torch.device(device))
        
    return example, folder

In [15]:
def relative_errors(prediction, target, p, axis = [-2, -1], reduction_axis = 1):
    # NOTE reduction_axis is the feature axis. The relative errors on each feature dimension are averaged in this case.
    diff = torch.abs(prediction - target)
    target = torch.abs(target)

    error =  torch.sum(diff ** p, axis = axis) / torch.sum(target ** p, axis = axis)
    error = torch.pow(error, 1/p)
    return error.mean(reduction_axis)

def relative_errors_orig(prediction, target):
    diff = torch.abs(prediction - target).mean()
    target = torch.abs(target).mean()
    return diff / target


In [16]:
def calculate_error_distribution(example):
    with torch.no_grad():
        example.model.eval()

        errors = []

        for inputs, outputs in example.test_loader:
            outputs = outputs.to(device)
            predictions = example.model(inputs.to(device))

            errors.append(relative_errors(predictions, outputs, 2).detach().cpu().numpy().flatten())

        errors = onp.concatenate(errors, 0)

    return errors

def calculate_error_distribution_orig(example):
    with torch.no_grad():
        example.model.eval()

        errors = []

        for inputs, outputs in example.test_loader:
            outputs = outputs.to(device)
            predictions = example.model(inputs.to(device))

            errors.append(relative_errors_orig(predictions, outputs).detach().cpu().numpy().flatten())

        errors = onp.concatenate(errors, 0)

    return errors
        

In [17]:

CNO_examples = [
    "shear_layer",
    "poisson",
    "wave_0_5",
    "allen",
    "cont_tran",
    "disc_tran",
    "airfoil",
    "darcy",
]

show = False

In [18]:
for root_folder in ['results/cno_rel_l1_TrainedModels', 'results/mse_TrainedModels', 'results/rel_l1_TrainedModels', 'results/rel_l2_TrainedModels']:

    f = open(root_folder + '/summary.txt', mode = 'w')

    print(root_folder)
    for which_example in CNO_examples + ['ns']:

        try:

            example, example_folder = load_model(root_folder, which_example, device, batch_size=16, in_dist = True)
            errors = calculate_error_distribution(example)

            sns.displot(errors * 100, kind = "kde")
            plt.xlabel("Test Error Distribution (in \%)")
            plt.savefig(example_folder + "/test_error_distribution.png", bbox_inches = 'tight', dpi = 300)
            if show:
                plt.show()
            plt.close()

            if not which_example == "ns":
                example, _ = load_model(root_folder, which_example, device, batch_size=16, in_dist = False)
                errors_ood = calculate_error_distribution(example)
                # example, _ = load_model(root_folder, which_example, device, batch_size=16, in_dist = False)
                errors_ood_orig = calculate_error_distribution_orig(example)


                sns.displot(errors_ood * 100, kind = "kde")
                plt.xlabel("OOD Error Distribution (in \%)")
                plt.savefig(example_folder + "/ood_error_distribution.png", bbox_inches = 'tight', dpi = 300)
                if show:
                    plt.show()
                plt.close()

                text = f"{which_example: <12}, in dist test rel. l2 error {errors.mean() * 100:^8.2f}%, out dist test rel. l2 error {errors_ood.mean() * 100:^8.2f}%, original out dist test rel. l2 error {errors_ood_orig.mean() * 100:<3.2f}%"

                print(text)
                f.write(text + "\n")
            else:
                text = f"{which_example: <12}, in dist test rel. l2 error {errors.mean() * 100:^8.2f}%"
                print(text)
                f.write(text + "\n")

        except FileNotFoundError:
            pass
    f.close()

results/cno_rel_l1_TrainedModels
shear_layer , in dist test rel. l2 error   3.92  %, out dist test rel. l2 error  16.23  %, original out dist test rel. l2 error 14.16%
poisson     , in dist test rel. l2 error   1.52  %, out dist test rel. l2 error  675.56 %, original out dist test rel. l2 error 585.31%
wave_0_5    , in dist test rel. l2 error   0.57  %, out dist test rel. l2 error   6.23  %, original out dist test rel. l2 error 5.15%
allen       , in dist test rel. l2 error   0.98  %, out dist test rel. l2 error  49.26  %, original out dist test rel. l2 error 42.34%
cont_tran   , in dist test rel. l2 error   0.22  %, out dist test rel. l2 error   0.34  %, original out dist test rel. l2 error 0.40%
disc_tran   , in dist test rel. l2 error   1.66  %, out dist test rel. l2 error   1.70  %, original out dist test rel. l2 error 1.14%
airfoil     , in dist test rel. l2 error  10.08  %, out dist test rel. l2 error  10.10  %, original out dist test rel. l2 error 2.90%
results/mse_TrainedModels