# This notebook is to make / update the lists used in querying objects each day

### Creating new lists:
- Go to the sources page on Fritz
- Filter the table on Type (Ia & all Ia subtypes)
- Download the csv (takes a while, but it's the fastest method I could find)
- Follow the steps below for creating new lists
- Put the lists in the required location for the real-time program to work properly

Note: keep the csv file secure to reduce download time should the lists be remade at some point

### Updating lists:
- Check until what date objects are already saved
- Go to Fritz and download the objects saved later than that (See above for notes)
- Follow the steps below for updating lists
- Replace the lists in the required location for the rea-time program to work properly

Note: add the newly downloaded sources to the old csv file to reduce download time the next time lists are updated

In [None]:
#Useful imports and functions
import pandas as pd
from pathlib import Path
from astropy import time
import numpy as np
from datetime import datetime
import matplotlib.pyplot as plt

def load_sources(path):
    # Load the sources file retrieved from Fritz
    sources = pd.read_csv(path, header=0,
                          usecols=['id', 'ra [deg]', 'dec [deg]', 'redshift', 'classification', 'Saved at'])
    sources.rename(columns={'id':'ztfname','ra [deg]':'ra','dec [deg]':'dec','Saved at':'saved_at'}, inplace=True)
    sources['save_mjd'] = [time.Time(_).mjd for _ in sources.saved_at.values]
    return sources

def load_oldlists(paths):
    # Load the old lists currently used in the real-time program to update them
    oldlists = []
    for i in paths:
        oldlists.append(pd.read_csv(i, header=0))
    return oldlists

def save_lists(lists, saveloc):
    # Save the new lists
    for i in range(len(lists)):
        lists[i].to_csv(saveloc/f'list_{i}.csv')
    print(f'Lists saved on {datetime.now().date()} at {datetime.now().time()}')
    return

def restrict_age(sources, min_age=0, max_age=None):
    # The 18+ restriction question, but for transients in days
    mjd_now = time.Time(datetime.now()).mjd
    if max_age == None:
        return sources[sources.save_mjd<mjd_now-min_age].reset_index(drop=True)
    else:
        return sources[((sources.save_mjd<mjd_now-min_age) &
                        (sources.save_mjd>mjd_now-max_age))].reset_index(drop=True)

def transform_format(sources):
    # Transform the DataFrame from the Fritz columns to those needed in the final lists
    #Needed columns: name, ra, dec, last_checked_mjd, last_checked_nr (0 for now, updated when checked)
    transformed_list = sources.drop(labels=['redshift', 'classification', 'saved_at', 'save_mjd'], axis=1)
    transformed_list['last_checked_mjd'] = 0
    transformed_list['last_checked_nr'] = 0
    return transformed_list

def distribute_sources(sources, n):
    # Distribute the sources over n lists of the same size
    # First shuffle the sample to make sure each subset more or less covers the entire sky evenly
    sources = sources.sample(frac=1)
    nr_sources = len(sources)
    listsize = int(np.floor(nr_sources/n))
    left = nr_sources - n*listsize
    #Going to need left lists of size listsize+1 and n-left lists of size n
    lists = []
    for i in range(n):
        if i < left:
            lists.append(sources[i*(listsize+1):(i+1)*(listsize+1)])
        else:
            lists.append(sources[left*(listsize+1)+(i-left)*listsize:left*(listsize+1)+(i-left+1)*listsize])
    return lists

def add_to_existing_lists(oldlists, newlists):
    # Add the newlists to the oldlists
    #The first n & m have an object more, so add with an offset to make sure the 1st list isn't growing faster
    #If there aren't the same amount of oldlists and newlists, don't do anything
    if len(oldlists) != len(newlists):
        print('Error: Not the same amount of lists')
        return None
    combined_lists = []
    oldlens = [len(_) for _ in oldlists]
    newlens = [len(_) for _ in newlists]
    if oldlens[0] == oldlens[-1]: #all lists already had the same length, no offset needed
        step = 0
    else:
        step = oldlens.index(oldlens[0]-1)
    print(step)
    for i in range(len(oldlists)):
        combined_lists.append(pd.concat([oldlists[i], newlists[(i-step)%len(newlists)]], ignore_index=True))
    return combined_lists

def plot_list(i, j=None):
    # Plot the source locations in list i
    fig, ax = plt.subplots(subplot_kw={'projection': 'polar'})
    ax.scatter(i.ra*np.pi/180, i.dec, color='b', marker='.')
    if j is not None:
        ax.scatter(j.ra*np.pi/180, j.dec, color='r', marker='*')
    ax.set_rmin(90)
    ax.set_rmax(-90)
    plt.show()
    return

### Create new lists from scratch

In [None]:
# Read in file with sources
source_loc = Path('')
sources = load_sources(source_loc)

In [None]:
# Take the read-in sources, restrict their minimum age, reformat, and spread over the amount of lists needed
list_nr = 28
min_age = 100

lists = distribute_sources(transform_format(restrict_age(sources, min_age=min_age)), list_nr)

In [None]:
# Check if the lists are spread out evenly (more or less)
for i in lists:
    plot_list(i)

In [None]:
# Save the lists in the given location
save_loc = Path('')
for i in lists:
    i.reset_index(drop=True, inplace=True)
save_lists(lists, save_loc)

### Add sources to existing lists

In [None]:
# Read in file with sources
source_loc = Path('')
sources = load_sources(source_loc)

# Read in the current lists
oldlistpaths = [Path(f'/list_{_}.csv') for _ in range(28)]
oldlists = load_oldlists(oldlistpaths)

In [None]:
# Take the read-in sources, restrict their minimum age, reformat, and spread over the amount of lists needed
list_nr = len(oldlists)
min_age = 100
max_age = 

newlists = distribute_sources(transform_format(restrict_age(sources, min_age=min_age, max_age=max_age)), list_nr)

In [None]:
# Combine the new and old lists
combined_lists = add_to_existing_lists(oldlists, newlists)

In [None]:
# Check where the new sources are in each list
for i in range(list_nr):
    plot_list(oldlists[i], newlists[i])

In [None]:
# Save the lists in the given location
save_loc = Path('')
save_lists(lists, save_loc)