In [4]:
def accumulator_simd(accm_op, variables, func="_mm512_add_pd", indent_level=1):
    indent = " " * 4 * indent_level
    
    accm_op.append("\n")
    if len(variables) == 1:
        return variables[0]

    variables_next_ = []
    variables_ = variables[:-1] if len(variables) % 2 == 1 else variables
    for i in range(0, len(variables_), 2):
        accm_op.append(f"{indent}{variables_[i]} ={func}({variables_[i]}, {variables_[i+1]});")
        variables_next_.append(f"{variables_[i]}")
        
    if len(variables) % 2 == 1:
        variables_next_.append(f"{variables[-1]}")
        
    accumulator_simd(accm_op, variables_next_, indent_level = indent_level)


for unroll in range(1, 8+1, 1):
    with open(f"./small_dgemv_naive_c_implementation_ver3-1_ZmmNaive_Unroll-InnerLoop-{unroll}.c", "w") as p:
        p.write("#include <stddef.h>\n")
        p.write("#include <stdio.h>\n")
        p.write("#include <stdlib.h>\n")
        p.write("#include <stdint.h>\n")
        p.write("#include <math.h>\n")
        p.write("#include <omp.h>\n")
        p.write("#include <immintrin.h>\n") 
        p.write("#include <xmmintrin.h>\n")
        
        p.write(f"""
double sum_zmm_elements_with_remain_v3_1_{unroll}(__m512d zmmX, double remain_sum){{
    __m256d ymm_upper = _mm512_extractf64x4_pd(zmmX, 1);
    __m256d ymm_lower = _mm512_extractf64x4_pd(zmmX, 0);

    __m256d ymm_sum = _mm256_add_pd(ymm_upper, ymm_lower);

    __m128d xmm_sum = _mm_add_pd(_mm256_extractf128_pd(ymm_sum, 1), _mm256_extractf128_pd(ymm_sum, 0));

    __m128d xmm_high_low = _mm_add_pd(xmm_sum, _mm_unpackhi_pd(xmm_sum, xmm_sum));

    double result;
    _mm_store_sd(&result, xmm_high_low);

    return result + remain_sum;
}}  
        """)
        
        
        # Normal ----------------------------------------------------------------------------------------------------------
        y_start, y_end = 1, 1+unroll
        a_start, a_end = 1+unroll, 1+unroll*2

        remain_op_i = f"""
    for (int64_t i = lda-i_remain; i<lda; i++) {{
        y[i] += a[i] * x[k];
    }}        
        """

        load_y = "".join([f"""
            zmm{i} = _mm512_loadu_pd(&y[i+{(i-y_start)*8}]);""" for i in range(y_start, y_end, 1)])
        load_a = "".join([f"""
            zmm{i} = _mm512_loadu_pd(&a[i+{(i-a_start)*8}]);""" for i in range(a_start, a_end, 1)])

        mul_ax = "".join([f"""
            zmm{i+a_start-1} = _mm512_mul_pd(zmm0, zmm{i+a_start-1});""" for i in range(y_start, y_end)])

        add_ax = "".join([f"""
            zmm{i+a_start-1} = _mm512_add_pd(zmm{i}, zmm{i+a_start-1});""" for i in range(y_start, y_end)])

        store_y = "".join([f"""
            _mm512_storeu_pd(&y[i+{(i-1)*8}], zmm{i+a_start-1});""" for i in range(y_start, y_end)])

        code = f"""        
void mydgemv_n_ver3_1_unroll{unroll}(double a[], double x[], double y[], int64_t lda, int64_t ldx, int64_t ldy){{
    __m512d zmm0, {", ".join([f"zmm{i}" for i in range(1, 1+unroll*2)])};
    double tmp_x;
    for (int64_t i=0; i<lda; i++){{
        y[i] = 0.0;
    }}

    int64_t i_remain = lda % {8*unroll};
    for (int64_t k=0; k<ldx; k++){{
        zmm0 = _mm512_set1_pd(x[k+0]);
        for (int64_t i=0; i<lda-i_remain; i+={8*unroll}){{
            {load_y}
            {load_a}

            {mul_ax}
            {add_ax}
            {store_y}
        }}

        {remain_op_i}
        a += lda;
    }}
}}
        """
        p.write(code)


        # Transposed ----------------------------------------------------------------------------------------------------------
        x_start, x_end = 1, 1+unroll
        a_start, a_end = 1+unroll, 1+unroll*2

        declare_x = "__m512d " + ", ".join([f"zmm{i}" for i in range(0, unroll+1)])
        
        remain_op = ""
        if unroll>=2:
            remain_op = f"""
        for (int64_t k =ldx-k_remain; k<ldx; k++) {{
            remain_sum += a_t[k] * x[k];
        }}
            """

        load_x = "".join([f"""
            zmm{i} = _mm512_loadu_pd(&x[k+{(i-x_start)*8}]);""" for i in range(x_start, x_end, 1)])
        load_a = "".join([f"""
            zmm{i} = _mm512_loadu_pd(&a_t[k+{(i-a_start)*8}]);""" for i in range(a_start, a_end, 1)])

        mul_ax = "".join([f"""
            zmm{i+a_start-1} = _mm512_mul_pd(zmm{i}, zmm{i+a_start-1});""" for i in range(x_start, x_end)])

        accm_op = []
        accumulator_simd(accm_op, ["zmm0"] + [f"zmm{i+a_start-1}" for i in range(x_start, x_end)], indent_level=3)
        accm_op = "\n".join(accm_op)

        store_y = "".join([f"""
            _mm512_storeu_pd(&y[i+{i-1}], zmm{i+a_start-1});""" for i in range(y_start, y_end)])

        code = f"""        
void mydgemv_t_ver3_1_unroll{unroll}(double a_t[], double x[], double y[], int64_t lda, int64_t ldx, int64_t ldy){{
    __m512d zmm0, {", ".join([f"zmm{i}" for i in range(1, 1+unroll*2)])};

    int64_t k_remain = ldx % {unroll*8};
    for (int64_t i=0; i<lda; i++){{
        zmm0 = _mm512_setzero_pd();

        for (int64_t k=0; k<ldx-k_remain; k+={unroll*8}){{
            {load_x}
            {load_a}

            {mul_ax}

            {accm_op}
        }}
        double remain_sum=0.0;

        {remain_op}

        y[i] = sum_zmm_elements_with_remain_v3_1_{unroll}(zmm0, remain_sum);
        a_t += ldx;
    }}
}}
        """
        p.write(code)