In [2]:
import ast


class CryptoVarExtractor(ast.NodeTransformer):
    def __init__(self):
        self.next_var_id = 0

        # Instead of global mapping, we keep per-scope mapping
        self.scope_stack = []  # Each element: {"mapping":{}, "assigns":[]}

    def _new_scope(self):
        scope = {"mapping": {}, "assigns": []}
        self.scope_stack.append(scope)
        return scope

    def _end_scope(self):
        return self.scope_stack.pop()

    def _make_var_name(self):
        name = f"cond_{self.next_var_id}"
        self.next_var_id += 1
        return name

    def _replace_condition(self, node):
        """Return ast.Name referencing a variable for this condition."""
        scope = self.scope_stack[-1]

        if id(node) not in scope["mapping"]:
            var_name = self._make_var_name()
            scope["mapping"][id(node)] = var_name

            assign = ast.Assign(
                targets=[ast.Name(id=var_name, ctx=ast.Store())],
                value=node
            )
            ast.fix_missing_locations(assign)
            scope["assigns"].append(assign)

        return ast.Name(id=scope["mapping"][id(node)], ctx=ast.Load())

    def visit_FunctionDef(self, node):
        self._new_scope()

        self.generic_visit(node)

        scope = self._end_scope()

        node.body = scope["assigns"] + node.body

        return node

    def visit_Module(self, node):
        self._new_scope()

        self.generic_visit(node)

        scope = self._end_scope()

        node.body = scope["assigns"] + node.body

        return node

    def visit_If(self, node):
        self.generic_visit(node)
        node.test = self._replace_condition(node.test)
        return node

    def visit_Compare(self, node):
        return self._replace_condition(self.generic_visit(node))

    def visit_BoolOp(self, node):
        return self._replace_condition(self.generic_visit(node))

    def get_refactored_code(self, source_code):
        
        try:
            tree = ast.parse(source_code)
            self.next_var_id = 0 
            new_tree = self.visit(tree)
            result = ast.unparse(new_tree)

        except Exception as e:
            raise ValueError(f"Parse transform failed: {e}")


        return result
if __name__ == "__main__":
    src = """ 
a = 1
b = 4 
c = 3
d = 2

def fun():
    if a > b and (a != c) or (d != b):
        print("condition check 2")

if (a>b) or ((a !=b) or (c < d)) or ((c > a) and (b > d)):
    print("condition check")
else:
    pass
"""

    extractor = CryptoVarExtractor()
    print(extractor.get_refactored_code(src))
    
    src = """ 
a = 1
b = 4 
c = 3
d = 2

def fun():
    if a > b and (a != c) or (d != b):
        print("condition check 2")

if (a>b) or ((a !=b) or (c < d)) or ((c > a) and (b > d)):
    print("condition check")
else:
    pass
"""
    print("\n**********************************\n")
    extractor = CryptoVarExtractor()
    print(extractor.get_refactored_code(src))

cond_6 = a > b
cond_7 = a != b
cond_8 = c < d
cond_9 = cond_7 or cond_8
cond_10 = c > a
cond_11 = b > d
cond_12 = cond_10 and cond_11
cond_13 = cond_6 or cond_9 or cond_12
cond_14 = cond_13
a = 1
b = 4
c = 3
d = 2

def fun():
    cond_0 = a > b
    cond_1 = a != c
    cond_2 = cond_0 and cond_1
    cond_3 = d != b
    cond_4 = cond_2 or cond_3
    cond_5 = cond_4
    if cond_5:
        print('condition check 2')
if cond_14:
    print('condition check')
else:
    pass

**********************************

cond_6 = a > b
cond_7 = a != b
cond_8 = c < d
cond_9 = cond_7 or cond_8
cond_10 = c > a
cond_11 = b > d
cond_12 = cond_10 and cond_11
cond_13 = cond_6 or cond_9 or cond_12
cond_14 = cond_13
a = 1
b = 4
c = 3
d = 2

def fun():
    cond_0 = a > b
    cond_1 = a != c
    cond_2 = cond_0 and cond_1
    cond_3 = d != b
    cond_4 = cond_2 or cond_3
    cond_5 = cond_4
    if cond_5:
        print('condition check 2')
if cond_14:
    print('condition check')
else:
    pass


Lambda Refactor

In [13]:
import ast

class LambdaRefactor(ast.NodeTransformer):



    @staticmethod
    def has_decorators(func_def: ast.FunctionDef) -> bool:
        return bool(func_def.decorator_list)

    def _function_to_lambda(self, func_def: ast.FunctionDef):
        if len(func_def.body) != 1:
            return None

        stmt = func_def.body[0]
        if not isinstance(stmt, ast.Return) or stmt.value is None:
            return None

        if self.has_decorators(func_def):
            return None

        return ast.Assign(
            targets=[ast.Name(id=func_def.name, ctx=ast.Store())],
            value=ast.Lambda(
                args=func_def.args,
                body=stmt.value
            )
        )

    def _lambda_to_function(self, assign: ast.Assign):
        if len(assign.targets) != 1:
            return None

        target = assign.targets[0]
        if not isinstance(target, ast.Name):
            return None

        if not isinstance(assign.value, ast.Lambda):
            return None

        lam = assign.value

        return ast.FunctionDef(
            name=target.id,
            args=lam.args,
            body=[ast.Return(value=lam.body)],
            decorator_list=[]
        )

    def visit_Module(self, node: ast.Module):
        new_body = []
        for stmt in node.body:
            if isinstance(stmt, ast.FunctionDef):
                converted = self._function_to_lambda(stmt)
                if converted:
                    new_body.append(converted)
                    continue

            if isinstance(stmt, ast.Assign):
                converted = self._lambda_to_function(stmt)
                if converted:
                    new_body.append(converted)
                    continue

            new_body.append(self.visit(stmt))

        node.body = new_body
        return node

    def visit_ClassDef(self, node: ast.ClassDef):
        new_body = []
        for stmt in node.body:
            if isinstance(stmt, ast.FunctionDef):
                converted = self._function_to_lambda(stmt)
                if converted:
                    new_body.append(converted)
                    continue

            if isinstance(stmt, ast.Assign):
                converted = self._lambda_to_function(stmt)
                if converted:
                    new_body.append(converted)
                    continue

            new_body.append(self.visit(stmt))

        node.body = new_body
        return node

    def visit_FunctionDef(self, node: ast.FunctionDef):
        node.body = [self.visit(stmt) for stmt in node.body]
        return node

    def get_refactored_code(self, source_code: str) -> str:
        try:
            tree = ast.parse(source_code)
            tree = self.visit(tree)
            ast.fix_missing_locations(tree)
            return ast.unparse(tree)
        except SyntaxError as e:
            raise ValueError(f"Syntax error in source code: {e}")



if __name__ == "__main__":
    code = """
class Math:
    def add(a, b):
        return a + b

    subtract = lambda x, y: x - y

    def mul(x): return x * 2

    div = lambda a, b=1: a / b
"""

    refactored = LambdaRefactor().get_refactored_code(code)
    print(refactored)

class Math:
    add = lambda a, b: a + b

    def subtract(x, y):
        return x - y
    mul = lambda x: x * 2

    def div(a, b=1):
        return a / b


elseif and elif

In [27]:
import libcst as cst

class ElseTransformer(cst.CSTTransformer):
    
    def leave_If(self, original: cst.If, updated: cst.If) -> cst.If:
        return self._process_chain(updated)

    def _process_chain(self, node: cst.If) -> cst.If:
        chain = self._collect_chain(node)
        has_elif = any(isinstance(n.orelse, cst.If) for n in chain)
        has_else_if = any(
            isinstance(n.orelse, cst.Else) and len(n.orelse.body.body) == 1 and isinstance(n.orelse.body.body[0], cst.If)
            for n in chain
        )

        if has_elif:
            return self._convert_elif_to_else_if(node)
        elif has_else_if:
            return self._convert_else_if_to_elif(node)
        return node

    def _collect_chain(self, node: cst.If):
        chain = []
        cur = node
        while isinstance(cur, cst.If):
            chain.append(cur)
            if isinstance(cur.orelse, cst.If):
                cur = cur.orelse
            elif (
                isinstance(cur.orelse, cst.Else)
                and len(cur.orelse.body.body) == 1
                and isinstance(cur.orelse.body.body[0], cst.If)
            ):
                cur = cur.orelse.body.body[0]
            else:
                break
        return chain

    def _convert_elif_to_else_if(self, node: cst.If) -> cst.If:
        if isinstance(node.orelse, cst.If):
            return node.with_changes(
                orelse=cst.Else(
                    body=cst.IndentedBlock(
                        body=[self._convert_elif_to_else_if(node.orelse)]
                    )
                )
            )
        elif isinstance(node.orelse, cst.Else):
            new_body = [self._convert_elif_to_else_if(b) if isinstance(b, cst.If) else b for b in node.orelse.body.body]
            return node.with_changes(
                orelse=node.orelse.with_changes(body=cst.IndentedBlock(body=new_body))
            )
        return node

    def _convert_else_if_to_elif(self, node: cst.If) -> cst.If:
        orelse = node.orelse
        if (
            isinstance(orelse, cst.Else)
            and len(orelse.body.body) == 1
            and isinstance(orelse.body.body[0], cst.If)
        ):
            return node.with_changes(
                orelse=self._convert_else_if_to_elif(
                    orelse.body.body[0].with_changes(leading_lines=[])
                )
            )
        elif isinstance(orelse, cst.Else):
            new_body = [self._convert_else_if_to_elif(b) if isinstance(b, cst.If) else b for b in orelse.body.body]
            return node.with_changes(
                orelse=orelse.with_changes(body=cst.IndentedBlock(body=new_body))
            )
        return node

    def get_refactored_code(self, source: str) -> str:
        module = cst.parse_module(source)
        return module.visit(self).code



if __name__ == "__main__":
    src = """
if a > b:
    print("A")
elif a == b:
    print("Equal")
else:
    if a < b:
        print("B")

if x > y:
    print("X")
else:
    if x == y:
        print("EQ")
    elif x < y:
        print("Y")
"""

    transformer = ElseTransformer()
    print(transformer.get_refactored_code(src))



if a > b:
    print("A")
else:
    if a == b:
        print("Equal")
    else:
        if a < b:
            print("B")

if x > y:
    print("X")
elif x == y:
    print("EQ")
elif x < y:
    print("Y")

