In [None]:
%pip install equinox

In [None]:
%pip install --force-reinstall scipy==1.12.0

In [None]:
import jax
import jax.numpy as jnp

import scipy
scipy.__version__

In [None]:
jax.devices()

In [None]:
#What coefficient to train
coefficient = "drag"

#Whether I am testing out architecture vs serious training
trial_mode = False

In [None]:
from torch.utils.data import Dataset
import pandas as pd
import numpy as np

input_labels = ["B", "T", "P", "C", "E", "R", "Alpha", "Re"]
#output_labels = ["Cl", "Cd", "Cdp", "Cm"]
if coefficient == "lift":
    output_labels = ["Cl"]
elif coefficient == "drag":
    output_labels = ["Cd"]
elif coefficient == "pdrag":
    output_labels = ["Cdp"]
elif coefficient == "moment":  
    output_labels = ["Cm"]
    
def clean_data(labels):
    #Filter out stall behaviour
    labels = labels[(labels["Alpha"] > -10) & (labels["Alpha"] < 10)]
    labels = labels[labels["Re"] > 200_000]
    
    return labels

class AirfoilTrainDataset(Dataset):
    def __init__(self, data_file):
        self.labels = pd.read_csv(data_file)
        
        #Filter out stall behaviour
        self.labels = clean_data(self.labels)

        #Split into train and validation
        split_point = int(len(self.labels) * 0.7)
        
        self.labels = self.labels[:split_point]
        
        self.input_labels = self.labels[input_labels]
        self.output_labels = self.labels[output_labels]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        input_df = self.input_labels.iloc[index]
        input_arr = np.array(input_df, dtype=jnp.float32)
        output_df = self.output_labels.iloc[index]
        output_arr = np.array(output_df, dtype=jnp.float32)
        return (input_arr, output_arr)
    
class AirfoilTestDataset(Dataset):
    def __init__(self, data_file):
        self.labels = pd.read_csv(data_file)
        
        #Filter out stall behaviour
        self.labels = clean_data(self.labels)
        
        #Split into train and validation
        split_point = int(len(self.labels) * 0.7)
        
        self.labels = self.labels[split_point:]
        
        self.input_labels = self.labels[input_labels]
        self.output_labels = self.labels[output_labels]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        input_df = self.input_labels.iloc[index]
        input_arr = np.array(input_df, dtype=jnp.float32)
        output_df = self.output_labels.iloc[index]
        output_arr = np.array(output_df, dtype=jnp.float32)
        return (input_arr, output_arr)

In [None]:
import equinox as eqx

class SurrogateModel(eqx.Module):
    
    layers: list
        
    def __init__(self, in_size, out_size, width_size, depth, activation, key):
        keys = jax.random.split(key, depth + 2)
        
        input_key = keys[0]
        output_key = keys[-1]
        hidden_keys = keys[1:-1]
        
        input_layer = eqx.nn.Linear(in_size, width_size, key=input_key)
        output_layer = eqx.nn.Linear(width_size, out_size, key=output_key)
        
        #Make Reynolds number on log10 scale
        @jax.jit
        def normalize_reynolds_number(x):
            Re = x[-1]
            Re = jnp.log10(Re)
            
            return jnp.hstack((x[:-1], Re))
            
        
        self.layers = [
            normalize_reynolds_number,
            jax.nn.standardize, #Standardize -1 to 1
            input_layer,
            activation
        ]
        for key in hidden_keys:
            self.layers.append(eqx.nn.Linear(width_size, width_size, key=key))
            self.layers.append(activation)
            
        self.layers.append(output_layer)
        
    def __call__(self, x):
                
        for layer in self.layers:
            x = layer(x)

        return x
        

In [None]:
key = jax.random.PRNGKey(42)

model = SurrogateModel(
    in_size=len(input_labels),
    out_size=len(output_labels),
    width_size=64,
    depth=4,
    activation=jax.nn.silu,
    key=key
)

model

In [None]:
@eqx.filter_jit
def loss(model, x, y):
    y_hat = jax.vmap(model)(x)
    
    #MSE Loss
    return jnp.mean((y - y_hat) ** 2)

@eqx.filter_jit
def training_step(model, opt_state, optim, x, y):
    loss_val, loss_grad = eqx.filter_value_and_grad(loss)(model, x, y)
    
    updates, opt_state = optim.update(loss_grad, opt_state)
    
    model = eqx.apply_updates(model, updates)
    
    return model, opt_state, loss_val

In [None]:
dataset_file = "airfoils_xs.csv" if trial_mode else "airfoils.csv"

In [None]:
import optax
from torch.utils.data import DataLoader
from IPython.display import clear_output
import math
import time

batch_size = 200
lr = 1e-3
epochs = 2000 if trial_mode else 1000

if not trial_mode:
    try:
        model = eqx.tree_deserialise_leaves(f"/kaggle/input/{coefficient}_surrogate.eqx", model)
    except:
        print("WARNING: No base model to continue training off of. Starting training from scratch.")

dataset = AirfoilTrainDataset(f"/kaggle/input/{dataset_file}")
training_dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

optim = optax.adabelief(learning_rate=lr)
opt_state = optim.init(eqx.filter(model, eqx.is_array))

for i in range(epochs):
    print("Training epoch",i)
    start_time = time.time()
    
    total_losses = []
    j = 1
    for x, y in training_dataloader:
        x = jnp.array(x, dtype=jnp.float32)
        y = jnp.array(y, dtype=jnp.float32)
                
        model, opt_state, loss_val = training_step(model, opt_state, optim, x, y)
        total_losses.append(loss_val)
        j += 1
    clear_output()

    total_losses = jnp.array(total_losses)
    print(f"Avg. Loss: {jnp.mean(total_losses):e}")
    print(f"Min. Loss: {jnp.min(total_losses):e}")
    print(f"Max. Loss {jnp.max(total_losses):e}")
    
    duration = time.time() - start_time
    print("Epoch took", duration, "seconds")
    
    if i % 5 == 0 and i != 0:
        eqx.tree_serialise_leaves(f"/kaggle/working/{coefficient}_surrogate.eqx", model)
        print(f"Saved checkpoint to /kaggle/working/{coefficient}_surrogate.eqx")
        if loss_val < 1e-6:
            print("Loss value is", loss_val, "cutting training early.")
            break

In [None]:
import matplotlib.pyplot as plt
model = eqx.tree_deserialise_leaves(f"/kaggle/input/{coefficient}_surrogate.eqx", model)

validation_dataset = AirfoilTestDataset(f"/kaggle/input/{dataset_file}")
x, y = list(DataLoader(dataset=validation_dataset, batch_size=len(validation_dataset), shuffle=False))[0]

x = jnp.array(x, dtype=jnp.float32)

alphas = x[:, -2]
alphas = jnp.round(alphas)
unique_alphas = jnp.unique(alphas)

y = jnp.array(y, dtype=jnp.float32)

y_hat = jax.vmap(model)(x)

abs_error = jnp.abs(y - y_hat)

alpha_errors = dict()

for a in unique_alphas:
    alpha_errors[str(a)] = abs_error[alphas == a].mean().item()

print("Running validation...")
print(f"Validation MAE loss ({coefficient})", jnp.mean(abs_error))
print("Validation standard deviation", abs_error.std())

keys = jnp.array([float(k) for k in alpha_errors.keys()])
values = jnp.array(list(alpha_errors.values()))

plt.plot(keys, values);

print("[" + ", ".join([str(x) for x in keys]) + "]")
print("[" + ", ".join([str(x) for x in values]) + "]")

In [None]:
#! rm /kaggle/working/surrogate.eqx

In [None]:
import jax
import jax.numpy as jnp
import equinox as eqx
import matplotlib.pyplot as plt
import random

input_labels = ["B", "T", "P", "C", "E", "R", "Alpha", "Re"]

if coefficient == "lift":
    output_labels = ["Cl"]
elif coefficient == "drag":
    output_labels = ["Cd"]
elif coefficient == "pdrag":
    output_labels = ["Cdp"]
elif coefficient == "moment":  
    output_labels = ["Cm"]

key = jax.random.PRNGKey(42)

model = SurrogateModel(
    in_size=len(input_labels),
    out_size=len(output_labels),
    width_size=64,
    depth=4,
    activation=jax.nn.silu,
    key=key
)

try:
    model = eqx.tree_deserialise_leaves(f"/kaggle/working/{coefficient}_surrogate.eqx", model)
except:
    print("No model available")
    
#It's fine if it's overfitting
validation_dataset = AirfoilTrainDataset(f"/kaggle/input/{dataset_file}")

geometry_params = validation_dataset[0][0][:6]
Re = validation_dataset[0][0][-1]

print("Reynolds number", Re)

num_rows = 0
while (validation_dataset[num_rows][0][:6] == geometry_params).all() and (validation_dataset[num_rows][0][-1] == Re).all():
    num_rows += 1
    
alpha = []
#Real coefficients
real_polars = []
for row in range(num_rows):
    alpha.append(validation_dataset[row][0][-2])
    real_polars.append(validation_dataset[row][1])
    
real_coefficients = jnp.vstack(real_polars).T.flatten()
    
jacobian = jax.jacrev(model)(jnp.hstack((geometry_params, 0, Re)))
print(jacobian)
    
alpha = jnp.array(alpha).reshape(1, -1).T
Re = jnp.full(alpha.shape, Re)



#Get Jacobian of example airfoil
geometry_params = jnp.full((alpha.size, geometry_params.size), geometry_params)

all_params = jnp.hstack((geometry_params, alpha, Re))

results = jax.vmap(model)(all_params)
coefficients = results.T.flatten()

alpha = alpha.flatten()

f, ax = plt.subplots(1, 1)


ax.plot(alpha, coefficients, label="Prediction")
ax.plot(alpha, real_coefficients, label="Real")
ax.legend()

plt.show()