In [None]:
# imports

import itertools
import os
import time

import jax
import jax.numpy as jnp
import jax.nn as nn
from jax import vmap, jit, grad
from jax.example_libraries import optimizers
import numpy as np
import numpy.random as npr
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.ticker import MaxNLocator
from scipy.interpolate import interp1d
import seaborn as sns
try:
    from sklearn.model_selection import train_test_split
except ImportError as e:
    install = input(f'{e}, Do you want to install it? [Y/n]')
    if install == 'Y':
        import sys
        !{sys.executable} -m pip install 'scikit-learn'
        import sklearn
        print(sklearn.__version__)
        from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm

from dosipy.utils.viz import fig_config, save_fig, set_colorblind
from dosipy.utils.dataloader import load_antenna_el_properties
from utils import *

In [None]:
# jax.config.update('jax_enable_x64', True)

In [None]:
print(f'platform: {jax.lib.xla_bridge.get_backend().platform}')

In [None]:
set_colorblind()
%config InlineBackend.figure_format = 'retina'

In [None]:
# load pre-computed source data

f = 60e9  # operating frequency of the antenna
antenna_data = load_antenna_el_properties(f)
Is = antenna_data.ireal.to_numpy() + antenna_data.iimag.to_numpy() * 1j
Is = np.asarray(Is)
xs = antenna_data.x.to_numpy()
L = xs[-1]
Imax = np.abs(Is).max()
xticks = [0, L/2, L]
xticklabels = [-round(L*1000/2, 2), 0.0, round(L*1000/2, 2)]
xlabel = '$x$ [mm]'

In [None]:
# visualize spatial current distribution

yticks = [0, Imax/2, Imax]
yticklabels = [0.0, round(Imax/2, 3), round(Imax, 3)]
ylabel = '$I(x)$ [A]'

fig_config(latex=True, text_size=20, line_width=2.5)
fig = plt.figure()
ax = fig.add_subplot()
ax.plot(xs, np.abs(Is), 'k-', label='$|I(x)|$')
ax.plot(xs, np.real(Is), 'k--', label='$\Re{[I(x)]}$')
ax.plot(xs, np.imag(Is), 'k:', label='$\Im{[I(x)]}$')
ax.set(xlabel=xlabel,
       xticks=xticks,
       xticklabels=xticklabels,
       ylabel=ylabel,
       yticks=[-Imax/2, *yticks],
       yticklabels=[-round(Imax/2, 3), *yticklabels])
ax.legend(prop={'size': 20})

#fname = os.path.join('figures', 'complex_current')
#save_fig(fig, fname=fname, formats=['png'])

In [None]:
def forward_diff(fn, h=1e-6):
    """Forward difference approximation."""
    def dfn_dx(x):
        return (fn(x + h) - fn(x)) / h
    return dfn_dx
    
    
def backward_diff(fn, h=1e-6):
    """Backward difference approximation."""
    def dfn_dx(x):
        return (fn(x) - fn(x - h)) / h
    return dfn_dx

    
def central_diff(fn, h=1e-6):
    """Central difference approximation."""
    def dfn_dx(x):
        return (fn(x + h) - fn(x - h)) / (2 * h)
    return dfn_dx


def complex_step_diff(fn, h=1e-6):
    """Complex-step difference approximation.
    
    Note: Incompatible with SciPy interpolation module.
    """
    def dfn_dx(x):
        return np.imag(fn(x + 1j * h)) / h
    return dfn_dx

In [None]:
# quadratic interpolation of current distribution function

Is_fn = interp1d(xs, np.abs(Is), kind='quadratic')

xs_interp = np.linspace(xs.min(), xs.max(), 961)
Is_interp = Is_fn(xs_interp)

In [None]:
# mean relative error

error = np.abs(np.abs(Is)[1:-1] - Is_fn(xs)[1:-1]) / np.abs(Is)[1:-1]
mean_error = np.mean(error)
print(f'relative error = {mean_error * 100:.2e} %')

In [None]:
# visualize spatial current distribution as calculated and interpolated

fig_config(latex=True, text_size=20, line_width=2.5, marker_size=7)
fig = plt.figure()
ax = fig.add_subplot()
ax.plot(xs, np.abs(Is), 'k-', label='$|I(x)|$')
ax.plot(xs_interp, Is_interp, 'ko', markevery=30, label='$|\hat{I}(x)|$')
ax.set(xlabel=xlabel,
       xticks=xticks,
       xticklabels=xticklabels,
       ylabel=ylabel,
       yticks=yticks,
       yticklabels=yticklabels)
ax.legend()

# fname = os.path.join('figures', 'interp_current')
# save_fig(fig, fname=fname, formats=['png'])

In [None]:
# finite difference on computed and interpolated data

dIsdxs = np.r_[(forward_diff(Is_fn)(xs[0]),
                central_diff(Is_fn)(xs[1:-1]),
                backward_diff(Is_fn)(xs[-1]))]
dIsdxs_interp = np.r_[(forward_diff(Is_fn)(xs_interp[0]),
                       central_diff(Is_fn)(xs_interp[1:-1]),
                       backward_diff(Is_fn)(xs_interp[-1]))]

In [None]:
# visualize current gradient distribution as calculated and interpolated

fig_config(latex=True, text_size=20, line_width=2.5, marker_size=7)
fig = plt.figure()
ax = fig.add_subplot()
ax.plot(xs, dIsdxs, 'k-', label='$\mathrm{d}|I|$/$\mathrm{d}x$')
ax.plot(xs_interp, dIsdxs_interp, 'ko', markevery=30, label='$\mathrm{d}|\hat{I}|$/$\mathrm{d}x$')
ax.plot([xs[1], xs[-2]], [dIsdxs[1], dIsdxs[-2]], 'ko', fillstyle='none',
        markersize=50, markeredgewidth=2.5)
textbox = ax.text(xs[0], -dIsdxs[1]-4, s='numerical artifacts', fontweight='bold',
                  size=18,
                  bbox={'facecolor': 'lightgray',
                        'edgecolor': 'black',
                        'alpha': 1,
                        'pad': 5})
ax.annotate('', xy=(xs[2], dIsdxs[5]-2),
            xytext=(xs[11], -dIsdxs[2]),
            arrowprops={'facecolor': 'black'})
ax.annotate('', xy=(xs[-6], dIsdxs[-2]),
            xytext=(xs[11]+0.000925, -dIsdxs[2]-5), 
            arrowprops={'facecolor': 'black',})
ax.set(xlabel=xlabel, ylabel=r'$\mathrm{d}I$/$\mathrm{d}x$ [A/m]',
       xticks=xticks,
       xticklabels=xticklabels)
ax.legend(prop={'size': 18})

# fname = os.path.join('figures', 'grad_interp_current')
# save_fig(fig, fname=fname, formats=['png'])

In [None]:
rng = jax.random.PRNGKey(0)

def init_network_params(sizes, key):
    """Initialize network parameters."""
    keys = jax.random.split(key, len(sizes))
    def random_layer_params(m, n, key, scale=1e-2):
        w_key, b_key = jax.random.split(key)
        return (scale * jax.random.normal(w_key, (n, m)),
                scale * jax.random.normal(b_key, (n, )))
    return [random_layer_params(m, n, key)
            for m, n, key in zip(sizes[:-1], sizes[1:], keys)]


def forward(params, X, scaler):
    """Forward pass."""
    output = X
    for w, b in params[:-1]:
        output = nn.tanh(w @ output + b)
    w, b = params[-1]
    output = w @ output + b
    return output * scaler


# vectorized mapping of network input, `X`, on `forward` function
batch_forward = vmap(forward, in_axes=(None, 0, None))


@jit
def loss_fn(params, batch, scaler):
    """Summed square error loss function."""
    X, y = batch
    y_pred = batch_forward(params, X, scaler)
    return jnp.sum(jnp.square(y_pred - y))


# derivative of the loss function
grad_fn = jit(grad(loss_fn))


@jit
def update(step, optim_state, batch, scaler):
    """Return current optimal state of the network."""
    params = optim_params(optim_state)
    grads = grad_fn(params, batch, scaler)
    optim_state = optim_update(step, grads, optim_state)
    return optim_state

In [None]:
xs_norm = normalize(xs_interp, xs_interp)
xs_data = jnp.array(xs_norm).reshape(-1, 1)
# the following scaling is a bit weird but is needed in order to keep
# gradients of the nn output wrt input (not parameters, but actual input) in
# the same scale with the ones that are computed via FDM
Is_norm = normalize(xs_interp, Is_interp)
Is_data = jnp.array(Is_norm).reshape(-1, 1)
X_train, X_test, y_train, y_test = train_test_split(xs_data, Is_data, test_size=0.25)

In [None]:
# set network hyperparameter and train

# to set the output of the nn in scale with the target data, we define scaler
scaler = np.abs(y_train).max()
step_size = 1e-3
n_epochs = 10_000
printout = int(n_epochs / 100.)
epochs = np.arange(0, n_epochs+1, step=printout)
batch_size = 128
momentum_mass = 0.9  # for momentum and adagrad
sizes = [1, 128, 256, 128, 1]

num_train = X_train.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)


def data_stream(num_train, num_batches):
    """Training data random generator."""
    rng = npr.RandomState(0)
    while True:
        perm = rng.permutation(num_train)
        for i in range(num_batches):
            batch_idx = perm[i * batch_size:(i + 1) * batch_size]
            yield X_train[batch_idx], y_train[batch_idx]
            

batches = data_stream(num_train, num_batches)

optim_init, optim_update, optim_params = optimizers.adam(step_size)
init_params = init_network_params(sizes, rng)
optim_state = optim_init(init_params)
itercount = itertools.count()
 
loss_train, loss_test = [], []
params_list = []
start_time = time.time()
pbar = tqdm(range(n_epochs))
for epoch in pbar:
    start_epoch_time = time.time()
    for _ in range(num_batches):
        optim_state = update(next(itercount), optim_state, next(batches), scaler)
    epoch_duration = time.time() - start_epoch_time
    
    params = optim_params(optim_state)
    if (epoch == 0) or (epoch % printout == (printout - 1)):
        params_list.append(params)
        curr_loss_train_val = loss_fn(params, (X_train, y_train), scaler)
        curr_loss_test_val = loss_fn(params, (X_test, y_test), scaler)
        loss_train.append(curr_loss_train_val)
        loss_test.append(curr_loss_test_val)
        pbar.set_description(f'Loss (test): {curr_loss_test_val:.4e}')
training_duration = time.time() - start_time
print(f'Training time: {training_duration:.2f} s')

In [None]:
# visualize training loss dynamics

fig_config(latex=True, text_size=24, line_width=4)
fig = plt.figure()
ax = fig.add_subplot()
ax.plot(epochs[::10], loss_train[::10], 'k-', label='train set')
ax.plot(epochs[::10], loss_test[::10], 'k--', label='test set')
ax.set(xlabel='epoch', ylabel='loss', yscale='log', yticks=[1e-1, 1e1, 1e3])
ax.legend(prop={'size': 24})

# fname = os.path.join('figures', 'loss')
# save_fig(fig, fname=fname, formats=['png'])

In [None]:
# choose params with the best performance on test set

best_params_idx = loss_test.index(min(loss_test))
params = params_list[best_params_idx]

In [None]:
# visualize spatial current distribution as calculated and fitted

Is_fit = batch_forward(params, xs_data.reshape(-1, 1), scaler)
Is_fit_inv_norm = inv_normalize(Is_fit, xs_interp)

fig_config(latex=True, text_size=20, line_width=2.5, marker_size=7)
fig = plt.figure()
ax = fig.add_subplot()
ax.plot(xs_interp, Is_interp, 'k-', label='$|I(x)|$')
ax.plot(xs_interp, Is_fit_inv_norm, 'ko', markevery=30, label='NN$(x)$')
ax.set(xlabel=xlabel,
       xticks=xticks,
       xticklabels=xticklabels,
       ylabel=ylabel,
       yticks=yticks,
       yticklabels=yticklabels)
ax.legend()

# fname = os.path.join('figures', 'nn_current')
# save_fig(fig, fname=fname, formats=['png'])

In [None]:
def Is_nn(xs):
    """Current value at specific location, `xs`.
    
    Note: This is single-value wrapper for the forward pass function.
    """
    return forward(params, xs, scaler)[0]


# derivative of the current approximation function
grad_Is_nn = jit(vmap(grad(Is_nn)))

In [None]:
# visualize current gradient distribution as calculated and fitted

fig_config(latex=True, text_size=20, line_width=2.5, marker_size=7)
fig = plt.figure()
ax = fig.add_subplot()
ax.plot(xs_interp, dIsdxs_interp, 'k-', label='$\mathrm{d}|I|$/$\mathrm{d}x$')
ax.plot(xs_interp, grad_Is_nn(xs_data), 'ko', markevery=30, label='$\mathrm{d}$NN/$\mathrm{d}x$')
ax.plot([xs[1], xs[-2]], [dIsdxs[1], dIsdxs[-2]], 'ko', fillstyle='none',
        markersize=50, markeredgewidth=2.5)
textbox = ax.text(xs[0], -dIsdxs[1]-4,
                  s='smooth gradients with\nthe network interpolant', size=18,
                  bbox={'facecolor': 'lightgray',
                        'edgecolor': 'black',
                        'alpha': 1,
                        'pad': 5})
ax.annotate('', xy=(xs[2], dIsdxs[5]-2),
            xytext=(xs[11], -dIsdxs[2]+6),
            arrowprops={'facecolor': 'black'})
ax.annotate('', xy=(xs[-6], dIsdxs[-2]),
            xytext=(xs[11]+0.0013, -dIsdxs[2]-2.5), 
            arrowprops={'facecolor': 'black',})
ax.set(xlabel=xlabel, ylabel=r'$\mathrm{d}I$/$\mathrm{d}x$ [A/m]',
       xticks=xticks,
       xticklabels=xticklabels)
ax.legend(prop={'size': 18})

# fname = os.path.join('figures', 'grad_nn_current')
# save_fig(fig, fname=fname, formats=['png'])