diff --git a/src/Debug/debugger.cpp b/src/Debug/debugger.cpp index 8828f2fe..c2b2ce63 100644 --- a/src/Debug/debugger.cpp +++ b/src/Debug/debugger.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #ifndef ARDUINO #include #else @@ -332,6 +333,14 @@ bool Debugger::checkDebugMessages(Module *m, RunningState *program_state) { this->dumpCallbackmapping(); free(interruptData); break; + case interruptSetOverridePinValue: + this->addOverride(m, interruptData + 1); + free(interruptData); + break; + case interruptUnsetOverridePinValue: + this->removeOverride(m, interruptData + 1); + free(interruptData); + break; default: // handle later this->channel->write("COULD not parse interrupt data!\n"); @@ -725,7 +734,7 @@ bool Debugger::handlePushedEvent(char *bytes) const { } void Debugger::snapshot(Module *m) const { - uint16_t numberBytes = 11; + uint16_t numberBytes = 12; uint8_t state[] = {pcState, breakpointsState, callstackState, @@ -736,7 +745,8 @@ void Debugger::snapshot(Module *m) const { stackState, callbacksState, eventsState, - ioState}; + ioState, + overridesState}; inspect(m, numberBytes, state); } @@ -885,6 +895,23 @@ void Debugger::inspect(Module *m, const uint16_t sizeStateArray, addComma = true; break; } + case overridesState: { + this->channel->write("%s", addComma ? "," : ""); + this->channel->write(R"("overrides": [)"); + bool comma = false; + for (auto key : overrides) { + for (auto argResult : key.second) { + this->channel->write("%s", comma ? ", " : ""); + this->channel->write( + R"({"fidx": %d, "arg": %d, "return_value": %d})", + key.first, argResult.first, argResult.second); + comma = true; + } + } + this->channel->write("]"); + addComma = true; + break; + } default: { debug("dumpExecutionState: Received unknown state request\n"); break; @@ -1240,6 +1267,19 @@ bool Debugger::saveState(Module *m, uint8_t *interruptData) { restore_external_state(m, external_state); break; } + case overridesState: { + debug("receiving overridesState\n"); + overrides.clear(); + uint8_t overrides_count = *program_state++; + for (uint32_t i = 0; i < overrides_count; i++) { + uint32_t fidx = read_B32(&program_state); + uint32_t arg = read_B32(&program_state); + uint32_t return_value = read_B32(&program_state); + overrides[fidx][arg] = return_value; + debug("Override %d %d %d\n", fidx, arg, return_value); + } + break; + } default: { FATAL("saveState: Received unknown program state\n"); } @@ -1411,6 +1451,66 @@ bool Debugger::reset(Module *m) const { return true; } +std::optional resolve_imported_function(Module *m, + std::string function_name) { + for (uint32_t fidx = 0; fidx < m->import_count; fidx++) { + if (!strcmp(m->functions[fidx].import_field, function_name.c_str())) { + return fidx; + } + } + return {}; +} + +std::string read_string(uint8_t **pos) { + std::string str = ""; + char c = *(*pos)++; + while (c != '\0') { + str += c; + c = *(*pos)++; + } + return str; +} + +void Debugger::addOverride(Module *m, uint8_t *interruptData) { + std::string primitive_name = read_string(&interruptData); + uint32_t arg = read_B32(&interruptData); + uint32_t result = read_B32(&interruptData); + + std::optional fidx = resolve_imported_function(m, primitive_name); + if (!fidx) { + channel->write( + "Cannot override the result for unknown function \"%s\".\n", + primitive_name.c_str()); + return; + } + + channel->write("Override %s(%d) = %d.\n", primitive_name.c_str(), arg, + result); + overrides[fidx.value()][arg] = result; +} + +void Debugger::removeOverride(Module *m, uint8_t *interruptData) { + std::string primitive_name = read_string(&interruptData); + uint32_t arg = read_B32(&interruptData); + + std::optional fidx = resolve_imported_function(m, primitive_name); + if (!fidx) { + channel->write("Cannot remove override for unknown function \"%s\".\n", + primitive_name.c_str()); + return; + } + + if (overrides[fidx.value()].count(arg) == 0) { + channel->write("Override for %s(%d) not found.\n", + primitive_name.c_str(), arg); + return; + } + + channel->write("Removing override %s(%d) = %d.\n", primitive_name.c_str(), + arg, overrides[fidx.value()][arg]); + overrides[fidx.value()].erase(arg); +} + Debugger::~Debugger() { this->disconnect_proxy(); this->stop(); diff --git a/src/Debug/debugger.h b/src/Debug/debugger.h index 0dbb1cca..c43a00af 100644 --- a/src/Debug/debugger.h +++ b/src/Debug/debugger.h @@ -49,6 +49,7 @@ enum ExecutionState { callbacksState = 0x09, eventsState = 0x0A, ioState = 0x0B, + overridesState = 0x0C, }; enum InterruptTypes { @@ -90,10 +91,13 @@ enum InterruptTypes { interruptDUMPCallbackmapping = 0x74, interruptRecvCallbackmapping = 0x75, + // Primitive overrides + interruptSetOverridePinValue = 0x80, + interruptUnsetOverridePinValue = 0x81, + // Operations interruptStore = 0xa0, interruptStored = 0xa1, - }; class Debugger { @@ -117,6 +121,9 @@ class Debugger { bool asyncSnapshots; + std::unordered_map> + overrides; + // Private methods void printValue(const StackValue *v, uint32_t idx, bool end) const; @@ -270,4 +277,15 @@ class Debugger { void notifyPushedEvent() const; bool handlePushedEvent(char *bytes) const; + + // Concolic Multiverse Debugging + inline bool isMocked(uint32_t fidx, uint32_t argument) { + return overrides.count(fidx) > 0 && overrides[fidx].count(argument) > 0; + } + inline uint32_t getMockedValue(uint32_t fidx, uint32_t argument) { + return overrides[fidx][argument]; + } + + void addOverride(Module *m, uint8_t *interruptData); + void removeOverride(Module *m, uint8_t *interruptData); }; diff --git a/src/Interpreter/instructions.cpp b/src/Interpreter/instructions.cpp index 9046c0d0..0e4ceb3c 100644 --- a/src/Interpreter/instructions.cpp +++ b/src/Interpreter/instructions.cpp @@ -287,6 +287,16 @@ bool i_instr_call(Module *m) { } if (fidx < m->import_count) { + // Mocking only works on primitives, no need to check for it otherwise. + if (m->sp >= 0) { + uint32_t arg = m->stack[m->sp].value.uint32; + if (m->warduino->debugger->isMocked(fidx, arg)) { + m->stack[m->sp].value.uint32 = + m->warduino->debugger->getMockedValue(fidx, arg); + return true; + } + } + return ((Primitive)m->functions[fidx].func_ptr)(m); } else { if (m->csp >= CALLSTACK_SIZE) {