In [10]:
import torch
import inspect
from collections import deque, defaultdict

# Step 1: Identify C-level functions
def is_c_level_function(func):
    return not hasattr(func, '__code__')

# Step 2: Build a call graph
call_graph = defaultdict(set)
functions = [getattr(torch, attr) for attr in dir(torch) if callable(getattr(torch, attr))]

print(functions)


for func in functions:
    if not is_c_level_function(func):
        source_code = inspect.getsource(func)
        called_functions = [getattr(torch, line.strip()) for line in source_code.split() if callable(getattr(torch, line.strip(), None))]
        call_graph[func].update(called_functions)
    else:
        call_graph[func] = set()

# number of functions in the call graph
print(len(call_graph))

# Step 3: Calculate distances using BFS
def calculate_distances(call_graph):
    distances = {}
    queue = deque()

    # Initialize distances
    for func in call_graph:
        if is_c_level_function(func):
            distances[func] = 1
            queue.append((func, 1))
        else:
            distances[func] = float('inf')
    while queue:
        current_func, current_distance = queue.popleft()
        for called_func in call_graph[current_func]:
            new_distance = current_distance + 1
            if called_func not in distances or new_distance < distances[called_func]:
                distances[called_func] = new_distance
                queue.append((called_func, new_distance))

    return distances

distances = calculate_distances(call_graph)

# Print the distances
for func, dist in distances.items():
    print(f"{func.__name__}: {dist}")


[<class 'torch.AggregationType'>, <class 'torch.AliasDb'>, typing.Any, <class 'torch.AnyType'>, <class 'torch.Argument'>, <class 'torch.ArgumentSpec'>, <class 'torch.AwaitType'>, <class 'torch.BFloat16Storage'>, <class 'torch.BFloat16Tensor'>, <class 'torch.BenchmarkConfig'>, <class 'torch.BenchmarkExecutionStats'>, <class 'torch.Block'>, <class 'torch.BoolStorage'>, <class 'torch.BoolTensor'>, <class 'torch.BoolType'>, <class 'torch.BufferDict'>, <class 'torch.ByteStorage'>, <class 'torch.ByteTensor'>, <class 'torch.CallStack'>, typing.Callable, <class 'torch.Capsule'>, <class 'torch.CharStorage'>, <class 'torch.CharTensor'>, <class 'torch.ClassType'>, <class 'torch.Code'>, <class 'torch.jit.CompilationUnit'>, <class 'torch.CompleteArgumentSpec'>, <class 'torch.ComplexDoubleStorage'>, <class 'torch.ComplexFloatStorage'>, <class 'torch.ComplexType'>, <class 'torch.ConcreteModuleType'>, <class 'torch.ConcreteModuleTypeBuilder'>, <class 'torch.DeepCopyMemoTable'>, <class 'torch.Deseriali

In [18]:
import torch
import inspect
import re
from collections import deque, defaultdict

# Helper function to check if a function is a C-level function
def is_c_level_function(func):
    return not hasattr(func, '__code__')

# Helper function to get all functions in a module (including submodules)
def get_functions(module, visited=None):
    if visited is None:
        visited = set()
    elif module in visited:
        return []
    visited.add(module)

    funcs = []
    for name in dir(module):
        obj = getattr(module, name)
        if inspect.isfunction(obj) or inspect.ismethod(obj):
            funcs.append(obj)
        elif inspect.ismodule(obj):
            funcs.extend(get_functions(obj, visited))
    return funcs

# Build the call graph
def build_call_graph(funcs):
    call_graph = defaultdict(list)
    for func in funcs:
        # skip C-level functions
        if is_c_level_function(func):
            continue
        try:
            source = inspect.getsource(func)
            for line in source.splitlines():
                # Ignore comments
                line = re.sub(r'#.*', '', line)
                for other_func in funcs:
                    if other_func.__name__ in line and other_func != func:
                        # skip redundant calls
                        if other_func not in call_graph[func]:
                            call_graph[func].append(other_func)
        except (TypeError, OSError):
            continue
    return call_graph

# Compute the distance to the nearest C-level function using BFS
def compute_distances(call_graph, funcs):
    distances = {func: float('inf') for func in funcs}
    visited = set()
    queue = deque()

    # Initialize the queue with C-level functions
    for func in funcs:
        if is_c_level_function(func):
            distances[func] = 0
            queue.append(func)
            visited.add(func)

    # Perform BFS
    while queue:
        current_func = queue.popleft()
        current_distance = distances[current_func]
        for neighbor in call_graph[current_func]:
            if neighbor not in visited and distances[neighbor] > current_distance + 1:
                distances[neighbor] = current_distance + 1
                queue.append(neighbor)
                visited.add(neighbor)

    return distances

# Get all functions in the torch library
torch_functions = get_functions(torch)

print(len(torch_functions))

# Build the call graph
call_graph = build_call_graph(torch_functions)

print("call_graph is built")
# Compute distances
distances = compute_distances(call_graph, torch_functions)

# Print distances
for func, distance in distances.items():
    print(f"Function: {func.__name__}, Distance to C-level function: {distance}")


9083


KeyboardInterrupt: 

In [1]:
import torch
import inspect

def get_functions(module, prefix='torch'):
    functions = {}
    for name, obj in inspect.getmembers(module):
        full_name = f"{prefix}.{name}"
        if inspect.isfunction(obj):
            functions[full_name] = obj
        elif inspect.ismodule(obj) and obj.__name__.startswith('torch'):
            # Recursively get functions in submodules
            functions.update(get_functions(obj, full_name))
    return functions

# Get all functions in the torch library and its submodules
torch_functions = get_functions(torch)

# Print all functions
for func_name in sorted(torch_functions.keys()):
    print(func_name)


RecursionError: maximum recursion depth exceeded while calling a Python object