In [15]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [16]:
import random

import h5py
import matplotlib.pylab as plt
import numpy as np
import pandas as pd
#import seaborn as sn
import torch
import ast
import scipy
import torch.nn as nn

#import seaborn as sns
import pandas as pd
from scipy import stats

import scipy
from random import randint

In [17]:
def load_data(folder, which_model, device, which_example, in_size = 64, batch_size = 32, training_samples = 1, in_dist = True):
    
    if which_model == "CNO":
        from Problems.CNOBenchmarks_new_normalization import Darcy, Airfoil, DiscContTranslation, ContTranslation, AllenCahn, SinFrequency, WaveEquation, ShearLayer
    elif which_model == "FNO":
        from Problems.FNOBenchmarks import Darcy, Airfoil, DiscContTranslation, ContTranslation, AllenCahn, SinFrequency, WaveEquation, ShearLayer

    model_architecture_ = dict()
    with open(folder + "/net_architecture.txt") as f:
        for line in f:
            key_values = line.replace("\n", "").split(",")
            key = key_values[0]
            # NOTE: activation is str and cannot be converted to int
            try:
                value = ast.literal_eval(key_values[1])
                model_architecture_[key] = int(value)
            except ValueError:
                pass
                
    model_architecture_["in_size"] = in_size
    
    if which_model == "CNO" or which_model == "UNET":
        if which_example == "shear_layer":
            example = ShearLayer(model_architecture_, device, batch_size, training_samples, in_dist = in_dist)
        elif which_example == "poisson":
            example = SinFrequency(model_architecture_, device, batch_size, in_dist = in_dist)
        elif which_example == "wave_0_5":
            example = WaveEquation(model_architecture_, device, batch_size, in_dist = in_dist)
        elif which_example == "allen":
            example = AllenCahn(model_architecture_, device, batch_size, in_dist = in_dist)
        elif which_example == "cont_tran":
            example = ContTranslation(model_architecture_, device, batch_size, in_dist = in_dist)
        elif which_example == "darcy":
            example = Darcy(model_architecture_, device, batch_size, in_dist = in_dist)
        elif which_example == "disc_tran":
            example = DiscContTranslation(model_architecture_, device, batch_size, training_samples, in_dist = in_dist)
        elif which_example == "airfoil":
            example = Airfoil(model_architecture_, device, batch_size, training_samples, in_dist = in_dist)

        else:
            raise ValueError()
    
    elif which_model == "FNO":
        
        if which_example == "shear_layer":
            example = ShearLayer(model_architecture_, device, batch_size, training_samples, in_dist = in_dist)
        elif which_example == "poisson":
            example = SinFrequency(model_architecture_, device, batch_size, s = 64, in_dist = in_dist)
        elif which_example == "wave_0_5":
            example = WaveEquation(model_architecture_, device, batch_size, in_dist = in_dist)
        elif which_example == "allen":
            example = AllenCahn(model_architecture_, device, batch_size, in_dist = in_dist)
        elif which_example == "cont_tran":
            example = ContTranslation(model_architecture_, device, batch_size, training_samples, in_dist = in_dist)
        elif which_example == "darcy":
            example = Darcy(model_architecture_, device, batch_size, training_samples, in_dist = in_dist)
        elif which_example == "disc_tran":
            example = DiscContTranslation(model_architecture_, device, batch_size,training_samples, in_dist = in_dist)
        elif which_example == "airfoil":
            example = Airfoil(model_architecture_, device, batch_size, training_samples, in_dist = in_dist)
        else:
            raise ValueError()
    
    testing_loader = example.test_loader
    
    # NOTE: also add example as output
    return testing_loader

In [18]:
def error_distribution(which_model, model, testing_loader, p, N, device, which = "shear_layer"):
    E_diss = np.zeros(N)
    cnt = 0
    
    with torch.no_grad():
        for i, (inputs, outputs) in enumerate(testing_loader):
            
            batch = inputs.shape[0]            

            if i>=N:
                break
                
            if which_model == "CNO" or which_model == "UNET":
                
                inputs = inputs.to(device)
                outputs = outputs.to(device)
                prediction = model(inputs)

                if which == "airfoil":
                    outputs[inputs==1] = 1
                    prediction[inputs==1] = 1
                
                err = (torch.mean(abs(prediction[:,0,:,:] - outputs[:,0,:,:]) ** p, (-2, -1)) / torch.mean(abs(outputs[:,0,:,:]) ** p, (-2, -1))) ** (1 / p) * 100
                
            elif which_model == "FNO":
                inputs = inputs.to(model.device)
                outputs = outputs.to(model.device)
                prediction = model(inputs)
                
                if which == "airfoil":
                    outputs[inputs==1] = 1
                    prediction[inputs==1] = 1
                
                err = (torch.mean(abs(prediction[:,:, :, 0] - outputs[:, :, :, 0]) ** p, (-2, -1)) / torch.mean(abs(outputs[:, :, :,0]) ** p, (-2, -1))) ** (1 / p) * 100
            
            else:
                return None
            
            E_diss[cnt: cnt + batch] = err.detach().cpu().numpy()
            cnt+=batch
    
            
    
    return E_diss

In [19]:
device = 'cuda:0'
which_model = 'poisson'
folder = 'TrainedReportedModels_NoEarlyStopping_new_normalization/CNO_poisson_1'
test_loader = load_data(folder, 'CNO', device, which_model, in_size = 64, batch_size = 32, training_samples = 1, in_dist = True)
ood_test_loader= load_data(folder, 'CNO', device, which_model, in_size = 64, batch_size = 32, training_samples = 1, in_dist = False)

In [28]:
model = torch.load('eth/eth_poisson_model.pkl', map_location=torch.device(device))
test_errors = error_distribution('CNO', model, test_loader, 1, 256, device, which = "shear_layer")
ood_test_errors = error_distribution('CNO', model, ood_test_loader, 1, 256, device, which = "shear_layer")
print(np.median(test_errors))
print(np.median(ood_test_errors))

0.2109132930636406
0.2718319892883301


In [29]:
model = torch.load('TrainedReportedModels_NoEarlyStopping_new_normalization/CNO_poisson_1/model.pkl', map_location=torch.device(device))
test_errors = error_distribution('CNO', model, test_loader, 1, 256, device, which = "shear_layer")
ood_test_errors = error_distribution('CNO', model, ood_test_loader, 1, 256, device, which = "shear_layer")
print(np.median(test_errors))
print(np.median(ood_test_errors))

0.2851915657520294
0.43258820474147797
