diff --git a/src/aspire/abinitio/commonline_base.cu b/src/aspire/abinitio/commonline_base.cu new file mode 100644 index 0000000000..a3ee96ceed --- /dev/null +++ b/src/aspire/abinitio/commonline_base.cu @@ -0,0 +1,151 @@ +#include +#include +#include + +extern "C" __global__ +void build_clmatrix_kernel( + const int n, + const int m, + const int r, + const complex* __restrict__ pf, + int16_t* const __restrict__ clmatrix, + const int n_shifts, + const complex* const __restrict__ shift_phases) +{ + /* n n_img */ + /* m angular componentns, n_theta//2 */ + /* r radial componentns */ + /* (n, m, r) = pf.shape in python (before transpose for CUDA kernel) */ + + /* thread index (2d), represents "i" and "j" indices */ + const unsigned int i = blockDim.x * blockIdx.x + threadIdx.x; + const unsigned int j = blockDim.y * blockIdx.y + threadIdx.y; + + /* no-op when out of bounds */ + if(i >= n) return; + if(j >= n) return; + /* no-op lower triangle */ + if(j <= i) return; + + int k; + int s; + int cl1, cl2; + int best_cl1, best_cl2; + double xcorr, best_cl_xcorr; + double p1, p2; + complex pfik, pfjk; + + best_cl1 = -1; + best_cl2 = -1; + best_cl_xcorr = -INFINITY; + + for(cl1=0; cl1 best_cl_xcorr){ + best_cl_xcorr = xcorr; + best_cl1 = cl1; + best_cl2 = cl2; + } + + xcorr = p1 + p2; + if(xcorr > best_cl_xcorr){ + best_cl_xcorr = xcorr; + best_cl1 = cl1; + best_cl2 = cl2 + m; /* m is pf.shape[1], which should be n_theta//2 */ + } + + } /* s */ + } /* cl2 */ + }/* cl1 */ + + /* update global best for i, j */ + clmatrix[i*n + j] = best_cl1; + clmatrix[j*n+i] = best_cl2; /* [j,i] */ + +} /* build_clmatrix_kernel */ + +extern "C" __global__ +void fbuild_clmatrix_kernel( + const int n, + const int m, + const int r, + const complex* __restrict__ pf, + int16_t* const __restrict__ clmatrix, + const int n_shifts, + const complex* const __restrict__ shift_phases) +{ + /* n n_img */ + /* m angular componentns, n_theta//2 */ + /* r radial componentns */ + /* (n, m, r) = pf.shape in python (before transpose for CUDA kernel) */ + + /* thread index (2d), represents "i" and "j" indices */ + const unsigned int i = blockDim.x * blockIdx.x + threadIdx.x; + const unsigned int j = blockDim.y * blockIdx.y + threadIdx.y; + + /* no-op when out of bounds */ + if(i >= n) return; + if(j >= n) return; + /* no-op lower triangle */ + if(j <= i) return; + + int k; + int s; + int cl1, cl2; + int best_cl1, best_cl2; + float xcorr, best_cl_xcorr; + float p1, p2; + complex pfik, pfjk; + + best_cl1 = -1; + best_cl2 = -1; + best_cl_xcorr = -INFINITY; + + for(cl1=0; cl1 best_cl_xcorr){ + best_cl_xcorr = xcorr; + best_cl1 = cl1; + best_cl2 = cl2; + } + + xcorr = p1 + p2; + if(xcorr > best_cl_xcorr){ + best_cl_xcorr = xcorr; + best_cl1 = cl1; + best_cl2 = cl2 + m; /* m is pf.shape[1], which should be n_theta//2 */ + } + + } /* s */ + } /* cl2 */ + }/* cl1 */ + + /* update global best for i, j */ + clmatrix[i*n + j] = best_cl1; + clmatrix[j*n+i] = best_cl2; /* [j,i] */ + +} /* fbuild_clmatrix_kernel */ diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index c0c3718803..3329e117a3 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -1,12 +1,13 @@ import logging import math +import os import numpy as np import scipy.sparse as sparse from aspire.image import Image from aspire.operators import PolarFT -from aspire.utils import common_line_from_rots, fuzzy_mask, tqdm +from aspire.utils import common_line_from_rots, complex_type, fuzzy_mask, tqdm from aspire.utils.random import choice logger = logging.getLogger(__name__) @@ -60,7 +61,9 @@ def __init__( self.n_theta = n_theta self.n_check = n_check self.hist_bin_width = hist_bin_width - self.full_width = full_width + if str(full_width).lower() == "adaptive": + full_width = -1 + self.full_width = int(full_width) self.clmatrix = None self.max_shift = math.ceil(max_shift * self.n_res) self.shift_step = shift_step @@ -68,6 +71,29 @@ def __init__( self.rotations = None self._pf = None + # Sanity limit to match potential clmatrix dtype of int16. + if self.n_img > (2**15 - 1): + raise NotImplementedError( + "Commonlines implementation limited to <2**15 images." + ) + + # Auto configure GPU + self.__gpu_module = None + try: + import cupy as cp + + if cp.cuda.runtime.getDeviceCount() >= 1: + gpu_id = cp.cuda.runtime.getDevice() + logger.info( + f"cupy and GPU {gpu_id} found by cuda runtime; enabling cupy." + ) + self.__gpu_module = self.__init_cupy_module() + else: + logger.info("GPU not found, defaulting to numpy.") + + except ModuleNotFoundError: + logger.info("cupy not found, defaulting to numpy.") + self._build() def _build(self): @@ -131,6 +157,24 @@ def estimate_rotations(self): def build_clmatrix(self): """ Build common-lines matrix from Fourier stack of 2D images + + Wrapper for cpu/gpu dispatch. + """ + + logger.info("Begin building Common Lines Matrix") + + # host/gpu dispatch + if self.__gpu_module: + res = self.build_clmatrix_cu() + else: + res = self.build_clmatrix_host() + + # Unpack result + self.shifts_1d, self.clmatrix = res + + def build_clmatrix_host(self): + """ + Build common-lines matrix from Fourier stack of 2D images """ n_img = self.n_img @@ -233,8 +277,102 @@ def build_clmatrix(self): pbar.update() pbar.close() - self.clmatrix = clmatrix - self.shifts_1d = shifts_1d + return shifts_1d, clmatrix + + def build_clmatrix_cu(self): + """ + Build common-lines matrix from Fourier stack of 2D images + """ + + import cupy as cp + + n_img = self.n_img + r = self.pf.shape[2] + + if self.n_theta % 2 == 1: + msg = "n_theta must be even" + logger.error(msg) + raise NotImplementedError(msg) + + # Copy to prevent modifying self.pf for other functions + # Simultaneously place on GPU + pf = cp.array(self.pf) + + # Allocate local variables for return + # clmatrix represents the common lines matrix. + # Namely, clmatrix[i,j] contains the index in image i of + # the common line with image j. Note the common line index + # starts from 0 instead of 1 as Matlab version. -1 means + # there is no common line such as clmatrix[i,i]. + clmatrix = -cp.ones((n_img, n_img), dtype=np.int16) + + # Allocate variables used for shift estimation + # + # Set maximum value of 1D shift (in pixels) to search + # between common-lines. + # Set resolution of shift estimation in pixels. Note that + # shift_step can be any positive real number. + # + # Prepare the shift phases to try and generate filter for common-line detection + # + # Note the CUDA implementation has been optimized to not + # compute or return diagnostic 1d shifts. + _, shift_phases, h = self._generate_shift_phase_and_filter( + r, self.max_shift, self.shift_step + ) + # Transfer to device, dtypes must match kernel header. + shift_phases = cp.asarray(shift_phases, dtype=complex_type(self.dtype)) + + # Apply bandpass filter, normalize each ray of each image + # Note that this only uses half of each ray + pf = self._apply_filter_and_norm("ijk, k -> ijk", pf, r, h) + + # Tranpose `pf` for better (CUDA) memory access pattern, and cast as needed. + pf = cp.ascontiguousarray(pf.T, dtype=complex_type(self.dtype)) + + # Get kernel + if self.dtype == np.float64: + build_clmatrix_kernel = self.__gpu_module.get_function( + "build_clmatrix_kernel" + ) + elif self.dtype == np.float32: + build_clmatrix_kernel = self.__gpu_module.get_function( + "fbuild_clmatrix_kernel" + ) + else: + raise NotImplementedError( + "build_clmatrix_kernel only implemented for float32 and float64." + ) + + # Configure grid of blocks + blkszx = 32 + # Enough blocks to cover n_img-1 + nblkx = (self.n_img + blkszx - 2) // blkszx + blkszy = 32 + # Enough blocks to cover n_img + nblky = (self.n_img + blkszy - 1) // blkszy + + # Launch + logger.info("Launching `build_clmatrix_kernel`.") + build_clmatrix_kernel( + (nblkx, nblky), + (blkszx, blkszy), + ( + n_img, + pf.shape[1], + r, + pf, + clmatrix, + len(shift_phases), + shift_phases, + ), + ) + + # Copy result device arrays to host + clmatrix = clmatrix.get().astype(self.dtype, copy=False) + + # Note diagnostic 1d shifts are not computed in the CUDA implementation. + return None, clmatrix def estimate_shifts(self, equations_factor=1, max_memory=4000): """ @@ -488,10 +626,10 @@ def _generate_shift_phase_and_filter(self, r_max, max_shift, shift_step): n_shifts = int(np.ceil(2 * max_shift / shift_step + 1)) # only half of ray, excluding the DC component. - rk = np.arange(1, r_max + 1) + rk = np.arange(1, r_max + 1, dtype=self.dtype) # Generate all shift phases - shifts = -max_shift + shift_step * np.arange(n_shifts) + shifts = -max_shift + shift_step * np.arange(n_shifts, dtype=self.dtype) shift_phases = np.exp(np.outer(shifts, -2 * np.pi * 1j * rk / (2 * r_max + 1))) # Set filter for common-line detection h = np.sqrt(np.abs(rk)) * np.exp(-np.square(rk) / (2 * (r_max / 4) ** 2)) @@ -556,7 +694,7 @@ def _apply_filter_and_norm(self, subscripts, pf, r_max, h): # Note if we'd rather not have the dtype and casting args, # we can control h.dtype instead. - np.einsum(subscripts, pf, h, out=pf, dtype=pf.dtype, casting="same_kind") + pf = np.einsum(subscripts, pf, h, dtype=pf.dtype) # This is a high pass filter, cutting out the lowest frequency # (DC has already been removed). @@ -564,3 +702,27 @@ def _apply_filter_and_norm(self, subscripts, pf, r_max, h): pf /= np.linalg.norm(pf, axis=-1)[..., np.newaxis] return pf + + @staticmethod + def __init_cupy_module(): + """ + Private utility method to read in CUDA source and return as + compiled CuPy module. + """ + + import cupy as cp + + # Read in contents of file + fp = os.path.join(os.path.dirname(__file__), "commonline_base.cu") + with open(fp, "r") as fh: + module_code = fh.read() + + # CuPy compile the CUDA code + # Note these optimizations are to steer aggresive optimization + # for single precision code. Fast math will potentionally + # reduce accuracy in single precision. + return cp.RawModule( + code=module_code, + backend="nvcc", + options=("-O3", "--use_fast_math", "--extra-device-vectorization"), + ) diff --git a/src/aspire/abinitio/commonline_sync3n.cu b/src/aspire/abinitio/commonline_sync3n.cu index eeaee723b9..65ab91a77d 100644 --- a/src/aspire/abinitio/commonline_sync3n.cu +++ b/src/aspire/abinitio/commonline_sync3n.cu @@ -1,8 +1,36 @@ +#include "stdint.h" +#include "math.h" -/* from i,j indices to the common index in the N-choose-2 sized array */ +/* From i,j indices to the common index in the N-choose-2 sized array */ +/* Careful, this is strictly the upper triangle! */ #define PAIR_IDX(N,I,J) ((2*N-I-1)*I/2 + J-I-1) +/* convert euler angles (a,b,c) in ZYZ to rotation matrix r */ +__host__ __device__ +inline void ang2orth(double* r, double a, double b, double c){ + double sa = sin(a); + double sb = sin(b); + double sc = sin(c); + double ca = cos(a); + double cb = cos(b); + double cc = cos(c); + + /* ZYZ Proper Euler angles */ + /* https://en.wikipedia.org/wiki/Euler_angles#Rotation_matrix */ + r[0] = ca*cb*cc - sa*sc; + r[1] = -cc*sa -ca*cb*sc; + r[2] = ca*sb; + r[3] = ca*sc + cb*cc*sa; + r[4] = ca*cc - cb*sa*sc; + r[5] = sa*sb; + r[6] = -cc*sb; + r[7] = sb*sc; + r[8] = cb; +} + + +__host__ __device__ inline void mult_3x3(double *out, double *R1, double *R2) { /* 3X3 matrices multiplication: out = R1*R2 * Note, this differs from the MATLAB mult_3x3. @@ -20,6 +48,7 @@ inline void mult_3x3(double *out, double *R1, double *R2) { } } +__host__ __device__ inline void JRJ(double *R, double *A) { /* multiple 3X3 matrix by J from both sizes: A = JRJ */ A[0]=R[0]; @@ -33,6 +62,7 @@ inline void JRJ(double *R, double *A) { A[8]=R[8]; } +__host__ __device__ inline double diff_norm_3x3(const double *R1, const double *R2) { /* difference 2 matrices and return squared norm: ||R1-R2||^2 */ int i; @@ -413,3 +443,273 @@ void triangle_scores_inner(int n, double* Rijs, int n_intervals, unsigned int* s return; }; + +extern "C" __global__ +void estimate_all_angles1(int j, + int n, + int n_theta, + double hist_bin_width, + int full_width, + double sigma, + int sync, + int16_t* __restrict__ clmatrix, + double* __restrict__ hist, + uint16_t* __restrict__ k_map, + double* __restrict__ angles_map, + double* __restrict__ angles) +{ + /* n n_img */ + /* j is image j index */ + + /* thread index represents "i" index */ + const unsigned int i = blockDim.x * blockIdx.x + threadIdx.x; + /* thread index represents "k" index */ + const unsigned int k = blockDim.y * blockIdx.y + threadIdx.y; + + int cl_diff1, cl_diff2, cl_diff3; + double theta1, theta2, theta3; + double c1, c2, c3; + double cond; + double cos_phi2; + double w_theta_need; + + /* no-op when out of bounds */ + if(i >= n) return; + if(k >= n) return; + if(i >= j) return; + /* + These are also tested later via the clmatrix values, + testing now avoids extra reads. + */ + if(k==i) return; + if(k==j) return; + + int map_idx; /* tmp index var */ + + int cl_idx12, cl_idx21; + int cl_idx13, cl_idx31; + int cl_idx23, cl_idx32; + const int ntics = 180. / hist_bin_width; + const double TOL_idx = 1e-12; + bool ind1, ind2; + double grid_angle, angle_diff, angle; + int b; + const double two_sigma_sq = 2*sigma*sigma; + + + const int pair_idx = PAIR_IDX(n,i,j); + + cl_idx12 = clmatrix[i*n + j]; + cl_idx21 = clmatrix[j*n + i]; + /* + MATLAB code indicated this condition might occur outside i==j; + Ask Yoel what other reasons this would occur. + */ + if(cl_idx12 == -1) return; + + /* Assume that k_list starts as all n images */ + + cl_idx13 = clmatrix[i*n + k]; + cl_idx31 = clmatrix[k*n + i]; + cl_idx23 = clmatrix[j*n + k]; + cl_idx32 = clmatrix[k*n + j]; + + /* test `k` values */ + if(cl_idx13 == -1) return; /* i, k */ + if(cl_idx23 == -1) return; /* j, k */ + + /* get cosine angles */ + cl_diff1 = cl_idx13 - cl_idx12; + cl_diff2 = cl_idx23 - cl_idx21; + cl_diff3 = cl_idx32 - cl_idx31; + + theta1 = cl_diff1 * 2 * M_PI / n_theta; + theta2 = cl_diff2 * 2 * M_PI / n_theta; + theta3 = cl_diff3 * 2 * M_PI / n_theta; + + c1 = cos(theta1); + c2 = cos(theta2); + c3 = cos(theta3); + + /* test if we have a good index */ + cond = 1 + 2 * c1 * c2 * c3 - (c1*c1 + c2*c2 + c3*c3); + if(cond <= 1e-5) return; /* current value of k is not good, skip */ + + /* Calculated cos values of angle between i and j images */ + if( sync == 1){ + + cos_phi2 = (c3 - c1*c2) / (sqrt(1 - c1*c1) * sqrt(1 - c2*c2)); + + /* + Some synchronization must be applied when common line is out by 180 degrees. + Here fix the angles between c_ij(c_ji) and c_ik(c_jk) to be smaller than pi/2, + otherwise there will be an ambiguity between alpha and pi-alpha. + */ + + /* Check sync conditions */ + ind1 = (theta1 > (M_PI + TOL_idx)) || ( + (theta1 < -TOL_idx) && (theta1 > -M_PI) + ); + ind2 = (theta2 > (M_PI + TOL_idx)) || ( + (theta2 < -TOL_idx) && (theta2 > -M_PI) + ); + if( (ind1 && !ind2) || (!ind1 && ind2)){ + /* Apply sync */ + cos_phi2 = -cos_phi2; + } + + } /* end sync */ + else{ + cos_phi2 = (c3 - c1*c2 ) / (sin(theta1) * sin(theta2)); + } /* end not sync */ + + /* clip cosine phi between [-1,1] */ + if(cos_phi2 > 1){ + cos_phi2 = 1; + } + if(cos_phi2 < -1){ + cos_phi2 = -1; + } + + /* compute histogram contribution, angle mapping, and index mappings. */ + angle = acos(cos_phi2) * 180. / M_PI; + /* index of angle's bin */ + map_idx = i*n + k; + /* + For each k, keep track of bin and angles. + Note, this is slightly different than the host + which uses slightly different angle/hist grids (likely an oversight). + */ + k_map[map_idx] = angle / hist_bin_width; + angles_map[map_idx] = angle; /* degrees */ + for(b=0; b= n) return; + if(i >= j) return; + + int map_idx; /* tmp index var */ + + const int ntics = 180. / hist_bin_width; + int b; + int peak_idx; + double peak; + + const int pair_idx = PAIR_IDX(n,i,j); + + /* Find peak and peak index in histogram */ + peak = -99999; + peak_idx = -1; + for(b=0; b peak){ + peak = hist[map_idx]; + peak_idx = b; + } + } + + /* find mean of rotations */ + + if(full_width==-1){ + /* adaptive width*/ + w_theta_needed = 0; + cnt = 0; + while(cnt == 0){ + /* broaden search width */ + w_theta_needed += hist_bin_width; + /* find satisfying indices */ + for(k=0; k= n) return; + + ang2orth(&(rotations[i*9]), angles[i*3], angles[i*3+1], angles[i*3+2]); + +} diff --git a/src/aspire/abinitio/commonline_sync3n.py b/src/aspire/abinitio/commonline_sync3n.py index 6a7e5390d3..879728436d 100644 --- a/src/aspire/abinitio/commonline_sync3n.py +++ b/src/aspire/abinitio/commonline_sync3n.py @@ -62,6 +62,7 @@ def __init__( full_width="adaptive", epsilon=1e-2, max_iters=1000, + sigma=3, seed=None, mask=True, S_weighting=False, @@ -83,6 +84,7 @@ def __init__( `hist_bin_width`s required to find at least one valid image index. :param epsilon: Tolerance for the power method. :param max_iter: Maximum iterations for the power method. + :param sigma: Voting contribution smoothing factor. :param seed: Optional seed for RNG. :param mask: Option to mask `src.images` with a fuzzy mask (boolean). Default, `True`, applies a mask. @@ -113,6 +115,7 @@ def __init__( self.epsilon = epsilon self.max_iters = max_iters + self.sigma = float(sigma) self.seed = seed # Sync3N specific vars @@ -820,6 +823,136 @@ def _estimate_all_Rijs(self, clmatrix): """ Estimate Rijs using the voting method. + :param clmatrix: Common lines matrix + :return: Estimated rotations + """ + # host/gpu dispatch + if self.__gpu_module: + res = self._estimate_all_Rijs_cu(clmatrix) + else: + res = self._estimate_all_Rijs_host(clmatrix) + + return res + + def _estimate_all_Rijs_cu(self, clmatrix): + import cupy as cp + + estimate_all_angles1 = self.__gpu_module.get_function("estimate_all_angles1") + estimate_all_angles2 = self.__gpu_module.get_function("estimate_all_angles2") + angles_to_rots = self.__gpu_module.get_function("angles_to_rots") + + # Use the sync3n MATLAB implementation, + # other mode exists to support other CL methods. + sync = 1 + + # transfer input to device + clmatrix = cp.asarray(clmatrix, order="C", dtype=np.int16) + + # workspace arrays + ntics = int(180 / self.hist_bin_width) + n_pairs = self.n_img * (self.n_img - 1) // 2 + hist = cp.zeros((self.n_img, ntics), dtype=np.float64) + # k_map stores the mapping of i, k indices to histogram bins + k_map = cp.zeros((self.n_img, self.n_img), dtype=np.uint16) + # angles_map stores mapping of i, k indices to angles + angles_map = cp.zeros((self.n_img, self.n_img), dtype=np.float64) + # resulting pairs i,j euler angles + angles = cp.zeros((n_pairs, 3), dtype=np.float64) + + # Configure 2d grid of blocks (kernel 1) + blkszx = 32 + nblkx = (self.n_img + blkszx - 1) // blkszx # i + blkszy = 32 + nblky = (self.n_img + blkszy - 1) // blkszy # k + + # Configure 1d grid of blocks (kernel 2) + blksz = 1024 + nblk = (self.n_img + blksz - 1) // blksz + + for j in range(self.n_img): + + # ------------------------------------------ + # Zero histogram and k mapping for each `j`. + hist[:] = 0 + k_map[:] = 0 + + # ------------------- + # Vote into histogram + estimate_all_angles1( + (nblkx, nblky), + (blkszx, blkszy), + ( + j, + self.n_img, + self.n_theta, + np.float64(self.hist_bin_width), + self.full_width, + np.float64(self.sigma), + sync, + clmatrix, # input + hist, # tmp + k_map, # tmp + angles_map, # tmp + angles, # output + ), + ) + + # ------------------------- + # Solve hist for mean angle + estimate_all_angles2( + (nblk,), + (blksz,), + ( + j, + self.n_img, + np.float64(self.hist_bin_width), + self.full_width, + hist, # tmp + k_map, # tmp + angles_map, # tmp + angles, # output + ), + ) + + # Force all kernels to complete + cp.cuda.runtime.deviceSynchronize() + + # Explicitly inform CuPy we longer need these workspace vars + del hist + del k_map + del angles_map + + # --------------------------- + # Convert angles to rotations + rotations = cp.empty((n_pairs, 3, 3), dtype=np.float64) + + # Configure another 1d grid of blocks + blksz = 1024 + nblk = (n_pairs + blksz - 1) // blksz + + logger.info("Launching `angles_to_rots` kernel.") + angles_to_rots( + (nblk,), + (blksz,), + ( + n_pairs, + angles, + rotations, + ), + ) + + # Force all kernels to complete + cp.cuda.runtime.deviceSynchronize() + + # transfer results device to host + rotations = rotations.get() + + return rotations + + def _estimate_all_Rijs_host(self, clmatrix): + """ + Estimate Rijs using the voting method. + :param clmatrix: Common lines matrix :return: Estimated rotations """ @@ -1061,4 +1194,4 @@ def __init_cupy_module(): module_code = fh.read() # CUPY compile the CUDA code - return cp.RawModule(code=module_code) + return cp.RawModule(code=module_code, backend="nvcc") diff --git a/src/aspire/abinitio/sync_voting.py b/src/aspire/abinitio/sync_voting.py index fb626a8a91..651ba35c5e 100644 --- a/src/aspire/abinitio/sync_voting.py +++ b/src/aspire/abinitio/sync_voting.py @@ -142,7 +142,7 @@ def _vote_ij(self, clmatrix, n_theta, i, j, k_list, sync=False): # similar. This sigma ensures that the width of the density # estimation kernel is roughly 10 degrees. For 15 degrees, the # value of the kernel is negligible. - sigma = 3.0 + sigma = getattr(self, "sigma", 3.0) # get from class if avail # Compute the histogram of the angles between images i and j angles_distances = angles_grid[None, :] - angles[:, None] @@ -156,7 +156,7 @@ def _vote_ij(self, clmatrix, n_theta, i, j, k_list, sync=False): # that accidentally fall near the peak. peak_idx = angles_hist.argmax() - if str(self.full_width).lower() == "adaptive": + if self.full_width == -1: # Adaptive width (MATLAB) # Look for the estimations in the peak of the histogram w_theta_needed = 0 diff --git a/tests/test_commonline_sync3n.py b/tests/test_commonline_sync3n.py index 600f883c2d..8305aa85e9 100644 --- a/tests/test_commonline_sync3n.py +++ b/tests/test_commonline_sync3n.py @@ -62,7 +62,12 @@ def source_orientation_objs(resolution, offsets, dtype): max_shift = 0.20 shift_step = 0.25 # Reduce shift steps for non-integer offsets of Simulation. - orient_est = CLSync3N(src, max_shift=max_shift, shift_step=shift_step, seed=789) + orient_est = CLSync3N( + src, + max_shift=max_shift, + shift_step=shift_step, + seed=789, + ) # Estimate rotations once for all tests. orient_est.estimate_rotations() @@ -136,3 +141,45 @@ def test_estimate_rotations(source_orientation_objs): if src.offsets.all() != 0: tol = 4 mean_aligned_angular_distance(orient_est.rotations, src.rotations, degree_tol=tol) + + +@pytest.mark.expensive +def test_weighted_sync3n(source_orientation_objs): + """ + Test alternative Sync3N configuration code paths. + """ + src, _ = source_orientation_objs + + # Search for common lines over less shifts for 0 offsets. + max_shift = 1 / src.L + shift_step = 1 + if src.offsets.all() != 0: + max_shift = 0.20 + shift_step = 0.25 # Reduce shift steps for non-integer offsets of Simulation. + + orient_est = CLSync3N( + src, + max_shift=max_shift, + shift_step=shift_step, + seed=789, + S_weighting=True, + J_weighting=True, + full_width=2, + sigma=2.9, + ) + # Estimate rotations + orient_est.estimate_rotations() + + gt_clmatrix = rots_to_clmatrix(src.rotations, orient_est.n_theta) + + angle_diffs = abs(orient_est.clmatrix - gt_clmatrix) * 360 / orient_est.n_theta + + # Count number of estimates within 5 degrees of ground truth. + within_5 = np.sum((angle_diffs - 360) % 360 < 5) + + # Check that at least 98% of estimates are within 5 degrees. + tol = 0.98 + if src.offsets.all() != 0: + # Set tolerance to 75% when using nonzero offsets. + tol = 0.75 + assert within_5 / angle_diffs.size > tol