Skip to content

Commit

Permalink
FIX: make dgemm fallback kernel work for all beta
Browse files Browse the repository at this point in the history
  • Loading branch information
SuperFluffy committed Dec 4, 2018
1 parent aedb824 commit 6d48971
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/dgemm_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -665,13 +665,12 @@ unsafe fn kernel_x86_avx(k: usize, alpha: T, a: *const T, b: *const T,
}

#[inline(always)]
pub unsafe fn kernel_fallback_impl(k: usize, alpha: T, a: *const T, b: *const T,
unsafe fn kernel_fallback_impl(k: usize, alpha: T, a: *const T, b: *const T,
beta: T, c: *mut T, rsc: isize, csc: isize)
{
let mut ab: [[T; NR]; MR] = [[0.; NR]; MR];
let mut a = a;
let mut b = b;
debug_assert_eq!(beta, 0.); // always masked

// Compute matrix multiplication into ab[i][j]
unroll_by!(4 => k, {
Expand All @@ -685,8 +684,12 @@ pub unsafe fn kernel_fallback_impl(k: usize, alpha: T, a: *const T, b: *const T,
($i:expr, $j:expr) => (c.offset(rsc * $i as isize + csc * $j as isize));
}

// set C = α A B
loop_m!(i, loop_n!(j, *c![i, j] = alpha * ab[i][j]));
// set C = α A B + β C
if beta == 0. {
loop_n!(j, loop_m!(i, *c![i, j] = alpha * ab[i][j]));
} else {
loop_n!(j, loop_m!(i, *c![i, j] = *c![i, j] * beta + alpha * ab[i][j]));
}
}

#[inline(always)]
Expand Down

0 comments on commit 6d48971

Please sign in to comment.