diff --git a/script/parser/compile.lua b/script/parser/compile.lua index d4129ab4f..f9992047a 100644 --- a/script/parser/compile.lua +++ b/script/parser/compile.lua @@ -12,6 +12,8 @@ local specials = { ['loadfile'] = true, ['pcall'] = true, ['xpcall'] = true, + ['pairs'] = true, + ['ipairs'] = true, } _ENV = nil diff --git a/script/parser/guide.lua b/script/parser/guide.lua index 354a5eed7..d0e82ba8a 100644 --- a/script/parser/guide.lua +++ b/script/parser/guide.lua @@ -97,7 +97,7 @@ m.childMap = { ['doc.generic'] = {'#generics', 'comment'}, ['doc.generic.object'] = {'generic', 'extends', 'comment'}, ['doc.vararg'] = {'vararg', 'comment'}, - ['doc.type.table'] = {'key', 'value', 'comment'}, + ['doc.type.table'] = {'node', 'key', 'value', 'comment'}, ['doc.type.function'] = {'#args', '#returns', 'comment'}, ['doc.type.typeliteral'] = {'node'}, ['doc.overload'] = {'overload', 'comment'}, @@ -169,6 +169,21 @@ function m.getParentFunction(obj) return nil end +--- 寻找父的table类型 doc.type.table +function m.getParentDocTypeTable(obj) + for _ = 1, 1000 do + local parent = obj.parent + if not parent then + return nil + end + if parent.type == 'doc.type.table' then + return obj + end + obj = parent + end + error('guide.getParentDocTypeTable overstack') +end + --- 寻找所在区块 function m.getBlock(obj) for _ = 1, 1000 do @@ -1648,12 +1663,20 @@ function m.checkSameSimpleOfRefByDocSource(status, obj, start, pushQueue, mode) end end -local function getArrayLevel(obj) +local function getArrayOrTableLevel(obj) local level = 0 while true do local parent = obj.parent if parent.type == 'doc.type.array' then level = level + 1 + elseif parent.type == 'doc.type.table' then + if obj.type == 'doc.type' then + level = level + 1 + -- else 只存在 obj.type == 'doc.type.name' 的情况,即 table 中的 table,这种是不需要再增加层级的 + end + elseif parent.type == 'doc.type' and parent.parent and parent.parent.type == 'doc.type.table' then + level = level + 1 + parent = parent.parent else break end @@ -1710,9 +1733,10 @@ function m.checkSameSimpleByDoc(status, obj, start, pushQueue, mode) for _, res in ipairs(pieceResult) do pushQueue(res, start, true) end + local state = m.getDocState(obj) if state.type == 'doc.type' and mode == 'ref' then - m.checkSameSimpleOfRefByDocSource(status, state, start - getArrayLevel(obj), pushQueue, mode) + m.checkSameSimpleOfRefByDocSource(status, state, start - getArrayOrTableLevel(obj), pushQueue, mode) end return true elseif obj.type == 'doc.field' then @@ -1723,6 +1747,10 @@ function m.checkSameSimpleByDoc(status, obj, start, pushQueue, mode) elseif obj.type == 'doc.type.array' then pushQueue(obj.node, start + 1, true) return true + elseif obj.type == 'doc.type.table' then + pushQueue(obj.node, start, true) + pushQueue(obj.value, start + 1, true) + return true end end @@ -2188,6 +2216,72 @@ function m.checkSameSimpleAsSetValue(status, ref, start, pushQueue) end end +local function getTableAndIndexIfIsForPairsKeyOrValue(ref) + if ref.type ~= 'local' then + return + end + + if not ref.parent or ref.parent.type ~= 'in' then + return + end + + if not ref.value or ref.value.type ~= 'select' then + return + end + + local rootSelectObj = ref.value + if rootSelectObj.index ~= 1 and rootSelectObj.index ~= 2 then + return + end + + if not rootSelectObj.vararg or rootSelectObj.vararg.type ~= 'call' then + return + end + local rootCallObj = rootSelectObj.vararg + + if not rootCallObj.node or rootCallObj.node.type ~= 'call' then + return + end + local pairsCallObj = rootCallObj.node + + if not pairsCallObj.node + or (pairsCallObj.node.special ~= 'pairs' and pairsCallObj.node.special ~= 'ipairs') then + return + end + + if not pairsCallObj.args or not pairsCallObj.args[1] then + return + end + local tableObj = pairsCallObj.args[1] + + return tableObj, rootSelectObj.index +end + +function m.checkSameSimpleAsKeyOrValueInForParis(status, ref, start, pushQueue) + local tableObj, index = getTableAndIndexIfIsForPairsKeyOrValue(ref) + if not tableObj then + return + end + + local newStatus = m.status(status) + m.searchRefs(newStatus, tableObj, 'def') + for _, def in ipairs(newStatus.results) do + if def.bindDocs then + for _, binddoc in ipairs(def.bindDocs) do + if binddoc.type == 'doc.type' then + if binddoc.types[1] and binddoc.types[1].type == 'doc.type.table' then + if index == 1 then + pushQueue(binddoc.types[1].key, start, true) + elseif index == 2 then + pushQueue(binddoc.types[1].value, start, true) + end + end + end + end + end + end +end + local function hasTypeName(doc, name) if doc.type ~= 'doc.type' then return false @@ -2422,6 +2516,8 @@ function m.checkSameSimple(status, simple, ref, start, force, mode, pushQueue) m.checkSameSimpleAsReturn(status, ref, i, pushQueue) -- 检查形如 a = f 的情况 m.checkSameSimpleAsSetValue(status, ref, i, pushQueue) + -- 检查形如 for k,v in pairs()/ipairs() do end 的情况 + m.checkSameSimpleAsKeyOrValueInForParis(status, ref, i, pushQueue) end end if i == #simple then @@ -2863,7 +2959,7 @@ function m.viewInferType(infers) or src.type == 'doc.class.name' or src.type == 'doc.type.name' or src.type == 'doc.type.array' - or src.type == 'doc.type.generic' then + or src.type == 'doc.type.table' then if infer.type ~= 'any' then hasDoc = true break @@ -2878,7 +2974,7 @@ function m.viewInferType(infers) or src.type == 'doc.class.name' or src.type == 'doc.type.name' or src.type == 'doc.type.array' - or src.type == 'doc.type.generic' + or src.type == 'doc.type.table' or src.type == 'doc.type.enum' or src.type == 'doc.resume' then local tp = infer.type or 'any' @@ -3107,7 +3203,7 @@ function m.getDocTypeUnitName(status, unit) typeName = 'function' elseif unit.type == 'doc.type.array' then typeName = m.getDocTypeUnitName(status, unit.node) .. '[]' - elseif unit.type == 'doc.type.generic' then + elseif unit.type == 'doc.type.table' then typeName = ('%s<%s, %s>'):format( m.getDocTypeUnitName(status, unit.node), m.viewInferType(m.getDocTypeNames(status, unit.key)), diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua index 647a6bedc..47248ba4f 100644 --- a/script/parser/luadoc.lua +++ b/script/parser/luadoc.lua @@ -261,30 +261,36 @@ local function parseTypeUnitArray(node) return result end -local function parseTypeUnitGeneric(node) +local function parseTypeUnitTable(parent, node) if not checkToken('symbol', '<', 1) then return nil end if not nextSymbolOrError('<') then return nil end - local key = parseType(node) + + local result = { + type = 'doc.type.table', + start = node.start, + node = node, + parent = parent, + } + + local key = parseType(result) if not key or not nextSymbolOrError(',') then return nil end - local value = parseType(node) + local value = parseType(result) if not value then return nil end nextSymbolOrError('>') - local result = { - type = 'doc.type.generic', - start = node.start, - finish = getFinish(), - node = node, - key = key, - value = value, - } + + node.parent = result; + result.finish = getFinish() + result.key = key + result.value = value + return result end @@ -398,7 +404,7 @@ local function parseTypeUnit(parent, content) result.parent = parent while true do local newResult = parseTypeUnitArray(result) - or parseTypeUnitGeneric(result) + or parseTypeUnitTable(parent, result) if not newResult then break end diff --git a/script/vm/getDocs.lua b/script/vm/getDocs.lua index 632dd1c28..790a9b50d 100644 --- a/script/vm/getDocs.lua +++ b/script/vm/getDocs.lua @@ -16,6 +16,11 @@ local function getTypesOfFile(uri) or src.type == 'doc.class.name' or src.type == 'doc.extends.name' or src.type == 'doc.alias.name' then + if src.type == 'doc.type.name' then + if guide.getParentDocTypeTable(src) then + return + end + end local name = src[1] if name then if not types[name] then diff --git a/test/definition/luadoc.lua b/test/definition/luadoc.lua index 1f3dae007..5315b5fd6 100644 --- a/test/definition/luadoc.lua +++ b/test/definition/luadoc.lua @@ -253,3 +253,91 @@ function Generic(arg1) print(arg1) end local v1 = Generic("Foo") print(v1.) ]] + +TEST [[ +---@class Foo +local Foo = {} +function Foo:() end + +---@type table +local v1 +print(v1[1].) +]] + +TEST [[ +---@class Foo +local Foo = {} +function Foo:() end + +---@class Foo2 +local Foo2 = {} +function Foo2:bar1() end + +---@type Foo2 +local v1 +print(v1[1].) +]] + +--TODO 得扩展 simple 的信息才能识别这种情况了 +--TEST [[ +-----@class Foo +--local Foo = {} +--function Foo:bar1() end +-- +-----@class Foo2 +--local Foo2 = {} +--function Foo2:() end +-- +-----@type Foo2 +--local v1 +--print(v1.) +--]] + +TEST [[ +---@class Foo +local Foo = {} +function Foo:() end + +---@type table +local v1 +local ipairs = ipairs +for i, v in ipairs(v1) do + print(v.) +end +]] + +TEST [[ +---@class Foo +local Foo = {} +function Foo:() end + +---@type table +local v1 +for k, v in pairs(v1) do + print(k.) + print(v.bar1) +end +]] + +TEST [[ +---@class Foo +local Foo = {} +function Foo:() end + +---@type table> +local v1 +for i, v in ipairs(v1) do + local v2 = v[1] + print(v2.) +end +]] + +TEST [[ +---@class Foo +local Foo = {} +function Foo:() end + +---@type table> +local v1 +print(v1[1][1].) +]] diff --git a/test/example/guide.txt b/test/example/guide.txt index 437e37b0c..da8d5c326 100644 --- a/test/example/guide.txt +++ b/test/example/guide.txt @@ -2702,7 +2702,7 @@ function m.viewInferType(infers) or src.type == 'doc.class.name' or src.type == 'doc.type.name' or src.type == 'doc.type.array' - or src.type == 'doc.type.generic' then + or src.type == 'doc.type.table' then if infer.type ~= 'any' then hasDoc = true break @@ -2717,7 +2717,7 @@ function m.viewInferType(infers) or src.type == 'doc.class.name' or src.type == 'doc.type.name' or src.type == 'doc.type.array' - or src.type == 'doc.type.generic' + or src.type == 'doc.type.table' or src.type == 'doc.type.enum' or src.type == 'doc.resume' then local tp = infer.type or 'any' @@ -2946,7 +2946,7 @@ local function getDocTypeUnitName(status, unit) typeName = 'function' elseif unit.type == 'doc.type.array' then typeName = getDocTypeUnitName(status, unit.node) .. '[]' - elseif unit.type == 'doc.type.generic' then + elseif unit.type == 'doc.type.table' then typeName = ('%s<%s, %s>'):format( getDocTypeUnitName(status, unit.node), m.viewInferType(m.getDocTypeNames(status, unit.key)),