In [30]:
import cupy as cp
import numpy as np

# CUDA C++ kernel definition
_weights_to_lambdas_kernel = r'''
extern "C" __global__
void weights_to_lambdas(const float* lambdas, const float* weightss, float* new_lambdas,
                        int num_terms, int num_qubits) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int size = 1 << (num_qubits * 2); // 4^num_qubits
    
    if (idx < size) {
        new_lambdas[idx] = 0.0f;
        
        // For each term
        for (int j = 0; j < num_terms; j++) {
            float product = 1.0f;
            int temp_idx = idx;
            
            // Calculate product for this index
            for (int k = 0; k < num_qubits; k++) {
                int weight_idx = temp_idx & 3;  // mod 4 to get the index within this qubit's weights
                temp_idx >>= 2;  // Move to next qubit
                
                // weightss is arranged as [term][qubit][4]
                int offset = j * (num_qubits * 4) + k * 4 + weight_idx;
                product *= weightss[offset];
            }
            
            new_lambdas[idx] += lambdas[j] * product;
        }
    }
}
'''

# Compile the kernel
weights_to_lambdas_kernel = cp.RawKernel(_weights_to_lambdas_kernel, 'weights_to_lambdas')

def weightss_to_lambdas_cuda(lambdas: cp.ndarray, weightss: cp.ndarray) -> cp.ndarray:
    """CUDA implementation of weightss_to_lambdas using raw kernel
    
    Args:
        lambdas: Coefficients for each term, shape (num_terms,)
        weightss: Weights matrix, shape (num_terms, num_qubits, 4)
        
    Returns:
        Tuple containing:
        - new_lambdas: Non-zero lambda values
        - non_zeros_indices: Indices of non-zero values
    """
    num_terms, num_qubits, _ = weightss.shape
    
    # Reshape weightss for easier access in CUDA kernel
    weightss_flat = weightss.reshape(num_terms * num_qubits * 4)
    
    # Initialize output array for all possible combinations (4^num_qubits)
    size = 4**num_qubits
    new_lambdas = cp.zeros(size, dtype=cp.float32)
    
    # Calculate grid and block dimensions
    threads_per_block = 256
    blocks_per_grid = (size + threads_per_block - 1) // threads_per_block
    
    # Launch kernel
    weights_to_lambdas_kernel(
        (blocks_per_grid,), (threads_per_block,),
        (lambdas, weightss_flat, new_lambdas, num_terms, num_qubits)
    )
    
    # Extract non-zero elements and their indices
    non_zeros_indices = cp.nonzero(new_lambdas)[0]
    result_lambdas = new_lambdas[non_zeros_indices]
    
    return result_lambdas, non_zeros_indices

# Example usage
def test_cuda_kernel():
    # Create test data
    num_terms = 3
    num_qubits = 2
    
    # Example weights array like in the docstring
    test_weights = cp.ones((num_terms, num_qubits, 4), dtype=cp.float32)
    for i in range(num_terms):
        for j in range(num_qubits):
            test_weights[i, j] = cp.array([1, 2, 3, 4], dtype=cp.float32)
    
    test_lambdas = cp.ones(num_terms, dtype=cp.float32)
    
    # Execute CUDA kernel implementation
    cuda_result, indices = weightss_to_lambdas_cuda(test_lambdas, test_weights)
    
    print("CUDA Result:")
    print(cuda_result)
    print("Indices:")
    print(indices)

if __name__ == "__main__":
    test_cuda_kernel()

CUDA Result:
[ 3.  6.  9. 12.  6. 12. 18. 24.  9. 18. 27. 36. 12. 24. 36. 48.]
Indices:
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15]


In [31]:
lambdas = cp.array([[1], [2]])
weightss = cp.array([
    [[0,2,3,4],[1,2,3,4]], 
    [[0,2,3,4],[1,2,3,4]]],
      dtype=cp.float32)

weightss_to_lambdas_cuda(lambdas, weightss)

(array([], dtype=float32), array([], dtype=int64))

In [28]:
lambdas = cp.array([[1], [2]])
weightss = cp.array([
    [[0,2,3,4],[1,2,3,4]], 
    [[0,2,3,4],[1,2,3,4]]],
      dtype=cp.float32)

weightss_to_lambdas(lambdas, weightss)

(array([ 6., 12., 18., 24.,  9., 18., 27., 36., 12., 24., 36., 48.]),
 array([ 4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15], dtype=int64))

In [18]:
lambdas = cp.array([[1]])
weightss = cp.array([[
    	[ 0.        , -0.36763045,  0.9023816 , -0.22484507],
        [ 1.        ,  0.        ,  0.        ,  0.        ],
        [ 1.        ,  0.        ,  0.        ,  0.        ],
        [ 0.        , -0.36763045,  0.9023816 , -0.22484507]]],
      dtype=cp.float32)

weightss_to_lambdas(lambdas, weightss)

(array([ 0.13515215, -0.33174294,  0.08265989, -0.33174294,  0.81429255,
        -0.20289604,  0.08265989, -0.20289604,  0.0505553 ]),
 array([ 65,  66,  67, 129, 130, 131, 193, 194, 195], dtype=int64))

In [None]:
lambdass = [cp.array([1.]), cp.array([1.]), cp.array([1.]), cp.array([1.])]
weightsss = [
    cp.array([[[ 0.        , -0.36763045,  0.9023816 , -0.22484507],
        [ 1.        ,  0.        ,  0.        ,  0.        ],
        [ 1.        ,  0.        ,  0.        ,  0.        ],
        [ 0.        , -0.36763045,  0.9023816 , -0.22484507]]],
      dtype=cp.float32), 
    cp.array([[[ 0.        , -0.36763045,  0.9023816 , -0.22484507],
        [ 0.        , -0.36763045,  0.9023816 , -0.22484507],
        [ 1.        ,  0.        ,  0.        ,  0.        ],
        [ 0.        , -0.36763045,  0.9023816 , -0.22484507]]],
      dtype=cp.float32), 
    cp.array([[[ 1.        ,  0.        ,  0.        ,  0.        ],
        [ 0.        , -0.36763045,  0.9023816 , -0.22484507],
        [ 0.        , -0.36763045,  0.9023816 , -0.22484507],
        [ 1.        ,  0.        ,  0.        ,  0.        ]]],
      dtype=cp.float32), 
    cp.array([[[ 1.        ,  0.        ,  0.        ,  0.        ],
        [ 1.        ,  0.        ,  0.        ,  0.        ],
        [ 0.        , -0.36763045,  0.9023816 , -0.22484507],
        [ 0.        , -0.36763045,  0.9023816 , -0.22484507]]],
      dtype=cp.float32)]

weightsss_to_lambdass(lambdass, weightsss)

([array([-1.48219694e-323, -9.88131292e-324, -2.96439388e-323,
         -9.88131292e-324, -9.88131292e-324, -4.94065646e-324,
         -2.47032823e-323, -9.88131292e-324, -2.96439388e-323,
         -2.47032823e-323, -7.90505033e-323, -2.47032823e-323,
         -9.88131292e-324, -4.94065646e-324, -2.47032823e-323,
         -9.88131292e-324,  4.17565817e-015,  2.94439541e-015,
          1.02091435e-014,  3.36134641e-015,  2.94439541e-015,
          2.07619110e-015,  7.19880652e-015,  2.37019712e-015,
          1.02091435e-014,  7.19880652e-015,  2.49605228e-014,
          8.21821770e-015,  3.36134641e-015,  2.37019712e-015,
          8.21821770e-015,  2.70583684e-015,  1.02819297e-016,
          7.25013051e-017,  2.51384792e-016,  8.27680960e-017,
          7.25013051e-017,  5.11230808e-017,  1.77259775e-016,
          5.83625364e-017,  2.51384792e-016,  1.77259775e-016,
          6.14615304e-016,  2.02361241e-016,  8.27680960e-017,
          5.83625364e-017,  2.02361241e-016,  6.6627159