In [1]:
import tensorflow as tf
import numpy as np
import pandas as pd
import math
import time
from matplotlib import pyplot as plt

In [2]:
# Initialize Global Variables

NUM_MODELS = 10

In [3]:
# Initialize Dataset

# Read .csv files for training

xdata = pd.read_csv('ML_inputs_tear.csv')
ydata = pd.read_csv('U_tear.csv').T

# Adding Constant Parameters to xdata
xdata.insert(0, 'c', 6)
xdata.insert(1, 'D', .05)
xdata.insert(2, 'k1', 15)
xdata.insert(3, 'k2', 15)
xdata.insert(4, 'N', 2)
xdata.insert(5, 'kappa_1', .25)
xdata.insert(6, 'kappa_2', .25)
xdata.insert(7, 'kappa_3', .25)

# Switch to numpy array
x = xdata.to_numpy()
y = ydata.to_numpy()


In [4]:
# Import Models

annmodels = []

for i in range(NUM_MODELS):
    annmodels.append(tf.keras.models.load_model('model' + str(i) + '.h5'))

In [5]:
# Timing Calculations
modelTime = []

for model in annmodels:
    
    # Perform 50 timed predictions, average result
    tmpTime = []
    for i in range(50):
        start = time.time()
        model.predict(x)
        end = time.time()
        tmpTime.append(end - start)
        
    modelTime.append(np.mean(tmpTime))

In [6]:
# Error Calculations
modelError = []

for model in annmodels:
    # find the relative frobenius error of the prediction
    pred = model.predict(x)
    
    toterr = 0
    for i in range(len(x)):
        Norm1 = np.linalg.norm(y[i] - pred[i], 2)
        Norm2 = np.linalg.norm(pred[i], 2)
        toterr = toterr + ((Norm1**2)/(Norm2**2))
        
    modelError.append(math.sqrt(toterr))

In [8]:
export = pd.DataFrame({'offline time': modelTime, 'error': modelError})
onlineTime = pd.read_csv('onlineTime.csv')

export = pd.concat((onlineTime, export), axis = 1)
export.to_csv('MLModelData.csv', encoding = 'utf-8', index=False)
export

Unnamed: 0,online time,offline time,error
0,7.276224,0.048223,0.135837
1,5.599475,0.034671,0.062366
2,5.5925,0.035692,0.072565
3,5.595942,0.034637,0.077484
4,5.709719,0.035693,0.10336
5,5.478569,0.036794,0.078452
6,5.45663,0.03585,0.086418
7,5.810365,0.03996,0.065731
8,6.055452,0.036042,0.169024
9,5.873522,0.035463,0.149312
