This document introduces the intrinsics for RISC-V matrix programming, including the general naming rules for intrinsics, the data types of matrix, and the full set of intrinsics.
-
Data type naming rules: Prefix the basic data type with 'm'. Additionally, use the suffix "x2" to denote a pair of registers, exemplified by "mint8x2_t".
-
Function interface naming rules: For simplicity, the interface is named by the instruction name and the prefix "__riscv_th_", while mm, mv, and mx are used to distinguish the source operands as matrix-matrix, matrix-vector, and matrix-scalar.
Type |
Name |
Description |
matrix int |
mint8_t |
All elements of the matrix are int8_t |
mint16_t |
||
mint32_t |
||
mint64_t |
||
muint8_t |
All elements of the matrix are uint8_t |
|
muint16_t |
||
muint32_t |
||
muint64_t |
||
matrix float |
mfloat16_t |
All elements of the matrix are float16_t |
mfloat32_t |
||
mfloat64_t |
||
matrix int x2 |
mint8x2_t |
The type is denoted by a pair of sequential matrix registers, with the initial register being an even number. |
mint16x2_t |
||
mint32x2_t |
||
mint64x2_t |
||
muint8x2_t |
||
muint16x2_t |
||
muint32x2_t |
||
muint64x2_t |
||
matrix float x2 |
mfloat16x2_t |
|
mfloat32x2_t |
||
mfloat64x2_t |
||
matrix row num |
mrow_t |
The type is actually size_t, which represents the number of matrix rows. Only the lower 8 bits of this type are valid. |
matrix column num |
mcol_t |
The type is actually size_t, which represents the number of matrix columns. Only the lower 16 bits of this type are valid. |
Instructions
mcfgi<m/n/k> rd,uimm7
mcfg<m/n/k> rd,rs1
Intrinsic functions list
mrow_t __riscv_th_msetmrow_m (mrow_t m);
mrow_t __riscv_th_msetmrow_n (mrow_t n);
mcol_t __riscv_th_msetmcol_e8 (mcol_t c);
mcol_t __riscv_th_msetmcol_e16 (mcol_t c);
mcol_t __riscv_th_msetmcol_e32 (mcol_t c);
mcol_t __riscv_th_msetmcol_e64 (mcol_t c);
Note
|
Set m, n, and k in the CSR msize to return valid values. When the value of any parameter exceeds the maximum allowable setting, only the lower bits are considered valid: 8 bits for m, 8 bits for n, and 16 bits for k.
|
Instructions
csrr rd,xmrstart
csrr rd,xmcsr
csrr rd,xmsize
csrr rd,xmlenb
csrr rd,xrlenb
csrr rd,xmisa
csrw xmrstart,rs1
csrw xmcsr,rs1
csrw xmsize,rs1
Intrinsic functions list
enum RVM_CSR {
RVM_XMRSTART = 0,
RVM_XMCSR,
RVM_XMSIZE,
RVM_XMLENB,
RVM_XRLENB,
RVM_XMISA
};
unsigned long __riscv_th_mread_csr(enum RVM_CSR csr);
void __riscv_th_mwrite_csr(enum RVM_CSR csr, unsigned long value)
// specialization version
unsigned long __riscv_th_xmlenb();
unsigned long __riscv_th_xrlenb();
unsigned long __riscv_th_xmsize();
Intrinsic functions list
mint8_t __riscv_th_mundefined_i8 ();
mint16_t __riscv_th_mundefined_i16 ();
mint32_t __riscv_th_mundefined_i32 ();
mint64_t __riscv_th_mundefined_i64 ();
muint8_t __riscv_th_mundefined_u8 ();
muint16_t __riscv_th_mundefined_u16 ();
muint32_t __riscv_th_mundefined_u32 ();
muint64_t __riscv_th_mundefined_u64 ();
mfloat16_t __riscv_th_mundefined_f16 ();
mfloat32_t __riscv_th_mundefined_f32 ();
mfloat64_t __riscv_th_mundefined_f64 ();
mint8x2_t __riscv_th_mundefined_i8x2 ();
mint16x2_t __riscv_th_mundefined_i16x2 ();
mint32x2_t __riscv_th_mundefined_i32x2 ();
mint64x2_t __riscv_th_mundefined_i64x2 ();
muint8x2_t __riscv_th_mundefined_u8x2 ();
muint16x2_t __riscv_th_mundefined_u16x2 ();
muint32x2_t __riscv_th_mundefined_u32x2 ();
muint64x2_t __riscv_th_mundefined_u64x2 ();
mfloat16x2_t __riscv_th_mundefined_f16x2 ();
mfloat32x2_t __riscv_th_mundefined_f32x2 ();
mfloat64x2_t __riscv_th_mundefined_f64x2 ();
Intrinsic functions list
mint8_t __riscv_th_mreinterpret_i8 (src);
mint16_t __riscv_th_mreinterpret_i16 (src);
mint32_t __riscv_th_mreinterpret_i32 (src);
mint64_t __riscv_th_mreinterpret_i64 (src);
muint8_t __riscv_th_mreinterpret_u8 (src);
muint16_t __riscv_th_mreinterpret_u16 (src);
muint32_t __riscv_th_mreinterpret_u32 (src);
muint64_t __riscv_th_mreinterpret_u64 (src);
mfloat16_t __riscv_th_mreinterpret_f16 (src);
mfloat32_t __riscv_th_mreinterpret_f32 (src);
mfloat64_t __riscv_th_mreinterpret_f64 (src);
mint8x2_t __riscv_th_mreinterpret_i8x2 (src);
mint16x2_t __riscv_th_mreinterpret_i16x2 (src);
mint32x2_t __riscv_th_mreinterpret_i32x2 (src);
mint64x2_t __riscv_th_mreinterpret_i64x2 (src);
muint8x2_t __riscv_th_mreinterpret_u8x2 (src);
muint16x2_t __riscv_th_mreinterpret_u16x2 (src);
muint32x2_t __riscv_th_mreinterpret_u32x2 (src);
muint64x2_t __riscv_th_mreinterpret_u64x2 (src);
mfloat16x2_t __riscv_th_mreinterpret_f16x2 (src);
mfloat32x2_t __riscv_th_mreinterpret_f32x2 (src);
mfloat64x2_t __riscv_th_mreinterpret_f64x2 (src);
Note
|
The type of SRC can be any matrix type with the same number of registers. |
Instructions
mzero rd
Intrinsic functions list
mint8_t __riscv_th_mzero_i8 ();
mint16_t __riscv_th_mzero_i16 ();
mint32_t __riscv_th_mzero_i32 ();
mint64_t __riscv_th_mzero_i64 ();
muint8_t __riscv_th_mzero_u8 ();
muint16_t __riscv_th_mzero_u16 ();
muint32_t __riscv_th_mzero_u32 ();
muint64_t __riscv_th_mzero_u64 ();
mfloat16_t __riscv_th_mzero_f16 ();
mfloat32_t __riscv_th_mzero_f32 ();
mfloat64_t __riscv_th_mzero_f64 ();
mint8x2_t __riscv_th_mzero_i8x2 ();
mint16x2_t __riscv_th_mzero_i16x2 ();
mint32x2_t __riscv_th_mzero_i32x2 ();
mint64x2_t __riscv_th_mzero_i64x2 ();
muint8x2_t __riscv_th_mzero_u8x2 ();
muint16x2_t __riscv_th_mzero_u16x2 ();
muint32x2_t __riscv_th_mzero_u32x2 ();
muint64x2_t __riscv_th_mzero_u64x2 ();
mfloat16x2_t __riscv_th_mzero_f16x2 ();
mfloat32x2_t __riscv_th_mzero_f32x2 ();
mfloat64x2_t __riscv_th_mzero_f64x2 ();
Note
|
Zero all elements of matrix register. |
Instructions
#matrix load
mld<b/h/w/d> md, rs2, (rs1)
#stream matrix load
msld<b/h/w/d> md, rs2, (rs1)
Intrinsic functions list
//matrix load
mint8_t __riscv_th_mld (const int8_t *base, long stride, mrow_t row, mcol_t col);
muint8_t __riscv_th_mld (const uint8_t *base, long stride, mrow_t row, mcol_t col);
mint16_t __riscv_th_mld (const int16_t *base, long stride, mrow_t row, mcol_t col);
muint16_t __riscv_th_mld (const uint16_t *base, long stride, mrow_t row, mcol_t col);
mint32_t __riscv_th_mld (const int32_t *base, long stride, mrow_t row, mcol_t col);
muint32_t __riscv_th_mld (const uint32_t *base, long stride, mrow_t row, mcol_t col);
mint64_t __riscv_th_mld (const int64_t *base, long stride, mrow_t row, mcol_t col);
muint64_t __riscv_th_mld (const uint64_t *base, long stride, mrow_t row, mcol_t col);
mfloat16_t __riscv_th_mld (const float16_t *base, long stride, mrow_t row, mcol_t col);
mfloat32_t __riscv_th_mld (const float32_t *base, long stride, mrow_t row, mcol_t col);
mfloat64_t __riscv_th_mld (const float64_t *base, long stride, mrow_t row, mcol_t col);
//stream matrix load
mint8_t __riscv_th_msld (const int8_t *base, long stride, mrow_t row, mcol_t col);
muint8_t __riscv_th_msld (const uint8_t *base, long stride, mrow_t row, mcol_t col);
mint16_t __riscv_th_msld (const int16_t *base, long stride, mrow_t row, mcol_t col);
muint16_t __riscv_th_msld (const uint16_t *base, long stride, mrow_t row, mcol_t col);
mint32_t __riscv_th_msld (const int32_t *base, long stride, mrow_t row, mcol_t col);
muint32_t __riscv_th_msld (const uint32_t *base, long stride, mrow_t row, mcol_t col);
mint64_t __riscv_th_msld (const int64_t *base, long stride, mrow_t row, mcol_t col);
muint64_t __riscv_th_msld (const uint64_t *base, long stride, mrow_t row, mcol_t col);
mfloat16_t __riscv_th_msld (const float16_t *base, long stride, mrow_t row, mcol_t col);
mfloat32_t __riscv_th_msld (const float32_t *base, long stride, mrow_t row, mcol_t col);
mfloat64_t __riscv_th_msld (const float64_t *base, long stride, mrow_t row, mcol_t col);
Note
|
Read from the memory to the matrix register: The input parameter is the memory base address, stride, and the return value is the target matrix. |
Instructions
#matrix store
mst<b/h/w/d> ms3, rs2, (rs1)
#stream matrix store
msst<b/h/w/d> ms3, rs2, (rs1)
Intrinsic functions list
//matrix store
void __riscv_th_mst (const int8_t *base, long stride, mint8_t value, mrow_t row, mcol_t col);
void __riscv_th_mst (const uint8_t *base, long stride, muint8_t value, mrow_t row, mcol_t col);
void __riscv_th_mst (const int16_t *base, long stride, mint16_t value, mrow_t row, mcol_t col);
void __riscv_th_mst (const uint16_t *base, long stride, muint16_t value, mrow_t row, mcol_t col);
void __riscv_th_mst (const int32_t *base, long stride, mint32_t value, mrow_t row, mcol_t col);
void __riscv_th_mst (const uint32_t *base, long stride, muint32_t value, mrow_t row, mcol_t col);
void __riscv_th_mst (const int64_t *base, long stride, mint64_t value, mrow_t row, mcol_t col);
void __riscv_th_mst (const uint64_t *base, long stride, muint64_t value, mrow_t row, mcol_t col);
void __riscv_th_mst (const float16_t *base, long stride, mfloat16_t value, mrow_t row, mcol_t col);
void __riscv_th_mst (const float32_t *base, long stride, mfloat32_t value, mrow_t row, mcol_t col);
void __riscv_th_mst (const float64_t *base, long stride, mfloat64_t value, mrow_t row, mcol_t col);
//stream matrix store
void __riscv_th_msst (const int8_t *base, long stride, mint8_t value, mrow_t row, mcol_t col);
void __riscv_th_msst (const uint8_t *base, long stride, muint8_t value, mrow_t row, mcol_t col);
void __riscv_th_msst (const int16_t *base, long stride, mint16_t value, mrow_t row, mcol_t col);
void __riscv_th_msst (const uint16_t *base, long stride, muint16_t value, mrow_t row, mcol_t col);
void __riscv_th_msst (const int32_t *base, long stride, mint32_t value, mrow_t row, mcol_t col);
void __riscv_th_msst (const uint32_t *base, long stride, muint32_t value, mrow_t row, mcol_t col);
void __riscv_th_msst (const int64_t *base, long stride, mint64_t value, mrow_t row, mcol_t col);
void __riscv_th_msst (const uint64_t *base, long stride, muint64_t value, mrow_t row, mcol_t col);
void __riscv_th_msst (const float16_t *base, long stride, mfloat16_t value, mrow_t row, mcol_t col);
void __riscv_th_msst (const float32_t *base, long stride, mfloat32_t value, mrow_t row, mcol_t col);
void __riscv_th_msst (const float64_t *base, long stride, mfloat64_t value, mrow_t row, mcol_t col);
Note
|
Write the matrix register data into the memory, and the input parameter is the destination base address, stride, and the original operand. |
Instructions
#matrix-matrix mov
mmov.mm md, ms1
#matrix-vector add,rs1'/uimm3
mmov.mv.x md, ms1[rs1']
mmov.mv.i md, ms1[uimm3]
#matrix-scalar mov with duplicate
mdup<b/h/w/d>.m.x md, rs2
#matrix-scalar mov
mmov<b/h/w/d>.m.x md, rs2, rs1
mmov<b/h/w/d>.x.m rd, ms2, rs1
Intrinsic functions list
//matrix-vector mov,rs1/uimm3
mint8_t __riscv_th_mmov_mv (mint8_t src, size_t index);
muint8_t __riscv_th_mmov_mv (muint8_t src, size_t index);
mint16_t __riscv_th_mmov_mv (mint16_t src, size_t index);
muint16_t __riscv_th_mmov_mv (muint16_t src, size_t index);
mint32_t __riscv_th_mmov_mv (mint32_t src, size_t index);
muint32_t __riscv_th_mmov_mv (muint32_t src, size_t index);
mint64_t __riscv_th_mmov_mv (mint64_t src, size_t index);
muint64_t __riscv_th_mmov_mv (muint64_t src, size_t index);
mfloat16_t __riscv_th_mmov_mv (mfloat16_t src, size_t index);
mfloat32_t __riscv_th_mmov_mv (mfloat32_t src, size_t index);
mfloat64_t __riscv_th_mmov_mv (mfloat64_t src, size_t index);
// matrix-scalar mov with duplicate
mint8_t __riscv_th_mdup_m_x (int8_t src);
muint8_t __riscv_th_mdup_m_x (uint8_t src);
mint16_t __riscv_th_mdup_m_x (int16_t src);
muint16_t __riscv_th_mdup_m_x (uint16_t src);
mint32_t __riscv_th_mdup_m_x (int32_t src);
muint32_t __riscv_th_mdup_m_x (uint32_t src);
mint64_t __riscv_th_mdup_m_x (int64_t src);
muint64_t __riscv_th_mdup_m_x (uint64_t src);
// matrix-scalar mov
mint8_t __riscv_th_mmov_m_x (mint8_t dest, int8_t src, size_t index);
muint8_t __riscv_th_mmov_m_x (muint8_t dest, uint8_t src, size_t index);
mint16_t __riscv_th_mmov_m_x (mint16_t dest, int16_t src, size_t index);
muint16_t __riscv_th_mmov_m_x (muint16_t dest, uint16_t src, size_t index);
mint32_t __riscv_th_mmov_m_x (mint32_t dest, int32_t src, size_t index);
muint32_t __riscv_th_mmov_m_x (muint32_t dest, uint32_t src, size_t index);
mint64_t __riscv_th_mmov_m_x (mint64_t dest, int64_t src, size_t index);
muint64_t __riscv_th_mmov_m_x (muint64_t dest, uint64_t src, size_t index);
int8_t __riscv_th_mmov_x_m (mint8_t src, size_t index);
uint8_t __riscv_th_mmov_x_m (muint8_t src, size_t index);
int16_t __riscv_th_mmov_x_m (mint16_t src, size_t index);
uint16_t __riscv_th_mmov_x_m (muint16_t src, size_t index);
int32_t __riscv_th_mmov_x_m (mint32_t src, size_t index);
uint32_t __riscv_th_mmov_x_m (muint32_t src, size_t index);
int64_t __riscv_th_mmov_x_m (mint64_t src, size_t index);
uint64_t __riscv_th_mmov_x_m (muint64_t src, size_t index);
Intrinsic functions list
// matrix tuple
mint8x2_t __riscv_th_mset (mint8x2_t src, size_t index, mint8_t value);
mint16x2_t __riscv_th_mset (mint16x2_t src, size_t index, mint16_t value);
mint32x2_t __riscv_th_mset (mint32x2_t src, size_t index, mint32_t value);
mint64x2_t __riscv_th_mset (mint64x2_t src, size_t index, mint64_t value);
muint8x2_t __riscv_th_mset (muint8x2_t src, size_t index, muint8_t value);
muint16x2_t __riscv_th_mset (muint16x2_t src, size_t index, muint16_t value);
muint32x2_t __riscv_th_mset (muint32x2_t src, size_t index, muint32_t value);
muint64x2_t __riscv_th_mset (muint64x2_t src, size_t index, muint64_t value);
mfloat16x2_t __riscv_th_mset (mfloat16x2_t src, size_t index, mfloat16_t value);
mfloat32x2_t __riscv_th_mset (mfloat32x2_t src, size_t index, mfloat32_t value);
mfloat64x2_t __riscv_th_mset (mfloat64x2_t src, size_t index, mfloat64_t value);
mint8_t __riscv_th_mget (mint8x2_t src, size_t index);
mint16_t __riscv_th_mget (mint16x2_t src, size_t index);
mint32_t __riscv_th_mget (mint32x2_t src, size_t index);
mint64_t __riscv_th_mget (mint64x2_t src, size_t index);
muint8_t __riscv_th_mget (muint8x2_t src, size_t index);
muint16_t __riscv_th_mget (muint16x2_t src, size_t index);
muint32_t __riscv_th_mget (muint32x2_t src, size_t index);
muint64_t __riscv_th_mget (muint64x2_t src, size_t index);
mfloat16_t __riscv_th_mget (mfloat16x2_t src, size_t index);
mfloat32_t __riscv_th_mget (mfloat32x2_t src, size_t index);
mfloat64_t __riscv_th_mget (mfloat64x2_t src, size_t index);
Note
|
The INDEX argument must be provided as a constant integer expression. |
Instructions
#matrix-matrix add
madd.<s/d>.mm md, ms2, ms1
#matrix-vector add,rs1/uimm6
madd.<s/d>.mv.x md, ms2, ms1[rs1]
madd.<s/d>.mv.i md, ms2, ms1[uimm3]
#matrix-scalar add
madd.<s/d>.mx md, ms2, rs1
Intrinsic functions list
//matrix-matrix add
mint32_t __riscv_th_madd_mm (mint32_t src1, mint32_t src2, mrow_t row, mcol_t col);
mint64_t __riscv_th_madd_mm (mint64_t src1, mint64_t src2, mrow_t row, mcol_t col);
//matrix-vector add,rs1/uimm6
mint32_t __riscv_th_madd_mv (mint32_t src1, mint32_t src2, size_t index, mrow_t row, mcol_t col);
mint64_t __riscv_th_madd_mv (mint64_t src1, mint64_t src2, size_t index, mrow_t row, mcol_t col);
//matrix-scalar add
mint32_t __riscv_th_madd_mx (mint32_t src1, int32_t src2, mrow_t row, mcol_t col);
mint64_t __riscv_th_madd_mx (mint64_t src1, int64_t src2, mrow_t row, mcol_t col);
Instructions
#matrix-matrix sub
msub.<s/d>.mm md, ms2, ms1
#matrix-vector sub,rs1/uimm6
msub.<s/d>.mv.x md, ms2, ms1[rs1]
msub.<s/d>.mv.i md, ms2, ms1[uimm3]
#matrix-scalar sub
msub.<s/d>.mx md, ms2, rs1
Intrinsic functions list
//matrix-matrix sub
mint32_t __riscv_th_msub_mm (mint32_t src1, mint32_t src2, mrow_t row, mcol_t col);
mint64_t __riscv_th_msub_mm (mint64_t src1, mint64_t src2, mrow_t row, mcol_t col);
//matrix-vector sub,rs1/uimm6
mint32_t __riscv_th_msub_mv (mint32_t src1, mint32_t src2, size_t index, mrow_t row, mcol_t col);
mint64_t __riscv_th_msub_mv (mint64_t src1, mint64_t src2, size_t index, mrow_t row, mcol_t col);
//matrix-scalar sub
mint32_t __riscv_th_msub_mx (mint32_t src1, int32_t src2, mrow_t row, mcol_t col);
mint64_t __riscv_th_msub_mx (mint64_t src1, int64_t src2, mrow_t row, mcol_t col);
Instructions
#matrix-matrix shift
msra.<s/d>.mm md, ms2, ms1
#matrix-vector shift,rs1
msra.<s/d>.mv.x md, ms2, ms1[rs1]
msra.<s/d>.mv.i md, ms2, ms1[uimm3]
#matrix-scalar shift
msra.<s/d>.mx md, ms2, rs1
Intrinsic functions list
//matrix-matrix sra
mint32_t __riscv_th_msra_mm (mint32_t src1, muint32_t src2, mrow_t row, mcol_t col);
mint64_t __riscv_th_msra_mm (mint64_t src1, muint64_t src2, mrow_t row, mcol_t col);
//matrix-vector sra,rs1/uimm6
mint32_t __riscv_th_msra_mv (mint32_t src1, muint32_t src2, size_t index, mrow_t row, mcol_t col);
mint64_t __riscv_th_msra_mv (mint64_t src1, muint64_t src2, size_t index, mrow_t row, mcol_t col);
//matrix-scalar sra
mint32_t __riscv_th_msra_mx (mint32_t src1, uint32_t src2, mrow_t row, mcol_t col);
mint64_t __riscv_th_msra_mx (mint64_t src1, uint64_t src2, mrow_t row, mcol_t col);
Instructions
#matrix-matrix signed clip
mn4clip.<s/d>.mm md, ms2, ms1
#matrix-vector clip,rs0
mn4clip.<s/d>.mv.x md, ms2, ms1[rs1]
mn4clip.<s/d>.mv.i md, ms2, ms1[uimm3]
#matrix-scalar clip
mn4clip.<s/d>.mx md, ms2, rs1
#matrix-matrix unsigned clip
mn4clipu.<s/d>.mm md, ms2, ms1
#matrix-vector clip,rs0
mn4clipu.<s/d>.mv.x md, ms2, ms1[rs1]
mn4clipu.<s/d>.mv.i md, ms2, ms1[uimm3]
#matrix-scalar clip
mn4clipu.<s/d>.mx md, ms2, rs1
Intrinsic functions list
//matrix-matrix signed clip
mint8_t __riscv_th_mn4clip_mm (mint32_t src1, muint32_t src2, mrow_t row, mcol_t col);
mint8_t __riscv_th_mn4clip_mm (mint64_t src1, muint64_t src2, mrow_t row, mcol_t col);
//matrix-vector clip,rs1/uimm3
mint8_t __riscv_th_mn4clip_mv (mint32_t src1, muint32_t src2, size_t index, mrow_t row, mcol_t col);
mint8_t __riscv_th_mn4clip_mv (mint64_t src1, muint64_t src2, size_t index, mrow_t row, mcol_t col);
//matrix-scalar clip
mint8_t __riscv_th_mn4clip_mx (mint32_t src1, uint32_t src2, mrow_t row, mcol_t col);
mint8_t __riscv_th_mn4clip_mx (mint64_t src1, uint64_t src2, mrow_t row, mcol_t col);
//matrix-matrix unsigned clip
muint8_t __riscv_th_mn4clipu_mm (muint32_t src1, muint32_t src2, mrow_t row, mcol_t col);
muint8_t __riscv_th_mn4clipu_mm (muint64_t src1, muint64_t src2, mrow_t row, mcol_t col);
//matrix-vector clip,rs1/uimm3
muint8_t __riscv_th_mn4clipu_mv (muint32_t src1, muint32_t src2, size_t index, mrow_t row, mcol_t col);
muint8_t __riscv_th_mn4clipu_mv (muint64_t src1, muint64_t src2, size_t index, mrow_t row, mcol_t col);
//matrix-scalar clip
muint8_t __riscv_th_mn4clipu_mx (muint32_t src1, uint32_t src2, mrow_t row, mcol_t col);
muint8_t __riscv_th_mn4clipu_mx (muint64_t src1, uint64_t src2, mrow_t row, mcol_t col);
Instructions
#matrix-matrix mul
mmul.<s/d>.mx md, ms2, ms1
#matrix-vector mul, rs1
mmul.<s/d>.mv.x md, ms2, ms1[rs1]
mmul.<s/d>.mv.i md, ms2, ms1[uimm3]
#matrix-scalar mul
mmul.<s/d>.mx md, ms2, rs1
Intrinsic functions list
//matrix-matrix mul
mint32_t __riscv_th_mmul_mm (mint32_t src1, mint32_t src2, mrow_t row, mcol_t col);
mint64_t __riscv_th_mmul_mm (mint64_t src1, mint64_t src2, mrow_t row, mcol_t col);
//matrix-vector mul,rs1/uimm3
mint32_t __riscv_th_mmul_mv (mint32_t src1, mint32_t src2, size_t index, mrow_t row, mcol_t col);
mint64_t __riscv_th_mmul_mv (mint64_t src1, mint64_t src2, size_t index, mrow_t row, mcol_t col);
//matrix-scalar mul
mint32_t __riscv_th_mmul_mx (mint32_t src1, int32_t src2, mrow_t row, mcol_t col);
mint64_t __riscv_th_mmul_mx (mint64_t src1, int64_t src2, mrow_t row, mcol_t col);
keep the low-half of the 64-bit result.
Instructions
#matrix-matrix mul
mmulh.s.mx md, ms2, ms1
#matrix-vector mul, rs1
mmulh.s.mv.x md, ms2, ms1[rs1]
mmulh.s.mv.i md, ms2, ms1[uimm3]
#matrix-scalar mul
mmulh.s.mx md, ms2, rs1
Intrinsic functions list
//matrix-matrix mulh
mint32_t __riscv_th_mmulh_mm (mint32_t src1, mint32_t src2, mrow_t row, mcol_t col);
//matrix-vector mulh,rs1/uimm3
mint32_t __riscv_th_mmulh_mv (mint32_t src1, mint32_t src2, size_t index, mrow_t row, mcol_t col);
//matrix-scalar mulh
mint32_t __riscv_th_mmulh_mx (mint32_t src1, int32_t src2, mrow_t row, mcol_t col);
Note
|
High-half of the 64-bit result reserved. |
The DEST represents the previous value of the return value, which requires initialization in the absence of an old value to prevent the appearance of unknown data. Furthermore, both the SRC1 and SRC2 serve as multipliers.
Instructions
#matrix-matrix
fmmacc.<h/s/d> md, ms2, ms1
Intrinsic functions list
//matrix-matrix
mfloat16_t __riscv_th_fmmacc (mfloat16_t dest, mfloat16_t src1, mfloat16x2_t src2, mrow_t row1, mrow_t row2, mcol_t col);
mfloat32_t __riscv_th_fmmacc (mfloat32_t dest, mfloat32_t src1, mfloat32_t src2, mrow_t row1, mrow_t row2, mcol_t col);
mfloat64x2_t __riscv_th_fmmacc (mfloat64x2_t dest, mfloat64_t src1, mfloat64_t src2, mrow_t row1, mrow_t row2, mcol_t col);
Instructions
#matrix-matrix
fwmmacc.<h/s> md, ms2, ms1
Intrinsic functions list
//matrix-matrix
mfloat32_t __riscv_th_fwmmacc (mfloat32_t dest, mfloat16_t src1, mfloat16_t src2, mrow_t row1, mrow_t row2, mcol_t col);
mfloat64x2_t __riscv_th_fwmmacc (mfloat64x2_t dest, mfloat32_t src1, mfloat32_t src2, mrow_t row1, mrow_t row2, mcol_t col);
Instructions
#8bit data width
#signed matrix multiply
mmaqa.<b/h> md, ms2, ms1
#unsigned matrix multiply
mmaqau.<b/h> md, ms2, ms1
#unsigned-signed matrix multiply
mmaqaus.<b/h> md, ms2, ms1
#signed-unsigned matrix multiply
mmaqasu.<b/h> md, ms2, ms1
Intrinsic functions list
//signed matrix multiply
mint32_t __riscv_th_mmaqa (mint32_t dest, mint8_t src1, mint8_t src2, mrow_t row1, mrow_t row2, mcol_t col);
mint64x2_t __riscv_th_mmaqa (mint64x2_t dest, mint16_t src1, mint16_t src2, mrow_t row1, mrow_t row2, mcol_t col);
//unsigned matrix multiply
mint32_t __riscv_th_mmaqau (mint32_t dest, muint8_t src1, muint8_t src2, mrow_t row1, mrow_t row2, mcol_t col);
mint64x2_t __riscv_th_mmaqau (mint64x2_t dest, muint16_t src1, muint16_t src2, mrow_t row1, mrow_t row2, mcol_t col);
//unsigned-signed matrix multiply
mint32_t __riscv_th_mmaqaus (mint32_t dest, muint8_t src1, mint8_t src2, mrow_t row1, mrow_t row2, mcol_t col);
mint64x2_t __riscv_th_mmaqaus (mint64x2_t dest, muint16_t src1, mint16_t src2, mrow_t row1, mrow_t row2, mcol_t col);
//signed-unsigned matrix multiply
mint32_t __riscv_th_mmaqasu (mint32_t dest, mint8_t src1, muint8_t src2, mrow_t row1, mrow_t row2, mcol_t col);
mint64x2_t __riscv_th_mmaqasu (mint64x2_t dest, mint16_t src1, muint16_t src2, mrow_t row1, mrow_t row2, mcol_t col);
Instructions
#4bit data width
#signed matrix multiply
pmmaqa.b md, ms2, ms1
#unsigned matrix multiply
pmmaqau.b md, ms2, ms1
#unsigned-signed matrix multiply
pmmaqaus.b md, ms2, ms1
#signed-unsigned matrix multiply
pmmaqasu.b md, ms2, ms1
Intrinsic functions list
//signed matrix multiply
mint32_t __riscv_th_pmmaqa (mint32_t dest, mint8_t src1, mint8_t src2, mrow_t row1, mrow_t row2, mcol_t col);
//unsigned matrix multiply
mint32_t __riscv_th_pmmaqau (mint32_t dest, muint8_t src1, muint8_t src2, mrow_t row1, mrow_t row2, mcol_t col);
//unsigned-signed matrix multiply
mint32_t __riscv_th_pmmaqaus (mint32_t dest, muint8_t src1, mint8_t src2, mrow_t row1, mrow_t row2, mcol_t col);
//signed-unsigned matrix multiply
mint32_t __riscv_th_pmmaqasu (mint32_t dest, mint8_t src1, muint8_t src2, mrow_t row1, mrow_t row2, mcol_t col);
Source:
#include <stdio.h>
#include <thead_matrix.h>
#define N 16
void __attribute__((noinline))
print_data(const char *fmt, mint32_t ma, mint32_t mb, mint32_t ans, mrow_t mrow, mcol_t mcol)
{
unsigned int row, col;
int32_t tmp_ma[N];
int32_t tmp_mb[N];
int32_t tmp_ans[N];
printf("%s:\n", fmt);
__riscv_th_mst(tmp_ma, 8, ma, mrow, mcol);
__riscv_th_mst(tmp_mb, 8, mb, mrow, mcol);
__riscv_th_mst(tmp_ans, 8, ans, mrow, mcol);
printf("ma:\t\tmb:\t\tans:\n");
for (row = 0; row < mrow; row++)
{
for (col = 0; col < mcol; col++)
{
printf("%-3d ", tmp_ma[row * mrow + col]);
}
printf("\t");
for (col = 0; col < mcol; col++)
{
printf("%-3d ", tmp_mb[row * mrow + col]);
}
printf("\t");
for (col = 0; col < mcol; col++)
{
if (tmp_ans[0] == 0)
printf("%-2d ", tmp_ans[row * mrow + col]);
else
printf("%-2d = %-2d * %-2d ", tmp_ans[row * mrow + col], tmp_ma[row * mrow + col], tmp_mb[row * mrow + col]);
}
printf("\n");
}
}
int main()
{
/* init data */
int32_t x[N] = {16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1};
int32_t y[N] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
int32_t z[N] = {0};
uint8_t msize_m = 2;
uint8_t msize_k = 2;
long stride = 2 * sizeof(int32_t); // sizeof(int32_t) * 2;
/* init matrix value*/
mint32_t ma = __riscv_th_mld(x, stride, msize_m, msize_k);
mint32_t mb = __riscv_th_mld(y, stride, msize_m, msize_k);
mint32_t ans = __riscv_th_mld(z, stride, msize_m, msize_k);
print_data("Initial value of matrix", ma, mb, ans, msize_m, msize_k);
ans = __riscv_th_mmul_mm(ma, mb, msize_m, msize_k);
print_data("Results of multiplication", ma, mb, ans, msize_m, msize_k);
return 0;
}
Compile:
riscv64-unknown-linux-gnu-gcc -static -O2 -march=rv64g_xtheadmatrix matrix-mul.c -o matrix-mul
Result:
$ qemu-riscv64 -cpu c907fdvm ./matrix-mul
Initial value of matrix:
ma: mb: ans:
16 15 1 2 0 0
14 13 3 4 0 0
Results of multiplication:
ma: mb: ans:
16 15 1 2 16 = 16 * 1 30 = 15 * 2
14 13 3 4 42 = 14 * 3 52 = 13 * 4