-
Notifications
You must be signed in to change notification settings - Fork 502
/
Copy pathLlamaBatch.cc
1723 lines (1476 loc) · 67.5 KB
/
LlamaBatch.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
// Copyright (c) OpenMMLab. All rights reserved.
#include "src/turbomind/models/llama/LlamaBatch.h"
#include "src/turbomind/kernels/attention/data_type.h"
#include "src/turbomind/kernels/decoding_kernels.h"
#include "src/turbomind/kernels/sampling_topk_kernels.h"
#include "src/turbomind/macro.h"
#include "src/turbomind/models/llama/BlockManager.h"
#include "src/turbomind/models/llama/LlamaNcclGuard.h"
#include "src/turbomind/models/llama/LlamaV2.h"
#include "src/turbomind/models/llama/Request.h"
#include "src/turbomind/models/llama/SequenceManager.h"
#include "src/turbomind/models/llama/copy.h"
#include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/constant.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/debug_utils.h"
#include "src/turbomind/utils/gemm_test/gemm_func.h"
#include "src/turbomind/utils/logger.h"
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <iomanip>
#include <iterator>
#include <mutex>
#include <numeric>
#include <sstream>
#include <unordered_map>
#include <utility>
namespace turbomind {
void PrintDecodeTokens(
const int* token_ids, int max_seq_len, int batch_sizse, cudaStream_t stream, const std::string& msg)
{
// tokens in [S, B] layout
std::vector<int> tokens(max_seq_len * batch_sizse);
check_cuda_error(cudaMemcpyAsync(tokens.data(), token_ids, sizeof(int) * tokens.size(), cudaMemcpyDefault, stream));
check_cuda_error(cudaStreamSynchronize(stream));
printf("[%s] ", msg.c_str());
for (int j = 0; j < max_seq_len; ++j) {
printf("%5d ", j);
}
printf("\n");
for (int i = 0; i < batch_sizse; ++i) {
printf("[%s] ", msg.c_str());
for (int j = 0; j < max_seq_len; ++j) {
// std::cout << sb_tokens[j * batch_size + i] << " ";
printf("%5d ", tokens[j * batch_sizse + i]);
}
printf("\n");
}
}
void ClearState(BatchState& s)
{
std::fill_n(s.requests.begin(), s.size, nullptr);
std::fill_n(s.sequences.begin(), s.size, nullptr);
s.size = s.active_size = 0;
}
void DropEmbeddings(const Sequence& seq)
{
int seq_len = seq.tokens.size();
int num_emb = seq.input_embeddings.size();
size_t sz = num_emb;
for (; sz >= 1; sz--) {
if (seq.input_embedding_ranges[sz - 1].second <= seq_len) {
break;
}
}
// should we keep part of embedding?
seq.input_embeddings.resize(sz);
seq.input_embedding_ranges.resize(sz);
}
template<typename T>
void LlamaBatch<T>::RejectInvalidRequests(Requests& stop_reqs, Requests& infer_reqs)
{
std::unordered_map<uint64_t, int> occurrence;
auto count_occurrence = [&occurrence](const Requests& rs) {
for (const auto& r : rs) {
++occurrence[r->id];
}
};
auto reject = [](const char* type, std::shared_ptr<Request>& req, int ec) {
TM_LOG_WARNING(
"[RejectInvalidRequests] Skipping invalid %s request for id %ld, code = %d", type, (long)req->id, ec);
req->signal.set_value(ec);
req.reset();
};
auto handle_conflict_or_invalid = [this, &occurrence, &reject](Requests& rs, const char* type) {
for (auto& r : rs) {
if (r) {
int ec = 0;
const int input_length = r->inputs[rank_].getVal<int>("input_lengths", 0);
const auto get_offset = [&](int token_count) {
return std::max(0, std::min(token_count, r->inputs[rank_].getVal<int>("step", token_count)));
};
if (occurrence[r->id] != 1) {
ec = Request::kConflict;
}
else if (r->start_flag && r->stop_flag) {
ec = Request::kInvalid;
}
else if (input_length > session_len_) {
ec = Request::kTooLong;
}
else if (!r->start_flag) {
if (auto seq = sequence_manager_->Get(r->id); seq == nullptr) {
ec = Request::kInvalid;
}
else if (get_offset(seq->tokens.size()) + input_length > session_len_) {
ec = Request::kTooLong;
}
}
if (ec) {
reject(type, r, ec);
}
}
}
};
auto drop_invalid = [](Requests& rs) {
int count = 0;
for (int i = 0; i < rs.size(); ++i) {
if (rs[i]) {
rs[count++] = std::move(rs[i]);
}
}
rs.resize(count);
};
count_occurrence(stop_reqs);
count_occurrence(infer_reqs);
if (!stop_reqs.empty()) {
handle_conflict_or_invalid(stop_reqs, "stop");
// invalidate stop-only requests for inactive sequences
for (auto& r : stop_reqs) {
if (r && r->end_flag == false) {
int ec = Request::kInactive;
for (int i = 0; i < state_->size; ++i) {
if (state_->requests[i] && state_->requests[i]->id == r->id) {
ec = 0;
break;
}
}
if (ec) {
reject("stop", r, ec);
}
}
}
drop_invalid(stop_reqs);
}
if (!infer_reqs.empty()) {
handle_conflict_or_invalid(infer_reqs, "infer");
// invalidate requests for busy sequences
for (auto& r : infer_reqs) {
if (r) {
for (int i = 0; i < state_->size; ++i) {
if (state_->requests[i] && state_->requests[i]->id == r->id) {
reject("infer", r, Request::kBusy);
break;
}
}
}
}
drop_invalid(infer_reqs);
}
}
template<typename T>
auto LlamaBatch<T>::ProcessStopRequests(const Requests& requests) -> std::vector<Signal>
{
NvtxScope scope("stop_request");
std::vector<Signal> signals;
int count = 0;
for (const auto& r : requests) {
int ec = Request::kFail;
// find matching active sequence
for (int i = 0; i < state_->size; ++i) {
// stop & optionally erase active sequence
if (state_->requests[i] && state_->requests[i]->id == r->id) {
ec = 0;
signals.push_back(Interrupt(i, true, r->end_flag));
++count;
break;
}
}
// mismatch, try erase inactive sequence, in this case there is no active request to interrupt
if (ec && r->end_flag) {
if (sequence_manager_->Erase(r->id)) {
ec = 0;
}
}
signals.push_back([=] {
if (rank_ == 0) {
r->signal.set_value(ec);
}
});
}
if (count) {
check_cuda_error(cudaStreamSynchronize(stream_));
}
return signals;
}
template<typename T>
void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
{
NvtxScope scope("infer_request");
auto& state = *incoming_;
FT_CHECK(state.size == 0);
FT_CHECK(state.active_size == 0);
std::vector<int> existing_idx;
int idx = 0;
for (const auto& r : requests) {
FT_CHECK(!state.requests[idx]);
if (rank_ == 0) {
TM_LOG_WARNING("[ProcessInferRequests] Request for %ld received.", (long)r->id);
}
state.requests[idx] = r;
// get sequence for the request
state.sequences[idx] = r->start_flag ? sequence_manager_->Create(r->id) : sequence_manager_->Get(r->id);
FT_CHECK(state.sequences[idx]);
auto& seq = *state.sequences[idx];
if (int step = r->inputs[rank_].getVal<int>("step", -1); step >= 0) {
if (step <= seq.tokens.size()) {
seq.tokens.resize(step);
seq.cache_len = std::min(seq.cache_len, step);
DropEmbeddings(seq);
}
else if (rank_ == 0) {
TM_LOG_WARNING(
"[ProcessInferRequests] Skipping invalid step (%d) setting for ID %ld", step, (long)seq.id);
}
}
const int input_length = r->inputs[rank_].getVal<int>("input_lengths");
const int* input_ids = r->inputs[rank_].getPtr<int>("input_ids");
// `output_ids` contains all token ids of the sequences
const auto output_ids_base = state.output_ids + session_len_ * idx;
auto output_ids = output_ids_base;
// copy history tokens
if (!seq.tokens.empty()) {
output_ids = Copy(seq.tokens.data(), seq.tokens.size(), output_ids);
}
// copy input tokens
if (input_length) {
output_ids = Copy(input_ids, input_length, output_ids);
}
// copy input tokens to prompt for prefix matching
if (input_length && r->start_flag && !r->inputs[rank_].isExist("input_embedding_ranges")) {
// TODO: truncate prompt to enable prefix caching for VLM
seq.prompt.resize(input_length);
std::copy_n(input_ids, input_length, seq.prompt.data());
}
// copy input embeddings
if (r->inputs[rank_].isExist("input_embedding_ranges")) {
const auto range_tensor = r->inputs[rank_].at("input_embedding_ranges");
const auto emb_tensor = r->inputs[rank_].at("input_embeddings");
const int* ranges = range_tensor.getPtr<int>();
auto check_embeddings = [&](int& num_valid_embeddings) {
if (range_tensor.shape.size() != 3 || range_tensor.shape[2] % 2 != 0) {
return false;
}
int embedding_count = range_tensor.shape[1];
int embedding_length = 0;
int pre_end = -1;
for (size_t i = 0; i < embedding_count; i++) {
int begin = ranges[i * 2];
int end = ranges[i * 2 + 1];
embedding_length += (end - begin);
if (begin < 0 || end < 0) {
break;
}
if (begin >= end || end > input_length || begin < pre_end
|| embedding_length * model_->hidden_units_ * sizeof(T) > emb_tensor.shape[1]) {
return false;
}
pre_end = end;
num_valid_embeddings = i + 1;
}
return true;
};
int num_valid_embeddings = 0;
if (!check_embeddings(num_valid_embeddings)) {
TM_LOG_WARNING("[ImageFeature] Skip invalid input embeddings, id = %ld, input_length = %d, "
"input embeddings = %s, range_tensor = %s",
(long)seq.id,
input_length,
emb_tensor.toString().c_str(),
range_tensor.toString().c_str());
}
else {
char* emb_tensor_ptr = emb_tensor.getPtr<char>();
for (size_t i = 0; i < num_valid_embeddings; i++) {
int begin = ranges[i * 2];
int end = ranges[i * 2 + 1];
size_t count = (end - begin) * model_->hidden_units_ * sizeof(T);
seq.input_embeddings.emplace_back((std::byte*)emb_tensor_ptr, (std::byte*)(emb_tensor_ptr + count));
seq.input_embedding_ranges.emplace_back(begin + seq.tokens.size(), end + seq.tokens.size());
emb_tensor_ptr += count;
}
}
}
// total context length (history + input)
state.h_prompt_length[idx] = output_ids - output_ids_base;
state.h_context_length[idx] = output_ids - output_ids_base;
state.h_finished[idx] = false;
const int request_output_len = state.requests[idx]->inputs[rank_].getVal<int>("request_output_len");
state.seq_len_limit[idx] = state.h_context_length[idx] + request_output_len;
// `length_criterion` sets finish flag when step >= seq_limit_len, however when step == seq_limit_len
// the actual sequence length is seq_limit_len + 1, hence seq_limit_len must truncated to session_len - 1
if (state.seq_len_limit[idx] >= session_len_) {
state.seq_len_limit[idx] = session_len_ - 1;
if (rank_ == 0) {
const int trunc_output_len = state.seq_len_limit[idx] - state.h_context_length[idx];
TM_LOG_WARNING(
"[ProcessInferRequests] [%ld] total sequence length (%d + %d) exceeds `session_len` (%d), `request_output_len` is truncated to %d",
(long)seq.id,
state.h_context_length[idx],
request_output_len,
(int)session_len_,
trunc_output_len);
}
}
// compute rope scaling factor
if (r->start_flag) {
seq.rope_theta = model_->attn_params_.rotary_embedding_base;
if (model_->attn_params_.use_dynamic_ntk) {
auto scaling_factor = model_->attn_params_.rope_scaling_factor;
if (scaling_factor >= 1.f) { // infer by current context length
auto max_seq_len = state.h_context_length[idx];
auto max_pos_emb = model_->attn_params_.max_position_embeddings;
if (max_seq_len > max_pos_emb) {
scaling_factor = scaling_factor * max_seq_len / max_pos_emb - (scaling_factor - 1);
float rope_dim = model_->attn_params_.rotary_embedding_dim;
seq.rope_theta *= powf(scaling_factor, rope_dim / (rope_dim - 2.f));
TM_LOG_INFO("[ProcessInferRequests] %ld rope_scaling_factor: %f, rope_theta = %f",
(long)seq.id,
scaling_factor,
seq.rope_theta);
}
}
}
}
state.h_rope_theta[idx] = seq.rope_theta;
if (r->start_flag) {
// prepare to initialize random state for new sequence
h_random_seed_[idx] = r->inputs[rank_].getVal<unsigned long long>("random_seed", 0);
}
else {
// Recover device states if not a new sequence
h_curand_state_[existing_idx.size()] = *(curandState_t*)seq.random_state.data();
existing_idx.push_back(idx);
}
// ! SHARED STATE IS MODIFIED, BARRIER SYNCHRONIZATION REQUIRED
// assign priority based on arrival time
if (rank_ == 0) {
r->unique_id = request_count_++;
}
// increment pointer
idx++;
}
state.size = idx;
// when there are new sequences
if (state.size != existing_idx.size()) {
// copy random seeds to device
Copy(h_random_seed_, state.size, d_random_seed_);
// initialize random states
invokeCurandBatchInitialize(state.curand_state, state.size, d_random_seed_, stream_);
sync_check_cuda_error();
}
if (!existing_idx.empty()) {
// copy existing curand states to device
Copy(h_curand_state_, existing_idx.size(), d_curand_state_);
// insert the states to their correct positions in the batch
IndexedCopy({}, existing_idx, std::tuple{d_curand_state_, state.curand_state, 1});
}
}
template<typename T>
void LlamaBatch<T>::AdjustMaxInputCount(GenerationState& g,
const std::vector<const Sequence*>& sequences,
const std::vector<int>& context_length)
{
int input_count = 0;
for (int i = 0; i < sequences.size(); ++i) {
input_count += context_length[i] - sequences[i]->cache_len;
}
const int batch_size = sequences.size();
input_count -= batch_size;
// min tokens per iter for satisfying max prefill iters constraint
input_count = (input_count + max_prefill_iters_ - 1) / max_prefill_iters_;
if (g.min_input_count.empty()) {
g.min_input_count.resize(max_prefill_iters_);
}
g.min_input_count.pop_front();
g.min_input_count.push_back(input_count);
/// TODO: sub-optimal when there are inactive sequences due to memory constraint
for (auto& x : g.min_input_count) {
x = std::max(x, input_count);
}
input_count = std::max(g.min_input_count.front() + batch_size, num_tokens_per_iter_);
input_count = std::min(input_count, max_context_token_num_);
// update max input count
g.max_input_count1 = input_count;
g.max_input_count2 = std::min(input_count + extra_tokens_per_iter_, max_context_token_num_);
}
template<typename T>
void LlamaBatch<T>::Initialize(GenerationState& g)
{
NvtxScope scope("initialize");
std::vector<const Sequence*> sequences;
std::vector<Sequence::Status> status;
std::vector<uint64_t> priorities;
std::vector<int> context_lengths;
std::vector<std::pair<BatchState*, int>> coords;
// count the holes introduced by finished requests in from previous iteration or stop requests from
// current iteration
int holes{};
int active_holes{};
for (int i = 0; i < state_->size; ++i) {
if (!state_->requests[i]) {
++holes;
if (i < state_->active_size) {
++active_holes;
}
}
}
auto process = [&](BatchState* state) {
for (int i = 0; i < state->size; ++i) {
if (auto& r = state->requests[i]) {
sequences.push_back(state->sequences[i]);
status.push_back(state->sequences[i]->status);
priorities.push_back(r->unique_id);
context_lengths.push_back(state->h_context_length[i]);
coords.emplace_back(state, i);
}
}
};
process(state_);
process(incoming_);
auto adjust = [this, &g](const Sequences& sequences,
const std::vector<int>& context_length) -> std::pair<int, int> {
AdjustMaxInputCount(g, sequences, context_length);
return {g.max_input_count1, g.max_input_count2};
};
// TM_LOG_INFO("max_input_count %d", max_input_count);
auto outcome = sequence_manager_->Materialize(sequences, context_lengths, priorities, step_length_, adjust);
if (outcome.allocation || outcome.swap_in || outcome.swap_out) {
dbg(outcome);
}
bool exchange = outcome.swap_in + outcome.swap_out > 0;
std::vector<int> idxs(sequences.size());
std::iota(idxs.begin(), idxs.end(), 0);
if (exchange || holes || incoming_->size) {
// put active ones first
auto active_end = std::stable_partition(idxs.begin(), idxs.end(), [&](int idx) {
return sequences[idx]->status == Sequence::kActive; // current status
});
// all blocks are not enough to hold a single sequence
if (!sequences.empty()) {
FT_CHECK_WITH_INFO(active_end != idxs.begin(), "No enough blocks.");
}
// move the partial seq to the back
auto partial_beg = std::stable_partition(idxs.begin(), active_end, [&](int i) {
return sequences[i]->cache_len + sequences[i]->input_length == context_lengths[i];
});
FT_CHECK(active_end - partial_beg <= 1);
auto swapin_beg = std::stable_partition(idxs.begin(), partial_beg, [&](int i) {
return status[i] == Sequence::kActive; // past status
});
// sort swap-ins according to input length
if (swapin_beg != partial_beg) {
std::stable_sort(swapin_beg, partial_beg, [&](int i, int j) {
return sequences[i]->input_length < sequences[j]->input_length;
});
}
// Copy sequence states to back buffer
FT_CHECK(back_->size == 0 && back_->active_size == 0);
std::vector<std::tuple<BatchState*, BatchState*, int, int>> cpys;
for (const auto& i : idxs) {
auto& s = *sequences[i];
if (s.status == Sequence::kActive) {
++back_->active_size;
}
cpys.emplace_back(coords[i].first, back_, coords[i].second, back_->size++);
}
CopyState(cpys);
// Swap the buffers
std::swap(state_, back_);
ClearState(*back_);
ClearState(*incoming_);
}
FT_CHECK(state_->size <= max_batch_size_);
/// Update block ptrs when there were
// 1. swap-in or swap-out
// 2. holes in the active buffer
// 3. new allocations (for existing active sequences)
if (exchange || active_holes || outcome.allocation) {
// Prepare intermediate buffers
h_cu_block_counts_[0] = 0;
auto block_ptrs = h_block_ptrs_;
const int batch_size = state_->active_size;
for (int i = 0; i < batch_size; ++i) {
const auto& seq = *state_->sequences[i];
// cumulative num of blocks
h_cu_block_counts_[i + 1] = h_cu_block_counts_[i] + seq.blocks.size();
block_ptrs = std::transform(seq.blocks.cbegin(), seq.blocks.cend(), block_ptrs, [&](int block_id) {
return reinterpret_cast<uintptr_t>(sequence_manager_->GetBlockPtr(block_id));
});
}
static_assert(sizeof(uintptr_t) == sizeof(void*));
Copy(h_cu_block_counts_, batch_size + 1, cu_block_counts_);
Copy(h_block_ptrs_, h_cu_block_counts_[batch_size], block_ptrs_);
// Copy(h_k_block_ptrs_, h_cu_block_counts_[batch_size], k_block_ptrs_);
// Copy(h_v_block_ptrs_, h_cu_block_counts_[batch_size], v_block_ptrs_);
}
const int batch_size = state_->active_size;
// check if the last sequence is partial
int partial = 0;
int partial_len = -1;
if (state_->active_size) {
const int i = state_->active_size - 1;
partial = state_->sequences[i]->cache_len + state_->sequences[i]->input_length != state_->h_context_length[i];
if (partial) {
// backup full context length of partial
partial_len = state_->h_context_length[i];
// replace with partial context length
state_->h_context_length[i] = state_->sequences[i]->cache_len + state_->sequences[i]->input_length;
}
}
const int max_context_len = *std::max_element(state_->h_context_length, state_->h_context_length + batch_size);
std::vector<uint64_t> unique_ids(batch_size);
for (int i = 0; i < batch_size; ++i) {
unique_ids[i] = state_->requests[i]->unique_id;
}
// Real-time context length that will change during generation
Copy(state_->h_context_length, batch_size, context_length_buf_);
Copy(state_->h_finished, batch_size, finished_buf_);
Copy(state_->h_rope_theta, batch_size, rope_theta_);
bool skip_init_sampling = std::equal(g.unique_ids.begin(), //
g.unique_ids.end() - g.partial,
unique_ids.begin(),
unique_ids.end() - partial);
g.partial = partial;
g.partial_context_legnth = partial_len;
g.unique_ids = std::move(unique_ids);
g.finished_count = 0;
if (!skip_init_sampling) {
g.max_init_ctx_len = max_context_len;
g.step = max_context_len;
InitializeSampling(g);
}
}
template<typename T>
void LlamaBatch<T>::CopyState(const std::vector<std::tuple<BatchState*, BatchState*, int, int>>& desc)
{
if (desc.empty()) {
return;
}
std::vector<int> idxs(desc.size());
std::iota(idxs.begin(), idxs.end(), 0);
std::sort(idxs.begin(), idxs.end(), [&](int i, int j) { return desc[i] < desc[j]; });
auto get_signature = [&](int i) -> std::pair<BatchState*, BatchState*> {
return std::make_pair(std::get<0>(desc[idxs[i]]), std::get<1>(desc[idxs[i]]));
};
std::vector<int> offsets;
auto current = get_signature(0);
offsets.push_back(0);
for (int i = 0; i < idxs.size(); ++i) {
if (auto signature = get_signature(i); signature != current) {
current = signature;
offsets.push_back(i);
}
}
offsets.push_back(idxs.size());
for (int bi = 1; bi < offsets.size(); ++bi) {
int beg = offsets[bi - 1];
int end = offsets[bi];
if (beg == end) {
continue;
}
auto [s, d] = get_signature(beg);
std::vector<int> s_idx;
std::vector<int> d_idx;
for (int i = beg; i < end; ++i) {
s_idx.push_back(std::get<2>(desc[idxs[i]]));
d_idx.push_back(std::get<3>(desc[idxs[i]]));
}
IndexedCopy(s_idx,
d_idx,
std::tuple{s->output_ids, d->output_ids, session_len_},
std::tuple{s->curand_state, d->curand_state, 1});
}
for (const auto& [s, d, si, di] : desc) {
d->h_prompt_length[di] = s->h_prompt_length[si];
d->h_context_length[di] = s->h_context_length[si];
d->h_finished[di] = s->h_finished[si];
d->h_rope_theta[di] = s->h_rope_theta[si];
d->seq_len_limit[di] = s->seq_len_limit[si];
d->sequences[di] = s->sequences[si];
d->requests[di] = s->requests[si];
}
}
template<typename T>
void LlamaBatch<T>::AllocateBuffer(size_t batch_size, size_t session_len, int cache_block_seq_len)
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
const size_t batchxbeam = batch_size;
const size_t hidden_units = model_->hidden_units_;
const size_t vocab_size = model_->vocab_size_padded_;
const size_t head_dim = model_->size_per_head_;
const size_t local_kv_head_num = model_->local_kv_head_num_;
// +1 padding, BlockIterator does not use predicate
const size_t max_batch_block_count =
batch_size * ((session_len + cache_block_seq_len - 1) / cache_block_seq_len) + 1;
if (model_->lora_params_.policy == LoraPolicy::kPlora) {
lora_mask_buf_ = (int*)allocator_->reMalloc(lora_mask_buf_, sizeof(int) * max_context_token_num_, false);
size_t sz = sizeof(T) * max_context_token_num_ * (hidden_units + model_->lora_params_.max_wo_r);
context_decoder_output_buf_ = (T*)allocator_->reMalloc(context_decoder_output_buf_, sz, false);
}
else {
context_decoder_output_buf_ = (T*)allocator_->reMalloc(
context_decoder_output_buf_, sizeof(T) * max_context_token_num_ * hidden_units, false);
}
context_decoder_input_buf_ =
(T*)allocator_->reMalloc(context_decoder_input_buf_, sizeof(T) * max_context_token_num_ * hidden_units, false);
context_decoder_ids_buf_ =
(int*)allocator_->reMalloc(context_decoder_ids_buf_, sizeof(int) * max_context_token_num_, false);
decoder_input_buf_ = (T*)allocator_->reMalloc(decoder_input_buf_, sizeof(T) * batchxbeam * hidden_units, false);
decoder_output_buf_ = (T*)allocator_->reMalloc(decoder_output_buf_, sizeof(T) * batchxbeam * hidden_units, false);
input_ids_buf_ = (int*)allocator_->reMalloc(input_ids_buf_, sizeof(int) * batchxbeam * session_len, true);
input_length_buf_ = (int*)allocator_->reMalloc(input_length_buf_, sizeof(int) * batchxbeam);
context_length_buf_ = (int*)allocator_->reMalloc(context_length_buf_, sizeof(int) * batchxbeam);
init_context_length_ = (int*)allocator_->reMalloc(init_context_length_, sizeof(int) * batchxbeam);
sequence_lengths_ = (int*)allocator_->reMalloc(sequence_lengths_, sizeof(int) * batchxbeam, false);
cu_block_counts_ = (int*)allocator_->reMalloc(cu_block_counts_, sizeof(int) * (batch_size + 1));
block_ptrs_ = (uintptr_t*)allocator_->reMalloc(block_ptrs_, sizeof(uintptr_t) * max_batch_block_count);
logits_buf_ = (float*)allocator_->reMalloc(logits_buf_, sizeof(float) * batchxbeam * vocab_size, false);
local_logits_buf_ = (float*)allocator_->reMalloc(local_logits_buf_, sizeof(float) * batchxbeam * vocab_size, false);
sampled_logprobs_ =
(float*)allocator_->reMalloc(sampled_logprobs_, sizeof(float) * batchxbeam * kMaxLogProb, false);
sampled_indexes_ =
(uint32_t*)allocator_->reMalloc(sampled_indexes_, sizeof(uint32_t) * batchxbeam * kMaxLogProb, false);
sampled_nums_ = (uint32_t*)allocator_->reMalloc(sampled_nums_, sizeof(uint32_t) * batchxbeam, false);
token_ids_buf_ = (int*)allocator_->reMalloc(token_ids_buf_, sizeof(int) * batchxbeam * session_len * 2, true);
finished_buf_ = (bool*)allocator_->reMalloc(finished_buf_, sizeof(bool) * batchxbeam, false);
seq_limit_len_ = (uint32_t*)allocator_->reMalloc(seq_limit_len_, sizeof(uint32_t) * batch_size, false);
rope_theta_ = (float*)allocator_->reMalloc(rope_theta_, sizeof(float) * batch_size, false);
is_allocate_buffer_ = true;
}
template<typename T>
void LlamaBatch<T>::AllocatePersistantBuffer(size_t max_batch_size, int cache_block_seq_len)
{
d_stop_words_ =
(int*)allocator_->reMalloc(d_stop_words_, sizeof(int) * max_batch_size * 2 * kMaxStopBadWordsLen, true);
d_bad_words_ =
(int*)allocator_->reMalloc(d_bad_words_, sizeof(int) * max_batch_size * 2 * kMaxStopBadWordsLen, true);
h_stop_words_ =
(int*)allocator_->reMalloc(h_stop_words_, sizeof(int) * max_batch_size * 2 * kMaxStopBadWordsLen, true, true);
h_bad_words_ =
(int*)allocator_->reMalloc(h_bad_words_, sizeof(int) * max_batch_size * 2 * kMaxStopBadWordsLen, true, true);
h_min_length_ = (int*)allocator_->reMalloc(h_min_length_, sizeof(int) * max_batch_size, true, true);
h_runtime_top_k_ = (int*)allocator_->reMalloc(h_runtime_top_k_, sizeof(int) * max_batch_size, true, true);
h_runtime_top_p_ = (float*)allocator_->reMalloc(h_runtime_top_p_, sizeof(float) * max_batch_size, true, true);
h_temperature_ = (float*)allocator_->reMalloc(h_temperature_, sizeof(float) * max_batch_size, true, true);
h_repetition_penalty_ =
(float*)allocator_->reMalloc(h_repetition_penalty_, sizeof(float) * max_batch_size, true, true);
h_random_seed_ = (unsigned long long*)allocator_->reMalloc(
h_random_seed_, sizeof(unsigned long long) * max_batch_size, true, true);
d_random_seed_ = (unsigned long long*)allocator_->reMalloc(
d_random_seed_, sizeof(unsigned long long) * max_batch_size, true, false);
h_curand_state_ =
(curandState_t*)allocator_->reMalloc(h_curand_state_, sizeof(curandState_t) * max_batch_size, true, true);
d_curand_state_ =
(curandState_t*)allocator_->reMalloc(d_curand_state_, sizeof(curandState_t) * max_batch_size, true, false);
d_end_ids_buf_ = (int*)allocator_->reMalloc(d_end_ids_buf_, sizeof(int) * max_batch_size, false);
h_end_ids_buf_ = (int*)allocator_->reMalloc(h_end_ids_buf_, sizeof(int) * max_batch_size, false, true);
sampling_params_ = {
{"stop_words_list", (std::byte*)h_stop_words_, (std::byte*)d_stop_words_},
{"bad_words_list", (std::byte*)h_bad_words_, (std::byte*)d_bad_words_},
{"min_length", (std::byte*)h_min_length_, nullptr},
{"runtime_top_k", (std::byte*)h_runtime_top_k_, nullptr},
{"runtime_top_p", (std::byte*)h_runtime_top_p_, nullptr},
{"temperature", (std::byte*)h_temperature_, nullptr},
{"repetition_penalty", (std::byte*)h_repetition_penalty_, nullptr},
};
for (auto& s : states_) {
s.output_ids = (int*)allocator_->reMalloc(s.output_ids, sizeof(int) * max_batch_size * session_len_, true);
s.curand_state =
(curandState_t*)allocator_->reMalloc(s.curand_state, sizeof(curandState_t) * max_batch_size, true);
}
const size_t max_batch_block_count =
max_batch_size * ((session_len_ + cache_block_seq_len - 1) / cache_block_seq_len);
{
NcclGuard barrier(model_->tensor_para_, stream_, true);
h_input_ids_buf_ =
(int*)allocator_->reMalloc(h_input_ids_buf_, sizeof(int) * max_batch_size * session_len_, false, true);
h_input_length_buf_ =
(int*)allocator_->reMalloc(h_input_length_buf_, sizeof(int) * max_batch_size, false, true);
h_cu_block_counts_ =
(int*)allocator_->reMalloc(h_cu_block_counts_, sizeof(int) * (max_batch_size + 1), false, true);
h_block_ptrs_ =
(uintptr_t*)allocator_->reMalloc(h_block_ptrs_, sizeof(uintptr_t) * max_batch_block_count, false, true);
for (auto& s : states_) {
s.h_prompt_length =
(int*)allocator_->reMalloc(s.h_prompt_length, sizeof(int) * max_batch_size, false, true);
s.h_context_length =
(int*)allocator_->reMalloc(s.h_context_length, sizeof(int) * max_batch_size, false, true);
s.h_finished = (bool*)allocator_->reMalloc(s.h_finished, sizeof(bool) * max_batch_size * 2, false, true);
s.h_rope_theta = (float*)allocator_->reMalloc(s.h_rope_theta, sizeof(float) * max_batch_size, false, true);
}
h_seq_limit_len_ =
(uint32_t*)allocator_->reMalloc(h_seq_limit_len_, sizeof(uint32_t) * max_batch_size, false, true);
h_output_ids_ =
(int*)allocator_->reMalloc(h_output_ids_, sizeof(int) * max_batch_size * session_len_, false, true);
}
h_sampled_logprobs_ =
(float*)allocator_->reMalloc(h_sampled_logprobs_, sizeof(float) * max_batch_size * kMaxLogProb, false, true);
h_sampled_indexes_ = (uint32_t*)allocator_->reMalloc(
h_sampled_indexes_, sizeof(uint32_t) * max_batch_size * kMaxLogProb, false, true);
h_sampled_nums_ = (uint32_t*)allocator_->reMalloc(h_sampled_nums_, sizeof(uint32_t) * max_batch_size, false, true);
is_allocate_persistant_buffer_ = true;
}
template<typename T>
void LlamaBatch<T>::FreeBuffer()
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
if (is_allocate_buffer_) {
allocator_->free((void**)&context_decoder_input_buf_);
allocator_->free((void**)&context_decoder_output_buf_);
allocator_->free((void**)&context_decoder_ids_buf_);
allocator_->free((void**)&lora_mask_buf_);
allocator_->free((void**)&decoder_input_buf_);
allocator_->free((void**)&decoder_output_buf_);
allocator_->free((void**)&input_ids_buf_);
allocator_->free((void**)&input_length_buf_);
allocator_->free((void**)&context_length_buf_);
allocator_->free((void**)&init_context_length_);
allocator_->free((void**)&sequence_lengths_);
allocator_->free((void**)&cu_block_counts_);
allocator_->free((void**)&block_ptrs_);
allocator_->free((void**)&logits_buf_);
allocator_->free((void**)&local_logits_buf_);
if (local_context_logits_buf_) {
allocator_->free((void**)&local_context_logits_buf_);
}
if (context_logits_buf_) {
allocator_->free((void**)&context_logits_buf_);
}
allocator_->free((void**)&token_ids_buf_);
allocator_->free((void**)&d_end_ids_buf_);
allocator_->free((void**)&h_end_ids_buf_, true);
allocator_->free((void**)&finished_buf_);
allocator_->free((void**)&seq_limit_len_);
allocator_->free((void**)&rope_theta_);
allocator_->free((void**)&sampled_logprobs_);
allocator_->free((void**)&sampled_indexes_);
allocator_->free((void**)&sampled_nums_);
is_allocate_buffer_ = false;
}
if (is_allocate_persistant_buffer_) {
allocator_->free((void**)&d_stop_words_);
allocator_->free((void**)&h_stop_words_, true);
allocator_->free((void**)&d_bad_words_);
allocator_->free((void**)&h_bad_words_, true);
allocator_->free((void**)&d_random_seed_);
allocator_->free((void**)&h_random_seed_, true);
allocator_->free((void**)&d_curand_state_);
allocator_->free((void**)&h_curand_state_, true);
for (auto& s : states_) {
allocator_->free((void**)&s.h_context_length, true);
allocator_->free((void**)&s.h_finished, true);
allocator_->free((void**)&s.h_rope_theta, true);
allocator_->free((void**)&s.output_ids);
allocator_->free((void**)&s.curand_state);
}
allocator_->free((void**)&h_cu_block_counts_, true);
allocator_->free((void**)&h_block_ptrs_, true);
allocator_->free((void**)&h_input_ids_buf_, true);
allocator_->free((void**)&h_input_length_buf_, true);
allocator_->free((void**)&h_seq_limit_len_, true);
allocator_->free((void**)&h_output_ids_, true);
allocator_->free((void**)&h_sampled_logprobs_);
allocator_->free((void**)&h_sampled_indexes_);
allocator_->free((void**)&h_sampled_nums_);
is_allocate_persistant_buffer_ = false;
}
}
template<typename T>
LlamaBatch<T>::LlamaBatch(const EngineParams& params, int cache_block_seq_len, int quant_policy, LlamaV2<T>* model):
max_batch_size_(params.max_batch_size),
max_context_token_num_(params.max_context_token_num),
session_len_(params.session_len),
rank_(model->tensor_para_.rank_),
debug_(model->debug_),
step_length_(params.step_length),
model_(model),
data_type_(getTensorType<T>()),
num_tokens_per_iter_(params.num_tokens_per_iter),
extra_tokens_per_iter_(params.extra_tokens_per_iter),
max_prefill_iters_(params.max_prefill_iters)
{
stream_ = model_->stream_;
allocator_ = model_->allocator_;
cublas_wrapper_ = model_->cublas_wrapper_;
const int elem_bits = quant_policy ? quant_policy : bitsof<T>;
auto get_free_size = [&] {
return GetSyncFreeMemSize(*model_->shared_state_->barrier, model_->shared_state_->free_size);
};
SequenceManager::BlockConfig block_config{
(int)model_->size_per_head_,
(int)model_->local_kv_head_num_,
cache_block_seq_len,
elem_bits == bitsof<T> ? 0 : bitsof<T>,
elem_bits,
};
sequence_manager_.reset(new SequenceManager{model_->num_layer_,
block_config,
params.cache_max_block_count,
params.cache_chunk_size,
params.enable_prefix_caching,
model->tensor_para_.rank_,
allocator_,
get_free_size});
const size_t max_session_len = sequence_manager_->max_block_count() * cache_block_seq_len;
if (max_session_len < session_len_) {
if (rank_ == 0) {
TM_LOG_WARNING("No enough blocks for `session_len` (%d), `session_len` truncated to %d.",
session_len_,
max_session_len);
}
session_len_ = max_session_len;
}
FT_CHECK(max_context_token_num_ >= session_len_);
for (auto& s : states_) {
s.requests.resize(max_batch_size_);
s.sequences.resize(max_batch_size_);
s.seq_len_limit.resize(max_batch_size_);
}
state_ = &states_[0];
back_ = &states_[1];
incoming_ = &states_[2];
AllocateBuffer(max_batch_size_, session_len_, cache_block_seq_len);
AllocatePersistantBuffer(max_batch_size_, cache_block_seq_len);
}
template<typename T>
void LlamaBatch<T>::InitializeSampling(const GenerationState& g)
{
NvtxScope _("InitSampling");