In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [2]:
import torch
import ast
import numpy as onp
import seaborn as sns
from matplotlib import pyplot as plt

from Problems.CNOBenchmarks import Darcy, Airfoil, DiscContTranslation, ContTranslation, AllenCahn, SinFrequency, WaveEquation, ShearLayer
from Problems.PDEArenaBenchmarks import StandardNavierStokes

In [3]:
plt.rcParams.update(plt.rcParamsDefault)
plt.rc('font', family='serif')
plt.rcParams.update({
                      "text.usetex": True,
                      "font.family": "serif",
                      'text.latex.preamble': r'\usepackage{amsmath}'})

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
def load_model(root_folder, which_example, device, batch_size, in_dist, prefix = 'CNO_', postfix = "_1"):
    
    folder = root_folder + "/" + prefix + which_example + postfix

    model_architecture_ = dict()
    with open(folder + "/net_architecture.txt") as f:
        for line in f:
            key_values = line.replace("\n", "").split(",")
            key = key_values[0]
            if key_values[0] == 'activation':
                value = key_values[1]
                model_architecture_[key] = value
            else:
                value = ast.literal_eval(key_values[1])
                model_architecture_[key] = int(value)

    with open(folder + "/training_properties.txt") as f:
        for line in f:
            key, value = line.replace("\n", "").split(",")
            if key == 'training_samples':
                training_samples = int(ast.literal_eval(value))

    


    if which_example == "shear_layer":
        example = ShearLayer(model_architecture_, device, batch_size, training_samples, size = model_architecture_['in_size'], in_dist = in_dist)
    elif which_example == "poisson":
        example = SinFrequency(model_architecture_, device, batch_size, training_samples, in_dist = in_dist)
    elif which_example == "wave_0_5":
        example = WaveEquation(model_architecture_, device, batch_size, training_samples, in_dist = in_dist)
    elif which_example == "allen":
        example = AllenCahn(model_architecture_, device, batch_size, training_samples, in_dist = in_dist)
    elif which_example == "cont_tran":
        example = ContTranslation(model_architecture_, device, batch_size, training_samples, in_dist = in_dist)
    elif which_example == "disc_tran":
        example = DiscContTranslation(model_architecture_, device, batch_size, training_samples, in_dist = in_dist)
    elif which_example == "airfoil":
        example = Airfoil(model_architecture_, device, batch_size, training_samples, in_dist = in_dist)
    elif which_example == "darcy":
        example = Darcy(model_architecture_, device, batch_size, training_samples, in_dist = in_dist)
    elif which_example == "ns":
        example = StandardNavierStokes(model_architecture_, device, batch_size, training_samples, size = None, in_dist = in_dist)
    else:
        raise ValueError()
    
    example.model = torch.load(folder + "/model.pkl", map_location=torch.device(device))
        
    return example, folder

In [6]:
def relative_errors(prediction, target, p, axis = [-2, -1], reduction_axis = 1):
    # NOTE reduction_axis is the feature axis. The relative errors on each feature dimension are averaged in this case.
    diff = torch.abs(prediction - target)
    target = torch.abs(target)

    error =  torch.sum(diff ** p, axis = axis) / torch.sum(target ** p, axis = axis)
    error = torch.pow(error, 1/p)
    return error.mean(reduction_axis)

In [7]:
def calculate_error_distribution(example):
    with torch.no_grad():
        example.model.eval()

        errors = []

        for inputs, outputs in example.test_loader:
            outputs = outputs.to(device)
            predictions = example.model(inputs.to(device))

            errors.append(relative_errors(predictions, outputs, 2).detach().cpu().numpy().flatten())

        errors = onp.concatenate(errors, 0)

    return errors
        

In [8]:
root_folder = '/scratch/wangh19/ConvolutionalNeuralOperator/TrainedModels'
CNO_examples = [
    "shear_layer",
    "poisson",
    "wave_0_5",
    "allen",
    "cont_tran",
    "disc_tran",
    "airfoil",
    "darcy",
]

show = False

In [9]:
for which_example in CNO_examples + ['ns']:

    example, example_folder = load_model(root_folder, which_example, device, batch_size=16, in_dist = True)
    errors = calculate_error_distribution(example)

    sns.displot(errors * 100, kind = "kde")
    plt.xlabel("Test Error Distribution (in \%)")
    plt.savefig(example_folder + "/test_error_distribution.png", bbox_inches = 'tight', dpi = 300)
    if show:
        plt.show()
    plt.close()

    if not which_example == "ns":
        example, _ = load_model(root_folder, which_example, device, batch_size=16, in_dist = False)
        errors_ood = calculate_error_distribution(example)


        sns.displot(errors_ood * 100, kind = "kde")
        plt.xlabel("OOD Error Distribution (in \%)")
        plt.savefig(example_folder + "/ood_error_distribution.png", bbox_inches = 'tight', dpi = 300)
        if show:
            plt.show()
        plt.close()

        print(f"{which_example: <12}, in dist test rel. l2 error {errors.mean() * 100:.2f}%, out dist test rel. l2 error {errors_ood.mean() * 100:.2f}%")
    else:
        print(f"{which_example: <12}, in dist test rel. l2 error {errors.mean() * 100:.2f}%")

Setting up PyTorch plugin "filtered_lrelu_plugin"... Done.
shear_layer , in dist test rel. l2 error 3.85%, out dist test rel. l2 error 16.67%
poisson     , in dist test rel. l2 error 1.18%, out dist test rel. l2 error 529.34%
wave_0_5    , in dist test rel. l2 error 0.57%, out dist test rel. l2 error 6.24%
allen       , in dist test rel. l2 error 1.11%, out dist test rel. l2 error 49.40%
cont_tran   , in dist test rel. l2 error 0.21%, out dist test rel. l2 error 0.32%
disc_tran   , in dist test rel. l2 error 1.82%, out dist test rel. l2 error 2.10%
airfoil     , in dist test rel. l2 error 6.82%, out dist test rel. l2 error 6.98%
darcy       , in dist test rel. l2 error 0.66%, out dist test rel. l2 error 0.44%
ns          , in dist test rel. l2 error 5.84%
