Skip to content

Commit

Permalink
Merge pull request #340 from uhziel/doc-type-table
Browse files Browse the repository at this point in the history
 添加对 EmmyLua table type 的支持
  • Loading branch information
sumneko committed Jan 20, 2021
2 parents 2fac237 + 217aa4e commit 924aa35
Show file tree
Hide file tree
Showing 6 changed files with 218 additions and 21 deletions.
2 changes: 2 additions & 0 deletions script/parser/compile.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ local specials = {
['loadfile'] = true,
['pcall'] = true,
['xpcall'] = true,
['pairs'] = true,
['ipairs'] = true,
}

_ENV = nil
Expand Down
108 changes: 102 additions & 6 deletions script/parser/guide.lua
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ m.childMap = {
['doc.generic'] = {'#generics', 'comment'},
['doc.generic.object'] = {'generic', 'extends', 'comment'},
['doc.vararg'] = {'vararg', 'comment'},
['doc.type.table'] = {'key', 'value', 'comment'},
['doc.type.table'] = {'node', 'key', 'value', 'comment'},
['doc.type.function'] = {'#args', '#returns', 'comment'},
['doc.type.typeliteral'] = {'node'},
['doc.overload'] = {'overload', 'comment'},
Expand Down Expand Up @@ -169,6 +169,21 @@ function m.getParentFunction(obj)
return nil
end

--- 寻找父的table类型 doc.type.table
function m.getParentDocTypeTable(obj)
for _ = 1, 1000 do
local parent = obj.parent
if not parent then
return nil
end
if parent.type == 'doc.type.table' then
return obj
end
obj = parent
end
error('guide.getParentDocTypeTable overstack')
end

--- 寻找所在区块
function m.getBlock(obj)
for _ = 1, 1000 do
Expand Down Expand Up @@ -1671,12 +1686,20 @@ function m.checkSameSimpleOfRefByDocSource(status, obj, start, pushQueue, mode)
end
end

local function getArrayLevel(obj)
local function getArrayOrTableLevel(obj)
local level = 0
while true do
local parent = obj.parent
if parent.type == 'doc.type.array' then
level = level + 1
elseif parent.type == 'doc.type.table' then
if obj.type == 'doc.type' then
level = level + 1
-- else 只存在 obj.type == 'doc.type.name' 的情况,即 table<k,v> 中的 table,这种是不需要再增加层级的
end
elseif parent.type == 'doc.type' and parent.parent and parent.parent.type == 'doc.type.table' then
level = level + 1
parent = parent.parent
else
break
end
Expand Down Expand Up @@ -1733,9 +1756,10 @@ function m.checkSameSimpleByDoc(status, obj, start, pushQueue, mode)
for _, res in ipairs(pieceResult) do
pushQueue(res, start, true)
end

local state = m.getDocState(obj)
if state.type == 'doc.type' and mode == 'ref' then
m.checkSameSimpleOfRefByDocSource(status, state, start - getArrayLevel(obj), pushQueue, mode)
m.checkSameSimpleOfRefByDocSource(status, state, start - getArrayOrTableLevel(obj), pushQueue, mode)
end
return true
elseif obj.type == 'doc.field' then
Expand All @@ -1746,6 +1770,10 @@ function m.checkSameSimpleByDoc(status, obj, start, pushQueue, mode)
elseif obj.type == 'doc.type.array' then
pushQueue(obj.node, start + 1, true)
return true
elseif obj.type == 'doc.type.table' then
pushQueue(obj.node, start, true)
pushQueue(obj.value, start + 1, true)
return true
end
end

Expand Down Expand Up @@ -2211,6 +2239,72 @@ function m.checkSameSimpleAsSetValue(status, ref, start, pushQueue)
end
end

local function getTableAndIndexIfIsForPairsKeyOrValue(ref)
if ref.type ~= 'local' then
return
end

if not ref.parent or ref.parent.type ~= 'in' then
return
end

if not ref.value or ref.value.type ~= 'select' then
return
end

local rootSelectObj = ref.value
if rootSelectObj.index ~= 1 and rootSelectObj.index ~= 2 then
return
end

if not rootSelectObj.vararg or rootSelectObj.vararg.type ~= 'call' then
return
end
local rootCallObj = rootSelectObj.vararg

if not rootCallObj.node or rootCallObj.node.type ~= 'call' then
return
end
local pairsCallObj = rootCallObj.node

if not pairsCallObj.node
or (pairsCallObj.node.special ~= 'pairs' and pairsCallObj.node.special ~= 'ipairs') then
return
end

if not pairsCallObj.args or not pairsCallObj.args[1] then
return
end
local tableObj = pairsCallObj.args[1]

return tableObj, rootSelectObj.index
end

function m.checkSameSimpleAsKeyOrValueInForParis(status, ref, start, pushQueue)
local tableObj, index = getTableAndIndexIfIsForPairsKeyOrValue(ref)
if not tableObj then
return
end

local newStatus = m.status(status)
m.searchRefs(newStatus, tableObj, 'def')
for _, def in ipairs(newStatus.results) do
if def.bindDocs then
for _, binddoc in ipairs(def.bindDocs) do
if binddoc.type == 'doc.type' then
if binddoc.types[1] and binddoc.types[1].type == 'doc.type.table' then
if index == 1 then
pushQueue(binddoc.types[1].key, start, true)
elseif index == 2 then
pushQueue(binddoc.types[1].value, start, true)
end
end
end
end
end
end
end

local function hasTypeName(doc, name)
if doc.type ~= 'doc.type' then
return false
Expand Down Expand Up @@ -2447,6 +2541,8 @@ function m.checkSameSimple(status, simple, ref, start, force, mode, pushQueue)
m.checkSameSimpleAsReturn(status, ref, i, pushQueue)
-- 检查形如 a = f 的情况
m.checkSameSimpleAsSetValue(status, ref, i, pushQueue)
-- 检查形如 for k,v in pairs()/ipairs() do end 的情况
m.checkSameSimpleAsKeyOrValueInForParis(status, ref, i, pushQueue)
end
end
if i == #simple then
Expand Down Expand Up @@ -2888,7 +2984,7 @@ function m.viewInferType(infers)
or src.type == 'doc.class.name'
or src.type == 'doc.type.name'
or src.type == 'doc.type.array'
or src.type == 'doc.type.generic' then
or src.type == 'doc.type.table' then
if infer.type ~= 'any' then
hasDoc = true
break
Expand All @@ -2903,7 +2999,7 @@ function m.viewInferType(infers)
or src.type == 'doc.class.name'
or src.type == 'doc.type.name'
or src.type == 'doc.type.array'
or src.type == 'doc.type.generic'
or src.type == 'doc.type.table'
or src.type == 'doc.type.enum'
or src.type == 'doc.resume' then
local tp = infer.type or 'any'
Expand Down Expand Up @@ -3132,7 +3228,7 @@ function m.getDocTypeUnitName(status, unit)
typeName = 'function'
elseif unit.type == 'doc.type.array' then
typeName = m.getDocTypeUnitName(status, unit.node) .. '[]'
elseif unit.type == 'doc.type.generic' then
elseif unit.type == 'doc.type.table' then
typeName = ('%s<%s, %s>'):format(
m.getDocTypeUnitName(status, unit.node),
m.viewInferType(m.getDocTypeNames(status, unit.key)),
Expand Down
30 changes: 18 additions & 12 deletions script/parser/luadoc.lua
Original file line number Diff line number Diff line change
Expand Up @@ -261,30 +261,36 @@ local function parseTypeUnitArray(node)
return result
end

local function parseTypeUnitGeneric(node)
local function parseTypeUnitTable(parent, node)
if not checkToken('symbol', '<', 1) then
return nil
end
if not nextSymbolOrError('<') then
return nil
end
local key = parseType(node)

local result = {
type = 'doc.type.table',
start = node.start,
node = node,
parent = parent,
}

local key = parseType(result)
if not key or not nextSymbolOrError(',') then
return nil
end
local value = parseType(node)
local value = parseType(result)
if not value then
return nil
end
nextSymbolOrError('>')
local result = {
type = 'doc.type.generic',
start = node.start,
finish = getFinish(),
node = node,
key = key,
value = value,
}

node.parent = result;
result.finish = getFinish()
result.key = key
result.value = value

return result
end

Expand Down Expand Up @@ -398,7 +404,7 @@ local function parseTypeUnit(parent, content)
result.parent = parent
while true do
local newResult = parseTypeUnitArray(result)
or parseTypeUnitGeneric(result)
or parseTypeUnitTable(parent, result)
if not newResult then
break
end
Expand Down
5 changes: 5 additions & 0 deletions script/vm/getDocs.lua
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ local function getTypesOfFile(uri)
or src.type == 'doc.class.name'
or src.type == 'doc.extends.name'
or src.type == 'doc.alias.name' then
if src.type == 'doc.type.name' then
if guide.getParentDocTypeTable(src) then
return
end
end
local name = src[1]
if name then
if not types[name] then
Expand Down
88 changes: 88 additions & 0 deletions test/definition/luadoc.lua
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,91 @@ function Generic(arg1) print(arg1) end
local v1 = Generic("Foo")
print(v1.<?bar1?>)
]]

TEST [[
---@class Foo
local Foo = {}
function Foo:<!bar1!>() end
---@type table<number, Foo>
local v1
print(v1[1].<?bar1?>)
]]

TEST [[
---@class Foo
local Foo = {}
function Foo:<!bar1!>() end
---@class Foo2
local Foo2 = {}
function Foo2:bar1() end
---@type Foo2<number, Foo>
local v1
print(v1[1].<?bar1?>)
]]

--TODO 得扩展 simple 的信息才能识别这种情况了
--TEST [[
-----@class Foo
--local Foo = {}
--function Foo:bar1() end
--
-----@class Foo2
--local Foo2 = {}
--function Foo2:<!bar1!>() end
--
-----@type Foo2<number, Foo>
--local v1
--print(v1.<?bar1?>)
--]]

TEST [[
---@class Foo
local Foo = {}
function Foo:<!bar1!>() end
---@type table<number, Foo>
local v1
local ipairs = ipairs
for i, v in ipairs(v1) do
print(v.<?bar1?>)
end
]]

TEST [[
---@class Foo
local Foo = {}
function Foo:<!bar1!>() end
---@type table<Foo, Foo>
local v1
for k, v in pairs(v1) do
print(k.<?bar1?>)
print(v.bar1)
end
]]

TEST [[
---@class Foo
local Foo = {}
function Foo:<!bar1!>() end
---@type table<number, table<number, Foo>>
local v1
for i, v in ipairs(v1) do
local v2 = v[1]
print(v2.<?bar1?>)
end
]]

TEST [[
---@class Foo
local Foo = {}
function Foo:<!bar1!>() end
---@type table<number, table<number, Foo>>
local v1
print(v1[1][1].<?bar1?>)
]]
6 changes: 3 additions & 3 deletions test/example/guide.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2702,7 +2702,7 @@ function m.viewInferType(infers)
or src.type == 'doc.class.name'
or src.type == 'doc.type.name'
or src.type == 'doc.type.array'
or src.type == 'doc.type.generic' then
or src.type == 'doc.type.table' then
if infer.type ~= 'any' then
hasDoc = true
break
Expand All @@ -2717,7 +2717,7 @@ function m.viewInferType(infers)
or src.type == 'doc.class.name'
or src.type == 'doc.type.name'
or src.type == 'doc.type.array'
or src.type == 'doc.type.generic'
or src.type == 'doc.type.table'
or src.type == 'doc.type.enum'
or src.type == 'doc.resume' then
local tp = infer.type or 'any'
Expand Down Expand Up @@ -2946,7 +2946,7 @@ local function getDocTypeUnitName(status, unit)
typeName = 'function'
elseif unit.type == 'doc.type.array' then
typeName = getDocTypeUnitName(status, unit.node) .. '[]'
elseif unit.type == 'doc.type.generic' then
elseif unit.type == 'doc.type.table' then
typeName = ('%s<%s, %s>'):format(
getDocTypeUnitName(status, unit.node),
m.viewInferType(m.getDocTypeNames(status, unit.key)),
Expand Down

0 comments on commit 924aa35

Please sign in to comment.