diff --git a/src/vba_unit/Coverage/coverage.py b/src/vba_unit/Coverage/coverage.py index fb020bf..7c88fca 100644 --- a/src/vba_unit/Coverage/coverage.py +++ b/src/vba_unit/Coverage/coverage.py @@ -2,7 +2,7 @@ import requests from typing import TypeVar from .git_repo import GitRepo -from vba_unit.Interpreter.coverage_table import CoverageTable +from pyvba_interpreter.symbol_table import SymbolTable T = TypeVar('T', bound='Coverage') @@ -16,7 +16,7 @@ class Coverage(): """ def __init__(self: T) -> None: self.endpoint = '' - self.table: CoverageTable + self.table: SymbolTable self.git: GitRepo def generate_report(self: T) -> dict: diff --git a/src/vba_unit/Coverage/coveralls.py b/src/vba_unit/Coverage/coveralls.py index 179f3a2..481896d 100644 --- a/src/vba_unit/Coverage/coveralls.py +++ b/src/vba_unit/Coverage/coveralls.py @@ -3,7 +3,7 @@ from typing import TypeVar from vba_unit.Coverage.coverage import Coverage from vba_unit.Coverage.git_repo import GitRepo -from vba_unit.Interpreter.coverage_table import CoverageTable, VbaUnitModDef +from pyvba_interpreter.symbol_table import ModuleDefinition, SymbolTable T = TypeVar('T', bound='Coveralls') @@ -13,13 +13,13 @@ class Coveralls(Coverage): def __init__(self: T) -> None: self.endpoint = "https://coveralls.io/api/v1/jobs" self.git: GitRepo - self.table: CoverageTable + self.table: SymbolTable def generate_report(self: T) -> dict: source_files = [] for lib in self.table.definitions.values(): for module in lib["modules"].values(): - if module["cover"]: + if module["extra"]["vba_unit"]["cover"]: file_cov = self.file_coverage(module) source_files.append(file_cov) @@ -35,13 +35,13 @@ def generate_report(self: T) -> dict: } return report - def file_coverage(self: T, module: VbaUnitModDef) -> dict: - file_path = module["path"] + def file_coverage(self: T, module: ModuleDefinition) -> dict: + file_path = module["extra"]["vba_unit"]["path"] with open(file_path, 'r') as f: source_code = f.read() digest = hashlib.md5(source_code.encode('utf-8')).hexdigest() return { "name": file_path, "source_digest": digest, - "coverage": module["coverage"], + "coverage": module["extra"]["vba_unit"]["coverage"], } diff --git a/src/vba_unit/Interpreter/coverage_table.py b/src/vba_unit/Interpreter/coverage_table.py deleted file mode 100644 index 3032948..0000000 --- a/src/vba_unit/Interpreter/coverage_table.py +++ /dev/null @@ -1,23 +0,0 @@ -from pyvba_interpreter.symbol_table import ( - FunctionDefinition, ModuleDefinition, ProjectDefinition, SymbolTable -) - - -class VbaUnitFuncDef(FunctionDefinition): - visited: bool - start_end_lines: tuple[int, int] - - -class VbaUnitModDef(ModuleDefinition): - path: str # The file path - cover: bool # Track coverage on this file? - coverage: list[None | int] # lines covered - functions: dict[str, VbaUnitFuncDef] # {@inheritDoc} - - -class VbaUnitProjDef(ProjectDefinition): - modules: dict[str, VbaUnitModDef] - - -class CoverageTable(SymbolTable): - pass diff --git a/src/vba_unit/Interpreter/vba_unit_listener.py b/src/vba_unit/Interpreter/vba_unit_listener.py index 0d50fae..79ec950 100644 --- a/src/vba_unit/Interpreter/vba_unit_listener.py +++ b/src/vba_unit/Interpreter/vba_unit_listener.py @@ -1,27 +1,46 @@ +from antlr4 import CommonTokenStream from antlr4_vba.vbaParser import vbaParser as Parser from antlr4_vba.vbaLexer import vbaLexer as Lexer +from pyvba_interpreter.symbol_table import SymbolTable from pyvba_interpreter.vba_listener import VbaListener -from typing import TypeVar +from typing import cast, TypedDict, TypeVar T = TypeVar('T', bound='VbaUnitListener') +class VbaUnitModuleextra(TypedDict): + path: str # The file path + cover: bool # Track coverage on this file? + coverage: list[None | int] # lines covered + + class VbaUnitListener(VbaListener): + def __init__(self: T, project: str, table: SymbolTable) -> None: + self.parser: Parser + super().__init__(project, table) def enterProceduralModuleHeader( # noqa: N802 self: T, ctx: Parser.ProceduralModuleHeaderContext) -> None: super().enterProceduralModuleHeader(ctx) - token_stream = self.parser.getInputStream() + token_stream = cast(CommonTokenStream, self.parser.getTokenStream()) token_stream.fill() eof_token = token_stream.get(len(token_stream.tokens) - 1) total_lines = eof_token.line - mods = self.table.definitions[self.project_name]["modules"] - mods[self.module_name.lower()]["coverage"] = [0] * total_lines - mods[self.module_name.lower()]["cover"] = True - mods[self.module_name.lower()]["coverage"][total_lines - 1] = None - mods[self.module_name.lower()]["coverage"][ctx.start.line - 1] = 1 + name = self.module_name.lower() + mod = self.table.definitions[self.project_name]["modules"][name] + extra: VbaUnitModuleextra = { + "coverage": [0] * total_lines, + "cover": True, + "path": '' + } + mod["extra"]["vba_unit"] = extra + # EOF line is ignored. + # Need to test the case where EOF is on the same line as code. + mod["extra"]["vba_unit"]["coverage"][total_lines - 1] = None + if ctx.start is not None: + mod["extra"]["vba_unit"]["coverage"][ctx.start.line - 1] = 1 def enterCommentBody( # noqa: N802 self: T, @@ -29,22 +48,23 @@ def enterCommentBody( # noqa: N802 super().enterCommentBody(ctx) # Check if the token starts at column 1 or if the token before this is # a wsc and it starts at column 1 - in_str = self.parser.getInputStream() - if ctx.start is not None: + in_str = cast(CommonTokenStream, self.parser.getTokenStream()) + if ctx.start is not None and ctx.start.tokenIndex is not None: if (ctx.start.column == 0 or in_str.get(ctx.start.tokenIndex - 1).column == 0): # Comments cannot be the first token in a file, so # tokenIndex - 1 cannot be less than zero index_num = ctx.start.line - 1 + name = self.module_name.lower() mods = self.table.definitions[self.project_name]["modules"] - mods[self.module_name.lower()]["coverage"][index_num] = None + mods[name]["extra"]["vba_unit"]["coverage"][index_num] = None def enterEndOfLine( # noqa: N802 self: T, ctx: Parser.EndOfLineContext) -> None: super().enterEndOfLine(ctx) - in_str = self.parser.getInputStream() - if ctx.start is not None: + in_str = cast(CommonTokenStream, self.parser.getTokenStream()) + if ctx.start is not None and ctx.start.tokenIndex is not None: tok_ind = ctx.start.tokenIndex is_wsc = in_str.get(tok_ind - 1).type == Lexer.WS if ( @@ -52,5 +72,6 @@ def enterEndOfLine( # noqa: N802 (is_wsc and (in_str.get(tok_ind - 1).column == 0)) ): index_num = ctx.start.line - 1 + name = self.module_name.lower() mods = self.table.definitions[self.project_name]["modules"] - mods[self.module_name.lower()]["coverage"][index_num] = None + mods[name]["extra"]["vba_unit"]["coverage"][index_num] = None diff --git a/src/vba_unit/Interpreter/vba_unit_visitor.py b/src/vba_unit/Interpreter/vba_unit_visitor.py index f5b9137..8cbd483 100644 --- a/src/vba_unit/Interpreter/vba_unit_visitor.py +++ b/src/vba_unit/Interpreter/vba_unit_visitor.py @@ -1,10 +1,12 @@ from antlr4.tree.Tree import Tree from antlr4 import ParserRuleContext from antlr4_vba.vbaParser import vbaParser as Parser +from pyvba_interpreter.symbol_table import ( + FunctionDefinition, LibraryDefinition, SymbolTable +) from pyvba_interpreter.vba_visitor import VbaVisitor from typing import Any, TypeVar from vba_unit.test_fail_exception import TestFailException -from .coverage_table import VbaUnitFuncDef, CoverageTable T = TypeVar('T', bound='VbaUnitVisitor') @@ -12,7 +14,7 @@ class VbaUnitVisitor(VbaVisitor): - def __init__(self: T, table: CoverageTable) -> None: + def __init__(self: T, table: SymbolTable) -> None: self.current_line = 0 super().__init__(table) @@ -24,14 +26,15 @@ def visit(self: T, tree: Tree) -> Any: prev_line = self.current_line if tok is not None: mods = self.table.definitions[self.context[0]]["modules"] - if mods[self.context[1]]["cover"]: + if mods[self.context[1]]["extra"]["vba_unit"]["cover"]: line_num = tok.line if line_num != self.current_line: self.current_line = line_num self.context_changed = False name = self.context[0] mods = self.table.definitions[name]["modules"] - coverage = mods[self.context[1]]["coverage"] + extra = mods[self.context[1]]["extra"]["vba_unit"] + coverage = extra["coverage"] coverage[line_num - 1] += 1 # Call the original visit to continue traversal return super().visit(tree) @@ -43,10 +46,11 @@ def visitFunctionDeclaration( # noqa: N802 # Touch the end function statement # If there is an Exit Function statement immediately beore the end, # is there a way to prohibit it from being touched...does it matter? - line_num = ctx.stop.line - mods = self.table.definitions[self.context[0]]["modules"] - coverage = mods[self.context[1]]["coverage"] - coverage[line_num - 1] += 1 + if ctx.stop is not None: + line_num = ctx.stop.line + mods = self.table.definitions[self.context[0]]["modules"] + coverage = mods[self.context[1]]["extra"]["vba_unit"]["coverage"] + coverage[line_num - 1] += 1 return super().visitFunctionDeclaration(ctx) @@ -60,7 +64,7 @@ def visitAssertStatement( # noqa: N802 raise TestFailException() def run_function(self: T, - defn: VbaUnitFuncDef, + defn: FunctionDefinition | LibraryDefinition, args: list[Any]) -> Any: prev_line = self.current_line if (self.context[0] != defn["project"] or diff --git a/src/vba_unit/cli.py b/src/vba_unit/cli.py index 0447684..4513572 100644 --- a/src/vba_unit/cli.py +++ b/src/vba_unit/cli.py @@ -4,10 +4,10 @@ from antlr4 import FileStream, CommonTokenStream, ParseTreeWalker from antlr4_vba.vbaLexer import vbaLexer from antlr4_vba.vbaParser import vbaParser +from pyvba_interpreter.symbol_table import ModuleDefinition, SymbolTable from typing import TypeVar from vba_unit.Coverage.coverage_factory import CovFact from vba_unit.Coverage.git_factory import GitFact -from vba_unit.Interpreter.coverage_table import VbaUnitModDef, CoverageTable from vba_unit.Interpreter.vba_unit_listener import VbaUnitListener from vba_unit.Interpreter.vba_unit_visitor import VbaUnitVisitor from vba_unit.test_fail_exception import TestFailException @@ -52,7 +52,7 @@ def main() -> None: ) args = parser.parse_args() - table = CoverageTable() + table = SymbolTable() run_tests(args.src, args.tests, args.project, table) # Submit Coverage @@ -72,7 +72,7 @@ def main() -> None: def run_tests(src: str, tests: str, - project_name: str, table: CoverageTable) -> None: + project_name: str, table: SymbolTable) -> None: test_project_name = "vbatests" # Parse source code @@ -95,7 +95,7 @@ def run_tests(src: str, tests: str, _generate_report(report) -def _parse_file(file_path: str, project: str, table: CoverageTable) -> None: +def _parse_file(file_path: str, project: str, table: SymbolTable) -> None: input_stream = FileStream(file_path, encoding="cp1252") lexer = vbaLexer(input_stream) ts = CommonTokenStream(lexer) @@ -107,16 +107,18 @@ def _parse_file(file_path: str, project: str, table: CoverageTable) -> None: walker.walk(listener, tree) mod_name = listener.module_name.lower() project = project.lower() - table.definitions[project]["modules"][mod_name]["path"] = file_path + mod = table.definitions[project]["modules"][mod_name] + extra = mod["extra"]["vba_unit"] + extra["path"] = file_path if project == "vbatests": - table.definitions[project]["modules"][mod_name]["cover"] = False + extra["cover"] = False else: - table.definitions[project]["modules"][mod_name]["cover"] = True + extra["cover"] = True def _run_all_tests( - test_modules: dict[str, VbaUnitModDef], - table: CoverageTable) -> list: + test_modules: dict[str, ModuleDefinition], + table: SymbolTable) -> list: report = [] visitor = VbaUnitVisitor(table) for mod_name, module in test_modules.items(): diff --git a/tests/Unit/test_coveralls.py b/tests/Unit/test_coveralls.py index b5cb077..4f8e57b 100644 --- a/tests/Unit/test_coveralls.py +++ b/tests/Unit/test_coveralls.py @@ -37,10 +37,16 @@ def __init__(self: T) -> None: "modules": { "roots": { "name": "roots", - "cover": True, - "coverage": [1, None, None, None, None, None, None, - None, None, None, None, 1, 1, 1], - "path": 'src/Modules/Roots.bas' + "extra": { + "vba_unit": { + "cover": True, + "coverage": [1, None, None, None, + None, None, None, + None, None, None, None, + 1, 1, 1], + "path": 'src/Modules/Roots.bas' + } + } } } } diff --git a/tests/Unit/test_listener.py b/tests/Unit/test_listener.py index 980d93f..945e7f1 100644 --- a/tests/Unit/test_listener.py +++ b/tests/Unit/test_listener.py @@ -1,12 +1,12 @@ from antlr4 import FileStream, CommonTokenStream, ParseTreeWalker from antlr4_vba.vbaLexer import vbaLexer from antlr4_vba.vbaParser import vbaParser -from vba_unit.Interpreter.coverage_table import CoverageTable +from pyvba_interpreter.symbol_table import SymbolTable from vba_unit.Interpreter.vba_unit_listener import VbaUnitListener def test_listener() -> None: - table = CoverageTable() + table = SymbolTable() input_stream = FileStream( "tests/src/VbaProject/Module1.bas", encoding="cp1252" @@ -19,9 +19,10 @@ def test_listener() -> None: listener.parser = parser walker = ParseTreeWalker() walker.walk(listener, tree) + mod = table.definitions["vbaproject"]["modules"]["module1"] + extra = mod["extra"]["vba_unit"] + assert extra["cover"] - assert table.definitions["vbaproject"]["modules"]["module1"]["cover"] - - result = table.definitions["vbaproject"]["modules"]["module1"]["coverage"] + result = extra["coverage"] expected = [1, 0, None, 0, 0, None] assert result == expected diff --git a/tests/Unit/test_visitor.py b/tests/Unit/test_visitor.py index f0e3e8e..b635272 100644 --- a/tests/Unit/test_visitor.py +++ b/tests/Unit/test_visitor.py @@ -1,13 +1,13 @@ from antlr4 import FileStream, CommonTokenStream, ParseTreeWalker from antlr4_vba.vbaLexer import vbaLexer from antlr4_vba.vbaParser import vbaParser -from vba_unit.Interpreter.coverage_table import CoverageTable +from pyvba_interpreter.symbol_table import SymbolTable from vba_unit.Interpreter.vba_unit_listener import VbaUnitListener from vba_unit.Interpreter.vba_unit_visitor import VbaUnitVisitor -def test_listener() -> None: - table = CoverageTable() +def test_visitor() -> None: + table = SymbolTable() input_stream = FileStream( "tests/src/VbaProject/Module1.bas", encoding="cp1252" @@ -23,6 +23,6 @@ def test_listener() -> None: visitor = VbaUnitVisitor(table) mod = table.definitions["vbaproject"]["modules"]["module1"] func = mod["functions"]["foo"] - assert mod["coverage"][3] == 0 + assert mod["extra"]["vba_unit"]["coverage"][3] == 0 visitor.run_function(func, []) - assert mod["coverage"][3] == 1 + assert mod["extra"]["vba_unit"]["coverage"][3] == 1