Skip to content

Commit

Permalink
Change the way the new fast/precise flags work
Browse files Browse the repository at this point in the history
  • Loading branch information
jart committed May 7, 2024
1 parent e6532f7 commit b749326
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 27 deletions.
2 changes: 0 additions & 2 deletions llama.cpp/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,12 +291,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
}
if (arg == "--fast") {
FLAG_precise = false;
FLAG_precision_specified = true;
return true;
}
if (arg == "--precise") {
FLAG_precise = true;
FLAG_precision_specified = true;
return true;
}
if (arg == "--trap") {
Expand Down
2 changes: 0 additions & 2 deletions llama.cpp/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2525,12 +2525,10 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
else if (arg == "--fast")
{
FLAG_precise = false;
FLAG_precision_specified = true;
}
else if (arg == "--precise")
{
FLAG_precise = true;
FLAG_precision_specified = true;
}
else if (arg == "--trap")
{
Expand Down
3 changes: 1 addition & 2 deletions llamafile/flags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,4 @@

#include "llamafile.h"

bool FLAG_precise = true;
bool FLAG_precision_specified;
bool FLAG_precise;
1 change: 0 additions & 1 deletion llamafile/llamafile.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ void llamafile_launch_browser(const char *);
extern bool FLAG_trap;
extern bool FLAG_precise;
extern bool FLAG_unsecure;
extern bool FLAG_precision_specified;

#define LLAMAFILE_GPU_ERROR -2
#define LLAMAFILE_GPU_DISABLE -1
Expand Down
16 changes: 10 additions & 6 deletions llamafile/numba.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,31 @@ inline float float01(unsigned x) { // (0,1)
return 1.f / 8388608 * ((x >> 9) + .5f);
}

inline float numba(void) { // (-1,1)
return float01(rand32()) * 2 - 1;
inline float numba(void) { // (-10,10)
return float01(rand32()) * 2.f - 1.f;
}

template <typename T> void randomize(T *A, int n) {
template <typename T>
void randomize(T *A, int n) {
for (int i = 0; i < n; ++i)
A[i] = numba();
}

template <typename T> void randomize(int m, int n, T *A, int lda) {
template <typename T>
void randomize(int m, int n, T *A, int lda) {
for (int j = 0; j < n; ++j)
for (int i = 0; i < m; ++i)
A[lda * j + i] = numba();
}

template <typename T, typename U> void broadcast(T *A, int n, U x) {
template <typename T, typename U>
void broadcast(T *A, int n, U x) {
for (int i = 0; i < n; ++i)
A[i] = x;
}

template <typename T, typename U> void broadcast(int m, int n, T *A, int lda, U x) {
template <typename T, typename U>
void broadcast(int m, int n, T *A, int lda, U x) {
for (int j = 0; j < n; ++j)
for (int i = 0; i < m; ++i)
A[lda * j + i] = x;
Expand Down
18 changes: 15 additions & 3 deletions llamafile/tinyblas_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,9 @@ class tinyBLAS {
D Cv[RN][RM] = {};
D Ce[RN][RM] = {};
for (long l = 0; l < k; l += KN)
#pragma GCC unroll 100
for (int j = 0; j < RN; ++j)
#pragma GCC unroll 100
for (int i = 0; i < RM; ++i)
if (PRECISE)
Cv[j][i] = madder(load<V>(INDEX(A, lda, ii + i, l)), //
Expand All @@ -632,7 +634,9 @@ class tinyBLAS {
Cv[j][i] = madd(load<V>(INDEX(A, lda, ii + i, l)), //
load<V>(INDEX(B, ldb, jj + j, l)), //
Cv[j][i]);
#pragma GCC unroll 100
for (int j = 0; j < RN; ++j)
#pragma GCC unroll 100
for (int i = 0; i < RM; ++i)
store(INDEX(C, ldc, jj + j, ii + i), hsum(Cv[j][i]));
}
Expand Down Expand Up @@ -670,7 +674,7 @@ class tinyBLAS_Q0_ARM {
NOINLINE void mnpack(long m0, long m, long n0, long n) {
long mc, nc, mp, np;

if (!FLAG_precise || (!FLAG_precision_specified && sizeof(TB) == sizeof(block_q4_0))) {
if (!FLAG_precise) {
switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3)) {
case 0x33:
mc = 3;
Expand Down Expand Up @@ -762,7 +766,9 @@ class tinyBLAS_Q0_ARM {
float32x4_t Cv[RN][RM] = {};
float32x4_t Ce[RN][RM] = {};
for (int l = 0; l < k; ++l)
#pragma GCC unroll 100
for (int j = 0; j < RN; ++j)
#pragma GCC unroll 100
for (int i = 0; i < RM; ++i) {
float32x4_t a = vcvtq_f32_s32(vdotq_s32(
vdotq_s32(vdupq_n_s32(0), load_lo(INDEX(A, lda, ii + i, l)),
Expand All @@ -775,7 +781,9 @@ class tinyBLAS_Q0_ARM {
else
Cv[j][i] = vmlaq_n_f32(Cv[j][i], a, b);
}
#pragma GCC unroll 100
for (int j = 0; j < RN; ++j)
#pragma GCC unroll 100
for (int i = 0; i < RM; ++i)
store(INDEX(C, ldc, jj + j, ii + i), hsum(Cv[j][i]));
}
Expand Down Expand Up @@ -829,7 +837,7 @@ class tinyBLAS_Q0_AVX2 {
long mc, nc, mp, np;

#if VECTOR_REGISTERS == 32
if (!FLAG_precise || (!FLAG_precision_specified && sizeof(TB) == sizeof(block_q4_0))) {
if (!FLAG_precise) {
switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3)) {
case 0x33:
mc = 3;
Expand Down Expand Up @@ -901,7 +909,7 @@ class tinyBLAS_Q0_AVX2 {
#endif

#if VECTOR_REGISTERS == 16
if (!FLAG_precise || (!FLAG_precision_specified && sizeof(TB) == sizeof(block_q4_0))) {
if (!FLAG_precise) {
switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 2)) {
case 0x32:
mc = 3;
Expand Down Expand Up @@ -982,7 +990,9 @@ class tinyBLAS_Q0_AVX2 {
__m256 Cv[RN][RM] = {};
__m256 Ce[RN][RM] = {};
for (long l = 0; l < k; ++l)
#pragma GCC unroll 100
for (int j = 0; j < RN; ++j)
#pragma GCC unroll 100
for (int i = 0; i < RM; ++i) {
__m256 a = _mm256_set1_ps(unhalf(INDEX(A, lda, ii + i, l)->d) *
unhalf(INDEX(B, ldb, jj + j, l)->d));
Expand All @@ -995,7 +1005,9 @@ class tinyBLAS_Q0_AVX2 {
else
Cv[j][i] = madd(a, b, Cv[j][i]);
}
#pragma GCC unroll 100
for (int j = 0; j < RN; ++j)
#pragma GCC unroll 100
for (int i = 0; i < RM; ++i)
store(INDEX(C, ldc, jj + j, ii + i), hsum(Cv[j][i]));
}
Expand Down
22 changes: 11 additions & 11 deletions llamafile/tinyblas_mnpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# # tinyBLAS
# MAX_M = 5
# MAX_N = 5
# EDGE_M = 2
# EDGE_N = 2
# OVERHEAD = 1

# tinyBLAS_Q0
MAX_M = 3
MAX_N = 3
# tinyBLAS
MAX_M = 5
MAX_N = 5
EDGE_M = 2
EDGE_N = 2
OVERHEAD = 8
OVERHEAD = 1

# # tinyBLAS_Q0
# MAX_M = 3
# MAX_N = 3
# EDGE_M = 2
# EDGE_N = 2
# OVERHEAD = 8

def doit(VECTOR_REGISTERS, PRECISE):
# choose tile size that exploits all vector registers
Expand Down

0 comments on commit b749326

Please sign in to comment.