diff --git a/llamafile/sgemm_q0q0s_dotprod.cpp b/llamafile/sgemm_q0q0s_dotprod.cpp index 184b70e6d4..c737f66c68 100644 --- a/llamafile/sgemm_q0q0s_dotprod.cpp +++ b/llamafile/sgemm_q0q0s_dotprod.cpp @@ -56,8 +56,7 @@ class GEMMERQ0ARM { mp = m0 + (m - m0) / mc * mc; np = n0 + (n - n0) / nc * nc; mnpack(mp, m, n0, np); - mnpack(m0, mp, np, n); - mnpack(mp, m, np, n); + mnpack(m0, m, np, n); } dontinline void gemm3x3(int m0, int m, int n0, int n) { diff --git a/llamafile/sgemmer.inc b/llamafile/sgemmer.inc index aa9fad16f7..762cafaa13 100644 --- a/llamafile/sgemmer.inc +++ b/llamafile/sgemmer.inc @@ -15,6 +15,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "llama.cpp/ggml.h" #include "hsum.h" @@ -34,36 +36,155 @@ class SGEMMER { } private: - dontinline void mnpack(int m0, int m, int n0, int n) { + void mnpack(int m0, int m, int n0, int n) { int mc, nc, mp, np; - if (m - m0 <= 0 || n - n0 <= 0) - return; - if (VECTOR_REGISTERS >= 32 && m - m0 >= 8 && n - n0 >= 3) { - mc = 8; + switch ((std::min(m - m0, 5) << 4) | std::min(n - n0, 5)) { +#if VECTOR_REGISTERS == 32 + case 0x55: + mc = 5; + nc = 5; + gemm<5, 5>(m0, m, n0, n); + break; + case 0x45: + mc = 4; + nc = 5; + gemm<4, 5>(m0, m, n0, n); + break; + case 0x54: + mc = 5; + nc = 4; + gemm<5, 4>(m0, m, n0, n); + break; + case 0x44: + mc = 4; + nc = 4; + gemm<4, 4>(m0, m, n0, n); + break; + case 0x53: + mc = 5; nc = 3; - gemm<8, 3>(m0, m, n0, n); - } else if (m - m0 >= 4 && n - n0 >= 3) { + gemm<5, 3>(m0, m, n0, n); + break; + case 0x35: + mc = 3; + nc = 5; + gemm<3, 5>(m0, m, n0, n); + break; + case 0x43: mc = 4; nc = 3; gemm<4, 3>(m0, m, n0, n); - } else if (n - n0 >= 4) { - mc = 1; + break; +#else + case 0x55: + case 0x54: + case 0x53: + case 0x45: + case 0x44: + case 0x43: + mc = 4; + nc = 3; + gemm<4, 3>(m0, m, n0, n); + break; + case 0x35: +#endif + case 0x34: + mc = 3; nc = 4; - gemm<1, 4>(m0, m, n0, n); - } else if (m - m0 >= 4) { + gemm<3, 4>(m0, m, n0, n); + break; + case 0x52: + mc = 5; + nc = 2; + gemm<5, 2>(m0, m, n0, n); + break; + case 0x33: + mc = 3; + nc = 3; + gemm<3, 3>(m0, m, n0, n); + break; + case 0x25: + mc = 2; + nc = 5; + gemm<2, 5>(m0, m, n0, n); + break; + case 0x42: + mc = 4; + nc = 2; + gemm<4, 2>(m0, m, n0, n); + break; + case 0x24: + mc = 2; + nc = 4; + gemm<2, 4>(m0, m, n0, n); + break; + case 0x32: + mc = 3; + nc = 2; + gemm<3, 2>(m0, m, n0, n); + break; + case 0x23: + mc = 2; + nc = 3; + gemm<2, 3>(m0, m, n0, n); + break; + case 0x51: + mc = 5; + nc = 1; + gemm<5, 1>(m0, m, n0, n); + break; + case 0x41: mc = 4; nc = 1; gemm<4, 1>(m0, m, n0, n); - } else { + break; + case 0x22: + mc = 2; + nc = 2; + gemm<2, 2>(m0, m, n0, n); + break; + case 0x15: + mc = 1; + nc = 5; + gemm<1, 5>(m0, m, n0, n); + break; + case 0x14: + mc = 1; + nc = 4; + gemm<1, 4>(m0, m, n0, n); + break; + case 0x31: + mc = 3; + nc = 1; + gemm<3, 1>(m0, m, n0, n); + break; + case 0x13: + mc = 1; + nc = 3; + gemm<1, 3>(m0, m, n0, n); + break; + case 0x21: + mc = 2; + nc = 1; + gemm<2, 1>(m0, m, n0, n); + break; + case 0x12: + mc = 1; + nc = 2; + gemm<1, 2>(m0, m, n0, n); + break; + case 0x11: mc = 1; nc = 1; gemm<1, 1>(m0, m, n0, n); + break; + default: + return; } mp = m0 + (m - m0) / mc * mc; np = n0 + (n - n0) / nc * nc; mnpack(mp, m, n0, np); - mnpack(m0, mp, np, n); - mnpack(mp, m, np, n); + mnpack(m0, m, np, n); } template dontinline void gemm(int m0, int m, int n0, int n) { diff --git a/llamafile/sgemmer0.inc b/llamafile/sgemmer0.inc index 68fecf7afc..2475875e76 100644 --- a/llamafile/sgemmer0.inc +++ b/llamafile/sgemmer0.inc @@ -15,6 +15,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "llama.cpp/ggml-impl.h" #include "llama.cpp/ggml.h" @@ -35,217 +37,146 @@ class SGEMMER0 { } private: - dontinline void mnpack(int m0, int m, int n0, int n) { - if (m - m0 <= 0 || n - n0 <= 0) - return; + void mnpack(int m0, int m, int n0, int n) { int mc, nc, mp, np; - if (m - m0 >= 4 && n - n0 >= 3) { + switch ((std::min(m - m0, 4) << 4) | std::min(n - n0, 4)) { +#if VECTOR_REGISTERS == 32 + case 0x44: + mc = 4; + nc = 4; + gemm<4, 4>(m0, m, n0, n); + break; + case 0x43: + mc = 4; + nc = 3; + gemm<4, 3>(m0, m, n0, n); + break; + case 0x34: + mc = 3; + nc = 4; + gemm<3, 4>(m0, m, n0, n); + break; + case 0x33: + mc = 3; + nc = 3; + gemm<3, 3>(m0, m, n0, n); + break; + case 0x42: + mc = 4; + nc = 2; + gemm<4, 2>(m0, m, n0, n); + break; + case 0x24: + mc = 2; + nc = 4; + gemm<2, 4>(m0, m, n0, n); + break; +#else + case 0x44: + case 0x43: + case 0x42: mc = 4; + nc = 2; + gemm<4, 2>(m0, m, n0, n); + break; + case 0x34: + case 0x24: + mc = 2; + nc = 4; + gemm<2, 4>(m0, m, n0, n); + break; + case 0x33: +#endif + case 0x32: + mc = 3; + nc = 2; + gemm<3, 2>(m0, m, n0, n); + break; + case 0x23: + mc = 2; nc = 3; - gemm4x3(m0, m, n0, n); - } else if (m - m0 >= 4 && n - n0 >= 1) { + gemm<2, 3>(m0, m, n0, n); + break; + case 0x41: mc = 4; nc = 1; - gemm4x1(m0, m, n0, n); - } else if (m - m0 >= 1 && n - n0 >= 4) { + gemm<4, 1>(m0, m, n0, n); + break; + case 0x22: + mc = 2; + nc = 2; + gemm<2, 2>(m0, m, n0, n); + break; + case 0x14: mc = 1; nc = 4; - gemm1x4(m0, m, n0, n); - } else { + gemm<1, 4>(m0, m, n0, n); + break; + case 0x31: + mc = 3; + nc = 1; + gemm<3, 1>(m0, m, n0, n); + break; + case 0x13: + mc = 1; + nc = 3; + gemm<1, 3>(m0, m, n0, n); + break; + case 0x21: + mc = 2; + nc = 1; + gemm<2, 1>(m0, m, n0, n); + break; + case 0x12: + mc = 1; + nc = 2; + gemm<1, 2>(m0, m, n0, n); + break; + case 0x11: mc = 1; nc = 1; - gemm1x1(m0, m, n0, n); + gemm<1, 1>(m0, m, n0, n); + break; + default: + return; } mp = m0 + (m - m0) / mc * mc; np = n0 + (n - n0) / nc * nc; mnpack(mp, m, n0, np); - mnpack(m0, mp, np, n); - mnpack(mp, m, np, n); - } - - dontinline void gemm4x3(int m0, int m, int n0, int n) { - BEGIN_KERNEL(4, 3) - __m256 c00 = _mm256_setzero_ps(); - __m256 c10 = _mm256_setzero_ps(); - __m256 c20 = _mm256_setzero_ps(); - __m256 c30 = _mm256_setzero_ps(); - __m256 c01 = _mm256_setzero_ps(); - __m256 c11 = _mm256_setzero_ps(); - __m256 c21 = _mm256_setzero_ps(); - __m256 c31 = _mm256_setzero_ps(); - __m256 c02 = _mm256_setzero_ps(); - __m256 c12 = _mm256_setzero_ps(); - __m256 c22 = _mm256_setzero_ps(); - __m256 c32 = _mm256_setzero_ps(); - const TA *Ap0 = A + lda * (i + 0); - const TA *Ap1 = A + lda * (i + 1); - const TA *Ap2 = A + lda * (i + 2); - const TA *Ap3 = A + lda * (i + 3); - const TB *Bp0 = B + ldb * (j + 0); - const TB *Bp1 = B + ldb * (j + 1); - const TB *Bp2 = B + ldb * (j + 2); - for (int l = 0; l < k; ++l) { - float da0 = unhalf(Ap0[l].d); - float da1 = unhalf(Ap1[l].d); - float da2 = unhalf(Ap2[l].d); - float da3 = unhalf(Ap3[l].d); - __m256i e0 = load(Ap0 + l); - __m256i e1 = load(Ap1 + l); - __m256i e2 = load(Ap2 + l); - __m256i e3 = load(Ap3 + l); - float db0 = unhalf(Bp0[l].d); - __m256 d00 = _mm256_set1_ps(da0 * db0); - __m256 d10 = _mm256_set1_ps(da1 * db0); - __m256 d20 = _mm256_set1_ps(da2 * db0); - __m256 d30 = _mm256_set1_ps(da3 * db0); - __m256i f0 = load(Bp0 + l); - __m256i u0 = _mm256_sign_epi8(f0, f0); - __m256i s00 = _mm256_sign_epi8(e0, f0); - __m256i s10 = _mm256_sign_epi8(e1, f0); - __m256i s20 = _mm256_sign_epi8(e2, f0); - __m256i s30 = _mm256_sign_epi8(e3, f0); - c00 = madd(d00, updot(u0, s00), c00); - c10 = madd(d10, updot(u0, s10), c10); - c20 = madd(d20, updot(u0, s20), c20); - c30 = madd(d30, updot(u0, s30), c30); - float db1 = unhalf(Bp1[l].d); - __m256 d01 = _mm256_set1_ps(da0 * db1); - __m256 d11 = _mm256_set1_ps(da1 * db1); - __m256 d21 = _mm256_set1_ps(da2 * db1); - __m256 d31 = _mm256_set1_ps(da3 * db1); - __m256i f1 = load(Bp1 + l); - __m256i u1 = _mm256_sign_epi8(f1, f1); - __m256i s01 = _mm256_sign_epi8(e0, f1); - __m256i s11 = _mm256_sign_epi8(e1, f1); - __m256i s21 = _mm256_sign_epi8(e2, f1); - __m256i s31 = _mm256_sign_epi8(e3, f1); - c01 = madd(d01, updot(u1, s01), c01); - c11 = madd(d11, updot(u1, s11), c11); - c21 = madd(d21, updot(u1, s21), c21); - c31 = madd(d31, updot(u1, s31), c31); - float db2 = unhalf(Bp2[l].d); - __m256 d02 = _mm256_set1_ps(da0 * db2); - __m256 d12 = _mm256_set1_ps(da1 * db2); - __m256 d22 = _mm256_set1_ps(da2 * db2); - __m256 d32 = _mm256_set1_ps(da3 * db2); - __m256i f2 = load(Bp2 + l); - __m256i u2 = _mm256_sign_epi8(f2, f2); - __m256i s02 = _mm256_sign_epi8(e0, f2); - __m256i s12 = _mm256_sign_epi8(e1, f2); - __m256i s22 = _mm256_sign_epi8(e2, f2); - __m256i s32 = _mm256_sign_epi8(e3, f2); - c02 = madd(d02, updot(u2, s02), c02); - c12 = madd(d12, updot(u2, s12), c12); - c22 = madd(d22, updot(u2, s22), c22); - c32 = madd(d32, updot(u2, s32), c32); - } - C[ldc * (j + 0) + (i + 0)] = hsum(c00); - C[ldc * (j + 0) + (i + 1)] = hsum(c10); - C[ldc * (j + 0) + (i + 2)] = hsum(c20); - C[ldc * (j + 0) + (i + 3)] = hsum(c30); - C[ldc * (j + 1) + (i + 0)] = hsum(c01); - C[ldc * (j + 1) + (i + 1)] = hsum(c11); - C[ldc * (j + 1) + (i + 2)] = hsum(c21); - C[ldc * (j + 1) + (i + 3)] = hsum(c31); - C[ldc * (j + 2) + (i + 0)] = hsum(c02); - C[ldc * (j + 2) + (i + 1)] = hsum(c12); - C[ldc * (j + 2) + (i + 2)] = hsum(c22); - C[ldc * (j + 2) + (i + 3)] = hsum(c32); - END_KERNEL() - } - - dontinline void gemm4x1(int m0, int m, int n0, int n) { - BEGIN_KERNEL(4, 1) - __m256 c0 = _mm256_setzero_ps(); - __m256 c1 = _mm256_setzero_ps(); - __m256 c2 = _mm256_setzero_ps(); - __m256 c3 = _mm256_setzero_ps(); - const TA *Ap0 = A + lda * (i + 0); - const TA *Ap1 = A + lda * (i + 1); - const TA *Ap2 = A + lda * (i + 2); - const TA *Ap3 = A + lda * (i + 3); - const TB *Bp = B + ldb * j; - for (int l = 0; l < k; ++l) { - float db0 = unhalf(Bp[l].d); - __m256i f = load(Bp + l); - __m256i u = _mm256_sign_epi8(f, f); - __m256 d0 = _mm256_set1_ps(unhalf(Ap0[l].d) * db0); - __m256 d1 = _mm256_set1_ps(unhalf(Ap1[l].d) * db0); - __m256 d2 = _mm256_set1_ps(unhalf(Ap2[l].d) * db0); - __m256 d3 = _mm256_set1_ps(unhalf(Ap3[l].d) * db0); - __m256i e0 = load(Ap0 + l); - __m256i e1 = load(Ap1 + l); - __m256i e2 = load(Ap2 + l); - __m256i e3 = load(Ap3 + l); - __m256i s0 = _mm256_sign_epi8(e0, f); - __m256i s1 = _mm256_sign_epi8(e1, f); - __m256i s2 = _mm256_sign_epi8(e2, f); - __m256i s3 = _mm256_sign_epi8(e3, f); - __m256 g0 = updot(u, s0); - __m256 g1 = updot(u, s1); - __m256 g2 = updot(u, s2); - __m256 g3 = updot(u, s3); - c0 = madd(d0, g0, c0); - c1 = madd(d1, g1, c1); - c2 = madd(d2, g2, c2); - c3 = madd(d3, g3, c3); - } - C[ldc * j + (i + 0)] = hsum(c0); - C[ldc * j + (i + 1)] = hsum(c1); - C[ldc * j + (i + 2)] = hsum(c2); - C[ldc * j + (i + 3)] = hsum(c3); - END_KERNEL() - } - - dontinline void gemm1x4(int m0, int m, int n0, int n) { - BEGIN_KERNEL(1, 4) - __m256 c0 = _mm256_setzero_ps(); - __m256 c1 = _mm256_setzero_ps(); - __m256 c2 = _mm256_setzero_ps(); - __m256 c3 = _mm256_setzero_ps(); - const TB *Bp0 = B + ldb * (j + 0); - const TB *Bp1 = B + ldb * (j + 1); - const TB *Bp2 = B + ldb * (j + 2); - const TB *Bp3 = B + ldb * (j + 3); - const TA *Ap = A + lda * i; - for (int l = 0; l < k; ++l) { - float da0 = unhalf(Ap[l].d); - __m256i f = load(Ap + l); - __m256i u = _mm256_sign_epi8(f, f); - __m256 d0 = _mm256_set1_ps(unhalf(Bp0[l].d) * da0); - __m256 d1 = _mm256_set1_ps(unhalf(Bp1[l].d) * da0); - __m256 d2 = _mm256_set1_ps(unhalf(Bp2[l].d) * da0); - __m256 d3 = _mm256_set1_ps(unhalf(Bp3[l].d) * da0); - __m256 g0 = updot(u, _mm256_sign_epi8(load(Bp0 + l), f)); - __m256 g1 = updot(u, _mm256_sign_epi8(load(Bp1 + l), f)); - __m256 g2 = updot(u, _mm256_sign_epi8(load(Bp2 + l), f)); - __m256 g3 = updot(u, _mm256_sign_epi8(load(Bp3 + l), f)); - c0 = madd(d0, g0, c0); - c1 = madd(d1, g1, c1); - c2 = madd(d2, g2, c2); - c3 = madd(d3, g3, c3); - } - C[ldc * (j + 0) + i] = hsum(c0); - C[ldc * (j + 1) + i] = hsum(c1); - C[ldc * (j + 2) + i] = hsum(c2); - C[ldc * (j + 3) + i] = hsum(c3); - END_KERNEL() + mnpack(m0, m, np, n); } - dontinline void gemm1x1(int m0, int m, int n0, int n) { - BEGIN_KERNEL(1, 1) - __m256 c = _mm256_setzero_ps(); - const TA *Ap = A + lda * i; - const TB *Bp = B + ldb * j; - for (int l = 0; l < k; ++l) { - __m256 d = _mm256_set1_ps(unhalf(Ap[l].d) * unhalf(Bp[l].d)); - __m256i e = load(Ap + l); - __m256i f = load(Bp + l); - __m256 g = updot(_mm256_sign_epi8(e, e), _mm256_sign_epi8(f, e)); - c = madd(d, g, c); + template dontinline void gemm(int m0, int m, int n0, int n) { + int ytiles = (m - m0) / RM; + int xtiles = (n - n0) / RN; + int tiles = xtiles * ytiles; + int duty = (tiles + nth - 1) / nth; + int start = duty * ith; + int end = start + duty; + if (end > tiles) + end = tiles; + for (int job = start; job < end; ++job) { + int ii = m0 + job / xtiles * RM; + int jj = n0 + job % xtiles * RN; + __m256 Cv[RN][RM] = {0}; + for (int l = 0; l < k; ++l) + for (int j = 0; j < RN; ++j) + for (int i = 0; i < RM; ++i) + Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) * + unhalf(B[ldb * (jj + j) + l].d)), + updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), + load(A + lda * (ii + i) + l)), + _mm256_sign_epi8(load(B + ldb * (jj + j) + l), + load(A + lda * (ii + i) + l))), + Cv[j][i]); + TC Cd[RN][RM]; + for (int j = 0; j < RN; ++j) + for (int i = 0; i < RM; ++i) + Cd[j][i] = hsum(Cv[j][i]); + for (int j = 0; j < RN; ++j) + for (int i = 0; i < RM; ++i) + C[ldc * (jj + j) + (ii + i)] = Cd[j][i]; } - C[ldc * j + i] = hsum(c); - END_KERNEL() } inline __m256i load(const block_q8_0 *b) { diff --git a/llamafile/sgemmer1.inc b/llamafile/sgemmer1.inc index 7eb37d40fa..f8966d8fb5 100644 --- a/llamafile/sgemmer1.inc +++ b/llamafile/sgemmer1.inc @@ -55,8 +55,7 @@ class SGEMMER1 { mp = m0 + (m - m0) / mc * mc; np = n0 + (n - n0) / nc * nc; mnpack(mp, m, n0, np); - mnpack(m0, mp, np, n); - mnpack(mp, m, np, n); + mnpack(m0, m, np, n); } dontinline void gemm4x2(int m0, int m, int n0, int n) {