In [33]:
import ast

from ast import AST, stmt, expr, Load, Store, dump, parse

from ast import Name

from ast import NodeVisitor

p_dump = lambda node: print(dump(node, indent=2))

In [37]:

source = """
self.list.clear()
x = 42
l = api.fetch()
if not l: return
self.list.extend(l)
a = 1
b = a + 1
c = b + 1
print(a)
d = (a + b + c) * x * 2
print(d)
"""

In [38]:
root = parse(source)
p_dump(parse("exit()"))

Module(
  body=[
    Expr(
      value=Call(
        func=Name(id='exit', ctx=Load()),
        args=[],
        keywords=[]))],
  type_ignores=[])


In [39]:
class NameVisitor(NodeVisitor):
    
    def get_names(self, node: AST, id: str = None) -> list:
        self._id    = id
        self._names = []
        self.visit(node)
        return self._names

    def visit_Name(self, node: ast.Name):
        match self._id:
            case None:
                self._names.append(node)
            case id:
                if id != node.id: return
                self._names.append(node)

class Node:
    
    @staticmethod
    def all_names(node: AST, id: str = None) -> list[Name]:
        """Returns a list the list of all `ast.Name` nodes inside `node`.
        """
        visitor = NameVisitor()
        mentions = visitor.get_names(node, id)
        return mentions
    
    @staticmethod
    def has_names(node: AST, id: str = None) -> bool:
        visitor = NameVisitor()
        mentions = visitor.get_names(node, id)
        return bool(len(mentions))

 
def all_ids(node: AST) -> list[str]: return [name.id for name in Node.all_names(node)]

def isexitcall(node: AST) -> bool:
    match node:
        case ast.Call(
            func=Name(id='exit', ctx=Load()), args=_, keywords=_):
            return True
        case _:
            return False
                
def move_down_node(node: AST, body: list[AST]):
    if isinstance(node, ast.Return) or isexitcall(node):
        return
    
    i, last = body.index(node), len(body) - 1
    
    if i == last:
        return
    
    while i <= last - 1 and not (
        set(all_ids(body[i]))
        &
        set(all_ids(body[i + 1]))
    ):
        if isinstance(body[i + 1], ast.Return) or isexitcall(body[i + 1]):
            break
        
        body[i], body[i + 1] = body[i + 1], body[i]
        i += 1


class NodeMover(NodeVisitor):
    def generic_visit(self, node):
        super().generic_visit(node)
        match node:
        # match body with at least 2 elements
            case AST(body=[_, _, *_] as body):
                for node in reversed(body[:-1]):
                    move_down_node(node, body)
            case _:
                pass
        #return node


print(ast.unparse(root))
print('-'*100)
mover = NodeMover()
mover.visit(root)
print('-'*100)
print(ast.unparse(root))


self.list.clear()
x = 42
l = api.fetch()
if not l:
    return
self.list.extend(l)
a = 1
b = a + 1
c = b + 1
print(a)
d = (a + b + c) * x * 2
print(d)
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
a = 1
b = a + 1
print(a)
c = b + 1
x = 42
d = (a + b + c) * x * 2
print(d)
l = api.fetch()
if not l:
    return
self.list.clear()
self.list.extend(l)
