In [1]:
import jax
import jax.numpy as jnp
import jax.random as jrandom
import equinox as eqx

In [2]:
from src.zpinn.models.SIREN import SIREN
from src.zpinn.models.SIREN import SIREN
from src.zpinn.modules.sine_layer import SineLayer

model_key = jrandom.PRNGKey(0)

model = SIREN(
    key=model_key,
    in_features=4,
    out_features=2,
    hidden_features=256,
    outermost_linear=True,
    hidden_layers=3,
    first_omega_0=30.0,
    hidden_omega_0=30.0,
)

# test forward pass
val = model(*tuple([0.0] * 4))


val = jax.vmap(model, in_axes=(0, 0, 0, 0))(jnp.zeros((4)), jnp.zeros((4)), jnp.zeros((4)), jnp.zeros((4)))
val

Array([[ 0.01592879, -0.01134333],
       [ 0.01592879, -0.01134333],
       [ 0.01592879, -0.01134333],
       [ 0.01592879, -0.01134333]], dtype=float32)

In [3]:
im, re = val

ValueError: too many values to unpack (expected 2)

In [None]:
a, b = val
print(a,b)

0.01592879 -0.011343336


In [None]:
def apply_model(params, model, *args):
    """Forward pass of the model as a function of its parameters."""

    get_params = lambda m: m.get_params()
    model = eqx.tree_at(get_params, model, params)
    return model(*args)

In [None]:
apply_model(model.get_params(), model, *tuple([0.0] * 4))

Array([ 0.01592879, -0.01134334], dtype=float32)

In [None]:
result_tuple = (None, None, *[0.0] * 4)
result_tuple

(None, None, 0.0, 0.0, 0.0, 0.0)

In [4]:
from jax.tree_util import tree_map, tree_leaves, tree_reduce
from jax.flatten_util import ravel_pytree

flatten_pytree = lambda pytree: ravel_pytree(pytree)[0]

def physics_loss_fn(params, model, x, y):
    y_pred = apply_model(params, model, *x)
    return jnp.mean(jnp.square(y - y_pred))


def bc_loss_fn(params, model, x, y):
    y_pred = apply_model(params, model, *x)
    return jnp.mean(jnp.square(y - y_pred))


def data_loss_fn(params, model, x, y):
    y_pred = apply_model(params, model, *x)
    return jnp.mean(jnp.square(y - y_pred))


def total_loss_fn(params, model, x, y):
    return dict(
        physics_loss=physics_loss_fn(params, model, x, y),
        bc_loss=bc_loss_fn(params, model, x, y),
        data_loss=data_loss_fn(params, model, x, y),
    )
    
x = tuple([0.0] * 4)
y = jnp.zeros((2,))

# Compute the gradient of each loss w.r.t. the parameters
grads = jax.jacrev(total_loss_fn, argnums=0)(model.get_params(), model, x, y)

# Compute the grad norm of each loss
grad_norm_dict = {}
for key, value in grads.items():
    flattened_grad = flatten_pytree(value)
    grad_norm_dict[key] = jnp.linalg.norm(flattened_grad)


# Compute the mean of grad norms over all losses
mean_grad_norm = jnp.mean(jnp.stack(tree_leaves(grad_norm_dict)))

# Grad Norm Weighting
w = tree_map(lambda x: (mean_grad_norm / x), grad_norm_dict)
w

# add losses together
loss = tree_reduce(lambda x, y: x * y, w) 
loss

AttributeError: 'SIREN' object has no attribute 'get_params'

In [None]:
'Adaptive loss weighting'
            if epoch % update_loss_weights_every_iter == 0:
                
                # All parameters in the MLP
                if pnet_helmholtz.__class__.__name__ == 'MLP' or pnet_helmholtz.__class__.__name__ == 'ModMLP':
                    thetas = list(map(lambda layer: layer.weight, pnet_helmholtz.linears)) 

                # Calculate derivative with respect to parameters
                loss_data_grads = torch.autograd.grad(loss_data, thetas, torch.ones_like(loss_data), retain_graph=True)[0]  # 
                loss_r_grads = torch.autograd.grad(loss_r, thetas, torch.ones_like(loss_r), retain_graph=True)[0]
                loss_std_grads = list(map(lambda l: torch.autograd.grad(l, thetas, torch.ones_like(l), retain_graph = True)[0], loss_std))    
               

                lambda_data, lambda_residual,lambda_std,loss_logger = calc_loss_weights(loss_logger,loss_data_grads, loss_r_grads, loss_std_grads, 
                                                                            lambda_data, lambda_residual, lambda_std,
                                                                             alpha=alpha_lambda, alpha_std=alpha_lambda_std)


def calc_loss_weights(loss_logger,loss_data_grad, loss_residual_grad, loss_std_grad, 
                      lambda_data_old, lambda_residual_old,lambda_std_old, alpha, alpha_std):
    # Calculate the dynamic loss weightening as proposed in 
    # - "Understanding and mitigating gradient flow pathologies in physics-informed neural networks", S. Wang et al., and 
    # - "An experts's guide to training physics-informed neural networks", S. Wang et al
    

    # Calculate the weights from the relative change between losses
    'Scheme 1: Mean-based weighting scheme'
    data_metrics = torch.norm(loss_data_grad, p=2) 
    res_metrics = torch.norm(loss_residual_grad, p=2)
    std_metrics = list(map(lambda l: torch.norm(l, p=2), loss_std_grad))


    'Scheme 2: STD based weighting scheme'
    # data_metrics = torch.std(loss_data_grad) 
    # res_metrics = torch.std(loss_residual_grad)
    # std_metrics = list(map(lambda l: torch.std(l), loss_std_grad))

    'Calculate weights'
    lambda_data_temp = data_metrics + res_metrics/data_metrics
    lambda_r_temp = data_metrics + res_metrics/res_metrics
    lambda_std_temp = list(map(lambda l: data_metrics + res_metrics/l, std_metrics))

    'Scheme 3: Kurtosis and standard deviation-based'
    # lambda_data_temp = kurtosis(loss_data_grad)/torch.std(loss_data_grad)
    # lambda_r_temp = kurtosis(loss_residual_grad)/torch.std(loss_residual_grad)
    # lambda_std_temp = list(map(lambda l: kurtosis(l)/torch.std(l), loss_std_grad))

    # moving average
    lambda_data = alpha*lambda_data_old + (1-alpha)*lambda_data_temp
    lambda_r = alpha*lambda_residual_old + (1-alpha)*lambda_r_temp
    lambda_std = [alpha_std*lambda_std_old[i] + (1-alpha_std)*lambda_std_temp[i] for i in range(len(lambda_std_old))]

    # Add to loss logger
    loss_logger.set_weights(data_metrics, res_metrics, std_metrics,lambda_data_temp,lambda_r_temp,lambda_std_temp,
                              lambda_data,lambda_r,lambda_std)

IndentationError: unexpected indent (1530294719.py, line 2)

In [None]:
from src.zpinn.dataio import PressureDataset

dataset = PressureDataset(
    path=r"C:\Users\STNj\dtu\thesis\code\data\processed\inf_baffle.pkl",
)

dataloader = dataset.get_dataloader(batch_size=32, shuffle=True)

f, x, y, z = next(iter(dataloader))[0].keys()


'z'

In [None]:
t = (1,2,3,4)

def test(*args):
    print(args)

def nested_test(*args):
    test(*args)
    
nested_test(*t)

(1, 2, 3, 4)


In [None]:
from sympy import symbols, re, im, diff, pi, I

# Define symbols
p_re, p_im, dp_re, dp_im, a_c, a_0, b_c, b_0, f, f_c, f_0, z_c, rho_0, z = symbols(
    "p_re p_im d_re d_im a_c a_0 b_c b_0 f f_c f_0 z_c rho_0 z", real=True
)

# Define the expression

p = (p_re *a_c + a_0) + 1j*(p_im *b_c + b_0)

un = 1 / (1j * 2 * pi * f * rho_0) / z_c
un *=  (a_c * dp_re + 1j*b_c * dp_im) 

Z_n = p / un

# Extract real and imaginary parts
real_part = re(Z_n).simplify()
imag_part = im(Z_n).simplify()

print("Real part:")
print(real_part)
print("\nImaginary part:")
print(imag_part)

Real part:
2.0*pi*f*rho_0*z_c*(-a_c*d_re*(b_0 + b_c*p_im) + b_c*d_im*(a_0 + a_c*p_re))/(a_c**2*d_re**2 + 1.0*b_c**2*d_im**2)

Imaginary part:
2.0*pi*f*rho_0*z_c*(a_c*d_re*(a_0 + a_c*p_re) + b_c*d_im*(b_0 + b_c*p_im))/(a_c**2*d_re**2 + 1.0*b_c**2*d_im**2)


In [None]:
real_part.factor()

2.0*pi*f*rho_0*z_c*(a_0*b_c*d_im - a_c*b_0*d_re + a_c*b_c*d_im*p_re - a_c*b_c*d_re*p_im)/(1.0*a_c**2*d_re**2 + 1.0*b_c**2*d_im**2)

In [None]:
from abc import ABC, abstractmethod
from functools import partial

import jax.numpy as jnp
from jax import random, pmap, local_device_count

from torch.utils.data import Dataset


class BaseSampler(Dataset):
    def __init__(self, batch_size, rng_key=random.PRNGKey(1234)):
        self.batch_size = batch_size
        self.key = rng_key
        self.num_devices = local_device_count()

    def __getitem__(self, index):
        "Generate one batch of data"
        self.key, subkey = random.split(self.key)
        keys = random.split(subkey, self.num_devices)
        batch = self.data_generation(keys)
        return batch

    def data_generation(self, key):
        raise NotImplementedError("Subclasses should implement this!")


class UniformSampler(BaseSampler):
    def __init__(self, dom, batch_size, rng_key=random.PRNGKey(1234)):
        super().__init__(batch_size, rng_key)
        self.dom = dom
        self.dim = dom.shape[0]

    @partial(pmap, static_broadcasted_argnums=(0,))
    def data_generation(self, key):
        "Generates data containing batch_size samples"
        batch = random.uniform(
            key,
            shape=(self.batch_size, self.dim),
            minval=self.dom[:, 0],
            maxval=self.dom[:, 1],
        )

        return batch
    
    
sampler = UniformSampler(dom=jnp.array([[0.0, 1.0]]), batch_size=32)
it = iter(sampler)

NameError: name 'dataclass' is not defined

In [None]:
next(it)

Array([[[0.09473562],
        [0.80631065],
        [0.96966016],
        [0.7157742 ],
        [0.07006764],
        [0.6013645 ],
        [0.7976681 ],
        [0.03148794],
        [0.55536425],
        [0.6960144 ],
        [0.5531845 ],
        [0.9130745 ],
        [0.8807967 ],
        [0.82231283],
        [0.6479956 ],
        [0.1361841 ],
        [0.08883166],
        [0.00661266],
        [0.09548712],
        [0.34428847],
        [0.04157531],
        [0.33518732],
        [0.33928144],
        [0.58889234],
        [0.34694982],
        [0.20022178],
        [0.23062909],
        [0.18157935],
        [0.62718284],
        [0.45432687],
        [0.5500257 ],
        [0.1126709 ]]], dtype=float32)