diff --git a/include/subexpr.h b/include/subexpr.h index c87da07..44a964b 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -131,6 +131,28 @@ typedef struct right_matmul_expr CSC_Matrix *CSC_work; } right_matmul_expr; +/* Bivariate matrix multiplication: Z = f(u) @ g(u) where both children + * may be composite expressions. */ +typedef struct matmul_expr +{ + expr base; + /* Jacobian workspace */ + CSR_Matrix *term1_CSR; /* (Y^T x I_m) @ J_f */ + CSR_Matrix *term2_CSR; /* (I_n x X) @ J_g */ + + /* Hessian workspace (composite only) */ + CSR_Matrix *B; /* cross-Hessian B(w), mk x kn */ + CSR_Matrix *BJg; /* B @ J_g */ + CSC_Matrix *BJg_CSC; /* BJg in CSC */ + int *BJg_csc_work; /* CSR-to-CSC workspace */ + CSR_Matrix *C; /* J_f^T @ B @ J_g */ + CSR_Matrix *CT; /* C^T */ + int *idx_map_C; + int *idx_map_CT; + int *idx_map_Hf; + int *idx_map_Hg; +} matmul_expr; + /* Constant scalar multiplication: y = a * child where a is a constant double */ typedef struct const_scalar_mult_expr { diff --git a/include/utils/CSR_sum.h b/include/utils/CSR_sum.h index c04c7a8..063b024 100644 --- a/include/utils/CSR_sum.h +++ b/include/utils/CSR_sum.h @@ -14,14 +14,12 @@ void sum_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C); /* Compute sparsity pattern of A + B where A, B, C are CSR matrices. * Fills C->p, C->i, and C->nnz; does not touch C->x. */ -void sum_csr_matrices_fill_sparsity(const CSR_Matrix *A, const CSR_Matrix *B, - CSR_Matrix *C); +void sum_csr_alloc(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C); /* Fill only the values of C = A + B, assuming C's sparsity pattern (p and i) * is already filled and matches the union of A and B per row. Does not modify * C->p, C->i, or C->nnz. */ -void sum_csr_matrices_fill_values(const CSR_Matrix *A, const CSR_Matrix *B, - CSR_Matrix *C); +void sum_csr_fill_values(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C); /* Compute C = diag(d1) * A + diag(d2) * B where A, B, C are CSR matrices */ void sum_scaled_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C, diff --git a/include/utils/dense_matrix.h b/include/utils/dense_matrix.h new file mode 100644 index 0000000..8121020 --- /dev/null +++ b/include/utils/dense_matrix.h @@ -0,0 +1,20 @@ +#ifndef DENSE_MATRIX_H +#define DENSE_MATRIX_H + +#include "matrix.h" + +/* Dense matrix (row-major) */ +typedef struct Dense_Matrix +{ + Matrix base; + double *x; + double *work; /* scratch buffer, length n */ +} Dense_Matrix; + +/* Constructors */ +Matrix *new_dense_matrix(int m, int n, const double *data); + +/* Transpose helper */ +Matrix *dense_matrix_trans(const Dense_Matrix *self); + +#endif /* DENSE_MATRIX_H */ diff --git a/include/utils/linalg_dense_sparse_matmuls.h b/include/utils/linalg_dense_sparse_matmuls.h new file mode 100644 index 0000000..8404940 --- /dev/null +++ b/include/utils/linalg_dense_sparse_matmuls.h @@ -0,0 +1,26 @@ +#ifndef LINALG_DENSE_SPARSE_H +#define LINALG_DENSE_SPARSE_H + +#include "CSC_Matrix.h" +#include "CSR_Matrix.h" +#include "matrix.h" + +/* C = (I_p kron A) @ J via the polymorphic Matrix interface. + * A is dense m x n, J is (n*p) x k in CSC, C is (m*p) x k in CSC. */ +// TODO: maybe we can replace these with I_kron_X functionality? +CSC_Matrix *I_kron_A_alloc(const Matrix *A, const CSC_Matrix *J, int p); +void I_kron_A_fill_values(const Matrix *A, const CSC_Matrix *J, CSC_Matrix *C); + +/* Sparsity and values of C = (Y^T kron I_m) @ J where Y is k x n, J is (m*k) x p, + and C is (m*n) x p. Y is given in column-major dense format. */ +CSR_Matrix *YT_kron_I_alloc(int m, int k, int n, const CSC_Matrix *J); +void YT_kron_I_fill_values(int m, int k, int n, const double *Y, const CSC_Matrix *J, + CSR_Matrix *C); + +/* Sparsity and values of C = (I_n kron X) @ J where X is m x k (col-major dense), + J is (k*n) x p, and C is (m*n) x p. */ +CSR_Matrix *I_kron_X_alloc(int m, int k, int n, const CSC_Matrix *J); +void I_kron_X_fill_values(int m, int k, int n, const double *X, const CSC_Matrix *J, + CSR_Matrix *C); + +#endif /* LINALG_DENSE_SPARSE_H */ diff --git a/include/utils/linalg_sparse_matmuls.h b/include/utils/linalg_sparse_matmuls.h index 25890a9..7a40c92 100644 --- a/include/utils/linalg_sparse_matmuls.h +++ b/include/utils/linalg_sparse_matmuls.h @@ -1,9 +1,8 @@ #ifndef LINALG_H #define LINALG_H -/* Forward declarations */ -struct CSR_Matrix; -struct CSC_Matrix; +#include "CSC_Matrix.h" +#include "CSR_Matrix.h" /* Compute sparsity pattern and values for the matrix-matrix multiplication C = (I_p kron A) @ J where A is m x n, J is (n*p) x k, and C is (m*p) x k, @@ -15,32 +14,28 @@ struct CSC_Matrix; * Mathematically it corresponds to C = [A @ J1; A @ J2; ...; A @ Jp], where J = [J1; J2; ...; Jp] */ -struct CSC_Matrix *block_left_multiply_fill_sparsity(const struct CSR_Matrix *A, - const struct CSC_Matrix *J, - int p); +CSC_Matrix *block_left_multiply_fill_sparsity(const CSR_Matrix *A, + const CSC_Matrix *J, int p); -void block_left_multiply_fill_values(const struct CSR_Matrix *A, - const struct CSC_Matrix *J, - struct CSC_Matrix *C); +void block_left_multiply_fill_values(const CSR_Matrix *A, const CSC_Matrix *J, + CSC_Matrix *C); /* Compute y = kron(I_p, A) @ x where A is m x n and x is(n*p)-length vector. The output y is m*p-length vector corresponding to y = [A @ x1; A @ x2; ...; A @ xp] where x is divided into p blocks of n elements. */ -void block_left_multiply_vec(const struct CSR_Matrix *A, const double *x, double *y, - int p); +void block_left_multiply_vec(const CSR_Matrix *A, const double *x, double *y, int p); /* Fill values of C = A @ B where A is CSR, B is CSC. * C must have sparsity pattern already computed. */ -void csr_csc_matmul_fill_values(const struct CSR_Matrix *A, - const struct CSC_Matrix *B, struct CSR_Matrix *C); +void csr_csc_matmul_fill_values(const CSR_Matrix *A, const CSC_Matrix *B, + CSR_Matrix *C); /* C = A @ B where A is CSR, B is CSC. Result C is CSR. * Allocates and precomputes sparsity pattern. No workspace required. */ -struct CSR_Matrix *csr_csc_matmul_alloc(const struct CSR_Matrix *A, - const struct CSC_Matrix *B); +CSR_Matrix *csr_csc_matmul_alloc(const CSR_Matrix *A, const CSC_Matrix *B); #endif /* LINALG_H */ diff --git a/include/utils/matrix.h b/include/utils/matrix.h index fe7db5f..478fabd 100644 --- a/include/utils/matrix.h +++ b/include/utils/matrix.h @@ -41,21 +41,11 @@ typedef struct Sparse_Matrix CSR_Matrix *csr; } Sparse_Matrix; -/* Dense matrix (row-major) */ -typedef struct Dense_Matrix -{ - Matrix base; - double *x; - double *work; /* scratch buffer, length n */ -} Dense_Matrix; - /* Constructors */ Matrix *new_sparse_matrix(const CSR_Matrix *A); -Matrix *new_dense_matrix(int m, int n, const double *data); -/* Transpose helpers */ +/* Transpose helper */ Matrix *sparse_matrix_trans(const Sparse_Matrix *self, int *iwork); -Matrix *dense_matrix_trans(const Dense_Matrix *self); /* Free helper */ static inline void free_matrix(Matrix *m) diff --git a/include/utils/mini_numpy.h b/include/utils/mini_numpy.h index b29c673..931f516 100644 --- a/include/utils/mini_numpy.h +++ b/include/utils/mini_numpy.h @@ -23,4 +23,14 @@ void scaled_ones(double *result, int size, double value); /* Naive implementation of Z = X @ Y, X is m x k, Y is k x n, Z is m x n */ void mat_mat_mult(const double *X, const double *Y, double *Z, int m, int k, int n); +/* Compute v = (Y kron I_m) @ w where Y is k x n (col-major), w has + length m*n, and v has length m*k. Equivalently, reshape w as the + m x n matrix W (col-major) and compute v = vec(W @ Y^T). */ +void Y_kron_I_vec(int m, int k, int n, const double *Y, const double *w, double *v); + +/* Compute v = (I_n kron X^T) @ w where X is m x k (col-major), w has + length m*n, and v has length k*n. Equivalently, reshape w as the + m x n matrix W (col-major) and compute v = vec(X^T @ W). */ +void I_kron_XT_vec(int m, int k, int n, const double *X, const double *w, double *v); + #endif /* MINI_NUMPY_H */ diff --git a/src/affine/add.c b/src/affine/add.c index f34160b..bd86da4 100644 --- a/src/affine/add.c +++ b/src/affine/add.c @@ -45,8 +45,7 @@ static void jacobian_init_impl(expr *node) node->jacobian = new_csr_matrix(node->size, node->n_vars, nnz_max); /* fill sparsity pattern */ - sum_csr_matrices_fill_sparsity(node->left->jacobian, node->right->jacobian, - node->jacobian); + sum_csr_alloc(node->left->jacobian, node->right->jacobian, node->jacobian); } static void eval_jacobian(expr *node) @@ -56,8 +55,7 @@ static void eval_jacobian(expr *node) node->right->eval_jacobian(node->right); /* sum children's jacobians */ - sum_csr_matrices_fill_values(node->left->jacobian, node->right->jacobian, - node->jacobian); + sum_csr_fill_values(node->left->jacobian, node->right->jacobian, node->jacobian); } static void wsum_hess_init_impl(expr *node) @@ -71,8 +69,7 @@ static void wsum_hess_init_impl(expr *node) node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, nnz_max); /* fill sparsity pattern of hessian */ - sum_csr_matrices_fill_sparsity(node->left->wsum_hess, node->right->wsum_hess, - node->wsum_hess); + sum_csr_alloc(node->left->wsum_hess, node->right->wsum_hess, node->wsum_hess); } static void eval_wsum_hess(expr *node, const double *w) @@ -82,8 +79,8 @@ static void eval_wsum_hess(expr *node, const double *w) node->right->eval_wsum_hess(node->right, w); /* sum children's wsum_hess */ - sum_csr_matrices_fill_values(node->left->wsum_hess, node->right->wsum_hess, - node->wsum_hess); + sum_csr_fill_values(node->left->wsum_hess, node->right->wsum_hess, + node->wsum_hess); } static bool is_affine(const expr *node) diff --git a/src/affine/hstack.c b/src/affine/hstack.c index 8636969..92f99ad 100644 --- a/src/affine/hstack.c +++ b/src/affine/hstack.c @@ -124,7 +124,7 @@ static void wsum_hess_init_impl(expr *node) { expr *child = hnode->args[i]; copy_csr_matrix(H, hnode->CSR_work); - sum_csr_matrices_fill_sparsity(hnode->CSR_work, child->wsum_hess, H); + sum_csr_alloc(hnode->CSR_work, child->wsum_hess, H); } } @@ -140,7 +140,7 @@ static void wsum_hess_eval(expr *node, const double *w) expr *child = hnode->args[i]; child->eval_wsum_hess(child, w + row_offset); copy_csr_matrix(H, hnode->CSR_work); - sum_csr_matrices_fill_values(hnode->CSR_work, child->wsum_hess, H); + sum_csr_fill_values(hnode->CSR_work, child->wsum_hess, H); row_offset += child->size; } } diff --git a/src/affine/left_matmul.c b/src/affine/left_matmul.c index 26ea172..4067e65 100644 --- a/src/affine/left_matmul.c +++ b/src/affine/left_matmul.c @@ -17,7 +17,7 @@ */ #include "affine.h" #include "subexpr.h" -#include "utils/matrix.h" +#include "utils/dense_matrix.h" #include #include #include diff --git a/src/bivariate_full_dom/matmul.c b/src/bivariate_full_dom/matmul.c index 97e6ba9..64b6d0b 100644 --- a/src/bivariate_full_dom/matmul.c +++ b/src/bivariate_full_dom/matmul.c @@ -17,12 +17,67 @@ */ #include "bivariate_full_dom.h" #include "subexpr.h" +#include "utils/CSC_Matrix.h" +#include "utils/CSR_Matrix.h" +#include "utils/CSR_sum.h" +#include "utils/linalg_dense_sparse_matmuls.h" +#include "utils/linalg_sparse_matmuls.h" #include "utils/mini_numpy.h" +#include "utils/utils.h" #include #include #include #include +// ------------------------------------------------------------------------------ +// Helpers for the cross-Hessian B(w) of the bilinear map X @ Y. +// B(w) is the mk x kn weighted cross-Hessian of the bilinear map Z = XY. +// It captures d^2(w^T vec(Z)) / d(vec(X)) d(vec(Y)). +// +// Entry: B[row + j*m, j + col*k] = w[row + col*m]. +// Each row has exactly n nonzeros. All k block-rows (one per j) have +// the same values (columns of W = reshape(w, m, n)), but at different +// column positions (offset by j in the Y-variable indexing). +// ------------------------------------------------------------------------------ + +static CSR_Matrix *build_cross_hessian_sparsity(int m, int k, int n) +{ + int total_nnz = m * k * n; + CSR_Matrix *B = new_csr_matrix(m * k, k * n, total_nnz); + int idx = 0; + + for (int j = 0; j < k; j++) + { + for (int row = 0; row < m; row++) + { + B->p[row + j * m] = idx; + for (int col = 0; col < n; col++) + { + B->i[idx++] = j + col * k; + } + } + } + B->p[m * k] = idx; + assert(idx == total_nnz); + return B; +} + +static void fill_cross_hessian_values(int m, int k, int n, const double *w, + CSR_Matrix *B) +{ + int idx = 0; + for (int j = 0; j < k; j++) + { + for (int row = 0; row < m; row++) + { + for (int col = 0; col < n; col++) + { + B->x[idx++] = w[row + col * m]; + } + } + } +} + // ------------------------------------------------------------------------------ // Implementation of matrix multiplication: Z = X @ Y // where X is m x k and Y is k x n, producing Z which is m x n @@ -42,58 +97,74 @@ static void forward(expr *node, const double *u) mat_mat_mult(x->value, y->value, node->value, x->d1, x->d2, y->d2); } +static void free_matmul_data(expr *node) +{ + matmul_expr *mnode = (matmul_expr *) node; + /* Jacobian workspace */ + free_csr_matrix(mnode->term1_CSR); + free_csr_matrix(mnode->term2_CSR); + /* Hessian workspace */ + free_csr_matrix(mnode->B); + free_csr_matrix(mnode->BJg); + free_csc_matrix(mnode->BJg_CSC); + free(mnode->BJg_csc_work); + free_csr_matrix(mnode->C); + free_csr_matrix(mnode->CT); + free(mnode->idx_map_C); + free(mnode->idx_map_CT); + free(mnode->idx_map_Hf); + free(mnode->idx_map_Hg); +} + static bool is_affine(const expr *node) { (void) node; return false; } -static void jacobian_init_impl(expr *node) +// -------------------------------------------------------------------------------- +// Jacobian initialization and evaluation for the case when no chain rule is needed, +// ie., when both children are different leaf variables. +// -------------------------------------------------------------------------------- +static void jacobian_init_no_chain_rule(expr *node) { + assert(node->left->var_id != NOT_A_VARIABLE && + node->right->var_id != NOT_A_VARIABLE && + node->left->var_id != node->right->var_id); + expr *x = node->left; expr *y = node->right; - - /* dimensions: X is m x k, Y is k x n, Z is m x n */ int m = x->d1; int k = x->d2; int n = y->d2; int nnz = m * n * 2 * k; node->jacobian = new_csr_matrix(node->size, node->n_vars, nnz); - /* fill sparsity pattern */ int nnz_idx = 0; for (int i = 0; i < node->size; i++) { - /* Convert flat index to (row, col) in Z */ int row = i % m; int col = i / m; node->jacobian->p[i] = nnz_idx; - /* X has lower var_id */ if (x->var_id < y->var_id) { - /* sparsity pattern of kron(YT, I) for this row */ for (int j = 0; j < k; j++) { node->jacobian->i[nnz_idx++] = x->var_id + row + j * m; } - - /* sparsity pattern of kron(I, X) for this row */ for (int j = 0; j < k; j++) { node->jacobian->i[nnz_idx++] = y->var_id + col * k + j; } } - else /* Y has lower var_id */ + else { - /* sparsity pattern of kron(I, X) for this row */ for (int j = 0; j < k; j++) { node->jacobian->i[nnz_idx++] = y->var_id + col * k + j; } - - /* sparsity pattern of kron(YT, I) for this row */ for (int j = 0; j < k; j++) { node->jacobian->i[nnz_idx++] = x->var_id + row + j * m; @@ -104,29 +175,23 @@ static void jacobian_init_impl(expr *node) assert(nnz_idx == nnz); } -static void eval_jacobian(expr *node) +static void eval_jacobian_no_chain_rule(expr *node) { expr *x = node->left; expr *y = node->right; - - /* dimensions: X is m x k, Y is k x n, Z is m x n */ int m = x->d1; int k = x->d2; double *Jx = node->jacobian->x; - /* fill values row-by-row */ for (int i = 0; i < node->size; i++) { - int row = i % m; /* row in Z */ - int col = i / m; /* col in Z */ + int row = i % m; + int col = i / m; int pos = node->jacobian->p[i]; if (x->var_id < y->var_id) { - /* contribution to this row from YT */ memcpy(Jx + pos, y->value + col * k, k * sizeof(double)); - - /* contribution to this row from X */ for (int j = 0; j < k; j++) { Jx[pos + k + j] = x->value[row + j * m]; @@ -134,24 +199,74 @@ static void eval_jacobian(expr *node) } else { - /* contribution to this row from X */ for (int j = 0; j < k; j++) { Jx[pos + j] = x->value[row + j * m]; } - - /* contribution to this row from YT */ memcpy(Jx + pos + k, y->value + col * k, k * sizeof(double)); } } } -static void wsum_hess_init_impl(expr *node) +// ------------------------------------------------------------------------------------ +// Jacobian initialization and evaluation for the case where chain rule is needed, +// ie., when at least one child is composite, or same variable. The jacobian of h(u) +// = f(u) @ g(u) where f is m x k and g is k x n, is given by Jh = (g^T kron I_m) +// Jf + (I_n kron f) Jg . */ +// ------------------------------------------------------------------------------------ +static void jacobian_init_chain_rule(expr *node) +{ + expr *f = node->left; + expr *g = node->right; + matmul_expr *mnode = (matmul_expr *) node; + int m = f->d1; + int k = f->d2; + int n = g->d2; + + /* initialize Jacobian of children */ + jacobian_init(f); + jacobian_init(g); + jacobian_csc_init(f); + jacobian_csc_init(g); + + /* initialize term1, term2, and their sum */ + mnode->term1_CSR = YT_kron_I_alloc(m, k, n, f->work->jacobian_csc); + mnode->term2_CSR = I_kron_X_alloc(m, k, n, g->work->jacobian_csc); + int max_nnz = mnode->term1_CSR->nnz + mnode->term2_CSR->nnz; + node->jacobian = new_csr_matrix(node->size, node->n_vars, max_nnz); + sum_csr_alloc(mnode->term1_CSR, mnode->term2_CSR, node->jacobian); +} + +static void eval_jacobian_chain_rule(expr *node) +{ + expr *f = node->left; + expr *g = node->right; + matmul_expr *mnode = (matmul_expr *) node; + int m = f->d1; + int k = f->d2; + int n = g->d2; + + /* evaluate Jacobians of children */ + f->eval_jacobian(f); + g->eval_jacobian(g); + csr_to_csc_fill_values(f->jacobian, f->work->jacobian_csc, f->work->csc_work); + csr_to_csc_fill_values(g->jacobian, g->work->jacobian_csc, g->work->csc_work); + + /* evaluate term1, term2, and their sum */ + YT_kron_I_fill_values(m, k, n, g->value, f->work->jacobian_csc, + mnode->term1_CSR); + I_kron_X_fill_values(m, k, n, f->value, g->work->jacobian_csc, mnode->term2_CSR); + sum_csr_fill_values(mnode->term1_CSR, mnode->term2_CSR, node->jacobian); +} + +// ------------------------------------------------------------------------------------ +// Hessian initialization and evaluation for the case where no chain rule is needed, +// ie., when both children are different leaf variables. +// ------------------------------------------------------------------------------------ +static void wsum_hess_init_no_chain_rule(expr *node) { expr *x = node->left; expr *y = node->right; - - /* dimensions: X is m x k, Y is k x n, Z is m x n */ int m = x->d1; int k = x->d2; int n = y->d2; @@ -164,7 +279,6 @@ static void wsum_hess_init_impl(expr *node) if (x->var_id < y->var_id) { - /* fill rows corresponding to x */ for (i = 0; i < x->size; i++) { Hp[x->var_id + i] = nnz; @@ -174,14 +288,10 @@ static void wsum_hess_init_impl(expr *node) Hi[nnz++] = start + col * k; } } - - /* fill rows between x and y */ for (i = x->var_id + x->size; i < y->var_id; i++) { Hp[i] = nnz; } - - /* fill rows corresponding to y */ for (i = 0; i < y->size; i++) { Hp[y->var_id + i] = nnz; @@ -191,8 +301,6 @@ static void wsum_hess_init_impl(expr *node) Hi[nnz++] = start + row; } } - - /* fill rows after y */ for (i = y->var_id + y->size; i <= node->n_vars; i++) { Hp[i] = nnz; @@ -200,8 +308,6 @@ static void wsum_hess_init_impl(expr *node) } else { - /* Y has lower var_id than X */ - /* fill rows corresponding to y */ for (i = 0; i < y->size; i++) { Hp[y->var_id + i] = nnz; @@ -211,14 +317,10 @@ static void wsum_hess_init_impl(expr *node) Hi[nnz++] = start + row; } } - - /* fill rows between y and x */ for (i = y->var_id + y->size; i < x->var_id; i++) { Hp[i] = nnz; } - - /* fill rows corresponding to x */ for (i = 0; i < x->size; i++) { Hp[x->var_id + i] = nnz; @@ -228,34 +330,28 @@ static void wsum_hess_init_impl(expr *node) Hi[nnz++] = start + col * k; } } - - /* fill rows after x */ for (i = x->var_id + x->size; i <= node->n_vars; i++) { Hp[i] = nnz; } } - Hp[node->n_vars] = nnz; assert(nnz == total_nnz); } -static void eval_wsum_hess(expr *node, const double *w) +static void eval_wsum_hess_no_chain_rule(expr *node, const double *w) { expr *x = node->left; expr *y = node->right; - int m = x->d1; int k = x->d2; int n = y->d2; int offset = 0; - double *Hx = node->wsum_hess->x; const double *w_temp; if (x->var_id < y->var_id) { - /* rows corresponding to x */ for (int k_idx = 0; k_idx < k; k_idx++) { for (int row = 0; row < m; row++) @@ -266,8 +362,6 @@ static void eval_wsum_hess(expr *node, const double *w) } } } - - /* rows corresponding to y */ for (int col = 0; col < n; col++) { w_temp = w + col * m; @@ -280,7 +374,6 @@ static void eval_wsum_hess(expr *node, const double *w) } else { - /* rows corresponding to y */ for (int col = 0; col < n; col++) { w_temp = w + col * m; @@ -290,8 +383,6 @@ static void eval_wsum_hess(expr *node, const double *w) offset += m; } } - - /* rows corresponding to x */ for (int k_idx = 0; k_idx < k; k_idx++) { for (int row = 0; row < m; row++) @@ -305,6 +396,123 @@ static void eval_wsum_hess(expr *node, const double *w) } } +// ------------------------------------------------------------------------------------ +// Hessian chain rule for Z = f(u) @ g(u). +// H = C + C^T + H_f(v_f) + H_g(v_g) where: +// +// C = J_f^T B(w) J_g cross term (B is the weighted cross-Hessian) +// v_f = (Y kron I_m) w backpropagated weights for left child +// v_g = (I_n kron X^T) w backpropagated weights for right child +// H_f(v_f), H_g(v_g) child Hessians evaluated with transformed weights +// ------------------------------------------------------------------------------------ +static void wsum_hess_init_chain_rule(expr *node) +{ + expr *f = node->left; + expr *g = node->right; + matmul_expr *mnode = (matmul_expr *) node; + int m = f->d1; + int k = f->d2; + int n = g->d2; + CSC_Matrix *Jf = f->work->jacobian_csc; + CSC_Matrix *Jg = g->work->jacobian_csc; + + /* initialize C = Jf^T @ B @ Jg = Jf^T @ (B @ Jg) */ + mnode->B = build_cross_hessian_sparsity(m, k, n); + mnode->BJg = csr_csc_matmul_alloc(mnode->B, Jg); + int max_alloc = MAX(mnode->BJg->m, mnode->BJg->n); + mnode->BJg_csc_work = (int *) malloc(max_alloc * sizeof(int)); + mnode->BJg_CSC = csr_to_csc_alloc(mnode->BJg, mnode->BJg_csc_work); + mnode->C = BTA_alloc(mnode->BJg_CSC, Jf); + + /* initialize C^T */ + node->work->iwork = (int *) malloc(mnode->C->m * sizeof(int)); + mnode->CT = AT_alloc(mnode->C, node->work->iwork); + + /* initialize Hessians of children */ + wsum_hess_init(f); + wsum_hess_init(g); + + /* sum the four terms and fill idx maps */ + int *maps[4]; + node->wsum_hess = sum_4_csr_fill_sparsity_and_idx_maps( + mnode->C, mnode->CT, f->wsum_hess, g->wsum_hess, maps); + mnode->idx_map_C = maps[0]; + mnode->idx_map_CT = maps[1]; + mnode->idx_map_Hf = maps[2]; + mnode->idx_map_Hg = maps[3]; + + /* allocate weight backprop workspace */ + if (!f->is_affine(f) || !g->is_affine(g)) + { + node->work->dwork = + (double *) malloc(MAX(f->size, g->size) * sizeof(double)); + } +} + +static void eval_wsum_hess_chain_rule(expr *node, const double *w) +{ + expr *f = node->left; + expr *g = node->right; + matmul_expr *mnode = (matmul_expr *) node; + int m = f->d1; + int k = f->d2; + int n = g->d2; + bool is_f_affine = f->is_affine(f); + bool is_g_affine = g->is_affine(g); + CSC_Matrix *Jf = f->work->jacobian_csc; + CSC_Matrix *Jg = g->work->jacobian_csc; + + /* refresh child Jacobian CSC values (cache if affine) */ + if (!f->work->jacobian_csc_filled) + { + csr_to_csc_fill_values(f->jacobian, Jf, f->work->csc_work); + if (is_f_affine) + { + f->work->jacobian_csc_filled = true; + } + } + + /* refresh child Jacobian CSC values (cache if affine) */ + if (!g->work->jacobian_csc_filled) + { + csr_to_csc_fill_values(g->jacobian, Jg, g->work->csc_work); + if (is_g_affine) + { + g->work->jacobian_csc_filled = true; + } + } + + /* compute C = J_f^T @ B(w) @ J_g */ + fill_cross_hessian_values(m, k, n, w, mnode->B); + csr_csc_matmul_fill_values(mnode->B, Jg, mnode->BJg); + csr_to_csc_fill_values(mnode->BJg, mnode->BJg_CSC, mnode->BJg_csc_work); + BTDA_fill_values(mnode->BJg_CSC, Jf, NULL, mnode->C); + + /* compute CT */ + AT_fill_values(mnode->C, mnode->CT, node->work->iwork); + + /* compute Hessian of f */ + if (!is_f_affine) + { + Y_kron_I_vec(m, k, n, g->value, w, node->work->dwork); + f->eval_wsum_hess(f, node->work->dwork); + } + + /* compute Hessian of g */ + if (!is_g_affine) + { + I_kron_XT_vec(m, k, n, f->value, w, node->work->dwork); + g->eval_wsum_hess(g, node->work->dwork); + } + + /* accumulate H = C + C^T + H_f + H_g */ + memset(node->wsum_hess->x, 0, node->wsum_hess->nnz * sizeof(double)); + idx_map_accumulator(mnode->C, mnode->idx_map_C, node->wsum_hess->x); + idx_map_accumulator(mnode->CT, mnode->idx_map_CT, node->wsum_hess->x); + idx_map_accumulator(f->wsum_hess, mnode->idx_map_Hf, node->wsum_hess->x); + idx_map_accumulator(g->wsum_hess, mnode->idx_map_Hg, node->wsum_hess->x); +} + expr *new_matmul(expr *x, expr *y) { /* Verify dimensions: x->d2 must equal y->d1 */ @@ -317,21 +525,26 @@ expr *new_matmul(expr *x, expr *y) exit(1); } - /* verify both are variables and not the same variable */ - if (x->var_id == NOT_A_VARIABLE || y->var_id == NOT_A_VARIABLE || - x->var_id == y->var_id) - { - fprintf(stderr, "Error in new_matmul: operands must be variables and not " - "the same variable\n"); - exit(1); - } - /* Allocate the expression node */ - expr *node = (expr *) calloc(1, sizeof(expr)); + expr *node = (expr *) calloc(1, sizeof(matmul_expr)); + + /* Choose no-chain-rule or chain-rule function pointers */ + bool use_chain_rule = !(x->var_id != NOT_A_VARIABLE && + y->var_id != NOT_A_VARIABLE && x->var_id != y->var_id); - /* Initialize with d1 = x->d1, d2 = y->d2 (result is m x n) */ - init_expr(node, x->d1, y->d2, x->n_vars, forward, jacobian_init_impl, - eval_jacobian, is_affine, wsum_hess_init_impl, eval_wsum_hess, NULL); + if (!use_chain_rule) + { + init_expr(node, x->d1, y->d2, x->n_vars, forward, + jacobian_init_no_chain_rule, eval_jacobian_no_chain_rule, + is_affine, wsum_hess_init_no_chain_rule, + eval_wsum_hess_no_chain_rule, free_matmul_data); + } + else + { + init_expr(node, x->d1, y->d2, x->n_vars, forward, jacobian_init_chain_rule, + eval_jacobian_chain_rule, is_affine, wsum_hess_init_chain_rule, + eval_wsum_hess_chain_rule, free_matmul_data); + } /* Set children */ node->left = x; diff --git a/src/bivariate_full_dom/multiply.c b/src/bivariate_full_dom/multiply.c index bd13bfb..a77a88c 100644 --- a/src/bivariate_full_dom/multiply.c +++ b/src/bivariate_full_dom/multiply.c @@ -23,16 +23,6 @@ #include #include -/* Scatter-add src->x into dest using precomputed index map */ -static void accumulate_mapped(double *dest, const CSR_Matrix *src, - const int *idx_map) -{ - for (int j = 0; j < src->nnz; j++) - { - dest[idx_map[j]] += src->x[j]; - } -} - // ------------------------------------------------------------------------------ // Implementation of elementwise multiplication when both arguments are vectors. // If one argument is a scalar variable, the broadcasting should be represented @@ -62,8 +52,7 @@ static void jacobian_init_impl(expr *node) node->jacobian = new_csr_matrix(node->size, node->n_vars, nnz_max); /* fill sparsity pattern */ - sum_csr_matrices_fill_sparsity(node->left->jacobian, node->right->jacobian, - node->jacobian); + sum_csr_alloc(node->left->jacobian, node->right->jacobian, node->jacobian); } static void eval_jacobian(expr *node) @@ -268,10 +257,10 @@ static void eval_wsum_hess(expr *node, const double *w) // compute H = C + C^T + term2 + term3 // --------------------------------------------------------------- memset(node->wsum_hess->x, 0, node->wsum_hess->nnz * sizeof(double)); - accumulate_mapped(node->wsum_hess->x, C, mul_node->idx_map_C); - accumulate_mapped(node->wsum_hess->x, CT, mul_node->idx_map_CT); - accumulate_mapped(node->wsum_hess->x, x->wsum_hess, mul_node->idx_map_Hx); - accumulate_mapped(node->wsum_hess->x, y->wsum_hess, mul_node->idx_map_Hy); + idx_map_accumulator(C, mul_node->idx_map_C, node->wsum_hess->x); + idx_map_accumulator(CT, mul_node->idx_map_CT, node->wsum_hess->x); + idx_map_accumulator(x->wsum_hess, mul_node->idx_map_Hx, node->wsum_hess->x); + idx_map_accumulator(y->wsum_hess, mul_node->idx_map_Hy, node->wsum_hess->x); } } diff --git a/src/elementwise_full_dom/common.c b/src/elementwise_full_dom/common.c index 59c43c5..74280f9 100644 --- a/src/elementwise_full_dom/common.c +++ b/src/elementwise_full_dom/common.c @@ -102,8 +102,8 @@ void wsum_hess_init_elementwise(expr *node) /* wsum_hess = term1 + term2 */ int max_nnz = node->work->hess_term1->nnz + node->work->hess_term2->nnz; node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, max_nnz); - sum_csr_matrices_fill_sparsity(node->work->hess_term1, - node->work->hess_term2, node->wsum_hess); + sum_csr_alloc(node->work->hess_term1, node->work->hess_term2, + node->wsum_hess); } } } @@ -155,8 +155,8 @@ void eval_wsum_hess_elementwise(expr *node, const double *w) child->wsum_hess->nnz * sizeof(double)); /* wsum_hess = term1 + term2 */ - sum_csr_matrices_fill_values(node->work->hess_term1, - node->work->hess_term2, node->wsum_hess); + sum_csr_fill_values(node->work->hess_term1, node->work->hess_term2, + node->wsum_hess); } } } diff --git a/src/other/quad_form.c b/src/other/quad_form.c index 91ea602..78bd436 100644 --- a/src/other/quad_form.c +++ b/src/other/quad_form.c @@ -156,8 +156,8 @@ static void wsum_hess_init_impl(expr *node) /* hess = term1 + term2 */ int max_nnz = node->work->hess_term1->nnz + node->work->hess_term2->nnz; node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, max_nnz); - sum_csr_matrices_fill_sparsity(node->work->hess_term1, - node->work->hess_term2, node->wsum_hess); + sum_csr_alloc(node->work->hess_term1, node->work->hess_term2, + node->wsum_hess); } } @@ -205,7 +205,7 @@ static void eval_wsum_hess(expr *node, const double *w) cblas_dscal(term2->nnz, two_w, term2->x, 1); /* sum the two terms */ - sum_csr_matrices_fill_values(term1, term2, node->wsum_hess); + sum_csr_fill_values(term1, term2, node->wsum_hess); } } diff --git a/src/utils/CSR_sum.c b/src/utils/CSR_sum.c index e024a0d..1f30906 100644 --- a/src/utils/CSR_sum.c +++ b/src/utils/CSR_sum.c @@ -85,8 +85,7 @@ void sum_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C) C->p[A->m] = C->nnz; } -void sum_csr_matrices_fill_sparsity(const CSR_Matrix *A, const CSR_Matrix *B, - CSR_Matrix *C) +void sum_csr_alloc(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C) { /* A and B must be different from C */ assert(A != C && B != C); @@ -143,8 +142,7 @@ void sum_csr_matrices_fill_sparsity(const CSR_Matrix *A, const CSR_Matrix *B, C->p[A->m] = C->nnz; } -void sum_csr_matrices_fill_values(const CSR_Matrix *A, const CSR_Matrix *B, - CSR_Matrix *C) +void sum_csr_fill_values(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C) { /* Assumes C->p and C->i already contain the sparsity pattern of A+B. Fills only C->x accordingly. */ @@ -573,13 +571,10 @@ void sum_evenly_spaced_rows_csr_fill_sparsity_and_idx_map(const CSR_Matrix *A, void idx_map_accumulator(const CSR_Matrix *A, const int *idx_map, double *accumulator) { - /* don't forget to initialze accumulator to 0 before calling this */ - for (int row = 0; row < A->m; row++) + /* don't forget to initialize accumulator to 0 before calling this */ + for (int j = 0; j < A->nnz; j++) { - for (int j = A->p[row]; j < A->p[row + 1]; j++) - { - accumulator[idx_map[j]] += A->x[j]; - } + accumulator[idx_map[j]] += A->x[j]; } } diff --git a/src/utils/dense_matrix.c b/src/utils/dense_matrix.c index 63f3442..8a2dd2c 100644 --- a/src/utils/dense_matrix.c +++ b/src/utils/dense_matrix.c @@ -15,9 +15,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "utils/dense_matrix.h" #include "utils/cblas_wrapper.h" -#include "utils/iVec.h" -#include "utils/matrix.h" +#include "utils/linalg_dense_sparse_matmuls.h" #include #include @@ -41,136 +41,6 @@ static void dense_block_left_mult_vec(const Matrix *A, const double *x, double * n, 0.0, y, m); } -static CSC_Matrix *dense_block_left_mult_sparsity(const Matrix *A, - const CSC_Matrix *J, int p) -{ - int m = A->m; - int n = A->n; - int i, j, jj, block, block_start, block_end, block_jj_start, row_offset; - - int *Cp = (int *) malloc((J->n + 1) * sizeof(int)); - iVec *Ci = iVec_new(J->n * m); - Cp[0] = 0; - - /* for each column of J */ - for (j = 0; j < J->n; j++) - { - /* if empty we continue */ - if (J->p[j] == J->p[j + 1]) - { - Cp[j + 1] = Cp[j]; - continue; - } - - /* process each of p blocks of rows in this column of J */ - jj = J->p[j]; - for (block = 0; block < p; block++) - { - // ----------------------------------------------------------------- - // find start and end indices of rows of J in this block - // ----------------------------------------------------------------- - block_start = block * n; - block_end = block_start + n; - while (jj < J->p[j + 1] && J->i[jj] < block_start) - { - jj++; - } - - block_jj_start = jj; - while (jj < J->p[j + 1] && J->i[jj] < block_end) - { - jj++; - } - - /* if no entries in this block, continue */ - if (jj == block_jj_start) - { - continue; - } - - /* dense A: all m rows contribute */ - row_offset = block * m; - for (i = 0; i < m; i++) - { - iVec_append(Ci, row_offset + i); - } - } - Cp[j + 1] = Ci->len; - } - - CSC_Matrix *C = new_csc_matrix(m * p, J->n, Ci->len); - memcpy(C->p, Cp, (J->n + 1) * sizeof(int)); - memcpy(C->i, Ci->data, Ci->len * sizeof(int)); - free(Cp); - iVec_free(Ci); - - return C; -} - -static void dense_block_left_mult_values(const Matrix *A, const CSC_Matrix *J, - CSC_Matrix *C) -{ - const Dense_Matrix *dm = (const Dense_Matrix *) A; - int m = dm->base.m; - int n = dm->base.n; - int k = J->n; - - int i, j, s, block, block_start, block_end, start, end; - - double *j_dense = dm->work; - - /* for each column of J (and C) */ - for (j = 0; j < k; j++) - { - for (i = C->p[j]; i < C->p[j + 1]; i += m) - { - block = C->i[i] / m; - block_start = block * n; - block_end = block_start + n; - - start = J->p[j]; - end = J->p[j + 1]; - - while (start < J->p[j + 1] && J->i[start] < block_start) - { - start++; - } - - while (end > start && J->i[end - 1] >= block_end) - { - end--; - } - - int count = end - start; - - if (count == 1) - { - /* Fast path: C column segment = val * A[:, row_in_block] */ - int row_in_block = J->i[start] - block_start; - double val = J->x[start]; - cblas_dcopy(m, dm->x + row_in_block, n, C->x + i, 1); - if (val != 1.0) - { - cblas_dscal(m, val, C->x + i, 1); - } - } - else - { - /* scatter sparse J col into dense vector and then compute A @ - * j_dense */ - memset(j_dense, 0, n * sizeof(double)); - for (s = start; s < end; s++) - { - j_dense[J->i[s] - block_start] = J->x[s]; - } - - cblas_dgemv(CblasRowMajor, CblasNoTrans, m, n, 1.0, dm->x, n, - j_dense, 1, 0.0, C->x + i, 1); - } - } - } -} - static void dense_free(Matrix *A) { Dense_Matrix *dm = (Dense_Matrix *) A; @@ -185,8 +55,8 @@ Matrix *new_dense_matrix(int m, int n, const double *data) dm->base.m = m; dm->base.n = n; dm->base.block_left_mult_vec = dense_block_left_mult_vec; - dm->base.block_left_mult_sparsity = dense_block_left_mult_sparsity; - dm->base.block_left_mult_values = dense_block_left_mult_values; + dm->base.block_left_mult_sparsity = I_kron_A_alloc; + dm->base.block_left_mult_values = I_kron_A_fill_values; dm->base.free_fn = dense_free; dm->x = (double *) malloc(m * n * sizeof(double)); memcpy(dm->x, data, m * n * sizeof(double)); diff --git a/src/utils/linalg_dense_sparse_matmuls.c b/src/utils/linalg_dense_sparse_matmuls.c new file mode 100644 index 0000000..a073349 --- /dev/null +++ b/src/utils/linalg_dense_sparse_matmuls.c @@ -0,0 +1,348 @@ +/* + * Copyright 2026 Daniel Cederberg and William Zhang + * + * This file is part of the DNLP-differentiation-engine project. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "utils/CSC_Matrix.h" +#include "utils/CSR_Matrix.h" +#include "utils/cblas_wrapper.h" +#include "utils/dense_matrix.h" +#include "utils/iVec.h" +#include +#include +#include + +/* --------------------------------------------------------------- + * C = (I_p kron A) @ J via the polymorphic Matrix interface. + * A is dense m x n, J is (n*p) x k CSC, C is (m*p) x k CSC. + * --------------------------------------------------------------- */ +CSC_Matrix *I_kron_A_alloc(const Matrix *A, const CSC_Matrix *J, int p) +{ + int m = A->m; + int n = A->n; + int i, j, jj, block, block_start, block_end, block_jj_start, row_offset; + + int *Cp = (int *) malloc((J->n + 1) * sizeof(int)); + iVec *Ci = iVec_new(J->n * m); + Cp[0] = 0; + + /* for each column of J */ + for (j = 0; j < J->n; j++) + { + /* if empty we continue */ + if (J->p[j] == J->p[j + 1]) + { + Cp[j + 1] = Cp[j]; + continue; + } + + /* process each of p blocks of rows in this column of J */ + jj = J->p[j]; + for (block = 0; block < p; block++) + { + block_start = block * n; + block_end = block_start + n; + while (jj < J->p[j + 1] && J->i[jj] < block_start) + { + jj++; + } + + block_jj_start = jj; + while (jj < J->p[j + 1] && J->i[jj] < block_end) + { + jj++; + } + + /* if no entries in this block, continue */ + if (jj == block_jj_start) + { + continue; + } + + /* dense A: all m rows contribute */ + row_offset = block * m; + for (i = 0; i < m; i++) + { + iVec_append(Ci, row_offset + i); + } + } + Cp[j + 1] = Ci->len; + } + + CSC_Matrix *C = new_csc_matrix(m * p, J->n, Ci->len); + memcpy(C->p, Cp, (J->n + 1) * sizeof(int)); + memcpy(C->i, Ci->data, Ci->len * sizeof(int)); + free(Cp); + iVec_free(Ci); + + return C; +} + +void I_kron_A_fill_values(const Matrix *A, const CSC_Matrix *J, CSC_Matrix *C) +{ + const Dense_Matrix *dm = (const Dense_Matrix *) A; + int m = dm->base.m; + int n = dm->base.n; + int k = J->n; + + int i, j, s, block, block_start, block_end, start, end; + + double *j_dense = dm->work; + + /* for each column of J (and C) */ + for (j = 0; j < k; j++) + { + for (i = C->p[j]; i < C->p[j + 1]; i += m) + { + block = C->i[i] / m; + block_start = block * n; + block_end = block_start + n; + + start = J->p[j]; + end = J->p[j + 1]; + + while (start < J->p[j + 1] && J->i[start] < block_start) + { + start++; + } + + while (end > start && J->i[end - 1] >= block_end) + { + end--; + } + + int count = end - start; + + if (count == 1) + { + /* Fast path: C column segment = val * A[:, row_in_block] */ + int row_in_block = J->i[start] - block_start; + double val = J->x[start]; + cblas_dcopy(m, dm->x + row_in_block, n, C->x + i, 1); + if (val != 1.0) + { + cblas_dscal(m, val, C->x + i, 1); + } + } + else + { + /* scatter sparse J col into dense vector and then + * compute A @ j_dense */ + memset(j_dense, 0, n * sizeof(double)); + for (s = start; s < end; s++) + { + j_dense[J->i[s] - block_start] = J->x[s]; + } + + cblas_dgemv(CblasRowMajor, CblasNoTrans, m, n, 1.0, dm->x, n, + j_dense, 1, 0.0, C->x + i, 1); + } + } + } +} + +/* --------------------------------------------------------------- + * C = (Y^T kron I_m) @ J + * Y is k x n (col-major), J is (m*k) x p CSC, C is (m*n) x p CSR + * --------------------------------------------------------------- */ +CSR_Matrix *YT_kron_I_alloc(int m, int k, int n, const CSC_Matrix *J) +{ + (void) k; + /* C has n blocks of m rows. All rows at the same position within + their block (same blk_row) share the same column sparsity. + + Step 1: for each blk_row, find which columns of J contribute. + Column j contributes iff J[:,j] has any nonzero r + with r % m == blk_row. + Step 2: replicate each blk_row's pattern across all n blocks. */ + + int i, j, ii, blk_row, total_nnz; + + // --------------------------------------------------------------- + // build sparsity pattern per blk_row + // --------------------------------------------------------------- + iVec **pattern = (iVec **) malloc(m * sizeof(iVec *)); + total_nnz = 0; + for (blk_row = 0; blk_row < m; blk_row++) + { + pattern[blk_row] = iVec_new(J->n); + + /* check each column of J */ + for (j = 0; j < J->n; j++) + { + for (ii = J->p[j]; ii < J->p[j + 1]; ii++) + { + if (J->i[ii] % m == blk_row) + { + iVec_append(pattern[blk_row], j); + break; + } + } + } + total_nnz += pattern[blk_row]->len * n; + } + + // --------------------------------------------------------------- + // replicate sparsity pattern across blocks + // --------------------------------------------------------------- + CSR_Matrix *C = new_csr_matrix(m * n, J->n, total_nnz); + int idx = 0; + for (i = 0; i < m * n; i++) + { + blk_row = i % m; + C->p[i] = idx; + int len = pattern[blk_row]->len; + memcpy(C->i + idx, pattern[blk_row]->data, len * sizeof(int)); + idx += len; + } + C->p[m * n] = idx; + assert(idx == total_nnz); + + for (blk_row = 0; blk_row < m; blk_row++) + { + iVec_free(pattern[blk_row]); + } + free(pattern); + return C; +} + +void YT_kron_I_fill_values(int m, int k, int n, const double *Y, const CSC_Matrix *J, + CSR_Matrix *C) +{ + (void) n; + assert(C->m == m * n); + /* C[i, j] = sum_l Y[l, blk] * J[blk_row + l*m, j] + * where blk_row = i % m, blk = i / m */ + + int i, j, ii, jj, blk, blk_row; + double sum; + + /* for each row i of C */ + for (i = 0; i < C->m; i++) + { + blk = i / m; /* which block of C */ + blk_row = i % m; /* which row within block */ + const double *Y_col = Y + blk * k; /* column blk of Y */ + + /* for each column j in row i of C */ + for (jj = C->p[i]; jj < C->p[i + 1]; jj++) + { + j = C->i[jj]; + sum = 0.0; + + /* matching J nonzeros in this column */ + for (ii = J->p[j]; ii < J->p[j + 1]; ii++) + { + if (J->i[ii] % m == blk_row) + { + sum += Y_col[J->i[ii] / m] * J->x[ii]; + } + } + C->x[jj] = sum; + } + } +} + +CSR_Matrix *I_kron_X_alloc(int m, int k, int n, const CSC_Matrix *J) +{ + /* Step 1: for each block, find which columns of J have any + * nonzero in row range [blk*k, blk*k + k). */ + int i, j, ii, blk; + + iVec **pattern = (iVec **) malloc(n * sizeof(iVec *)); + int total_nnz = 0; + for (blk = 0; blk < n; blk++) + { + int blk_start = blk * k; + int blk_end = blk_start + k; + pattern[blk] = iVec_new(J->n); + + /* check each column of J */ + for (j = 0; j < J->n; j++) + { + for (ii = J->p[j]; ii < J->p[j + 1]; ii++) + { + if (J->i[ii] >= blk_start && J->i[ii] < blk_end) + { + iVec_append(pattern[blk], j); + break; + } + } + } + total_nnz += (int) pattern[blk]->len * m; + } + + /* Step 2: replicate each block's pattern for all m rows + * within that block. */ + CSR_Matrix *C = new_csr_matrix(m * n, J->n, total_nnz); + int idx = 0; + for (i = 0; i < m * n; i++) + { + blk = i / m; + C->p[i] = idx; + int len = (int) pattern[blk]->len; + memcpy(C->i + idx, pattern[blk]->data, len * sizeof(int)); + idx += len; + } + C->p[m * n] = idx; + assert(idx == total_nnz); + + for (blk = 0; blk < n; blk++) + { + iVec_free(pattern[blk]); + } + free(pattern); + return C; +} + +void I_kron_X_fill_values(int m, int k, int n, const double *X, const CSC_Matrix *J, + CSR_Matrix *C) +{ + (void) n; + assert(C->m == m * n); + /* C[i, j] = sum_l X[blk_row + l*m] * J[blk*k + l, j] + * where blk = i / m, blk_row = i % m */ + + int i, j, ii, jj, blk, blk_row; + double sum; + + /* for each row i of C */ + for (i = 0; i < C->m; i++) + { + blk = i / m; /* which block of C */ + blk_row = i % m; /* which row within block */ + int blk_start = blk * k; + int blk_end = blk_start + k; + + /* for each column j in row i of C */ + for (jj = C->p[i]; jj < C->p[i + 1]; jj++) + { + j = C->i[jj]; + sum = 0.0; + + /* J nonzeros in column j within this block's row range */ + for (ii = J->p[j]; ii < J->p[j + 1]; ii++) + { + int r = J->i[ii]; + if (r >= blk_start && r < blk_end) + { + int l = r - blk_start; + sum += X[blk_row + l * m] * J->x[ii]; + } + } + C->x[jj] = sum; + } + } +} diff --git a/src/utils/mini_numpy.c b/src/utils/mini_numpy.c index c30073c..27f50d8 100644 --- a/src/utils/mini_numpy.c +++ b/src/utils/mini_numpy.c @@ -68,3 +68,35 @@ void mat_mat_mult(const double *X, const double *Y, double *Z, int m, int k, int } } } + +void Y_kron_I_vec(int m, int k, int n, const double *Y, const double *w, double *v) +{ + for (int j = 0; j < k; j++) + { + for (int row = 0; row < m; row++) + { + double sum = 0.0; + for (int col = 0; col < n; col++) + { + sum += Y[j + col * k] * w[row + col * m]; + } + v[row + j * m] = sum; + } + } +} + +void I_kron_XT_vec(int m, int k, int n, const double *X, const double *w, double *v) +{ + for (int col = 0; col < n; col++) + { + for (int j = 0; j < k; j++) + { + double sum = 0.0; + for (int row = 0; row < m; row++) + { + sum += X[row + j * m] * w[row + col * m]; + } + v[j + col * k] = sum; + } + } +} diff --git a/tests/all_tests.c b/tests/all_tests.c index a92a6d9..c004226 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -55,6 +55,7 @@ #include "utils/test_csr_csc_conversion.h" #include "utils/test_csr_matrix.h" #include "utils/test_linalg_sparse_matmuls.h" +#include "utils/test_linalg_utils_matmul_chain_rule.h" #include "utils/test_matrix.h" #include "wsum_hess/affine/test_broadcast.h" #include "wsum_hess/affine/test_const_scalar_mult.h" @@ -139,6 +140,11 @@ int main(void) mu_run_test(test_jacobian_AX_BX_multiply, tests_run); mu_run_test(test_jacobian_quad_form_Ax, tests_run); mu_run_test(test_jacobian_quad_form_exp, tests_run); + mu_run_test(test_jacobian_matmul_exp_exp, tests_run); + mu_run_test(test_jacobian_matmul_sin_cos, tests_run); + mu_run_test(test_jacobian_matmul_Ax_By, tests_run); + mu_run_test(test_jacobian_matmul_sin_Ax_cos_Bx, tests_run); + mu_run_test(test_jacobian_matmul_X_X, tests_run); mu_run_test(test_jacobian_composite_exp_add, tests_run); mu_run_test(test_jacobian_const_scalar_mult_log_vector, tests_run); mu_run_test(test_jacobian_const_scalar_mult_log_matrix, tests_run); @@ -279,6 +285,11 @@ int main(void) mu_run_test(test_wsum_hess_quad_form_Ax, tests_run); mu_run_test(test_wsum_hess_quad_form_sin_Ax, tests_run); mu_run_test(test_wsum_hess_quad_form_exp, tests_run); + mu_run_test(test_wsum_hess_matmul_exp_exp, tests_run); + mu_run_test(test_wsum_hess_matmul_sin_cos, tests_run); + mu_run_test(test_wsum_hess_matmul_Ax_By, tests_run); + mu_run_test(test_wsum_hess_matmul_sin_Ax_cos_Bx, tests_run); + mu_run_test(test_wsum_hess_matmul_X_X, tests_run); printf("\n--- Utility Tests ---\n"); mu_run_test(test_cblas_ddot, tests_run); @@ -318,6 +329,10 @@ int main(void) mu_run_test(test_sparse_vs_dense_mult_vec, tests_run); mu_run_test(test_dense_matrix_trans, tests_run); mu_run_test(test_sparse_vs_dense_mult_vec_blocks, tests_run); + mu_run_test(test_YT_kron_I, tests_run); + mu_run_test(test_YT_kron_I_larger, tests_run); + mu_run_test(test_I_kron_X, tests_run); + mu_run_test(test_I_kron_X_larger, tests_run); printf("\n--- Numerical Diff Tests ---\n"); mu_run_test(test_check_jacobian_composite_exp, tests_run); diff --git a/tests/jacobian_tests/composite/test_chain_rule_jacobian.h b/tests/jacobian_tests/composite/test_chain_rule_jacobian.h index 9e916af..bc61d83 100644 --- a/tests/jacobian_tests/composite/test_chain_rule_jacobian.h +++ b/tests/jacobian_tests/composite/test_chain_rule_jacobian.h @@ -170,3 +170,101 @@ const char *test_jacobian_quad_form_exp(void) free_csr_matrix(Q); return 0; } + +const char *test_jacobian_matmul_exp_exp(void) +{ + /* Z = exp(X) @ exp(Y), X is 2x3, Y is 3x2 */ + double u_vals[12] = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2}; + + expr *X = new_variable(2, 3, 0, 12); + expr *Y = new_variable(3, 2, 6, 12); + expr *exp_X = new_exp(X); + expr *exp_Y = new_exp(Y); + expr *Z = new_matmul(exp_X, exp_Y); + + mu_assert("check_jacobian failed", + check_jacobian(Z, u_vals, NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(Z); + return 0; +} + +const char *test_jacobian_matmul_sin_cos(void) +{ + /* Z = sin(X) @ cos(Y), X is 2x2, Y is 2x3 */ + double u_vals[10] = {0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0}; + + expr *X = new_variable(2, 2, 0, 10); + expr *Y = new_variable(2, 3, 4, 10); + expr *sin_X = new_sin(X); + expr *cos_Y = new_cos(Y); + expr *Z = new_matmul(sin_X, cos_Y); + + mu_assert("check_jacobian failed", + check_jacobian(Z, u_vals, NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(Z); + return 0; +} + +const char *test_jacobian_matmul_Ax_By(void) +{ + /* Z = (A @ X) @ (B @ Y) with constant matrices A, B */ + double u_vals[10] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0}; + + CSR_Matrix *A = new_csr_random(3, 2, 1.0); + CSR_Matrix *B = new_csr_random(2, 3, 1.0); + + expr *X = new_variable(2, 2, 0, 10); /* 2x2, vars 0-3 */ + expr *Y = new_variable(3, 2, 4, 10); /* 3x2, vars 4-9 */ + expr *AX = new_left_matmul(X, A); /* 3x2 */ + expr *BY = new_left_matmul(Y, B); /* 2x2 */ + expr *Z = new_matmul(AX, BY); /* 3x2 */ + + mu_assert("check_jacobian failed", + check_jacobian(Z, u_vals, NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(Z); + free_csr_matrix(A); + free_csr_matrix(B); + return 0; +} + +const char *test_jacobian_matmul_sin_Ax_cos_Bx(void) +{ + /* Z = sin(A @ X) @ cos(B @ X), shared variable X */ + double u_vals[6] = {0.5, 1.0, 1.5, 2.0, 2.5, 3.0}; + + CSR_Matrix *A = new_csr_random(2, 3, 1.0); + CSR_Matrix *B = new_csr_random(2, 3, 1.0); + + expr *X = new_variable(3, 2, 0, 6); /* 3x2, vars 0-5 */ + expr *AX = new_left_matmul(X, A); /* 2x2 */ + expr *BX = new_left_matmul(X, B); /* 2x2 */ + expr *sin_AX = new_sin(AX); /* 2x2 */ + expr *cos_BX = new_cos(BX); /* 2x2 */ + expr *Z = new_matmul(sin_AX, cos_BX); /* 2x2 */ + + mu_assert("check_jacobian failed", + check_jacobian(Z, u_vals, NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(Z); + free_csr_matrix(A); + free_csr_matrix(B); + return 0; +} + +const char *test_jacobian_matmul_X_X(void) +{ + /* Z = X @ X, same leaf variable as both children */ + double u_vals[4] = {1.0, 2.0, 3.0, 4.0}; + + expr *X = new_variable(2, 2, 0, 4); + expr *Z = new_matmul(X, X); /* 2x2 */ + + mu_assert("check_jacobian failed", + check_jacobian(Z, u_vals, NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(Z); + return 0; +} diff --git a/tests/utils/test_linalg_utils_matmul_chain_rule.h b/tests/utils/test_linalg_utils_matmul_chain_rule.h new file mode 100644 index 0000000..0c66589 --- /dev/null +++ b/tests/utils/test_linalg_utils_matmul_chain_rule.h @@ -0,0 +1,228 @@ +#include +#include +#include + +#include "minunit.h" +#include "test_helpers.h" +#include "utils/CSC_Matrix.h" +#include "utils/CSR_Matrix.h" +#include "utils/linalg_dense_sparse_matmuls.h" + +/* Test YT_kron_I_alloc and YT_kron_I_fill_values + * + * C = (Y^T kron I_m) @ J + * m=2, k=2, n=2, p=3 + * + * Y (k x n, col-major [1,2,3,4]): + * [1 3] + * [2 4] + * + * J (mk=4 x p=3, CSC): + * [1 0 2] + * [0 1 0] + * [3 0 0] + * [0 0 1] + * + * C = (Y^T kron I_2) @ J: + * [ 7 0 2] + * [ 0 1 2] + * [15 0 6] + * [ 0 3 4] + */ +const char *test_YT_kron_I(void) +{ + int m = 2, k = 2, n = 2; + + /* J is 4x3 CSC */ + CSC_Matrix *J = new_csc_matrix(4, 3, 5); + int Jp[4] = {0, 2, 3, 5}; + int Ji[5] = {0, 2, 1, 0, 3}; + double Jx[5] = {1.0, 3.0, 1.0, 2.0, 1.0}; + memcpy(J->p, Jp, 4 * sizeof(int)); + memcpy(J->i, Ji, 5 * sizeof(int)); + memcpy(J->x, Jx, 5 * sizeof(double)); + + /* Y col-major: Y[0,0]=1, Y[1,0]=2, Y[0,1]=3, Y[1,1]=4 */ + double Y[4] = {1.0, 2.0, 3.0, 4.0}; + + CSR_Matrix *C = YT_kron_I_alloc(m, k, n, J); + + /* Expected CSR (from scipy) */ + int exp_p[5] = {0, 2, 4, 6, 8}; + int exp_i[8] = {0, 2, 1, 2, 0, 2, 1, 2}; + double exp_x[8] = {7.0, 2.0, 1.0, 2.0, 15.0, 6.0, 3.0, 4.0}; + + mu_assert("C dims", C->m == 4 && C->n == 3); + mu_assert("C nnz", C->nnz == 8); + mu_assert("C row ptrs", cmp_int_array(C->p, exp_p, 5)); + mu_assert("C col indices", cmp_int_array(C->i, exp_i, 8)); + + YT_kron_I_fill_values(m, k, n, Y, J, C); + mu_assert("C values", cmp_double_array(C->x, exp_x, 8)); + + free_csr_matrix(C); + free_csc_matrix(J); + return NULL; +} + +/* Test YT_kron_I with larger dimensions: m=3, k=2, n=3, p=4 + * + * Y (k=2 x n=3, col-major [1,3,0.5,1,2,0.5]): + * [1.0 0.5 2.0] + * [3.0 1.0 0.5] + * + * J (mk=6 x p=4, CSC): + * [1 0 0 2] + * [0 0 1 0] + * [0 3 0 0] + * [2 0 0 1] + * [0 1 0 0] + * [0 0 4 0] + * + * C = (Y^T kron I_3) @ J is 9 x 4 + */ +const char *test_YT_kron_I_larger(void) +{ + int m = 3, k = 2, n = 3; + + /* J is 6x4 CSC */ + CSC_Matrix *J = new_csc_matrix(6, 4, 8); + int Jp[5] = {0, 2, 4, 6, 8}; + int Ji[8] = {0, 3, 2, 4, 1, 5, 0, 3}; + double Jx[8] = {1.0, 2.0, 3.0, 1.0, 1.0, 4.0, 2.0, 1.0}; + memcpy(J->p, Jp, 5 * sizeof(int)); + memcpy(J->i, Ji, 8 * sizeof(int)); + memcpy(J->x, Jx, 8 * sizeof(double)); + + /* Y col-major */ + double Y[6] = {1.0, 3.0, 0.5, 1.0, 2.0, 0.5}; + + CSR_Matrix *C = YT_kron_I_alloc(m, k, n, J); + + /* Expected CSR (from scipy) */ + int exp_p[10] = {0, 2, 4, 6, 8, 10, 12, 14, 16, 18}; + int exp_i[18] = {0, 3, 1, 2, 1, 2, 0, 3, 1, 2, 1, 2, 0, 3, 1, 2, 1, 2}; + double exp_x[18] = {7.0, 5.0, 3.0, 1.0, 3.0, 12.0, 2.5, 2.0, 1.0, + 0.5, 1.5, 4.0, 3.0, 4.5, 0.5, 2.0, 6.0, 2.0}; + + mu_assert("C2 dims", C->m == 9 && C->n == 4); + mu_assert("C2 nnz", C->nnz == 18); + mu_assert("C2 row ptrs", cmp_int_array(C->p, exp_p, 10)); + mu_assert("C2 col indices", cmp_int_array(C->i, exp_i, 18)); + + YT_kron_I_fill_values(m, k, n, Y, J, C); + mu_assert("C2 values", cmp_double_array(C->x, exp_x, 18)); + + free_csr_matrix(C); + free_csc_matrix(J); + return NULL; +} + +/* Test I_kron_X_alloc and I_kron_X_fill_values + * + * C = (I_n kron X) @ J + * m=2, k=2, n=2, p=3 + * + * X (m x k, col-major [1,2,3,4]): + * [1 3] + * [2 4] + * + * J (kn=4 x p=3, CSC): + * [1 0 2] + * [0 1 0] + * [3 0 0] + * [0 0 1] + * + * C = (I_2 kron X) @ J: + * [1 3 2] + * [2 4 4] + * [3 0 3] + * [6 0 4] + */ +const char *test_I_kron_X(void) +{ + int m = 2, k = 2, n = 2; + + /* J is 4x3 CSC */ + CSC_Matrix *J = new_csc_matrix(4, 3, 5); + int Jp[4] = {0, 2, 3, 5}; + int Ji[5] = {0, 2, 1, 0, 3}; + double Jx[5] = {1.0, 3.0, 1.0, 2.0, 1.0}; + memcpy(J->p, Jp, 4 * sizeof(int)); + memcpy(J->i, Ji, 5 * sizeof(int)); + memcpy(J->x, Jx, 5 * sizeof(double)); + + /* X col-major */ + double X[4] = {1.0, 2.0, 3.0, 4.0}; + + CSR_Matrix *C = I_kron_X_alloc(m, k, n, J); + + /* Expected CSR */ + int exp_p[5] = {0, 3, 6, 8, 10}; + int exp_i[10] = {0, 1, 2, 0, 1, 2, 0, 2, 0, 2}; + double exp_x[10] = {1.0, 3.0, 2.0, 2.0, 4.0, 4.0, 3.0, 3.0, 6.0, 4.0}; + + mu_assert("C dims", C->m == 4 && C->n == 3); + mu_assert("C nnz", C->nnz == 10); + mu_assert("C row ptrs", cmp_int_array(C->p, exp_p, 5)); + mu_assert("C col indices", cmp_int_array(C->i, exp_i, 10)); + + I_kron_X_fill_values(m, k, n, X, J, C); + mu_assert("C values", cmp_double_array(C->x, exp_x, 10)); + + free_csr_matrix(C); + free_csc_matrix(J); + return NULL; +} + +/* Test I_kron_X with larger dimensions: m=3, k=2, n=2, p=4 + * + * X (m=3 x k=2, col-major [1,2,3,0.5,1,0.5]): + * [1.0 0.5] + * [2.0 1.0] + * [3.0 0.5] + * + * J (kn=4 x p=4, CSC): + * [1 0 0 2] + * [0 3 1 0] + * [0 0 4 0] + * [2 0 0 1] + * + * C = (I_2 kron X) @ J is 6 x 4 + */ +const char *test_I_kron_X_larger(void) +{ + int m = 3, k = 2, n = 2; + + /* J is 4x4 CSC */ + CSC_Matrix *J = new_csc_matrix(4, 4, 7); + int Jp[5] = {0, 2, 3, 5, 7}; + int Ji[7] = {0, 3, 1, 1, 2, 0, 3}; + double Jx[7] = {1.0, 2.0, 3.0, 1.0, 4.0, 2.0, 1.0}; + memcpy(J->p, Jp, 5 * sizeof(int)); + memcpy(J->i, Ji, 7 * sizeof(int)); + memcpy(J->x, Jx, 7 * sizeof(double)); + + /* X col-major */ + double X[6] = {1.0, 2.0, 3.0, 0.5, 1.0, 0.5}; + + CSR_Matrix *C = I_kron_X_alloc(m, k, n, J); + + /* Expected CSR */ + int exp_p[7] = {0, 4, 8, 12, 15, 18, 21}; + int exp_i[21] = {0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 2, 3, 0, 2, 3, 0, 2, 3}; + double exp_x[21] = {1.0, 1.5, 0.5, 2.0, 2.0, 3.0, 1.0, 4.0, 3.0, 1.5, 0.5, + 6.0, 1.0, 4.0, 0.5, 2.0, 8.0, 1.0, 1.0, 12.0, 0.5}; + + mu_assert("C2 dims", C->m == 6 && C->n == 4); + mu_assert("C2 nnz", C->nnz == 21); + mu_assert("C2 row ptrs", cmp_int_array(C->p, exp_p, 7)); + mu_assert("C2 col indices", cmp_int_array(C->i, exp_i, 21)); + + I_kron_X_fill_values(m, k, n, X, J, C); + mu_assert("C2 values", cmp_double_array(C->x, exp_x, 21)); + + free_csr_matrix(C); + free_csc_matrix(J); + return NULL; +} diff --git a/tests/utils/test_matrix.h b/tests/utils/test_matrix.h index 8add477..c329a16 100644 --- a/tests/utils/test_matrix.h +++ b/tests/utils/test_matrix.h @@ -3,7 +3,7 @@ #include "minunit.h" #include "test_helpers.h" -#include "utils/matrix.h" +#include "utils/dense_matrix.h" #include #include diff --git a/tests/wsum_hess/composite/test_chain_rule_wsum_hess.h b/tests/wsum_hess/composite/test_chain_rule_wsum_hess.h index d0dbe4a..a0c469c 100644 --- a/tests/wsum_hess/composite/test_chain_rule_wsum_hess.h +++ b/tests/wsum_hess/composite/test_chain_rule_wsum_hess.h @@ -259,6 +259,109 @@ const char *test_wsum_hess_quad_form_sin_Ax(void) return 0; } +const char *test_wsum_hess_matmul_exp_exp(void) +{ + /* Z = exp(X) @ exp(Y), X is 2x3, Y is 3x2, Z is 2x2 */ + double u_vals[12] = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2}; + double w[4] = {1.0, 2.0, 3.0, 4.0}; + + expr *X = new_variable(2, 3, 0, 12); + expr *Y = new_variable(3, 2, 6, 12); + expr *exp_X = new_exp(X); + expr *exp_Y = new_exp(Y); + expr *Z = new_matmul(exp_X, exp_Y); + + mu_assert("check_wsum_hess failed", + check_wsum_hess(Z, u_vals, w, NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(Z); + return 0; +} + +const char *test_wsum_hess_matmul_sin_cos(void) +{ + /* Z = sin(X) @ cos(Y), X is 2x2, Y is 2x3, Z is 2x3 */ + double u_vals[10] = {0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0}; + double w[6] = {1.1, 2.2, 3.3, 4.4, 5.5, 6.6}; + + expr *X = new_variable(2, 2, 0, 10); + expr *Y = new_variable(2, 3, 4, 10); + expr *sin_X = new_sin(X); + expr *cos_Y = new_cos(Y); + expr *Z = new_matmul(sin_X, cos_Y); + + mu_assert("check_wsum_hess failed", + check_wsum_hess(Z, u_vals, w, NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(Z); + return 0; +} + +const char *test_wsum_hess_matmul_Ax_By(void) +{ + /* Z = (A @ X) @ (B @ Y), affine children */ + double u_vals[10] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0}; + double w[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + + CSR_Matrix *A = new_csr_random(3, 2, 1.0); + CSR_Matrix *B = new_csr_random(2, 3, 1.0); + + expr *X = new_variable(2, 2, 0, 10); + expr *Y = new_variable(3, 2, 4, 10); + expr *AX = new_left_matmul(X, A); /* 3x2 */ + expr *BY = new_left_matmul(Y, B); /* 2x2 */ + expr *Z = new_matmul(AX, BY); /* 3x2 */ + + mu_assert("check_wsum_hess failed", + check_wsum_hess(Z, u_vals, w, NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(Z); + free_csr_matrix(A); + free_csr_matrix(B); + return 0; +} + +const char *test_wsum_hess_matmul_sin_Ax_cos_Bx(void) +{ + /* Z = sin(A @ X) @ cos(B @ X), shared variable */ + double u_vals[6] = {0.5, 1.0, 1.5, 2.0, 2.5, 3.0}; + double w[4] = {1.0, 2.0, 3.0, 4.0}; + + CSR_Matrix *A = new_csr_random(2, 3, 1.0); + CSR_Matrix *B = new_csr_random(2, 3, 1.0); + + expr *X = new_variable(3, 2, 0, 6); + expr *AX = new_left_matmul(X, A); /* 2x2 */ + expr *BX = new_left_matmul(X, B); /* 2x2 */ + expr *sin_AX = new_sin(AX); + expr *cos_BX = new_cos(BX); + expr *Z = new_matmul(sin_AX, cos_BX); /* 2x2 */ + + mu_assert("check_wsum_hess failed", + check_wsum_hess(Z, u_vals, w, NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(Z); + free_csr_matrix(A); + free_csr_matrix(B); + return 0; +} + +const char *test_wsum_hess_matmul_X_X(void) +{ + /* Z = X @ X, same leaf variable */ + double u_vals[4] = {1.0, 2.0, 3.0, 4.0}; + double w[4] = {1.0, 2.0, 3.0, 4.0}; + + expr *X = new_variable(2, 2, 0, 4); + expr *Z = new_matmul(X, X); + + mu_assert("check_wsum_hess failed", + check_wsum_hess(Z, u_vals, w, NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(Z); + return 0; +} + const char *test_wsum_hess_quad_form_exp(void) { double u_vals[3] = {0.5, 1.0, 1.5};