Skip to content

Commit

Permalink
cuda : fix RoPE after ggerganov#2268 (ggerganov#3897)
Browse files Browse the repository at this point in the history
  • Loading branch information
cebtenzzre authored and olexiyb committed Nov 23, 2023
1 parent 5cf4114 commit cdd4a93
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4539,7 +4539,7 @@ static __global__ void rope(
const int i2 = row/p_delta_rows;

const int p = has_pos ? pos[i2] : 0;
const float theta_base = p*powf(freq_base, -col/ncols);
const float theta_base = p*powf(freq_base, -float(col)/ncols);

float cos_theta, sin_theta;
rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
Expand All @@ -4566,8 +4566,8 @@ static __global__ void rope_neox(
const int i = row*ncols + col/2;
const int i2 = row/p_delta_rows;

// simplified from `(row * ncols + col) * (-1 / ncols)`
const float cur_rot = -col/ncols - row;
// simplified from `(ib * ncols + col) * (-1 / ncols)`, where ib is assumed to be zero
const float cur_rot = -float(col)/ncols;

const int p = has_pos ? pos[i2] : 0;
const float theta_base = p*powf(freq_base, cur_rot);
Expand Down

0 comments on commit cdd4a93

Please sign in to comment.