Skip to content

Latest commit

 

History

History
731 lines (645 loc) · 26.2 KB

rvm-intrinsic-api.adoc

File metadata and controls

731 lines (645 loc) · 26.2 KB

Introduction

This document introduces the intrinsics for RISC-V matrix programming, including the general naming rules for intrinsics, the data types of matrix, and the full set of intrinsics.

Naming Rules

  • Data type naming rules: Prefix the basic data type with 'm'. Additionally, use the suffix "x2" to denote a pair of registers, exemplified by "mint8x2_t".

  • Function interface naming rules: For simplicity, the interface is named by the instruction name and the prefix "__riscv_th_", while mm, mv, and mx are used to distinguish the source operands as matrix-matrix, matrix-vector, and matrix-scalar.

Data Types

Table 1. Matrix data types

Type

Name

Description

matrix int

mint8_t

All elements of the matrix are int8_t

mint16_t

mint32_t

mint64_t

muint8_t

All elements of the matrix are uint8_t

muint16_t

muint32_t

muint64_t

matrix float

mfloat16_t

All elements of the matrix are float16_t

mfloat32_t

mfloat64_t

matrix int x2

mint8x2_t

The type is denoted by a pair of sequential matrix registers, with the initial register being an even number.

mint16x2_t

mint32x2_t

mint64x2_t

muint8x2_t

muint16x2_t

muint32x2_t

muint64x2_t

matrix float x2

mfloat16x2_t

mfloat32x2_t

mfloat64x2_t

matrix row num

mrow_t

The type is actually size_t, which represents the number of matrix rows. Only the lower 8 bits of this type are valid.

matrix column num

mcol_t

The type is actually size_t, which represents the number of matrix columns. Only the lower 16 bits of this type are valid.

Intrinsic interface

Configuration instructions:

Instructions

mcfgi<m/n/k>  rd,uimm7
mcfg<m/n/k>    rd,rs1

Intrinsic functions list

mrow_t __riscv_th_msetmrow_m (mrow_t m);
mrow_t __riscv_th_msetmrow_n (mrow_t n);
mcol_t __riscv_th_msetmcol_e8 (mcol_t c);
mcol_t __riscv_th_msetmcol_e16 (mcol_t c);
mcol_t __riscv_th_msetmcol_e32 (mcol_t c);
mcol_t __riscv_th_msetmcol_e64 (mcol_t c);
Note
Set m, n, and k in the CSR msize to return valid values. When the value of any parameter exceeds the maximum allowable setting, only the lower bits are considered valid: 8 bits for m, 8 bits for n, and 16 bits for k.

Read/Write Matrix CSRs

Instructions

csrr	rd,xmrstart
csrr    rd,xmcsr
csrr    rd,xmsize
csrr    rd,xmlenb
csrr    rd,xrlenb
csrr    rd,xmisa

csrw    xmrstart,rs1
csrw    xmcsr,rs1
csrw    xmsize,rs1

Intrinsic functions list

enum RVM_CSR {
  RVM_XMRSTART = 0,
  RVM_XMCSR,
  RVM_XMSIZE,
  RVM_XMLENB,
  RVM_XRLENB,
  RVM_XMISA
};

unsigned long __riscv_th_mread_csr(enum RVM_CSR csr);
void __riscv_th_mwrite_csr(enum RVM_CSR csr, unsigned long value)

// specialization version
unsigned long __riscv_th_xmlenb();
unsigned long __riscv_th_xrlenb();
unsigned long __riscv_th_xmsize();

Undefined

Intrinsic functions list

mint8_t __riscv_th_mundefined_i8 ();
mint16_t __riscv_th_mundefined_i16 ();
mint32_t __riscv_th_mundefined_i32 ();
mint64_t __riscv_th_mundefined_i64 ();
muint8_t __riscv_th_mundefined_u8 ();
muint16_t __riscv_th_mundefined_u16 ();
muint32_t __riscv_th_mundefined_u32 ();
muint64_t __riscv_th_mundefined_u64 ();
mfloat16_t __riscv_th_mundefined_f16 ();
mfloat32_t __riscv_th_mundefined_f32 ();
mfloat64_t __riscv_th_mundefined_f64 ();

mint8x2_t __riscv_th_mundefined_i8x2 ();
mint16x2_t __riscv_th_mundefined_i16x2 ();
mint32x2_t __riscv_th_mundefined_i32x2 ();
mint64x2_t __riscv_th_mundefined_i64x2 ();
muint8x2_t __riscv_th_mundefined_u8x2 ();
muint16x2_t __riscv_th_mundefined_u16x2 ();
muint32x2_t __riscv_th_mundefined_u32x2 ();
muint64x2_t __riscv_th_mundefined_u64x2 ();
mfloat16x2_t __riscv_th_mundefined_f16x2 ();
mfloat32x2_t __riscv_th_mundefined_f32x2 ();
mfloat64x2_t __riscv_th_mundefined_f64x2 ();

Reinterpret Cast Conversion Functions

Intrinsic functions list

mint8_t __riscv_th_mreinterpret_i8 (src);
mint16_t __riscv_th_mreinterpret_i16 (src);
mint32_t __riscv_th_mreinterpret_i32 (src);
mint64_t __riscv_th_mreinterpret_i64 (src);
muint8_t __riscv_th_mreinterpret_u8 (src);
muint16_t __riscv_th_mreinterpret_u16 (src);
muint32_t __riscv_th_mreinterpret_u32 (src);
muint64_t __riscv_th_mreinterpret_u64 (src);
mfloat16_t __riscv_th_mreinterpret_f16 (src);
mfloat32_t __riscv_th_mreinterpret_f32 (src);
mfloat64_t __riscv_th_mreinterpret_f64 (src);

mint8x2_t __riscv_th_mreinterpret_i8x2 (src);
mint16x2_t __riscv_th_mreinterpret_i16x2 (src);
mint32x2_t __riscv_th_mreinterpret_i32x2 (src);
mint64x2_t __riscv_th_mreinterpret_i64x2 (src);
muint8x2_t __riscv_th_mreinterpret_u8x2 (src);
muint16x2_t __riscv_th_mreinterpret_u16x2 (src);
muint32x2_t __riscv_th_mreinterpret_u32x2 (src);
muint64x2_t __riscv_th_mreinterpret_u64x2 (src);
mfloat16x2_t __riscv_th_mreinterpret_f16x2 (src);
mfloat32x2_t __riscv_th_mreinterpret_f32x2 (src);
mfloat64x2_t __riscv_th_mreinterpret_f64x2 (src);
Note
The type of SRC can be any matrix type with the same number of registers.

Mzero

Instructions

mzero rd

Intrinsic functions list

mint8_t __riscv_th_mzero_i8 ();
mint16_t __riscv_th_mzero_i16 ();
mint32_t __riscv_th_mzero_i32 ();
mint64_t __riscv_th_mzero_i64 ();
muint8_t __riscv_th_mzero_u8 ();
muint16_t __riscv_th_mzero_u16 ();
muint32_t __riscv_th_mzero_u32 ();
muint64_t __riscv_th_mzero_u64 ();
mfloat16_t __riscv_th_mzero_f16 ();
mfloat32_t __riscv_th_mzero_f32 ();
mfloat64_t __riscv_th_mzero_f64 ();

mint8x2_t __riscv_th_mzero_i8x2 ();
mint16x2_t __riscv_th_mzero_i16x2 ();
mint32x2_t __riscv_th_mzero_i32x2 ();
mint64x2_t __riscv_th_mzero_i64x2 ();
muint8x2_t __riscv_th_mzero_u8x2 ();
muint16x2_t __riscv_th_mzero_u16x2 ();
muint32x2_t __riscv_th_mzero_u32x2 ();
muint64x2_t __riscv_th_mzero_u64x2 ();
mfloat16x2_t __riscv_th_mzero_f16x2 ();
mfloat32x2_t __riscv_th_mzero_f32x2 ();
mfloat64x2_t __riscv_th_mzero_f64x2 ();
Note
Zero all elements of matrix register.

Load and store instructions

Load

Instructions

#matrix load
mld<b/h/w/d> md, rs2, (rs1)

#stream matrix load
msld<b/h/w/d>  md, rs2, (rs1)

Intrinsic functions list

//matrix load
mint8_t __riscv_th_mld (const int8_t *base, long stride, mrow_t row, mcol_t col);
muint8_t __riscv_th_mld (const uint8_t *base, long stride, mrow_t row, mcol_t col);
mint16_t __riscv_th_mld (const int16_t *base, long stride, mrow_t row, mcol_t col);
muint16_t __riscv_th_mld (const uint16_t *base, long stride, mrow_t row, mcol_t col);
mint32_t __riscv_th_mld (const int32_t *base, long stride, mrow_t row, mcol_t col);
muint32_t __riscv_th_mld (const uint32_t *base, long stride, mrow_t row, mcol_t col);
mint64_t __riscv_th_mld (const int64_t *base, long stride, mrow_t row, mcol_t col);
muint64_t __riscv_th_mld (const uint64_t *base, long stride, mrow_t row, mcol_t col);
mfloat16_t __riscv_th_mld (const float16_t *base, long stride, mrow_t row, mcol_t col);
mfloat32_t __riscv_th_mld (const float32_t *base, long stride, mrow_t row, mcol_t col);
mfloat64_t __riscv_th_mld (const float64_t *base, long stride, mrow_t row, mcol_t col);

//stream matrix load
mint8_t __riscv_th_msld (const int8_t *base, long stride, mrow_t row, mcol_t col);
muint8_t __riscv_th_msld (const uint8_t *base, long stride, mrow_t row, mcol_t col);
mint16_t __riscv_th_msld (const int16_t *base, long stride, mrow_t row, mcol_t col);
muint16_t __riscv_th_msld (const uint16_t *base, long stride, mrow_t row, mcol_t col);
mint32_t __riscv_th_msld (const int32_t *base, long stride, mrow_t row, mcol_t col);
muint32_t __riscv_th_msld (const uint32_t *base, long stride, mrow_t row, mcol_t col);
mint64_t __riscv_th_msld (const int64_t *base, long stride, mrow_t row, mcol_t col);
muint64_t __riscv_th_msld (const uint64_t *base, long stride, mrow_t row, mcol_t col);
mfloat16_t __riscv_th_msld (const float16_t *base, long stride, mrow_t row, mcol_t col);
mfloat32_t __riscv_th_msld (const float32_t *base, long stride, mrow_t row, mcol_t col);
mfloat64_t __riscv_th_msld (const float64_t *base, long stride, mrow_t row, mcol_t col);
Note
Read from the memory to the matrix register: The input parameter is the memory base address, stride, and the return value is the target matrix.

Store

Instructions

#matrix store
mst<b/h/w/d>  ms3, rs2, (rs1)

#stream matrix store
msst<b/h/w/d>  ms3, rs2, (rs1)

Intrinsic functions list

//matrix store
void __riscv_th_mst (const int8_t *base, long stride, mint8_t value, mrow_t row, mcol_t col);
void __riscv_th_mst (const uint8_t *base, long stride, muint8_t value, mrow_t row, mcol_t col);
void __riscv_th_mst (const int16_t *base, long stride, mint16_t value, mrow_t row, mcol_t col);
void __riscv_th_mst (const uint16_t *base, long stride, muint16_t value, mrow_t row, mcol_t col);
void __riscv_th_mst (const int32_t *base, long stride, mint32_t value, mrow_t row, mcol_t col);
void __riscv_th_mst (const uint32_t *base, long stride, muint32_t value, mrow_t row, mcol_t col);
void __riscv_th_mst (const int64_t *base, long stride, mint64_t value, mrow_t row, mcol_t col);
void __riscv_th_mst (const uint64_t *base, long stride, muint64_t value, mrow_t row, mcol_t col);
void __riscv_th_mst (const float16_t *base, long stride, mfloat16_t value, mrow_t row, mcol_t col);
void __riscv_th_mst (const float32_t *base, long stride, mfloat32_t value, mrow_t row, mcol_t col);
void __riscv_th_mst (const float64_t *base, long stride, mfloat64_t value, mrow_t row, mcol_t col);

//stream matrix store
void __riscv_th_msst (const int8_t *base, long stride, mint8_t value, mrow_t row, mcol_t col);
void __riscv_th_msst (const uint8_t *base, long stride, muint8_t value, mrow_t row, mcol_t col);
void __riscv_th_msst (const int16_t *base, long stride, mint16_t value, mrow_t row, mcol_t col);
void __riscv_th_msst (const uint16_t *base, long stride, muint16_t value, mrow_t row, mcol_t col);
void __riscv_th_msst (const int32_t *base, long stride, mint32_t value, mrow_t row, mcol_t col);
void __riscv_th_msst (const uint32_t *base, long stride, muint32_t value, mrow_t row, mcol_t col);
void __riscv_th_msst (const int64_t *base, long stride, mint64_t value, mrow_t row, mcol_t col);
void __riscv_th_msst (const uint64_t *base, long stride, muint64_t value, mrow_t row, mcol_t col);
void __riscv_th_msst (const float16_t *base, long stride, mfloat16_t value, mrow_t row, mcol_t col);
void __riscv_th_msst (const float32_t *base, long stride, mfloat32_t value, mrow_t row, mcol_t col);
void __riscv_th_msst (const float64_t *base, long stride, mfloat64_t value, mrow_t row, mcol_t col);
Note
Write the matrix register data into the memory, and the input parameter is the destination base address, stride, and the original operand.

Mov instructions

Instructions

#matrix-matrix mov
mmov.mm md, ms1

#matrix-vector add,rs1'/uimm3
mmov.mv.x md, ms1[rs1']
mmov.mv.i md, ms1[uimm3]

#matrix-scalar mov with duplicate
mdup<b/h/w/d>.m.x md, rs2

#matrix-scalar mov
mmov<b/h/w/d>.m.x md, rs2, rs1

mmov<b/h/w/d>.x.m rd, ms2, rs1

Intrinsic functions list

//matrix-vector mov,rs1/uimm3
mint8_t __riscv_th_mmov_mv (mint8_t src, size_t index);
muint8_t __riscv_th_mmov_mv (muint8_t src, size_t index);
mint16_t __riscv_th_mmov_mv (mint16_t src, size_t index);
muint16_t __riscv_th_mmov_mv (muint16_t src, size_t index);
mint32_t __riscv_th_mmov_mv (mint32_t src, size_t index);
muint32_t __riscv_th_mmov_mv (muint32_t src, size_t index);
mint64_t __riscv_th_mmov_mv (mint64_t src, size_t index);
muint64_t __riscv_th_mmov_mv (muint64_t src, size_t index);
mfloat16_t __riscv_th_mmov_mv (mfloat16_t src, size_t index);
mfloat32_t __riscv_th_mmov_mv (mfloat32_t src, size_t index);
mfloat64_t __riscv_th_mmov_mv (mfloat64_t src, size_t index);

// matrix-scalar mov with duplicate
mint8_t __riscv_th_mdup_m_x (int8_t src);
muint8_t __riscv_th_mdup_m_x (uint8_t src);
mint16_t __riscv_th_mdup_m_x (int16_t src);
muint16_t __riscv_th_mdup_m_x (uint16_t src);
mint32_t __riscv_th_mdup_m_x (int32_t src);
muint32_t __riscv_th_mdup_m_x (uint32_t src);
mint64_t __riscv_th_mdup_m_x (int64_t src);
muint64_t __riscv_th_mdup_m_x (uint64_t src);

// matrix-scalar mov
mint8_t __riscv_th_mmov_m_x (mint8_t dest, int8_t src, size_t index);
muint8_t __riscv_th_mmov_m_x (muint8_t dest, uint8_t src, size_t index);
mint16_t __riscv_th_mmov_m_x (mint16_t dest, int16_t src, size_t index);
muint16_t __riscv_th_mmov_m_x (muint16_t dest, uint16_t src, size_t index);
mint32_t __riscv_th_mmov_m_x (mint32_t dest, int32_t src, size_t index);
muint32_t __riscv_th_mmov_m_x (muint32_t dest, uint32_t src, size_t index);
mint64_t __riscv_th_mmov_m_x (mint64_t dest, int64_t src, size_t index);
muint64_t __riscv_th_mmov_m_x (muint64_t dest, uint64_t src, size_t index);

int8_t __riscv_th_mmov_x_m (mint8_t src, size_t index);
uint8_t __riscv_th_mmov_x_m (muint8_t src, size_t index);
int16_t __riscv_th_mmov_x_m (mint16_t src, size_t index);
uint16_t __riscv_th_mmov_x_m (muint16_t src, size_t index);
int32_t __riscv_th_mmov_x_m (mint32_t src, size_t index);
uint32_t __riscv_th_mmov_x_m (muint32_t src, size_t index);
int64_t __riscv_th_mmov_x_m (mint64_t src, size_t index);
uint64_t __riscv_th_mmov_x_m (muint64_t src, size_t index);

Tuple instructions

Intrinsic functions list

// matrix tuple
mint8x2_t    __riscv_th_mset (mint8x2_t    src, size_t index, mint8_t    value);
mint16x2_t   __riscv_th_mset (mint16x2_t   src, size_t index, mint16_t   value);
mint32x2_t   __riscv_th_mset (mint32x2_t   src, size_t index, mint32_t   value);
mint64x2_t   __riscv_th_mset (mint64x2_t   src, size_t index, mint64_t   value);
muint8x2_t   __riscv_th_mset (muint8x2_t   src, size_t index, muint8_t   value);
muint16x2_t  __riscv_th_mset (muint16x2_t  src, size_t index, muint16_t  value);
muint32x2_t  __riscv_th_mset (muint32x2_t  src, size_t index, muint32_t  value);
muint64x2_t  __riscv_th_mset (muint64x2_t  src, size_t index, muint64_t  value);
mfloat16x2_t __riscv_th_mset (mfloat16x2_t src, size_t index, mfloat16_t value);
mfloat32x2_t __riscv_th_mset (mfloat32x2_t src, size_t index, mfloat32_t value);
mfloat64x2_t __riscv_th_mset (mfloat64x2_t src, size_t index, mfloat64_t value);

mint8_t    __riscv_th_mget (mint8x2_t    src, size_t index);
mint16_t   __riscv_th_mget (mint16x2_t   src, size_t index);
mint32_t   __riscv_th_mget (mint32x2_t   src, size_t index);
mint64_t   __riscv_th_mget (mint64x2_t   src, size_t index);
muint8_t   __riscv_th_mget (muint8x2_t   src, size_t index);
muint16_t  __riscv_th_mget (muint16x2_t  src, size_t index);
muint32_t  __riscv_th_mget (muint32x2_t  src, size_t index);
muint64_t  __riscv_th_mget (muint64x2_t  src, size_t index);
mfloat16_t __riscv_th_mget (mfloat16x2_t src, size_t index);
mfloat32_t __riscv_th_mget (mfloat32x2_t src, size_t index);
mfloat64_t __riscv_th_mget (mfloat64x2_t src, size_t index);
Note
The INDEX argument must be provided as a constant integer expression.

Matrix Integer Operation Instruction

Add

Instructions

#matrix-matrix add
madd.<s/d>.mm md, ms2, ms1

#matrix-vector add,rs1/uimm6
madd.<s/d>.mv.x md, ms2, ms1[rs1]
madd.<s/d>.mv.i md, ms2, ms1[uimm3]

#matrix-scalar add
madd.<s/d>.mx md, ms2, rs1

Intrinsic functions list

//matrix-matrix add
mint32_t __riscv_th_madd_mm (mint32_t src1, mint32_t src2, mrow_t row, mcol_t col);
mint64_t __riscv_th_madd_mm (mint64_t src1, mint64_t src2, mrow_t row, mcol_t col);

//matrix-vector add,rs1/uimm6
mint32_t __riscv_th_madd_mv (mint32_t src1, mint32_t src2, size_t index, mrow_t row, mcol_t col);
mint64_t __riscv_th_madd_mv (mint64_t src1, mint64_t src2, size_t index, mrow_t row, mcol_t col);

//matrix-scalar add
mint32_t __riscv_th_madd_mx (mint32_t src1, int32_t src2, mrow_t row, mcol_t col);
mint64_t __riscv_th_madd_mx (mint64_t src1, int64_t src2, mrow_t row, mcol_t col);

Sub

Instructions

#matrix-matrix sub
msub.<s/d>.mm md, ms2, ms1

#matrix-vector sub,rs1/uimm6
msub.<s/d>.mv.x md, ms2, ms1[rs1]
msub.<s/d>.mv.i md, ms2, ms1[uimm3]

#matrix-scalar sub
msub.<s/d>.mx md, ms2, rs1

Intrinsic functions list

//matrix-matrix sub
mint32_t __riscv_th_msub_mm (mint32_t src1, mint32_t src2, mrow_t row, mcol_t col);
mint64_t __riscv_th_msub_mm (mint64_t src1, mint64_t src2, mrow_t row, mcol_t col);

//matrix-vector sub,rs1/uimm6
mint32_t __riscv_th_msub_mv (mint32_t src1, mint32_t src2, size_t index, mrow_t row, mcol_t col);
mint64_t __riscv_th_msub_mv (mint64_t src1, mint64_t src2, size_t index, mrow_t row, mcol_t col);

//matrix-scalar sub
mint32_t __riscv_th_msub_mx (mint32_t src1, int32_t src2, mrow_t row, mcol_t col);
mint64_t __riscv_th_msub_mx (mint64_t src1, int64_t src2, mrow_t row, mcol_t col);

Shift

Msra

Instructions

#matrix-matrix shift
msra.<s/d>.mm md, ms2, ms1

#matrix-vector shift,rs1
msra.<s/d>.mv.x md, ms2, ms1[rs1]
msra.<s/d>.mv.i md, ms2, ms1[uimm3]

#matrix-scalar shift
msra.<s/d>.mx md, ms2, rs1

Intrinsic functions list

//matrix-matrix sra
mint32_t __riscv_th_msra_mm (mint32_t src1, muint32_t src2, mrow_t row, mcol_t col);
mint64_t __riscv_th_msra_mm (mint64_t src1, muint64_t src2, mrow_t row, mcol_t col);

//matrix-vector sra,rs1/uimm6
mint32_t __riscv_th_msra_mv (mint32_t src1, muint32_t src2, size_t index, mrow_t row, mcol_t col);
mint64_t __riscv_th_msra_mv (mint64_t src1, muint64_t src2, size_t index, mrow_t row, mcol_t col);

//matrix-scalar sra
mint32_t __riscv_th_msra_mx (mint32_t src1, uint32_t src2, mrow_t row, mcol_t col);
mint64_t __riscv_th_msra_mx (mint64_t src1, uint64_t src2, mrow_t row, mcol_t col);
Mn4clip/Mn4clipu

Instructions

#matrix-matrix signed clip
mn4clip.<s/d>.mm md, ms2, ms1

#matrix-vector clip,rs0
mn4clip.<s/d>.mv.x md, ms2, ms1[rs1]
mn4clip.<s/d>.mv.i md, ms2, ms1[uimm3]

#matrix-scalar clip
mn4clip.<s/d>.mx md, ms2, rs1


#matrix-matrix unsigned clip
mn4clipu.<s/d>.mm md, ms2, ms1

#matrix-vector clip,rs0
mn4clipu.<s/d>.mv.x md, ms2, ms1[rs1]
mn4clipu.<s/d>.mv.i md, ms2, ms1[uimm3]

#matrix-scalar clip
mn4clipu.<s/d>.mx md, ms2, rs1

Intrinsic functions list

//matrix-matrix signed clip
mint8_t __riscv_th_mn4clip_mm (mint32_t src1, muint32_t src2, mrow_t row, mcol_t col);
mint8_t __riscv_th_mn4clip_mm (mint64_t src1, muint64_t src2, mrow_t row, mcol_t col);

//matrix-vector clip,rs1/uimm3
mint8_t __riscv_th_mn4clip_mv (mint32_t src1, muint32_t src2, size_t index, mrow_t row, mcol_t col);
mint8_t __riscv_th_mn4clip_mv (mint64_t src1, muint64_t src2, size_t index, mrow_t row, mcol_t col);

//matrix-scalar clip
mint8_t __riscv_th_mn4clip_mx (mint32_t src1, uint32_t src2, mrow_t row, mcol_t col);
mint8_t __riscv_th_mn4clip_mx (mint64_t src1, uint64_t src2, mrow_t row, mcol_t col);

//matrix-matrix unsigned clip
muint8_t __riscv_th_mn4clipu_mm (muint32_t src1, muint32_t src2, mrow_t row, mcol_t col);
muint8_t __riscv_th_mn4clipu_mm (muint64_t src1, muint64_t src2, mrow_t row, mcol_t col);

//matrix-vector clip,rs1/uimm3
muint8_t __riscv_th_mn4clipu_mv (muint32_t src1, muint32_t src2, size_t index, mrow_t row, mcol_t col);
muint8_t __riscv_th_mn4clipu_mv (muint64_t src1, muint64_t src2, size_t index, mrow_t row, mcol_t col);

//matrix-scalar clip
muint8_t __riscv_th_mn4clipu_mx (muint32_t src1, uint32_t src2, mrow_t row, mcol_t col);
muint8_t __riscv_th_mn4clipu_mx (muint64_t src1, uint64_t src2, mrow_t row, mcol_t col);

Multiply Instruction

Low-half-reserved multiplication

Instructions

#matrix-matrix mul
mmul.<s/d>.mx md, ms2, ms1

#matrix-vector mul, rs1
mmul.<s/d>.mv.x md, ms2, ms1[rs1]
mmul.<s/d>.mv.i md, ms2, ms1[uimm3]

#matrix-scalar mul
mmul.<s/d>.mx md, ms2, rs1

Intrinsic functions list

//matrix-matrix mul
mint32_t __riscv_th_mmul_mm (mint32_t src1, mint32_t src2, mrow_t row, mcol_t col);
mint64_t __riscv_th_mmul_mm (mint64_t src1, mint64_t src2, mrow_t row, mcol_t col);

//matrix-vector mul,rs1/uimm3
mint32_t __riscv_th_mmul_mv (mint32_t src1, mint32_t src2, size_t index, mrow_t row, mcol_t col);
mint64_t __riscv_th_mmul_mv (mint64_t src1, mint64_t src2, size_t index, mrow_t row, mcol_t col);

//matrix-scalar mul
mint32_t __riscv_th_mmul_mx (mint32_t src1, int32_t src2, mrow_t row, mcol_t col);
mint64_t __riscv_th_mmul_mx (mint64_t src1, int64_t src2, mrow_t row, mcol_t col);

keep the low-half of the 64-bit result.

High-half-reserved multiplication

Instructions

#matrix-matrix mul
mmulh.s.mx md, ms2, ms1

#matrix-vector mul, rs1
mmulh.s.mv.x md, ms2, ms1[rs1]
mmulh.s.mv.i md, ms2, ms1[uimm3]

#matrix-scalar mul
mmulh.s.mx md, ms2, rs1

Intrinsic functions list

//matrix-matrix mulh
mint32_t __riscv_th_mmulh_mm (mint32_t src1, mint32_t src2, mrow_t row, mcol_t col);

//matrix-vector mulh,rs1/uimm3
mint32_t __riscv_th_mmulh_mv (mint32_t src1, mint32_t src2, size_t index, mrow_t row, mcol_t col);

//matrix-scalar mulh
mint32_t __riscv_th_mmulh_mx (mint32_t src1, int32_t src2, mrow_t row, mcol_t col);
Note
High-half of the 64-bit result reserved.

Matrix Multiplication Instruction

The DEST represents the previous value of the return value, which requires initialization in the absence of an old value to prevent the appearance of unknown data. Furthermore, both the SRC1 and SRC2 serve as multipliers.

Floating point Matrix Multiplication

Fmmacc

Instructions

#matrix-matrix
fmmacc.<h/s/d> md, ms2, ms1

Intrinsic functions list

//matrix-matrix
mfloat16_t __riscv_th_fmmacc (mfloat16_t dest, mfloat16_t src1, mfloat16x2_t src2, mrow_t row1, mrow_t row2, mcol_t col);
mfloat32_t __riscv_th_fmmacc (mfloat32_t dest, mfloat32_t src1, mfloat32_t src2, mrow_t row1, mrow_t row2, mcol_t col);
mfloat64x2_t __riscv_th_fmmacc (mfloat64x2_t dest, mfloat64_t src1, mfloat64_t src2, mrow_t row1, mrow_t row2, mcol_t col);
Fwmmacc

Instructions

#matrix-matrix
fwmmacc.<h/s> md, ms2, ms1

Intrinsic functions list

//matrix-matrix
mfloat32_t __riscv_th_fwmmacc (mfloat32_t dest, mfloat16_t src1, mfloat16_t src2, mrow_t row1, mrow_t row2, mcol_t col);
mfloat64x2_t __riscv_th_fwmmacc (mfloat64x2_t dest, mfloat32_t src1, mfloat32_t src2, mrow_t row1, mrow_t row2, mcol_t col);

Integer 4x Extension Matrix Multiplication

Mmaqa

Instructions

#8bit data width
#signed matrix multiply
mmaqa.<b/h> md, ms2, ms1

#unsigned matrix multiply
mmaqau.<b/h> md, ms2, ms1

#unsigned-signed matrix multiply
mmaqaus.<b/h> md, ms2, ms1

#signed-unsigned matrix multiply
mmaqasu.<b/h> md, ms2, ms1

Intrinsic functions list

//signed matrix multiply
mint32_t __riscv_th_mmaqa (mint32_t dest, mint8_t src1, mint8_t src2, mrow_t row1, mrow_t row2, mcol_t col);
mint64x2_t __riscv_th_mmaqa (mint64x2_t dest, mint16_t src1, mint16_t src2, mrow_t row1, mrow_t row2, mcol_t col);

//unsigned matrix multiply
mint32_t __riscv_th_mmaqau (mint32_t dest, muint8_t src1, muint8_t src2, mrow_t row1, mrow_t row2, mcol_t col);
mint64x2_t __riscv_th_mmaqau (mint64x2_t dest, muint16_t src1, muint16_t src2, mrow_t row1, mrow_t row2, mcol_t col);

//unsigned-signed matrix multiply
mint32_t __riscv_th_mmaqaus (mint32_t dest, muint8_t src1, mint8_t src2, mrow_t row1, mrow_t row2, mcol_t col);
mint64x2_t __riscv_th_mmaqaus (mint64x2_t dest, muint16_t src1, mint16_t src2, mrow_t row1, mrow_t row2, mcol_t col);

//signed-unsigned matrix multiply
mint32_t __riscv_th_mmaqasu (mint32_t dest, mint8_t src1, muint8_t src2, mrow_t row1, mrow_t row2, mcol_t col);
mint64x2_t __riscv_th_mmaqasu (mint64x2_t dest, mint16_t src1, muint16_t src2, mrow_t row1, mrow_t row2, mcol_t col);
Pmmaqa

Instructions

#4bit data width
#signed matrix multiply
pmmaqa.b md, ms2, ms1

#unsigned matrix multiply
pmmaqau.b md, ms2, ms1

#unsigned-signed matrix multiply
pmmaqaus.b md, ms2, ms1

#signed-unsigned matrix multiply
pmmaqasu.b md, ms2, ms1

Intrinsic functions list

//signed matrix multiply
mint32_t __riscv_th_pmmaqa (mint32_t dest, mint8_t src1, mint8_t src2, mrow_t row1, mrow_t row2, mcol_t col);

//unsigned matrix multiply
mint32_t __riscv_th_pmmaqau (mint32_t dest, muint8_t src1, muint8_t src2, mrow_t row1, mrow_t row2, mcol_t col);

//unsigned-signed matrix multiply
mint32_t __riscv_th_pmmaqaus (mint32_t dest, muint8_t src1, mint8_t src2, mrow_t row1, mrow_t row2, mcol_t col);

//signed-unsigned matrix multiply
mint32_t __riscv_th_pmmaqasu (mint32_t dest, mint8_t src1, muint8_t src2, mrow_t row1, mrow_t row2, mcol_t col);

Mrelease

Instructions

mrelease

Intrinsic functions list

void __riscv_th_mrelease();

Example

Source:

#include <stdio.h>
#include <thead_matrix.h>
#define N 16

void __attribute__((noinline))
print_data(const char *fmt, mint32_t ma, mint32_t mb, mint32_t ans, mrow_t mrow, mcol_t mcol)
{
  unsigned int row, col;
  int32_t tmp_ma[N];
  int32_t tmp_mb[N];
  int32_t tmp_ans[N];

  printf("%s:\n", fmt);

  __riscv_th_mst(tmp_ma, 8, ma, mrow, mcol);
  __riscv_th_mst(tmp_mb, 8, mb, mrow, mcol);
  __riscv_th_mst(tmp_ans, 8, ans, mrow, mcol);

  printf("ma:\t\tmb:\t\tans:\n");
  for (row = 0; row < mrow; row++)
  {
    for (col = 0; col < mcol; col++)
    {
      printf("%-3d ", tmp_ma[row * mrow + col]);
    }
    printf("\t");
    for (col = 0; col < mcol; col++)
    {
      printf("%-3d ", tmp_mb[row * mrow + col]);
    }
    printf("\t");
    for (col = 0; col < mcol; col++)
    {
      if (tmp_ans[0] == 0)
        printf("%-2d ", tmp_ans[row * mrow + col]);
      else
        printf("%-2d = %-2d * %-2d  ", tmp_ans[row * mrow + col], tmp_ma[row * mrow + col], tmp_mb[row * mrow + col]);
    }
    printf("\n");
  }
}

int main()
{
  /* init data */
  int32_t x[N] = {16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1};
  int32_t y[N] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
  int32_t z[N] = {0};

  uint8_t msize_m = 2;
  uint8_t msize_k = 2;
  long stride = 2 * sizeof(int32_t); // sizeof(int32_t) * 2;

  /* init matrix value*/
  mint32_t ma = __riscv_th_mld(x, stride, msize_m, msize_k);
  mint32_t mb = __riscv_th_mld(y, stride, msize_m, msize_k);
  mint32_t ans = __riscv_th_mld(z, stride, msize_m, msize_k);

  print_data("Initial value of matrix", ma, mb, ans, msize_m, msize_k);

  ans = __riscv_th_mmul_mm(ma, mb, msize_m, msize_k);

  print_data("Results of multiplication", ma, mb, ans, msize_m, msize_k);
  return 0;
}

Compile:

riscv64-unknown-linux-gnu-gcc -static -O2 -march=rv64g_xtheadmatrix matrix-mul.c -o matrix-mul

Result:

$ qemu-riscv64 -cpu c907fdvm ./matrix-mul
Initial value of matrix:
ma:             mb:             ans:
16  15          1   2           0  0
14  13          3   4           0  0
Results of multiplication:
ma:             mb:             ans:
16  15          1   2           16 = 16 * 1   30 = 15 * 2
14  13          3   4           42 = 14 * 3   52 = 13 * 4