In [43]:
import attr
import ast
import logging


@attr.s
class PythonModuleSplitter:

    """
    A class designed to split Python modules into smaller parts based on class definitions,
    optionally including or excluding docstrings and class headers from the output.

    This class is useful for processing and analyzing Python code at a granular level,
    especially when working with large modules.

    Attributes
    ----------
    module_path : str, optional
        Path to the Python module file to be split. Default is None.
    include_docstrings : bool, optional
        Flag to include docstrings in the output. Default is True.
    include_class_header : bool, optional
        Flag to include the class header in the output. Default is True.
    module_content : str
        The content of the Python module file as a string. This attribute is initialized after the class is instantiated.
    module_content_no_docstring : str
        The content of the Python module file as a string, with docstrings removed. This attribute is initialized after the class is instantiated.
    logger : logging.Logger, optional
        Custom logger for logging information. Default is None, which causes the class to initialize a new logger.
    logger_name : str, optional
        Name of the logger to be initialized if no custom logger is provided. Default is 'Python Module Splitter'.
    loggerLvl : logging.LEVEL, optional
        Logging level for the logger. Default is logging.INFO.
    logger_format : str, optional
        Logging format for the logger. Default is None, which uses the basic logging format.

    """

    module_path = attr.ib(default=None, type = str)
    include_docstrings = attr.ib(default = True, type = bool)
    include_class_header = attr.ib(default = True, type = bool)

    module_content = attr.ib(init=False)
    module_content_no_docstring = attr.ib(init=False)

    logger = attr.ib(default=None)
    logger_name = attr.ib(default='Python Module Splitter')
    loggerLvl = attr.ib(default=logging.INFO)
    logger_format = attr.ib(default=None)

    def __attrs_post_init__(self):
        self._initialize_logger()
        if self.module_path:
            self.get_module_code()

    def _initialize_logger(self):

        """
        Initialize a logger for the class instance based on the specified logging level and logger name.
        """

        if self.logger is None:
            logging.basicConfig(level=self.loggerLvl, format=self.logger_format)
            logger = logging.getLogger(self.logger_name)
            logger.setLevel(self.loggerLvl)

            self.logger = logger

    def get_module_code(self, module_path : str = None):

        if module_path is None:
            module_path = self.module_path
            return_content = False
        else:
            return_content = True

        if module_path is None:
            raise ValueError("Module was not provided!")

        with open(module_path, 'r') as file:
            module_content = file.read()

        self._remove_docstrings(module_content = module_content)

        self.module_content = module_content

        if return_content:
            return module_content

    def _remove_docstrings(self, module_content : str = None):

        if module_content is None:
            module_content = self.module_content

        class DocstringRemover(ast.NodeTransformer):
            def visit_FunctionDef(self, node):
                """Strip docstring from a function definition."""
                self.generic_visit(node)
                if ast.get_docstring(node):
                    node.body = node.body[1:]
                return node

            def visit_ClassDef(self, node):
                """Strip docstring from a class definition."""
                self.generic_visit(node)
                if ast.get_docstring(node):
                    node.body = node.body[1:]
                return node

        parsed_tree = ast.parse(module_content)
        docstring_remover = DocstringRemover()
        modified_tree = docstring_remover.visit(parsed_tree)
        self.module_content_no_docstring =  ast.unparse(modified_tree)

    def split_text(self,
                   text : str = None,
                   include_docstrings: bool = None,
                   include_class_header : bool = None):

        if include_docstrings is None:
            include_docstrings = self.include_docstrings

        if include_class_header is None:
            include_class_header = self.include_class_header

        if text is None:

            if include_docstrings:
                module_content = self.module_content
            else:
                module_content = self.module_content_no_docstring
        else:
            module_content = text


        tree = ast.parse(module_content)
        class_definitions = [node for node in tree.body if isinstance(node, ast.ClassDef)]

        segments = []
        for class_def in class_definitions:

            if include_class_header:

                # Find the first method
                first_method_index = next((i for i, n in enumerate(class_def.body) if isinstance(n, ast.FunctionDef)), None)

                # If there's no method, continue to next class
                if first_method_index is None:
                    continue

                # Everything before the first method
                pre_method_body = class_def.body[:first_method_index]


            for method in [n for n in class_def.body if isinstance(n, ast.FunctionDef)]:

                body_code = []
                if include_class_header:
                    # Class body consists of pre-method part and the current method
                    body_code = pre_method_body
                body_code = body_code + [method]

                class_copy = ast.ClassDef(name=class_def.name,
                                        bases=class_def.bases,
                                        keywords=class_def.keywords,
                                        body=body_code,
                                        decorator_list=class_def.decorator_list)
                class_code = ast.unparse(class_copy)
                segments.append(class_code)

        return segments


In [44]:
# Example usage

ps = PythonModuleSplitter(module_path = '../python_modules/mocker_db.py',
                          include_docstrings = False)

splits = ps.split_text()

for split in splits:
    print(split)
    print("--------------------")

class SentenceTransformerEmbedder:

    def __init__(self, tbatch_size=32, processing_type='batch', max_workers=2, *args, **kwargs):
        logging.getLogger('sentence_transformers').setLevel(logging.ERROR)
        self.tbatch_size = tbatch_size
        self.processing_type = processing_type
        self.max_workers = max_workers
        self.model = SentenceTransformer(*args, **kwargs)
class SentenceTransformerEmbedder:

    def embed_sentence_transformer(self, text):
        return self.model.encode(text)
class SentenceTransformerEmbedder:

    def embed(self, text, processing_type: str=None):
        if processing_type is None:
            processing_type = self.processing_type
        if processing_type == 'batch':
            return self.embed_texts_in_batches(texts=text)
        if processing_type == 'parallel':
            return self.embed_sentences_in_batches_parallel(texts=text)
        return self.embed_sentence_transformer(text=str(text))
class SentenceTransformerEmbedder:

In [54]:
import ast
import logging

@attr.s
class PythonCodeAnalyzer(ast.NodeVisitor):

    """
    A class to analyze Python code files for various aspects such as defined functions,
    call chains within those functions, and class definitions. It leverages the AST
    (Abstract Syntax Tree) module to parse and visit nodes in the syntax tree of a Python file.

    Attributes
    ----------
    filename : str
        The path to the Python file to be analyzed.
    logger : logging.Logger, optional
        Custom logger for logging information. If not provided, a new logger will be initialized.
    logger_name : str, optional
        The name of the logger to use or initialize. Default is 'Python Code Analyzer'.
    loggerLvl : logging.Level, optional
        The logging level for the logger. Default is logging.INFO.
    logger_format : str, optional
        The logging format to use for the logger. Default is None, which uses the basic logging format.
    defined_functions : set
        A set of names of all functions and methods defined in the analyzed file. Populated after file parsing.
    call_chains : dict
        A dictionary mapping function names to lists of functions they call. Populated after file parsing.
    classes : dict
        A dictionary mapping class names to lists of their method names. Populated after file parsing.


    """

    filename = attr.ib()

    logger = attr.ib(default=None)
    logger_name = attr.ib(default='Python Code Analyzer')
    loggerLvl = attr.ib(default=logging.INFO)
    logger_format = attr.ib(default=None)


    def __attrs_post_init__(self):
        self._initialize_logger()

        self.defined_functions = set()  # Stores names of defined functions and methods
        self.call_chains = {}  # Stores call chains
        self.classes = {}

    def _initialize_logger(self):

        """
        Initialize a logger for the class instance based on the specified logging level and logger name.
        """

        if self.logger is None:
            logging.basicConfig(level=self.loggerLvl, format=self.logger_format)
            logger = logging.getLogger(self.logger_name)
            logger.setLevel(self.loggerLvl)

            self.logger = logger

    def visit_FunctionDef(self, node):

        """
        Visits FunctionDef nodes in the AST. It registers the function's name, adds it to the set of
        defined functions, and tracks the call chain within the function.

        Parameters
        ----------
        node : ast.FunctionDef
            The node representing a function definition in the AST.
        """

        function_name = node.name
        self.defined_functions.add(function_name)
        self.call_chains[function_name] = self._find_function_calls(node)
        self.generic_visit(node)

    def visit_ClassDef(self, node):

        """
        Visits ClassDef nodes in the AST. It registers the class's name and stores the names of its methods.

        Parameters
        ----------
        node : ast.ClassDef
            The node representing a class definition in the AST.
        """

        class_name = node.name
        self.classes[class_name] = []
        for item in node.body:
            if isinstance(item, ast.FunctionDef):
                self.classes[class_name].append(item.name)
        self.generic_visit(node)

    def _find_function_calls(self, node):

        """
        Identifies and returns function calls within a given node. It checks if the function called
        is defined within the file and adds it to the call chain.

        Parameters
        ----------
        node : ast.Node
            The node to search for function calls within.

        Returns
        -------
        list of str
            A list of names of functions called within the node.
        """

        calls = []
        for n in ast.walk(node):
            if isinstance(n, ast.Call) and isinstance(n.func, (ast.Attribute, ast.Name)):
                called_func_name = ''
                if isinstance(n.func, ast.Attribute):
                    if isinstance(n.func.value, ast.Name) and n.func.value.id == 'self':
                        called_func_name = n.func.attr
                elif isinstance(n.func, ast.Name):
                    called_func_name = n.func.id

                if called_func_name in self.defined_functions:
                    calls.append(called_func_name)
        return calls

    def parse_file(self):

        """
        Parses the Python file specified by `filename`, visiting nodes to collect information on
        function definitions, call chains, and class definitions.
        """

        with open(self.filename, 'r') as file:
            source_code = file.read()
        tree = ast.parse(source_code)
        self.visit(tree)

    def report(self):

        """
        Prints a report of the function call chains and class method definitions found in the file.
        """

        print("Function Call Chains:")
        for func, calls in self.call_chains.items():
            print(f"{func} calls: {', '.join(calls) if calls else 'None'}")


In [55]:
# Example usage
analyzer = PythonCodeAnalyzer("../python_modules/mocker_db.py")
analyzer.parse_file()
analyzer.report()

Function Call Chains:
__init__ calls: None
embed_sentence_transformer calls: None
embed calls: embed_sentence_transformer
embed_texts_in_batches calls: None
embed_sentences_in_batches_parallel calls: None
__attrs_post_init__ calls: _initialize_logger
_initialize_logger calls: None
hnsw_search calls: None
linear_search calls: None
search calls: linear_search, hnsw_search
establish_connection calls: None
save_data calls: None
hash_string_sha256 calls: None
_make_key calls: hash_string_sha256
_make_embs_key calls: hash_string_sha256
_insert_values_dict calls: save_data, _make_embs_key
insert_values calls: _make_key, _insert_values_dict
flush_database calls: save_data
filter_keys calls: None
filter_database calls: None
remove_from_database calls: save_data
search_database_keys calls: _make_embs_key
get_dict_results calls: None
search_database calls: search_database_keys, get_dict_results, filter_database


In [53]:
analyzer.classes

{'SentenceTransformerEmbedder': ['__init__',
  'embed_sentence_transformer',
  'embed',
  'embed_texts_in_batches',
  'embed_sentences_in_batches_parallel'],
 'MockerSimilaritySearch': ['__attrs_post_init__',
  '_initialize_logger',
  'hnsw_search',
  'linear_search',
  'search'],
 'MockerDB': ['__attrs_post_init__',
  '_initialize_logger',
  'establish_connection',
  'save_data',
  'hash_string_sha256',
  '_make_key',
  '_make_embs_key',
  '_insert_values_dict',
  'insert_values',
  'flush_database',
  'filter_keys',
  'filter_database',
  'remove_from_database',
  'search_database_keys',
  'get_dict_results',
  'search_database']}

In [51]:
analyzer.call_chains

{'__init__': [],
 'embed_sentence_transformer': [],
 'embed': ['embed_sentence_transformer'],
 'embed_texts_in_batches': [],
 'embed_sentences_in_batches_parallel': [],
 '__attrs_post_init__': ['_initialize_logger'],
 '_initialize_logger': [],
 'hnsw_search': [],
 'linear_search': [],
 'search': ['linear_search', 'hnsw_search'],
 'establish_connection': [],
 'save_data': [],
 'hash_string_sha256': [],
 '_make_key': ['hash_string_sha256'],
 '_make_embs_key': ['hash_string_sha256'],
 '_insert_values_dict': ['save_data', '_make_embs_key'],
 'insert_values': ['_make_key', '_insert_values_dict'],
 'flush_database': ['save_data'],
 'filter_keys': [],
 'filter_database': [],
 'remove_from_database': ['save_data'],
 'search_database_keys': ['_make_embs_key'],
 'get_dict_results': [],
 'search_database': ['search_database_keys',
  'get_dict_results',
  'filter_database']}