From e7be574c6dfc041dddd20d50db0b000dcbd0aaf4 Mon Sep 17 00:00:00 2001 From: "Christopher H. Jordan" Date: Fri, 11 Nov 2022 11:17:53 +0800 Subject: [PATCH] WIP. Peeling! WIP: Write thorough tests that intended visibilities make their way to the peel functions. Currently, averaging to 1.28 MHz isn't happening WIP: Add UVW cutoffs WIP: Fix error handling in src/cli/error.rs This is a squashed commit of months of work by CHJ and Dev. Big props to Dev for their help. --- .github/workflows/bench.yml | 44 + src/averaging/mod.rs | 95 +- src/cli/di_calibrate/tests.rs | 28 +- src/cli/error.rs | 50 +- src/cli/mod.rs | 9 + src/cli/peel/error.rs | 72 + src/cli/peel/mod.rs | 651 +++++++ src/di_calibrate/mod.rs | 2 +- src/gpu/common.cuh | 235 ++- src/gpu/mod.rs | 5 +- src/gpu/peel.cu | 681 +++++++ src/gpu/peel.h | 38 + src/gpu/peel_double.rs | 78 + src/gpu/peel_single.rs | 78 + src/gpu/types.h | 33 + src/gpu/types_double.rs | 75 + src/gpu/types_single.rs | 75 + src/gpu/update_rust_bindings.sh | 12 + src/io/write/mod.rs | 1 + src/model/cpu.rs | 49 +- src/model/gpu.rs | 427 +++-- src/model/mod.rs | 19 +- src/params/di_calibration.rs | 106 +- src/params/mod.rs | 4 +- src/params/peel/mod.rs | 2336 +++++++++++++++++++++++ src/params/peel/tests.rs | 2459 +++++++++++++++++++++++++ src/srclist/types/components/mod.rs | 21 +- src/srclist/types/components/tests.rs | 6 +- src/unit_parsing/mod.rs | 4 + 29 files changed, 7404 insertions(+), 289 deletions(-) create mode 100644 .github/workflows/bench.yml create mode 100644 src/cli/peel/error.rs create mode 100644 src/cli/peel/mod.rs create mode 100644 src/gpu/peel.cu create mode 100644 src/gpu/peel.h create mode 100644 src/gpu/peel_double.rs create mode 100644 src/gpu/peel_single.rs create mode 100644 src/params/peel/mod.rs create mode 100644 src/params/peel/tests.rs diff --git a/.github/workflows/bench.yml b/.github/workflows/bench.yml new file mode 100644 index 00000000..e5c2a695 --- /dev/null +++ b/.github/workflows/bench.yml @@ -0,0 +1,44 @@ +--- +name: Benchmarks + +on: + push: + pull_request: + branches: + - "**" + +env: + CARGO_TERM_COLOR: always + CARGO_INCREMENTAL: 0 + MWA_BEAM_FILE: /usr/local/mwa_full_embedded_element_pattern.h5 + +jobs: + test: + name: Benchmarks + runs-on: self-hosted + steps: + - name: Checkout sources + uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Install stable toolchain + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: stable + override: true + + - name: Cargo Bench + run: | + cargo bench + env: + HYPERDRIVE_TEST_DIR: /home/runner/data + + - name: Zip benchmark results + run: zip -r criterion.zip target/criterion/* + + - uses: actions/upload-artifact@v2 + with: + name: criterion.zip + path: criterion.zip diff --git a/src/averaging/mod.rs b/src/averaging/mod.rs index b93e27e0..5dadcec4 100644 --- a/src/averaging/mod.rs +++ b/src/averaging/mod.rs @@ -478,7 +478,8 @@ pub(super) fn parse_freq_average_factor( // Scale the quantity by the unit, if required. let quantity = match time_format { FreqFormat::Hz => quantity, - FreqFormat::kHz => 1000.0 * quantity, + FreqFormat::kHz => 1e3 * quantity, + FreqFormat::MHz => 1e6 * quantity, }; let factor = quantity / freq_res; // Reject non-integer floats. @@ -615,3 +616,95 @@ fn vis_average_weights_non_zero( } *weight_to = weight_sum as f32; } + +/// This function is the same as `vis_average`, except it assumes that any +/// flagged visibilities are indicated by weights equal to 0.0 (or -0.0). This +/// allows the code to be more efficient. +pub(crate) fn vis_average_no_negative_weights( + jones_from_tfb: ArrayView3>, + mut jones_to_fb: ArrayViewMut2>, + weight_from_tfb: ArrayView3, + mut weight_to_fb: ArrayViewMut2, + flagged_chan_indices: &HashSet, +) { + let avg_time = jones_from_tfb.len_of(Axis(0)); + let avg_freq = (jones_from_tfb.len_of(Axis(1)) as f64 / jones_to_fb.len_of(Axis(0)) as f64) + .ceil() as usize; + + // iterate along time axis in chunks of avg_time + jones_from_tfb + .axis_chunks_iter(Axis(0), avg_time) + .zip(weight_from_tfb.axis_chunks_iter(Axis(0), avg_time)) + .for_each(|(jones_chunk_tfb, weight_chunk_tfb)| { + jones_chunk_tfb + .axis_iter(Axis(2)) + .zip(weight_chunk_tfb.axis_iter(Axis(2))) + .zip(jones_to_fb.axis_iter_mut(Axis(1))) + .zip(weight_to_fb.axis_iter_mut(Axis(1))) + .for_each( + |(((jones_chunk_tf, weight_chunk_tf), mut jones_to_f), mut weight_to_f)| { + jones_chunk_tf + .axis_chunks_iter(Axis(1), avg_freq) + .zip(weight_chunk_tf.axis_chunks_iter(Axis(1), avg_freq)) + .enumerate() + .filter(|(i, _)| !flagged_chan_indices.contains(&(*i as u16))) + .map(|(_, d)| d) + .zip(jones_to_f.iter_mut()) + .zip(weight_to_f.iter_mut()) + .for_each( + |(((jones_chunk_tf, weight_chunk_tf), jones_to), weight_to)| { + vis_average_weights_are_zero( + jones_chunk_tf, + weight_chunk_tf, + jones_to, + weight_to, + ); + }, + ); + }, + ); + }); +} + +/// Average a chunk of visibilities and weights (both must have the same +/// dimensions) into an output vis and weight. This function needs the input +/// weights to be 0 or greater; this allows the averaging algorithm to be +/// simpler (as well as hopefully being faster), while also allowing further +/// averaging of the output visibilities without complicated logic. +#[inline] +fn vis_average_weights_are_zero( + jones_chunk_tf: ArrayView2>, + weight_chunk_tf: ArrayView2, + jones_to: &mut Jones, + weight_to: &mut f32, +) { + let mut jones_weighted_sum = Jones::default(); + let mut weight_sum = 0.0; + + // iterate through time chunks + jones_chunk_tf + .outer_iter() + .zip_eq(weight_chunk_tf.outer_iter()) + .for_each(|(jones_chunk_f, weights_chunk_f)| { + jones_chunk_f + .iter() + .zip_eq(weights_chunk_f.iter()) + .for_each(|(jones, weight)| { + // Any flagged visibilities would have a weight <= 0, but + // we've already capped them to 0. This means we don't need + // to check the value of the weight when accumulating + // unflagged visibilities; the flagged ones contribute + // nothing. + + let jones = Jones::::from(*jones); + let weight = *weight as f64; + jones_weighted_sum += jones * weight; + weight_sum += weight; + }); + }); + + if weight_sum > 0.0 { + *jones_to = Jones::from(jones_weighted_sum / weight_sum); + *weight_to = weight_sum as f32; + } +} diff --git a/src/cli/di_calibrate/tests.rs b/src/cli/di_calibrate/tests.rs index 825359f0..1fe90ce7 100644 --- a/src/cli/di_calibrate/tests.rs +++ b/src/cli/di_calibrate/tests.rs @@ -27,7 +27,6 @@ use crate::{ MsReader, VisRead, }, math::TileBaselineFlags, - params::CalVis, tests::{ get_reduced_1090008640_ms, get_reduced_1090008640_raw, get_reduced_1090008640_uvfits, DataAsStrings, @@ -918,14 +917,11 @@ fn test_1090008640_calibration_quality_raw() { }; let params = args.parse().unwrap(); - let CalVis { - vis_data, - vis_model, - .. - } = params + let mut cal_vis = params .get_cal_vis() .expect("Couldn't read data and generate a model"); - test_1090008640_quality(params, vis_data.view(), vis_model.view()); + cal_vis.scale_by_weights(Some(¶ms.baseline_weights)); + test_1090008640_quality(params, cal_vis.vis_data.view(), cal_vis.vis_model.view()); } #[test] @@ -962,14 +958,11 @@ fn test_1090008640_calibration_quality_ms() { }; let params = args.parse().unwrap(); - let CalVis { - vis_data, - vis_model, - .. - } = params + let mut cal_vis = params .get_cal_vis() .expect("Couldn't read data and generate a model"); - test_1090008640_quality(params, vis_data.view(), vis_model.view()); + cal_vis.scale_by_weights(Some(¶ms.baseline_weights)); + test_1090008640_quality(params, cal_vis.vis_data.view(), cal_vis.vis_model.view()); } #[test] @@ -1005,12 +998,9 @@ fn test_1090008640_calibration_quality_uvfits() { }; let params = args.parse().unwrap(); - let CalVis { - vis_data, - vis_model, - .. - } = params + let mut cal_vis = params .get_cal_vis() .expect("Couldn't read data and generate a model"); - test_1090008640_quality(params, vis_data.view(), vis_model.view()); + cal_vis.scale_by_weights(Some(¶ms.baseline_weights)); + test_1090008640_quality(params, cal_vis.vis_data.view(), cal_vis.vis_model.view()); } diff --git a/src/cli/error.rs b/src/cli/error.rs index 907f7830..92799597 100644 --- a/src/cli/error.rs +++ b/src/cli/error.rs @@ -10,6 +10,7 @@ use thiserror::Error; use super::{ common::InputVisArgsError, di_calibrate::DiCalArgsError, + peel::PeelArgsError, solutions::{SolutionsApplyArgsError, SolutionsPlotError}, srclist::SrclistByBeamError, vis_convert::VisConvertArgsError, @@ -24,7 +25,7 @@ use crate::{ GlobError, }, model::ModelError, - params::{DiCalibrateError, VisConvertError, VisSimulateError, VisSubtractError}, + params::{DiCalibrateError, PeelError, VisConvertError, VisSimulateError, VisSubtractError}, solutions::{SolutionsReadError, SolutionsWriteError}, srclist::{ReadSourceListError, SrclistError, WriteSourceListError}, }; @@ -39,6 +40,10 @@ pub enum HyperdriveError { #[error("{0}\n\nSee for more info: {URL}/user/di_cal/intro.html")] DiCalibrate(String), + /// An error related to peeling. + #[error("{0}\n\nSee for more info: {URL}/*****.html")] + Peel(String), + /// An error related to solutions-apply. #[error("{0}\n\nSee for more info: {URL}/user/solutions_apply/intro.html")] SolutionsApply(String), @@ -164,6 +169,49 @@ impl From for HyperdriveError { } } +impl From for HyperdriveError { + fn from(e: PeelArgsError) -> Self { + match e { + PeelArgsError::NoOutput + | PeelArgsError::NoChannels + | PeelArgsError::ZeroPasses + | PeelArgsError::ParseIonoTimeAverageFactor(_) + | PeelArgsError::ParseIonoFreqAverageFactor(_) + | PeelArgsError::IonoTimeFactorNotInteger + | PeelArgsError::IonoFreqFactorNotInteger + | PeelArgsError::IonoTimeResNotMultiple { .. } + | PeelArgsError::IonoFreqResNotMultiple { .. } + | PeelArgsError::IonoTimeFactorZero + | PeelArgsError::IonoFreqFactorZero + | PeelArgsError::ParseUvwMin(_) + | PeelArgsError::ParseUvwMax(_) => Self::Generic(e.to_string()), + PeelArgsError::Glob(e) => Self::from(e), + PeelArgsError::VisRead(e) => Self::from(e), + PeelArgsError::FileWrite(e) => Self::from(e), + PeelArgsError::SourceList(e) => Self::from(e), + PeelArgsError::Beam(e) => Self::from(e), + PeelArgsError::Model(e) => Self::from(e), + PeelArgsError::IO(e) => Self::from(e), + #[cfg(any(feature = "cuda", feature = "hip"))] + PeelArgsError::Gpu(e) => Self::from(e), + } + } +} + +impl From for HyperdriveError { + fn from(e: PeelError) -> Self { + match e { + PeelError::VisRead(e) => Self::from(e), + PeelError::FileWrite(e) => Self::from(e), + PeelError::Beam(e) => Self::from(e), + PeelError::Model(e) => Self::from(e), + PeelError::IO(e) => Self::from(e), + #[cfg(any(feature = "cuda", feature = "hip"))] + PeelError::Gpu(e) => Self::from(e), + } + } +} + impl From for HyperdriveError { fn from(e: SolutionsApplyArgsError) -> Self { let s = e.to_string(); diff --git a/src/cli/mod.rs b/src/cli/mod.rs index 4d720e3d..1e0e84e4 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -18,6 +18,7 @@ mod beam; mod di_calibrate; mod dipole_gains; mod error; +mod peel; mod solutions; mod srclist; mod vis_convert; @@ -94,6 +95,9 @@ https://mwatelescope.github.io/mwa_hyperdrive/user/di_cal/intro.html"# )] DiCalibrate(di_calibrate::DiCalArgs), + #[clap(about = r#"Peeling!"#)] + Peel(peel::PeelArgs), + #[clap(alias = "convert-vis")] #[clap(about = r#"Convert visibilities from one type to another. https://mwatelescope.github.io/mwa_hyperdrive/user/vis_convert/intro.html"#)] @@ -171,6 +175,7 @@ impl Hyperdrive { // Print the version of hyperdrive and its build-time information. let sub_command = match &self.command { Command::DiCalibrate(_) => "di-calibrate", + Command::Peel(_) => "peel", Command::VisConvert(_) => "vis-convert", Command::VisSimulate(_) => "vis-simulate", Command::VisSubtract(_) => "vis-subtract", @@ -209,6 +214,10 @@ impl Hyperdrive { merge_save_run!(args) } + Command::Peel(args) => { + merge_save_run!(args) + } + Command::VisConvert(args) => { merge_save_run!(args) } diff --git a/src/cli/peel/error.rs b/src/cli/peel/error.rs new file mode 100644 index 00000000..2cdbe769 --- /dev/null +++ b/src/cli/peel/error.rs @@ -0,0 +1,72 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#[derive(thiserror::Error, Debug)] +pub(crate) enum PeelArgsError { + #[error("No calibration output was specified. There must be at least one calibration solution file.")] + NoOutput, + + #[error("The data either contains no frequency channels or all channels are flagged")] + NoChannels, + + #[error("The number of iono sub passes cannot be 0")] + ZeroPasses, + + #[error("Error when parsing iono time average factor: {0}")] + ParseIonoTimeAverageFactor(crate::unit_parsing::UnitParseError), + + #[error("Error when parsing iono freq. average factor: {0}")] + ParseIonoFreqAverageFactor(crate::unit_parsing::UnitParseError), + + #[error("Iono time average factor isn't an integer")] + IonoTimeFactorNotInteger, + + #[error("Iono freq. average factor isn't an integer")] + IonoFreqFactorNotInteger, + + #[error( + "Iono time resolution isn't a multiple of input data's: {out} seconds vs {inp} seconds" + )] + IonoTimeResNotMultiple { out: f64, inp: f64 }, + + #[error("Iono freq. resolution isn't a multiple of input data's: {out} Hz vs {inp} Hz")] + IonoFreqResNotMultiple { out: f64, inp: f64 }, + + #[error("Iono time average factor cannot be 0")] + IonoTimeFactorZero, + + #[error("Iono freq. average factor cannot be 0")] + IonoFreqFactorZero, + + #[error("Error when parsing minimum UVW cutoff: {0}")] + ParseUvwMin(crate::unit_parsing::UnitParseError), + + #[error("Error when parsing maximum UVW cutoff: {0}")] + ParseUvwMax(crate::unit_parsing::UnitParseError), + + #[error(transparent)] + Glob(#[from] crate::io::GlobError), + + #[error(transparent)] + VisRead(#[from] crate::io::read::VisReadError), + + #[error(transparent)] + FileWrite(#[from] crate::io::write::FileWriteError), + + #[error("Error when trying to read source list: {0}")] + SourceList(#[from] crate::srclist::ReadSourceListError), + + #[error(transparent)] + Beam(#[from] crate::beam::BeamError), + + #[error(transparent)] + Model(#[from] crate::model::ModelError), + + #[error(transparent)] + IO(#[from] std::io::Error), + + #[cfg(any(feature = "cuda", feature = "hip"))] + #[error(transparent)] + Gpu(#[from] crate::gpu::GpuError), +} diff --git a/src/cli/peel/mod.rs b/src/cli/peel/mod.rs new file mode 100644 index 00000000..2da81bcb --- /dev/null +++ b/src/cli/peel/mod.rs @@ -0,0 +1,651 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +mod error; +pub(crate) use error::PeelArgsError; + +use std::{collections::HashSet, num::NonZeroUsize, path::PathBuf, str::FromStr}; + +use clap::Parser; +use log::{debug, info, trace}; +use marlu::{precession::precess_time, LatLngHeight}; +use serde::{Deserialize, Serialize}; +use vec1::Vec1; + +use super::common::{BeamArgs, InputVisArgs, ModellingArgs, SkyModelWithVetoArgs, ARG_FILE_HELP}; +use crate::{ + averaging::{ + channels_to_chanblocks, parse_freq_average_factor, parse_time_average_factor, + timesteps_to_timeblocks, AverageFactorError, + }, + cli::{ + common::{display_warnings, OutputVisArgs}, + Warn, + }, + io::write::{VisOutputType, VIS_OUTPUT_EXTENSIONS}, + params::{ModellingParams, PeelParams}, + unit_parsing::WAVELENGTH_FORMATS, + HyperdriveError, +}; + +const DEFAULT_OUTPUT_PEEL_FILENAME: &str = "hyperdrive_peeled.uvfits"; +const DEFAULT_OUTPUT_IONO_CONSTS: &str = "hyperdrive_iono_consts.json"; +#[cfg(not(any(feature = "cuda", feature = "hip")))] +const DEFAULT_NUM_PASSES: usize = 1; +#[cfg(any(feature = "cuda", feature = "hip"))] +const DEFAULT_NUM_PASSES: usize = 3; +const DEFAULT_TIME_AVERAGE_FACTOR: &str = "8s"; +const DEFAULT_FREQ_AVERAGE_FACTOR: &str = "80kHz"; +const DEFAULT_IONO_FREQ_AVERAGE_FACTOR: &str = "1.28MHz"; +const DEFAULT_OUTPUT_TIME_AVERAGE_FACTOR: &str = "8s"; +const DEFAULT_OUTPUT_FREQ_AVERAGE_FACTOR: &str = "80kHz"; +const DEFAULT_UVW_MIN: &str = "0λ"; + +lazy_static::lazy_static! { + static ref VIS_OUTPUTS_HELP: String = format!("The paths to the files where the peeled visibilities are written. Supported formats: {}", *VIS_OUTPUT_EXTENSIONS); + + static ref NUM_PASSES_HELP: String = format!("The number of times to iterate over all sources per iono timeblock. Default: {DEFAULT_NUM_PASSES}"); + + static ref TIME_AVERAGE_FACTOR_HELP: String = format!("The number of timesteps to use per timeblock *during* peeling. Also supports a target time resolution (e.g. 8s). If this is 0, then all data are averaged together. Default: {DEFAULT_TIME_AVERAGE_FACTOR}. e.g. If this variable is 4, then peeling is performed with 4 timesteps per timeblock. If the variable is instead 4s, then each timeblock contains up to 4s worth of data."); + + static ref FREQ_AVERAGE_FACTOR_HELP: String = format!("The number of fine-frequency channels to average together *before* peeling. Also supports a target time resolution (e.g. 80kHz). If this is 0, then all data is averaged together. Default: {DEFAULT_FREQ_AVERAGE_FACTOR}. e.g. If the input data is in 20kHz resolution and this variable was 2, then we average 40kHz worth of data into a chanblock before peeling. If the variable is instead 40kHz, then each chanblock contains up to 40kHz worth of data."); + + static ref IONO_FREQ_AVERAGE_FACTOR_HELP: String = format!("The number of fine-frequency channels to average together *during* peeling. Also supports a target time resolution (e.g. 1.28MHz). Cannot be 0. Default: {DEFAULT_IONO_FREQ_AVERAGE_FACTOR}. e.g. If the input data is in 40kHz resolution and this variable was 2, then we average 80kHz worth of data into a chanblock during peeling. If the variable is instead 1.28MHz, then each chanblock contains 32 fine channels."); + + static ref OUTPUT_TIME_AVERAGE_FACTOR_HELP: String = format!("The number of timeblocks to average together when writing out visibilities. Also supports a target time resolution (e.g. 8s). If this is 0, then all data are averaged together. Default: {DEFAULT_OUTPUT_TIME_AVERAGE_FACTOR}. e.g. If this variable is 4, then 8 timesteps are averaged together as a timeblock in the output visibilities."); + + static ref OUTPUT_FREQ_AVERAGE_FACTOR_HELP: String = format!("The number of fine-frequency channels to average together when writing out visibilities. Also supports a target time resolution (e.g. 80kHz). If this is 0, then all data are averaged together. Default: {DEFAULT_OUTPUT_FREQ_AVERAGE_FACTOR}. This is multiplicative with the freq average factor; e.g. If this variable is 4, and the freq average factor is 2, then 8 fine-frequency channels are averaged together as a chanblock in the output visibilities."); + + static ref UVW_MIN_HELP: String = format!("The minimum UVW length to use. This value must have a unit annotated. Allowed units: {}. Default: {DEFAULT_UVW_MIN}", *WAVELENGTH_FORMATS); + + static ref UVW_MAX_HELP: String = format!("The maximum UVW length to use. This value must have a unit annotated. Allowed units: {}. No default.", *WAVELENGTH_FORMATS); +} + +#[derive(Parser, Debug, Clone, Default, Serialize, Deserialize)] +pub(crate) struct PeelCliArgs { + /// The number of sources to peel. Peel sources are treated the same as + /// "ionospherically subtracted" sources, except before subtracting, a "DI + /// calibration" is done between the iono-rotated model and the data. This + /// allows for scintillation and any other phase shift to be corrected. + #[clap(long = "peel", help_heading = "PEELING")] + pub(super) num_sources_to_peel: Option, + + /// The number of sources to "ionospherically subtract". That is, a λ² + /// dependence is found for each of these sources and removed. The number of + /// iono subtract sources cannot be more than the number of sources to + /// subtract. The default is the number of sources in the source list, after + /// vetoing. + #[clap(long = "iono-sub", help_heading = "PEELING")] + pub(super) num_sources_to_iono_subtract: Option, + + /// The number of sources to subtract. This subtraction just uses the sky + /// model directly; no peeling or ionospheric λ² is found. There must be at + /// least as many sources subtracted as there are ionospherically + /// subtracted. The default is the number of sources in the source list, + /// after vetoing. + #[clap(long = "sub", help_heading = "PEELING")] + pub(super) num_sources_to_subtract: Option, + + #[clap(long, help = NUM_PASSES_HELP.as_str(), help_heading = "PEELING")] + pub(super) num_passes: Option, + + #[clap(short, long, help = TIME_AVERAGE_FACTOR_HELP.as_str(), help_heading = "PEELING")] + pub(super) iono_time_average_factor: Option, + + #[clap(short, long, help = FREQ_AVERAGE_FACTOR_HELP.as_str(), help_heading = "PEELING")] + pub(super) iono_freq_average_factor: Option, + + #[clap(short, long, help = IONO_FREQ_AVERAGE_FACTOR_HELP.as_str(), help_heading = "PEELING")] + pub(super) low_res_iono_freq_average: Option, + + #[clap(long, help = UVW_MIN_HELP.as_str(), help_heading = "CALIBRATION")] + pub(super) uvw_min: Option, + + #[clap(long, help = UVW_MAX_HELP.as_str(), help_heading = "CALIBRATION")] + pub(super) uvw_max: Option, + + #[clap(short, long, multiple_values(true), help = VIS_OUTPUTS_HELP.as_str(), help_heading = "OUTPUT FILES")] + pub(super) outputs: Option>, + + #[clap(long, help = OUTPUT_TIME_AVERAGE_FACTOR_HELP.as_str(), help_heading = "AVERAGING")] + pub(super) output_vis_time_average: Option, + + #[clap(long, help = OUTPUT_FREQ_AVERAGE_FACTOR_HELP.as_str(), help_heading = "AVERAGING")] + pub(super) output_vis_freq_average: Option, +} + +#[derive(Parser, Debug, Clone, Default, Serialize, Deserialize)] +pub(super) struct PeelArgs { + #[clap(name = "ARGUMENTS_FILE", help = ARG_FILE_HELP.as_str(), parse(from_os_str))] + pub(super) args_file: Option, + + #[clap(flatten)] + #[serde(rename = "data")] + #[serde(default)] + pub(super) data_args: InputVisArgs, + + #[clap(flatten)] + #[serde(rename = "sky-model")] + #[serde(default)] + pub(super) srclist_args: SkyModelWithVetoArgs, + + #[clap(flatten)] + #[serde(rename = "model")] + #[serde(default)] + pub(super) model_args: ModellingArgs, + + #[clap(flatten)] + #[serde(rename = "beam")] + #[serde(default)] + pub(super) beam_args: BeamArgs, + + #[clap(flatten)] + #[serde(rename = "peel")] + #[serde(default)] + pub(super) peel_args: PeelCliArgs, +} + +impl PeelArgs { + pub(crate) fn merge(self) -> Result { + debug!("Merging command-line arguments with the argument file"); + + let cli_args = self; + + if let Some(arg_file) = cli_args.args_file { + // Read in the file arguments. Ensure all of the file args are + // accounted for by pattern matching. + let PeelArgs { + args_file: _, + data_args, + srclist_args, + model_args, + beam_args, + peel_args, + } = unpack_arg_file!(arg_file); + + // Merge all the arguments, preferring the CLI args when available. + Ok(PeelArgs { + args_file: None, + data_args: cli_args.data_args.merge(data_args), + srclist_args: cli_args.srclist_args.merge(srclist_args), + model_args: cli_args.model_args.merge(model_args), + beam_args: cli_args.beam_args.merge(beam_args), + peel_args: cli_args.peel_args.merge(peel_args), + }) + } else { + Ok(cli_args) + } + } + + fn parse(self) -> Result { + debug!("{:#?}", self); + + let Self { + args_file: _, + data_args, + mut srclist_args, + model_args, + beam_args, + peel_args: + PeelCliArgs { + num_sources_to_peel, + num_sources_to_iono_subtract, + num_sources_to_subtract, + num_passes, + iono_time_average_factor, + iono_freq_average_factor, + low_res_iono_freq_average, + uvw_min, + uvw_max, + outputs, + output_vis_time_average, + output_vis_freq_average, + }, + } = self; + + let input_vis_params = data_args.parse("Peeling")?; + let obs_context = input_vis_params.get_obs_context(); + let total_num_tiles = input_vis_params.get_total_num_tiles(); + + let beam = beam_args.parse( + total_num_tiles, + obs_context.dipole_delays.clone(), + obs_context.dipole_gains.clone(), + Some(obs_context.input_data_type), + )?; + let modelling_params @ ModellingParams { apply_precession } = model_args.parse(); + + let LatLngHeight { + longitude_rad, + latitude_rad, + height_metres: _, + } = obs_context.array_position; + let source_list = { + srclist_args.num_sources = num_sources_to_subtract; + + let precession_info = precess_time( + longitude_rad, + latitude_rad, + obs_context.phase_centre, + input_vis_params.timeblocks.first().median, + input_vis_params.dut1, + ); + if apply_precession { + srclist_args.parse( + obs_context.phase_centre, + precession_info.lmst_j2000, + precession_info.array_latitude_j2000, + &obs_context.get_veto_freqs(), + &*beam, + )? + } else { + srclist_args.parse( + obs_context.phase_centre, + precession_info.lmst, + latitude_rad, + &obs_context.get_veto_freqs(), + &*beam, + )? + } + }; + + // Check that the number of sources to peel, iono subtract and subtract + // are sensible. When that's done, veto up to the maximum number of + // sources to subtract. + let _max_num_sources = match (num_sources_to_iono_subtract, num_sources_to_subtract) { + (Some(is), Some(s)) => { + if s < is { + panic!("The number of sources to subtract ({s}) must be at least equal to the number of sources to iono subtract ({is})"); + } + Some(s) + } + (None, Some(s)) => Some(s), + (Some(_), None) => None, + (None, None) => None, + }; + + let num_passes = NonZeroUsize::try_from(num_passes.unwrap_or(DEFAULT_NUM_PASSES)) + .map_err(|_| PeelArgsError::ZeroPasses)?; + + // Set up the iono timeblocks. These break up the input data timesteps + // (which may be averaged into timeblocks) into groups, each of which + // will be peeled together. + let iono_time_average_factor = { + let default_time_average_factor = parse_time_average_factor( + Some(input_vis_params.time_res), + Some(DEFAULT_TIME_AVERAGE_FACTOR), + NonZeroUsize::new(1).unwrap(), + ) + .unwrap_or(NonZeroUsize::new(1).unwrap()); + + let f = parse_time_average_factor( + Some(input_vis_params.time_res), + iono_time_average_factor.as_deref(), + default_time_average_factor, + ) + .map_err(|e| match e { + AverageFactorError::Zero => PeelArgsError::IonoTimeFactorZero, + AverageFactorError::NotInteger => PeelArgsError::IonoTimeFactorNotInteger, + AverageFactorError::NotIntegerMultiple { out, inp } => { + PeelArgsError::IonoTimeResNotMultiple { out, inp } + } + AverageFactorError::Parse(e) => PeelArgsError::ParseIonoTimeAverageFactor(e), + })?; + + // Check that the factor is not too big. + if f.get() > input_vis_params.timeblocks.len() { + format!( + "Cannot average {f} timeblocks; only {} are being used. Capping.", + input_vis_params.timeblocks.len() + ) + .warn(); + NonZeroUsize::new(input_vis_params.timeblocks.len()) + .expect("timeblocks is Vec1, which cannot be empty") + } else { + f + } + }; + let input_timestamps = Vec1::try_from_vec( + input_vis_params + .timeblocks + .iter() + .flat_map(|t| t.timestamps.iter().map(|(e, _)| *e)) + .collect(), + ) + .unwrap(); + let iono_timeblocks = + timesteps_to_timeblocks(&input_timestamps, iono_time_average_factor, None); + + // // Set up the chanblocks. + // let iono_freq_average_factor = { + // let default_freq_average_factor = parse_freq_average_factor( + // Some(input_vis_params.spw.freq_res), + // Some(DEFAULT_FREQ_AVERAGE_FACTOR), + // 1, + // ) + // .unwrap_or(1); + + // parse_freq_average_factor( + // Some(input_vis_params.spw.freq_res), + // iono_freq_average_factor.as_deref(), + // default_freq_average_factor, + // ) + // .map_err(|e| match e { + // AverageFactorError::Zero => PeelArgsError::IonoFreqFactorZero, + // AverageFactorError::NotInteger => PeelArgsError::IonoFreqFactorNotInteger, + // AverageFactorError::NotIntegerMultiple { out, inp } => { + // PeelArgsError::IonoFreqResNotMultiple { out, inp } + // } + // AverageFactorError::Parse(e) => PeelArgsError::ParseIonoFreqAverageFactor(e), + // })? + // }; + // // Check that the factor is not too big. + // let freq_average_factor = if freq_average_factor > input_vis_params.spw.chanblocks.len() { + // format!( + // "Cannot average {} channels; only {} are being used. Capping.", + // freq_average_factor, + // input_vis_params.spw.chanblocks.len() + // ) + // .warn(); + // input_vis_params.spw.chanblocks.len() + // } else { + // freq_average_factor + // }; + + // let mut iono_spws = { + // let all_freqs = { + // let n = input_vis_params.spw.chanblocks.len() + // + input_vis_params.spw.flagged_chanblock_indices.len(); + // let mut freqs = Vec::with_capacity(n); + // let first_freq = input_vis_params.spw.first_freq.round() as u64; + // let freq_res = input_vis_params.spw.freq_res.round() as u64; + // for i in 0..n as u64 { + // freqs.push(first_freq + freq_res * i); + // } + // freqs + // }; + + // channels_to_chanblocks( + // &all_freqs, + // input_vis_params.spw.freq_res, + // freq_average_factor, + // &input_vis_params.spw.flagged_chanblock_indices, + // ) + // }; + // // There must be at least one chanblock to do anything. + // let iono_spw = match iono_spws.as_slice() { + // // No spectral windows is the same as no chanblocks. + // [] => return Err(PeelArgsError::NoChannels.into()), + // [f] => { + // // Check that the chanblocks aren't all flagged. + // if f.chanblocks.is_empty() { + // return Err(PeelArgsError::NoChannels.into()); + // } + // iono_spws.swap_remove(0) + // } + // [f, ..] => { + // // Check that the chanblocks aren't all flagged. + // if f.chanblocks.is_empty() { + // return Err(PeelArgsError::NoChannels.into()); + // } + // // TODO: Allow picket fence. + // eprintln!("\"Picket fence\" data detected. hyperdrive does not support this right now -- exiting."); + // eprintln!("See for more info: https://MWATelescope.github.io/mwa_hyperdrive/defs/mwa/picket_fence.html"); + // std::process::exit(1); + // } + // }; + + let low_res_freq_average_factor = { + let default_iono_freq_average_factor = parse_freq_average_factor( + Some(input_vis_params.spw.freq_res), + Some(DEFAULT_IONO_FREQ_AVERAGE_FACTOR), + NonZeroUsize::new(1).unwrap(), + ) + .unwrap_or(NonZeroUsize::new(1).unwrap()); + + parse_freq_average_factor( + Some(input_vis_params.spw.freq_res), + low_res_iono_freq_average.as_deref(), + default_iono_freq_average_factor, + ) + .unwrap() + }; + let mut low_res_spws = { + let spw = &input_vis_params.spw; + let all_freqs = { + let n = spw.chanblocks.len() + spw.flagged_chanblock_indices.len(); + let mut freqs = Vec::with_capacity(n); + let first_freq = spw.first_freq.round() as u64; + let freq_res = spw.freq_res.round() as u64; + for i in 0..n as u64 { + freqs.push(first_freq + freq_res * i); + } + freqs + }; + + channels_to_chanblocks( + &all_freqs, + spw.freq_res * low_res_freq_average_factor.get() as f64, + low_res_freq_average_factor, + &HashSet::new(), + ) + }; + assert_eq!( + low_res_spws.len(), + 1, + "There should only be 1 low-res SPW, because there's only 1 high-res SPW" + ); + let low_res_spw = low_res_spws.swap_remove(0); + + // Parse vis and iono const outputs. + let (vis_outputs, iono_outputs) = { + let mut vis_outputs = vec![]; + let mut iono_outputs = vec![]; + match outputs { + // Defaults. + None => { + let pb = PathBuf::from(DEFAULT_OUTPUT_PEEL_FILENAME); + pb.extension() + .and_then(|os_str| os_str.to_str()) + .and_then(|s| VisOutputType::from_str(s).ok()) + // Tests should pick up a bad default filename. + .expect("DEFAULT_OUTPUT_PEEL_FILENAME has an unhandled extension!"); + vis_outputs.push(pb); + // TODO: Type this and clean up + let pb = PathBuf::from(DEFAULT_OUTPUT_IONO_CONSTS); + if pb.extension().and_then(|os_str| os_str.to_str()) != Some("json") { + // Tests should pick up a bad default filename. + panic!("DEFAULT_OUTPUT_IONO_CONSTS has an unhandled extension!"); + } + iono_outputs.push(pb); + } + Some(os) => { + // Just find the .json files; other code will parse the + // visibility outputs. + for file in os { + let ext = file.extension().and_then(|os_str| os_str.to_str()); + match ext.map(|s| s == "json") { + Some(true) => { + iono_outputs.push(file); + } + _ => { + vis_outputs.push(file); + } + } + } + } + }; + (vis_outputs, iono_outputs) + }; + if vis_outputs.len() + iono_outputs.len() == 0 { + return Err(PeelArgsError::NoOutput.into()); + } + + let output_vis_params = if vis_outputs.is_empty() { + None + } else { + let params = OutputVisArgs { + outputs: Some(vis_outputs), + output_vis_time_average, + output_vis_freq_average, + } + .parse( + input_vis_params.time_res, + input_vis_params.spw.freq_res, + &Vec1::try_from_vec( + input_vis_params + .timeblocks + .iter() + // .map(|t| t.median) + .flat_map(|t| t.timestamps.iter().map(|(e, _)| *e)) + .collect(), + ) + .unwrap(), + DEFAULT_OUTPUT_PEEL_FILENAME, + Some("peeled"), + )?; + Some(params) + }; + + // let tile_baseline_flags = &input_vis_params.tile_baseline_flags; + // let flagged_tiles = &tile_baseline_flags.flagged_tiles; + + // let unflagged_tile_xyzs: Vec = obs_context + // .tile_xyzs + // .par_iter() + // .enumerate() + // .filter(|(tile_index, _)| !flagged_tiles.contains(tile_index)) + // .map(|(_, xyz)| *xyz) + // .collect(); + + // // Set baseline weights from UVW cuts. Use a lambda from the centroid + // // frequency if UVW cutoffs are specified as wavelengths. + // let freq_centroid = obs_context + // .fine_chan_freqs + // .iter() + // .map(|&u| u as f64) + // .sum::() + // / obs_context.fine_chan_freqs.len() as f64; + // let lambda = marlu::constants::VEL_C / freq_centroid; + // let (uvw_min, uvw_min_metres) = { + // let (quantity, unit) = parse_wavelength( + // uvw_min + // .as_deref() + // .unwrap_or(crate::cli::di_calibrate::DEFAULT_UVW_MIN), + // ) + // .map_err(PeelError::ParseUvwMin)?; + // match unit { + // WavelengthUnit::M => ((quantity, unit), quantity), + // WavelengthUnit::L => ((quantity, unit), quantity * lambda), + // } + // }; + // let (uvw_max, uvw_max_metres) = match uvw_max { + // None => ((f64::INFINITY, WavelengthUnit::M), f64::INFINITY), + // Some(s) => { + // let (quantity, unit) = parse_wavelength(&s).map_err(PeelError::ParseUvwMax)?; + // match unit { + // WavelengthUnit::M => ((quantity, unit), quantity), + // WavelengthUnit::L => ((quantity, unit), quantity * lambda), + // } + // } + // }; + + // let (baseline_weights, num_flagged_baselines) = { + // let mut baseline_weights = Vec1::try_from_vec(vec![ + // 1.0; + // tile_baseline_flags + // .unflagged_cross_baseline_to_tile_map + // .len() + // ]) + // .map_err(|_| PeelError::NoTiles)?; + // let uvws = xyzs_to_cross_uvws( + // &unflagged_tile_xyzs, + // obs_context.phase_centre.to_hadec(lmst), + // ); + // assert_eq!(baseline_weights.len(), uvws.len()); + // let uvw_min = uvw_min_metres.powi(2); + // let uvw_max = uvw_max_metres.powi(2); + // let mut num_flagged_baselines = 0; + // for (uvw, baseline_weight) in uvws.into_iter().zip_eq(baseline_weights.iter_mut()) { + // let uvw_length = uvw.u.powi(2) + uvw.v.powi(2) + uvw.w.powi(2); + // if uvw_length < uvw_min || uvw_length > uvw_max { + // *baseline_weight = 0.0; + // num_flagged_baselines += 1; + // } + // } + // (baseline_weights, num_flagged_baselines) + // }; + + // TODO: Ensure that the order of the sources are brightest first, + // dimmest last. + let num_sources_to_iono_subtract = + num_sources_to_iono_subtract.unwrap_or(source_list.len()); + if num_sources_to_iono_subtract == 0 { + "No sources are being iono subtracted; behaviour will match vis-subtract".warn() + } + + display_warnings(); + + Ok(PeelParams { + input_vis_params, + output_vis_params, + iono_outputs, + beam, + source_list, + modelling_params, + iono_timeblocks, + iono_time_average_factor, + low_res_spw, + num_sources_to_iono_subtract, + num_passes, + }) + } + + pub(super) fn run(self, dry_run: bool) -> Result<(), HyperdriveError> { + debug!("Converting arguments into parameters"); + trace!("{:#?}", self); + let params = self.parse()?; + + if dry_run { + info!("Dry run -- exiting now."); + return Ok(()); + } + + params.run()?; + Ok(()) + } +} + +impl PeelCliArgs { + fn merge(self, other: Self) -> Self { + Self { + num_sources_to_peel: self.num_sources_to_peel.or(other.num_sources_to_peel), + num_sources_to_iono_subtract: self + .num_sources_to_iono_subtract + .or(other.num_sources_to_iono_subtract), + num_sources_to_subtract: self + .num_sources_to_subtract + .or(other.num_sources_to_subtract), + num_passes: self.num_passes.or(other.num_passes), + iono_time_average_factor: self + .iono_time_average_factor + .or(other.iono_time_average_factor), + iono_freq_average_factor: self + .iono_freq_average_factor + .or(other.iono_freq_average_factor), + low_res_iono_freq_average: self + .low_res_iono_freq_average + .or(other.low_res_iono_freq_average), + uvw_min: self.uvw_min.or(other.uvw_min), + uvw_max: self.uvw_max.or(other.uvw_max), + outputs: self.outputs.or(other.outputs), + output_vis_time_average: self + .output_vis_time_average + .or(other.output_vis_time_average), + output_vis_freq_average: self + .output_vis_freq_average + .or(other.output_vis_freq_average), + } + } +} diff --git a/src/di_calibrate/mod.rs b/src/di_calibrate/mod.rs index 2c932acf..69c8b2e4 100644 --- a/src/di_calibrate/mod.rs +++ b/src/di_calibrate/mod.rs @@ -429,7 +429,7 @@ fn make_calibration_progress_bar(num_chanblocks: usize, message: String) -> Prog /// Worker function to do calibration. #[allow(clippy::too_many_arguments)] -fn calibrate_timeblock( +pub(crate) fn calibrate_timeblock( vis_data_tfb: ArrayView3>, vis_model_tfb: ArrayView3>, mut di_jones: ArrayViewMut3>, diff --git a/src/gpu/common.cuh b/src/gpu/common.cuh index 90dc777e..f64abe81 100644 --- a/src/gpu/common.cuh +++ b/src/gpu/common.cuh @@ -15,11 +15,17 @@ #define gpuFree hipFree #define gpuMemcpy hipMemcpy #define gpuMemcpyHostToDevice hipMemcpyHostToDevice +#define gpuMemcpyDeviceToHost hipMemcpyDeviceToHost #define gpuGetErrorString hipGetErrorString #define gpuGetLastError hipGetLastError #define gpuDeviceSynchronize hipDeviceSynchronize #define gpuError_t hipError_t #define gpuSuccess hipSuccess +#define C32 hipFloatComplex +#define C64 hipDoubleComplex +#define MAKE_C32 make_hipFloatComplex +#define MAKE_C64 make_hipDoubleComplex +#define __syncwarp __syncthreads // If SINGLE is enabled, use single-precision floats everywhere. Otherwise // default to double-precision. @@ -54,12 +60,17 @@ #define gpuFree cudaFree #define gpuMemcpy cudaMemcpy #define gpuMemcpyHostToDevice cudaMemcpyHostToDevice +#define gpuMemcpyDeviceToHost cudaMemcpyDeviceToHost #define gpuGetErrorString cudaGetErrorString #define gpuGetLastError cudaGetLastError #define gpuDeviceSynchronize cudaDeviceSynchronize #define gpuError_t cudaError_t #define gpuSuccess cudaSuccess #define warpSize 32 +#define C32 cuFloatComplex +#define C64 cuDoubleComplex +#define MAKE_C32 make_cuFloatComplex +#define MAKE_C64 make_cuDoubleComplex #ifdef SINGLE #define FLOAT4 float4 @@ -85,8 +96,6 @@ #define EXP exp #endif // SINGLE #endif // __HIPCC__ -// #define C32 cuFloatComplex -// #define C64 cuDoubleComplex #ifdef __CUDACC__ #include @@ -120,6 +129,10 @@ inline __device__ COMPLEX operator*(const COMPLEX a, const COMPLEX b) { return MAKE_COMPLEX(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x); } +inline __device__ C32 operator*(const C32 a, const C64 b) { + return MAKE_C32(a.x * (float)b.x - a.y * (float)b.y, a.x * (float)b.y + a.y * (float)b.x); +} + inline __device__ void operator*=(COMPLEX &a, const COMPLEX b) { a = MAKE_COMPLEX(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x); } @@ -131,8 +144,8 @@ inline __device__ void operator+=(COMPLEX &a, const COMPLEX b) { a.y += b.y; } -inline __device__ JONES operator*(const JONES a, const FLOAT b) { - return JONES{ +inline __device__ JonesF32 operator*(const JonesF32 a, const float b) { + return JonesF32{ .j00_re = a.j00_re * b, .j00_im = a.j00_im * b, .j01_re = a.j01_re * b, @@ -144,8 +157,21 @@ inline __device__ JONES operator*(const JONES a, const FLOAT b) { }; } -inline __device__ JONES operator*(const JONES a, const COMPLEX b) { - return JONES{ +inline __device__ JonesF64 operator*(const JonesF64 a, const double b) { + return JonesF64{ + .j00_re = a.j00_re * b, + .j00_im = a.j00_im * b, + .j01_re = a.j01_re * b, + .j01_im = a.j01_im * b, + .j10_re = a.j10_re * b, + .j10_im = a.j10_im * b, + .j11_re = a.j11_re * b, + .j11_im = a.j11_im * b, + }; +} + +inline __device__ JonesF32 operator*(const JonesF32 a, const C32 b) { + return JonesF32{ .j00_re = a.j00_re * b.x - a.j00_im * b.y, .j00_im = a.j00_re * b.y + a.j00_im * b.x, .j01_re = a.j01_re * b.x - a.j01_im * b.y, @@ -157,18 +183,72 @@ inline __device__ JONES operator*(const JONES a, const COMPLEX b) { }; } -inline __device__ void operator+=(JONES &a, const JONES b) { - a.j00_re += b.j00_re; - a.j00_im += b.j00_im; - a.j01_re += b.j01_re; - a.j01_im += b.j01_im; - a.j10_re += b.j10_re; - a.j10_im += b.j10_im; - a.j11_re += b.j11_re; - a.j11_im += b.j11_im; +inline __device__ JonesF32 operator*(const JonesF32 a, const C64 b) { + return JonesF32{ + .j00_re = a.j00_re * (float)b.x - a.j00_im * (float)b.y, + .j00_im = a.j00_re * (float)b.y + a.j00_im * (float)b.x, + .j01_re = a.j01_re * (float)b.x - a.j01_im * (float)b.y, + .j01_im = a.j01_re * (float)b.y + a.j01_im * (float)b.x, + .j10_re = a.j10_re * (float)b.x - a.j10_im * (float)b.y, + .j10_im = a.j10_re * (float)b.y + a.j10_im * (float)b.x, + .j11_re = a.j11_re * (float)b.x - a.j11_im * (float)b.y, + .j11_im = a.j11_re * (float)b.y + a.j11_im * (float)b.x, + }; +} + +inline __device__ JonesF64 operator*(const JonesF64 a, const C32 b) { + return JonesF64{ + .j00_re = a.j00_re * b.x - a.j00_im * b.y, + .j00_im = a.j00_re * b.y + a.j00_im * b.x, + .j01_re = a.j01_re * b.x - a.j01_im * b.y, + .j01_im = a.j01_re * b.y + a.j01_im * b.x, + .j10_re = a.j10_re * b.x - a.j10_im * b.y, + .j10_im = a.j10_re * b.y + a.j10_im * b.x, + .j11_re = a.j11_re * b.x - a.j11_im * b.y, + .j11_im = a.j11_re * b.y + a.j11_im * b.x, + }; } -inline __device__ void operator+=(JonesF32 &a, const JonesF64 b) { +inline __device__ JonesF64 operator*(const JonesF64 a, const C64 b) { + return JonesF64{ + .j00_re = a.j00_re * b.x - a.j00_im * b.y, + .j00_im = a.j00_re * b.y + a.j00_im * b.x, + .j01_re = a.j01_re * b.x - a.j01_im * b.y, + .j01_im = a.j01_re * b.y + a.j01_im * b.x, + .j10_re = a.j10_re * b.x - a.j10_im * b.y, + .j10_im = a.j10_re * b.y + a.j10_im * b.x, + .j11_re = a.j11_re * b.x - a.j11_im * b.y, + .j11_im = a.j11_re * b.y + a.j11_im * b.x, + }; +} + +inline __device__ JonesF32 operator+(JonesF32 a, JonesF32 b) { + return JonesF32{ + .j00_re = a.j00_re + b.j00_re, + .j00_im = a.j00_im + b.j00_im, + .j01_re = a.j01_re + b.j01_re, + .j01_im = a.j01_im + b.j01_im, + .j10_re = a.j10_re + b.j10_re, + .j10_im = a.j10_im + b.j10_im, + .j11_re = a.j11_re + b.j11_re, + .j11_im = a.j11_im + b.j11_im, + }; +} + +inline __device__ JonesF64 operator+(JonesF64 a, JonesF64 b) { + return JonesF64{ + .j00_re = a.j00_re + b.j00_re, + .j00_im = a.j00_im + b.j00_im, + .j01_re = a.j01_re + b.j01_re, + .j01_im = a.j01_im + b.j01_im, + .j10_re = a.j10_re + b.j10_re, + .j10_im = a.j10_im + b.j10_im, + .j11_re = a.j11_re + b.j11_re, + .j11_im = a.j11_im + b.j11_im, + }; +} + +inline __device__ void operator+=(JonesF32 &a, const JonesF64 &b) { a.j00_re += (float)b.j00_re; a.j00_im += (float)b.j00_im; a.j01_re += (float)b.j01_re; @@ -179,6 +259,129 @@ inline __device__ void operator+=(JonesF32 &a, const JonesF64 b) { a.j11_im += (float)b.j11_im; } +inline __device__ void operator+=(JonesF32 &a, JonesF32 b) { + a.j00_re += b.j00_re; + a.j00_im += b.j00_im; + a.j01_re += b.j01_re; + a.j01_im += b.j01_im; + a.j10_re += b.j10_re; + a.j10_im += b.j10_im; + a.j11_re += b.j11_re; + a.j11_im += b.j11_im; +} + +inline __device__ void operator+=(JonesF64 &a, JonesF64 b) { + a.j00_re += b.j00_re; + a.j00_im += b.j00_im; + a.j01_re += b.j01_re; + a.j01_im += b.j01_im; + a.j10_re += b.j10_re; + a.j10_im += b.j10_im; + a.j11_re += b.j11_re; + a.j11_im += b.j11_im; +} + +inline __device__ void operator+=(volatile JonesF32 &a, volatile JonesF32 &b) { + a.j00_re += b.j00_re; + a.j00_im += b.j00_im; + a.j01_re += b.j01_re; + a.j01_im += b.j01_im; + a.j10_re += b.j10_re; + a.j10_im += b.j10_im; + a.j11_re += b.j11_re; + a.j11_im += b.j11_im; +} + +inline __device__ void operator+=(volatile JonesF64 &a, volatile JonesF64 &b) { + a.j00_re += b.j00_re; + a.j00_im += b.j00_im; + a.j01_re += b.j01_re; + a.j01_im += b.j01_im; + a.j10_re += b.j10_re; + a.j10_im += b.j10_im; + a.j11_re += b.j11_re; + a.j11_im += b.j11_im; +} + +inline __device__ void operator+=(JonesF64 &a, JonesF32 &b) { + a.j00_re += (double)b.j00_re; + a.j00_im += (double)b.j00_im; + a.j01_re += (double)b.j01_re; + a.j01_im += (double)b.j01_im; + a.j10_re += (double)b.j10_re; + a.j10_im += (double)b.j10_im; + a.j11_re += (double)b.j11_re; + a.j11_im += (double)b.j11_im; +} + +inline __device__ JONES operator/(JONES a, FLOAT b) { + return JONES{ + .j00_re = a.j00_re / b, + .j00_im = a.j00_im / b, + .j01_re = a.j01_re / b, + .j01_im = a.j01_im / b, + .j10_re = a.j10_re / b, + .j10_im = a.j10_im / b, + .j11_re = a.j11_re / b, + .j11_im = a.j11_im / b, + }; +} + +inline __device__ void operator/=(JonesF64 &a, double b) { + a.j00_re /= b; + a.j00_im /= b; + a.j01_re /= b; + a.j01_im /= b; + a.j10_re /= b; + a.j10_im /= b; + a.j11_re /= b; + a.j11_im /= b; +} + +inline __device__ void operator-=(JonesF32 &a, JonesF32 b) { + a.j00_re -= b.j00_re; + a.j00_im -= b.j00_im; + a.j01_re -= b.j01_re; + a.j01_im -= b.j01_im; + a.j10_re -= b.j10_re; + a.j10_im -= b.j10_im; + a.j11_re -= b.j11_re; + a.j11_im -= b.j11_im; +} + +inline __device__ void operator-=(JonesF64 &a, JonesF64 b) { + a.j00_re -= b.j00_re; + a.j00_im -= b.j00_im; + a.j01_re -= b.j01_re; + a.j01_im -= b.j01_im; + a.j10_re -= b.j10_re; + a.j10_im -= b.j10_im; + a.j11_re -= b.j11_re; + a.j11_im -= b.j11_im; +} + +inline __device__ void operator-=(JonesF32 &a, JonesF64 b) { + a.j00_re -= (float)b.j00_re; + a.j00_im -= (float)b.j00_im; + a.j01_re -= (float)b.j01_re; + a.j01_im -= (float)b.j01_im; + a.j10_re -= (float)b.j10_re; + a.j10_im -= (float)b.j10_im; + a.j11_re -= (float)b.j11_re; + a.j11_im -= (float)b.j11_im; +} + +inline __device__ void operator-=(JonesF64 &a, JonesF32 b) { + a.j00_re -= (double)b.j00_re; + a.j00_im -= (double)b.j00_im; + a.j01_re -= (double)b.j01_re; + a.j01_im -= (double)b.j01_im; + a.j10_re -= (double)b.j10_re; + a.j10_im -= (double)b.j10_im; + a.j11_re -= (double)b.j11_re; + a.j11_im -= (double)b.j11_im; +} + inline __device__ UVW operator*(const UVW a, const FLOAT b) { return UVW{ .u = a.u * b, diff --git a/src/gpu/mod.rs b/src/gpu/mod.rs index 801973c9..38455c7d 100644 --- a/src/gpu/mod.rs +++ b/src/gpu/mod.rs @@ -31,6 +31,7 @@ cfg_if::cfg_if! { include!("types_single.rs"); include!("model_single.rs"); + include!("peel_single.rs"); } else if #[cfg(all(any(feature = "cuda", feature = "hip"), not(feature = "gpu-single")))] { /// f64 (not using "gpu-single") pub(crate) type GpuFloat = f64; @@ -38,6 +39,7 @@ cfg_if::cfg_if! { include!("types_double.rs"); include!("model_double.rs"); + include!("peel_double.rs"); } } @@ -132,7 +134,7 @@ pub(crate) unsafe fn peek_and_sync(gpu_call: GpuCall) -> Result<(), GpuError> { /// [`gpuFree`] is called on the pointer. #[derive(Debug)] pub(crate) struct DevicePointer { - ptr: *mut T, + pub(crate) ptr: *mut T, /// The number of bytes allocated against `ptr`. size: usize, @@ -267,7 +269,6 @@ impl DevicePointer { } /// Clear all of the bytes in the buffer by writing zeros. - #[cfg(test)] pub(crate) fn clear(&mut self) { #[cfg(feature = "cuda")] use cuda_runtime_sys::cudaMemset as gpuMemset; diff --git a/src/gpu/peel.cu b/src/gpu/peel.cu new file mode 100644 index 00000000..4c0381e2 --- /dev/null +++ b/src/gpu/peel.cu @@ -0,0 +1,681 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include + +#include "common.cuh" +#include "peel.h" + +/** + * Turn XYZs into UVWs. Multiple sets of XYZs over time can be converted. + * Expects the device to be parallel over baselines. + */ +__global__ void xyzs_to_uvws_kernel(const XYZ *xyzs, const FLOAT *lmsts, UVW *uvws, RADec pointing_centre, + int num_tiles, int num_baselines, int num_timesteps) { + const int i_bl = threadIdx.x + (blockDim.x * blockIdx.x); + if (i_bl >= num_baselines) + return; + + // Find the tile indices from the baseline index. `num_tiles` has to be + // subtracted by 1 to make it "0 index". + const float n = (float)(num_tiles - 1); + const float tile1f = floorf(-0.5 * sqrtf(4.0 * n * (n + 1.0) - 8.0 * i_bl + 1.0) + n + 0.5); + const int tile2 = (int)(i_bl - tile1f * (n - (tile1f + 1.0) / 2.0) + 1.0); + const int tile1 = (int)tile1f; + + FLOAT s_ha, c_ha, s_dec, c_dec; + SINCOS(pointing_centre.dec, &s_dec, &c_dec); + + for (int i_time = 0; i_time < num_timesteps; i_time++) { + XYZ xyz = xyzs[i_time * num_tiles + tile1]; + const XYZ xyz2 = xyzs[i_time * num_tiles + tile2]; + xyz.x -= xyz2.x; + xyz.y -= xyz2.y; + xyz.z -= xyz2.z; + + const FLOAT hour_angle = lmsts[i_time] - pointing_centre.ra; + SINCOS(hour_angle, &s_ha, &c_ha); + + uvws[i_time * num_baselines + (int)i_bl] = UVW{ + .u = s_ha * xyz.x + c_ha * xyz.y, + .v = -s_dec * c_ha * xyz.x + s_dec * s_ha * xyz.y + c_dec * xyz.z, + .w = c_dec * c_ha * xyz.x - c_dec * s_ha * xyz.y + s_dec * xyz.z, + }; + } +} + +/** + * Kernel for rotating visibilities and averaging them into "low-resolution" + * visibilities. + * + * The visibilities should be ordered in time, frequency and baseline (slowest + * to fastest). The weights should never be negative; this allows us to avoid + * special logic when averaging. + */ +__global__ void rotate_average_kernel(const JonesF32 *high_res_vis, const float *high_res_weights, + JonesF32 *low_res_vis, RADec pointing_centre, const int num_timesteps, + const int num_tiles, const int num_baselines, const int num_freqs, + const int freq_average_factor, const FLOAT *lmsts, const XYZ *xyzs, + const UVW *uvws_from, UVW *uvws_to, const FLOAT *lambdas) { + const int i_bl = threadIdx.x + (blockDim.x * blockIdx.x); + if (i_bl >= num_baselines) + return; + + for (int i_freq = 0; i_freq < num_freqs; i_freq += freq_average_factor) { + JonesF64 vis_weighted_sum = JonesF64{ + .j00_re = 0.0, + .j00_im = 0.0, + .j01_re = 0.0, + .j01_im = 0.0, + .j10_re = 0.0, + .j10_im = 0.0, + .j11_re = 0.0, + .j11_im = 0.0, + }; + double weight_sum = 0.0; + + for (int i_time = 0; i_time < num_timesteps; i_time++) { + // Prepare an "argument" for later. + const double arg = -TAU * ((double)uvws_to[i_time * num_baselines + i_bl].w - + (double)uvws_from[i_time * num_baselines + i_bl].w); + for (int i_freq_chunk = i_freq; i_freq_chunk < i_freq + freq_average_factor; i_freq_chunk++) { + C64 complex; + sincos(arg / lambdas[i_freq_chunk], &complex.y, &complex.x); + + const int step = (i_time * num_freqs + i_freq_chunk) * num_baselines + i_bl; + const double weight = high_res_weights[step]; + const JonesF32 vis_single = high_res_vis[step]; + const JonesF64 vis_double = JonesF64{ + .j00_re = vis_single.j00_re, + .j00_im = vis_single.j00_im, + .j01_re = vis_single.j01_re, + .j01_im = vis_single.j01_im, + .j10_re = vis_single.j10_re, + .j10_im = vis_single.j10_im, + .j11_re = vis_single.j11_re, + .j11_im = vis_single.j11_im, + }; + const JonesF64 rotated_weighted_vis = vis_double * weight * complex; + + vis_weighted_sum += rotated_weighted_vis; + weight_sum += weight; + } + } + + // If `weight_sum` is bigger than 0, use it in division, otherwise just + // divide by 1. We do this so we don't get NaN values, and we don't use + // if statements in case the compiler optimises this better to avoid + // warp divergence. + vis_weighted_sum /= (weight_sum > 0.0) ? weight_sum : 1.0; + + const int low_res_step = (i_freq / freq_average_factor) * num_baselines + i_bl; + low_res_vis[low_res_step] = JonesF32{ + .j00_re = (float)vis_weighted_sum.j00_re, + .j00_im = (float)vis_weighted_sum.j00_im, + .j01_re = (float)vis_weighted_sum.j01_re, + .j01_im = (float)vis_weighted_sum.j01_im, + .j10_re = (float)vis_weighted_sum.j10_re, + .j10_im = (float)vis_weighted_sum.j10_im, + .j11_re = (float)vis_weighted_sum.j11_re, + .j11_im = (float)vis_weighted_sum.j11_im, + }; + // low_res_weights[low_res_step] = weight_sum; + } +} + +/** + * + */ +__device__ void apply_iono(const JonesF32 *vis, JonesF32 *vis_out, const FLOAT iono_const_alpha, + const FLOAT iono_const_beta, const int num_baselines, const int num_freqs, const UVW *uvws, + const FLOAT *lambdas_m) { + const int i_bl = threadIdx.x + (blockDim.x * blockIdx.x); + // No need to check if this thread should continue; this is a device + // function. + + const UVW uvw = uvws[i_bl]; + const FLOAT arg = -TAU * (uvw.u * iono_const_alpha + uvw.v * iono_const_beta); + + for (int i_freq = 0; i_freq < num_freqs; i_freq++) { + COMPLEX complex; + // The baseline UV is in units of metres, so we need to divide by λ to + // use it in an exponential. But we're also multiplying by λ², so just + // multiply by λ. + SINCOS(arg * lambdas_m[i_freq], &complex.y, &complex.x); + + const int step = i_freq * num_baselines + i_bl; + vis_out[step] = vis[step] * complex; + } +} + +/** + * This is extremely difficult to explain in a comment. This code is derived + * from this presentation + * (https://developer.download.nvidia.com/assets/cuda/files/reduction.pdf), but + * I've had to add `__syncwarp` calls otherwise the sum is incorrect. + */ +template __device__ void warp_reduce(volatile JonesF64 *sdata, int tid) { + if (BLOCK_SIZE >= 64) { + sdata[tid] += sdata[tid + 32]; + __syncwarp(); + } + if (BLOCK_SIZE >= 32) { + sdata[tid] += sdata[tid + 16]; + __syncwarp(); + } + if (BLOCK_SIZE >= 16) { + sdata[tid] += sdata[tid + 8]; + __syncwarp(); + } + if (BLOCK_SIZE >= 8) { + sdata[tid] += sdata[tid + 4]; + __syncwarp(); + } + if (BLOCK_SIZE >= 4) { + sdata[tid] += sdata[tid + 2]; + __syncwarp(); + } + if (BLOCK_SIZE >= 2) { + sdata[tid] += sdata[tid + 1]; + } +} + +/** + * Kernel to add ionospherically-related values (all baselines for a frequency). + */ +template __global__ void reduce_baselines(JonesF64 *data, const int num_baselines) { + // Every thread has an element of shared memory. This is useful for speeding + // up accumulation. + __shared__ JonesF64 sdata[BLOCK_SIZE]; + // tid is "thread ID". + int tid = threadIdx.x; + // This thread will start accessing data from this index. It is intended to + // be targeting a specific frequency (`blockIdx.x`). + int i = num_baselines * blockIdx.x + tid; + + // Initialise the thread's shared memory. + sdata[tid] = JonesF64{ + .j00_re = 0.0, + .j00_im = 0.0, + .j01_re = 0.0, + .j01_im = 0.0, + .j10_re = 0.0, + .j10_im = 0.0, + .j11_re = 0.0, + .j11_im = 0.0, + }; + + // Accumulate all the baselines for this frequency. `i` is incremented so + // that all threads can do coalesced reads. + while (i < num_baselines * (int)(blockIdx.x + 1)) { + sdata[tid] += data[i]; + i += BLOCK_SIZE; + } + // The threads may be out of sync because some have a not-too-big index and + // others have a too-big index. So we sync. The following syncs are done for + // the same reason. + __syncthreads(); + + if (BLOCK_SIZE >= 512) { + if (tid < 256) { + sdata[tid] += sdata[tid + 256]; + } + __syncthreads(); + } + if (BLOCK_SIZE >= 256) { + if (tid < 128) { + sdata[tid] += sdata[tid + 128]; + } + __syncthreads(); + } + if (BLOCK_SIZE >= 128) { + if (tid < 64) { + sdata[tid] += sdata[tid + 64]; + } + __syncthreads(); + } + + // At this point, we need to add the results for the first warp (first 32 + // threads). We no longer need to sync threads because we don't care about + // thread indices >= 32. + if (tid < 32) + warp_reduce(sdata, tid); + // The first index has the sum; write it out so we can use it later. The + // index is the frequency. + if (tid == 0) + data[blockIdx.x] = sdata[0]; +} + +/** + * Kernel to add ionospherically-related values (all frequencies, after all + * baselines have been added in a per-frequency basis). There should only be one + * thread block running this kernel. + */ +template +__global__ void reduce_freqs(JonesF64 *data, const FLOAT *lambdas_m, const int num_freqs, double *iono_consts) { + // Every thread has an element of shared memory. This is useful for speeding + // up accumulation. + __shared__ JonesF64 sdata[BLOCK_SIZE]; + // tid is "thread ID". + int tid = threadIdx.x; + + // Initialise the thread's shared memory. + sdata[tid] = JonesF64{ + .j00_re = 0.0, + .j00_im = 0.0, + .j01_re = 0.0, + .j01_im = 0.0, + .j10_re = 0.0, + .j10_im = 0.0, + .j11_re = 0.0, + .j11_im = 0.0, + }; + + // Accumulate the per-frequency ionospheric fits. `i_freq` is incremented so + // that all threads can do coalesced reads. + for (int i_freq = tid; i_freq < num_freqs; i_freq += BLOCK_SIZE) { + // The data we're accessing here represents all of the ionospheric + // values for each frequency. These values have not been scaled by λ, so + // we do that here. When the values were generated, UV were not scaled + // by λ, so below we use λ² for λ⁴, and λ for λ². + const double lambda = (double)lambdas_m[i_freq]; + const double lambda_2 = lambda * lambda; + + // Scale the ionospheric values by lambda. + JonesF64 j = data[i_freq]; + j.j00_re *= lambda_2; // a_uu + j.j00_im *= lambda_2; // a_uv + j.j01_re *= lambda_2; // a_vv + j.j01_im *= -lambda; // aa_u + j.j10_re *= -lambda; // aa_v + + sdata[tid] += j; + } + // The threads may be out of sync because some have a not-too-big index and + // others have a too-big index. So we sync. The following syncs are done for + // the same reason. + __syncthreads(); + + if (BLOCK_SIZE >= 512) { + if (tid < 256) { + sdata[tid] += sdata[tid + 256]; + } + __syncthreads(); + } + if (BLOCK_SIZE >= 256) { + if (tid < 128) { + sdata[tid] += sdata[tid + 128]; + } + __syncthreads(); + } + if (BLOCK_SIZE >= 128) { + if (tid < 64) { + sdata[tid] += sdata[tid + 64]; + } + __syncthreads(); + } + + if (tid < 32) + warp_reduce(sdata, tid); + if (tid == 0) { + const double a_uu = sdata[0].j00_re; + const double a_uv = sdata[0].j00_im; + const double a_vv = sdata[0].j01_re; + const double aa_u = sdata[0].j01_im; + const double aa_v = sdata[0].j10_re; + // const double s_vm = sdata[0].j10_im; + // const double s_mm = sdata[0].j11_re; + + // Not necessary, but might be useful for checking things. + // data[0] = sdata[0]; + + const double denom = TAU * (a_uu * a_vv - a_uv * a_uv); + iono_consts[0] += (aa_u * a_vv - aa_v * a_uv) / denom; + iono_consts[1] += (aa_v * a_uu - aa_u * a_uv) / denom; + } +} + +/** + * Kernel for ... + */ +__global__ void iono_loop_kernel(const JonesF32 *vis_residual, const float *vis_weights, const JonesF32 *vis_model, + JonesF32 *vis_model_rotated, const double *iono_consts, JonesF64 *iono_fits, + const int num_iterations, const int num_baselines, const int num_freqs, + const FLOAT *lmsts, const UVW *uvws, const FLOAT *lambdas_m) { + const int i_bl = threadIdx.x + (blockDim.x * blockIdx.x); + if (i_bl >= num_baselines) + return; + + const UVW uvw = uvws[i_bl]; + + // Apply the latest iono constants to the model visibilities. + const double iono_const_alpha = iono_consts[0]; + const double iono_const_beta = iono_consts[1]; + + // TODO: Would it be better to avoid the function call? + // TODO: Use the updated source position for the UVWs? + apply_iono(vis_model, vis_model_rotated, iono_const_alpha, iono_const_beta, num_baselines, num_freqs, uvws, + lambdas_m); + + for (int i_freq = 0; i_freq < num_freqs; i_freq++) { + // Normally, we would divide by λ to get dimensionless UV. However, UV + // are only used to determine a_uu, a_uv, a_vv, which are also scaled by + // lambda. So... don't divide by λ. + const double u = (double)uvw.u; + const double v = (double)uvw.v; + + const int step = i_freq * num_baselines + i_bl; + const double weight = (double)vis_weights[step]; + const JonesF32 *residual = &vis_residual[step]; + const double residual_i_re = residual->j00_re + residual->j11_re; + const double residual_i_im = residual->j00_im + residual->j11_im; + const JonesF32 *model = &vis_model_rotated[step]; + const double model_i_re = model->j00_re + model->j11_re; + const double model_i_im = model->j00_im + model->j11_im; + + const double mr = model_i_re * (residual_i_im - model_i_im); + const double mm = model_i_re * model_i_re; + + JonesF64 j = JonesF64{ + // Rather than multiplying by λ here, do it later, when all these + // values are added together for a single frequency. This means + // we'll have higher precision overall and fewer multiplies. + .j00_re = weight * mm * u * u, // a_uu + .j00_im = weight * mm * u * v, // a_uv + .j01_re = weight * mm * v * v, // a_vv + .j01_im = weight * mr * u, // aa_u + .j10_re = weight * mr * v, // aa_v + .j10_im = weight * model_i_re * residual_i_re, // s_vm + .j11_re = weight * mm, // s_mm + .j11_im = 1.0, + }; + iono_fits[step] = j; + } +} + +__global__ void subtract_iono_kernel(JonesF32 *vis_residual, const JonesF32 *vis_model, const double iono_const_alpha, + const double iono_const_beta, const double old_iono_const_alpha, + const double old_iono_const_beta, const UVW *uvws, const FLOAT *lambdas_m, + const int num_timesteps, const int num_baselines, const int num_freqs) { + const int i_bl = threadIdx.x + (blockDim.x * blockIdx.x); + if (i_bl >= num_baselines) + return; + + for (int i_time = 0; i_time < num_timesteps; i_time++) { + const UVW uvw = uvws[i_time * num_baselines + i_bl]; + const FLOAT arg = -TAU * (uvw.u * iono_const_alpha + uvw.v * iono_const_beta); + const FLOAT old_arg = -TAU * (uvw.u * old_iono_const_alpha + uvw.v * old_iono_const_beta); + for (int i_freq = 0; i_freq < num_freqs; i_freq++) { + const FLOAT lambda = lambdas_m[i_freq]; + + COMPLEX complex; + // The baseline UV is in units of metres, so we need to divide by λ to + // use it in an exponential. But we're also multiplying by λ², so just + // multiply by λ. + SINCOS(arg * lambda, &complex.y, &complex.x); + COMPLEX old_complex; + SINCOS(old_arg * lambda, &old_complex.y, &old_complex.x); + + const int step = (i_time * num_freqs + i_freq) * num_baselines + i_bl; + JonesF32 r = vis_residual[step]; + const JonesF32 m = vis_model[step]; + + // Promoting the Jones matrices makes things demonstrably more + // precise. + JonesF64 r2 = JonesF64{ + .j00_re = r.j00_re, + .j00_im = r.j00_im, + .j01_re = r.j01_re, + .j01_im = r.j01_im, + .j10_re = r.j10_re, + .j10_im = r.j10_im, + .j11_re = r.j11_re, + .j11_im = r.j11_im, + }; + JonesF64 m2 = JonesF64{ + .j00_re = m.j00_re, + .j00_im = m.j00_im, + .j01_re = m.j01_re, + .j01_im = m.j01_im, + .j10_re = m.j10_re, + .j10_im = m.j10_im, + .j11_re = m.j11_re, + .j11_im = m.j11_im, + }; + + r2 += m2 * old_complex; + r2 -= m2 * complex; + vis_residual[step] = JonesF32{ + .j00_re = (float)r2.j00_re, + .j00_im = (float)r2.j00_im, + .j01_re = (float)r2.j01_re, + .j01_im = (float)r2.j01_im, + .j10_re = (float)r2.j10_re, + .j10_im = (float)r2.j10_im, + .j11_re = (float)r2.j11_re, + .j11_im = (float)r2.j11_im, + }; + } + } +} + +__global__ void add_model_kernel(JonesF32 *vis_residual, const JonesF32 *vis_model, const FLOAT iono_const_alpha, + const FLOAT iono_const_beta, const FLOAT *lambdas_m, const UVW *uvws, + const int num_timesteps, const int num_baselines, const int num_freqs) { + const int i_bl = threadIdx.x + (blockDim.x * blockIdx.x); + if (i_bl >= num_baselines) + return; + + for (int i_time = 0; i_time < num_timesteps; i_time++) { + const UVW uvw = uvws[i_time * num_baselines + i_bl]; + const FLOAT arg = -TAU * (uvw.u * iono_const_alpha + uvw.v * iono_const_beta); + + for (int i_freq = 0; i_freq < num_freqs; i_freq++) { + COMPLEX complex; + if (iono_const_alpha == 0.0 && iono_const_beta == 0.0) { + complex.x = 1.0; + complex.y = 0.0; + } else { + // The baseline UV is in units of metres, so we need to divide + // by λ to use it in an exponential. But we're also multiplying + // by λ², so just multiply by λ. + SINCOS(arg * lambdas_m[i_freq], &complex.y, &complex.x); + } + + const int step = (i_time * num_freqs + i_freq) * num_baselines + i_bl; + vis_residual[step] += vis_model[step] * complex; + } + } +} + +/* Host functions */ + +extern "C" const char *xyzs_to_uvws(const XYZ *d_xyzs, const FLOAT *d_lmsts, UVW *d_uvws, RADec pointing_centre, + int num_tiles, int num_baselines, int num_timesteps) { + dim3 gridDim, blockDim; + // Thread blocks are distributed by baseline indices. + blockDim.x = 256; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = (int)ceil((double)num_baselines / (double)blockDim.x); + gridDim.y = 1; + gridDim.z = 1; + + xyzs_to_uvws_kernel<<>>(d_xyzs, d_lmsts, d_uvws, pointing_centre, num_tiles, num_baselines, + num_timesteps); + gpuError_t error_id = gpuDeviceSynchronize(); + if (error_id != gpuSuccess) { + return gpuGetErrorString(error_id); + } + error_id = gpuGetLastError(); + if (error_id != gpuSuccess) { + return gpuGetErrorString(error_id); + } + + return NULL; +} + +extern "C" const char *rotate_average(const JonesF32 *d_high_res_vis, const float *d_high_res_weights, + JonesF32 *d_low_res_vis, RADec pointing_centre, const int num_timesteps, + const int num_tiles, const int num_baselines, const int num_freqs, + const int freq_average_factor, const FLOAT *d_lmsts, const XYZ *d_xyzs, + const UVW *d_uvws_from, UVW *d_uvws_to, const FLOAT *d_lambdas) { + dim3 gridDim, blockDim; + // Thread blocks are distributed by baseline indices. + blockDim.x = 256; + gridDim.x = (int)ceil((double)num_baselines / (double)blockDim.x); + + gpuError_t error_id; + + // Prepare the "to" UVWs. + xyzs_to_uvws_kernel<<>>(d_xyzs, d_lmsts, d_uvws_to, pointing_centre, num_tiles, num_baselines, + num_timesteps); + // This function is unlikely to fail. + // error_id = gpuDeviceSynchronize(); + // if (error_id != gpuSuccess) { + // return gpuGetErrorString(error_id); + // } + // error_id = gpuGetLastError(); + // if (error_id != gpuSuccess) { + // return gpuGetErrorString(error_id); + // } + + rotate_average_kernel<<>>( + d_high_res_vis, d_high_res_weights, d_low_res_vis, pointing_centre, num_timesteps, num_tiles, num_baselines, + num_freqs, freq_average_factor, d_lmsts, d_xyzs, d_uvws_from, d_uvws_to, d_lambdas); + error_id = gpuDeviceSynchronize(); + if (error_id != gpuSuccess) { + return gpuGetErrorString(error_id); + } + error_id = gpuGetLastError(); + if (error_id != gpuSuccess) { + return gpuGetErrorString(error_id); + } + + return NULL; +} + +extern "C" const char *iono_loop(const JonesF32 *d_vis_residual, const float *d_vis_weights, + const JonesF32 *d_vis_model, JonesF32 *d_vis_model_rotated, JonesF64 *d_iono_fits, + double *iono_const_alpha, double *iono_const_beta, const int num_timesteps, + const int num_tiles, const int num_baselines, const int num_freqs, + const int num_iterations, const FLOAT *d_lmsts, const UVW *d_uvws, + const FLOAT *d_lambdas_m) { + // Thread blocks are distributed by baseline indices. + dim3 gridDim, blockDim; + blockDim.x = 256; + gridDim.x = (int)ceil((double)num_baselines / (double)blockDim.x); + // These are used to do add ionospheric fits (all baselines per frequency). + dim3 gridDimAdd, blockDimAdd; + const int NUM_ADD_THREADS = 256; + blockDimAdd.x = NUM_ADD_THREADS; + gridDimAdd.x = num_freqs; + // These are used to accumulate the per-frequency ionospheric fits. + dim3 gridDimAdd2, blockDimAdd2; + const int NUM_ADD_THREADS2 = 256; + blockDimAdd2.x = NUM_ADD_THREADS2; + gridDimAdd2.x = 1; + + double *d_iono_consts; + gpuMalloc(&d_iono_consts, 2 * sizeof(double)); + gpuMemcpy(d_iono_consts, iono_const_alpha, sizeof(double), gpuMemcpyHostToDevice); + gpuMemcpy(d_iono_consts + 1, iono_const_beta, sizeof(double), gpuMemcpyHostToDevice); + + for (int iteration = 0; iteration < num_iterations; iteration++) { + // Do the work for one loop of the iteration. + iono_loop_kernel<<>>(d_vis_residual, d_vis_weights, d_vis_model, d_vis_model_rotated, + d_iono_consts, d_iono_fits, num_tiles, num_baselines, num_freqs, + d_lmsts, d_uvws, d_lambdas_m); + gpuError_t error_id = gpuDeviceSynchronize(); + if (error_id != gpuSuccess) { + return gpuGetErrorString(error_id); + } + error_id = gpuGetLastError(); + if (error_id != gpuSuccess) { + return gpuGetErrorString(error_id); + } + + // Sum the iono fits. + reduce_baselines<<>>(d_iono_fits, num_baselines); + error_id = gpuDeviceSynchronize(); + if (error_id != gpuSuccess) { + return gpuGetErrorString(error_id); + } + error_id = gpuGetLastError(); + if (error_id != gpuSuccess) { + return gpuGetErrorString(error_id); + } + + reduce_freqs + <<>>(d_iono_fits, d_lambdas_m, num_freqs, d_iono_consts); + error_id = gpuDeviceSynchronize(); + if (error_id != gpuSuccess) { + return gpuGetErrorString(error_id); + } + error_id = gpuGetLastError(); + if (error_id != gpuSuccess) { + return gpuGetErrorString(error_id); + } + + // // Sane? + // printf("iter %d\n", iteration); + // gpuMemcpy(iono_const_alpha, d_iono_consts, sizeof(double), gpuMemcpyDeviceToHost); + // gpuMemcpy(iono_const_beta, d_iono_consts + 1, sizeof(double), gpuMemcpyDeviceToHost); + // printf("%.4e %.4e\n", *iono_const_alpha, *iono_const_beta); + } + + gpuMemcpy(iono_const_alpha, d_iono_consts, sizeof(double), gpuMemcpyDeviceToHost); + gpuMemcpy(iono_const_beta, d_iono_consts + 1, sizeof(double), gpuMemcpyDeviceToHost); + gpuFree(d_iono_consts); + // printf("%.4e %.4e\n", *iono_const_alpha, *iono_const_beta); + + return NULL; +} + +extern "C" const char *subtract_iono(JonesF32 *d_vis_residual, const JonesF32 *d_vis_model, double iono_const_alpha, + double iono_const_beta, double old_iono_const_alpha, double old_iono_const_beta, + const UVW *d_uvws, const FLOAT *d_lambdas_m, const int num_timesteps, + const int num_baselines, const int num_freqs) { + // Thread blocks are distributed by baseline indices. + dim3 gridDim, blockDim; + blockDim.x = 256; + gridDim.x = (int)ceil((double)num_baselines / (double)blockDim.x); + + subtract_iono_kernel<<>>(d_vis_residual, d_vis_model, iono_const_alpha, iono_const_beta, + old_iono_const_alpha, old_iono_const_beta, d_uvws, d_lambdas_m, + num_timesteps, num_baselines, num_freqs); + gpuError_t error_id = gpuDeviceSynchronize(); + if (error_id != gpuSuccess) { + return gpuGetErrorString(error_id); + } + error_id = gpuGetLastError(); + if (error_id != gpuSuccess) { + return gpuGetErrorString(error_id); + } + + return NULL; +} + +extern "C" const char *add_model(JonesF32 *d_vis_residual, const JonesF32 *d_vis_model, const FLOAT iono_const_alpha, + const FLOAT iono_const_beta, const FLOAT *d_lambdas_m, const UVW *d_uvws, + const int num_timesteps, const int num_baselines, const int num_freqs) { + // Thread blocks are distributed by baseline indices. + dim3 gridDim, blockDim; + blockDim.x = 256; + gridDim.x = (int)ceil((double)num_baselines / (double)blockDim.x); + + add_model_kernel<<>>(d_vis_residual, d_vis_model, iono_const_alpha, iono_const_beta, d_lambdas_m, + d_uvws, num_timesteps, num_baselines, num_freqs); + + gpuError_t error_id = gpuDeviceSynchronize(); + if (error_id != gpuSuccess) { + return gpuGetErrorString(error_id); + } + error_id = gpuGetLastError(); + if (error_id != gpuSuccess) { + return gpuGetErrorString(error_id); + } + + return NULL; +} diff --git a/src/gpu/peel.h b/src/gpu/peel.h new file mode 100644 index 00000000..8c6d9b8f --- /dev/null +++ b/src/gpu/peel.h @@ -0,0 +1,38 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#pragma once + +#include "types.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +const char *xyzs_to_uvws(const XYZ *d_xyzs, const FLOAT *d_lmsts, UVW *d_uvws, RADec pointing_centre, int num_tiles, + int num_baselines, int num_timesteps); + +const char *rotate_average(const JonesF32 *d_high_res_vis, const float *d_high_res_weights, JonesF32 *d_low_res_vis, + RADec pointing_centre, const int num_timesteps, const int num_tiles, const int num_baselines, + const int num_freqs, const int freq_average_factor, const FLOAT *d_lmsts, const XYZ *d_xyzs, + const UVW *d_uvws_from, UVW *d_uvws_to, const FLOAT *d_lambdas); + +const char *iono_loop(const JonesF32 *d_vis_residual, const float *d_vis_weights, const JonesF32 *d_vis_model, + JonesF32 *d_vis_model_rotated, JonesF64 *d_iono_fits, double *iono_const_alpha, + double *iono_const_beta, const int num_timesteps, const int num_tiles, const int num_baselines, + const int num_freqs, const int num_iterations, const FLOAT *d_lmsts, const UVW *d_uvws, + const FLOAT *d_lambdas_m); + +const char *subtract_iono(JonesF32 *d_vis_residual, const JonesF32 *d_vis_model, double iono_const_alpha, + double iono_const_beta, double old_iono_const_alpha, double old_iono_const_beta, + const UVW *d_uvws, const FLOAT *d_lambdas_m, const int num_timesteps, const int num_baselines, + const int num_freqs); + +const char *add_model(JonesF32 *d_vis_residual, const JonesF32 *d_vis_model, const FLOAT iono_const_alpha, + const FLOAT iono_const_beta, const FLOAT *d_lambdas_m, const UVW *d_uvws, + const int num_timesteps, const int num_baselines, const int num_freqs); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/src/gpu/peel_double.rs b/src/gpu/peel_double.rs new file mode 100644 index 00000000..0776fffe --- /dev/null +++ b/src/gpu/peel_double.rs @@ -0,0 +1,78 @@ +/* automatically generated by rust-bindgen 0.65.1 */ + +extern "C" { + pub fn xyzs_to_uvws( + d_xyzs: *const XYZ, + d_lmsts: *const f64, + d_uvws: *mut UVW, + pointing_centre: RADec, + num_tiles: ::std::os::raw::c_int, + num_baselines: ::std::os::raw::c_int, + num_timesteps: ::std::os::raw::c_int, + ) -> *const ::std::os::raw::c_char; +} +extern "C" { + pub fn rotate_average( + d_high_res_vis: *const JonesF32, + d_high_res_weights: *const f32, + d_low_res_vis: *mut JonesF32, + pointing_centre: RADec, + num_timesteps: ::std::os::raw::c_int, + num_tiles: ::std::os::raw::c_int, + num_baselines: ::std::os::raw::c_int, + num_freqs: ::std::os::raw::c_int, + freq_average_factor: ::std::os::raw::c_int, + d_lmsts: *const f64, + d_xyzs: *const XYZ, + d_uvws_from: *const UVW, + d_uvws_to: *mut UVW, + d_lambdas: *const f64, + ) -> *const ::std::os::raw::c_char; +} +extern "C" { + pub fn iono_loop( + d_vis_residual: *const JonesF32, + d_vis_weights: *const f32, + d_vis_model: *const JonesF32, + d_vis_model_rotated: *mut JonesF32, + d_iono_fits: *mut JonesF64, + iono_const_alpha: *mut f64, + iono_const_beta: *mut f64, + num_timesteps: ::std::os::raw::c_int, + num_tiles: ::std::os::raw::c_int, + num_baselines: ::std::os::raw::c_int, + num_freqs: ::std::os::raw::c_int, + num_iterations: ::std::os::raw::c_int, + d_lmsts: *const f64, + d_uvws: *const UVW, + d_lambdas_m: *const f64, + ) -> *const ::std::os::raw::c_char; +} +extern "C" { + pub fn subtract_iono( + d_vis_residual: *mut JonesF32, + d_vis_model: *const JonesF32, + iono_const_alpha: f64, + iono_const_beta: f64, + old_iono_const_alpha: f64, + old_iono_const_beta: f64, + d_uvws: *const UVW, + d_lambdas_m: *const f64, + num_timesteps: ::std::os::raw::c_int, + num_baselines: ::std::os::raw::c_int, + num_freqs: ::std::os::raw::c_int, + ) -> *const ::std::os::raw::c_char; +} +extern "C" { + pub fn add_model( + d_vis_residual: *mut JonesF32, + d_vis_model: *const JonesF32, + iono_const_alpha: f64, + iono_const_beta: f64, + d_lambdas_m: *const f64, + d_uvws: *const UVW, + num_timesteps: ::std::os::raw::c_int, + num_baselines: ::std::os::raw::c_int, + num_freqs: ::std::os::raw::c_int, + ) -> *const ::std::os::raw::c_char; +} diff --git a/src/gpu/peel_single.rs b/src/gpu/peel_single.rs new file mode 100644 index 00000000..290b6a77 --- /dev/null +++ b/src/gpu/peel_single.rs @@ -0,0 +1,78 @@ +/* automatically generated by rust-bindgen 0.65.1 */ + +extern "C" { + pub fn xyzs_to_uvws( + d_xyzs: *const XYZ, + d_lmsts: *const f32, + d_uvws: *mut UVW, + pointing_centre: RADec, + num_tiles: ::std::os::raw::c_int, + num_baselines: ::std::os::raw::c_int, + num_timesteps: ::std::os::raw::c_int, + ) -> *const ::std::os::raw::c_char; +} +extern "C" { + pub fn rotate_average( + d_high_res_vis: *const JonesF32, + d_high_res_weights: *const f32, + d_low_res_vis: *mut JonesF32, + pointing_centre: RADec, + num_timesteps: ::std::os::raw::c_int, + num_tiles: ::std::os::raw::c_int, + num_baselines: ::std::os::raw::c_int, + num_freqs: ::std::os::raw::c_int, + freq_average_factor: ::std::os::raw::c_int, + d_lmsts: *const f32, + d_xyzs: *const XYZ, + d_uvws_from: *const UVW, + d_uvws_to: *mut UVW, + d_lambdas: *const f32, + ) -> *const ::std::os::raw::c_char; +} +extern "C" { + pub fn iono_loop( + d_vis_residual: *const JonesF32, + d_vis_weights: *const f32, + d_vis_model: *const JonesF32, + d_vis_model_rotated: *mut JonesF32, + d_iono_fits: *mut JonesF64, + iono_const_alpha: *mut f64, + iono_const_beta: *mut f64, + num_timesteps: ::std::os::raw::c_int, + num_tiles: ::std::os::raw::c_int, + num_baselines: ::std::os::raw::c_int, + num_freqs: ::std::os::raw::c_int, + num_iterations: ::std::os::raw::c_int, + d_lmsts: *const f32, + d_uvws: *const UVW, + d_lambdas_m: *const f32, + ) -> *const ::std::os::raw::c_char; +} +extern "C" { + pub fn subtract_iono( + d_vis_residual: *mut JonesF32, + d_vis_model: *const JonesF32, + iono_const_alpha: f64, + iono_const_beta: f64, + old_iono_const_alpha: f64, + old_iono_const_beta: f64, + d_uvws: *const UVW, + d_lambdas_m: *const f32, + num_timesteps: ::std::os::raw::c_int, + num_baselines: ::std::os::raw::c_int, + num_freqs: ::std::os::raw::c_int, + ) -> *const ::std::os::raw::c_char; +} +extern "C" { + pub fn add_model( + d_vis_residual: *mut JonesF32, + d_vis_model: *const JonesF32, + iono_const_alpha: f32, + iono_const_beta: f32, + d_lambdas_m: *const f32, + d_uvws: *const UVW, + num_timesteps: ::std::os::raw::c_int, + num_baselines: ::std::os::raw::c_int, + num_freqs: ::std::os::raw::c_int, + ) -> *const ::std::os::raw::c_char; +} diff --git a/src/gpu/types.h b/src/gpu/types.h index 318236f8..4c15626b 100644 --- a/src/gpu/types.h +++ b/src/gpu/types.h @@ -29,6 +29,39 @@ const FLOAT SBF_DX = 0.01; extern "C" { #endif // __cplusplus +/** + * (right ascension, declination) coordinates. Each is in units of radians. + */ +typedef struct RADec { + // right ascension coordinate [radians] + FLOAT ra; + // declination coordinate [radians] + FLOAT dec; +} RADec; + +/** + * (hour angle, declination) coordinates. Each is in units of radians. + */ +typedef struct HADec { + // hour angle coordinate [radians] + FLOAT ha; + // declination coordinate [radians] + FLOAT dec; +} HADec; + +/** + * The (x,y,z) coordinates of an antenna/tile/station. They are in units of + * metres. + */ +typedef struct XYZ { + // x coordinate [metres] + FLOAT x; + // y coordinate [metres] + FLOAT y; + // z coordinate [metres] + FLOAT z; +} XYZ; + /** * The (u,v,w) coordinates of a baseline. They are in units of metres. */ diff --git a/src/gpu/types_double.rs b/src/gpu/types_double.rs index 19b557f4..63a6352c 100644 --- a/src/gpu/types_double.rs +++ b/src/gpu/types_double.rs @@ -4,6 +4,81 @@ pub const SBF_C: f64 = 5000.0; pub const SBF_L: ::std::os::raw::c_int = 10001; pub const SBF_N: ::std::os::raw::c_int = 101; pub const SBF_DX: f64 = 0.01; +#[doc = " (right ascension, declination) coordinates. Each is in units of radians."] +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct RADec { + pub ra: f64, + pub dec: f64, +} +#[test] +fn bindgen_test_layout_RADec() { + const UNINIT: ::std::mem::MaybeUninit = ::std::mem::MaybeUninit::uninit(); + let ptr = UNINIT.as_ptr(); + assert_eq!( + ::std::mem::size_of::(), + 16usize, + concat!("Size of: ", stringify!(RADec)) + ); + assert_eq!( + ::std::mem::align_of::(), + 8usize, + concat!("Alignment of ", stringify!(RADec)) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).ra) as usize - ptr as usize }, + 0usize, + concat!("Offset of field: ", stringify!(RADec), "::", stringify!(ra)) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).dec) as usize - ptr as usize }, + 8usize, + concat!( + "Offset of field: ", + stringify!(RADec), + "::", + stringify!(dec) + ) + ); +} +#[doc = " The (x,y,z) coordinates of an antenna/tile/station. They are in units of\n metres."] +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct XYZ { + pub x: f64, + pub y: f64, + pub z: f64, +} +#[test] +fn bindgen_test_layout_XYZ() { + const UNINIT: ::std::mem::MaybeUninit = ::std::mem::MaybeUninit::uninit(); + let ptr = UNINIT.as_ptr(); + assert_eq!( + ::std::mem::size_of::(), + 24usize, + concat!("Size of: ", stringify!(XYZ)) + ); + assert_eq!( + ::std::mem::align_of::(), + 8usize, + concat!("Alignment of ", stringify!(XYZ)) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).x) as usize - ptr as usize }, + 0usize, + concat!("Offset of field: ", stringify!(XYZ), "::", stringify!(x)) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).y) as usize - ptr as usize }, + 8usize, + concat!("Offset of field: ", stringify!(XYZ), "::", stringify!(y)) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).z) as usize - ptr as usize }, + 16usize, + concat!("Offset of field: ", stringify!(XYZ), "::", stringify!(z)) + ); +} #[doc = " The (u,v,w) coordinates of a baseline. They are in units of metres."] #[repr(C)] #[derive(Debug, Default, Copy, Clone, PartialEq)] diff --git a/src/gpu/types_single.rs b/src/gpu/types_single.rs index 0c9f7497..7a21f632 100644 --- a/src/gpu/types_single.rs +++ b/src/gpu/types_single.rs @@ -4,6 +4,81 @@ pub const SBF_C: f32 = 5000.0; pub const SBF_L: ::std::os::raw::c_int = 10001; pub const SBF_N: ::std::os::raw::c_int = 101; pub const SBF_DX: f32 = 0.009999999776482582; +#[doc = " (right ascension, declination) coordinates. Each is in units of radians."] +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct RADec { + pub ra: f32, + pub dec: f32, +} +#[test] +fn bindgen_test_layout_RADec() { + const UNINIT: ::std::mem::MaybeUninit = ::std::mem::MaybeUninit::uninit(); + let ptr = UNINIT.as_ptr(); + assert_eq!( + ::std::mem::size_of::(), + 8usize, + concat!("Size of: ", stringify!(RADec)) + ); + assert_eq!( + ::std::mem::align_of::(), + 4usize, + concat!("Alignment of ", stringify!(RADec)) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).ra) as usize - ptr as usize }, + 0usize, + concat!("Offset of field: ", stringify!(RADec), "::", stringify!(ra)) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).dec) as usize - ptr as usize }, + 4usize, + concat!( + "Offset of field: ", + stringify!(RADec), + "::", + stringify!(dec) + ) + ); +} +#[doc = " The (x,y,z) coordinates of an antenna/tile/station. They are in units of\n metres."] +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct XYZ { + pub x: f32, + pub y: f32, + pub z: f32, +} +#[test] +fn bindgen_test_layout_XYZ() { + const UNINIT: ::std::mem::MaybeUninit = ::std::mem::MaybeUninit::uninit(); + let ptr = UNINIT.as_ptr(); + assert_eq!( + ::std::mem::size_of::(), + 12usize, + concat!("Size of: ", stringify!(XYZ)) + ); + assert_eq!( + ::std::mem::align_of::(), + 4usize, + concat!("Alignment of ", stringify!(XYZ)) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).x) as usize - ptr as usize }, + 0usize, + concat!("Offset of field: ", stringify!(XYZ), "::", stringify!(x)) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).y) as usize - ptr as usize }, + 4usize, + concat!("Offset of field: ", stringify!(XYZ), "::", stringify!(y)) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).z) as usize - ptr as usize }, + 8usize, + concat!("Offset of field: ", stringify!(XYZ), "::", stringify!(z)) + ); +} #[doc = " The (u,v,w) coordinates of a baseline. They are in units of metres."] #[repr(C)] #[derive(Debug, Default, Copy, Clone, PartialEq)] diff --git a/src/gpu/update_rust_bindings.sh b/src/gpu/update_rust_bindings.sh index 5070e37d..2de36af6 100755 --- a/src/gpu/update_rust_bindings.sh +++ b/src/gpu/update_rust_bindings.sh @@ -20,6 +20,8 @@ for PRECISION in SINGLE DOUBLE; do bindgen "${SCRIPTPATH}"/types.h \ --ignore-functions \ --blocklist-type "__int8_t" \ + --allowlist-type "RADec" \ + --allowlist-type "XYZ" \ --allowlist-type "UVW" \ --allowlist-type "LmnRime" \ --allowlist-type "ShapeletCoeff" \ @@ -42,4 +44,14 @@ for PRECISION in SINGLE DOUBLE; do --blocklist-type ".*" \ -- -D "${PRECISION}" \ > "${SCRIPTPATH}/model_${LOWER_CASE}.rs" + + bindgen "${SCRIPTPATH}"/peel.h \ + --allowlist-function "rotate_average" \ + --allowlist-function "xyzs_to_uvws" \ + --allowlist-function "iono_loop" \ + --allowlist-function "subtract_iono" \ + --allowlist-function "add_model" \ + --blocklist-type ".*" \ + -- -D "${PRECISION}" \ + > "${SCRIPTPATH}/peel_${LOWER_CASE}.rs" done diff --git a/src/io/write/mod.rs b/src/io/write/mod.rs index 499cc103..ecffdb8f 100644 --- a/src/io/write/mod.rs +++ b/src/io/write/mod.rs @@ -429,6 +429,7 @@ pub(crate) fn write_vis( ..vis_ctx.clone() }; + trace!("this_timeblock.range: {:?}", this_timeblock.range); for vis_writer in writers.iter_mut() { vis_writer.write_vis( out_data_tfb.slice(s![0..this_timeblock.range.len(), .., ..]), diff --git a/src/model/cpu.rs b/src/model/cpu.rs index f7af4c71..6f4eee07 100644 --- a/src/model/cpu.rs +++ b/src/model/cpu.rs @@ -27,38 +27,38 @@ use crate::{ constants::*, context::Polarisations, model::mask_pols, - srclist::{ComponentList, GaussianParams, PerComponentParams, SourceList}, + srclist::{ComponentList, GaussianParams, PerComponentParams, Source, SourceList}, }; const GAUSSIAN_EXP_CONST: f64 = -(FRAC_PI_2 * FRAC_PI_2) / LN_2; const SHAPELET_CONST: f64 = SQRT_FRAC_PI_SQ_2_LN_2 / shapelets::SBF_DX; pub struct SkyModellerCpu<'a> { - pub(super) beam: &'a dyn Beam, + pub(crate) beam: &'a dyn Beam, /// The phase centre used for all modelling. - pub(super) phase_centre: RADec, + pub(crate) phase_centre: RADec, /// The longitude of the array we're using \[radians\]. - pub(super) array_longitude: f64, + pub(crate) array_longitude: f64, /// The latitude of the array we're using \[radians\]. - pub(super) array_latitude: f64, + pub(crate) array_latitude: f64, /// The UT1 - UTC offset. If this is 0, effectively UT1 == UTC, which is a /// wrong assumption by up to 0.9s. We assume the this value does not change /// over the timestamps given to this `SkyModellerCpu`. - pub(super) dut1: Duration, + pub(crate) dut1: Duration, /// Shift baselines and LSTs back to J2000. - pub(super) precess: bool, + pub(crate) precess: bool, - pub(super) unflagged_fine_chan_freqs: &'a [f64], + pub(crate) unflagged_fine_chan_freqs: &'a [f64], /// The [`XyzGeodetic`] positions of each of the unflagged tiles. - pub(super) unflagged_tile_xyzs: &'a [XyzGeodetic], - pub(super) flagged_tiles: &'a HashSet, - pub(super) unflagged_baseline_to_tile_map: HashMap, + pub(crate) unflagged_tile_xyzs: &'a [XyzGeodetic], + pub(crate) flagged_tiles: &'a HashSet, + pub(crate) unflagged_baseline_to_tile_map: HashMap, - pub(super) components: ComponentList, + pub(crate) components: ComponentList, - pub(super) pols: Polarisations, + pub(crate) pols: Polarisations, } impl<'a> SkyModellerCpu<'a> { @@ -76,7 +76,14 @@ impl<'a> SkyModellerCpu<'a> { dut1: Duration, apply_precession: bool, ) -> SkyModellerCpu<'a> { - let components = ComponentList::new(source_list, unflagged_fine_chan_freqs, phase_centre); + let components = ComponentList::new( + source_list + .values() + .rev() + .flat_map(|src| src.components.iter()), + unflagged_fine_chan_freqs, + phase_centre, + ); let maps = crate::math::TileBaselineFlags::new( unflagged_tile_xyzs.len() + flagged_tiles.len(), flagged_tiles.clone(), @@ -727,4 +734,18 @@ impl<'a> super::SkyModeller<'a> for SkyModellerCpu<'a> { Ok(uvws) } + + fn update_with_a_source( + &mut self, + source: &Source, + phase_centre: RADec, + ) -> Result<(), ModelError> { + self.phase_centre = phase_centre; + self.components = ComponentList::new( + source.components.iter(), + self.unflagged_fine_chan_freqs, + phase_centre, + ); + Ok(()) + } } diff --git a/src/model/gpu.rs b/src/model/gpu.rs index 54a84efc..29643c7a 100644 --- a/src/model/gpu.rs +++ b/src/model/gpu.rs @@ -21,7 +21,8 @@ use crate::{ context::Polarisations, gpu::{self, gpu_kernel_call, DevicePointer, GpuError, GpuFloat, GpuJones}, srclist::{ - get_instrumental_flux_densities, ComponentType, FluxDensityType, ShapeletCoeff, SourceList, + get_instrumental_flux_densities, ComponentType, FluxDensityType, ShapeletCoeff, Source, + SourceComponent, SourceList, }, }; @@ -63,6 +64,7 @@ pub struct SkyModellerGpu<'a> { /// perhaps 33. tile_index_to_unflagged_tile_index_map: DevicePointer, + freqs: &'a [f64], d_freqs: DevicePointer, d_shapelet_basis_values: DevicePointer, @@ -133,6 +135,10 @@ pub struct SkyModellerGpu<'a> { shapelet_list_coeff_lens: DevicePointer, } +// You're not re-using pointers after they've been sent to another thread, +// right? +unsafe impl<'a> Send for SkyModellerGpu<'a> {} + impl<'a> SkyModellerGpu<'a> { /// Given a source list, split the components into each component type (e.g. /// points, shapelets) and by each flux density type (e.g. list, power law), @@ -146,7 +152,7 @@ impl<'a> SkyModellerGpu<'a> { source_list: &SourceList, pols: Polarisations, unflagged_tile_xyzs: &'a [XyzGeodetic], - unflagged_fine_chan_freqs: &[f64], + unflagged_fine_chan_freqs: &'a [f64], flagged_tiles: &HashSet, phase_centre: RADec, array_longitude_rad: f64, @@ -154,47 +160,177 @@ impl<'a> SkyModellerGpu<'a> { dut1: Duration, apply_precession: bool, ) -> Result, ModelError> { - let mut point_power_law_radecs: Vec = vec![]; + // Variables for CUDA/HIP. They're made flexible in their types for + // whichever precision is being used. + let (unflagged_fine_chan_freqs_ints, unflagged_fine_chan_freqs_floats): (Vec<_>, Vec<_>) = + unflagged_fine_chan_freqs + .iter() + .map(|&f| (f as u32, f as GpuFloat)) + .unzip(); + let shapelet_basis_values: Vec = shapelets::SHAPELET_BASIS_VALUES + .iter() + .map(|&f| f as GpuFloat) + .collect(); + + let num_baselines = (unflagged_tile_xyzs.len() * (unflagged_tile_xyzs.len() - 1)) / 2; + let num_freqs = unflagged_fine_chan_freqs.len(); + + let d_freqs = DevicePointer::copy_to_device(&unflagged_fine_chan_freqs_floats)?; + let d_shapelet_basis_values = DevicePointer::copy_to_device(&shapelet_basis_values)?; + + let mut tile_index_to_unflagged_tile_index_map: Vec = + Vec::with_capacity(unflagged_tile_xyzs.len()); + let mut i_unflagged_tile = 0; + for i_tile in 0..unflagged_tile_xyzs.len() + flagged_tiles.len() { + if flagged_tiles.contains(&i_tile) { + i_unflagged_tile += 1; + continue; + } + tile_index_to_unflagged_tile_index_map.push(i_unflagged_tile); + i_unflagged_tile += 1; + } + let d_tile_index_to_unflagged_tile_index_map = + DevicePointer::copy_to_device(&tile_index_to_unflagged_tile_index_map)?; + + let mut modeller = SkyModellerGpu { + gpu_beam: beam.prepare_gpu_beam(&unflagged_fine_chan_freqs_ints)?, + + phase_centre, + array_longitude: array_longitude_rad, + array_latitude: array_latitude_rad, + dut1, + precess: apply_precession, + + unflagged_tile_xyzs, + num_baselines: num_baselines.try_into().expect("not bigger than i32::MAX"), + num_freqs: num_freqs.try_into().expect("not bigger than i32::MAX"), + + pols, + + tile_index_to_unflagged_tile_index_map: d_tile_index_to_unflagged_tile_index_map, + + freqs: unflagged_fine_chan_freqs, + d_freqs, + d_shapelet_basis_values, + + point_power_law_radecs: vec![], + point_power_law_lmns: DevicePointer::default(), + point_power_law_fds: DevicePointer::default(), + point_power_law_sis: DevicePointer::default(), + + point_curved_power_law_radecs: vec![], + point_curved_power_law_lmns: DevicePointer::default(), + point_curved_power_law_fds: DevicePointer::default(), + point_curved_power_law_sis: DevicePointer::default(), + point_curved_power_law_qs: DevicePointer::default(), + + point_list_radecs: vec![], + point_list_lmns: DevicePointer::default(), + point_list_fds: DevicePointer::default(), + + gaussian_power_law_radecs: vec![], + gaussian_power_law_lmns: DevicePointer::default(), + gaussian_power_law_fds: DevicePointer::default(), + gaussian_power_law_sis: DevicePointer::default(), + gaussian_power_law_gps: DevicePointer::default(), + + gaussian_curved_power_law_radecs: vec![], + gaussian_curved_power_law_lmns: DevicePointer::default(), + gaussian_curved_power_law_fds: DevicePointer::default(), + gaussian_curved_power_law_sis: DevicePointer::default(), + gaussian_curved_power_law_qs: DevicePointer::default(), + gaussian_curved_power_law_gps: DevicePointer::default(), + + gaussian_list_radecs: vec![], + gaussian_list_lmns: DevicePointer::default(), + gaussian_list_fds: DevicePointer::default(), + gaussian_list_gps: DevicePointer::default(), + + shapelet_power_law_radecs: vec![], + shapelet_power_law_lmns: DevicePointer::default(), + shapelet_power_law_fds: DevicePointer::default(), + shapelet_power_law_sis: DevicePointer::default(), + shapelet_power_law_gps: DevicePointer::default(), + shapelet_power_law_coeffs: DevicePointer::default(), + shapelet_power_law_coeff_lens: DevicePointer::default(), + + shapelet_curved_power_law_radecs: vec![], + shapelet_curved_power_law_lmns: DevicePointer::default(), + shapelet_curved_power_law_fds: DevicePointer::default(), + shapelet_curved_power_law_sis: DevicePointer::default(), + shapelet_curved_power_law_qs: DevicePointer::default(), + shapelet_curved_power_law_gps: DevicePointer::default(), + shapelet_curved_power_law_coeffs: DevicePointer::default(), + shapelet_curved_power_law_coeff_lens: DevicePointer::default(), + + shapelet_list_radecs: vec![], + shapelet_list_lmns: DevicePointer::default(), + shapelet_list_fds: DevicePointer::default(), + shapelet_list_gps: DevicePointer::default(), + shapelet_list_coeffs: DevicePointer::default(), + shapelet_list_coeff_lens: DevicePointer::default(), + }; + modeller.update_source_list( + source_list + .values() + .rev() + .flat_map(|src| src.components.iter()), + phase_centre, + )?; + Ok(modeller) + } + + fn update_source_list<'b, I>( + &mut self, + components: I, + phase_centre: RADec, + ) -> Result<(), ModelError> + where + I: IntoIterator, + { + self.phase_centre = phase_centre; + + self.point_power_law_radecs.clear(); let mut point_power_law_lmns: Vec = vec![]; let mut point_power_law_fds: Vec<_> = vec![]; let mut point_power_law_sis: Vec<_> = vec![]; - let mut point_curved_power_law_radecs: Vec = vec![]; + self.point_curved_power_law_radecs.clear(); let mut point_curved_power_law_lmns: Vec = vec![]; let mut point_curved_power_law_fds: Vec<_> = vec![]; let mut point_curved_power_law_sis: Vec<_> = vec![]; let mut point_curved_power_law_qs: Vec<_> = vec![]; - let mut point_list_radecs: Vec = vec![]; + self.point_list_radecs.clear(); let mut point_list_lmns: Vec = vec![]; let mut point_list_fds: Vec<&FluxDensityType> = vec![]; - let mut gaussian_power_law_radecs: Vec = vec![]; + self.gaussian_power_law_radecs.clear(); let mut gaussian_power_law_lmns: Vec = vec![]; let mut gaussian_power_law_fds: Vec<_> = vec![]; let mut gaussian_power_law_sis: Vec<_> = vec![]; let mut gaussian_power_law_gps: Vec = vec![]; - let mut gaussian_curved_power_law_radecs: Vec = vec![]; + self.gaussian_curved_power_law_radecs.clear(); let mut gaussian_curved_power_law_lmns: Vec = vec![]; let mut gaussian_curved_power_law_fds: Vec<_> = vec![]; let mut gaussian_curved_power_law_sis: Vec<_> = vec![]; let mut gaussian_curved_power_law_qs: Vec<_> = vec![]; let mut gaussian_curved_power_law_gps: Vec = vec![]; - let mut gaussian_list_radecs: Vec = vec![]; + self.gaussian_list_radecs.clear(); let mut gaussian_list_lmns: Vec = vec![]; let mut gaussian_list_fds: Vec<&FluxDensityType> = vec![]; let mut gaussian_list_gps: Vec = vec![]; - let mut shapelet_power_law_radecs: Vec = vec![]; + self.shapelet_power_law_radecs.clear(); let mut shapelet_power_law_lmns: Vec = vec![]; let mut shapelet_power_law_fds: Vec<_> = vec![]; let mut shapelet_power_law_sis: Vec<_> = vec![]; let mut shapelet_power_law_gps: Vec = vec![]; let mut shapelet_power_law_coeffs: Vec<&[ShapeletCoeff]> = vec![]; - let mut shapelet_curved_power_law_radecs: Vec = vec![]; + self.shapelet_curved_power_law_radecs.clear(); let mut shapelet_curved_power_law_lmns: Vec = vec![]; let mut shapelet_curved_power_law_fds: Vec<_> = vec![]; let mut shapelet_curved_power_law_sis: Vec<_> = vec![]; @@ -202,7 +338,7 @@ impl<'a> SkyModellerGpu<'a> { let mut shapelet_curved_power_law_gps: Vec = vec![]; let mut shapelet_curved_power_law_coeffs: Vec<&[ShapeletCoeff]> = vec![]; - let mut shapelet_list_radecs: Vec = vec![]; + self.shapelet_list_radecs.clear(); let mut shapelet_list_lmns: Vec = vec![]; let mut shapelet_list_fds: Vec<&FluxDensityType> = vec![]; let mut shapelet_list_gps: Vec = vec![]; @@ -228,13 +364,9 @@ impl<'a> SkyModellerGpu<'a> { // float starting from the brightest component means that the // floating-point precision errors are greater as we work through the // source list. - for comp in source_list - .iter() - .rev() - .flat_map(|(_, src)| src.components.iter()) - { + for comp in components { let radec = comp.radec; - let LmnRime { l, m, n } = comp.radec.to_lmn(phase_centre).prepare_for_rime(); + let LmnRime { l, m, n } = radec.to_lmn(phase_centre).prepare_for_rime(); let lmn = gpu::LmnRime { l: l as GpuFloat, m: m as GpuFloat, @@ -243,7 +375,7 @@ impl<'a> SkyModellerGpu<'a> { match &comp.comp_type { ComponentType::Point => match comp.flux_type { FluxDensityType::PowerLaw { si, .. } => { - point_power_law_radecs.push(radec); + self.point_power_law_radecs.push(radec); point_power_law_lmns.push(lmn); let fd_at_150mhz = comp.estimate_at_freq(gpu::POWER_LAW_FD_REF_FREQ as _); let inst_fd: Jones = fd_at_150mhz.to_inst_stokes(); @@ -253,7 +385,7 @@ impl<'a> SkyModellerGpu<'a> { } FluxDensityType::CurvedPowerLaw { si, q, .. } => { - point_curved_power_law_radecs.push(radec); + self.point_curved_power_law_radecs.push(radec); point_curved_power_law_lmns.push(lmn); let fd_at_150mhz = comp.estimate_at_freq(gpu::POWER_LAW_FD_REF_FREQ as _); let inst_fd: Jones = fd_at_150mhz.to_inst_stokes(); @@ -264,7 +396,7 @@ impl<'a> SkyModellerGpu<'a> { } FluxDensityType::List { .. } => { - point_list_radecs.push(radec); + self.point_list_radecs.push(radec); point_list_lmns.push(lmn); point_list_fds.push(&comp.flux_type); } @@ -278,7 +410,7 @@ impl<'a> SkyModellerGpu<'a> { }; match comp.flux_type { FluxDensityType::PowerLaw { si, .. } => { - gaussian_power_law_radecs.push(radec); + self.gaussian_power_law_radecs.push(radec); gaussian_power_law_lmns.push(lmn); let fd_at_150mhz = comp.estimate_at_freq(gpu::POWER_LAW_FD_REF_FREQ as _); @@ -290,7 +422,7 @@ impl<'a> SkyModellerGpu<'a> { } FluxDensityType::CurvedPowerLaw { si, q, .. } => { - gaussian_curved_power_law_radecs.push(radec); + self.gaussian_curved_power_law_radecs.push(radec); gaussian_curved_power_law_lmns.push(lmn); let fd_at_150mhz = comp.estimate_at_freq(gpu::POWER_LAW_FD_REF_FREQ as _); @@ -303,7 +435,7 @@ impl<'a> SkyModellerGpu<'a> { } FluxDensityType::List { .. } => { - gaussian_list_radecs.push(radec); + self.gaussian_list_radecs.push(radec); gaussian_list_lmns.push(lmn); gaussian_list_fds.push(&comp.flux_type); gaussian_list_gps.push(gp); @@ -324,7 +456,7 @@ impl<'a> SkyModellerGpu<'a> { }; match comp.flux_type { FluxDensityType::PowerLaw { si, .. } => { - shapelet_power_law_radecs.push(radec); + self.shapelet_power_law_radecs.push(radec); shapelet_power_law_lmns.push(lmn); let fd_at_150mhz = comp .flux_type @@ -338,7 +470,7 @@ impl<'a> SkyModellerGpu<'a> { } FluxDensityType::CurvedPowerLaw { si, q, .. } => { - shapelet_curved_power_law_radecs.push(radec); + self.shapelet_curved_power_law_radecs.push(radec); shapelet_curved_power_law_lmns.push(lmn); let fd_at_150mhz = comp.estimate_at_freq(gpu::POWER_LAW_FD_REF_FREQ as _); @@ -352,7 +484,7 @@ impl<'a> SkyModellerGpu<'a> { } FluxDensityType::List { .. } => { - shapelet_list_radecs.push(radec); + self.shapelet_list_radecs.push(radec); shapelet_list_lmns.push(lmn); shapelet_list_fds.push(&comp.flux_type); shapelet_list_gps.push(gp); @@ -364,14 +496,11 @@ impl<'a> SkyModellerGpu<'a> { } let point_list_fds = - get_instrumental_flux_densities(&point_list_fds, unflagged_fine_chan_freqs) - .mapv(jones_to_gpu_jones); - let gaussian_list_fds = - get_instrumental_flux_densities(&gaussian_list_fds, unflagged_fine_chan_freqs) - .mapv(jones_to_gpu_jones); - let shapelet_list_fds = - get_instrumental_flux_densities(&shapelet_list_fds, unflagged_fine_chan_freqs) - .mapv(jones_to_gpu_jones); + get_instrumental_flux_densities(&point_list_fds, self.freqs).mapv(jones_to_gpu_jones); + let gaussian_list_fds = get_instrumental_flux_densities(&gaussian_list_fds, self.freqs) + .mapv(jones_to_gpu_jones); + let shapelet_list_fds = get_instrumental_flux_densities(&shapelet_list_fds, self.freqs) + .mapv(jones_to_gpu_jones); let (shapelet_power_law_coeffs, shapelet_power_law_coeff_lens) = get_flattened_coeffs(shapelet_power_law_coeffs); @@ -380,149 +509,85 @@ impl<'a> SkyModellerGpu<'a> { let (shapelet_list_coeffs, shapelet_list_coeff_lens) = get_flattened_coeffs(shapelet_list_coeffs); - // Variables for CUDA/HIP. They're made flexible in their types for - // whichever precision is being used. - let (unflagged_fine_chan_freqs_ints, unflagged_fine_chan_freqs_floats): (Vec<_>, Vec<_>) = - unflagged_fine_chan_freqs - .iter() - .map(|&f| (f as u32, f as GpuFloat)) - .unzip(); - let shapelet_basis_values: Vec = shapelets::SHAPELET_BASIS_VALUES - .iter() - .map(|&f| f as GpuFloat) - .collect(); - - let num_baselines = (unflagged_tile_xyzs.len() * (unflagged_tile_xyzs.len() - 1)) / 2; - let num_freqs = unflagged_fine_chan_freqs.len(); - - let d_freqs = DevicePointer::copy_to_device(&unflagged_fine_chan_freqs_floats)?; - let d_shapelet_basis_values = DevicePointer::copy_to_device(&shapelet_basis_values)?; - - let mut tile_index_to_unflagged_tile_index_map: Vec = - Vec::with_capacity(unflagged_tile_xyzs.len()); - let mut i_unflagged_tile = 0; - for i_tile in 0..unflagged_tile_xyzs.len() + flagged_tiles.len() { - if flagged_tiles.contains(&i_tile) { - i_unflagged_tile += 1; - continue; - } - tile_index_to_unflagged_tile_index_map.push(i_unflagged_tile); - i_unflagged_tile += 1; - } - let d_tile_index_to_unflagged_tile_index_map = - DevicePointer::copy_to_device(&tile_index_to_unflagged_tile_index_map)?; - - Ok(SkyModellerGpu { - gpu_beam: beam.prepare_gpu_beam(&unflagged_fine_chan_freqs_ints)?, + self.point_power_law_lmns.overwrite(&point_power_law_lmns)?; + self.point_power_law_fds.overwrite(&point_power_law_fds)?; + self.point_power_law_sis.overwrite(&point_power_law_sis)?; + + self.point_curved_power_law_lmns + .overwrite(&point_curved_power_law_lmns)?; + self.point_curved_power_law_fds + .overwrite(&point_curved_power_law_fds)?; + self.point_curved_power_law_sis + .overwrite(&point_curved_power_law_sis)?; + self.point_curved_power_law_qs + .overwrite(&point_curved_power_law_qs)?; + + self.point_list_lmns.overwrite(&point_list_lmns)?; + self.point_list_fds + .overwrite(point_list_fds.as_slice().unwrap())?; + + self.gaussian_power_law_lmns + .overwrite(&gaussian_power_law_lmns)?; + self.gaussian_power_law_fds + .overwrite(&gaussian_power_law_fds)?; + self.gaussian_power_law_sis + .overwrite(&gaussian_power_law_sis)?; + self.gaussian_power_law_gps + .overwrite(&gaussian_power_law_gps)?; + + self.gaussian_curved_power_law_lmns + .overwrite(&gaussian_curved_power_law_lmns)?; + self.gaussian_curved_power_law_fds + .overwrite(&gaussian_curved_power_law_fds)?; + self.gaussian_curved_power_law_sis + .overwrite(&gaussian_curved_power_law_sis)?; + self.gaussian_curved_power_law_qs + .overwrite(&gaussian_curved_power_law_qs)?; + self.gaussian_curved_power_law_gps + .overwrite(&gaussian_curved_power_law_gps)?; + + self.gaussian_list_lmns.overwrite(&gaussian_list_lmns)?; + self.gaussian_list_fds + .overwrite(gaussian_list_fds.as_slice().unwrap())?; + self.gaussian_list_gps.overwrite(&gaussian_list_gps)?; + + self.shapelet_power_law_lmns + .overwrite(&shapelet_power_law_lmns)?; + self.shapelet_power_law_fds + .overwrite(&shapelet_power_law_fds)?; + self.shapelet_power_law_sis + .overwrite(&shapelet_power_law_sis)?; + self.shapelet_power_law_gps + .overwrite(&shapelet_power_law_gps)?; + self.shapelet_power_law_coeffs + .overwrite(&shapelet_power_law_coeffs)?; + self.shapelet_power_law_coeff_lens + .overwrite(&shapelet_power_law_coeff_lens)?; + + self.shapelet_curved_power_law_lmns + .overwrite(&shapelet_curved_power_law_lmns)?; + self.shapelet_curved_power_law_fds + .overwrite(&shapelet_curved_power_law_fds)?; + self.shapelet_curved_power_law_sis + .overwrite(&shapelet_curved_power_law_sis)?; + self.shapelet_curved_power_law_qs + .overwrite(&shapelet_curved_power_law_qs)?; + self.shapelet_curved_power_law_gps + .overwrite(&shapelet_curved_power_law_gps)?; + self.shapelet_curved_power_law_coeffs + .overwrite(&shapelet_curved_power_law_coeffs)?; + self.shapelet_curved_power_law_coeff_lens + .overwrite(&shapelet_curved_power_law_coeff_lens)?; + + self.shapelet_list_lmns.overwrite(&shapelet_list_lmns)?; + self.shapelet_list_fds + .overwrite(shapelet_list_fds.as_slice().unwrap())?; + self.shapelet_list_gps.overwrite(&shapelet_list_gps)?; + self.shapelet_list_coeffs.overwrite(&shapelet_list_coeffs)?; + self.shapelet_list_coeff_lens + .overwrite(&shapelet_list_coeff_lens)?; - phase_centre, - array_longitude: array_longitude_rad, - array_latitude: array_latitude_rad, - dut1, - precess: apply_precession, - - unflagged_tile_xyzs, - num_baselines: num_baselines.try_into().expect("not bigger than i32::MAX"), - num_freqs: num_freqs.try_into().expect("not bigger than i32::MAX"), - - pols, - - tile_index_to_unflagged_tile_index_map: d_tile_index_to_unflagged_tile_index_map, - - d_freqs, - d_shapelet_basis_values, - - point_power_law_radecs, - point_power_law_lmns: DevicePointer::copy_to_device(&point_power_law_lmns)?, - point_power_law_fds: DevicePointer::copy_to_device(&point_power_law_fds)?, - point_power_law_sis: DevicePointer::copy_to_device(&point_power_law_sis)?, - - point_curved_power_law_radecs, - point_curved_power_law_lmns: DevicePointer::copy_to_device( - &point_curved_power_law_lmns, - )?, - point_curved_power_law_fds: DevicePointer::copy_to_device(&point_curved_power_law_fds)?, - point_curved_power_law_sis: DevicePointer::copy_to_device(&point_curved_power_law_sis)?, - point_curved_power_law_qs: DevicePointer::copy_to_device(&point_curved_power_law_qs)?, - - point_list_radecs, - point_list_lmns: DevicePointer::copy_to_device(&point_list_lmns)?, - point_list_fds: DevicePointer::copy_to_device( - point_list_fds.as_slice().expect("is contiguous"), - )?, - - gaussian_power_law_radecs, - gaussian_power_law_lmns: DevicePointer::copy_to_device(&gaussian_power_law_lmns)?, - gaussian_power_law_fds: DevicePointer::copy_to_device(&gaussian_power_law_fds)?, - gaussian_power_law_sis: DevicePointer::copy_to_device(&gaussian_power_law_sis)?, - gaussian_power_law_gps: DevicePointer::copy_to_device(&gaussian_power_law_gps)?, - - gaussian_curved_power_law_radecs, - gaussian_curved_power_law_lmns: DevicePointer::copy_to_device( - &gaussian_curved_power_law_lmns, - )?, - gaussian_curved_power_law_fds: DevicePointer::copy_to_device( - &gaussian_curved_power_law_fds, - )?, - gaussian_curved_power_law_sis: DevicePointer::copy_to_device( - &gaussian_curved_power_law_sis, - )?, - gaussian_curved_power_law_qs: DevicePointer::copy_to_device( - &gaussian_curved_power_law_qs, - )?, - gaussian_curved_power_law_gps: DevicePointer::copy_to_device( - &gaussian_curved_power_law_gps, - )?, - - gaussian_list_radecs, - gaussian_list_lmns: DevicePointer::copy_to_device(&gaussian_list_lmns)?, - gaussian_list_fds: DevicePointer::copy_to_device( - gaussian_list_fds.as_slice().expect("is contiguous"), - )?, - gaussian_list_gps: DevicePointer::copy_to_device(&gaussian_list_gps)?, - - shapelet_power_law_radecs, - shapelet_power_law_lmns: DevicePointer::copy_to_device(&shapelet_power_law_lmns)?, - shapelet_power_law_fds: DevicePointer::copy_to_device(&shapelet_power_law_fds)?, - shapelet_power_law_sis: DevicePointer::copy_to_device(&shapelet_power_law_sis)?, - shapelet_power_law_gps: DevicePointer::copy_to_device(&shapelet_power_law_gps)?, - shapelet_power_law_coeffs: DevicePointer::copy_to_device(&shapelet_power_law_coeffs)?, - shapelet_power_law_coeff_lens: DevicePointer::copy_to_device( - &shapelet_power_law_coeff_lens, - )?, - - shapelet_curved_power_law_radecs, - shapelet_curved_power_law_lmns: DevicePointer::copy_to_device( - &shapelet_curved_power_law_lmns, - )?, - shapelet_curved_power_law_fds: DevicePointer::copy_to_device( - &shapelet_curved_power_law_fds, - )?, - shapelet_curved_power_law_sis: DevicePointer::copy_to_device( - &shapelet_curved_power_law_sis, - )?, - shapelet_curved_power_law_qs: DevicePointer::copy_to_device( - &shapelet_curved_power_law_qs, - )?, - shapelet_curved_power_law_gps: DevicePointer::copy_to_device( - &shapelet_curved_power_law_gps, - )?, - shapelet_curved_power_law_coeffs: DevicePointer::copy_to_device( - &shapelet_curved_power_law_coeffs, - )?, - shapelet_curved_power_law_coeff_lens: DevicePointer::copy_to_device( - &shapelet_curved_power_law_coeff_lens, - )?, - - shapelet_list_radecs, - shapelet_list_lmns: DevicePointer::copy_to_device(&shapelet_list_lmns)?, - shapelet_list_fds: DevicePointer::copy_to_device( - shapelet_list_fds.as_slice().expect("is contiguous"), - )?, - shapelet_list_gps: DevicePointer::copy_to_device(&shapelet_list_gps)?, - shapelet_list_coeffs: DevicePointer::copy_to_device(&shapelet_list_coeffs)?, - shapelet_list_coeff_lens: DevicePointer::copy_to_device(&shapelet_list_coeff_lens)?, - }) + Ok(()) } /// This function is mostly used for testing. For a single timestep, over @@ -811,7 +876,7 @@ impl<'a> SkyModellerGpu<'a> { /// it accepts GPU buffers directly, saving some allocations. Unlike the /// aforementioned function, the incoming visibilities *are not* cleared; /// visibilities are accumulated. - fn model_timestep_with( + pub(crate) fn model_timestep_with( &self, lst_rad: f64, array_latitude_rad: f64, @@ -972,6 +1037,14 @@ impl<'a> SkyModeller<'a> for SkyModellerGpu<'a> { Ok(uvws) } + + fn update_with_a_source( + &mut self, + source: &Source, + phase_centre: RADec, + ) -> Result<(), ModelError> { + self.update_source_list(source.components.iter(), phase_centre) + } } /// The return type of [SkyModellerGpu::get_shapelet_uvs]. These arrays have diff --git a/src/model/mod.rs b/src/model/mod.rs index 2125bf32..0363c800 100644 --- a/src/model/mod.rs +++ b/src/model/mod.rs @@ -23,7 +23,12 @@ use hifitime::{Duration, Epoch}; use marlu::{c32, Jones, RADec, XyzGeodetic, UVW}; use ndarray::{Array2, ArrayViewMut2}; -use crate::{beam::Beam, context::Polarisations, srclist::SourceList, MODEL_DEVICE}; +use crate::{ + beam::Beam, + context::Polarisations, + srclist::{Source, SourceList}, + MODEL_DEVICE, +}; #[derive(Debug, Clone, Copy)] pub enum ModelDevice { @@ -111,11 +116,11 @@ pub(crate) fn get_cpu_info() -> String { } #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))] - Ok(format!("{} CPU", std::env::consts::ARCH)); + format!("{} CPU", std::env::consts::ARCH) } /// An object that simulates sky-model visibilities. -pub trait SkyModeller<'a> { +pub trait SkyModeller<'a>: Send { /// Generate sky-model visibilities for a single timestep. The visibilities /// as well as the [`UVW`] coordinates used in generating the visibilities /// are returned. The visibilities are ordered by frequency and then @@ -163,6 +168,14 @@ pub trait SkyModeller<'a> { timestamp: Epoch, vis_fb: ArrayViewMut2>, ) -> Result, ModelError>; + + /// Replace the sky model with a single source. This is mostly useful for + /// peeling. + fn update_with_a_source( + &mut self, + source: &Source, + phase_centre: RADec, + ) -> Result<(), ModelError>; } /// Create a [`SkyModeller`] trait object that generates sky-model visibilities diff --git a/src/params/di_calibration.rs b/src/params/di_calibration.rs index 48c54919..a676ac20 100644 --- a/src/params/di_calibration.rs +++ b/src/params/di_calibration.rs @@ -471,52 +471,17 @@ impl DiCalParams { // Propagate errors. scoped_threads_result?; - debug!("Multiplying visibilities by weights"); - - // Multiply the visibilities by the weights (and baseline weights based on - // UVW cuts). If a weight is negative, it means the corresponding visibility - // should be flagged, so that visibility is set to 0; this means it does not - // affect calibration. Not iterating over weights during calibration makes - // makes calibration run significantly faster. - vis_data - .outer_iter_mut() - .into_par_iter() - .zip(vis_model.outer_iter_mut()) - .zip(vis_weights.outer_iter()) - .for_each(|((mut vis_data, mut vis_model), vis_weights)| { - vis_data - .outer_iter_mut() - .zip(vis_model.outer_iter_mut()) - .zip(vis_weights.outer_iter()) - .for_each(|((mut vis_data, mut vis_model), vis_weights)| { - vis_data - .iter_mut() - .zip(vis_model.iter_mut()) - .zip(vis_weights.iter()) - .zip(self.baseline_weights.iter()) - .for_each(|(((vis_data, vis_model), &vis_weight), baseline_weight)| { - let weight = f64::from(vis_weight) * *baseline_weight; - if weight <= 0.0 { - *vis_data = Jones::default(); - *vis_model = Jones::default(); - } else { - *vis_data = - Jones::::from(Jones::::from(*vis_data) * weight); - *vis_model = - Jones::::from(Jones::::from(*vis_model) * weight); - } - }); - }); - }); - - info!("Finished reading input data and sky modelling"); - - Ok(CalVis { + let mut cal_vis = CalVis { vis_data, vis_weights, vis_model, pols: obs_context.polarisations, - }) + }; + cal_vis.scale_by_weights(Some(&self.baseline_weights)); + + info!("Finished reading input data and sky modelling"); + + Ok(cal_vis) } } @@ -534,6 +499,63 @@ pub(crate) struct CalVis { pub(crate) pols: Polarisations, } +impl CalVis { + // Multiply the data and model visibilities by the weights (and baseline + // weights that could be e.g. based on UVW cuts). If a weight is negative, + // it means the corresponding visibility should be flagged, so that + // visibility is set to 0; this means it does not affect calibration. Not + // iterating over weights during calibration makes makes calibration run + // significantly faster. + pub(crate) fn scale_by_weights(&mut self, baseline_weights: Option<&[f64]>) { + debug!("Multiplying visibilities by weights"); + + // Ensure that the number of baseline weights is the same as the number + // of baselines. + if let Some(w) = baseline_weights { + assert_eq!(w.len(), self.vis_data.len_of(Axis(2))); + } + + self.vis_data + .outer_iter_mut() + .into_par_iter() + .zip(self.vis_model.outer_iter_mut()) + .zip(self.vis_weights.outer_iter()) + .for_each(|((mut vis_data, mut vis_model), vis_weights)| { + vis_data + .outer_iter_mut() + .zip(vis_model.outer_iter_mut()) + .zip(vis_weights.outer_iter()) + .for_each(|((mut vis_data, mut vis_model), vis_weights)| { + vis_data + .iter_mut() + .zip(vis_model.iter_mut()) + .zip(vis_weights.iter()) + .zip( + baseline_weights + .map(|w| w.iter().cycle()) + .unwrap_or_else(|| [1.0].iter().cycle()), + ) + .for_each( + |(((vis_data, vis_model), &vis_weight), &baseline_weight)| { + let weight = f64::from(vis_weight) * baseline_weight; + if weight <= 0.0 { + *vis_data = Jones::default(); + *vis_model = Jones::default(); + } else { + *vis_data = Jones::::from( + Jones::::from(*vis_data) * weight, + ); + *vis_model = Jones::::from( + Jones::::from(*vis_model) * weight, + ); + } + }, + ); + }); + }); + } +} + #[derive(thiserror::Error, Debug)] pub(crate) enum DiCalibrateError { #[error("Insufficient memory available to perform calibration; need {need_gib} of memory.\nYou could try using fewer timesteps and channels.")] diff --git a/src/params/mod.rs b/src/params/mod.rs index c57289ef..d518aedd 100644 --- a/src/params/mod.rs +++ b/src/params/mod.rs @@ -12,15 +12,15 @@ mod di_calibration; mod input_vis; +mod peel; mod solutions_apply; mod vis_convert; mod vis_simulate; mod vis_subtract; -#[cfg(test)] -pub(crate) use di_calibration::CalVis; pub(crate) use di_calibration::{DiCalParams, DiCalibrateError}; pub(crate) use input_vis::InputVisParams; +pub(crate) use peel::{PeelError, PeelParams}; pub(crate) use solutions_apply::SolutionsApplyParams; pub(crate) use vis_convert::{VisConvertError, VisConvertParams}; pub(crate) use vis_simulate::{VisSimulateError, VisSimulateParams}; diff --git a/src/params/peel/mod.rs b/src/params/peel/mod.rs new file mode 100644 index 00000000..93be8464 --- /dev/null +++ b/src/params/peel/mod.rs @@ -0,0 +1,2336 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#[cfg(test)] +mod tests; + +use std::{ + f64::consts::TAU, + io::Write, + num::NonZeroUsize, + ops::{Div, Neg, Sub}, + path::PathBuf, + thread::{self, ScopedJoinHandle}, +}; + +use crossbeam_channel::{bounded, unbounded}; +use crossbeam_utils::atomic::AtomicCell; +use hifitime::{Duration, Epoch}; +use indexmap::IndexMap; +use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle}; +use itertools::{izip, Itertools}; +use log::{debug, info, trace}; +#[cfg(any(feature = "cuda", feature = "hip"))] +use marlu::pos::xyz::xyzs_to_cross_uvws; +use marlu::{ + constants::{FREQ_WEIGHT_FACTOR, TIME_WEIGHT_FACTOR, VEL_C}, + precession::{get_lmst, precess_time}, + HADec, Jones, LatLngHeight, RADec, XyzGeodetic, UVW, +}; +use ndarray::{prelude::*, Zip}; +use num_complex::Complex; +use num_traits::Zero; +use rayon::prelude::*; +use scopeguard::defer_on_unwind; +use serde::{Deserialize, Serialize}; +use vec1::Vec1; + +use crate::{ + averaging::{vis_average_no_negative_weights, Spw, Timeblock}, + beam::Beam, + context::ObsContext, + io::{ + read::VisReadError, + write::{write_vis, VisTimestep}, + }, + model::{new_sky_modeller, ModelDevice, ModelError, SkyModeller, SkyModellerCpu}, + srclist::SourceList, + MODEL_DEVICE, PROGRESS_BARS, +}; +#[cfg(any(feature = "cuda", feature = "hip"))] +use crate::{ + gpu::{self, DevicePointer, GpuError, GpuFloat}, + model::SkyModellerGpu, +}; + +use super::{InputVisParams, ModellingParams, OutputVisParams}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct SourceIonoConsts { + pub(crate) alphas: Vec, + pub(crate) betas: Vec, + pub(crate) gains: Vec, + pub(crate) weighted_catalogue_pos_j2000: RADec, + // pub(crate) centroid_timestamps: Vec, +} + +pub(crate) struct PeelParams { + pub(crate) input_vis_params: InputVisParams, + pub(crate) output_vis_params: Option, + pub(crate) iono_outputs: Vec, + pub(crate) beam: Box, + pub(crate) source_list: SourceList, + pub(crate) modelling_params: ModellingParams, + pub(crate) iono_timeblocks: Vec1, + pub(crate) iono_time_average_factor: NonZeroUsize, + pub(crate) low_res_spw: Spw, + pub(crate) num_sources_to_iono_subtract: usize, + pub(crate) num_passes: NonZeroUsize, +} + +impl PeelParams { + pub(crate) fn run(&self) -> Result<(), PeelError> { + // Expose all the struct fields to ensure they're all used. + let PeelParams { + input_vis_params, + output_vis_params, + iono_outputs, + beam, + source_list, + modelling_params: ModellingParams { apply_precession }, + iono_timeblocks, + iono_time_average_factor, + low_res_spw, + num_sources_to_iono_subtract, + num_passes, + } = self; + + let obs_context = input_vis_params.get_obs_context(); + let num_unflagged_tiles = input_vis_params.get_num_unflagged_tiles(); + let num_unflagged_cross_baselines = (num_unflagged_tiles * (num_unflagged_tiles - 1)) / 2; + let array_position = obs_context.array_position; + let tile_baseline_flags = &input_vis_params.tile_baseline_flags; + let flagged_tiles = &tile_baseline_flags.flagged_tiles; + + let unflagged_tile_xyzs: Vec = obs_context + .tile_xyzs + .par_iter() + .enumerate() + .filter(|(tile_index, _)| !flagged_tiles.contains(tile_index)) + .map(|(_, xyz)| *xyz) + .collect(); + + let spw = &input_vis_params.spw; + let all_fine_chan_freqs_hz = + Vec1::try_from_vec(spw.chanblocks.iter().map(|c| c.freq).collect()).unwrap(); + let all_fine_chan_lambdas_m = all_fine_chan_freqs_hz.mapped_ref(|f| VEL_C / *f); + let (low_res_freqs_hz, low_res_lambdas_m): (Vec<_>, Vec<_>) = low_res_spw + .chanblocks + .iter() + .map(|c| { + let f = c.freq; + (f, VEL_C / f) + }) + .unzip(); + + // Finding the Stokes-I-weighted `RADec` of each source. + let source_weighted_positions = { + let mut component_radecs = vec![]; + let mut component_stokes_is = vec![]; + let mut source_weighted_positions = Vec::with_capacity(*num_sources_to_iono_subtract); + for source in source_list.values().take(*num_sources_to_iono_subtract) { + component_radecs.clear(); + component_stokes_is.clear(); + for comp in source.components.iter() { + component_radecs.push(comp.radec); + // TODO: Do this properly. + component_stokes_is.push(1.0); + } + + source_weighted_positions.push( + RADec::weighted_average(&component_radecs, &component_stokes_is) + .expect("component RAs aren't too far apart from one another"), + ); + } + source_weighted_positions + }; + + let error = AtomicCell::new(false); + let (tx_data, rx_data) = bounded(1); + let (tx_residual, rx_residual) = bounded(1); + let (tx_write, rx_write) = bounded(2); + let (tx_iono_consts, rx_iono_consts) = unbounded(); + + // Progress bars. Courtesy Dev. + let multi_progress = MultiProgress::with_draw_target(if PROGRESS_BARS.load() { + ProgressDrawTarget::stdout() + } else { + ProgressDrawTarget::hidden() + }); + let pb = ProgressBar::new(input_vis_params.timeblocks.len() as _) + .with_style( + ProgressStyle::default_bar() + .template("{msg:17}: [{wide_bar:.blue}] {pos:2}/{len:2} timesteps ({elapsed_precise}<{eta_precise})").unwrap() + .progress_chars("=> "), + ) + .with_position(0) + .with_message("Reading data"); + let read_progress = multi_progress.add(pb); + let pb = ProgressBar::new(input_vis_params.timeblocks.len() as _) + .with_style( + ProgressStyle::default_bar() + .template("{msg:17}: [{wide_bar:.blue}] {pos:2}/{len:2} timesteps ({elapsed_precise}<{eta_precise})").unwrap() + .progress_chars("=> "), + ) + .with_position(0) + .with_message("Sky modelling"); + let model_progress = multi_progress.add(pb); + let pb = ProgressBar::new(iono_timeblocks.len() as _) + .with_style( + ProgressStyle::default_bar() + .template("{msg:17}: [{wide_bar:.blue}] {pos:2}/{len:2} timeblocks ({elapsed_precise}<{eta_precise})").unwrap() + .progress_chars("=> "), + ) + .with_position(0) + .with_message("Peeling timeblocks"); + let overall_peel_progress = multi_progress.add(pb); + let write_progress = if let Some(output_vis_params) = output_vis_params { + let pb = ProgressBar::new(output_vis_params.output_timeblocks.len() as _) + .with_style( + ProgressStyle::default_bar() + .template("{msg:17}: [{wide_bar:.blue}] {pos:2}/{len:2} timeblocks ({elapsed_precise}<{eta_precise})").unwrap() + .progress_chars("=> "), + ) + .with_position(0) + .with_message("Writing visibilities"); + Some(multi_progress.add(pb)) + } else { + None + }; + + thread::scope(|scope| { + // Input visibility-data reading thread. + let read_handle: ScopedJoinHandle> = thread::Builder::new() + .name("read".to_string()) + .spawn_scoped(scope, || { + // If a panic happens, update our atomic error. + defer_on_unwind! { error.store(true); } + read_progress.tick(); + + for iono_timeblock in iono_timeblocks { + // Make a new block of data to be passed along; this + // contains the target number of timesteps per peel + // cadence but is averaged to the + // `iono_freq_average_factor`. + let mut iono_data_tfb = Array3::zeros(( + iono_timeblock.timestamps.len(), + spw.chanblocks.len(), + num_unflagged_cross_baselines, + )); + let mut iono_weights_tfb = Array3::zeros(iono_data_tfb.dim()); + + for (&(timestamp, i_timeblock), iono_data_fb, iono_weights_fb) in izip!( + iono_timeblock.timestamps.iter(), + iono_data_tfb.outer_iter_mut(), + iono_weights_tfb.outer_iter_mut() + ) { + let timeblock = &input_vis_params.timeblocks[i_timeblock]; + let result = input_vis_params.read_timeblock( + timeblock, + iono_data_fb, + iono_weights_fb, + None, + &error, + ); + + // If the result of reading data was an error, allow + // the other threads to see this so they can abandon + // their work early. + if result.is_err() { + error.store(true); + } + result?; + + // Should we continue? + if error.load() { + return Ok(()); + } + read_progress.inc(1); + } + + // Cap negative weights to 0. + iono_weights_tfb.iter_mut().for_each(|w| { + if *w <= 0.0 { + *w = -0.0; + } + }); + + match tx_data.send((iono_data_tfb, iono_weights_tfb, iono_timeblock)) { + Ok(()) => (), + // If we can't send the message, it's because the + // channel has been closed on the other side. That + // should only happen because the writer has exited + // due to error; in that case, just exit this + // thread. + Err(_) => return Ok(()), + } + } + + read_progress.abandon_with_message("Finished reading input data"); + drop(tx_data); + Ok(()) + }) + .expect("OS can create threads"); + + let model_handle: ScopedJoinHandle> = thread::Builder::new() + .name("model".to_string()) + .spawn_scoped(scope, || { + defer_on_unwind! { error.store(true); } + model_progress.tick(); + + let modeller = new_sky_modeller( + &**beam, + source_list, + obs_context.polarisations, + &unflagged_tile_xyzs, + &all_fine_chan_freqs_hz, + &tile_baseline_flags.flagged_tiles, + obs_context.phase_centre, + array_position.longitude_rad, + array_position.latitude_rad, + input_vis_params.dut1, + *apply_precession, + ) + .unwrap(); + + let mut vis_model_tfb = Array3::zeros(( + iono_time_average_factor.get(), + spw.chanblocks.len(), + num_unflagged_cross_baselines, + )); + for iono_timeblock in iono_timeblocks { + // The number of timestamps in a timeblock can vary; + // don't use zip_eq. + for (vis_model_slice, (timestamp, _)) in vis_model_tfb + .outer_iter_mut() + .zip(iono_timeblock.timestamps.iter()) + { + // Should we continue? + if error.load() { + return Ok(()); + } + + let result = modeller.model_timestep_with(*timestamp, vis_model_slice); + if let Err(e) = result { + error.store(true); + return Err(e); + } + + model_progress.inc(1); + } + + // We call the received data "residual", but the model + // needs to be subtracted for this to be a true + // "residual". That's what happens next. + let (mut vis_residual, vis_weights, timeblock) = rx_data.recv().unwrap(); + + // Don't rotate to each source and subtract; just + // subtract the full model. + for (mut vis_residual_slice, vis_model_slice) in vis_residual + .outer_iter_mut() + .zip(vis_model_tfb.outer_iter()) + { + vis_residual_slice -= &vis_model_slice; + } + + match tx_residual.send((vis_residual, vis_weights, timeblock)) { + Ok(()) => (), + Err(_) => return Ok(()), + } + } + + drop(tx_residual); + model_progress.abandon_with_message("Finished generating sky model"); + Ok(()) + }) + .expect("OS can create threads"); + + let peel_handle: ScopedJoinHandle> = thread::Builder::new() + .name("peel".to_string()) + .spawn_scoped(scope, || { + defer_on_unwind! { error.store(true); } + overall_peel_progress.tick(); + + for (i, (mut vis_residual_tfb, vis_weights_tfb, timeblock)) in + rx_residual.iter().enumerate() + { + // Should we continue? + if error.load() { + return Ok(()); + } + + let mut iono_consts = vec![(0.0, 0.0); *num_sources_to_iono_subtract]; + if *num_sources_to_iono_subtract > 0 { + if matches!(MODEL_DEVICE.load(), ModelDevice::Cpu) { + let mut low_res_modeller = SkyModellerCpu::new( + &**beam, + &SourceList::new(), + obs_context.polarisations, + &unflagged_tile_xyzs, + &low_res_freqs_hz, + flagged_tiles, + RADec::default(), + array_position.longitude_rad, + array_position.latitude_rad, + input_vis_params.dut1, + *apply_precession, + ); + let mut high_res_modeller = SkyModellerCpu::new( + &**beam, + &SourceList::new(), + obs_context.polarisations, + &unflagged_tile_xyzs, + &all_fine_chan_freqs_hz, + flagged_tiles, + RADec::default(), + array_position.longitude_rad, + array_position.latitude_rad, + input_vis_params.dut1, + *apply_precession, + ); + + peel_cpu( + vis_residual_tfb.view_mut(), + vis_weights_tfb.view(), + timeblock, + source_list, + &mut iono_consts, + &source_weighted_positions, + num_passes.get(), + &low_res_freqs_hz, + &all_fine_chan_lambdas_m, + &low_res_lambdas_m, + obs_context, + array_position, + &unflagged_tile_xyzs, + &mut low_res_modeller, + &mut high_res_modeller, + input_vis_params.dut1, + !apply_precession, + &multi_progress, + ) + .unwrap(); + } + + #[cfg(any(feature = "cuda", feature = "hip"))] + if matches!(MODEL_DEVICE.load(), ModelDevice::Gpu) { + let mut low_res_modeller = SkyModellerGpu::new( + &**beam, + &SourceList::new(), + obs_context.polarisations, + &unflagged_tile_xyzs, + &low_res_freqs_hz, + flagged_tiles, + RADec::default(), + array_position.longitude_rad, + array_position.latitude_rad, + input_vis_params.dut1, + *apply_precession, + )?; + let mut high_res_modeller = SkyModellerGpu::new( + &**beam, + &SourceList::new(), + obs_context.polarisations, + &unflagged_tile_xyzs, + &all_fine_chan_freqs_hz, + flagged_tiles, + RADec::default(), + array_position.longitude_rad, + array_position.latitude_rad, + input_vis_params.dut1, + *apply_precession, + )?; + peel_gpu( + vis_residual_tfb.view_mut(), + vis_weights_tfb.view(), + timeblock, + source_list, + &mut iono_consts, + &source_weighted_positions, + num_passes.get(), + &low_res_freqs_hz, + &all_fine_chan_lambdas_m, + &low_res_lambdas_m, + obs_context, + array_position, + &unflagged_tile_xyzs, + &mut low_res_modeller, + &mut high_res_modeller, + input_vis_params.dut1, + !apply_precession, + &multi_progress, + ) + .unwrap(); + } + + // dev: what's with this? + if i == 0 { + source_list + .iter() + .take(10) + .zip(iono_consts.iter()) + .for_each(|((name, src), iono_consts)| { + multi_progress + .println(format!( + "{name}: {:+.5e} {:+.5e} ({})", + iono_consts.0, + iono_consts.1, + src.components[0].radec, + )) + .unwrap(); + }); + } + } + + match tx_iono_consts.send(iono_consts) { + Ok(()) => (), + Err(_) => return Ok(()), + } + + for ((cross_data_fb, cross_weights_fb), (timestamp, _)) in vis_residual_tfb + .outer_iter() + .zip(vis_weights_tfb.outer_iter()) + .zip(timeblock.timestamps.iter()) + { + // TODO: Puke. + let cross_data_fb = cross_data_fb.to_shared(); + let cross_weights_fb = cross_weights_fb.to_shared(); + if output_vis_params.is_some() { + match tx_write.send(VisTimestep { + cross_data_fb, + cross_weights_fb, + autos: None, + timestamp: *timestamp, + }) { + Ok(()) => (), + Err(_) => return Ok(()), + } + } + } + + overall_peel_progress.inc(1); + } + overall_peel_progress.abandon_with_message("Finished peeling"); + drop(tx_write); + drop(tx_iono_consts); + + Ok(()) + }) + .expect("OS can create threads"); + + let write_handle = thread::Builder::new() + .name("write".to_string()) + .spawn_scoped(scope, || { + defer_on_unwind! { error.store(true); } + + if let Some(output_vis_params) = output_vis_params.as_ref() { + write_progress + .as_ref() + .expect("is available if we're writing output vis") + .tick(); + + let result = write_vis( + &output_vis_params.output_files, + array_position, + obs_context.phase_centre, + obs_context.pointing_centre, + &obs_context.tile_xyzs, + &obs_context.tile_names, + obs_context.obsid, + &output_vis_params.output_timeblocks, + input_vis_params.time_res, + input_vis_params.dut1, + &input_vis_params.spw, + &tile_baseline_flags + .unflagged_cross_baseline_to_tile_map + .values() + .copied() + .sorted() + .collect::>(), + output_vis_params.output_time_average_factor, + output_vis_params.output_freq_average_factor, + input_vis_params.vis_reader.get_marlu_mwa_info().as_ref(), + rx_write, + &error, + write_progress, + ); + match result { + Ok(m) => info!("{m}"), + Err(e) => { + error.store(true); + return Err(e); + } + } + } + + if !iono_outputs.is_empty() { + // Write out the iono consts. First, allocate a space + // for all the results. We use an IndexMap to keep the + // order of the sources preserved while also being able + // to write out a "HashMap-style" json. + let mut output_iono_consts: IndexMap<&str, SourceIonoConsts> = source_list + .iter() + .take(*num_sources_to_iono_subtract) + .zip_eq(source_weighted_positions.iter().copied()) + .map(|((name, _src), weighted_pos)| { + ( + name.as_str(), + SourceIonoConsts { + alphas: vec![], + betas: vec![], + gains: vec![], + weighted_catalogue_pos_j2000: weighted_pos, + }, + ) + }) + .collect(); + + // Store the results as they are received on the + // channel. + while let Ok(incoming_iono_consts) = rx_iono_consts.recv() { + incoming_iono_consts + .into_iter() + .zip_eq(output_iono_consts.iter_mut()) + .for_each(|((alpha, beta), (_src_name, src_iono_consts))| { + src_iono_consts.alphas.push(alpha); + src_iono_consts.betas.push(beta); + }); + } + + // The channel has stopped sending results; write them + // out to a file. + let output_json_string = + serde_json::to_string_pretty(&output_iono_consts).unwrap(); + for iono_output in iono_outputs { + let mut file = std::fs::File::create(iono_output)?; + file.write_all(output_json_string.as_bytes())?; + } + } + + Ok(()) + }) + .expect("OS can create threads"); + + read_handle.join().unwrap().unwrap(); + model_handle.join().unwrap().unwrap(); + peel_handle.join().unwrap().unwrap(); + write_handle.join().unwrap().unwrap(); + }); + + Ok(()) + } +} + +fn _get_weights_rts( + tile_uvs: ArrayView2, + lambdas_m: &[f64], + short_sigma: f64, + weight_factor: f64, +) -> Array3 { + let (num_timesteps, num_tiles) = tile_uvs.dim(); + let num_cross_baselines = (num_tiles * (num_tiles - 1)) / 2; + + let mut weights = Array3::zeros((num_timesteps, lambdas_m.len(), num_cross_baselines)); + weights + .outer_iter_mut() + .into_par_iter() + .zip_eq(tile_uvs.outer_iter()) + .for_each(|(mut weights, tile_uvs)| { + let mut i_tile1 = 0; + let mut i_tile2 = 0; + let mut tile1_uv = tile_uvs[i_tile1]; + let mut tile2_uv = tile_uvs[i_tile2]; + weights.axis_iter_mut(Axis(1)).for_each(|mut weights| { + i_tile2 += 1; + if i_tile2 == num_tiles { + i_tile1 += 1; + i_tile2 = i_tile1 + 1; + tile1_uv = tile_uvs[i_tile1]; + } + tile2_uv = tile_uvs[i_tile2]; + let uv = tile1_uv - tile2_uv; + + weights + .iter_mut() + .zip_eq(lambdas_m) + .for_each(|(weight, lambda_m)| { + let UV { u, v } = uv / *lambda_m; + // 1 - exp(-(u*u+v*v)/(2*sig^2)) + let uv_sq = u * u + v * v; + let exp = (-uv_sq / (2.0 * short_sigma * short_sigma)).exp(); + *weight = (weight_factor * (1.0 - exp)) as f32; + }); + }); + }); + weights +} + +// /// Average "high-res" vis and weights to "low-res" vis and weights +// /// arguments are all 3D arrays with axes (time, freq, baseline) +// // TODO (Dev): rename to vis_weight_average_tfb +// fn vis_average( +// jones_from: ArrayView3>, +// mut jones_to: ArrayViewMut3>, +// weight_from: ArrayView3, +// mut weight_to: ArrayViewMut3, +// ) { +// let from_dims = jones_from.dim(); +// let (time_axis, freq_axis, baseline_axis) = (Axis(0), Axis(1), Axis(2)); +// let avg_time = jones_from.len_of(time_axis) / jones_to.len_of(time_axis); +// let avg_freq = jones_from.len_of(freq_axis) / jones_to.len_of(freq_axis); + +// assert_eq!(from_dims, weight_from.dim()); +// let to_dims = jones_to.dim(); +// assert_eq!( +// to_dims, +// ( +// (from_dims.0 as f64 / avg_time as f64).floor() as usize, +// (from_dims.1 as f64 / avg_freq as f64).floor() as usize, +// from_dims.2, +// ) +// ); +// assert_eq!(to_dims, weight_to.dim()); + +// let num_tiles = +// num_tiles_from_num_cross_correlation_baselines(jones_from.len_of(baseline_axis)); +// assert_eq!( +// (num_tiles * (num_tiles - 1)) / 2, +// jones_from.len_of(baseline_axis) +// ); + +// // iterate along time axis in chunks of avg_time +// jones_from +// .axis_chunks_iter(time_axis, avg_time) +// .zip_eq(weight_from.axis_chunks_iter(time_axis, avg_time)) +// .zip_eq(jones_to.outer_iter_mut()) +// .zip_eq(weight_to.outer_iter_mut()) +// .for_each( +// |(((jones_chunk, weight_chunk), mut jones_to), mut weight_to)| { +// // iterate along baseline axis +// let mut i_tile1 = 0; +// let mut i_tile2 = 0; +// jones_chunk +// .axis_iter(Axis(2)) +// .zip_eq(weight_chunk.axis_iter(Axis(2))) +// .zip_eq(jones_to.axis_iter_mut(Axis(1))) +// .zip_eq(weight_to.axis_iter_mut(Axis(1))) +// .for_each( +// |(((jones_chunk, weight_chunk), mut jones_to), mut weight_to)| { +// i_tile2 += 1; +// if i_tile2 == num_tiles { +// i_tile1 += 1; +// i_tile2 = i_tile1 + 1; +// } + +// jones_chunk +// .axis_chunks_iter(Axis(1), avg_freq) +// .zip_eq(weight_chunk.axis_chunks_iter(Axis(1), avg_freq)) +// .zip_eq(jones_to.iter_mut()) +// .zip_eq(weight_to.iter_mut()) +// .for_each( +// |(((jones_chunk, weight_chunk), jones_to), weight_to)| { +// let mut jones_weighted_sum = Jones::default(); +// let mut weight_sum = 0.0; + +// // iterate through time chunks +// jones_chunk +// .outer_iter() +// .zip_eq(weight_chunk.outer_iter()) +// .for_each(|(jones_chunk, weights_chunk)| { +// jones_chunk +// .iter() +// .zip_eq(weights_chunk.iter()) +// .for_each(|(jones, weight)| { +// // Any flagged +// // visibilities would +// // have a weight <= 0, +// // but we've already +// // capped them to 0. +// // This means we don't +// // need to check the +// // value of the weight +// // when accumulating +// // unflagged +// // visibilities; the +// // flagged ones +// // contribute nothing. + +// let jones = Jones::::from(*jones); +// let weight = *weight as f64; +// jones_weighted_sum += jones * weight; +// weight_sum += weight; +// }); +// }); + +// if weight_sum > 0.0 { +// *jones_to = +// Jones::from(jones_weighted_sum / weight_sum); +// *weight_to = weight_sum as f32; +// } +// }, +// ); +// }, +// ); +// }, +// ); +// } + +// TODO (dev): a.div_ceil(b) would be better, but it's nightly: +// https://doc.rust-lang.org/std/primitive.i32.html#method.div_ceil +fn div_ceil(a: usize, b: usize) -> usize { + (a + b - 1) / b +} + +/// Like `vis_weight_average_tfb`, but for when we don't need to keep the low-res weights +/// Average "high-res" vis and weights to "low-res" vis (no low-res weights) +/// arguments are all 3D arrays with axes (time, freq, baseline). +/// assumes weights are capped to 0 +// TODO (Dev): rename to vis_average_tfb +fn vis_average2( + jones_from_tfb: ArrayView3>, + mut jones_to_tfb: ArrayViewMut3>, + weight_tfb: ArrayView3, +) { + let from_dims = jones_from_tfb.dim(); + let (time_axis, freq_axis, baseline_axis) = (Axis(0), Axis(1), Axis(2)); + let avg_time = div_ceil( + jones_from_tfb.len_of(time_axis), + jones_to_tfb.len_of(time_axis), + ); + let avg_freq = div_ceil( + jones_from_tfb.len_of(freq_axis), + jones_to_tfb.len_of(freq_axis), + ); + + assert_eq!(from_dims, weight_tfb.dim()); + let to_dims = jones_to_tfb.dim(); + assert_eq!( + to_dims, + ( + div_ceil(from_dims.0, avg_time), + div_ceil(from_dims.1, avg_freq), + from_dims.2, + ) + ); + + // iterate along time axis in chunks of avg_time + for (jones_chunk_tfb, weight_chunk_tfb, mut jones_to_fb) in izip!( + jones_from_tfb.axis_chunks_iter(time_axis, avg_time), + weight_tfb.axis_chunks_iter(time_axis, avg_time), + jones_to_tfb.outer_iter_mut() + ) { + for (jones_chunk_tfb, weight_chunk_tfb, mut jones_to_b) in izip!( + jones_chunk_tfb.axis_chunks_iter(freq_axis, avg_freq), + weight_chunk_tfb.axis_chunks_iter(freq_axis, avg_freq), + jones_to_fb.outer_iter_mut() + ) { + // iterate along baseline axis + for (jones_chunk_tf, weight_chunk_tf, jones_to) in izip!( + jones_chunk_tfb.axis_iter(baseline_axis), + weight_chunk_tfb.axis_iter(baseline_axis), + jones_to_b.iter_mut() + ) { + let mut jones_weighted_sum = Jones::zero(); + let mut weight_sum: f64 = 0.0; + for (&jones, &weight) in jones_chunk_tf.iter().zip_eq(weight_chunk_tf.iter()) { + // assumes weights are capped to 0. otherwise we would need to check weight >= 0 + debug_assert!(weight >= 0.0, "weight was not capped to zero: {}", weight); + jones_weighted_sum += Jones::::from(jones) * weight as f64; + weight_sum += weight as f64; + } + + if weight_sum > 0.0 { + *jones_to = Jones::from(jones_weighted_sum / weight_sum); + } + } + } + } +} + +fn weights_average(weight_tfb: ArrayView3, mut weight_avg_tfb: ArrayViewMut3) { + let from_dims = weight_tfb.dim(); + let (time_axis, freq_axis, baseline_axis) = (Axis(0), Axis(1), Axis(2)); + let avg_time = div_ceil( + weight_tfb.len_of(time_axis), + weight_avg_tfb.len_of(time_axis), + ); + let avg_freq = div_ceil( + weight_tfb.len_of(freq_axis), + weight_avg_tfb.len_of(freq_axis), + ); + + let to_dims = weight_avg_tfb.dim(); + assert_eq!( + to_dims, + ( + div_ceil(from_dims.0, avg_time), + div_ceil(from_dims.1, avg_freq), + from_dims.2, + ) + ); + + // iterate along time axis in chunks of avg_time + for (weight_chunk_tfb, mut weight_avg_fb) in izip!( + weight_tfb.axis_chunks_iter(time_axis, avg_time), + weight_avg_tfb.outer_iter_mut() + ) { + // iterate along frequency axis in chunks of avg_freq + for (weight_chunk_tfb, mut weight_avg_b) in izip!( + weight_chunk_tfb.axis_chunks_iter(freq_axis, avg_freq), + weight_avg_fb.outer_iter_mut() + ) { + // iterate along baseline axis + for (weight_chunk_tf, weight_avg) in izip!( + weight_chunk_tfb.axis_iter(baseline_axis), + weight_avg_b.iter_mut() + ) { + let mut weight_sum: f64 = 0.0; + for &weight in weight_chunk_tf.iter() { + weight_sum += weight as f64; + } + + *weight_avg = (weight_sum as f32).max(0.); + } + } + } +} + +// #[allow(clippy::too_many_arguments)] +// #[deprecated = "doesn't support --no-precession"] +// fn vis_rotate2( +// jones_array: ArrayView3>, +// mut jones_array2: ArrayViewMut3>, +// phase_to: RADec, +// tile_xyzs: ArrayView2, +// tile_ws_from: ArrayView2, +// mut tile_ws_to: ArrayViewMut2, +// lmsts: &[f64], +// fine_chan_lambdas_m: &[f64], +// ) { +// let num_tiles = tile_xyzs.len_of(Axis(1)); +// assert_eq!(tile_ws_from.len_of(Axis(1)), num_tiles); +// assert_eq!(tile_ws_to.len_of(Axis(1)), num_tiles); + +// // iterate along time axis in chunks of avg_time +// jones_array +// .outer_iter() +// .into_par_iter() +// .zip(jones_array2.outer_iter_mut()) +// .zip(tile_ws_from.outer_iter()) +// .zip(tile_ws_to.outer_iter_mut()) +// .zip(tile_xyzs.outer_iter()) +// .zip(lmsts.par_iter()) +// .for_each( +// |(((((vis_tfb, mut vis_rot_tfb), tile_ws_from), mut tile_ws_to), tile_xyzs), lmst)| { +// assert_eq!(tile_ws_from.len(), num_tiles); +// // Generate the "to" Ws. +// let phase_to = phase_to.to_hadec(*lmst); +// setup_ws( +// tile_ws_to.as_slice_mut().unwrap(), +// tile_xyzs.as_slice().unwrap(), +// phase_to, +// ); + +// vis_rotate_fb( +// vis_tfb, +// vis_rot_tfb, +// tile_ws_from.as_slice().unwrap(), +// tile_ws_to.as_slice().unwrap(), +// fine_chan_lambdas_m, +// ); +// }, +// ); +// } + +fn vis_rotate_fb( + vis_fb: ArrayView2>, + mut vis_rot_fb: ArrayViewMut2>, + tile_ws_from: &[W], + tile_ws_to: &[W], + fine_chan_lambdas_m: &[f64], +) { + let num_tiles = tile_ws_from.len(); + let mut i_tile1 = 0; + let mut i_tile2 = 0; + let mut tile1_w_from = tile_ws_from[i_tile1]; + let mut tile2_w_from = tile_ws_from[i_tile2]; + let mut tile1_w_to = tile_ws_to[i_tile1]; + let mut tile2_w_to = tile_ws_to[i_tile2]; + // iterate along baseline axis + vis_fb + .axis_iter(Axis(1)) + .zip(vis_rot_fb.axis_iter_mut(Axis(1))) + .for_each(|(vis_f, mut vis_rot_f)| { + i_tile2 += 1; + if i_tile2 == num_tiles { + i_tile1 += 1; + i_tile2 = i_tile1 + 1; + tile1_w_from = tile_ws_from[i_tile1]; + tile1_w_to = tile_ws_to[i_tile1]; + } + tile2_w_from = tile_ws_from[i_tile2]; + tile2_w_to = tile_ws_to[i_tile2]; + + let w_diff = (tile1_w_to - tile2_w_to) - (tile1_w_from - tile2_w_from); + let arg = -TAU * w_diff; + // iterate along frequency axis + vis_f + .iter() + .zip(vis_rot_f.iter_mut()) + .zip(fine_chan_lambdas_m.iter()) + .for_each(|((jones, jones_rot), lambda_m)| { + let rotation = Complex::cis(arg / *lambda_m); + *jones_rot = Jones::::from(Jones::::from(*jones) * rotation); + }); + }); +} + +/// Rotate the supplied visibilities according to the `λ²` constants of +/// proportionality with `exp(-2πi(αu+βv)λ²)`. +fn apply_iono2( + vis_tfb: ArrayView3>, + mut vis_iono_tfb: ArrayViewMut3>, + tile_uvs: ArrayView2, + const_lm: (f64, f64), + lambdas_m: &[f64], +) { + let num_tiles = tile_uvs.len_of(Axis(1)); + + // iterate along time axis + vis_tfb + .outer_iter() + .zip(vis_iono_tfb.outer_iter_mut()) + .zip(tile_uvs.outer_iter()) + .for_each(|((vis_fb, mut vis_iono_fb), tile_uvs)| { + // Just in case the compiler can't understand how an ndarray is laid + // out. + assert_eq!(tile_uvs.len(), num_tiles); + + // iterate along baseline axis + let mut i_tile1 = 0; + let mut i_tile2 = 0; + vis_fb + .axis_iter(Axis(1)) + .zip(vis_iono_fb.axis_iter_mut(Axis(1))) + .for_each(|(vis_f, mut vis_iono_f)| { + i_tile2 += 1; + if i_tile2 == num_tiles { + i_tile1 += 1; + i_tile2 = i_tile1 + 1; + } + + let UV { u, v } = tile_uvs[i_tile1] - tile_uvs[i_tile2]; + let arg = -TAU * (u * const_lm.0 + v * const_lm.1); + // iterate along frequency axis + vis_f + .iter() + .zip(vis_iono_f.iter_mut()) + .zip(lambdas_m.iter()) + .for_each(|((jones, jones_iono), lambda_m)| { + let j = Jones::::from(*jones); + // The baseline UV is in units of metres, so we need + // to divide by λ to use it in an exponential. But + // we're also multiplying by λ², so just multiply by + // λ. + let rotation = Complex::cis(arg * *lambda_m); + *jones_iono = Jones::from(j * rotation); + }); + }); + }); +} + +/// unpeel model, peel iono model +/// this is useful when vis_model has already been subtraced from vis_residual +/// TODO (Dev): rename to unpeel_model +fn apply_iono3( + vis_model: ArrayView3>, + mut vis_residual: ArrayViewMut3>, + tile_uvs: ArrayView2, + const_lm: (f64, f64), + old_const_lm: (f64, f64), + lambdas_m: &[f64], +) { + let num_tiles = tile_uvs.len_of(Axis(1)); + + // iterate along time axis + vis_model + .outer_iter() + .into_par_iter() + .zip(vis_residual.outer_iter_mut()) + .zip(tile_uvs.outer_iter()) + .for_each(|((vis_model, mut vis_residual), tile_uvs)| { + // Just in case the compiler can't understand how an ndarray is laid + // out. + assert_eq!(tile_uvs.len(), num_tiles); + + // iterate along baseline axis + let mut i_tile1 = 0; + let mut i_tile2 = 0; + vis_model + .axis_iter(Axis(1)) + .zip(vis_residual.axis_iter_mut(Axis(1))) + .for_each(|(vis_model, mut vis_residual)| { + i_tile2 += 1; + if i_tile2 == num_tiles { + i_tile1 += 1; + i_tile2 = i_tile1 + 1; + } + + let UV { u, v } = tile_uvs[i_tile1] - tile_uvs[i_tile2]; + let arg = -TAU * (u * const_lm.0 + v * const_lm.1); + let old_arg = -TAU * (u * old_const_lm.0 + v * old_const_lm.1); + // iterate along frequency axis + vis_model + .iter() + .zip(vis_residual.iter_mut()) + .zip(lambdas_m.iter()) + .for_each(|((vis_model, vis_residual), lambda_m)| { + let mut j = Jones::::from(*vis_residual); + let m = Jones::::from(*vis_model); + // The baseline UV is in units of metres, so we need + // to divide by λ to use it in an exponential. But + // we're also multiplying by λ², so just multiply by + // λ. + let old_rotation = Complex::cis(old_arg * *lambda_m); + j += m * old_rotation; + + let rotation = Complex::cis(arg * *lambda_m); + j -= m * rotation; + *vis_residual = Jones::from(j); + }); + }); + }); +} + +// the offsets as defined by the RTS code +// TODO: Assume there's only 1 timestep, because this is low res data? +fn iono_fit( + residual: ArrayView3>, + weights: ArrayView3, + model: ArrayView3>, + lambdas_m: &[f64], + tile_uvs_low_res: ArrayView2, +) -> [f64; 4] { + let num_tiles = tile_uvs_low_res.len_of(Axis(1)); + + // a-terms used in least-squares estimator + let (mut a_uu, mut a_uv, mut a_vv) = (0.0, 0.0, 0.0); + // A-terms used in least-squares estimator + let (mut aa_u, mut aa_v) = (0.0, 0.0); + // Excess amplitude in the visibilities (V) over the models (M) + let (mut s_vm, mut s_mm) = (0.0, 0.0); + + // iterate over time + residual + .outer_iter() + .zip(weights.outer_iter()) + .zip(model.outer_iter()) + .zip(tile_uvs_low_res.outer_iter()) + .for_each(|(((residual, weights), model), tile_uvs_low_res)| { + // iterate over frequency + residual + .outer_iter() + .zip(weights.outer_iter()) + .zip(model.outer_iter()) + .zip(lambdas_m.iter()) + .for_each(|(((residual, weights), model), &lambda)| { + let lambda_2 = lambda * lambda; + + let mut i_tile1 = 0; + let mut i_tile2 = 0; + let mut uv_tile1 = tile_uvs_low_res[i_tile1]; + let mut uv_tile2 = tile_uvs_low_res[i_tile2]; + + let mut a_uu_bl = 0.0; + let mut a_uv_bl = 0.0; + let mut a_vv_bl = 0.0; + let mut aa_u_bl = 0.0; + let mut aa_v_bl = 0.0; + let mut s_vm_bl = 0.0; + let mut s_mm_bl = 0.0; + + // iterate over baseline + residual + .iter() + .zip(weights.iter()) + .zip(model.iter()) + .for_each(|((residual, weight), model)| { + i_tile2 += 1; + if i_tile2 == num_tiles { + i_tile1 += 1; + i_tile2 = i_tile1 + 1; + uv_tile1 = tile_uvs_low_res[i_tile1]; + } + + if *weight > 0.0 { + uv_tile2 = tile_uvs_low_res[i_tile2]; + // Normally, we would divide by λ to get + // dimensionless UV. However, UV are only used + // to determine a_uu, a_uv, a_vv, which are also + // scaled by lambda. So... don't divide by λ. + let UV { u, v } = uv_tile1 - uv_tile2; + + // Stokes I of the residual visibilities and + // model visibilities. It doesn't matter if the + // convention is to divide by 2 or not; the + // algorithm's result is algebraically the same. + let residual_i = residual[0] + residual[3]; + let model_i = model[0] + model[3]; + + let model_i_re = model_i.re as f64; + let mr = model_i_re * (residual_i.im as f64 - model_i.im as f64); + let mm = model_i_re * model_i_re; + let s_vm = model_i_re * residual_i.re as f64; + let s_mm = mm; + let weight = *weight as f64; + + #[cfg(test)] + { + println!("uv ({:+1.5}, {:+1.5}) l{:+1.3} | RI {:+1.5} @{:+1.5}pi | MI {:+1.5} @{:+1.5}pi", u, v, lambda, residual_i.norm(), residual_i.arg(), model_i.norm(), model_i.arg()); + if i_tile1 == 0 && i_tile2 == 1 { + let a_uu_asdf = weight * mm * u * u; + let a_uv_asdf = weight * mm * u * v; + let a_vv_asdf = weight * mm * v * v; + dbg!(residual_i, model_i, weight, u, v, mr, mm, a_uu_asdf, a_uv_asdf, a_vv_asdf); + } + } + + // To avoid accumulating floating-point errors + // (and save some multiplies), multiplications + // with powers of lambda are done outside the + // loop. + a_uu_bl += weight * mm * u * u; + a_uv_bl += weight * mm * u * v; + a_vv_bl += weight * mm * v * v; + aa_u_bl += weight * mr * u; + aa_v_bl += weight * mr * v; + s_vm_bl += weight * s_vm; + s_mm_bl += weight * s_mm; + } + }); + + // As above, we didn't divide UV by lambda, so below we use + // λ² for λ⁴, and λ for λ². + a_uu += a_uu_bl * lambda_2; + a_uv += a_uv_bl * lambda_2; + a_vv += a_vv_bl * lambda_2; + aa_u += aa_u_bl * -lambda; + aa_v += aa_v_bl * -lambda; + s_vm += s_vm_bl; + s_mm += s_mm_bl; + }); + }); + + let denom = TAU * (a_uu * a_vv - a_uv * a_uv); + #[cfg(test)] + dbg!(a_uu, a_vv, a_uv, denom); + [ + (aa_u * a_vv - aa_v * a_uv) / denom, + (aa_v * a_uu - aa_u * a_uv) / denom, + s_vm, + s_mm, + ] +} + +#[cfg(test)] +fn setup_ws(tile_ws: &mut [W], tile_xyzs: &[XyzGeodetic], phase_centre: HADec) { + assert_eq!(tile_ws.len(), tile_xyzs.len()); + let (s_ha, c_ha) = phase_centre.ha.sin_cos(); + let (s_dec, c_dec) = phase_centre.dec.sin_cos(); + tile_ws + .iter_mut() + .zip(tile_xyzs.iter().copied()) + .for_each(|(tile_w, tile_xyz)| { + *tile_w = W::_from_xyz(tile_xyz, s_ha, c_ha, s_dec, c_dec); + }); +} + +fn setup_uvs(tile_uvs: &mut [UV], tile_xyzs: &[XyzGeodetic], phase_centre: HADec) { + assert_eq!(tile_uvs.len(), tile_xyzs.len()); + let (s_ha, c_ha) = phase_centre.ha.sin_cos(); + let (s_dec, c_dec) = phase_centre.dec.sin_cos(); + tile_uvs + .iter_mut() + .zip(tile_xyzs.iter().copied()) + .for_each(|(tile_uv, tile_xyz)| { + *tile_uv = UV::from_xyz(tile_xyz, s_ha, c_ha, s_dec, c_dec); + }); +} + +fn model_timesteps( + modeller: &dyn SkyModeller, + timestamps: &[Epoch], + mut vis_result_tfb: ArrayViewMut3>, +) -> Result<(), ModelError> { + vis_result_tfb + .outer_iter_mut() + .zip(timestamps.iter()) + .try_for_each(|(mut vis_result, epoch)| { + modeller + .model_timestep_with(*epoch, vis_result.view_mut()) + .map(|_| ()) + }) +} + +#[allow(clippy::too_many_arguments)] +fn peel_cpu( + // TODO (Dev): I would name this vis_residual_tfb + mut vis_residual: ArrayViewMut3>, + // TODO (Dev): I would name this vis_weights + vis_weights: ArrayView3, + timeblock: &Timeblock, + source_list: &SourceList, + iono_consts: &mut [(f64, f64)], + source_weighted_positions: &[RADec], + num_passes: usize, + // TODO (dev): Why do we need both this and low_res_lambdas_m? it's not even used + _low_res_freqs_hz: &[f64], + all_fine_chan_lambdas_m: &[f64], + low_res_lambdas_m: &[f64], + obs_context: &ObsContext, + // TODO (dev): array_position is available from obs_context + array_position: LatLngHeight, + // TODO (dev): unflagged_tile_xyzs is available from obs_context + unflagged_tile_xyzs: &[XyzGeodetic], + low_res_modeller: &mut dyn SkyModeller, + high_res_modeller: &mut dyn SkyModeller, + // TODO (dev): dut1 is available from obs_context + dut1: Duration, + no_precession: bool, + multi_progress_bar: &MultiProgress, +) -> Result<(), PeelError> { + // TODO: Do we allow multiple timesteps in the low-res data? + + let timestamps = &timeblock.timestamps; + let num_timestamps_high_res = timestamps.len(); + let num_timestamps_low_res = 1; + let avg_time = num_timestamps_high_res / num_timestamps_low_res; + + let num_tiles = unflagged_tile_xyzs.len(); + let num_cross_baselines = (num_tiles * (num_tiles - 1)) / 2; + + let num_freqs_high_res = all_fine_chan_lambdas_m.len(); + let num_freqs_low_res = low_res_lambdas_m.len(); + + let num_sources = source_list.len(); + let num_sources_to_iono_subtract = iono_consts.len(); + + // TODO: these assertions should be actual errors. + let (time_axis, freq_axis, baseline_axis) = (Axis(0), Axis(1), Axis(2)); + + assert_eq!(vis_residual.len_of(time_axis), num_timestamps_high_res); + assert_eq!(vis_weights.len_of(time_axis), num_timestamps_high_res); + + assert_eq!(vis_residual.len_of(baseline_axis), num_cross_baselines); + assert_eq!(vis_weights.len_of(baseline_axis), num_cross_baselines); + + assert_eq!(vis_residual.len_of(freq_axis), num_freqs_high_res); + assert_eq!(vis_weights.len_of(freq_axis), num_freqs_high_res); + + assert_eq!(iono_consts.len(), num_sources_to_iono_subtract); + assert!(num_sources_to_iono_subtract <= num_sources); + + let peel_progress = multi_progress_bar.add( + ProgressBar::new(num_sources_to_iono_subtract as _) + .with_style( + ProgressStyle::default_bar() + .template("{msg:17}: [{wide_bar:.blue}] {pos:2}/{len:2} sources ({elapsed_precise}<{eta_precise})").unwrap() + .progress_chars("=> "), + ) + .with_position(0) + .with_message(format!("Peeling timeblock {}", timeblock.index + 1)), + ); + peel_progress.tick(); + + // observation phase center + let mut tile_uvs_high_res = Array2::::default((timestamps.len(), num_tiles)); + // TODO (Dev): I would name this "tile_ws_high_res" + let mut tile_ws_from = Array2::::default((timestamps.len(), num_tiles)); + // source phase center + // TODO (Dev): I would name this tile_uvs_high_res_rot + let mut tile_uvs_high_res_rot = tile_uvs_high_res.clone(); + // TODO (Dev): I would name this tile_ws_high_res_rot + let mut tile_ws_to = tile_ws_from.clone(); + // TODO (Dev): I would name this tile_uvs_low_res_rot + let mut tile_uvs_low_res = Array2::::default((1, num_tiles)); + + // Pre-compute high-res tile UVs and Ws at observation phase centre. + for (&(time, _), mut tile_uvs, mut tile_ws) in izip!( + timestamps.iter(), + tile_uvs_high_res.outer_iter_mut(), + tile_ws_from.outer_iter_mut(), + ) { + let (lmst, precessed_xyzs) = if !no_precession { + let precession_info = precess_time( + array_position.longitude_rad, + array_position.latitude_rad, + obs_context.phase_centre, + time, + dut1, + ); + let precessed_xyzs = precession_info.precess_xyz(unflagged_tile_xyzs); + (precession_info.lmst_j2000, precessed_xyzs) + } else { + let lmst = get_lmst(array_position.longitude_rad, time, dut1); + (lmst, unflagged_tile_xyzs.into()) + }; + let hadec_phase = obs_context.phase_centre.to_hadec(lmst); + let (s_ha, c_ha) = hadec_phase.ha.sin_cos(); + let (s_dec, c_dec) = hadec_phase.dec.sin_cos(); + for (tile_uv, tile_w, &precessed_xyzs) in izip!( + tile_uvs.iter_mut(), + tile_ws.iter_mut(), + precessed_xyzs.iter(), + ) { + let uvw = UVW::from_xyz_inner(precessed_xyzs, s_ha, c_ha, s_dec, c_dec); + *tile_uv = UV { u: uvw.u, v: uvw.v }; + *tile_w = W(uvw.w); + } + } + + // TODO (Dev): iono_taper_weights could be supplied to peel + // use the baseline taper from the RTS, 1-exp(-(u*u+v*v)/(2*sig^2)); + // let short_baseline_sigma = 20.; + // TODO: Do we care about weights changing over time? + // let vis_weights = { + // let mut iono_taper = get_weights_rts( + // tile_uvs_high_res.view(), + // all_fine_chan_lambdas_m, + // short_baseline_sigma, + // (obs_context.guess_freq_res() / FREQ_WEIGHT_FACTOR) + // * (obs_context.guess_time_res().to_seconds() / TIME_WEIGHT_FACTOR), + // ); + // iono_taper *= &vis_weights; + // iono_taper + // }; + + // Temporary visibility array, re-used for each timestep + // TODO (Dev): I would name this vis_residual_rot_tfb + let mut vis_residual_tmp = vis_residual.to_owned(); + let high_res_vis_dims = vis_residual.dim(); + let mut vis_model_high_res = Array3::default(high_res_vis_dims); + + // temporary arrays for accumulation + // TODO: Do a stocktake of arrays that are lying around! + // TODO (Dev): I would name this vis_residual_low_res_rot + let mut vis_residual_low_res: Array3> = Array3::zeros(( + num_timestamps_low_res, + num_freqs_low_res, + num_cross_baselines, + )); + let mut vis_model_low_res = vis_residual_low_res.clone(); + // TODO (Dev): I would name this vis_model_low_res_rot + let mut vis_model_low_res_tmp = vis_residual_low_res.clone(); + let mut vis_weights_low_res: Array3 = Array3::zeros(vis_residual_low_res.dim()); + + // The low-res weights only need to be populated once. + weights_average(vis_weights.view(), vis_weights_low_res.view_mut()); + + for pass in 0..num_passes { + for (((source_name, source), iono_consts), source_phase_centre) in source_list + .iter() + .take(num_sources_to_iono_subtract) + .zip_eq(iono_consts.iter_mut()) + .zip_eq(source_weighted_positions.iter().copied()) + { + multi_progress_bar.suspend(|| { + debug!("peel loop {pass}: {source_name} at {source_phase_centre} (has iono {iono_consts:?})") + }); + let start = std::time::Instant::now(); + let old_iono_consts = *iono_consts; + + low_res_modeller.update_with_a_source(source, source_phase_centre)?; + high_res_modeller.update_with_a_source(source, obs_context.phase_centre)?; + // this is only necessary for cpu modeller. + vis_model_low_res.fill(Jones::zero()); + + multi_progress_bar.suspend(|| { + trace!( + "{:?}: initialise modellers", + std::time::Instant::now() - start + ) + }); + // TODO (dev): model_timestep returns uvws, could re-use these here. + // iterate along time chunks: + // - calculate high-res uvws in source phase centre + // - rotate residuals to source phase centre + // - model low res visibilities in source phase centre + // - calculate low-res uvws in source phase centre + for ( + timestamps, + vis_residual_tfb, + mut vis_residual_rot_tfb, + mut vis_model_low_res_rot_fb, + tile_ws_high_res, + mut tile_uvs_high_res_rot, + mut tile_ws_high_res_rot, + mut tile_uvs_low_res_rot, + ) in izip!( + timestamps.chunks(avg_time), + vis_residual.axis_chunks_iter(time_axis, avg_time), + vis_residual_tmp.axis_chunks_iter_mut(time_axis, avg_time), + vis_model_low_res.outer_iter_mut(), + tile_ws_from.axis_chunks_iter(time_axis, avg_time), + tile_uvs_high_res_rot.axis_chunks_iter_mut(time_axis, avg_time), + tile_ws_to.axis_chunks_iter_mut(time_axis, avg_time), + tile_uvs_low_res.outer_iter_mut(), + ) { + multi_progress_bar.suspend(|| { + trace!( + "{:?}: calc source uvw, rotate residual", + std::time::Instant::now() - start + ) + }); + // iterate along high res times + for ( + &(time, _), + vis_residual_fb, + mut vis_residual_rot_fb, + tile_ws_high_res, + mut tile_uvs_high_res_rot, + mut tile_ws_high_res_rot, + ) in izip!( + timestamps, + vis_residual_tfb.outer_iter(), + vis_residual_rot_tfb.outer_iter_mut(), + tile_ws_high_res.outer_iter(), + tile_uvs_high_res_rot.outer_iter_mut(), + tile_ws_high_res_rot.outer_iter_mut(), + ) { + let (lmst, precessed_xyzs) = if !no_precession { + let precession_info = precess_time( + array_position.longitude_rad, + array_position.latitude_rad, + obs_context.phase_centre, + time, + dut1, + ); + let precessed_xyzs = precession_info.precess_xyz(unflagged_tile_xyzs); + (precession_info.lmst_j2000, precessed_xyzs) + } else { + let lmst = get_lmst(array_position.longitude_rad, time, dut1); + (lmst, unflagged_tile_xyzs.into()) + }; + let hadec_source = source_phase_centre.to_hadec(lmst); + let (s_ha, c_ha) = hadec_source.ha.sin_cos(); + let (s_dec, c_dec) = hadec_source.dec.sin_cos(); + for (tile_uv, tile_w, &precessed_xyz) in izip!( + tile_uvs_high_res_rot.iter_mut(), + tile_ws_high_res_rot.iter_mut(), + precessed_xyzs.iter(), + ) { + let UVW { u, v, w } = + UVW::from_xyz_inner(precessed_xyz, s_ha, c_ha, s_dec, c_dec); + *tile_uv = UV { u, v }; + *tile_w = W(w); + } + + vis_rotate_fb( + vis_residual_fb.view(), + vis_residual_rot_fb.view_mut(), + tile_ws_high_res.as_slice().unwrap(), + tile_ws_high_res_rot.as_slice().unwrap(), + all_fine_chan_lambdas_m, + ); + } + multi_progress_bar + .suspend(|| trace!("{:?}: low-res uvws", std::time::Instant::now() - start)); + + let low_res_epoch = timeblock.median; + // compute low-res tile UVs at source phase centre. + let (lmst, precessed_xyzs) = if !no_precession { + let precession_info = precess_time( + array_position.longitude_rad, + array_position.latitude_rad, + obs_context.phase_centre, + low_res_epoch, + dut1, + ); + let precessed_xyzs = precession_info.precess_xyz(unflagged_tile_xyzs); + (precession_info.lmst_j2000, precessed_xyzs) + } else { + let lmst = get_lmst(array_position.longitude_rad, low_res_epoch, dut1); + (lmst, unflagged_tile_xyzs.into()) + }; + let hadec_source = source_phase_centre.to_hadec(lmst); + setup_uvs( + tile_uvs_low_res_rot.as_slice_mut().unwrap(), + &precessed_xyzs, + hadec_source, + ); + + multi_progress_bar + .suspend(|| trace!("{:?}: low-res model", std::time::Instant::now() - start)); + low_res_modeller + .model_timestep_with(low_res_epoch, vis_model_low_res_rot_fb.view_mut())?; + } + + multi_progress_bar + .suspend(|| trace!("{:?}: vis_average", std::time::Instant::now() - start)); + vis_average2( + vis_residual_tmp.view(), + vis_residual_low_res.view_mut(), + vis_weights.view(), + ); + + // Add the low-res model to the residuals. If the iono consts are + // non-zero, then also shift the model before adding it. + multi_progress_bar + .suspend(|| trace!("{:?}: add low-res model", std::time::Instant::now() - start)); + if iono_consts.0.abs() > f64::EPSILON && iono_consts.1.abs() > f64::EPSILON { + apply_iono2( + vis_model_low_res.view(), + vis_model_low_res_tmp.view_mut(), + tile_uvs_low_res.view(), + *iono_consts, + low_res_lambdas_m, + ); + Zip::from(&mut vis_residual_low_res) + .and(&vis_model_low_res_tmp) + .for_each(|r, m| { + *r += *m; + }); + } else { + Zip::from(&mut vis_residual_low_res) + .and(&vis_model_low_res) + .for_each(|r, m| { + *r += *m; + }); + } + + multi_progress_bar + .suspend(|| trace!("{:?}: alpha/beta loop", std::time::Instant::now() - start)); + // let mut gain_update = 1.0; + let mut iteration = 0; + while iteration != 10 { + iteration += 1; + multi_progress_bar.suspend(|| debug!("iter {iteration}, consts: {iono_consts:?}")); + + // iono rotate model using existing iono consts + apply_iono2( + vis_model_low_res.view(), + vis_model_low_res_tmp.view_mut(), + tile_uvs_low_res.view(), + *iono_consts, + low_res_lambdas_m, + ); + + let iono_fits = iono_fit( + vis_residual_low_res.view(), + vis_weights_low_res.view(), + vis_model_low_res_tmp.view(), + low_res_lambdas_m, + tile_uvs_low_res.view(), + ); + multi_progress_bar.suspend(|| trace!("iono_fits: {iono_fits:?}")); + + iono_consts.0 += iono_fits[0]; + iono_consts.1 += iono_fits[1]; + // gain_update *= iono_fits[2] / iono_fits[3]; + // vis_model_low_res + // .iter_mut() + // .for_each(|v| *v *= gain_update as f32); + + // if the offset is small, we've converged. + if iono_fits[0].abs() < 1e-12 && iono_fits[1].abs() < 1e-12 { + debug!("iter {iteration}, consts: {iono_consts:?}, finished"); + break; + } + } + + multi_progress_bar + .suspend(|| trace!("{:?}: high res model", std::time::Instant::now() - start)); + vis_model_high_res.fill(Jones::default()); + model_timesteps( + high_res_modeller, + ×tamps.mapped_ref(|(e, _)| *e), + vis_model_high_res.view_mut(), + )?; + + multi_progress_bar + .suspend(|| trace!("{:?}: apply_iono3", std::time::Instant::now() - start)); + // add the model to residual, and subtract the iono rotated model + apply_iono3( + vis_model_high_res.view(), + vis_residual.view_mut(), + // TODO: pretty sure this needs to be tile_uvs_high_res_rot + // tile_uvs_high_res.view(), + tile_uvs_high_res_rot.view(), + *iono_consts, + old_iono_consts, + all_fine_chan_lambdas_m, + ); + + multi_progress_bar.suspend(|| { + debug!( + "peel loop finished: {source_name} at {source_phase_centre} (has iono {iono_consts:?})" + ) + }); + peel_progress.inc(1); + } + } + + Ok(()) +} + +#[cfg(any(feature = "cuda", feature = "hip"))] +#[allow(clippy::too_many_arguments)] +fn peel_gpu( + mut vis_residual: ArrayViewMut3>, + vis_weights: ArrayView3, + timeblock: &Timeblock, + source_list: &SourceList, + iono_consts: &mut [(f64, f64)], + source_weighted_positions: &[RADec], + num_passes: usize, + low_res_freqs_hz: &[f64], + all_fine_chan_lambdas_m: &[f64], + low_res_lambdas_m: &[f64], + obs_context: &ObsContext, + array_position: LatLngHeight, + unflagged_tile_xyzs: &[XyzGeodetic], + low_res_modeller: &mut SkyModellerGpu, + high_res_modeller: &mut SkyModellerGpu, + dut1: Duration, + no_precession: bool, + multi_progress_bar: &MultiProgress, +) -> Result<(), PeelError> { + use crate::gpu::gpu_kernel_call; + + let num_sources_to_iono_subtract = iono_consts.len(); + + let timestamps = &timeblock.timestamps; + let num_timestamps_high_res = timestamps.len(); + let num_timestamps_low_res = 1; + + let num_tiles = unflagged_tile_xyzs.len(); + let num_cross_baselines = (num_tiles * (num_tiles - 1)) / 2; + + let num_freqs_high_res = all_fine_chan_lambdas_m.len(); + let num_freqs_low_res = low_res_lambdas_m.len(); + + let num_sources = source_list.len(); + let num_sources_to_iono_subtract = iono_consts.len(); + + let (time_axis, freq_axis, baseline_axis) = (Axis(0), Axis(1), Axis(2)); + + assert_eq!(vis_residual.len_of(time_axis), num_timestamps_high_res); + assert_eq!(vis_weights.len_of(time_axis), num_timestamps_high_res); + + assert_eq!(vis_residual.len_of(baseline_axis), num_cross_baselines); + assert_eq!(vis_weights.len_of(baseline_axis), num_cross_baselines); + + assert_eq!(vis_residual.len_of(freq_axis), num_freqs_high_res); + assert_eq!(vis_weights.len_of(freq_axis), num_freqs_high_res); + + assert_eq!(iono_consts.len(), num_sources_to_iono_subtract); + assert!(num_sources_to_iono_subtract <= num_sources); + + // TODO: Do we allow multiple timesteps in the low-res data? + + let timestamps = &timeblock.timestamps; + let peel_progress = multi_progress_bar.add( + ProgressBar::new(num_sources_to_iono_subtract as _) + .with_style( + ProgressStyle::default_bar() + .template("{msg:17}: [{wide_bar:.blue}] {pos:2}/{len:2} sources ({elapsed_precise}<{eta_precise})").unwrap() + .progress_chars("=> "), + ) + .with_position(0) + ); + peel_progress.tick(); + + let num_timesteps = vis_residual.len_of(Axis(0)); + let num_tiles = unflagged_tile_xyzs.len(); + let num_cross_baselines = (num_tiles * (num_tiles - 1)) / 2; + + let mut lmsts = vec![0.; timestamps.len()]; + let mut latitudes = vec![0.; timestamps.len()]; + let mut tile_xyzs_high_res = Array2::::default((timestamps.len(), num_tiles)); + let mut high_res_uvws = Array2::default((timestamps.len(), num_cross_baselines)); + let mut tile_uvs_high_res = Array2::::default((timestamps.len(), num_tiles)); + let mut tile_ws_high_res = Array2::::default((timestamps.len(), num_tiles)); + + // Pre-compute high-res tile UVs and Ws at observation phase centre. + for ( + &(time, _), + lmst, + latitude, + mut tile_xyzs_high_res, + mut high_res_uvws, + mut tile_uvs_high_res, + mut tile_ws_high_res, + ) in izip!( + timestamps.iter(), + lmsts.iter_mut(), + latitudes.iter_mut(), + tile_xyzs_high_res.outer_iter_mut(), + high_res_uvws.outer_iter_mut(), + tile_uvs_high_res.outer_iter_mut(), + tile_ws_high_res.outer_iter_mut(), + ) { + if !no_precession { + let precession_info = precess_time( + array_position.longitude_rad, + array_position.latitude_rad, + obs_context.phase_centre, + time, + dut1, + ); + tile_xyzs_high_res + .iter_mut() + .zip_eq(&precession_info.precess_xyz(unflagged_tile_xyzs)) + .for_each(|(a, b)| *a = *b); + *lmst = precession_info.lmst_j2000; + *latitude = precession_info.array_latitude_j2000; + } else { + tile_xyzs_high_res + .iter_mut() + .zip_eq(unflagged_tile_xyzs) + .for_each(|(a, b)| *a = *b); + *lmst = get_lmst(array_position.longitude_rad, time, dut1); + *latitude = array_position.latitude_rad; + }; + let hadec_phase = obs_context.phase_centre.to_hadec(*lmst); + let (s_ha, c_ha) = hadec_phase.ha.sin_cos(); + let (s_dec, c_dec) = hadec_phase.dec.sin_cos(); + let mut tile_uvws_high_res = vec![UVW::default(); num_tiles]; + for (tile_uvw, tile_uv, tile_w, &tile_xyz) in izip!( + tile_uvws_high_res.iter_mut(), + tile_uvs_high_res.iter_mut(), + tile_ws_high_res.iter_mut(), + tile_xyzs_high_res.iter(), + ) { + let uvw = UVW::from_xyz_inner(tile_xyz, s_ha, c_ha, s_dec, c_dec); + *tile_uvw = uvw; + *tile_uv = UV { u: uvw.u, v: uvw.v }; + *tile_w = W(uvw.w); + } + + // The UVWs for every timestep will be the same (because the phase + // centres are always the same). Make these ahead of time for + // efficiency. + let mut count = 0; + for (i, t1) in tile_uvws_high_res.iter().enumerate() { + for t2 in tile_uvws_high_res.iter().skip(i + 1) { + high_res_uvws[count] = *t1 - *t2; + count += 1; + } + } + } + + // use the baseline taper from the RTS, 1-exp(-(u*u+v*v)/(2*sig^2)); + // let short_baseline_sigma = 20.; + // TODO: Do we care about weights changing over time? + // let vis_weights = { + // let mut iono_taper = get_weights_rts( + // tile_uvs_high_res.view(), + // all_fine_chan_lambdas_m, + // short_baseline_sigma, + // (obs_context.guess_freq_res() / FREQ_WEIGHT_FACTOR) + // * (obs_context.guess_time_res().to_seconds() / TIME_WEIGHT_FACTOR), + // ); + // iono_taper *= &vis_weights; + // iono_taper + // }; + + let (average_lmst, average_latitude, average_tile_xyzs) = if no_precession { + let average_timestamp = timeblock.median; + let average_tile_xyzs = + ArrayView2::from_shape((1, num_tiles), unflagged_tile_xyzs).expect("correct shape"); + ( + get_lmst(array_position.longitude_rad, average_timestamp, dut1), + array_position.latitude_rad, + CowArray::from(average_tile_xyzs), + ) + } else { + let average_timestamp = timeblock.median; + let average_precession_info = precess_time( + array_position.longitude_rad, + array_position.latitude_rad, + obs_context.phase_centre, + average_timestamp, + dut1, + ); + let average_precessed_tile_xyzs = Array2::from_shape_vec( + (1, num_tiles), + average_precession_info.precess_xyz(unflagged_tile_xyzs), + ) + .expect("correct shape"); + + ( + average_precession_info.lmst_j2000, + average_precession_info.array_latitude_j2000, + CowArray::from(average_precessed_tile_xyzs), + ) + }; + + // temporary arrays for accumulation + // TODO: Do a stocktake of arrays that are lying around! + // These are time, bl, channel + let vis_residual_low_res: Array3> = + Array3::zeros((1, low_res_freqs_hz.len(), num_cross_baselines)); + let mut vis_weights_low_res: Array3 = Array3::zeros(vis_residual_low_res.dim()); + + // The low-res weights only need to be populated once. + weights_average(vis_weights.view(), vis_weights_low_res.view_mut()); + + let freq_average_factor = all_fine_chan_lambdas_m.len() / low_res_freqs_hz.len(); + + unsafe { + let cuda_xyzs_high_res: Vec<_> = tile_xyzs_high_res + .iter() + .copied() + .map(|XyzGeodetic { x, y, z }| gpu::XYZ { + x: x as GpuFloat, + y: y as GpuFloat, + z: z as GpuFloat, + }) + .collect(); + let mut cuda_uvws = Array2::from_elem( + (num_timesteps, num_cross_baselines), + gpu::UVW { + u: 0.0, + v: 0.0, + w: 0.0, + }, + ); + cuda_uvws + .outer_iter_mut() + .zip(tile_xyzs_high_res.outer_iter()) + .zip(lmsts.iter()) + .for_each(|((mut cuda_uvws, xyzs), lmst)| { + let phase_centre = obs_context.phase_centre.to_hadec(*lmst); + let v = xyzs_to_cross_uvws(xyzs.as_slice().unwrap(), phase_centre) + .into_iter() + .map(|uvw| gpu::UVW { + u: uvw.u as GpuFloat, + v: uvw.v as GpuFloat, + w: uvw.w as GpuFloat, + }) + .collect::>(); + cuda_uvws.assign(&ArrayView1::from(&v)); + }); + let cuda_lmsts: Vec = lmsts.iter().map(|l| *l as GpuFloat).collect(); + let cuda_lambdas: Vec = all_fine_chan_lambdas_m + .iter() + .map(|l| *l as GpuFloat) + .collect(); + let cuda_xyzs_low_res: Vec<_> = average_tile_xyzs + .iter() + .copied() + .map(|XyzGeodetic { x, y, z }| gpu::XYZ { + x: x as _, + y: y as _, + z: z as _, + }) + .collect(); + let cuda_low_res_lambdas: Vec = + low_res_lambdas_m.iter().map(|l| *l as GpuFloat).collect(); + + let d_xyzs = DevicePointer::copy_to_device(&cuda_xyzs_high_res).unwrap(); + let d_uvws_from = DevicePointer::copy_to_device(cuda_uvws.as_slice().unwrap()).unwrap(); + let mut d_uvws_to = + DevicePointer::malloc(cuda_uvws.len() * std::mem::size_of::()).unwrap(); + let d_lmsts = DevicePointer::copy_to_device(&cuda_lmsts).unwrap(); + let d_lambdas = DevicePointer::copy_to_device(&cuda_lambdas).unwrap(); + let d_xyzs_low_res = DevicePointer::copy_to_device(&cuda_xyzs_low_res).unwrap(); + let d_average_lmsts = DevicePointer::copy_to_device(&[average_lmst as GpuFloat]).unwrap(); + let mut d_uvws_source_low_res: DevicePointer = + DevicePointer::malloc(cuda_uvws.len() * std::mem::size_of::()).unwrap(); + let d_low_res_lambdas = DevicePointer::copy_to_device(&cuda_low_res_lambdas).unwrap(); + // Make the amount of elements in `d_iono_fits` a power of 2, for + // efficiency. + let mut d_iono_fits = { + let min_size = + num_cross_baselines * low_res_freqs_hz.len() * std::mem::size_of::>(); + let n = (min_size as f64).log2().ceil() as u32; + let size = 2_usize.pow(n); + let mut d: DevicePointer> = DevicePointer::malloc(size).unwrap(); + d.clear(); + d + }; + + // let mut d_low_res_vis = DevicePointer::malloc( + // num_cross_baselines * low_res_freqs_hz.len() * std::mem::size_of::>(), + // ); + // let mut d_low_res_weights = DevicePointer::malloc( + // num_cross_baselines * low_res_freqs_hz.len() * std::mem::size_of::(), + // ); + + let mut d_high_res_vis = + DevicePointer::copy_to_device(vis_residual.as_slice().unwrap()).unwrap(); + let d_high_res_weights = + DevicePointer::copy_to_device(vis_weights.as_slice().unwrap()).unwrap(); + + let mut d_low_res_vis = + DevicePointer::copy_to_device(vis_residual_low_res.as_slice().unwrap()).unwrap(); + let d_low_res_weights = + DevicePointer::copy_to_device(vis_weights_low_res.as_slice().unwrap()).unwrap(); + let mut d_low_res_model = + DevicePointer::copy_to_device(vis_residual_low_res.as_slice().unwrap()).unwrap(); + let mut d_low_res_model_rotated = + DevicePointer::copy_to_device(vis_residual_low_res.as_slice().unwrap()).unwrap(); + + let mut d_high_res_model: DevicePointer> = DevicePointer::malloc( + timestamps.len() + * num_cross_baselines + * all_fine_chan_lambdas_m.len() + * std::mem::size_of::>(), + ) + .unwrap(); + + // One pointer per timestep. + let mut d_uvws = Vec::with_capacity(high_res_uvws.len_of(Axis(0))); + // Temp vector to store results. + let mut cuda_uvws = vec![ + gpu::UVW { + u: 0.0, + v: 0.0, + w: 0.0 + }; + high_res_uvws.len_of(Axis(1)) + ]; + for uvws in high_res_uvws.outer_iter() { + // Convert the type and push the results to the device, + // saving the resulting pointer. + uvws.iter() + .zip_eq(cuda_uvws.iter_mut()) + .for_each(|(&UVW { u, v, w }, cuda_uvw)| { + *cuda_uvw = gpu::UVW { + u: u as GpuFloat, + v: v as GpuFloat, + w: w as GpuFloat, + } + }); + d_uvws.push(DevicePointer::copy_to_device(&cuda_uvws).unwrap()); + } + let mut d_beam_jones = DevicePointer::default(); + + for pass in 0..num_passes { + peel_progress.reset(); + peel_progress.set_message(format!( + "Peeling timeblock {}, pass {}", + timeblock.index + 1, + pass + 1 + )); + + for (((source_name, source), iono_consts), source_phase_centre) in source_list + .iter() + .take(num_sources_to_iono_subtract) + .zip_eq(iono_consts.iter_mut()) + .zip_eq(source_weighted_positions.iter().copied()) + { + let start = std::time::Instant::now(); + multi_progress_bar.suspend(|| { + debug!( + "peel loop {pass}: {source_name} at {source_phase_centre} (has iono {iono_consts:?})" + ) + }); + + let old_iono_consts = *iono_consts; + + gpu_kernel_call!("rotate_average", || gpu::rotate_average( + d_high_res_vis.get().cast(), + d_high_res_weights.get(), + d_low_res_vis.get_mut().cast(), + gpu::RADec { + ra: source_phase_centre.ra as _, + dec: source_phase_centre.dec as _, + }, + timestamps.len().try_into().unwrap(), + num_tiles.try_into().unwrap(), + num_cross_baselines.try_into().unwrap(), + all_fine_chan_lambdas_m.len().try_into().unwrap(), + freq_average_factor.try_into().unwrap(), + d_lmsts.get(), + d_xyzs.get(), + d_uvws_from.get(), + d_uvws_to.get_mut(), + d_lambdas.get(), + ))?; + + multi_progress_bar + .suspend(|| trace!("{:?}: rotate_average", std::time::Instant::now() - start)); + + low_res_modeller.update_with_a_source(source, source_phase_centre)?; + multi_progress_bar.suspend(|| { + trace!( + "{:?}: low res update and clear", + std::time::Instant::now() - start + ) + }); + + gpu_kernel_call!("xyzs_to_uvws", || gpu::xyzs_to_uvws( + d_xyzs_low_res.get(), + d_average_lmsts.get(), + d_uvws_source_low_res.get_mut(), + gpu::RADec { + ra: source_phase_centre.ra as GpuFloat, + dec: source_phase_centre.dec as GpuFloat, + }, + num_tiles.try_into().unwrap(), + num_cross_baselines.try_into().unwrap(), + 1, + ))?; + + multi_progress_bar.suspend(|| { + trace!( + "{:?}: low res xyzs_to_uvws", + std::time::Instant::now() - start + ) + }); + + d_low_res_model.clear(); + low_res_modeller.model_timestep_with( + average_lmst, + average_latitude, + &d_uvws_source_low_res, + &mut d_beam_jones, + &mut d_low_res_model, + )?; + multi_progress_bar + .suspend(|| trace!("{:?}: low res model", std::time::Instant::now() - start)); + + gpu_kernel_call!("add_model", || gpu::add_model( + d_low_res_vis.get_mut().cast(), + d_low_res_model.get().cast(), + iono_consts.0 as GpuFloat, + iono_consts.1 as GpuFloat, + d_low_res_lambdas.get(), + d_uvws_source_low_res.get(), + vis_residual_low_res.len_of(Axis(0)).try_into().unwrap(), + num_cross_baselines.try_into().unwrap(), + low_res_freqs_hz.len().try_into().unwrap(), + ))?; + + #[cfg(test)] + { + let vis_residual_low_res = d_low_res_vis.copy_from_device_new().unwrap(); + let vis_residual_low_res = Array3::from_shape_vec( + (1, low_res_freqs_hz.len(), num_cross_baselines), + vis_residual_low_res, + ) + .unwrap(); + + let vis_model_low_res = d_low_res_model.copy_from_device_new().unwrap(); + let vis_model_low_res = Array3::from_shape_vec( + (1, low_res_freqs_hz.len(), num_cross_baselines), + vis_model_low_res, + ) + .unwrap(); + + let uvws_low_res = d_uvws_source_low_res.copy_from_device_new()?; + + for (vis_residual_low_res, vis_model_low_res, &lambda) in izip!( + vis_residual_low_res.slice(s![0, .., ..]).outer_iter(), + vis_model_low_res.slice(s![0, .., ..]).outer_iter(), + low_res_lambdas_m + ) { + for (residual, model, &gpu::UVW { u, v, w: _ }) in izip!( + vis_residual_low_res.iter(), + vis_model_low_res.iter(), + &uvws_low_res, + ) { + let residual_i = residual[0] + residual[3]; + let model_i = model[0] + model[3]; + let u = u / lambda as GpuFloat; + let v = v / lambda as GpuFloat; + + println!("uv ({:+1.5}, {:+1.5}) l{:+1.3} | RI {:+1.5} @{:+1.5}pi | MI {:+1.5} @{:+1.5}pi", u, v, lambda, residual_i.norm(), residual_i.arg(), model_i.norm(), model_i.arg()); + } + } + } + + gpu_kernel_call!("iono_loop", || gpu::iono_loop( + d_low_res_vis.get().cast(), + d_low_res_weights.get(), + d_low_res_model.get().cast(), + d_low_res_model_rotated.get_mut().cast(), + d_iono_fits.get_mut().cast(), + &mut iono_consts.0, + &mut iono_consts.1, + num_timesteps.try_into().unwrap(), + num_tiles.try_into().unwrap(), + num_cross_baselines.try_into().unwrap(), + low_res_freqs_hz.len().try_into().unwrap(), + 10, + d_average_lmsts.get(), + d_uvws_source_low_res.get(), + d_low_res_lambdas.get(), + ))?; + + multi_progress_bar + .suspend(|| trace!("{:?}: iono_loop", std::time::Instant::now() - start)); + + high_res_modeller.update_with_a_source(source, obs_context.phase_centre)?; + // high_res_modeller.clear_vis(); + // Clear the old memory before reusing the buffer. + d_high_res_model.clear(); + d_uvws + .iter() + .zip(lmsts.iter()) + .zip(latitudes.iter()) + .enumerate() + .try_for_each(|(i_time, ((d_uvws, lmst), latitude))| { + let original = d_high_res_model.ptr; + d_high_res_model.ptr = d_high_res_model + .ptr + .add(i_time * num_cross_baselines * all_fine_chan_lambdas_m.len()); + let result = SkyModellerGpu::model_timestep_with( + high_res_modeller, + *lmst, + *latitude, + d_uvws, + &mut d_beam_jones, + &mut d_high_res_model, + ); + d_high_res_model.ptr = original; + result + })?; + multi_progress_bar + .suspend(|| trace!("{:?}: high res model", std::time::Instant::now() - start)); + + gpu_kernel_call!("subtract_iono", || gpu::subtract_iono( + d_high_res_vis.get_mut().cast(), + d_high_res_model.get().cast(), + iono_consts.0, + iono_consts.1, + old_iono_consts.0, + old_iono_consts.1, + d_uvws_to.get(), + d_lambdas.get(), + num_timesteps.try_into().unwrap(), + num_cross_baselines.try_into().unwrap(), + all_fine_chan_lambdas_m.len().try_into().unwrap(), + ))?; + + multi_progress_bar + .suspend(|| trace!("{:?}: subtract_iono", std::time::Instant::now() - start)); + debug!("peel loop finished: {source_name} at {source_phase_centre} (has iono {iono_consts:?})"); + + peel_progress.inc(1); + } + } + + // copy results back to host + d_high_res_vis + .copy_from_device(vis_residual.as_slice_mut().unwrap()) + .unwrap(); + } + + Ok(()) +} + +/// Just the W terms of [`UVW`] coordinates. +#[derive(Clone, Copy, Default, PartialEq, Debug)] +struct W(f64); + +impl W { + fn _from_xyz(xyz: XyzGeodetic, s_ha: f64, c_ha: f64, s_dec: f64, c_dec: f64) -> W { + W(c_dec * c_ha * xyz.x - c_dec * s_ha * xyz.y + s_dec * xyz.z) + } +} + +impl Sub for W { + type Output = f64; + + fn sub(self, rhs: Self) -> Self::Output { + self.0 - rhs.0 + } +} + +impl Neg for W { + type Output = Self; + + fn neg(self) -> Self::Output { + W(-self.0) + } +} + +#[cfg(test)] +impl approx::AbsDiffEq for W { + type Epsilon = f64; + + fn default_epsilon() -> Self::Epsilon { + f64::EPSILON + } + + fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool { + f64::abs_diff_eq(&self.0, &other.0, epsilon) + } +} + +/// Just the U and V terms of [`UVW`] coordinates. +#[derive(Clone, Copy, Default, PartialEq, Debug)] +struct UV { + u: f64, + v: f64, +} + +impl UV { + fn from_xyz(xyz: XyzGeodetic, s_ha: f64, c_ha: f64, s_dec: f64, c_dec: f64) -> UV { + UV { + u: s_ha * xyz.x + c_ha * xyz.y, + v: -s_dec * c_ha * xyz.x + s_dec * s_ha * xyz.y + c_dec * xyz.z, + } + } +} + +impl Sub for UV { + type Output = UV; + + fn sub(self, rhs: Self) -> Self::Output { + UV { + u: self.u - rhs.u, + v: self.v - rhs.v, + } + } +} + +impl Div for UV { + type Output = UV; + + fn div(self, rhs: f64) -> Self::Output { + UV { + u: self.u / rhs, + v: self.v / rhs, + } + } +} + +#[cfg(test)] +impl approx::AbsDiffEq for UV { + type Epsilon = f64; + + fn default_epsilon() -> Self::Epsilon { + f64::EPSILON + } + + fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool { + f64::abs_diff_eq(&self.u, &other.u, epsilon) && f64::abs_diff_eq(&self.v, &other.v, epsilon) + } +} + +#[derive(thiserror::Error, Debug)] +pub(crate) enum PeelError { + #[error(transparent)] + VisRead(#[from] crate::io::read::VisReadError), + + #[error(transparent)] + FileWrite(#[from] crate::io::write::FileWriteError), + + #[error(transparent)] + Beam(#[from] crate::beam::BeamError), + + #[error(transparent)] + Model(#[from] crate::model::ModelError), + + #[error(transparent)] + IO(#[from] std::io::Error), + + #[cfg(any(feature = "cuda", feature = "hip"))] + #[error(transparent)] + Gpu(#[from] crate::gpu::GpuError), +} diff --git a/src/params/peel/tests.rs b/src/params/peel/tests.rs new file mode 100644 index 00000000..d48283f0 --- /dev/null +++ b/src/params/peel/tests.rs @@ -0,0 +1,2459 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//! Tests against peeling + +use std::{collections::HashSet, f64::consts::TAU}; + +use approx::assert_abs_diff_eq; +use hifitime::{Duration, Epoch}; +use indexmap::indexmap; +use indicatif::{MultiProgress, ProgressDrawTarget}; +use itertools::{izip, Itertools}; +use marlu::{ + constants::VEL_C, + math::cross_correlation_baseline_to_tiles, + precession::{get_lmst, precess_time}, + Complex, HADec, Jones, LatLngHeight, RADec, XyzGeodetic, +}; +use ndarray::{prelude::*, Zip}; +use num_traits::Zero; +use vec1::{vec1, Vec1}; + +use super::*; +use crate::{ + averaging::Timeblock, + beam::{Delays, FEEBeam}, + context::{ObsContext, Polarisations}, + io::read::VisInputType, + model::{new_sky_modeller, SkyModellerCpu}, + srclist::{ComponentType, FluxDensity, FluxDensityType, Source, SourceComponent, SourceList}, +}; + +// a single-component point source, stokes I. +macro_rules! point_src_i { + ($radec:expr, $si:expr, $freq:expr, $i:expr) => { + Source { + components: vec![SourceComponent { + radec: $radec, + comp_type: ComponentType::Point, + flux_type: FluxDensityType::PowerLaw { + si: $si, + fd: FluxDensity { + freq: $freq, + i: $i, + q: 0.0, + u: 0.0, + v: 0.0, + }, + }, + }] + .into_boxed_slice(), + } + }; +} + +fn get_beam(num_tiles: usize) -> FEEBeam { + #[rustfmt::skip] + let delays = vec![ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + ]; + // https://github.com/MWATelescope/mwa_pb/blob/90d6fbfc11bf4fca35796e3d5bde3ab7c9833b66/mwa_pb/mwa_sweet_spots.py#L60 + // let delays = vec![0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12]; + + FEEBeam::new_from_env(num_tiles, Delays::Partial(delays), None).unwrap() +} + +// get a timestamp at lmst=0 around the year 2100 +// precessing to j2000 will introduce a noticable difference. +fn get_j2100(array_position: &LatLngHeight, dut1: Duration) -> Epoch { + let mut epoch = Epoch::from_gregorian_utc_at_midnight(2100, 1, 1); + + // shift zenith_time to the nearest time when the phase centre is at zenith + let sidereal2solar = 365.24 / 366.24; + let obs_lst_rad = get_lmst(array_position.longitude_rad, epoch, dut1); + if obs_lst_rad.abs() > 1e-6 { + epoch -= Duration::from_days(sidereal2solar * obs_lst_rad / TAU); + } + epoch +} + +/// get 3 simple tiles: +/// - tile "o" is at origin +/// - tile "u" has a u-component of s at lambda = 1m +/// - tile "v" has a v-component of s at lambda = 1m +#[rustfmt::skip] +fn get_simple_tiles(s_: f64) -> (Vec1, Vec1) { + ( + vec1!["o", "u", "v"].mapped(|s| s.into()), + vec1![ + XyzGeodetic { x: 0., y: 0., z: 0., }, + XyzGeodetic { x: 0., y: s_, z: 0., }, + XyzGeodetic { x: 0., y: 0., z: s_, }, + ], + ) +} + +/// get an observation context with: +/// - array positioned at LatLngHeight = 0, 0, 100m +/// - 2 timestamps: +/// - first: phase centre is at zenith on j2100 +/// - second: an hour later, +/// - 2 frequencies: lambda = 2m, 1m +/// - tiles from [get_simple_tiles], s=1 +fn get_simple_obs_context() -> ObsContext { + let array_position = LatLngHeight { + longitude_rad: 0., + latitude_rad: 0., + height_metres: 100., + }; + + let dut1 = Duration::from_seconds(0.0); + let obs_epoch = get_j2100(&array_position, dut1); + + // at first timestep phase centre is at zenith + let lst_zenith_rad = get_lmst(array_position.longitude_rad, obs_epoch, dut1); + let phase_centre = RADec::from_hadec( + HADec::from_radians(0., array_position.latitude_rad), + lst_zenith_rad, + ); + + // second timestep is at 1h + let hour_epoch = obs_epoch + Duration::from_hours(1.0); + let timestamps = vec1![obs_epoch, hour_epoch]; + + let (tile_names, tile_xyzs) = get_simple_tiles(1.); + let lambdas_m = vec1![2., 1.]; + let fine_chan_freqs: Vec1 = lambdas_m.mapped(|l| (VEL_C / l) as u64); + + ObsContext { + input_data_type: VisInputType::Raw, + obsid: None, + timestamps, + all_timesteps: vec1![0, 1], + unflagged_timesteps: vec![0, 1], + phase_centre, + pointing_centre: None, + array_position, + supplied_array_position: array_position, + dut1: Some(dut1), + tile_names, + tile_xyzs, + flagged_tiles: vec![], + unavailable_tiles: vec![], + autocorrelations_present: false, + dipole_delays: Some(Delays::Partial(vec![ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ])), + dipole_gains: None, + time_res: Some(hour_epoch - obs_epoch), + mwa_coarse_chan_nums: None, + num_fine_chans_per_coarse_chan: None, + freq_res: Some((fine_chan_freqs[1] - fine_chan_freqs[0]) as f64), + fine_chan_freqs, + flagged_fine_chans: vec![], + flagged_fine_chans_per_coarse_chan: None, + polarisations: Polarisations::default(), + } +} + +/// get an observation context with: +/// - array positioned at LatLngHeight = 0, 0, 100m +/// - 2 timestamps: +/// - first: phase centre is at zenith on j2100 +/// - second: an hour later, +/// - 2 frequencies: lambda = 2m, 1m +/// - tiles from [get_simple_tiles], s=1 +fn get_complex_obs_context() -> ObsContext { + let tile_limit = 32; + let array_position = LatLngHeight::mwa(); + + let meta_path = "test_files/1090008640/1090008640.metafits"; + let meta_ctx = mwalib::MetafitsContext::new(meta_path, None).unwrap(); + + let obsid = meta_ctx.obs_id; + let obs_time = Epoch::from_gpst_seconds(obsid as _); + let dut1 = Duration::from_seconds(0.); + + // let obs_lst_rad = get_lmst(array_position.longitude_rad, obs_time, dut1); + // shift obs_time to the nearest time when the phase centre is at zenith + let zenith_lst_rad = get_lmst(array_position.longitude_rad, obs_time, dut1); + eprintln!("lst % 𝜏 should be 0: {zenith_lst_rad:?}"); + let phase_centre = RADec::from_hadec( + HADec::from_radians(0., array_position.latitude_rad), + zenith_lst_rad, + ); + eprintln!("phase centre: {phase_centre:?}"); + let hadec = phase_centre.to_hadec(zenith_lst_rad); + eprintln!("ha % 𝜏 should be 0: {hadec:?}"); + let azel = hadec.to_azel(array_position.latitude_rad); + eprintln!("(az, el) % 𝜏 should be 0, pi/2: {azel:?}"); + let tile_names: Vec = meta_ctx + .antennas + .iter() + .map(|ant| ant.tile_name.clone()) + .collect(); + let tile_names = Vec1::try_from_vec(tile_names).unwrap(); + let tile_xyzs: Vec = XyzGeodetic::get_tiles_mwa(&meta_ctx) + .into_iter() + .take(tile_limit) + .collect(); + let tile_xyzs = Vec1::try_from_vec(tile_xyzs).unwrap(); + + // at first timestep phase centre is at zenith + let lst_zenith_rad = get_lmst(array_position.longitude_rad, obs_time, dut1); + let phase_centre = RADec::from_hadec( + HADec::from_radians(0., array_position.latitude_rad), + lst_zenith_rad, + ); + + // second timestep is at 1h + let hour_epoch = obs_time + Duration::from_hours(1.0); + let timestamps = vec1![obs_time, hour_epoch]; + + let lambdas_m = vec1![2., 1.]; + let fine_chan_freqs: Vec1 = lambdas_m.mapped(|l| (VEL_C / l) as u64); + + ObsContext { + input_data_type: VisInputType::Raw, + obsid: None, + timestamps, + all_timesteps: vec1![0, 1], + unflagged_timesteps: vec![0, 1], + phase_centre, + pointing_centre: None, + array_position, + supplied_array_position: array_position, + dut1: Some(dut1), + tile_names, + tile_xyzs, + flagged_tiles: vec![], + unavailable_tiles: vec![], + autocorrelations_present: false, + dipole_delays: Some(Delays::Partial(vec![ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ])), + dipole_gains: None, + time_res: Some(hour_epoch - obs_time), + mwa_coarse_chan_nums: None, + num_fine_chans_per_coarse_chan: None, + freq_res: Some((fine_chan_freqs[1] - fine_chan_freqs[0]) as f64), + fine_chan_freqs, + flagged_fine_chans: vec![], + flagged_fine_chans_per_coarse_chan: None, + polarisations: Polarisations::default(), + } +} + +// these are used for debugging the tests +use ndarray::{ArrayView2, ArrayView3}; +#[allow(clippy::too_many_arguments)] +fn display_vis_b( + name: &String, + vis_b: &[Jones], + ant_pairs: &[(usize, usize)], + uvs: &[UV], + ws: &[W], + tile_names: &[String], + seconds: f64, + lambda: f64, +) { + use std::f64::consts::PI; + println!( + "bl u v w | @ time={:>+9.3}s, lam={:+1.3}m, {}", + seconds, lambda, name + ); + for (jones, &(ant1, ant2)) in vis_b.iter().zip_eq(ant_pairs.iter()) { + let uv = uvs[ant1] - uvs[ant2]; + let w = ws[ant1] - ws[ant2]; + let (name1, name2) = (&tile_names[ant1], &tile_names[ant2]); + let (xx, xy, yx, yy) = jones.iter().collect_tuple().unwrap(); + println!( + "{:1}-{:1} {:+1.5} {:+1.5} {:+1.5} | \ + XX {:07.5} @{:+08.5}pi XY {:07.5} @{:+08.5}pi \ + YX {:07.5} @{:+08.5}pi YY {:07.5} @{:+08.5}pi", + name1, + name2, + uv.u / lambda, + uv.v / lambda, + w / lambda, + xx.norm(), + xx.arg() as f64 / PI, + xy.norm(), + xy.arg() as f64 / PI, + yx.norm(), + yx.arg() as f64 / PI, + yy.norm(), + yy.arg() as f64 / PI, + ); + } +} + +#[allow(clippy::too_many_arguments)] +fn display_vis_fb( + name: &String, + vis_fb: ArrayView2>, + seconds: f64, + uvs: &[UV], + ws: &[W], + lambdas_m: &[f64], + ant_pairs: &[(usize, usize)], + tile_names: &[String], +) { + // println!("{:9} | {:7} | bl u v w | {}", "time", "lam", name); + for (vis_b, &lambda) in vis_fb.outer_iter().zip_eq(lambdas_m.iter()) { + display_vis_b( + name, + vis_b.as_slice().unwrap(), + ant_pairs, + uvs, + ws, + tile_names, + seconds, + lambda, + ); + } +} + +/// display named visibilities and uvws in table format +#[allow(clippy::too_many_arguments)] +fn display_vis_tfb( + name: &String, + vis_tfb: ArrayView3>, + obs_context: &ObsContext, + phase_centre: RADec, + apply_precession: bool, +) { + let array_pos = obs_context.array_position; + let num_tiles = obs_context.get_total_num_tiles(); + let num_baselines = (num_tiles * (num_tiles - 1)) / 2; + let ant_pairs = (0..num_baselines) + .map(|bl_idx| cross_correlation_baseline_to_tiles(num_tiles, bl_idx)) + .collect_vec(); + let fine_chan_freqs_hz = obs_context + .fine_chan_freqs + .iter() + .map(|&f| f as f64) + .collect_vec(); + let lambdas_m = fine_chan_freqs_hz.iter().map(|&f| VEL_C / f).collect_vec(); + + let start_seconds = obs_context.timestamps[0].to_gpst_seconds(); + let mut tile_uvs_tmp = vec![UV::default(); num_tiles]; + let mut tile_ws_tmp = vec![W::default(); num_tiles]; + // println!("{:9} | {:7} | bl u v w | {}", "time", "lam", name); + for (vis_fb, &time) in vis_tfb.outer_iter().zip_eq(obs_context.timestamps.iter()) { + if apply_precession { + let precession_info = precess_time( + array_pos.longitude_rad, + array_pos.latitude_rad, + phase_centre, + time, + obs_context.dut1.unwrap_or_default(), + ); + let hadec = phase_centre.to_hadec(precession_info.lmst_j2000); + let precessed_xyzs = precession_info.precess_xyz(&obs_context.tile_xyzs); + setup_uvs(&mut tile_uvs_tmp, &precessed_xyzs, hadec); + setup_ws(&mut tile_ws_tmp, &precessed_xyzs, hadec); + } else { + let lmst = get_lmst( + array_pos.longitude_rad, + time, + obs_context.dut1.unwrap_or_default(), + ); + let hadec = phase_centre.to_hadec(lmst); + setup_uvs(&mut tile_uvs_tmp, &obs_context.tile_xyzs, hadec); + setup_ws(&mut tile_ws_tmp, &obs_context.tile_xyzs, hadec); + } + let seconds = time.to_gpst_seconds() - start_seconds; + display_vis_fb( + name, + vis_fb.view(), + seconds, + tile_uvs_tmp.as_slice(), + tile_ws_tmp.as_slice(), + &lambdas_m, + &ant_pairs, + &obs_context.tile_names, + ); + } +} + +/// Populate the [UV] and [W] arrays ([times, tiles]) for the given +/// [ObsContext], returning the LMSTs. +fn setup_tile_uv_w_arrays( + mut tile_uvs: ArrayViewMut2, + mut tile_ws: ArrayViewMut2, + obs_context: &ObsContext, + phase_centre: RADec, + apply_precession: bool, +) -> (Vec, Array2) { + let mut lmsts = vec![0.0; obs_context.timestamps.len()]; + let mut xyzs = Array2::default(tile_uvs.dim()); + + let array_pos = obs_context.array_position; + for (&time, mut tile_uvs, mut tile_ws, lmst, mut xyzs) in izip!( + obs_context.timestamps.iter(), + tile_uvs.outer_iter_mut(), + tile_ws.outer_iter_mut(), + lmsts.iter_mut(), + xyzs.outer_iter_mut() + ) { + if apply_precession { + let precession_info = precess_time( + array_pos.longitude_rad, + array_pos.latitude_rad, + obs_context.phase_centre, + time, + obs_context.dut1.unwrap_or_default(), + ); + *lmst = precession_info.lmst_j2000; + let hadec = phase_centre.to_hadec(*lmst); + let precessed_xyzs = precession_info.precess_xyz(&obs_context.tile_xyzs); + setup_uvs(tile_uvs.as_slice_mut().unwrap(), &precessed_xyzs, hadec); + setup_ws(tile_ws.as_slice_mut().unwrap(), &precessed_xyzs, hadec); + xyzs.assign(&ArrayView1::from(&precessed_xyzs)); + } else { + *lmst = get_lmst( + array_pos.longitude_rad, + time, + obs_context.dut1.unwrap_or_default(), + ); + let hadec = phase_centre.to_hadec(*lmst); + setup_uvs( + tile_uvs.as_slice_mut().unwrap(), + &obs_context.tile_xyzs, + hadec, + ); + setup_ws( + tile_ws.as_slice_mut().unwrap(), + &obs_context.tile_xyzs, + hadec, + ); + xyzs.assign(&ArrayView1::from(&obs_context.tile_xyzs)); + } + } + + (lmsts, xyzs) +} + +#[test] +/// test [setup_uvs], [setup_ws] +fn test_setup_uv() { + let obs_context = get_simple_obs_context(); + let array_pos = obs_context.array_position; + + let num_tiles = obs_context.get_total_num_tiles(); + let num_times = obs_context.timestamps.len(); + + // source is at zenith at 1h + let hour_epoch = obs_context.timestamps[1]; + let lst_1h_rad = get_lmst( + array_pos.longitude_rad, + hour_epoch, + obs_context.dut1.unwrap_or_default(), + ); + let source_radec = + RADec::from_hadec(HADec::from_radians(0., array_pos.latitude_rad), lst_1h_rad); + + // tile uvs and ws in the observation phase centre + let mut tile_uvs_obs = Array2::default((num_times, num_tiles)); + let mut tile_ws_obs = Array2::default((num_times, num_tiles)); + // tile uvs and ws in the source phase centre + let mut tile_uvs_src = Array2::default((num_times, num_tiles)); + let mut tile_ws_src = Array2::default((num_times, num_tiles)); + + for apply_precession in [false, true] { + setup_tile_uv_w_arrays( + tile_uvs_obs.view_mut(), + tile_ws_obs.view_mut(), + &obs_context, + obs_context.phase_centre, + apply_precession, + ); + setup_tile_uv_w_arrays( + tile_uvs_src.view_mut(), + tile_ws_src.view_mut(), + &obs_context, + source_radec, + apply_precession, + ); + + if !apply_precession { + for a in 0..num_tiles { + // uvws for the phase centre at first timestpe should be the same as + // uvws for the source at the second timestep + assert_abs_diff_eq!(tile_uvs_obs[[0, a]], tile_uvs_src[[1, a]], epsilon = 1e-6); + assert_abs_diff_eq!(tile_ws_obs[[0, a]], tile_ws_src[[1, a]], epsilon = 1e-6); + // uvws for the phase centre at the second timestep should be the same as + // uvws for the source at the first timestep, rotated in the opposite direciton. + // since all the baselines sit flat on the uv plane, only the w component is negative. + assert_abs_diff_eq!(tile_uvs_obs[[1, a]], tile_uvs_src[[0, a]], epsilon = 1e-6); + assert_abs_diff_eq!(tile_ws_obs[[1, a]], -tile_ws_src[[0, a]], epsilon = 1e-6); + } + for t in 0..num_times { + // tile 2 is a special case with only a v component, so should be unchanged + assert_abs_diff_eq!(tile_uvs_obs[[t, 2]].v, 1., epsilon = 1e-6); + assert_abs_diff_eq!(tile_uvs_obs[[t, 2]], tile_uvs_src[[t, 2]], epsilon = 1e-6); + assert_abs_diff_eq!(tile_ws_obs[[t, 2]], tile_ws_src[[t, 2]], epsilon = 1e-6); + } + // tile 1 is aligned with zenith at t=0 + assert_abs_diff_eq!(tile_uvs_obs[[0, 1]].u, 1., epsilon = 1e-6); + // tile 1 is aligned with source at t=1 + assert_abs_diff_eq!(tile_uvs_src[[1, 1]].u, 1., epsilon = 1e-6); + } + for t in 0..num_times { + for a in 0..num_tiles { + // println!( + // "prec={:5} t={} a={} obs=({:+1.6} {:+1.6} {:+1.6}), src=({:+1.6} {:+1.6} {:+1.6})", + // apply_precession, t, a, + // tile_uvs_obs[[t, a]].u, tile_uvs_obs[[t, a]].v, tile_ws_obs[[t, a]].0, + // tile_uvs_src[[t, a]].u, tile_uvs_src[[t, a]].v, tile_ws_src[[t, a]].0, + // ); + // no difference between the two phase centres for v component + assert_abs_diff_eq!( + tile_uvs_obs[[t, a]].v, + tile_uvs_src[[t, a]].v, + epsilon = 1e-6 + ); + } + // tile 0 is the origin tile + assert_abs_diff_eq!(tile_uvs_obs[[t, 0]].u, 0., epsilon = 1e-6); + assert_abs_diff_eq!(tile_uvs_obs[[t, 0]].v, 0., epsilon = 1e-6); + assert_abs_diff_eq!(tile_ws_obs[[t, 0]].0, 0., epsilon = 1e-6); + assert_abs_diff_eq!(tile_uvs_src[[t, 0]].u, 0., epsilon = 1e-6); + assert_abs_diff_eq!(tile_uvs_src[[t, 0]].v, 0., epsilon = 1e-6); + assert_abs_diff_eq!(tile_ws_src[[t, 0]].0, 0., epsilon = 1e-6); + } + } +} + +#[test] +/// tests vis_rotate_fb by asserting that: +/// - rotated visibilities have the source at the phase centre +/// simulate vis, where at the first timestep, the phase centre is at zenith +/// and a t the second timestep, 1h later, the source is at zenithiono rotated model +fn test_vis_rotation() { + let obs_context = get_simple_obs_context(); + let array_pos = obs_context.array_position; + + let num_tiles = obs_context.get_total_num_tiles(); + let num_times = obs_context.timestamps.len(); + let num_baselines = (num_tiles * (num_tiles - 1)) / 2; + let ant_pairs = (0..num_baselines) + .map(|bl_idx| cross_correlation_baseline_to_tiles(num_tiles, bl_idx)) + .collect_vec(); + let flagged_tiles = HashSet::new(); + let num_chans = obs_context.fine_chan_freqs.len(); + + let fine_chan_freqs_hz = obs_context + .fine_chan_freqs + .iter() + .map(|&f| f as f64) + .collect_vec(); + let lambdas_m = fine_chan_freqs_hz.iter().map(|&f| VEL_C / f).collect_vec(); + + // source is at zenith at 1h + let hour_epoch = obs_context.timestamps[1]; + let lst_1h_rad = get_lmst( + array_pos.longitude_rad, + hour_epoch, + obs_context.dut1.unwrap_or_default(), + ); + let source_radec = + RADec::from_hadec(HADec::from_radians(0., array_pos.latitude_rad), lst_1h_rad); + let source_fd = 1.; + let source_list = SourceList::from(indexmap! { + "One".into() => point_src_i!(source_radec, 0., fine_chan_freqs_hz[0], source_fd), + }); + + let beam = get_beam(num_tiles); + + let mut vis_tfb = Array3::default((num_times, num_chans, num_baselines)); + let mut vis_rot_tfb = Array3::default((num_times, num_chans, num_baselines)); + + // tile uvs and ws in the observation phase centre + let mut tile_uvs_obs = Array2::default((num_times, num_tiles)); + let mut tile_ws_obs = Array2::default((num_times, num_tiles)); + // tile uvs and ws in the source phase centre + let mut tile_uvs_src = Array2::default((num_times, num_tiles)); + let mut tile_ws_src = Array2::default((num_times, num_tiles)); + + for apply_precession in [false, true] { + let modeller = SkyModellerCpu::new( + &beam, + &source_list, + Polarisations::default(), + &obs_context.tile_xyzs, + &fine_chan_freqs_hz, + &flagged_tiles, + obs_context.phase_centre, + array_pos.longitude_rad, + array_pos.latitude_rad, + obs_context.dut1.unwrap_or_default(), + apply_precession, + ); + + vis_tfb.fill(Jones::zero()); + model_timesteps(&modeller, &obs_context.timestamps, vis_tfb.view_mut()).unwrap(); + + setup_tile_uv_w_arrays( + tile_uvs_obs.view_mut(), + tile_ws_obs.view_mut(), + &obs_context, + obs_context.phase_centre, + apply_precession, + ); + setup_tile_uv_w_arrays( + tile_uvs_src.view_mut(), + tile_ws_src.view_mut(), + &obs_context, + source_radec, + apply_precession, + ); + + // iterate over time, rotating visibilities + for (vis_fb, mut vis_rot_fb, tile_ws_obs, tile_ws_src) in izip!( + vis_tfb.outer_iter(), + vis_rot_tfb.view_mut().outer_iter_mut(), + tile_ws_obs.outer_iter(), + tile_ws_src.outer_iter(), + ) { + vis_rotate_fb( + vis_fb.view(), + vis_rot_fb.view_mut(), + tile_ws_obs.as_slice().unwrap(), + tile_ws_src.as_slice().unwrap(), + &lambdas_m, + ); + } + + // display_vis_tfb( + // &"model@obs".into(), + // vis_tfb.view(), + // &obs_context, + // obs_context.phase_centre, + // apply_precession, + // ); + // display_vis_tfb( + // &"rotated@source".into(), + // vis_rot_tfb.view(), + // &obs_context, + // source_radec, + // apply_precession, + // ); + + if !apply_precession { + // rotated vis should always have the source in phase, so no angle in pols XX, YY + for vis_rot in vis_rot_tfb.iter() { + assert_abs_diff_eq!(vis_rot[0].arg(), 0., epsilon = 1e-6); // XX + assert_abs_diff_eq!(vis_rot[3].arg(), 0., epsilon = 1e-6); // YY + } + // baseline 1, from origin to v has no u or w component, should not be affected by the rotation + for (vis, vis_rot) in izip!( + vis_tfb.slice(s![.., .., 1]), + vis_rot_tfb.slice(s![.., .., 1]), + ) { + assert_abs_diff_eq!(vis, vis_rot, epsilon = 1e-6); + } + // in the second timestep, the source should be at the pointing centre, so should not be + // attenuated by the beam + for (vis, vis_rot) in izip!( + vis_tfb.slice(s![1, .., ..]), + vis_rot_tfb.slice(s![1, .., ..]), + ) { + // XX + assert_abs_diff_eq!(vis[0].norm(), source_fd as f32, epsilon = 1e-6); + assert_abs_diff_eq!(vis_rot[0].norm(), source_fd as f32, epsilon = 1e-6); + // YY + assert_abs_diff_eq!(vis[3].norm(), source_fd as f32, epsilon = 1e-6); + assert_abs_diff_eq!(vis_rot[3].norm(), source_fd as f32, epsilon = 1e-6); + } + } + + for (tile_ws_obs, tile_ws_src, vis_fb, vis_rot_fb) in izip!( + tile_ws_obs.outer_iter(), + tile_ws_src.outer_iter(), + vis_tfb.outer_iter(), + vis_rot_tfb.outer_iter(), + ) { + for (lambda_m, vis_b, vis_rot_b) in izip!( + lambdas_m.iter(), + vis_fb.outer_iter(), + vis_rot_fb.outer_iter(), + ) { + for (&(ant1, ant2), vis, vis_rot) in + izip!(ant_pairs.iter(), vis_b.iter(), vis_rot_b.iter(),) + { + let w_obs = tile_ws_obs[ant1] - tile_ws_obs[ant2]; + let w_src = tile_ws_src[ant1] - tile_ws_src[ant2]; + let arg = TAU * (w_src - w_obs) / lambda_m; + for (pol_model, pol_model_rot) in vis.iter().zip_eq(vis_rot.iter()) { + // magnitudes shoud not be affected by rotation + assert_abs_diff_eq!(pol_model.norm(), pol_model_rot.norm(), epsilon = 1e-6); + let pol_model_rot_expected = Complex::from_polar( + pol_model.norm(), + (pol_model.arg() as f64 - arg) as f32, + ); + assert_abs_diff_eq!( + pol_model_rot_expected.arg(), + pol_model_rot.arg(), + epsilon = 1e-6 + ); + } + } + } + } + } +} + +#[test] +// +fn test_weight_average() { + let weights_tfb: Array3 = array![ + [[1., 1., 1., -1.], [2., 2., 2., -2.], [4., 4., 4., -4.],], + [ + [8., 8., 8., -8.], + [16., 16., 16., -16.], + [32., 32., 32., -32.], + ], + ]; + + // 2, 3, 4 + let weights_shape = weights_tfb.dim(); + // 1, 2, 4 + let avg_shape = (1, 2, weights_shape.2); + + let mut weights_avg_tfb = Array3::zeros(avg_shape); + + weights_average(weights_tfb.view(), weights_avg_tfb.view_mut()); + + assert_eq!( + weights_avg_tfb, + array![[[27., 27., 27., 0.], [36., 36., 36., 0.],],] + ); +} + +#[test] +// +fn test_vis_average() { + #[rustfmt::skip] + let vis_tfb: Array3> = array![ + [ + [ Jones::zero(), Jones::zero(), Jones::zero(), Jones::identity() ], + [ Jones::zero(), Jones::zero(), Jones::identity(), Jones::zero() ], + [ Jones::zero(), Jones::zero(), Jones::identity(), Jones::identity() ], + ], + [ + [ Jones::zero(), Jones::identity(), Jones::zero(), Jones::zero() ], + [ Jones::zero(), Jones::identity(), Jones::zero(), Jones::identity() ], + [ Jones::zero(), Jones::identity(), Jones::identity(), Jones::zero() ], + ], + ]; + + #[rustfmt::skip] + let weights_tfb: Array3 = array![ + [ + [1., 1., 1., 1.], + [2., 2., 2., 2.], + [4., 4., 4., 4.], + ], + [ + [8., 8., 8., 8.], + [16., 16., 16., 16.], + [32., 32., 32., 32.], + ], + ]; + + // 2, 3, 4 + let vis_shape = vis_tfb.dim(); + // 1, 2, 4 + let avg_shape = (1, 2, vis_shape.2); + + let mut vis_avg_tfb = Array3::zeros(avg_shape); + + vis_average2(vis_tfb.view(), vis_avg_tfb.view_mut(), weights_tfb.view()); + + assert_eq!( + vis_avg_tfb.slice(s![.., .., 0]), + array![[Jones::zero(), Jones::zero()]] + ); + + assert_eq!( + vis_avg_tfb.slice(s![.., 0, 1]), + array![Jones::identity() * 24. / 27.] + ); + assert_eq!( + vis_avg_tfb.slice(s![.., 0, 2]), + array![Jones::identity() * 2. / 27.] + ); + assert_eq!( + vis_avg_tfb.slice(s![.., 0, 3]), + array![Jones::identity() * 17. / 27.] + ); + assert_eq!( + vis_avg_tfb.slice(s![.., 1, 1]), + array![Jones::identity() * 32. / 36.] + ); + assert_eq!( + vis_avg_tfb.slice(s![.., 1, 2]), + array![Jones::identity() * 36. / 36.] + ); + assert_eq!( + vis_avg_tfb.slice(s![.., 1, 3]), + array![Jones::identity() * 4. / 36.] + ); +} + +#[test] +// +fn test_apply_iono2() { + let obs_context = get_simple_obs_context(); + let array_pos = obs_context.array_position; + + // second timestep is at 1h + let hour_epoch = obs_context.timestamps[1]; + let num_tiles = obs_context.get_total_num_tiles(); + let num_times = obs_context.timestamps.len(); + let num_baselines = (num_tiles * (num_tiles - 1)) / 2; + let ant_pairs = (0..num_baselines) + .map(|bl_idx| cross_correlation_baseline_to_tiles(num_tiles, bl_idx)) + .collect_vec(); + let flagged_tiles = HashSet::new(); + let num_chans = obs_context.fine_chan_freqs.len(); + + let fine_chan_freqs_hz = obs_context + .fine_chan_freqs + .iter() + .map(|&f| f as f64) + .collect_vec(); + let lambdas_m = fine_chan_freqs_hz.iter().map(|&f| VEL_C / f).collect_vec(); + + // source is at zenith at 1h + let lst_1h_rad = get_lmst( + array_pos.longitude_rad, + hour_epoch, + obs_context.dut1.unwrap_or_default(), + ); + let source_radec = + RADec::from_hadec(HADec::from_radians(0., array_pos.latitude_rad), lst_1h_rad); + let source_fd = 1.; + let source_list = SourceList::from(indexmap! { + "One".into() => point_src_i!(source_radec, 0., fine_chan_freqs_hz[0], source_fd), + }); + + let mut vis_tfb = Array3::default((num_times, num_chans, num_baselines)); + let mut vis_iono_tfb = Array3::default((num_times, num_chans, num_baselines)); + + // tile uvs and ws in the source phase centre + let mut tile_uvs_src = Array2::default((num_times, num_tiles)); + let mut tile_ws_src = Array2::default((num_times, num_tiles)); + + let beam = get_beam(num_tiles); + + for apply_precession in [false, true] { + let modeller = SkyModellerCpu::new( + &beam, + &source_list, + Polarisations::default(), + &obs_context.tile_xyzs, + &fine_chan_freqs_hz, + &flagged_tiles, + obs_context.phase_centre, + array_pos.longitude_rad, + array_pos.latitude_rad, + obs_context.dut1.unwrap_or_default(), + apply_precession, + ); + + vis_tfb.fill(Jones::zero()); + model_timesteps(&modeller, &obs_context.timestamps, vis_tfb.view_mut()).unwrap(); + + setup_tile_uv_w_arrays( + tile_uvs_src.view_mut(), + tile_ws_src.view_mut(), + &obs_context, + source_radec, + apply_precession, + ); + + // we want consts such that at lambda = 2m, the shift moves the source to the phase centre + let iono_lmn = source_radec.to_lmn(obs_context.phase_centre); + let consts_lm = (iono_lmn.l / 4., iono_lmn.m / 4.); + // let consts_lm = ((lst_1h_rad-lst_zenith_rad)/4., 0.); + + apply_iono2( + vis_tfb.view(), + vis_iono_tfb.view_mut(), + tile_uvs_src.view(), + consts_lm, + &lambdas_m, + ); + + // display_vis_tfb( + // &"model@obs".into(), + // vis_tfb.view(), + // &obs_context, + // obs_context.phase_centre, + // apply_precession, + // ); + // display_vis_tfb( + // &"iono@source".into(), + // vis_iono_tfb.view(), + // &obs_context, + // source_radec, + // apply_precession, + // ); + + for (time, tile_uvs_src, vis_fb, vis_iono_fb) in izip!( + obs_context.timestamps.iter(), + tile_uvs_src.outer_iter(), + vis_tfb.outer_iter(), + vis_iono_tfb.outer_iter(), + ) { + if !apply_precession { + // baseline 1, from origin to v has no u or w component, should not be affected by the rotation + for (vis, vis_iono) in izip!(vis_fb.slice(s![.., 1]), vis_iono_fb.slice(s![.., 1]),) + { + assert_abs_diff_eq!(vis, vis_iono, epsilon = 1e-6); + } + // in the second timestep, the source should be at the pointing centre, so: + // - should not be attenuated by the beam + // - at lambda=2, iono vis should have the source at the phase centre, so no angle in pols XX, YY + if time == &hour_epoch { + for (jones, jones_iono) in izip!(vis_fb.iter(), vis_iono_fb.iter(),) { + // XX + assert_abs_diff_eq!(jones[0].norm(), source_fd as f32, epsilon = 1e-6); + assert_abs_diff_eq!(jones_iono[0].norm(), source_fd as f32, epsilon = 1e-6); + // YY + assert_abs_diff_eq!(jones[3].norm(), source_fd as f32, epsilon = 1e-6); + assert_abs_diff_eq!(jones_iono[3].norm(), source_fd as f32, epsilon = 1e-6); + } + for vis_iono in vis_iono_fb.slice(s![0, ..]).iter() { + assert_abs_diff_eq!(vis_iono[0].arg(), 0., epsilon = 1e-6); // XX + assert_abs_diff_eq!(vis_iono[3].arg(), 0., epsilon = 1e-6); + // YY + } + } + } + + for (&lambda_m, vis_b, vis_iono_b) in izip!( + lambdas_m.iter(), + vis_fb.outer_iter(), + vis_iono_fb.outer_iter(), + ) { + for (&(ant1, ant2), vis, vis_iono) in + izip!(ant_pairs.iter(), vis_b.iter(), vis_iono_b.iter(),) + { + let UV { u, v } = tile_uvs_src[ant1] - tile_uvs_src[ant2]; + let arg = TAU * (u * consts_lm.0 + v * consts_lm.1) * lambda_m; + for (pol_model, pol_model_iono) in vis.iter().zip_eq(vis_iono.iter()) { + // magnitudes shoud not be affected by iono rotation + assert_abs_diff_eq!( + pol_model.norm(), + pol_model_iono.norm(), + epsilon = 1e-6 + ); + let pol_model_iono_expected = Complex::from_polar( + pol_model.norm(), + (pol_model.arg() as f64 - arg) as f32, + ); + assert_abs_diff_eq!( + pol_model_iono_expected.arg(), + pol_model_iono.arg(), + epsilon = 1e-6 + ); + } + } + } + } + } +} + +#[test] +/// test iono_fit, where residual is just iono rotated model +fn test_iono_fit() { + let obs_context = get_simple_obs_context(); + let array_pos = obs_context.array_position; + + // second timestep is at 1h + let hour_epoch = obs_context.timestamps[1]; + + let num_tiles = obs_context.get_total_num_tiles(); + let num_times = obs_context.timestamps.len(); + let num_baselines = (num_tiles * (num_tiles - 1)) / 2; + let flagged_tiles = HashSet::new(); + let num_chans = obs_context.fine_chan_freqs.len(); + + // lambda = 1m + let fine_chan_freqs_hz = obs_context + .fine_chan_freqs + .iter() + .map(|&f| f as f64) + .collect_vec(); + let lambdas_m = fine_chan_freqs_hz.iter().map(|&f| VEL_C / f).collect_vec(); + + // source is at zenith at 1h + let lst_1h_rad = get_lmst( + array_pos.longitude_rad, + hour_epoch, + obs_context.dut1.unwrap_or_default(), + ); + let source_radec = + RADec::from_hadec(HADec::from_radians(0., array_pos.latitude_rad), lst_1h_rad); + let source_fd = 1.; + let source_list = SourceList::from(indexmap! { + "One".into() => point_src_i!(source_radec, 0., fine_chan_freqs_hz[0], source_fd), + }); + + let mut vis_tfb = Array3::>::zeros((num_times, num_chans, num_baselines)); + let mut vis_iono_tfb = Array3::>::zeros((num_times, num_chans, num_baselines)); + + // tile uvs and ws in the source phase centre + let mut tile_uvs_src = Array2::default((num_times, num_tiles)); + let mut tile_ws_src = Array2::default((num_times, num_tiles)); + + let beam = get_beam(num_tiles); + + for apply_precession in [false, true] { + // unlike the other tests, this is in the SOURCE phase centre + let modeller = SkyModellerCpu::new( + &beam, + &source_list, + Polarisations::default(), + &obs_context.tile_xyzs, + &fine_chan_freqs_hz, + &flagged_tiles, + source_radec, + array_pos.longitude_rad, + array_pos.latitude_rad, + obs_context.dut1.unwrap_or_default(), + apply_precession, + ); + + vis_tfb.fill(Jones::zero()); + + model_timesteps(&modeller, &obs_context.timestamps, vis_tfb.view_mut()).unwrap(); + + setup_tile_uv_w_arrays( + tile_uvs_src.view_mut(), + tile_ws_src.view_mut(), + &obs_context, + source_radec, + apply_precession, + ); + + // display_vis_tfb( + // &"model@obs".into(), + // vis_tfb.view(), + // &obs_context, + // obs_context.phase_centre, + // apply_precession, + // ); + + let shape = vis_tfb.shape(); + let weights = Array3::ones((shape[0], shape[1], shape[2])); + + for consts_lm in [ + (0., 0.), + (0.0001, -0.0003), + (0.0003, -0.0001), + (-0.0007, 0.0001), + ] { + apply_iono2( + vis_tfb.view(), + vis_iono_tfb.view_mut(), + tile_uvs_src.view(), + consts_lm, + &lambdas_m, + ); + + // display_vis_tfb( + // &"iono@obs".into(), + // vis_iono_tfb.view(), + // &obs_context, + // obs_context.phase_centre, + // apply_precession, + // ); + + let results = iono_fit( + vis_iono_tfb.view(), + weights.view(), + vis_tfb.view(), + &lambdas_m, + tile_uvs_src.view(), + ); + + // println!("prec: {:?}, expected: {:?}, got: {:?}", apply_precession, consts_lm, &results); + + assert_abs_diff_eq!(results[0], consts_lm.0, epsilon = 1e-8); + assert_abs_diff_eq!(results[1], consts_lm.1, epsilon = 1e-8); + } + } +} + +#[test] +/// - synthesize model visibilities +/// - apply ionospheric rotation +/// - create residual: ionospheric - model +/// - ap ply_iono3 should result in empty visibilitiesiono rotated model +fn test_apply_iono3() { + let obs_context = get_simple_obs_context(); + let array_pos = obs_context.array_position; + + let num_tiles = obs_context.get_total_num_tiles(); + let num_times = obs_context.timestamps.len(); + let num_baselines = (num_tiles * (num_tiles - 1)) / 2; + let flagged_tiles = HashSet::new(); + let num_chans = obs_context.fine_chan_freqs.len(); + + // lambda = 1m + let fine_chan_freqs_hz = obs_context + .fine_chan_freqs + .iter() + .map(|&f| f as f64) + .collect_vec(); + let lambdas_m = fine_chan_freqs_hz.iter().map(|&f| VEL_C / f).collect_vec(); + + // source is at zenith at 1h + let lst_1h_rad = get_lmst( + array_pos.longitude_rad, + obs_context.timestamps[1], + obs_context.dut1.unwrap_or_default(), + ); + let source_radec = + RADec::from_hadec(HADec::from_radians(0., array_pos.latitude_rad), lst_1h_rad); + let source_fd = 1.; + let source_list = SourceList::from(indexmap! { + "One".into() => point_src_i!(source_radec, 0., fine_chan_freqs_hz[0], source_fd), + }); + + let beam = get_beam(num_tiles); + + // residual visibilities in the observation phase centre + let mut vis_resid_obs_tfb = Array3::>::zeros((num_times, num_chans, num_baselines)); + // model visibilities in the observation phase centre + let mut vis_model_obs_tfb = Array3::>::zeros((num_times, num_chans, num_baselines)); + // iono rotated model visibilities in the observation phase centre + let mut vis_iono_obs_tfb = Array3::>::zeros((num_times, num_chans, num_baselines)); + + // tile uvs and ws in the source phase centre + let mut tile_uvs_src = Array2::default((num_times, num_tiles)); + let mut tile_ws_src = Array2::default((num_times, num_tiles)); + + for apply_precession in [false, true] { + let modeller = SkyModellerCpu::new( + &beam, + &source_list, + Polarisations::default(), + &obs_context.tile_xyzs, + &fine_chan_freqs_hz, + &flagged_tiles, + obs_context.phase_centre, + array_pos.longitude_rad, + array_pos.latitude_rad, + obs_context.dut1.unwrap_or_default(), + apply_precession, + ); + + vis_model_obs_tfb.fill(Jones::zero()); + + model_timesteps( + &modeller, + &obs_context.timestamps, + vis_model_obs_tfb.view_mut(), + ) + .unwrap(); + + setup_tile_uv_w_arrays( + tile_uvs_src.view_mut(), + tile_ws_src.view_mut(), + &obs_context, + source_radec, + apply_precession, + ); + + // display_vis_tfb( + // &"model@obs".into(), + // vis_model_obs_tfb.view(), + // &obs_context, + // obs_context.phase_centre, + // apply_precession, + // ); + + for consts_lm in [(0.0001, -0.0003), (0.0003, -0.0001), (-0.0007, 0.0001)] { + // apply iono rotation at source phase to model at observation phase + apply_iono2( + vis_model_obs_tfb.view(), + vis_iono_obs_tfb.view_mut(), + tile_uvs_src.view(), + consts_lm, + &lambdas_m, + ); + + // subtract model from iono at observation phase centre + vis_resid_obs_tfb.assign(&vis_iono_obs_tfb); + vis_resid_obs_tfb -= &vis_model_obs_tfb; + + apply_iono3( + vis_model_obs_tfb.view(), + vis_resid_obs_tfb.view_mut(), + tile_uvs_src.view(), + consts_lm, + (0.0, 0.0), + &lambdas_m, + ); + + // display_vis_tfb( + // &"residual@obs".into(), + // vis_residual_obs_tfb.view(), + // &obs_context, + // obs_context.phase_centre, + // apply_precession, + // ); + + for jones_residual in vis_resid_obs_tfb.iter() { + for pol_residual in jones_residual.iter() { + assert_abs_diff_eq!(pol_residual.norm(), 0., epsilon = 1e-6); + } + } + } + } +} + +#[derive(Clone, Copy)] +#[allow(clippy::upper_case_acronyms)] +enum PeelType { + CPU, + + #[cfg(any(feature = "cuda", feature = "hip"))] + Gpu, +} + +/// Test a peel function with and without precession on a single source +#[track_caller] +fn test_peel_single_source(peel_type: PeelType) { + // enable trace + // let mut builder = env_logger::Builder::from_default_env(); + // builder.target(env_logger::Target::Stdout); + // builder.format_target(false); + // builder.filter_level(log::LevelFilter::Trace); + // builder.init(); + + // modify obs_context so that timesteps are closer together + let mut obs_context = get_simple_obs_context(); + let hour_epoch = obs_context.timestamps[1]; + let time_res = Duration::from_seconds(1.0); + let second_epoch = obs_context.timestamps[0] + time_res; + obs_context.time_res = Some(time_res); + obs_context.timestamps[1] = second_epoch; + + let array_pos = obs_context.array_position; + let num_tiles = obs_context.get_total_num_tiles(); + let num_times = obs_context.timestamps.len(); + let num_baselines = (num_tiles * (num_tiles - 1)) / 2; + let flagged_tiles = HashSet::new(); + let num_chans = obs_context.fine_chan_freqs.len(); + + // lambda = 1m + let fine_chan_freqs_hz = obs_context + .fine_chan_freqs + .iter() + .map(|&f| f as f64) + .collect_vec(); + let lambdas_m = fine_chan_freqs_hz.iter().map(|&f| VEL_C / f).collect_vec(); + + // source is at zenith at 1h (before precession) + let lst_1h_rad = get_lmst( + array_pos.longitude_rad, + hour_epoch, + obs_context.dut1.unwrap_or_default(), + ); + let source_radec = + RADec::from_hadec(HADec::from_radians(0., array_pos.latitude_rad), lst_1h_rad); + let source_fd = 1.; + let source_list = SourceList::from(indexmap! { + "One".into() => point_src_i!(source_radec, 0., fine_chan_freqs_hz[0], source_fd), + }); + + let beam = get_beam(num_tiles); + + // model visibilities in the observation phase centre + let mut vis_model_obs_tfb = Array3::zeros((num_times, num_chans, num_baselines)); + // iono rotated model visibilities in the observation phase centre + let mut vis_iono_obs_tfb = Array3::zeros((num_times, num_chans, num_baselines)); + // residual visibilities in the observation phase centre + let mut vis_residual_obs_tfb = Array3::zeros((num_times, num_chans, num_baselines)); + + // tile uvs and ws in the source phase centre + let mut tile_uvs_src = Array2::default((num_times, num_tiles)); + let mut tile_ws_src = Array2::default((num_times, num_tiles)); + + let timeblock = Timeblock { + index: 0, + range: 0..2, + timestamps: Vec1::try_from_vec( + obs_context + .timestamps + .iter() + .enumerate() + .map(|(i, e)| (*e, i)) + .collect(), + ) + .unwrap(), + median: obs_context.timestamps[0], + }; + + let vis_shape = vis_residual_obs_tfb.dim(); + let vis_weights = Array3::::ones(vis_shape); + let source_weighted_positions = [source_radec]; + + let multi_progress = MultiProgress::with_draw_target(ProgressDrawTarget::hidden()); + + for apply_precession in [false, true] { + let mut high_res_modeller = new_sky_modeller( + &beam, + &source_list, + Polarisations::default(), + &obs_context.tile_xyzs, + &fine_chan_freqs_hz, + &flagged_tiles, + obs_context.phase_centre, + array_pos.longitude_rad, + array_pos.latitude_rad, + obs_context.dut1.unwrap_or_default(), + apply_precession, + ) + .unwrap(); + + let mut low_res_modeller = new_sky_modeller( + &beam, + &source_list, + Polarisations::default(), + &obs_context.tile_xyzs, + &fine_chan_freqs_hz, + &flagged_tiles, + obs_context.phase_centre, + array_pos.longitude_rad, + array_pos.latitude_rad, + obs_context.dut1.unwrap_or_default(), + apply_precession, + ) + .unwrap(); + + vis_model_obs_tfb.fill(Jones::zero()); + + model_timesteps( + &*high_res_modeller, + &obs_context.timestamps, + vis_model_obs_tfb.view_mut(), + ) + .unwrap(); + + setup_tile_uv_w_arrays( + tile_uvs_src.view_mut(), + tile_ws_src.view_mut(), + &obs_context, + source_radec, + apply_precession, + ); + + // display_vis_tfb( + // &"model@obs".into(), + // vis_model_obs_tfb.view(), + // &obs_context, + // obs_context.phase_centre, + // apply_precession, + // ); + + for consts_lm in [ + (0., 0.), + (0.0001, -0.0003), + (0.0003, -0.0001), + (-0.0007, 0.0001), + ] { + log::info!("Testing with iono consts {consts_lm:?}"); + apply_iono2( + vis_model_obs_tfb.view(), + vis_iono_obs_tfb.view_mut(), + tile_uvs_src.view(), + consts_lm, + &lambdas_m, + ); + + // display_vis_tfb( + // &format!("iono@obs prec={}, ({}, {})", apply_precession, &consts_lm.0, &consts_lm.1), + // vis_iono_obs_tfb.view(), + // &obs_context, + // obs_context.phase_centre, + // apply_precession, + // ); + + // subtract model from iono at observation phase centre + vis_residual_obs_tfb.assign(&vis_iono_obs_tfb); + vis_residual_obs_tfb -= &vis_model_obs_tfb; + + // display_vis_tfb( + // &"residual@obs".into(), + // vis_residual_obs_tfb.view(), + // &obs_context, + // obs_context.phase_centre, + // apply_precession, + // ); + + let mut iono_consts = vec![(0., 0.); 1]; + + // When peel_cpu and peel_gpu are able to take generic + // `SkyModeller` objects (requires the generic objects to take + // currently CUDA-only methods), uncomment the following code and + // delete what follows. + + // let function = match peel_type { + // CPU => peel_cpu, + // #[cfg(any(feature = "cuda", feature = "hip"))] + // CUDA => peel_gpu, + // }; + // function( + // vis_residual_obs_tfb.view_mut(), + // vis_weights.view(), + // &timeblock, + // &source_list, + // &mut iono_consts, + // &source_weighted_positions, + // num_sources_to_iono_subtract, + // &fine_chan_freqs_hz, + // &lambdas_m, + // &lambdas_m, + // &obs_context, + // obs_context.array_position.unwrap(), + // &obs_context.tile_xyzs, + // low_res_modeller.deref_mut(), + // high_res_modeller.deref_mut(), + // obs_context.dut1.unwrap_or_default(), + // !apply_precession, + // &multi_progress, + // ) + // .unwrap(); + match peel_type { + PeelType::CPU => peel_cpu( + vis_residual_obs_tfb.view_mut(), + vis_weights.view(), + &timeblock, + &source_list, + &mut iono_consts, + &source_weighted_positions, + 3, + &fine_chan_freqs_hz, + &lambdas_m, + &lambdas_m, + &obs_context, + obs_context.array_position, + &obs_context.tile_xyzs, + &mut *low_res_modeller, + &mut *high_res_modeller, + obs_context.dut1.unwrap_or_default(), + !apply_precession, + &multi_progress, + ) + .unwrap(), + + #[cfg(any(feature = "cuda", feature = "hip"))] + PeelType::Gpu => { + let mut high_res_modeller = crate::model::SkyModellerGpu::new( + &beam, + &source_list, + Polarisations::default(), + &obs_context.tile_xyzs, + &fine_chan_freqs_hz, + &flagged_tiles, + obs_context.phase_centre, + array_pos.longitude_rad, + array_pos.latitude_rad, + obs_context.dut1.unwrap_or_default(), + apply_precession, + ) + .unwrap(); + + let mut low_res_modeller = crate::model::SkyModellerGpu::new( + &beam, + &source_list, + Polarisations::default(), + &obs_context.tile_xyzs, + &fine_chan_freqs_hz, + &flagged_tiles, + obs_context.phase_centre, + array_pos.longitude_rad, + array_pos.latitude_rad, + obs_context.dut1.unwrap_or_default(), + apply_precession, + ) + .unwrap(); + + peel_gpu( + vis_residual_obs_tfb.view_mut(), + vis_weights.view(), + &timeblock, + &source_list, + &mut iono_consts, + &source_weighted_positions, + 3, + &fine_chan_freqs_hz, + &lambdas_m, + &lambdas_m, + &obs_context, + obs_context.array_position, + &obs_context.tile_xyzs, + &mut low_res_modeller, + &mut high_res_modeller, + obs_context.dut1.unwrap_or_default(), + !apply_precession, + &multi_progress, + ) + .unwrap() + } + }; + + println!("prec: {apply_precession:?}, expected: {consts_lm:?}, got: {iono_consts:?}"); + + display_vis_tfb( + &"peeled@obs".into(), + vis_residual_obs_tfb.view(), + &obs_context, + obs_context.phase_centre, + apply_precession, + ); + + assert_abs_diff_eq!(iono_consts[0].0, consts_lm.0, epsilon = 7e-10); + assert_abs_diff_eq!(iono_consts[0].1, consts_lm.1, epsilon = 7e-10); + + // peel should perfectly remove the iono rotate model vis + for jones_residual in vis_residual_obs_tfb.iter() { + for pol_residual in jones_residual.iter() { + #[cfg(not(feature = "gpu-single"))] + let eps = 1.3e-8; + #[cfg(feature = "gpu-single")] + let eps = 1.7e-8; + assert_abs_diff_eq!(pol_residual.norm(), 0., epsilon = eps); + } + } + } + } +} + +#[track_caller] +fn test_peel_multi_source(peel_type: PeelType) { + // // enable trace + // let mut builder = env_logger::Builder::from_default_env(); + // builder.target(env_logger::Target::Stdout); + // builder.format_target(false); + // builder.filter_level(log::LevelFilter::Trace); + // builder.init(); + + // modify obs_context so that timesteps are closer together + let mut obs_context = get_complex_obs_context(); + let time_res = Duration::from_seconds(1.0); + let second_epoch = obs_context.timestamps[0] + time_res; + obs_context.time_res = Some(time_res); + obs_context.timestamps[1] = second_epoch; + + let array_pos = obs_context.array_position; + let num_tiles = obs_context.get_total_num_tiles(); + let num_times = obs_context.timestamps.len(); + let num_baselines = (num_tiles * (num_tiles - 1)) / 2; + let flagged_tiles = HashSet::new(); + let num_chans = obs_context.fine_chan_freqs.len(); + + // lambda = 1m + let fine_chan_freqs_hz = obs_context + .fine_chan_freqs + .iter() + .map(|&f| f as f64) + .collect_vec(); + let lambdas_m = fine_chan_freqs_hz.iter().map(|&f| VEL_C / f).collect_vec(); + + let lst_0h_rad = get_lmst( + array_pos.longitude_rad, + obs_context.timestamps[0], + obs_context.dut1.unwrap_or_default(), + ); + let source_midpoint = + RADec::from_hadec(HADec::from_radians(0., array_pos.latitude_rad), lst_0h_rad); + + let source_list = SourceList::from(indexmap! { + "Four".into() => point_src_i!(RADec {ra: source_midpoint.ra + 0.05, dec: source_midpoint.dec + 0.05}, 0., fine_chan_freqs_hz[0], 4.), + "Three".into() => point_src_i!(RADec {ra: source_midpoint.ra + 0.03, dec: source_midpoint.dec - 0.03}, 0., fine_chan_freqs_hz[0], 3.), + // "Two".into() => point_src_i!(RADec {ra: source_midpoint.ra - 0.01, dec: source_midpoint.dec + 0.02}, 0., fine_chan_freqs_hz[0], 2.), + // "One".into() => point_src_i!(RADec {ra: source_midpoint.ra - 0.02, dec: source_midpoint.dec - 0.01}, 0., fine_chan_freqs_hz[0], 1.), + }); + + let source_weighted_positions = source_list + .iter() + .map(|(_, source)| source.components[0].radec) + .collect_vec(); + + let iono_consts = [ + (-0.00002, -0.00001), + (0.00001, -0.00003), + (0.0003, -0.0001), + (-0.0007, 0.0001), + ]; + + let beam = get_beam(num_tiles); + + // model visibilities of each source + let mut vis_model_tmp_tfb = Array3::>::zeros((num_times, num_chans, num_baselines)); + // iono rotated visibilities of each source + let mut vis_iono_tmp_tfb = Array3::>::zeros((num_times, num_chans, num_baselines)); + // residual visibilities in the observation phase centre + let mut vis_residual_obs_tfb = + Array3::>::zeros((num_times, num_chans, num_baselines)); + // tile uvs and ws in the source phase centre + let mut tile_uvs_src = Array2::default((num_times, num_tiles)); + let mut tile_ws_src = Array2::default((num_times, num_tiles)); + + let vis_weights = Array3::::ones(vis_residual_obs_tfb.dim()); + + let timeblock = Timeblock { + index: 0, + range: 0..2, + timestamps: Vec1::try_from_vec( + obs_context + .timestamps + .iter() + .enumerate() + .map(|(i, e)| (*e, i)) + .collect(), + ) + .unwrap(), + median: obs_context.timestamps[0], + }; + + let num_sources_to_iono_subtract = source_list.len(); + + let multi_progress = MultiProgress::with_draw_target(ProgressDrawTarget::hidden()); + + for apply_precession in [false, true] { + let mut high_res_modeller = new_sky_modeller( + &beam, + &source_list, + Polarisations::default(), + &obs_context.tile_xyzs, + &fine_chan_freqs_hz, + &flagged_tiles, + obs_context.phase_centre, + array_pos.longitude_rad, + array_pos.latitude_rad, + obs_context.dut1.unwrap_or_default(), + apply_precession, + ) + .unwrap(); + + let mut low_res_modeller = new_sky_modeller( + &beam, + &source_list, + Polarisations::default(), + &obs_context.tile_xyzs, + &fine_chan_freqs_hz, + &flagged_tiles, + obs_context.phase_centre, + array_pos.longitude_rad, + array_pos.latitude_rad, + obs_context.dut1.unwrap_or_default(), + apply_precession, + ) + .unwrap(); + + vis_residual_obs_tfb.fill(Jones::zero()); + + // model each source in source_list and rotate by iono_consts with apply_iono2 + for (&consts_lm, (name, source)) in izip!(iono_consts.iter(), source_list.iter(),) { + let source_radec = source.components[0].radec; + println!("source {} radec {:?}", name, &source_radec); + + high_res_modeller + .update_with_a_source(source, obs_context.phase_centre) + .unwrap(); + + // model visibilities in the observation phase centre + vis_model_tmp_tfb.fill(Jones::zero()); + model_timesteps( + &*high_res_modeller, + &obs_context.timestamps, + vis_model_tmp_tfb.view_mut(), + ) + .unwrap(); + + setup_tile_uv_w_arrays( + tile_uvs_src.view_mut(), + tile_ws_src.view_mut(), + &obs_context, + source_radec, + apply_precession, + ); + + apply_iono2( + vis_model_tmp_tfb.view(), + vis_iono_tmp_tfb.view_mut(), + tile_uvs_src.view(), + consts_lm, + &lambdas_m, + ); + + display_vis_tfb( + &format!("iono@src={} consts={:?}", name, &consts_lm), + vis_iono_tmp_tfb.view(), + &obs_context, + source_radec, + apply_precession, + ); + + // add iono rotated and subtract model visibilities from residual + Zip::from(vis_residual_obs_tfb.view_mut()) + .and(vis_iono_tmp_tfb.view()) + .and(vis_model_tmp_tfb.view()) + .for_each(|res, iono, model| *res += *iono - *model); + } + + let mut iono_consts_result = vec![(0., 0.); num_sources_to_iono_subtract]; + + // When peel_cpu and peel_gpu are able to take generic + // `SkyModeller` objects (requires the generic objects to take + // currently CUDA-only methods), uncomment the following code and + // delete what follows. + + // let function = match peel_type { + // CPU => peel_cpu, + // #[cfg(any(feature = "cuda", feature = "hip"))] + // CUDA => peel_gpu, + // }; + // function( + // vis_residual_obs_tfb.view_mut(), + // vis_weights.view(), + // &timeblock, + // &source_list, + // &mut iono_consts, + // &source_weighted_positions, + // num_sources_to_iono_subtract, + // &fine_chan_freqs_hz, + // &lambdas_m, + // &lambdas_m, + // &obs_context, + // obs_context.array_position.unwrap(), + // &obs_context.tile_xyzs, + // low_res_modeller.deref_mut(), + // high_res_modeller.deref_mut(), + // obs_context.dut1.unwrap_or_default(), + // !apply_precession, + // &multi_progress, + // ) + // .unwrap(); + match peel_type { + PeelType::CPU => peel_cpu( + vis_residual_obs_tfb.view_mut(), + vis_weights.view(), + &timeblock, + &source_list, + &mut iono_consts_result, + &source_weighted_positions, + 3, + &fine_chan_freqs_hz, + &lambdas_m, + &lambdas_m, + &obs_context, + obs_context.array_position, + &obs_context.tile_xyzs, + &mut *low_res_modeller, + &mut *high_res_modeller, + obs_context.dut1.unwrap_or_default(), + !apply_precession, + &multi_progress, + ) + .unwrap(), + + #[cfg(any(feature = "cuda", feature = "hip"))] + PeelType::Gpu => { + let mut high_res_modeller = crate::model::SkyModellerGpu::new( + &beam, + &source_list, + Polarisations::default(), + &obs_context.tile_xyzs, + &fine_chan_freqs_hz, + &flagged_tiles, + obs_context.phase_centre, + array_pos.longitude_rad, + array_pos.latitude_rad, + obs_context.dut1.unwrap_or_default(), + apply_precession, + ) + .unwrap(); + + let mut low_res_modeller = crate::model::SkyModellerGpu::new( + &beam, + &source_list, + Polarisations::default(), + &obs_context.tile_xyzs, + &fine_chan_freqs_hz, + &flagged_tiles, + obs_context.phase_centre, + array_pos.longitude_rad, + array_pos.latitude_rad, + obs_context.dut1.unwrap_or_default(), + apply_precession, + ) + .unwrap(); + + peel_gpu( + vis_residual_obs_tfb.view_mut(), + vis_weights.view(), + &timeblock, + &source_list, + &mut iono_consts_result, + &source_weighted_positions, + 3, + &fine_chan_freqs_hz, + &lambdas_m, + &lambdas_m, + &obs_context, + obs_context.array_position, + &obs_context.tile_xyzs, + &mut low_res_modeller, + &mut high_res_modeller, + obs_context.dut1.unwrap_or_default(), + !apply_precession, + &multi_progress, + ) + .unwrap() + } + } + + display_vis_tfb( + &"peeled@obs".into(), + vis_residual_obs_tfb.view(), + &obs_context, + obs_context.phase_centre, + apply_precession, + ); + + for (expected, result) in izip!(iono_consts.iter(), iono_consts_result.iter(),) { + println!("prec: {apply_precession:?}, expected: {expected:?}, got: {result:?}"); + assert_abs_diff_eq!(expected.0, result.0, epsilon = 3e-7); + assert_abs_diff_eq!(expected.1, result.1, epsilon = 3e-7); + } + + // peel should perfectly remove the iono rotate model vis + for jones_residual in vis_residual_obs_tfb.iter() { + for pol_residual in jones_residual.iter() { + #[cfg(not(feature = "gpu-single"))] + let eps = 3e-7; + #[cfg(feature = "gpu-single")] + let eps = 5e-7; + assert_abs_diff_eq!(pol_residual.norm(), 0., epsilon = eps); + } + } + } +} + +#[test] +fn test_peel_cpu_single_source() { + // let mut builder = env_logger::Builder::from_default_env(); + // builder.target(env_logger::Target::Stdout); + // builder.format_target(false); + // builder.filter_level(log::LevelFilter::Trace); + // builder.init(); + test_peel_single_source(PeelType::CPU) +} + +#[test] +fn test_peel_cpu_multi_source() { + test_peel_multi_source(PeelType::CPU) +} + +#[cfg(any(feature = "cuda", feature = "hip"))] +mod gpu_tests { + use std::ffi::CStr; + + use marlu::{pos::xyz::xyzs_to_cross_uvws, UVW}; + + use super::*; + use crate::{ + gpu::{self, DevicePointer, GpuFloat}, + model::SkyModellerGpu, + }; + + /// Populate the [UVW] array ([times, baselines]) for the given [ObsContext]. + fn setup_uvw_array( + mut uvws: ArrayViewMut2, + obs_context: &ObsContext, + phase_centre: RADec, + apply_precession: bool, + ) { + let array_pos = obs_context.array_position; + let num_tiles = obs_context.get_total_num_tiles(); + let mut tile_uvws_tmp = vec![UVW::default(); num_tiles]; + // let num_cross_baselines = (num_tiles * (num_tiles - 1)) / 2; + for (&time, mut uvws) in izip!(obs_context.timestamps.iter(), uvws.outer_iter_mut(),) { + let (lmst, precessed_xyzs) = if apply_precession { + let precession_info = precess_time( + array_pos.longitude_rad, + array_pos.latitude_rad, + obs_context.phase_centre, + time, + obs_context.dut1.unwrap_or_default(), + ); + let precessed_xyzs = precession_info.precess_xyz(&obs_context.tile_xyzs); + (precession_info.lmst_j2000, precessed_xyzs) + } else { + let lmst = get_lmst( + array_pos.longitude_rad, + time, + obs_context.dut1.unwrap_or_default(), + ); + (lmst, obs_context.tile_xyzs.clone().into()) + }; + let hadec = phase_centre.to_hadec(lmst); + let (s_ha, c_ha) = hadec.ha.sin_cos(); + let (s_dec, c_dec) = hadec.dec.sin_cos(); + for (tile_uvw, &precessed_xyz) in + izip!(tile_uvws_tmp.iter_mut(), precessed_xyzs.iter(),) + { + *tile_uvw = UVW::from_xyz_inner(precessed_xyz, s_ha, c_ha, s_dec, c_dec); + } + let mut count = 0; + for (i, t1) in tile_uvws_tmp.iter().enumerate() { + for t2 in tile_uvws_tmp.iter().skip(i + 1) { + uvws[count] = *t1 - *t2; + count += 1; + } + } + } + } + + #[test] + /// - synthesize model visibilities + /// - apply ionospheric rotation + /// - create residual: ionospheric - model + /// - ap ply_iono3 should result in empty visibilitiesiono rotated model + fn test_gpu_subtract_iono() { + let obs_context = get_simple_obs_context(); + let array_pos = obs_context.array_position; + + let num_tiles = obs_context.get_total_num_tiles(); + let num_times = obs_context.timestamps.len(); + let num_baselines = (num_tiles * (num_tiles - 1)) / 2; + let flagged_tiles = HashSet::new(); + let num_chans = obs_context.fine_chan_freqs.len(); + + // lambda = 1m + let fine_chan_freqs_hz = obs_context + .fine_chan_freqs + .iter() + .map(|&f| f as f64) + .collect_vec(); + let lambdas_m = fine_chan_freqs_hz.iter().map(|&f| VEL_C / f).collect_vec(); + + // source is at zenith at 1h + let lst_1h_rad = get_lmst( + array_pos.longitude_rad, + obs_context.timestamps[1], + obs_context.dut1.unwrap_or_default(), + ); + let source_radec = + RADec::from_hadec(HADec::from_radians(0., array_pos.latitude_rad), lst_1h_rad); + let source_fd = 1.; + let source_list = SourceList::from(indexmap! { + "One".into() => point_src_i!(source_radec, 0., fine_chan_freqs_hz[0], source_fd), + }); + + let beam = get_beam(num_tiles); + + // residual visibilities in the observation phase centre + let mut vis_residual_obs_tfb = + Array3::>::zeros((num_times, num_chans, num_baselines)); + // model visibilities in the observation phase centre + let mut vis_model_obs_tfb = + Array3::>::zeros((num_times, num_chans, num_baselines)); + // iono rotated model visibilities in the observation phase centre + let mut vis_iono_obs_tfb = + Array3::>::zeros((num_times, num_chans, num_baselines)); + // tile uvs and ws in the source phase centre + let mut tile_uvws_src = Array2::default((num_times, num_tiles)); + let mut tile_uvs_src = Array2::default((num_times, num_tiles)); + let mut tile_ws_src = Array2::default((num_times, num_tiles)); + + for apply_precession in [false, true] { + let modeller = SkyModellerGpu::new( + &beam, + &source_list, + Polarisations::default(), + &obs_context.tile_xyzs, + &fine_chan_freqs_hz, + &flagged_tiles, + obs_context.phase_centre, + array_pos.longitude_rad, + array_pos.latitude_rad, + obs_context.dut1.unwrap_or_default(), + apply_precession, + ) + .unwrap(); + + model_timesteps( + &modeller, + &obs_context.timestamps, + vis_model_obs_tfb.view_mut(), + ) + .unwrap(); + + setup_uvw_array( + tile_uvws_src.view_mut(), + &obs_context, + source_radec, + apply_precession, + ); + setup_tile_uv_w_arrays( + tile_uvs_src.view_mut(), + tile_ws_src.view_mut(), + &obs_context, + source_radec, + apply_precession, + ); + + // display_vis_tfb( + // &"model@obs".into(), + // vis_model_obs_tfb.view(), + // &obs_context, + // obs_context.phase_centre, + // apply_precession, + // ); + + let d_high_res_model = + DevicePointer::copy_to_device(vis_model_obs_tfb.as_slice().unwrap()).unwrap(); + + let gpu_uvws_src = tile_uvws_src.mapv(|uvw| gpu::UVW { + u: uvw.u as GpuFloat, + v: uvw.v as GpuFloat, + w: uvw.w as GpuFloat, + }); + + let d_uvws_src = + DevicePointer::copy_to_device(gpu_uvws_src.as_slice().unwrap()).unwrap(); + let d_lambdas = DevicePointer::copy_to_device( + &lambdas_m.iter().map(|l| *l as GpuFloat).collect::>(), + ) + .unwrap(); + + for consts_lm in [(0.0001, -0.0003), (0.0003, -0.0001), (-0.0007, 0.0001)] { + // apply iono rotation at source phase to model at observation phase + apply_iono2( + vis_model_obs_tfb.view(), + vis_iono_obs_tfb.view_mut(), + tile_uvs_src.view(), + consts_lm, + &lambdas_m, + ); + + // display_vis_tfb( + // &format!("iono@obs ({}, {})", &consts_lm.0, &consts_lm.1), + // vis_iono_obs_tfb.view(), + // &obs_context, + // obs_context.phase_centre, + // apply_precession, + // ); + + // subtract model from iono at observation phase centre + vis_residual_obs_tfb.assign(&vis_iono_obs_tfb); + vis_residual_obs_tfb -= &vis_model_obs_tfb; + + // display_vis_tfb( + // &"residual@obs before".into(), + // vis_residual_obs_tfb.view(), + // &obs_context, + // obs_context.phase_centre, + // apply_precession, + // ); + + let mut d_high_res_vis = + DevicePointer::copy_to_device(vis_residual_obs_tfb.as_slice().unwrap()) + .unwrap(); + + let error_message_ptr = unsafe { + gpu::subtract_iono( + d_high_res_vis.get_mut().cast(), + d_high_res_model.get().cast(), + consts_lm.0, + consts_lm.1, + 0.0, + 0.0, + d_uvws_src.get(), + d_lambdas.get(), + num_times.try_into().unwrap(), + num_baselines.try_into().unwrap(), + num_chans.try_into().unwrap(), + ) + }; + assert!( + error_message_ptr.is_null(), + "{}", + unsafe { CStr::from_ptr(error_message_ptr) } + .to_str() + .unwrap_or("") + ); + + d_high_res_vis + .copy_from_device(vis_residual_obs_tfb.as_slice_mut().unwrap()) + .unwrap(); + + // display_vis_tfb( + // &"residual@obs after".into(), + // vis_residual_obs_tfb.view(), + // &obs_context, + // obs_context.phase_centre, + // apply_precession, + // ); + + for jones_residual in vis_residual_obs_tfb.iter() { + for pol_residual in jones_residual.iter() { + assert_abs_diff_eq!(pol_residual.norm(), 0., epsilon = 1e-6); + } + } + } + } + } + + #[test] + fn test_rotate_average() { + let obs_context = get_simple_obs_context(); + let array_pos = obs_context.array_position; + + let num_tiles = obs_context.get_total_num_tiles(); + let num_times = obs_context.timestamps.len(); + let num_baselines = (num_tiles * (num_tiles - 1)) / 2; + let flagged_tiles = HashSet::new(); + let num_chans = obs_context.fine_chan_freqs.len(); + + let fine_chan_freqs_hz = obs_context + .fine_chan_freqs + .iter() + .map(|&f| f as f64) + .collect_vec(); + let lambdas_m = fine_chan_freqs_hz.iter().map(|&f| VEL_C / f).collect_vec(); + + // source is at zenith at 1h + let hour_epoch = obs_context.timestamps[1]; + let lst_1h_rad = get_lmst( + array_pos.longitude_rad, + hour_epoch, + obs_context.dut1.unwrap_or_default(), + ); + let source_radec = + RADec::from_hadec(HADec::from_radians(0., array_pos.latitude_rad), lst_1h_rad); + let source_fd = 1.; + let source_list = SourceList::from(indexmap! { + "One".into() => point_src_i!(source_radec, 0., fine_chan_freqs_hz[0], source_fd), + }); + + let beam = get_beam(num_tiles); + + let mut vis_tfb = Array3::default((num_times, num_chans, num_baselines)); + let mut vis_rot_tfb = Array3::default((num_times, num_chans, num_baselines)); + + // tile uvs and ws in the observation phase centre + let mut tile_uvs_obs = Array2::default((num_times, num_tiles)); + let mut tile_ws_obs = Array2::default((num_times, num_tiles)); + // tile uvs and ws in the source phase centre + let mut tile_uvs_src = Array2::default((num_times, num_tiles)); + let mut tile_ws_src = Array2::default((num_times, num_tiles)); + + let weights_tfb = Array3::from_elem(vis_tfb.dim(), 2.0); + + for apply_precession in [false, true] { + let modeller = SkyModellerCpu::new( + &beam, + &source_list, + Polarisations::default(), + &obs_context.tile_xyzs, + &fine_chan_freqs_hz, + &flagged_tiles, + obs_context.phase_centre, + array_pos.longitude_rad, + array_pos.latitude_rad, + obs_context.dut1.unwrap_or_default(), + apply_precession, + ); + + vis_tfb.fill(Jones::zero()); + model_timesteps(&modeller, &obs_context.timestamps, vis_tfb.view_mut()).unwrap(); + + let (lmsts, xyzs) = setup_tile_uv_w_arrays( + tile_uvs_obs.view_mut(), + tile_ws_obs.view_mut(), + &obs_context, + obs_context.phase_centre, + apply_precession, + ); + setup_tile_uv_w_arrays( + tile_uvs_src.view_mut(), + tile_ws_src.view_mut(), + &obs_context, + source_radec, + apply_precession, + ); + + // iterate over time, rotating visibilities + for (vis_fb, mut vis_rot_fb, tile_ws_obs, tile_ws_src) in izip!( + vis_tfb.outer_iter(), + vis_rot_tfb.view_mut().outer_iter_mut(), + tile_ws_obs.outer_iter(), + tile_ws_src.outer_iter(), + ) { + vis_rotate_fb( + vis_fb.view(), + vis_rot_fb.view_mut(), + tile_ws_obs.as_slice().unwrap(), + tile_ws_src.as_slice().unwrap(), + &lambdas_m, + ); + } + + let mut vis_averaged_tfb = Array3::default((1, 1, num_baselines)); + vis_average2( + vis_rot_tfb.view(), + vis_averaged_tfb.view_mut(), + weights_tfb.view(), + ); + + // display_vis_tfb( + // &"model@obs".into(), + // vis_tfb.view(), + // &obs_context, + // obs_context.phase_centre, + // apply_precession, + // ); + // display_vis_tfb( + // &"rotated@source".into(), + // vis_rot_tfb.view(), + // &obs_context, + // source_radec, + // apply_precession, + // ); + + // if !apply_precession { + // // rotated vis should always have the source in phase, so no angle in pols XX, YY + // for vis_rot in vis_rot_tfb.iter() { + // assert_abs_diff_eq!(vis_rot[0].arg(), 0., epsilon = 1e-6); // XX + // assert_abs_diff_eq!(vis_rot[3].arg(), 0., epsilon = 1e-6); // YY + // } + // // baseline 1, from origin to v has no u or w component, should not be affected by the rotation + // for (vis, vis_rot) in izip!( + // vis_tfb.slice(s![.., .., 1]), + // vis_rot_tfb.slice(s![.., .., 1]), + // ) { + // assert_abs_diff_eq!(vis, vis_rot, epsilon = 1e-6); + // } + // // in the second timestep, the source should be at the pointing centre, so should not be + // // attenuated by the beam + // for (vis, vis_rot) in izip!( + // vis_tfb.slice(s![1, .., ..]), + // vis_rot_tfb.slice(s![1, .., ..]), + // ) { + // // XX + // assert_abs_diff_eq!(vis[0].norm(), source_fd as f32, epsilon = 1e-6); + // assert_abs_diff_eq!(vis_rot[0].norm(), source_fd as f32, epsilon = 1e-6); + // // YY + // assert_abs_diff_eq!(vis[3].norm(), source_fd as f32, epsilon = 1e-6); + // assert_abs_diff_eq!(vis_rot[3].norm(), source_fd as f32, epsilon = 1e-6); + // } + // } + + let (time_axis, freq_axis, _baseline_axis) = (Axis(0), Axis(1), Axis(2)); + let gpu_xyzs: Vec<_> = xyzs + .iter() + .copied() + .map(|XyzGeodetic { x, y, z }| gpu::XYZ { + x: x as GpuFloat, + y: y as GpuFloat, + z: z as GpuFloat, + }) + .collect(); + let mut gpu_uvws = Array2::from_elem( + (num_times, num_baselines), + gpu::UVW { + u: -99.0, + v: -99.0, + w: -99.0, + }, + ); + gpu_uvws + .outer_iter_mut() + .zip(xyzs.outer_iter()) + .zip(lmsts.iter()) + .for_each(|((mut gpu_uvws, xyzs), lmst)| { + let phase_centre = obs_context.phase_centre.to_hadec(*lmst); + let v = xyzs_to_cross_uvws(xyzs.as_slice().unwrap(), phase_centre) + .into_iter() + .map(|uvw| gpu::UVW { + u: uvw.u as GpuFloat, + v: uvw.v as GpuFloat, + w: uvw.w as GpuFloat, + }) + .collect::>(); + gpu_uvws.assign(&ArrayView1::from(&v)); + }); + let gpu_lambdas: Vec = lambdas_m.iter().map(|l| *l as GpuFloat).collect(); + + let mut result = vis_averaged_tfb.clone(); + result.fill(Jones::default()); + + let avg_freq = div_ceil( + vis_tfb.len_of(freq_axis), + vis_averaged_tfb.len_of(freq_axis), + ); + + let d_uvws_from = DevicePointer::copy_to_device(gpu_uvws.as_slice().unwrap()).unwrap(); + let mut d_uvws_to = + DevicePointer::malloc(gpu_uvws.len() * std::mem::size_of::()).unwrap(); + + unsafe { + let d_vis_tfb = DevicePointer::copy_to_device(vis_tfb.as_slice().unwrap()).unwrap(); + let d_weights_tfb = + DevicePointer::copy_to_device(weights_tfb.as_slice().unwrap()).unwrap(); + let mut d_vis_averaged_tfb = + DevicePointer::copy_to_device(result.as_slice().unwrap()).unwrap(); + let d_lmsts = DevicePointer::copy_to_device( + &lmsts + .iter() + .map(|lmst| *lmst as GpuFloat) + .collect::>(), + ) + .unwrap(); + let d_xyzs = DevicePointer::copy_to_device(&gpu_xyzs).unwrap(); + let d_lambdas = DevicePointer::copy_to_device(&gpu_lambdas).unwrap(); + + gpu::rotate_average( + d_vis_tfb.get().cast(), + d_weights_tfb.get().cast(), + d_vis_averaged_tfb.get_mut().cast(), + gpu::RADec { + ra: source_radec.ra as GpuFloat, + dec: source_radec.dec as GpuFloat, + }, + vis_tfb.len_of(time_axis).try_into().unwrap(), + num_tiles.try_into().unwrap(), + num_baselines.try_into().unwrap(), + vis_tfb.len_of(freq_axis).try_into().unwrap(), + avg_freq.try_into().unwrap(), + d_lmsts.get(), + d_xyzs.get(), + d_uvws_from.get(), + d_uvws_to.get_mut(), + d_lambdas.get(), + ); + + d_vis_averaged_tfb + .copy_from_device(result.as_slice_mut().unwrap()) + .unwrap(); + } + + // Test that the CPU and GPU UVWs are the same. CPU UVWs are per + // tile, GPU UVWs are per baseline, so we just make the CPU UVWs + // ourselves. + let mut cpu_uvws = Array2::from_elem((num_times, num_baselines), UVW::default()); + cpu_uvws + .outer_iter_mut() + .zip(xyzs.outer_iter()) + .zip(lmsts.iter()) + .for_each(|((mut cpu_uvws, xyzs), lmst)| { + let phase_centre = obs_context.phase_centre.to_hadec(*lmst); + let v = xyzs_to_cross_uvws(xyzs.as_slice().unwrap(), phase_centre); + cpu_uvws.assign(&ArrayView1::from(&v)); + }); + let gpu_uvws = Array2::from_shape_vec( + (num_times, num_baselines), + d_uvws_from.copy_from_device_new().unwrap(), + ) + .unwrap() + .mapv(|gpu::UVW { u, v, w }| UVW { + // The GPU float precision might not be f64. + u: u as _, + v: v as _, + w: w as _, + }); + #[cfg(not(feature = "gpu-single"))] + let eps = 0.0; + #[cfg(feature = "gpu-single")] + let eps = 5e-8; + assert_abs_diff_eq!(cpu_uvws, gpu_uvws, epsilon = eps); + + // Hack to use `display_vis_tfb` with low-res visibilities. + let mut low_res_obs_context = get_simple_obs_context(); + low_res_obs_context.timestamps = vec1![hour_epoch]; + low_res_obs_context.fine_chan_freqs = vec1![VEL_C as _]; + display_vis_tfb( + &"host".to_string(), + vis_averaged_tfb.view(), + &low_res_obs_context, + obs_context.phase_centre, + apply_precession, + ); + display_vis_tfb( + &"gpu vis_rotate_average".to_string(), + result.view(), + &low_res_obs_context, + obs_context.phase_centre, + apply_precession, + ); + + assert_abs_diff_eq!(vis_averaged_tfb, result); + + // for (tile_ws_obs, tile_ws_src, vis_fb, vis_rot_fb) in izip!( + // tile_ws_obs.outer_iter(), + // tile_ws_src.outer_iter(), + // vis_tfb.outer_iter(), + // vis_rot_tfb.outer_iter(), + // ) { + // for (lambda_m, vis_b, vis_rot_b) in izip!( + // lambdas_m.iter(), + // vis_fb.outer_iter(), + // vis_rot_fb.outer_iter(), + // ) { + // for (&(ant1, ant2), vis, vis_rot) in + // izip!(ant_pairs.iter(), vis_b.iter(), vis_rot_b.iter(),) + // { + // let w_obs = tile_ws_obs[ant1] - tile_ws_obs[ant2]; + // let w_src = tile_ws_src[ant1] - tile_ws_src[ant2]; + // let arg = (TAU * (w_src - w_obs) / lambda_m) as f64; + // for (pol_model, pol_model_rot) in vis.iter().zip_eq(vis_rot.iter()) { + // // magnitudes shoud not be affected by rotation + // assert_abs_diff_eq!( + // pol_model.norm(), + // pol_model_rot.norm(), + // epsilon = 1e-6 + // ); + // let pol_model_rot_expected = Complex::from_polar( + // pol_model.norm(), + // (pol_model.arg() as f64 - arg) as f32, + // ); + // assert_abs_diff_eq!( + // pol_model_rot_expected.arg(), + // pol_model_rot.arg(), + // epsilon = 1e-6 + // ); + // } + // } + // } + // } + } + } + + #[test] + fn test_peel_gpu_single_source() { + test_peel_single_source(PeelType::Gpu) + } + + #[test] + fn test_peel_gpu_multi_source() { + test_peel_multi_source(PeelType::Gpu) + } +} diff --git a/src/srclist/types/components/mod.rs b/src/srclist/types/components/mod.rs index de031d8f..4e7041c6 100644 --- a/src/srclist/types/components/mod.rs +++ b/src/srclist/types/components/mod.rs @@ -14,7 +14,7 @@ use ndarray::prelude::*; use rayon::prelude::*; use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use super::{FluxDensity, FluxDensityType, SourceList}; +use super::{FluxDensity, FluxDensityType}; /// Information on a source's component. #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] @@ -173,15 +173,19 @@ pub(crate) struct ComponentList { } impl ComponentList { - /// Given a source list, split the components into each [ComponentType]. + /// Given sky-model source components, split them into each + /// [`ComponentType`]. /// /// These parameters don't change over time, so it's ideal to run this /// function once. - pub(crate) fn new( - source_list: &SourceList, + pub(crate) fn new<'a, I>( + components: I, unflagged_fine_chan_freqs: &[f64], phase_centre: RADec, - ) -> ComponentList { + ) -> ComponentList + where + I: IntoIterator, + { // Unpack each of the component parameters into vectors. let mut point_radecs = vec![]; let mut point_lmns = vec![]; @@ -198,6 +202,7 @@ impl ComponentList { let mut shapelet_gaussian_params = vec![]; let mut shapelet_coeffs: Vec> = vec![]; + // TODO: Reverse elsewhere // Reverse the source list; if the source list has been sorted // (brightest sources first), reversing makes the dimmest sources get // used first. This is good because floating-point precision errors are @@ -205,11 +210,7 @@ impl ComponentList { // float starting from the brightest component means that the // floating-point precision errors are greater as we work through the // source list. - for comp in source_list - .iter() - .rev() - .flat_map(|(_, src)| src.components.iter()) - { + for comp in components.into_iter() { let comp_lmn = comp.radec.to_lmn(phase_centre).prepare_for_rime(); match &comp.comp_type { ComponentType::Point => { diff --git a/src/srclist/types/components/tests.rs b/src/srclist/types/components/tests.rs index 8557021c..a1220256 100644 --- a/src/srclist/types/components/tests.rs +++ b/src/srclist/types/components/tests.rs @@ -59,7 +59,11 @@ fn test_split_components() { .count() }); - let split_components = ComponentList::new(&srclist, &freqs, phase_centre); + let split_components = ComponentList::new( + srclist.values().rev().flat_map(|src| src.components.iter()), + &freqs, + phase_centre, + ); let points = split_components.points; let gaussians = split_components.gaussians; let shapelets = split_components.shapelets; diff --git a/src/unit_parsing/mod.rs b/src/unit_parsing/mod.rs index 29ae0838..8cdd106d 100644 --- a/src/unit_parsing/mod.rs +++ b/src/unit_parsing/mod.rs @@ -77,6 +77,10 @@ pub(crate) enum FreqFormat { /// kiloHertz #[strum(serialize = "kHz")] kHz, + + /// MegaHertz + #[strum(serialize = "MHz")] + MHz, } /// Parse a string that may have a unit of frequency attached to it.