<a href="https://colab.research.google.com/github/AndreSlavescu/llmdotc_testbench_tools/blob/main/llmc_testbench.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%writefile common.h
#include <stdlib.h>
#include <stdio.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <cublasLt.h>
#include <float.h>

template<class T>
__host__ __device__ T ceil_div(T dividend, T divisor) {
    return (dividend + divisor-1) / divisor;
}

__device__ float warpReduceSum(float val) {
    for (int offset = 16; offset > 0; offset /= 2) {
        val += __shfl_xor_sync(0xFFFFFFFF, val, offset);
    }
    return val;
}


// CUDA error checking
void cudaCheck(cudaError_t error, const char *file, int line) {
  if (error != cudaSuccess) {
    printf("[CUDA ERROR] at file %s:%d:\n%s\n", file, line, cudaGetErrorString(error));
    exit(EXIT_FAILURE);
  }
}
#define cudaCheck(err) (cudaCheck(err, __FILE__, __LINE__))

float* make_random_float(size_t N) {
    float* arr = (float*)malloc(N * sizeof(float));
    for (size_t i = 0; i < N; i++) {
        arr[i] = ((float)rand() / RAND_MAX) * 2.0 - 1.0; // range -1..1
    }
    return arr;
}

template<class D, class T>
void validate_result(D* device_result, const T* cpu_reference, const char* name, std::size_t num_elements, T tolerance=1e-4) {
    D* out_gpu = (D*)malloc(num_elements * sizeof(D));
    cudaCheck(cudaMemcpy(out_gpu, device_result, num_elements * sizeof(D), cudaMemcpyDeviceToHost));
    int nfaults = 0;
#ifndef ENABLE_BF16
    float epsilon = FLT_EPSILON;
#else
    float epsilon = 0.079;
#endif
    for (int i = 0; i < num_elements; i++) {
        // Skip masked elements
        if(!isfinite(cpu_reference[i]))
            continue;

        // print the first few comparisons
        if (i < 5) {
            printf("%f %f\n", cpu_reference[i], (T)out_gpu[i]);
        }
        // effective tolerance is based on expected rounding error (epsilon),
        // plus any specified additional tolerance
        float t_eff = tolerance + fabs(cpu_reference[i]) * epsilon;
        // ensure correctness for all elements.
        if (fabs(cpu_reference[i] - (T)out_gpu[i]) > t_eff) {
            printf("Mismatch of %s at %d: CPU_ref: %f vs GPU: %f\n", name, i, cpu_reference[i], (T)out_gpu[i]);
            nfaults ++;
            if (nfaults >= 10) {
                free(out_gpu);
                exit(EXIT_FAILURE);
            }
        }
    }

    if (nfaults > 0) {
        free(out_gpu);
        exit(EXIT_FAILURE);
    }

    free(out_gpu);
}

template<class Kernel, class... KernelArgs>
float benchmark_kernel(int repeats, Kernel kernel, KernelArgs&&... kernel_args) {
    cudaEvent_t start, stop;
    // prepare buffer to scrub L2 cache between benchmarks
    // just memset a large dummy array, recommended by
    // https://stackoverflow.com/questions/31429377/how-can-i-clear-flush-the-l2-cache-and-the-tlb-of-a-gpu
    // and apparently used in nvbench.
    int deviceIdx = 0;
    cudaCheck(cudaSetDevice(deviceIdx));
    cudaDeviceProp deviceProp;
    cudaCheck(cudaGetDeviceProperties(&deviceProp, deviceIdx));
    void* flush_buffer;
    cudaCheck(cudaMalloc(&flush_buffer, deviceProp.l2CacheSize));

    cudaCheck(cudaEventCreate(&start));
    cudaCheck(cudaEventCreate(&stop));
    float elapsed_time = 0.f;
    for (int i = 0; i < repeats; i++) {
        // clear L2
        cudaCheck(cudaMemset(flush_buffer, 0, deviceProp.l2CacheSize));
        // now we can start recording the timing of the kernel
        cudaCheck(cudaEventRecord(start, nullptr));
        kernel(std::forward<KernelArgs>(kernel_args)...);
        cudaCheck(cudaEventRecord(stop, nullptr));
        cudaCheck(cudaEventSynchronize(start));
        cudaCheck(cudaEventSynchronize(stop));
        float single_call;
        cudaCheck(cudaEventElapsedTime(&single_call, start, stop));
        elapsed_time += single_call;
    }

    cudaCheck(cudaFree(flush_buffer));

    return elapsed_time / repeats;
}


Overwriting common.h


In [None]:
%%writefile rmsnorm_forward.cu
#include <stdio.h>
#include <stdlib.h>
#include <cuda_runtime.h>
#include <assert.h>
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include "common.h"

// Root Mean Square Layernorm Forward Pass
void rmsnorm_forward_cpu(
    float *out,
    float *rms,
    const float *inp,
    const float *weight,
    const float *bias,
    int B,
    int T,
    int C
) {
    const float eps = 1e-6f;

    for (int b = 0; b < B; b++) {
        for (int t = 0; t < T; t++) {
            // seek to the input position inp[b,t,:]
            const float* x = inp + b * T * C + t * C;
            // compute RMS
            float sum_of_squares = 0.0f;
            for (int i = 0; i < C; i++) {
                sum_of_squares += x[i] * x[i];
            }
            float rms_val = rsqrtf(sum_of_squares / C + eps);
            // seek to the output position in out[b,t,:]
            float* out_bt = out + b * T * C + t * C;
            for (int i = 0; i < C; i++) {
                float n = x[i] * rms_val; // normalized output
                float o = n * weight[i] + bias[i]; // scale and shift it
                out_bt[i] = o; // write
            }
            // cache the rms for the backward pass later
            rms[b * T + t] = rms_val;
        }
    }
}

// ----------------------------------------------------------------------------
// GPU kernels

__global__ void rmsnorm_forward_kernel1(
    float* out,
    float* rms,
    const float* inp,
    const float* weight,
    const float* bias,
    int N,
    int C
) {
    const float eps = 1e-6f;
    int idx = blockIdx.x * blockDim.x + threadIdx.x;

    if (idx < N) {
        // Seek to the input position inp[idx,:]
        const float* x = inp + idx * C;

        // Calculate the sum of squares
        float sum_of_squares = 0.0f;

        #pragma unroll
        for (int i = 0; i < C; i++) {
            sum_of_squares += x[i] * x[i];
        }

        // Compute RMS value
        sum_of_squares = sum_of_squares / C;
        float rms_val = rsqrtf(sum_of_squares + eps);

        // Seek to the output position in out[idx,:]
        float* out_idx = out + idx * C;

        #pragma unroll
        for (int i = 0; i < C; i++) {
            float n = x[i] * rms_val; // Normalized output
            float o = n * weight[i] + bias[i]; // Scale and shift it
            out_idx[i] = o; // Write
        }

        // Cache the RMS for the backward pass later
        rms[idx] = rms_val;
    }
}

__global__ void rms_val_kernel(
    float* rms,
    const float* inp,
    int N,
    int C,
    int block_size
) {
    extern __shared__ float shared[];
    int idx = blockIdx.x; // range [0, B*T)
    int tid = threadIdx.x; // range [0, blocksize]
    const float *x = inp + idx * C;

    const float eps = 1e-6f;
    float sum_of_squares = 0.0f;

    #pragma unroll
    for (int i = tid; i < C; i += block_size) {
        sum_of_squares += x[i] * x[i];
    }
    shared[tid] = sum_of_squares;
    __syncthreads();

    #pragma unroll
    for (int stride = block_size >> 1; stride > 0; stride >>= 1) {
        __syncthreads();
        if (tid < stride) {
            shared[tid] += shared[tid + stride];
        }
    }

    if (tid == 0) {
        rms[idx] = rsqrt(shared[0] / C + eps); // write back accumulated value in thread 0
    }
}

__global__ void rmsnorm_forward_kernel2(
    float* out,
    float* rms,
    const float* inp,
    const float* weight,
    const float* bias,
    int B,
    int T,
    int C
) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int bt = idx / C;
    int c = idx % C;

    float rms_val = rms[bt];
    float xi = inp[idx];
    float n = xi * rms_val;
    float o = n * weight[c] + bias[c];

    out[idx] = o;
}

__global__ void rmsnorm_forward_kernel3(
    float* __restrict__ out,
    float* __restrict__ rms,
    const float* __restrict__ inp,
    const float* __restrict__ weight,
    const float* __restrict__ bias,
    int B,
    int T,
    int C
) {
    namespace cg = cooperative_groups;
    constexpr unsigned WARP_SIZE = 32;

    int num_warps = blockDim.x / WARP_SIZE;
    int lane_id = threadIdx.x % WARP_SIZE;
    int warp_id = threadIdx.x / WARP_SIZE;
    int idx = blockIdx.x;

    __shared__ float shared[WARP_SIZE];
    const float *x = inp + idx * C;

    const float eps = 1e-6f;
    float thread_sum_of_squares = 0.0f;

    #pragma unroll
    for (int i = threadIdx.x; i < C; i += blockDim.x) {
        float xi = x[i];
        thread_sum_of_squares += xi * xi;
    }

    cg::thread_block block = cg::this_thread_block();
    cg::thread_block_tile<WARP_SIZE> warp = cg::tiled_partition<WARP_SIZE>(block);

    float warp_sum_of_squares = cg::reduce(warp, thread_sum_of_squares, cg::plus<float>{}); // sum(x * x)
    if (lane_id == 0) {
        shared[warp_id] = warp_sum_of_squares;
        __syncthreads();
    }

    warp_sum_of_squares = (lane_id < num_warps) ? shared[lane_id] : 0.0f;
    float block_sum_of_squares = cg::reduce(warp, warp_sum_of_squares, cg::plus<float>{}); // sum(x * x)

    // compute rms
    float rms_val = rsqrtf(block_sum_of_squares / C + eps);
    if (threadIdx.x == 0 && rms != nullptr) {
        __stcs(rms + idx, rms_val);
    }

    float *o = out + idx * C;

    #pragma unroll
    for (int i = threadIdx.x; i < C; i += blockDim.x) {
        float n =  __ldcs(x+i) * rms_val;
        __stcs(o+i, n * weight[i] + bias[i]);
    }
}

// ----------------------------------------------------------------------------
// kernel launcher

void rmsnorm_forward1(
    float* out,
    float* rms,
    const float* inp,
    const float* weight,
    const float* bias,
    int B,
    int T,
    int C,
    const int block_size
) {
    const int N = B * T;
    const int grid_size = ceil_div(N, block_size);
    rmsnorm_forward_kernel1<<<grid_size, block_size>>>(out, rms, inp, weight, bias, N, C);
    cudaCheck(cudaGetLastError());
}

void rmsnorm_forward2(
    float* out,
    float* rms,
    const float* inp,
    const float* weight,
    const float* bias,
    int B,
    int T,
    int C,
    const int block_size
) {
    int N = B * T;
    // in rms, threads cooperate within blocks via reductions
    rms_val_kernel<<<B * T, block_size, block_size * sizeof(float)>>>(rms, inp, N, C, block_size);
    cudaCheck(cudaGetLastError());
    const int grid_size = ceil_div(B * T * C, block_size);
    rmsnorm_forward_kernel2<<<grid_size, block_size>>>(out, rms, inp, weight, bias, B, T, C);
    cudaCheck(cudaGetLastError());
}

void rmsnorm_forward3(
    float* out,
    float* rms,
    const float* inp,
    const float* weight,
    const float* bias,
    int B,
    int T,
    int C,
    const int block_size
) {
    assert(block_size % 32 == 0);
    const int N = B * T;
    const int grid_size = N;
    rmsnorm_forward_kernel3<<<grid_size, block_size>>>(out, rms, inp, weight, bias, B, T, C);
    cudaCheck(cudaGetLastError());
}

// kernel version dispatch
void rmsnorm_forward(
    int kernel_num,
    float* out,
    float* rms,
    const float* inp,
    const float* weight,
    const float* bias,
    int B,
    int T,
    int C,
    const int block_size
) {
    switch (kernel_num) {
        case 1:
            rmsnorm_forward1(out, rms, inp, weight, bias, B, T, C, block_size);
            break;
        case 2:
            rmsnorm_forward2(out, rms, inp, weight, bias, B, T, C, block_size);
            break;
        case 3:
            rmsnorm_forward3(out, rms, inp, weight, bias, B, T, C, block_size);
            break;
        default:
            printf("Invalid kernel number\n");
            exit(1);
    }
}

// ----------------------------------------------------------------------------

int main(int argc, char **argv) {
    srand(0);

    int B = 8;
    int T = 1024;
    int C = 768;

    int deviceIdx = 0;
    cudaCheck(cudaSetDevice(deviceIdx));

    // create host memory of random numbers
    float* out = (float*)malloc(B * T * C * sizeof(float));
    float* rms = (float*)malloc(B * T * sizeof(float));
    float* inp = make_random_float(B * T * C);
    float* weight = make_random_float(C);
    float* bias = make_random_float(C);

    // move to GPU
    float* d_out;
    float* d_rms;
    float* d_inp;
    float* d_weight;
    float* d_bias;
    cudaCheck(cudaMalloc(&d_out, B * T * C * sizeof(float)));
    cudaCheck(cudaMalloc(&d_rms, B * T * sizeof(float)));
    cudaCheck(cudaMalloc(&d_inp, B * T * C * sizeof(float)));
    cudaCheck(cudaMalloc(&d_weight, C * sizeof(float)));
    cudaCheck(cudaMalloc(&d_bias, C * sizeof(float)));
    cudaCheck(cudaMemcpy(d_inp, inp, B * T * C * sizeof(float), cudaMemcpyHostToDevice));
    cudaCheck(cudaMemcpy(d_weight, weight, C * sizeof(float), cudaMemcpyHostToDevice));
    cudaCheck(cudaMemcpy(d_bias, bias, C * sizeof(float), cudaMemcpyHostToDevice));

    // read kernel_num from command line
    int kernel_num = 3;
    if (argc > 1) {
        kernel_num = atoi(argv[1]);
    }
    printf("Using kernel %d\n", kernel_num);

    int block_sizes[] = {32, 64, 128, 256, 512, 1024};
    float* out_gpu = (float*)malloc(B * T * C * sizeof(float));
    float* rms_gpu = (float*)malloc(B * T * sizeof(float));

    rmsnorm_forward_cpu(out, rms, inp, weight, bias, B, T, C);

    // check the correctness of the kernel at all block sizes
    for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {
        int block_size = block_sizes[j];
        printf("Checking block size %d.\n", block_size);

        rmsnorm_forward(kernel_num, d_out, d_rms, d_inp, d_weight, d_bias, B, T, C, block_size);

        validate_result(d_out, out, "out", B * T * C, 1e-5f);
        validate_result(d_rms, rms, "rms", B * T, 1e-5f);
    }

    printf("All results match. Starting benchmarks.\n\n");

    // time the kernel at different block sizes
    for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {
        int block_size = block_sizes[j];

        int repeat_times = 2000;
        float elapsed_time = benchmark_kernel(
                                repeat_times,
                                rmsnorm_forward,
                                kernel_num,
                                d_out,
                                d_rms,
                                d_inp,
                                d_weight,
                                d_bias,
                                B,
                                T,
                                C,
                                block_size
                            );

        // napkin math: estimate the memory bandwidth achieved
        // e.g. A100 40GB PCIe is advertised at 1,555GB/s
        long memory_ops = (2 * B * T * C) * 4; // *4 for float
        float memory_bandwidth = memory_ops / elapsed_time / 1e6;

        printf("block_size %4d | time %.4f ms | bandwidth %.2f GB/s\n", block_size, elapsed_time, memory_bandwidth);
    }

    // free memory
    free(out);
    free(rms);
    free(inp);
    free(weight);
    free(bias);
    cudaCheck(cudaFree(d_out));
    cudaCheck(cudaFree(d_rms));
    cudaCheck(cudaFree(d_inp));
    cudaCheck(cudaFree(d_weight));
    cudaCheck(cudaFree(d_bias));

    return 0;
}



Overwriting rmsnorm_forward.cu


In [None]:
!nvcc -I /usr/local/cuda/samples/common/inc/ -L/usr/local/cuda/include -lcublas -lcusolver -O3 --use_fast_math -lcublas -std=c++17 rmsnorm_forward.cu -o rmsnorm_forward

In [None]:
!./rmsnorm_forward

Using kernel 3
Checking block size 32.
1.101622 1.101622
0.951610 0.951610
1.522446 1.522446
-1.583111 -1.583111
-0.604489 -0.604489
1.769185 1.769185
1.743165 1.743165
1.696036 1.696035
1.725868 1.725868
1.705505 1.705506
Checking block size 64.
1.101622 1.101622
0.951610 0.951610
1.522446 1.522446
-1.583111 -1.583111
-0.604489 -0.604489
1.769185 1.769185
1.743165 1.743165
1.696036 1.696035
1.725868 1.725868
1.705505 1.705506
Checking block size 128.
1.101622 1.101622
0.951610 0.951610
1.522446 1.522446
-1.583111 -1.583111
-0.604489 -0.604489
1.769185 1.769185
1.743165 1.743165
1.696036 1.696035
1.725868 1.725868
1.705505 1.705506
Checking block size 256.
1.101622 1.101622
0.951610 0.951610
1.522446 1.522446
-1.583111 -1.583111
-0.604489 -0.604489
1.769185 1.769185
1.743165 1.743165
1.696036 1.696035
1.725868 1.725868
1.705505 1.705506
Checking block size 512.
1.101622 1.101622
0.951610 0.951610
1.522446 1.522446
-1.583111 -1.583111
-0.604489 -0.604489
1.769185 1.769185
1.743165 1.743

In [None]:
%%writefile simple_test.c
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <assert.h>

// Root Mean Square Layernorm Forward Pass
void rmsnorm_forward_cpu(
    float *out,
    float *rms,
    const float *inp,
    const float *weight,
    const float *bias,
    int B,
    int T,
    int C
) {
    const float eps = 1e-6f;

    for (int b = 0; b < B; b++) {
        for (int t = 0; t < T; t++) {
            // seek to the input position inp[b,t,:]
            const float* x = inp + b * T * C + t * C;
            // compute RMS
            float sum_of_squares = 0.0f;
            for (int i = 0; i < C; i++) {
                sum_of_squares += x[i] * x[i];
            }
            float rms_val = 1.0f / sqrtf(sum_of_squares / C + eps);
            // seek to the output position in out[b,t,:]
            float* out_bt = out + b * T * C + t * C;
            for (int i = 0; i < C; i++) {
                float n = x[i] * rms_val; // normalized output
                float o = n * weight[i] + bias[i]; // scale and shift it
                out_bt[i] = o; // write
            }
            // cache the rms for the backward pass later
            rms[b * T + t] = rms_val;
        }
    }
}

// Input data
float inp[2 * 3 * 4] = {
    1.0, 2.0, 3.0, 4.0,
    5.0, 6.0, 7.0, 8.0,
    9.0, 10.0, 11.0, 12.0,
    13.0, 14.0, 15.0, 16.0,
    17.0, 18.0, 19.0, 20.0,
    21.0, 22.0, 23.0, 24.0
};

// Weights and bias
float weight[4] = {1.0, 1.0, 1.0, 1.0};
float bias[4] = {0.0, 0.0, 0.0, 0.0};

// Simplified main function to test a small batch
int main() {
    int B = 2;
    int T = 3;
    int C = 4;

    // Outputs
    float out[B * T * C];
    float rms[B * T];

    rmsnorm_forward_cpu(out, rms, inp, weight, bias, B, T, C);

    // Print outputs for comparison
    printf("Output:\n");
    for (int i = 0; i < B * T * C; i++) {
        printf("%f ", out[i]);
        if ((i + 1) % C == 0) {
            printf("\n");
        }
    }

    printf("RMS:\n");
    for (int i = 0; i < B * T; i++) {
        printf("%f ", rms[i]);
        if ((i + 1) % T == 0) {
            printf("\n");
        }
    }

    return 0;
}


Overwriting simple_test.c


In [None]:
!gcc simple_test.c -o simple_test -lm

In [None]:
!./simple_test

Output:
0.365148 0.730297 1.095445 1.460593 
0.758098 0.909718 1.061337 1.212957 
0.852325 0.947028 1.041730 1.136433 
0.893898 0.962660 1.031421 1.100183 
0.917245 0.971201 1.025157 1.079112 
0.932183 0.976573 1.020963 1.065352 
RMS:
0.365148 0.151620 0.094703 
0.068761 0.053956 0.044390 


In [None]:
import torch
import torch.nn as nn
import numpy as np

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

# Input data
B, T, C = 2, 3, 4
inputs = torch.tensor([
    [1.0, 2.0, 3.0, 4.0],
    [5.0, 6.0, 7.0, 8.0],
    [9.0, 10.0, 11.0, 12.0],
    [13.0, 14.0, 15.0, 16.0],
    [17.0, 18.0, 19.0, 20.0],
    [21.0, 22.0, 23.0, 24.0]
]).reshape(B, T, C)

# PyTorch RMSNorm model
model = RMSNorm(C)
with torch.no_grad():
    torch_output = model(inputs).numpy()

print("PyTorch Output:")
print(torch_output)

PyTorch Output:
[[[0.36514837 0.73029673 1.0954452  1.4605935 ]
  [0.75809807 0.9097177  1.0613372  1.2129569 ]
  [0.8523247  0.9470275  1.0417303  1.136433  ]]

 [[0.8938984  0.96265984 1.0314212  1.1001827 ]
  [0.91724545 0.97120106 1.0251567  1.0791123 ]
  [0.9321832  0.9765729  1.0209626  1.0653522 ]]]


In [None]:
%%writefile llama3_rope_forward.c

#include <stdio.h>
#include <stdlib.h>
#include <math.h>

// Reshape function to split the real and imaginary parts
void reshape_complex(float* inp, int B, int T, int C, float* out_r, float* out_i) {
    for (int b = 0; b < B; b++) {
        for (int t = 0; t < T; t++) {
            for (int c = 0; c < C/2; c++) {
                int idx = b * T * C + t * C + 2 * c;
                out_r[b * T * (C/2) + t * (C/2) + c] = inp[idx];
                out_i[b * T * (C/2) + t * (C/2) + c] = inp[idx + 1];
            }
        }
    }
}

void apply_rotary_emb_cpu(
    float* xq_inp,
    float* xk_inp,
    float* freqs_cos,
    float* freqs_sin,
    float* xq_out,
    float* xk_out,
    int B,
    int T,
    int C
) {
    float* xq_r = (float*)malloc(B * T * (C/2) * sizeof(float));
    float* xq_i = (float*)malloc(B * T * (C/2) * sizeof(float));
    float* xk_r = (float*)malloc(B * T * (C/2) * sizeof(float));
    float* xk_i = (float*)malloc(B * T * (C/2) * sizeof(float));

    reshape_complex(xq_inp, B, T, C, xq_r, xq_i);
    reshape_complex(xk_inp, B, T, C, xk_r, xk_i);

    for (int b = 0; b < B; b++) {
        for (int t = 0; t < T; t++) {
            for (int c = 0; c < C/2; c++) {
                int idx = b * T * (C/2) + t * (C/2) + c;
                float xq_r_val = xq_r[idx];
                float xq_i_val = xq_i[idx];
                float xk_r_val = xk_r[idx];
                float xk_i_val = xk_i[idx];

                float cos_val = freqs_cos[c];
                float sin_val = freqs_sin[c];

                xq_out[idx * 2] = xq_r_val * cos_val - xq_i_val * sin_val;
                xq_out[idx * 2 + 1] = xq_r_val * sin_val + xq_i_val * cos_val;

                xk_out[idx * 2] = xk_r_val * cos_val - xk_i_val * sin_val;
                xk_out[idx * 2 + 1] = xk_r_val * sin_val + xk_i_val * cos_val;
            }
        }
    }

    free(xq_r);
    free(xq_i);
    free(xk_r);
    free(xk_i);
}

int main() {
    int B = 2;
    int T = 3;
    int C = 4;

    float* xq_inp = (float*)malloc(B * T * C * sizeof(float));
    float* xk_inp = (float*)malloc(B * T * C * sizeof(float));
    float* freqs_cos = (float*)malloc((C/2) * sizeof(float));
    float* freqs_sin = (float*)malloc((C/2) * sizeof(float));
    float* xq_out = (float*)malloc(B * T * C * sizeof(float));
    float* xk_out = (float*)malloc(B * T * C * sizeof(float));

    for (int i = 0; i < B * T * C; i++) {
        xq_inp[i] = i + 1;
        xk_inp[i] = i + 1;
    }
    for (int i = 0; i < C/2; i++) {
        freqs_cos[i] = cos(i);
        freqs_sin[i] = sin(i);
    }

    apply_rotary_emb_cpu(xq_inp, xk_inp, freqs_cos, freqs_sin, xq_out, xk_out, B, T, C);

    for (int i = 0; i < B * T * C; i++) {
        printf("xq_out[%d] = %f\n", i, xq_out[i]);
        printf("xk_out[%d] = %f\n", i, xk_out[i]);
    }

    free(xq_inp);
    free(xk_inp);
    free(freqs_cos);
    free(freqs_sin);
    free(xq_out);
    free(xk_out);

    return 0;
}



Overwriting llama3_rope_forward.c


In [None]:
!gcc llama3_rope_forward.c -o llama3_rope_forward -lm

In [None]:
!./llama3_rope_forward

xq_out[0] = 1.000000
xk_out[0] = 1.000000
xq_out[1] = 2.000000
xk_out[1] = 2.000000
xq_out[2] = -1.744977
xk_out[2] = -1.744977
xq_out[3] = 4.685622
xk_out[3] = 4.685622
xq_out[4] = 5.000000
xk_out[4] = 5.000000
xq_out[5] = 6.000000
xk_out[5] = 6.000000
xq_out[6] = -2.949652
xk_out[6] = -2.949652
xq_out[7] = 10.212715
xk_out[7] = 10.212715
xq_out[8] = 9.000000
xk_out[8] = 9.000000
xq_out[9] = 10.000000
xk_out[9] = 10.000000
xq_out[10] = -4.154326
xk_out[10] = -4.154326
xq_out[11] = 15.739808
xk_out[11] = 15.739808
xq_out[12] = 13.000000
xk_out[12] = 13.000000
xq_out[13] = 14.000000
xk_out[13] = 14.000000
xq_out[14] = -5.359001
xk_out[14] = -5.359001
xq_out[15] = 21.266901
xk_out[15] = 21.266901
xq_out[16] = 17.000000
xk_out[16] = 17.000000
xq_out[17] = 18.000000
xk_out[17] = 18.000000
xq_out[18] = -6.563675
xk_out[18] = -6.563675
xq_out[19] = 26.793995
xk_out[19] = 26.793995
xq_out[20] = 21.000000
xk_out[20] = 21.000000
xq_out[21] = 22.000000
xk_out[21] = 22.000000
xq_out[22] = -7.7683

In [None]:
import torch
import numpy as np

from typing import Tuple

# Helper function to reshape for broadcasting
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(shape)

# Rotary Embedding Function
def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cos: torch.Tensor,
    freqs_sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:

    # reshape xq and xk to match the complex representation
    xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
    xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)

    # reshape freqs_cos and freqs_sin for broadcasting
    freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
    freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)

    # apply rotation using real numbers
    xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
    xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
    xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
    xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos

    # flatten last two dimensions
    xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
    xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)

    return xq_out.type_as(xq), xk_out.type_as(xk)

# Input data
B, T, C = 2, 3, 4
xq_inp = torch.tensor([
    [1.0, 2.0, 3.0, 4.0],
    [5.0, 6.0, 7.0, 8.0],
    [9.0, 10.0, 11.0, 12.0],
    [13.0, 14.0, 15.0, 16.0],
    [17.0, 18.0, 19.0, 20.0],
    [21.0, 22.0, 23.0, 24.0]
]).reshape(B, T, C)

xk_inp = torch.tensor([
    [1.0, 2.0, 3.0, 4.0],
    [5.0, 6.0, 7.0, 8.0],
    [9.0, 10.0, 11.0, 12.0],
    [13.0, 14.0, 15.0, 16.0],
    [17.0, 18.0, 19.0, 20.0],
    [21.0, 22.0, 23.0, 24.0]
]).reshape(B, T, C)

# Precompute frequency components
freqs_cos = torch.tensor([np.cos(i) for i in range(C//2)]).repeat(T, 1)
freqs_sin = torch.tensor([np.sin(i) for i in range(C//2)]).repeat(T, 1)

# Apply rotary embeddings using PyTorch
xq_out, xk_out = apply_rotary_emb(xq_inp, xk_inp, freqs_cos, freqs_sin)

print("xq_out:")
print(xq_out)
print("xk_out:")
print(xk_out)


xq_out:
tensor([[[[ 1.0000,  2.0000],
          [-1.7450,  4.6856]],

         [[ 5.0000,  6.0000],
          [-2.9497, 10.2127]],

         [[ 9.0000, 10.0000],
          [-4.1543, 15.7398]]],


        [[[13.0000, 14.0000],
          [-5.3590, 21.2669]],

         [[17.0000, 18.0000],
          [-6.5637, 26.7940]],

         [[21.0000, 22.0000],
          [-7.7684, 32.3211]]]])
xk_out:
tensor([[[[ 1.0000,  2.0000],
          [-1.7450,  4.6856]],

         [[ 5.0000,  6.0000],
          [-2.9497, 10.2127]],

         [[ 9.0000, 10.0000],
          [-4.1543, 15.7398]]],


        [[[13.0000, 14.0000],
          [-5.3590, 21.2669]],

         [[17.0000, 18.0000],
          [-6.5637, 26.7940]],

         [[21.0000, 22.0000],
          [-7.7684, 32.3211]]]])
