From 5781f9a62d54bee7dfd2926f9618a61e2d266848 Mon Sep 17 00:00:00 2001 From: "Christopher H. Jordan" Date: Mon, 4 Apr 2022 14:27:49 +0800 Subject: [PATCH] Move vis weighting outside calibration. Multiply visibilities by weights before calibration and divide by weights if writing visibilities out. This means weights are needed during calibration and the code is able to run about 25% faster. Note that there are relatively big (up to 1e-4) float-point errors introduced by doing this; writing visibilities out of di-calibration should therefore only be done as a "quick and dirty" convenience. When solutions-apply is introduced, that will be the preferred way to obtain calibrated visibilities. --- benches/bench.rs | 2 - src/calibrate/di/code/mod.rs | 151 ++++++++++++++++----------------- src/calibrate/di/code/tests.rs | 18 ++-- src/calibrate/di/mod.rs | 28 +++++- 4 files changed, 108 insertions(+), 91 deletions(-) diff --git a/benches/bench.rs b/benches/bench.rs index 2e562e9e..0eda858e 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -534,7 +534,6 @@ fn calibrate_benchmarks(c: &mut Criterion) { let vis_shape = (num_timesteps, num_baselines, num_chanblocks); let vis_data: Array3> = Array3::from_elem(vis_shape, Jones::identity() * 4.0); - let vis_weights: Array3 = Array3::ones(vis_shape); let vis_model: Array3> = Array3::from_elem(vis_shape, Jones::identity()); let baseline_weights = vec![1.0; num_baselines]; @@ -544,7 +543,6 @@ fn calibrate_benchmarks(c: &mut Criterion) { b.iter(|| { calibrate_timeblocks( vis_data.view(), - vis_weights.view(), vis_model.view(), &timeblocks, &chanblocks, diff --git a/src/calibrate/di/code/mod.rs b/src/calibrate/di/code/mod.rs index 93b7a621..8da0960c 100644 --- a/src/calibrate/di/code/mod.rs +++ b/src/calibrate/di/code/mod.rs @@ -281,6 +281,46 @@ pub(crate) fn get_cal_vis( ), } + debug!("Multiplying visibilities by weights"); + + // Multiply the visibilities by the weights (and baseline weights based on + // UVW cuts). If a weight is negative, it means the corresponding visibility + // should be flagged, so that visibility is set to 0; this means it does not + // affect calibration. Not iterating over weights during calibration makes + // makes calibration run significantly faster. + vis_data + .outer_iter_mut() + .into_par_iter() + .zip(vis_model.outer_iter_mut()) + .zip(vis_weights.outer_iter()) + .for_each(|((mut vis_data, mut vis_model), vis_weights)| { + vis_data + .outer_iter_mut() + .zip(vis_model.outer_iter_mut()) + .zip(vis_weights.outer_iter()) + .zip(params.baseline_weights.iter()) + .for_each( + |(((mut vis_data, mut vis_model), vis_weights), &baseline_weight)| { + vis_data + .iter_mut() + .zip(vis_model.iter_mut()) + .zip(vis_weights.iter()) + .for_each(|((vis_data, vis_model), &vis_weight)| { + let weight = f64::from(vis_weight) * baseline_weight; + if weight <= 0.0 { + *vis_data = Jones::default(); + *vis_model = Jones::default(); + } else { + *vis_data = + Jones::::from(Jones::::from(*vis_data) * weight); + *vis_model = + Jones::::from(Jones::::from(*vis_model) * weight); + } + }); + }, + ); + }); + info!("Finished reading input data and sky modelling"); Ok(CalVis { @@ -634,7 +674,6 @@ impl<'a> IncompleteSolutions<'a> { #[allow(clippy::too_many_arguments)] pub fn calibrate_timeblocks<'a>( vis_data: ArrayView3>, - vis_weights: ArrayView3, vis_model: ArrayView3>, timeblocks: &'a [Timeblock], chanblocks: &'a [Chanblock], @@ -645,21 +684,6 @@ pub fn calibrate_timeblocks<'a>( draw_progress_bar: bool, print_convergence_messages: bool, ) -> (IncompleteSolutions<'a>, Array2) { - // Multiply the baseline weights against the visibility weights. Then, only - // the visibility weights need to be multiplied against the data and model - // visibilities. - assert_eq!(vis_weights.len_of(Axis(1)), baseline_weights.len()); - let mut vis_weights = vis_weights.to_owned(); - vis_weights - .axis_iter_mut(Axis(1)) - .into_par_iter() - .zip(baseline_weights) - .for_each(|(mut vis_weights, &baseline_weight)| { - vis_weights.iter_mut().for_each(|vis_weight| { - *vis_weight = (*vis_weight as f64 * baseline_weight) as f32; - }) - }); - let num_unflagged_tiles = num_tiles_from_num_cross_correlation_baselines(vis_data.dim().1); let num_timeblocks = timeblocks.len(); let num_chanblocks = chanblocks.len(); @@ -675,7 +699,6 @@ pub fn calibrate_timeblocks<'a>( ); let cal_results = calibrate_timeblock( vis_data.view(), - vis_weights.view(), vis_model.view(), di_jones.view_mut(), timeblocks.first().unwrap(), @@ -727,7 +750,6 @@ pub fn calibrate_timeblocks<'a>( ); let cal_results = calibrate_timeblock( vis_data.view(), - vis_weights.view(), vis_model.view(), di_jones.view_mut(), &timeblock, @@ -763,7 +785,6 @@ pub fn calibrate_timeblocks<'a>( ); let mut cal_results = calibrate_timeblock( vis_data.view(), - vis_weights.view(), vis_model.view(), di_jones.view_mut(), timeblock, @@ -831,7 +852,6 @@ fn make_calibration_progress_bar( #[allow(clippy::too_many_arguments)] fn calibrate_timeblock( vis_data: ArrayView3>, - vis_weights: ArrayView3, vis_model: ArrayView3>, mut di_jones: ArrayViewMut3>, timeblock: &Timeblock, @@ -863,7 +883,6 @@ fn calibrate_timeblock( ]; let mut cal_result = calibrate( vis_data.slice(range), - vis_weights.slice(range), vis_model.slice(range), di_jones, max_iterations, @@ -990,7 +1009,6 @@ fn calibrate_timeblock( let range = s![timeblock.range.clone(), .., i_chanblock..i_chanblock + 1]; let mut new_cal_result = calibrate( vis_data.slice(range), - vis_weights.slice(range), vis_model.slice(range), di_jones, max_iterations, @@ -1068,7 +1086,6 @@ pub struct CalibrationResult { /// parallel code is inside this function. pub(super) fn calibrate( data: ArrayView3>, - weights: ArrayView3, model: ArrayView3>, mut di_jones: ArrayViewMut1>, max_iterations: usize, @@ -1096,14 +1113,7 @@ pub(super) fn calibrate( top.fill(Jones::default()); bot.fill(Jones::default()); - calibration_loop( - data, - weights, - model, - di_jones.view(), - top.view_mut(), - bot.view_mut(), - ); + calibration_loop(data, model, di_jones.view(), top.view_mut(), bot.view_mut()); // Obtain the new DI Jones matrices from "top" and "bot". // Tile/antenna axis. @@ -1239,7 +1249,6 @@ pub(super) fn calibrate( /// "MitchCal". fn calibration_loop( data: ArrayView3>, - weights: ArrayView3, model: ArrayView3>, di_jones: ArrayView1>, mut top: ArrayViewMut1>, @@ -1249,58 +1258,48 @@ fn calibration_loop( // Time axis. data.outer_iter() - .zip(weights.outer_iter()) .zip(model.outer_iter()) - .for_each(|((data, weights), model)| { + .for_each(|(data, model)| { // Unflagged baseline axis. data.outer_iter() - .zip(weights.outer_iter()) .zip(model.outer_iter()) .enumerate() - .for_each(|(i_baseline, ((data, weights), model))| { + .for_each(|(i_baseline, (data, model))| { let (tile1, tile2) = cross_correlation_baseline_to_tiles(num_tiles, i_baseline); // Unflagged frequency chan axis. - data.iter() - .zip(weights) - .zip(model) - // Don't do anything if the weight is flagged. - .filter(|((_, weight), _)| **weight > 0.0) - .for_each(|((j_data, weight), j_model)| { - // Copy and promote the data and model Jones - // matrices. - let weight = *weight as f64; - let j_data: Jones = Jones::from(j_data) * weight; - let j_model: Jones = Jones::from(j_model) * weight; - - // Suppress boundary checks for maximum performance! - unsafe { - let j_t1 = di_jones.uget(tile1); - let j_t2 = di_jones.uget(tile2); - - let top_t1 = top.uget_mut(tile1); - let bot_t1 = bot.uget_mut(tile1); - - // André's calibrate: ( D J M^H ) / ( M J^H J M^H ) - // J M^H - let z = *j_t2 * j_model.h(); - // D (J M^H) - *top_t1 += j_data * z; - // (J M^H)^H (J M^H) - *bot_t1 += z.h() * z; - - let top_t2 = top.uget_mut(tile2); - let bot_t2 = bot.uget_mut(tile2); - - // André's calibrate: ( D J M^H ) / ( M J^H J M^H ) - // J (M^H)^H - let z = *j_t1 * j_model; - // D^H (J M^H)^H - *top_t2 += j_data.h() * z; - // (J M^H) (J M^H) - *bot_t2 += z.h() * z; - } - }); + data.iter().zip(model).for_each(|(j_data, j_model)| { + let j_data = Jones::::from(j_data); + let j_model = Jones::::from(j_model); + + // Suppress boundary checks for maximum performance! + unsafe { + let j_t1 = di_jones.uget(tile1); + let j_t2 = di_jones.uget(tile2); + + let top_t1 = top.uget_mut(tile1); + let bot_t1 = bot.uget_mut(tile1); + + // André's calibrate: ( D J M^H ) / ( M J^H J M^H ) + // J M^H + let z = *j_t2 * j_model.h(); + // D (J M^H) + *top_t1 += j_data * z; + // (J M^H)^H (J M^H) + *bot_t1 += z.h() * z; + + let top_t2 = top.uget_mut(tile2); + let bot_t2 = bot.uget_mut(tile2); + + // André's calibrate: ( D J M^H ) / ( M J^H J M^H ) + // J (M^H)^H + let z = *j_t1 * j_model; + // D^H (J M^H)^H + *top_t2 += j_data.h() * z; + // (J M^H) (J M^H) + *bot_t2 += z.h() * z; + } + }); }) }); } diff --git a/src/calibrate/di/code/tests.rs b/src/calibrate/di/code/tests.rs index a8030a27..dd88fb49 100644 --- a/src/calibrate/di/code/tests.rs +++ b/src/calibrate/di/code/tests.rs @@ -33,7 +33,6 @@ fn test_calibrate_trivial() { let vis_shape = (num_timesteps, num_baselines, num_chanblocks); let vis_data: Array3> = Array3::from_elem(vis_shape, Jones::identity() * 4.0); - let vis_weights: Array3 = Array3::ones(vis_shape); let vis_model: Array3> = Array3::from_elem(vis_shape, Jones::identity()); let mut di_jones = Array3::from_elem( (num_timeblocks, num_tiles, num_chanblocks), @@ -56,11 +55,9 @@ fn test_calibrate_trivial() { chanblock_index..chanblock_index + 1 ]; let vis_data_slice = vis_data.slice(range); - let vis_weight_slice = vis_weights.slice(range); let vis_model_slice = vis_model.slice(range); let result = calibrate( vis_data_slice, - vis_weight_slice, vis_model_slice, di_jones_rev.view_mut(), 20, @@ -112,7 +109,7 @@ fn test_calibrate_trivial_with_flags() { let bad_vis = vis_data.get_mut((0, 0, 0)).unwrap(); *bad_vis = Jones::identity() * 9000.0; let mut vis_weights: Array3 = Array3::ones(vis_shape); - let vis_model: Array3> = Array3::from_elem(vis_shape, Jones::identity()); + let mut vis_model: Array3> = Array3::from_elem(vis_shape, Jones::identity()); let mut di_jones = Array3::from_elem( (num_timeblocks, num_tiles, num_chanblocks), Jones::::identity(), @@ -134,11 +131,9 @@ fn test_calibrate_trivial_with_flags() { chanblock_index..chanblock_index + 1 ]; let vis_data_slice = vis_data.slice(range); - let vis_weight_slice = vis_weights.slice(range); let vis_model_slice = vis_model.slice(range); let result = calibrate( vis_data_slice, - vis_weight_slice, vis_model_slice, di_jones_rev.view_mut(), 20, @@ -165,9 +160,15 @@ fn test_calibrate_trivial_with_flags() { } } - // Fix the weight and repeat. + // Fix the weight and repeat. We have to set the corresponding visibilities + // to 0 (this is normally done before the visibilities are returned via + // `CalVis`). let bad_weight = vis_weights.get_mut((0, 0, 0)).unwrap(); *bad_weight = -1.0; + let bad_data = vis_data.get_mut((0, 0, 0)).unwrap(); + *bad_data = Jones::default(); + let bad_model = vis_model.get_mut((0, 0, 0)).unwrap(); + *bad_model = Jones::default(); di_jones.fill(Jones::identity()); for timeblock in 0..num_timeblocks { let time_range_start = timeblock * timeblock_length; @@ -185,11 +186,9 @@ fn test_calibrate_trivial_with_flags() { chanblock_index..chanblock_index + 1 ]; let vis_data_slice = vis_data.slice(range); - let vis_weight_slice = vis_weights.slice(range); let vis_model_slice = vis_model.slice(range); let result = calibrate( vis_data_slice, - vis_weight_slice, vis_model_slice, di_jones_rev.view_mut(), 20, @@ -592,7 +591,6 @@ fn incomplete_to_complete_flags_complex() { pub(crate) fn test_1090008640_quality(params: CalibrateParams, cal_vis: CalVis) { let (_, cal_results) = calibrate_timeblocks( cal_vis.vis_data.view(), - cal_vis.vis_weights.view(), cal_vis.vis_model.view(), ¶ms.timeblocks, ¶ms.fences.first().chanblocks, diff --git a/src/calibrate/di/mod.rs b/src/calibrate/di/mod.rs index 8d440adb..4d4a183f 100644 --- a/src/calibrate/di/mod.rs +++ b/src/calibrate/di/mod.rs @@ -40,6 +40,7 @@ pub(crate) fn di_calibrate( vis_weights, vis_model, } = get_cal_vis(params, !params.no_progress_bars)?; + assert_eq!(vis_weights.len_of(Axis(1)), params.baseline_weights.len()); let obs_context = params.input_data.get_obs_context(); @@ -57,14 +58,13 @@ pub(crate) fn di_calibrate( shape.1, shape.2, shape.0 * shape.1 * shape.2 * std::mem::size_of::>() - // 1024 * 1024 == 1 MiB. - / 1024 / 1024 + // 1024 * 1024 == 1 MiB. + / 1024 / 1024 ); } let (sols, _) = calibrate_timeblocks( vis_data.view(), - vis_weights.view(), vis_model.view(), ¶ms.timeblocks, ¶ms.fences.first().chanblocks, @@ -140,6 +140,28 @@ pub(crate) fn di_calibrate( // Write out calibrated visibilities. if !params.output_vis_filenames.is_empty() { + debug!("Dividing visibilities by weights"); + // Divide the visibilities by the weights (undoing the multiplication earlier). + vis_data + .outer_iter_mut() + .into_par_iter() + .zip(vis_weights.outer_iter()) + .for_each(|(mut vis_data, vis_weights)| { + vis_data + .outer_iter_mut() + .zip(vis_weights.outer_iter()) + .zip(params.baseline_weights.iter()) + .for_each(|((mut vis_data, vis_weights), &baseline_weight)| { + vis_data.iter_mut().zip(vis_weights.iter()).for_each( + |(vis_data, &vis_weight)| { + let weight = f64::from(vis_weight) * baseline_weight; + *vis_data = + Jones::::from(Jones::::from(*vis_data) / weight); + }, + ); + }); + }); + info!("Writing visibilities..."); // TODO(dev): support and test autos