Skip to content

Commit

Permalink
Merge branch ihavnoid/tensorcore_opt
Browse files Browse the repository at this point in the history
  • Loading branch information
alreadydone committed Sep 23, 2019
2 parents 1e6c6eb + 4c7b38e commit 70a8aff
Show file tree
Hide file tree
Showing 11 changed files with 131 additions and 121 deletions.
18 changes: 12 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,13 @@ a (weaker) network trained from human games [here](https://sjeng.org/zero/best_v
If you are on Windows, download an official release from [here](https://github.com/leela-zero/leela-zero/releases) and head to the [Usage](#usage-for-playing-or-analyzing-games)
section of this README.

If you are on Unix or macOS, you have to compile the program yourself. Follow
If you are on macOS, Leela Zero is available through [Homebrew](https://homebrew.sh), the de facto standard
package manager. You can install it with:
```
brew install leela-zero
```

If you are on Unix, you have to compile the program yourself. Follow
the compilation instructions below and then read the [Usage](#usage-for-playing-or-analyzing-games) section.

# Compiling AutoGTP and/or Leela Zero
Expand Down Expand Up @@ -312,13 +318,13 @@ This requires a working installation of TensorFlow 1.4 or later:
src/leelaz -w weights.txt
dump_supervised bigsgf.sgf train.out
exit
training/tf/parse.py train.out
training/tf/parse.py 6 128 train.out

This will run and regularly dump Leela Zero weight files to disk, as
well as snapshots of the learning state numbered by the batch number.
If interrupted, training can be resumed with:
This will run and regularly dump Leela Zero weight files (of networks with 6
blocks and 128 filters) to disk, as well as snapshots of the learning state
numbered by the batch number. If interrupted, training can be resumed with:

training/tf/parse.py train.out leelaz-model-batchnumber
training/tf/parse.py 6 128 train.out leelaz-model-batchnumber

# Todo

Expand Down
5 changes: 5 additions & 0 deletions src/GTP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ const std::string GTP::s_commands[] = {
"fixed_handicap",
"last_move",
"move_history",
"clear_cache",
"place_free_handicap",
"set_free_handicap",
"loadsgf",
Expand Down Expand Up @@ -943,6 +944,10 @@ void GTP::execute(GameState & game, const std::string& xinput) {
}
gtp_printf_raw("\n");
return;
} else if (command.find("clear_cache") == 0) {
s_network->nncache_clear();
gtp_printf(id, "");
return;
} else if (command.find("place_free_handicap") == 0) {
std::istringstream cmdstream(command);
std::string tmp;
Expand Down
2 changes: 1 addition & 1 deletion src/Network.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ class Network {
size_t get_estimated_size();
size_t get_estimated_cache_size();
void nncache_resize(int max_count);
void nncache_clear();

void clear_stats() { m_forward->clear_stats(); m_nncache.clear_stats(); }
void dump_stats() { m_forward->dump_stats(); m_nncache.dump_stats(); }
void nncache_clear();

//int get_max_size() { return m_forward->m_max_queue_size.load(); }
void set_search(UCTSearch* search) { m_search = m_forward->m_search = search; m_forward->m_network = this; }
Expand Down
21 changes: 15 additions & 6 deletions src/OpenCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -920,12 +920,21 @@ OpenCL<net_t>::OpenCL(int gpu, bool silent) {
}

myprintf("Tensor Core support: ");
try {
cl::Program(m_context, sourceCode_tensorcore_test).build(m_cl_args.c_str());
m_tensorcore = true;
myprintf("Yes.\n");
} catch (...) {
myprintf("No.\n");
{
// if this is a nvidia GPU, test-compile a sample inline assembly code with
// tensor wmma instructions. if not, don't bother trying
std::string this_vendor = m_device.getInfo<CL_DEVICE_VENDOR>();
if (boost::icontains(this_vendor, "nvidia")) {
try {
cl::Program(m_context, sourceCode_tensorcore_test).build(m_cl_args.c_str());
m_tensorcore = true;
myprintf("Yes.\n");
} catch (...) {
myprintf("No.\n");
}
} else {
myprintf("No.\n");
}
}
}

Expand Down
23 changes: 16 additions & 7 deletions src/Tuner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,15 @@ bool Tuner<net_t>::valid_config_sgemm(Parameters p, bool exhaustive) {
if (p["NDIMC"] < p["NDIMB"]) {
return false;
}
if (p["MWG"] < 32) {
return false;
}
if (p["NWG"] < 32) {
return false;
}
if (p["KWG"] < 32) {
return false;
}
// VWM / VWN has no meaning if we don't do SA / SB.
// Only test VWM / VWN == 2
if (p["SA"] == 0 && p["VWM"] != 2) {
Expand Down Expand Up @@ -335,8 +344,8 @@ std::vector<Parameters> Tuner<net_t>::build_valid_params() {
if (cfg_sgemm_exhaustive) {
topts = {
{"MWG", {32, 64, 128, 256}},
{"NWG", {8, 16, 32, 64}},
{"KWG", {16, 32, 64}},
{"NWG", {8, 16, 32, 64, 128, 256}},
{"KWG", {16, 32, 64, 128, 256}},
{"MDIMC", {8, 16, 32, 64}},
{"NDIMC", {8, 16, 32, 64}},
{"MDIMA", {8, 16, 32}},
Expand All @@ -352,8 +361,8 @@ std::vector<Parameters> Tuner<net_t>::build_valid_params() {
} else {
topts = {
{"MWG", {32, 64, 128}},
{"NWG", {8, 16, 32}},
{"KWG", {16, 32}},
{"NWG", {16, 32, 64, 128}},
{"KWG", {16, 32, 64, 128}},
{"MDIMC", {8, 16, 32}},
{"NDIMC", {8, 16, 32}},
{"MDIMA", {8, 16, 32}},
Expand Down Expand Up @@ -401,9 +410,9 @@ template <typename net_t>
std::string Tuner<net_t>::tune_sgemm(const int m, const int n, const int k,
const int batch_size, const int runs) {
// This needs to be at minimum the maximum (MNK/WG) values above.
auto m_max = std::max(64, m);
auto n_max = std::max(64, n);
auto k_max = std::max(32, k);
auto m_max = std::max(256, m);
auto n_max = std::max(256, n);
auto k_max = std::max(256, k);

auto at_size = batch_size
* next_power_of_two(k_max) * next_power_of_two(m_max);
Expand Down
5 changes: 3 additions & 2 deletions src/UCTNodePointer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,9 @@ void UCTNodePointer::inflate() const {
if (is_inflated(v)) return;

auto v2 = reinterpret_cast<std::uint64_t>(
new UCTNode(read_vertex(v), read_policy(v))
) | POINTER;
new UCTNode(read_vertex(v), read_policy(v)));
assert((v2 & 3ULL) == 0);
v2 |= POINTER;
bool success = m_data.compare_exchange_strong(v, v2);
if (success) {
increment_tree_size(sizeof(UCTNode));
Expand Down
118 changes: 45 additions & 73 deletions src/kernels/clblast/hgemm_tensorcore.opencl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
// literal). Comment-out this line for syntax-highlighting when developing.

R"(
#define USE_TC

#ifndef SA
#define SA 1
#endif
Expand Down Expand Up @@ -153,7 +151,6 @@ void HgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,
}

int k, m, n, mb, nb, kb, kwg;
#ifdef USE_TC
int zero_pair;
asm("{\n"
".reg .b16 xh;\n"
Expand Down Expand Up @@ -182,18 +179,6 @@ void HgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,
c3[mb][nb] = zero_pair;
}
}
#else
float acc[MWG/MDIMC][NWG/NDIMC][2][4];
for(mb = 0; mb < MWG / MDIMC; mb += 1) {
for(nb = 0; nb < NWG / NDIMC; nb += 1) {
for(m=0; m<2; m++) {
for(int n=0; n<4; n++) {
acc[mb][nb][m][n] = 0.0f;
}
}
}
}
#endif
for(kwg = 0; kwg < kSizeK; kwg += KWG) {
#if SA == 1
GlobalToLocalA(get_local_id(0) + get_local_id(1) * WARP_SIZE * MDIMC / MDIMA, kSizeM,
Expand All @@ -216,77 +201,75 @@ void HgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,

#pragma unroll
for(kb = 0; kb < KWG; kb += 16) {
#pragma promote_to_registers
int b[NWG/NDIMC][8];
for(nb = 0; nb < NWG / NDIMC; nb += 1) {
#if SB == 1
const int block_loc_n = (get_local_id(1)) % (NDIMC/NDIMB);
const int bgm_stride = NWG;
const __local half * b_bgm_ = (const __local half *)(blm + (nb + block_loc_n * (NWG/NDIMC)) * NDIMB);
const __local half * bb_bgm_ = b_bgm_ + bgm_stride * kb;
#else
const int bgm_stride = kSizeN;
const __global half * b_bgm_ = bgm_ + nb * NDIMB;
const __global half * bb_bgm_ = b_bgm_ + kSizeN * (kb + kwg);
#endif
asm("{\n"
#if SB == 1
"wmma.load.b.sync.aligned." WMMA_SHAPE ".shared.row.f16 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8], %9;\n"
#else
"wmma.load.b.sync.aligned." WMMA_SHAPE ".row.f16 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8], %9;\n"
#endif
"}": "=r"(b[nb][0]), "=r"(b[nb][1]), "=r"(b[nb][2]), "=r"(b[nb][3]), "=r"(b[nb][4]), "=r"(b[nb][5]), "=r"(b[nb][6]), "=r"(b[nb][7]) : "l"(bb_bgm_), "r"(bgm_stride));
}
#pragma unroll
for(mb = 0; mb < MWG / MDIMC; mb += 1) {
#pragma unroll
for(nb = 0; nb < NWG / NDIMC; nb += 1) {
#pragma promote_to_registers
int a[8];
#if SA == 1
const int block_loc_m = (get_local_id(0)/WARP_SIZE) % (MDIMC/MDIMA);
const int agm_stride = MWG;
const __local half * b_agm_ = (const __local half *)(alm + (mb + block_loc_m * (MWG/MDIMC)) * MDIMA);
const __local half * bb_agm_ = b_agm_ + agm_stride * kb;
const int block_loc_m = (get_local_id(0)/WARP_SIZE) % (MDIMC/MDIMA);
const int agm_stride = MWG;
const __local half * b_agm_ = (const __local half *)(alm + (mb + block_loc_m * (MWG/MDIMC)) * MDIMA);
const __local half * bb_agm_ = b_agm_ + agm_stride * kb;
#else
const int agm_stride = kSizeM;
const __global half * b_agm_ = agm_ + mb * MDIMA;
const __global half * bb_agm_ = b_agm_ + kSizeM * (kb + kwg);
const int agm_stride = kSizeM;
const __global half * b_agm_ = agm_ + mb * MDIMA;
const __global half * bb_agm_ = b_agm_ + kSizeM * (kb + kwg);
#endif

#if SB == 1
const int block_loc_n = (get_local_id(1)) % (NDIMC/NDIMB);
const int bgm_stride = NWG;
const __local half * b_bgm_ = (const __local half *)(blm + (nb + block_loc_n * (NWG/NDIMC)) * NDIMB);
const __local half * bb_bgm_ = b_bgm_ + bgm_stride * kb;
asm("{\n"
#if SA == 1
"wmma.load.a.sync.aligned." WMMA_SHAPE ".shared.col.f16 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8], %9;\n"
#else
const int bgm_stride = kSizeN;
const __global half * b_bgm_ = bgm_ + nb * NDIMB;
const __global half * bb_bgm_ = b_bgm_ + kSizeN * (kb + kwg);
"wmma.load.a.sync.aligned." WMMA_SHAPE ".col.f16 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8], %9;\n"
#endif
#ifdef USE_TC
"}": "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]), "=r"(a[4]), "=r"(a[5]), "=r"(a[6]), "=r"(a[7]) : "l"(bb_agm_), "r"(agm_stride));

#pragma unroll
for(nb = 0; nb < NWG / NDIMC; nb += 1) {
int d0_, d1_, d2_, d3_;
int c0_ = c0[mb][nb];
int c1_ = c1[mb][nb];
int c2_ = c2[mb][nb];
int c3_ = c3[mb][nb];
asm("{\n"
".reg .b32 a0, a1, a2, a3, a4, a5, a6, a7;\n"
".reg .b32 b0, b1, b2, b3, b4, b5, b6, b7;\n"
#if SA == 1
"wmma.load.a.sync.aligned." WMMA_SHAPE ".shared.col.f16 {a0,a1,a2,a3,a4,a5,a6,a7}, [%4], %6;\n"
#else
"wmma.load.a.sync.aligned." WMMA_SHAPE ".col.f16 {a0,a1,a2,a3,a4,a5,a6,a7}, [%4], %6;\n"
#endif
#if SB == 1
"wmma.load.b.sync.aligned." WMMA_SHAPE ".shared.row.f16 {b0,b1,b2,b3,b4,b5,b6,b7}, [%5], %7;\n"
#else
"wmma.load.b.sync.aligned." WMMA_SHAPE ".row.f16 {b0,b1,b2,b3,b4,b5,b6,b7}, [%5], %7;\n"
#endif
"wmma.mma.sync.aligned.col.row." WMMA_SHAPE ".f16.f16 "
" {%0,%1,%2,%3},\n"
" {a0,a1,a2,a3,a4,a5,a6,a7},\n"
" {b0,b1,b2,b3,b4,b5,b6,b7},\n"
" {%8,%9,%10,%11};\n"
"}": "=r"(d0_), "=r"(d1_), "=r"(d2_), "=r"(d3_) : "l"(bb_agm_), "l"(bb_bgm_), "r"(agm_stride), "r"(bgm_stride), "r"(c0_), "r"(c1_), "r"(c2_), "r"(c3_));
" {%8,%9,%10,%11,%12,%13,%14,%15},\n"
" {%16,%17,%18,%19,%20,%21,%22,%23},\n"
" {%4,%5,%6,%7};\n"
"}": "=r"(d0_), "=r"(d1_), "=r"(d2_), "=r"(d3_) : "r"(c0_), "r"(c1_), "r"(c2_), "r"(c3_),
"r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(a[4]), "r"(a[5]), "r"(a[6]), "r"(a[7]),
"r"(b[nb][0]), "r"(b[nb][1]), "r"(b[nb][2]), "r"(b[nb][3]), "r"(b[nb][4]), "r"(b[nb][5]), "r"(b[nb][6]), "r"(b[nb][7])
);
c0[mb][nb] = d0_;
c1[mb][nb] = d1_;
c2[mb][nb] = d2_;
c3[mb][nb] = d3_;
#else
for(m = offset_m; m < MDIMA; m += MDIMA/2) {
for(n = offset_n; n < NDIMB; n += NDIMB/4) {
float a = 0.0f;
for(k = 0; k < 16; k++) {
a += vload_half(agm_stride * k + m, bb_agm_) * vload_half(bgm_stride * k + n, bb_bgm_);
}
acc[mb][nb][m/(MDIMA/2)][n/(NDIMB/4)] += a;
}
}
#endif
}
}
}
}

#ifdef USE_TC
#pragma unroll
for(mb = 0; mb < MWG / MDIMC; mb += 1) {
#pragma unroll
Expand All @@ -301,17 +284,6 @@ void HgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,
"}" : : "r"(c0_), "r"(c1_), "r"(c2_), "r"(c3_), "l"(b_cgm_), "r"(kSizeM));
}
}
#else
for(mb = 0; mb < MWG / MDIMC; mb += 1) {
for(nb = 0; nb < NWG / NDIMC; nb += 1) {
for(m = offset_m; m < MDIMA; m += MDIMA/2) {
for(n = offset_n; n < NDIMB; n += NDIMB/4) {
vstore_half(acc[mb][nb][m/(MDIMA/2)][n/(NDIMB/4)], kSizeM * (nb * NDIMB + n) + mb * MDIMA + m, cgm_);
}
}
}
}
#endif
}

struct alm_t {short alm[KWG * MWG];} __attribute__((aligned(32)));
Expand Down
8 changes: 1 addition & 7 deletions training/tf/net_to_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,8 @@
blocks //= 8
print("Blocks", blocks)

tfprocess = TFProcess()
tfprocess = TFProcess(blocks, channels)
tfprocess.init(batch_size=1, gpus_num=1)
if tfprocess.RESIDUAL_BLOCKS != blocks:
raise ValueError("Number of blocks in tensorflow model doesn't match "\
"number of blocks in input network")
if tfprocess.RESIDUAL_FILTERS != channels:
raise ValueError("Number of filters in tensorflow model doesn't match "\
"number of filters in input network")
tfprocess.replace_weights(weights)
path = os.path.join(os.getcwd(), "leelaz-model")
save_path = tfprocess.saver.save(tfprocess.session, path, global_step=0)
20 changes: 17 additions & 3 deletions training/tf/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,24 +107,38 @@ def split_chunks(chunks, test_ratio):
def main():
parser = argparse.ArgumentParser(
description='Train network from game data.')
parser.add_argument("blockspref",
help="Number of blocks", nargs='?', type=int)
parser.add_argument("filterspref",
help="Number of filters", nargs='?', type=int)
parser.add_argument("trainpref",
help='Training file prefix', nargs='?', type=str)
parser.add_argument("restorepref",
help='Training snapshot prefix', nargs='?', type=str)
parser.add_argument("--blocks", '-b',
help="Number of blocks", type=int)
parser.add_argument("--filters", '-f',
help="Number of filters", type=int)
parser.add_argument("--train", '-t',
help="Training file prefix", type=str)
parser.add_argument("--test", help="Test file prefix", type=str)
parser.add_argument("--restore", type=str,
help="Prefix of tensorflow snapshot to restore from")
parser.add_argument("--logbase", default='leelalogs', type=str,
help="Log file prefix (for tensorboard)")
help="Log file prefix (for tensorboard) (default: %(default)s)")
parser.add_argument("--sample", default=DOWN_SAMPLE, type=int,
help="Rate of data down-sampling to use")
help="Rate of data down-sampling to use (default: %(default)d)")
args = parser.parse_args()

blocks = args.blocks or args.blockspref
filters = args.filters or args.filterspref
train_data_prefix = args.train or args.trainpref
restore_prefix = args.restore or args.restorepref

if not blocks or not filters:
print("Must supply number of blocks and filters")
return

training = get_chunks(train_data_prefix)
if not args.test:
# Generate test by taking 10% of the training chunks.
Expand All @@ -150,7 +164,7 @@ def main():
sample=args.sample,
batch_size=RAM_BATCH_SIZE).parse()

tfprocess = TFProcess()
tfprocess = TFProcess(blocks, filters)
tfprocess.init(RAM_BATCH_SIZE,
logbase=args.logbase,
macrobatch=BATCH_SIZE // RAM_BATCH_SIZE)
Expand Down
Loading

0 comments on commit 70a8aff

Please sign in to comment.