Skip to content

Commit

Permalink
refactor stardata, improve date-guess
Browse files Browse the repository at this point in the history
  • Loading branch information
andrew551 committed Feb 28, 2024
1 parent 63d61f8 commit ff82446
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 61 deletions.
2 changes: 2 additions & 0 deletions MEE2024util.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def write_ini(options):
def date_string_to_float(x):
return datetime.datetime.fromisoformat(x).toordinal()/365.24+1

def date_from_float(x):
return datetime.datetime.fromordinal(int((x - 1) * 365.24)).date().isoformat()

'''
logging setup
Expand Down
111 changes: 96 additions & 15 deletions StarData.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,116 @@
import pandas as pd
import numpy as np
import astropy.units as u
from astropy.coordinates import SkyCoord, Distance
from astropy.time import Time
from MEE2024util import date_string_to_float

'''
remove NaNs
set anything smaller than 1e-3 arcseconds (1 mas) to a constant
including negative parallax from gaia measurement error
'''
def regularize_parallax(parallax, minimum=1):
x = np.copy(parallax)
x[np.isnan(x)] = 0
x[x < minimum] = minimum
return x

def regularize_pm(pm):
x = np.copy(pm)
x[np.isnan(x)] = 0
return x

# wrapper for gaia star data
class StarData:

# source: gaia or tycho
def __init__(self, ids, data, epoch):
self.data = data
self.ids = ids
self.epoch = epoch
# r: a gaia result object
def __init__(self, r=None, date=None, has_pm=None):
if r is None:
return # make empty stardata
self.epoch = Time(date, format='jyear', scale='tcb')
print('epoch', self.epoch)
self.has_pm = has_pm
self.mags = r['phot_g_mean_mag']
self.ids = r['source_id']
self.pm = np.zeros((self.nstars(), 2))
self.parallax = np.ones(self.nstars())*1e-4
if has_pm:
self.pm[:, 0] = r['pmra']
self.pm[:, 1] = r['pmdec']
self.parallax = regularize_parallax(r['parallax'])
self.c = c = SkyCoord(ra=r['ra'],
dec=r['dec'],
distance=Distance(parallax= self.parallax * u.mas),
pm_ra_cosdec=regularize_pm(self.pm[:, 0]) * u.mas / u.yr,
pm_dec=regularize_pm(self.pm[:, 1]) * u.mas / u.yr,
obstime=self.epoch)
else:
self.c = c = SkyCoord(ra=r['ra'],
dec=r['dec'],
obstime=self.epoch)

self._update_vectors()

def nstars(self):
return self.ids.shape[0]

def get_ra(self):
return self.data[:, 0]
return self.c.ra.rad

def get_dec(self):
return self.data[:, 1]
return self.c.dec.rad

def get_ra_dec(self):
return np.c_[self.get_ra(), self.get_dec()]

def _update_vectors(self):
self.vectors = np.zeros((self.ids.shape[0], 3))
star_table = self.get_ra_dec()
self.vectors[:, 0] = np.cos(star_table[:, 0]) * np.cos(star_table[:, 1])
self.vectors[:, 1] = np.sin(star_table[:, 0]) * np.cos(star_table[:, 1])
self.vectors[:, 2] = np.sin(star_table[:, 1])

# return unit vectors for each star as np array
def get_vectors(self):
return self.data[:, 2:5]
return self.vectors

def get_mags(self):
return self.data[:, 5]
return self.mags

def get_parallax(self):
return self.data[:, 6]
return self.parallax

def get_pmotion(self):
return self.data[:, 7:9]
return self.pm

# return star ids array
def get_ids(self):
return self.ids

#def update_epoch(self, new_epoch):
# pass

def update_epoch(self, date):
if not self.has_pm:
raise Exception("cannot update epoch without pm")
self.epoch = Time(date, format='jyear', scale='tcb')
#print('updating epoch to', self.epoch)
self.c = self.c.apply_space_motion(self.epoch)
self._update_vectors()

def select_indices(self, indices):
self.data = self.data[indices, :]
self.mags = self.mags[indices]
self.vectors = self.vectors[indices, :]
self.ids = self.ids[indices]
self.c = self.c[indices]
self.pm = self.pm[indices, :]
self.parallax = self.parallax[indices]

def get_epoch_float(self):
# TODO: make less dodgy
return float(str(self.epoch))#date_string_to_float(self.epoch.TimeISO())

'''
# TODO: fix me or delete me?
def update_data(self, newdata):
my_ids = self.get_ids()
other_ids = dict(zip(newdata.get_ids(), np.arange(newdata.data.shape[0])))
Expand All @@ -49,8 +120,18 @@ def update_data(self, newdata):
for i in range(my_ids.shape[0]):
j = other_ids[my_ids[i]]
self.data[i, :] = newdata.data[j, :]
self._update_vectors()
'''

def __copy__(self):
newone = type(self)(self.ids, self.data, self.epoch)
newone = type(self)()
newone.epch = self.epoch
newone.mags = self.mags
newone.vectors = self.vectors
newone.ids = self.ids
newone.c = self.c
newone.pm = self.pm
newone.parallax = self.parallax
newone.has_pm = self.has_pm
return newone

68 changes: 61 additions & 7 deletions distortion_cubic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import matplotlib.pyplot as plt
import scipy
import datetime
from MEE2024util import date_string_to_float
from MEE2024util import date_string_to_float, date_from_float
import copy

mapping = {'linear':1, 'cubic':3, 'quintic':5, 'septic':7}

Expand All @@ -28,6 +29,18 @@ def get_coeff_names(options):
names = [name.replace('x^1', 'x').replace('y^1', 'y') for name in names]
return names

'''
performs linear regression on errors, return the rms residual error
'''

def _regression_helper(errors, basis_x, basis_y):
reg_x = LinearRegression().fit(basis_x, errors[:, 1])
reg_y = LinearRegression().fit(basis_y, errors[:, 0])
res_x = reg_x.predict(basis_x) - errors[:, 1]
res_y = reg_y.predict(basis_y) - errors[:, 0]
rms = np.mean(res_x**2+res_y**2)**0.5
return rms

'''
absorb two constant and two linear degrees of freedom in (reg_x, reg_y) into shifts in
shifts in q
Expand All @@ -48,10 +61,12 @@ def _get_corrected_q(q, reg_x, reg_y, w):
return : improved date guess, pmotion correction
'''
def _date_guess(date_guess, q, plate, stardata, img_shape, options):
target = stardata.get_vectors()
pmotion = stardata.get_pmotion()
w = (max(img_shape)/2) # 1 # for astrometrica convention
m = 1 #q[0] # for astrometrica convention
'''
target = stardata.get_vectors()
pmotion = stardata.get_pmotion()
detransformed = transforms.detransform_vectors(q, target)
Expand All @@ -65,13 +80,13 @@ def _date_guess(date_guess, q, plate, stardata, img_shape, options):
pm_pixel = np.einsum('ij, ...j-> ...j', rmatrix, pmotion)
pm_pixel[:, [0, 1]] = pm_pixel[:, [1, 0]] # swap columns of pm_pixel
# apply date_guess to correct pmotion
errors += pm_pixel * (date_string_to_float(date_guess) - stardata.epoch)
errors_p = errors + pm_pixel * (date_string_to_float(date_guess) - stardata.get_epoch_float())
basis_x = np.c_[basis, pm_pixel[:, 1]]
basis_y = np.c_[basis, pm_pixel[:, 0]]
reg_x = LinearRegression().fit(basis_x, errors[:, 1]*m)
reg_y = LinearRegression().fit(basis_y, errors[:, 0]*m)
reg_x = LinearRegression().fit(basis_x, errors_p[:, 1]*m)
reg_y = LinearRegression().fit(basis_y, errors_p[:, 0]*m)
plate_corrected = plate + np.array([reg_y.predict(basis_x), reg_x.predict(basis_y)]).T / m
#print(reg_x.coef_, reg_x.intercept_)
#print(reg_y.coef_, reg_y.intercept_)
Expand All @@ -80,7 +95,46 @@ def _date_guess(date_guess, q, plate, stardata, img_shape, options):
t_guess = (t0 + datetime.timedelta(days=-int((reg_x.coef_[-1]+ reg_y.coef_[-1])*365.25/2))).date().isoformat()
print('I guess image was taken on date:', date_guess, t_guess, int((reg_x.coef_[-1]+ reg_y.coef_[-1])*365.25/2))
pmotion_correction = pm_pixel * (date_string_to_float(t_guess) - date_string_to_float(options['observation_date']))
return t_guess, pmotion_correction
'''
# show plot of rms vs t


dtt = np.linspace(-15, 15, num=40)
rmss = []
basis = get_basis(plate[:, 0], plate[:, 1], w, m, options)
t0 = date_string_to_float(date_guess)
for dt in dtt:
stardata_copy = copy.copy(stardata)
stardata_copy.update_epoch(dt+t0)
target_t = stardata_copy.get_vectors()

detransformed = transforms.detransform_vectors(q, target_t)
errors = detransformed - plate
rms = np.degrees(_regression_helper(errors, basis, basis)*q[0])*3600
rmss.append(rms)
plt.plot(dtt+t0, rmss)
plt.ylabel('rms / arcsec')
plt.xlabel('date (years)')
if options['flag_display2']:
plt.show()
plt.close()

def rms_func(t):
stardata_copy = copy.copy(stardata)
stardata_copy.update_epoch(t)
target_t = stardata_copy.get_vectors()
detransformed = transforms.detransform_vectors(q, target_t)
errors = detransformed - plate
rms = _regression_helper(errors, basis, basis)
return rms

min_result = scipy.optimize.minimize_scalar(rms_func, bounds = (t0-50, t0+50), method='bounded')

print('min_result', min_result)

min_date = date_from_float(min_result.x)
print('min_date', min_date)
return min_date

'''
perform requested linear regression with general
Expand Down
28 changes: 14 additions & 14 deletions distortion_fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ def match_centroids(other_stars_df, result, dbs, corners, image_size, lookupdate
transformed_all = transforms.to_polar(transforms.linear_transform(result.x, all_star_plate))

# match nearest neighbours
candidate_stars = np.zeros((stardata.data.shape[0], 2))
candidate_stars[:, 0] = np.degrees(stardata.data[:, 1])
candidate_stars[:, 1] = np.degrees(stardata.data[:, 0])
candidate_stars = np.zeros((stardata.nstars(), 2))
candidate_stars[:, 0] = np.degrees(stardata.get_dec())
candidate_stars[:, 1] = np.degrees(stardata.get_ra())

# find nearest two catalogue stars to each observed star
neigh = NearestNeighbors(n_neighbors=2)
Expand Down Expand Up @@ -109,9 +109,9 @@ def match_centroids(other_stars_df, result, dbs, corners, image_size, lookupdate

plt.scatter(cata_matched[:, 1], cata_matched[:, 0], label='catalogue')
plt.scatter(obs_matched[:, 1], obs_matched[:, 0], marker='+', label='observations')
for i in range(stardata.data.shape[0]):
for i in range(stardata.nstars()):
if i in indices[keep_i, 0]:
plt.gca().annotate(str(stardata.ids[i]) + f'\nMag={stardata.data[i, 5]:.1f}', (np.degrees(stardata.data[i, 0]), np.degrees(stardata.data[i, 1])), color='black', fontsize=5)
plt.gca().annotate(str(stardata.ids[i]), (np.degrees(stardata.get_ra()[i]), np.degrees(stardata.get_dec()[i])), color='black', fontsize=5)
plt.xlabel('RA')
plt.ylabel('DEC')
plt.title('initial rough fit')
Expand Down Expand Up @@ -186,7 +186,7 @@ def match_and_fit_distortion(path_data, options, debug_folder=None):

if options['guess_date']:
dateguess = options['DEFAULT_DATE'] # initial guess
dateguess, _ = distortion_cubic._date_guess(dateguess, initial_guess, plate2, stardata, image_size, options)
dateguess = distortion_cubic._date_guess(dateguess, initial_guess, plate2, stardata, image_size, dict(options, **{'flag_display2':False}))
# re-get gaia database
stardata, plate2, alt, az = match_centroids(other_stars_df, result, dbs, corners, image_size, dateguess, dict(options, **{'flag_display2':False}))

Expand All @@ -204,10 +204,10 @@ def match_and_fit_distortion(path_data, options, debug_folder=None):
flag_is_double = np.zeros(stardata.ids.shape[0], int)
neigh_all = gaia_search.lookup_nearby(stardata, options['double_star_cutoff'], options['double_star_mag'])
neigh = NearestNeighbors(n_neighbors=2)
neigh_all_data_extra2 = np.r_[neigh_all.data[:, :2], np.array([[-99999,-99999], [-99999, -99999]])] # ensure at least 2 "pseudo-neighbours"
neigh_all_data_extra2 = np.r_[neigh_all.get_ra_dec(), np.array([[-99999,-99999], [-99999, -99999]])] # ensure at least 2 "pseudo-neighbours"

neigh.fit(neigh_all_data_extra2)
distances, indices = neigh.kneighbors(stardata.data[:, :2])
distances, indices = neigh.kneighbors(stardata.get_ra_dec())

flag_is_double = distances[:, 1] < np.radians(options['double_star_cutoff']/3600)
flag_missing_pm = np.isnan(stardata.get_pmotion()[:, 0])
Expand All @@ -227,10 +227,10 @@ def match_and_fit_distortion(path_data, options, debug_folder=None):
# do 2nd fit with outliers removed

if options['guess_date']:
dateguess, _ = distortion_cubic._date_guess(dateguess, initial_guess, plate2, stardata, image_size, options)
# re-get gaia database # TODO: it would be nice to epoch propagate offline, since we have the pmra, and pmdec
stardata_new = dbs.lookup_objects(*get_bbox(corners), star_max_magnitude=options['max_star_mag_dist'], time=date_string_to_float(dateguess))
stardata.update_data(stardata_new)
dateguess = distortion_cubic._date_guess(dateguess, initial_guess, plate2, stardata, image_size, options)
#stardata_new = dbs.lookup_objects(*get_bbox(corners), star_max_magnitude=options['max_star_mag_dist'], time=date_string_to_float(dateguess))
#stardata.update_data(stardata_new)
stardata.update_epoch(date_string_to_float(dateguess))

result, plate2_corrected, reg_x, reg_y = distortion_cubic.do_cubic_fit(plate2, stardata, initial_guess, image_size, options)
transformed_final = transforms.linear_transform(result, plate2_corrected, image_size)
Expand Down Expand Up @@ -320,8 +320,8 @@ def match_and_fit_distortion(path_data, options, debug_folder=None):
'px_dist': plate2_unfiltered_corrected[:, 1]+image_size[1]/2,
'py_dist': plate2_unfiltered_corrected[:, 0]+image_size[0]/2,
'ID': ['gaia:'+str(_) for _ in stardata_unfiltered.ids],
'RA(catalog)': np.degrees(stardata_unfiltered.data[:, 0]),
'DEC(catalog)': np.degrees(stardata_unfiltered.data[:, 1]),
'RA(catalog)': np.degrees(stardata_unfiltered.get_ra()),
'DEC(catalog)': np.degrees(stardata_unfiltered.get_dec()),
'RA(obs)': transforms.to_polar(transformed_final)[:, 1],
'DEC(obs)': transforms.to_polar(transformed_final)[:, 0],
'magV': stardata_unfiltered.get_mags(),
Expand Down
13 changes: 7 additions & 6 deletions gaia_search.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from astroquery.gaia import Gaia
#import astropy.units as u
import astropy.units as u
#from astropy.coordinates import SkyCoord
import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -45,7 +45,7 @@ def get_prop_pos(T1):

def select_in_box(T1, ra_range, dec_range, max_mag):
query = f"SELECT source_id, phot_g_mean_mag, COORD1(ESDC_EPOCH_PROP_POS(ra, dec, parallax, pmra, pmdec, radial_velocity, ref_epoch, {T1})),\
COORD2(ESDC_EPOCH_PROP_POS(ra, dec, parallax, pmra, pmdec, radial_velocity, ref_epoch, {T1})), parallax, pmra, pmdec \
COORD2(ESDC_EPOCH_PROP_POS(ra, dec, parallax, pmra, pmdec, radial_velocity, ref_epoch, {T1})), parallax, pmra, pmdec, ref_epoch \
FROM gaiadr3.gaia_source \
WHERE ra BETWEEN {ra_range[0]} AND {ra_range[1]} AND \
dec BETWEEN {dec_range[0]} AND {dec_range[1]} AND \
Expand All @@ -59,7 +59,7 @@ def select_in_box(T1, ra_range, dec_range, max_mag):
return results

def lookup_nearby(startable, distance, max_mag_neighbours):
query = f"SELECT source_id, phot_g_mean_mag, ra, dec \
query = f"SELECT source_id, phot_g_mean_mag, ra, dec, ref_epoch \
FROM gaiadr3.gaia_source \
WHERE "

Expand All @@ -85,7 +85,7 @@ def helper(ra, dec):
star_table[:, 3] = np.sin(star_table[:, 0]) * np.cos(star_table[:, 1])
star_table[:, 4] = np.sin(star_table[:, 1])
star_catID = results['source_id']
return StarData.StarData(star_catID, star_table, 2016)
return StarData.StarData(results, 2016, False)

gaia_limit=13
class dbs_gaia:
Expand All @@ -97,7 +97,8 @@ def lookup_objects(self, range_ra, range_dec, star_max_magnitude=12, time=2024):
l = len(results)

star_table = np.zeros((l, 9), dtype=float)

results['ra'] = results['COORD1'] * u.deg
results['dec'] = results['COORD2'] * u.deg
star_table[:, 0] = np.radians(results['COORD1'])
star_table[:, 1] = np.radians(results['COORD2'])
star_table[:, 5] = results['phot_g_mean_mag']
Expand All @@ -108,7 +109,7 @@ def lookup_objects(self, range_ra, range_dec, star_max_magnitude=12, time=2024):
star_table[:, 7] = results['pmra']
star_table[:, 8] = results['pmdec']
star_catID = results['source_id']
return StarData.StarData(star_catID, star_table, time)
return StarData.StarData(results, time, True)

if __name__ == '__main__':
#l = select_in_box(2024, (37.4, 37.5), (0.35, 0.45), 16)
Expand Down
Loading

0 comments on commit ff82446

Please sign in to comment.