diff --git a/src/assertDeepEqual.lua b/src/assertDeepEqual.lua index 3f422d85..2fbda552 100644 --- a/src/assertDeepEqual.lua +++ b/src/assertDeepEqual.lua @@ -5,56 +5,7 @@ This should only be used in tests. ]] - -local function deepEqual(a, b) - if typeof(a) ~= typeof(b) then - local message = ("{1} is of type %s, but {2} is of type %s"):format( - typeof(a), - typeof(b) - ) - return false, message - end - - if typeof(a) == "table" then - local visitedKeys = {} - - for key, value in pairs(a) do - visitedKeys[key] = true - - local success, innerMessage = deepEqual(value, b[key]) - if not success then - local message = innerMessage - :gsub("{1}", ("{1}[%s]"):format(tostring(key))) - :gsub("{2}", ("{2}[%s]"):format(tostring(key))) - - return false, message - end - end - - for key, value in pairs(b) do - if not visitedKeys[key] then - local success, innerMessage = deepEqual(value, a[key]) - - if not success then - local message = innerMessage - :gsub("{1}", ("{1}[%s]"):format(tostring(key))) - :gsub("{2}", ("{2}[%s]"):format(tostring(key))) - - return false, message - end - end - end - - return true - end - - if a == b then - return true - end - - local message = "{1} ~= {2}" - return false, message -end +local deepEqual = require(script.Parent.deepEqual) local function assertDeepEqual(a, b) local success, innerMessageTemplate = deepEqual(a, b) diff --git a/src/assertDeepEqual.spec.lua b/src/assertDeepEqual.spec.lua index 484eeffd..b7610137 100644 --- a/src/assertDeepEqual.spec.lua +++ b/src/assertDeepEqual.spec.lua @@ -1,7 +1,12 @@ return function() local assertDeepEqual = require(script.Parent.assertDeepEqual) - it("should fail with a message when args are not equal", function() + it("should not throw if the args are equal", function() + assertDeepEqual(1, 1) + assertDeepEqual("hello", "hello") + end) + + it("should throw and format the error message when args are not equal", function() local success, message = pcall(assertDeepEqual, 1, 2) expect(success).to.equal(false) @@ -15,85 +20,7 @@ return function() expect(success).to.equal(false) expect(message:find("first%[foo%] ~= second%[foo%]")).to.be.ok() - end) - - it("should compare non-table values using standard '==' equality", function() - assertDeepEqual(1, 1) - assertDeepEqual("hello", "hello") - assertDeepEqual(nil, nil) - - local someFunction = function() end - local theSameFunction = someFunction - - assertDeepEqual(someFunction, theSameFunction) - - local A = { - foo = someFunction - } - local B = { - foo = theSameFunction - } - - assertDeepEqual(A, B) - end) - - it("should fail when types differ", function() - local success, message = pcall(assertDeepEqual, 1, "1") - - expect(success).to.equal(false) - expect(message:find("first is of type number, but second is of type string")).to.be.ok() - end) - - it("should compare (and report about) nested tables", function() - local A = { - foo = "bar", - nested = { - foo = 1, - bar = 2, - } - } - local B = { - foo = "bar", - nested = { - foo = 1, - bar = 2, - } - } - - assertDeepEqual(A, B) - - local C = { - foo = "bar", - nested = { - foo = 1, - bar = 3, - } - } - - local success, message = pcall(assertDeepEqual, A, C) - - expect(success).to.equal(false) - expect(message:find("first%[nested%]%[bar%] ~= second%[nested%]%[bar%]")).to.be.ok() - end) - - it("should be commutative", function() - local equalArgsA = { - foo = "bar", - hello = "world", - } - local equalArgsB = { - foo = "bar", - hello = "world", - } - - assertDeepEqual(equalArgsA, equalArgsB) - assertDeepEqual(equalArgsB, equalArgsA) - - local nonEqualArgs = { - foo = "bar", - } - - expect(function() assertDeepEqual(equalArgsA, nonEqualArgs) end).to.throw() - expect(function() assertDeepEqual(nonEqualArgs, equalArgsA) end).to.throw() + expect(message:find("{1}")).never.to.be.ok() + expect(message:find("{2}")).never.to.be.ok() end) end \ No newline at end of file diff --git a/src/deepEqual.lua b/src/deepEqual.lua new file mode 100644 index 00000000..0e052ba1 --- /dev/null +++ b/src/deepEqual.lua @@ -0,0 +1,51 @@ +local function deepEqual(a, b) + if typeof(a) ~= typeof(b) then + local message = ("{1} is of type %s, but {2} is of type %s"):format( + typeof(a), + typeof(b) + ) + return false, message + end + + if typeof(a) == "table" then + local visitedKeys = {} + + for key, value in pairs(a) do + visitedKeys[key] = true + + local success, innerMessage = deepEqual(value, b[key]) + if not success then + local message = innerMessage + :gsub("{1}", ("{1}[%s]"):format(tostring(key))) + :gsub("{2}", ("{2}[%s]"):format(tostring(key))) + + return false, message + end + end + + for key, value in pairs(b) do + if not visitedKeys[key] then + local success, innerMessage = deepEqual(value, a[key]) + + if not success then + local message = innerMessage + :gsub("{1}", ("{1}[%s]"):format(tostring(key))) + :gsub("{2}", ("{2}[%s]"):format(tostring(key))) + + return false, message + end + end + end + + return true + end + + if a == b then + return true + end + + local message = "{1} ~= {2}" + return false, message +end + +return deepEqual \ No newline at end of file diff --git a/src/deepEqual.spec.lua b/src/deepEqual.spec.lua new file mode 100644 index 00000000..41931156 --- /dev/null +++ b/src/deepEqual.spec.lua @@ -0,0 +1,99 @@ +return function() + local deepEqual = require(script.Parent.deepEqual) + + it("should compare non-table values using standard '==' equality", function() + expect(deepEqual(1, 1)).to.equal(true) + expect(deepEqual("hello", "hello")).to.equal(true) + expect(deepEqual(nil, nil)).to.equal(true) + + local someFunction = function() end + local theSameFunction = someFunction + + expect(deepEqual(someFunction, theSameFunction)).to.equal(true) + + local A = { + foo = someFunction + } + local B = { + foo = theSameFunction + } + + expect(deepEqual(A, B)).to.equal(true) + end) + + it("should fail with a message when args are not equal", function() + local success, message = deepEqual(1, 2) + + expect(success).to.equal(false) + expect(message:find("{1} ~= {2}")).to.be.ok() + + success, message = deepEqual({ + foo = 1, + }, { + foo = 2, + }) + + expect(success).to.equal(false) + expect(message:find("{1}%[foo%] ~= {2}%[foo%]")).to.be.ok() + end) + + it("should fail when types differ", function() + local success, message = deepEqual(1, "1") + + expect(success).to.equal(false) + expect(message:find("{1} is of type number, but {2} is of type string")).to.be.ok() + end) + + it("should compare (and report about) nested tables", function() + local A = { + foo = "bar", + nested = { + foo = 1, + bar = 2, + } + } + local B = { + foo = "bar", + nested = { + foo = 1, + bar = 2, + } + } + + deepEqual(A, B) + + local C = { + foo = "bar", + nested = { + foo = 1, + bar = 3, + } + } + + local success, message = deepEqual(A, C) + + expect(success).to.equal(false) + expect(message:find("{1}%[nested%]%[bar%] ~= {2}%[nested%]%[bar%]")).to.be.ok() + end) + + it("should be commutative", function() + local equalArgsA = { + foo = "bar", + hello = "world", + } + local equalArgsB = { + foo = "bar", + hello = "world", + } + + expect(deepEqual(equalArgsA, equalArgsB)).to.equal(true) + expect(deepEqual(equalArgsB, equalArgsA)).to.equal(true) + + local nonEqualArgs = { + foo = "bar", + } + + expect(deepEqual(equalArgsA, nonEqualArgs)).to.equal(false) + expect(deepEqual(nonEqualArgs, equalArgsA)).to.equal(false) + end) +end \ No newline at end of file diff --git a/src/init.lua b/src/init.lua index f002f975..8e9240f4 100644 --- a/src/init.lua +++ b/src/init.lua @@ -8,6 +8,7 @@ local createReconcilerCompat = require(script.createReconcilerCompat) local RobloxRenderer = require(script.RobloxRenderer) local strict = require(script.strict) local Binding = require(script.Binding) +local shallow = require(script.shallow) local robloxReconciler = createReconciler(RobloxRenderer) local reconcilerCompat = createReconcilerCompat(robloxReconciler) @@ -32,6 +33,7 @@ local Roact = strict { mount = robloxReconciler.mountVirtualTree, unmount = robloxReconciler.unmountVirtualTree, update = robloxReconciler.updateVirtualTree, + shallow = shallow, reify = reconcilerCompat.reify, teardown = reconcilerCompat.teardown, diff --git a/src/init.spec.lua b/src/init.spec.lua index 7fcf79c8..439620a6 100644 --- a/src/init.spec.lua +++ b/src/init.spec.lua @@ -11,6 +11,7 @@ return function() mount = "function", unmount = "function", update = "function", + shallow = "function", oneChild = "function", setGlobalConfig = "function", diff --git a/src/shallow/ShallowWrapper.lua b/src/shallow/ShallowWrapper.lua new file mode 100644 index 00000000..2961f941 --- /dev/null +++ b/src/shallow/ShallowWrapper.lua @@ -0,0 +1,184 @@ +local RoactRoot = script.Parent.Parent + +local Children = require(RoactRoot.PropMarkers.Children) +local ElementKind = require(RoactRoot.ElementKind) +local ElementUtils = require(RoactRoot.ElementUtils) +local VirtualNodeConstraints = require(script.Parent.VirtualNodeConstraints) +local Snapshot = require(script.Parent.Snapshot) + +local ShallowWrapper = {} +local ShallowWrapperMetatable = { + __index = ShallowWrapper, +} + +local function getTypeFromVirtualNode(virtualNode) + local element = virtualNode.currentElement + local kind = ElementKind.of(element) + + if kind == ElementKind.Host then + return { + kind = ElementKind.Host, + className = element.component, + } + elseif kind == ElementKind.Function then + return { + kind = ElementKind.Function, + functionComponent = element.component, + } + elseif kind == ElementKind.Stateful then + return { + kind = ElementKind.Stateful, + component = element.component, + } + else + error(('shallow wrapper does not support element of kind %q'):format(tostring(kind))) + end +end + +local function findNextVirtualNode(virtualNode, maxDepth) + local currentDepth = 0 + local currentNode = virtualNode + local nextNode = currentNode.children[ElementUtils.UseParentKey] + + while currentDepth < maxDepth and nextNode ~= nil do + currentNode = nextNode + nextNode = currentNode.children[ElementUtils.UseParentKey] + currentDepth = currentDepth + 1 + end + + return currentNode +end + +local function countChildrenOfElement(element) + if ElementKind.of(element) == ElementKind.Fragment then + local count = 0 + + for _, subElement in pairs(element.elements) do + count = count + countChildrenOfElement(subElement) + end + + return count + else + return 1 + end +end + +local function getChildren(virtualNode, results, maxDepth) + if ElementKind.of(virtualNode.currentElement) == ElementKind.Fragment then + for _, subVirtualNode in pairs(virtualNode.children) do + getChildren(subVirtualNode, results, maxDepth) + end + else + local childWrapper = ShallowWrapper.new( + virtualNode, + maxDepth + ) + + table.insert(results, childWrapper) + end +end + +local function filterProps(props) + if props[Children] == nil then + return props + end + + local filteredProps = {} + + for key, value in pairs(props) do + if key ~= Children then + filteredProps[key] = value + end + end + + return filteredProps +end + +function ShallowWrapper.new(virtualNode, maxDepth) + virtualNode = findNextVirtualNode(virtualNode, maxDepth) + + local wrapper = { + _virtualNode = virtualNode, + _childrenMaxDepth = maxDepth - 1, + _virtualNodeChildren = maxDepth == 0 and {} or virtualNode.children, + type = getTypeFromVirtualNode(virtualNode), + props = filterProps(virtualNode.currentElement.props), + hostKey = virtualNode.hostKey, + instance = virtualNode.hostObject, + } + + return setmetatable(wrapper, ShallowWrapperMetatable) +end + +function ShallowWrapper:childrenCount() + local count = 0 + + for _, virtualNode in pairs(self._virtualNodeChildren) do + local element = virtualNode.currentElement + count = count + countChildrenOfElement(element) + end + + return count +end + +function ShallowWrapper:find(constraints) + VirtualNodeConstraints.validate(constraints) + + local results = {} + local children = self:getChildren() + + for i=1, #children do + local childWrapper = children[i] + + if VirtualNodeConstraints.satisfiesAll(childWrapper._virtualNode, constraints) then + table.insert(results, childWrapper) + end + end + + return results +end + +function ShallowWrapper:findUnique(constraints) + local children = self:getChildren() + + if constraints == nil then + assert( + #children == 1, + ("expect to contain exactly one child, but found %d"):format(#children) + ) + return children[1] + end + + local constrainedChildren = self:find(constraints) + + assert( + #constrainedChildren == 1, + ("expect to find only one child, but found %d"):format(#constrainedChildren) + ) + + return constrainedChildren[1] +end + +function ShallowWrapper:getChildren() + local results = {} + + for _, childVirtualNode in pairs(self._virtualNodeChildren) do + getChildren(childVirtualNode, results, self._childrenMaxDepth) + end + + return results +end + +function ShallowWrapper:matchSnapshot(identifier) + assert(typeof(identifier) == "string", "Snapshot identifier must be a string") + + local snapshotResult = Snapshot.createMatcher(identifier, self) + + snapshotResult:match() +end + +function ShallowWrapper:snapshotToString() + return Snapshot.toString(self) +end + +return ShallowWrapper \ No newline at end of file diff --git a/src/shallow/ShallowWrapper.spec.lua b/src/shallow/ShallowWrapper.spec.lua new file mode 100644 index 00000000..e2ddd7a5 --- /dev/null +++ b/src/shallow/ShallowWrapper.spec.lua @@ -0,0 +1,575 @@ +return function() + local RoactRoot = script.Parent.Parent + local ShallowWrapper = require(script.Parent.ShallowWrapper) + + local assertDeepEqual = require(RoactRoot.assertDeepEqual) + local Children = require(RoactRoot.PropMarkers.Children) + local ElementKind = require(RoactRoot.ElementKind) + local createElement = require(RoactRoot.createElement) + local createFragment = require(RoactRoot.createFragment) + local createReconciler = require(RoactRoot.createReconciler) + local RoactComponent = require(RoactRoot.Component) + local RobloxRenderer = require(RoactRoot.RobloxRenderer) + + local robloxReconciler = createReconciler(RobloxRenderer) + + local function shallow(element, options) + options = options or {} + local maxDepth = options.depth or 1 + local hostKey = options.hostKey or "ShallowTree" + local hostParent = options.hostParent or Instance.new("Folder") + + local virtualNode = robloxReconciler.mountVirtualNode(element, hostParent, hostKey) + + return ShallowWrapper.new(virtualNode, maxDepth) + end + + describe("single host element", function() + local className = "TextLabel" + + local function Component(props) + return createElement(className, props) + end + + it("should have it's type.kind to Host", function() + local element = createElement(Component) + + local result = shallow(element) + + expect(result.type.kind).to.equal(ElementKind.Host) + end) + + it("should have its type.className to given instance class", function() + local element = createElement(Component) + + local result = shallow(element) + + expect(result.type.className).to.equal(className) + end) + + it("children count should be zero", function() + local element = createElement(Component) + + local result = shallow(element) + + expect(result:childrenCount()).to.equal(0) + end) + end) + + describe("single function element", function() + local function FunctionComponent(props) + return createElement("TextLabel") + end + + local function Component(props) + return createElement(FunctionComponent, props) + end + + it("should have its type.kind to Function", function() + local element = createElement(Component) + + local result = shallow(element) + + expect(result.type.kind).to.equal(ElementKind.Function) + end) + + it("should have its type.functionComponent to Function", function() + local element = createElement(Component) + + local result = shallow(element) + + expect(result.type.functionComponent).to.equal(FunctionComponent) + end) + end) + + describe("single stateful element", function() + local StatefulComponent = RoactComponent:extend("StatefulComponent") + + function StatefulComponent:render() + return createElement("TextLabel") + end + + local function Component(props) + return createElement(StatefulComponent, props) + end + + it("should have its type.kind to Stateful", function() + local element = createElement(Component) + + local result = shallow(element) + + expect(result.type.kind).to.equal(ElementKind.Stateful) + end) + + it("should have its type.component to given component class", function() + local element = createElement(Component) + + local result = shallow(element) + + expect(result.type.component).to.equal(StatefulComponent) + end) + end) + + describe("depth", function() + local unwrappedClassName = "TextLabel" + local function A(props) + return createElement(unwrappedClassName) + end + + local function B(props) + return createElement(A) + end + + local function Component(props) + return createElement(B) + end + + local function ComponentWithChildren(props) + return createElement("Frame", {}, { + ChildA = createElement(A), + ChildB = createElement(B), + }) + end + + it("should unwrap function components when depth has not exceeded", function() + local element = createElement(Component) + + local result = shallow(element, { + depth = 3, + }) + + expect(result.type.kind).to.equal(ElementKind.Host) + expect(result.type.className).to.equal(unwrappedClassName) + end) + + it("should stop unwrapping function components when depth has exceeded", function() + local element = createElement(Component) + + local result = shallow(element, { + depth = 2, + }) + + expect(result.type.kind).to.equal(ElementKind.Function) + expect(result.type.functionComponent).to.equal(A) + end) + + it("should not unwrap the element when depth is zero", function() + local element = createElement(Component) + + local result = shallow(element, { + depth = 0, + }) + + expect(result.type.kind).to.equal(ElementKind.Function) + expect(result.type.functionComponent).to.equal(Component) + end) + + it("should not unwrap children when depth is one", function() + local element = createElement(ComponentWithChildren) + + local result = shallow(element, { + depth = 1, + }) + + local childA = result:find({ + component = A, + }) + expect(#childA).to.equal(1) + + local childB = result:find({ + component = B, + }) + expect(#childB).to.equal(1) + end) + + it("should unwrap children when depth is two", function() + local element = createElement(ComponentWithChildren) + + local result = shallow(element, { + depth = 2, + }) + + local hostChild = result:find({ + component = unwrappedClassName, + }) + expect(#hostChild).to.equal(1) + + local unwrappedBChild = result:find({ + component = A, + }) + expect(#unwrappedBChild).to.equal(1) + end) + + it("should not include any children when depth is zero", function() + local element = createElement(ComponentWithChildren) + + local result = shallow(element, { + depth = 0, + }) + + expect(result:childrenCount()).to.equal(0) + end) + + it("should not include any grand-children when depth is one", function() + local function ParentComponent() + return createElement("Folder", {}, { + Child = createElement(ComponentWithChildren), + }) + end + + local element = createElement(ParentComponent) + + local result = shallow(element, { + depth = 1, + }) + + expect(result:childrenCount()).to.equal(1) + + local componentWithChildrenWrapper = result:find({ + component = ComponentWithChildren, + })[1] + expect(componentWithChildrenWrapper).to.be.ok() + + expect(componentWithChildrenWrapper:childrenCount()).to.equal(0) + end) + end) + + describe("childrenCount", function() + local childClassName = "TextLabel" + + local function Component(props) + local children = {} + + for i=1, props.childrenCount do + children[("Key%d"):format(i)] = createElement(childClassName) + end + + return createElement("Frame", {}, children) + end + + it("should return 1 when the element contains only one child element", function() + local element = createElement(Component, { + childrenCount = 1, + }) + + local result = shallow(element) + + expect(result:childrenCount()).to.equal(1) + end) + + it("should return 0 when the element does not contain elements", function() + local element = createElement(Component, { + childrenCount = 0, + }) + + local result = shallow(element) + + expect(result:childrenCount()).to.equal(0) + end) + + it("should count children in a fragment", function() + local element = createElement("Frame", {}, { + Frag = createFragment({ + Label = createElement("TextLabel"), + Button = createElement("TextButton"), + }) + }) + + local result = shallow(element) + + expect(result:childrenCount()).to.equal(2) + end) + + it("should count children nested in fragments", function() + local element = createElement("Frame", {}, { + Frag = createFragment({ + SubFrag = createFragment({ + Frame = createElement("Frame"), + }), + Label = createElement("TextLabel"), + Button = createElement("TextButton"), + }) + }) + + local result = shallow(element) + + expect(result:childrenCount()).to.equal(3) + end) + end) + + describe("props", function() + it("should contains the same props using Host element", function() + local function Component(props) + return createElement("Frame", props) + end + + local props = { + BackgroundTransparency = 1, + Visible = false, + } + local element = createElement(Component, props) + + local result = shallow(element) + + expect(result.type.kind).to.equal(ElementKind.Host) + expect(result.props).to.be.ok() + + assertDeepEqual(props, result.props) + end) + + it("should have the same props using function element", function() + local function ChildComponent(props) + return createElement("Frame", props) + end + + local function Component(props) + return createElement(ChildComponent, props) + end + + local props = { + BackgroundTransparency = 1, + Visible = false, + } + local propsCopy = {} + for key, value in pairs(props) do + propsCopy[key] = value + end + local element = createElement(Component, props) + + local result = shallow(element) + + expect(result.type.kind).to.equal(ElementKind.Function) + expect(result.props).to.be.ok() + + assertDeepEqual(propsCopy, result.props) + end) + + it("should not have the children property", function() + local function ComponentWithChildren(props) + return createElement("Frame", props, { + Key = createElement("TextLabel"), + }) + end + + local props = { + BackgroundTransparency = 1, + Visible = false, + } + + local element = createElement(ComponentWithChildren, props) + + local result = shallow(element) + + expect(result.props).to.be.ok() + expect(result.props[Children]).never.to.be.ok() + end) + + it("should have the inherited props", function() + local function Component(props) + local frameProps = { + LayoutOrder = 7, + } + for key, value in pairs(props) do + frameProps[key] = value + end + + return createElement("Frame", frameProps) + end + + local element = createElement(Component, { + BackgroundTransparency = 1, + Visible = false, + }) + + local result = shallow(element) + + expect(result.props).to.be.ok() + + local expectProps = { + BackgroundTransparency = 1, + Visible = false, + LayoutOrder = 7, + } + + assertDeepEqual(expectProps, result.props) + end) + end) + + describe("instance", function() + it("should contain the instance when it is a host component", function() + local className = "Frame" + local function Component(props) + return createElement(className) + end + + local element = createElement(Component) + + local result = shallow(element) + + expect(result.instance).to.be.ok() + expect(result.instance.ClassName).to.equal(className) + end) + + it("should not have an instance if it is a function component", function() + local function Child() + return createElement("Frame") + end + local function Component(props) + return createElement(Child) + end + + local element = createElement(Component) + + local result = shallow(element) + + expect(result.instance).never.to.be.ok() + end) + end) + + describe("find children", function() + it("should throw if the constraint does not exist", function() + local element = createElement("Frame") + + local result = shallow(element) + + local function findWithInvalidConstraint() + result:find({ + nothing = false, + }) + end + + expect(findWithInvalidConstraint).to.throw() + end) + + it("should return children that matches all contraints", function() + local function ComponentWithChildren() + return createElement("Frame", {}, { + ChildA = createElement("TextLabel", { + Visible = false, + }), + ChildB = createElement("TextButton", { + Visible = false, + }), + }) + end + + local element = createElement(ComponentWithChildren) + + local result = shallow(element) + + local children = result:find({ + className = "TextLabel", + props = { + Visible = false, + }, + }) + + expect(#children).to.equal(1) + end) + + it("should return children from fragments", function() + local childClassName = "TextLabel" + + local function ComponentWithFragment() + return createElement("Frame", {}, { + Fragment = createFragment({ + Child = createElement(childClassName), + }), + }) + end + + local element = createElement(ComponentWithFragment) + + local result = shallow(element) + + local children = result:find({ + className = childClassName + }) + + expect(#children).to.equal(1) + end) + + it("should return children from nested fragments", function() + local childClassName = "TextLabel" + + local function ComponentWithFragment() + return createElement("Frame", {}, { + Fragment = createFragment({ + SubFragment = createFragment({ + Child = createElement(childClassName), + }), + }), + }) + end + + local element = createElement(ComponentWithFragment) + + local result = shallow(element) + + local children = result:find({ + className = childClassName + }) + + expect(#children).to.equal(1) + end) + end) + + describe("findUnique", function() + it("should return the only child when no constraints are given", function() + local element = createElement("Frame", {}, { + Child = createElement("TextLabel"), + }) + + local result = shallow(element) + + local child = result:findUnique() + + expect(child.type.kind).to.equal(ElementKind.Host) + expect(child.type.className).to.equal("TextLabel") + end) + + it("should return the only child that satifies the constraint", function() + local element = createElement("Frame", {}, { + ChildA = createElement("TextLabel"), + ChildB = createElement("TextButton"), + }) + + local result = shallow(element) + + local child = result:findUnique({ + className = "TextLabel", + }) + + expect(child.type.className).to.equal("TextLabel") + end) + + it("should throw if there is not any child element", function() + local element = createElement("Frame") + + local result = shallow(element) + + local function shouldThrow() + result:findUnique() + end + + expect(shouldThrow).to.throw() + end) + + it("should throw if more than one child satisfies the constraint", function() + local element = createElement("Frame", {}, { + ChildA = createElement("TextLabel"), + ChildB = createElement("TextLabel"), + }) + + local result = shallow(element) + + local function shouldThrow() + result:findUnique({ + className = "TextLabel", + }) + end + + expect(shouldThrow).to.throw() + end) + end) +end \ No newline at end of file diff --git a/src/shallow/Snapshot/Serialize/IndentedOutput.lua b/src/shallow/Snapshot/Serialize/IndentedOutput.lua new file mode 100644 index 00000000..c4ccd3ae --- /dev/null +++ b/src/shallow/Snapshot/Serialize/IndentedOutput.lua @@ -0,0 +1,54 @@ +local IndentedOutput = {} +local IndentedOutputMetatable = { + __index = IndentedOutput, +} + +function IndentedOutput.new(indentation) + indentation = indentation or 2 + + local output = { + _level = 0, + _indentation = (" "):rep(indentation), + _lines = {}, + } + + setmetatable(output, IndentedOutputMetatable) + + return output +end + +function IndentedOutput:write(line, ...) + if select("#", ...) > 0 then + line = line:format(...) + end + + local indentedLine = ("%s%s"):format(self._indentation:rep(self._level), line) + + table.insert(self._lines, indentedLine) +end + +function IndentedOutput:push() + self._level = self._level + 1 +end + +function IndentedOutput:pop() + self._level = math.max(self._level - 1, 0) +end + +function IndentedOutput:writeAndPush(...) + self:write(...) + self:push() +end + +function IndentedOutput:popAndWrite(...) + self:pop() + self:write(...) +end + +function IndentedOutput:join(separator) + separator = separator or "\n" + + return table.concat(self._lines, separator) +end + +return IndentedOutput diff --git a/src/shallow/Snapshot/Serialize/IndentedOutput.spec.lua b/src/shallow/Snapshot/Serialize/IndentedOutput.spec.lua new file mode 100644 index 00000000..7f9ffbe1 --- /dev/null +++ b/src/shallow/Snapshot/Serialize/IndentedOutput.spec.lua @@ -0,0 +1,72 @@ +return function() + local IndentedOutput = require(script.Parent.IndentedOutput) + + describe("join", function() + it("should concat the lines with a new line by default", function() + local output = IndentedOutput.new() + + output:write("foo") + output:write("bar") + + expect(output:join()).to.equal("foo\nbar") + end) + + it("should concat the lines with the given string", function() + local output = IndentedOutput.new() + + output:write("foo") + output:write("bar") + + expect(output:join("-")).to.equal("foo-bar") + end) + end) + + describe("push", function() + it("should indent next written lines", function() + local output = IndentedOutput.new() + + output:write("foo") + output:push() + output:write("bar") + + expect(output:join()).to.equal("foo\n bar") + end) + end) + + describe("pop", function() + it("should dedent next written lines", function() + local output = IndentedOutput.new() + + output:write("foo") + output:push() + output:write("bar") + output:pop() + output:write("baz") + + expect(output:join()).to.equal("foo\n bar\nbaz") + end) + end) + + describe("writeAndPush", function() + it("should write the line and push", function() + local output = IndentedOutput.new() + + output:writeAndPush("foo") + output:write("bar") + + expect(output:join()).to.equal("foo\n bar") + end) + end) + + describe("popAndWrite", function() + it("should write the line and push", function() + local output = IndentedOutput.new() + + output:writeAndPush("foo") + output:write("bar") + output:popAndWrite("baz") + + expect(output:join()).to.equal("foo\n bar\nbaz") + end) + end) +end \ No newline at end of file diff --git a/src/shallow/Snapshot/Serialize/Markers/AnonymousFunction.lua b/src/shallow/Snapshot/Serialize/Markers/AnonymousFunction.lua new file mode 100644 index 00000000..4a3c04fa --- /dev/null +++ b/src/shallow/Snapshot/Serialize/Markers/AnonymousFunction.lua @@ -0,0 +1,7 @@ +local RoactRoot = script.Parent.Parent.Parent.Parent.Parent + +local Symbol = require(RoactRoot.Symbol) + +local AnonymousFunction = Symbol.named("AnonymousFunction") + +return AnonymousFunction \ No newline at end of file diff --git a/src/shallow/Snapshot/Serialize/Markers/EmptyRef.lua b/src/shallow/Snapshot/Serialize/Markers/EmptyRef.lua new file mode 100644 index 00000000..53b80650 --- /dev/null +++ b/src/shallow/Snapshot/Serialize/Markers/EmptyRef.lua @@ -0,0 +1,7 @@ +local RoactRoot = script.Parent.Parent.Parent.Parent.Parent + +local Symbol = require(RoactRoot.Symbol) + +local EmptyRef = Symbol.named("EmptyRef") + +return EmptyRef \ No newline at end of file diff --git a/src/shallow/Snapshot/Serialize/Markers/Signal.lua b/src/shallow/Snapshot/Serialize/Markers/Signal.lua new file mode 100644 index 00000000..04244470 --- /dev/null +++ b/src/shallow/Snapshot/Serialize/Markers/Signal.lua @@ -0,0 +1,7 @@ +local RoactRoot = script.Parent.Parent.Parent.Parent.Parent + +local Symbol = require(RoactRoot.Symbol) + +local Signal = Symbol.named("Signal") + +return Signal \ No newline at end of file diff --git a/src/shallow/Snapshot/Serialize/Markers/Unknown.lua b/src/shallow/Snapshot/Serialize/Markers/Unknown.lua new file mode 100644 index 00000000..82ab0d50 --- /dev/null +++ b/src/shallow/Snapshot/Serialize/Markers/Unknown.lua @@ -0,0 +1,7 @@ +local RoactRoot = script.Parent.Parent.Parent.Parent.Parent + +local Symbol = require(RoactRoot.Symbol) + +local Unkown = Symbol.named("Unkown") + +return Unkown \ No newline at end of file diff --git a/src/shallow/Snapshot/Serialize/Markers/init.lua b/src/shallow/Snapshot/Serialize/Markers/init.lua new file mode 100644 index 00000000..3999927c --- /dev/null +++ b/src/shallow/Snapshot/Serialize/Markers/init.lua @@ -0,0 +1,10 @@ +local RoactRoot = script.Parent.Parent.Parent.Parent + +local strict = require(RoactRoot.strict) + +return strict({ + AnonymousFunction = require(script.AnonymousFunction), + EmptyRef = require(script.EmptyRef), + Signal = require(script.Signal), + Unknown = require(script.Unknown), +}, "Markers") \ No newline at end of file diff --git a/src/shallow/Snapshot/Serialize/Serializer.lua b/src/shallow/Snapshot/Serialize/Serializer.lua new file mode 100644 index 00000000..b74a8537 --- /dev/null +++ b/src/shallow/Snapshot/Serialize/Serializer.lua @@ -0,0 +1,235 @@ +local RoactRoot = script.Parent.Parent.Parent.Parent + +local ElementKind = require(RoactRoot.ElementKind) +local Ref = require(RoactRoot.PropMarkers.Ref) +local Type = require(RoactRoot.Type) +local Markers = require(script.Parent.Markers) +local IndentedOutput = require(script.Parent.IndentedOutput) + +local Serializer = {} + +function Serializer.kind(kind) + if kind == ElementKind.Host then + return "Host" + elseif kind == ElementKind.Function then + return "Function" + elseif kind == ElementKind.Stateful then + return "Stateful" + else + error(("Cannot serialize ElementKind %q"):format(tostring(kind))) + end +end + +function Serializer.type(data, output) + output:writeAndPush("type = {") + output:write("kind = ElementKind.%s,", Serializer.kind(data.kind)) + + if data.className then + output:write("className = %q,", data.className) + elseif data.componentName then + output:write("componentName = %q,", data.componentName) + end + + output:popAndWrite("},") +end + +function Serializer.tableKey(key) + local keyType = type(key) + + if keyType == "string" and key:match("^%a%w+$") then + return key + else + return ("[%s]"):format(Serializer.tableValue(key)) + end +end + +function Serializer.number(value) + local _, fraction = math.modf(value) + + if fraction == 0 then + return ("%s"):format(tostring(value)) + else + return ("%0.7f"):format(value):gsub("%.?0+$", "") + end +end + +function Serializer.tableValue(value) + local valueType = typeof(value) + + if valueType == "string" then + return ("%q"):format(value) + + elseif valueType == "number" then + return Serializer.number(value) + + elseif valueType == "boolean" then + return ("%s"):format(tostring(value)) + + elseif valueType == "Color3" then + return ("Color3.new(%s, %s, %s)"):format( + Serializer.number(value.r), + Serializer.number(value.g), + Serializer.number(value.b) + ) + + elseif valueType == "EnumItem" then + return ("%s"):format(tostring(value)) + + elseif valueType == "UDim" then + return ("UDim.new(%s, %d)"):format(Serializer.number(value.Scale), value.Offset) + + elseif valueType == "UDim2" then + return ("UDim2.new(%s, %d, %s, %d)"):format( + Serializer.number(value.X.Scale), + value.X.Offset, + Serializer.number(value.Y.Scale), + value.Y.Offset + ) + + elseif valueType == "Vector2" then + return ("Vector2.new(%s, %s)"):format( + Serializer.number(value.X), + Serializer.number(value.Y) + ) + + elseif Type.of(value) == Type.HostEvent then + return ("Roact.Event.%s"):format(value.name) + + elseif Type.of(value) == Type.HostChangeEvent then + return ("Roact.Change.%s"):format(value.name) + + elseif value == Ref then + return "Roact.Ref" + + else + for markerName, marker in pairs(Markers) do + if value == marker then + return ("Markers.%s"):format(markerName) + end + end + + error(("Cannot serialize value %q of type %q"):format( + tostring(value), + valueType + )) + end +end + +function Serializer.getKeyTypeOrder(key) + if type(key) == "string" then + return 1 + elseif Type.of(key) == Type.HostEvent then + return 2 + elseif Type.of(key) == Type.HostChangeEvent then + return 3 + elseif key == Ref then + return 4 + else + return math.huge + end +end + +function Serializer.compareKeys(a, b) + -- a and b are of the same type here, because Serializer.sortTableKeys + -- will only use this function to compare keys of the same type + if Type.of(a) == Type.HostEvent or Type.of(a) == Type.HostChangeEvent then + return a.name < b.name + else + return a < b + end +end + +function Serializer.sortTableKeys(a, b) + -- first sort by the type of key, to place string props, then Roact.Event + -- events, Roact.Change events and the Ref + local orderA = Serializer.getKeyTypeOrder(a) + local orderB = Serializer.getKeyTypeOrder(b) + + if orderA == orderB then + return Serializer.compareKeys(a, b) + else + return orderA < orderB + end +end + +function Serializer.table(tableKey, dict, output) + if next(dict) == nil then + output:write("%s = {},", tableKey) + return + end + + output:writeAndPush("%s = {", tableKey) + + local keys = {} + + for key in pairs(dict) do + table.insert(keys, key) + end + + table.sort(keys, Serializer.sortTableKeys) + + for i=1, #keys do + local key = keys[i] + local value = dict[key] + local serializedKey = Serializer.tableKey(key) + + if type(value) == "table" then + Serializer.table(serializedKey, value, output) + else + output:write("%s = %s,", serializedKey, Serializer.tableValue(value)) + end + end + + output:popAndWrite("},") +end + +function Serializer.props(props, output) + Serializer.table("props", props, output) +end + +function Serializer.children(children, output) + if #children == 0 then + output:write("children = {},") + return + end + + output:writeAndPush("children = {") + + for i=1, #children do + Serializer.snapshotData(children[i], output) + end + + output:popAndWrite("},") +end + +function Serializer.snapshotDataContent(snapshotData, output) + Serializer.type(snapshotData.type, output) + output:write("hostKey = %q,", snapshotData.hostKey) + Serializer.props(snapshotData.props, output) + Serializer.children(snapshotData.children, output) +end + +function Serializer.snapshotData(snapshotData, output) + output:writeAndPush("{") + Serializer.snapshotDataContent(snapshotData, output) + output:popAndWrite("},") +end + +function Serializer.firstSnapshotData(snapshotData) + local output = IndentedOutput.new() + output:writeAndPush("return function(dependencies)") + output:write("local Roact = dependencies.Roact") + output:write("local ElementKind = dependencies.ElementKind") + output:write("local Markers = dependencies.Markers") + output:write("") + output:writeAndPush("return {") + + Serializer.snapshotDataContent(snapshotData, output) + + output:popAndWrite("}") + output:popAndWrite("end") + + return output:join() +end + +return Serializer diff --git a/src/shallow/Snapshot/Serialize/Serializer.spec.lua b/src/shallow/Snapshot/Serialize/Serializer.spec.lua new file mode 100644 index 00000000..833102ac --- /dev/null +++ b/src/shallow/Snapshot/Serialize/Serializer.spec.lua @@ -0,0 +1,377 @@ +return function() + local RoactRoot = script.Parent.Parent.Parent.Parent + + local Markers = require(script.Parent.Markers) + local Change = require(RoactRoot.PropMarkers.Change) + local Event = require(RoactRoot.PropMarkers.Event) + local ElementKind = require(RoactRoot.ElementKind) + local IndentedOutput = require(script.Parent.IndentedOutput) + local Ref = require(RoactRoot.PropMarkers.Ref) + local Serializer = require(script.Parent.Serializer) + + describe("type", function() + it("should serialize host elements", function() + local output = IndentedOutput.new() + Serializer.type({ + kind = ElementKind.Host, + className = "TextLabel", + }, output) + + expect(output:join()).to.equal( + "type = {\n" + .. " kind = ElementKind.Host,\n" + .. " className = \"TextLabel\",\n" + .. "}," + ) + end) + + it("should serialize stateful elements", function() + local output = IndentedOutput.new() + Serializer.type({ + kind = ElementKind.Stateful, + componentName = "SomeComponent", + }, output) + + expect(output:join()).to.equal( + "type = {\n" + .. " kind = ElementKind.Stateful,\n" + .. " componentName = \"SomeComponent\",\n" + .. "}," + ) + end) + + it("should serialize function elements", function() + local output = IndentedOutput.new() + Serializer.type({ + kind = ElementKind.Function, + }, output) + + expect(output:join()).to.equal( + "type = {\n" + .. " kind = ElementKind.Function,\n" + .. "}," + ) + end) + end) + + describe("tableKey", function() + it("should serialize to a named dictionary field", function() + local keys = {"foo", "foo1"} + + for i=1, #keys do + local key = keys[i] + local result = Serializer.tableKey(key) + + expect(result).to.equal(key) + end + end) + + it("should serialize to a string value field to escape non-alphanumeric characters", function() + local keys = {"foo.bar", "1foo"} + + for i=1, #keys do + local key = keys[i] + local result = Serializer.tableKey(key) + + expect(result).to.equal('["' .. key .. '"]') + end + end) + end) + + describe("number", function() + it("should format integers", function() + expect(Serializer.number(1)).to.equal("1") + expect(Serializer.number(0)).to.equal("0") + expect(Serializer.number(10)).to.equal("10") + end) + + it("should minimize floating points zeros", function() + expect(Serializer.number(1.2)).to.equal("1.2") + expect(Serializer.number(0.002)).to.equal("0.002") + expect(Serializer.number(5.5001)).to.equal("5.5001") + end) + + it("should keep only 7 decimals", function() + expect(Serializer.number(0.123456789)).to.equal("0.1234568") + expect(Serializer.number(0.123456709)).to.equal("0.1234567") + end) + end) + + describe("tableValue", function() + it("should serialize strings", function() + local result = Serializer.tableValue("foo") + + expect(result).to.equal('"foo"') + end) + + it("should serialize strings with \"", function() + local result = Serializer.tableValue('foo"bar') + + expect(result).to.equal('"foo\\"bar"') + end) + + it("should serialize numbers", function() + local result = Serializer.tableValue(10.5) + + expect(result).to.equal("10.5") + end) + + it("should serialize booleans", function() + expect(Serializer.tableValue(true)).to.equal("true") + expect(Serializer.tableValue(false)).to.equal("false") + end) + + it("should serialize enum items", function() + local result = Serializer.tableValue(Enum.SortOrder.LayoutOrder) + + expect(result).to.equal("Enum.SortOrder.LayoutOrder") + end) + + it("should serialize Color3", function() + local result = Serializer.tableValue(Color3.new(0.1, 0.2, 0.3)) + + expect(result).to.equal("Color3.new(0.1, 0.2, 0.3)") + end) + + it("should serialize UDim", function() + local result = Serializer.tableValue(UDim.new(1.2, 0)) + + expect(result).to.equal("UDim.new(1.2, 0)") + end) + + it("should serialize UDim2", function() + local result = Serializer.tableValue(UDim2.new(1.5, 5, 2, 3)) + + expect(result).to.equal("UDim2.new(1.5, 5, 2, 3)") + end) + + it("should serialize Vector2", function() + local result = Serializer.tableValue(Vector2.new(1.5, 0.3)) + + expect(result).to.equal("Vector2.new(1.5, 0.3)") + end) + + it("should serialize markers symbol", function() + for name, marker in pairs(Markers) do + local result = Serializer.tableValue(marker) + + expect(result).to.equal(("Markers.%s"):format(name)) + end + end) + + it("should serialize Roact.Event events", function() + local result = Serializer.tableValue(Event.Activated) + + expect(result).to.equal("Roact.Event.Activated") + end) + + it("should serialize Roact.Change events", function() + local result = Serializer.tableValue(Change.AbsoluteSize) + + expect(result).to.equal("Roact.Change.AbsoluteSize") + end) + end) + + describe("table", function() + it("should serialize an empty nested table", function() + local output = IndentedOutput.new() + Serializer.table("sub", {}, output) + + expect(output:join()).to.equal("sub = {},") + end) + + it("should serialize an nested table", function() + local output = IndentedOutput.new() + Serializer.table("sub", { + foo = 1, + }, output) + + expect(output:join()).to.equal("sub = {\n foo = 1,\n},") + end) + end) + + describe("props", function() + it("should serialize an empty table", function() + local output = IndentedOutput.new() + Serializer.props({}, output) + + expect(output:join()).to.equal("props = {},") + end) + + it("should serialize table fields", function() + local output = IndentedOutput.new() + Serializer.props({ + key = 8, + }, output) + + expect(output:join()).to.equal("props = {\n key = 8,\n},") + end) + + it("should serialize Roact.Event", function() + local output = IndentedOutput.new() + Serializer.props({ + [Event.Activated] = Markers.AnonymousFunction, + }, output) + + expect(output:join()).to.equal( + "props = {\n" + .. " [Roact.Event.Activated] = Markers.AnonymousFunction,\n" + .. "}," + ) + end) + + it("should sort Roact.Event", function() + local output = IndentedOutput.new() + Serializer.props({ + [Event.Activated] = Markers.AnonymousFunction, + [Event.MouseEnter] = Markers.AnonymousFunction, + }, output) + + expect(output:join()).to.equal( + "props = {\n" + .. " [Roact.Event.Activated] = Markers.AnonymousFunction,\n" + .. " [Roact.Event.MouseEnter] = Markers.AnonymousFunction,\n" + .. "}," + ) + end) + + it("should serialize Roact.Change", function() + local output = IndentedOutput.new() + Serializer.props({ + [Change.Position] = Markers.AnonymousFunction, + }, output) + + expect(output:join()).to.equal( + "props = {\n" + .. " [Roact.Change.Position] = Markers.AnonymousFunction,\n" + .. "}," + ) + end) + + it("should sort props, Roact.Event, Roact.Change and Ref", function() + local output = IndentedOutput.new() + Serializer.props({ + foo = 1, + [Event.Activated] = Markers.AnonymousFunction, + [Change.Position] = Markers.AnonymousFunction, + [Ref] = Markers.EmptyRef, + }, output) + + expect(output:join()).to.equal( + "props = {\n" + .. " foo = 1,\n" + .. " [Roact.Event.Activated] = Markers.AnonymousFunction,\n" + .. " [Roact.Change.Position] = Markers.AnonymousFunction,\n" + .. " [Roact.Ref] = Markers.EmptyRef,\n" + .. "}," + ) + end) + + it("should sort props within themselves", function() + local output = IndentedOutput.new() + Serializer.props({ + foo = 1, + bar = 2, + }, output) + + expect(output:join()).to.equal( + "props = {\n" + .. " bar = 2,\n" + .. " foo = 1,\n" + .. "}," + ) + end) + end) + + describe("children", function() + it("should serialize an empty table", function() + local output = IndentedOutput.new() + Serializer.children({}, output) + + expect(output:join()).to.equal("children = {},") + end) + + it("should serialize children in an array", function() + local snapshotData = { + type = { + kind = ElementKind.Function, + }, + hostKey = "HostKey", + props = {}, + children = {}, + } + + local childrenOutput = IndentedOutput.new() + Serializer.children({snapshotData}, childrenOutput) + + local snapshotDataOutput = IndentedOutput.new() + snapshotDataOutput:push() + Serializer.snapshotData(snapshotData, snapshotDataOutput) + + local expectResult = "children = {\n" .. snapshotDataOutput:join() .. "\n}," + expect(childrenOutput:join()).to.equal(expectResult) + end) + end) + + describe("snapshotDataContent", function() + it("should serialize all fields", function() + local snapshotData = { + type = { + kind = ElementKind.Function, + }, + hostKey = "HostKey", + props = {}, + children = {}, + } + local output = IndentedOutput.new() + Serializer.snapshotDataContent(snapshotData, output) + + expect(output:join()).to.equal( + "type = {\n" + .. " kind = ElementKind.Function,\n" + .. "},\n" + .. 'hostKey = "HostKey",\n' + .. "props = {},\n" + .. "children = {}," + ) + end) + end) + + describe("snapshotData", function() + it("should wrap snapshotDataContent result between curly braces", function() + local snapshotData = { + type = { + kind = ElementKind.Function, + }, + hostKey = "HostKey", + props = {}, + children = {}, + } + local contentOutput = IndentedOutput.new() + contentOutput:push() + Serializer.snapshotDataContent(snapshotData, contentOutput) + + local output = IndentedOutput.new() + Serializer.snapshotData(snapshotData, output) + + local expectResult = "{\n" .. contentOutput:join() .. "\n}," + expect(output:join()).to.equal(expectResult) + end) + end) + + describe("firstSnapshotData", function() + it("should return a function that returns a table", function() + local result = Serializer.firstSnapshotData({ + type = { + kind = ElementKind.Function, + }, + hostKey = "HostKey", + props = {}, + children = {}, + }) + + local pattern = "^return function%(.-%).+return%s+{(.+)}%s+end$" + expect(result:match(pattern)).to.be.ok() + end) + end) +end \ No newline at end of file diff --git a/src/shallow/Snapshot/Serialize/Snapshot.lua b/src/shallow/Snapshot/Serialize/Snapshot.lua new file mode 100644 index 00000000..137934e6 --- /dev/null +++ b/src/shallow/Snapshot/Serialize/Snapshot.lua @@ -0,0 +1,124 @@ +local RoactRoot = script.Parent.Parent.Parent.Parent + +local Markers = require(script.Parent.Markers) +local ElementKind = require(RoactRoot.ElementKind) +local Type = require(RoactRoot.Type) +local Ref = require(RoactRoot.PropMarkers.Ref) + +local function sortSerializedChildren(childA, childB) + return childA.hostKey < childB.hostKey +end + +local Snapshot = {} + +function Snapshot.type(wrapperType) + local typeData = { + kind = wrapperType.kind, + } + + if wrapperType.kind == ElementKind.Host then + typeData.className = wrapperType.className + elseif wrapperType.kind == ElementKind.Stateful then + typeData.componentName = tostring(wrapperType.component) + end + + return typeData +end + +function Snapshot.signal(signal) + local signalToString = tostring(signal) + local signalName = signalToString:match("Signal (%w+)") + + assert(signalName ~= nil, ("Can not extract signal name from %q"):format(signalToString)) + + return { + [Markers.Signal] = signalName + } +end + +function Snapshot.propValue(prop) + local propType = type(prop) + + if propType == "string" + or propType == "number" + or propType == "boolean" + then + return prop + + elseif propType == "function" then + return Markers.AnonymousFunction + + elseif typeof(prop) == "RBXScriptSignal" then + return Snapshot.signal(prop) + + elseif propType == "userdata" then + return prop + + elseif propType == "table" then + return Snapshot.props(prop) + + else + warn(("Snapshot does not support prop with value %q (type %q)"):format( + tostring(prop), + propType + )) + return Markers.Unknown + end +end + +function Snapshot.props(wrapperProps) + local serializedProps = {} + + for key, prop in pairs(wrapperProps) do + if type(key) == "string" + or Type.of(key) == Type.HostChangeEvent + or Type.of(key) == Type.HostEvent + then + serializedProps[key] = Snapshot.propValue(prop) + + elseif key == Ref then + local current = prop:getValue() + + if current then + serializedProps[key] = { + className = current.ClassName, + } + else + serializedProps[key] = Markers.EmptyRef + end + + else + error(("Snapshot does not support prop with key %q (type: %s)"):format( + tostring(key), + type(key) + )) + end + end + + return serializedProps +end + +function Snapshot.children(children) + local serializedChildren = {} + + for i=1, #children do + local childWrapper = children[i] + + serializedChildren[i] = Snapshot.new(childWrapper) + end + + table.sort(serializedChildren, sortSerializedChildren) + + return serializedChildren +end + +function Snapshot.new(wrapper) + return { + type = Snapshot.type(wrapper.type), + hostKey = wrapper.hostKey, + props = Snapshot.props(wrapper.props), + children = Snapshot.children(wrapper:getChildren()), + } +end + +return Snapshot diff --git a/src/shallow/Snapshot/Serialize/Snapshot.spec.lua b/src/shallow/Snapshot/Serialize/Snapshot.spec.lua new file mode 100644 index 00000000..551ebd4a --- /dev/null +++ b/src/shallow/Snapshot/Serialize/Snapshot.spec.lua @@ -0,0 +1,296 @@ +return function() + local RoactRoot = script.Parent.Parent.Parent.Parent + + local Markers = require(script.Parent.Markers) + local assertDeepEqual = require(RoactRoot.assertDeepEqual) + local Binding = require(RoactRoot.Binding) + local Change = require(RoactRoot.PropMarkers.Change) + local Component = require(RoactRoot.Component) + local createElement = require(RoactRoot.createElement) + local createReconciler = require(RoactRoot.createReconciler) + local createRef = require(RoactRoot.createRef) + local ElementKind = require(RoactRoot.ElementKind) + local Event = require(RoactRoot.PropMarkers.Event) + local Ref = require(RoactRoot.PropMarkers.Ref) + local RobloxRenderer = require(RoactRoot.RobloxRenderer) + local ShallowWrapper = require(script.Parent.Parent.Parent.ShallowWrapper) + + local Snapshot = require(script.Parent.Snapshot) + + local robloxReconciler = createReconciler(RobloxRenderer) + + local function shallow(element, options) + options = options or {} + local maxDepth = options.depth or 1 + local hostKey = options.hostKey or "ShallowTree" + local hostParent = options.hostParent or Instance.new("Folder") + + local virtualNode = robloxReconciler.mountVirtualNode(element, hostParent, hostKey) + + return ShallowWrapper.new(virtualNode, maxDepth) + end + + describe("type", function() + describe("host elements", function() + it("should contain the host kind", function() + local wrapper = shallow(createElement("Frame")) + + local result = Snapshot.type(wrapper.type) + + expect(result.kind).to.equal(ElementKind.Host) + end) + + it("should contain the class name", function() + local className = "Frame" + local wrapper = shallow(createElement(className)) + + local result = Snapshot.type(wrapper.type) + + expect(result.className).to.equal(className) + end) + end) + + describe("function elements", function() + local function SomeComponent() + return nil + end + + it("should contain the host kind", function() + local wrapper = shallow(createElement(SomeComponent)) + + local result = Snapshot.type(wrapper.type) + + expect(result.kind).to.equal(ElementKind.Function) + end) + end) + + describe("stateful elements", function() + local componentName = "ComponentName" + local SomeComponent = Component:extend(componentName) + + function SomeComponent:render() + return nil + end + + it("should contain the host kind", function() + local wrapper = shallow(createElement(SomeComponent)) + + local result = Snapshot.type(wrapper.type) + + expect(result.kind).to.equal(ElementKind.Stateful) + end) + + it("should contain the component name", function() + local wrapper = shallow(createElement(SomeComponent)) + + local result = Snapshot.type(wrapper.type) + + expect(result.componentName).to.equal(componentName) + end) + end) + end) + + describe("signal", function() + it("should convert signals", function() + local signalName = "Foo" + local signalMock = setmetatable({}, { + __tostring = function() + return "Signal " .. signalName + end + }) + + local result = Snapshot.signal(signalMock) + + assertDeepEqual(result, { + [Markers.Signal] = signalName + }) + end) + end) + + describe("propValue", function() + it("should return the same value for basic types", function() + local propValues = {7, "hello", Enum.SortOrder.LayoutOrder} + + for i=1, #propValues do + local prop = propValues[i] + local result = Snapshot.propValue(prop) + + expect(result).to.equal(prop) + end + end) + + it("should return an empty table given an empty table", function() + local result = Snapshot.propValue({}) + + expect(next(result)).never.to.be.ok() + end) + + it("should serialize a table as a props table", function() + local key = "some key" + local value = { + [key] = "foo", + } + local result = Snapshot.propValue(value) + + expect(result[key]).to.equal("foo") + expect(next(result, key)).never.to.be.ok() + end) + + it("should return the AnonymousFunction symbol when given a function", function() + local result = Snapshot.propValue(function() end) + + expect(result).to.equal(Markers.AnonymousFunction) + end) + + it("should return the Unknown symbol when given an unexpected value", function() + local result = Snapshot.propValue(nil) + + expect(result).to.equal(Markers.Unknown) + end) + end) + + describe("props", function() + it("should keep props with string keys", function() + local props = { + image = "hello", + text = "never", + } + + local result = Snapshot.props(props) + + assertDeepEqual(result, props) + end) + + it("should map Roact.Event to AnonymousFunction", function() + local props = { + [Event.Activated] = function() end, + } + + local result = Snapshot.props(props) + + assertDeepEqual(result, { + [Event.Activated] = Markers.AnonymousFunction, + }) + end) + + it("should map Roact.Change to AnonymousFunction", function() + local props = { + [Change.Position] = function() end, + } + + local result = Snapshot.props(props) + + assertDeepEqual(result, { + [Change.Position] = Markers.AnonymousFunction, + }) + end) + + it("should map empty refs to the EmptyRef symbol", function() + local props = { + [Ref] = createRef(), + } + + local result = Snapshot.props(props) + + assertDeepEqual(result, { + [Ref] = Markers.EmptyRef, + }) + end) + + it("should map refs with value to their symbols", function() + local instanceClassName = "Folder" + local ref = createRef() + Binding.update(ref, Instance.new(instanceClassName)) + + local props = { + [Ref] = ref, + } + + local result = Snapshot.props(props) + + assertDeepEqual(result, { + [Ref] = { + className = instanceClassName, + }, + }) + end) + + it("should throw when the key is a table", function() + local function shouldThrow() + Snapshot.props({ + [{}] = "invalid", + }) + end + + expect(shouldThrow).to.throw() + end) + end) + + describe("wrapper", function() + it("should have the host key", function() + local hostKey = "SomeKey" + local wrapper = shallow(createElement("Frame")) + wrapper.hostKey = hostKey + + local result = Snapshot.new(wrapper) + + expect(result.hostKey).to.equal(hostKey) + end) + + it("should contain the element type", function() + local wrapper = shallow(createElement("Frame")) + + local result = Snapshot.new(wrapper) + + expect(result.type).to.be.ok() + expect(result.type.kind).to.equal(ElementKind.Host) + expect(result.type.className).to.equal("Frame") + end) + + it("should contain the props", function() + local props = { + LayoutOrder = 3, + [Change.Size] = function() end, + } + local expectProps = { + LayoutOrder = 3, + [Change.Size] = Markers.AnonymousFunction, + } + + local wrapper = shallow(createElement("Frame", props)) + + local result = Snapshot.new(wrapper) + + expect(result.props).to.be.ok() + assertDeepEqual(result.props, expectProps) + end) + + it("should contain the element children", function() + local wrapper = shallow(createElement("Frame", {}, { + Child = createElement("TextLabel"), + })) + + local result = Snapshot.new(wrapper) + + expect(result.children).to.be.ok() + expect(#result.children).to.equal(1) + local childData = result.children[1] + expect(childData.type.kind).to.equal(ElementKind.Host) + expect(childData.type.className).to.equal("TextLabel") + end) + + it("should sort children by their host key", function() + local wrapper = shallow(createElement("Frame", {}, { + Child = createElement("TextLabel"), + Label = createElement("TextLabel"), + })) + + local result = Snapshot.new(wrapper) + + expect(result.children).to.be.ok() + expect(#result.children).to.equal(2) + expect(result.children[1].hostKey).to.equal("Child") + expect(result.children[2].hostKey).to.equal("Label") + end) + end) +end \ No newline at end of file diff --git a/src/shallow/Snapshot/Serialize/init.lua b/src/shallow/Snapshot/Serialize/init.lua new file mode 100644 index 00000000..170cca39 --- /dev/null +++ b/src/shallow/Snapshot/Serialize/init.lua @@ -0,0 +1,11 @@ +local Serializer = require(script.Serializer) +local Snapshot = require(script.Snapshot) + +return { + wrapperToSnapshot = function(wrapper) + return Snapshot.new(wrapper) + end, + snapshotToString = function(snapshot) + return Serializer.firstSnapshotData(snapshot) + end, +} diff --git a/src/shallow/Snapshot/SnapshotMatcher.lua b/src/shallow/Snapshot/SnapshotMatcher.lua new file mode 100644 index 00000000..1560bbd7 --- /dev/null +++ b/src/shallow/Snapshot/SnapshotMatcher.lua @@ -0,0 +1,99 @@ +local ReplicatedStorage = game:GetService("ReplicatedStorage") +local RoactRoot = script.Parent.Parent.Parent + +local Markers = require(script.Parent.Serialize.Markers) +local Serialize = require(script.Parent.Serialize) +local deepEqual = require(RoactRoot.deepEqual) +local ElementKind = require(RoactRoot.ElementKind) + +local SnapshotFolderName = "RoactSnapshots" +local SnapshotFolder = ReplicatedStorage:FindFirstChild(SnapshotFolderName) + +local SnapshotMatcher = {} +local SnapshotMetatable = { + __index = SnapshotMatcher, +} + +function SnapshotMatcher.new(identifier, snapshot) + local snapshotMatcher = { + _identifier = identifier, + _snapshot = snapshot, + _existingSnapshot = SnapshotMatcher._loadExistingData(identifier), + } + + setmetatable(snapshotMatcher, SnapshotMetatable) + + return snapshotMatcher +end + +function SnapshotMatcher:match() + if self._existingSnapshot == nil then + self:serialize() + self._existingSnapshot = self._snapshot + return + end + + local areEqual, innerMessageTemplate = deepEqual(self._snapshot, self._existingSnapshot) + + if areEqual then + return + end + + local newSnapshot = SnapshotMatcher.new(self._identifier .. ".NEW", self._snapshot) + newSnapshot:serialize() + + local innerMessage = innerMessageTemplate + :gsub("{1}", "new") + :gsub("{2}", "existing") + + local message = ("Snapshots do not match.\n%s"):format(innerMessage) + + error(message, 2) +end + +function SnapshotMatcher:serialize() + local folder = SnapshotMatcher.getSnapshotFolder() + + local snapshotSource = Serialize.snapshotToString(self._snapshot) + local existingData = folder:FindFirstChild(self._identifier) + + if not (existingData and existingData:IsA("StringValue")) then + existingData = Instance.new("StringValue") + existingData.Name = self._identifier + existingData.Parent = folder + end + + existingData.Value = snapshotSource +end + +function SnapshotMatcher.getSnapshotFolder() + SnapshotFolder = ReplicatedStorage:FindFirstChild(SnapshotFolderName) + + if not SnapshotFolder then + SnapshotFolder = Instance.new("Folder") + SnapshotFolder.Name = SnapshotFolderName + SnapshotFolder.Parent = ReplicatedStorage + end + + return SnapshotFolder +end + +function SnapshotMatcher._loadExistingData(identifier) + local folder = SnapshotMatcher.getSnapshotFolder() + + local existingData = folder:FindFirstChild(identifier) + + if not (existingData and existingData:IsA("ModuleScript")) then + return nil + end + + local loadSnapshot = require(existingData) + + return loadSnapshot({ + Roact = require(RoactRoot), + ElementKind = ElementKind, + Markers = Markers, + }) +end + +return SnapshotMatcher \ No newline at end of file diff --git a/src/shallow/Snapshot/SnapshotMatcher.spec.lua b/src/shallow/Snapshot/SnapshotMatcher.spec.lua new file mode 100644 index 00000000..be44e88f --- /dev/null +++ b/src/shallow/Snapshot/SnapshotMatcher.spec.lua @@ -0,0 +1,142 @@ +return function() + local RoactRoot = script.Parent.Parent.Parent + + local SnapshotMatcher = require(script.Parent.SnapshotMatcher) + + local ElementKind = require(RoactRoot.ElementKind) + local createSpy = require(RoactRoot.createSpy) + + local snapshotFolder = Instance.new("Folder") + local originalGetSnapshotFolder = SnapshotMatcher.getSnapshotFolder + + local function mockGetSnapshotFolder() + return snapshotFolder + end + + local originalLoadExistingData = SnapshotMatcher._loadExistingData + local loadExistingDataSpy = nil + + describe("match", function() + local snapshotMap = {} + + local function beforeTest() + snapshotMap = {} + + loadExistingDataSpy = createSpy(function(identifier) + return snapshotMap[identifier] + end) + SnapshotMatcher._loadExistingData = loadExistingDataSpy.value + end + + local function cleanTest() + loadExistingDataSpy = nil + SnapshotMatcher._loadExistingData = originalLoadExistingData + end + + it("should serialize the snapshot if no data is found", function() + beforeTest() + + local snapshot = {} + local serializeSpy = createSpy() + + local matcher = SnapshotMatcher.new("foo", snapshot) + matcher.serialize = serializeSpy.value + + matcher:match() + + cleanTest() + + serializeSpy:assertCalledWith(matcher) + end) + + it("should not serialize if the snapshot already exist", function() + beforeTest() + + local snapshot = {} + local identifier = "foo" + snapshotMap[identifier] = snapshot + + local serializeSpy = createSpy() + + local matcher = SnapshotMatcher.new(identifier, snapshot) + matcher.serialize = serializeSpy.value + + matcher:match() + + cleanTest() + + expect(serializeSpy.callCount).to.equal(0) + end) + + it("should throw an error if the previous snapshot does not match", function() + beforeTest() + + local snapshot = {} + local identifier = "foo" + snapshotMap[identifier] = { + Key = "Value" + } + + local serializeSpy = createSpy() + + local matcher = SnapshotMatcher.new(identifier, snapshot) + matcher.serialize = serializeSpy.value + + local function shouldThrow() + matcher:match() + end + + cleanTest() + + expect(shouldThrow).to.throw() + end) + end) + + describe("serialize", function() + it("should create a StringValue if it does not exist", function() + SnapshotMatcher.getSnapshotFolder = mockGetSnapshotFolder + + local identifier = "foo" + + local matcher = SnapshotMatcher.new(identifier, { + type = { + kind = ElementKind.Function, + }, + hostKey = "HostKey", + props = {}, + children = {}, + }) + + matcher:serialize() + local stringValue = snapshotFolder:FindFirstChild(identifier) + + SnapshotMatcher.getSnapshotFolder = originalGetSnapshotFolder + + expect(stringValue).to.be.ok() + expect(stringValue.Value:len() > 0).to.equal(true) + + stringValue:Destroy() + end) + end) + + describe("_loadExistingData", function() + it("should return nil if data is not found", function() + SnapshotMatcher.getSnapshotFolder = mockGetSnapshotFolder + + local result = SnapshotMatcher._loadExistingData("foo") + + SnapshotMatcher.getSnapshotFolder = originalGetSnapshotFolder + + expect(result).never.to.be.ok() + end) + end) + + describe("getSnapshotFolder", function() + it("should create a folder in the ReplicatedStorage if it is not found", function() + local folder = SnapshotMatcher.getSnapshotFolder() + + expect(folder).to.be.ok() + expect(folder.Parent).to.equal(game:GetService("ReplicatedStorage")) + end) + end) +end \ No newline at end of file diff --git a/src/shallow/Snapshot/init.lua b/src/shallow/Snapshot/init.lua new file mode 100644 index 00000000..88b5ea84 --- /dev/null +++ b/src/shallow/Snapshot/init.lua @@ -0,0 +1,28 @@ +local Serialize = require(script.Serialize) +local SnapshotMatcher = require(script.SnapshotMatcher) + +local characterClass = "%w_%-%." +local identifierPattern = "^[" .. characterClass .. "]+$" +local invalidPattern = "[^" .. characterClass .. "]" + +local function createMatcher(identifier, shallowWrapper) + if not identifier:match(identifierPattern) then + error(("Snapshot identifier has invalid character: '%s'"):format(identifier:match(invalidPattern))) + end + + local snapshot = Serialize.wrapperToSnapshot(shallowWrapper) + local matcher = SnapshotMatcher.new(identifier, snapshot) + + return matcher +end + +local function toString(shallowWrapper) + local snapshot = Serialize.wrapperToSnapshot(shallowWrapper) + + return Serialize.snapshotToString(snapshot) +end + +return { + createMatcher = createMatcher, + toString = toString, +} \ No newline at end of file diff --git a/src/shallow/Snapshot/init.spec.lua b/src/shallow/Snapshot/init.spec.lua new file mode 100644 index 00000000..69fb22e1 --- /dev/null +++ b/src/shallow/Snapshot/init.spec.lua @@ -0,0 +1,104 @@ +return function() + local RoactRoot = script.Parent.Parent.Parent + + local Change = require(RoactRoot.PropMarkers.Change) + local Component = require(RoactRoot.Component) + local createElement = require(RoactRoot.createElement) + local createReconciler = require(RoactRoot.createReconciler) + local Event = require(RoactRoot.PropMarkers.Event) + local RobloxRenderer = require(RoactRoot.RobloxRenderer) + local ShallowWrapper = require(script.Parent.Parent.ShallowWrapper) + local Snapshot = require(script.Parent) + + local robloxReconciler = createReconciler(RobloxRenderer) + + local hostTreeKey = "RoactTree" + + it("should match snapshot of host component with multiple props", function() + local element = createElement("Frame", { + AnchorPoint = Vector2.new(0, 0.5), + BackgroundColor3 = Color3.new(0.1, 0.2, 0.3), + BackgroundTransparency = 0.205, + ClipsDescendants = false, + Size = UDim2.new(0.5, 0, 0.4, 1), + SizeConstraint = Enum.SizeConstraint.RelativeXY, + Visible = true, + ZIndex = 5, + }) + + local rootNode = robloxReconciler.mountVirtualNode(element, nil, hostTreeKey) + local wrapper = ShallowWrapper.new(rootNode, 1) + + Snapshot.createMatcher("host-frame-with-multiple-props", wrapper):match() + end) + + it("should match snapshot of function component children", function() + local function LabelComponent(props) + return createElement("TextLabel", props) + end + + local element = createElement("Frame", {}, { + LabelA = createElement(LabelComponent, { + Text = "I am label A", + }), + LabelB = createElement(LabelComponent, { + Text = "I am label B", + }), + }) + + local rootNode = robloxReconciler.mountVirtualNode(element, nil, hostTreeKey) + local wrapper = ShallowWrapper.new(rootNode, 1) + + Snapshot.createMatcher("function-component-children", wrapper):match() + end) + + it("should match snapshot of stateful component", function() + local StatefulComponent = Component:extend("CoolComponent") + + function StatefulComponent:render() + return createElement("TextLabel") + end + + local element = createElement("Frame", {}, { + Child = createElement(StatefulComponent, { + label = { + Text = "foo", + }, + }), + }) + + local rootNode = robloxReconciler.mountVirtualNode(element, nil, hostTreeKey) + local wrapper = ShallowWrapper.new(rootNode, 1) + + Snapshot.createMatcher("stateful-component-children", wrapper):match() + end) + + it("should match snapshot with event props", function() + local function emptyFunction() + end + + local element = createElement("TextButton", { + [Change.AbsoluteSize] = emptyFunction, + [Change.Visible] = emptyFunction, + [Event.Activated] = emptyFunction, + [Event.MouseButton1Click] = emptyFunction, + }) + + local rootNode = robloxReconciler.mountVirtualNode(element, nil, hostTreeKey) + local wrapper = ShallowWrapper.new(rootNode, 1) + + Snapshot.createMatcher("component-with-event-props", wrapper):match() + end) + + it("should throw if the identifier contains invalid characters", function() + local invalidCharacters = {"\\", "/", "?"} + + for i=1, #invalidCharacters do + local function shouldThrow() + Snapshot.createMatcher("id" .. invalidCharacters[i], {}) + end + + expect(shouldThrow).to.throw() + end + end) +end \ No newline at end of file diff --git a/src/shallow/VirtualNodeConstraints/Constraints.lua b/src/shallow/VirtualNodeConstraints/Constraints.lua new file mode 100644 index 00000000..821b771e --- /dev/null +++ b/src/shallow/VirtualNodeConstraints/Constraints.lua @@ -0,0 +1,42 @@ +local RoactRoot = script.Parent.Parent.Parent + +local ElementKind = require(RoactRoot.ElementKind) + +local Constraints = setmetatable({}, { + __index = function(self, unexpectedConstraint) + error(("unknown constraint %q"):format(unexpectedConstraint)) + end, +}) + +function Constraints.kind(virtualNode, expectKind) + return ElementKind.of(virtualNode.currentElement) == expectKind +end + +function Constraints.className(virtualNode, className) + local element = virtualNode.currentElement + local isHost = ElementKind.of(element) == ElementKind.Host + + return isHost and element.component == className +end + +function Constraints.component(virtualNode, expectComponentValue) + return virtualNode.currentElement.component == expectComponentValue +end + +function Constraints.props(virtualNode, propSubSet) + local elementProps = virtualNode.currentElement.props + + for propKey, propValue in pairs(propSubSet) do + if elementProps[propKey] ~= propValue then + return false + end + end + + return true +end + +function Constraints.hostKey(virtualNode, expectHostKey) + return virtualNode.hostKey == expectHostKey +end + +return Constraints \ No newline at end of file diff --git a/src/shallow/VirtualNodeConstraints/Constraints.spec.lua b/src/shallow/VirtualNodeConstraints/Constraints.spec.lua new file mode 100644 index 00000000..5a40af42 --- /dev/null +++ b/src/shallow/VirtualNodeConstraints/Constraints.spec.lua @@ -0,0 +1,200 @@ +return function() + local RoactRoot = script.Parent.Parent.Parent + + local ElementKind = require(RoactRoot.ElementKind) + local createElement = require(RoactRoot.createElement) + local createReconciler = require(RoactRoot.createReconciler) + local RoactComponent = require(RoactRoot.Component) + local RobloxRenderer = require(RoactRoot.RobloxRenderer) + + local Constraints = require(script.Parent.Constraints) + + local robloxReconciler = createReconciler(RobloxRenderer) + + local HOST_PARENT = nil + local HOST_KEY = "ConstraintsTree" + + local function getVirtualNode(element) + return robloxReconciler.mountVirtualNode(element, HOST_PARENT, HOST_KEY) + end + + describe("kind", function() + it("should return true when the element is of the same kind", function() + local element = createElement("TextLabel") + local virtualNode = getVirtualNode(element) + + local result = Constraints.kind(virtualNode, ElementKind.Host) + + expect(result).to.equal(true) + end) + + it("should return false when the element is not of the same kind", function() + local element = createElement("TextLabel") + local virtualNode = getVirtualNode(element) + + local result = Constraints.kind(virtualNode, ElementKind.Stateful) + + expect(result).to.equal(false) + end) + end) + + describe("className", function() + it("should return true when a host virtualNode has the given class name", function() + local className = "TextLabel" + local element = createElement(className) + + local virtualNode = getVirtualNode(element) + + local result = Constraints.className(virtualNode, className) + + expect(result).to.equal(true) + end) + + it("should return false when a host virtualNode does not have the same class name", function() + local element = createElement("Frame") + + local virtualNode = getVirtualNode(element) + + local result = Constraints.className(virtualNode, "TextLabel") + + expect(result).to.equal(false) + end) + + it("should return false when not a host virtualNode", function() + local function Component() + return createElement("TextLabel") + end + local element = createElement(Component) + + local virtualNode = getVirtualNode(element) + + local result = Constraints.className(virtualNode, "TextLabel") + + expect(result).to.equal(false) + end) + end) + + describe("component", function() + it("should return true given a host virtualNode with the same class name", function() + local className = "TextLabel" + local element = createElement(className) + + local virtualNode = getVirtualNode(element) + + local result = Constraints.component(virtualNode, className) + + expect(result).to.equal(true) + end) + + it("should return true given a functional virtualNode function", function() + local function Component(props) + return nil + end + + local element = createElement(Component) + local virtualNode = getVirtualNode(element) + + local result = Constraints.component(virtualNode, Component) + + expect(result).to.equal(true) + end) + + it("should return true given a stateful virtualNode component class", function() + local Component = RoactComponent:extend("Foo") + + function Component:render() + return nil + end + + local element = createElement(Component) + local virtualNode = getVirtualNode(element) + + local result = Constraints.component(virtualNode, Component) + + expect(result).to.equal(true) + end) + + it("should return false when components kind do not match", function() + local function Component(props) + return nil + end + + local element = createElement(Component) + local virtualNode = getVirtualNode(element) + + local result = Constraints.component(virtualNode, "TextLabel") + + expect(result).to.equal(false) + end) + end) + + describe("props", function() + it("should return true when the virtualNode satisfies all prop constraints", function() + local props = { + Visible = false, + LayoutOrder = 7, + } + local element = createElement("TextLabel", props) + local virtualNode = getVirtualNode(element) + + local result = Constraints.props(virtualNode, { + Visible = false, + LayoutOrder = 7, + }) + + expect(result).to.equal(true) + end) + + it("should return true if the props are from a subset of the virtualNode props", function() + local props = { + Visible = false, + LayoutOrder = 7, + } + + local element = createElement("TextLabel", props) + local virtualNode = getVirtualNode(element) + + local result = Constraints.props(virtualNode, { + LayoutOrder = 7, + }) + + expect(result).to.equal(true) + end) + + it("should return false if a subset of the props are different from the given props", function() + local props = { + Visible = false, + LayoutOrder = 1, + } + + local element = createElement("TextLabel", props) + local virtualNode = getVirtualNode(element) + + local result = Constraints.props(virtualNode, { + LayoutOrder = 7, + }) + + expect(result).to.equal(false) + end) + end) + + describe("hostKey", function() + it("should return true when the virtualNode has the same hostKey", function() + local element = createElement("TextLabel") + local virtualNode = getVirtualNode(element) + + local result = Constraints.hostKey(virtualNode, HOST_KEY) + + expect(result).to.equal(true) + end) + + it("should return false when the virtualNode hostKey is different", function() + local element = createElement("TextLabel") + local virtualNode = getVirtualNode(element) + + local result = Constraints.hostKey(virtualNode, "foo") + + expect(result).to.equal(false) + end) + end) +end \ No newline at end of file diff --git a/src/shallow/VirtualNodeConstraints/init.lua b/src/shallow/VirtualNodeConstraints/init.lua new file mode 100644 index 00000000..1b3e2a7e --- /dev/null +++ b/src/shallow/VirtualNodeConstraints/init.lua @@ -0,0 +1,24 @@ +local Constraints = require(script.Constraints) + +local function satisfiesAll(virtualNode, constraints) + for constraint, value in pairs(constraints) do + local constraintFunction = Constraints[constraint] + + if not constraintFunction(virtualNode, value) then + return false + end + end + + return true +end + +local function validate(constraints) + for constraint in pairs(constraints) do + assert(Constraints[constraint] ~= nil, ("unknown constraint %q"):format(constraint)) + end +end + +return { + satisfiesAll = satisfiesAll, + validate = validate, +} \ No newline at end of file diff --git a/src/shallow/VirtualNodeConstraints/init.spec.lua b/src/shallow/VirtualNodeConstraints/init.spec.lua new file mode 100644 index 00000000..d7df1e91 --- /dev/null +++ b/src/shallow/VirtualNodeConstraints/init.spec.lua @@ -0,0 +1,31 @@ +return function() + local VirtualNodesConstraints = require(script.Parent) + + describe("validate", function() + it("should throw when a constraint does not exist", function() + local constraints = { + hostKey = "Key", + foo = "bar", + } + + local function validateNotExistingConstraint() + VirtualNodesConstraints.validate(constraints) + end + + expect(validateNotExistingConstraint).to.throw() + end) + + it("should not throw when all constraints exsits", function() + local constraints = { + hostKey = "Key", + className = "Frame", + } + + local function validateExistingConstraints() + VirtualNodesConstraints.validate(constraints) + end + + expect(validateExistingConstraints).never.to.throw() + end) + end) +end \ No newline at end of file diff --git a/src/shallow/init.lua b/src/shallow/init.lua new file mode 100644 index 00000000..e4dd6b9f --- /dev/null +++ b/src/shallow/init.lua @@ -0,0 +1,21 @@ +local createReconciler = require(script.Parent.createReconciler) +local Type = require(script.Parent.Type) +local RobloxRenderer = require(script.Parent.RobloxRenderer) +local ShallowWrapper = require(script.ShallowWrapper) + +local robloxReconciler = createReconciler(RobloxRenderer) + +local shallowTreeKey = "RoactTree" + +local function shallow(element, options) + assert(Type.of(element) == Type.Element, "Expected arg #1 to be an Element") + + options = options or {} + local maxDepth = options.depth or 1 + + local rootNode = robloxReconciler.mountVirtualNode(element, nil, shallowTreeKey) + + return ShallowWrapper.new(rootNode, maxDepth) +end + +return shallow \ No newline at end of file diff --git a/src/shallow/init.spec.lua b/src/shallow/init.spec.lua new file mode 100644 index 00000000..10e0ee65 --- /dev/null +++ b/src/shallow/init.spec.lua @@ -0,0 +1,17 @@ +return function() + local createElement = require(script.Parent.Parent.createElement) + local shallow = require(script.Parent) + + it("should return a shallow wrapper with depth = 1 by default", function() + local element = createElement("Frame", {}, { + Child = createElement("Frame", {}, { + SubChild = createElement("Frame"), + }), + }) + + local wrapper = shallow(element) + local childWrapper = wrapper:findUnique() + + expect(childWrapper:childrenCount()).to.equal(0) + end) +end \ No newline at end of file