-
Notifications
You must be signed in to change notification settings - Fork 26
CUDA build common lines #1167
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
CUDA build common lines #1167
Changes from all commits
Commits
Show all changes
45 commits
Select commit
Hold shift + click to select a range
ca2d15d
quick and dirty cupy build_clmatrix
garrettwrong 99e0bf2
stashing
garrettwrong 80d9acd
stashing 2, clmatrix match, good speedup
garrettwrong 0ec15b4
cleanup
garrettwrong 6cb90d1
implement transpose PF for better mem patterns
garrettwrong 9137b6c
use cu complex for CL kernel
garrettwrong 6a4689a
fixed shifts, dtype
garrettwrong 6a0e276
revert some unused optimizations, fix minor casting/dtype issue, begi…
garrettwrong 0930e69
tox check cleanup
garrettwrong 997da45
stashing stub
garrettwrong 425d3ee
stashing stubs
garrettwrong 5b0d7cc
stashing init kernel port
garrettwrong 0b779a1
update kernel with some of the angs work
garrettwrong 6235b1f
start populating eastimate_Rijs kernel call (stash)
garrettwrong 41665b7
breakout angles and add angles map (stash)
garrettwrong be32959
pair_idx doesn't include diag
garrettwrong 62d6d1a
angles matching (Stash), bug in angles to rot func
garrettwrong 558fddc
fixed zyz angles conversion
garrettwrong cd26b1a
remove dbg prints
garrettwrong 9e7a3b4
add adaptive width to kernel
garrettwrong 9d1d6e2
1d rij kernel
garrettwrong 1882d6c
implement nvcc backend and int16_t
garrettwrong 306e35e
split kernels
garrettwrong 18b0548
general cleanup
garrettwrong 2f54330
threads over k
garrettwrong 4b663fd
continue cleanup threads over k
garrettwrong 0e219cc
fix j<i bound bug
garrettwrong 3eaef56
fix adative param oversight bug
garrettwrong 1748a69
parallel case bug
garrettwrong 90ff5b3
C order, sigh
garrettwrong 627d4c5
remove unused vars from build cl kernel
garrettwrong 230fd0f
remove unused vars from build cl kernel
garrettwrong b64164e
continue removing unused vars
garrettwrong 58fdb60
update constants
garrettwrong 7e6415d
add single precision build CL kernel and launching code
garrettwrong a570155
revert accidental config commit
garrettwrong bc4c3b8
cleanup base cuda code a little
garrettwrong c4ca481
self review cleanup
garrettwrong 1631cff
use adaptive width mode for sync3n tests
garrettwrong 706c2bd
add additional sync3n code paths
garrettwrong f2025c8
must use smaller shift step for unit test size problem
garrettwrong 13cc09f
Remove missed debug string change
garrettwrong 78030fe
Remove range(0,...)
garrettwrong 6a99b05
change var name from dist to xcorr
garrettwrong c47b626
Remove kernel timing
garrettwrong File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,151 @@ | ||
| #include <stdint.h> | ||
| #include <math.h> | ||
| #include <cupy/complex.cuh> | ||
|
|
||
| extern "C" __global__ | ||
| void build_clmatrix_kernel( | ||
| const int n, | ||
| const int m, | ||
| const int r, | ||
| const complex<double>* __restrict__ pf, | ||
| int16_t* const __restrict__ clmatrix, | ||
| const int n_shifts, | ||
| const complex<double>* const __restrict__ shift_phases) | ||
| { | ||
| /* n n_img */ | ||
| /* m angular componentns, n_theta//2 */ | ||
| /* r radial componentns */ | ||
| /* (n, m, r) = pf.shape in python (before transpose for CUDA kernel) */ | ||
|
|
||
| /* thread index (2d), represents "i" and "j" indices */ | ||
| const unsigned int i = blockDim.x * blockIdx.x + threadIdx.x; | ||
| const unsigned int j = blockDim.y * blockIdx.y + threadIdx.y; | ||
|
|
||
| /* no-op when out of bounds */ | ||
| if(i >= n) return; | ||
| if(j >= n) return; | ||
| /* no-op lower triangle */ | ||
| if(j <= i) return; | ||
|
|
||
| int k; | ||
| int s; | ||
| int cl1, cl2; | ||
| int best_cl1, best_cl2; | ||
| double xcorr, best_cl_xcorr; | ||
| double p1, p2; | ||
| complex<double> pfik, pfjk; | ||
|
|
||
| best_cl1 = -1; | ||
| best_cl2 = -1; | ||
| best_cl_xcorr = -INFINITY; | ||
|
|
||
| for(cl1=0; cl1<m; cl1++){ | ||
| for(cl2=0; cl2<m; cl2++){ | ||
| for(s=0; s<n_shifts; s++){ | ||
| p1 = 0; | ||
| p2 = 0; | ||
| /* inner most dim of dot (matmul) */ | ||
| for(k=0; k<r; k++){ | ||
| pfik = pf[k*m*n + cl1*n + i]; | ||
| pfjk = conj(pf[k*m*n + cl2*n + j]) * shift_phases[s*r + k]; | ||
| p1 += real(pfik) * real(pfjk); | ||
| p2 += imag(pfik) * imag(pfjk); | ||
| } /* k */ | ||
|
|
||
| xcorr = p1 - p2; | ||
| if(xcorr > best_cl_xcorr){ | ||
| best_cl_xcorr = xcorr; | ||
| best_cl1 = cl1; | ||
| best_cl2 = cl2; | ||
| } | ||
|
|
||
| xcorr = p1 + p2; | ||
| if(xcorr > best_cl_xcorr){ | ||
| best_cl_xcorr = xcorr; | ||
| best_cl1 = cl1; | ||
| best_cl2 = cl2 + m; /* m is pf.shape[1], which should be n_theta//2 */ | ||
| } | ||
|
|
||
| } /* s */ | ||
| } /* cl2 */ | ||
| }/* cl1 */ | ||
|
|
||
| /* update global best for i, j */ | ||
| clmatrix[i*n + j] = best_cl1; | ||
| clmatrix[j*n+i] = best_cl2; /* [j,i] */ | ||
|
|
||
| } /* build_clmatrix_kernel */ | ||
|
|
||
| extern "C" __global__ | ||
| void fbuild_clmatrix_kernel( | ||
| const int n, | ||
| const int m, | ||
| const int r, | ||
| const complex<float>* __restrict__ pf, | ||
| int16_t* const __restrict__ clmatrix, | ||
| const int n_shifts, | ||
| const complex<float>* const __restrict__ shift_phases) | ||
| { | ||
| /* n n_img */ | ||
| /* m angular componentns, n_theta//2 */ | ||
| /* r radial componentns */ | ||
| /* (n, m, r) = pf.shape in python (before transpose for CUDA kernel) */ | ||
|
|
||
| /* thread index (2d), represents "i" and "j" indices */ | ||
| const unsigned int i = blockDim.x * blockIdx.x + threadIdx.x; | ||
| const unsigned int j = blockDim.y * blockIdx.y + threadIdx.y; | ||
|
|
||
| /* no-op when out of bounds */ | ||
| if(i >= n) return; | ||
| if(j >= n) return; | ||
| /* no-op lower triangle */ | ||
| if(j <= i) return; | ||
|
|
||
| int k; | ||
| int s; | ||
| int cl1, cl2; | ||
| int best_cl1, best_cl2; | ||
| float xcorr, best_cl_xcorr; | ||
| float p1, p2; | ||
| complex<float> pfik, pfjk; | ||
|
|
||
| best_cl1 = -1; | ||
| best_cl2 = -1; | ||
| best_cl_xcorr = -INFINITY; | ||
|
|
||
| for(cl1=0; cl1<m; cl1++){ | ||
| for(cl2=0; cl2<m; cl2++){ | ||
| for(s=0; s<n_shifts; s++){ | ||
| p1 = 0; | ||
| p2 = 0; | ||
| /* inner most dim of dot (matmul) */ | ||
| for(k=0; k<r; k++){ | ||
| pfik = pf[k*m*n + cl1*n + i]; | ||
| pfjk = conj(pf[k*m*n + cl2*n + j]) * shift_phases[s*r + k]; | ||
| p1 += real(pfik) * real(pfjk); | ||
| p2 += imag(pfik) * imag(pfjk); | ||
| } /* k */ | ||
|
|
||
| xcorr = p1 - p2; | ||
| if(xcorr > best_cl_xcorr){ | ||
| best_cl_xcorr = xcorr; | ||
| best_cl1 = cl1; | ||
| best_cl2 = cl2; | ||
| } | ||
|
|
||
| xcorr = p1 + p2; | ||
| if(xcorr > best_cl_xcorr){ | ||
| best_cl_xcorr = xcorr; | ||
| best_cl1 = cl1; | ||
| best_cl2 = cl2 + m; /* m is pf.shape[1], which should be n_theta//2 */ | ||
| } | ||
|
|
||
| } /* s */ | ||
| } /* cl2 */ | ||
| }/* cl1 */ | ||
|
|
||
| /* update global best for i, j */ | ||
| clmatrix[i*n + j] = best_cl1; | ||
| clmatrix[j*n+i] = best_cl2; /* [j,i] */ | ||
|
|
||
| } /* fbuild_clmatrix_kernel */ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.