Skip to content

Commit

Permalink
Move vis weighting outside calibration.
Browse files Browse the repository at this point in the history
Multiply visibilities by weights before calibration and divide by
weights if writing visibilities out. This means weights are needed
during calibration and the code is able to run about 25% faster.

Note that there are relatively big (up to 1e-4) float-point errors
introduced by doing this; writing visibilities out of di-calibration
should therefore only be done as a "quick and dirty" convenience. When
solutions-apply is introduced, that will be the preferred way to obtain
calibrated visibilities.
  • Loading branch information
cjordan committed Apr 5, 2022
1 parent 16d004a commit 5781f9a
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 91 deletions.
2 changes: 0 additions & 2 deletions benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,6 @@ fn calibrate_benchmarks(c: &mut Criterion) {

let vis_shape = (num_timesteps, num_baselines, num_chanblocks);
let vis_data: Array3<Jones<f32>> = Array3::from_elem(vis_shape, Jones::identity() * 4.0);
let vis_weights: Array3<f32> = Array3::ones(vis_shape);
let vis_model: Array3<Jones<f32>> = Array3::from_elem(vis_shape, Jones::identity());
let baseline_weights = vec![1.0; num_baselines];

Expand All @@ -544,7 +543,6 @@ fn calibrate_benchmarks(c: &mut Criterion) {
b.iter(|| {
calibrate_timeblocks(
vis_data.view(),
vis_weights.view(),
vis_model.view(),
&timeblocks,
&chanblocks,
Expand Down
151 changes: 75 additions & 76 deletions src/calibrate/di/code/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,46 @@ pub(crate) fn get_cal_vis(
),
}

debug!("Multiplying visibilities by weights");

// Multiply the visibilities by the weights (and baseline weights based on
// UVW cuts). If a weight is negative, it means the corresponding visibility
// should be flagged, so that visibility is set to 0; this means it does not
// affect calibration. Not iterating over weights during calibration makes
// makes calibration run significantly faster.
vis_data
.outer_iter_mut()
.into_par_iter()
.zip(vis_model.outer_iter_mut())
.zip(vis_weights.outer_iter())
.for_each(|((mut vis_data, mut vis_model), vis_weights)| {
vis_data
.outer_iter_mut()
.zip(vis_model.outer_iter_mut())
.zip(vis_weights.outer_iter())
.zip(params.baseline_weights.iter())
.for_each(
|(((mut vis_data, mut vis_model), vis_weights), &baseline_weight)| {
vis_data
.iter_mut()
.zip(vis_model.iter_mut())
.zip(vis_weights.iter())
.for_each(|((vis_data, vis_model), &vis_weight)| {
let weight = f64::from(vis_weight) * baseline_weight;
if weight <= 0.0 {
*vis_data = Jones::default();
*vis_model = Jones::default();
} else {
*vis_data =
Jones::<f32>::from(Jones::<f64>::from(*vis_data) * weight);
*vis_model =
Jones::<f32>::from(Jones::<f64>::from(*vis_model) * weight);
}
});
},
);
});

info!("Finished reading input data and sky modelling");

Ok(CalVis {
Expand Down Expand Up @@ -634,7 +674,6 @@ impl<'a> IncompleteSolutions<'a> {
#[allow(clippy::too_many_arguments)]
pub fn calibrate_timeblocks<'a>(
vis_data: ArrayView3<Jones<f32>>,
vis_weights: ArrayView3<f32>,
vis_model: ArrayView3<Jones<f32>>,
timeblocks: &'a [Timeblock],
chanblocks: &'a [Chanblock],
Expand All @@ -645,21 +684,6 @@ pub fn calibrate_timeblocks<'a>(
draw_progress_bar: bool,
print_convergence_messages: bool,
) -> (IncompleteSolutions<'a>, Array2<CalibrationResult>) {
// Multiply the baseline weights against the visibility weights. Then, only
// the visibility weights need to be multiplied against the data and model
// visibilities.
assert_eq!(vis_weights.len_of(Axis(1)), baseline_weights.len());
let mut vis_weights = vis_weights.to_owned();
vis_weights
.axis_iter_mut(Axis(1))
.into_par_iter()
.zip(baseline_weights)
.for_each(|(mut vis_weights, &baseline_weight)| {
vis_weights.iter_mut().for_each(|vis_weight| {
*vis_weight = (*vis_weight as f64 * baseline_weight) as f32;
})
});

let num_unflagged_tiles = num_tiles_from_num_cross_correlation_baselines(vis_data.dim().1);
let num_timeblocks = timeblocks.len();
let num_chanblocks = chanblocks.len();
Expand All @@ -675,7 +699,6 @@ pub fn calibrate_timeblocks<'a>(
);
let cal_results = calibrate_timeblock(
vis_data.view(),
vis_weights.view(),
vis_model.view(),
di_jones.view_mut(),
timeblocks.first().unwrap(),
Expand Down Expand Up @@ -727,7 +750,6 @@ pub fn calibrate_timeblocks<'a>(
);
let cal_results = calibrate_timeblock(
vis_data.view(),
vis_weights.view(),
vis_model.view(),
di_jones.view_mut(),
&timeblock,
Expand Down Expand Up @@ -763,7 +785,6 @@ pub fn calibrate_timeblocks<'a>(
);
let mut cal_results = calibrate_timeblock(
vis_data.view(),
vis_weights.view(),
vis_model.view(),
di_jones.view_mut(),
timeblock,
Expand Down Expand Up @@ -831,7 +852,6 @@ fn make_calibration_progress_bar(
#[allow(clippy::too_many_arguments)]
fn calibrate_timeblock(
vis_data: ArrayView3<Jones<f32>>,
vis_weights: ArrayView3<f32>,
vis_model: ArrayView3<Jones<f32>>,
mut di_jones: ArrayViewMut3<Jones<f64>>,
timeblock: &Timeblock,
Expand Down Expand Up @@ -863,7 +883,6 @@ fn calibrate_timeblock(
];
let mut cal_result = calibrate(
vis_data.slice(range),
vis_weights.slice(range),
vis_model.slice(range),
di_jones,
max_iterations,
Expand Down Expand Up @@ -990,7 +1009,6 @@ fn calibrate_timeblock(
let range = s![timeblock.range.clone(), .., i_chanblock..i_chanblock + 1];
let mut new_cal_result = calibrate(
vis_data.slice(range),
vis_weights.slice(range),
vis_model.slice(range),
di_jones,
max_iterations,
Expand Down Expand Up @@ -1068,7 +1086,6 @@ pub struct CalibrationResult {
/// parallel code is inside this function.
pub(super) fn calibrate(
data: ArrayView3<Jones<f32>>,
weights: ArrayView3<f32>,
model: ArrayView3<Jones<f32>>,
mut di_jones: ArrayViewMut1<Jones<f64>>,
max_iterations: usize,
Expand Down Expand Up @@ -1096,14 +1113,7 @@ pub(super) fn calibrate(
top.fill(Jones::default());
bot.fill(Jones::default());

calibration_loop(
data,
weights,
model,
di_jones.view(),
top.view_mut(),
bot.view_mut(),
);
calibration_loop(data, model, di_jones.view(), top.view_mut(), bot.view_mut());

// Obtain the new DI Jones matrices from "top" and "bot".
// Tile/antenna axis.
Expand Down Expand Up @@ -1239,7 +1249,6 @@ pub(super) fn calibrate(
/// "MitchCal".
fn calibration_loop(
data: ArrayView3<Jones<f32>>,
weights: ArrayView3<f32>,
model: ArrayView3<Jones<f32>>,
di_jones: ArrayView1<Jones<f64>>,
mut top: ArrayViewMut1<Jones<f64>>,
Expand All @@ -1249,58 +1258,48 @@ fn calibration_loop(

// Time axis.
data.outer_iter()
.zip(weights.outer_iter())
.zip(model.outer_iter())
.for_each(|((data, weights), model)| {
.for_each(|(data, model)| {
// Unflagged baseline axis.
data.outer_iter()
.zip(weights.outer_iter())
.zip(model.outer_iter())
.enumerate()
.for_each(|(i_baseline, ((data, weights), model))| {
.for_each(|(i_baseline, (data, model))| {
let (tile1, tile2) = cross_correlation_baseline_to_tiles(num_tiles, i_baseline);

// Unflagged frequency chan axis.
data.iter()
.zip(weights)
.zip(model)
// Don't do anything if the weight is flagged.
.filter(|((_, weight), _)| **weight > 0.0)
.for_each(|((j_data, weight), j_model)| {
// Copy and promote the data and model Jones
// matrices.
let weight = *weight as f64;
let j_data: Jones<f64> = Jones::from(j_data) * weight;
let j_model: Jones<f64> = Jones::from(j_model) * weight;

// Suppress boundary checks for maximum performance!
unsafe {
let j_t1 = di_jones.uget(tile1);
let j_t2 = di_jones.uget(tile2);

let top_t1 = top.uget_mut(tile1);
let bot_t1 = bot.uget_mut(tile1);

// André's calibrate: ( D J M^H ) / ( M J^H J M^H )
// J M^H
let z = *j_t2 * j_model.h();
// D (J M^H)
*top_t1 += j_data * z;
// (J M^H)^H (J M^H)
*bot_t1 += z.h() * z;

let top_t2 = top.uget_mut(tile2);
let bot_t2 = bot.uget_mut(tile2);

// André's calibrate: ( D J M^H ) / ( M J^H J M^H )
// J (M^H)^H
let z = *j_t1 * j_model;
// D^H (J M^H)^H
*top_t2 += j_data.h() * z;
// (J M^H) (J M^H)
*bot_t2 += z.h() * z;
}
});
data.iter().zip(model).for_each(|(j_data, j_model)| {
let j_data = Jones::<f64>::from(j_data);
let j_model = Jones::<f64>::from(j_model);

// Suppress boundary checks for maximum performance!
unsafe {
let j_t1 = di_jones.uget(tile1);
let j_t2 = di_jones.uget(tile2);

let top_t1 = top.uget_mut(tile1);
let bot_t1 = bot.uget_mut(tile1);

// André's calibrate: ( D J M^H ) / ( M J^H J M^H )
// J M^H
let z = *j_t2 * j_model.h();
// D (J M^H)
*top_t1 += j_data * z;
// (J M^H)^H (J M^H)
*bot_t1 += z.h() * z;

let top_t2 = top.uget_mut(tile2);
let bot_t2 = bot.uget_mut(tile2);

// André's calibrate: ( D J M^H ) / ( M J^H J M^H )
// J (M^H)^H
let z = *j_t1 * j_model;
// D^H (J M^H)^H
*top_t2 += j_data.h() * z;
// (J M^H) (J M^H)
*bot_t2 += z.h() * z;
}
});
})
});
}
18 changes: 8 additions & 10 deletions src/calibrate/di/code/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ fn test_calibrate_trivial() {

let vis_shape = (num_timesteps, num_baselines, num_chanblocks);
let vis_data: Array3<Jones<f32>> = Array3::from_elem(vis_shape, Jones::identity() * 4.0);
let vis_weights: Array3<f32> = Array3::ones(vis_shape);
let vis_model: Array3<Jones<f32>> = Array3::from_elem(vis_shape, Jones::identity());
let mut di_jones = Array3::from_elem(
(num_timeblocks, num_tiles, num_chanblocks),
Expand All @@ -56,11 +55,9 @@ fn test_calibrate_trivial() {
chanblock_index..chanblock_index + 1
];
let vis_data_slice = vis_data.slice(range);
let vis_weight_slice = vis_weights.slice(range);
let vis_model_slice = vis_model.slice(range);
let result = calibrate(
vis_data_slice,
vis_weight_slice,
vis_model_slice,
di_jones_rev.view_mut(),
20,
Expand Down Expand Up @@ -112,7 +109,7 @@ fn test_calibrate_trivial_with_flags() {
let bad_vis = vis_data.get_mut((0, 0, 0)).unwrap();
*bad_vis = Jones::identity() * 9000.0;
let mut vis_weights: Array3<f32> = Array3::ones(vis_shape);
let vis_model: Array3<Jones<f32>> = Array3::from_elem(vis_shape, Jones::identity());
let mut vis_model: Array3<Jones<f32>> = Array3::from_elem(vis_shape, Jones::identity());
let mut di_jones = Array3::from_elem(
(num_timeblocks, num_tiles, num_chanblocks),
Jones::<f64>::identity(),
Expand All @@ -134,11 +131,9 @@ fn test_calibrate_trivial_with_flags() {
chanblock_index..chanblock_index + 1
];
let vis_data_slice = vis_data.slice(range);
let vis_weight_slice = vis_weights.slice(range);
let vis_model_slice = vis_model.slice(range);
let result = calibrate(
vis_data_slice,
vis_weight_slice,
vis_model_slice,
di_jones_rev.view_mut(),
20,
Expand All @@ -165,9 +160,15 @@ fn test_calibrate_trivial_with_flags() {
}
}

// Fix the weight and repeat.
// Fix the weight and repeat. We have to set the corresponding visibilities
// to 0 (this is normally done before the visibilities are returned via
// `CalVis`).
let bad_weight = vis_weights.get_mut((0, 0, 0)).unwrap();
*bad_weight = -1.0;
let bad_data = vis_data.get_mut((0, 0, 0)).unwrap();
*bad_data = Jones::default();
let bad_model = vis_model.get_mut((0, 0, 0)).unwrap();
*bad_model = Jones::default();
di_jones.fill(Jones::identity());
for timeblock in 0..num_timeblocks {
let time_range_start = timeblock * timeblock_length;
Expand All @@ -185,11 +186,9 @@ fn test_calibrate_trivial_with_flags() {
chanblock_index..chanblock_index + 1
];
let vis_data_slice = vis_data.slice(range);
let vis_weight_slice = vis_weights.slice(range);
let vis_model_slice = vis_model.slice(range);
let result = calibrate(
vis_data_slice,
vis_weight_slice,
vis_model_slice,
di_jones_rev.view_mut(),
20,
Expand Down Expand Up @@ -592,7 +591,6 @@ fn incomplete_to_complete_flags_complex() {
pub(crate) fn test_1090008640_quality(params: CalibrateParams, cal_vis: CalVis) {
let (_, cal_results) = calibrate_timeblocks(
cal_vis.vis_data.view(),
cal_vis.vis_weights.view(),
cal_vis.vis_model.view(),
&params.timeblocks,
&params.fences.first().chanblocks,
Expand Down
28 changes: 25 additions & 3 deletions src/calibrate/di/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pub(crate) fn di_calibrate(
vis_weights,
vis_model,
} = get_cal_vis(params, !params.no_progress_bars)?;
assert_eq!(vis_weights.len_of(Axis(1)), params.baseline_weights.len());

let obs_context = params.input_data.get_obs_context();

Expand All @@ -57,14 +58,13 @@ pub(crate) fn di_calibrate(
shape.1,
shape.2,
shape.0 * shape.1 * shape.2 * std::mem::size_of::<Jones<f64>>()
// 1024 * 1024 == 1 MiB.
/ 1024 / 1024
// 1024 * 1024 == 1 MiB.
/ 1024 / 1024
);
}

let (sols, _) = calibrate_timeblocks(
vis_data.view(),
vis_weights.view(),
vis_model.view(),
&params.timeblocks,
&params.fences.first().chanblocks,
Expand Down Expand Up @@ -140,6 +140,28 @@ pub(crate) fn di_calibrate(

// Write out calibrated visibilities.
if !params.output_vis_filenames.is_empty() {
debug!("Dividing visibilities by weights");
// Divide the visibilities by the weights (undoing the multiplication earlier).
vis_data
.outer_iter_mut()
.into_par_iter()
.zip(vis_weights.outer_iter())
.for_each(|(mut vis_data, vis_weights)| {
vis_data
.outer_iter_mut()
.zip(vis_weights.outer_iter())
.zip(params.baseline_weights.iter())
.for_each(|((mut vis_data, vis_weights), &baseline_weight)| {
vis_data.iter_mut().zip(vis_weights.iter()).for_each(
|(vis_data, &vis_weight)| {
let weight = f64::from(vis_weight) * baseline_weight;
*vis_data =
Jones::<f32>::from(Jones::<f64>::from(*vis_data) / weight);
},
);
});
});

info!("Writing visibilities...");

// TODO(dev): support and test autos
Expand Down

0 comments on commit 5781f9a

Please sign in to comment.