<a href="https://colab.research.google.com/github/AndreSlavescu/Intermediate-Gauss-Seidel-Decoding/blob/main/Intermediate_Gauss_Seidel_Decoding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!nvidia-smi

In [None]:
!pip install pycuda

In [95]:
import numpy as np
import numpy.linalg as la

import pycuda.autoinit
import pycuda.driver as cuda
from pycuda.compiler import SourceModule
from pycuda.driver import Event

# Kernel code
"""
Idea:

The Gauss-Seidel Iteration Method, while inherently parallel for computing indices
of the vector x in Ax = b, it suffers from a sequential nature when computing n > 1
iterations. The idea with the below kernel is to perform a sort of "jump iteration",
where even indices of x are computed for even iterations and odd indices for odd iterations,
allowing for a two-fold parallelism in the convergence for finding the solution.
The implementation below along with the test for 100 iterations suggests that this ideology
may be effective, and can be applied to problems such as parallel token decoding in LLMs,
as seen in lookahead decoding (https://lmsys.org/blog/2023-11-21-lookahead-decoding/).
"""


kernel_code = '''
__global__ void jump_iteration_gauss_seidel(float *A, float *b, float *x_1, float *x_2, int size, int iterations) {
    int index = blockIdx.x * blockDim.x + threadIdx.x;
    float *x_read, *x_write;

    if (index < size) {
        #pragma unroll
        for (int iter = 0; iter < iterations; ++iter) {
            // Determine the read and write buffers
            if (iter % 2 == 0) {
                x_read = x_1;
                x_write = x_2;
            } else {
                x_read = x_2;
                x_write = x_1;
            }

            // jump iteration update logic
            bool even_iteration = (iter % 2 == 0);
            bool is_even_index = ((index / sqrt((float)size)) + ((index % (int)sqrt((float)size))) % 2 == 0);

            if (even_iteration == is_even_index) {
                float sum = 0.0;

                #pragma unroll
                for (int j = 0; j < size; ++j) {
                    if (j != index) {
                        sum += A[index * size + j] * x_read[j];
                    }
                }
                x_write[index] = (b[index] - sum) / A[index * size + index];
            } else {
                x_write[index] = x_read[index];
            }
            __syncthreads();
        }

        // assign latest values to x_1
        if (iterations % 2 != 0 && index < size) {
            x_1[index] = x_write[index];
        }
    }
}
'''

# Compile the kernel code
mod = SourceModule(kernel_code)
jump_iteration_gauss_seidel = mod.get_function("jump_iteration_gauss_seidel")

def run_gauss_seidel_gpu(A, b, x, size, iterations):
    A_gpu = cuda.mem_alloc(A.nbytes)
    b_gpu = cuda.mem_alloc(b.nbytes)
    x_gpu_1 = cuda.mem_alloc(x.nbytes)
    x_gpu_2 = cuda.mem_alloc(x.nbytes)

    # host to device copy
    cuda.memcpy_htod(A_gpu, A)
    cuda.memcpy_htod(b_gpu, b)
    cuda.memcpy_htod(x_gpu_1, x)

    block_size = 256
    grid_size = int(np.ceil(size / block_size))

    # Time kernel for 100 iterations
    start = cuda.Event()
    end = cuda.Event()
    start.record()

    jump_iteration_gauss_seidel(A_gpu, b_gpu, x_gpu_1, x_gpu_2, np.int32(size), np.int32(iterations),
                              block=(block_size, 1, 1), grid=(grid_size, 1))

    end.record()
    end.synchronize()

    # Calculate the elapsed time
    elapsed_time = start.time_till(end)
    print(f"Kernel execution time: {elapsed_time} milliseconds")

    # device to host copy
    cuda.memcpy_dtoh(x, x_gpu_1)
    return x

# Test system
A = np.array([[10.0, -1.0, 2.0], [-1.0, 11.0, -1.0], [2.0, -1.0, 10.0]], dtype=np.float32)
b = np.array([6.0, 25.0, -11.0], dtype=np.float32)
x0 = np.zeros_like(b)

iterations = 100
x_gpu = run_gauss_seidel_gpu(A, b, x0.copy(), len(b), iterations)

# solve system
x_correct = la.solve(A, b)

print()
print("GPU Result:", x_gpu)
print("Correct Solution:", x_correct)
print()

assert np.allclose(x_gpu, x_correct), "GPU result does not match correct solution"
print("GPU result matches correct solution")

Kernel execution time: 0.14905600249767303 milliseconds

GPU Result: [ 1.0432693  2.2692308 -1.0817307]
Correct Solution: [ 1.0432693  2.2692308 -1.0817307]

GPU result matches correct solution
