diff --git a/.luacheckrc b/.luacheckrc index 0291ef5..30a5aab 100644 --- a/.luacheckrc +++ b/.luacheckrc @@ -24,6 +24,7 @@ stds.roblox = { stds.testez = { read_globals = { "describe", + "beforeEach", "afterEach", "beforeAll", "afterAll", "it", "itFOCUS", "itSKIP", "FOCUS", "SKIP", "HACK_NO_XPCALL", "expect", diff --git a/CHANGELOG.md b/CHANGELOG.md index 0406b07..32331f6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ # Rodux Changelog ## Unreleased Changes +* Introduce error handling to catch and report errors during reducers ([#60](https://github.com/Roblox/rodux/pull/60)). ## 1.1.0 (2021-01-04) * Added color schemes for documentation based on user preference ([#56](https://github.com/Roblox/rodux/pull/56)). diff --git a/modules/lemur b/modules/lemur index 74286fd..3e4ff7d 160000 --- a/modules/lemur +++ b/modules/lemur @@ -1 +1 @@ -Subproject commit 74286fdacd7ba01024d18839371d6b8aa4b8ff96 +Subproject commit 3e4ff7d8f09e57164ad0614c1b81cc9338caf006 diff --git a/modules/testez b/modules/testez index 442b719..edc7246 160000 --- a/modules/testez +++ b/modules/testez @@ -1 +1 @@ -Subproject commit 442b71926d4e9bd9933bbdd87d95679062723dad +Subproject commit edc7246d0173a3a90eba4a9f64ea26c576be5873 diff --git a/spec.lua b/spec.lua index 0038199..0eefd0e 100644 --- a/spec.lua +++ b/spec.lua @@ -5,7 +5,7 @@ -- If you add any dependencies, add them to this table so they'll be loaded! local LOAD_MODULES = { {"src", "Library"}, - {"modules/testez/lib", "TestEZ"}, + {"modules/testez/src", "TestEZ"}, } -- This makes sure we can load Lemur and other libraries that depend on init.lua @@ -31,7 +31,10 @@ end -- Load TestEZ and run our tests local TestEZ = habitat:require(Root.TestEZ) -local results = TestEZ.TestBootstrap:run(Root.Library, TestEZ.Reporters.TextReporter) +local results = TestEZ.TestBootstrap:run( + { Root.Library }, + TestEZ.Reporters.TextReporter +) -- Did something go wrong? if results.failureCount > 0 then diff --git a/src/NoYield.lua b/src/NoYield.lua index f9519f1..3be5a39 100644 --- a/src/NoYield.lua +++ b/src/NoYield.lua @@ -1,3 +1,5 @@ +--!nocheck + --[[ Calls a function and throws an error if it attempts to yield. @@ -26,4 +28,4 @@ local function NoYield(callback, ...) return resultHandler(co, coroutine.resume(co, ...)) end -return NoYield \ No newline at end of file +return NoYield diff --git a/src/Signal.lua b/src/Signal.lua index dc4d041..1846b8e 100644 --- a/src/Signal.lua +++ b/src/Signal.lua @@ -4,6 +4,7 @@ Handlers are fired in order, and (dis)connections are properly handled when executing an event. ]] +local inspect = require(script.Parent.inspect).inspect local function immutableAppend(list, ...) local new = {} @@ -36,9 +37,10 @@ local Signal = {} Signal.__index = Signal -function Signal.new() +function Signal.new(store) local self = { - _listeners = {} + _listeners = {}, + _store = store } setmetatable(self, Signal) @@ -47,15 +49,45 @@ function Signal.new() end function Signal:connect(callback) + if typeof(callback) ~= "function" then + error("Expected the listener to be a function.") + end + + if self._store and self._store._isDispatching then + error( + 'You may not call store.changed:connect() while the reducer is executing. ' .. + 'If you would like to be notified after the store has been updated, subscribe from a ' .. + 'component and invoke store:getState() in the callback to access the latest state. ' + ) + end + local listener = { callback = callback, disconnected = false, + connectTraceback = debug.traceback(), + disconnectTraceback = nil } self._listeners = immutableAppend(self._listeners, listener) local function disconnect() + if listener.disconnected then + local errorMessage = ("Listener connected at: \n%s\n" .. + "was already disconnected at: \n%s\n"):format( + tostring(listener.connectTraceback), + tostring(listener.disconnectTraceback) + ) + self._store._errorReporter:reportErrorDeferred(errorMessage, debug.traceback()) + + return + end + + if self._store and self._store._isDispatching then + error("You may not unsubscribe from a store listener while the reducer is executing.") + end + listener.disconnected = true + listener.disconnectTraceback = debug.traceback() self._listeners = immutableRemoveValue(self._listeners, listener) end @@ -64,10 +96,35 @@ function Signal:connect(callback) } end +function Signal:reportListenerError(listener, callbackArgs, error_) + local message = ("Caught error when calling event listener (%s), " .. + "originally subscribed from: \n%s\n" .. + "with arguments: \n%s\n"):format( + tostring(listener.callback), + tostring(listener.connectTraceback), + inspect(callbackArgs) + ) + + if self._store then + self._store._errorReporter:reportErrorImmediately(message, error_) + else + print(message .. tostring(error_)) + end +end + function Signal:fire(...) for _, listener in ipairs(self._listeners) do if not listener.disconnected then - listener.callback(...) + local ok, result = pcall(function(...) + listener.callback(...) + end, ...) + if not ok then + self:reportListenerError( + listener, + {...}, + result + ) + end end end end diff --git a/src/Signal.spec.lua b/src/Signal.spec.lua index f00f947..0cc3b47 100644 --- a/src/Signal.spec.lua +++ b/src/Signal.spec.lua @@ -111,4 +111,85 @@ return function() expect(countA).to.equal(1) expect(countB).to.equal(0) end) + + describe("when event handlers error", function() + local reportedErrorError, reportedErrorMessage + local mockStore = { + _errorReporter = { + reportErrorImmediately = function(_self, message, error_) + reportedErrorMessage = message + reportedErrorError = error_ + end, + reportErrorDeferred = function(_self, message, error_) + reportedErrorMessage = message + reportedErrorError = error_ + end + } + } + + beforeEach(function() + reportedErrorError = "" + reportedErrorMessage = "" + end) + + it("first listener succeeds when second listener errors", function() + local signal = Signal.new(mockStore) + local countA = 0 + + signal:connect(function() + countA = countA + 1 + end) + + signal:connect(function() + error("connectionB") + end) + + signal:fire() + + expect(countA).to.equal(1) + local caughtErrorMessage = "Caught error when calling event listener" + expect(string.find(reportedErrorMessage, caughtErrorMessage)).to.be.ok() + local caughtErrorError = "connectionB" + expect(string.find(reportedErrorError, caughtErrorError)).to.be.ok() + end) + + it("second listener succeeds when first listener errors", function() + local signal = Signal.new(mockStore) + local countB = 0 + + signal:connect(function() + error("connectionA") + end) + + signal:connect(function() + countB = countB + 1 + end) + + signal:fire() + + expect(countB).to.equal(1) + local caughtErrorMessage = "Caught error when calling event listener" + expect(string.find(reportedErrorMessage, caughtErrorMessage)).to.be.ok() + local caughtErrorError = "connectionA" + expect(string.find(reportedErrorError, caughtErrorError)).to.be.ok() + end) + + it("serializes table arguments when reporting errors", function() + local signal = Signal.new(mockStore) + + signal:connect(function() + error("connectionA") + end) + + local actionCommand = "SENTINEL" + signal:fire({actionCommand = actionCommand}) + + local caughtErrorMessage = "Caught error when calling event listener" + local caughtErrorArg = "actionCommand: \"" .. actionCommand .. "\"" + expect(string.find(reportedErrorMessage, caughtErrorMessage)).to.be.ok() + expect(string.find(reportedErrorMessage, caughtErrorArg)).to.be.ok() + local caughtErrorError = "connectionA" + expect(string.find(reportedErrorError, caughtErrorError)).to.be.ok() + end) + end) end \ No newline at end of file diff --git a/src/Store.lua b/src/Store.lua index 90aa02f..ba6a034 100644 --- a/src/Store.lua +++ b/src/Store.lua @@ -2,6 +2,18 @@ local RunService = game:GetService("RunService") local Signal = require(script.Parent.Signal) local NoYield = require(script.Parent.NoYield) +local inspect = require(script.Parent.inspect).inspect + +local defaultErrorReporter = { + reportErrorDeferred = function(self, message, stacktrace) + print(message) + print(stacktrace) + end, + reportErrorImmediately = function(self, message, stacktrace) + print(message) + print(stacktrace) + end +} local Store = {} @@ -23,22 +35,41 @@ Store.__index = Store Reducers do not mutate the state object, so the original state is still valid. ]] -function Store.new(reducer, initialState, middlewares) +function Store.new(reducer, initialState, middlewares, errorReporter) assert(typeof(reducer) == "function", "Bad argument #1 to Store.new, expected function.") assert(middlewares == nil or typeof(middlewares) == "table", "Bad argument #3 to Store.new, expected nil or table.") + if middlewares ~= nil then + for i=1, #middlewares, 1 do + assert( + typeof(middlewares[i]) == "function", + ("Expected the middleware ('%s') at index %d to be a function."):format(tostring(middlewares[i]), i) + ) + end + end local self = {} + self._errorReporter = errorReporter or defaultErrorReporter + self._isDispatching = false self._reducer = reducer - self._state = reducer(initialState, { + local initAction = { type = "@@INIT", - }) + } + self._lastAction = initAction + local ok, result = pcall(function() + self._state = reducer(initialState, initAction) + end) + if not ok then + local message = ("Caught error with init action of reducer (%s): %s"):format(tostring(reducer), tostring(result)) + errorReporter:reportErrorImmediately(message, debug.traceback()) + self._state = initialState + end self._lastState = self._state self._mutatedSinceFlush = false self._connections = {} - self.changed = Signal.new() + self.changed = Signal.new(self) setmetatable(self, Store) @@ -58,7 +89,7 @@ function Store.new(reducer, initialState, middlewares) dispatch = middleware(dispatch, self) end - self.dispatch = function(self, ...) + self.dispatch = function(_self, ...) return dispatch(...) end end @@ -70,9 +101,29 @@ end Get the current state of the Store. Do not mutate this! ]] function Store:getState() + if self._isDispatching then + error(("You may not call store:getState() while the reducer is executing. " .. + "The reducer (%s) has already received the state as an argument. " .. + "Pass it down from the top reducer instead of reading it from the store."):format(tostring(self._reducer))) + end + return self._state end +function Store:_reportReducerError(failedAction, error_, traceback) + local message = ("Caught error when running action (%s) " .. + "through reducer (%s): \n%s \n" .. + "previous action type was: %s" + ):format( + tostring(failedAction), + tostring(self._reducer), + tostring(error_), + inspect(self._lastAction) + ) + + self._errorReporter:reportErrorImmediately(message, traceback) +end + --[[ Dispatch an action to the store. This allows the store's reducer to mutate the state of the application by creating a new copy of the state. @@ -81,16 +132,39 @@ end changes, but not necessarily on every Dispatch. ]] function Store:dispatch(action) - if typeof(action) == "table" then - if action.type == nil then - error("action does not have a type field", 2) - end + if typeof(action) ~= "table" then + error(("Actions must be tables. " .. + "Use custom middleware for %q actions."):format(typeof(action)), + 2 + ) + end + if action.type == nil then + error("Actions may not have an undefined 'type' property. " .. + "Have you misspelled a constant? \n" .. + inspect(action), 2) + end + + if self._isDispatching then + error("Reducers may not dispatch actions.") + end + + local ok, result = pcall(function() + self._isDispatching = true self._state = self._reducer(self._state, action) self._mutatedSinceFlush = true - else - error(("actions of type %q are not permitted"):format(typeof(action)), 2) + end) + + self._isDispatching = false + + if not ok then + self:_reportReducerError( + action, + result, + debug.traceback() + ) end + self._lastAction = action end --[[ diff --git a/src/Store.spec.lua b/src/Store.spec.lua index 16a1e20..91bc6d0 100644 --- a/src/Store.spec.lua +++ b/src/Store.spec.lua @@ -142,6 +142,39 @@ return function() store:destruct() end) + + it("should error if the reducer errors", function() + local reportedErrorMessage, reportedErrorError + local mockErrorReporter = { + reportErrorImmediately = function(_self, message, error_) + reportedErrorMessage = message + reportedErrorError = error_ + end, + reportErrorDeferred = function(_self, message, error_) + reportedErrorMessage = message + reportedErrorError = error_ + end + } + + local innerErrorMessage = "Z4PH0D" + local reducerThatErrors = function(state, action) + error(innerErrorMessage) + end + + local store + store = Store.new(reducerThatErrors, nil, nil, mockErrorReporter) + + local caughtErrorMessage = "Caught error with init" + expect(string.find(reportedErrorMessage, caughtErrorMessage)).to.be.ok() + expect(string.find(reportedErrorMessage, innerErrorMessage)).to.be.ok() + -- We want to verify that this is a stacktrace without caring too + -- much about the format, so we look for the stack frame associated + -- with this test file + expect(string.find(reportedErrorError, script.Name)).to.be.ok() + + store:destruct() + end) + end) describe("getState", function() @@ -260,13 +293,24 @@ return function() end) it("should prevent yielding from changed handler", function() + local reportedErrorMessage, reportedErrorError + local mockErrorReporter = { + reportErrorImmediately = function(_self, message, error_) + reportedErrorMessage = message + reportedErrorError = error_ + end, + reportErrorDeferred = function(_self, message, error_) + reportedErrorMessage = message + reportedErrorError = error_ + end + } local preCount = 0 local postCount = 0 local store = Store.new(function(state, action) state = state or 0 return state + 1 - end) + end, nil, nil, mockErrorReporter) store.changed:connect(function(state, oldState) preCount = preCount + 1 @@ -278,13 +322,25 @@ return function() type = "increment", }) - expect(function() - store:flush() - end).to.throw() + store:flush() expect(preCount).to.equal(1) expect(postCount).to.equal(0) + local caughtErrorMessage = "Caught error when calling event listener" + expect(string.find(reportedErrorMessage, caughtErrorMessage)).to.be.ok() + -- We want to verify that this is a stacktrace without caring too + -- much about the format, so we look for the stack frame associated + -- with this test file + expect(string.find(reportedErrorMessage, script.Name)).to.be.ok() + -- In vanilla lua, we get this message: + -- "attempt to yield across metamethod/C-call boundary" + -- In luau, we should end up wrapping our own NoYield message: + -- "Attempted to yield inside changed event!" + -- For convenience's sake, we just look for the common substring + local caughtErrorSubstring = "to yield" + expect(string.find(reportedErrorError, caughtErrorSubstring)).to.be.ok() + store:destruct() end) @@ -311,6 +367,42 @@ return function() store:destruct() end) + + it("should report an error if the reducer errors", function() + local reportedErrorMessage, reportedErrorError + local mockErrorReporter = { + reportErrorImmediately = function(_self, message, error_) + reportedErrorMessage = message + reportedErrorError = error_ + end, + reportErrorDeferred = function(_self, message, error_) + reportedErrorMessage = message + reportedErrorError = error_ + end + } + + local innerErrorMessage = "Z4PH0D" + local reducerCallCount = 0 + local reducerThatErrors = function(state, action) + if reducerCallCount > 0 then + error(innerErrorMessage) + end + reducerCallCount = reducerCallCount + 1 + end + local store = Store.new(reducerThatErrors, nil, nil, mockErrorReporter) + expect(reportedErrorMessage).to.equal(nil) + + store:dispatch({type = "any"}) + + local previousAction = "previous action type was: { type: \"@@INIT\" }" + expect(string.find(reportedErrorMessage, innerErrorMessage)).to.be.ok() + expect(string.find(reportedErrorMessage, previousAction)).to.be.ok() + -- We want to verify that this is a stacktrace without caring too + -- much about the format, so we look for the stack frame associated + -- with this test file + expect(string.find(reportedErrorError, script.Name)).to.be.ok() + store:destruct() + end) end) describe("flush", function() diff --git a/src/inspect.lua b/src/inspect.lua new file mode 100644 index 0000000..32082de --- /dev/null +++ b/src/inspect.lua @@ -0,0 +1,133 @@ +-- upstream: https://github.com/graphql/graphql-js/blob/1951bce42092123e844763b6a8e985a8a3327511/src/jsutils/inspect.js +local HttpService = game:GetService("HttpService") + +local isArray = require(script.Parent.isArray) +local objectKeys = require(script.Parent.objectKeys) + +local MAX_ARRAY_LENGTH = 10 +local MAX_RECURSIVE_DEPTH = 2 + +local formatValue +local formatObjectValue +local formatArray +local formatObject +local getObjectTag + +local function find(array, value) + for i = 1, #array do + if array[i] == value then + return i + end + end + return nil +end + +--[[ + * Used to print values in error messages. + ]] +local function inspect(value) + return formatValue(value, {}) +end + +function formatValue(value, seenValues) + local valueType = typeof(value) + if valueType == "string" then + return HttpService:JSONEncode(value) + elseif valueType == "number" then + if value ~= value then + return "NaN" + elseif value == math.huge then + return "Infinity" + elseif value == -math.huge then + return "-Infinity" + else + return tostring(value) + end + elseif valueType == "function" then + return "[function]" + elseif valueType == "table" then + return formatObjectValue(value, seenValues) + else + return tostring(value) + end +end + +function formatObjectValue(value, previouslySeenValues) + if find(previouslySeenValues, value) ~= nil then + return "[Circular]" + end + + local seenValues = { unpack(previouslySeenValues) } + table.insert(seenValues, value) + + if typeof(value.toJSON) == "function" then + local jsonValue = value:toJSON(value) + + if jsonValue ~= value then + if typeof(jsonValue) == "string" then + return jsonValue + else + return formatValue(jsonValue, seenValues) + end + end + elseif isArray(value) then + return formatArray(value, seenValues) + end + + return formatObject(value, seenValues) +end + +function formatObject(object, seenValues) + local keys = objectKeys(object) + + if #keys == 0 then + return "{}" + end + if #seenValues > MAX_RECURSIVE_DEPTH then + return "[" .. getObjectTag(object) .. "]" + end + + local properties = {} + for i = 1, #keys do + local key = keys[i] + local value = formatValue(object[key], seenValues) + + properties[i] = key .. ": " .. value + end + + return "{ " .. table.concat(properties, ", ") .. " }" +end + +function formatArray(array, seenValues) + local length = #array + if length == 0 then + return "[]" + end + if #seenValues > MAX_RECURSIVE_DEPTH then + return "[Array]" + end + + local len = math.min(MAX_ARRAY_LENGTH, length) + local remaining = length - len + local items = {} + + for i = 1, len do + items[i] = (formatValue(array[i], seenValues)) + end + + if remaining == 1 then + table.insert(items, "... 1 more item") + elseif remaining > 1 then + table.insert(items, ("... %s more items"):format(remaining)) + end + + return "[" .. table.concat(items, ", ") .. "]" +end + +function getObjectTag(_object) + return "Object" +end + +return { + inspect = inspect, +} diff --git a/src/isArray.lua b/src/isArray.lua new file mode 100644 index 0000000..2a13bc5 --- /dev/null +++ b/src/isArray.lua @@ -0,0 +1,30 @@ +return function(value) + if typeof(value) ~= "table" then + return false + end + if next(value) == nil then + -- an empty table is an empty array + return true + end + + local length = #value + + if length == 0 then + return false + end + + local count = 0 + local sum = 0 + for key in pairs(value) do + if typeof(key) ~= "number" then + return false + end + if key % 1 ~= 0 or key < 1 then + return false + end + count = count + 1 + sum = sum + key + end + + return sum == (count * (count + 1) / 2) +end diff --git a/src/loggerMiddleware.lua b/src/loggerMiddleware.lua index 9bc922b..b889474 100644 --- a/src/loggerMiddleware.lua +++ b/src/loggerMiddleware.lua @@ -1,40 +1,8 @@ -local indent = " " - -local function prettyPrint(value, indentLevel) - indentLevel = indentLevel or 0 - local output = {} - - if typeof(value) == "table" then - table.insert(output, "{\n") - - for key, value in pairs(value) do - table.insert(output, indent:rep(indentLevel + 1)) - table.insert(output, tostring(key)) - table.insert(output, " = ") - - table.insert(output, prettyPrint(value, indentLevel + 1)) - table.insert(output, "\n") - end - - table.insert(output, indent:rep(indentLevel)) - table.insert(output, "}") - elseif typeof(value) == "string" then - table.insert(output, string.format("%q", value)) - table.insert(output, " (string)") - else - table.insert(output, tostring(value)) - table.insert(output, " (") - table.insert(output, typeof(value)) - table.insert(output, ")") - end - - return table.concat(output, "") -end - -- We want to be able to override outputFunction in tests, so the shape of this -- module is kind of unconventional. -- -- We fix it this weird shape in init.lua. +local prettyPrint = require(script.Parent.prettyPrint) local loggerMiddleware = { outputFunction = print, } diff --git a/src/objectKeys.lua b/src/objectKeys.lua new file mode 100644 index 0000000..66e97d6 --- /dev/null +++ b/src/objectKeys.lua @@ -0,0 +1,21 @@ +return function(value) + if value == nil then + error("cannot extract keys from a nil value") + end + + local valueType = typeof(value) + + local keys = {} + if valueType == "table" then + for key in pairs(value) do + table.insert(keys, key) + end + elseif valueType == "string" then + local length = value:len() + for i = 1, length do + keys[i] = tostring(i) + end + end + + return keys +end diff --git a/src/prettyPrint.lua b/src/prettyPrint.lua new file mode 100644 index 0000000..dedbe1a --- /dev/null +++ b/src/prettyPrint.lua @@ -0,0 +1,34 @@ +local indent = " " + +local function prettyPrint(value, indentLevel) + indentLevel = indentLevel or 0 + local output = {} + + if typeof(value) == "table" then + table.insert(output, "{\n") + + for tableKey, tableValue in pairs(value) do + table.insert(output, indent:rep(indentLevel + 1)) + table.insert(output, tostring(tableKey)) + table.insert(output, " = ") + + table.insert(output, prettyPrint(tableValue, indentLevel + 1)) + table.insert(output, "\n") + end + + table.insert(output, indent:rep(indentLevel)) + table.insert(output, "}") + elseif typeof(value) == "string" then + table.insert(output, string.format("%q", value)) + table.insert(output, " (string)") + else + table.insert(output, tostring(value)) + table.insert(output, " (") + table.insert(output, typeof(value)) + table.insert(output, ")") + end + + return table.concat(output, "") +end + +return prettyPrint \ No newline at end of file diff --git a/src/thunkMiddleware.lua b/src/thunkMiddleware.lua index 08c676b..d521df9 100644 --- a/src/thunkMiddleware.lua +++ b/src/thunkMiddleware.lua @@ -4,13 +4,40 @@ This middleware consumes the function; middleware further down the chain will not receive it. ]] +local function reportThunkError(errorReporter, failedAction, error_, traceback) + local message = ("Caught error when running thunk (%s) " .. + "through thunk: \n%s"):format(tostring(failedAction), tostring(error_)) + + errorReporter:reportErrorImmediately(message, traceback) +end + local function thunkMiddleware(nextDispatch, store) return function(action) if typeof(action) == "function" then - return action(store) - else + local ok, result = pcall(function() + return action(store) + end) + + if not ok then + -- report the error and move on so it's non-fatal app + reportThunkError(store._errorReporter, action, result, debug.traceback()) + return nil + end + + return result + end + + local ok, result = pcall(function() return nextDispatch(action) + end) + + if not ok then + -- report the error and move on so it's non-fatal app + reportThunkError(store._errorReporter, action, result, debug.traceback()) + return nil end + + return result end end diff --git a/src/thunkMiddleware.spec.lua b/src/thunkMiddleware.spec.lua index 8f717e4..1cb9372 100644 --- a/src/thunkMiddleware.spec.lua +++ b/src/thunkMiddleware.spec.lua @@ -10,7 +10,7 @@ return function() local store = Store.new(reducer, {}, { thunkMiddleware }) local thunkCount = 0 - local function thunk(store) + local function thunk(_store) thunkCount = thunkCount + 1 end @@ -47,7 +47,7 @@ return function() local store = Store.new(reducer, {}, { thunkMiddleware }) local thunkValue = "test" - local function thunk(store) + local function thunk(_store) return thunkValue end