Skip to content

Commit

Permalink
offset support in table
Browse files Browse the repository at this point in the history
  • Loading branch information
kripken committed Aug 15, 2016
1 parent 086c4c0 commit 113efca
Show file tree
Hide file tree
Showing 45 changed files with 240 additions and 102 deletions.
10 changes: 7 additions & 3 deletions src/asm2wasm.h
Expand Up @@ -660,13 +660,17 @@ void Asm2WasmBuilder::processAsm(Ref ast) {
// TODO: when not using aliasing function pointers, we could merge them by noticing that
// index 0 in each table is the null func, and each other index should only have one
// non-null func. However, that breaks down when function pointer casts are emulated.
functionTableStarts[name] = wasm.table.names.size(); // this table starts here
if (wasm.table.segments.size() == 0) {
wasm.table.segments.emplace_back(wasm.allocator.alloc<Const>()->set(Literal(uint32_t(0))));
}
auto& segment = wasm.table.segments[0];
functionTableStarts[name] = segment.data.size(); // this table starts here
Ref contents = value[1];
for (unsigned k = 0; k < contents->size(); k++) {
IString curr = contents[k][1]->getIString();
wasm.table.names.push_back(curr);
segment.data.push_back(curr);
}
wasm.table.initial = wasm.table.max = wasm.table.names.size();
wasm.table.initial = wasm.table.max = segment.data.size();
} else {
abort_on("invalid var element", pair);
}
Expand Down
6 changes: 4 additions & 2 deletions src/binaryen-c.cpp
Expand Up @@ -729,10 +729,12 @@ void BinaryenSetFunctionTable(BinaryenModuleRef module, BinaryenFunctionRef* fun
}

auto* wasm = (Module*)module;
Table::Segment segment(wasm->allocator.alloc<Const>()->set(Literal(int32_t(0))));
for (BinaryenIndex i = 0; i < numFuncs; i++) {
wasm->table.names.push_back(((Function*)funcs[i])->name);
segment.data.push_back(((Function*)funcs[i])->name);
}
wasm->table.initial = wasm->table.max = wasm->table.names.size();
wasm->table.segments.push_back(segment);
wasm->table.initial = wasm->table.max = numFuncs;
}

// Memory. One per module
Expand Down
10 changes: 6 additions & 4 deletions src/passes/DuplicateFunctionElimination.cpp
Expand Up @@ -123,10 +123,12 @@ struct DuplicateFunctionElimination : public Pass {
replacerRunner.add<FunctionReplacer>(&replacements);
replacerRunner.run();
// replace in table
for (auto& name : module->table.names) {
auto iter = replacements.find(name);
if (iter != replacements.end()) {
name = iter->second;
for (auto& segment : module->table.segments) {
for (auto& name : segment.data) {
auto iter = replacements.find(name);
if (iter != replacements.end()) {
name = iter->second;
}
}
}
// replace in start
Expand Down
15 changes: 9 additions & 6 deletions src/passes/Print.cpp
Expand Up @@ -579,12 +579,15 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
if (curr->max && curr->max != Table::kMaxSize) o << ' ' << curr->max;
o << " anyfunc)\n";
doIndent(o, indent);
printOpening(o, "elem", true);
for (auto name : curr->names) {
o << ' ';
printName(name);
for (auto& segment : curr->segments) {
printOpening(o, "elem ", true);
visit(segment.offset);
for (auto name : segment.data) {
o << ' ';
printName(name);
}
o << ')';
}
o << ')';
}
void visitModule(Module *curr) {
currModule = curr;
Expand Down Expand Up @@ -652,7 +655,7 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
visitGlobal(child.get());
o << maybeNewLine;
}
if (curr->table.names.size() > 0) {
if (curr->table.segments.size() > 0 || curr->table.initial > 0 || curr->table.max != Table::kMaxSize) {
doIndent(o, indent);
visitTable(&curr->table);
o << maybeNewLine;
Expand Down
6 changes: 4 additions & 2 deletions src/passes/RemoveUnusedFunctions.cpp
Expand Up @@ -39,8 +39,10 @@ struct RemoveUnusedFunctions : public Pass {
root.push_back(module->getFunction(curr->value));
}
// For now, all functions that can be called indirectly are marked as roots.
for (auto& curr : module->table.names) {
root.push_back(module->getFunction(curr));
for (auto& segment : module->table.segments) {
for (auto& curr : segment.data) {
root.push_back(module->getFunction(curr));
}
}
// Compute function reachability starting from the root set.
DirectCallGraphAnalyzer analyzer(module, root);
Expand Down
6 changes: 4 additions & 2 deletions src/passes/ReorderFunctions.cpp
Expand Up @@ -38,8 +38,10 @@ struct ReorderFunctions : public WalkerPass<PostWalker<ReorderFunctions, Visitor
for (auto& curr : module->exports) {
counts[curr->value]++;
}
for (auto& curr : module->table.names) {
counts[curr]++;
for (auto& segment : module->table.segments) {
for (auto& curr : segment.data) {
counts[curr]++;
}
}
std::sort(module->functions.begin(), module->functions.end(), [this](
const std::unique_ptr<Function>& a,
Expand Down
24 changes: 24 additions & 0 deletions src/shell-interface.h
Expand Up @@ -86,6 +86,8 @@ struct ShellExternalInterface : ModuleInstance::ExternalInterface {
}
} memory;

std::vector<Name> table;

ShellExternalInterface() : memory() {}

void init(Module& wasm) override {
Expand All @@ -98,6 +100,15 @@ struct ShellExternalInterface : ModuleInstance::ExternalInterface {
memory.set(offset + i, segment.data[i]);
}
}

table.resize(wasm.table.initial);
for (auto& segment : wasm.table.segments) {
Address offset = ConstantExpressionRunner().visit(segment.offset).value.geti32();
assert(offset + segment.data.size() <= wasm.table.initial);
for (size_t i = 0; i != segment.data.size(); ++i) {
table[offset + i] = segment.data[i];
}
}
}

Literal callImport(Import *import, LiteralList& arguments) override {
Expand All @@ -115,6 +126,19 @@ struct ShellExternalInterface : ModuleInstance::ExternalInterface {
abort();
}

Literal callTable(Index index, Name type, LiteralList& arguments, ModuleInstance& instance) override {
if (index >= table.size()) trap("callTable overflow");
auto* func = instance.wasm.getFunction(table[index]);
if (func->type.is() && func->type != type) trap("callIndirect: bad type");
if (func->params.size() != arguments.size()) trap("callIndirect: bad # of arguments");
for (size_t i = 0; i < func->params.size(); i++) {
if (func->params[i] != arguments[i].type) {
trap("callIndirect: bad argument type");
}
}
return instance.callFunctionInternal(func->name, arguments);
}

Literal load(Load* load, Address addr) override {
switch (load->type) {
case i32: {
Expand Down
37 changes: 26 additions & 11 deletions src/wasm-binary.h
Expand Up @@ -740,12 +740,19 @@ class WasmBinaryWriter : public Visitor<WasmBinaryWriter, void> {
}

void writeFunctionTable() {
if (wasm->table.names.size() == 0) return;
if (wasm->table.segments.size() == 0) return;
if (debug) std::cerr << "== writeFunctionTable" << std::endl;
auto start = startSection(BinaryConsts::Section::FunctionTable);
o << U32LEB(wasm->table.names.size());
for (auto name : wasm->table.names) {
o << U32LEB(getFunctionIndex(name));
o << U32LEB(wasm->table.initial);
o << U32LEB(wasm->table.max);
o << U32LEB(wasm->table.segments.size());
for (auto& segment : wasm->table.segments) {
writeExpression(segment.offset);
o << int8_t(BinaryConsts::End);
o << U32LEB(segment.data.size());
for (auto name : segment.data) {
o << U32LEB(getFunctionIndex(name));
}
}
finishSection(start);
}
Expand Down Expand Up @@ -1644,11 +1651,13 @@ class WasmBinaryBuilder {
}
}

for (size_t index : functionTable) {
assert(index < wasm.functions.size());
wasm.table.names.push_back(wasm.functions[index]->name);
for (auto& pair : functionTable) {
auto i = pair.first;
auto& indexes = pair.second;
for (auto j : indexes) {
wasm.table.segments[i].data.push_back(wasm.functions[j]->name);
}
}
wasm.table.initial = wasm.table.max = wasm.table.names.size();
}

void readDataSegments() {
Expand All @@ -1667,14 +1676,20 @@ class WasmBinaryBuilder {
}
}

std::vector<size_t> functionTable;
std::map<Index, std::vector<Index>> functionTable;

void readFunctionTable() {
if (debug) std::cerr << "== readFunctionTable" << std::endl;
wasm.table.initial = getU32LEB();
wasm.table.max = getU32LEB();
auto num = getU32LEB();
for (size_t i = 0; i < num; i++) {
auto index = getU32LEB();
functionTable.push_back(index);
wasm.table.segments.emplace_back(readExpression());
auto& temporary = functionTable[i];
auto size = getU32LEB();
for (Index j = 0; j < size; j++) {
temporary.push_back(getU32LEB());
}
}
}

Expand Down
21 changes: 7 additions & 14 deletions src/wasm-interpreter.h
Expand Up @@ -533,6 +533,7 @@ class ModuleInstance {
struct ExternalInterface {
virtual void init(Module& wasm) {}
virtual Literal callImport(Import* import, LiteralList& arguments) = 0;
virtual Literal callTable(Index index, Name type, LiteralList& arguments, ModuleInstance& instance) = 0;
virtual Literal load(Load* load, Address addr) = 0;
virtual void store(Store* store, Address addr, Literal value) = 0;
virtual void growMemory(Address oldSize, Address newSize) = 0;
Expand Down Expand Up @@ -591,8 +592,8 @@ class ModuleInstance {
return callFunctionInternal(name, arguments);
}

private:
// Internal function call.
public:
// Internal function call. Must be public so that callTable implementations can use it (refactor?)
Literal callFunctionInternal(IString name, LiteralList& arguments) {

class FunctionScope {
Expand Down Expand Up @@ -672,18 +673,8 @@ class ModuleInstance {
LiteralList arguments;
Flow flow = generateArguments(curr->operands, arguments);
if (flow.breaking()) return flow;
size_t index = target.value.geti32();
if (index >= instance.wasm.table.names.size()) trap("callIndirect: overflow");
Name name = instance.wasm.table.names[index];
Function *func = instance.wasm.getFunction(name);
if (func->type.is() && func->type != curr->fullType) trap("callIndirect: bad type");
if (func->params.size() != arguments.size()) trap("callIndirect: bad # of arguments");
for (size_t i = 0; i < func->params.size(); i++) {
if (func->params[i] != arguments[i].type) {
trap("callIndirect: bad argument type");
}
}
return instance.callFunctionInternal(name, arguments);
Index index = target.value.geti32();
return instance.externalInterface->callTable(index, curr->fullType, arguments, instance);
}

Flow visitGetLocal(GetLocal *curr) {
Expand Down Expand Up @@ -802,6 +793,8 @@ class ModuleInstance {
return ret;
}

private:

Address memorySize; // in pages

template <class LS>
Expand Down
68 changes: 60 additions & 8 deletions src/wasm-js.cpp
Expand Up @@ -195,10 +195,23 @@ extern "C" void EMSCRIPTEN_KEEPALIVE instantiate() {
target.set(source, $0);
}, ConstantExpressionRunner().visit(segment.offset).value.geti32(), &segment.data[0], segment.data.size());
}
// Table support is in a JS array. If the entry is a number, it's a function pointer. If not, it's a JS method to be called directly
// TODO: make them all JS methods, wrapping a dynCall where necessary?
EM_ASM_({
Module['outside']['wasmTable'] = new Array($0);
}, wasm.table.initial);
for (auto segment : wasm.table.segments) {
Address offset = ConstantExpressionRunner().visit(segment.offset).value.geti32();
assert(offset + segment.data.size() <= wasm.table.initial);
for (size_t i = 0; i != segment.data.size(); ++i) {
EM_ASM_({
Module['outside']['wasmTable'][$0] = $1;
}, offset + i, wasm.getFunction(segment.data[i]));
}
}
}

Literal callImport(Import *import, LiteralList& arguments) override {
if (wasmJSDebug) std::cout << "calling import " << import->name.str << '\n';
void prepareTempArgments(LiteralList& arguments) {
EM_ASM({
Module['tempArguments'] = [];
});
Expand All @@ -213,6 +226,21 @@ extern "C" void EMSCRIPTEN_KEEPALIVE instantiate() {
abort();
}
}
}

Literal getResultFromJS(double ret, WasmType type) {
switch (type) {
case none: return Literal(0);
case i32: return Literal((int32_t)ret);
case f32: return Literal((float)ret);
case f64: return Literal((double)ret);
default: abort();
}
}

Literal callImport(Import *import, LiteralList& arguments) override {
if (wasmJSDebug) std::cout << "calling import " << import->name.str << '\n';
prepareTempArgments(arguments);
double ret = EM_ASM_DOUBLE({
var mod = Pointer_stringify($0);
var base = Pointer_stringify($1);
Expand All @@ -224,12 +252,36 @@ extern "C" void EMSCRIPTEN_KEEPALIVE instantiate() {

if (wasmJSDebug) std::cout << "calling import returning " << ret << '\n';

switch (import->type->result) {
case none: return Literal(0);
case i32: return Literal((int32_t)ret);
case f32: return Literal((float)ret);
case f64: return Literal((double)ret);
default: abort();
return getResultFromJS(ret, import->type->result);
}

Literal callTable(Index index, Name type, LiteralList& arguments, ModuleInstance& instance) override {
void* ptr = (void*)EM_ASM_INT({
var value = Module['outside']['wasmTable'][$0];
return typeof value === "number" ? value : -1;
}, index);
if (ptr == nullptr) trap("callTable overflow");
if (ptr != (void*)-1) {
// a Function we can call
Function* func = (Function*)ptr;
if (func->type.is() && func->type != type) trap("callIndirect: bad type");
if (func->params.size() != arguments.size()) trap("callIndirect: bad # of arguments");
for (size_t i = 0; i < func->params.size(); i++) {
if (func->params[i] != arguments[i].type) {
trap("callIndirect: bad argument type");
}
}
return instance.callFunctionInternal(func->name, arguments);
} else {
// A JS function JS can call
prepareTempArgments(arguments);
double ret = EM_ASM_DOUBLE({
var func = Module['outside']['wasmTable'][$0];
var tempArguments = Module['tempArguments'];
Module['tempArguments'] = null;
return func.apply(null, tempArguments);
}, index);
return getResultFromJS(ret, instance.wasm.getFunctionType(type)->result);
}
}

Expand Down

0 comments on commit 113efca

Please sign in to comment.