In [None]:
import os

import numpy as np
import pylab as plt
import pyrap.tables as pt
from daskms.experimental.zarr import xds_from_zarr
from h5parm import DataPack

# Edit these
ionosphere_file = './output/sim_dsa2000W_dawn_30.0_1.5_2src/sim_dsa2000W_dawn_30.0_1.5_2src.h5'
solutions_file = './output/W-64chan-20int-2src-iono/32chan_6s_2src/gains::G'
ms_file = './data/W-64chan-20int-2src-iono.ms/'
dir_idx = 0


[gains] = xds_from_zarr(solutions_file)

def wrap(phi):
    return (phi + np.pi) % (2 * np.pi) - np.pi

with DataPack(ionosphere_file, readonly=True) as dp:
    sim_phase, axes = dp.phase
    sim_phase = wrap(sim_phase)
    sim_phase -= sim_phase[:,:,0:1,:,:]
    sim_phase = wrap(sim_phase[0, :, :, :, :]) * 180. / np.pi

    sim_tec, _ = dp.tec
    sim_tec -= sim_tec[:,:,0:1,:]
    sim_tec = sim_tec[0, :, :, :]
    _, sim_directions = dp.get_directions(axes['dir'])
    _, sim_times = dp.get_times(axes['time'])
    _, sim_freqs = dp.get_freqs(axes['freq'])
    _, sim_ants = dp.get_antennas(axes['ant'])

    print(dp.axes_order)

    Nd, Na, Nf, Nt = sim_phase.shape

if dir_idx >= Nd:
    raise IndexError(f"Direction index {dir_idx} too big.")
print(f"DD simulation, with {Nd} directions.")
print(f"Inspecting direction {dir_idx}")


with pt.table(os.path.join(ms_file, "ANTENNA")) as t:
    ant_pos_map = dict(zip(t.getcol('STATION'), t.getcol('POSITION')))
    ant_pos = t.getcol('POSITION')

In [None]:
with pt.table(os.path.join(ms_file, "SPECTRAL_WINDOW")) as t:
    ms_freqs =  t.getcol('CHAN_FREQ')

with pt.table(ms_file) as t:
#     print(t.colnames())
    ms_times = np.unique(t.getcol('TIME'))/86400.
    print('MS times (mjd):',ms_times)
    print('Sim times (mjd):', sim_times.mjd)
    print('MS dt (mjd):', np.diff(ms_times)[0])
    print('Sim dt (mjd):', np.diff(sim_times.mjd)[0])
    offset = (ms_times[0] - sim_times.mjd[0]) * 86400.
    print('MS Offset (s):', offset)

In [None]:

if len(gains.params.direction) > 1:
    print(f"DD solve with {len(gains.params.direction)} dirs")
if len(gains.params.direction) != Nd:
    print(f"DD simulation directions {Nd} doesn't match number of solve directions {len(gains.params.direction)}.")
gains.params

In [None]:

plt.imshow(sim_tec[dir_idx, :, :], aspect='auto', cmap='jet')
plt.colorbar(label=r'$\Delta$TEC [mTECU]')
plt.xlabel('Time stamp')
plt.ylabel('Antenna Index')
plt.title(r'Simulated $\Delta$ Tec Ground Truth')
plt.show()

plt.imshow(sim_phase[dir_idx, :, 0, :], aspect='auto', cmap='jet')
plt.colorbar(label='Phase [deg]')
plt.xlabel('Time stamp')
plt.ylabel('Antenna Index')
plt.title(f'Simulated Phase Ground Truth @ {sim_freqs[0]}')
plt.show()


TEC_CONV = -8.4479745e6  # Hz/mTECU

# window over freq
sim_phase_over_freq = wrap(sim_tec[dir_idx,-1,:,None] * (TEC_CONV / ms_freqs)) * 180/np.pi
plt.imshow(sim_phase_over_freq, aspect='auto', cmap='jet')
plt.colorbar(label='Phase [deg]')
plt.xlabel('Channel index')
plt.ylabel('Time stamp')
plt.title(f'Simulated Phase Ground Truth')
plt.show()


sc=plt.scatter(sim_ants.x, sim_ants.y,c=sim_phase[dir_idx, :, 0, 0],cmap='jet')
plt.colorbar(sc, label='phase [deg]')
plt.title('Simulation ground truth phase XX (gridded)')
plt.xlabel('X')
plt.ylabel('Y')
plt.show()

In [None]:
freqs = gains.params.param_freq.values
times = gains.params.param_time.values
phase = wrap(gains.params.values)
phase -= phase[:,:,0:1,...]
phase = phase[...,dir_idx:dir_idx+1,:]
phase = wrap(phase) * 180/np.pi

In [None]:
plt.plot(sim_phase[dir_idx, :, 0, 0])

plt.title("Simulated Phase XX")
plt.xlabel("Ant Index")
plt.ylabel("Phase [deg]")
plt.show()

plt.plot(phase[0, 0, :, 0, 0])
plt.title("Solved Phase XX")
plt.xlabel("Ant Index")
plt.ylabel("Phase [deg]")
plt.show()

# What the simulation looks like

In [None]:
sim_phase_mean = np.mean(sim_phase[dir_idx, :, 0, :], axis=-1)

vmin, vmax = np.percentile(sim_phase_mean, [5, 95])


plt.scatter(ant_pos[:, 0], ant_pos[:, 1], c=sim_phase_mean, cmap='jet', vmin=vmin, vmax=vmax)
plt.colorbar(label='Phase [deg]')
plt.xlabel('X [m]')
plt.ylabel('Y [m]')
plt.title("Average Simulated Phase XX")
plt.show()



for t in range(0, Nt, Nt // 5):

    plt.scatter(ant_pos[:, 0], ant_pos[:, 1], c=sim_phase[dir_idx, :, 0, t], cmap='jet', vmin=vmin, vmax=vmax)
    plt.colorbar(label='Phase [deg]')
    plt.xlabel('X [m]')
    plt.ylabel('Y [m]')
    plt.title(f"timestep: {t} Simulated Phase XX")
    plt.show()

In [None]:
vmin, vmax = np.percentile(sim_phase_mean, [5, 95])

# Let's average the simulated phase in blocks
blocksize = Nt // phase.shape[0]

variance_per_block = []
for t in range(phase.shape[0]):

    plt.scatter(ant_pos[:, 0], ant_pos[:, 1], c=phase[t, 0, :, 0, 0],
                vmin=vmin, vmax=vmax,
                cmap='jet')

    plt.colorbar(label='Phase [deg]')
    plt.xlabel('X [m]')
    plt.ylabel('Y [m]')
    plt.title(f"Timstep {t}: Solved Phase XX")

    plt.show()

    start = t * blocksize
    stop = (t+1) * blocksize

    block_averaged = np.mean(sim_phase[dir_idx, :, 0, start:stop], axis=-1)

    plt.scatter(ant_pos[:, 0], ant_pos[:, 1], c=block_averaged, cmap='jet', vmin=vmin, vmax=vmax)
    plt.colorbar(label='Phase [deg]')
    plt.xlabel('X [m]')
    plt.ylabel('Y [m]')
    plt.title(f"Block Averaged {start} to {stop}: Simulated Phase XX")
    plt.show()

    diff_phase = block_averaged - phase[t, 0, :, 0, 0]
    _vmin, _vmax = np.percentile(diff_phase, [5, 95])
    plt.scatter(ant_pos[:, 0], ant_pos[:, 1], c=diff_phase, cmap='jet', vmin=_vmin, vmax=_vmax)
    plt.colorbar(label='Phase [deg]')
    plt.xlabel('X [m]')
    plt.ylabel('Y [m]')
    plt.title(f"Block Averaged {start} to {stop}: Difference from ground truth Phase XX")
    plt.show()

    bias = np.mean(diff_phase)
    stddev = np.std(diff_phase)
    variance_per_block.append(stddev**2)
    print(f"Bias: {bias} deg")
    print(f"Sigma: {stddev} deg")

    plt.hist(diff_phase, bins='auto')
    plt.gca().axvline(bias, c='red', ls='solid')
    plt.gca().axvline(bias + stddev, c='red', ls='dashed')
    plt.gca().axvline(bias - stddev, c='red', ls='dashed')
    plt.xlabel(r'$\Delta$ Phase [deg]')
    plt.title(f"Block Averaged {start} to {stop}: Histogram of phase errors")
    plt.show()

noise_variance = np.mean(variance_per_block)

In [None]:
import numpy as np
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, WhiteKernel

def smooth_values(x, y, z, f, kernel_length_scale, noise_variance):
    # Combine x, y, and z coordinates into a single 2D array of shape (n, 3)
    coordinates = np.column_stack((x, y, z))

    # Set up the kernel for the Gaussian Process
    kernel = RBF(length_scale=kernel_length_scale) + WhiteKernel(noise_level=noise_variance, noise_level_bounds='fixed')

    # Create the Gaussian Process regressor
    gp = GaussianProcessRegressor(kernel=kernel)

    # Fit the Gaussian Process to the data
    gp.fit(coordinates, f)

    # Predict the smoothed values at the same coordinates
    f_smoothed = gp.predict(coordinates)

    return f_smoothed



for t in range(phase.shape[0]):

    phase_smooth = smooth_values(ant_pos[:,0], ant_pos[:,1], ant_pos[:,2], f=phase[t, 0, :, 0, 0],
                                 kernel_length_scale=1000, noise_variance=noise_variance)
    phase_smooth -= phase_smooth[0]

    plt.scatter(ant_pos[:, 0], ant_pos[:, 1], c=phase_smooth,
                vmin=vmin, vmax=vmax,
                cmap='jet')

    plt.colorbar(label='Phase [deg]')
    plt.xlabel('X [m]')
    plt.ylabel('Y [m]')
    plt.title(f"Timstep {t}: Solved Phase XX (GP smoothed)")

    plt.show()

    start = t * blocksize
    stop = (t+1) * blocksize

    block_averaged = np.mean(sim_phase[dir_idx, :, 0, start:stop], axis=-1)

    plt.scatter(ant_pos[:, 0], ant_pos[:, 1], c=block_averaged, cmap='jet', vmin=vmin, vmax=vmax)
    plt.colorbar(label='Phase [deg]')
    plt.xlabel('X [m]')
    plt.ylabel('Y [m]')
    plt.title(f"Block Averaged {start} to {stop}: Simulated Phase XX")
    plt.show()

    diff_phase = block_averaged - phase_smooth
    _vmin, _vmax = np.percentile(diff_phase, [5, 95])
    plt.scatter(ant_pos[:, 0], ant_pos[:, 1], c=diff_phase, cmap='jet', vmin=_vmin, vmax=_vmax)
    plt.colorbar(label='Phase [deg]')
    plt.xlabel('X [m]')
    plt.ylabel('Y [m]')
    plt.title(f"Block Averaged {start} to {stop}: Difference from ground truth Phase XX")
    plt.show()

    bias = np.mean(diff_phase)
    stddev = np.std(diff_phase)
    print(f"Bias: {bias} deg")
    print(f"Sigma: {stddev} deg")

    plt.hist(diff_phase, bins='auto')
    plt.gca().axvline(bias, c='red', ls='solid')
    plt.gca().axvline(bias + stddev, c='red', ls='dashed')
    plt.gca().axvline(bias - stddev, c='red', ls='dashed')
    plt.xlabel(r'$\Delta$ Phase [deg]')
    plt.title(f"Block Averaged {start} to {stop}: Histogram of phase errors")
    plt.show()