In [54]:
import inspect
import importlib.util
import os
import ast
import re

def camel_to_snake(name):
    # Convert camel case to snake case
    name = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
    return re.sub('([a-z0-9])([A-Z])', r'\1_\2', name).lower()

def get_classes_and_create_files(file_path):
    # Load the module from the file
    spec = importlib.util.spec_from_file_location("module_name", file_path)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    
    # Get all classes defined in the module
    classes = inspect.getmembers(module, inspect.isclass)
    
    # Filter out classes that are imported (not defined in this file)
    classes = [cls for name, cls in classes if cls.__module__ == "module_name"]
    
    # Get the directory of the original file
    directory = os.path.dirname(file_path)
    
    # Read the content of the original file
    with open(file_path, 'r') as file:
        content = file.read()
    
    # Parse the content into an AST
    tree = ast.parse(content)
    
    # Extract all import statements
    import_statements = []
    for node in ast.iter_child_nodes(tree):
        if isinstance(node, (ast.Import, ast.ImportFrom)):
            import_statements.append(ast.get_source_segment(content, node))
    
    # Join import statements
    imports = '\n'.join(import_statements)
    
    # Find and write each class to a separate file
    for node in ast.walk(tree):
        if isinstance(node, ast.ClassDef):
            class_name = node.name
            snake_case_name = camel_to_snake(class_name)
            class_content = ast.get_source_segment(content, node)
            
            # Create a new file for the class with snake case name
            new_file_path = os.path.join(directory, f"{snake_case_name}.py")
            with open(new_file_path, 'w') as new_file:
                # Write imports first, then an empty line, then the class
                new_file.write(f"{imports}\n\n{class_content}")
            
            print(f"Created file: {new_file_path}")
    
    return classes

# Example usage
file_path = "/Users/alexander/Documents/GitHub/prefect/src/prefect/client/schemas/filters/__init__.py"
classes = get_classes_and_create_files(file_path)

for cls in classes:
    print(f"Found class: {cls.__name__}")

Created file: /Users/alexander/Documents/GitHub/prefect/src/prefect/client/schemas/filters/operator.py
Created file: /Users/alexander/Documents/GitHub/prefect/src/prefect/client/schemas/filters/operator_mixin.py
Created file: /Users/alexander/Documents/GitHub/prefect/src/prefect/client/schemas/filters/flow_filter_id.py
Created file: /Users/alexander/Documents/GitHub/prefect/src/prefect/client/schemas/filters/flow_filter_name.py
Created file: /Users/alexander/Documents/GitHub/prefect/src/prefect/client/schemas/filters/flow_filter_tags.py
Created file: /Users/alexander/Documents/GitHub/prefect/src/prefect/client/schemas/filters/flow_filter.py
Created file: /Users/alexander/Documents/GitHub/prefect/src/prefect/client/schemas/filters/flow_run_filter_id.py
Created file: /Users/alexander/Documents/GitHub/prefect/src/prefect/client/schemas/filters/flow_run_filter_name.py
Created file: /Users/alexander/Documents/GitHub/prefect/src/prefect/client/schemas/filters/flow_run_filter_tags.py
Created 

In [53]:
prefix = 'from typing import TYPE_CHECKING\n'
prefix += 'if TYPE_CHECKING:\t'
for import_ in [f'from .{camel_to_snake(x.__name__)} import {x.__name__}' for x in classes]:
    prefix+='\n\t' + import_
prefix += '\n\n_public_api={\n'
for (key, value) in {x.__name__: f'{camel_to_snake(x.__name__)}' for x in classes}.items():
    prefix+=f'\t"{key}": (__spec__.parent, ".{value}")\n'
prefix += '}\n'
prefix += '__all__=['
for (key, value) in {x.__name__: f'{camel_to_snake(x.__name__)}' for x in classes}.items():
    prefix += f'"{key}",\n'
prefix+=']'
# {x.__name__: f'(__spec__.main, .{camel_to_snake(x.__name__)})' for x in classes}
print(prefix)

from typing import TYPE_CHECKING
if TYPE_CHECKING:	
	from .artifact_create import ArtifactCreate
	from .artifact_update import ArtifactUpdate
	from .block_document_create import BlockDocumentCreate
	from .block_document_reference_create import BlockDocumentReferenceCreate
	from .block_document_update import BlockDocumentUpdate
	from .block_schema_create import BlockSchemaCreate
	from .block_type_create import BlockTypeCreate
	from .block_type_update import BlockTypeUpdate
	from .concurrency_limit_create import ConcurrencyLimitCreate
	from .concurrency_limit_v2_create import ConcurrencyLimitV2Create
	from .concurrency_limit_v2_update import ConcurrencyLimitV2Update
	from .deployment_create import DeploymentCreate
	from .deployment_flow_run_create import DeploymentFlowRunCreate
	from .deployment_schedule_create import DeploymentScheduleCreate
	from .deployment_schedule_update import DeploymentScheduleUpdate
	from .deployment_update import DeploymentUpdate
	from .flow_create import FlowCr

In [56]:
print('\n'.join([f'from .{camel_to_snake(x.__name__)} import {x.__name__}' for x in classes]))

from .artifact_collection_filter import ArtifactCollectionFilter
from .artifact_collection_filter_flow_run_id import ArtifactCollectionFilterFlowRunId
from .artifact_collection_filter_key import ArtifactCollectionFilterKey
from .artifact_collection_filter_latest_id import ArtifactCollectionFilterLatestId
from .artifact_collection_filter_task_run_id import ArtifactCollectionFilterTaskRunId
from .artifact_collection_filter_type import ArtifactCollectionFilterType
from .artifact_filter import ArtifactFilter
from .artifact_filter_flow_run_id import ArtifactFilterFlowRunId
from .artifact_filter_id import ArtifactFilterId
from .artifact_filter_key import ArtifactFilterKey
from .artifact_filter_task_run_id import ArtifactFilterTaskRunId
from .artifact_filter_type import ArtifactFilterType
from .block_document_filter import BlockDocumentFilter
from .block_document_filter_block_type_id import BlockDocumentFilterBlockTypeId
from .block_document_filter_id import BlockDocumentFilterId
from .block_

In [71]:
import os
import re

def remove_specific_imports(directory):
    # Ensure the directory path ends with a slash
    directory = os.path.join(directory, "")

    # Compile a regex pattern to match the import statements we want to remove
    # This pattern matches "from .snake_case import CamelCase"
    pattern = re.compile(r'^from\s+\.(\w+)\s+import\s+(\w+)$')

    # Loop through all .py files in the directory
    for filename in os.listdir(directory):
        if filename.endswith('.py'):
            file_path = os.path.join(directory, filename)
            
            # Read the content of the file
            with open(file_path, 'r') as file:
                lines = file.readlines()

            # Flag to check if any changes were made
            changes_made = False

            # Process each line
            new_lines = []
            for line in lines:
                match = pattern.match(line.strip())
                if match:
                    snake_case = match.group(1)
                    camel_case = match.group(2)
                    if snake_case == '_'.join(re.findall('[A-Z][^A-Z]*', camel_case)).lower():
                        # This is an import we want to remove
                        changes_made = True
                        print(f"Removing import: {line.strip()}")
                    else:
                        new_lines.append(line)
                else:
                    new_lines.append(line)

            # If changes were made, write the new content back to the file
            if changes_made:
                with open(file_path, 'w') as file:
                    file.writelines(new_lines)
                print(f"Updated {filename}")
            else:
                print(f"No changes needed in {filename}")
remove_specific_imports('/Users/alexander/Documents/GitHub/prefect/src/prefect/client/schemas/filters/')

Removing import: from .operator import Operator
Updated operator_mixin.py
Removing import: from .deployment_filter_tags import DeploymentFilterTags
Removing import: from .operator_mixin import OperatorMixin
Updated deployment_filter_tags.py
Removing import: from .operator_mixin import OperatorMixin
Removing import: from .worker_filter_last_heartbeat_time import WorkerFilterLastHeartbeatTime
Updated worker_filter.py
Removing import: from .artifact_filter_key import ArtifactFilterKey
Updated artifact_filter_key.py
No changes needed in variable_filter_id.py
No changes needed in log_filter_timestamp.py
Removing import: from .block_document_filter_block_type_id import BlockDocumentFilterBlockTypeId
Updated block_document_filter_block_type_id.py
Removing import: from .block_type_filter_slug import BlockTypeFilterSlug
Updated block_type_filter_slug.py
No changes needed in log_filter_level.py
Removing import: from .block_type_filter_name import BlockTypeFilterName
Updated block_type_filter_nam

In [134]:
import jinja2
import inspect
file_path = '/Users/alexander/Documents/GitHub/prefect/src/prefect/client/schemas/filters/__init__.py'

def get_classes(file_path: str):
    spec = importlib.util.spec_from_file_location("module_name", file_path)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    
    # Get all classes defined in the module
    classes = inspect.getmembers(module, inspect.isclass)
    
    # Filter out classes that are imported (not defined in this file)
    classes = [cls.__name__ for name, cls in classes if cls.__module__ == "module_name"]
    return classes

def get_classes_and_create_files(file_path):
    # Load the module from the file
    spec = importlib.util.spec_from_file_location("module_name", file_path)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    
    # Get all classes defined in the module
    classes = inspect.getmembers(module, inspect.isclass)
    
    # Filter out classes that are imported (not defined in this file)
    classes = [cls for name, cls in classes if cls.__module__ == "module_name"]
    
    # Get the directory of the original file
    directory = os.path.dirname(file_path)
    
    # Read the content of the original file
    with open(file_path, 'r') as file:
        content = file.read()
    
    # Parse the content into an AST
    tree = ast.parse(content)
    
    # Extract all import statements
    import_statements = []
    for node in ast.iter_child_nodes(tree):
        if isinstance(node, (ast.Import, ast.ImportFrom)):
            import_statements.append(ast.get_source_segment(content, node))
    
    # Join import statements
    imports = '\n'.join(import_statements)
    
    # Find and write each class to a separate file
    for node in ast.walk(tree):
        if isinstance(node, ast.ClassDef):
            class_name = node.name
            snake_case_name = camel_to_snake(class_name)
            class_content = ast.get_source_segment(content, node)
            
            # Create a new file for the class with snake case name
            new_file_path = os.path.join(directory, f"{snake_case_name}.py")
            with open(new_file_path, 'w') as new_file:
                # Write imports first, then an empty line, then the class
                for class_ in classes:
                    if snake_case_name != camel_to_snake(class_.__name__):
                        new_file.write(f"from .{camel_to_snake(class_.__name__)} import {class_.__name__}\n")
                new_file.write(f"{imports}\n\n{class_content}")
            
            print(f"Created file: {new_file_path}")

def write_init(file_path: str): 
    classes = get_classes(file_path)
    with open(file_path, 'w') as new_file:
        new_file.write(jinja2.Template(inspect.cleandoc(
    '''
    from typing import TYPE_CHECKING
    
    if TYPE_CHECKING:
        {% for class in classes -%}
        from .{{camel_to_snake(class)}} import {{class}}
        {% endfor %}
    
    _public_api: dict[str, tuple[str, str]] = {
        {% for class in classes %}"{{class}}": (__spec__.parent, ".{{camel_to_snake(class)}}"),
        {% endfor %}
        }

    __all__ = [
        {% for class in classes -%}
        "{{class}}",
        {% endfor %}  
    ]
    
    def __getattr__(attr_name: str) -> object:
        dynamic_attr = _public_api.get(attr_name)
        if dynamic_attr is None:
            return importlib.import_module(f".{attr_name}", package=__name__)
    
        package, module_name = dynamic_attr
    
        from importlib import import_module
    
        if module_name == "__module__":
            return import_module(f".{attr_name}", package=package)
        else:
            module = import_module(module_name, package=package)
            return getattr(module, attr_name)
    
        
    ''')).render(classes = classes, camel_to_snake = camel_to_snake).strip())

# get_classes_and_create_files(file_path)
write_init(file_path)