/
te_compiler.cc
1205 lines (1081 loc) · 51.6 KB
/
te_compiler.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "./te_compiler.h"
#include <tvm/driver/driver_api.h>
#include <tvm/ir/attrs.h>
#include <tvm/ir/function.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/attrs/call.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/schedule.h>
#include <tvm/te/schedule_pass.h>
#include <tvm/tir/transform.h>
#include <tvm/topi/tags.h>
#include <functional>
#include <limits>
#include <mutex>
#include <tuple>
#include <unordered_map>
#include <utility>
#include <vector>
#include "../op/annotation/annotation.h"
#include "../op/call/call.h"
#include "../op/memory/device_copy.h"
#include "../transforms/device_aware_visitors.h"
#include "./te_compiler_cache.h"
#include "./utils.h"
namespace tvm {
namespace relay {
// TODO(@jroesch, @csullivan): declare directly elsewhere
backend::StaticMemoryPlan GraphPlanMemory(const Function& func);
namespace tec {
using namespace tvm::relay::transform;
TVM_REGISTER_OBJECT_TYPE(TECompilerNode);
class TECompilerImpl : public TECompilerNode {
public:
explicit TECompilerImpl(Optional<IRModule> opt_mod) {
// Make sure we don't collide with any existing globals in the module.
if (opt_mod) {
for (const auto& kv : opt_mod.value()->functions) {
name_map_[kv.first->name_hint] = 1;
}
}
}
// Lower the function.
CachedFunc Lower(const CCacheKey& key, std::function<String(String)> mangle_fn) {
return LowerInternal(key, mangle_fn)->cached_func;
}
CachedFunc Lower(const CCacheKey& key, const String mod_name) {
auto mangle_fn = [mod_name](String name) { return runtime::get_name_mangled(mod_name, name); };
return Lower(key, mangle_fn);
}
// For now, build one module per function.
PackedFunc JIT(const CCacheKey& key) final {
auto mangle_fn = [](String name) { return name; };
CCacheValue value = LowerInternal(key, mangle_fn);
if (value->packed_func != nullptr) {
return value->packed_func;
}
auto m = build(value->cached_func->funcs, key->target, Target(nullptr));
value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint);
return value->packed_func;
}
CachedFunc LowerShapeFunc(const CCacheKey& key) final {
return LowerShapeFuncInternal(key)->cached_func;
}
IRModule GetLoweredFunctions() {
VLOG(1) << "GetLoweredFunctions";
IRModule mod;
// Extract lowered functions from the cache
for (const auto& it : cache_) {
auto source_func = it.first;
auto lowered_func = it.second;
IRModule lowered_mod = lowered_func->cached_func->funcs;
// Annotate functions with their target and put them in the return module
for (const auto& kv : lowered_mod->functions) {
const GlobalVar& var = kv.first;
const BaseFunc& func = kv.second;
// Only add functions that are not external functions
if (!func->GetAttr<String>(attr::kCompiler).defined()) {
ICHECK(func->IsInstance<tir::PrimFuncNode>())
<< "Expected all functions that are not external to be PrimFuncs, but found:"
<< std::endl
<< PrettyPrint(func);
const tir::PrimFunc& prim_func = Downcast<tir::PrimFunc>(func);
mod->Update(var, WithAttr(prim_func, tvm::attr::kTarget, source_func->target));
}
}
}
// Extract lowered dynamic shape functions from the shape cache
for (const auto& it : shape_func_cache_) {
auto source_func = it.first;
auto lowered_func = it.second;
auto target = source_func->target;
IRModule lowered_mod = lowered_func->cached_func->funcs;
// Annotate functions with their target and put them in the return module
for (auto kv : lowered_mod->functions) {
const GlobalVar& var = kv.first;
const BaseFunc& func = kv.second;
const tir::PrimFunc& prim_func = Downcast<tir::PrimFunc>(func);
mod->Update(var, WithAttr(prim_func, tvm::attr::kTarget, source_func->target));
}
}
return mod;
}
void AddExterns(IRModule module) {
// Everything tagged with "Compiler" has been compiled, so remove those definitions.
std::vector<GlobalVar> to_be_deleted;
for (const auto& kv : module->functions) {
if (kv.second->GetAttr<String>(attr::kCompiler).defined()) {
to_be_deleted.push_back(kv.first);
}
}
for (const auto& global_var : to_be_deleted) {
module->Remove(global_var);
}
// HOWEVER we still need a Relay definition to go with those now external functions, so
// retrieve them from the cache and mark them with "ExternalSymbol".
for (const auto& kv1 : cache_) {
auto src_func = kv1.first->source_func;
ICHECK(src_func.defined());
if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
for (const auto& kv2 : kv1.second->cached_func->funcs->functions) {
if (const auto* function_node = kv2.second.as<FunctionNode>()) {
// Abandon the existing function annotations.
// Unfortuantely, Optional<DictAttrs>() is indistinguishable from
// NullValue<DictAttrs>(), and DictAttrs() is nullptr, so to erase the attributes, we
// need pass in DictAttrs<Map<String, ObjectRef>()), which is a DictAttrs containing no
// attributes.
Function function =
WithFields(GetRef<Function>(function_node), function_node->params,
function_node->body, function_node->ret_type, function_node->type_params,
/* erase attributes */ DictAttrs(Map<String, ObjectRef>()));
// Mark function as 'extern' using the "ExternalSymbol" attribute.
function = WithAttr(std::move(function), attr::kExternalSymbol, kv2.first->name_hint);
module->Add(kv2.first, function);
}
}
}
}
}
Array<tvm::runtime::Module> LowerExternalFunctions() {
Array<tvm::runtime::Module> ret;
std::vector<CCacheKey> cached_ext_funcs;
for (const auto& it : cache_) {
auto src_func = it.first->source_func;
ICHECK(src_func.defined());
Optional<String> opt_compiler = src_func->GetAttr<String>(attr::kCompiler);
if (opt_compiler.defined()) {
Optional<String> opt_symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(opt_symbol_name.defined()) << "No external symbol is set for:" << std::endl
<< PrettyPrint(src_func);
VLOG(1) << "using external codegen '" << opt_compiler.value() << "' for name '"
<< opt_symbol_name.value() << "' and function:" << std::endl
<< PrettyPrint(src_func);
cached_ext_funcs.push_back(it.first);
std::string ext_name = "relay.ext." + opt_compiler.value();
auto pf = tvm::runtime::Registry::Get(ext_name);
ICHECK(pf) << "Failed to find the codegen tool for " << ext_name;
// No need to keep compiler attribute at this point, functions have been
// extracted for specific codegen.
src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue<ObjectRef>());
VLOG_CONTEXT << ext_name;
runtime::Module ext_mod = (*pf)(src_func);
if (ext_mod.defined()) {
if (ext_mod->GetFunction(opt_symbol_name.value(), /*query_imports=*/true) == nullptr) {
// It's possible the codegen yielded C or C++ tracked separately and thus the
// returned runtime module can be empty.
VLOG(1) << "Unable to find definition for the external function '"
<< opt_symbol_name.value()
<< "' in the runtime module generated by external codegen '"
<< opt_compiler.value() << "'";
}
ret.push_back(ext_mod);
} else {
// A warning only so that we can write unit tests which can return an empty runtime
// module.
LOG(WARNING) << "No external runtime module was generated by external codegen '"
<< opt_compiler.value() << "'";
}
}
}
// No need to cache external functions as we collected them all to create
// external runtime modules.
for (const auto& it : cached_ext_funcs) {
cache_.erase(it);
}
return ret;
}
Map<GlobalVar, String> GetDeviceContexts() { return device_contexts_; }
void SetDeviceContexts(const Map<GlobalVar, String>& device_contexts) {
device_contexts_ = device_contexts;
}
void Clear() final { cache_.clear(); }
// List all items in the cache.
Array<ObjectRef> ListItems() {
std::lock_guard<std::mutex> lock(mutex_);
Array<ObjectRef> items;
for (auto& kv : cache_) {
items.push_back(kv.first);
items.push_back(kv.second);
}
return items;
}
/*!
* \brief Get the cache key of the function that is being lowered currently
* \return the cache key
*/
CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; }
private:
// implement lowered func
CCacheValue LowerInternal(const CCacheKey& key, std::function<String(String)> mangle_fn) {
VLOG(1) << "lowering:" << std::endl
<< PrettyPrint(key->source_func) << std::endl
<< "for target:" << std::endl
<< key->target->ToDebugString();
std::lock_guard<std::mutex> lock(mutex_);
CCacheValue value;
auto it = cache_.find(key);
if (it != cache_.end()) {
VLOG(1) << "already lowered to name:" << std::endl
<< PrettyPrint(it->second->cached_func->prim_fn_var);
it->second->use_count += 1;
if (it->second->cached_func.defined()) return it->second;
value = it->second;
} else {
value = CCacheValue(make_object<CCacheValueNode>());
value->use_count = 1;
cache_[key] = value;
}
cur_ccache_key_ = key;
Optional<String> opt_compiler = key->source_func->GetAttr<String>(attr::kCompiler);
if (opt_compiler.defined()) {
// Don't compile now since we don't have anywhere to put the resulting runtime module.
// Instead place the original definition in the cache and wait for LowerExternalFunctions.
IRModule ir_module;
Optional<String> opt_global_symbol =
key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(opt_global_symbol.defined()) << "External function has not been attached a name yet.";
// Note that the source_func may already be bound to a global function in the module
// we are compiling, in which case we should not attempt to make its name unique w.r.t.
// the module's globals. Furthermore, the external codegen tool must bind the compiled
// function to the "global_symbol" attribute on the source_func. So do not use GetUniqueName
// here.
auto target = Target("ext_dev");
auto global_var = GlobalVar(opt_global_symbol.value());
global_var->checked_type_ = key->source_func->checked_type();
ir_module->Add(global_var, key->source_func);
value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule{nullptr},
tir::PrimFunc{nullptr}, {}, ir_module);
// Collect these here as it's removed in LowerExternalFunctions()
device_contexts_.Set(value->cached_func->prim_fn_var, opt_compiler.value());
VLOG(1) << "preparing to use external codegen '" << opt_compiler.value()
<< "' with name:" << std::endl
<< PrettyPrint(value->cached_func->prim_fn_var) << std::endl
<< "and definitions:" << std::endl
<< PrettyPrint(value->cached_func->funcs);
return value;
}
// Enforce use the target.
With<Target> target_scope(key->target);
ICHECK(!value->cached_func.defined());
value->cached_func = PrimFuncFor(key->source_func, key->target, [&](std::string name) {
auto mangled = mangle_fn(name);
return GetUniqueName(mangled, &name_map_);
});
if (value->cached_func->prim_func.defined()) {
VLOG(1) << "Lowering PrimFunc";
IRModule lowered = tvm::LowerPrimFunc(value->cached_func->prim_func.value(),
value->cached_func->prim_fn_var->name_hint, false);
ICHECK_EQ(lowered->functions.size(), 1);
for (const auto& kv : lowered->functions) {
value->cached_func->funcs->Add(value->cached_func->prim_fn_var, kv.second);
}
} else {
// NOTE: array will copy on write.
Array<te::Tensor> all_args = Array<te::Tensor>(value->cached_func->inputs);
for (te::Tensor arg : value->cached_func->outputs) {
all_args.push_back(arg);
}
Array<runtime::NDArray> all_consts;
for (auto kv : value->cached_func->constant_tensors) {
all_args.push_back(kv.second);
all_consts.push_back(kv.first->data);
}
// lower the function
std::unordered_map<te::Tensor, tir::Buffer> binds;
auto func_name = value->cached_func->prim_fn_var->name_hint;
VLOG(1) << "scheduling";
IRModule scheduled_module =
tvm::LowerSchedule(value->cached_func->schedule, all_args, func_name, binds);
scheduled_module->Update(tir::transform::BindParams(all_consts)(scheduled_module));
// Unfortunately the above machinery creates its own GlobalVars instead of using *the*
// GlobalVar we established above. Fix this before the confusion spreads any further.
// TODO(mbs): LowerSchedule should be given prim_fn_gvar instead of func_name.
for (const auto& kv : scheduled_module->functions) {
GlobalVar global_var = kv.first->name_hint == value->cached_func->prim_fn_var->name_hint
? value->cached_func->prim_fn_var
: kv.first;
auto func = kv.second;
// Propagate the structural hash of the relay function to the tir
// function so associations can be made between the two.
Optional<String> hash = key->source_func->attrs.GetAttr<String>("hash");
if (hash) {
func = WithAttrs(Downcast<tir::PrimFunc>(func), {{String("hash"), hash.value()}});
}
value->cached_func->funcs->Add(global_var, func);
}
ICHECK(value->cached_func->funcs->Lookup(value->cached_func->prim_fn_var)
.as<tir::PrimFuncNode>());
}
VLOG(1) << "lowered to name:" << std::endl
<< PrettyPrint(value->cached_func->prim_fn_var) << std::endl
<< "with definitions:" << std::endl
<< PrettyPrint(value->cached_func->funcs);
return value;
}
// implement lowered shape func
CCacheValue LowerShapeFuncInternal(const CCacheKey& key) {
VLOG(1) << "lowering dynamic shape function for:" << std::endl
<< PrettyPrint(key->source_func) << std::endl
<< "for target:" << std::endl
<< key->target->ToDebugString();
std::lock_guard<std::mutex> lock(mutex_);
CCacheValue value;
auto it = shape_func_cache_.find(key);
if (it != shape_func_cache_.end()) {
it->second->use_count += 1;
if (it->second->cached_func.defined()) return it->second;
value = it->second;
} else {
value = CCacheValue(make_object<CCacheValueNode>());
value->use_count = 0;
shape_func_cache_[key] = value;
}
// Enforce use the target.
With<Target> target_scope(key->target);
ICHECK(!value->cached_func.defined());
using tvm::transform::PassContext;
With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
value->cached_func = ShapeFuncFor(key->source_func, key->target, [&](std::string name) {
return GetUniqueName(name, &name_map_);
});
ICHECK(
value->cached_func->funcs->Lookup(value->cached_func->prim_fn_var).as<tir::PrimFuncNode>());
VLOG(1) << "lowered to name:" << std::endl
<< PrettyPrint(value->cached_func->prim_fn_var) << std::endl
<< "with definitions:" << std::endl
<< PrettyPrint(value->cached_func->funcs);
return value;
}
Map<String, Integer> GetOpWeights() const {
Map<String, Integer> weights;
for (const auto& kv : cache_) {
auto value = kv.second;
auto name = value->cached_func->prim_fn_var->name_hint;
weights.Set(name, value->use_count);
}
return weights;
}
// TODO(mbs): Hold the output module here and reduce the cache_ to just be from
// Function to GlobalVar.
/*! \brief compiler cache lock*/
std::mutex mutex_;
/*! \brief internal name map to get an unique name */
std::unordered_map<std::string, int> name_map_;
/*! \brief internal compiler cache */
std::unordered_map<CCacheKey, CCacheValue> cache_;
/*! \brief internal compiler cache for shape funcs */
std::unordered_map<CCacheKey, CCacheValue> shape_func_cache_;
/*! \brief the cache key of the function that is being lowered currently*/
CCacheKey cur_ccache_key_;
/*! \brief Map of GlobalVar to C Device API context names */
Map<GlobalVar, String> device_contexts_;
};
TECompiler::TECompiler(Optional<IRModule> opt_mod) {
auto object = make_object<TECompilerImpl>(std::move(opt_mod));
data_ = object;
}
/*! \brief The global TE compiler */
// TODO(mbs): To be terminated with extreme prejudice.
TECompiler& TECompiler::Global() {
static TECompiler* inst = new TECompiler(make_object<TECompilerImpl>(Optional<IRModule>()));
return *inst;
}
TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_meta_schedule", Bool);
TVM_REGISTER_GLOBAL("relay.backend._TECompilerGlobal").set_body_typed([]() {
return TECompiler::Global();
});
TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey")
.set_body_typed([](Function source_func, Target target) {
return CCacheKey(source_func, target);
});
TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput")
.set_body_typed([](tvm::Array<te::Tensor> outputs, OpImplementation impl) {
return LoweredOutput(outputs, impl);
});
TVM_REGISTER_GLOBAL("relay.backend._TECompilerClear").set_body_typed([](TECompiler self) {
self->Clear();
});
TVM_REGISTER_GLOBAL("relay.backend._TECompilerLower")
.set_body_typed([](TECompiler self, CCacheKey key, const String mod_name) {
return self->Lower(key, mod_name);
});
TVM_REGISTER_GLOBAL("relay.backend._TECompilerJIT")
.set_body_typed([](TECompiler self, CCacheKey key) { return self->JIT(key); });
TVM_REGISTER_GLOBAL("relay.backend._TECompilerListItems").set_body_typed([](TECompiler self) {
TECompilerImpl* ptr = dynamic_cast<TECompilerImpl*>(self.operator->());
ICHECK(ptr != nullptr);
return ptr->ListItems();
});
using AnalysisRemapping = std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual>;
/*!
* \brief Rewrites call expressions to Relay Functions marked as "primitive"
* to calls to the corresponding TIR PrimFunc for the appropriate target.
*
* \code
* %0 = fn(...) { prim_op(...) } OR let %p = fn(...) { prim_op(...) }
* ... %0(...) ... ... %p(...) ...
* ==>
* def @q(..., target=<target>) { <tir body> }
* ... @q(...) ...
* \endcode
*
* Requires FuseOps, ToANormalForm, EtaExpand and InferType to have run.
*
* FuseOps is needed to identify and lift all prim op calls:
* \code
* ... prim_op(...) ...
* ==>
* %0 = fn(...) { prim_op(...) }
* ... %0(...) ...
* \endcode
*
* ToANormalForm is needed so we only need to consider vars and function literals as the call
* target.
*
* EtaExpand is needed to ensures all calls to primitives are direct:
* \code
* let %p1 = fn(...) { prim_op1(...) }
* let %p2 = fn(...) { prim_op2(...) }
* let %p = if (...) { %p1 } else { %p2 }
* ... %p(...) ...
* ==>
* let %p1 = fn(...) { prim_op1(...) }
* let %p2 = fn(...) { prim_op2(...) }
* let %p = fn(...) { if (...) { %p1(...) } else { %p2(...) } }
* ... %p(...) ...
* \endcode
*/
class LowerTensorExprMutator : public DeviceAwareExprMutator {
public:
LowerTensorExprMutator(const IRModule& module, ProcessFn process_fn, String module_name,
TECompiler compiler, VirtualDevice host_virtual_device)
: DeviceAwareExprMutator(module),
module_(module),
process_fn_(std::move(process_fn)),
module_name_(std::move(module_name)),
compiler_(std::move(compiler)),
host_virtual_device_(std::move(host_virtual_device)),
debug_op_(Op::Get("debug")) {}
/*!
* \brief Returns the primitive function associated with \p expr, or nullptr if none.
*/
BaseFunc ResolveToPrimitive(const Expr& expr) {
// NOTE: We can't assume expr->checked_type_ is defined, so can't early exit for first-order
// expressions.
if (const auto* global_var_node = expr.as<GlobalVarNode>()) {
if (!module_->ContainGlobalVar(global_var_node->name_hint)) {
// TODO(mbs): extern function cleanup
// Assume the function is extern and thus no longer in the IRModule.
return {};
} else {
BaseFunc base_func = module_->Lookup(GetRef<GlobalVar>(global_var_node));
return ResolveToPrimitive(base_func);
}
} else if (const auto* prim_func_node = expr.as<tir::PrimFuncNode>()) {
return GetRef<tir::PrimFunc>(prim_func_node);
} else if (const auto* var_node = expr.as<VarNode>()) {
auto itr = primitive_functions_.find(var_node);
if (itr == primitive_functions_.end()) {
// Not bound to a primitive function.
return {};
} else {
return itr->second;
}
} else if (const auto* function_node = expr.as<FunctionNode>()) {
if (!function_node->HasNonzeroAttr(attr::kPrimitive)) {
// Not marked as primitive by FuseOps.
return {};
}
if (const auto* call_node = function_node->body.as<CallNode>()) {
if (call_node->op == debug_op_) {
// Debug 'primitives' are not lowered.
return {};
}
}
return GetRef<Function>(function_node);
} else {
return {};
}
}
/*!
* \brief Lowers the primitive function \p func to TIR for ultimate execution
* on a device with configuration \p target. Returns the global var bound
* to the TIR implementation, and attributes to attach to the call to identify it as
* a TIR call.
*/
Expr MakeLoweredCall(Function func, Array<Expr> visited_args, Span span, Target target) {
CCacheKey key = CCacheKey(func, target);
CachedFunc cfunc = compiler_->Lower(key, module_name_);
ICHECK(cfunc.defined());
auto opt_compiler = func->GetAttr<String>(attr::kCompiler);
// Add some metadata on top of the *original function* and invoke the callback so it can
// be captured.
// TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT
Map<GlobalVar, tir::PrimFunc> prim_fns;
Array<GlobalVar> all_prim_fn_vars;
for (const auto& kv : cfunc->funcs->functions) {
if (opt_compiler) {
// We expect just the original func but with just the ExternalSymbol attribute signaling
// the function (will be) compiled externally.
ICHECK(kv.second.as<FunctionNode>())
<< PrettyPrint(kv.first) << " must be bound to an (external) Function";
} else {
// We expect one or more PrimFuncs, one of which corresponds to 'the' lowered primitive
// (and the rest in support of that via tir::Calls).
ICHECK(kv.second.as<tir::PrimFuncNode>())
<< PrettyPrint(kv.first) << " must be bound to a PrimFunc";
prim_fns.Set(kv.first, Downcast<tir::PrimFunc>(kv.second));
all_prim_fn_vars.push_back(kv.first);
}
}
Function func_with_metadata = func;
func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", cfunc->prim_fn_var);
func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns);
func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, cfunc->target);
this->process_fn_(func_with_metadata);
CallLoweredAttrs call_lowered_attrs;
// Non-External Relay Function
// TODO(mbs): "reshape" cleanup.
if (!opt_compiler && func->HasNonzeroAttr(attr::kReshapeOnly)) {
call_lowered_attrs.metadata.Set(attr::kReshapeOnly, tvm::Integer(1));
}
call_lowered_attrs.metadata.Set("relay_attrs", func->attrs);
call_lowered_attrs.metadata.Set("all_prim_fn_vars", all_prim_fn_vars);
if (IsDynamic(func->ret_type)) {
// Also lower the companion dynamic shape function.
// Shape function keys use the underlying primitive function as their 'function',
// but the generic 'cpu' target as the target since all shape functions run
// on the host cpu irrespective of where the primitive runs.
CCacheKey shape_key(func, host_virtual_device_->target);
CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key);
// Capture the shape function's global var and parameters 'states' in call
// annotations so calling convention can be recovered.
// TODO(mbs): Shape cleanup.
call_lowered_attrs.metadata.Set("prim_shape_fn_var", lowered_shape_func->prim_fn_var);
call_lowered_attrs.metadata.Set("prim_shape_fn_states",
lowered_shape_func->shape_func_param_states);
call_lowered_attrs.metadata.Set("prim_shape_fn_num_inputs",
Integer(static_cast<int>(lowered_shape_func->inputs.size())));
call_lowered_attrs.metadata.Set(
"prim_shape_fn_num_outputs",
Integer(static_cast<int>(lowered_shape_func->outputs.size())));
Array<GlobalVar> all_prim_shape_fn_vars;
for (const auto& kv : lowered_shape_func->funcs->functions) {
CHECK(kv.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
all_prim_shape_fn_vars.push_back(kv.first);
}
call_lowered_attrs.metadata.Set("all_prim_shape_fn_vars", all_prim_shape_fn_vars);
}
return CallLowered(cfunc->prim_fn_var, std::move(visited_args), std::move(call_lowered_attrs),
std::move(span));
}
std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) final {
Var new_var = Downcast<Var>(Mutate(var));
Expr new_value = Mutate(value);
BaseFunc prim_func = ResolveToPrimitive(new_value);
if (prim_func.defined()) {
// Remember let var is bound (possibly indirectly) to a primitive function.
primitive_functions_.emplace(var.get(), prim_func);
}
return {new_var, new_value};
}
Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) final {
BaseFunc prim_func = ResolveToPrimitive(post_let_node->value);
if (prim_func.defined()) {
// Leaving let var scope
primitive_functions_.erase(pre_let_node->var.get());
}
return DeviceAwareExprMutator::PostVisitLet_(pre_let_node, post_let_node);
}
Expr DeviceAwareVisitExpr_(const FunctionNode* function_node) override {
if (function_node->HasNonzeroAttr(attr::kPrimitive) ||
function_node->GetAttr<String>(attr::kExternalSymbol)) {
// Nothing to lower inside primitive/external functions.
return GetRef<Function>(function_node);
} else {
return DeviceAwareExprMutator::DeviceAwareVisitExpr_(function_node);
}
}
Expr DeviceAwareVisitExpr_(const CallNode* call_node) override {
// We can see five forms of calls:
// 1. A 'normal' Relay call to a Function with the "primitive" attribute. We will need
// to lower that to a global PrimFunc and rewrite the call to:
// call_lowered(@new_global, (arg1, ..., argn), <attributes>)
// However there are a few special forms which are excluded from this treatment, see
// below.
// 2. A 'normal' Relay call to a Function with the "compiler" attribute. We will need
// to invoke the appropriate BYOC toolchain function to yield a runtime module and
// rewrite the call to the same form as above.
// 3. A 'normal' Relay call to a PrimFunc which has already been supplied via a global
// definition. We rewrite to use the call_lowered form, but otherwise nothing else
// needs to be done.
// 4. A 'normal' Relay call to a Relay Function without any special attribute. These
// calls are not changed.
// 5. A call_lowered call from an earlier invocation of this pass.
// Note that ResolveToPrimitive will yield non-null only for cases 1-3.
// Look for (possibly indirect) calls to primitives.
BaseFunc primitive_func = ResolveToPrimitive(call_node->op);
if (!primitive_func.defined()) {
// Not a call to a primitive function we need to rewrite.
if (const auto* function_node = call_node->op.as<FunctionNode>()) {
process_fn_(GetRef<Function>(function_node));
}
return DeviceAwareExprMutator::DeviceAwareVisitExpr_(call_node);
}
// Prepare the arguments.
Array<Expr> new_args;
for (const auto& arg : call_node->args) {
new_args.push_back(VisitExpr(arg));
}
// Special case: device_copies are left as calls to primitive operators
// (thus undoing FuseOps) so that each backend can handle them directly.
// TODO(mbs): device_copy cleanup. Would be better for FuseOps to just leave device_copy alone.
if (const auto* function_node = primitive_func.as<FunctionNode>()) {
DeviceCopyProps device_copy_props = GetDeviceCopyProps(function_node->body);
if (device_copy_props.body.defined()) {
ICHECK_EQ(new_args.size(), 1);
return DeviceCopy(new_args[0], device_copy_props.src_virtual_device,
device_copy_props.dst_virtual_device);
}
}
// Special case: If already lowered by other means then so we don't need to mutate
// the call but we do need to mutate the arguments
if (const auto* prim_func_node = primitive_func.as<tir::PrimFuncNode>()) {
// Function should already be Target annotated by this point
// but the TE Compiler metadata is still needed for the callback
// TODO(Mousius) - Robustify this to not assume we're in the GlobalVar for Target Hooks
GlobalVar prim_func_var = Downcast<GlobalVar>(call_node->op);
tir::PrimFunc prim_func = GetRef<tir::PrimFunc>(prim_func_node);
Map<GlobalVar, tir::PrimFunc> prim_fns = {{prim_func_var, prim_func}};
tir::PrimFunc func_with_metadata = WithAttrs(prim_func, {
{"prim_fn_var", prim_func_var},
{"prim_funcs", prim_fns},
});
ICHECK(!IsDynamic(call_node->checked_type()));
CallLoweredAttrs call_lowered_attrs;
call_lowered_attrs.metadata.Set("relay_attrs", primitive_func->attrs);
process_fn_(func_with_metadata);
ICHECK(call_node->type_args.empty()) << "lowered functions cannot be polymorphic";
return CallLowered(prim_func_var, std::move(new_args), std::move(call_lowered_attrs),
call_node->span);
}
// Typical case: call to fused primitive Relay Function.
// Find the desired target device.
Target target;
if (primitive_func->GetAttr<String>(attr::kCompiler).defined()) {
// The generic 'external device' target.
// TODO(mbs): Retire once replaced unified BYOC compiler and target machinery
target = Target("ext_dev");
} else {
// The target corresponding to the call_node expression's annotation.
VirtualDevice virtual_device = GetVirtualDevice(GetRef<Call>(call_node));
ICHECK(!virtual_device->IsFullyUnconstrained());
target = virtual_device->target;
ICHECK(target.defined());
}
// Lower the primitive function for that target.
Function function = Downcast<Function>(primitive_func);
ICHECK(call_node->type_args.empty()) << "lowered functions cannot be polymorphic";
return MakeLoweredCall(function, std::move(new_args), call_node->span, target);
}
IRModule module_;
ProcessFn process_fn_;
// Map from in-scope let-bound variables to Functions known to be primitive, or PrimFuncs which
// have already been lowered. We'll rewrite these to the fresh global vars bound to the lowered
// primitive function as we go. Those vars will be bound in the target device-type specific
// module we'll ultimately emit for each required device-type. Note that a primitive may be
// lowered for multiple device types, each which will be assigned a fresh var.
std::unordered_map<const VarNode*, BaseFunc> primitive_functions_;
String module_name_;
TECompiler compiler_;
/*!
* \brief The \p VirtualDevice for the host, which is where all shape-related data and computation
* must live.
*/
VirtualDevice host_virtual_device_;
// Cache ops that need to be frequently used later to reduce lookup overhead.
const Op& debug_op_;
};
Target GetTargetFromInteger(DLDeviceType dev_type, tec::TargetMap targets) {
if (targets.size() == 1) {
// The homogeneous execution case, return the only target.
const auto& it = targets.begin();
return (*it).second;
} else {
// The heterogeneous execution case, return the target associated with the
// given device type.
// If "dev_type" equals to 0, the device name only can be got from
// "targets", and it may not be "llvm", so here just set it to "unknown".
std::string dev_name = "unknown";
if (dev_type != 0) {
dev_name = runtime::DeviceName(dev_type);
}
if (targets.count(dev_type) == 0) {
std::stringstream msg;
msg << "No target is specified for provided device name: `" << dev_name << "`\n\n"
<< dev_name << " mapped to device type (" << dev_type
<< ") which was not found in the target map.\n"
<< "Availible targets: \n";
for (auto target : targets) {
msg << " " << target.first << "-> " << target.second << "\n";
}
LOG(FATAL) << msg.str();
}
return targets[dev_type];
}
}
Pass LowerTensorExpr(const String& module_name, TECompiler compiler, ProcessFn process_fn,
VirtualDevice host_virtual_device) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function func, IRModule module, PassContext ctx) {
LowerTensorExprMutator lower_te(module, process_fn, module_name, compiler,
host_virtual_device);
return Downcast<Function>(lower_te.Mutate(func));
};
return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {});
}
backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMap targets,
Map<Expr, backend::StorageInfo> storage_info_map) {
Function func = Downcast<Function>(mod->Lookup("main"));
VLOG_CONTEXT << "UpdateMainWorkspaceSize";
VLOG(1) << "calculating FunctionInfo for main:" << std::endl << PrettyPrint(func);
for (const auto& kv : targets) {
VLOG(1) << " target " << kv.first << " = " << kv.second->str();
}
// This is a Map<device,Map<storage_id, size>>
// TODO(mbs): Collapsing VirtualDevices to just device type.
std::unordered_map<DLDeviceType, std::unordered_map<int, int>, backend::EnumClassHash>
sid_workspace;
// This is a Map<device, size_of_inputs_and_outputs>
std::unordered_map<DLDeviceType, int, backend::EnumClassHash> device_io;
// This is a Map<device, size_of_constants>
std::unordered_map<DLDeviceType, int, backend::EnumClassHash> device_consts;
// Initialize the mapping from all storage identifiers to workspace sizes,
// the amount of device io, and the device constants.
for (const auto& kv : storage_info_map) {
const backend::StorageInfo& storage_info = kv.second;
const std::vector<int64_t>& storage_ids = storage_info->storage_ids;
const std::vector<VirtualDevice>& virtual_devices = storage_info->virtual_devices;
CHECK_EQ(storage_ids.size(), virtual_devices.size());
for (uint32_t i = 0; i < virtual_devices.size(); i++) {
DLDeviceType device_type = virtual_devices[i]->device_type();
sid_workspace[device_type][storage_ids[i]] = 0;
device_io[device_type] = 0;
device_consts[device_type] = 0;
}
}
// Iterate the storage map to compute all the tensor sizes in the program.
// There are 3 cases in this code:
//
// First we need to compute the sizes of all
// inline constants.
//
// Second we compute the size of any bound variable as these are input and output
// sizes of the program.
//
// Finally for all other expressions we check which storage identifier they have
// been assigned and we compute the maximal size of the storage, as tensors can
// share storage with other tensors which are the same size or larger.
//
// In this final case there is only one allocation for all tensors which share storage
// which will be the maximal size of all tensors which were assigned to it.
for (const auto& kv : storage_info_map) {
const Expr& expr = kv.first;
const backend::StorageInfo& storage_info = kv.second;
int64_t size_bytes = backend::CalculateRelayExprSizeBytes(expr->checked_type());
VLOG(1) << "expression:" << std::endl
<< PrettyPrint(expr) << std::endl
<< "of type:" << std::endl
<< PrettyPrint(expr->checked_type()) << std::endl
<< "has size " << size_bytes << " and storage info:" << std::endl
<< storage_info;
const std::vector<int64_t>& storage_ids = storage_info->storage_ids;
const std::vector<VirtualDevice>& virtual_devices = storage_info->virtual_devices;
if (expr->IsInstance<ConstantNode>()) {
for (const auto& virtual_device : virtual_devices) {
DLDeviceType device_type = virtual_device->device_type();
ICHECK_EQ(device_consts.count(device_type), 1);
device_consts[device_type] += size_bytes;
}
} else if (expr->IsInstance<VarNode>() || expr.same_as(func->body)) {
CHECK(size_bytes == 0 || virtual_devices.size() >= 1) << "must be at least one device";
for (const auto& virtual_device : virtual_devices) {
DLDeviceType device_type = virtual_device->device_type();
device_io[device_type] += size_bytes;
}
} else {
// TODO(@electriclilies): This code is never being called which means sid_workspace is not
// updated.. This means that storage info is probably not being created correctly. Or is not
// equivalent to what was here previously
for (uint32_t i = 0; i < storage_ids.size(); i++) {
// Here we record the largest size of the tensor
// that share the same storage id, because storage_id will
// be shared between multiple tensors that are not live simultaneously.
DLDeviceType device_type = virtual_devices[i]->device_type();
if (size_bytes > sid_workspace[device_type][storage_ids[i]]) {
sid_workspace[device_type][storage_ids[i]] = size_bytes;
}
}
}
}
// This is a Map<device, workspace_size>
std::unordered_map<DLDeviceType, int, backend::EnumClassHash> device_workspace;
// Once we know the sizes of sids, we need to accumulate per device
for (const auto& dev_sid_size : sid_workspace) {
auto dev = dev_sid_size.first;
device_workspace[dev] = 0;
for (const auto& sid_size : dev_sid_size.second) {
device_workspace[dev] += sid_size.second;
}
}
Map<Target, Integer> workspace_sizes;
Map<Target, Integer> io_sizes;
Map<Target, Integer> constant_sizes;
Map<Target, tir::PrimFunc> tir_primfuncs;
Map<Target, Function> relay_primfuncs;
// Initialize all target workspaces to zero
for (const auto& kv : targets) {
auto tgt = kv.second;
workspace_sizes.Set(tgt, 0);
}
for (const auto& dev_and_size : device_workspace) {
auto tgt = tec::GetTargetFromInteger(dev_and_size.first, targets);
workspace_sizes.Set(tgt, dev_and_size.second);
relay_primfuncs.Set(tgt, func);
}
for (const auto& dev_and_size : device_io) {
auto tgt = tec::GetTargetFromInteger(dev_and_size.first, targets);
io_sizes.Set(tgt, dev_and_size.second);
}
for (const auto& dev_and_size : device_consts) {
auto tgt = tec::GetTargetFromInteger(dev_and_size.first, targets);
ICHECK_EQ(constant_sizes.count(tgt), 0);
constant_sizes.Set(tgt, dev_and_size.second);
}
backend::FunctionInfo func_info(std::move(workspace_sizes), std::move(io_sizes),
std::move(constant_sizes), std::move(tir_primfuncs),
std::move(relay_primfuncs));
VLOG(1) << "func_info: " << func_info;
return std::move(func_info);
}
/*!
* \brief A function to create the function metadata for an input function (ie calculate buffer
* input/output sizes)
* \param func The function to calculate function metadata for
* \param function_metadata The map that stores all the function metadatas
*/
void UpdateFunctionMetadata(BaseFunc func,
Map<String, backend::FunctionInfo>& function_metadata, // NOLINT(*)
Integer workspace_byte_alignment) {
VLOG_CONTEXT << "UpdateFunctionMetadata";
VLOG(1) << "updating function metadata for:" << std::endl << PrettyPrint(func);
// Originally UpdateFunctionMetadata took in CCachedFunc and looped through all the funcs stored
// there Now the goal is to take only one func because process_fn should be controlling the
// iteration However, to do the workspace calculations we need the primfuncs. So process_fn
// needs to either access the cached funcs or be directly passed primfuncs This is bad and
// ideally we don't want process_fn to look at primfuncs There's also the question now of what
// the function metadatas are and how they are used if we can do something else to replicate the
// behavior of the function metadatas that might be good (ie annotating functions or something).