In [1]:
import sys
from pathlib import Path
# Add the project root to Python path
sys.path.append(str(Path.cwd().parent.parent))

# Now your original imports should work
from parser_engine.language_parsers.python_parser import PythonParser
from parser_engine.models.data_models import *
from parser_engine.core.repo_analyzer import RepoIndexer
from parser_engine.language_parsers.python_parser import PythonFunctionCallVisitor


repo_path = Path('/Users/qudi/Desktop/workspace/third-party/fairscale/fairscale/nn/model_parallel/')

In [2]:
# Create a test function with various function calls
test_code = """
import math
from typing import List
from collections import defaultdict as dd

def complex_function(x: int, y: List[int]) -> int:
    # Direct function call
    print("Testing function calls")
    
    # Built-in function call
    abs_val = abs(x)
    
    # Imported module function call
    sin_val = math.sin(x)
    
    # Call with multiple arguments
    max_val = max(x, *y)
    
    # Method call on imported alias
    d = dd(list)
    d['key'].append(1)
    
    # Nested function calls
    result = abs(math.floor(sin_val))
    
    return result

class TestClass:
    def method1(self):
        # Method calling another method
        self.method2()
        
    def method2(self):
        print("Called method2")
"""

# Parse the test code
import ast
tree = ast.parse(test_code)

# Create visitor and visit the AST
visitor = PythonFunctionCallVisitor({
    'math': 'math',
    'dd': 'collections.defaultdict'
})
visitor.visit(tree)

# Print found function calls
print("Found function calls:")
for call in visitor.calls:
    print(f"\nFunction: {call.name}")
    print(f"Qualified name: {call.resolved_name}")
    print(f"Module name: {call.module_name}")
    print(f"Line number: {call.line_number}")


function_name: print
module_name: builtins
function_name: abs
module_name: builtins
function_name: math.sin
module_name: math.
function_name: max
module_name: builtins
function_name: dd
module_name: collections.defaultdict
function_name: append
module_name: None
function_name: abs
module_name: builtins
function_name: math.floor
module_name: math.
function_name: self.method2
module_name: None
function_name: print
module_name: builtins
Found function calls:

Function: print
Qualified name: None
Module name: builtins
Line number: 8

Function: abs
Qualified name: None
Module name: builtins
Line number: 11

Function: math.sin
Qualified name: None
Module name: math.
Line number: 14

Function: max
Qualified name: None
Module name: builtins
Line number: 17

Function: dd
Qualified name: None
Module name: collections.defaultdict
Line number: 20

Function: append
Qualified name: None
Module name: None
Line number: 21

Function: abs
Qualified name: None
Module name: builtins
Line number: 24

Funct

In [3]:
parser = PythonParser()

# Parse all Python files in the repo
modules = []
for file_path in repo_path.rglob('*.py'):
    if parser.can_parse(file_path):
        try:
            module = parser.parse_file(file_path)
            modules.append(module)
        except Exception as e:
            print(f"Error parsing {file_path}: {e}")

# Create and run indexer
indexer = RepoIndexer(modules)
indexer.index_repository()

extracting function calls in model_parallel.cross_entropy.vocab_parallel_cross_entropy(vocab_parallel_logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor
function_name: _VocabParallelCrossEntropy.apply
module_name: None
    extracted function calls internally: [FunctionCallElement(name='_VocabParallelCrossEntropy.apply', module_name=None, line_number=3, resolved_name=None)]
module.imports_mapping: {'torch': 'model_parallel.torch', 'get_model_parallel_group': 'model_parallel.initialize.get_model_parallel_group', 'get_model_parallel_rank': 'model_parallel.initialize.get_model_parallel_rank', 'get_model_parallel_world_size': 'model_parallel.initialize.get_model_parallel_world_size', 'VocabUtility': 'model_parallel.utils.VocabUtility'}
        call: _VocabParallelCrossEntropy.apply, resolved call: None
extracting function calls in model_parallel.cross_entropy._VocabParallelCrossEntropy.forward(ctx: Any, vocab_parallel_logits: Any, target: Any) -> Any
    extracted function calls in

In [4]:
indexer.symbol_table

{'model_parallel.cross_entropy._VocabParallelCrossEntropy': ClassElement(name='model_parallel.cross_entropy._VocabParallelCrossEntropy', path=PosixPath('/Users/qudi/Desktop/workspace/third-party/fairscale/fairscale/nn/model_parallel/cross_entropy.py'), start_line=29, end_line=101, module=ModuleElement(name='model_parallel.cross_entropy', path=PosixPath('/Users/qudi/Desktop/workspace/third-party/fairscale/fairscale/nn/model_parallel/cross_entropy.py'), language='Python', classes=[...], functions=[FunctionElement(name='model_parallel.cross_entropy.vocab_parallel_cross_entropy(vocab_parallel_logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor', path=PosixPath('/Users/qudi/Desktop/workspace/third-party/fairscale/fairscale/nn/model_parallel/cross_entropy.py'), start_line=104, end_line=106, module=..., documentation=DocumentationElement(content='Helper function for the cross entropy.', path='104', line_number=104, type='docstring', context=None), parameters=['vocab_parallel_logits: t

In [5]:
for module in modules:
    print(f"module: {module.name}")
    print(f"    imports_mapping: {module.imports_mapping}")
    print(f"    imports: {module.imports}")

module: model_parallel.cross_entropy
    imports_mapping: {'torch': 'model_parallel.torch', 'get_model_parallel_group': 'model_parallel.initialize.get_model_parallel_group', 'get_model_parallel_rank': 'model_parallel.initialize.get_model_parallel_rank', 'get_model_parallel_world_size': 'model_parallel.initialize.get_model_parallel_world_size', 'VocabUtility': 'model_parallel.utils.VocabUtility'}
    imports: ['torch', 'get_model_parallel_group', 'get_model_parallel_rank', 'get_model_parallel_world_size', 'VocabUtility']
module: model_parallel.initialize
    imports_mapping: {'List': 'model_parallel.typing.List', 'Optional': 'model_parallel.typing.Optional', 'torch': 'model_parallel.torch', 'timedelta': 'model_parallel.datetime.timedelta', 'ensure_divisibility': 'model_parallel.utils.ensure_divisibility'}
    imports: ['List', 'Optional', 'torch', 'timedelta', 'ensure_divisibility']
module: model_parallel.__init__
    imports_mapping: {'List': 'model_parallel.typing.List', 'vocab_parall

In [14]:


# Print some basic stats
print(f"Found {len(modules)} Python modules")
print(f"Total classes: {sum(len(m.classes) for m in modules)}")
print(f"Total functions: {sum(len(m.functions) for m in modules)}")

# Print details of each module
for module in modules:
    print(f"\nModule: {module.name}")
    
    print("Classes:")
    for class_elem in module.classes:
        print(f"  {class_elem.name}")
        print(f"    Methods: {len(class_elem.methods)}")
        print(f"    Base classes: {class_elem.base_classes}")
        
    print("Functions:")
    for func in module.functions:
        print(f"  {func.name}")
        if func.function_calls:
            print("    Calls:")
            for call in func.function_calls:
                print(f"      {call.name} -> {call.resolved_name}")


Found 7 Python modules
Total classes: 12
Total functions: 32

Module: model_parallel.cross_entropy
Classes:
  model_parallel.cross_entropy._VocabParallelCrossEntropy
    Methods: 2
    Base classes: ['torch.autograd.Function']
Functions:
  model_parallel.cross_entropy.vocab_parallel_cross_entropy(vocab_parallel_logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor
    Calls:
      apply -> None

Module: model_parallel.initialize
Classes:
Functions:
  model_parallel.initialize.initialize_model_parallel(model_parallel_size_: int, pipeline_length: int, context_parallel_size: int) -> None
    Calls:
      is_initialized -> None
      get_world_size -> None
      int -> int
      min -> min
      ensure_divisibility -> None
      ensure_divisibility -> None
      ensure_divisibility -> None
      get_rank -> None
      int -> int
      get_rank -> None
      print -> print
      format -> format
      print -> print
      format -> format
      print -> print
      format -> format
  

In [15]:
indexer.symbol_table

{'model_parallel.cross_entropy.model_parallel.cross_entropy._VocabParallelCrossEntropy': ClassElement(name='model_parallel.cross_entropy._VocabParallelCrossEntropy', path=PosixPath('/Users/qudi/Desktop/workspace/third-party/fairscale/fairscale/nn/model_parallel/cross_entropy.py'), start_line=29, end_line=101, module=ModuleElement(name='model_parallel.cross_entropy', path=PosixPath('/Users/qudi/Desktop/workspace/third-party/fairscale/fairscale/nn/model_parallel/cross_entropy.py'), language='Python', classes=[...], functions=[FunctionElement(name='model_parallel.cross_entropy.vocab_parallel_cross_entropy(vocab_parallel_logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor', path=PosixPath('/Users/qudi/Desktop/workspace/third-party/fairscale/fairscale/nn/model_parallel/cross_entropy.py'), start_line=104, end_line=106, module=..., documentation=DocumentationElement(content='Helper function for the cross entropy.', path='104', line_number=104, type='docstring', context=None), paramete

In [6]:
import ast

# Sample Python code to analyze
sample_code = """
import os
import sys as system
from pathlib import Path
from datetime import datetime as dt, timedelta
from json import dumps as to_json

def process_file(filepath):
    if system.version_info >= (3, 8):
        current_time = dt.now()
        path = Path(filepath)
        if path.exists():
            stats = os.stat(filepath)
            result = {
                'filename': path.name,
                'modified': current_time - timedelta(days=1),
                'size': stats.st_size
            }
            return to_json(result)
    return None
"""

# Create a simple test harness
def test_import_parsing():
    # Parse the code into an AST
    tree = ast.parse(sample_code)
    
    # Create an instance of our parser
    parser = PythonParser()
    
    # Track all imports
    imports_mapping = {}
    
    # Find and parse all import nodes
    for node in ast.walk(tree):
        if isinstance(node, (ast.Import, ast.ImportFrom)):
            imports_mapping.update(parser._parse_imports(node))
    
    # Print the resulting mapping
    print("Import mappings:")
    for local_name, full_path in imports_mapping.items():
        print(f"{local_name:12} -> {full_path}")

test_import_parsing()

Import mappings:
os           -> os
system       -> sys
Path         -> pathlib.Path
dt           -> datetime.datetime
timedelta    -> datetime.timedelta
to_json      -> json.dumps


In [10]:
import ast
from typing import Union, Dict
from pathlib import Path

def _parse_imports(node: Union[ast.Import, ast.ImportFrom], parent_module: str = '') -> Dict[str, str]:
    """Parse import statements and build a mapping."""
    imports_mapping = dict()
    top_level = parent_module.split('.')[0] if parent_module else ''
    
    print(f"\nProcessing import with parent_module: {parent_module}")
    print(f"Top level module: {top_level}")
    print(f"Node type: {type(node).__name__}")

    if isinstance(node, ast.Import):
        for alias in node.names:
            name = alias.name
            asname = alias.asname if alias.asname else alias.name.split('.')[0]
            if parent_module and not name.startswith('.'):
                if not name.startswith(top_level):
                    name = f"{top_level}.{name}"
            imports_mapping[asname] = name
    elif isinstance(node, ast.ImportFrom):
        module = node.module or ''
        print(f"Import from module: {module}")
        print(f"Level: {node.level}")
        
        # Handle relative imports
        if node.level > 0:  # This is a relative import
            if parent_module:
                parts = parent_module.split('.')
                # For level=1 (current directory), we want all parts
                # For level=2 (parent directory), we want all parts except last one
                base_path = parts[:-node.level] if node.level > 1 else parts
                module = '.'.join(base_path + ([module] if module else []))
                print(f"Parts: {parts}")
                print(f"Base path: {base_path}")
                print(f"Resolved relative import to: {module}")
        else:
            # For absolute imports
            if module == top_level:
                # Case: from model_parallel import utils
                for alias in node.names:
                    name = alias.name
                    asname = alias.asname if alias.asname else name
                    full_name = f"{module}.{name}"
                    imports_mapping[asname] = full_name
                return imports_mapping
            elif module.startswith(top_level + '.'):
                # Case: from model_parallel.utils import split_tensor
                pass  # Keep the module name as is
            elif parent_module and not module.startswith(top_level):
                # Other imports - prepend top-level module if needed
                module = f"{top_level}.{module}"
        
        for alias in node.names:
            name = alias.name
            asname = alias.asname if alias.asname else name
            full_name = f"{module}.{name}" if module else name
            imports_mapping[asname] = full_name
    
    return imports_mapping

def test_import_parsing(code: str, parent_module: str = ''):
    """Test import parsing with given code snippet."""
    print("\n" + "="*50)
    print(f"Testing code:\n{code}")
    
    tree = ast.parse(code)
    for node in ast.walk(tree):
        if isinstance(node, (ast.Import, ast.ImportFrom)):
            result = _parse_imports(node, parent_module)
            print("\nResult mapping:")
            for key, value in result.items():
                print(f"  {key} -> {value}")

def main():
    # Test cases
    test_cases = [
        # Regular imports
        ("import utils", "model_parallel"),
        ("import model_parallel.utils", "model_parallel"),
        
        # From imports
        ("from model_parallel import utils", "model_parallel"),
        ("from model_parallel.utils import split_tensor", "model_parallel"),
        
        # Relative imports
        ("from . import utils", "model_parallel.core"),
        ("from .utils import split_tensor", "model_parallel.core"),
        ("from ..utils import split_tensor", "model_parallel.core.submodule"),
        
        # Multiple imports
        ("""from model_parallel.utils import (
            split_tensor,
            combine_tensor
        )""", "model_parallel"),
        
        # Aliased imports
        ("from model_parallel.utils import split_tensor as split", "model_parallel"),
    ]
    
    for code, parent_module in test_cases:
        test_import_parsing(code, parent_module)

main()


Testing code:
import utils

Processing import with parent_module: model_parallel
Top level module: model_parallel
Node type: Import

Result mapping:
  utils -> model_parallel.utils

Testing code:
import model_parallel.utils

Processing import with parent_module: model_parallel
Top level module: model_parallel
Node type: Import

Result mapping:
  model_parallel -> model_parallel.utils

Testing code:
from model_parallel import utils

Processing import with parent_module: model_parallel
Top level module: model_parallel
Node type: ImportFrom
Import from module: model_parallel
Level: 0

Result mapping:
  utils -> model_parallel.utils

Testing code:
from model_parallel.utils import split_tensor

Processing import with parent_module: model_parallel
Top level module: model_parallel
Node type: ImportFrom
Import from module: model_parallel.utils
Level: 0

Result mapping:
  split_tensor -> model_parallel.utils.split_tensor

Testing code:
from . import utils

Processing import with parent_module: