In [17]:
import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax
import pickle
from dataclasses import dataclass

plt.style.use('default')

In [2]:
a = np.load(rf"E:\Aeronautics\fyp\DiffWake\data\horn\weather_data.npz", allow_pickle=True)
a

NpzFile 'E:\\Aeronautics\\fyp\\DiffWake\\data\\horn\\weather_data.npz' with keys: wind_speed, wind_direction, weight, windrose

In [3]:
ws = a['wind_speed']
wd = a['wind_direction']
w = a['weight']
wr = a['windrose']

In [9]:
a['wind_direction']

array([ 22.5,  67.5, 112.5, 157.5, 202.5, 247.5, 292.5, 337.5,  22.5,
        67.5, 112.5, 157.5, 202.5, 247.5, 292.5, 337.5,  22.5,  67.5,
       112.5, 157.5, 202.5, 247.5, 292.5, 337.5,  22.5,  67.5, 112.5,
       157.5, 202.5, 247.5, 292.5, 337.5,  22.5,  67.5, 112.5, 157.5,
       202.5, 247.5, 292.5, 337.5,  22.5,  67.5, 112.5, 157.5, 202.5,
       247.5, 292.5, 337.5,  22.5,  67.5, 112.5, 157.5, 202.5, 247.5,
       292.5, 337.5,  22.5,  67.5, 112.5, 157.5, 202.5, 247.5, 292.5,
       337.5,  22.5,  67.5])

In [18]:
wr

array({'wd': array([270.]), 'ws': array([13.5]), 'freq': array([0.06])},
      dtype=object)

In [22]:
# Create a simple weather case, single wd, ws and weight
wd_ = np.asarray([270.0])
ws_ = np.asarray([13.5])
w_ = np.asarray([1.0])
wr = np.array(dict(
    wd=wd_, ws=ws_, freq=w_
))

wd_, ws_, w_, wr

(array([270.]),
 array([13.5]),
 array([1.]),
 array({'wd': array([270.]), 'ws': array([13.5]), 'freq': array([1.])},
       dtype=object))

In [23]:
np.savez(rf"./data/simple/weather_data.npz",
         wind_speed=ws_,
         wind_direction=wd_,
         weight=w_,
         )

In [20]:
# reload
b = np.load(rf"./data/simple/weather_data.npz", allow_pickle=True)
b

NpzFile './data/simple/weather_data.npz' with keys: wind_speed, wind_direction, weight

In [21]:
b['wind_speed']

array([13.5])

In [24]:
c = np.load(rf"E:\Aeronautics\fyp\DiffWake\results\yaw_lbfgs\20260219_003433\arrays.npz", allow_pickle=True)
c

NpzFile 'E:\\Aeronautics\\fyp\\DiffWake\\results\\yaw_lbfgs\\20260219_003433\\arrays.npz' with keys: best_yaw, best_omega, per_case_power_MW, wind_dir_deg, wind_speed...

In [27]:
np.rad2deg(c['best_yaw'])

array([[25.        , 25.        , 12.52128107]])

In [18]:
# verify LHS sampling
def _latin_hypercube_unit(key, M: int, D: int, dtype) -> jax.Array:
    base = jnp.arange(M, dtype=jnp.int32)
    keys = jax.random.split(key, D + 1)
    k_jit, k_perms = keys[0], keys[1:]
    jitter = jax.random.uniform(k_jit, shape=(D, M), dtype=dtype)

    def _perm_one(k):
        return jax.random.permutation(k, base)

    perms = jax.vmap(_perm_one)(k_perms)
    U = (perms.astype(dtype) + jitter) / dtype(M)
    return U.T


@dataclass
class YawConstraints:
    gamma_min: float  # minimum allowable yaw angle
    gamma_max: float  # maximum allowable yaw angle

    def mins(self, dtype) -> jax.Array:
        return jnp.deg2rad(jnp.array([self.gamma_min], dtype=dtype))

    def maxs(self, dtype) -> jax.Array:
        return jnp.deg2rad(jnp.array([self.gamma_max], dtype=dtype))

def test_latin_hypercube_properties():
    # Setup
    M, D = 10, 3
    key = jax.random.PRNGKey(42)
    dtype = jnp.float32

    # Execute
    result = _latin_hypercube_unit(key, M, D, dtype)

    # 1. Check Shape
    assert result.shape == (M, D)

    # 2. Check Range (should be within [0, 1])
    assert jnp.all(result >= 0) and jnp.all(result <= 1)

    # 3. Check LHS Property:
    # If we multiply by M and floor the values, each column should
    # contain a permutation of integers 0 to M-1.
    bins = jnp.floor(result * M).astype(jnp.int32)
    for d in range(D):
        column_bins = bins[:, d]
        assert len(jnp.unique(column_bins)) == M, f"Dimension {d} is missing bins!"

test_latin_hypercube_properties()
print("test passed")

test passed


In [12]:
# Parameters
T = 3 # Number of turbines (Dimensions)
M = 1  # Number of initial configurations to test (Samples)
yaw_min = jnp.deg2rad(0.0)   # Minimum yaw offset in degrees
yaw_max = jnp.deg2rad(25.0)  # Maximum yaw offset in degrees

# 1. Generate the unit samples
key = jax.random.PRNGKey(0)
unit_samples = _latin_hypercube_unit(key, M, T, jnp.float32)

yaw_configs = yaw_min + (yaw_max - yaw_min) * unit_samples

In [15]:
print(f"Unit samples: {unit_samples}")
print(f"Yaw configs (radians): {yaw_configs}")
print(f"Yaw configs (degrees): {jnp.rad2deg(yaw_configs)}")

Unit samples: [[0.8423141  0.18237865 0.2271781 ]]
Yaw configs (radians): [[0.3675289  0.0795777  0.09912515]]
Yaw configs (degrees): [[21.057854   4.5594664  5.6794524]]


In [16]:
def omega_from_gamma_sig(gamma: jax.Array,
                          gamma_max: jax.Array,
                          eps=1e-7) -> jax.Array:
    """
    Map points in [gamma_min, gamma_max] to unconstrained omega via sigmoid function
    :param gamma: JAX array of physical yaw angles in radians
    :param gamma_max: maximum allowable physical yaw angle in radians
        """
    p = gamma / gamma_max
    p = jnp.clip(p, eps, 1.0 - eps)  # avoid log(0) or log(1)

    return jnp.log(p / (1.0 - p))

omega_from_gamma_sig()

In [45]:
def yaw_penalty_sq(gamma: jax.Array,
                   gamma_min: jax.Array,
                   gamma_max: jax.Array,
                   ) -> jax.Array:
    violation_min = jax.nn.softplus(gamma_min - gamma)
    print(violation_min)
    violation_max = jax.nn.softplus(gamma - gamma_max)
    print(violation_max)

    return jnp.sum(violation_min + violation_max)

def yaw_penalty_sq_v2(gamma: jax.Array,
                       gamma_min: jax.Array,
                       gamma_max: jax.Array) -> jax.Array:

    beta_1, beta_2 = 45, 146
    alpha_1, alpha_2 = 9.8, 4
    # penalise yaw below gamma_min (should be positive if violated)

    violation_min = (1.0 / beta_1) * jnp.log(1 + jnp.exp(beta_2 * (gamma_min - gamma)))
    print(violation_min)
    # penalise yaw above gamma_max (should be positive if violated)
    violation_max = (1.0 / alpha_1) * jnp.log(1 + jnp.exp(alpha_2 * (gamma_max - gamma)))
    print(violation_max)

    return jnp.sum(violation_max + violation_min)

def yaw_penalty_sq_v3(gamma: jax.Array,
                       gamma_min: jax.Array,
                       gamma_max: jax.Array) -> jax.Array:
    # penalise any absolute yaw value above 25 degrees
    yaw_diff = jnp.abs(gamma - gamma_max)

In [46]:
gamma = jnp.deg2rad(jnp.array([25.0, 19.0, 0.0]))
gamma_min = jnp.deg2rad(0.0)
gamma_max = jnp.deg2rad(25.0)
yaw_penalty_sq(gamma, gamma_min, gamma_max)

[0.4985928 0.5410242 0.6931472]
[0.6931472  0.64215744 0.4985928 ]


Array(3.5666618, dtype=float32)

In [47]:
gamma_max

Array(0.43633232, dtype=float32, weak_type=True)

In [48]:
yaw_penalty_sq_v2(gamma, gamma_min, gamma_max)

[0.         0.         0.01540327]
[0.07072931 0.09432252 0.19451493]


Array(0.37497002, dtype=float32)

In [40]:
(1.0 / 9.8) * jnp.log(1 + jnp.exp(4 * (0.4363 - 0.1)))

Array(0.16088761, dtype=float32, weak_type=True)

In [None]:
def make_losses(state,
                runner,
                weights: jax.Array,
                penalty_weight: float,
                dtype):

    pw = jnp.asarray(penalty_weight, dtype=dtype)

    def loss_from_yaw(yaw_angles: jnp.ndarray):
        """Physics-based objective: negative total farm power (weighted over wind cases)."""
        out = runner(yaw_angles)
        vel = average_velocity_jax(out.u_sorted)
        pow_mw = power(
            state.farm.power_thrust_table,
            vel,
            state.flow.air_density,
            yaw_angles=yaw_angles
        )
        case_power = jnp.sum(pow_mw, axis=1)
        return -jnp.sum(case_power * weights) / 1e6  # scalar in MW

    def loss_from_omega(omega: jax.Array,
                        gamma_min: jax.Array,
                        gamma_max: jax.Array) -> jax.Array:
        gammas = gamma_from_omega_sig(omega, gamma_max)
        phys = loss_from_yaw(gammas)
        yaw = yaw_penalty_sq(gammas, gamma_min, gamma_max)
        # Power is in MW (~1-100), yaw penalty is in rad^2 (~0-1)
        # We need to scale penalty_weight accordingly.
        return phys + pw * yaw

    return loss_from_yaw, loss_from_omega, pw