Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
ca2d15d
quick and dirty cupy build_clmatrix
garrettwrong Aug 14, 2024
99e0bf2
stashing
garrettwrong Aug 15, 2024
80d9acd
stashing 2, clmatrix match, good speedup
garrettwrong Aug 15, 2024
0ec15b4
cleanup
garrettwrong Aug 19, 2024
6cb90d1
implement transpose PF for better mem patterns
garrettwrong Aug 19, 2024
9137b6c
use cu complex for CL kernel
garrettwrong Aug 19, 2024
6a4689a
fixed shifts, dtype
garrettwrong Aug 20, 2024
6a0e276
revert some unused optimizations, fix minor casting/dtype issue, begi…
garrettwrong Aug 20, 2024
0930e69
tox check cleanup
garrettwrong Aug 20, 2024
997da45
stashing stub
garrettwrong Aug 26, 2024
425d3ee
stashing stubs
garrettwrong Sep 4, 2024
5b0d7cc
stashing init kernel port
garrettwrong Sep 4, 2024
0b779a1
update kernel with some of the angs work
garrettwrong Sep 17, 2024
6235b1f
start populating eastimate_Rijs kernel call (stash)
garrettwrong Sep 17, 2024
41665b7
breakout angles and add angles map (stash)
garrettwrong Sep 17, 2024
be32959
pair_idx doesn't include diag
garrettwrong Sep 17, 2024
62d6d1a
angles matching (Stash), bug in angles to rot func
garrettwrong Sep 18, 2024
558fddc
fixed zyz angles conversion
garrettwrong Sep 18, 2024
cd26b1a
remove dbg prints
garrettwrong Sep 18, 2024
9e7a3b4
add adaptive width to kernel
garrettwrong Sep 18, 2024
9d1d6e2
1d rij kernel
garrettwrong Sep 19, 2024
1882d6c
implement nvcc backend and int16_t
garrettwrong Sep 19, 2024
306e35e
split kernels
garrettwrong Sep 19, 2024
18b0548
general cleanup
garrettwrong Sep 19, 2024
2f54330
threads over k
garrettwrong Sep 19, 2024
4b663fd
continue cleanup threads over k
garrettwrong Sep 19, 2024
0e219cc
fix j<i bound bug
garrettwrong Sep 23, 2024
3eaef56
fix adative param oversight bug
garrettwrong Sep 23, 2024
1748a69
parallel case bug
garrettwrong Sep 23, 2024
90ff5b3
C order, sigh
garrettwrong Sep 24, 2024
627d4c5
remove unused vars from build cl kernel
garrettwrong Oct 4, 2024
230fd0f
remove unused vars from build cl kernel
garrettwrong Oct 4, 2024
b64164e
continue removing unused vars
garrettwrong Oct 10, 2024
58fdb60
update constants
garrettwrong Oct 10, 2024
7e6415d
add single precision build CL kernel and launching code
garrettwrong Oct 10, 2024
a570155
revert accidental config commit
garrettwrong Oct 10, 2024
bc4c3b8
cleanup base cuda code a little
garrettwrong Oct 10, 2024
c4ca481
self review cleanup
garrettwrong Oct 10, 2024
1631cff
use adaptive width mode for sync3n tests
garrettwrong Oct 10, 2024
706c2bd
add additional sync3n code paths
garrettwrong Oct 10, 2024
f2025c8
must use smaller shift step for unit test size problem
garrettwrong Oct 11, 2024
13cc09f
Remove missed debug string change
garrettwrong Oct 16, 2024
78030fe
Remove range(0,...)
garrettwrong Oct 16, 2024
6a99b05
change var name from dist to xcorr
garrettwrong Oct 18, 2024
c47b626
Remove kernel timing
garrettwrong Oct 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions src/aspire/abinitio/commonline_base.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
#include <stdint.h>
#include <math.h>
#include <cupy/complex.cuh>

extern "C" __global__
void build_clmatrix_kernel(
const int n,
const int m,
const int r,
const complex<double>* __restrict__ pf,
int16_t* const __restrict__ clmatrix,
const int n_shifts,
const complex<double>* const __restrict__ shift_phases)
{
/* n n_img */
/* m angular componentns, n_theta//2 */
/* r radial componentns */
/* (n, m, r) = pf.shape in python (before transpose for CUDA kernel) */

/* thread index (2d), represents "i" and "j" indices */
const unsigned int i = blockDim.x * blockIdx.x + threadIdx.x;
const unsigned int j = blockDim.y * blockIdx.y + threadIdx.y;

/* no-op when out of bounds */
if(i >= n) return;
if(j >= n) return;
/* no-op lower triangle */
if(j <= i) return;

int k;
int s;
int cl1, cl2;
int best_cl1, best_cl2;
double xcorr, best_cl_xcorr;
double p1, p2;
complex<double> pfik, pfjk;

best_cl1 = -1;
best_cl2 = -1;
best_cl_xcorr = -INFINITY;

for(cl1=0; cl1<m; cl1++){
for(cl2=0; cl2<m; cl2++){
for(s=0; s<n_shifts; s++){
p1 = 0;
p2 = 0;
/* inner most dim of dot (matmul) */
for(k=0; k<r; k++){
pfik = pf[k*m*n + cl1*n + i];
pfjk = conj(pf[k*m*n + cl2*n + j]) * shift_phases[s*r + k];
p1 += real(pfik) * real(pfjk);
p2 += imag(pfik) * imag(pfjk);
} /* k */

xcorr = p1 - p2;
if(xcorr > best_cl_xcorr){
best_cl_xcorr = xcorr;
best_cl1 = cl1;
best_cl2 = cl2;
}

xcorr = p1 + p2;
if(xcorr > best_cl_xcorr){
best_cl_xcorr = xcorr;
best_cl1 = cl1;
best_cl2 = cl2 + m; /* m is pf.shape[1], which should be n_theta//2 */
}

} /* s */
} /* cl2 */
}/* cl1 */

/* update global best for i, j */
clmatrix[i*n + j] = best_cl1;
clmatrix[j*n+i] = best_cl2; /* [j,i] */

} /* build_clmatrix_kernel */

extern "C" __global__
void fbuild_clmatrix_kernel(
const int n,
const int m,
const int r,
const complex<float>* __restrict__ pf,
int16_t* const __restrict__ clmatrix,
const int n_shifts,
const complex<float>* const __restrict__ shift_phases)
{
/* n n_img */
/* m angular componentns, n_theta//2 */
/* r radial componentns */
/* (n, m, r) = pf.shape in python (before transpose for CUDA kernel) */

/* thread index (2d), represents "i" and "j" indices */
const unsigned int i = blockDim.x * blockIdx.x + threadIdx.x;
const unsigned int j = blockDim.y * blockIdx.y + threadIdx.y;

/* no-op when out of bounds */
if(i >= n) return;
if(j >= n) return;
/* no-op lower triangle */
if(j <= i) return;

int k;
int s;
int cl1, cl2;
int best_cl1, best_cl2;
float xcorr, best_cl_xcorr;
float p1, p2;
complex<float> pfik, pfjk;

best_cl1 = -1;
best_cl2 = -1;
best_cl_xcorr = -INFINITY;

for(cl1=0; cl1<m; cl1++){
for(cl2=0; cl2<m; cl2++){
for(s=0; s<n_shifts; s++){
p1 = 0;
p2 = 0;
/* inner most dim of dot (matmul) */
for(k=0; k<r; k++){
pfik = pf[k*m*n + cl1*n + i];
pfjk = conj(pf[k*m*n + cl2*n + j]) * shift_phases[s*r + k];
p1 += real(pfik) * real(pfjk);
p2 += imag(pfik) * imag(pfjk);
} /* k */

xcorr = p1 - p2;
if(xcorr > best_cl_xcorr){
best_cl_xcorr = xcorr;
best_cl1 = cl1;
best_cl2 = cl2;
}

xcorr = p1 + p2;
if(xcorr > best_cl_xcorr){
best_cl_xcorr = xcorr;
best_cl1 = cl1;
best_cl2 = cl2 + m; /* m is pf.shape[1], which should be n_theta//2 */
}

} /* s */
} /* cl2 */
}/* cl1 */

/* update global best for i, j */
clmatrix[i*n + j] = best_cl1;
clmatrix[j*n+i] = best_cl2; /* [j,i] */

} /* fbuild_clmatrix_kernel */
176 changes: 169 additions & 7 deletions src/aspire/abinitio/commonline_base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import logging
import math
import os

import numpy as np
import scipy.sparse as sparse

from aspire.image import Image
from aspire.operators import PolarFT
from aspire.utils import common_line_from_rots, fuzzy_mask, tqdm
from aspire.utils import common_line_from_rots, complex_type, fuzzy_mask, tqdm
from aspire.utils.random import choice

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -60,14 +61,39 @@
self.n_theta = n_theta
self.n_check = n_check
self.hist_bin_width = hist_bin_width
self.full_width = full_width
if str(full_width).lower() == "adaptive":
full_width = -1
self.full_width = int(full_width)
self.clmatrix = None
self.max_shift = math.ceil(max_shift * self.n_res)
self.shift_step = shift_step
self.mask = mask
self.rotations = None
self._pf = None

# Sanity limit to match potential clmatrix dtype of int16.
if self.n_img > (2**15 - 1):
raise NotImplementedError(

Check warning on line 76 in src/aspire/abinitio/commonline_base.py

View check run for this annotation

Codecov / codecov/patch

src/aspire/abinitio/commonline_base.py#L76

Added line #L76 was not covered by tests
"Commonlines implementation limited to <2**15 images."
)

# Auto configure GPU
self.__gpu_module = None
try:
import cupy as cp

if cp.cuda.runtime.getDeviceCount() >= 1:
gpu_id = cp.cuda.runtime.getDevice()
logger.info(

Check warning on line 87 in src/aspire/abinitio/commonline_base.py

View check run for this annotation

Codecov / codecov/patch

src/aspire/abinitio/commonline_base.py#L85-L87

Added lines #L85 - L87 were not covered by tests
f"cupy and GPU {gpu_id} found by cuda runtime; enabling cupy."
)
self.__gpu_module = self.__init_cupy_module()

Check warning on line 90 in src/aspire/abinitio/commonline_base.py

View check run for this annotation

Codecov / codecov/patch

src/aspire/abinitio/commonline_base.py#L90

Added line #L90 was not covered by tests
else:
logger.info("GPU not found, defaulting to numpy.")

Check warning on line 92 in src/aspire/abinitio/commonline_base.py

View check run for this annotation

Codecov / codecov/patch

src/aspire/abinitio/commonline_base.py#L92

Added line #L92 was not covered by tests

except ModuleNotFoundError:
logger.info("cupy not found, defaulting to numpy.")

self._build()

def _build(self):
Expand Down Expand Up @@ -131,6 +157,24 @@
def build_clmatrix(self):
"""
Build common-lines matrix from Fourier stack of 2D images

Wrapper for cpu/gpu dispatch.
"""

logger.info("Begin building Common Lines Matrix")

# host/gpu dispatch
if self.__gpu_module:
res = self.build_clmatrix_cu()

Check warning on line 168 in src/aspire/abinitio/commonline_base.py

View check run for this annotation

Codecov / codecov/patch

src/aspire/abinitio/commonline_base.py#L168

Added line #L168 was not covered by tests
else:
res = self.build_clmatrix_host()

# Unpack result
self.shifts_1d, self.clmatrix = res

def build_clmatrix_host(self):
"""
Build common-lines matrix from Fourier stack of 2D images
"""

n_img = self.n_img
Expand Down Expand Up @@ -233,8 +277,102 @@
pbar.update()
pbar.close()

self.clmatrix = clmatrix
self.shifts_1d = shifts_1d
return shifts_1d, clmatrix

def build_clmatrix_cu(self):
"""
Build common-lines matrix from Fourier stack of 2D images
"""

import cupy as cp

Check warning on line 287 in src/aspire/abinitio/commonline_base.py

View check run for this annotation

Codecov / codecov/patch

src/aspire/abinitio/commonline_base.py#L287

Added line #L287 was not covered by tests

n_img = self.n_img
r = self.pf.shape[2]

Check warning on line 290 in src/aspire/abinitio/commonline_base.py

View check run for this annotation

Codecov / codecov/patch

src/aspire/abinitio/commonline_base.py#L289-L290

Added lines #L289 - L290 were not covered by tests

if self.n_theta % 2 == 1:
msg = "n_theta must be even"
logger.error(msg)
raise NotImplementedError(msg)

Check warning on line 295 in src/aspire/abinitio/commonline_base.py

View check run for this annotation

Codecov / codecov/patch

src/aspire/abinitio/commonline_base.py#L292-L295

Added lines #L292 - L295 were not covered by tests

# Copy to prevent modifying self.pf for other functions
# Simultaneously place on GPU
pf = cp.array(self.pf)

Check warning on line 299 in src/aspire/abinitio/commonline_base.py

View check run for this annotation

Codecov / codecov/patch

src/aspire/abinitio/commonline_base.py#L299

Added line #L299 was not covered by tests

# Allocate local variables for return
# clmatrix represents the common lines matrix.
# Namely, clmatrix[i,j] contains the index in image i of
# the common line with image j. Note the common line index
# starts from 0 instead of 1 as Matlab version. -1 means
# there is no common line such as clmatrix[i,i].
clmatrix = -cp.ones((n_img, n_img), dtype=np.int16)

Check warning on line 307 in src/aspire/abinitio/commonline_base.py

View check run for this annotation

Codecov / codecov/patch

src/aspire/abinitio/commonline_base.py#L307

Added line #L307 was not covered by tests

# Allocate variables used for shift estimation
#
# Set maximum value of 1D shift (in pixels) to search
# between common-lines.
# Set resolution of shift estimation in pixels. Note that
# shift_step can be any positive real number.
#
# Prepare the shift phases to try and generate filter for common-line detection
#
# Note the CUDA implementation has been optimized to not
# compute or return diagnostic 1d shifts.
_, shift_phases, h = self._generate_shift_phase_and_filter(

Check warning on line 320 in src/aspire/abinitio/commonline_base.py

View check run for this annotation

Codecov / codecov/patch

src/aspire/abinitio/commonline_base.py#L320

Added line #L320 was not covered by tests
r, self.max_shift, self.shift_step
)
# Transfer to device, dtypes must match kernel header.
shift_phases = cp.asarray(shift_phases, dtype=complex_type(self.dtype))

Check warning on line 324 in src/aspire/abinitio/commonline_base.py

View check run for this annotation

Codecov / codecov/patch

src/aspire/abinitio/commonline_base.py#L324

Added line #L324 was not covered by tests

# Apply bandpass filter, normalize each ray of each image
# Note that this only uses half of each ray
pf = self._apply_filter_and_norm("ijk, k -> ijk", pf, r, h)

Check warning on line 328 in src/aspire/abinitio/commonline_base.py

View check run for this annotation

Codecov / codecov/patch

src/aspire/abinitio/commonline_base.py#L328

Added line #L328 was not covered by tests

# Tranpose `pf` for better (CUDA) memory access pattern, and cast as needed.
pf = cp.ascontiguousarray(pf.T, dtype=complex_type(self.dtype))

Check warning on line 331 in src/aspire/abinitio/commonline_base.py

View check run for this annotation

Codecov / codecov/patch

src/aspire/abinitio/commonline_base.py#L331

Added line #L331 was not covered by tests

# Get kernel
if self.dtype == np.float64:
build_clmatrix_kernel = self.__gpu_module.get_function(

Check warning on line 335 in src/aspire/abinitio/commonline_base.py

View check run for this annotation

Codecov / codecov/patch

src/aspire/abinitio/commonline_base.py#L334-L335

Added lines #L334 - L335 were not covered by tests
"build_clmatrix_kernel"
)
elif self.dtype == np.float32:
build_clmatrix_kernel = self.__gpu_module.get_function(

Check warning on line 339 in src/aspire/abinitio/commonline_base.py

View check run for this annotation

Codecov / codecov/patch

src/aspire/abinitio/commonline_base.py#L338-L339

Added lines #L338 - L339 were not covered by tests
"fbuild_clmatrix_kernel"
)
else:
raise NotImplementedError(

Check warning on line 343 in src/aspire/abinitio/commonline_base.py

View check run for this annotation

Codecov / codecov/patch

src/aspire/abinitio/commonline_base.py#L343

Added line #L343 was not covered by tests
"build_clmatrix_kernel only implemented for float32 and float64."
)

# Configure grid of blocks
blkszx = 32

Check warning on line 348 in src/aspire/abinitio/commonline_base.py

View check run for this annotation

Codecov / codecov/patch

src/aspire/abinitio/commonline_base.py#L348

Added line #L348 was not covered by tests
# Enough blocks to cover n_img-1
nblkx = (self.n_img + blkszx - 2) // blkszx
blkszy = 32

Check warning on line 351 in src/aspire/abinitio/commonline_base.py

View check run for this annotation

Codecov / codecov/patch

src/aspire/abinitio/commonline_base.py#L350-L351

Added lines #L350 - L351 were not covered by tests
# Enough blocks to cover n_img
nblky = (self.n_img + blkszy - 1) // blkszy

Check warning on line 353 in src/aspire/abinitio/commonline_base.py

View check run for this annotation

Codecov / codecov/patch

src/aspire/abinitio/commonline_base.py#L353

Added line #L353 was not covered by tests

# Launch
logger.info("Launching `build_clmatrix_kernel`.")
build_clmatrix_kernel(

Check warning on line 357 in src/aspire/abinitio/commonline_base.py

View check run for this annotation

Codecov / codecov/patch

src/aspire/abinitio/commonline_base.py#L356-L357

Added lines #L356 - L357 were not covered by tests
(nblkx, nblky),
(blkszx, blkszy),
(
n_img,
pf.shape[1],
r,
pf,
clmatrix,
len(shift_phases),
shift_phases,
),
)

# Copy result device arrays to host
clmatrix = clmatrix.get().astype(self.dtype, copy=False)

Check warning on line 372 in src/aspire/abinitio/commonline_base.py

View check run for this annotation

Codecov / codecov/patch

src/aspire/abinitio/commonline_base.py#L372

Added line #L372 was not covered by tests

# Note diagnostic 1d shifts are not computed in the CUDA implementation.
return None, clmatrix

Check warning on line 375 in src/aspire/abinitio/commonline_base.py

View check run for this annotation

Codecov / codecov/patch

src/aspire/abinitio/commonline_base.py#L375

Added line #L375 was not covered by tests

def estimate_shifts(self, equations_factor=1, max_memory=4000):
"""
Expand Down Expand Up @@ -488,10 +626,10 @@
n_shifts = int(np.ceil(2 * max_shift / shift_step + 1))

# only half of ray, excluding the DC component.
rk = np.arange(1, r_max + 1)
rk = np.arange(1, r_max + 1, dtype=self.dtype)

# Generate all shift phases
shifts = -max_shift + shift_step * np.arange(n_shifts)
shifts = -max_shift + shift_step * np.arange(n_shifts, dtype=self.dtype)
shift_phases = np.exp(np.outer(shifts, -2 * np.pi * 1j * rk / (2 * r_max + 1)))
# Set filter for common-line detection
h = np.sqrt(np.abs(rk)) * np.exp(-np.square(rk) / (2 * (r_max / 4) ** 2))
Expand Down Expand Up @@ -556,11 +694,35 @@

# Note if we'd rather not have the dtype and casting args,
# we can control h.dtype instead.
np.einsum(subscripts, pf, h, out=pf, dtype=pf.dtype, casting="same_kind")
pf = np.einsum(subscripts, pf, h, dtype=pf.dtype)

# This is a high pass filter, cutting out the lowest frequency
# (DC has already been removed).
pf[..., 0] = 0
pf /= np.linalg.norm(pf, axis=-1)[..., np.newaxis]

return pf

@staticmethod
def __init_cupy_module():
"""
Private utility method to read in CUDA source and return as
compiled CuPy module.
"""

import cupy as cp

Check warning on line 713 in src/aspire/abinitio/commonline_base.py

View check run for this annotation

Codecov / codecov/patch

src/aspire/abinitio/commonline_base.py#L713

Added line #L713 was not covered by tests

# Read in contents of file
fp = os.path.join(os.path.dirname(__file__), "commonline_base.cu")
with open(fp, "r") as fh:
module_code = fh.read()

Check warning on line 718 in src/aspire/abinitio/commonline_base.py

View check run for this annotation

Codecov / codecov/patch

src/aspire/abinitio/commonline_base.py#L716-L718

Added lines #L716 - L718 were not covered by tests

# CuPy compile the CUDA code
# Note these optimizations are to steer aggresive optimization
# for single precision code. Fast math will potentionally
# reduce accuracy in single precision.
return cp.RawModule(

Check warning on line 724 in src/aspire/abinitio/commonline_base.py

View check run for this annotation

Codecov / codecov/patch

src/aspire/abinitio/commonline_base.py#L724

Added line #L724 was not covered by tests
code=module_code,
backend="nvcc",
options=("-O3", "--use_fast_math", "--extra-device-vectorization"),
)
Loading
Loading