In [1]:
# import gzip
# import shutil
# from astropy.io import fits

# # Input and output file paths
# input_gz_file = '/u/ywagh/gwemopt_tests/006.gz'
# output_fits_file = '/u/ywagh/gwemopt_tests/006.fits'

# # Decompress the .gz file
# with gzip.open(input_gz_file, 'rb') as f_in:
#     with open(output_fits_file, 'wb') as f_out:
#         shutil.copyfileobj(f_in, f_out)

# # Open the FITS file to ensure it's correctly decompressed
# with fits.open(output_fits_file) as hdul:
#     print(hdul.info())


In [2]:
import astroplan
import regions
from astropy.coordinates import ICRS, SkyCoord, AltAz, get_moon, EarthLocation, get_body
from astropy import units as u
from astropy.utils.data import download_file
from astropy.table import Table, QTable, join
from astropy.time import Time, TimeDelta
from astropy_healpix import *
from ligo.skymap import plot
from ligo.skymap.io import read_sky_map
import healpy as hp
import os
from matplotlib import pyplot as plt
import numpy as np
from tqdm.auto import tqdm
import datetime as dt
import pickle
import pandas as pd
from docplex.mp.model import Model

import warnings
warnings.filterwarnings("ignore", "Wswiglal-redir-stdio")
warnings.simplefilter('ignore', astroplan.TargetNeverUpWarning)
warnings.simplefilter('ignore', astroplan.TargetAlwaysUpWarning)

# directory_path = "/u/ywagh/test_skymaps/S240422ed.fits"
# skymap, metadata = read_sky_map(os.path.join(directory_path))

directory_path = "/u/ywagh/test_skymaps/"
filelist = sorted([f for f in os.listdir(directory_path) if f.endswith('.gz')])

slew_speed = 2.5 * u.deg / u.s
slew_accel = 0.4 * u.deg / u.s**2
readout = 8.2 * u.s

ns_nchips = 4
ew_nchips = 4
ns_npix = 6144
ew_npix = 6160
plate_scale = 1.01 * u.arcsec
ns_chip_gap = 0.205 * u.deg
ew_chip_gap = 0.140 * u.deg

ns_total = ns_nchips * ns_npix * plate_scale + (ns_nchips - 1) * ns_chip_gap
ew_total = ew_nchips * ew_npix * plate_scale + (ew_nchips - 1) * ew_chip_gap

rcid = np.arange(64)

chipid, rc_in_chip_id = np.divmod(rcid, 4)
ns_chip_index, ew_chip_index = np.divmod(chipid, ew_nchips)
ns_rc_in_chip_index = np.where(rc_in_chip_id <= 1, 1, 0)
ew_rc_in_chip_index = np.where((rc_in_chip_id == 0) | (rc_in_chip_id == 3), 0, 1)

ew_offsets = ew_chip_gap * (ew_chip_index - (ew_nchips - 1) / 2) + ew_npix * plate_scale * (ew_chip_index - ew_nchips / 2) + 0.5 * ew_rc_in_chip_index * plate_scale * ew_npix
ns_offsets = ns_chip_gap * (ns_chip_index - (ns_nchips - 1) / 2) + ns_npix * plate_scale * (ns_chip_index - ns_nchips / 2) + 0.5 * ns_rc_in_chip_index * plate_scale * ns_npix

ew_ccd_corners = 0.5 * plate_scale * np.asarray([ew_npix, 0, 0, ew_npix])
ns_ccd_corners = 0.5 * plate_scale * np.asarray([ns_npix, ns_npix, 0, 0])

ew_vertices = ew_offsets[:, np.newaxis] + ew_ccd_corners[np.newaxis, :]
ns_vertices = ns_offsets[:, np.newaxis] + ns_ccd_corners[np.newaxis, :]

def get_footprint(center):
    return SkyCoord(
        ew_vertices, ns_vertices,
        frame=center[..., np.newaxis, np.newaxis].skyoffset_frame()
    ).icrs

url = 'https://github.com/ZwickyTransientFacility/ztf_information/raw/master/field_grid/ZTF_Fields.txt'
filename = download_file(url)
field_grid = QTable(np.recfromtxt(filename, comments='%', usecols=range(3), names=['field_id', 'ra', 'dec']))
field_grid['coord'] = SkyCoord(field_grid.columns.pop('ra') * u.deg, field_grid.columns.pop('dec') * u.deg)
field_grid = field_grid[0:881]

#******************************************************************************
skymap, metadata = read_sky_map(os.path.join(directory_path, filelist[5]))

plot_filename = os.path.basename(filelist[5])
print(plot_filename)
# plot_filename = 'S240422ed'
# ci
#******************************************************************************

event_time = Time(metadata['gps_time'], format='gps').utc
gps_time = Time(metadata['gps_time'], format='gps')

event_time.format = 'iso'
print('event time:',event_time)
observer = astroplan.Observer.at_site('Palomar')
night_horizon = -18 * u.deg
if observer.is_night(event_time, horizon=night_horizon):
    start_time = event_time
else:
    start_time = observer.sun_set_time(
        event_time, horizon=night_horizon, which='next')

# Find the latest possible end time of observations: the time of sunrise.
end_time = observer.sun_rise_time(
    start_time, horizon=night_horizon, which='next')

min_airmass = 2.5 * u.dimensionless_unscaled
airmass_horizon = (90 * u.deg - np.arccos(1 / min_airmass))
targets = field_grid['coord']

# Find the time that each field rises and sets above an airmass of 2.5.
target_start_time = Time(np.where(
    observer.target_is_up(start_time, targets, horizon=airmass_horizon),
    start_time,
    observer.target_rise_time(start_time, targets, which='next', horizon=airmass_horizon)))
target_start_time.format = 'iso'

# Find the time that each field sets below the airmass limit. If the target
# is always up (i.e., it's circumpolar) or if it sets after surnsise,
# then set the end time to sunrise.
target_end_time = observer.target_set_time(
    target_start_time, targets, which='next', horizon=airmass_horizon)
target_end_time[
    (target_end_time.mask & ~target_start_time.mask) | (target_end_time > end_time)
] = end_time
target_end_time.format = 'iso'
# Select fields that are observable for long enough for at least one exposure
##############################################################################
exposure_time = 180 * u.second
exposure_time_day = exposure_time.to_value(u.day)

num_visits = 3
num_filters = 1

cadence = 30         #minutes
cadence_days = cadence / (60 * 24)
##############################################################################
field_grid['start_time'] = target_start_time
field_grid['end_time'] = target_end_time
observable_fields = field_grid[target_end_time - target_start_time >= exposure_time]

# print(observable_fields)
hpx = HEALPix(nside=256, frame=ICRS())

footprint = np.moveaxis(
    get_footprint(SkyCoord(0 * u.deg, 0 * u.deg)).cartesian.xyz.value, 0, -1)
footprint_healpix = np.unique(np.concatenate(
    [hp.query_polygon(hpx.nside, v, nest=(hpx.order == 'nested')) for v in footprint]))

'''
# computing the footprints of every ZTF field as HEALPix indices. Downsampling skymap to same resolution.
'''
footprints = np.moveaxis(get_footprint(observable_fields['coord']).cartesian.xyz.value, 0, -1)
footprints_healpix = [
    np.unique(np.concatenate([hp.query_polygon(hpx.nside, v) for v in footprint]))
    for footprint in tqdm(footprints)]

prob = hp.ud_grade(skymap, hpx.nside, power=-2)

# k = max number of 300s exposures 
min_start = min(observable_fields['start_time'])
max_end =max(observable_fields['end_time'])
# min_start.format = 'jd'
# max_end.format = 'jd'

# k=30

k = int(np.floor((max_end - min_start)/(exposure_time.to(u.day))))
k = np.floor(k/(num_visits*num_filters))
print(k," number of exposures could be taken tonight")

print("problem setup completed")


SWIGLAL standard output/error redirection is enabled in IPython.
This may lead to performance penalties. To disable locally, use:

with lal.no_swig_redirect_standard_output_error():
    ...

To disable globally, use:

lal.swig_redirect_standard_output_error(False)

Note however that this will likely lead to error messages from
LAL functions being either misdirected or lost when called from
Jupyter notebooks.


import lal

  import lal


006.gz
event time: 2024-05-14 08:03:21.048


  result = super().__array_ufunc__(function, method, *arrays, **kwargs)


  0%|          | 0/351 [00:00<?, ?it/s]

20.0  number of exposures could be taken tonight
problem setup completed


In [3]:
# print(start_time.isot)
# print(end_time.iso)
# print(end_time-start_time)
# start = start_time.mjd
# f = (end_time-start_time).value
# f

In [4]:
m1 = Model('max coverage problem')

field_vars = m1.binary_var_list(len(footprints), name='field')
pixel_vars = m1.binary_var_list(hpx.npix, name='pixel')

footprints_healpix_inverse = [[] for _ in range(hpx.npix)]

for field, pixels in enumerate(footprints_healpix):
    for pixel in pixels:
        footprints_healpix_inverse[pixel].append(field)

for i_pixel, i_fields in enumerate(footprints_healpix_inverse):
     m1.add_constraint(m1.sum(field_vars[i] for i in i_fields) >= pixel_vars[i_pixel])

m1.add_constraint(m1.sum(field_vars) <= k)
m1.maximize(m1.dot(pixel_vars, prob))
print(f"number fo fields observed should be less than {k}")

solution = m1.solve(log_output=True)

print("optimization completed")
total_prob_covered = solution.objective_value

print("Total probability covered:",total_prob_covered)

selected_fields_ID = [i for i, v in enumerate(field_vars) if v.solution_value == 1]
print(len(selected_fields_ID), "fields selected")
selected_fields = observable_fields[selected_fields_ID]
# print(selected_fields)

separation_matrix = selected_fields['coord'][:,np.newaxis].separation(selected_fields['coord'][np.newaxis,:])

def slew_time(separation):
   return np.where(separation <= (slew_speed**2 / slew_accel),
                   np.sqrt(2 * separation / slew_accel),
                   (2 * slew_speed / slew_accel) + (separation - slew_speed**2 / slew_accel) / slew_speed)

slew_times = slew_time(separation_matrix).value

slew_time_value = slew_times*u.second
slew_time_day = slew_time_value.to_value(u.day)

footprints_selected = np.moveaxis(get_footprint(selected_fields['coord']).cartesian.xyz.value, 0, -1)
footprints_healpix_selected = [
    np.unique(np.concatenate([hp.query_polygon(hpx.nside, v) for v in footprint]))
    for footprint in tqdm(footprints_selected)]

probabilities = []

for field_index in range(len(footprints_healpix_selected)):
    probability_field = np.sum(prob[footprints_healpix_selected[field_index]])
    probabilities.append(probability_field)
print("worked for",len(probabilities),"fields")

selected_fields['probabilities'] = probabilities

number fo fields observed should be less than 20.0
Version identifier: 22.1.1.0 | 2022-11-28 | 9160aff4d
CPXPARAM_Read_DataCheck                          1
Found incumbent of value 0.000000 after 0.03 sec. (24.20 ticks)
Tried aggregator 2 times.
MIP Presolve eliminated 784344 rows and 514583 columns.
Aggregator did 351 substitutions.
Reduced MIP has 1738 rows, 2088 columns, and 6504 nonzeros.
Reduced MIP has 2088 binaries, 0 generals, 0 SOSs, and 0 indicators.
Presolve time = 0.69 sec. (593.80 ticks)
Probing time = 0.00 sec. (0.44 ticks)
Tried aggregator 1 time.
Detecting symmetries...
Reduced MIP has 1738 rows, 2088 columns, and 6504 nonzeros.
Reduced MIP has 2088 binaries, 0 generals, 0 SOSs, and 0 indicators.
Presolve time = 0.01 sec. (4.02 ticks)
Probing time = 0.00 sec. (0.44 ticks)
MIP emphasis: balance optimality and feasibility.
MIP search method: dynamic search.
Parallel mode: deterministic, using up to 32 threads.
Root relaxation solution time = 0.01 sec. (5.49 ticks)

      

  0%|          | 0/20 [00:00<?, ?it/s]

worked for 20 fields


In [16]:
from docplex.mp.model import Model
import numpy as np
from astropy import units as u
from astropy.time import Time, TimeDelta

def create_scheduling_model(observable_fields, slew_times, num_visits=3, cadence_minutes=30,
                          exposure_time=180*u.second, prob_weight=0.9):
    """
    Creates and solves a MILP model for scheduling field observations.
    Times are normalized relative to the earliest start time.
    
    Parameters:
    -----------
    observable_fields : astropy.table.Table
        Table containing field information including IDs, coordinates, start/end times
    slew_times : numpy.ndarray
        Matrix of slew times between fields
    num_visits : int
        Maximum number of visits per field
    cadence_minutes : int
        Minimum time between revisits in minutes
    exposure_time : astropy.units.Quantity
        Duration of each exposure
    prob_weight : float
        Weight given to probability maximization (vs time minimization)
    """
    
    # Convert other time quantities to days
    exposure_time_days = exposure_time.to(u.day).value
    cadence_days = cadence_minutes / (24 * 60)
    
    # Create model
    m = Model('field_scheduling')
    
    num_fields = len(selected_fields)
    cad_list = np.ones(num_fields)*cadence_days
    # Define variables
    # Binary variable x[i,v]: 1 if field i is observed for visit v
    x = m.binary_var_matrix(num_fields, num_visits, name='field_observation')
    
    # Continuous variable t[i,v]: start time of observation for field i, visit v (in days from reference)
    # t = m.continuous_var_matrix(num_fields, num_visits, lb=0,ub = reference_time, name='start_time')
    
    tc = [[m.continuous_var(
    lb=(row['start_time'] - start_time).to_value(u.day),
    ub=(row['end_time'] - start_time - exposure_time).to_value(u.day),
    name=f"start_time_field_{i}_visit_{v}")
    for v in range(num_visits*num_filters)] 
    for i, row in enumerate(selected_fields)]
    
    # Constraints
    for i in range(num_fields):
        field_start = tc[i]
        field_end = field_start+cad_list
        
        for v in range(num_visits):
            # Time window constraints
            m.add_constraint(tc[i,v] >= field_start * x[i,v])
            m.add_constraint(tc[i,v] <= field_end * x[i,v] + (1 - x[i,v]) * field_end)
            
            # Ensure visits happen in order
            if v > 0:
                m.add_constraint(tc[i,v] >= tc[i,v-1] + cadence_days * x[i,v])
                # If a visit v is scheduled, all previous visits must be scheduled
                m.add_constraint(x[i,v] <= x[i,v-1])
    
    # Non-overlap constraints between all fields
    for i in range(num_fields):
        for j in range(i+1, num_fields):
            for v1 in range(num_visits):
                for v2 in range(num_visits):
                    slew_days = slew_times[i,j] / (24 * 3600)  # Convert to days
                    m.add_indicator_constraint(
                        x[i,v1],
                        tc[j,v2] >= tc[i,v1] + exposure_time_days + slew_days - 
                        (1 - x[j,v2]) * max(normalized_end_times)
                    )
                    m.add_indicator_constraint(
                        x[j,v2],
                        tc[i,v1] >= tc[j,v2] + exposure_time_days + slew_days - 
                        (1 - x[i,v1]) * max(normalized_end_times)
                    )
    
    # Objective function
    total_prob = m.sum(x[i,v] * selected_fields['probabilities'][i] 
                      for i in range(num_fields) 
                      for v in range(num_visits))
    
    total_time = m.sum(t[i,v] for i in range(num_fields) for v in range(num_visits))
    
    # Maximize probability while minimizing total time
    m.maximize(prob_weight * total_prob - (1 - prob_weight) * total_time)
    
    # Solve the model
    solution = m.solve(log_output=True)
    
    if solution is None:
        raise ValueError("No feasible solution found")
    
    # Extract results
    field_matrix = np.zeros((num_fields, num_visits))
    time_matrix = np.zeros((num_fields, 2 * num_visits))  # Start and end times for each visit
    
    for i in range(num_fields):
        for v in range(num_visits):
            field_matrix[i,v] = x[i,v].solution_value
            if x[i,v].solution_value > 0.5:  # Account for floating point errors
                time_matrix[i,2*v] = tc[i,v].solution_value  # Start time (days from reference)
                time_matrix[i,2*v+1] = tc[i,v].solution_value + exposure_time_days  # End time
    
    return field_matrix, time_matrix, solution.objective_value, reference_time

# Helper function to convert normalized times to readable format if needed
def format_time_matrix(time_matrix, reference_time):
    """
    Convert normalized days to readable times
    """
    formatted_times = []
    for row in time_matrix:
        row_times = []
        for tc in row:
            if tc > 0:  # Only convert non-zero times
                absolute_time = Time(reference_time + tc, format='jd')
                row_times.append(absolute_time.iso)
            else:
                row_times.append('')
        formatted_times.append(row_times)
    return formatted_times

In [17]:
field_matrix, time_matrix, objective_value = create_scheduling_model(
    selected_fields,
    slew_time_day,
    num_visits=3,
    cadence_minutes=30
)

# Convert times to readable format
formatted_times = format_time_matrix(time_matrix, event_time)

ValueError: operands could not be broadcast together with shapes (3,) (20,) 

In [None]:
# # Instead of filtering out fields, add an "availability factor" to each field
# # Compute availability
# availability = [
#     (row['end_time'] - row['start_time']).to_value(u.day) / (end_time - start_time).to_value(u.day)
#     for row in selected_fields
# ]

# # Add 'availability' to selected_fields
# selected_fields['availability'] = availability

# # Create selected_fil_fileds with all required columns
# selected_fil_fileds = QTable(rows=selected_fields, names=selected_fields.colnames)

# # Verify column names
# print("Columns in selected_fil_fileds:", selected_fil_fileds.colnames)

# # Convert availability to a NumPy array for safe indexing
# availabilities = np.array(selected_fil_fileds['availability'])

# # Apply constraints using the availability array
# for i in range(len(selected_fil_fileds)):
#     min_obs = max(1, int(np.ceil(3 * availabilities[i])))
#     m2.add_constraint(
#         m2.sum(x[i][v] for v in range(num_visits*num_filters)) >= min_obs,
#         ctname=f"min_obs_field_{i}"
#     )


# m2 = Model("Telescope Scheduling")
# footprints_selected = np.moveaxis(get_footprint(selected_fil_fileds['coord']).cartesian.xyz.value, 0, -1)
# footprints_healpix_selected = [
#     np.unique(np.concatenate([hp.query_polygon(hpx.nside, v) for v in footprint]))
#     for footprint in tqdm(footprints_selected)]

# # Compute probabilities for each field
# probabilities = []
# for field_index in range(len(footprints_healpix_selected)):
#     probability_field = np.sum(prob[footprints_healpix_selected[field_index]])
#     probabilities.append(probability_field)

# # Ensure 'probability' column is properly assigned to selected_fields
# selected_fields['probability'] = probabilities

# # Create QTable with probabilities included
# selected_fil_fileds = QTable(rows=selected_fields, names=selected_fields.colnames)

# # Convert probabilities to NumPy array before using in optimization
# probabilities = np.array(selected_fil_fileds['probability'])



# # selected_fields['probabilities'] = probabilities
# # Define observation duration and total observation window
# delta = exposure_time.to_value(u.day)  # Observation duration per field
# M = (selected_fil_fileds['end_time'].max() - selected_fil_fileds['start_time'].min()).to_value(u.day).item()

# # Decision variables: Binary variables for each field and visit
# x = [[m2.binary_var(name=f"x_{i}_visit_{v}") 
#       for v in range(num_visits*num_filters)] 
#       for i in range(len(selected_fil_fileds))]

# # Continuous variables for start time of each visit
# tc = [[m2.continuous_var(
#     lb=(row['start_time'] - start_time).to_value(u.day),
#     ub=(row['end_time'] - start_time - exposure_time).to_value(u.day),
#     name=f"start_time_field_{i}_visit_{v}")
#     for v in range(num_visits*num_filters)] 
#     for i, row in enumerate(selected_fil_fileds)]

# # Variables for enforcing non-overlapping observations
# visit_transition_times = [m2.continuous_var(
#     lb=0, ub=M, name=f"visit_transition_{v}")
#     for v in range(num_visits*num_filters - 1)]

# # **1. Enforce Minimum Observations Proportional to Availability**
# for i, row in enumerate(selected_fil_fileds):
#     # min_obs = max(1, round(3 * row['availability']))  # Ensure at least 1 observation
#     min_obs = max(1, int(np.ceil(3 * row['availability'])))  # Use ceil instead of round

#     m2.add_constraint(
#         m2.sum(x[i][v] for v in range(num_visits*num_filters)) >= min_obs,
#         ctname=f"min_obs_field_{i}"
#     )

# # **2. Enforce Minimum Time Gap Between Observations (30 min)**
# cadence_days = (30 * u.minute).to_value(u.day)  # 30 minutes in days

# for i in range(len(selected_fil_fileds)):
#     for v in range(1, num_visits*num_filters):
#         m2.add_constraint(
#             tc[i][v] - tc[i][v-1] >= (cadence_days + delta) * (x[i][v] + x[i][v-1] - 1),
#             ctname=f"cadence_constraint_field_{i}_visits_{v}"
#         )

# for v in range(num_visits * num_filters):
#     for i in range(len(selected_fil_fileds)):
#         for j in range(i):
#             # Ensure i finishes before j starts OR j finishes before i starts
#             m2.add_constraint(
#                 tc[i][v] + delta * x[i][v] + slew_time_day[i][j] <= tc[j][v] + M * (2 - x[i][v] - x[j][v]),
#                 ctname=f"non_overlapping_field_{i}_{j}_visit_{v}"
#             )
#             m2.add_constraint(
#                 tc[j][v] + delta * x[j][v] + slew_time_day[i][j] <= tc[i][v] + M * (2 - x[i][v] - x[j][v]),
#                 ctname=f"non_overlapping_field_{j}_{i}_visit_{v}"
#             )


# m2.add_constraint(
#     m2.sum(x[i][v] for v in range(num_visits*num_filters)) <= 3,
#     ctname=f"max_obs_field_{i}"
# )
# # **4. Ensure Observations Stay Within Availability Windows**
# for i, row in enumerate(selected_fil_fileds):
#     for v in range(num_visits*num_filters):
#         m2.add_constraint(
#             tc[i][v] >= (row['start_time'] - start_time).to_value(u.day) * x[i][v],
#             ctname=f"start_time_restrict_field_{i}_visit_{v}"
#         )
#         m2.add_constraint(
#             tc[i][v] <= (row['end_time'] - start_time).to_value(u.day) * x[i][v],
#             ctname=f"end_time_restrict_field_{i}_visit_{v}"
#         )

# # **5. Maximize Probability-Weighted Observations**
# probabilities = np.array([row['probability'] for row in selected_fil_fileds])
# m2.maximize(
#     m2.sum(probabilities[i] * x[i][v]
#            for i in range(len(selected_fil_fileds))
#            for v in range(num_visits*num_filters))
# )

# # **6. Solver Parameters**
# m2.parameters.mip.tolerances.mipgap = 0.01  # 1% optimality gap
# m2.parameters.emphasis.mip = 2  # Focus on optimality over feasibility

# # **7. Solve the Model**
# solution = m2.solve(log_output=True)


## finding coverage

In [5]:
# for var in field_vars:
#     print(solution.get_value(var))

In [None]:
# # Instead of filtering out fields, add an "availability factor" to each field
# # Compute availability
# availability = [
#     (row['end_time'] - row['start_time']).to_value(u.day) / (end_time - start_time).to_value(u.day)
#     for row in selected_fields
# ]

# # Add 'availability' to selected_fields
# selected_fields['availability'] = availability

# # Create selected_fil_fileds with all required columns
# selected_fil_fileds = QTable(rows=selected_fields, names=selected_fields.colnames)

# # Verify column names
# print("Columns in selected_fil_fileds:", selected_fil_fileds.colnames)

# # Convert availability to a NumPy array for safe indexing
# availabilities = np.array(selected_fil_fileds['availability'])

# # Apply constraints using the availability array
# for i in range(len(selected_fil_fileds)):
#     min_obs = max(1, int(np.ceil(3 * availabilities[i])))
#     m2.add_constraint(
#         m2.sum(x[i][v] for v in range(num_visits*num_filters)) >= min_obs,
#         ctname=f"min_obs_field_{i}"
#     )


# m2 = Model("Telescope Scheduling")
# footprints_selected = np.moveaxis(get_footprint(selected_fil_fileds['coord']).cartesian.xyz.value, 0, -1)
# footprints_healpix_selected = [
#     np.unique(np.concatenate([hp.query_polygon(hpx.nside, v) for v in footprint]))
#     for footprint in tqdm(footprints_selected)]

# # Compute probabilities for each field
# probabilities = []
# for field_index in range(len(footprints_healpix_selected)):
#     probability_field = np.sum(prob[footprints_healpix_selected[field_index]])
#     probabilities.append(probability_field)

# # Ensure 'probability' column is properly assigned to selected_fields
# selected_fields['probability'] = probabilities

# # Create QTable with probabilities included
# selected_fil_fileds = QTable(rows=selected_fields, names=selected_fields.colnames)

# # Convert probabilities to NumPy array before using in optimization
# probabilities = np.array(selected_fil_fileds['probability'])



# # selected_fields['probabilities'] = probabilities
# # Define observation duration and total observation window
# delta = exposure_time.to_value(u.day)  # Observation duration per field
# M = (selected_fil_fileds['end_time'].max() - selected_fil_fileds['start_time'].min()).to_value(u.day).item()

# # Decision variables: Binary variables for each field and visit
# x = [[m2.binary_var(name=f"x_{i}_visit_{v}") 
#       for v in range(num_visits*num_filters)] 
#       for i in range(len(selected_fil_fileds))]

# # Continuous variables for start time of each visit
# tc = [[m2.continuous_var(
#     lb=(row['start_time'] - start_time).to_value(u.day),
#     ub=(row['end_time'] - start_time - exposure_time).to_value(u.day),
#     name=f"start_time_field_{i}_visit_{v}")
#     for v in range(num_visits*num_filters)] 
#     for i, row in enumerate(selected_fil_fileds)]

# # Variables for enforcing non-overlapping observations
# visit_transition_times = [m2.continuous_var(
#     lb=0, ub=M, name=f"visit_transition_{v}")
#     for v in range(num_visits*num_filters - 1)]

# # **1. Enforce Minimum Observations Proportional to Availability**
# for i, row in enumerate(selected_fil_fileds):
#     # min_obs = max(1, round(3 * row['availability']))  # Ensure at least 1 observation
#     min_obs = max(1, int(np.ceil(3 * row['availability'])))  # Use ceil instead of round

#     m2.add_constraint(
#         m2.sum(x[i][v] for v in range(num_visits*num_filters)) >= min_obs,
#         ctname=f"min_obs_field_{i}"
#     )

# # **2. Enforce Minimum Time Gap Between Observations (30 min)**
# cadence_days = (30 * u.minute).to_value(u.day)  # 30 minutes in days

# for i in range(len(selected_fil_fileds)):
#     for v in range(1, num_visits*num_filters):
#         m2.add_constraint(
#             tc[i][v] - tc[i][v-1] >= (cadence_days + delta) * (x[i][v] + x[i][v-1] - 1),
#             ctname=f"cadence_constraint_field_{i}_visits_{v}"
#         )

# for v in range(num_visits * num_filters):
#     for i in range(len(selected_fil_fileds)):
#         for j in range(i):
#             # Ensure i finishes before j starts OR j finishes before i starts
#             m2.add_constraint(
#                 tc[i][v] + delta * x[i][v] + slew_time_day[i][j] <= tc[j][v] + M * (2 - x[i][v] - x[j][v]),
#                 ctname=f"non_overlapping_field_{i}_{j}_visit_{v}"
#             )
#             m2.add_constraint(
#                 tc[j][v] + delta * x[j][v] + slew_time_day[i][j] <= tc[i][v] + M * (2 - x[i][v] - x[j][v]),
#                 ctname=f"non_overlapping_field_{j}_{i}_visit_{v}"
#             )


# m2.add_constraint(
#     m2.sum(x[i][v] for v in range(num_visits*num_filters)) <= 3,
#     ctname=f"max_obs_field_{i}"
# )
# # **4. Ensure Observations Stay Within Availability Windows**
# for i, row in enumerate(selected_fil_fileds):
#     for v in range(num_visits*num_filters):
#         m2.add_constraint(
#             tc[i][v] >= (row['start_time'] - start_time).to_value(u.day) * x[i][v],
#             ctname=f"start_time_restrict_field_{i}_visit_{v}"
#         )
#         m2.add_constraint(
#             tc[i][v] <= (row['end_time'] - start_time).to_value(u.day) * x[i][v],
#             ctname=f"end_time_restrict_field_{i}_visit_{v}"
#         )

# # **5. Maximize Probability-Weighted Observations**
# probabilities = np.array([row['probability'] for row in selected_fil_fileds])
# m2.maximize(
#     m2.sum(probabilities[i] * x[i][v]
#            for i in range(len(selected_fil_fileds))
#            for v in range(num_visits*num_filters))
# )

# # **6. Solver Parameters**
# m2.parameters.mip.tolerances.mipgap = 0.01  # 1% optimality gap
# m2.parameters.emphasis.mip = 2  # Focus on optimality over feasibility

# # **7. Solve the Model**
# solution = m2.solve(log_output=True)


In [None]:
# First model (m1) remains the same as it correctly selects fields covering LIGO localization

# Modify the availability calculation to be more granular
availability = [
    (row['end_time'] - row['start_time']).to_value(u.day) / (end_time - start_time).to_value(u.day)
    for row in selected_fields
]
# Instead of filtering out fields, add an "availability factor" to each field

# Add 'availability' to selected_fields
selected_fields['availability'] = availability

# Create selected_fil_fileds with all required columns
selected_fil_fileds = QTable(rows=selected_fields, names=selected_fields.colnames)

# Verify column names
print("Columns in selected_fil_fileds:", selected_fil_fileds.colnames)

# Convert availability to a NumPy array for safe indexing
availabilities = np.array(selected_fil_fileds['availability'])

# Apply constraints using the availability array



footprints_selected = np.moveaxis(get_footprint(selected_fil_fileds['coord']).cartesian.xyz.value, 0, -1)
footprints_healpix_selected = [
    np.unique(np.concatenate([hp.query_polygon(hpx.nside, v) for v in footprint]))
    for footprint in tqdm(footprints_selected)]

# Compute probabilities for each field
probabilities = []
for field_index in range(len(footprints_healpix_selected)):
    probability_field = np.sum(prob[footprints_healpix_selected[field_index]])
    probabilities.append(probability_field)

# Ensure 'probability' column is properly assigned to selected_fields
selected_fields['probability'] = probabilities

# Create QTable with probabilities included
selected_fil_fileds = QTable(rows=selected_fields, names=selected_fields.colnames)

# Convert probabilities to NumPy array before using in optimization
probabilities = np.array(selected_fil_fileds['probability'])
# Create the second optimization model
m2 = Model("Flexible Telescope Scheduling")

# Constants
delta = exposure_time.to_value(u.day)
M = (observable_fields['end_time'].max() - observable_fields['start_time'].min()).to_value(u.day).item()
cadence_days = (30 * u.minute).to_value(u.day)

# Decision variables for all fields, not just fully available ones
x = [[m2.binary_var(name=f"x_{i}_visit_{v}") 
      for v in range(num_visits*num_filters)] 
      for i in range(len(selected_fields))]

# Continuous variables for observation start times
tc = [[m2.continuous_var(
    lb=(row['start_time'] - start_time).to_value(u.day),
    ub=(row['end_time'] - start_time - exposure_time).to_value(u.day),
    name=f"start_time_field_{i}_visit_{v}")
    for v in range(num_visits*num_filters)] 
    for i, row in enumerate(selected_fields)]

# Modified constraints based on availability
for i, row in enumerate(selected_fields):
    availability_factor = availability[i]
    
    # Dynamic minimum and maximum observations based on availability
    if availability_factor >= 0.8:
        min_obs = 1  # Changed from 3 to 1
        max_obs = 3
    elif availability_factor >= 0.5:
        min_obs = 1  # Changed from 2 to 1
        max_obs = 2
    else:
        min_obs = 1
        max_obs = 1
    
    # Add minimum observation constraint
    m2.add_constraint(
        m2.sum(x[i][v] for v in range(num_visits*num_filters)) >= min_obs,
        ctname=f"min_obs_field_{i}"
    )
    
    # Add maximum observation constraint
    m2.add_constraint(
        m2.sum(x[i][v] for v in range(num_visits*num_filters)) <= max_obs,
        ctname=f"max_obs_field_{i}"
    )

    # Add time window constraints
    valid_time_windows = []
    for v in range(num_visits*num_filters):
        # Only add time constraints if the field is actually observed
        m2.add_constraint(
            tc[i][v] >= (row['start_time'] - start_time).to_value(u.day) * x[i][v],
            ctname=f"start_time_restrict_field_{i}_visit_{v}"
        )
        m2.add_constraint(
            tc[i][v] <= (row['end_time'] - start_time - exposure_time).to_value(u.day) * x[i][v],
            ctname=f"end_time_restrict_field_{i}_visit_{v}"
        )

# Maintain minimum time gap between observations (30 min)
for i in range(len(selected_fields)):
    for v in range(1, num_visits*num_filters):
        m2.add_constraint(
            tc[i][v] - tc[i][v-1] >= (cadence_days + delta) * (x[i][v] + x[i][v-1] - 1),
            ctname=f"cadence_constraint_field_{i}_visits_{v}"
        )

# Non-overlapping observations constraint
for v in range(num_visits * num_filters):
    for i in range(len(selected_fields)):
        for j in range(i):
            m2.add_constraint(
                tc[i][v] + delta * x[i][v] + slew_time_day[i][j] <= tc[j][v] + M * (2 - x[i][v] - x[j][v]),
                ctname=f"non_overlapping_field_{i}_{j}_visit_{v}"
            )
            m2.add_constraint(
                tc[j][v] + delta * x[j][v] + slew_time_day[i][j] <= tc[i][v] + M * (2 - x[i][v] - x[j][v]),
                ctname=f"non_overlapping_field_{j}_{i}_visit_{v}"
            )

# Modified objective function incorporating availability and probability
probabilities = np.array([row['probability'] for row in selected_fields])
availability_weights = np.array(availability)

m2.maximize(
    m2.sum(probabilities[i] * x[i][v]
           for i in range(len(selected_fields))
           for v in range(num_visits*num_filters))
)

# Solver parameters for better convergence
# m2.parameters.mip.tolerances.mipgap = 0.05  # 5% optimality gap
# m2.parameters.timelimit = 600  # 10 minute time limit
# m2.parameters.emphasis.mip = 3  # Balance between feasibility and optimality
# m2.parameters.mip.strategy.variableselect = 3  # Strong branching

# Solve the model
solution = m2.solve(log_output=True)

In [19]:
schedule_matrix = np.zeros((len(selected_fields), num_visits*num_filters))
for i in range(len(selected_fields)):
    for v in range(num_visits*num_filters):
        if solution.get_value(x[i][v]) > 0.0000:  # Using 0.5 as threshold for binary variables
            schedule_matrix[i,v] = 1

In [None]:
schedule_matrix

In [None]:
m2 = Model("Telescope timings")
footprints_selected = np.moveaxis(get_footprint(selected_fields['coord']).cartesian.xyz.value, 0, -1)
footprints_healpix_selected = [
    np.unique(np.concatenate([hp.query_polygon(hpx.nside, v) for v in footprint]))
    for footprint in tqdm(footprints_selected)]

probabilities = []

for field_index in range(len(footprints_healpix_selected)):
    probability_field = np.sum(prob[footprints_healpix_selected[field_index]])
    probabilities.append(probability_field)
print("worked for",len(probabilities),"fields")

selected_fields['probabilities'] = probabilities
# Calculate available observation windows
available_duration = (selected_fields['end_time'] - selected_fields['start_time']).to_value(u.day)
min_required_duration = exposure_time.to_value(u.day) * num_visits
field_feasibility = available_duration >= min_required_duration

# Keep all fields but mark their feasibility
selected_fields['is_feasible'] = field_feasibility

delta = exposure_time.to_value(u.day)
M = (selected_fields['end_time'].max() - selected_fields['start_time'].min()).to_value(u.day).item()

# Decision variables
x = [[m2.binary_var(name=f"x_{i}_visit_{v}") 
      for v in range(num_visits)] 
      for i in range(len(selected_fields))]

tc = [[m2.continuous_var(
    lb=(row['start_time'] - start_time).to_value(u.day),
    ub=(row['end_time'] - start_time - exposure_time).to_value(u.day),
    name=f"start_time_field_{i}_visit_{v}")
    for v in range(num_visits)] 
    for i, row in enumerate(selected_fields)]

# Variables to track if a field completes all visits
all_visits_completed = [m2.binary_var(name=f"complete_{i}") 
                       for i in range(len(selected_fields))]

# Constraint: Fields must be scheduled within their available windows
for i in range(len(selected_fields)):
    for v in range(num_visits):
        m2.add_constraint(
            tc[i][v] + delta * x[i][v] <= 
            (selected_fields['end_time'][i] - start_time).to_value(u.day),
            ctname=f"end_window_{i}_visit_{v}"
        )
        m2.add_constraint(
            tc[i][v] >= 
            (selected_fields['start_time'][i] - start_time).to_value(u.day) * x[i][v],
            ctname=f"start_window_{i}_visit_{v}"
        )

# Constraint: Minimum time between visits for the same field
for i in range(len(selected_fields)):
    for v in range(1, num_visits):
        m2.add_constraint(
            tc[i][v] - tc[i][v-1] >= cadence_days * (x[i][v] + x[i][v-1] - 1),
            ctname=f"min_separation_{i}_visit_{v}"
        )

# Non-overlapping constraint between fields within each visit
for v in range(num_visits):
    for i in range(len(selected_fields)):
        for j in range(i):
            # Binary variable for ordering within this visit
            zij_v = m2.binary_var(name=f"z_{i}_{j}_visit_{v}")
            
            m2.add_constraint(
                tc[i][v] + delta * x[i][v] + slew_time_day[i][j] - tc[j][v] <= 
                M * (1 - zij_v + (2 - x[i][v] - x[j][v])),
                ctname=f"sequence1_{i}_{j}_visit_{v}"
            )
            m2.add_constraint(
                tc[j][v] + delta * x[j][v] + slew_time_day[i][j] - tc[i][v] <= 
                M * (zij_v + (2 - x[i][v] - x[j][v])),
                ctname=f"sequence2_{i}_{j}_visit_{v}"
            )

# Constraint: Link completion variable to visits
for i in range(len(selected_fields)):
    visit_sum = m2.sum(x[i][v] for v in range(num_visits))
    m2.add_constraint(
        visit_sum >= num_visits * all_visits_completed[i],
        ctname=f"completion_link1_{i}"
    )
    m2.add_constraint(
        visit_sum <= num_visits - 1 + all_visits_completed[i],
        ctname=f"completion_link2_{i}"
    )

# Constraint: Each field can only be scheduled if it's feasible
for i in range(len(selected_fields)):
    if not field_feasibility[i]:
        for v in range(num_visits):
            m2.add_constraint(x[i][v] == 0, ctname=f"infeasible_{i}_visit_{v}")

# Linear objective function
m2.maximize(
    m2.sum(
        probabilities[i] * (
            m2.sum(x[i][v] for v in range(num_visits)) +  # Base reward for each visit
            0.2 * all_visits_completed[i]  # Bonus for completing all visits
        )
        for i in range(len(selected_fields))
    )
)

# Solver parameters
m2.parameters.mip.tolerances.mipgap = 0.05
m2.parameters.timelimit = 300
m2.parameters.emphasis.mip = 3
m2.parameters.mip.strategy.variableselect = 4
m2.parameters.mip.pool.intensity = 2

# Solve the model
solution = m2.solve(log_output=True)

# Process results
if solution:
    scheduled_fields = []
    visit_schedules = []
    for i in range(len(selected_fields)):
        field_visits = []
        for v in range(num_visits):
            if solution.get_value(x[i][v]) > 0.5:
                field_visits.append({
                    'visit': v,
                    'time': solution.get_value(tc[i][v])
                })
        if field_visits:
            scheduled_fields.append(i)
            visit_schedules.append(field_visits)
            
    print(f"Successfully scheduled {len(scheduled_fields)} fields")
    print(f"Total objective value: {solution.objective_value:.4f}")
else:
    print("No solution found within the time limit")

In [None]:
a= np.mean(slew_times)
a

In [None]:
plt.figure(figsize=(10, 8))
#off-center case
# ax = plt.axes(projection='astro mollweide', center='0h 60d')
ax = plt.axes(projection='astro mollweide', center='0h 0d')

for row in selected_fields:
    coords = SkyCoord(
        [ew_total, -ew_total, -ew_total, ew_total],
        [ns_total, ns_total, -ns_total, -ns_total],
        frame=row['coord'].skyoffset_frame()
    ).icrs
    ax.add_patch(plt.Polygon(
        np.column_stack((coords.ra.deg, coords.dec.deg)),
        alpha=0.5,
        facecolor='lightgray',
        edgecolor='black',
        transform=ax.get_transform('world')
    ))
# plot_filename = os.path.basename(skymap_file)
plot_filename = 'S240910ci'
ax.grid()
ax.imshow_hpx(prob, cmap='cylon')
plt.text(0.05, 0.95, f'Total Probability Covered: {total_prob_covered:.2f}', transform=ax.transAxes,
        fontsize=12, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.5))

In [None]:
from astropy import units as u

exp_time_list = [60, 120, 180, 240]
for i in range(len(exp_time_list)):
    exposure_time = exp_time_list[i] * u.second
    print(exposure_time)

In [None]:
end_time.jd

In [None]:
delta = exposure_time.to_value(u.day)

limit_duration = ((end_time-start_time).value*2/3) 
filtered_rows = [
    row for row in selected_fields
    if (row['end_time'] - row['start_time']).to_value(u.day) > limit_duration
]

# Create a new QTable with the filtered rows
selected_fil_fileds = QTable(rows=filtered_rows, names=selected_fields.colnames)

selected_fil_fileds

In [None]:
selected_fil_fileds

In [12]:
# # Verify time windows for each field
# for i, row in enumerate(selected_fields):
#     print(f"Field {i}:")
#     print(f"Start time: {row['start_time']}")
#     print(f"End time: {row['end_time']}")
#     print(f"Window duration: {row['end_time'] - row['start_time']}")
#     delta = exposure_time.to_value(u.day)

#     if (row['end_time'] - row['start_time']).to_value(u.day) < (num_visits * num_filters * delta):
#         print(f"Warning: Time window might be too short for field {i}")

In [None]:
m2 = Model("Telescope timings")

observer_location = EarthLocation.of_site('Palomar')

footprints_selected = np.moveaxis(get_footprint(selected_fil_fileds['coord']).cartesian.xyz.value, 0, -1)
footprints_healpix_selected = [
    np.unique(np.concatenate([hp.query_polygon(hpx.nside, v) for v in footprint]))
    for footprint in tqdm(footprints_selected)]

probabilities = []

for field_index in range(len(footprints_healpix_selected)):
    probability_field = np.sum(prob[footprints_healpix_selected[field_index]])
    probabilities.append(probability_field)
print("worked for",len(probabilities),"fields")

selected_fil_fileds['probabilities'] = probabilities

delta = exposure_time.to_value(u.day)
M = (selected_fil_fileds['end_time'].max() - selected_fil_fileds['start_time'].min()).to_value(u.day).item()

x = [[m2.binary_var(name=f"x_{i}_visit_{v}") 
      for v in range(num_visits*num_filters)] 
      for i in range(len(selected_fil_fileds))]

tc = [[m2.continuous_var(
    lb=(row['start_time'] - start_time).to_value(u.day),
    ub=(row['end_time'] - start_time - exposure_time).to_value(u.day),
    name=f"start_time_field_{i}_visit_{v}")
    for v in range(num_visits*num_filters)] 
    for i, row in enumerate(selected_fil_fileds)]

visit_transition_times = [m2.continuous_var(
    lb=0,ub=M,name=f"visit_transition_{v}")
                          for v in range(num_visits*num_filters-1)]  

# Isolating visits
for v in range(1, num_visits*num_filters):
    for i in range(len(selected_fil_fileds)):
        m2.add_constraint(tc[i][v-1] + delta * x[i][v-1] <= visit_transition_times[v-1],
            ctname=f"visit_end_{i}_visit_{v-1}")
        m2.add_constraint(tc[i][v] >= visit_transition_times[v-1],
            ctname=f"visit_start_{i}_visit_{v}")


# Cadence constraints
for i in range(len(selected_fil_fileds)):
    for v in range(1, num_visits*num_filters):
        m2.add_constraint(tc[i][v] - tc[i][v-1] >= (cadence_days+delta) * (x[i][v] + x[i][v-1] - 1),
            ctname=f"cadence_constraint_field_{i}_visits_{v}")

#non-overlapping
for v in range(num_visits*num_filters):
    for i in range(len(selected_fil_fileds)):
        for j in range(i):
            m2.add_constraint(tc[i][v] + delta * x[i][v] + slew_time_day[i][j] - tc[j][v] <= M * (2 - x[i][v] - x[j][v]),
                              ctname=f"non_overlapping_cross_fields_{i}_{j}_visits_{v}")
            m2.add_constraint(tc[j][v] + delta * x[j][v] + slew_time_day[i][j] - tc[i][v] <= M * (-1 + x[i][v] + x[j][v]),
                ctname=f"non_overlapping_cross_fields_{j}_{i}_visits_{v}")

# Initialize the objective
m2.maximize(
    m2.sum(probabilities[i] * x[i][v]
           for i in range(len(selected_fil_fileds))
           for v in range(num_visits*num_filters))
)

# m2.parameters.timelimit = 60
m2.parameters.mip.tolerances.mipgap = 0.01  # 1% optimality gap
m2.parameters.emphasis.mip = 2  # Emphasize optimality over feasibility
solution = m2.solve(log_output=True)

In [None]:
for v in range(num_visits * num_filters):
    visit_fields = [i for i in range(len(selected_fil_fileds)) if solution.get_value(x[i][v]) == 1]
    print(visit_fields)

In [None]:
solution.objective_value/(num_visits * num_filters)

In [21]:

scheduled_fields_by_visit = []
for v in range(num_visits * num_filters):
    visit_fields = [i for i in range(len(selected_fil_fileds)) if solution.get_value(x[i][v]) == 1]
    scheduled_fields_by_visit.append(visit_fields)

scheduled_fields = selected_fil_fileds.copy()

scheduled_tc = []
for v in range(num_visits * num_filters):
    visit_times = []
    for i in range(len(selected_fil_fileds)):
        if i in scheduled_fields_by_visit[v]:
            visit_times.append(solution.get_value(tc[i][v]))
        else:
            visit_times.append(np.nan) 
    scheduled_tc.append(visit_times)

scheduled_tc = np.array(scheduled_tc).T  

for i in range(num_visits * num_filters):
    scheduled_fields[f"Scheduled_start_filt_times_{i}"] = scheduled_tc[:, i]

for v in range(num_visits * num_filters):
    scheduled_fields[f"Selected_in_visit_{v}"] = [1 if i in scheduled_fields_by_visit[v] else 0 
                                                 for i in range(len(scheduled_fields))]

In [None]:
scheduled_tc


In [None]:
n_visits = num_visits * num_filters  

fig, axes = plt.subplots(n_visits, 1, figsize=(8, 3 * n_visits), sharex=True)

for i in range(n_visits):
    start_col = f'Scheduled_start_filt_times_{i}'
    end_col = f'Scheduled_end_filt_times_{i}'
    
    # Filter valid rows
    valid_rows = ~np.isnan(scheduled_tc).any(axis=1)
    valid_scheduled_tc = scheduled_tc[valid_rows]
    valid_scheduled_fields = scheduled_fields[valid_rows]
    
    if len(valid_scheduled_fields) == 0:
        print(f"No entries found for Visit {i + 1}. Skipping plot.")
        continue  # Skip this visit
    
    # Convert start times to ISO format and compute end times
    valid_scheduled_fields[start_col] = Time(valid_scheduled_fields[start_col], format='mjd')
    valid_scheduled_fields[start_col].format = 'iso'
    valid_scheduled_fields[end_col] = valid_scheduled_fields[start_col] + exposure_time_day
    
    # Sort fields by end time
    valid_scheduled_fields.sort(end_col)
    
    # Get the start and end times for plotting
    first_start_time = valid_scheduled_fields[start_col].mjd[0]
    last_end_time = valid_scheduled_fields[end_col].mjd[-1]
    
    ax = axes[i]
    ax.hlines(
        np.arange(len(valid_scheduled_fields)),
        valid_scheduled_fields[start_col].mjd,
        valid_scheduled_fields[end_col].mjd,
        colors='blue',
        linewidth=2
    )
    # Plot small vertical lines at start and end times of each interval
    for j in range(len(valid_scheduled_fields)):
        ax.vlines(
            valid_scheduled_fields[start_col][j].mjd,
            ymin=j - 0.2,
            ymax=j + 0.2,
            color='black',
            linewidth=0.5,
            linestyle='-'
        )
        ax.vlines(
            valid_scheduled_fields[end_col][j].mjd,
            ymin=j - 0.2,
            ymax=j + 0.2,
            color='black',
            linewidth=0.5,
            linestyle='-'
        )
    
    # Highlight first start and last end times
    if i == 0:  # Add legend only for the first subplot
        ax.axvline(first_start_time, color='red', linestyle='--', linewidth=1.5, label='Start of First Field')
        ax.axvline(last_end_time, color='green', linestyle='--', linewidth=1.5, label='End of Last Field')
        ax.legend(loc='upper right')
    else:
        ax.axvline(first_start_time, color='red', linestyle='--', linewidth=1.5)
        ax.axvline(last_end_time, color='green', linestyle='--', linewidth=1.5)
    
    # Add labels and title
    ax.set_yticks(np.arange(len(valid_scheduled_fields)))
    ax.set_yticklabels(valid_scheduled_fields['field_id'].astype(str))
    ax.set_ylabel('Field ID')
    ax.set_title(f'Observation Schedule for Visit {i + 1}')
    
axes[-1].set_xlabel('Observation time (MJD)')

plt.tight_layout()
# save_path = '/u/ywagh/scheduler_results/plots_manuscript'
# os.makedirs(save_path, exist_ok=True)  # Ensure directory exists
# full_path = os.path.join(save_path, f'schedule_revisit_.png')
# plt.savefig(full_path, dpi=300, bbox_inches='tight')
plt.show()


## model two

In [None]:
m2 = Model("Telescope timings")

observer_location = EarthLocation.of_site('Palomar')

footprints_selected = np.moveaxis(get_footprint(selected_fields['coord']).cartesian.xyz.value, 0, -1)
footprints_healpix_selected = [
    np.unique(np.concatenate([hp.query_polygon(hpx.nside, v) for v in footprint]))
    for footprint in tqdm(footprints_selected)]

probabilities = []

for field_index in range(len(footprints_healpix_selected)):
    probability_field = np.sum(prob[footprints_healpix_selected[field_index]])
    probabilities.append(probability_field)
print("worked for",len(probabilities),"fields")

selected_fields['probabilities'] = probabilities

delta = exposure_time.to_value(u.day)
M = (selected_fields['end_time'].max() - selected_fields['start_time'].min()).to_value(u.day).item()

x = [[m2.binary_var(name=f"x_{i}_visit_{v}") 
      for v in range(num_visits*num_filters)] 
      for i in range(len(selected_fields))]

tc = [[m2.continuous_var(
    lb=(row['start_time'] - start_time).to_value(u.day),
    ub=(row['end_time'] - start_time - exposure_time).to_value(u.day),
    name=f"start_time_field_{i}_visit_{v}")
    for v in range(num_visits*num_filters)] 
    for i, row in enumerate(selected_fields)]

# Cadence constraints
for i in range(len(selected_fields)):
    for v in range(1, num_visits*num_filters):
        m2.add_constraint(tc[i][v] - tc[i][v-1] >= cadence_days * (x[i][v] + x[i][v-1] - 1),
            ctname=f"cadence_constraint_field_{i}_visits_{v}")

#non-overlapping
for v in range(num_visits*num_filters):
    for i in range(len(selected_fields)):
        for j in range(i):
            m2.add_constraint(tc[i][v] + delta * x[i][v] + slew_time_day[i][j] - tc[j][v] <= M * (2 - x[i][v] - x[j][v]),
                              ctname=f"non_overlapping_cross_fields_{i}_{j}_visits_{v}")
            m2.add_constraint(tc[j][v] + delta * x[j][v] + slew_time_day[i][j] - tc[i][v] <= M * (-1 + x[i][v] + x[j][v]),
                ctname=f"non_overlapping_cross_fields_{j}_{i}_visits_{v}")
            
# Create auxiliary variables for visit transition times
# visit_transition_times = [m2.continuous_var(
#     lb=0,  # or appropriate lower bound based on your problem
#     ub=M,  # your big-M value
#     name=f"visit_transition_{v}"
# ) for v in range(num_visits*num_filters-1)]  # one less than total visits

# # Add constraints for visit transitions
# for v in range(1, num_visits*num_filters):
#     # All fields from previous visit must end before transition time
#     for i in range(len(selected_fields)):
#         m2.add_constraint(
#             tc[i][v-1] + delta * x[i][v-1] <= visit_transition_times[v-1],
#             ctname=f"visit_end_{i}_visit_{v-1}"
#         )
        
#         # All fields in current visit must start after transition time
#         m2.add_constraint(
#             tc[i][v] >= visit_transition_times[v-1],
#             ctname=f"visit_start_{i}_visit_{v}"
#         )

# Isolating visits
for v in range(1, num_visits*num_filters):
    prev_visit_end = m2.max([tc[i][v-1] + 2 * delta * x[i][v-1] for i in range(len(selected_fields))])
    for i in range(len(selected_fields)):
        m2.add_constraint(tc[i][v] >= prev_visit_end,
            ctname=f"visit_sequence_field_{i}_visit_{v}")

m2.maximize(m2.sum([probabilities[i] * x[i][v]
                    for i in range(len(selected_fields))
                    for v in range(num_visits*num_filters)]))

m2.parameters.timelimit = 60
solution2 = m2.solve(log_output=True)

'''

# Visit ordering constraints
# for v in range(1, num_visits*num_filters):
#     for i in range(len(selected_fields)):
#         m2.add_constraint(tc[i][v] >= tc[len(selected_fields)-1][v-1] + delta + slew_time_day[][],
#                           ctname=f"visit_ordering_constraint_field_{i}_visit_{v}")
# Modified objective function to sum over fields and visits

for i in range(len(selected_fields)):
    for v in range(1, num_visits*num_filters):
        m2.add_constraint(tc[i][v] - tc[i][v-1] >= cadence_days * (x[i][v] + x[i][v-1] - 1),
            ctname=f"cadence_constraint_field_{i}_visits_{v}")

# For the first visit (v=0), use original constraints
for i in range(len(selected_fields)):
    for j in range(i):
        m2.add_constraint(
            tc[i][0] + delta * x[i][0] + slew_time_day[i][j] - tc[j][0] <= M * (2 - x[i][0] - x[j][0]),
            ctname=f"non_overlapping_cross_fields_{i}_{j}_visit_0"
        )
        m2.add_constraint(
            tc[j][0] + delta * x[j][0] + slew_time_day[i][j] - tc[i][0] <= M * (-1 + x[i][0] + x[j][0]),
            ctname=f"non_overlapping_cross_fields_{j}_{i}_visit_0"
        )

# For subsequent visits
for v in range(1, num_visits*num_filters):
    # Calculate end time of previous visit using all fields
    prev_visit_end = m2.max([
        tc[k][v-1] + delta * x[k][v-1]  
        for k in range(len(selected_fields))
    ])
    
    # For current visit
    for i in range(len(selected_fields)):
        # Ensure field starts after previous visit ends if it's selected
        m2.add_constraint(
            tc[i][v] >= prev_visit_end - M * (1 - x[i][v]),
            ctname=f"sequential_start_field_{i}_visit_{v}"
        )
        
        # Non-overlapping constraints within current visit
        for j in range(i):
            m2.add_constraint(
                tc[i][v] + delta * x[i][v] + slew_time_day[i][j] - tc[j][v] <= M * (2 - x[i][v] - x[j][v]),
                ctname=f"non_overlapping_cross_fields_{i}_{j}_visit_{v}"
            )
            m2.add_constraint(
                tc[j][v] + delta * x[j][v] + slew_time_day[i][j] - tc[i][v] <= M * (-1 + x[i][v] + x[j][v]),
                ctname=f"non_overlapping_cross_fields_{j}_{i}_visit_{v}"
            )
m2.maximize(m2.sum([
    probabilities[i] * x[i][v] 
    for i in range(len(selected_fields))
    for v in range(num_visits*num_filters)
]))

m2.parameters.timelimit = 60
solution = m2.solve(log_output=True)
'''

### exctracting solution

In [None]:
solution2


In [None]:
# scheduled_fields_by_visit = []
# for v in range(num_visits * num_filters):
#     visit_fields = [i for i in range(len(selected_fields)) if solution.get_value(x[i][v]) == 1]
#     scheduled_fields_by_visit.append(visit_fields)
# # scheduled_fields_by_visit

# scheduled_tc = []
# for v in range(num_visits * num_filters):
#     visit_times = []
#     for i in range(len(selected_fields)):
#         if i in scheduled_fields_by_visit[v]:
#             visit_times.append(solution.get_value(tc[i][v]))
#         else:
#             visit_times.append(np.nan) 
#     scheduled_tc.append(visit_times)
# scheduled_tc

In [None]:
'''
# Get the indices of scheduled fields
scheduled_fields_ID = [i for i, v in enumerate(x) if v.solution_value == 1]
scheduled_fields = selected_fields[scheduled_fields_ID]
# scheduled_fields
scheduled_tc = [[solution.get_value(tc[i][v]) for v in range(num_visits * num_filters)] for i in scheduled_fields_ID]
scheduled_tc = np.asarray(scheduled_tc)
# scheduled_fields
for i in range(num_visits*num_filters):
    scheduled_fields[f"Scheduled_start_filt_times_{i}"] = scheduled_tc[:,i] 
'''
scheduled_fields_by_visit = []
for v in range(num_visits * num_filters):
    visit_fields = [i for i in range(len(selected_fields)) if solution.get_value(x[i][v]) == 1]
    scheduled_fields_by_visit.append(visit_fields)

scheduled_fields = selected_fields.copy()

scheduled_tc = []
for v in range(num_visits * num_filters):
    visit_times = []
    for i in range(len(selected_fields)):
        if i in scheduled_fields_by_visit[v]:
            visit_times.append(solution.get_value(tc[i][v]))
        else:
            visit_times.append(np.nan) 
    scheduled_tc.append(visit_times)

scheduled_tc = np.array(scheduled_tc).T  

for i in range(num_visits * num_filters):
    scheduled_fields[f"Scheduled_start_filt_times_{i}"] = scheduled_tc[:, i]

for v in range(num_visits * num_filters):
    scheduled_fields[f"Selected_in_visit_{v}"] = [1 if i in scheduled_fields_by_visit[v] else 0 
                                                 for i in range(len(selected_fields))]


In [None]:
scheduled_tc

#### garbage

## plotting

In [None]:
n_visits = num_visits * num_filters  

fig, axes = plt.subplots(n_visits, 1, figsize=(8, 3 * n_visits), sharex=True)

for i in range(n_visits):
    start_col = f'Scheduled_start_filt_times_{i}'
    end_col = f'Scheduled_end_filt_times_{i}'
    
    # Convert start times to MJD and set format
    scheduled_fields[start_col] = Time(scheduled_fields[start_col], format='mjd')
    scheduled_fields[start_col].format = 'iso'
    scheduled_fields[end_col] = scheduled_fields[start_col] + exposure_time_day
    
    # Sort fields by end time for better visualization
    scheduled_fields.sort(end_col)
    
    # Get the start and end times for the vertical lines
    first_start_time = scheduled_fields[start_col].mjd[0]
    last_end_time = scheduled_fields[end_col].mjd[-1]
    
    ax = axes[i]  
    # Plot observation time intervals as horizontal lines
    ax.hlines(
        np.arange(len(scheduled_fields)),
        scheduled_fields[start_col].mjd,
        scheduled_fields[end_col].mjd,
        colors='blue',
        linewidth=2
    )
    # Plot small vertical lines at start and end times of each interval
    for j in range(len(scheduled_fields)):
        ax.vlines(
            scheduled_fields[start_col][j].mjd,
            ymin=j - 0.2,
            ymax=j + 0.2,
            color='black',
            linewidth=0.5,
            linestyle='-'
        )
        ax.vlines(
            scheduled_fields[end_col][j].mjd,
            ymin=j - 0.2,
            ymax=j + 0.2,
            color='black',
            linewidth=0.5,
            linestyle='-'
        )
    
    # Plot big vertical lines at the start of the first field and end of the last field
    ax.axvline(first_start_time, color='red', linestyle='--', linewidth=1.5, label='Start of First Field')
    ax.axvline(last_end_time, color='green', linestyle='--', linewidth=1.5, label='End of Last Field')
    
    # Add labels and title
    ax.set_yticks(np.arange(len(scheduled_fields)))
    ax.set_yticklabels(scheduled_fields['field_id'].astype(str))
    ax.set_ylabel('Field ID')
    ax.set_title(f'Observation Schedule for Visit {i + 1}')
    ax.legend(loc='upper right')  # Add legend to distinguish vertical lines
    
axes[-1].set_xlabel('Observation time (MJD)')

plt.tight_layout()
plt.show()


In [None]:
scheduled_fields

In [None]:
# from docplex.mp.model import Model

# # Make sure all time values are in days
# exposure_time_day = exposure_time.to_value(u.day)
# cadence_days = cadence / (24 * 60)  # Convert minutes to days

# # Create and solve the model


# def create_observation_model(prob, observable_fields, exposure_time, cadence_days, slew_time_day, num_visits):
#     m = Model("Telescope Observation Schedule")
    
#     # Index sets
#     n_fields = len(observable_fields)
    
#     # Decision Variables
#     # p[i]: pixel i is inside footprint of selected fields (binary)
#     p = m.binary_var_list(len(prob), name='pixel')
    
#     # r[j]: field j is selected (binary)
#     r = m.binary_var_list(n_fields, name='field')
    
#     # t[j,k]: start time of observation j visit k (continuous)
#     t = [[m.continuous_var(
#         lb=(row['start_time'] - observable_fields['start_time'].min()).to_value(u.day),
#         ub=(row['end_time'] - observable_fields['start_time'].min() - exposure_time).to_value(u.day),
#         name=f"start_time_field_{j}_visit_{k}")
#         for k in range(num_visits)] for j, row in enumerate(observable_fields)]
    
#     # Containment Constraints
#     # A pixel is only counted if it's in a selected field
#     for i, fields_containing_pixel in enumerate(footprints_healpix_inverse):
#         m.add_constraint(
#             p[i] <= m.sum(r[j] for j in fields_containing_pixel),
#             ctname=f'containment_{i}'
#         )
    
#     # Cadence Constraints
#     # Minimum time between visits of the same field
#     for j in range(n_fields):
#         for k in range(1, num_visits):
#             m.add_constraint(
#                 t[j][k] - t[j][k-1] >= cadence_days * r[j],
#                 ctname=f'cadence_field_{j}_visit_{k}'
#             )
    
#     # No Overlap Constraints
#     # Observations must be separated by exposure + slew time
#     for j1 in range(n_fields):
#         for j2 in range(j1):
#             for k1 in range(num_visits):
#                 for k2 in range(num_visits):
#                     min_separation = exposure_time_day + slew_time_day[j1][j2]
#                     # Either j1,k1 happens after j2,k2 or vice versa
#                     m.add_constraint(
#                         (t[j1][k1] - t[j2][k2] >= min_separation * (r[j1] + r[j2] - 1)) |
#                         (t[j2][k2] - t[j1][k1] >= min_separation * (r[j1] + r[j2] - 1)),
#                         ctname=f'no_overlap_{j1}_{k1}_{j2}_{k2}'
#                     )
    
#     # Field of Regard Constraints
#     # Start time must be within observable window
#     for j, field in enumerate(observable_fields):
#         for k in range(num_visits):
#             m.add_constraint(
#                 t[j][k] >= (field['start_time'] - observable_fields['start_time'].min()).to_value(u.day) * r[j],
#                 ctname=f'for_start_{j}_{k}'
#             )
#             m.add_constraint(
#                 t[j][k] <= (field['end_time'] - observable_fields['start_time'].min() - exposure_time).to_value(u.day) * r[j],
#                 ctname=f'for_end_{j}_{k}'
#             )
    
#     # Objective: Maximize probability coverage
#     objective = m.sum(p[i] * prob[i] for i in range(len(prob)))
#     m.maximize(objective)
    
#     return m

In [None]:
# model = create_observation_model(prob, observable_fields, exposure_time_day, cadence_days, slew_time_day, num_visits)
# solution = model.solve(log_output=True)

In [None]:
import astroplan
from astropy.coordinates import ICRS, SkyCoord, AltAz, get_moon, EarthLocation, get_body
from astropy import units as u
from astropy.utils.data import download_file
from astropy.table import Table, QTable, join
from astropy.time import Time, TimeDelta
from astropy_healpix import *
from ligo.skymap import plot
from ligo.skymap.io import read_sky_map
import healpy as hp
import os
from matplotlib import pyplot as plt
import numpy as np
from tqdm.auto import tqdm
import datetime as dt
import pickle
import pandas as pd
from docplex.mp.model import Model

import warnings
warnings.filterwarnings("ignore", "Wswiglal-redir-stdio")
warnings.simplefilter('ignore', astroplan.TargetNeverUpWarning)
warnings.simplefilter('ignore', astroplan.TargetAlwaysUpWarning)

# directory_path = "/u/ywagh/test_skymaps/S240422ed.fits"
# skymap, metadata = read_sky_map(os.path.join(directory_path))

directory_path = "/u/ywagh/test_skymaps/"
filelist = sorted([f for f in os.listdir(directory_path) if f.endswith('.gz')])

slew_speed = 2.5 * u.deg / u.s
slew_accel = 0.4 * u.deg / u.s**2
readout = 8.2 * u.s

ns_nchips = 4
ew_nchips = 4
ns_npix = 6144
ew_npix = 6160
plate_scale = 1.01 * u.arcsec
ns_chip_gap = 0.205 * u.deg
ew_chip_gap = 0.140 * u.deg

ns_total = ns_nchips * ns_npix * plate_scale + (ns_nchips - 1) * ns_chip_gap
ew_total = ew_nchips * ew_npix * plate_scale + (ew_nchips - 1) * ew_chip_gap

rcid = np.arange(64)

chipid, rc_in_chip_id = np.divmod(rcid, 4)
ns_chip_index, ew_chip_index = np.divmod(chipid, ew_nchips)
ns_rc_in_chip_index = np.where(rc_in_chip_id <= 1, 1, 0)
ew_rc_in_chip_index = np.where((rc_in_chip_id == 0) | (rc_in_chip_id == 3), 0, 1)

ew_offsets = ew_chip_gap * (ew_chip_index - (ew_nchips - 1) / 2) + ew_npix * plate_scale * (ew_chip_index - ew_nchips / 2) + 0.5 * ew_rc_in_chip_index * plate_scale * ew_npix
ns_offsets = ns_chip_gap * (ns_chip_index - (ns_nchips - 1) / 2) + ns_npix * plate_scale * (ns_chip_index - ns_nchips / 2) + 0.5 * ns_rc_in_chip_index * plate_scale * ns_npix

ew_ccd_corners = 0.5 * plate_scale * np.asarray([ew_npix, 0, 0, ew_npix])
ns_ccd_corners = 0.5 * plate_scale * np.asarray([ns_npix, ns_npix, 0, 0])

ew_vertices = ew_offsets[:, np.newaxis] + ew_ccd_corners[np.newaxis, :]
ns_vertices = ns_offsets[:, np.newaxis] + ns_ccd_corners[np.newaxis, :]

def get_footprint(center):
    return SkyCoord(
        ew_vertices, ns_vertices,
        frame=center[..., np.newaxis, np.newaxis].skyoffset_frame()
    ).icrs

url = 'https://github.com/ZwickyTransientFacility/ztf_information/raw/master/field_grid/ZTF_Fields.txt'
filename = download_file(url)
field_grid = QTable(np.recfromtxt(filename, comments='%', usecols=range(3), names=['field_id', 'ra', 'dec']))
field_grid['coord'] = SkyCoord(field_grid.columns.pop('ra') * u.deg, field_grid.columns.pop('dec') * u.deg)
field_grid = field_grid[0:881]

#******************************************************************************
skymap, metadata = read_sky_map(os.path.join(directory_path, filelist[40]))

plot_filename = os.path.basename(filelist[40])
# plot_filename = 'S240422ed'
# ci
#******************************************************************************

event_time = Time(metadata['gps_time'], format='gps').utc
event_time.format = 'iso'

event_time = Time(metadata['gps_time'], format='gps').utc
event_time.format = 'iso'
print('event time:',event_time)
observer = astroplan.Observer.at_site('Palomar')
night_horizon = -18 * u.deg
if observer.is_night(event_time, horizon=night_horizon):
    start_time = event_time
else:
    start_time = observer.sun_set_time(
        event_time, horizon=night_horizon, which='next')

# Find the latest possible end time of observations: the time of sunrise.
end_time = observer.sun_rise_time(
    start_time, horizon=night_horizon, which='next')

min_airmass = 2.5 * u.dimensionless_unscaled
airmass_horizon = (90 * u.deg - np.arccos(1 / min_airmass))
targets = field_grid['coord']

# Find the time that each field rises and sets above an airmass of 2.5.
target_start_time = Time(np.where(
    observer.target_is_up(start_time, targets, horizon=airmass_horizon),
    start_time,
    observer.target_rise_time(start_time, targets, which='next', horizon=airmass_horizon)))
target_start_time.format = 'iso'

# Find the time that each field sets below the airmass limit. If the target
# is always up (i.e., it's circumpolar) or if it sets after surnsise,
# then set the end time to sunrise.
target_end_time = observer.target_set_time(
    target_start_time, targets, which='next', horizon=airmass_horizon)
target_end_time[
    (target_end_time.mask & ~target_start_time.mask) | (target_end_time > end_time)
] = end_time
target_end_time.format = 'iso'
# Select fields that are observable for long enough for at least one exposure
##############################################################################
exposure_time = 180 * u.second
exposure_time_day = exposure_time.to_value(u.day)

num_visits = 2
num_filters = 2

cadence = 60         #minutes
cadence_days = cadence / (60 * 24)
##############################################################################
field_grid['start_time'] = target_start_time
field_grid['end_time'] = target_end_time
observable_fields = field_grid[target_end_time - target_start_time >= exposure_time]

# print(observable_fields)
hpx = HEALPix(nside=256, frame=ICRS())

footprint = np.moveaxis(
    get_footprint(SkyCoord(0 * u.deg, 0 * u.deg)).cartesian.xyz.value, 0, -1)
footprint_healpix = np.unique(np.concatenate(
    [hp.query_polygon(hpx.nside, v, nest=(hpx.order == 'nested')) for v in footprint]))

'''
# computing the footprints of every ZTF field as HEALPix indices. Downsampling skymap to same resolution.
'''
footprints = np.moveaxis(get_footprint(observable_fields['coord']).cartesian.xyz.value, 0, -1)
footprints_healpix = [
    np.unique(np.concatenate([hp.query_polygon(hpx.nside, v) for v in footprint]))
    for footprint in tqdm(footprints)]

prob = hp.ud_grade(skymap, hpx.nside, power=-2)

# k = max number of 300s exposures 
min_start = min(observable_fields['start_time'])
max_end =max(observable_fields['end_time'])
# min_start.format = 'jd'
# max_end.format = 'jd'



k = int(np.floor((max_end - min_start)/(2*exposure_time.to(u.day))))
k = np.floor(k/(num_visits*num_filters))
print(k," number of exposures could be taken tonight")

print("problem setup completed")

m1 = Model('max coverage problem')

field_vars = m1.binary_var_list(len(footprints), name='field')
pixel_vars = m1.binary_var_list(hpx.npix, name='pixel')

footprints_healpix_inverse = [[] for _ in range(hpx.npix)]

for field, pixels in enumerate(footprints_healpix):
    for pixel in pixels:
        footprints_healpix_inverse[pixel].append(field)

for i_pixel, i_fields in enumerate(footprints_healpix_inverse):
     m1.add_constraint(m1.sum(field_vars[i] for i in i_fields) >= pixel_vars[i_pixel])

m1.add_constraint(m1.sum(field_vars) <= k)
m1.maximize(m1.dot(pixel_vars, prob))
print(f"number fo fields observed should be less than {k}")

solution = m1.solve(log_output=True)

print("optimization completed")
total_prob_covered = solution.objective_value

print("Total probability covered:",total_prob_covered)

selected_fields_ID = [i for i, v in enumerate(field_vars) if v.solution_value == 1]
print(len(selected_fields_ID), "fields selected")
selected_fields = observable_fields[selected_fields_ID]
# print(selected_fields)

separation_matrix = selected_fields['coord'][:,np.newaxis].separation(selected_fields['coord'][np.newaxis,:])

def slew_time(separation):
   return np.where(separation <= (slew_speed**2 / slew_accel),
                   np.sqrt(2 * separation / slew_accel),
                   (2 * slew_speed / slew_accel) + (separation - slew_speed**2 / slew_accel) / slew_speed)

slew_times = slew_time(separation_matrix).value

slew_time_value = slew_times*u.second
slew_time_day = slew_time_value.to_value(u.day)

m2 = Model("Telescope timings")

observer_location = EarthLocation.of_site('Palomar')

footprints_selected = np.moveaxis(get_footprint(selected_fields['coord']).cartesian.xyz.value, 0, -1)
footprints_healpix_selected = [
    np.unique(np.concatenate([hp.query_polygon(hpx.nside, v) for v in footprint]))
    for footprint in tqdm(footprints_selected)]

probabilities = []

for field_index in range(len(footprints_healpix_selected)):
    probability_field = np.sum(prob[footprints_healpix_selected[field_index]])
    probabilities.append(probability_field)
print("worked for",len(probabilities),"fields")

selected_fields['probabilities'] = probabilities

delta = exposure_time.to_value(u.day)
M = (selected_fields['end_time'].max() - selected_fields['start_time'].min()).to_value(u.day).item()

x = [[m2.binary_var(name=f"x_{i}_visit_{v}") 
      for v in range(num_visits*num_filters)] 
      for i in range(len(selected_fields))]

tc = [[m2.continuous_var(
    lb=(row['start_time'] - start_time).to_value(u.day),
    ub=(row['end_time'] - start_time - exposure_time).to_value(u.day),
    name=f"start_time_field_{i}_visit_{v}")
    for v in range(num_visits*num_filters)] 
    for i, row in enumerate(selected_fields)]

visit_transition_times = [m2.continuous_var(
    lb=0,ub=M,name=f"visit_transition_{v}")
                          for v in range(num_visits*num_filters-1)]  

# Isolating visits
for v in range(1, num_visits*num_filters):
    for i in range(len(selected_fields)):
        m2.add_constraint(tc[i][v-1] + delta * x[i][v-1] <= visit_transition_times[v-1],
            ctname=f"visit_end_{i}_visit_{v-1}")
        m2.add_constraint(tc[i][v] >= visit_transition_times[v-1],
            ctname=f"visit_start_{i}_visit_{v}")

# Cadence constraints
for i in range(len(selected_fields)):
    for v in range(1, num_visits*num_filters):
        m2.add_constraint(tc[i][v] - tc[i][v-1] >= (cadence_days+delta) * (x[i][v] + x[i][v-1] - 1),
            ctname=f"cadence_constraint_field_{i}_visits_{v}")

#non-overlapping constraints
# for v in range(num_visits * num_filters):
#     for i in range(len(selected_fields)):
#         for j in range(i):  # Ensure j < i to avoid duplicate constraints
#             buffer_time = 0.001  # Small buffer to prevent exact equality issues
            
#             m2.add_indicator(x[i][v], 
#                            tc[i][v] + delta + slew_time_day[i][j] + buffer_time <= tc[j][v],
#                            name=f"indicator_constraint_{i}_to_{j}_visit_{v}")
            
#             m2.add_indicator(x[j][v], 
#                            tc[j][v] + delta + slew_time_day[i][j] + buffer_time <= tc[i][v],
#                            name=f"indicator_constraint_{j}_to_{i}_visit_{v}")


#non-overlapping
for v in range(num_visits*num_filters):
    for i in range(len(selected_fields)):
        for j in range(i):
            m2.add_constraint(tc[i][v] + delta * x[i][v] + slew_time_day[i][j] - tc[j][v] <= M * (2 - x[i][v] - x[j][v]),
                              ctname=f"non_overlapping_cross_fields_{i}_{j}_visits_{v}")
            m2.add_constraint(tc[j][v] + delta * x[j][v] + slew_time_day[i][j] - tc[i][v] <= M * (-1 + x[i][v] + x[j][v]),
                ctname=f"non_overlapping_cross_fields_{j}_{i}_visits_{v}")

m2.maximize(m2.sum([probabilities[i] * x[i][v]
                    for i in range(len(selected_fields))
                    for v in range(num_visits*num_filters)]))

m2.parameters.timelimit = 60
m2.parameters.mip.tolerances.mipgap = 0.01  # 1% optimality gap
m2.parameters.emphasis.mip = 2  # Emphasize optimality over feasibility
solution = m2.solve(log_output=True)

scheduled_fields_by_visit = []
for v in range(num_visits * num_filters):
    visit_fields = [i for i in range(len(selected_fields)) if solution.get_value(x[i][v]) == 1]
    scheduled_fields_by_visit.append(visit_fields)

scheduled_fields = selected_fields.copy()

scheduled_tc = []
for v in range(num_visits * num_filters):
    visit_times = []
    for i in range(len(selected_fields)):
        if i in scheduled_fields_by_visit[v]:
            visit_times.append(solution.get_value(tc[i][v]))
        else:
            visit_times.append(np.nan) 
    scheduled_tc.append(visit_times)

scheduled_tc = np.array(scheduled_tc).T  

for i in range(num_visits * num_filters):
    scheduled_fields[f"Scheduled_start_filt_times_{i}"] = scheduled_tc[:, i]

for v in range(num_visits * num_filters):
    scheduled_fields[f"Selected_in_visit_{v}"] = [1 if i in scheduled_fields_by_visit[v] else 0 
                                                 for i in range(len(scheduled_fields))]



n_visits = num_visits * num_filters  

fig, axes = plt.subplots(n_visits, 1, figsize=(8, 3 * n_visits), sharex=True)

for i in range(n_visits):
    start_col = f'Scheduled_start_filt_times_{i}'
    end_col = f'Scheduled_end_filt_times_{i}'
    valid_rows = ~np.isnan(scheduled_tc).any(axis=1)
    valid_scheduled_tc = scheduled_tc[valid_rows]
    valid_scheduled_fields = scheduled_fields[valid_rows]
    valid_field_ids = scheduled_fields['field_id'][valid_rows]


    # Also get corresponding field IDs if needed
    valid_scheduled_fields[start_col] = Time(valid_scheduled_fields[start_col], format='mjd')
    valid_scheduled_fields[start_col].format = 'iso'
    valid_scheduled_fields[end_col] = valid_scheduled_fields[start_col] + exposure_time_day
    
    # Sort fields by end time for better visualization
    valid_scheduled_fields.sort(end_col)
    
    # Get the start and end times for the vertical lines
    first_start_time = valid_scheduled_fields[start_col].mjd[0]
    last_end_time = valid_scheduled_fields[end_col].mjd[-1]
    
    ax = axes[i]  
    ax.hlines(
        np.arange(len(valid_scheduled_fields)),
        valid_scheduled_fields[start_col].mjd,
        valid_scheduled_fields[end_col].mjd,
        colors='blue',
        linewidth=2
    )
    # Plot small vertical lines at start and end times of each interval
    for j in range(len(valid_scheduled_fields)):
        ax.vlines(
            valid_scheduled_fields[start_col][j].mjd,
            ymin=j - 0.2,
            ymax=j + 0.2,
            color='black',
            linewidth=0.5,
            linestyle='-'
        )
        ax.vlines(
            valid_scheduled_fields[end_col][j].mjd,
            ymin=j - 0.2,
            ymax=j + 0.2,
            color='black',
            linewidth=0.5,
            linestyle='-'
        )
    
    # Plot big vertical lines at the start of the first field and end of the last field
    ax.axvline(first_start_time, color='red', linestyle='--', linewidth=1.5, label='Start of First Field')
    ax.axvline(last_end_time, color='green', linestyle='--', linewidth=1.5, label='End of Last Field')
    
    # Add labels and title
    ax.set_yticks(np.arange(len(valid_scheduled_fields)))
    ax.set_yticklabels(valid_scheduled_fields['field_id'].astype(str))
    ax.set_ylabel('Field ID')
    ax.set_title(f'Observation Schedule for Visit {i + 1}')
    ax.legend(loc='upper right')  # Add legend to distinguish vertical lines
    
axes[-1].set_xlabel('Observation time (MJD)')

plt.tight_layout()
# plt.title(f'Total Cumulative Probability per Field:{total_cum_prob}')
plt.savefig('revisit_plots.png', dpi=300)
plt.show()


In [7]:
scheduled_fields_by_visit = []
for v in range(num_visits * num_filters):
    visit_fields = [i for i in range(len(selected_fields)) if solution.get_value(x[i][v]) == 1]
    scheduled_fields_by_visit.append(visit_fields)

scheduled_fields = selected_fields.copy()

scheduled_tc = []
for v in range(num_visits * num_filters):
    visit_times = []
    for i in range(len(selected_fields)):
        if i in scheduled_fields_by_visit[v]:
            visit_times.append(solution.get_value(tc[i][v]))
        else:
            visit_times.append(np.nan) 
    scheduled_tc.append(visit_times)

scheduled_tc = np.array(scheduled_tc).T  

for i in range(num_visits * num_filters):
    scheduled_fields[f"Scheduled_start_filt_times_{i}"] = scheduled_tc[:, i]

for v in range(num_visits * num_filters):
    scheduled_fields[f"Selected_in_visit_{v}"] = [1 if i in scheduled_fields_by_visit[v] else 0 
                                                 for i in range(len(selected_fields))]

In [None]:
# Get indices of rows that don't have any NaN values
valid_rows = ~np.isnan(scheduled_tc).any(axis=1)
# valid_rows
# Create new array with only valid rows
valid_scheduled_tc = scheduled_tc[valid_rows]
valid_scheduled_tc
# Also get corresponding field IDs if needed
# valid_field_ids = scheduled_fields['field_id'][valid_rows]

# dumping yard

In [None]:
# Sliding time window parameters
window_size = 1 * u.hour  # 1-hour windows
window_duration = window_size.to_value(u.day)  # Convert to days
time_windows = np.arange(
    min_start.to_value(u.day), 
    max_end.to_value(u.day), 
    window_duration
)

# Collect results across all time windows
scheduled_fields_all_windows = []
probabilities_covered = []

for window_start in time_windows:
    window_end = window_start + window_duration

    # Filter fields that can be observed within this time window
    window_fields = selected_fields[
        (selected_fields['start_time'].to_value(u.day) <= window_end) &
        (selected_fields['end_time'].to_value(u.day) >= window_start)
    ]
    if len(window_fields) == 0:
        continue

    # Precompute valid field pairs for the window
    window_separation_matrix = window_fields['coord'][:, np.newaxis].separation(
        window_fields['coord'][np.newaxis, :]
    )
    window_slew_times = slew_time(window_separation_matrix).to_value(u.day)
    valid_pairs = [
        (i, j)
        for i in range(len(window_fields))
        for j in range(i)
        if window_slew_times[i, j] + delta <= M
    ]

    # Create new Model 2 for the time window
    m2 = Model(f"Telescope timings (Window {window_start}-{window_end})")
    
    # Decision variables
    x = [[m2.binary_var(name=f"x_{i}_visit_{v}") 
          for v in range(num_visits*num_filters)] 
          for i in range(len(window_fields))]

    tc = [[m2.continuous_var(
        lb=max(window_fields['start_time'][i].to_value(u.day), window_start),
        ub=min(window_fields['end_time'][i].to_value(u.day), window_end - delta),
        name=f"start_time_field_{i}_visit_{v}")
        for v in range(num_visits*num_filters)] 
        for i in range(len(window_fields))]

    visit_transition_times = [m2.continuous_var(
        lb=window_start, ub=window_end, name=f"visit_transition_{v}")
                              for v in range(num_visits*num_filters - 1)]

    # Constraints
    for v in range(1, num_visits*num_filters):
        for i in range(len(window_fields)):
            m2.add_constraint(tc[i][v-1] + delta * x[i][v-1] <= visit_transition_times[v-1],
                              ctname=f"visit_end_{i}_visit_{v-1}")
            m2.add_constraint(tc[i][v] >= visit_transition_times[v-1],
                              ctname=f"visit_start_{i}_visit_{v}")

    for i in range(len(window_fields)):
        for v in range(1, num_visits*num_filters):
            m2.add_constraint(tc[i][v] - tc[i][v-1] >= cadence_days * (x[i][v] + x[i][v-1] - 1),
                              ctname=f"cadence_constraint_field_{i}_visits_{v}")

    # Non-overlapping constraints for valid pairs
    for v in range(num_visits*num_filters):
        for i, j in valid_pairs:
            m2.add_constraint(tc[i][v] + delta * x[i][v] + window_slew_times[i, j] - tc[j][v] <= M * (2 - x[i][v] - x[j][v]),
                              ctname=f"non_overlapping_cross_fields_{i}_{j}_visits_{v}")
            m2.add_constraint(tc[j][v] + delta * x[j][v] + window_slew_times[i, j] - tc[i][v] <= M * (-1 + x[i][v] + x[j][v]),
                              ctname=f"non_overlapping_cross_fields_{j}_{i}_visits_{v}")

    # Objective
    m2.maximize(m2.sum([probabilities[i] * x[i][v]
                        for i in range(len(window_fields))
                        for v in range(num_visits*num_filters)]))

    # Solve
    m2.parameters.timelimit = 300
    m2.parameters.mip.tolerances.mipgap = 0.05
    solution = m2.solve(log_output=True)

    if solution:
        total_prob = solution.objective_value
        probabilities_covered.append(total_prob)

        # Extract scheduled fields and times
        scheduled_fields_by_visit = []
        for v in range(num_visits * num_filters):
            visit_fields = [i for i in range(len(window_fields)) if solution.get_value(x[i][v]) == 1]
            scheduled_fields_by_visit.append(visit_fields)
        scheduled_fields_all_windows.append((window_start, window_end, scheduled_fields_by_visit))

# Combine results across windows if necessary
# (Implement further logic to consolidate results across time windows if needed)
