In [1]:
import sys
from pathlib import Path
# Add the project root to Python path
sys.path.append(str(Path.cwd().parent.parent))

# Now your original imports should work
from parser_engine.language_parsers.python_parser import PythonParser
from parser_engine.models.data_models import *
from parser_engine.core.repo_analyzer import RepoIndexer
from parser_engine.language_parsers.python_parser import PythonFunctionCallVisitor


repo_path = Path('/Users/qudi/Desktop/workspace/third-party/fairscale/fairscale/')

In [2]:
parser = PythonParser()

# Parse all Python files in the repo
modules = []
for file_path in repo_path.rglob('*.py'):
    if parser.can_parse(file_path):
        try:
            module = parser.parse_file(file_path, repo_root=repo_path)
            modules.append(module)
        except Exception as e:
            print(f"Error parsing {file_path}: {e}")

# Create and run indexer
indexer = RepoIndexer(modules)
indexer.index_repository()

In [None]:
for module in modules:
    print(f"module: {module.name}")
    print(f"    imports_mapping: {module.imports_mapping}")
    print(f"    imports: {module.imports}")

In [None]:


# Print some basic stats
print(f"Found {len(modules)} Python modules")
print(f"Total classes: {sum(len(m.classes) for m in modules)}")
print(f"Total functions: {sum(len(m.functions) for m in modules)}")

# Print details of each module
for module in modules:
    print(f"\nModule: {module.name}")
    
    print("Classes:")
    for class_elem in module.classes:
        print(f"  {class_elem.name}")
        print(f"    Methods: {len(class_elem.methods)}")
        print(f"    Base classes: {class_elem.base_classes}")
        
    print("Functions:")
    for func in module.functions:
        print(f"  {func.name}")
        if func.function_calls:
            print("    Calls:")
            for call in func.function_calls:
                print(f"      {call.name} -> {call.resolved_name}")


In [5]:
from collections import defaultdict

def build_call_graph(modules: List[ModuleElement], symbol_table: Dict[str, FunctionElement]) -> Dict[str, List[str]]:
    call_graph = defaultdict(list)
    for module in modules:
        # module.functions
        for func in module.functions:
            caller = func.qualified_name
            if not caller or caller not in symbol_table:
                continue
            for call in func.function_calls:
                if call.resolved_name and call.resolved_name in symbol_table:  # skip if None or not in symbol table
                    call_graph[caller].append(call.resolved_name)

        # module.classes -> .methods
        for cls in module.classes:
            for method in cls.methods:
                caller = method.qualified_name
                if not caller or caller not in symbol_table:
                    continue
                for call in method.function_calls:
                    if call.resolved_name and call.resolved_name in symbol_table:
                        call_graph[caller].append(call.resolved_name)
    return dict(call_graph)

In [None]:
call_graph = build_call_graph(modules, indexer.symbol_table)
# print call_graph
for caller, callees in call_graph.items():
    print(f"{caller}: {callees}")


In [7]:
def get_function_source_code(function_name: str, symbol_table: Dict[str, FunctionElement]) -> str:
    function_module = symbol_table[function_name]
    function_parent_module = function_module.module
    function_source_code = function_parent_module.body.splitlines()[function_module.start_line-1:function_module.end_line]
    return '\n'.join(function_source_code)


In [None]:
call_graph["data_parallel.fsdp_optim_utils.flatten_optim_state_dict"]

In [None]:
print(get_function_source_code("data_parallel.fsdp_optim_utils.flatten_optim_state_dict", indexer.symbol_table))


In [22]:
from collections import defaultdict
from typing import Dict, List, Optional

def group_calls_by_line(
    function: FunctionElement,
    symbol_table: Dict[str, FunctionElement]
) -> Dict[int, List[FunctionCallElement]]:
    """
    Returns a dict: line_number -> list of calls on that line.
    Note: We subtract 1 from call.line_number if your parser
    stores them 1-based. Adjust if needed.
    """
    calls_by_line = defaultdict(list)
    for call in function.function_calls:
        if call.resolved_name in symbol_table:
            # line_number - 1 if your code is 1-indexed
            calls_by_line[call.line_number - 1].append(call)
    return dict(calls_by_line)


def display_function_source_and_calls(
    func_fqn: str,
    symbol_table: Dict[str, FunctionElement],
    call_graph: Dict[str, List[str]],       # If you still need this, keep it; otherwise it can be removed
    visited_stack: Optional[List[str]] = None,
    indent: int = 0
):
    """
    Print the function's source code line by line. 
    Whenever we detect a call on a line, immediately inline-expand
    that callee's source code, indented further.
    """
    if visited_stack is None:
        visited_stack = []

    # --- Cycle detection
    if func_fqn in visited_stack:
        cycle_start_index = visited_stack.index(func_fqn)
        cycle_path = visited_stack[cycle_start_index:]
        print(" " * indent + f"(cycle) {func_fqn} ... {' -> '.join(cycle_path)} -> {func_fqn}")
        return

    visited_stack.append(func_fqn)

    # --- Retrieve the FunctionElement
    func_elem = symbol_table.get(func_fqn)
    if not func_elem:
        # Probably a built-in or unresolved function
        print(" " * indent + f"* {func_fqn} (unresolved or built-in)")
        visited_stack.pop()
        return

    # --- Print a short heading (function name or FQN)
    short_name = func_elem.name or func_fqn
    print(" " * indent + f"{short_name}()")

    # --- Get the function's source
    function_source_code = get_function_source_code(func_fqn, symbol_table)
    if function_source_code:
        source_lines = function_source_code.splitlines()
        # Build a dict of line_number -> calls
        calls_by_line = group_calls_by_line(func_elem, symbol_table)

        for idx, line_text in enumerate(source_lines):
            # Print the line of code
            print(" " * (indent + 2) + line_text)

            # Check if this line has any calls
            if idx in calls_by_line:
                for call in calls_by_line[idx]:
                    callee_fqn = call.resolved_name
                    if callee_fqn:
                        # Print a short note & inline-expand the callee
                        previous_line_length = len(source_lines[idx-1]) if idx > 0 else 0
                        print(" " * (previous_line_length) + f"-> calls {callee_fqn}")
                        display_function_source_and_calls(
                            callee_fqn,
                            symbol_table,
                            call_graph,
                            visited_stack,
                            previous_line_length
                        )
                    else:
                        # Unresolved or missing from symbol_table
                        print(" " * (previous_line_length) + f"-> calls {call.name} (unresolved)")
    else:
        print(" " * (indent + 2) + "<No source available>")

    # --- Pop from the recursion stack
    visited_stack.pop()

In [None]:
display_function_source_and_calls("data_parallel.fsdp_optim_utils.flatten_optim_state_dict", indexer.symbol_table, call_graph)

In [None]:
indexer.symbol_table

In [None]:
import ast

# Sample Python code to analyze
sample_code = """
import os
import sys as system
from pathlib import Path
from datetime import datetime as dt, timedelta
from json import dumps as to_json

def process_file(filepath):
    if system.version_info >= (3, 8):
        current_time = dt.now()
        path = Path(filepath)
        if path.exists():
            stats = os.stat(filepath)
            result = {
                'filename': path.name,
                'modified': current_time - timedelta(days=1),
                'size': stats.st_size
            }
            return to_json(result)
    return None
"""

# Create a simple test harness
def test_import_parsing():
    # Parse the code into an AST
    tree = ast.parse(sample_code)
    
    # Create an instance of our parser
    parser = PythonParser()
    
    # Track all imports
    imports_mapping = {}
    
    # Find and parse all import nodes
    for node in ast.walk(tree):
        if isinstance(node, (ast.Import, ast.ImportFrom)):
            imports_mapping.update(parser._parse_imports(node))
    
    # Print the resulting mapping
    print("Import mappings:")
    for local_name, full_path in imports_mapping.items():
        print(f"{local_name:12} -> {full_path}")

test_import_parsing()