Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions include/subexpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,28 @@ typedef struct right_matmul_expr
CSC_Matrix *CSC_work;
} right_matmul_expr;

/* Bivariate matrix multiplication: Z = f(u) @ g(u) where both children
* may be composite expressions. */
typedef struct matmul_expr
{
expr base;
/* Jacobian workspace */
CSR_Matrix *term1_CSR; /* (Y^T x I_m) @ J_f */
CSR_Matrix *term2_CSR; /* (I_n x X) @ J_g */

/* Hessian workspace (composite only) */
CSR_Matrix *B; /* cross-Hessian B(w), mk x kn */
CSR_Matrix *BJg; /* B @ J_g */
CSC_Matrix *BJg_CSC; /* BJg in CSC */
int *BJg_csc_work; /* CSR-to-CSC workspace */
CSR_Matrix *C; /* J_f^T @ B @ J_g */
CSR_Matrix *CT; /* C^T */
int *idx_map_C;
int *idx_map_CT;
int *idx_map_Hf;
int *idx_map_Hg;
} matmul_expr;

/* Constant scalar multiplication: y = a * child where a is a constant double */
typedef struct const_scalar_mult_expr
{
Expand Down
6 changes: 2 additions & 4 deletions include/utils/CSR_sum.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,12 @@ void sum_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C);

/* Compute sparsity pattern of A + B where A, B, C are CSR matrices.
* Fills C->p, C->i, and C->nnz; does not touch C->x. */
void sum_csr_matrices_fill_sparsity(const CSR_Matrix *A, const CSR_Matrix *B,
CSR_Matrix *C);
void sum_csr_alloc(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C);

/* Fill only the values of C = A + B, assuming C's sparsity pattern (p and i)
* is already filled and matches the union of A and B per row. Does not modify
* C->p, C->i, or C->nnz. */
void sum_csr_matrices_fill_values(const CSR_Matrix *A, const CSR_Matrix *B,
CSR_Matrix *C);
void sum_csr_fill_values(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C);

/* Compute C = diag(d1) * A + diag(d2) * B where A, B, C are CSR matrices */
void sum_scaled_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C,
Expand Down
20 changes: 20 additions & 0 deletions include/utils/dense_matrix.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#ifndef DENSE_MATRIX_H
#define DENSE_MATRIX_H

#include "matrix.h"

/* Dense matrix (row-major) */
typedef struct Dense_Matrix
{
Matrix base;
double *x;
double *work; /* scratch buffer, length n */
} Dense_Matrix;

/* Constructors */
Matrix *new_dense_matrix(int m, int n, const double *data);

/* Transpose helper */
Matrix *dense_matrix_trans(const Dense_Matrix *self);

#endif /* DENSE_MATRIX_H */
26 changes: 26 additions & 0 deletions include/utils/linalg_dense_sparse_matmuls.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef LINALG_DENSE_SPARSE_H
#define LINALG_DENSE_SPARSE_H

#include "CSC_Matrix.h"
#include "CSR_Matrix.h"
#include "matrix.h"

/* C = (I_p kron A) @ J via the polymorphic Matrix interface.
* A is dense m x n, J is (n*p) x k in CSC, C is (m*p) x k in CSC. */
// TODO: maybe we can replace these with I_kron_X functionality?
CSC_Matrix *I_kron_A_alloc(const Matrix *A, const CSC_Matrix *J, int p);
void I_kron_A_fill_values(const Matrix *A, const CSC_Matrix *J, CSC_Matrix *C);

/* Sparsity and values of C = (Y^T kron I_m) @ J where Y is k x n, J is (m*k) x p,
and C is (m*n) x p. Y is given in column-major dense format. */
CSR_Matrix *YT_kron_I_alloc(int m, int k, int n, const CSC_Matrix *J);
void YT_kron_I_fill_values(int m, int k, int n, const double *Y, const CSC_Matrix *J,
CSR_Matrix *C);

/* Sparsity and values of C = (I_n kron X) @ J where X is m x k (col-major dense),
J is (k*n) x p, and C is (m*n) x p. */
CSR_Matrix *I_kron_X_alloc(int m, int k, int n, const CSC_Matrix *J);
void I_kron_X_fill_values(int m, int k, int n, const double *X, const CSC_Matrix *J,
CSR_Matrix *C);

#endif /* LINALG_DENSE_SPARSE_H */
25 changes: 10 additions & 15 deletions include/utils/linalg_sparse_matmuls.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
#ifndef LINALG_H
#define LINALG_H

/* Forward declarations */
struct CSR_Matrix;
struct CSC_Matrix;
#include "CSC_Matrix.h"
#include "CSR_Matrix.h"

/* Compute sparsity pattern and values for the matrix-matrix multiplication
C = (I_p kron A) @ J where A is m x n, J is (n*p) x k, and C is (m*p) x k,
Expand All @@ -15,32 +14,28 @@ struct CSC_Matrix;
* Mathematically it corresponds to C = [A @ J1; A @ J2; ...; A @ Jp],
where J = [J1; J2; ...; Jp]
*/
struct CSC_Matrix *block_left_multiply_fill_sparsity(const struct CSR_Matrix *A,
const struct CSC_Matrix *J,
int p);
CSC_Matrix *block_left_multiply_fill_sparsity(const CSR_Matrix *A,
const CSC_Matrix *J, int p);

void block_left_multiply_fill_values(const struct CSR_Matrix *A,
const struct CSC_Matrix *J,
struct CSC_Matrix *C);
void block_left_multiply_fill_values(const CSR_Matrix *A, const CSC_Matrix *J,
CSC_Matrix *C);

/* Compute y = kron(I_p, A) @ x where A is m x n and x is(n*p)-length vector.
The output y is m*p-length vector corresponding to
y = [A @ x1; A @ x2; ...; A @ xp] where x is divided into p blocks of n
elements.
*/
void block_left_multiply_vec(const struct CSR_Matrix *A, const double *x, double *y,
int p);
void block_left_multiply_vec(const CSR_Matrix *A, const double *x, double *y, int p);

/* Fill values of C = A @ B where A is CSR, B is CSC.
* C must have sparsity pattern already computed.
*/
void csr_csc_matmul_fill_values(const struct CSR_Matrix *A,
const struct CSC_Matrix *B, struct CSR_Matrix *C);
void csr_csc_matmul_fill_values(const CSR_Matrix *A, const CSC_Matrix *B,
CSR_Matrix *C);

/* C = A @ B where A is CSR, B is CSC. Result C is CSR.
* Allocates and precomputes sparsity pattern. No workspace required.
*/
struct CSR_Matrix *csr_csc_matmul_alloc(const struct CSR_Matrix *A,
const struct CSC_Matrix *B);
CSR_Matrix *csr_csc_matmul_alloc(const CSR_Matrix *A, const CSC_Matrix *B);

#endif /* LINALG_H */
12 changes: 1 addition & 11 deletions include/utils/matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,11 @@ typedef struct Sparse_Matrix
CSR_Matrix *csr;
} Sparse_Matrix;

/* Dense matrix (row-major) */
typedef struct Dense_Matrix
{
Matrix base;
double *x;
double *work; /* scratch buffer, length n */
} Dense_Matrix;

/* Constructors */
Matrix *new_sparse_matrix(const CSR_Matrix *A);
Matrix *new_dense_matrix(int m, int n, const double *data);

/* Transpose helpers */
/* Transpose helper */
Matrix *sparse_matrix_trans(const Sparse_Matrix *self, int *iwork);
Matrix *dense_matrix_trans(const Dense_Matrix *self);

/* Free helper */
static inline void free_matrix(Matrix *m)
Expand Down
10 changes: 10 additions & 0 deletions include/utils/mini_numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,14 @@ void scaled_ones(double *result, int size, double value);
/* Naive implementation of Z = X @ Y, X is m x k, Y is k x n, Z is m x n */
void mat_mat_mult(const double *X, const double *Y, double *Z, int m, int k, int n);

/* Compute v = (Y kron I_m) @ w where Y is k x n (col-major), w has
length m*n, and v has length m*k. Equivalently, reshape w as the
m x n matrix W (col-major) and compute v = vec(W @ Y^T). */
void Y_kron_I_vec(int m, int k, int n, const double *Y, const double *w, double *v);

/* Compute v = (I_n kron X^T) @ w where X is m x k (col-major), w has
length m*n, and v has length k*n. Equivalently, reshape w as the
m x n matrix W (col-major) and compute v = vec(X^T @ W). */
void I_kron_XT_vec(int m, int k, int n, const double *X, const double *w, double *v);

#endif /* MINI_NUMPY_H */
13 changes: 5 additions & 8 deletions src/affine/add.c
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ static void jacobian_init_impl(expr *node)
node->jacobian = new_csr_matrix(node->size, node->n_vars, nnz_max);

/* fill sparsity pattern */
sum_csr_matrices_fill_sparsity(node->left->jacobian, node->right->jacobian,
node->jacobian);
sum_csr_alloc(node->left->jacobian, node->right->jacobian, node->jacobian);
}

static void eval_jacobian(expr *node)
Expand All @@ -56,8 +55,7 @@ static void eval_jacobian(expr *node)
node->right->eval_jacobian(node->right);

/* sum children's jacobians */
sum_csr_matrices_fill_values(node->left->jacobian, node->right->jacobian,
node->jacobian);
sum_csr_fill_values(node->left->jacobian, node->right->jacobian, node->jacobian);
}

static void wsum_hess_init_impl(expr *node)
Expand All @@ -71,8 +69,7 @@ static void wsum_hess_init_impl(expr *node)
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, nnz_max);

/* fill sparsity pattern of hessian */
sum_csr_matrices_fill_sparsity(node->left->wsum_hess, node->right->wsum_hess,
node->wsum_hess);
sum_csr_alloc(node->left->wsum_hess, node->right->wsum_hess, node->wsum_hess);
}

static void eval_wsum_hess(expr *node, const double *w)
Expand All @@ -82,8 +79,8 @@ static void eval_wsum_hess(expr *node, const double *w)
node->right->eval_wsum_hess(node->right, w);

/* sum children's wsum_hess */
sum_csr_matrices_fill_values(node->left->wsum_hess, node->right->wsum_hess,
node->wsum_hess);
sum_csr_fill_values(node->left->wsum_hess, node->right->wsum_hess,
node->wsum_hess);
}

static bool is_affine(const expr *node)
Expand Down
4 changes: 2 additions & 2 deletions src/affine/hstack.c
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ static void wsum_hess_init_impl(expr *node)
{
expr *child = hnode->args[i];
copy_csr_matrix(H, hnode->CSR_work);
sum_csr_matrices_fill_sparsity(hnode->CSR_work, child->wsum_hess, H);
sum_csr_alloc(hnode->CSR_work, child->wsum_hess, H);
}
}

Expand All @@ -140,7 +140,7 @@ static void wsum_hess_eval(expr *node, const double *w)
expr *child = hnode->args[i];
child->eval_wsum_hess(child, w + row_offset);
copy_csr_matrix(H, hnode->CSR_work);
sum_csr_matrices_fill_values(hnode->CSR_work, child->wsum_hess, H);
sum_csr_fill_values(hnode->CSR_work, child->wsum_hess, H);
row_offset += child->size;
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/affine/left_matmul.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
*/
#include "affine.h"
#include "subexpr.h"
#include "utils/matrix.h"
#include "utils/dense_matrix.h"
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
Expand Down
Loading
Loading