In [110]:
import numpy as np
import re
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression

DEGREE = 3

# Notebook for fitting a polynomial for the correction exponent for glass dielectrics 
# given a certain IOR and roughness
#
# The 'values'  array has been eyeballed manually to try and minimize the visual energy loss/gain
# The idea of the polynomial is to replace the massive if(),  else if(), else if() that would have been
# needed otherwise
#
# TURNS OUT THAT THE FITTING ERROR IS TOO LARGE AND THE POLYNOMIAL
# FITTED BY THIS SCRIPT IS THUS NEVER USED IN THE RENDERER
# 
# Instead, we just translate the double entry table IO-Roughness into a massive and disgusting
# if(), else if(). This is exactly what we wanted to avoid but in the end, it's efficient and fits well.
# The only single downside is that it looks disgusting but who cares?

ior_values = [1.01, 1.02, 1.03, 1.1, 1.2, 1.4, 1.5, 2.0, 2.4, 3.0]
roughness_values = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]

# Table of output values (for each combination of IOR and roughness)
# For example, these could represent reflectance values:
values = np.array([
    [2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5],  # Roughness 0.0
    [2.5, 2.5, 2.5, 2.5, 1.8, 1.8, 1.6, 1.5, 1.5, 1.5],  # Roughness 0.1
    [2.3, 2.3, 2.3, 2.3, 2.3, 2.3, 2.3, 2.2, 2, 1.7],  # Roughness 0.2
    [2.4, 2.4, 2.4, 2.38, 2.38, 2.38, 2.38, 2.38, 2.44, 2.38],  # Roughness 0.3
    [2.45, 2.475, 2.475, 2.475, 2.475, 2.475, 2.475, 2.475, 2.475, 2.475],  # Roughness 0.4
    [2.4665, 2.51, 2.51, 2.54, 2.55, 2.7, 2.7, 2.75, 3, 2.9],  # Roughness 0.5
    [2.52, 2.54, 2.544, 2.575, 2.65, 2.875, 2.95, 3.5, 3.8, 3.8],  # Roughness 0.6
    [2.55, 2.565, 2.565, 2.61, 2.675, 2.925, 3.1, 4.85, 7, 7.5],  # Roughness 0.7
    [2.55, 2.57, 2.58, 2.63, 2.7, 2.95, 3.1, 6, 10, 12],  # Roughness 0.8
    [2.585, 2.59, 2.6, 2.6, 2.675, 2.8, 3.05, 7, 12, 13.75],  # Roughness 0.9
    [2.5, 2.5, 2.5, 2.5, 2.5, 2.55, 2.57, 2.57, 3.9, 2.5],  # Roughness 1.0
])

# Create combinations of IOR and roughness
ior, roughness = np.meshgrid(ior_values, roughness_values)  # Create grid
ior = ior.ravel()  # Flatten to 1D array
roughness = roughness.ravel()  # Flatten to 1D array
outputs = values.ravel()  # Flatten table of outputs

# Prepare design matrix for polynomial features
X = np.column_stack((ior, roughness))
poly = PolynomialFeatures(degree=DEGREE, include_bias=False)  # Bias=False avoids adding constant
X_poly = poly.fit_transform(X)

# Fit the polynomial model
model = LinearRegression()
model.fit(X_poly, outputs)

# Extract coefficients
intercept = model.intercept_  # Constant term
coefficients = model.coef_  # Remaining terms

# Display results
# print("Intercept (constant term):", intercept)
# print("Coefficients (for polynomial terms):", coefficients)

# Polynomial interpretation
terms = poly.get_feature_names_out(["relative_eta", "roughness"])
for term, coef in zip(terms, coefficients):
    # Replace '^2' with ' * IOR' or ' * Roughness' for terms like IOR^2
    term = re.sub(r"(relative_eta|roughness)\^([0-9]+)", lambda m: "*".join([m.group(1)] * int(m.group(2))), term)

    # If the term has more than one factor, join them with ' * '
    formatted_term = " * ".join(term.split(" "))
    
    print(f"{formatted_term} * {coef}", end='+')

from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

# Predict the values using the fitted model
predicted_outputs = model.predict(X_poly)

# Calculate error metrics
mse = mean_squared_error(outputs, predicted_outputs)
rmse = np.sqrt(mse)
mae = mean_absolute_error(outputs, predicted_outputs)

# Display the errors
print("\n\nFitting Errors:")
print(f"Mean Squared Error (MSE): {mse}")
print(f"Root Mean Squared Error (RMSE): {rmse}")
print(f"Mean Absolute Error (MAE): {mae}")

#######
# Printing the massive if(), else if() block so that it is ready to be copy pasted in the shader
#######

print()
print()

print("float lower_relative_eta_bound;")
print("float lower_correction;")
for i in range(0, len(ior_values)):
    relative_eta = ior_values[i]
    next_relative_eta = 1
    if (i != len(ior_values) - 1):
        next_relative_eta = ior_values[i + 1]

    if (i == 0):
        print("if (", end='');
    else:
        print("else if (", end='');
    if (i != len(ior_values) - 1):
        print("relative_eta > " + str(relative_eta) + "f && relative_eta <= " + str(next_relative_eta) + "f)\n{")
    else:
        print("relative_eta > " + str(relative_eta) + "f)\n{")
    
    print("\tlower_relative_eta_bound = " + str(relative_eta) + "f;\n")
    for j in range(0, len(roughness_values)):
        roughness = roughness_values[j]
        str_roughness = ""
        str_roughness_minus_1 = ""
        if (j == 0):
            str_roughness = "0.0f"
        elif (j == len(roughness_values) - 1):
            str_roughness = "1.0f"
            str_roughness_minus_1 = str(round(roughness - 1 / (len(roughness_values) - 1), 1)) + "f"
        else:
            str_roughness = str(roughness) + "f"
            str_roughness_minus_1 = str(round(roughness - 1 / (len(roughness_values) - 1), 1)) + "f"

        print("\t", end='')
        if (j == 0):
            print("if (", end='')
        else:
            print("else if (", end='')

        print("roughness <= " + str_roughness + ")")

        if (j == 0):
            print("\t\tlower_correction = " + str(values[j][i]) +  "f;")
        else:
            lower_lerp = values[j - 1][i]
            higher_lerp = values[j][i]
            if (lower_lerp == higher_lerp):
                print("\t\tlower_correction = " + str(lower_lerp) + "f;")
            else:
                print("\t\tlower_correction = hippt::lerp(" + str(lower_lerp) + "f, " + str(higher_lerp) + "f, (roughness - " + str_roughness_minus_1 + ") / 0.1f);")

    print("}")
    
print("float higher_relative_eta_bound;")
print("float higher_correction;")
for i in range(0, len(ior_values)):
    relative_eta = ior_values[i]

    if (i == 0):
        print("if (", end='');
    else:
        print("else if (", end='');
    print("relative_eta <= " + str(relative_eta) + "f)\n{")

    print("\thigher_relative_eta_bound = " + str(relative_eta) + "f;\n")
    for j in range(0, len(roughness_values)):
        roughness = roughness_values[j]
        str_roughness = ""
        str_roughness_minus_1 = ""
        if (j == 0):
            str_roughness = "0.0f"
        elif (j == len(roughness_values) - 1):
            str_roughness = "1.0f"
            str_roughness_minus_1 = str(round(roughness - 1 / (len(roughness_values) - 1), 1)) + "f"
        else:
            str_roughness = str(roughness) + "f"
            str_roughness_minus_1 = str(round(roughness - 1 / (len(roughness_values) - 1), 1)) + "f"

        print("\t", end='')
        if (j == 0):
            print("if (", end='')
        else:
            print("else if (", end='')

        print("roughness <= " + str_roughness + ")")

        if (j == 0):
            print("\t\thigher_correction = " + str(values[j][i]) +  "f;")
        else:
            lower_lerp = values[j - 1][i]
            higher_lerp = values[j][i]
            if (lower_lerp == higher_lerp):
                print("\t\thigher_correction = " + str(lower_lerp) + "f;")
            else:
                print("\t\thigher_correction = hippt::lerp(" + str(lower_lerp) + "f, " + str(higher_lerp) + "f, (roughness - " + str_roughness_minus_1 + ") / 0.1f);")

    print("}")

print("\nreturn hippt::lerp(lower_correction, higher_correction, (relative_eta - lower_relative_eta_bound) / (higher_relative_eta_bound - lower_relative_eta_bound));")

relative_eta * -13.439691641502959+roughness * -14.648819997830152+relative_eta*relative_eta * 6.8789072129524165+relative_eta * roughness * 2.277848368999256+roughness*roughness * 30.73190238514913+relative_eta*relative_eta*relative_eta * -1.1621500927731714+relative_eta*relative_eta * roughness * 0.37498830694981306+relative_eta * roughness*roughness * 0.8224134608255266+roughness*roughness*roughness * -21.325291375291403+

Fitting Errors:
Mean Squared Error (MSE): 1.513004991070829
Root Mean Squared Error (RMSE): 1.2300426785566543
Mean Absolute Error (MAE): 0.6879266691669605


float lower_relative_eta_bound;
float lower_correction;
if (relative_eta > 1.01f && relative_eta <= 1.02f)
{
	lower_relative_eta_bound = 1.01f;

	if (roughness <= 0.0f)
		lower_correction = 2.5f;
	else if (roughness <= 0.1f)
		lower_correction = 2.5f;
	else if (roughness <= 0.2f)
		lower_correction = hippt::lerp(2.5f, 2.3f, (roughness - 0.1f) / 0.1f);
	else if (roughness <= 0.3f)
		lower_correction = hippt::