# DAY 20: Rotary Position Embedding (RoPE)

Implementation of Rotary Position Embedding used in modern transformer architectures like LLaMA and GPT-NeoX.



In [None]:
%%writefile rope.cu
// nvcc rope.cu -o rope -lm

#include <cuda_runtime.h>
#include <math.h>

__device__ void apply_rotary_embedding(
    float* q,           // query vectors
    float* k,           // key vectors
    const int head_dim, // dimension of each head
    const int position, // absolute position in sequence
    const float base = 10000.0f
) {
    // Process pairs of elements (real, imaginary)
    for (int i = 0; i < head_dim; i += 2) {
        float freq = 1.0f / powf(base, (float)(i) / head_dim);
        float theta = position * freq;
        
        // Calculate rotation matrix elements
        float cos_theta = cosf(theta);
        float sin_theta = sinf(theta);
        
        // Cache original values
        float q_real = q[i];
        float q_img = q[i + 1];
        float k_real = k[i];
        float k_img = k[i + 1];
        
        // Apply rotation to query
        q[i] = q_real * cos_theta - q_img * sin_theta;
        q[i + 1] = q_real * sin_theta + q_img * cos_theta;
        
        // Apply rotation to key
        k[i] = k_real * cos_theta - k_img * sin_theta;
        k[i + 1] = k_real * sin_theta + k_img * cos_theta;
    }
}

__global__ void rope_kernel(
    float* queries,        // [batch_size, seq_len, num_heads, head_dim]
    float* keys,          // [batch_size, seq_len, num_heads, head_dim]
    const int batch_size,
    const int seq_len,
    const int num_heads,
    const int head_dim
) {
    // Calculate global position
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    
    // Calculate batch, sequence position, and head indices
    int batch_idx = idx / (seq_len * num_heads);
    int seq_idx = (idx / num_heads) % seq_len;
    int head_idx = idx % num_heads;
    
    if (batch_idx >= batch_size) return;
    
    // Calculate base pointer offsets
    int base_idx = batch_idx * (seq_len * num_heads * head_dim) + 
                   seq_idx * (num_heads * head_dim) +
                   head_idx * head_dim;
    
    // Apply rotary embedding to this position
    apply_rotary_embedding(
        &queries[base_idx],
        &keys[base_idx],
        head_dim,
        seq_idx
    );
}

// Helper function to launch the kernel
void apply_rope(
    float* d_queries,
    float* d_keys,
    const int batch_size,
    const int seq_len,
    const int num_heads,
    const int head_dim
) {
    dim3 block_size(256);
    dim3 grid_size((batch_size * seq_len * num_heads + block_size.x - 1) / block_size.x);
    
    rope_kernel<<<grid_size, block_size>>>(
        d_queries,
        d_keys,
        batch_size,
        seq_len,
        num_heads,
        head_dim
    );
}

#include <stdio.h>
#include <stdlib.h>

int main() {
    // Configuration
    const int batch_size = 2;
    const int seq_len = 4;
    const int num_heads = 2;
    const int head_dim = 4;
    
    const int total_size = batch_size * seq_len * num_heads * head_dim;
    
    // Allocate host memory
    float *h_queries = (float*)malloc(total_size * sizeof(float));
    float *h_keys = (float*)malloc(total_size * sizeof(float));
    
    // Initialize with simple values
    for (int i = 0; i < total_size; i++) {
        h_queries[i] = (float)(i % 8) * 0.1f;
        h_keys[i] = (float)(i % 6) * 0.2f;
    }
    
    // Allocate device memory
    float *d_queries, *d_keys;
    cudaMalloc(&d_queries, total_size * sizeof(float));
    cudaMalloc(&d_keys, total_size * sizeof(float));
    
    // Copy to device
    cudaMemcpy(d_queries, h_queries, total_size * sizeof(float), cudaMemcpyHostToDevice);
    cudaMemcpy(d_keys, h_keys, total_size * sizeof(float), cudaMemcpyHostToDevice);
    
    // Apply RoPE
    apply_rope(d_queries, d_keys, batch_size, seq_len, num_heads, head_dim);
    
    // Copy back to host
    cudaMemcpy(h_queries, d_queries, total_size * sizeof(float), cudaMemcpyDeviceToHost);
    cudaMemcpy(h_keys, d_keys, total_size * sizeof(float), cudaMemcpyDeviceToHost);
    
    // Print sample results
    printf("Sample Query after RoPE (first 8 elements): ");
    for (int i = 0; i < 8; i++) {
        printf("%.3f ", h_queries[i]);
    }
    printf("\n");
    
    printf("Sample Key after RoPE (first 8 elements): ");
    for (int i = 0; i < 8; i++) {
        printf("%.3f ", h_keys[i]);
    }
    printf("\n");
    
    // Cleanup
    free(h_queries);
    free(h_keys);
    cudaFree(d_queries);
    cudaFree(d_keys);
    
    return 0;
}

In [None]:
# Compile and run the RoPE implementation
!nvcc rope.cu -o rope -lm
!./rope

## Output:
```
Sample Query after RoPE (first 8 elements): 0.000 0.100 0.200 0.300 0.000 0.050 0.141 0.212 
Sample Key after RoPE (first 8 elements): 0.000 0.200 0.400 0.600 0.000 0.100 0.283 0.424 
```