In [None]:
import logging
import sys

import numpy

import matplotlib.pyplot as plt

import astropy.constants as constants
from rascil.data_models import PolarisationFrame
from rascil.processing_components import create_blockvisibility_from_ms, \
    create_image_from_visibility, qa_image, export_image_to_fits, \
    flagging_blockvisibility, plot_configuration, plot_uvcoverage, \
    create_visibility_from_rows, weight_visibility, export_blockvisibility_to_ms, \
    create_calibration_controls, gaintable_plot

from rascil.workflows import calibrate_list_rsexecute_workflow
from rascil.processing_components.imaging.ng import invert_ng, predict_ng

from rascil.workflows.rsexecute.execution_support import rsexecute

log = logging.getLogger("logger")
log.setLevel(logging.DEBUG)
log.addHandler(logging.StreamHandler(sys.stdout))

In [None]:
def flag(bv):
    # Flag short spacings (only way to do it currently)
    uvdist = numpy.sqrt(bv.u**2+bv.v**2)
    bv.data['flags'][uvdist<100.0] = 1

    # These dishes do not see the transient
    ants = [0, 8, 9, 10, 11, 13, 27]
    bv = flagging_blockvisibility(bv, antenna=ants)

    # These dishes seem to be anomalous
    ants = [18, 20, 33]
    bv = flagging_blockvisibility(bv, antenna=ants)

    # Flag obviously bad channels
    for chan in [0, 1, 16, 17, 32, 33, 48, 49]:
        bv.data['flags'][...,chan,:] = 1

    # Also see strange visibilities for these channels
    for chan in [12, 28, 44, 60, 46]:
        bv.data['flags'][...,chan,:] = 1

    # Flag high points
    bad = numpy.abs(bv.vis) > 6e4
    bv.flags[...][bad] = 1

    return bv

def plot_grid(dirty, name, refchan=52):
    polnames = ['I', 'V']
    vmax = numpy.max(dirty.data[refchan, 0])
    vmin = -0.3 * vmax
    for pol in range(2):
        fig, axes = plt.subplots(nrows=8, ncols=8, squeeze=False, figsize=(64, 64))
        chan = 0

        for col in range(8):
            for row in range(8):
                ax = axes[col, row]
                im = ax.imshow(dirty.data[chan, pol, : :], aspect='equal',
                               vmax=vmax, vmin=vmin, cmap='rainbow')
                ax.set_title('Chan {}'.format(chan), fontsize=32)
                ax.invert_yaxis()
                ax.axes.xaxis.set_visible(False)
                ax.axes.yaxis.set_visible(False)
                chan += 1

        fig.suptitle("{} Stokes {}".format(name, polnames[pol]), fontsize=64)
        plt.savefig("{}_Stokes_{}.png".format(name, polnames[pol]))
        plt.show(block=False)
        plt.close(fig=fig)

In [None]:
cellsize = 2e-05
npixel = 1024
imaging_context = 'ng'
image_stokes = PolarisationFrame("stokesIV")

entire_bvis = create_blockvisibility_from_ms("selfmodel.ms")[0]

In [None]:
plt.clf()
plot_configuration([entire_bvis], title='MUSER configuration')
plt.show(block=False)

In [None]:
names = entire_bvis.configuration.names
diameters = entire_bvis.configuration.diameter
for ant, antxyz in enumerate(entire_bvis.configuration.xyz):
    print("{} \t{} \t{}m \t{}m".format(ant, names[ant], antxyz, diameters[ant]))

In [None]:
plt.clf()
plot_uvcoverage([entire_bvis], title='MUSER uvcoverage')
plt.show(block=False)

In [None]:
entire_bvis = flag(entire_bvis)
export_blockvisibility_to_ms("selfmodel_flagged.ms", [entire_bvis])

In [None]:
model = create_image_from_visibility(entire_bvis, cellsize=cellsize, npixel=npixel,
                                     nchan=64, polarisation_frame=image_stokes)


ntimes = len(entire_bvis.time)
nchan = entire_bvis.nchan
print("Size of BlockVis = {:.1f} GB".format(entire_bvis.size()))

In [None]:
selfcal = False
if selfcal:
    def read_sun_disk_data():
        sundisk_file_path = "quietsundisk400_2000MHz.txt"
        sundisk_file = open(sundisk_file_path, "r")
        sundisk_dict = {}
        try:
            while True:
                linekey = sundisk_file.readline()
                linevalue = sundisk_file.readline()
                if linekey and linevalue:
                    linekey = str(int(float(linekey.strip())))
                    value = list(map(eval, linevalue.strip().split(',')[1:]))
                    npvalue = []
                    for i in range(200):
                        npvalue.append(numpy.float32(value[i]))
                    sundisk_dict[linekey] = numpy.array(npvalue)
                else:
                    return sundisk_dict
        finally:
            sundisk_file.close()

    sundisk_dict = read_sun_disk_data()
    def fill_solar_model_jy(m):
        # fov - minute
        size = m.data.shape[2]
        cellsize = m.wcs.wcs.cdelt[1] * numpy.pi / 180.0
        fov = cellsize*180/numpy.pi * size * 60
        # sun disk = 32 arc minute,
        sun_radius = 16
        for chan in range(m.data.shape[0]):
            frequency = str(int(round(m.frequency[chan] * 1e-6)))
            assert frequency in sundisk_dict.keys(), "Key {} not present".format(frequency)
            sun_disk = sundisk_dict[frequency]
            for i in range(size):
                for j in range(size):
                    radius = numpy.sqrt((i-size//2)**2+(j-size//2)**2)*fov/size
                    if  radius <= sun_radius:
                        #print(int(round(radius/32*200,0)))
                        m.data[chan,0,i,j] = sun_disk[int(round(radius/32*200))]
            # The second polarisation is V which should be zero (or much less than I)
            m.data[chan,1,...] = 0.0
            # Convert from brightness temperature to Jy
            wavelength = constants.c.value / m.frequency[chan]
            t_to_jy = 1e26 * 2 * constants.k_B.value * cellsize**2 / wavelength**2
            m.data *= t_to_jy
            return m

    solar_model = fill_solar_model_jy(model)

    model_bvis = predict_ng(entire_bvis, solar_model, do_wstacking=False, verbosity=2)

    #%%

    controls = create_calibration_controls()
    controls['B']['first_selfcal'] = 0
    controls['B']['phase_only'] = False
    t_sol = 1.0
    controls['B']['timeslice'] = t_sol
    rsexecute.set_client(use_dask=False)
    cal_graph = calibrate_list_rsexecute_workflow([entire_bvis],
                                                  [model_bvis],
                                                  gt_list=None,
                                                  calibration_context='B',
                                                  controls=controls,
                                                  global_solution=False)
    cal_bvis_list, gt_list = rsexecute.compute(cal_graph, sync=True)

    plt.clf()
    igt = 0
    title="Bandpass_{:.2f}s_{:.3f}GHz".format(t_sol, cal_bvis_list[0].frequency[0]/1e9)
    gaintable_plot(gt_list[0]['B'], 'B', title=title)
    plt.savefig("{}.png".format(title))
    plt.show(block=False)

    entire_bvis = cal_bvis_list[0]


In [None]:
allvis = numpy.arange(ntimes)
on = (allvis > 2224) & (allvis<2947)
off = (allvis <= 2224) | (allvis>=2947)

In [None]:
for mode in [("on", on), ("off", off)]:
    bvis = create_visibility_from_rows(entire_bvis, mode[1])
    model = create_image_from_visibility(bvis, cellsize=cellsize, npixel=npixel,
                                         nchan=64, polarisation_frame=image_stokes)
    bvis = weight_visibility(bvis, model, weighting="robust")
    dirty, sumwt = invert_ng(bvis, model, do_wstacking=False, threads=8, verbosity=2)
    print(qa_image(dirty, context=mode[0]))
    export_image_to_fits(dirty, "muser_movie_channels_baselines_flag_{}.fits".format(mode[0]))
    plot_grid(dirty, "muser_movie_channels_baselines_flag_{}".format(mode[0]))

