In [None]:
import numpy as np
from matplotlib import pyplot as plt
import bgk

In [None]:
input = bgk.Input("/mnt/lustre/IAM851/jm1667/psc/inputs/bgk/case1-B=1-input.txt")
input.truncate(slice(0, 3000, 10))
# input.truncate(slice(1, 3000, 10))
input.convert_to_cs_units()

In [None]:
h0 = 0.9
k = 0.1
B = 1

## Compare LHS and RHS of eq 19

In [None]:
psi = input.Psi
rho = input.rho

lhs = np.gradient(np.gradient(psi, rho) * rho, rho) / rho
rhs = np.exp(psi) * (1 - h0 / np.sqrt(1 + 8 * k * rho**2) * np.exp(-k * B**2 * rho**4 / (1 + 8 * k * rho**2))) - 1

plt.plot(rho, lhs, ".", label="lhs")
plt.plot(rho, rhs, label="rhs")
# plt.plot(rho, psi, label="psi")
# plt.plot(rho, input.Te, label="Te")

plt.xlabel("rho")
plt.legend()

plt.show()

## View f

In [None]:
from math import exp, pi

def calc_v_phi(rho: float) -> float:
    return input.interpolate_value(rho, "v_phi")
def calc_psi(rho: float) -> float:
    return input.interpolate_value(rho, "Psi")
def calc_l(rho: float) -> float:
    return 2 * rho * calc_v_phi(rho) - B * rho**2
def calc_w(*, rho: float) -> float:
    return calc_v_phi(rho)**2 / 2 - calc_psi(rho)
def calc_f(*, rho: float=None, w: float=None, l: float=None) -> float:
    if rho is not None:
        w = calc_w(rho=rho)
        l = calc_l(rho=rho)
    return (2 * pi) ** (-3/2) * exp(-w) * (1 - h0 *  exp(-k * l ** 2))

In [None]:
%matplotlib widget
rhos = np.linspace(input.rho[0], input.rho[-2], 100)
fs = [calc_f(rho=rho) for rho in rhos]
ws = [calc_w(rho=rho) for rho in rhos]
ls = [calc_l(rho=rho) for rho in rhos]

plt.hexbin(ws, ls, fs, gridsize=50, cmap="Greens", vmax=.13)
plt.colorbar()
plt.xlabel("w")
plt.ylabel("l");

In [None]:
%matplotlib widget
from itertools import product


ws = np.array([calc_w(rho=rho) for rho in rhos])
ls = np.array([calc_l(rho=rho) for rho in rhos])
ws = np.linspace(np.min(ws), np.max(ws), 100)
ls = np.linspace(np.min(ls)/15, np.max(ls), 100)
ws, ls = np.meshgrid(ws, ls)
ws = ws.flatten()
ls = ls.flatten()
fs2 = np.array([calc_f(w=w, l=l) for w, l in zip(ws, ls)])

# fs2 = fs2.flatten()
plt.hexbin(ws, ls, fs2, gridsize=50, cmap="Greens", vmax=.13)
plt.colorbar()
plt.xlabel("w")
plt.ylabel("l");