From dfc3bed9650b04e37ed687f1d0504657581f73aa Mon Sep 17 00:00:00 2001 From: dance858 Date: Mon, 30 Mar 2026 08:51:35 -0700 Subject: [PATCH 01/13] new infrastructure --- include/utils/dense_matrix.h | 20 + include/utils/linalg_dense_sparse_matmuls.h | 26 ++ include/utils/matrix.h | 12 +- src/affine/left_matmul.c | 2 +- src/utils/dense_matrix.c | 138 +------ src/utils/linalg_dense_sparse_matmuls.c | 346 ++++++++++++++++++ tests/all_tests.c | 5 + .../test_linalg_utils_matmul_chain_rule.h | 228 ++++++++++++ tests/utils/test_matrix.h | 2 +- 9 files changed, 632 insertions(+), 147 deletions(-) create mode 100644 include/utils/dense_matrix.h create mode 100644 include/utils/linalg_dense_sparse_matmuls.h create mode 100644 src/utils/linalg_dense_sparse_matmuls.c create mode 100644 tests/utils/test_linalg_utils_matmul_chain_rule.h diff --git a/include/utils/dense_matrix.h b/include/utils/dense_matrix.h new file mode 100644 index 0000000..8121020 --- /dev/null +++ b/include/utils/dense_matrix.h @@ -0,0 +1,20 @@ +#ifndef DENSE_MATRIX_H +#define DENSE_MATRIX_H + +#include "matrix.h" + +/* Dense matrix (row-major) */ +typedef struct Dense_Matrix +{ + Matrix base; + double *x; + double *work; /* scratch buffer, length n */ +} Dense_Matrix; + +/* Constructors */ +Matrix *new_dense_matrix(int m, int n, const double *data); + +/* Transpose helper */ +Matrix *dense_matrix_trans(const Dense_Matrix *self); + +#endif /* DENSE_MATRIX_H */ diff --git a/include/utils/linalg_dense_sparse_matmuls.h b/include/utils/linalg_dense_sparse_matmuls.h new file mode 100644 index 0000000..8404940 --- /dev/null +++ b/include/utils/linalg_dense_sparse_matmuls.h @@ -0,0 +1,26 @@ +#ifndef LINALG_DENSE_SPARSE_H +#define LINALG_DENSE_SPARSE_H + +#include "CSC_Matrix.h" +#include "CSR_Matrix.h" +#include "matrix.h" + +/* C = (I_p kron A) @ J via the polymorphic Matrix interface. + * A is dense m x n, J is (n*p) x k in CSC, C is (m*p) x k in CSC. */ +// TODO: maybe we can replace these with I_kron_X functionality? +CSC_Matrix *I_kron_A_alloc(const Matrix *A, const CSC_Matrix *J, int p); +void I_kron_A_fill_values(const Matrix *A, const CSC_Matrix *J, CSC_Matrix *C); + +/* Sparsity and values of C = (Y^T kron I_m) @ J where Y is k x n, J is (m*k) x p, + and C is (m*n) x p. Y is given in column-major dense format. */ +CSR_Matrix *YT_kron_I_alloc(int m, int k, int n, const CSC_Matrix *J); +void YT_kron_I_fill_values(int m, int k, int n, const double *Y, const CSC_Matrix *J, + CSR_Matrix *C); + +/* Sparsity and values of C = (I_n kron X) @ J where X is m x k (col-major dense), + J is (k*n) x p, and C is (m*n) x p. */ +CSR_Matrix *I_kron_X_alloc(int m, int k, int n, const CSC_Matrix *J); +void I_kron_X_fill_values(int m, int k, int n, const double *X, const CSC_Matrix *J, + CSR_Matrix *C); + +#endif /* LINALG_DENSE_SPARSE_H */ diff --git a/include/utils/matrix.h b/include/utils/matrix.h index fe7db5f..478fabd 100644 --- a/include/utils/matrix.h +++ b/include/utils/matrix.h @@ -41,21 +41,11 @@ typedef struct Sparse_Matrix CSR_Matrix *csr; } Sparse_Matrix; -/* Dense matrix (row-major) */ -typedef struct Dense_Matrix -{ - Matrix base; - double *x; - double *work; /* scratch buffer, length n */ -} Dense_Matrix; - /* Constructors */ Matrix *new_sparse_matrix(const CSR_Matrix *A); -Matrix *new_dense_matrix(int m, int n, const double *data); -/* Transpose helpers */ +/* Transpose helper */ Matrix *sparse_matrix_trans(const Sparse_Matrix *self, int *iwork); -Matrix *dense_matrix_trans(const Dense_Matrix *self); /* Free helper */ static inline void free_matrix(Matrix *m) diff --git a/src/affine/left_matmul.c b/src/affine/left_matmul.c index 26ea172..4067e65 100644 --- a/src/affine/left_matmul.c +++ b/src/affine/left_matmul.c @@ -17,7 +17,7 @@ */ #include "affine.h" #include "subexpr.h" -#include "utils/matrix.h" +#include "utils/dense_matrix.h" #include #include #include diff --git a/src/utils/dense_matrix.c b/src/utils/dense_matrix.c index 63f3442..8a2dd2c 100644 --- a/src/utils/dense_matrix.c +++ b/src/utils/dense_matrix.c @@ -15,9 +15,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "utils/dense_matrix.h" #include "utils/cblas_wrapper.h" -#include "utils/iVec.h" -#include "utils/matrix.h" +#include "utils/linalg_dense_sparse_matmuls.h" #include #include @@ -41,136 +41,6 @@ static void dense_block_left_mult_vec(const Matrix *A, const double *x, double * n, 0.0, y, m); } -static CSC_Matrix *dense_block_left_mult_sparsity(const Matrix *A, - const CSC_Matrix *J, int p) -{ - int m = A->m; - int n = A->n; - int i, j, jj, block, block_start, block_end, block_jj_start, row_offset; - - int *Cp = (int *) malloc((J->n + 1) * sizeof(int)); - iVec *Ci = iVec_new(J->n * m); - Cp[0] = 0; - - /* for each column of J */ - for (j = 0; j < J->n; j++) - { - /* if empty we continue */ - if (J->p[j] == J->p[j + 1]) - { - Cp[j + 1] = Cp[j]; - continue; - } - - /* process each of p blocks of rows in this column of J */ - jj = J->p[j]; - for (block = 0; block < p; block++) - { - // ----------------------------------------------------------------- - // find start and end indices of rows of J in this block - // ----------------------------------------------------------------- - block_start = block * n; - block_end = block_start + n; - while (jj < J->p[j + 1] && J->i[jj] < block_start) - { - jj++; - } - - block_jj_start = jj; - while (jj < J->p[j + 1] && J->i[jj] < block_end) - { - jj++; - } - - /* if no entries in this block, continue */ - if (jj == block_jj_start) - { - continue; - } - - /* dense A: all m rows contribute */ - row_offset = block * m; - for (i = 0; i < m; i++) - { - iVec_append(Ci, row_offset + i); - } - } - Cp[j + 1] = Ci->len; - } - - CSC_Matrix *C = new_csc_matrix(m * p, J->n, Ci->len); - memcpy(C->p, Cp, (J->n + 1) * sizeof(int)); - memcpy(C->i, Ci->data, Ci->len * sizeof(int)); - free(Cp); - iVec_free(Ci); - - return C; -} - -static void dense_block_left_mult_values(const Matrix *A, const CSC_Matrix *J, - CSC_Matrix *C) -{ - const Dense_Matrix *dm = (const Dense_Matrix *) A; - int m = dm->base.m; - int n = dm->base.n; - int k = J->n; - - int i, j, s, block, block_start, block_end, start, end; - - double *j_dense = dm->work; - - /* for each column of J (and C) */ - for (j = 0; j < k; j++) - { - for (i = C->p[j]; i < C->p[j + 1]; i += m) - { - block = C->i[i] / m; - block_start = block * n; - block_end = block_start + n; - - start = J->p[j]; - end = J->p[j + 1]; - - while (start < J->p[j + 1] && J->i[start] < block_start) - { - start++; - } - - while (end > start && J->i[end - 1] >= block_end) - { - end--; - } - - int count = end - start; - - if (count == 1) - { - /* Fast path: C column segment = val * A[:, row_in_block] */ - int row_in_block = J->i[start] - block_start; - double val = J->x[start]; - cblas_dcopy(m, dm->x + row_in_block, n, C->x + i, 1); - if (val != 1.0) - { - cblas_dscal(m, val, C->x + i, 1); - } - } - else - { - /* scatter sparse J col into dense vector and then compute A @ - * j_dense */ - memset(j_dense, 0, n * sizeof(double)); - for (s = start; s < end; s++) - { - j_dense[J->i[s] - block_start] = J->x[s]; - } - - cblas_dgemv(CblasRowMajor, CblasNoTrans, m, n, 1.0, dm->x, n, - j_dense, 1, 0.0, C->x + i, 1); - } - } - } -} - static void dense_free(Matrix *A) { Dense_Matrix *dm = (Dense_Matrix *) A; @@ -185,8 +55,8 @@ Matrix *new_dense_matrix(int m, int n, const double *data) dm->base.m = m; dm->base.n = n; dm->base.block_left_mult_vec = dense_block_left_mult_vec; - dm->base.block_left_mult_sparsity = dense_block_left_mult_sparsity; - dm->base.block_left_mult_values = dense_block_left_mult_values; + dm->base.block_left_mult_sparsity = I_kron_A_alloc; + dm->base.block_left_mult_values = I_kron_A_fill_values; dm->base.free_fn = dense_free; dm->x = (double *) malloc(m * n * sizeof(double)); memcpy(dm->x, data, m * n * sizeof(double)); diff --git a/src/utils/linalg_dense_sparse_matmuls.c b/src/utils/linalg_dense_sparse_matmuls.c new file mode 100644 index 0000000..8fdef54 --- /dev/null +++ b/src/utils/linalg_dense_sparse_matmuls.c @@ -0,0 +1,346 @@ +/* + * Copyright 2026 Daniel Cederberg and William Zhang + * + * This file is part of the DNLP-differentiation-engine project. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "utils/CSC_Matrix.h" +#include "utils/CSR_Matrix.h" +#include "utils/cblas_wrapper.h" +#include "utils/dense_matrix.h" +#include "utils/iVec.h" +#include +#include +#include + +/* --------------------------------------------------------------- + * C = (I_p kron A) @ J via the polymorphic Matrix interface. + * A is dense m x n, J is (n*p) x k CSC, C is (m*p) x k CSC. + * --------------------------------------------------------------- */ +CSC_Matrix *I_kron_A_alloc(const Matrix *A, const CSC_Matrix *J, int p) +{ + int m = A->m; + int n = A->n; + int i, j, jj, block, block_start, block_end, block_jj_start, row_offset; + + int *Cp = (int *) malloc((J->n + 1) * sizeof(int)); + iVec *Ci = iVec_new(J->n * m); + Cp[0] = 0; + + /* for each column of J */ + for (j = 0; j < J->n; j++) + { + /* if empty we continue */ + if (J->p[j] == J->p[j + 1]) + { + Cp[j + 1] = Cp[j]; + continue; + } + + /* process each of p blocks of rows in this column of J */ + jj = J->p[j]; + for (block = 0; block < p; block++) + { + block_start = block * n; + block_end = block_start + n; + while (jj < J->p[j + 1] && J->i[jj] < block_start) + { + jj++; + } + + block_jj_start = jj; + while (jj < J->p[j + 1] && J->i[jj] < block_end) + { + jj++; + } + + /* if no entries in this block, continue */ + if (jj == block_jj_start) + { + continue; + } + + /* dense A: all m rows contribute */ + row_offset = block * m; + for (i = 0; i < m; i++) + { + iVec_append(Ci, row_offset + i); + } + } + Cp[j + 1] = Ci->len; + } + + CSC_Matrix *C = new_csc_matrix(m * p, J->n, Ci->len); + memcpy(C->p, Cp, (J->n + 1) * sizeof(int)); + memcpy(C->i, Ci->data, Ci->len * sizeof(int)); + free(Cp); + iVec_free(Ci); + + return C; +} + +void I_kron_A_fill_values(const Matrix *A, const CSC_Matrix *J, CSC_Matrix *C) +{ + const Dense_Matrix *dm = (const Dense_Matrix *) A; + int m = dm->base.m; + int n = dm->base.n; + int k = J->n; + + int i, j, s, block, block_start, block_end, start, end; + + double *j_dense = dm->work; + + /* for each column of J (and C) */ + for (j = 0; j < k; j++) + { + for (i = C->p[j]; i < C->p[j + 1]; i += m) + { + block = C->i[i] / m; + block_start = block * n; + block_end = block_start + n; + + start = J->p[j]; + end = J->p[j + 1]; + + while (start < J->p[j + 1] && J->i[start] < block_start) + { + start++; + } + + while (end > start && J->i[end - 1] >= block_end) + { + end--; + } + + int count = end - start; + + if (count == 1) + { + /* Fast path: C column segment = val * A[:, row_in_block] */ + int row_in_block = J->i[start] - block_start; + double val = J->x[start]; + cblas_dcopy(m, dm->x + row_in_block, n, C->x + i, 1); + if (val != 1.0) + { + cblas_dscal(m, val, C->x + i, 1); + } + } + else + { + /* scatter sparse J col into dense vector and then + * compute A @ j_dense */ + memset(j_dense, 0, n * sizeof(double)); + for (s = start; s < end; s++) + { + j_dense[J->i[s] - block_start] = J->x[s]; + } + + cblas_dgemv(CblasRowMajor, CblasNoTrans, m, n, 1.0, dm->x, n, + j_dense, 1, 0.0, C->x + i, 1); + } + } + } +} + +/* --------------------------------------------------------------- + * C = (Y^T kron I_m) @ J + * Y is k x n (col-major), J is (m*k) x p CSC, C is (m*n) x p CSR + * --------------------------------------------------------------- */ +CSR_Matrix *YT_kron_I_alloc(int m, int k, int n, const CSC_Matrix *J) +{ + (void) k; + /* C has n blocks of m rows. All rows at the same position within + their block (same blk_row) share the same column sparsity. + + Step 1: for each blk_row, find which columns of J contribute. + Column j contributes iff J[:,j] has any nonzero r + with r % m == blk_row. + Step 2: replicate each blk_row's pattern across all n blocks. */ + + int i, j, ii, blk_row, total_nnz; + + // --------------------------------------------------------------- + // build sparsity pattern per blk_row + // --------------------------------------------------------------- + iVec **pattern = (iVec **) malloc(m * sizeof(iVec *)); + total_nnz = 0; + for (blk_row = 0; blk_row < m; blk_row++) + { + pattern[blk_row] = iVec_new(J->n); + + /* check each column of J */ + for (j = 0; j < J->n; j++) + { + for (ii = J->p[j]; ii < J->p[j + 1]; ii++) + { + if (J->i[ii] % m == blk_row) + { + iVec_append(pattern[blk_row], j); + break; + } + } + } + total_nnz += pattern[blk_row]->len * n; + } + + // --------------------------------------------------------------- + // replicate sparsity pattern across blocks + // --------------------------------------------------------------- + CSR_Matrix *C = new_csr_matrix(m * n, J->n, total_nnz); + int idx = 0; + for (i = 0; i < m * n; i++) + { + blk_row = i % m; + C->p[i] = idx; + int len = pattern[blk_row]->len; + memcpy(C->i + idx, pattern[blk_row]->data, len * sizeof(int)); + idx += len; + } + C->p[m * n] = idx; + assert(idx == total_nnz); + + for (blk_row = 0; blk_row < m; blk_row++) + { + iVec_free(pattern[blk_row]); + } + free(pattern); + return C; +} + +void YT_kron_I_fill_values(int m, int k, int n, const double *Y, const CSC_Matrix *J, + CSR_Matrix *C) +{ + assert(C->m == m * n); + /* C[i, j] = sum_l Y[l, blk] * J[blk_row + l*m, j] + * where blk_row = i % m, blk = i / m */ + + int i, j, ii, jj, blk, blk_row; + double sum; + + /* for each row i of C */ + for (i = 0; i < C->m; i++) + { + blk = i / m; /* which block of C */ + blk_row = i % m; /* which row within block */ + const double *Y_col = Y + blk * k; /* column blk of Y */ + + /* for each column j in row i of C */ + for (jj = C->p[i]; jj < C->p[i + 1]; jj++) + { + j = C->i[jj]; + sum = 0.0; + + /* matching J nonzeros in this column */ + for (ii = J->p[j]; ii < J->p[j + 1]; ii++) + { + if (J->i[ii] % m == blk_row) + { + sum += Y_col[J->i[ii] / m] * J->x[ii]; + } + } + C->x[jj] = sum; + } + } +} + +CSR_Matrix *I_kron_X_alloc(int m, int k, int n, const CSC_Matrix *J) +{ + /* Step 1: for each block, find which columns of J have any + * nonzero in row range [blk*k, blk*k + k). */ + int i, j, ii, blk; + + iVec **pattern = (iVec **) malloc(n * sizeof(iVec *)); + int total_nnz = 0; + for (blk = 0; blk < n; blk++) + { + int blk_start = blk * k; + int blk_end = blk_start + k; + pattern[blk] = iVec_new(J->n); + + /* check each column of J */ + for (j = 0; j < J->n; j++) + { + for (ii = J->p[j]; ii < J->p[j + 1]; ii++) + { + if (J->i[ii] >= blk_start && J->i[ii] < blk_end) + { + iVec_append(pattern[blk], j); + break; + } + } + } + total_nnz += (int) pattern[blk]->len * m; + } + + /* Step 2: replicate each block's pattern for all m rows + * within that block. */ + CSR_Matrix *C = new_csr_matrix(m * n, J->n, total_nnz); + int idx = 0; + for (i = 0; i < m * n; i++) + { + blk = i / m; + C->p[i] = idx; + int len = (int) pattern[blk]->len; + memcpy(C->i + idx, pattern[blk]->data, len * sizeof(int)); + idx += len; + } + C->p[m * n] = idx; + assert(idx == total_nnz); + + for (blk = 0; blk < n; blk++) + { + iVec_free(pattern[blk]); + } + free(pattern); + return C; +} + +void I_kron_X_fill_values(int m, int k, int n, const double *X, const CSC_Matrix *J, + CSR_Matrix *C) +{ + assert(C->m == m * n); + /* C[i, j] = sum_l X[blk_row + l*m] * J[blk*k + l, j] + * where blk = i / m, blk_row = i % m */ + + int i, j, ii, jj, blk, blk_row; + double sum; + + /* for each row i of C */ + for (i = 0; i < C->m; i++) + { + blk = i / m; /* which block of C */ + blk_row = i % m; /* which row within block */ + int blk_start = blk * k; + int blk_end = blk_start + k; + + /* for each column j in row i of C */ + for (jj = C->p[i]; jj < C->p[i + 1]; jj++) + { + j = C->i[jj]; + sum = 0.0; + + /* J nonzeros in column j within this block's row range */ + for (ii = J->p[j]; ii < J->p[j + 1]; ii++) + { + int r = J->i[ii]; + if (r >= blk_start && r < blk_end) + { + int l = r - blk_start; + sum += X[blk_row + l * m] * J->x[ii]; + } + } + C->x[jj] = sum; + } + } +} diff --git a/tests/all_tests.c b/tests/all_tests.c index a92a6d9..92be5f0 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -55,6 +55,7 @@ #include "utils/test_csr_csc_conversion.h" #include "utils/test_csr_matrix.h" #include "utils/test_linalg_sparse_matmuls.h" +#include "utils/test_linalg_utils_matmul_chain_rule.h" #include "utils/test_matrix.h" #include "wsum_hess/affine/test_broadcast.h" #include "wsum_hess/affine/test_const_scalar_mult.h" @@ -318,6 +319,10 @@ int main(void) mu_run_test(test_sparse_vs_dense_mult_vec, tests_run); mu_run_test(test_dense_matrix_trans, tests_run); mu_run_test(test_sparse_vs_dense_mult_vec_blocks, tests_run); + mu_run_test(test_YT_kron_I, tests_run); + mu_run_test(test_YT_kron_I_larger, tests_run); + mu_run_test(test_I_kron_X, tests_run); + mu_run_test(test_I_kron_X_larger, tests_run); printf("\n--- Numerical Diff Tests ---\n"); mu_run_test(test_check_jacobian_composite_exp, tests_run); diff --git a/tests/utils/test_linalg_utils_matmul_chain_rule.h b/tests/utils/test_linalg_utils_matmul_chain_rule.h new file mode 100644 index 0000000..0c66589 --- /dev/null +++ b/tests/utils/test_linalg_utils_matmul_chain_rule.h @@ -0,0 +1,228 @@ +#include +#include +#include + +#include "minunit.h" +#include "test_helpers.h" +#include "utils/CSC_Matrix.h" +#include "utils/CSR_Matrix.h" +#include "utils/linalg_dense_sparse_matmuls.h" + +/* Test YT_kron_I_alloc and YT_kron_I_fill_values + * + * C = (Y^T kron I_m) @ J + * m=2, k=2, n=2, p=3 + * + * Y (k x n, col-major [1,2,3,4]): + * [1 3] + * [2 4] + * + * J (mk=4 x p=3, CSC): + * [1 0 2] + * [0 1 0] + * [3 0 0] + * [0 0 1] + * + * C = (Y^T kron I_2) @ J: + * [ 7 0 2] + * [ 0 1 2] + * [15 0 6] + * [ 0 3 4] + */ +const char *test_YT_kron_I(void) +{ + int m = 2, k = 2, n = 2; + + /* J is 4x3 CSC */ + CSC_Matrix *J = new_csc_matrix(4, 3, 5); + int Jp[4] = {0, 2, 3, 5}; + int Ji[5] = {0, 2, 1, 0, 3}; + double Jx[5] = {1.0, 3.0, 1.0, 2.0, 1.0}; + memcpy(J->p, Jp, 4 * sizeof(int)); + memcpy(J->i, Ji, 5 * sizeof(int)); + memcpy(J->x, Jx, 5 * sizeof(double)); + + /* Y col-major: Y[0,0]=1, Y[1,0]=2, Y[0,1]=3, Y[1,1]=4 */ + double Y[4] = {1.0, 2.0, 3.0, 4.0}; + + CSR_Matrix *C = YT_kron_I_alloc(m, k, n, J); + + /* Expected CSR (from scipy) */ + int exp_p[5] = {0, 2, 4, 6, 8}; + int exp_i[8] = {0, 2, 1, 2, 0, 2, 1, 2}; + double exp_x[8] = {7.0, 2.0, 1.0, 2.0, 15.0, 6.0, 3.0, 4.0}; + + mu_assert("C dims", C->m == 4 && C->n == 3); + mu_assert("C nnz", C->nnz == 8); + mu_assert("C row ptrs", cmp_int_array(C->p, exp_p, 5)); + mu_assert("C col indices", cmp_int_array(C->i, exp_i, 8)); + + YT_kron_I_fill_values(m, k, n, Y, J, C); + mu_assert("C values", cmp_double_array(C->x, exp_x, 8)); + + free_csr_matrix(C); + free_csc_matrix(J); + return NULL; +} + +/* Test YT_kron_I with larger dimensions: m=3, k=2, n=3, p=4 + * + * Y (k=2 x n=3, col-major [1,3,0.5,1,2,0.5]): + * [1.0 0.5 2.0] + * [3.0 1.0 0.5] + * + * J (mk=6 x p=4, CSC): + * [1 0 0 2] + * [0 0 1 0] + * [0 3 0 0] + * [2 0 0 1] + * [0 1 0 0] + * [0 0 4 0] + * + * C = (Y^T kron I_3) @ J is 9 x 4 + */ +const char *test_YT_kron_I_larger(void) +{ + int m = 3, k = 2, n = 3; + + /* J is 6x4 CSC */ + CSC_Matrix *J = new_csc_matrix(6, 4, 8); + int Jp[5] = {0, 2, 4, 6, 8}; + int Ji[8] = {0, 3, 2, 4, 1, 5, 0, 3}; + double Jx[8] = {1.0, 2.0, 3.0, 1.0, 1.0, 4.0, 2.0, 1.0}; + memcpy(J->p, Jp, 5 * sizeof(int)); + memcpy(J->i, Ji, 8 * sizeof(int)); + memcpy(J->x, Jx, 8 * sizeof(double)); + + /* Y col-major */ + double Y[6] = {1.0, 3.0, 0.5, 1.0, 2.0, 0.5}; + + CSR_Matrix *C = YT_kron_I_alloc(m, k, n, J); + + /* Expected CSR (from scipy) */ + int exp_p[10] = {0, 2, 4, 6, 8, 10, 12, 14, 16, 18}; + int exp_i[18] = {0, 3, 1, 2, 1, 2, 0, 3, 1, 2, 1, 2, 0, 3, 1, 2, 1, 2}; + double exp_x[18] = {7.0, 5.0, 3.0, 1.0, 3.0, 12.0, 2.5, 2.0, 1.0, + 0.5, 1.5, 4.0, 3.0, 4.5, 0.5, 2.0, 6.0, 2.0}; + + mu_assert("C2 dims", C->m == 9 && C->n == 4); + mu_assert("C2 nnz", C->nnz == 18); + mu_assert("C2 row ptrs", cmp_int_array(C->p, exp_p, 10)); + mu_assert("C2 col indices", cmp_int_array(C->i, exp_i, 18)); + + YT_kron_I_fill_values(m, k, n, Y, J, C); + mu_assert("C2 values", cmp_double_array(C->x, exp_x, 18)); + + free_csr_matrix(C); + free_csc_matrix(J); + return NULL; +} + +/* Test I_kron_X_alloc and I_kron_X_fill_values + * + * C = (I_n kron X) @ J + * m=2, k=2, n=2, p=3 + * + * X (m x k, col-major [1,2,3,4]): + * [1 3] + * [2 4] + * + * J (kn=4 x p=3, CSC): + * [1 0 2] + * [0 1 0] + * [3 0 0] + * [0 0 1] + * + * C = (I_2 kron X) @ J: + * [1 3 2] + * [2 4 4] + * [3 0 3] + * [6 0 4] + */ +const char *test_I_kron_X(void) +{ + int m = 2, k = 2, n = 2; + + /* J is 4x3 CSC */ + CSC_Matrix *J = new_csc_matrix(4, 3, 5); + int Jp[4] = {0, 2, 3, 5}; + int Ji[5] = {0, 2, 1, 0, 3}; + double Jx[5] = {1.0, 3.0, 1.0, 2.0, 1.0}; + memcpy(J->p, Jp, 4 * sizeof(int)); + memcpy(J->i, Ji, 5 * sizeof(int)); + memcpy(J->x, Jx, 5 * sizeof(double)); + + /* X col-major */ + double X[4] = {1.0, 2.0, 3.0, 4.0}; + + CSR_Matrix *C = I_kron_X_alloc(m, k, n, J); + + /* Expected CSR */ + int exp_p[5] = {0, 3, 6, 8, 10}; + int exp_i[10] = {0, 1, 2, 0, 1, 2, 0, 2, 0, 2}; + double exp_x[10] = {1.0, 3.0, 2.0, 2.0, 4.0, 4.0, 3.0, 3.0, 6.0, 4.0}; + + mu_assert("C dims", C->m == 4 && C->n == 3); + mu_assert("C nnz", C->nnz == 10); + mu_assert("C row ptrs", cmp_int_array(C->p, exp_p, 5)); + mu_assert("C col indices", cmp_int_array(C->i, exp_i, 10)); + + I_kron_X_fill_values(m, k, n, X, J, C); + mu_assert("C values", cmp_double_array(C->x, exp_x, 10)); + + free_csr_matrix(C); + free_csc_matrix(J); + return NULL; +} + +/* Test I_kron_X with larger dimensions: m=3, k=2, n=2, p=4 + * + * X (m=3 x k=2, col-major [1,2,3,0.5,1,0.5]): + * [1.0 0.5] + * [2.0 1.0] + * [3.0 0.5] + * + * J (kn=4 x p=4, CSC): + * [1 0 0 2] + * [0 3 1 0] + * [0 0 4 0] + * [2 0 0 1] + * + * C = (I_2 kron X) @ J is 6 x 4 + */ +const char *test_I_kron_X_larger(void) +{ + int m = 3, k = 2, n = 2; + + /* J is 4x4 CSC */ + CSC_Matrix *J = new_csc_matrix(4, 4, 7); + int Jp[5] = {0, 2, 3, 5, 7}; + int Ji[7] = {0, 3, 1, 1, 2, 0, 3}; + double Jx[7] = {1.0, 2.0, 3.0, 1.0, 4.0, 2.0, 1.0}; + memcpy(J->p, Jp, 5 * sizeof(int)); + memcpy(J->i, Ji, 7 * sizeof(int)); + memcpy(J->x, Jx, 7 * sizeof(double)); + + /* X col-major */ + double X[6] = {1.0, 2.0, 3.0, 0.5, 1.0, 0.5}; + + CSR_Matrix *C = I_kron_X_alloc(m, k, n, J); + + /* Expected CSR */ + int exp_p[7] = {0, 4, 8, 12, 15, 18, 21}; + int exp_i[21] = {0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 2, 3, 0, 2, 3, 0, 2, 3}; + double exp_x[21] = {1.0, 1.5, 0.5, 2.0, 2.0, 3.0, 1.0, 4.0, 3.0, 1.5, 0.5, + 6.0, 1.0, 4.0, 0.5, 2.0, 8.0, 1.0, 1.0, 12.0, 0.5}; + + mu_assert("C2 dims", C->m == 6 && C->n == 4); + mu_assert("C2 nnz", C->nnz == 21); + mu_assert("C2 row ptrs", cmp_int_array(C->p, exp_p, 7)); + mu_assert("C2 col indices", cmp_int_array(C->i, exp_i, 21)); + + I_kron_X_fill_values(m, k, n, X, J, C); + mu_assert("C2 values", cmp_double_array(C->x, exp_x, 21)); + + free_csr_matrix(C); + free_csc_matrix(J); + return NULL; +} diff --git a/tests/utils/test_matrix.h b/tests/utils/test_matrix.h index 8add477..c329a16 100644 --- a/tests/utils/test_matrix.h +++ b/tests/utils/test_matrix.h @@ -3,7 +3,7 @@ #include "minunit.h" #include "test_helpers.h" -#include "utils/matrix.h" +#include "utils/dense_matrix.h" #include #include From 89d2010b0863125ff47c9fa1debce7f34b0cc9c1 Mon Sep 17 00:00:00 2001 From: dance858 Date: Mon, 30 Mar 2026 09:20:16 -0700 Subject: [PATCH 02/13] add chain rule for jacobian + tests --- include/subexpr.h | 10 ++ src/bivariate_full_dom/matmul.c | 162 +++++++++++------- tests/all_tests.c | 5 + .../composite/test_chain_rule_jacobian.h | 99 +++++++++++ 4 files changed, 211 insertions(+), 65 deletions(-) diff --git a/include/subexpr.h b/include/subexpr.h index c87da07..64d32d5 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -131,6 +131,16 @@ typedef struct right_matmul_expr CSC_Matrix *CSC_work; } right_matmul_expr; +/* Bivariate matrix multiplication: Z = f(u) @ g(u) where both children + * may be composite expressions. Stores intermediate CSR results for the + * Jacobian chain rule: J_Z = (Y^T x I_m) @ J_f + (I_n x X) @ J_g */ +typedef struct matmul_expr +{ + expr base; + CSR_Matrix *term1_CSR; /* (Y^T x I_m) @ J_left */ + CSR_Matrix *term2_CSR; /* (I_n x X) @ J_right */ +} matmul_expr; + /* Constant scalar multiplication: y = a * child where a is a constant double */ typedef struct const_scalar_mult_expr { diff --git a/src/bivariate_full_dom/matmul.c b/src/bivariate_full_dom/matmul.c index 97e6ba9..fa83b7c 100644 --- a/src/bivariate_full_dom/matmul.c +++ b/src/bivariate_full_dom/matmul.c @@ -17,6 +17,9 @@ */ #include "bivariate_full_dom.h" #include "subexpr.h" +#include "utils/CSC_Matrix.h" +#include "utils/CSR_sum.h" +#include "utils/linalg_dense_sparse_matmuls.h" #include "utils/mini_numpy.h" #include #include @@ -57,51 +60,68 @@ static void jacobian_init_impl(expr *node) int m = x->d1; int k = x->d2; int n = y->d2; - int nnz = m * n * 2 * k; - node->jacobian = new_csr_matrix(node->size, node->n_vars, nnz); - /* fill sparsity pattern */ - int nnz_idx = 0; - for (int i = 0; i < node->size; i++) + if (x->var_id != NOT_A_VARIABLE && y->var_id != NOT_A_VARIABLE && + x->var_id != y->var_id) { - /* Convert flat index to (row, col) in Z */ - int row = i % m; - int col = i / m; + /* both children are differentleaf variables */ + int nnz = m * n * 2 * k; + node->jacobian = new_csr_matrix(node->size, node->n_vars, nnz); - node->jacobian->p[i] = nnz_idx; - - /* X has lower var_id */ - if (x->var_id < y->var_id) + int nnz_idx = 0; + for (int i = 0; i < node->size; i++) { - /* sparsity pattern of kron(YT, I) for this row */ - for (int j = 0; j < k; j++) - { - node->jacobian->i[nnz_idx++] = x->var_id + row + j * m; - } + int row = i % m; + int col = i / m; - /* sparsity pattern of kron(I, X) for this row */ - for (int j = 0; j < k; j++) - { - node->jacobian->i[nnz_idx++] = y->var_id + col * k + j; - } - } - else /* Y has lower var_id */ - { - /* sparsity pattern of kron(I, X) for this row */ - for (int j = 0; j < k; j++) + node->jacobian->p[i] = nnz_idx; + + if (x->var_id < y->var_id) { - node->jacobian->i[nnz_idx++] = y->var_id + col * k + j; + for (int j = 0; j < k; j++) + { + node->jacobian->i[nnz_idx++] = x->var_id + row + j * m; + } + for (int j = 0; j < k; j++) + { + node->jacobian->i[nnz_idx++] = y->var_id + col * k + j; + } } - - /* sparsity pattern of kron(YT, I) for this row */ - for (int j = 0; j < k; j++) + else { - node->jacobian->i[nnz_idx++] = x->var_id + row + j * m; + for (int j = 0; j < k; j++) + { + node->jacobian->i[nnz_idx++] = y->var_id + col * k + j; + } + for (int j = 0; j < k; j++) + { + node->jacobian->i[nnz_idx++] = x->var_id + row + j * m; + } } } + node->jacobian->p[node->size] = nnz_idx; + assert(nnz_idx == nnz); + } + else + { + /* chain rule: the jacobian of f(u) @ g(u) with f(u) and g(u) matrices + is term1 + term2 where term1 = (g(u)^T kron I) @ J_f and + term2 = (I kron f(u)) @ J_g. */ + matmul_expr *mnode = (matmul_expr *) node; + + jacobian_init(x); + jacobian_init(y); + jacobian_csc_init(x); + jacobian_csc_init(y); + + mnode->term1_CSR = YT_kron_I_alloc(m, k, n, x->work->jacobian_csc); + mnode->term2_CSR = I_kron_X_alloc(m, k, n, y->work->jacobian_csc); + + int max_nnz = mnode->term1_CSR->nnz + mnode->term2_CSR->nnz; + node->jacobian = new_csr_matrix(node->size, node->n_vars, max_nnz); + sum_csr_matrices_fill_sparsity(mnode->term1_CSR, mnode->term2_CSR, + node->jacobian); } - node->jacobian->p[node->size] = nnz_idx; - assert(nnz_idx == nnz); } static void eval_jacobian(expr *node) @@ -112,38 +132,59 @@ static void eval_jacobian(expr *node) /* dimensions: X is m x k, Y is k x n, Z is m x n */ int m = x->d1; int k = x->d2; - double *Jx = node->jacobian->x; + int n = y->d2; - /* fill values row-by-row */ - for (int i = 0; i < node->size; i++) + if (x->var_id != NOT_A_VARIABLE && y->var_id != NOT_A_VARIABLE && + x->var_id != y->var_id) { - int row = i % m; /* row in Z */ - int col = i / m; /* col in Z */ - int pos = node->jacobian->p[i]; + /* both children are different leaf variables */ + double *Jx = node->jacobian->x; - if (x->var_id < y->var_id) + for (int i = 0; i < node->size; i++) { - /* contribution to this row from YT */ - memcpy(Jx + pos, y->value + col * k, k * sizeof(double)); + int row = i % m; + int col = i / m; + int pos = node->jacobian->p[i]; - /* contribution to this row from X */ - for (int j = 0; j < k; j++) + if (x->var_id < y->var_id) { - Jx[pos + k + j] = x->value[row + j * m]; + memcpy(Jx + pos, y->value + col * k, k * sizeof(double)); + for (int j = 0; j < k; j++) + { + Jx[pos + k + j] = x->value[row + j * m]; + } } - } - else - { - /* contribution to this row from X */ - for (int j = 0; j < k; j++) + else { - Jx[pos + j] = x->value[row + j * m]; + for (int j = 0; j < k; j++) + { + Jx[pos + j] = x->value[row + j * m]; + } + memcpy(Jx + pos + k, y->value + col * k, k * sizeof(double)); } - - /* contribution to this row from YT */ - memcpy(Jx + pos + k, y->value + col * k, k * sizeof(double)); } } + else + { + /* composite case */ + matmul_expr *mnode = (matmul_expr *) node; + + x->eval_jacobian(x); + y->eval_jacobian(y); + + CSC_Matrix *Jx_csc = x->work->jacobian_csc; + CSC_Matrix *Jy_csc = y->work->jacobian_csc; + + /* refresh children's CSC values */ + csr_to_csc_fill_values(x->jacobian, Jx_csc, x->work->csc_work); + csr_to_csc_fill_values(y->jacobian, Jy_csc, y->work->csc_work); + + /* compute term1, term2, and sum */ + YT_kron_I_fill_values(m, k, n, y->value, Jx_csc, mnode->term1_CSR); + I_kron_X_fill_values(m, k, n, x->value, Jy_csc, mnode->term2_CSR); + sum_csr_matrices_fill_values(mnode->term1_CSR, mnode->term2_CSR, + node->jacobian); + } } static void wsum_hess_init_impl(expr *node) @@ -317,17 +358,8 @@ expr *new_matmul(expr *x, expr *y) exit(1); } - /* verify both are variables and not the same variable */ - if (x->var_id == NOT_A_VARIABLE || y->var_id == NOT_A_VARIABLE || - x->var_id == y->var_id) - { - fprintf(stderr, "Error in new_matmul: operands must be variables and not " - "the same variable\n"); - exit(1); - } - /* Allocate the expression node */ - expr *node = (expr *) calloc(1, sizeof(expr)); + expr *node = (expr *) calloc(1, sizeof(matmul_expr)); /* Initialize with d1 = x->d1, d2 = y->d2 (result is m x n) */ init_expr(node, x->d1, y->d2, x->n_vars, forward, jacobian_init_impl, diff --git a/tests/all_tests.c b/tests/all_tests.c index 92be5f0..ba59bc0 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -140,6 +140,11 @@ int main(void) mu_run_test(test_jacobian_AX_BX_multiply, tests_run); mu_run_test(test_jacobian_quad_form_Ax, tests_run); mu_run_test(test_jacobian_quad_form_exp, tests_run); + mu_run_test(test_jacobian_matmul_exp_exp, tests_run); + mu_run_test(test_jacobian_matmul_sin_cos, tests_run); + mu_run_test(test_jacobian_matmul_Ax_By, tests_run); + mu_run_test(test_jacobian_matmul_sin_Ax_cos_Bx, tests_run); + mu_run_test(test_jacobian_matmul_X_X, tests_run); mu_run_test(test_jacobian_composite_exp_add, tests_run); mu_run_test(test_jacobian_const_scalar_mult_log_vector, tests_run); mu_run_test(test_jacobian_const_scalar_mult_log_matrix, tests_run); diff --git a/tests/jacobian_tests/composite/test_chain_rule_jacobian.h b/tests/jacobian_tests/composite/test_chain_rule_jacobian.h index 9e916af..710aa89 100644 --- a/tests/jacobian_tests/composite/test_chain_rule_jacobian.h +++ b/tests/jacobian_tests/composite/test_chain_rule_jacobian.h @@ -170,3 +170,102 @@ const char *test_jacobian_quad_form_exp(void) free_csr_matrix(Q); return 0; } + +const char *test_jacobian_matmul_exp_exp(void) +{ + /* Z = exp(X) @ exp(Y), X is 2x3, Y is 3x2 */ + double u_vals[12] = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2}; + + expr *X = new_variable(2, 3, 0, 12); + expr *Y = new_variable(3, 2, 6, 12); + expr *exp_X = new_exp(X); + expr *exp_Y = new_exp(Y); + expr *Z = new_matmul(exp_X, exp_Y); + + mu_assert("check_jacobian failed", + check_jacobian(Z, u_vals, NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(Z); + return 0; +} + +const char *test_jacobian_matmul_sin_cos(void) +{ + /* Z = sin(X) @ cos(Y), X is 2x2, Y is 2x3 */ + double u_vals[10] = {0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0}; + + expr *X = new_variable(2, 2, 0, 10); + expr *Y = new_variable(2, 3, 4, 10); + expr *sin_X = new_sin(X); + expr *cos_Y = new_cos(Y); + expr *Z = new_matmul(sin_X, cos_Y); + + mu_assert("check_jacobian failed", + check_jacobian(Z, u_vals, NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(Z); + return 0; +} + +const char *test_jacobian_matmul_Ax_By(void) +{ + /* Z = (A @ X) @ (B @ Y) with constant matrices A, B */ + double u_vals[10] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0}; + + CSR_Matrix *A = new_csr_random(3, 2, 1.0); + CSR_Matrix *B = new_csr_random(2, 3, 1.0); + + expr *X = new_variable(2, 2, 0, 10); /* 2x2, vars 0-3 */ + expr *Y = new_variable(3, 2, 4, 10); /* 3x2, vars 4-9 */ + expr *AX = new_left_matmul(X, A); /* 3x2 */ + expr *BY = new_left_matmul(Y, B); /* 2x2 */ + expr *Z = new_matmul(AX, BY); /* 3x2 */ + + mu_assert("check_jacobian failed", + check_jacobian(Z, u_vals, NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(Z); + free_csr_matrix(A); + free_csr_matrix(B); + return 0; +} + +const char *test_jacobian_matmul_sin_Ax_cos_Bx(void) +{ + /* Z = sin(A @ X) @ cos(B @ X), shared variable X */ + double u_vals[6] = {0.5, 1.0, 1.5, 2.0, 2.5, 3.0}; + + CSR_Matrix *A = new_csr_random(2, 3, 1.0); + CSR_Matrix *B = new_csr_random(2, 3, 1.0); + + expr *X = new_variable(3, 2, 0, 6); /* 3x2, vars 0-5 */ + expr *AX = new_left_matmul(X, A); /* 2x2 */ + expr *BX = new_left_matmul(X, B); /* 2x2 */ + expr *sin_AX = new_sin(AX); /* 2x2 */ + expr *cos_BX = new_cos(BX); /* 2x2 */ + expr *Z = new_matmul(sin_AX, cos_BX); /* 2x2 */ + + mu_assert("check_jacobian failed", + check_jacobian(Z, u_vals, NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(Z); + free_csr_matrix(A); + free_csr_matrix(B); + return 0; +} + +const char *test_jacobian_matmul_X_X(void) +{ + /* Z = X @ X, same leaf variable as both children */ + double u_vals[4] = {1.0, 2.0, 3.0, 4.0}; + + expr *X = new_variable(2, 2, 0, 4); + expr *Z = new_matmul(X, X); /* 2x2 */ + + mu_assert("check_jacobian failed", + check_jacobian(Z, u_vals, NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(Z); + return 0; +} + From b2cb5475b39b54e2d2063520dec9092c65d581ed Mon Sep 17 00:00:00 2001 From: dance858 Date: Mon, 30 Mar 2026 09:33:23 -0700 Subject: [PATCH 03/13] add comment --- src/bivariate_full_dom/matmul.c | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/bivariate_full_dom/matmul.c b/src/bivariate_full_dom/matmul.c index fa83b7c..b85af4d 100644 --- a/src/bivariate_full_dom/matmul.c +++ b/src/bivariate_full_dom/matmul.c @@ -175,7 +175,8 @@ static void eval_jacobian(expr *node) CSC_Matrix *Jx_csc = x->work->jacobian_csc; CSC_Matrix *Jy_csc = y->work->jacobian_csc; - /* refresh children's CSC values */ + /* refresh children's CSC values (TODO: is this necessary in the sense that + * the kron infrastructure requires it?) */ csr_to_csc_fill_values(x->jacobian, Jx_csc, x->work->csc_work); csr_to_csc_fill_values(y->jacobian, Jy_csc, y->work->csc_work); From 8239141ffa3849d85085f04153750dfea90b57ef Mon Sep 17 00:00:00 2001 From: dance858 Date: Mon, 30 Mar 2026 10:13:27 -0700 Subject: [PATCH 04/13] chain rule for hessian first draft --- include/subexpr.h | 20 +- src/bivariate_full_dom/matmul.c | 434 +++++++++++++----- tests/all_tests.c | 5 + .../composite/test_chain_rule_jacobian.h | 13 +- .../composite/test_chain_rule_wsum_hess.h | 103 +++++ 5 files changed, 437 insertions(+), 138 deletions(-) diff --git a/include/subexpr.h b/include/subexpr.h index 64d32d5..3ef8f86 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -132,13 +132,25 @@ typedef struct right_matmul_expr } right_matmul_expr; /* Bivariate matrix multiplication: Z = f(u) @ g(u) where both children - * may be composite expressions. Stores intermediate CSR results for the - * Jacobian chain rule: J_Z = (Y^T x I_m) @ J_f + (I_n x X) @ J_g */ + * may be composite expressions. */ typedef struct matmul_expr { expr base; - CSR_Matrix *term1_CSR; /* (Y^T x I_m) @ J_left */ - CSR_Matrix *term2_CSR; /* (I_n x X) @ J_right */ + /* Jacobian workspace */ + CSR_Matrix *term1_CSR; /* (Y^T x I_m) @ J_f */ + CSR_Matrix *term2_CSR; /* (I_n x X) @ J_g */ + + /* Hessian workspace (composite only) */ + CSR_Matrix *B; /* cross-Hessian B(w), mk x kn */ + CSR_Matrix *BJ_g; /* B @ J_g */ + CSC_Matrix *BJ_g_CSC; /* BJ_g in CSC */ + int *BJ_g_csc_work; /* CSR-to-CSC workspace */ + CSR_Matrix *C; /* J_f^T @ B @ J_g */ + CSR_Matrix *CT; /* C^T */ + int *idx_map_C; + int *idx_map_CT; + int *idx_map_Hf; + int *idx_map_Hg; } matmul_expr; /* Constant scalar multiplication: y = a * child where a is a constant double */ diff --git a/src/bivariate_full_dom/matmul.c b/src/bivariate_full_dom/matmul.c index b85af4d..263e61d 100644 --- a/src/bivariate_full_dom/matmul.c +++ b/src/bivariate_full_dom/matmul.c @@ -18,14 +18,70 @@ #include "bivariate_full_dom.h" #include "subexpr.h" #include "utils/CSC_Matrix.h" +#include "utils/CSR_Matrix.h" #include "utils/CSR_sum.h" #include "utils/linalg_dense_sparse_matmuls.h" +#include "utils/linalg_sparse_matmuls.h" #include "utils/mini_numpy.h" +#include "utils/utils.h" #include #include #include #include +// ------------------------------------------------------------------------------ +// Helpers for the cross-Hessian B(w) of the bilinear map X @ Y. +// B is mk x kn with B[row + j*m, j + col*k] = w[row + col*m]. +// Each row has exactly n nonzeros. +// ------------------------------------------------------------------------------ + +static CSR_Matrix *build_cross_hessian_sparsity(int m, int k, int n) +{ + int total_nnz = m * k * n; + CSR_Matrix *B = new_csr_matrix(m * k, k * n, total_nnz); + int idx = 0; + + for (int j = 0; j < k; j++) + { + for (int row = 0; row < m; row++) + { + B->p[row + j * m] = idx; + for (int col = 0; col < n; col++) + { + B->i[idx++] = j + col * k; + } + } + } + B->p[m * k] = idx; + assert(idx == total_nnz); + return B; +} + +static void fill_cross_hessian_values(int m, int k, int n, const double *w, + CSR_Matrix *B) +{ + int idx = 0; + for (int j = 0; j < k; j++) + { + for (int row = 0; row < m; row++) + { + for (int col = 0; col < n; col++) + { + B->x[idx++] = w[row + col * m]; + } + } + } +} + +static void accumulate_mapped(double *dest, const CSR_Matrix *src, + const int *idx_map) +{ + for (int j = 0; j < src->nnz; j++) + { + dest[idx_map[j]] += src->x[j]; + } +} + // ------------------------------------------------------------------------------ // Implementation of matrix multiplication: Z = X @ Y // where X is m x k and Y is k x n, producing Z which is m x n @@ -51,149 +107,141 @@ static bool is_affine(const expr *node) return false; } -static void jacobian_init_impl(expr *node) +// ------------------------------------------------------------------------------ +// No chain rule: both children are different leaf variables +// ------------------------------------------------------------------------------ + +static void jacobian_init_no_chain_rule(expr *node) { expr *x = node->left; expr *y = node->right; - - /* dimensions: X is m x k, Y is k x n, Z is m x n */ int m = x->d1; int k = x->d2; int n = y->d2; + int nnz = m * n * 2 * k; + node->jacobian = new_csr_matrix(node->size, node->n_vars, nnz); - if (x->var_id != NOT_A_VARIABLE && y->var_id != NOT_A_VARIABLE && - x->var_id != y->var_id) + int nnz_idx = 0; + for (int i = 0; i < node->size; i++) { - /* both children are differentleaf variables */ - int nnz = m * n * 2 * k; - node->jacobian = new_csr_matrix(node->size, node->n_vars, nnz); - - int nnz_idx = 0; - for (int i = 0; i < node->size; i++) - { - int row = i % m; - int col = i / m; + int row = i % m; + int col = i / m; - node->jacobian->p[i] = nnz_idx; + node->jacobian->p[i] = nnz_idx; - if (x->var_id < y->var_id) + if (x->var_id < y->var_id) + { + for (int j = 0; j < k; j++) { - for (int j = 0; j < k; j++) - { - node->jacobian->i[nnz_idx++] = x->var_id + row + j * m; - } - for (int j = 0; j < k; j++) - { - node->jacobian->i[nnz_idx++] = y->var_id + col * k + j; - } + node->jacobian->i[nnz_idx++] = x->var_id + row + j * m; } - else + for (int j = 0; j < k; j++) { - for (int j = 0; j < k; j++) - { - node->jacobian->i[nnz_idx++] = y->var_id + col * k + j; - } - for (int j = 0; j < k; j++) - { - node->jacobian->i[nnz_idx++] = x->var_id + row + j * m; - } + node->jacobian->i[nnz_idx++] = y->var_id + col * k + j; + } + } + else + { + for (int j = 0; j < k; j++) + { + node->jacobian->i[nnz_idx++] = y->var_id + col * k + j; + } + for (int j = 0; j < k; j++) + { + node->jacobian->i[nnz_idx++] = x->var_id + row + j * m; } } - node->jacobian->p[node->size] = nnz_idx; - assert(nnz_idx == nnz); - } - else - { - /* chain rule: the jacobian of f(u) @ g(u) with f(u) and g(u) matrices - is term1 + term2 where term1 = (g(u)^T kron I) @ J_f and - term2 = (I kron f(u)) @ J_g. */ - matmul_expr *mnode = (matmul_expr *) node; - - jacobian_init(x); - jacobian_init(y); - jacobian_csc_init(x); - jacobian_csc_init(y); - - mnode->term1_CSR = YT_kron_I_alloc(m, k, n, x->work->jacobian_csc); - mnode->term2_CSR = I_kron_X_alloc(m, k, n, y->work->jacobian_csc); - - int max_nnz = mnode->term1_CSR->nnz + mnode->term2_CSR->nnz; - node->jacobian = new_csr_matrix(node->size, node->n_vars, max_nnz); - sum_csr_matrices_fill_sparsity(mnode->term1_CSR, mnode->term2_CSR, - node->jacobian); } + node->jacobian->p[node->size] = nnz_idx; + assert(nnz_idx == nnz); } -static void eval_jacobian(expr *node) +static void eval_jacobian_no_chain_rule(expr *node) { expr *x = node->left; expr *y = node->right; - - /* dimensions: X is m x k, Y is k x n, Z is m x n */ int m = x->d1; int k = x->d2; - int n = y->d2; + double *Jx = node->jacobian->x; - if (x->var_id != NOT_A_VARIABLE && y->var_id != NOT_A_VARIABLE && - x->var_id != y->var_id) + for (int i = 0; i < node->size; i++) { - /* both children are different leaf variables */ - double *Jx = node->jacobian->x; + int row = i % m; + int col = i / m; + int pos = node->jacobian->p[i]; - for (int i = 0; i < node->size; i++) + if (x->var_id < y->var_id) { - int row = i % m; - int col = i / m; - int pos = node->jacobian->p[i]; - - if (x->var_id < y->var_id) + memcpy(Jx + pos, y->value + col * k, k * sizeof(double)); + for (int j = 0; j < k; j++) { - memcpy(Jx + pos, y->value + col * k, k * sizeof(double)); - for (int j = 0; j < k; j++) - { - Jx[pos + k + j] = x->value[row + j * m]; - } + Jx[pos + k + j] = x->value[row + j * m]; } - else + } + else + { + for (int j = 0; j < k; j++) { - for (int j = 0; j < k; j++) - { - Jx[pos + j] = x->value[row + j * m]; - } - memcpy(Jx + pos + k, y->value + col * k, k * sizeof(double)); + Jx[pos + j] = x->value[row + j * m]; } + memcpy(Jx + pos + k, y->value + col * k, k * sizeof(double)); } } - else - { - /* composite case */ - matmul_expr *mnode = (matmul_expr *) node; +} - x->eval_jacobian(x); - y->eval_jacobian(y); +// ------------------------------------------------------------------------------ +// Chain rule: at least one child is composite, or same variable +// ------------------------------------------------------------------------------ - CSC_Matrix *Jx_csc = x->work->jacobian_csc; - CSC_Matrix *Jy_csc = y->work->jacobian_csc; +static void jacobian_init_chain_rule(expr *node) +{ + expr *x = node->left; + expr *y = node->right; + matmul_expr *mnode = (matmul_expr *) node; + int m = x->d1; + int k = x->d2; + int n = y->d2; - /* refresh children's CSC values (TODO: is this necessary in the sense that - * the kron infrastructure requires it?) */ - csr_to_csc_fill_values(x->jacobian, Jx_csc, x->work->csc_work); - csr_to_csc_fill_values(y->jacobian, Jy_csc, y->work->csc_work); + jacobian_init(x); + jacobian_init(y); + jacobian_csc_init(x); + jacobian_csc_init(y); - /* compute term1, term2, and sum */ - YT_kron_I_fill_values(m, k, n, y->value, Jx_csc, mnode->term1_CSR); - I_kron_X_fill_values(m, k, n, x->value, Jy_csc, mnode->term2_CSR); - sum_csr_matrices_fill_values(mnode->term1_CSR, mnode->term2_CSR, - node->jacobian); - } + mnode->term1_CSR = YT_kron_I_alloc(m, k, n, x->work->jacobian_csc); + mnode->term2_CSR = I_kron_X_alloc(m, k, n, y->work->jacobian_csc); + + int max_nnz = mnode->term1_CSR->nnz + mnode->term2_CSR->nnz; + node->jacobian = new_csr_matrix(node->size, node->n_vars, max_nnz); + sum_csr_matrices_fill_sparsity(mnode->term1_CSR, mnode->term2_CSR, + node->jacobian); } -static void wsum_hess_init_impl(expr *node) +static void eval_jacobian_chain_rule(expr *node) { expr *x = node->left; expr *y = node->right; + matmul_expr *mnode = (matmul_expr *) node; + int m = x->d1; + int k = x->d2; + int n = y->d2; + + x->eval_jacobian(x); + y->eval_jacobian(y); + + /* refresh children's CSC values */ + csr_to_csc_fill_values(x->jacobian, x->work->jacobian_csc, x->work->csc_work); + csr_to_csc_fill_values(y->jacobian, y->work->jacobian_csc, y->work->csc_work); + + YT_kron_I_fill_values(m, k, n, y->value, x->work->jacobian_csc, + mnode->term1_CSR); + I_kron_X_fill_values(m, k, n, x->value, y->work->jacobian_csc, mnode->term2_CSR); + sum_csr_matrices_fill_values(mnode->term1_CSR, mnode->term2_CSR, node->jacobian); +} - /* dimensions: X is m x k, Y is k x n, Z is m x n */ +static void wsum_hess_init_no_chain_rule(expr *node) +{ + expr *x = node->left; + expr *y = node->right; int m = x->d1; int k = x->d2; int n = y->d2; @@ -206,7 +254,6 @@ static void wsum_hess_init_impl(expr *node) if (x->var_id < y->var_id) { - /* fill rows corresponding to x */ for (i = 0; i < x->size; i++) { Hp[x->var_id + i] = nnz; @@ -216,14 +263,10 @@ static void wsum_hess_init_impl(expr *node) Hi[nnz++] = start + col * k; } } - - /* fill rows between x and y */ for (i = x->var_id + x->size; i < y->var_id; i++) { Hp[i] = nnz; } - - /* fill rows corresponding to y */ for (i = 0; i < y->size; i++) { Hp[y->var_id + i] = nnz; @@ -233,8 +276,6 @@ static void wsum_hess_init_impl(expr *node) Hi[nnz++] = start + row; } } - - /* fill rows after y */ for (i = y->var_id + y->size; i <= node->n_vars; i++) { Hp[i] = nnz; @@ -242,8 +283,6 @@ static void wsum_hess_init_impl(expr *node) } else { - /* Y has lower var_id than X */ - /* fill rows corresponding to y */ for (i = 0; i < y->size; i++) { Hp[y->var_id + i] = nnz; @@ -253,14 +292,10 @@ static void wsum_hess_init_impl(expr *node) Hi[nnz++] = start + row; } } - - /* fill rows between y and x */ for (i = y->var_id + y->size; i < x->var_id; i++) { Hp[i] = nnz; } - - /* fill rows corresponding to x */ for (i = 0; i < x->size; i++) { Hp[x->var_id + i] = nnz; @@ -270,34 +305,80 @@ static void wsum_hess_init_impl(expr *node) Hi[nnz++] = start + col * k; } } - - /* fill rows after x */ for (i = x->var_id + x->size; i <= node->n_vars; i++) { Hp[i] = nnz; } } - Hp[node->n_vars] = nnz; assert(nnz == total_nnz); } -static void eval_wsum_hess(expr *node, const double *w) +static void wsum_hess_init_chain_rule(expr *node) { expr *x = node->left; expr *y = node->right; + matmul_expr *mnode = (matmul_expr *) node; + int m = x->d1; + int k = x->d2; + int n = y->d2; + + jacobian_csc_init(x); + jacobian_csc_init(y); + CSC_Matrix *Jf = x->work->jacobian_csc; + CSC_Matrix *Jg = y->work->jacobian_csc; + /* build cross-Hessian B sparsity */ + mnode->B = build_cross_hessian_sparsity(m, k, n); + + /* C = J_f^T @ B @ J_g: + * step 1: BJ_g = B @ J_g */ + mnode->BJ_g = csr_csc_matmul_alloc(mnode->B, Jg); + mnode->BJ_g_csc_work = + (int *) malloc(MAX(mnode->BJ_g->m, mnode->BJ_g->n) * sizeof(int)); + mnode->BJ_g_CSC = csr_to_csc_alloc(mnode->BJ_g, mnode->BJ_g_csc_work); + + /* step 2: C = J_f^T @ BJ_g via BTA (B^T D A with D=I) */ + mnode->C = BTA_alloc(mnode->BJ_g_CSC, Jf); + + /* C^T */ + node->work->iwork = (int *) malloc(mnode->C->m * sizeof(int)); + mnode->CT = AT_alloc(mnode->C, node->work->iwork); + + /* allocate weight backprop workspace */ + if (!x->is_affine(x) || !y->is_affine(y)) + { + node->work->dwork = + (double *) malloc(MAX(x->size, y->size) * sizeof(double)); + } + + /* init child Hessians */ + wsum_hess_init(x); + wsum_hess_init(y); + + /* merge 4 sparsity patterns */ + int *maps[4]; + node->wsum_hess = sum_4_csr_fill_sparsity_and_idx_maps( + mnode->C, mnode->CT, x->wsum_hess, y->wsum_hess, maps); + mnode->idx_map_C = maps[0]; + mnode->idx_map_CT = maps[1]; + mnode->idx_map_Hf = maps[2]; + mnode->idx_map_Hg = maps[3]; +} + +static void eval_wsum_hess_no_chain_rule(expr *node, const double *w) +{ + expr *x = node->left; + expr *y = node->right; int m = x->d1; int k = x->d2; int n = y->d2; int offset = 0; - double *Hx = node->wsum_hess->x; const double *w_temp; if (x->var_id < y->var_id) { - /* rows corresponding to x */ for (int k_idx = 0; k_idx < k; k_idx++) { for (int row = 0; row < m; row++) @@ -308,8 +389,6 @@ static void eval_wsum_hess(expr *node, const double *w) } } } - - /* rows corresponding to y */ for (int col = 0; col < n; col++) { w_temp = w + col * m; @@ -322,7 +401,6 @@ static void eval_wsum_hess(expr *node, const double *w) } else { - /* rows corresponding to y */ for (int col = 0; col < n; col++) { w_temp = w + col * m; @@ -332,8 +410,6 @@ static void eval_wsum_hess(expr *node, const double *w) offset += m; } } - - /* rows corresponding to x */ for (int k_idx = 0; k_idx < k; k_idx++) { for (int row = 0; row < m; row++) @@ -347,6 +423,98 @@ static void eval_wsum_hess(expr *node, const double *w) } } +static void eval_wsum_hess_chain_rule(expr *node, const double *w) +{ + expr *x = node->left; + expr *y = node->right; + matmul_expr *mnode = (matmul_expr *) node; + int m = x->d1; + int k = x->d2; + int n = y->d2; + bool is_x_affine = x->is_affine(x); + bool is_y_affine = y->is_affine(y); + + /* refresh child Jacobian CSC values (cache if affine) */ + if (!x->work->jacobian_csc_filled) + { + csr_to_csc_fill_values(x->jacobian, x->work->jacobian_csc, + x->work->csc_work); + if (is_x_affine) + { + x->work->jacobian_csc_filled = true; + } + } + if (!y->work->jacobian_csc_filled) + { + csr_to_csc_fill_values(y->jacobian, y->work->jacobian_csc, + y->work->csc_work); + if (is_y_affine) + { + y->work->jacobian_csc_filled = true; + } + } + + CSC_Matrix *Jf = x->work->jacobian_csc; + CSC_Matrix *Jg = y->work->jacobian_csc; + + /* compute C = J_f^T @ B(w) @ J_g */ + fill_cross_hessian_values(m, k, n, w, mnode->B); + csr_csc_matmul_fill_values(mnode->B, Jg, mnode->BJ_g); + csr_to_csc_fill_values(mnode->BJ_g, mnode->BJ_g_CSC, mnode->BJ_g_csc_work); + BTDA_fill_values(mnode->BJ_g_CSC, Jf, NULL, mnode->C); + + /* C^T */ + AT_fill_values(mnode->C, mnode->CT, node->work->iwork); + + /* backpropagate weights and recurse into children */ + if (!is_x_affine) + { + /* v_f = vec(W @ Y^T): + * v_f[row + j*m] = sum_col Y[j,col] * w[row + col*m] */ + double *v_f = node->work->dwork; + for (int j = 0; j < k; j++) + { + for (int row = 0; row < m; row++) + { + double sum = 0.0; + for (int col = 0; col < n; col++) + { + sum += y->value[j + col * k] * w[row + col * m]; + } + v_f[row + j * m] = sum; + } + } + x->eval_wsum_hess(x, v_f); + } + + if (!is_y_affine) + { + /* v_g = vec(X^T @ W): + * v_g[j + col*k] = sum_row X[row,j] * w[row + col*m] */ + double *v_g = node->work->dwork; + for (int col = 0; col < n; col++) + { + for (int j = 0; j < k; j++) + { + double sum = 0.0; + for (int row = 0; row < m; row++) + { + sum += x->value[row + j * m] * w[row + col * m]; + } + v_g[j + col * k] = sum; + } + } + y->eval_wsum_hess(y, v_g); + } + + /* accumulate H = C + C^T + H_f + H_g */ + memset(node->wsum_hess->x, 0, node->wsum_hess->nnz * sizeof(double)); + accumulate_mapped(node->wsum_hess->x, mnode->C, mnode->idx_map_C); + accumulate_mapped(node->wsum_hess->x, mnode->CT, mnode->idx_map_CT); + accumulate_mapped(node->wsum_hess->x, x->wsum_hess, mnode->idx_map_Hf); + accumulate_mapped(node->wsum_hess->x, y->wsum_hess, mnode->idx_map_Hg); +} + expr *new_matmul(expr *x, expr *y) { /* Verify dimensions: x->d2 must equal y->d1 */ @@ -362,9 +530,21 @@ expr *new_matmul(expr *x, expr *y) /* Allocate the expression node */ expr *node = (expr *) calloc(1, sizeof(matmul_expr)); - /* Initialize with d1 = x->d1, d2 = y->d2 (result is m x n) */ - init_expr(node, x->d1, y->d2, x->n_vars, forward, jacobian_init_impl, - eval_jacobian, is_affine, wsum_hess_init_impl, eval_wsum_hess, NULL); + /* Choose no-chain-rule or chain-rule function pointers */ + bool use_chain_rule = !(x->var_id != NOT_A_VARIABLE && + y->var_id != NOT_A_VARIABLE && x->var_id != y->var_id); + + jacobian_init_fn jac_init = + use_chain_rule ? jacobian_init_chain_rule : jacobian_init_no_chain_rule; + eval_jacobian_fn jac_eval = + use_chain_rule ? eval_jacobian_chain_rule : eval_jacobian_no_chain_rule; + wsum_hess_init_fn hess_init = + use_chain_rule ? wsum_hess_init_chain_rule : wsum_hess_init_no_chain_rule; + wsum_hess_fn hess_eval = + use_chain_rule ? eval_wsum_hess_chain_rule : eval_wsum_hess_no_chain_rule; + + init_expr(node, x->d1, y->d2, x->n_vars, forward, jac_init, jac_eval, is_affine, + hess_init, hess_eval, NULL); /* Set children */ node->left = x; diff --git a/tests/all_tests.c b/tests/all_tests.c index ba59bc0..c004226 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -285,6 +285,11 @@ int main(void) mu_run_test(test_wsum_hess_quad_form_Ax, tests_run); mu_run_test(test_wsum_hess_quad_form_sin_Ax, tests_run); mu_run_test(test_wsum_hess_quad_form_exp, tests_run); + mu_run_test(test_wsum_hess_matmul_exp_exp, tests_run); + mu_run_test(test_wsum_hess_matmul_sin_cos, tests_run); + mu_run_test(test_wsum_hess_matmul_Ax_By, tests_run); + mu_run_test(test_wsum_hess_matmul_sin_Ax_cos_Bx, tests_run); + mu_run_test(test_wsum_hess_matmul_X_X, tests_run); printf("\n--- Utility Tests ---\n"); mu_run_test(test_cblas_ddot, tests_run); diff --git a/tests/jacobian_tests/composite/test_chain_rule_jacobian.h b/tests/jacobian_tests/composite/test_chain_rule_jacobian.h index 710aa89..bc61d83 100644 --- a/tests/jacobian_tests/composite/test_chain_rule_jacobian.h +++ b/tests/jacobian_tests/composite/test_chain_rule_jacobian.h @@ -238,12 +238,12 @@ const char *test_jacobian_matmul_sin_Ax_cos_Bx(void) CSR_Matrix *A = new_csr_random(2, 3, 1.0); CSR_Matrix *B = new_csr_random(2, 3, 1.0); - expr *X = new_variable(3, 2, 0, 6); /* 3x2, vars 0-5 */ - expr *AX = new_left_matmul(X, A); /* 2x2 */ - expr *BX = new_left_matmul(X, B); /* 2x2 */ - expr *sin_AX = new_sin(AX); /* 2x2 */ - expr *cos_BX = new_cos(BX); /* 2x2 */ - expr *Z = new_matmul(sin_AX, cos_BX); /* 2x2 */ + expr *X = new_variable(3, 2, 0, 6); /* 3x2, vars 0-5 */ + expr *AX = new_left_matmul(X, A); /* 2x2 */ + expr *BX = new_left_matmul(X, B); /* 2x2 */ + expr *sin_AX = new_sin(AX); /* 2x2 */ + expr *cos_BX = new_cos(BX); /* 2x2 */ + expr *Z = new_matmul(sin_AX, cos_BX); /* 2x2 */ mu_assert("check_jacobian failed", check_jacobian(Z, u_vals, NUMERICAL_DIFF_DEFAULT_H)); @@ -268,4 +268,3 @@ const char *test_jacobian_matmul_X_X(void) free_expr(Z); return 0; } - diff --git a/tests/wsum_hess/composite/test_chain_rule_wsum_hess.h b/tests/wsum_hess/composite/test_chain_rule_wsum_hess.h index d0dbe4a..a0c469c 100644 --- a/tests/wsum_hess/composite/test_chain_rule_wsum_hess.h +++ b/tests/wsum_hess/composite/test_chain_rule_wsum_hess.h @@ -259,6 +259,109 @@ const char *test_wsum_hess_quad_form_sin_Ax(void) return 0; } +const char *test_wsum_hess_matmul_exp_exp(void) +{ + /* Z = exp(X) @ exp(Y), X is 2x3, Y is 3x2, Z is 2x2 */ + double u_vals[12] = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2}; + double w[4] = {1.0, 2.0, 3.0, 4.0}; + + expr *X = new_variable(2, 3, 0, 12); + expr *Y = new_variable(3, 2, 6, 12); + expr *exp_X = new_exp(X); + expr *exp_Y = new_exp(Y); + expr *Z = new_matmul(exp_X, exp_Y); + + mu_assert("check_wsum_hess failed", + check_wsum_hess(Z, u_vals, w, NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(Z); + return 0; +} + +const char *test_wsum_hess_matmul_sin_cos(void) +{ + /* Z = sin(X) @ cos(Y), X is 2x2, Y is 2x3, Z is 2x3 */ + double u_vals[10] = {0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0}; + double w[6] = {1.1, 2.2, 3.3, 4.4, 5.5, 6.6}; + + expr *X = new_variable(2, 2, 0, 10); + expr *Y = new_variable(2, 3, 4, 10); + expr *sin_X = new_sin(X); + expr *cos_Y = new_cos(Y); + expr *Z = new_matmul(sin_X, cos_Y); + + mu_assert("check_wsum_hess failed", + check_wsum_hess(Z, u_vals, w, NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(Z); + return 0; +} + +const char *test_wsum_hess_matmul_Ax_By(void) +{ + /* Z = (A @ X) @ (B @ Y), affine children */ + double u_vals[10] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0}; + double w[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + + CSR_Matrix *A = new_csr_random(3, 2, 1.0); + CSR_Matrix *B = new_csr_random(2, 3, 1.0); + + expr *X = new_variable(2, 2, 0, 10); + expr *Y = new_variable(3, 2, 4, 10); + expr *AX = new_left_matmul(X, A); /* 3x2 */ + expr *BY = new_left_matmul(Y, B); /* 2x2 */ + expr *Z = new_matmul(AX, BY); /* 3x2 */ + + mu_assert("check_wsum_hess failed", + check_wsum_hess(Z, u_vals, w, NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(Z); + free_csr_matrix(A); + free_csr_matrix(B); + return 0; +} + +const char *test_wsum_hess_matmul_sin_Ax_cos_Bx(void) +{ + /* Z = sin(A @ X) @ cos(B @ X), shared variable */ + double u_vals[6] = {0.5, 1.0, 1.5, 2.0, 2.5, 3.0}; + double w[4] = {1.0, 2.0, 3.0, 4.0}; + + CSR_Matrix *A = new_csr_random(2, 3, 1.0); + CSR_Matrix *B = new_csr_random(2, 3, 1.0); + + expr *X = new_variable(3, 2, 0, 6); + expr *AX = new_left_matmul(X, A); /* 2x2 */ + expr *BX = new_left_matmul(X, B); /* 2x2 */ + expr *sin_AX = new_sin(AX); + expr *cos_BX = new_cos(BX); + expr *Z = new_matmul(sin_AX, cos_BX); /* 2x2 */ + + mu_assert("check_wsum_hess failed", + check_wsum_hess(Z, u_vals, w, NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(Z); + free_csr_matrix(A); + free_csr_matrix(B); + return 0; +} + +const char *test_wsum_hess_matmul_X_X(void) +{ + /* Z = X @ X, same leaf variable */ + double u_vals[4] = {1.0, 2.0, 3.0, 4.0}; + double w[4] = {1.0, 2.0, 3.0, 4.0}; + + expr *X = new_variable(2, 2, 0, 4); + expr *Z = new_matmul(X, X); + + mu_assert("check_wsum_hess failed", + check_wsum_hess(Z, u_vals, w, NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(Z); + return 0; +} + const char *test_wsum_hess_quad_form_exp(void) { double u_vals[3] = {0.5, 1.0, 1.5}; From 6126885742898e90bc28f45270e16723c15b0d98 Mon Sep 17 00:00:00 2001 From: dance858 Date: Mon, 30 Mar 2026 10:33:38 -0700 Subject: [PATCH 05/13] split up into several functions --- include/utils/CSC_Matrix.h | 14 +-- include/utils/CSR_Matrix.h | 6 +- include/utils/CSR_sum.h | 14 ++- include/utils/linalg_dense_sparse_matmuls.h | 10 +- include/utils/linalg_sparse_matmuls.h | 9 +- src/affine/add.c | 13 +-- src/affine/hstack.c | 4 +- src/affine/left_matmul.c | 4 +- src/bivariate_full_dom/matmul.c | 101 +++++++++--------- src/bivariate_full_dom/multiply.c | 19 ++-- src/bivariate_restricted_dom/quad_over_lin.c | 5 +- src/elementwise_full_dom/common.c | 26 ++--- src/other/quad_form.c | 18 ++-- src/utils/CSC_Matrix.c | 14 +-- src/utils/CSR_Matrix.c | 6 +- src/utils/CSR_sum.c | 16 ++- src/utils/dense_matrix.c | 2 +- src/utils/linalg_dense_sparse_matmuls.c | 10 +- src/utils/linalg_sparse_matmuls.c | 8 +- src/utils/sparse_matrix.c | 2 +- tests/utils/test_csc_matrix.h | 6 +- tests/utils/test_csr_csc_conversion.h | 12 +-- tests/utils/test_csr_matrix.h | 4 +- tests/utils/test_linalg_sparse_matmuls.h | 2 +- .../test_linalg_utils_matmul_chain_rule.h | 12 +-- 25 files changed, 165 insertions(+), 172 deletions(-) diff --git a/include/utils/CSC_Matrix.h b/include/utils/CSC_Matrix.h index 951b088..ae62aab 100644 --- a/include/utils/CSC_Matrix.h +++ b/include/utils/CSC_Matrix.h @@ -39,27 +39,27 @@ CSR_Matrix *BTA_alloc(const CSC_Matrix *A, const CSC_Matrix *B); CSC_Matrix *symBA_alloc(const CSR_Matrix *B, const CSC_Matrix *A); /* Compute values for C = A^T D A (null d corresponds to D as identity) */ -void ATDA_fill_values(const CSC_Matrix *A, const double *d, CSR_Matrix *C); +void ATDA_fill_vals(const CSC_Matrix *A, const double *d, CSR_Matrix *C); /* Compute values for C = B^T D A (null d corresonds to D as identity) */ -void BTDA_fill_values(const CSC_Matrix *A, const CSC_Matrix *B, const double *d, - CSR_Matrix *C); +void BTDA_fill_vals(const CSC_Matrix *A, const CSC_Matrix *B, const double *d, + CSR_Matrix *C); /* Fill values of C = BA. The matrix B does not have to be symmetric */ -void BA_fill_values(const CSR_Matrix *B, const CSC_Matrix *A, CSC_Matrix *C); +void BA_fill_vals(const CSR_Matrix *B, const CSC_Matrix *A, CSC_Matrix *C); /* Fill values of C = x^T A. The matrix C must have filled sparsity. */ -void yTA_fill_values(const CSC_Matrix *A, const double *x, CSR_Matrix *C); +void yTA_fill_vals(const CSC_Matrix *A, const double *x, CSR_Matrix *C); /* Count nonzero columns of a CSC matrix */ int count_nonzero_cols_csc(const CSC_Matrix *A); /* convert from CSR to CSC format */ CSC_Matrix *csr_to_csc_alloc(const CSR_Matrix *A, int *iwork); -void csr_to_csc_fill_values(const CSR_Matrix *A, CSC_Matrix *C, int *iwork); +void csr_to_csc_fill_vals(const CSR_Matrix *A, CSC_Matrix *C, int *iwork); /* convert from CSC to CSR format */ CSR_Matrix *csc_to_csr_alloc(const CSC_Matrix *A, int *iwork); -void csc_to_csr_fill_values(const CSC_Matrix *A, CSR_Matrix *C, int *iwork); +void csc_to_csr_fill_vals(const CSC_Matrix *A, CSR_Matrix *C, int *iwork); #endif /* CSC_MATRIX_H */ diff --git a/include/utils/CSR_Matrix.h b/include/utils/CSR_Matrix.h index df81839..5bc3846 100644 --- a/include/utils/CSR_Matrix.h +++ b/include/utils/CSR_Matrix.h @@ -35,7 +35,7 @@ void copy_csr_matrix(const CSR_Matrix *A, CSR_Matrix *C); /* transpose functionality (iwork must be of size A->n) */ CSR_Matrix *transpose(const CSR_Matrix *A, int *iwork); CSR_Matrix *AT_alloc(const CSR_Matrix *A, int *iwork); -void AT_fill_values(const CSR_Matrix *A, CSR_Matrix *AT, int *iwork); +void AT_fill_vals(const CSR_Matrix *A, CSR_Matrix *AT, int *iwork); /* Build (I_p kron A) = blkdiag(A, A, ..., A) of size (p*A->m) x (p*A->n) */ CSR_Matrix *block_diag_repeat_csr(const CSR_Matrix *A, int p); @@ -49,7 +49,7 @@ void csr_matvec_wo_offset(const CSR_Matrix *A, const double *x, double *y); /* Computes values of the row matrix C = z^T A (column indices must have been pre-computed) and transposed matrix AT must be provided) */ -void csr_matvec_fill_values(const CSR_Matrix *AT, const double *z, CSR_Matrix *C); +void csr_matvec_fill_vals(const CSR_Matrix *AT, const double *z, CSR_Matrix *C); /* Insert value into CSR matrix A with just one row at col_idx. Assumes that A has enough space and that A does not have an element at col_idx. It does update @@ -60,7 +60,7 @@ void csr_insert_value(CSR_Matrix *A, int col_idx, double value); * d must have length m * C must be pre-allocated with same dimensions as A */ void diag_csr_mult(const double *d, const CSR_Matrix *A, CSR_Matrix *C); -void diag_csr_mult_fill_values(const double *d, const CSR_Matrix *A, CSR_Matrix *C); +void diag_csr_mult_fill_vals(const double *d, const CSR_Matrix *A, CSR_Matrix *C); /* Count number of columns with nonzero entries */ int count_nonzero_cols(const CSR_Matrix *A, bool *col_nz); diff --git a/include/utils/CSR_sum.h b/include/utils/CSR_sum.h index c04c7a8..6476071 100644 --- a/include/utils/CSR_sum.h +++ b/include/utils/CSR_sum.h @@ -14,14 +14,12 @@ void sum_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C); /* Compute sparsity pattern of A + B where A, B, C are CSR matrices. * Fills C->p, C->i, and C->nnz; does not touch C->x. */ -void sum_csr_matrices_fill_sparsity(const CSR_Matrix *A, const CSR_Matrix *B, - CSR_Matrix *C); +void sum_csr_alloc(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C); /* Fill only the values of C = A + B, assuming C's sparsity pattern (p and i) * is already filled and matches the union of A and B per row. Does not modify * C->p, C->i, or C->nnz. */ -void sum_csr_matrices_fill_values(const CSR_Matrix *A, const CSR_Matrix *B, - CSR_Matrix *C); +void sum_csr_fill_vals(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C); /* Compute C = diag(d1) * A + diag(d2) * B where A, B, C are CSR matrices */ void sum_scaled_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C, @@ -30,9 +28,9 @@ void sum_scaled_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matri /* Fill only the values of C = diag(d1) * A + diag(d2) * B, assuming C's sparsity * pattern (p and i) is already filled and matches the union of A and B per row. * Does not modify C->p, C->i, or C->nnz. */ -void sum_scaled_csr_matrices_fill_values(const CSR_Matrix *A, const CSR_Matrix *B, - CSR_Matrix *C, const double *d1, - const double *d2); +void sum_scaled_csr_matrices_fill_vals(const CSR_Matrix *A, const CSR_Matrix *B, + CSR_Matrix *C, const double *d1, + const double *d2); /* Sum all rows of A into a single row matrix C */ void sum_all_rows_csr(const CSR_Matrix *A, CSR_Matrix *C, @@ -43,7 +41,7 @@ void sum_all_rows_csr_fill_sparsity_and_idx_map(const CSR_Matrix *A, CSR_Matrix int *iwork, int *idx_map); /* Fill values of summed rows using precomputed idx_map and sparsity of C */ -// void sum_all_rows_csr_fill_values(const CSR_Matrix *A, CSR_Matrix *C, +// void sum_all_rows_csr_fill_vals(const CSR_Matrix *A, CSR_Matrix *C, // const int *idx_map); /* Fill accumulator for summing rows using precomputed idx_map for each nnz of A. diff --git a/include/utils/linalg_dense_sparse_matmuls.h b/include/utils/linalg_dense_sparse_matmuls.h index 8404940..cbc4634 100644 --- a/include/utils/linalg_dense_sparse_matmuls.h +++ b/include/utils/linalg_dense_sparse_matmuls.h @@ -9,18 +9,18 @@ * A is dense m x n, J is (n*p) x k in CSC, C is (m*p) x k in CSC. */ // TODO: maybe we can replace these with I_kron_X functionality? CSC_Matrix *I_kron_A_alloc(const Matrix *A, const CSC_Matrix *J, int p); -void I_kron_A_fill_values(const Matrix *A, const CSC_Matrix *J, CSC_Matrix *C); +void I_kron_A_fill_vals(const Matrix *A, const CSC_Matrix *J, CSC_Matrix *C); /* Sparsity and values of C = (Y^T kron I_m) @ J where Y is k x n, J is (m*k) x p, and C is (m*n) x p. Y is given in column-major dense format. */ CSR_Matrix *YT_kron_I_alloc(int m, int k, int n, const CSC_Matrix *J); -void YT_kron_I_fill_values(int m, int k, int n, const double *Y, const CSC_Matrix *J, - CSR_Matrix *C); +void YT_kron_I_fill_vals(int m, int k, int n, const double *Y, const CSC_Matrix *J, + CSR_Matrix *C); /* Sparsity and values of C = (I_n kron X) @ J where X is m x k (col-major dense), J is (k*n) x p, and C is (m*n) x p. */ CSR_Matrix *I_kron_X_alloc(int m, int k, int n, const CSC_Matrix *J); -void I_kron_X_fill_values(int m, int k, int n, const double *X, const CSC_Matrix *J, - CSR_Matrix *C); +void I_kron_X_fill_vals(int m, int k, int n, const double *X, const CSC_Matrix *J, + CSR_Matrix *C); #endif /* LINALG_DENSE_SPARSE_H */ diff --git a/include/utils/linalg_sparse_matmuls.h b/include/utils/linalg_sparse_matmuls.h index 25890a9..22c3576 100644 --- a/include/utils/linalg_sparse_matmuls.h +++ b/include/utils/linalg_sparse_matmuls.h @@ -19,9 +19,8 @@ struct CSC_Matrix *block_left_multiply_fill_sparsity(const struct CSR_Matrix *A, const struct CSC_Matrix *J, int p); -void block_left_multiply_fill_values(const struct CSR_Matrix *A, - const struct CSC_Matrix *J, - struct CSC_Matrix *C); +void block_left_multiply_fill_vals(const struct CSR_Matrix *A, + const struct CSC_Matrix *J, struct CSC_Matrix *C); /* Compute y = kron(I_p, A) @ x where A is m x n and x is(n*p)-length vector. The output y is m*p-length vector corresponding to @@ -34,8 +33,8 @@ void block_left_multiply_vec(const struct CSR_Matrix *A, const double *x, double /* Fill values of C = A @ B where A is CSR, B is CSC. * C must have sparsity pattern already computed. */ -void csr_csc_matmul_fill_values(const struct CSR_Matrix *A, - const struct CSC_Matrix *B, struct CSR_Matrix *C); +void csr_csc_matmul_fill_vals(const struct CSR_Matrix *A, const struct CSC_Matrix *B, + struct CSR_Matrix *C); /* C = A @ B where A is CSR, B is CSC. Result C is CSR. * Allocates and precomputes sparsity pattern. No workspace required. diff --git a/src/affine/add.c b/src/affine/add.c index f34160b..3231507 100644 --- a/src/affine/add.c +++ b/src/affine/add.c @@ -45,8 +45,7 @@ static void jacobian_init_impl(expr *node) node->jacobian = new_csr_matrix(node->size, node->n_vars, nnz_max); /* fill sparsity pattern */ - sum_csr_matrices_fill_sparsity(node->left->jacobian, node->right->jacobian, - node->jacobian); + sum_csr_alloc(node->left->jacobian, node->right->jacobian, node->jacobian); } static void eval_jacobian(expr *node) @@ -56,8 +55,7 @@ static void eval_jacobian(expr *node) node->right->eval_jacobian(node->right); /* sum children's jacobians */ - sum_csr_matrices_fill_values(node->left->jacobian, node->right->jacobian, - node->jacobian); + sum_csr_fill_vals(node->left->jacobian, node->right->jacobian, node->jacobian); } static void wsum_hess_init_impl(expr *node) @@ -71,8 +69,7 @@ static void wsum_hess_init_impl(expr *node) node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, nnz_max); /* fill sparsity pattern of hessian */ - sum_csr_matrices_fill_sparsity(node->left->wsum_hess, node->right->wsum_hess, - node->wsum_hess); + sum_csr_alloc(node->left->wsum_hess, node->right->wsum_hess, node->wsum_hess); } static void eval_wsum_hess(expr *node, const double *w) @@ -82,8 +79,8 @@ static void eval_wsum_hess(expr *node, const double *w) node->right->eval_wsum_hess(node->right, w); /* sum children's wsum_hess */ - sum_csr_matrices_fill_values(node->left->wsum_hess, node->right->wsum_hess, - node->wsum_hess); + sum_csr_fill_vals(node->left->wsum_hess, node->right->wsum_hess, + node->wsum_hess); } static bool is_affine(const expr *node) diff --git a/src/affine/hstack.c b/src/affine/hstack.c index 8636969..b10356f 100644 --- a/src/affine/hstack.c +++ b/src/affine/hstack.c @@ -124,7 +124,7 @@ static void wsum_hess_init_impl(expr *node) { expr *child = hnode->args[i]; copy_csr_matrix(H, hnode->CSR_work); - sum_csr_matrices_fill_sparsity(hnode->CSR_work, child->wsum_hess, H); + sum_csr_alloc(hnode->CSR_work, child->wsum_hess, H); } } @@ -140,7 +140,7 @@ static void wsum_hess_eval(expr *node, const double *w) expr *child = hnode->args[i]; child->eval_wsum_hess(child, w + row_offset); copy_csr_matrix(H, hnode->CSR_work); - sum_csr_matrices_fill_values(hnode->CSR_work, child->wsum_hess, H); + sum_csr_fill_vals(hnode->CSR_work, child->wsum_hess, H); row_offset += child->size; } } diff --git a/src/affine/left_matmul.c b/src/affine/left_matmul.c index 4067e65..a9a782d 100644 --- a/src/affine/left_matmul.c +++ b/src/affine/left_matmul.c @@ -106,11 +106,11 @@ static void eval_jacobian(expr *node) /* evaluate child's jacobian and convert to CSC */ x->eval_jacobian(x); - csr_to_csc_fill_values(x->jacobian, Jchild_CSC, node->work->iwork); + csr_to_csc_fill_vals(x->jacobian, Jchild_CSC, node->work->iwork); /* compute this node's jacobian: */ lnode->A->block_left_mult_values(lnode->A, Jchild_CSC, J_CSC); - csc_to_csr_fill_values(J_CSC, node->jacobian, lnode->csc_to_csr_work); + csc_to_csr_fill_vals(J_CSC, node->jacobian, lnode->csc_to_csr_work); } static void wsum_hess_init_impl(expr *node) diff --git a/src/bivariate_full_dom/matmul.c b/src/bivariate_full_dom/matmul.c index 263e61d..4344a8f 100644 --- a/src/bivariate_full_dom/matmul.c +++ b/src/bivariate_full_dom/matmul.c @@ -107,12 +107,16 @@ static bool is_affine(const expr *node) return false; } -// ------------------------------------------------------------------------------ -// No chain rule: both children are different leaf variables -// ------------------------------------------------------------------------------ - +// -------------------------------------------------------------------------------- +// Jacobian initialization and evaluation for the case when no chain rule is needed, +// ie., when both children are different leaf variables. +// -------------------------------------------------------------------------------- static void jacobian_init_no_chain_rule(expr *node) { + assert(node->left->var_id != NOT_A_VARIABLE && + node->right->var_id != NOT_A_VARIABLE && + node->left->var_id != node->right->var_id); + expr *x = node->left; expr *y = node->right; int m = x->d1; @@ -189,53 +193,54 @@ static void eval_jacobian_no_chain_rule(expr *node) } } -// ------------------------------------------------------------------------------ -// Chain rule: at least one child is composite, or same variable -// ------------------------------------------------------------------------------ - +// ------------------------------------------------------------------------------------ +// Jacobian initialization and evaluation for the case where chain rule is needed, +// ie., when at least one child is composite, or same variable. The jacobian of h(u) +// = f(u) @ g(u) where f is m x k and g is k x n, is given by Jh = (g^T kron I_m) +// Jf + (I_n kron f) Jg . */ +// ------------------------------------------------------------------------------------ static void jacobian_init_chain_rule(expr *node) { - expr *x = node->left; - expr *y = node->right; + expr *f = node->left; + expr *g = node->right; matmul_expr *mnode = (matmul_expr *) node; - int m = x->d1; - int k = x->d2; - int n = y->d2; - - jacobian_init(x); - jacobian_init(y); - jacobian_csc_init(x); - jacobian_csc_init(y); - - mnode->term1_CSR = YT_kron_I_alloc(m, k, n, x->work->jacobian_csc); - mnode->term2_CSR = I_kron_X_alloc(m, k, n, y->work->jacobian_csc); - + int m = f->d1; + int k = f->d2; + int n = g->d2; + + /* initialize Jacobian of children */ + jacobian_init(f); + jacobian_init(g); + jacobian_csc_init(f); + jacobian_csc_init(g); + + /* initialize term1, term2, and their sum */ + mnode->term1_CSR = YT_kron_I_alloc(m, k, n, f->work->jacobian_csc); + mnode->term2_CSR = I_kron_X_alloc(m, k, n, g->work->jacobian_csc); int max_nnz = mnode->term1_CSR->nnz + mnode->term2_CSR->nnz; node->jacobian = new_csr_matrix(node->size, node->n_vars, max_nnz); - sum_csr_matrices_fill_sparsity(mnode->term1_CSR, mnode->term2_CSR, - node->jacobian); + sum_csr_alloc(mnode->term1_CSR, mnode->term2_CSR, node->jacobian); } static void eval_jacobian_chain_rule(expr *node) { - expr *x = node->left; - expr *y = node->right; + expr *f = node->left; + expr *g = node->right; matmul_expr *mnode = (matmul_expr *) node; - int m = x->d1; - int k = x->d2; - int n = y->d2; - - x->eval_jacobian(x); - y->eval_jacobian(y); - - /* refresh children's CSC values */ - csr_to_csc_fill_values(x->jacobian, x->work->jacobian_csc, x->work->csc_work); - csr_to_csc_fill_values(y->jacobian, y->work->jacobian_csc, y->work->csc_work); - - YT_kron_I_fill_values(m, k, n, y->value, x->work->jacobian_csc, - mnode->term1_CSR); - I_kron_X_fill_values(m, k, n, x->value, y->work->jacobian_csc, mnode->term2_CSR); - sum_csr_matrices_fill_values(mnode->term1_CSR, mnode->term2_CSR, node->jacobian); + int m = f->d1; + int k = f->d2; + int n = g->d2; + + /* evaluate Jacobians of children */ + f->eval_jacobian(f); + g->eval_jacobian(g); + csr_to_csc_fill_vals(f->jacobian, f->work->jacobian_csc, f->work->csc_work); + csr_to_csc_fill_vals(g->jacobian, g->work->jacobian_csc, g->work->csc_work); + + /* evaluate term1, term2, and their sum */ + YT_kron_I_fill_vals(m, k, n, g->value, f->work->jacobian_csc, mnode->term1_CSR); + I_kron_X_fill_vals(m, k, n, f->value, g->work->jacobian_csc, mnode->term2_CSR); + sum_csr_fill_vals(mnode->term1_CSR, mnode->term2_CSR, node->jacobian); } static void wsum_hess_init_no_chain_rule(expr *node) @@ -437,8 +442,7 @@ static void eval_wsum_hess_chain_rule(expr *node, const double *w) /* refresh child Jacobian CSC values (cache if affine) */ if (!x->work->jacobian_csc_filled) { - csr_to_csc_fill_values(x->jacobian, x->work->jacobian_csc, - x->work->csc_work); + csr_to_csc_fill_vals(x->jacobian, x->work->jacobian_csc, x->work->csc_work); if (is_x_affine) { x->work->jacobian_csc_filled = true; @@ -446,8 +450,7 @@ static void eval_wsum_hess_chain_rule(expr *node, const double *w) } if (!y->work->jacobian_csc_filled) { - csr_to_csc_fill_values(y->jacobian, y->work->jacobian_csc, - y->work->csc_work); + csr_to_csc_fill_vals(y->jacobian, y->work->jacobian_csc, y->work->csc_work); if (is_y_affine) { y->work->jacobian_csc_filled = true; @@ -459,12 +462,12 @@ static void eval_wsum_hess_chain_rule(expr *node, const double *w) /* compute C = J_f^T @ B(w) @ J_g */ fill_cross_hessian_values(m, k, n, w, mnode->B); - csr_csc_matmul_fill_values(mnode->B, Jg, mnode->BJ_g); - csr_to_csc_fill_values(mnode->BJ_g, mnode->BJ_g_CSC, mnode->BJ_g_csc_work); - BTDA_fill_values(mnode->BJ_g_CSC, Jf, NULL, mnode->C); + csr_csc_matmul_fill_vals(mnode->B, Jg, mnode->BJ_g); + csr_to_csc_fill_vals(mnode->BJ_g, mnode->BJ_g_CSC, mnode->BJ_g_csc_work); + BTDA_fill_vals(mnode->BJ_g_CSC, Jf, NULL, mnode->C); /* C^T */ - AT_fill_values(mnode->C, mnode->CT, node->work->iwork); + AT_fill_vals(mnode->C, mnode->CT, node->work->iwork); /* backpropagate weights and recurse into children */ if (!is_x_affine) diff --git a/src/bivariate_full_dom/multiply.c b/src/bivariate_full_dom/multiply.c index bd13bfb..3ce8299 100644 --- a/src/bivariate_full_dom/multiply.c +++ b/src/bivariate_full_dom/multiply.c @@ -62,8 +62,7 @@ static void jacobian_init_impl(expr *node) node->jacobian = new_csr_matrix(node->size, node->n_vars, nnz_max); /* fill sparsity pattern */ - sum_csr_matrices_fill_sparsity(node->left->jacobian, node->right->jacobian, - node->jacobian); + sum_csr_alloc(node->left->jacobian, node->right->jacobian, node->jacobian); } static void eval_jacobian(expr *node) @@ -76,8 +75,8 @@ static void eval_jacobian(expr *node) /* chain rule: the jacobian of h(x) = f(g1(x), g2(x))) is Jh = J_{f, 1} J_{g1} + * J_{f, 2} J_{g2} */ - sum_scaled_csr_matrices_fill_values(x->jacobian, y->jacobian, node->jacobian, - y->value, x->value); + sum_scaled_csr_matrices_fill_vals(x->jacobian, y->jacobian, node->jacobian, + y->value, x->value); } static void wsum_hess_init_impl(expr *node) @@ -211,8 +210,8 @@ static void eval_wsum_hess(expr *node, const double *w) // ---------------------------------------------------------------------- if (!x->work->jacobian_csc_filled) { - csr_to_csc_fill_values(x->jacobian, x->work->jacobian_csc, - x->work->csc_work); + csr_to_csc_fill_vals(x->jacobian, x->work->jacobian_csc, + x->work->csc_work); if (is_x_affine) { @@ -222,8 +221,8 @@ static void eval_wsum_hess(expr *node, const double *w) if (!y->work->jacobian_csc_filled) { - csr_to_csc_fill_values(y->jacobian, y->work->jacobian_csc, - y->work->csc_work); + csr_to_csc_fill_vals(y->jacobian, y->work->jacobian_csc, + y->work->csc_work); if (is_y_affine) { @@ -240,8 +239,8 @@ static void eval_wsum_hess(expr *node, const double *w) elementwise_mult_expr *mul_node = (elementwise_mult_expr *) node; CSR_Matrix *C = mul_node->CSR_work1; CSR_Matrix *CT = mul_node->CSR_work2; - BTDA_fill_values(Jg1, Jg2, w, C); - AT_fill_values(C, CT, node->work->iwork); + BTDA_fill_vals(Jg1, Jg2, w, C); + AT_fill_vals(C, CT, node->work->iwork); // --------------------------------------------------------------- // compute term2 and term 3 diff --git a/src/bivariate_restricted_dom/quad_over_lin.c b/src/bivariate_restricted_dom/quad_over_lin.c index bc8ea05..9b78c83 100644 --- a/src/bivariate_restricted_dom/quad_over_lin.c +++ b/src/bivariate_restricted_dom/quad_over_lin.c @@ -124,8 +124,7 @@ static void jacobian_init_impl(expr *node) * For a linear operator the values are constant, so fill * them once here. */ jacobian_csc_init(x); - csr_to_csc_fill_values(x->jacobian, x->work->jacobian_csc, - x->work->csc_work); + csr_to_csc_fill_vals(x->jacobian, x->work->jacobian_csc, x->work->csc_work); } } @@ -164,7 +163,7 @@ static void eval_jacobian(expr *node) } /* chain rule (no derivative wrt y) using CSC format */ - yTA_fill_values(x->work->jacobian_csc, node->work->dwork, node->jacobian); + yTA_fill_vals(x->work->jacobian_csc, node->work->dwork, node->jacobian); /* insert derivative wrt y at right place (for correctness this assumes that y does not appear in the numerator, but this will always be diff --git a/src/elementwise_full_dom/common.c b/src/elementwise_full_dom/common.c index 59c43c5..9874bc0 100644 --- a/src/elementwise_full_dom/common.c +++ b/src/elementwise_full_dom/common.c @@ -49,7 +49,7 @@ void eval_jacobian_elementwise(expr *node) node->local_jacobian(node, node->work->local_jac_diag); memcpy(node->work->dwork, node->work->local_jac_diag, node->size * sizeof(double)); - diag_csr_mult_fill_values(node->work->dwork, Jg, node->jacobian); + diag_csr_mult_fill_vals(node->work->dwork, Jg, node->jacobian); } } @@ -102,8 +102,8 @@ void wsum_hess_init_elementwise(expr *node) /* wsum_hess = term1 + term2 */ int max_nnz = node->work->hess_term1->nnz + node->work->hess_term2->nnz; node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, max_nnz); - sum_csr_matrices_fill_sparsity(node->work->hess_term1, - node->work->hess_term2, node->wsum_hess); + sum_csr_alloc(node->work->hess_term1, node->work->hess_term2, + node->wsum_hess); } } } @@ -122,25 +122,25 @@ void eval_wsum_hess_elementwise(expr *node, const double *w) { if (!child->work->jacobian_csc_filled) { - csr_to_csc_fill_values(child->jacobian, child->work->jacobian_csc, - child->work->csc_work); + csr_to_csc_fill_vals(child->jacobian, child->work->jacobian_csc, + child->work->csc_work); child->work->jacobian_csc_filled = true; } node->local_wsum_hess(node, node->work->dwork, w); - ATDA_fill_values(child->work->jacobian_csc, node->work->dwork, - node->wsum_hess); + ATDA_fill_vals(child->work->jacobian_csc, node->work->dwork, + node->wsum_hess); } else { /* refresh CSC jacobian values */ - csr_to_csc_fill_values(child->jacobian, child->work->jacobian_csc, - child->work->csc_work); + csr_to_csc_fill_vals(child->jacobian, child->work->jacobian_csc, + child->work->csc_work); /* term1: Jg^T @ D @ Jg */ node->local_wsum_hess(node, node->work->dwork, w); - ATDA_fill_values(child->work->jacobian_csc, node->work->dwork, - node->work->hess_term1); + ATDA_fill_vals(child->work->jacobian_csc, node->work->dwork, + node->work->hess_term1); /* term2: child Hessian with weight Jf^T w */ memcpy(node->work->dwork, node->work->local_jac_diag, @@ -155,8 +155,8 @@ void eval_wsum_hess_elementwise(expr *node, const double *w) child->wsum_hess->nnz * sizeof(double)); /* wsum_hess = term1 + term2 */ - sum_csr_matrices_fill_values(node->work->hess_term1, - node->work->hess_term2, node->wsum_hess); + sum_csr_fill_vals(node->work->hess_term1, node->work->hess_term2, + node->wsum_hess); } } } diff --git a/src/other/quad_form.c b/src/other/quad_form.c index 91ea602..0311d22 100644 --- a/src/other/quad_form.c +++ b/src/other/quad_form.c @@ -88,8 +88,8 @@ static void eval_jacobian(expr *node) if (!x->work->jacobian_csc_filled) { - csr_to_csc_fill_values(x->jacobian, x->work->jacobian_csc, - x->work->csc_work); + csr_to_csc_fill_vals(x->jacobian, x->work->jacobian_csc, + x->work->csc_work); if (x->is_affine(x)) { @@ -99,7 +99,7 @@ static void eval_jacobian(expr *node) /* The jacobian has same values as the gradient, which is J_f^T (Q @ f(x)). Here, dwork stores Q @ f(x) from forward */ - yTA_fill_values(x->work->jacobian_csc, node->work->dwork, node->jacobian); + yTA_fill_vals(x->work->jacobian_csc, node->work->dwork, node->jacobian); cblas_dscal(node->jacobian->nnz, 2.0, node->jacobian->x, 1); } @@ -156,8 +156,8 @@ static void wsum_hess_init_impl(expr *node) /* hess = term1 + term2 */ int max_nnz = node->work->hess_term1->nnz + node->work->hess_term2->nnz; node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, max_nnz); - sum_csr_matrices_fill_sparsity(node->work->hess_term1, - node->work->hess_term2, node->wsum_hess); + sum_csr_alloc(node->work->hess_term1, node->work->hess_term2, + node->wsum_hess); } } @@ -180,7 +180,7 @@ static void eval_wsum_hess(expr *node, const double *w) CSC_Matrix *Jf = x->work->jacobian_csc; if (!x->work->jacobian_csc_filled) { - csr_to_csc_fill_values(x->jacobian, Jf, x->work->csc_work); + csr_to_csc_fill_vals(x->jacobian, Jf, x->work->csc_work); if (x->is_affine(x)) { @@ -193,8 +193,8 @@ static void eval_wsum_hess(expr *node, const double *w) CSR_Matrix *term2 = node->work->hess_term2; /* term1 = J_f^T Q J_f = J_f^T B */ - BA_fill_values(Q, Jf, QJf); - BTDA_fill_values(Jf, QJf, NULL, term1); + BA_fill_vals(Q, Jf, QJf); + BTDA_fill_vals(Jf, QJf, NULL, term1); /* term2 */ x->eval_wsum_hess(x, node->work->dwork); @@ -205,7 +205,7 @@ static void eval_wsum_hess(expr *node, const double *w) cblas_dscal(term2->nnz, two_w, term2->x, 1); /* sum the two terms */ - sum_csr_matrices_fill_values(term1, term2, node->wsum_hess); + sum_csr_fill_vals(term1, term2, node->wsum_hess); } } diff --git a/src/utils/CSC_Matrix.c b/src/utils/CSC_Matrix.c index 9d4e078..bdb7f59 100644 --- a/src/utils/CSC_Matrix.c +++ b/src/utils/CSC_Matrix.c @@ -168,7 +168,7 @@ static inline double sparse_wdot(const double *a_x, const int *a_i, int a_nnz, return sum; } -void ATDA_fill_values(const CSC_Matrix *A, const double *d, CSR_Matrix *C) +void ATDA_fill_vals(const CSC_Matrix *A, const double *d, CSR_Matrix *C) { int j, ii, jj; for (ii = 0; ii < C->m; ii++) @@ -246,7 +246,7 @@ CSC_Matrix *csr_to_csc_alloc(const CSR_Matrix *A, int *iwork) return C; } -void csr_to_csc_fill_values(const CSR_Matrix *A, CSC_Matrix *C, int *iwork) +void csr_to_csc_fill_vals(const CSR_Matrix *A, CSC_Matrix *C, int *iwork) { int i, j; int *count = iwork; @@ -311,7 +311,7 @@ CSR_Matrix *csc_to_csr_alloc(const CSC_Matrix *A, int *iwork) return C; } -void csc_to_csr_fill_values(const CSC_Matrix *A, CSR_Matrix *C, int *iwork) +void csc_to_csr_fill_vals(const CSC_Matrix *A, CSR_Matrix *C, int *iwork) { int i, j; int *count = iwork; @@ -388,7 +388,7 @@ CSR_Matrix *BTA_alloc(const CSC_Matrix *A, const CSC_Matrix *B) return C; } -void yTA_fill_values(const CSC_Matrix *A, const double *y, CSR_Matrix *C) +void yTA_fill_vals(const CSC_Matrix *A, const double *y, CSR_Matrix *C) { for (int col = 0; col < A->n; col++) { @@ -413,8 +413,8 @@ void yTA_fill_values(const CSC_Matrix *A, const double *y, CSR_Matrix *C) } /* computes C = B^T * D * A in CSR */ -void BTDA_fill_values(const CSC_Matrix *A, const CSC_Matrix *B, const double *d, - CSR_Matrix *C) +void BTDA_fill_vals(const CSC_Matrix *A, const CSC_Matrix *B, const double *d, + CSR_Matrix *C) { int i, j, jj; for (i = 0; i < C->m; i++) @@ -445,7 +445,7 @@ void BTDA_fill_values(const CSC_Matrix *A, const CSC_Matrix *B, const double *d, * faster when Q is dense, since it touches each Q entry exactly once. * The sparse_dot approach below is simpler but redundantly scans * column j of A for each nonzero row of C. */ -void BA_fill_values(const CSR_Matrix *Q, const CSC_Matrix *A, CSC_Matrix *C) +void BA_fill_vals(const CSR_Matrix *Q, const CSC_Matrix *A, CSC_Matrix *C) { /* fill values of C = Q * A, given the sparsity pattern of C. */ int i, j, ii; diff --git a/src/utils/CSR_Matrix.c b/src/utils/CSR_Matrix.c index b175ea1..e09a146 100644 --- a/src/utils/CSR_Matrix.c +++ b/src/utils/CSR_Matrix.c @@ -215,7 +215,7 @@ void diag_csr_mult(const double *d, const CSR_Matrix *A, CSR_Matrix *C) } } -void diag_csr_mult_fill_values(const double *d, const CSR_Matrix *A, CSR_Matrix *C) +void diag_csr_mult_fill_vals(const double *d, const CSR_Matrix *A, CSR_Matrix *C) { memcpy(C->x, A->x, A->nnz * sizeof(double)); @@ -346,7 +346,7 @@ CSR_Matrix *AT_alloc(const CSR_Matrix *A, int *iwork) return AT; } -void AT_fill_values(const CSR_Matrix *A, CSR_Matrix *AT, int *iwork) +void AT_fill_vals(const CSR_Matrix *A, CSR_Matrix *AT, int *iwork) { /* Fill values of A^T given sparsity pattern is already computed */ int i, j; @@ -365,7 +365,7 @@ void AT_fill_values(const CSR_Matrix *A, CSR_Matrix *AT, int *iwork) } /**/ -void csr_matvec_fill_values(const CSR_Matrix *AT, const double *z, CSR_Matrix *C) +void csr_matvec_fill_vals(const CSR_Matrix *AT, const double *z, CSR_Matrix *C) { int A_ncols = AT->m; diff --git a/src/utils/CSR_sum.c b/src/utils/CSR_sum.c index e024a0d..9dc2356 100644 --- a/src/utils/CSR_sum.c +++ b/src/utils/CSR_sum.c @@ -85,8 +85,7 @@ void sum_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C) C->p[A->m] = C->nnz; } -void sum_csr_matrices_fill_sparsity(const CSR_Matrix *A, const CSR_Matrix *B, - CSR_Matrix *C) +void sum_csr_alloc(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C) { /* A and B must be different from C */ assert(A != C && B != C); @@ -143,8 +142,7 @@ void sum_csr_matrices_fill_sparsity(const CSR_Matrix *A, const CSR_Matrix *B, C->p[A->m] = C->nnz; } -void sum_csr_matrices_fill_values(const CSR_Matrix *A, const CSR_Matrix *B, - CSR_Matrix *C) +void sum_csr_fill_vals(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C) { /* Assumes C->p and C->i already contain the sparsity pattern of A+B. Fills only C->x accordingly. */ @@ -176,9 +174,9 @@ void sum_csr_matrices_fill_values(const CSR_Matrix *A, const CSR_Matrix *B, } } -void sum_scaled_csr_matrices_fill_values(const CSR_Matrix *A, const CSR_Matrix *B, - CSR_Matrix *C, const double *d1, - const double *d2) +void sum_scaled_csr_matrices_fill_vals(const CSR_Matrix *A, const CSR_Matrix *B, + CSR_Matrix *C, const double *d1, + const double *d2) { /* Assumes C->p and C->i already contain the sparsity pattern of A+B. Fills only C->x accordingly with scaling. */ @@ -439,7 +437,7 @@ void sum_block_of_rows_csr_fill_sparsity_and_idx_map(const CSR_Matrix *A, } /* -void sum_block_of_rows_csr_fill_values(const CSR_Matrix *A, CSR_Matrix *C, +void sum_block_of_rows_csr_fill_vals(const CSR_Matrix *A, CSR_Matrix *C, const int *idx_map) { memset(C->x, 0, C->nnz * sizeof(double)); @@ -688,7 +686,7 @@ void sum_all_rows_csr_fill_sparsity_and_idx_map(const CSR_Matrix *A, CSR_Matrix } /* -void sum_all_rows_csr_fill_values(const CSR_Matrix *A, CSR_Matrix *C, +void sum_all_rows_csr_fill_vals(const CSR_Matrix *A, CSR_Matrix *C, const int *idx_map) { memset(C->x, 0, C->nnz * sizeof(double)); diff --git a/src/utils/dense_matrix.c b/src/utils/dense_matrix.c index 8a2dd2c..bdbe0c0 100644 --- a/src/utils/dense_matrix.c +++ b/src/utils/dense_matrix.c @@ -56,7 +56,7 @@ Matrix *new_dense_matrix(int m, int n, const double *data) dm->base.n = n; dm->base.block_left_mult_vec = dense_block_left_mult_vec; dm->base.block_left_mult_sparsity = I_kron_A_alloc; - dm->base.block_left_mult_values = I_kron_A_fill_values; + dm->base.block_left_mult_values = I_kron_A_fill_vals; dm->base.free_fn = dense_free; dm->x = (double *) malloc(m * n * sizeof(double)); memcpy(dm->x, data, m * n * sizeof(double)); diff --git a/src/utils/linalg_dense_sparse_matmuls.c b/src/utils/linalg_dense_sparse_matmuls.c index 8fdef54..0126a6e 100644 --- a/src/utils/linalg_dense_sparse_matmuls.c +++ b/src/utils/linalg_dense_sparse_matmuls.c @@ -90,7 +90,7 @@ CSC_Matrix *I_kron_A_alloc(const Matrix *A, const CSC_Matrix *J, int p) return C; } -void I_kron_A_fill_values(const Matrix *A, const CSC_Matrix *J, CSC_Matrix *C) +void I_kron_A_fill_vals(const Matrix *A, const CSC_Matrix *J, CSC_Matrix *C) { const Dense_Matrix *dm = (const Dense_Matrix *) A; int m = dm->base.m; @@ -218,8 +218,8 @@ CSR_Matrix *YT_kron_I_alloc(int m, int k, int n, const CSC_Matrix *J) return C; } -void YT_kron_I_fill_values(int m, int k, int n, const double *Y, const CSC_Matrix *J, - CSR_Matrix *C) +void YT_kron_I_fill_vals(int m, int k, int n, const double *Y, const CSC_Matrix *J, + CSR_Matrix *C) { assert(C->m == m * n); /* C[i, j] = sum_l Y[l, blk] * J[blk_row + l*m, j] @@ -306,8 +306,8 @@ CSR_Matrix *I_kron_X_alloc(int m, int k, int n, const CSC_Matrix *J) return C; } -void I_kron_X_fill_values(int m, int k, int n, const double *X, const CSC_Matrix *J, - CSR_Matrix *C) +void I_kron_X_fill_vals(int m, int k, int n, const double *X, const CSC_Matrix *J, + CSR_Matrix *C) { assert(C->m == m * n); /* C[i, j] = sum_l X[blk_row + l*m] * J[blk*k + l, j] diff --git a/src/utils/linalg_sparse_matmuls.c b/src/utils/linalg_sparse_matmuls.c index 9a7b164..9be3499 100644 --- a/src/utils/linalg_sparse_matmuls.c +++ b/src/utils/linalg_sparse_matmuls.c @@ -183,8 +183,8 @@ CSC_Matrix *block_left_multiply_fill_sparsity(const CSR_Matrix *A, return C; } -void block_left_multiply_fill_values(const CSR_Matrix *A, const CSC_Matrix *J, - CSC_Matrix *C) +void block_left_multiply_fill_vals(const CSR_Matrix *A, const CSC_Matrix *J, + CSC_Matrix *C) { /* A is m x n, J is (n*p) x k, C is (m*p) x k */ int m = A->m; @@ -246,8 +246,8 @@ void block_left_multiply_fill_values(const CSR_Matrix *A, const CSC_Matrix *J, } /* Fill values of C = A @ B where A is CSR, B is CSC. */ -void csr_csc_matmul_fill_values(const CSR_Matrix *A, const CSC_Matrix *B, - CSR_Matrix *C) +void csr_csc_matmul_fill_vals(const CSR_Matrix *A, const CSC_Matrix *B, + CSR_Matrix *C) { for (int i = 0; i < A->m; i++) { diff --git a/src/utils/sparse_matrix.c b/src/utils/sparse_matrix.c index 24ed539..6d9b365 100644 --- a/src/utils/sparse_matrix.c +++ b/src/utils/sparse_matrix.c @@ -37,7 +37,7 @@ static void sparse_block_left_mult_values(const Matrix *self, const CSC_Matrix * CSC_Matrix *C) { const Sparse_Matrix *sm = (const Sparse_Matrix *) self; - block_left_multiply_fill_values(sm->csr, J, C); + block_left_multiply_fill_vals(sm->csr, J, C); } static void sparse_free(Matrix *self) diff --git a/tests/utils/test_csc_matrix.h b/tests/utils/test_csc_matrix.h index 5459021..312ce66 100644 --- a/tests/utils/test_csc_matrix.h +++ b/tests/utils/test_csc_matrix.h @@ -101,7 +101,7 @@ const char *test_ATA_alloc_random(void) double d[10] = {2, 8, 6, 2, 5, 1, 6, 9, 1, 3}; - ATDA_fill_values(A, d, C); + ATDA_fill_vals(A, d, C); double Cx_correct[38] = { 49., 21., 491., 56., 240., 416., 144., 288., 56., 98., 56., 21., 9., @@ -139,7 +139,7 @@ const char *test_ATA_alloc_random2(void) double d[15] = {-0.6, -0.23, -0.29, -1.36, 0.4, 0.36, 0.11, -0.13, -1.32, -0.32, -0.24, -0.7, -0.06, 0.5, 1.99}; - ATDA_fill_values(A, d, C); + ATDA_fill_vals(A, d, C); double Cx_correct[17] = {-0.362232, -0.189896, 0.06656, -0.228888, -0.025732, -0.016146, 0.032857, 0.06656, -1.004802, 0.1505, @@ -198,7 +198,7 @@ const char *test_BTA_alloc_and_BTDA_fill(void) /* Fill values with diagonal weights d */ double d[4] = {1.0, 2.0, 3.0, 4.0}; - BTDA_fill_values(A, B, d, C); + BTDA_fill_vals(A, B, d, C); double expected_x[3] = {37.0, 47.0, 108.0}; mu_assert("C values incorrect", cmp_double_array(C->x, expected_x, 3)); diff --git a/tests/utils/test_csr_csc_conversion.h b/tests/utils/test_csr_csc_conversion.h index efbdc9e..5070e12 100644 --- a/tests/utils/test_csr_csc_conversion.h +++ b/tests/utils/test_csr_csc_conversion.h @@ -8,7 +8,7 @@ #include "utils/CSC_Matrix.h" #include "utils/CSR_Matrix.h" -/* Test CSR to CSC conversion with fill_sparsity and fill_values */ +/* Test CSR to CSC conversion with fill_sparsity and fill_vals */ const char *test_csr_to_csc_split(void) { /* Create a 4x5 CSR matrix A: @@ -39,7 +39,7 @@ const char *test_csr_to_csc_split(void) mu_assert("C row indices incorrect", cmp_int_array(C->i, Ci_correct, 5)); /* Now fill values */ - csr_to_csc_fill_values(A, C, iwork); + csr_to_csc_fill_vals(A, C, iwork); /* Check values */ double Cx_correct[5] = {1.0, 2.0, 3.0, 4.0, 1.0}; @@ -97,7 +97,7 @@ const char *test_csc_to_csr_sparsity(void) return 0; } -/* Test CSC to CSR conversion with fill_values */ +/* Test CSC to CSR conversion with fill_vals */ const char *test_csc_to_csr_values(void) { /* Create a 4x5 CSC matrix A */ @@ -116,7 +116,7 @@ const char *test_csc_to_csr_values(void) CSR_Matrix *C = csc_to_csr_alloc(A, iwork); /* Fill values */ - csc_to_csr_fill_values(A, C, iwork); + csc_to_csr_fill_vals(A, C, iwork); /* Check values */ double Cx_correct[5] = {1.0, 2.0, 3.0, 4.0, 5.0}; @@ -149,12 +149,12 @@ const char *test_csr_csc_csr_roundtrip(void) /* Convert CSR to CSC */ int *iwork_csc = (int *) malloc(A->n * sizeof(int)); CSC_Matrix *B = csr_to_csc_alloc(A, iwork_csc); - csr_to_csc_fill_values(A, B, iwork_csc); + csr_to_csc_fill_vals(A, B, iwork_csc); /* Convert CSC back to CSR */ int *iwork_csr = (int *) malloc(B->m * sizeof(int)); CSR_Matrix *C = csc_to_csr_alloc(B, iwork_csr); - csc_to_csr_fill_values(B, C, iwork_csr); + csc_to_csr_fill_vals(B, C, iwork_csr); /* C should match A */ mu_assert("Round-trip: vals incorrect", cmp_double_array(C->x, Ax, 8)); diff --git a/tests/utils/test_csr_matrix.h b/tests/utils/test_csr_matrix.h index f6c9536..410da18 100644 --- a/tests/utils/test_csr_matrix.h +++ b/tests/utils/test_csr_matrix.h @@ -187,7 +187,7 @@ const char *test_csr_vecmat_values_sparse(void) CSR_Matrix *AT = transpose(A, iwork); - csr_matvec_fill_values(AT, z, C); + csr_matvec_fill_vals(AT, z, C); double Cx_correct[3] = {7.0, 22.0, 1.0}; @@ -412,7 +412,7 @@ const char *test_AT_alloc_and_fill(void) CSR_Matrix *AT = AT_alloc(A, iwork); /* Fill values of A^T */ - AT_fill_values(A, AT, iwork); + AT_fill_vals(A, AT, iwork); /* Expected A^T: * [1.0 0.0 5.0] diff --git a/tests/utils/test_linalg_sparse_matmuls.h b/tests/utils/test_linalg_sparse_matmuls.h index 120f99b..2f8c639 100644 --- a/tests/utils/test_linalg_sparse_matmuls.h +++ b/tests/utils/test_linalg_sparse_matmuls.h @@ -110,7 +110,7 @@ const char *test_block_left_multiply_two_blocks(void) * [0.0 1.0 1.0] */ CSC_Matrix *C = block_left_multiply_fill_sparsity(A, J, 2); - block_left_multiply_fill_values(A, J, C); + block_left_multiply_fill_vals(A, J, C); int expected_p2[4] = {0, 1, 2, 3}; int expected_i2[3] = {0, 2, 3}; diff --git a/tests/utils/test_linalg_utils_matmul_chain_rule.h b/tests/utils/test_linalg_utils_matmul_chain_rule.h index 0c66589..8724aa3 100644 --- a/tests/utils/test_linalg_utils_matmul_chain_rule.h +++ b/tests/utils/test_linalg_utils_matmul_chain_rule.h @@ -8,7 +8,7 @@ #include "utils/CSR_Matrix.h" #include "utils/linalg_dense_sparse_matmuls.h" -/* Test YT_kron_I_alloc and YT_kron_I_fill_values +/* Test YT_kron_I_alloc and YT_kron_I_fill_vals * * C = (Y^T kron I_m) @ J * m=2, k=2, n=2, p=3 @@ -57,7 +57,7 @@ const char *test_YT_kron_I(void) mu_assert("C row ptrs", cmp_int_array(C->p, exp_p, 5)); mu_assert("C col indices", cmp_int_array(C->i, exp_i, 8)); - YT_kron_I_fill_values(m, k, n, Y, J, C); + YT_kron_I_fill_vals(m, k, n, Y, J, C); mu_assert("C values", cmp_double_array(C->x, exp_x, 8)); free_csr_matrix(C); @@ -110,7 +110,7 @@ const char *test_YT_kron_I_larger(void) mu_assert("C2 row ptrs", cmp_int_array(C->p, exp_p, 10)); mu_assert("C2 col indices", cmp_int_array(C->i, exp_i, 18)); - YT_kron_I_fill_values(m, k, n, Y, J, C); + YT_kron_I_fill_vals(m, k, n, Y, J, C); mu_assert("C2 values", cmp_double_array(C->x, exp_x, 18)); free_csr_matrix(C); @@ -118,7 +118,7 @@ const char *test_YT_kron_I_larger(void) return NULL; } -/* Test I_kron_X_alloc and I_kron_X_fill_values +/* Test I_kron_X_alloc and I_kron_X_fill_vals * * C = (I_n kron X) @ J * m=2, k=2, n=2, p=3 @@ -167,7 +167,7 @@ const char *test_I_kron_X(void) mu_assert("C row ptrs", cmp_int_array(C->p, exp_p, 5)); mu_assert("C col indices", cmp_int_array(C->i, exp_i, 10)); - I_kron_X_fill_values(m, k, n, X, J, C); + I_kron_X_fill_vals(m, k, n, X, J, C); mu_assert("C values", cmp_double_array(C->x, exp_x, 10)); free_csr_matrix(C); @@ -219,7 +219,7 @@ const char *test_I_kron_X_larger(void) mu_assert("C2 row ptrs", cmp_int_array(C->p, exp_p, 7)); mu_assert("C2 col indices", cmp_int_array(C->i, exp_i, 21)); - I_kron_X_fill_values(m, k, n, X, J, C); + I_kron_X_fill_vals(m, k, n, X, J, C); mu_assert("C2 values", cmp_double_array(C->x, exp_x, 21)); free_csr_matrix(C); From faff41762ba40742fed03c39ec04e947dc5d0e0b Mon Sep 17 00:00:00 2001 From: dance858 Date: Mon, 30 Mar 2026 10:35:38 -0700 Subject: [PATCH 06/13] change order of two functions --- src/bivariate_full_dom/matmul.c | 108 +++++++++++++++++--------------- 1 file changed, 56 insertions(+), 52 deletions(-) diff --git a/src/bivariate_full_dom/matmul.c b/src/bivariate_full_dom/matmul.c index 4344a8f..3d51722 100644 --- a/src/bivariate_full_dom/matmul.c +++ b/src/bivariate_full_dom/matmul.c @@ -243,6 +243,10 @@ static void eval_jacobian_chain_rule(expr *node) sum_csr_fill_vals(mnode->term1_CSR, mnode->term2_CSR, node->jacobian); } +// ------------------------------------------------------------------------------------ +// Hessian initialization and evaluation for the case where no chain rule is needed, +// ie., when both children are different leaf variables. +// ------------------------------------------------------------------------------------ static void wsum_hess_init_no_chain_rule(expr *node) { expr *x = node->left; @@ -319,58 +323,6 @@ static void wsum_hess_init_no_chain_rule(expr *node) assert(nnz == total_nnz); } -static void wsum_hess_init_chain_rule(expr *node) -{ - expr *x = node->left; - expr *y = node->right; - matmul_expr *mnode = (matmul_expr *) node; - int m = x->d1; - int k = x->d2; - int n = y->d2; - - jacobian_csc_init(x); - jacobian_csc_init(y); - CSC_Matrix *Jf = x->work->jacobian_csc; - CSC_Matrix *Jg = y->work->jacobian_csc; - - /* build cross-Hessian B sparsity */ - mnode->B = build_cross_hessian_sparsity(m, k, n); - - /* C = J_f^T @ B @ J_g: - * step 1: BJ_g = B @ J_g */ - mnode->BJ_g = csr_csc_matmul_alloc(mnode->B, Jg); - mnode->BJ_g_csc_work = - (int *) malloc(MAX(mnode->BJ_g->m, mnode->BJ_g->n) * sizeof(int)); - mnode->BJ_g_CSC = csr_to_csc_alloc(mnode->BJ_g, mnode->BJ_g_csc_work); - - /* step 2: C = J_f^T @ BJ_g via BTA (B^T D A with D=I) */ - mnode->C = BTA_alloc(mnode->BJ_g_CSC, Jf); - - /* C^T */ - node->work->iwork = (int *) malloc(mnode->C->m * sizeof(int)); - mnode->CT = AT_alloc(mnode->C, node->work->iwork); - - /* allocate weight backprop workspace */ - if (!x->is_affine(x) || !y->is_affine(y)) - { - node->work->dwork = - (double *) malloc(MAX(x->size, y->size) * sizeof(double)); - } - - /* init child Hessians */ - wsum_hess_init(x); - wsum_hess_init(y); - - /* merge 4 sparsity patterns */ - int *maps[4]; - node->wsum_hess = sum_4_csr_fill_sparsity_and_idx_maps( - mnode->C, mnode->CT, x->wsum_hess, y->wsum_hess, maps); - mnode->idx_map_C = maps[0]; - mnode->idx_map_CT = maps[1]; - mnode->idx_map_Hf = maps[2]; - mnode->idx_map_Hg = maps[3]; -} - static void eval_wsum_hess_no_chain_rule(expr *node, const double *w) { expr *x = node->left; @@ -428,6 +380,58 @@ static void eval_wsum_hess_no_chain_rule(expr *node, const double *w) } } +static void wsum_hess_init_chain_rule(expr *node) +{ + expr *x = node->left; + expr *y = node->right; + matmul_expr *mnode = (matmul_expr *) node; + int m = x->d1; + int k = x->d2; + int n = y->d2; + + jacobian_csc_init(x); + jacobian_csc_init(y); + CSC_Matrix *Jf = x->work->jacobian_csc; + CSC_Matrix *Jg = y->work->jacobian_csc; + + /* build cross-Hessian B sparsity */ + mnode->B = build_cross_hessian_sparsity(m, k, n); + + /* C = J_f^T @ B @ J_g: + * step 1: BJ_g = B @ J_g */ + mnode->BJ_g = csr_csc_matmul_alloc(mnode->B, Jg); + mnode->BJ_g_csc_work = + (int *) malloc(MAX(mnode->BJ_g->m, mnode->BJ_g->n) * sizeof(int)); + mnode->BJ_g_CSC = csr_to_csc_alloc(mnode->BJ_g, mnode->BJ_g_csc_work); + + /* step 2: C = J_f^T @ BJ_g via BTA (B^T D A with D=I) */ + mnode->C = BTA_alloc(mnode->BJ_g_CSC, Jf); + + /* C^T */ + node->work->iwork = (int *) malloc(mnode->C->m * sizeof(int)); + mnode->CT = AT_alloc(mnode->C, node->work->iwork); + + /* allocate weight backprop workspace */ + if (!x->is_affine(x) || !y->is_affine(y)) + { + node->work->dwork = + (double *) malloc(MAX(x->size, y->size) * sizeof(double)); + } + + /* init child Hessians */ + wsum_hess_init(x); + wsum_hess_init(y); + + /* merge 4 sparsity patterns */ + int *maps[4]; + node->wsum_hess = sum_4_csr_fill_sparsity_and_idx_maps( + mnode->C, mnode->CT, x->wsum_hess, y->wsum_hess, maps); + mnode->idx_map_C = maps[0]; + mnode->idx_map_CT = maps[1]; + mnode->idx_map_Hf = maps[2]; + mnode->idx_map_Hg = maps[3]; +} + static void eval_wsum_hess_chain_rule(expr *node, const double *w) { expr *x = node->left; From 09c6427c21c03a320b9fbf0945bc58f24ce98849 Mon Sep 17 00:00:00 2001 From: dance858 Date: Mon, 30 Mar 2026 11:03:19 -0700 Subject: [PATCH 07/13] better infrastructure --- include/subexpr.h | 6 +- include/utils/mini_numpy.h | 12 ++ src/bivariate_full_dom/matmul.c | 190 +++++++++++++++----------------- src/utils/mini_numpy.c | 34 ++++++ 4 files changed, 135 insertions(+), 107 deletions(-) diff --git a/include/subexpr.h b/include/subexpr.h index 3ef8f86..6fb00c5 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -142,9 +142,9 @@ typedef struct matmul_expr /* Hessian workspace (composite only) */ CSR_Matrix *B; /* cross-Hessian B(w), mk x kn */ - CSR_Matrix *BJ_g; /* B @ J_g */ - CSC_Matrix *BJ_g_CSC; /* BJ_g in CSC */ - int *BJ_g_csc_work; /* CSR-to-CSC workspace */ + CSR_Matrix *BJg; /* B @ J_g */ + CSC_Matrix *BJg_CSC; /* BJg in CSC */ + int *BJg_csc_work; /* CSR-to-CSC workspace */ CSR_Matrix *C; /* J_f^T @ B @ J_g */ CSR_Matrix *CT; /* C^T */ int *idx_map_C; diff --git a/include/utils/mini_numpy.h b/include/utils/mini_numpy.h index b29c673..bb512c2 100644 --- a/include/utils/mini_numpy.h +++ b/include/utils/mini_numpy.h @@ -23,4 +23,16 @@ void scaled_ones(double *result, int size, double value); /* Naive implementation of Z = X @ Y, X is m x k, Y is k x n, Z is m x n */ void mat_mat_mult(const double *X, const double *Y, double *Z, int m, int k, int n); +/* Compute v = (Y kron I_m) @ w where Y is k x n (col-major), w has + length m*n, and v has length m*k. Equivalently, reshape w as the + m x n matrix W (col-major) and compute v = vec(W @ Y^T). */ +void Y_kron_I_vec(int m, int k, int n, const double *Y, const double *w, + double *v); + +/* Compute v = (I_n kron X^T) @ w where X is m x k (col-major), w has + length m*n, and v has length k*n. Equivalently, reshape w as the + m x n matrix W (col-major) and compute v = vec(X^T @ W). */ +void I_kron_XT_vec(int m, int k, int n, const double *X, const double *w, + double *v); + #endif /* MINI_NUMPY_H */ diff --git a/src/bivariate_full_dom/matmul.c b/src/bivariate_full_dom/matmul.c index 3d51722..9847f42 100644 --- a/src/bivariate_full_dom/matmul.c +++ b/src/bivariate_full_dom/matmul.c @@ -31,8 +31,13 @@ // ------------------------------------------------------------------------------ // Helpers for the cross-Hessian B(w) of the bilinear map X @ Y. -// B is mk x kn with B[row + j*m, j + col*k] = w[row + col*m]. -// Each row has exactly n nonzeros. +// B(w) is the mk x kn weighted cross-Hessian of the bilinear map Z = XY. +// It captures d^2(w^T vec(Z)) / d(vec(X)) d(vec(Y)). +// +// Entry: B[row + j*m, j + col*k] = w[row + col*m]. +// Each row has exactly n nonzeros. All k block-rows (one per j) have +// the same values (columns of W = reshape(w, m, n)), but at different +// column positions (offset by j in the Y-variable indexing). // ------------------------------------------------------------------------------ static CSR_Matrix *build_cross_hessian_sparsity(int m, int k, int n) @@ -380,146 +385,121 @@ static void eval_wsum_hess_no_chain_rule(expr *node, const double *w) } } +// ------------------------------------------------------------------------------------ +// Hessian chain rule for Z = f(u) @ g(u). +// H = C + C^T + H_f(v_f) + H_g(v_g) where: +// +// C = J_f^T B(w) J_g cross term (B is the weighted cross-Hessian) +// v_f = (Y kron I_m) w backpropagated weights for left child +// v_g = (I_n kron X^T) w backpropagated weights for right child +// H_f(v_f), H_g(v_g) child Hessians evaluated with transformed weights +// ------------------------------------------------------------------------------------ static void wsum_hess_init_chain_rule(expr *node) { - expr *x = node->left; - expr *y = node->right; + expr *f = node->left; + expr *g = node->right; matmul_expr *mnode = (matmul_expr *) node; - int m = x->d1; - int k = x->d2; - int n = y->d2; - - jacobian_csc_init(x); - jacobian_csc_init(y); - CSC_Matrix *Jf = x->work->jacobian_csc; - CSC_Matrix *Jg = y->work->jacobian_csc; + int m = f->d1; + int k = f->d2; + int n = g->d2; + CSC_Matrix *Jf = f->work->jacobian_csc; + CSC_Matrix *Jg = g->work->jacobian_csc; - /* build cross-Hessian B sparsity */ + /* initialize C = Jf^T @ B @ Jg = Jf^T @ (B @ Jg) */ mnode->B = build_cross_hessian_sparsity(m, k, n); + mnode->BJg = csr_csc_matmul_alloc(mnode->B, Jg); + int max_alloc = MAX(mnode->BJg->m, mnode->BJg->n); + mnode->BJg_csc_work = (int *) malloc(max_alloc * sizeof(int)); + mnode->BJg_CSC = csr_to_csc_alloc(mnode->BJg, mnode->BJg_csc_work); + mnode->C = BTA_alloc(mnode->BJg_CSC, Jf); - /* C = J_f^T @ B @ J_g: - * step 1: BJ_g = B @ J_g */ - mnode->BJ_g = csr_csc_matmul_alloc(mnode->B, Jg); - mnode->BJ_g_csc_work = - (int *) malloc(MAX(mnode->BJ_g->m, mnode->BJ_g->n) * sizeof(int)); - mnode->BJ_g_CSC = csr_to_csc_alloc(mnode->BJ_g, mnode->BJ_g_csc_work); - - /* step 2: C = J_f^T @ BJ_g via BTA (B^T D A with D=I) */ - mnode->C = BTA_alloc(mnode->BJ_g_CSC, Jf); - - /* C^T */ + /* initialize C^T */ node->work->iwork = (int *) malloc(mnode->C->m * sizeof(int)); mnode->CT = AT_alloc(mnode->C, node->work->iwork); - /* allocate weight backprop workspace */ - if (!x->is_affine(x) || !y->is_affine(y)) - { - node->work->dwork = - (double *) malloc(MAX(x->size, y->size) * sizeof(double)); - } - - /* init child Hessians */ - wsum_hess_init(x); - wsum_hess_init(y); + /* initialize Hessians of children */ + wsum_hess_init(f); + wsum_hess_init(g); - /* merge 4 sparsity patterns */ + /* sum the four terms and fill idx maps */ int *maps[4]; node->wsum_hess = sum_4_csr_fill_sparsity_and_idx_maps( - mnode->C, mnode->CT, x->wsum_hess, y->wsum_hess, maps); + mnode->C, mnode->CT, f->wsum_hess, g->wsum_hess, maps); mnode->idx_map_C = maps[0]; mnode->idx_map_CT = maps[1]; mnode->idx_map_Hf = maps[2]; mnode->idx_map_Hg = maps[3]; + + /* allocate weight backprop workspace */ + if (!f->is_affine(f) || !g->is_affine(g)) + { + node->work->dwork = + (double *) malloc(MAX(f->size, g->size) * sizeof(double)); + } } static void eval_wsum_hess_chain_rule(expr *node, const double *w) { - expr *x = node->left; - expr *y = node->right; + expr *f = node->left; + expr *g = node->right; matmul_expr *mnode = (matmul_expr *) node; - int m = x->d1; - int k = x->d2; - int n = y->d2; - bool is_x_affine = x->is_affine(x); - bool is_y_affine = y->is_affine(y); + int m = f->d1; + int k = f->d2; + int n = g->d2; + bool is_f_affine = f->is_affine(f); + bool is_g_affine = g->is_affine(g); /* refresh child Jacobian CSC values (cache if affine) */ - if (!x->work->jacobian_csc_filled) + if (!f->work->jacobian_csc_filled) { - csr_to_csc_fill_vals(x->jacobian, x->work->jacobian_csc, x->work->csc_work); - if (is_x_affine) + csr_to_csc_fill_vals(f->jacobian, f->work->jacobian_csc, f->work->csc_work); + if (is_f_affine) { - x->work->jacobian_csc_filled = true; + f->work->jacobian_csc_filled = true; } } - if (!y->work->jacobian_csc_filled) + if (!g->work->jacobian_csc_filled) { - csr_to_csc_fill_vals(y->jacobian, y->work->jacobian_csc, y->work->csc_work); - if (is_y_affine) + csr_to_csc_fill_vals(g->jacobian, g->work->jacobian_csc, g->work->csc_work); + if (is_g_affine) { - y->work->jacobian_csc_filled = true; + g->work->jacobian_csc_filled = true; } } - CSC_Matrix *Jf = x->work->jacobian_csc; - CSC_Matrix *Jg = y->work->jacobian_csc; + CSC_Matrix *Jf = f->work->jacobian_csc; + CSC_Matrix *Jg = g->work->jacobian_csc; /* compute C = J_f^T @ B(w) @ J_g */ fill_cross_hessian_values(m, k, n, w, mnode->B); - csr_csc_matmul_fill_vals(mnode->B, Jg, mnode->BJ_g); - csr_to_csc_fill_vals(mnode->BJ_g, mnode->BJ_g_CSC, mnode->BJ_g_csc_work); - BTDA_fill_vals(mnode->BJ_g_CSC, Jf, NULL, mnode->C); + csr_csc_matmul_fill_vals(mnode->B, Jg, mnode->BJg); + csr_to_csc_fill_vals(mnode->BJg, mnode->BJg_CSC, mnode->BJg_csc_work); + BTDA_fill_vals(mnode->BJg_CSC, Jf, NULL, mnode->C); - /* C^T */ + /* compute CT */ AT_fill_vals(mnode->C, mnode->CT, node->work->iwork); /* backpropagate weights and recurse into children */ - if (!is_x_affine) + if (!is_f_affine) { - /* v_f = vec(W @ Y^T): - * v_f[row + j*m] = sum_col Y[j,col] * w[row + col*m] */ - double *v_f = node->work->dwork; - for (int j = 0; j < k; j++) - { - for (int row = 0; row < m; row++) - { - double sum = 0.0; - for (int col = 0; col < n; col++) - { - sum += y->value[j + col * k] * w[row + col * m]; - } - v_f[row + j * m] = sum; - } - } - x->eval_wsum_hess(x, v_f); + Y_kron_I_vec(m, k, n, g->value, w, + node->work->dwork); + f->eval_wsum_hess(f, node->work->dwork); } - if (!is_y_affine) + if (!is_g_affine) { - /* v_g = vec(X^T @ W): - * v_g[j + col*k] = sum_row X[row,j] * w[row + col*m] */ - double *v_g = node->work->dwork; - for (int col = 0; col < n; col++) - { - for (int j = 0; j < k; j++) - { - double sum = 0.0; - for (int row = 0; row < m; row++) - { - sum += x->value[row + j * m] * w[row + col * m]; - } - v_g[j + col * k] = sum; - } - } - y->eval_wsum_hess(y, v_g); + I_kron_XT_vec(m, k, n, f->value, w, + node->work->dwork); + g->eval_wsum_hess(g, node->work->dwork); } /* accumulate H = C + C^T + H_f + H_g */ memset(node->wsum_hess->x, 0, node->wsum_hess->nnz * sizeof(double)); accumulate_mapped(node->wsum_hess->x, mnode->C, mnode->idx_map_C); accumulate_mapped(node->wsum_hess->x, mnode->CT, mnode->idx_map_CT); - accumulate_mapped(node->wsum_hess->x, x->wsum_hess, mnode->idx_map_Hf); - accumulate_mapped(node->wsum_hess->x, y->wsum_hess, mnode->idx_map_Hg); + accumulate_mapped(node->wsum_hess->x, f->wsum_hess, mnode->idx_map_Hf); + accumulate_mapped(node->wsum_hess->x, g->wsum_hess, mnode->idx_map_Hg); } expr *new_matmul(expr *x, expr *y) @@ -541,17 +521,19 @@ expr *new_matmul(expr *x, expr *y) bool use_chain_rule = !(x->var_id != NOT_A_VARIABLE && y->var_id != NOT_A_VARIABLE && x->var_id != y->var_id); - jacobian_init_fn jac_init = - use_chain_rule ? jacobian_init_chain_rule : jacobian_init_no_chain_rule; - eval_jacobian_fn jac_eval = - use_chain_rule ? eval_jacobian_chain_rule : eval_jacobian_no_chain_rule; - wsum_hess_init_fn hess_init = - use_chain_rule ? wsum_hess_init_chain_rule : wsum_hess_init_no_chain_rule; - wsum_hess_fn hess_eval = - use_chain_rule ? eval_wsum_hess_chain_rule : eval_wsum_hess_no_chain_rule; - - init_expr(node, x->d1, y->d2, x->n_vars, forward, jac_init, jac_eval, is_affine, - hess_init, hess_eval, NULL); + if (!use_chain_rule) + { + init_expr(node, x->d1, y->d2, x->n_vars, forward, + jacobian_init_no_chain_rule, eval_jacobian_no_chain_rule, + is_affine, wsum_hess_init_no_chain_rule, + eval_wsum_hess_no_chain_rule, NULL); + } + else + { + init_expr(node, x->d1, y->d2, x->n_vars, forward, jacobian_init_chain_rule, + eval_jacobian_chain_rule, is_affine, wsum_hess_init_chain_rule, + eval_wsum_hess_chain_rule, NULL); + } /* Set children */ node->left = x; diff --git a/src/utils/mini_numpy.c b/src/utils/mini_numpy.c index c30073c..656cc3f 100644 --- a/src/utils/mini_numpy.c +++ b/src/utils/mini_numpy.c @@ -68,3 +68,37 @@ void mat_mat_mult(const double *X, const double *Y, double *Z, int m, int k, int } } } + +void Y_kron_I_vec(int m, int k, int n, const double *Y, const double *w, + double *v) +{ + for (int j = 0; j < k; j++) + { + for (int row = 0; row < m; row++) + { + double sum = 0.0; + for (int col = 0; col < n; col++) + { + sum += Y[j + col * k] * w[row + col * m]; + } + v[row + j * m] = sum; + } + } +} + +void I_kron_XT_vec(int m, int k, int n, const double *X, const double *w, + double *v) +{ + for (int col = 0; col < n; col++) + { + for (int j = 0; j < k; j++) + { + double sum = 0.0; + for (int row = 0; row < m; row++) + { + sum += X[row + j * m] * w[row + col * m]; + } + v[j + col * k] = sum; + } + } +} From e5e33ecc96eb81bce5174f2a0903d1e0764e9622 Mon Sep 17 00:00:00 2001 From: dance858 Date: Mon, 30 Mar 2026 11:06:01 -0700 Subject: [PATCH 08/13] even better infrastructure --- include/subexpr.h | 6 +++--- include/utils/mini_numpy.h | 6 ++---- src/bivariate_full_dom/matmul.c | 20 ++++++++++---------- src/utils/mini_numpy.c | 6 ++---- 4 files changed, 17 insertions(+), 21 deletions(-) diff --git a/include/subexpr.h b/include/subexpr.h index 6fb00c5..44a964b 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -141,12 +141,12 @@ typedef struct matmul_expr CSR_Matrix *term2_CSR; /* (I_n x X) @ J_g */ /* Hessian workspace (composite only) */ - CSR_Matrix *B; /* cross-Hessian B(w), mk x kn */ + CSR_Matrix *B; /* cross-Hessian B(w), mk x kn */ CSR_Matrix *BJg; /* B @ J_g */ CSC_Matrix *BJg_CSC; /* BJg in CSC */ int *BJg_csc_work; /* CSR-to-CSC workspace */ - CSR_Matrix *C; /* J_f^T @ B @ J_g */ - CSR_Matrix *CT; /* C^T */ + CSR_Matrix *C; /* J_f^T @ B @ J_g */ + CSR_Matrix *CT; /* C^T */ int *idx_map_C; int *idx_map_CT; int *idx_map_Hf; diff --git a/include/utils/mini_numpy.h b/include/utils/mini_numpy.h index bb512c2..931f516 100644 --- a/include/utils/mini_numpy.h +++ b/include/utils/mini_numpy.h @@ -26,13 +26,11 @@ void mat_mat_mult(const double *X, const double *Y, double *Z, int m, int k, int /* Compute v = (Y kron I_m) @ w where Y is k x n (col-major), w has length m*n, and v has length m*k. Equivalently, reshape w as the m x n matrix W (col-major) and compute v = vec(W @ Y^T). */ -void Y_kron_I_vec(int m, int k, int n, const double *Y, const double *w, - double *v); +void Y_kron_I_vec(int m, int k, int n, const double *Y, const double *w, double *v); /* Compute v = (I_n kron X^T) @ w where X is m x k (col-major), w has length m*n, and v has length k*n. Equivalently, reshape w as the m x n matrix W (col-major) and compute v = vec(X^T @ W). */ -void I_kron_XT_vec(int m, int k, int n, const double *X, const double *w, - double *v); +void I_kron_XT_vec(int m, int k, int n, const double *X, const double *w, double *v); #endif /* MINI_NUMPY_H */ diff --git a/src/bivariate_full_dom/matmul.c b/src/bivariate_full_dom/matmul.c index 9847f42..923722c 100644 --- a/src/bivariate_full_dom/matmul.c +++ b/src/bivariate_full_dom/matmul.c @@ -448,28 +448,29 @@ static void eval_wsum_hess_chain_rule(expr *node, const double *w) int n = g->d2; bool is_f_affine = f->is_affine(f); bool is_g_affine = g->is_affine(g); + CSC_Matrix *Jf = f->work->jacobian_csc; + CSC_Matrix *Jg = g->work->jacobian_csc; /* refresh child Jacobian CSC values (cache if affine) */ if (!f->work->jacobian_csc_filled) { - csr_to_csc_fill_vals(f->jacobian, f->work->jacobian_csc, f->work->csc_work); + csr_to_csc_fill_vals(f->jacobian, Jf, f->work->csc_work); if (is_f_affine) { f->work->jacobian_csc_filled = true; } } + + /* refresh child Jacobian CSC values (cache if affine) */ if (!g->work->jacobian_csc_filled) { - csr_to_csc_fill_vals(g->jacobian, g->work->jacobian_csc, g->work->csc_work); + csr_to_csc_fill_vals(g->jacobian, Jg, g->work->csc_work); if (is_g_affine) { g->work->jacobian_csc_filled = true; } } - CSC_Matrix *Jf = f->work->jacobian_csc; - CSC_Matrix *Jg = g->work->jacobian_csc; - /* compute C = J_f^T @ B(w) @ J_g */ fill_cross_hessian_values(m, k, n, w, mnode->B); csr_csc_matmul_fill_vals(mnode->B, Jg, mnode->BJg); @@ -479,18 +480,17 @@ static void eval_wsum_hess_chain_rule(expr *node, const double *w) /* compute CT */ AT_fill_vals(mnode->C, mnode->CT, node->work->iwork); - /* backpropagate weights and recurse into children */ + /* compute Hessian of f */ if (!is_f_affine) { - Y_kron_I_vec(m, k, n, g->value, w, - node->work->dwork); + Y_kron_I_vec(m, k, n, g->value, w, node->work->dwork); f->eval_wsum_hess(f, node->work->dwork); } + /* compute Hessian of g */ if (!is_g_affine) { - I_kron_XT_vec(m, k, n, f->value, w, - node->work->dwork); + I_kron_XT_vec(m, k, n, f->value, w, node->work->dwork); g->eval_wsum_hess(g, node->work->dwork); } diff --git a/src/utils/mini_numpy.c b/src/utils/mini_numpy.c index 656cc3f..27f50d8 100644 --- a/src/utils/mini_numpy.c +++ b/src/utils/mini_numpy.c @@ -69,8 +69,7 @@ void mat_mat_mult(const double *X, const double *Y, double *Z, int m, int k, int } } -void Y_kron_I_vec(int m, int k, int n, const double *Y, const double *w, - double *v) +void Y_kron_I_vec(int m, int k, int n, const double *Y, const double *w, double *v) { for (int j = 0; j < k; j++) { @@ -86,8 +85,7 @@ void Y_kron_I_vec(int m, int k, int n, const double *Y, const double *w, } } -void I_kron_XT_vec(int m, int k, int n, const double *X, const double *w, - double *v) +void I_kron_XT_vec(int m, int k, int n, const double *X, const double *w, double *v) { for (int col = 0; col < n; col++) { From 021eb7b4b9d1e293002ebfc231c0308be72cdb9a Mon Sep 17 00:00:00 2001 From: dance858 Date: Mon, 30 Mar 2026 11:10:35 -0700 Subject: [PATCH 09/13] free' --- src/bivariate_full_dom/matmul.c | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/src/bivariate_full_dom/matmul.c b/src/bivariate_full_dom/matmul.c index 923722c..236c9c7 100644 --- a/src/bivariate_full_dom/matmul.c +++ b/src/bivariate_full_dom/matmul.c @@ -106,6 +106,25 @@ static void forward(expr *node, const double *u) mat_mat_mult(x->value, y->value, node->value, x->d1, x->d2, y->d2); } +static void free_matmul_data(expr *node) +{ + matmul_expr *mnode = (matmul_expr *) node; + /* Jacobian workspace */ + free_csr_matrix(mnode->term1_CSR); + free_csr_matrix(mnode->term2_CSR); + /* Hessian workspace */ + free_csr_matrix(mnode->B); + free_csr_matrix(mnode->BJg); + free_csc_matrix(mnode->BJg_CSC); + free(mnode->BJg_csc_work); + free_csr_matrix(mnode->C); + free_csr_matrix(mnode->CT); + free(mnode->idx_map_C); + free(mnode->idx_map_CT); + free(mnode->idx_map_Hf); + free(mnode->idx_map_Hg); +} + static bool is_affine(const expr *node) { (void) node; @@ -526,13 +545,13 @@ expr *new_matmul(expr *x, expr *y) init_expr(node, x->d1, y->d2, x->n_vars, forward, jacobian_init_no_chain_rule, eval_jacobian_no_chain_rule, is_affine, wsum_hess_init_no_chain_rule, - eval_wsum_hess_no_chain_rule, NULL); + eval_wsum_hess_no_chain_rule, free_matmul_data); } else { init_expr(node, x->d1, y->d2, x->n_vars, forward, jacobian_init_chain_rule, eval_jacobian_chain_rule, is_affine, wsum_hess_init_chain_rule, - eval_wsum_hess_chain_rule, NULL); + eval_wsum_hess_chain_rule, free_matmul_data); } /* Set children */ From 1096984802bd133943e0b41019334185e626fe8b Mon Sep 17 00:00:00 2001 From: dance858 Date: Mon, 30 Mar 2026 11:17:56 -0700 Subject: [PATCH 10/13] refactor accumulator --- src/bivariate_full_dom/matmul.c | 17 ++++------------- src/bivariate_full_dom/multiply.c | 18 ++++-------------- src/utils/CSR_sum.c | 9 +++------ 3 files changed, 11 insertions(+), 33 deletions(-) diff --git a/src/bivariate_full_dom/matmul.c b/src/bivariate_full_dom/matmul.c index 236c9c7..aaa8fa1 100644 --- a/src/bivariate_full_dom/matmul.c +++ b/src/bivariate_full_dom/matmul.c @@ -78,15 +78,6 @@ static void fill_cross_hessian_values(int m, int k, int n, const double *w, } } -static void accumulate_mapped(double *dest, const CSR_Matrix *src, - const int *idx_map) -{ - for (int j = 0; j < src->nnz; j++) - { - dest[idx_map[j]] += src->x[j]; - } -} - // ------------------------------------------------------------------------------ // Implementation of matrix multiplication: Z = X @ Y // where X is m x k and Y is k x n, producing Z which is m x n @@ -515,10 +506,10 @@ static void eval_wsum_hess_chain_rule(expr *node, const double *w) /* accumulate H = C + C^T + H_f + H_g */ memset(node->wsum_hess->x, 0, node->wsum_hess->nnz * sizeof(double)); - accumulate_mapped(node->wsum_hess->x, mnode->C, mnode->idx_map_C); - accumulate_mapped(node->wsum_hess->x, mnode->CT, mnode->idx_map_CT); - accumulate_mapped(node->wsum_hess->x, f->wsum_hess, mnode->idx_map_Hf); - accumulate_mapped(node->wsum_hess->x, g->wsum_hess, mnode->idx_map_Hg); + idx_map_accumulator(mnode->C, mnode->idx_map_C, node->wsum_hess->x); + idx_map_accumulator(mnode->CT, mnode->idx_map_CT, node->wsum_hess->x); + idx_map_accumulator(f->wsum_hess, mnode->idx_map_Hf, node->wsum_hess->x); + idx_map_accumulator(g->wsum_hess, mnode->idx_map_Hg, node->wsum_hess->x); } expr *new_matmul(expr *x, expr *y) diff --git a/src/bivariate_full_dom/multiply.c b/src/bivariate_full_dom/multiply.c index 3ce8299..98139c1 100644 --- a/src/bivariate_full_dom/multiply.c +++ b/src/bivariate_full_dom/multiply.c @@ -23,16 +23,6 @@ #include #include -/* Scatter-add src->x into dest using precomputed index map */ -static void accumulate_mapped(double *dest, const CSR_Matrix *src, - const int *idx_map) -{ - for (int j = 0; j < src->nnz; j++) - { - dest[idx_map[j]] += src->x[j]; - } -} - // ------------------------------------------------------------------------------ // Implementation of elementwise multiplication when both arguments are vectors. // If one argument is a scalar variable, the broadcasting should be represented @@ -267,10 +257,10 @@ static void eval_wsum_hess(expr *node, const double *w) // compute H = C + C^T + term2 + term3 // --------------------------------------------------------------- memset(node->wsum_hess->x, 0, node->wsum_hess->nnz * sizeof(double)); - accumulate_mapped(node->wsum_hess->x, C, mul_node->idx_map_C); - accumulate_mapped(node->wsum_hess->x, CT, mul_node->idx_map_CT); - accumulate_mapped(node->wsum_hess->x, x->wsum_hess, mul_node->idx_map_Hx); - accumulate_mapped(node->wsum_hess->x, y->wsum_hess, mul_node->idx_map_Hy); + idx_map_accumulator(C, mul_node->idx_map_C, node->wsum_hess->x); + idx_map_accumulator(CT, mul_node->idx_map_CT, node->wsum_hess->x); + idx_map_accumulator(x->wsum_hess, mul_node->idx_map_Hx, node->wsum_hess->x); + idx_map_accumulator(y->wsum_hess, mul_node->idx_map_Hy, node->wsum_hess->x); } } diff --git a/src/utils/CSR_sum.c b/src/utils/CSR_sum.c index 9dc2356..b6cb59d 100644 --- a/src/utils/CSR_sum.c +++ b/src/utils/CSR_sum.c @@ -571,13 +571,10 @@ void sum_evenly_spaced_rows_csr_fill_sparsity_and_idx_map(const CSR_Matrix *A, void idx_map_accumulator(const CSR_Matrix *A, const int *idx_map, double *accumulator) { - /* don't forget to initialze accumulator to 0 before calling this */ - for (int row = 0; row < A->m; row++) + /* don't forget to initialize accumulator to 0 before calling this */ + for (int j = 0; j < A->nnz; j++) { - for (int j = A->p[row]; j < A->p[row + 1]; j++) - { - accumulator[idx_map[j]] += A->x[j]; - } + accumulator[idx_map[j]] += A->x[j]; } } From 25bb0ea1c101258f68961c29b2bf2acb9c5e0db6 Mon Sep 17 00:00:00 2001 From: dance858 Date: Mon, 30 Mar 2026 12:00:22 -0700 Subject: [PATCH 11/13] minor --- src/utils/linalg_dense_sparse_matmuls.c | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/utils/linalg_dense_sparse_matmuls.c b/src/utils/linalg_dense_sparse_matmuls.c index 0126a6e..3c92bad 100644 --- a/src/utils/linalg_dense_sparse_matmuls.c +++ b/src/utils/linalg_dense_sparse_matmuls.c @@ -221,6 +221,7 @@ CSR_Matrix *YT_kron_I_alloc(int m, int k, int n, const CSC_Matrix *J) void YT_kron_I_fill_vals(int m, int k, int n, const double *Y, const CSC_Matrix *J, CSR_Matrix *C) { + (void) n; assert(C->m == m * n); /* C[i, j] = sum_l Y[l, blk] * J[blk_row + l*m, j] * where blk_row = i % m, blk = i / m */ @@ -309,6 +310,7 @@ CSR_Matrix *I_kron_X_alloc(int m, int k, int n, const CSC_Matrix *J) void I_kron_X_fill_vals(int m, int k, int n, const double *X, const CSC_Matrix *J, CSR_Matrix *C) { + (void) n; assert(C->m == m * n); /* C[i, j] = sum_l X[blk_row + l*m] * J[blk*k + l, j] * where blk = i / m, blk_row = i % m */ From 96d46e3df367b79a1d6a714d5d46df2aa2dee12c Mon Sep 17 00:00:00 2001 From: dance858 Date: Mon, 30 Mar 2026 12:44:56 -0700 Subject: [PATCH 12/13] redo name change --- include/utils/CSC_Matrix.h | 12 +++++----- include/utils/CSR_Matrix.h | 6 ++--- include/utils/CSR_sum.h | 6 ++--- include/utils/linalg_dense_sparse_matmuls.h | 6 ++--- include/utils/linalg_sparse_matmuls.h | 23 ++++++++----------- src/affine/add.c | 6 ++--- src/affine/hstack.c | 2 +- src/affine/left_matmul.c | 4 ++-- src/bivariate_full_dom/matmul.c | 23 ++++++++++--------- src/bivariate_full_dom/multiply.c | 16 ++++++------- src/bivariate_restricted_dom/quad_over_lin.c | 5 ++-- src/elementwise_full_dom/common.c | 22 +++++++++--------- src/other/quad_form.c | 14 +++++------ src/utils/CSC_Matrix.c | 14 +++++------ src/utils/CSR_Matrix.c | 6 ++--- src/utils/CSR_sum.c | 12 +++++----- src/utils/dense_matrix.c | 2 +- src/utils/linalg_dense_sparse_matmuls.c | 10 ++++---- src/utils/linalg_sparse_matmuls.c | 8 +++---- src/utils/sparse_matrix.c | 2 +- tests/utils/test_csc_matrix.h | 6 ++--- tests/utils/test_csr_csc_conversion.h | 12 +++++----- tests/utils/test_csr_matrix.h | 4 ++-- tests/utils/test_linalg_sparse_matmuls.h | 2 +- .../test_linalg_utils_matmul_chain_rule.h | 12 +++++----- 25 files changed, 117 insertions(+), 118 deletions(-) diff --git a/include/utils/CSC_Matrix.h b/include/utils/CSC_Matrix.h index ae62aab..def4a86 100644 --- a/include/utils/CSC_Matrix.h +++ b/include/utils/CSC_Matrix.h @@ -39,27 +39,27 @@ CSR_Matrix *BTA_alloc(const CSC_Matrix *A, const CSC_Matrix *B); CSC_Matrix *symBA_alloc(const CSR_Matrix *B, const CSC_Matrix *A); /* Compute values for C = A^T D A (null d corresponds to D as identity) */ -void ATDA_fill_vals(const CSC_Matrix *A, const double *d, CSR_Matrix *C); +void ATDA_fill_values(const CSC_Matrix *A, const double *d, CSR_Matrix *C); /* Compute values for C = B^T D A (null d corresonds to D as identity) */ -void BTDA_fill_vals(const CSC_Matrix *A, const CSC_Matrix *B, const double *d, +void BTDA_fill_values(const CSC_Matrix *A, const CSC_Matrix *B, const double *d, CSR_Matrix *C); /* Fill values of C = BA. The matrix B does not have to be symmetric */ -void BA_fill_vals(const CSR_Matrix *B, const CSC_Matrix *A, CSC_Matrix *C); +void BA_fill_values(const CSR_Matrix *B, const CSC_Matrix *A, CSC_Matrix *C); /* Fill values of C = x^T A. The matrix C must have filled sparsity. */ -void yTA_fill_vals(const CSC_Matrix *A, const double *x, CSR_Matrix *C); +void yTA_fill_values(const CSC_Matrix *A, const double *x, CSR_Matrix *C); /* Count nonzero columns of a CSC matrix */ int count_nonzero_cols_csc(const CSC_Matrix *A); /* convert from CSR to CSC format */ CSC_Matrix *csr_to_csc_alloc(const CSR_Matrix *A, int *iwork); -void csr_to_csc_fill_vals(const CSR_Matrix *A, CSC_Matrix *C, int *iwork); +void csr_to_csc_fill_values(const CSR_Matrix *A, CSC_Matrix *C, int *iwork); /* convert from CSC to CSR format */ CSR_Matrix *csc_to_csr_alloc(const CSC_Matrix *A, int *iwork); -void csc_to_csr_fill_vals(const CSC_Matrix *A, CSR_Matrix *C, int *iwork); +void csc_to_csr_fill_values(const CSC_Matrix *A, CSR_Matrix *C, int *iwork); #endif /* CSC_MATRIX_H */ diff --git a/include/utils/CSR_Matrix.h b/include/utils/CSR_Matrix.h index 5bc3846..df81839 100644 --- a/include/utils/CSR_Matrix.h +++ b/include/utils/CSR_Matrix.h @@ -35,7 +35,7 @@ void copy_csr_matrix(const CSR_Matrix *A, CSR_Matrix *C); /* transpose functionality (iwork must be of size A->n) */ CSR_Matrix *transpose(const CSR_Matrix *A, int *iwork); CSR_Matrix *AT_alloc(const CSR_Matrix *A, int *iwork); -void AT_fill_vals(const CSR_Matrix *A, CSR_Matrix *AT, int *iwork); +void AT_fill_values(const CSR_Matrix *A, CSR_Matrix *AT, int *iwork); /* Build (I_p kron A) = blkdiag(A, A, ..., A) of size (p*A->m) x (p*A->n) */ CSR_Matrix *block_diag_repeat_csr(const CSR_Matrix *A, int p); @@ -49,7 +49,7 @@ void csr_matvec_wo_offset(const CSR_Matrix *A, const double *x, double *y); /* Computes values of the row matrix C = z^T A (column indices must have been pre-computed) and transposed matrix AT must be provided) */ -void csr_matvec_fill_vals(const CSR_Matrix *AT, const double *z, CSR_Matrix *C); +void csr_matvec_fill_values(const CSR_Matrix *AT, const double *z, CSR_Matrix *C); /* Insert value into CSR matrix A with just one row at col_idx. Assumes that A has enough space and that A does not have an element at col_idx. It does update @@ -60,7 +60,7 @@ void csr_insert_value(CSR_Matrix *A, int col_idx, double value); * d must have length m * C must be pre-allocated with same dimensions as A */ void diag_csr_mult(const double *d, const CSR_Matrix *A, CSR_Matrix *C); -void diag_csr_mult_fill_vals(const double *d, const CSR_Matrix *A, CSR_Matrix *C); +void diag_csr_mult_fill_values(const double *d, const CSR_Matrix *A, CSR_Matrix *C); /* Count number of columns with nonzero entries */ int count_nonzero_cols(const CSR_Matrix *A, bool *col_nz); diff --git a/include/utils/CSR_sum.h b/include/utils/CSR_sum.h index 6476071..abe6c84 100644 --- a/include/utils/CSR_sum.h +++ b/include/utils/CSR_sum.h @@ -19,7 +19,7 @@ void sum_csr_alloc(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C); /* Fill only the values of C = A + B, assuming C's sparsity pattern (p and i) * is already filled and matches the union of A and B per row. Does not modify * C->p, C->i, or C->nnz. */ -void sum_csr_fill_vals(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C); +void sum_csr_fill_values(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C); /* Compute C = diag(d1) * A + diag(d2) * B where A, B, C are CSR matrices */ void sum_scaled_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C, @@ -28,7 +28,7 @@ void sum_scaled_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matri /* Fill only the values of C = diag(d1) * A + diag(d2) * B, assuming C's sparsity * pattern (p and i) is already filled and matches the union of A and B per row. * Does not modify C->p, C->i, or C->nnz. */ -void sum_scaled_csr_matrices_fill_vals(const CSR_Matrix *A, const CSR_Matrix *B, +void sum_scaled_csr_matrices_fill_values(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C, const double *d1, const double *d2); @@ -41,7 +41,7 @@ void sum_all_rows_csr_fill_sparsity_and_idx_map(const CSR_Matrix *A, CSR_Matrix int *iwork, int *idx_map); /* Fill values of summed rows using precomputed idx_map and sparsity of C */ -// void sum_all_rows_csr_fill_vals(const CSR_Matrix *A, CSR_Matrix *C, +// void sum_all_rows_csr_fill_values(const CSR_Matrix *A, CSR_Matrix *C, // const int *idx_map); /* Fill accumulator for summing rows using precomputed idx_map for each nnz of A. diff --git a/include/utils/linalg_dense_sparse_matmuls.h b/include/utils/linalg_dense_sparse_matmuls.h index cbc4634..5057424 100644 --- a/include/utils/linalg_dense_sparse_matmuls.h +++ b/include/utils/linalg_dense_sparse_matmuls.h @@ -9,18 +9,18 @@ * A is dense m x n, J is (n*p) x k in CSC, C is (m*p) x k in CSC. */ // TODO: maybe we can replace these with I_kron_X functionality? CSC_Matrix *I_kron_A_alloc(const Matrix *A, const CSC_Matrix *J, int p); -void I_kron_A_fill_vals(const Matrix *A, const CSC_Matrix *J, CSC_Matrix *C); +void I_kron_A_fill_values(const Matrix *A, const CSC_Matrix *J, CSC_Matrix *C); /* Sparsity and values of C = (Y^T kron I_m) @ J where Y is k x n, J is (m*k) x p, and C is (m*n) x p. Y is given in column-major dense format. */ CSR_Matrix *YT_kron_I_alloc(int m, int k, int n, const CSC_Matrix *J); -void YT_kron_I_fill_vals(int m, int k, int n, const double *Y, const CSC_Matrix *J, +void YT_kron_I_fill_values(int m, int k, int n, const double *Y, const CSC_Matrix *J, CSR_Matrix *C); /* Sparsity and values of C = (I_n kron X) @ J where X is m x k (col-major dense), J is (k*n) x p, and C is (m*n) x p. */ CSR_Matrix *I_kron_X_alloc(int m, int k, int n, const CSC_Matrix *J); -void I_kron_X_fill_vals(int m, int k, int n, const double *X, const CSC_Matrix *J, +void I_kron_X_fill_values(int m, int k, int n, const double *X, const CSC_Matrix *J, CSR_Matrix *C); #endif /* LINALG_DENSE_SPARSE_H */ diff --git a/include/utils/linalg_sparse_matmuls.h b/include/utils/linalg_sparse_matmuls.h index 22c3576..daa6c75 100644 --- a/include/utils/linalg_sparse_matmuls.h +++ b/include/utils/linalg_sparse_matmuls.h @@ -1,9 +1,8 @@ #ifndef LINALG_H #define LINALG_H -/* Forward declarations */ -struct CSR_Matrix; -struct CSC_Matrix; +#include "CSC_Matrix.h" +#include "CSR_Matrix.h" /* Compute sparsity pattern and values for the matrix-matrix multiplication C = (I_p kron A) @ J where A is m x n, J is (n*p) x k, and C is (m*p) x k, @@ -15,31 +14,29 @@ struct CSC_Matrix; * Mathematically it corresponds to C = [A @ J1; A @ J2; ...; A @ Jp], where J = [J1; J2; ...; Jp] */ -struct CSC_Matrix *block_left_multiply_fill_sparsity(const struct CSR_Matrix *A, - const struct CSC_Matrix *J, - int p); +CSC_Matrix *block_left_multiply_fill_sparsity(const CSR_Matrix *A, + const CSC_Matrix *J, int p); -void block_left_multiply_fill_vals(const struct CSR_Matrix *A, - const struct CSC_Matrix *J, struct CSC_Matrix *C); +void block_left_multiply_fill_values(const CSR_Matrix *A, const CSC_Matrix *J, + CSC_Matrix *C); /* Compute y = kron(I_p, A) @ x where A is m x n and x is(n*p)-length vector. The output y is m*p-length vector corresponding to y = [A @ x1; A @ x2; ...; A @ xp] where x is divided into p blocks of n elements. */ -void block_left_multiply_vec(const struct CSR_Matrix *A, const double *x, double *y, +void block_left_multiply_vec(const CSR_Matrix *A, const double *x, double *y, int p); /* Fill values of C = A @ B where A is CSR, B is CSC. * C must have sparsity pattern already computed. */ -void csr_csc_matmul_fill_vals(const struct CSR_Matrix *A, const struct CSC_Matrix *B, - struct CSR_Matrix *C); +void csr_csc_matmul_fill_values(const CSR_Matrix *A, const CSC_Matrix *B, + CSR_Matrix *C); /* C = A @ B where A is CSR, B is CSC. Result C is CSR. * Allocates and precomputes sparsity pattern. No workspace required. */ -struct CSR_Matrix *csr_csc_matmul_alloc(const struct CSR_Matrix *A, - const struct CSC_Matrix *B); +CSR_Matrix *csr_csc_matmul_alloc(const CSR_Matrix *A, const CSC_Matrix *B); #endif /* LINALG_H */ diff --git a/src/affine/add.c b/src/affine/add.c index 3231507..bd86da4 100644 --- a/src/affine/add.c +++ b/src/affine/add.c @@ -55,7 +55,7 @@ static void eval_jacobian(expr *node) node->right->eval_jacobian(node->right); /* sum children's jacobians */ - sum_csr_fill_vals(node->left->jacobian, node->right->jacobian, node->jacobian); + sum_csr_fill_values(node->left->jacobian, node->right->jacobian, node->jacobian); } static void wsum_hess_init_impl(expr *node) @@ -79,8 +79,8 @@ static void eval_wsum_hess(expr *node, const double *w) node->right->eval_wsum_hess(node->right, w); /* sum children's wsum_hess */ - sum_csr_fill_vals(node->left->wsum_hess, node->right->wsum_hess, - node->wsum_hess); + sum_csr_fill_values(node->left->wsum_hess, node->right->wsum_hess, + node->wsum_hess); } static bool is_affine(const expr *node) diff --git a/src/affine/hstack.c b/src/affine/hstack.c index b10356f..92f99ad 100644 --- a/src/affine/hstack.c +++ b/src/affine/hstack.c @@ -140,7 +140,7 @@ static void wsum_hess_eval(expr *node, const double *w) expr *child = hnode->args[i]; child->eval_wsum_hess(child, w + row_offset); copy_csr_matrix(H, hnode->CSR_work); - sum_csr_fill_vals(hnode->CSR_work, child->wsum_hess, H); + sum_csr_fill_values(hnode->CSR_work, child->wsum_hess, H); row_offset += child->size; } } diff --git a/src/affine/left_matmul.c b/src/affine/left_matmul.c index a9a782d..4067e65 100644 --- a/src/affine/left_matmul.c +++ b/src/affine/left_matmul.c @@ -106,11 +106,11 @@ static void eval_jacobian(expr *node) /* evaluate child's jacobian and convert to CSC */ x->eval_jacobian(x); - csr_to_csc_fill_vals(x->jacobian, Jchild_CSC, node->work->iwork); + csr_to_csc_fill_values(x->jacobian, Jchild_CSC, node->work->iwork); /* compute this node's jacobian: */ lnode->A->block_left_mult_values(lnode->A, Jchild_CSC, J_CSC); - csc_to_csr_fill_vals(J_CSC, node->jacobian, lnode->csc_to_csr_work); + csc_to_csr_fill_values(J_CSC, node->jacobian, lnode->csc_to_csr_work); } static void wsum_hess_init_impl(expr *node) diff --git a/src/bivariate_full_dom/matmul.c b/src/bivariate_full_dom/matmul.c index aaa8fa1..64b6d0b 100644 --- a/src/bivariate_full_dom/matmul.c +++ b/src/bivariate_full_dom/matmul.c @@ -249,13 +249,14 @@ static void eval_jacobian_chain_rule(expr *node) /* evaluate Jacobians of children */ f->eval_jacobian(f); g->eval_jacobian(g); - csr_to_csc_fill_vals(f->jacobian, f->work->jacobian_csc, f->work->csc_work); - csr_to_csc_fill_vals(g->jacobian, g->work->jacobian_csc, g->work->csc_work); + csr_to_csc_fill_values(f->jacobian, f->work->jacobian_csc, f->work->csc_work); + csr_to_csc_fill_values(g->jacobian, g->work->jacobian_csc, g->work->csc_work); /* evaluate term1, term2, and their sum */ - YT_kron_I_fill_vals(m, k, n, g->value, f->work->jacobian_csc, mnode->term1_CSR); - I_kron_X_fill_vals(m, k, n, f->value, g->work->jacobian_csc, mnode->term2_CSR); - sum_csr_fill_vals(mnode->term1_CSR, mnode->term2_CSR, node->jacobian); + YT_kron_I_fill_values(m, k, n, g->value, f->work->jacobian_csc, + mnode->term1_CSR); + I_kron_X_fill_values(m, k, n, f->value, g->work->jacobian_csc, mnode->term2_CSR); + sum_csr_fill_values(mnode->term1_CSR, mnode->term2_CSR, node->jacobian); } // ------------------------------------------------------------------------------------ @@ -464,7 +465,7 @@ static void eval_wsum_hess_chain_rule(expr *node, const double *w) /* refresh child Jacobian CSC values (cache if affine) */ if (!f->work->jacobian_csc_filled) { - csr_to_csc_fill_vals(f->jacobian, Jf, f->work->csc_work); + csr_to_csc_fill_values(f->jacobian, Jf, f->work->csc_work); if (is_f_affine) { f->work->jacobian_csc_filled = true; @@ -474,7 +475,7 @@ static void eval_wsum_hess_chain_rule(expr *node, const double *w) /* refresh child Jacobian CSC values (cache if affine) */ if (!g->work->jacobian_csc_filled) { - csr_to_csc_fill_vals(g->jacobian, Jg, g->work->csc_work); + csr_to_csc_fill_values(g->jacobian, Jg, g->work->csc_work); if (is_g_affine) { g->work->jacobian_csc_filled = true; @@ -483,12 +484,12 @@ static void eval_wsum_hess_chain_rule(expr *node, const double *w) /* compute C = J_f^T @ B(w) @ J_g */ fill_cross_hessian_values(m, k, n, w, mnode->B); - csr_csc_matmul_fill_vals(mnode->B, Jg, mnode->BJg); - csr_to_csc_fill_vals(mnode->BJg, mnode->BJg_CSC, mnode->BJg_csc_work); - BTDA_fill_vals(mnode->BJg_CSC, Jf, NULL, mnode->C); + csr_csc_matmul_fill_values(mnode->B, Jg, mnode->BJg); + csr_to_csc_fill_values(mnode->BJg, mnode->BJg_CSC, mnode->BJg_csc_work); + BTDA_fill_values(mnode->BJg_CSC, Jf, NULL, mnode->C); /* compute CT */ - AT_fill_vals(mnode->C, mnode->CT, node->work->iwork); + AT_fill_values(mnode->C, mnode->CT, node->work->iwork); /* compute Hessian of f */ if (!is_f_affine) diff --git a/src/bivariate_full_dom/multiply.c b/src/bivariate_full_dom/multiply.c index 98139c1..a77a88c 100644 --- a/src/bivariate_full_dom/multiply.c +++ b/src/bivariate_full_dom/multiply.c @@ -65,8 +65,8 @@ static void eval_jacobian(expr *node) /* chain rule: the jacobian of h(x) = f(g1(x), g2(x))) is Jh = J_{f, 1} J_{g1} + * J_{f, 2} J_{g2} */ - sum_scaled_csr_matrices_fill_vals(x->jacobian, y->jacobian, node->jacobian, - y->value, x->value); + sum_scaled_csr_matrices_fill_values(x->jacobian, y->jacobian, node->jacobian, + y->value, x->value); } static void wsum_hess_init_impl(expr *node) @@ -200,8 +200,8 @@ static void eval_wsum_hess(expr *node, const double *w) // ---------------------------------------------------------------------- if (!x->work->jacobian_csc_filled) { - csr_to_csc_fill_vals(x->jacobian, x->work->jacobian_csc, - x->work->csc_work); + csr_to_csc_fill_values(x->jacobian, x->work->jacobian_csc, + x->work->csc_work); if (is_x_affine) { @@ -211,8 +211,8 @@ static void eval_wsum_hess(expr *node, const double *w) if (!y->work->jacobian_csc_filled) { - csr_to_csc_fill_vals(y->jacobian, y->work->jacobian_csc, - y->work->csc_work); + csr_to_csc_fill_values(y->jacobian, y->work->jacobian_csc, + y->work->csc_work); if (is_y_affine) { @@ -229,8 +229,8 @@ static void eval_wsum_hess(expr *node, const double *w) elementwise_mult_expr *mul_node = (elementwise_mult_expr *) node; CSR_Matrix *C = mul_node->CSR_work1; CSR_Matrix *CT = mul_node->CSR_work2; - BTDA_fill_vals(Jg1, Jg2, w, C); - AT_fill_vals(C, CT, node->work->iwork); + BTDA_fill_values(Jg1, Jg2, w, C); + AT_fill_values(C, CT, node->work->iwork); // --------------------------------------------------------------- // compute term2 and term 3 diff --git a/src/bivariate_restricted_dom/quad_over_lin.c b/src/bivariate_restricted_dom/quad_over_lin.c index 9b78c83..bc8ea05 100644 --- a/src/bivariate_restricted_dom/quad_over_lin.c +++ b/src/bivariate_restricted_dom/quad_over_lin.c @@ -124,7 +124,8 @@ static void jacobian_init_impl(expr *node) * For a linear operator the values are constant, so fill * them once here. */ jacobian_csc_init(x); - csr_to_csc_fill_vals(x->jacobian, x->work->jacobian_csc, x->work->csc_work); + csr_to_csc_fill_values(x->jacobian, x->work->jacobian_csc, + x->work->csc_work); } } @@ -163,7 +164,7 @@ static void eval_jacobian(expr *node) } /* chain rule (no derivative wrt y) using CSC format */ - yTA_fill_vals(x->work->jacobian_csc, node->work->dwork, node->jacobian); + yTA_fill_values(x->work->jacobian_csc, node->work->dwork, node->jacobian); /* insert derivative wrt y at right place (for correctness this assumes that y does not appear in the numerator, but this will always be diff --git a/src/elementwise_full_dom/common.c b/src/elementwise_full_dom/common.c index 9874bc0..74280f9 100644 --- a/src/elementwise_full_dom/common.c +++ b/src/elementwise_full_dom/common.c @@ -49,7 +49,7 @@ void eval_jacobian_elementwise(expr *node) node->local_jacobian(node, node->work->local_jac_diag); memcpy(node->work->dwork, node->work->local_jac_diag, node->size * sizeof(double)); - diag_csr_mult_fill_vals(node->work->dwork, Jg, node->jacobian); + diag_csr_mult_fill_values(node->work->dwork, Jg, node->jacobian); } } @@ -122,25 +122,25 @@ void eval_wsum_hess_elementwise(expr *node, const double *w) { if (!child->work->jacobian_csc_filled) { - csr_to_csc_fill_vals(child->jacobian, child->work->jacobian_csc, - child->work->csc_work); + csr_to_csc_fill_values(child->jacobian, child->work->jacobian_csc, + child->work->csc_work); child->work->jacobian_csc_filled = true; } node->local_wsum_hess(node, node->work->dwork, w); - ATDA_fill_vals(child->work->jacobian_csc, node->work->dwork, - node->wsum_hess); + ATDA_fill_values(child->work->jacobian_csc, node->work->dwork, + node->wsum_hess); } else { /* refresh CSC jacobian values */ - csr_to_csc_fill_vals(child->jacobian, child->work->jacobian_csc, - child->work->csc_work); + csr_to_csc_fill_values(child->jacobian, child->work->jacobian_csc, + child->work->csc_work); /* term1: Jg^T @ D @ Jg */ node->local_wsum_hess(node, node->work->dwork, w); - ATDA_fill_vals(child->work->jacobian_csc, node->work->dwork, - node->work->hess_term1); + ATDA_fill_values(child->work->jacobian_csc, node->work->dwork, + node->work->hess_term1); /* term2: child Hessian with weight Jf^T w */ memcpy(node->work->dwork, node->work->local_jac_diag, @@ -155,8 +155,8 @@ void eval_wsum_hess_elementwise(expr *node, const double *w) child->wsum_hess->nnz * sizeof(double)); /* wsum_hess = term1 + term2 */ - sum_csr_fill_vals(node->work->hess_term1, node->work->hess_term2, - node->wsum_hess); + sum_csr_fill_values(node->work->hess_term1, node->work->hess_term2, + node->wsum_hess); } } } diff --git a/src/other/quad_form.c b/src/other/quad_form.c index 0311d22..78bd436 100644 --- a/src/other/quad_form.c +++ b/src/other/quad_form.c @@ -88,8 +88,8 @@ static void eval_jacobian(expr *node) if (!x->work->jacobian_csc_filled) { - csr_to_csc_fill_vals(x->jacobian, x->work->jacobian_csc, - x->work->csc_work); + csr_to_csc_fill_values(x->jacobian, x->work->jacobian_csc, + x->work->csc_work); if (x->is_affine(x)) { @@ -99,7 +99,7 @@ static void eval_jacobian(expr *node) /* The jacobian has same values as the gradient, which is J_f^T (Q @ f(x)). Here, dwork stores Q @ f(x) from forward */ - yTA_fill_vals(x->work->jacobian_csc, node->work->dwork, node->jacobian); + yTA_fill_values(x->work->jacobian_csc, node->work->dwork, node->jacobian); cblas_dscal(node->jacobian->nnz, 2.0, node->jacobian->x, 1); } @@ -180,7 +180,7 @@ static void eval_wsum_hess(expr *node, const double *w) CSC_Matrix *Jf = x->work->jacobian_csc; if (!x->work->jacobian_csc_filled) { - csr_to_csc_fill_vals(x->jacobian, Jf, x->work->csc_work); + csr_to_csc_fill_values(x->jacobian, Jf, x->work->csc_work); if (x->is_affine(x)) { @@ -193,8 +193,8 @@ static void eval_wsum_hess(expr *node, const double *w) CSR_Matrix *term2 = node->work->hess_term2; /* term1 = J_f^T Q J_f = J_f^T B */ - BA_fill_vals(Q, Jf, QJf); - BTDA_fill_vals(Jf, QJf, NULL, term1); + BA_fill_values(Q, Jf, QJf); + BTDA_fill_values(Jf, QJf, NULL, term1); /* term2 */ x->eval_wsum_hess(x, node->work->dwork); @@ -205,7 +205,7 @@ static void eval_wsum_hess(expr *node, const double *w) cblas_dscal(term2->nnz, two_w, term2->x, 1); /* sum the two terms */ - sum_csr_fill_vals(term1, term2, node->wsum_hess); + sum_csr_fill_values(term1, term2, node->wsum_hess); } } diff --git a/src/utils/CSC_Matrix.c b/src/utils/CSC_Matrix.c index bdb7f59..9d4e078 100644 --- a/src/utils/CSC_Matrix.c +++ b/src/utils/CSC_Matrix.c @@ -168,7 +168,7 @@ static inline double sparse_wdot(const double *a_x, const int *a_i, int a_nnz, return sum; } -void ATDA_fill_vals(const CSC_Matrix *A, const double *d, CSR_Matrix *C) +void ATDA_fill_values(const CSC_Matrix *A, const double *d, CSR_Matrix *C) { int j, ii, jj; for (ii = 0; ii < C->m; ii++) @@ -246,7 +246,7 @@ CSC_Matrix *csr_to_csc_alloc(const CSR_Matrix *A, int *iwork) return C; } -void csr_to_csc_fill_vals(const CSR_Matrix *A, CSC_Matrix *C, int *iwork) +void csr_to_csc_fill_values(const CSR_Matrix *A, CSC_Matrix *C, int *iwork) { int i, j; int *count = iwork; @@ -311,7 +311,7 @@ CSR_Matrix *csc_to_csr_alloc(const CSC_Matrix *A, int *iwork) return C; } -void csc_to_csr_fill_vals(const CSC_Matrix *A, CSR_Matrix *C, int *iwork) +void csc_to_csr_fill_values(const CSC_Matrix *A, CSR_Matrix *C, int *iwork) { int i, j; int *count = iwork; @@ -388,7 +388,7 @@ CSR_Matrix *BTA_alloc(const CSC_Matrix *A, const CSC_Matrix *B) return C; } -void yTA_fill_vals(const CSC_Matrix *A, const double *y, CSR_Matrix *C) +void yTA_fill_values(const CSC_Matrix *A, const double *y, CSR_Matrix *C) { for (int col = 0; col < A->n; col++) { @@ -413,8 +413,8 @@ void yTA_fill_vals(const CSC_Matrix *A, const double *y, CSR_Matrix *C) } /* computes C = B^T * D * A in CSR */ -void BTDA_fill_vals(const CSC_Matrix *A, const CSC_Matrix *B, const double *d, - CSR_Matrix *C) +void BTDA_fill_values(const CSC_Matrix *A, const CSC_Matrix *B, const double *d, + CSR_Matrix *C) { int i, j, jj; for (i = 0; i < C->m; i++) @@ -445,7 +445,7 @@ void BTDA_fill_vals(const CSC_Matrix *A, const CSC_Matrix *B, const double *d, * faster when Q is dense, since it touches each Q entry exactly once. * The sparse_dot approach below is simpler but redundantly scans * column j of A for each nonzero row of C. */ -void BA_fill_vals(const CSR_Matrix *Q, const CSC_Matrix *A, CSC_Matrix *C) +void BA_fill_values(const CSR_Matrix *Q, const CSC_Matrix *A, CSC_Matrix *C) { /* fill values of C = Q * A, given the sparsity pattern of C. */ int i, j, ii; diff --git a/src/utils/CSR_Matrix.c b/src/utils/CSR_Matrix.c index e09a146..b175ea1 100644 --- a/src/utils/CSR_Matrix.c +++ b/src/utils/CSR_Matrix.c @@ -215,7 +215,7 @@ void diag_csr_mult(const double *d, const CSR_Matrix *A, CSR_Matrix *C) } } -void diag_csr_mult_fill_vals(const double *d, const CSR_Matrix *A, CSR_Matrix *C) +void diag_csr_mult_fill_values(const double *d, const CSR_Matrix *A, CSR_Matrix *C) { memcpy(C->x, A->x, A->nnz * sizeof(double)); @@ -346,7 +346,7 @@ CSR_Matrix *AT_alloc(const CSR_Matrix *A, int *iwork) return AT; } -void AT_fill_vals(const CSR_Matrix *A, CSR_Matrix *AT, int *iwork) +void AT_fill_values(const CSR_Matrix *A, CSR_Matrix *AT, int *iwork) { /* Fill values of A^T given sparsity pattern is already computed */ int i, j; @@ -365,7 +365,7 @@ void AT_fill_vals(const CSR_Matrix *A, CSR_Matrix *AT, int *iwork) } /**/ -void csr_matvec_fill_vals(const CSR_Matrix *AT, const double *z, CSR_Matrix *C) +void csr_matvec_fill_values(const CSR_Matrix *AT, const double *z, CSR_Matrix *C) { int A_ncols = AT->m; diff --git a/src/utils/CSR_sum.c b/src/utils/CSR_sum.c index b6cb59d..1f30906 100644 --- a/src/utils/CSR_sum.c +++ b/src/utils/CSR_sum.c @@ -142,7 +142,7 @@ void sum_csr_alloc(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C) C->p[A->m] = C->nnz; } -void sum_csr_fill_vals(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C) +void sum_csr_fill_values(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C) { /* Assumes C->p and C->i already contain the sparsity pattern of A+B. Fills only C->x accordingly. */ @@ -174,9 +174,9 @@ void sum_csr_fill_vals(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C) } } -void sum_scaled_csr_matrices_fill_vals(const CSR_Matrix *A, const CSR_Matrix *B, - CSR_Matrix *C, const double *d1, - const double *d2) +void sum_scaled_csr_matrices_fill_values(const CSR_Matrix *A, const CSR_Matrix *B, + CSR_Matrix *C, const double *d1, + const double *d2) { /* Assumes C->p and C->i already contain the sparsity pattern of A+B. Fills only C->x accordingly with scaling. */ @@ -437,7 +437,7 @@ void sum_block_of_rows_csr_fill_sparsity_and_idx_map(const CSR_Matrix *A, } /* -void sum_block_of_rows_csr_fill_vals(const CSR_Matrix *A, CSR_Matrix *C, +void sum_block_of_rows_csr_fill_values(const CSR_Matrix *A, CSR_Matrix *C, const int *idx_map) { memset(C->x, 0, C->nnz * sizeof(double)); @@ -683,7 +683,7 @@ void sum_all_rows_csr_fill_sparsity_and_idx_map(const CSR_Matrix *A, CSR_Matrix } /* -void sum_all_rows_csr_fill_vals(const CSR_Matrix *A, CSR_Matrix *C, +void sum_all_rows_csr_fill_values(const CSR_Matrix *A, CSR_Matrix *C, const int *idx_map) { memset(C->x, 0, C->nnz * sizeof(double)); diff --git a/src/utils/dense_matrix.c b/src/utils/dense_matrix.c index bdbe0c0..8a2dd2c 100644 --- a/src/utils/dense_matrix.c +++ b/src/utils/dense_matrix.c @@ -56,7 +56,7 @@ Matrix *new_dense_matrix(int m, int n, const double *data) dm->base.n = n; dm->base.block_left_mult_vec = dense_block_left_mult_vec; dm->base.block_left_mult_sparsity = I_kron_A_alloc; - dm->base.block_left_mult_values = I_kron_A_fill_vals; + dm->base.block_left_mult_values = I_kron_A_fill_values; dm->base.free_fn = dense_free; dm->x = (double *) malloc(m * n * sizeof(double)); memcpy(dm->x, data, m * n * sizeof(double)); diff --git a/src/utils/linalg_dense_sparse_matmuls.c b/src/utils/linalg_dense_sparse_matmuls.c index 3c92bad..a073349 100644 --- a/src/utils/linalg_dense_sparse_matmuls.c +++ b/src/utils/linalg_dense_sparse_matmuls.c @@ -90,7 +90,7 @@ CSC_Matrix *I_kron_A_alloc(const Matrix *A, const CSC_Matrix *J, int p) return C; } -void I_kron_A_fill_vals(const Matrix *A, const CSC_Matrix *J, CSC_Matrix *C) +void I_kron_A_fill_values(const Matrix *A, const CSC_Matrix *J, CSC_Matrix *C) { const Dense_Matrix *dm = (const Dense_Matrix *) A; int m = dm->base.m; @@ -218,8 +218,8 @@ CSR_Matrix *YT_kron_I_alloc(int m, int k, int n, const CSC_Matrix *J) return C; } -void YT_kron_I_fill_vals(int m, int k, int n, const double *Y, const CSC_Matrix *J, - CSR_Matrix *C) +void YT_kron_I_fill_values(int m, int k, int n, const double *Y, const CSC_Matrix *J, + CSR_Matrix *C) { (void) n; assert(C->m == m * n); @@ -307,8 +307,8 @@ CSR_Matrix *I_kron_X_alloc(int m, int k, int n, const CSC_Matrix *J) return C; } -void I_kron_X_fill_vals(int m, int k, int n, const double *X, const CSC_Matrix *J, - CSR_Matrix *C) +void I_kron_X_fill_values(int m, int k, int n, const double *X, const CSC_Matrix *J, + CSR_Matrix *C) { (void) n; assert(C->m == m * n); diff --git a/src/utils/linalg_sparse_matmuls.c b/src/utils/linalg_sparse_matmuls.c index 9be3499..9a7b164 100644 --- a/src/utils/linalg_sparse_matmuls.c +++ b/src/utils/linalg_sparse_matmuls.c @@ -183,8 +183,8 @@ CSC_Matrix *block_left_multiply_fill_sparsity(const CSR_Matrix *A, return C; } -void block_left_multiply_fill_vals(const CSR_Matrix *A, const CSC_Matrix *J, - CSC_Matrix *C) +void block_left_multiply_fill_values(const CSR_Matrix *A, const CSC_Matrix *J, + CSC_Matrix *C) { /* A is m x n, J is (n*p) x k, C is (m*p) x k */ int m = A->m; @@ -246,8 +246,8 @@ void block_left_multiply_fill_vals(const CSR_Matrix *A, const CSC_Matrix *J, } /* Fill values of C = A @ B where A is CSR, B is CSC. */ -void csr_csc_matmul_fill_vals(const CSR_Matrix *A, const CSC_Matrix *B, - CSR_Matrix *C) +void csr_csc_matmul_fill_values(const CSR_Matrix *A, const CSC_Matrix *B, + CSR_Matrix *C) { for (int i = 0; i < A->m; i++) { diff --git a/src/utils/sparse_matrix.c b/src/utils/sparse_matrix.c index 6d9b365..24ed539 100644 --- a/src/utils/sparse_matrix.c +++ b/src/utils/sparse_matrix.c @@ -37,7 +37,7 @@ static void sparse_block_left_mult_values(const Matrix *self, const CSC_Matrix * CSC_Matrix *C) { const Sparse_Matrix *sm = (const Sparse_Matrix *) self; - block_left_multiply_fill_vals(sm->csr, J, C); + block_left_multiply_fill_values(sm->csr, J, C); } static void sparse_free(Matrix *self) diff --git a/tests/utils/test_csc_matrix.h b/tests/utils/test_csc_matrix.h index 312ce66..5459021 100644 --- a/tests/utils/test_csc_matrix.h +++ b/tests/utils/test_csc_matrix.h @@ -101,7 +101,7 @@ const char *test_ATA_alloc_random(void) double d[10] = {2, 8, 6, 2, 5, 1, 6, 9, 1, 3}; - ATDA_fill_vals(A, d, C); + ATDA_fill_values(A, d, C); double Cx_correct[38] = { 49., 21., 491., 56., 240., 416., 144., 288., 56., 98., 56., 21., 9., @@ -139,7 +139,7 @@ const char *test_ATA_alloc_random2(void) double d[15] = {-0.6, -0.23, -0.29, -1.36, 0.4, 0.36, 0.11, -0.13, -1.32, -0.32, -0.24, -0.7, -0.06, 0.5, 1.99}; - ATDA_fill_vals(A, d, C); + ATDA_fill_values(A, d, C); double Cx_correct[17] = {-0.362232, -0.189896, 0.06656, -0.228888, -0.025732, -0.016146, 0.032857, 0.06656, -1.004802, 0.1505, @@ -198,7 +198,7 @@ const char *test_BTA_alloc_and_BTDA_fill(void) /* Fill values with diagonal weights d */ double d[4] = {1.0, 2.0, 3.0, 4.0}; - BTDA_fill_vals(A, B, d, C); + BTDA_fill_values(A, B, d, C); double expected_x[3] = {37.0, 47.0, 108.0}; mu_assert("C values incorrect", cmp_double_array(C->x, expected_x, 3)); diff --git a/tests/utils/test_csr_csc_conversion.h b/tests/utils/test_csr_csc_conversion.h index 5070e12..efbdc9e 100644 --- a/tests/utils/test_csr_csc_conversion.h +++ b/tests/utils/test_csr_csc_conversion.h @@ -8,7 +8,7 @@ #include "utils/CSC_Matrix.h" #include "utils/CSR_Matrix.h" -/* Test CSR to CSC conversion with fill_sparsity and fill_vals */ +/* Test CSR to CSC conversion with fill_sparsity and fill_values */ const char *test_csr_to_csc_split(void) { /* Create a 4x5 CSR matrix A: @@ -39,7 +39,7 @@ const char *test_csr_to_csc_split(void) mu_assert("C row indices incorrect", cmp_int_array(C->i, Ci_correct, 5)); /* Now fill values */ - csr_to_csc_fill_vals(A, C, iwork); + csr_to_csc_fill_values(A, C, iwork); /* Check values */ double Cx_correct[5] = {1.0, 2.0, 3.0, 4.0, 1.0}; @@ -97,7 +97,7 @@ const char *test_csc_to_csr_sparsity(void) return 0; } -/* Test CSC to CSR conversion with fill_vals */ +/* Test CSC to CSR conversion with fill_values */ const char *test_csc_to_csr_values(void) { /* Create a 4x5 CSC matrix A */ @@ -116,7 +116,7 @@ const char *test_csc_to_csr_values(void) CSR_Matrix *C = csc_to_csr_alloc(A, iwork); /* Fill values */ - csc_to_csr_fill_vals(A, C, iwork); + csc_to_csr_fill_values(A, C, iwork); /* Check values */ double Cx_correct[5] = {1.0, 2.0, 3.0, 4.0, 5.0}; @@ -149,12 +149,12 @@ const char *test_csr_csc_csr_roundtrip(void) /* Convert CSR to CSC */ int *iwork_csc = (int *) malloc(A->n * sizeof(int)); CSC_Matrix *B = csr_to_csc_alloc(A, iwork_csc); - csr_to_csc_fill_vals(A, B, iwork_csc); + csr_to_csc_fill_values(A, B, iwork_csc); /* Convert CSC back to CSR */ int *iwork_csr = (int *) malloc(B->m * sizeof(int)); CSR_Matrix *C = csc_to_csr_alloc(B, iwork_csr); - csc_to_csr_fill_vals(B, C, iwork_csr); + csc_to_csr_fill_values(B, C, iwork_csr); /* C should match A */ mu_assert("Round-trip: vals incorrect", cmp_double_array(C->x, Ax, 8)); diff --git a/tests/utils/test_csr_matrix.h b/tests/utils/test_csr_matrix.h index 410da18..f6c9536 100644 --- a/tests/utils/test_csr_matrix.h +++ b/tests/utils/test_csr_matrix.h @@ -187,7 +187,7 @@ const char *test_csr_vecmat_values_sparse(void) CSR_Matrix *AT = transpose(A, iwork); - csr_matvec_fill_vals(AT, z, C); + csr_matvec_fill_values(AT, z, C); double Cx_correct[3] = {7.0, 22.0, 1.0}; @@ -412,7 +412,7 @@ const char *test_AT_alloc_and_fill(void) CSR_Matrix *AT = AT_alloc(A, iwork); /* Fill values of A^T */ - AT_fill_vals(A, AT, iwork); + AT_fill_values(A, AT, iwork); /* Expected A^T: * [1.0 0.0 5.0] diff --git a/tests/utils/test_linalg_sparse_matmuls.h b/tests/utils/test_linalg_sparse_matmuls.h index 2f8c639..120f99b 100644 --- a/tests/utils/test_linalg_sparse_matmuls.h +++ b/tests/utils/test_linalg_sparse_matmuls.h @@ -110,7 +110,7 @@ const char *test_block_left_multiply_two_blocks(void) * [0.0 1.0 1.0] */ CSC_Matrix *C = block_left_multiply_fill_sparsity(A, J, 2); - block_left_multiply_fill_vals(A, J, C); + block_left_multiply_fill_values(A, J, C); int expected_p2[4] = {0, 1, 2, 3}; int expected_i2[3] = {0, 2, 3}; diff --git a/tests/utils/test_linalg_utils_matmul_chain_rule.h b/tests/utils/test_linalg_utils_matmul_chain_rule.h index 8724aa3..0c66589 100644 --- a/tests/utils/test_linalg_utils_matmul_chain_rule.h +++ b/tests/utils/test_linalg_utils_matmul_chain_rule.h @@ -8,7 +8,7 @@ #include "utils/CSR_Matrix.h" #include "utils/linalg_dense_sparse_matmuls.h" -/* Test YT_kron_I_alloc and YT_kron_I_fill_vals +/* Test YT_kron_I_alloc and YT_kron_I_fill_values * * C = (Y^T kron I_m) @ J * m=2, k=2, n=2, p=3 @@ -57,7 +57,7 @@ const char *test_YT_kron_I(void) mu_assert("C row ptrs", cmp_int_array(C->p, exp_p, 5)); mu_assert("C col indices", cmp_int_array(C->i, exp_i, 8)); - YT_kron_I_fill_vals(m, k, n, Y, J, C); + YT_kron_I_fill_values(m, k, n, Y, J, C); mu_assert("C values", cmp_double_array(C->x, exp_x, 8)); free_csr_matrix(C); @@ -110,7 +110,7 @@ const char *test_YT_kron_I_larger(void) mu_assert("C2 row ptrs", cmp_int_array(C->p, exp_p, 10)); mu_assert("C2 col indices", cmp_int_array(C->i, exp_i, 18)); - YT_kron_I_fill_vals(m, k, n, Y, J, C); + YT_kron_I_fill_values(m, k, n, Y, J, C); mu_assert("C2 values", cmp_double_array(C->x, exp_x, 18)); free_csr_matrix(C); @@ -118,7 +118,7 @@ const char *test_YT_kron_I_larger(void) return NULL; } -/* Test I_kron_X_alloc and I_kron_X_fill_vals +/* Test I_kron_X_alloc and I_kron_X_fill_values * * C = (I_n kron X) @ J * m=2, k=2, n=2, p=3 @@ -167,7 +167,7 @@ const char *test_I_kron_X(void) mu_assert("C row ptrs", cmp_int_array(C->p, exp_p, 5)); mu_assert("C col indices", cmp_int_array(C->i, exp_i, 10)); - I_kron_X_fill_vals(m, k, n, X, J, C); + I_kron_X_fill_values(m, k, n, X, J, C); mu_assert("C values", cmp_double_array(C->x, exp_x, 10)); free_csr_matrix(C); @@ -219,7 +219,7 @@ const char *test_I_kron_X_larger(void) mu_assert("C2 row ptrs", cmp_int_array(C->p, exp_p, 7)); mu_assert("C2 col indices", cmp_int_array(C->i, exp_i, 21)); - I_kron_X_fill_vals(m, k, n, X, J, C); + I_kron_X_fill_values(m, k, n, X, J, C); mu_assert("C2 values", cmp_double_array(C->x, exp_x, 21)); free_csr_matrix(C); From 107c5596dc390a44e915c0fb2f76c91de3109cd3 Mon Sep 17 00:00:00 2001 From: dance858 Date: Mon, 30 Mar 2026 12:49:10 -0700 Subject: [PATCH 13/13] run formaterre --- include/utils/CSC_Matrix.h | 2 +- include/utils/CSR_sum.h | 4 ++-- include/utils/linalg_dense_sparse_matmuls.h | 4 ++-- include/utils/linalg_sparse_matmuls.h | 7 +++---- 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/include/utils/CSC_Matrix.h b/include/utils/CSC_Matrix.h index def4a86..951b088 100644 --- a/include/utils/CSC_Matrix.h +++ b/include/utils/CSC_Matrix.h @@ -43,7 +43,7 @@ void ATDA_fill_values(const CSC_Matrix *A, const double *d, CSR_Matrix *C); /* Compute values for C = B^T D A (null d corresonds to D as identity) */ void BTDA_fill_values(const CSC_Matrix *A, const CSC_Matrix *B, const double *d, - CSR_Matrix *C); + CSR_Matrix *C); /* Fill values of C = BA. The matrix B does not have to be symmetric */ void BA_fill_values(const CSR_Matrix *B, const CSC_Matrix *A, CSC_Matrix *C); diff --git a/include/utils/CSR_sum.h b/include/utils/CSR_sum.h index abe6c84..063b024 100644 --- a/include/utils/CSR_sum.h +++ b/include/utils/CSR_sum.h @@ -29,8 +29,8 @@ void sum_scaled_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matri * pattern (p and i) is already filled and matches the union of A and B per row. * Does not modify C->p, C->i, or C->nnz. */ void sum_scaled_csr_matrices_fill_values(const CSR_Matrix *A, const CSR_Matrix *B, - CSR_Matrix *C, const double *d1, - const double *d2); + CSR_Matrix *C, const double *d1, + const double *d2); /* Sum all rows of A into a single row matrix C */ void sum_all_rows_csr(const CSR_Matrix *A, CSR_Matrix *C, diff --git a/include/utils/linalg_dense_sparse_matmuls.h b/include/utils/linalg_dense_sparse_matmuls.h index 5057424..8404940 100644 --- a/include/utils/linalg_dense_sparse_matmuls.h +++ b/include/utils/linalg_dense_sparse_matmuls.h @@ -15,12 +15,12 @@ void I_kron_A_fill_values(const Matrix *A, const CSC_Matrix *J, CSC_Matrix *C); and C is (m*n) x p. Y is given in column-major dense format. */ CSR_Matrix *YT_kron_I_alloc(int m, int k, int n, const CSC_Matrix *J); void YT_kron_I_fill_values(int m, int k, int n, const double *Y, const CSC_Matrix *J, - CSR_Matrix *C); + CSR_Matrix *C); /* Sparsity and values of C = (I_n kron X) @ J where X is m x k (col-major dense), J is (k*n) x p, and C is (m*n) x p. */ CSR_Matrix *I_kron_X_alloc(int m, int k, int n, const CSC_Matrix *J); void I_kron_X_fill_values(int m, int k, int n, const double *X, const CSC_Matrix *J, - CSR_Matrix *C); + CSR_Matrix *C); #endif /* LINALG_DENSE_SPARSE_H */ diff --git a/include/utils/linalg_sparse_matmuls.h b/include/utils/linalg_sparse_matmuls.h index daa6c75..7a40c92 100644 --- a/include/utils/linalg_sparse_matmuls.h +++ b/include/utils/linalg_sparse_matmuls.h @@ -18,21 +18,20 @@ CSC_Matrix *block_left_multiply_fill_sparsity(const CSR_Matrix *A, const CSC_Matrix *J, int p); void block_left_multiply_fill_values(const CSR_Matrix *A, const CSC_Matrix *J, - CSC_Matrix *C); + CSC_Matrix *C); /* Compute y = kron(I_p, A) @ x where A is m x n and x is(n*p)-length vector. The output y is m*p-length vector corresponding to y = [A @ x1; A @ x2; ...; A @ xp] where x is divided into p blocks of n elements. */ -void block_left_multiply_vec(const CSR_Matrix *A, const double *x, double *y, - int p); +void block_left_multiply_vec(const CSR_Matrix *A, const double *x, double *y, int p); /* Fill values of C = A @ B where A is CSR, B is CSC. * C must have sparsity pattern already computed. */ void csr_csc_matmul_fill_values(const CSR_Matrix *A, const CSC_Matrix *B, - CSR_Matrix *C); + CSR_Matrix *C); /* C = A @ B where A is CSR, B is CSC. Result C is CSR. * Allocates and precomputes sparsity pattern. No workspace required.