In [10]:
%%writefile lash_attention_forward.cu
#include <iostream>
#include <cuda_runtime.h>
#include <curand.h>
#include <curand_kernel.h>
#include <cmath>

#define sequence_length 64
#define embed_dimension 64

constexpr int Block_column_size = 16;
constexpr int Block_row_size = 16;

__global__ void initRandom(float* matrix, int rows, int cols, unsigned long seed) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int total = rows * cols;
    if (idx < total) {
        curandState state;
        curand_init(seed, idx, 0, &state);
        matrix[idx] = curand_uniform(&state);
    }
}

__global__ void flashAttentionForward(
    float* Q, float* K, float* V, float* O,
    int seq_len, int dim)
{
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    int col = blockIdx.x * blockDim.x + threadIdx.x;
    if (row < seq_len && col < dim) {
        float sum = 0.0f;
        for (int i = 0; i < seq_len; ++i) {
            float dot = 0.0f;
            for (int d = 0; d < dim; ++d) {
                dot += Q[row * dim + d] * K[i * dim + d];
            }
            float score = expf(dot / sqrtf((float)dim));
            sum += score * V[i * dim + col];
        }
        O[row * dim + col] = sum;
    }
}

void printMatrix(float* matrix, int rows, int cols, int printRows = 5, int printCols = 5) {
    printf("Matrix sample (%dx%d of %dx%d):\n", printRows, printCols, rows, cols);
    for (int i = 0; i < printRows; ++i) {
        for (int j = 0; j < printCols; ++j) {
            printf("%f ", matrix[i * cols + j]);
        }
        printf("\n");
    }
}

int main() {
    size_t size = sequence_length * embed_dimension * sizeof(float);
    float *d_Q, *d_K, *d_V, *d_O;
    cudaMalloc(&d_Q, size);
    cudaMalloc(&d_K, size);
    cudaMalloc(&d_V, size);
    cudaMalloc(&d_O, size);

    dim3 blockDim(256);
    dim3 gridDim((sequence_length * embed_dimension + 255) / 256);
    initRandom<<<gridDim, blockDim>>>(d_Q, sequence_length, embed_dimension, 1234);
    initRandom<<<gridDim, blockDim>>>(d_K, sequence_length, embed_dimension, 5678);
    initRandom<<<gridDim, blockDim>>>(d_V, sequence_length, embed_dimension, 91011);

    dim3 threads(Block_column_size, Block_row_size);
    dim3 grid((embed_dimension + threads.x - 1) / threads.x,
              (sequence_length + threads.y - 1) / threads.y);
    flashAttentionForward<<<grid, threads>>>(d_Q, d_K, d_V, d_O, sequence_length, embed_dimension);

    float* h_Q = new float[sequence_length * embed_dimension];
    float* h_K = new float[sequence_length * embed_dimension];
    float* h_V = new float[sequence_length * embed_dimension];
    float* h_O = new float[sequence_length * embed_dimension];
    cudaMemcpy(h_Q, d_Q, size, cudaMemcpyDeviceToHost);
    cudaMemcpy(h_K, d_K, size, cudaMemcpyDeviceToHost);
    cudaMemcpy(h_V, d_V, size, cudaMemcpyDeviceToHost);
    cudaMemcpy(h_O, d_O, size, cudaMemcpyDeviceToHost);

    printf("Q:\n"); printMatrix(h_Q, sequence_length, embed_dimension);
    printf("K:\n"); printMatrix(h_K, sequence_length, embed_dimension);
    printf("V:\n"); printMatrix(h_V, sequence_length, embed_dimension);
    printf("O:\n"); printMatrix(h_O, sequence_length, embed_dimension);

    delete[] h_Q;
    delete[] h_K;
    delete[] h_V;
    delete[] h_O;
    cudaFree(d_Q);
    cudaFree(d_K);
    cudaFree(d_V);
    cudaFree(d_O);

    return 0;
}


Overwriting lash_attention_forward.cu


In [11]:
!nvcc lash_attention_forward.cu -o lash_attention_forward -gencode arch=compute_75,code=sm_75 -lcurand
!./lash_attention_forward


Q:
Matrix sample (5x5 of 64x64):
0.145468 0.820181 0.550399 0.294830 0.914733 
0.875473 0.221577 0.295817 0.404566 0.389569 
0.127504 0.049953 0.715364 0.101053 0.322029 
0.375354 0.159771 0.114455 0.104466 0.069897 
0.182689 0.044311 0.159555 0.029856 0.563860 
K:
Matrix sample (5x5 of 64x64):
0.661367 0.135027 0.782970 0.186697 0.234071 
0.911707 0.277680 0.795242 0.987023 0.082308 
0.684712 0.102407 0.411118 0.508820 0.921537 
0.143896 0.125743 0.863702 0.980113 0.251365 
0.018184 0.405578 0.006400 0.798228 0.754991 
V:
Matrix sample (5x5 of 64x64):
0.651542 0.707912 0.237569 0.089901 0.431445 
0.977253 0.048838 0.813942 0.426914 0.397408 
0.845735 0.783821 0.840378 0.296731 0.079394 
0.196971 0.339382 0.245803 0.880243 0.291420 
0.702024 0.153379 0.898005 0.907008 0.050992 
O:
Matrix sample (5x5 of 64x64):
261.446136 213.799515 253.103088 228.763107 213.048355 
259.997131 213.284042 252.750275 230.118103 213.506989 
238.396255 196.676163 235.457123 214.321701 201.330933 
227.966705