In [None]:

from typing import List, Tuple

import astropy.coordinates as ac
import astropy.time as at

from dsa2000_assets.array_constraints.v6.array_constraint import ArrayConstraintsV6
from dsa2000_fm.array_layout.pareto_front_search import Results
from dsa2000_fm.array_layout.sample_constraints import RegionSampler


def plot_solution(plot_folder: str, antennas: ac.EarthLocation, obstime: at.Time, array_location: ac.EarthLocation,
                  aoi_data: List[Tuple[RegionSampler, float]],
                  constraint_data: List[Tuple[RegionSampler, float]]):
    # Plot along with regions
    fig, ax = plt.subplots(1, 1, figsize=(6, 6))
    # array_constraint = ArrayConstraintV2()
    # aoi_data = array_constraint.get_area_of_interest_regions()
    # constraint_data = array_constraint.get_constraint_regions()
    for sampler, buffer in aoi_data:
        # sampler.info()
        sampler.plot_region(ax=ax, color='blue')
    for sampler, buffer in constraint_data:
        sampler.plot_region(ax=ax, color='none')

    # ax.scatter(antennas_enu[:, 0], antennas_enu[:, 1], s=1, c='green', alpha=0.5, marker='.')
    ax.scatter(antennas.geodetic.lon.deg, antennas.geodetic.lat.deg, s=1, c='green', alpha=0.5, marker='.')
    ax.set_xlabel('Longitude [deg]')
    ax.set_ylabel('Latitude [deg]')
    ax.set_title('Antenna layout')
    ax.set_xlim(-114.6, -114.3)
    ax.set_ylim(39.45, 39.70)
    fig.savefig(os.path.join(plot_folder, f'array_solution.png'))
    plt.show()

    # # Plot violations
    # for idx, point in enumerate(antennas):
    #     for sampler, buffer in constraint_data:
    #         (px, py), dist = sampler.closest_approach(point.geodetic.lon.deg, point.geodetic.lat.deg)
    #         earth_radius = np.linalg.norm(point.get_itrs().cartesian.xyz.to(au.m).value)
    #         dist = np.pi / 180 * dist * earth_radius

    #         if dist < buffer:
    #             print('Agree')
    #             sampler.info()
    #             fig, ax = plt.subplots(1, 1, figsize=(6, 6))
    #             sampler.plot_region(ax=ax, color='none')
    #             ax.scatter(px, py, c='g')
    #             ax.scatter(point.geodetic.lon.deg, point.geodetic.lat.deg, c='b')
    #             bbox = min(point.geodetic.lon.deg, px), max(point.geodetic.lon.deg, px), min(point.geodetic.lat.deg,
    #                                                                                          py), max(
    #                 point.geodetic.lat.deg, py)
    #             ax.set_xlim(bbox[0] - 0.005, bbox[1] + 0.005)
    #             ax.set_ylim(bbox[2] - 0.005, bbox[3] + 0.005)
    #             ax.set_title(f"{dist} {buffer}")
    #             plt.show()


results_file = "pareto_opt_v6_b/results.json"
results = Results.parse_file(results_file)

array_constraint = ArrayConstraintsV6(extension='b')
aoi_data = array_constraint.get_area_of_interest_regions()
# merge AOI's
merged_aoi_sampler = RegionSampler.merge([s for s, _ in aoi_data])
merged_buffer = max([b for _, b in aoi_data])
aoi_data = [(merged_aoi_sampler, merged_buffer)]
constraint_data = array_constraint.get_constraint_regions()

evaluation = min(results.evaluations, key=lambda x: x.cost)

plot_solution('./pareto_opt_target/', evaluation.antennas, results.obstime, results.array_location, aoi_data,
              constraint_data)

# with open('solution.txt', 'w') as f:
#     f.write('#X,Y,Z\n')
#     for antenna in evaluation.antennas:
#         f.write(f"{antenna.x.to('m').value},{antenna.y.to('m').value},{antenna.z.to('m').value}\n")


In [None]:
import os
import pylab as plt

from dsa2000_fm.array_layout.pareto_front_search import Results

for prefix in ['a', 'b', 'c', 'd', 'e', 'f', 'h', 'full']:
    results_file = f"pareto_opt_v6_{prefix}/results.json"
    results = Results.parse_file(results_file)

    quality = [r.quality for r in results.evaluations]

    plt.plot(quality, alpha=0.5, label=prefix)
plt.legend()
plt.xlabel('Iteration')
plt.ylabel('PSF quality')
plt.ylim(-7, 0)
plt.show()


In [None]:

from dsa2000_fm.array_layout.pareto_front_search import _get_pareto_eqs
import numpy as np
from dsa2000_common.common.plot_utils import figs_to_gif
import itertools
from tqdm import tqdm
from scipy.spatial import ConvexHull


def gen_figs(evaluations):
    pbar = tqdm(itertools.count())
    for idx in range(3, len(evaluations)):
        hull = ConvexHull(points=np.asarray([[e.cost, e.quality] for e in evaluations[:idx]]))
        normals, offsets, simplices, simplices_lengths, vertex_idxs = _get_pareto_eqs(hull)

        # Choose a random antenna to replace
        evaluation = evaluations[idx]
        point = np.asarray([evaluation.cost, evaluation.quality])
        # Check if the new point is on the Pareto front using eqs
        fig, ax = plt.subplots(1, 1, figsize=(10, 10))
        sc = ax.scatter(hull.points[:, 0], hull.points[:, 1], c=range(len(hull.points)), cmap='jet')
        plt.colorbar(sc, ax=ax, label='Iteration')
        ax.scatter(hull.points[vertex_idxs, 0], hull.points[vertex_idxs, 1], c='black', marker='*',
                   label='Previous Pareto front')
        ax.scatter(point[0], point[1], c='blue', marker='x',
                   label='New point')
        ax.set_xlabel('Cost')
        ax.set_ylabel('Quality')
        ax.legend()
        plt.close(fig)
        pbar.update(1)
        yield fig
        plt.close(fig)


figs_to_gif(gen_figs(results.evaluations),
            gif_path='pareto_optimisation.gif', duration=10, loop=0, dpi=100)