Skip to content

Commit

Permalink
add more fits sourcelist tests
Browse files Browse the repository at this point in the history
  • Loading branch information
d3v-null committed Feb 2, 2024
1 parent c1acda6 commit 9065adf
Show file tree
Hide file tree
Showing 7 changed files with 403 additions and 23 deletions.
41 changes: 29 additions & 12 deletions src/srclist/fits/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ impl CommonCols {
hdu: &FitsHdu,
col_names: &[String],
) -> Result<Self, FitsError> {
macro_rules! read_col {
macro_rules! read_optional_col {
($possible_col_names: expr) => {{
let mut maybe_col = None;
for possible_col_name in $possible_col_names {
Expand All @@ -116,10 +116,18 @@ impl CommonCols {
maybe_col = Some(fe!(file, hdu.read_col(fptr, possible_col_name)));
}
}
maybe_col.unwrap_or_else(|| {
if !maybe_col.is_some() {
debug!("None of {:?} were available columns!", $possible_col_names)
}
maybe_col
}}
}
macro_rules! read_mandatory_col {
($possible_col_names: expr) => {{
read_optional_col!($possible_col_names).unwrap_or_else(|| {
panic!("None of {:?} were available columns!", $possible_col_names)
})
}};
}}
}

let unq_source_id = if col_names.iter().any(|col_name| col_name == "UNQ_SOURCE_ID") {
Expand All @@ -128,12 +136,21 @@ impl CommonCols {
vec![]
};

let names = read_col!(["NAME", "Name"]);
let ra_degrees = read_col!(["RA", "RAJ2000"]);
let dec_degrees = read_col!(["DEC", "DEJ2000"]);
let majors = read_col!(["MAJOR_DC", "a"]);
let minors = read_col!(["MINOR_DC", "b"]);
let pas = read_col!(["PA_DC", "pa"]);
let names = read_mandatory_col!(["NAME", "Name"]);
let ra_degrees = read_mandatory_col!(["RA", "RAJ2000"]);
let dec_degrees = read_mandatory_col!(["DEC", "DEJ2000"]);
let majors = read_mandatory_col!(["MAJOR_DC", "a"]);
let minors = read_mandatory_col!(["MINOR_DC", "b"]);
let pas = read_mandatory_col!(["PA_DC", "pa"]);
// let majors = read_optional_col!(["MAJOR_DC", "a"]).unwrap_or_else(|| {
// vec![0.0; names.len()]
// });
// let minors = read_optional_col!(["MINOR_DC", "b"]).unwrap_or_else(|| {
// vec![0.0; names.len()]
// });
// let pas = read_optional_col!(["PA_DC", "pa"]).unwrap_or_else(|| {
// vec![0.0; names.len()]
// });

// Get any shapelet info ready. We assume that the info lives in HDU 3
// (index 2 in sane languages), and if there's an error, we assume it's
Expand Down Expand Up @@ -230,8 +247,8 @@ impl CommonCols {
comp_types
};

let power_law_stokes_is: Vec<f64> = read_col!(["NORM_COMP_PL", "S_200"]);
let power_law_alphas = read_col!(["ALPHA_PL", "alpha"]);
let power_law_stokes_is: Vec<f64> = read_mandatory_col!(["NORM_COMP_PL", "S_200"]);
let power_law_alphas = read_mandatory_col!(["ALPHA_PL", "alpha"]);

let (curved_power_law_stokes_is, curved_power_law_alphas, curved_power_law_qs): (
Vec<f64>,
Expand All @@ -241,7 +258,7 @@ impl CommonCols {
(
fe!(file, hdu.read_col(fptr, "NORM_COMP_CPL")),
fe!(file, hdu.read_col(fptr, "ALPHA_CPL")),
fe!(file, hdu.read_col(fptr, "CURVE_CPL")),
fe!(file, hdu.read_col(fptr, "CURVE_CPL")), // todo: gsm beta for cpl?
)
} else {
(vec![], vec![], vec![])
Expand Down
140 changes: 129 additions & 11 deletions src/srclist/general_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
use std::{
fs::File,
io::{BufReader, Cursor, Read},
path::PathBuf
};

use approx::assert_abs_diff_eq;
use approx::{abs_diff_eq, assert_abs_diff_eq};
use marlu::RADec;
use tempfile::NamedTempFile;
use vec1::vec1;

use super::*;
use super::{*, fits::parse_source_list};
use crate::constants::DEFAULT_SPEC_INDEX;

fn test_two_sources_lists_are_the_same(sl1: &SourceList, sl2: &SourceList) {
Expand Down Expand Up @@ -107,10 +108,10 @@ fn test_two_sources_lists_are_the_same(sl1: &SourceList, sl2: &SourceList) {
}

FluxDensityType::PowerLaw { .. } => {
assert!(matches!(
s2_comp.flux_type,
FluxDensityType::PowerLaw { .. }
));
assert!(
matches!( s2_comp.flux_type, FluxDensityType::PowerLaw { .. } ),
"{sl1_name}: fdtype mismatch {s1_comp:?} {s2_comp:?}"
);
match s2_comp.flux_type {
FluxDensityType::PowerLaw { .. } => {
// The parameters of the power law may not
Expand All @@ -119,7 +120,7 @@ fn test_two_sources_lists_are_the_same(sl1: &SourceList, sl2: &SourceList) {
let s1_fd = s1_comp.flux_type.estimate_at_freq(150e6);
let s2_fd = s2_comp.flux_type.estimate_at_freq(150e6);
assert_abs_diff_eq!(s1_fd.freq, s2_fd.freq, epsilon = 1e-10);
assert_abs_diff_eq!(s1_fd.i, s2_fd.i, epsilon = 1e-10);
assert!(abs_diff_eq!(s1_fd.i, s2_fd.i, epsilon = 1e-10), "{sl1_name}: i flux mismatch {s1_comp:?} != {s2_comp:?}");
assert_abs_diff_eq!(s1_fd.q, s2_fd.q, epsilon = 1e-10);
assert_abs_diff_eq!(s1_fd.u, s2_fd.u, epsilon = 1e-10);
assert_abs_diff_eq!(s1_fd.v, s2_fd.v, epsilon = 1e-10);
Expand All @@ -136,10 +137,10 @@ fn test_two_sources_lists_are_the_same(sl1: &SourceList, sl2: &SourceList) {
}

FluxDensityType::CurvedPowerLaw { .. } => {
assert!(matches!(
s2_comp.flux_type,
FluxDensityType::PowerLaw { .. }
));
assert!(
matches!( s2_comp.flux_type, FluxDensityType::CurvedPowerLaw { .. } ),
"{sl1_name}: fdtype mismatch {s1_comp:?} {s2_comp:?}"
);
match s2_comp.flux_type {
FluxDensityType::CurvedPowerLaw { .. } => {
// The parameters of the curved power law may
Expand Down Expand Up @@ -873,3 +874,120 @@ fn read_invalid_json_file() {
));
}
}

fn get_fits_expected_srclist(ref_freq: f64, include_list: bool, include_cpl: bool) -> SourceList {
let mut expected_srclist = SourceList::new();
let cmp_type_gaussian = ComponentType::Gaussian {
maj: 20.0_f64.to_radians(),
min: 10.0_f64.to_radians(),
pa: 75.0_f64.to_radians(),
};
let flux_type_list = FluxDensityType::List(vec1![
FluxDensity { freq: ref_freq, i: 1.0, q: 0.0, u: 0.0, v: 0.0, },
]);
let flux_type_pl = FluxDensityType::PowerLaw {
si: -0.8,
fd: FluxDensity { freq: ref_freq, i: 2.0, q: 0.0, u: 0.0, v: 0.0, },
};
let flux_type_cpl = FluxDensityType::CurvedPowerLaw {
si: -0.9,
fd: FluxDensity { freq: ref_freq, i: 3.0, q: 0.0, u: 0.0, v: 0.0, },
q: 0.2,
};
if include_list {
expected_srclist.insert(
"point-list".into(),
Source { components: vec![SourceComponent {
radec: RADec::from_degrees(0.0, 1.0),
comp_type: ComponentType::Point,
flux_type: flux_type_list.clone()
}].into()}
);
}
expected_srclist.insert(
"point-pl".into(),
Source { components: vec![SourceComponent {
radec: RADec::from_degrees(1.0, 2.0),
comp_type: ComponentType::Point,
flux_type: flux_type_pl.clone(),
}].into()}
);
if include_cpl {
expected_srclist.insert(
"point-cpl".into(),
Source { components: vec![SourceComponent {
radec: RADec::from_degrees(3.0, 4.0),
comp_type: ComponentType::Point,
flux_type: flux_type_cpl.clone(),
}].into()}
);
}
if include_list {
expected_srclist.insert(
"gauss-list".into(),
Source { components: vec![SourceComponent {
radec: RADec::from_degrees(0.0, 1.0),
comp_type: cmp_type_gaussian.clone(),
flux_type: flux_type_list
}].into()}
);
}
expected_srclist.insert(
"gauss-pl".into(),
Source { components: vec![SourceComponent {
radec: RADec::from_degrees(1.0, 2.0),
comp_type: cmp_type_gaussian.clone(),
flux_type: flux_type_pl,
}].into()}
);
if include_cpl {
expected_srclist.insert(
"gauss-cpl".into(),
Source { components: vec![SourceComponent {
radec: RADec::from_degrees(3.0, 4.0),
comp_type: cmp_type_gaussian,
flux_type: flux_type_cpl,
}].into()}
);
}
expected_srclist
}

#[test]
fn test_parse_gleam_fits() {
// TODO(Dev): CPL and List

// python -c 'from astropy.io import fits; import sys; from tabulate import tabulate; print(tabulate((i:=fits.open(sys.argv[-1])[1]).data, headers=[c.name for c in i.columns]))' /home/dev/src/hyperdrive_main/test_files/gleam.fits
// Name RAJ2000 DEJ2000 S_200 alpha beta a b pa
// ---------- --------- --------- ------- ------- ------ --- --- ----
// point-pl 1 2 1 -0.8 0 0 0 0
// gauss-pl 1 2 1 -0.8 0 20 10 75


let res_srclist = parse_source_list(&PathBuf::from("test_files/gleam.fits")).unwrap();
let expected_srclist = get_fits_expected_srclist(200e6, false, false);
dbg!(&res_srclist, &expected_srclist);
test_two_sources_lists_are_the_same(&res_srclist, &expected_srclist);
}

#[test]
fn test_parse_jack_fits() {
// TODO(Dev): List

// python -c 'from astropy.io import fits; import sys; from tabulate import tabulate; print(tabulate((i:=fits.open(sys.argv[-1])[1]).data, headers=[c.name for c in i.columns]))' /home/dev/src/hyperdrive_main/test_files/jack.fits
// UNQ_SOURCE_ID NAME RA DEC INT_FLX150 MAJOR_DC MINOR_DC PA_DC MOD_TYPE NORM_COMP_PL ALPHA_PL NORM_COMP_CPL ALPHA_CPL CURVE_CPL
// --------------- ------------- ---- ----- ------------ ---------- ---------- ------- ---------- -------------- ---------- --------------- ----------- -----------
// point-pl point-pl_C0 1 2 1 0 0 0 pl 1 -0.8 0 0 0
// point-cpl point-cpl_C0 3 4 1 0 0 0 cpl 0 0 1 -0.8 0.2
// gauss-pl gauss-pl_C0 1 2 1 20 10 75 pl 1 -0.8 0 0 0
// gauss-cpl gauss-cpl_C0 3 4 1 20 10 75 cpl 0 0 1 -0.8 0.2

// setup logging
// use crate::cli::setup_logging;
// setup_logging(3).expect("Failed to setup logging");

let res_srclist = parse_source_list(&PathBuf::from("test_files/jack.fits")).unwrap();
let expected_srclist = get_fits_expected_srclist(200e6, false, true);
dbg!(&res_srclist, &expected_srclist);
test_two_sources_lists_are_the_same(&res_srclist, &expected_srclist);
}
102 changes: 102 additions & 0 deletions test_files/gen_fits_srclists.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#!/usr/bin/env python

import pandas as pd
from astropy.coordinates import Angle
import astropy.units as u
from astropy.io import fits
from astropy.table import Table
import numpy as np
import re
import io

arrays_jack = {
"UNQ_SOURCE_ID": [],
"NAME": [],
"RA": [],
"DEC": [],
"INT_FLX150": [],
"MAJOR_DC": [],
"MINOR_DC": [],
"PA_DC": [],
"MOD_TYPE": [],
"NORM_COMP_PL": [],
"ALPHA_PL": [],
"NORM_COMP_CPL": [],
"ALPHA_CPL": [],
"CURVE_CPL": [],
}
arrays_gleam = {
"Name": [],
"RAJ2000": [],
"DEJ2000": [],
"S_200": [],
"alpha": [],
"beta": [],
"a": [],
"b": [],
"pa": [],
}
for (i,( name, cmp, ra, dec, fd, maj, min, pa, sn1, sn2, sv, typ, alpha, q, )) in enumerate([
["point-list", 0, 0., 1., 1., 0., 0., 0., 0., 0., 0.0, "nan", 0.0, 0.0 ],
["point-pl", 0, 1., 2., 2., 0., 0., 0., 0., 0., 0.0, "pl", -0.8, 0.0 ],
["point-cpl", 0, 3., 4., 3., 0., 0., 0., 0., 0., 0.0, "cpl", -0.9, 0.2 ],
["gauss-list", 0, 0., 1., 1., 20., 10., 75., 0., 0., 0.0, "nan", 0.0, 0.0 ],
["gauss-pl", 0, 1., 2., 2., 20., 10., 75., 0., 0., 0.0, "pl", -0.8, 0.0 ],
["gauss-cpl", 0, 3., 4., 3., 20., 10., 75., 0., 0., 0.0, "cpl", -0.9, 0.2 ],
# ["shape-list", 0, 0., 1., 1., 20., 10., 75., 0., 1., 0.5, "nan", 0.0, 0.0 ],
# ["shape-pl", 0, 1., 2., 1., 20., 10., 75., 0., 1., 0.5, "pl", -0.8, 0.0 ],
# ["shape-cpl", 0, 3., 4., 1., 20., 10., 75., 0., 1., 0.5, "cpl", -0.8, 0.2 ],
# todo: shapelets
]):
if typ == "nan":
continue

# i_ = f"{i:04d}"
arrays_jack["UNQ_SOURCE_ID"].append(f"{name}")
arrays_jack["NAME"].append(f"{name}_C{cmp}")
arrays_jack["RA"].append(ra)
arrays_jack["DEC"].append(dec)
arrays_jack["INT_FLX150"].append(fd)
arrays_jack["MAJOR_DC"].append(maj)
arrays_jack["MINOR_DC"].append(min)
arrays_jack["PA_DC"].append(pa)
arrays_jack["MOD_TYPE"].append(typ)
# arrays_jack["NORM_COMP_PL"].append(fd)
# arrays_jack["ALPHA_PL"].append(alpha)

if typ == "cpl":
arrays_jack["NORM_COMP_PL"].append(0.0)
arrays_jack["ALPHA_PL"].append(0.0)
arrays_jack["NORM_COMP_CPL"].append(fd)
arrays_jack["ALPHA_CPL"].append(alpha)
arrays_jack["CURVE_CPL"].append(q)
else:
arrays_jack["NORM_COMP_PL"].append(fd)
arrays_jack["ALPHA_PL"].append(alpha)
arrays_jack["NORM_COMP_CPL"].append(0.0)
arrays_jack["ALPHA_CPL"].append(0.0)
arrays_jack["CURVE_CPL"].append(0.0)

if typ == "cpl":
continue

arrays_gleam['Name'].append(f"{name}")
arrays_gleam['RAJ2000'].append(ra)
arrays_gleam['DEJ2000'].append(dec)
arrays_gleam['S_200'].append(fd)
arrays_gleam['alpha'].append(alpha)
arrays_gleam['beta'].append(q)
arrays_gleam['a'].append(maj * 3600)
arrays_gleam['b'].append(min * 3600)
arrays_gleam['pa'].append(pa)


table = Table(arrays_jack, )
table.write('test_files/jack.fits', overwrite=True)
df = table.to_pandas()
print("jack\n", df[[df.columns[0], *df.columns[2:]]].to_string(index=False))

table = Table(arrays_gleam, )
table.write('test_files/gleam.fits', overwrite=True)
df = table.to_pandas()
print("gleam\n", df.to_string(index=False))
Binary file added test_files/gleam.fits
Binary file not shown.
Loading

0 comments on commit 9065adf

Please sign in to comment.