In [2]:
import ast


class CryptoVarExtractor(ast.NodeTransformer):
    def __init__(self):
        self.next_var_id = 0

        # Instead of global mapping, we keep per-scope mapping
        self.scope_stack = []  # Each element: {"mapping":{}, "assigns":[]}

    def _new_scope(self):
        scope = {"mapping": {}, "assigns": []}
        self.scope_stack.append(scope)
        return scope

    def _end_scope(self):
        return self.scope_stack.pop()

    def _make_var_name(self):
        name = f"cond_{self.next_var_id}"
        self.next_var_id += 1
        return name

    def _replace_condition(self, node):
        """Return ast.Name referencing a variable for this condition."""
        scope = self.scope_stack[-1]

        if id(node) not in scope["mapping"]:
            var_name = self._make_var_name()
            scope["mapping"][id(node)] = var_name

            assign = ast.Assign(
                targets=[ast.Name(id=var_name, ctx=ast.Store())],
                value=node
            )
            ast.fix_missing_locations(assign)
            scope["assigns"].append(assign)

        return ast.Name(id=scope["mapping"][id(node)], ctx=ast.Load())

    def visit_FunctionDef(self, node):
        self._new_scope()

        self.generic_visit(node)

        scope = self._end_scope()

        node.body = scope["assigns"] + node.body

        return node

    def visit_Module(self, node):
        self._new_scope()

        self.generic_visit(node)

        scope = self._end_scope()

        node.body = scope["assigns"] + node.body

        return node

    def visit_If(self, node):
        self.generic_visit(node)
        node.test = self._replace_condition(node.test)
        return node

    def visit_Compare(self, node):
        return self._replace_condition(self.generic_visit(node))

    def visit_BoolOp(self, node):
        return self._replace_condition(self.generic_visit(node))

    def get_refactored_code(self, source_code):
        
        try:
            tree = ast.parse(source_code)
            self.next_var_id = 0 
            new_tree = self.visit(tree)
            result = ast.unparse(new_tree)

        except Exception as e:
            raise ValueError(f"Parse transform failed: {e}")


        return result
if __name__ == "__main__":
    src = """ 
a = 1
b = 4 
c = 3
d = 2

def fun():
    if a > b and (a != c) or (d != b):
        print("condition check 2")

if (a>b) or ((a !=b) or (c < d)) or ((c > a) and (b > d)):
    print("condition check")
else:
    pass
"""

    extractor = CryptoVarExtractor()
    print(extractor.get_refactored_code(src))
    
    src = """ 
a = 1
b = 4 
c = 3
d = 2

def fun():
    if a > b and (a != c) or (d != b):
        print("condition check 2")

if (a>b) or ((a !=b) or (c < d)) or ((c > a) and (b > d)):
    print("condition check")
else:
    pass
"""
    print("\n**********************************\n")
    extractor = CryptoVarExtractor()
    print(extractor.get_refactored_code(src))

cond_6 = a > b
cond_7 = a != b
cond_8 = c < d
cond_9 = cond_7 or cond_8
cond_10 = c > a
cond_11 = b > d
cond_12 = cond_10 and cond_11
cond_13 = cond_6 or cond_9 or cond_12
cond_14 = cond_13
a = 1
b = 4
c = 3
d = 2

def fun():
    cond_0 = a > b
    cond_1 = a != c
    cond_2 = cond_0 and cond_1
    cond_3 = d != b
    cond_4 = cond_2 or cond_3
    cond_5 = cond_4
    if cond_5:
        print('condition check 2')
if cond_14:
    print('condition check')
else:
    pass

**********************************

cond_6 = a > b
cond_7 = a != b
cond_8 = c < d
cond_9 = cond_7 or cond_8
cond_10 = c > a
cond_11 = b > d
cond_12 = cond_10 and cond_11
cond_13 = cond_6 or cond_9 or cond_12
cond_14 = cond_13
a = 1
b = 4
c = 3
d = 2

def fun():
    cond_0 = a > b
    cond_1 = a != c
    cond_2 = cond_0 and cond_1
    cond_3 = d != b
    cond_4 = cond_2 or cond_3
    cond_5 = cond_4
    if cond_5:
        print('condition check 2')
if cond_14:
    print('condition check')
else:
    pass


Lambda Refactor

In [13]:
import ast

class LambdaRefactor(ast.NodeTransformer):



    @staticmethod
    def has_decorators(func_def: ast.FunctionDef) -> bool:
        return bool(func_def.decorator_list)

    def _function_to_lambda(self, func_def: ast.FunctionDef):
        if len(func_def.body) != 1:
            return None

        stmt = func_def.body[0]
        if not isinstance(stmt, ast.Return) or stmt.value is None:
            return None

        if self.has_decorators(func_def):
            return None

        return ast.Assign(
            targets=[ast.Name(id=func_def.name, ctx=ast.Store())],
            value=ast.Lambda(
                args=func_def.args,
                body=stmt.value
            )
        )

    def _lambda_to_function(self, assign: ast.Assign):
        if len(assign.targets) != 1:
            return None

        target = assign.targets[0]
        if not isinstance(target, ast.Name):
            return None

        if not isinstance(assign.value, ast.Lambda):
            return None

        lam = assign.value

        return ast.FunctionDef(
            name=target.id,
            args=lam.args,
            body=[ast.Return(value=lam.body)],
            decorator_list=[]
        )

    def visit_Module(self, node: ast.Module):
        new_body = []
        for stmt in node.body:
            if isinstance(stmt, ast.FunctionDef):
                converted = self._function_to_lambda(stmt)
                if converted:
                    new_body.append(converted)
                    continue

            if isinstance(stmt, ast.Assign):
                converted = self._lambda_to_function(stmt)
                if converted:
                    new_body.append(converted)
                    continue

            new_body.append(self.visit(stmt))

        node.body = new_body
        return node

    def visit_ClassDef(self, node: ast.ClassDef):
        new_body = []
        for stmt in node.body:
            if isinstance(stmt, ast.FunctionDef):
                converted = self._function_to_lambda(stmt)
                if converted:
                    new_body.append(converted)
                    continue

            if isinstance(stmt, ast.Assign):
                converted = self._lambda_to_function(stmt)
                if converted:
                    new_body.append(converted)
                    continue

            new_body.append(self.visit(stmt))

        node.body = new_body
        return node

    def visit_FunctionDef(self, node: ast.FunctionDef):
        node.body = [self.visit(stmt) for stmt in node.body]
        return node

    def get_refactored_code(self, source_code: str) -> str:
        try:
            tree = ast.parse(source_code)
            tree = self.visit(tree)
            ast.fix_missing_locations(tree)
            return ast.unparse(tree)
        except SyntaxError as e:
            raise ValueError(f"Syntax error in source code: {e}")



if __name__ == "__main__":
    code = """
class Math:
    def add(a, b):
        return a + b

    subtract = lambda x, y: x - y

    def mul(x): return x * 2

    div = lambda a, b=1: a / b
"""

    refactored = LambdaRefactor().get_refactored_code(code)
    print(refactored)

class Math:
    add = lambda a, b: a + b

    def subtract(x, y):
        return x - y
    mul = lambda x: x * 2

    def div(a, b=1):
        return a / b


elseif and elif

In [27]:
import libcst as cst

class ElseTransformer(cst.CSTTransformer):
    
    def leave_If(self, original: cst.If, updated: cst.If) -> cst.If:
        return self._process_chain(updated)

    def _process_chain(self, node: cst.If) -> cst.If:
        chain = self._collect_chain(node)
        has_elif = any(isinstance(n.orelse, cst.If) for n in chain)
        has_else_if = any(
            isinstance(n.orelse, cst.Else) and len(n.orelse.body.body) == 1 and isinstance(n.orelse.body.body[0], cst.If)
            for n in chain
        )

        if has_elif:
            return self._convert_elif_to_else_if(node)
        elif has_else_if:
            return self._convert_else_if_to_elif(node)
        return node

    def _collect_chain(self, node: cst.If):
        chain = []
        cur = node
        while isinstance(cur, cst.If):
            chain.append(cur)
            if isinstance(cur.orelse, cst.If):
                cur = cur.orelse
            elif (
                isinstance(cur.orelse, cst.Else)
                and len(cur.orelse.body.body) == 1
                and isinstance(cur.orelse.body.body[0], cst.If)
            ):
                cur = cur.orelse.body.body[0]
            else:
                break
        return chain

    def _convert_elif_to_else_if(self, node: cst.If) -> cst.If:
        if isinstance(node.orelse, cst.If):
            return node.with_changes(
                orelse=cst.Else(
                    body=cst.IndentedBlock(
                        body=[self._convert_elif_to_else_if(node.orelse)]
                    )
                )
            )
        elif isinstance(node.orelse, cst.Else):
            new_body = [self._convert_elif_to_else_if(b) if isinstance(b, cst.If) else b for b in node.orelse.body.body]
            return node.with_changes(
                orelse=node.orelse.with_changes(body=cst.IndentedBlock(body=new_body))
            )
        return node

    def _convert_else_if_to_elif(self, node: cst.If) -> cst.If:
        orelse = node.orelse
        if (
            isinstance(orelse, cst.Else)
            and len(orelse.body.body) == 1
            and isinstance(orelse.body.body[0], cst.If)
        ):
            return node.with_changes(
                orelse=self._convert_else_if_to_elif(
                    orelse.body.body[0].with_changes(leading_lines=[])
                )
            )
        elif isinstance(orelse, cst.Else):
            new_body = [self._convert_else_if_to_elif(b) if isinstance(b, cst.If) else b for b in orelse.body.body]
            return node.with_changes(
                orelse=orelse.with_changes(body=cst.IndentedBlock(body=new_body))
            )
        return node

    def get_refactored_code(self, source: str) -> str:
        module = cst.parse_module(source)
        return module.visit(self).code



if __name__ == "__main__":
    src = """
if a > b:
    print("A")
elif a == b:
    print("Equal")
else:
    if a < b:
        print("B")

if x > y:
    print("X")
else:
    if x == y:
        print("EQ")
    elif x < y:
        print("Y")
"""

    transformer = ElseTransformer()
    print(transformer.get_refactored_code(src))



if a > b:
    print("A")
else:
    if a == b:
        print("Equal")
    else:
        if a < b:
            print("B")

if x > y:
    print("X")
elif x == y:
    print("EQ")
elif x < y:
    print("Y")



For While Loop viceversa

In [None]:
import ast


class LoopStyleTransformer(ast.NodeTransformer):
    def __init__(self):
        self.loop_stack = []

    def visit_For(self, node):
        if (isinstance(node.iter, ast.Call) and
            isinstance(node.iter.func, ast.Name) and
            node.iter.func.id == "range"):

            self.generic_visit(node)

            if not isinstance(node.target, ast.Name):
                return node

            loop_var = node.target.id

            args = node.iter.args
            if len(args) == 1:
                start = ast.Constant(value=0)
                stop = args[0]
                step = ast.Constant(value=1)
            elif len(args) == 2:
                start = args[0]
                stop = args[1]
                step = ast.Constant(value=1)
            elif len(args) == 3:
                start = args[0]
                stop = args[1]
                step = args[2]
            else:
                return node

            ops = [ast.Lt()]
            if isinstance(step, ast.Constant) and step.value < 0:
                ops = [ast.Gt()]

            init = ast.Assign(
                targets=[ast.Name(id=loop_var, ctx=ast.Store())],
                value=start
            )

            cond = ast.Compare(
                left=ast.Name(id=loop_var, ctx=ast.Load()),
                ops=ops,
                comparators=[stop]
            )

            incr = ast.AugAssign(
                target=ast.Name(id=loop_var, ctx=ast.Store()),
                op=ast.Add(),
                value=step
            )

            while_node = ast.While(
                test=cond,
                body=node.body + [incr],
                orelse=node.orelse
            )

            ast.fix_missing_locations(init)
            ast.fix_missing_locations(while_node)

            return [init, while_node]

        else:
            if not isinstance(node.target, ast.Name):
                return node

            loop_var = node.target.id

            node.iter = self.visit(node.iter)
            node.orelse = [self.visit(o) for o in node.orelse]

            idx_name = f"_idx_{loop_var}_{len(self.loop_stack)}"

            init_idx = ast.Assign(
                targets=[ast.Name(id=idx_name, ctx=ast.Store())],
                value=ast.Constant(value=0)
            )

            len_call = ast.Call(
                func=ast.Name(id="len", ctx=ast.Load()),
                args=[node.iter],
                keywords=[]
            )

            cond = ast.Compare(
                left=ast.Name(id=idx_name, ctx=ast.Load()),
                ops=[ast.Lt()],
                comparators=[len_call]
            )

            incr_idx = ast.AugAssign(
                target=ast.Name(id=idx_name, ctx=ast.Store()),
                op=ast.Add(),
                value=ast.Constant(value=1)
            )

            self.loop_stack.append((loop_var, node.iter, idx_name))

            new_body = [self.visit(stmt) for stmt in node.body] + [incr_idx]

            self.loop_stack.pop()

            while_node = ast.While(
                test=cond,
                body=new_body,
                orelse=node.orelse
            )

            ast.fix_missing_locations(init_idx)
            ast.fix_missing_locations(while_node)

            return [init_idx, while_node]

    def visit_While(self, node):
        self.generic_visit(node)

        if not isinstance(node.test, ast.Compare):
            return node

        if len(node.test.ops) != 1 or len(node.test.comparators) != 1:
            return node

        op = node.test.ops[0]
        left = node.test.left
        right = node.test.comparators[0]

        if not isinstance(left, ast.Name):
            return node

        loop_var = left.id

        incr_stmt = None
        step_node = ast.Constant(value=1)
        is_add = True

        new_body = node.body[:]

        for i, stmt in enumerate(new_body):
            if isinstance(stmt, ast.AugAssign) and isinstance(stmt.target, ast.Name) and stmt.target.id == loop_var:
                if isinstance(stmt.op, ast.Add):
                    is_add = True
                    step_node = stmt.value
                elif isinstance(stmt.op, ast.Sub):
                    is_add = False
                    step_node = stmt.value
                incr_stmt = stmt
                del new_body[i]
                break

            elif isinstance(stmt, ast.Assign) and len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name) and stmt.targets[0].id == loop_var and isinstance(stmt.value, ast.BinOp) and isinstance(stmt.value.left, ast.Name) and stmt.value.left.id == loop_var:
                if isinstance(stmt.value.op, ast.Add):
                    is_add = True
                    step_node = stmt.value.right
                elif isinstance(stmt.value.op, ast.Sub):
                    is_add = False
                    step_node = stmt.value.right
                incr_stmt = stmt
                del new_body[i]
                break

        if not incr_stmt:
            return node

        if isinstance(op, ast.Lt) and is_add:
            start_node = ast.Name(id=loop_var, ctx=ast.Load())
            stop_node = right
            step_node = step_node

        elif isinstance(op, ast.Gt) and not is_add:
            start_node = ast.Name(id=loop_var, ctx=ast.Load())
            stop_node = right
            step_node = ast.UnaryOp(op=ast.USub(), operand=step_node)

        else:
            return node

        for_node = ast.For(
            target=ast.Name(id=loop_var, ctx=ast.Store()),
            iter=ast.Call(
                func=ast.Name(id="range", ctx=ast.Load()),
                args=[start_node, stop_node, step_node],
                keywords=[]
            ),
            body=new_body,
            orelse=node.orelse
        )

        ast.fix_missing_locations(for_node)
        return for_node

    def visit_Name(self, node):
        if not self.loop_stack or not isinstance(node.ctx, ast.Load):
            return node

        for loop_var, iterable, idx_name in reversed(self.loop_stack):
            if node.id == loop_var:
                return ast.Subscript(
                    value=iterable,
                    slice=ast.Name(id=idx_name, ctx=ast.Load()),
                    ctx=ast.Load()
                )
        return node

    def visit_Subscript(self, node):
        self.generic_visit(node)

        if not self.loop_stack:
            return node

        for loop_var, _, idx_name in reversed(self.loop_stack):
            if isinstance(node.slice, ast.Name) and node.slice.id == loop_var:
                node.slice = ast.Name(id=idx_name, ctx=ast.Load())
                break

        return node

    def get_refactored_code(self, source_code):
        try:
            tree = ast.parse(source_code)
            tree = self.visit(tree)
            ast.fix_missing_locations(tree)
            return ast.unparse(tree)
        except Exception as e:
            raise ValueError(f"Error during transformation: {type(e).__name__}: {str(e)}")


if __name__ == "__main__":
    src="""   
for i in range(5):
    print(i)

fruits = ["apple", "banana", "mango"]

for fruit in fruits:
    print(fruit)

i = 0
while i < 5:
    print(i)
    i += 1

count = 5
while count > 0:
    print(count)
    count -= 1

for i in range(5):
    if i == 2:
        continue
    print(i)

for i in range(1, 3):
    for j in range(1, 4):
        print(i, j)

for i in range(1, 3):
    j = 0
    while j < 4:
        print(i, j)
        j = j + 1


"""
    transformer = LoopStyleTransformer()
    print(transformer.get_refactored_code(src))

i = 0
while i < 5:
    print(i)
    i += 1
fruits = ['apple', 'banana', 'mango']
_idx_fruit_0 = 0
while _idx_fruit_0 < len(fruits):
    print(fruits[_idx_fruit_0])
    _idx_fruit_0 += 1
i = 0
for i in range(i, 5, 1):
    print(i)
count = 5
for count in range(count, 0, -1):
    print(count)
i = 0
while i < 5:
    if i == 2:
        continue
    print(i)
    i += 1
i = 1
while i < 3:
    j = 1
    while j < 4:
        print(i, j)
        j += 1
    i += 1
i = 1
while i < 3:
    j = 0
    for j in range(j, 4, 1):
        print(i, j)
    i += 1


hardcode params not working for intergers and outside of funciton

In [10]:
import ast
from collections import defaultdict


class HardcodedValues(ast.NodeTransformer):

    SUPPORTED_TYPES = (str, int)

    def _collect_from_call(self, call, collected):
        for arg in call.args:
            if isinstance(arg, ast.Constant) and isinstance(arg.value, self.SUPPORTED_TYPES):
                collected["p"].add(arg.value)

        for kw in call.keywords:
            if (
                isinstance(kw.value, ast.Constant)
                and isinstance(kw.value.value, self.SUPPORTED_TYPES)
            ):
                collected["p"].add(kw.value.value)

    def _build_vars_and_map(self, node):
        raw = defaultdict(set)

        for sub in ast.walk(node):
            if isinstance(sub, ast.Call):
                self._collect_from_call(sub, raw)

        var_list = []
        counters = defaultdict(int)

        for lit in raw["p"]:
            idx = counters["p"]
            counters["p"] += 1
            var_list.append((f"var{idx}", lit))

        literal_to_var = {lit: name for name, lit in var_list}

        assignments = [
            ast.Assign(
                targets=[ast.Name(id=name, ctx=ast.Store())],
                value=ast.Constant(value=lit),
            )
            for name, lit in var_list
        ]

        for a in assignments:
            ast.fix_missing_locations(a)

        return assignments, literal_to_var

    class _Replacer(ast.NodeTransformer):
        def __init__(self, literal_to_var):
            self.l2v = literal_to_var

        def visit_Call(self, node):
            self.generic_visit(node)

            node.args = [
                ast.Name(id=self.l2v[arg.value], ctx=ast.Load())
                if (
                    isinstance(arg, ast.Constant)
                    and arg.value in self.l2v
                )
                else arg
                for arg in node.args
            ]

            node.keywords = [
                ast.keyword(
                    arg=kw.arg,
                    value=ast.Name(id=self.l2v[kw.value.value], ctx=ast.Load()),
                )
                if (
                    isinstance(kw.value, ast.Constant)
                    and kw.value.value in self.l2v
                )
                else kw
                for kw in node.keywords
            ]

            return node

    def _inject(self, node):
        assignments, literal_to_var = self._build_vars_and_map(node)

        if assignments:
            node.body[0:0] = assignments

        if literal_to_var:
            self._Replacer(literal_to_var).visit(node)

    def refactor(self, tree):
        self._inject(tree)

        for node in ast.walk(tree):
            if isinstance(node, ast.FunctionDef):
                self._inject(node)

        return ast.unparse(tree)

    def get_refactored_code(self, source_code):
        try:
            tree = ast.parse(source_code)
            return self.refactor(tree)
        except SyntaxError as e:
            raise ValueError(f"Syntax error in source code: {e}")


if __name__ == "__main__":
    src = """
print("hello",10)

def foo():
    bar("test", count=5)

def fun():
    print("hello")
    print(123)
"""
    t = HardcodedValues()
    print(t.get_refactored_code(src))


var0 = 5
var1 = 'test'
var2 = 10
var3 = 'hello'
var4 = 123
print(var3, var2)

def foo():
    bar(var1, count=var0)

def fun():
    print(var3)
    print(var4)


mergeparamater

In [None]:
class ParamRecorder(ast.NodeVisitor):
    def __init__(self):
        self.func2par = {}
        self.class2func = {}
        self.par2arg = {}
    def visit_ClassDef(self, node):
        func = []
        for inner_node in node.body:
            if isinstance(inner_node, ast.FunctionDef) and not (inner_node.name.startswith("__") and inner_node.name.endswith("__")):
                func.append(inner_node.name)
            self.class2func[node.name] = func
        self.generic_visit(node)
    def visit_FunctionDef(self, node):
        flag = 0
        if not (node.name.startswith("__") and node.name.endswith("__")):
            for ls in self.class2func.values():
                if node.name in ls:
                    flag = 1
            if flag:
                return node 
            params = []
            for elt in node.args.args:
                params.append(elt.arg)
            self.func2par[node.name] = params
        return node
    def visit_Call(self, node):
        if isinstance(node.func, ast.Name) and node.func.id in self.func2par.keys():
            for idx, elt in enumerate(node.args):
                ls = list(self.func2par[node.func.id])
                self.par2arg[ls[idx]] = elt.id 
        return node


class ParamRefactor(ast.NodeTransformer):
    def __init__(self):

        pass 
    def get_refactored_code(self, code):
        obj = ParamRecorder()
        obj.visit(code)
        self.class_w_func = list(obj.class2func.values())
        self.valid_functions = list(obj.func2par.keys())
        self.par2arg_map = obj.par2arg
        

strutucre obje relation operatiion al converstion and integer refactoring scripts 

In [3]:
#integer refactor
import ast
import random
class ConstantRefactor(ast.NodeTransformer):
    def findFactor(self, num):
        if num%2 != 0:
            return [num/2, num/2]
        else:
            n = num//2
            return [n] + self.findFactor(n)
    def binOpFunc(self, length):
        if length <= 2:
            new_binop = ast.BinOp(
                left = ast.Constant(value=self.num_arr[length-2]),
                op=ast.Add(),
                right=ast.Constant(value=self.num_arr[length-1])
            )
            return new_binop
        else:
            length -= 1
            new_binop = ast.BinOp(
                left = self.binOpFunc(length),
                op = ast.Add(),
                right = ast.Constant(value=self.num_arr[length - 1])
            )
            return new_binop
    def visit_Constant(self, node):
        r = random.randint(0, 2)
        if isinstance(node.value, int) and node.value != 0 and r:
            self.num_arr = self.findFactor(node.value)
            return self.binOpFunc(len(self.num_arr))
        return node 
    def get_refactored_code(self, code):
        try:
            tree = ast.parse(code)
            mod_tree = self.visit(tree)
            return ast.unparse(mod_tree)
        except SyntaxError as e:
            raise ValueError(f"Syntax error in source code: {e}")
        
if __name__=="__main__":
    src = """  
a[2] = [90//2, 100]
a = {1: '29'}
class Test():
    def func(a=54):
        print(51)
        print(25*89.90)
        print(0)
"""

    trans = ConstantRefactor()
    print(trans.get_refactored_code(src))

a[1 + 0.5 + 0.5] = [(45 + 22.5 + 22.5) // 2, 50 + 25 + 25 + 12.5]
a = {0.5 + 0.5: '29'}

class Test:

    def func(a=27 + 13.5 + 13.5):
        print(25.5 + 25.5)
        print(25 * 89.9)
        print(0)


Relation operators

In [9]:
import ast 

class RelationRefactors(ast.NodeTransformer):
    def change_exp(self, node):
        temp_left = node.left.id 
        temp_ops = node.ops[0]
        temp_right = node.comparators[0].id 
        if isinstance(temp_ops, ast.Gt):
            new_exp = ast.Compare(
                left = ast.Name(id=temp_right, ctx=ast.Load()),
                ops=[ast.Lt()],
                comparators=[ast.Name(id=temp_left, ctx=ast.Load())]
            )
        elif isinstance(temp_ops, ast.GtE):
            new_exp = ast.Compare(
                left=ast.Name(id=temp_right, ctx=ast.Load()),
                ops=[ast.LtE()],
                comparators=[ast.Name(id=temp_left, ctx=ast.Load())]
            )
        elif isinstance(temp_ops, ast.Lt):
            new_exp = ast.Compare(
                left = ast.Name(id=temp_right, ctx=ast.Load()),
                ops=[ast.Gt()],
                comparators=[ast.Name(id=temp_left, ctx=ast.Load())]
            )
        elif isinstance(temp_ops, ast.LtE):
            new_exp=ast.Compare(
                left = ast.Name(id=temp_right, ctx=ast.Load()),
                ops=[ast.GtE()],
                comparators=[ast.Name(id=temp_left, ctx=ast.Load())]
            )
        elif isinstance(temp_ops, ast.Eq):
            new_exp = ast.UnaryOp(
                op=ast.Not(),
                operand = ast.Compare(
                    left=ast.Name(id=temp_left, ctx=ast.Load()),
                    ops=[ast.NotEq()],
                    comparators=[ast.Name(id=temp_right, ctx=ast.Load())]
                )
            )
        elif isinstance(temp_ops, ast.NotEq):
            new_exp = ast.UnaryOp(
                op=ast.Not(),
                operand=ast.Compare(
                    left=ast.Name(id=temp_left, ctx=ast.Load()),
                    ops=[ast.Eq()],
                    comparators=[ast.Name(id=temp_right, ctx=ast.Load())]
                )
            )
        return new_exp
    def conditional_check(self, node):
        if isinstance(node, ast.Compare):
            return self.change_exp(node)
        elif isinstance(node, ast.BoolOp):
            exp = ast.BoolOp(
                op=node.op,
                values=[self.conditional_check(value) for value in node.values]
            ) 
            return exp 
    def get_refactored_code(self, code):
        try:
            tree = ast.parse(code)
            for node in ast.walk(tree):
                if isinstance(node, ast.If) or isinstance(node, ast.While):
                    if isinstance(node.test, ast.Compare):
                        node.test = self.change_exp(node.test)
                    elif isinstance(node.test, ast.BoolOp):
                        node.test.values = [self.conditional_check(value) for value in node.test.values]
            return ast.unparse(tree)
        except SyntaxError as e:
            raise ValueError(f"Syntax error in source code: {e}")
        
if __name__=="__main__":
    src = """
def func(a, b, c):
    if a>b and b >= c or a==c:
        return a
    elif c < b:
        return b
    elif a == c:
        return c
    elif func2(): return None

while(a<b and var1!=var2):
    a += 1
"""


    trans = RelationRefactors()
    print(trans.get_refactored_code(src))
    

def func(a, b, c):
    if b < a and c <= b or not a != c:
        return a
    elif b > c:
        return b
    elif not a != c:
        return c
    elif func2():
        return None
while b > a and (not var1 == var2):
    a += 1


In [None]:
import ast

class RelationRefactors(ast.NodeTransformer):

    def change_exp(self, node: ast.Compare):
        if not (
            isinstance(node.left, ast.Name)
            and len(node.ops) == 1
            and len(node.comparators) == 1
            and isinstance(node.comparators[0], ast.Name)
        ):
            return node

        left = node.left.id
        right = node.comparators[0].id
        op = node.ops[0]

        op_map = {
            ast.Gt: ast.Lt,
            ast.GtE: ast.LtE,
            ast.Lt: ast.Gt,
            ast.LtE: ast.GtE,
            ast.Eq: ast.Eq,
            ast.NotEq: ast.NotEq,
        }

        new_op_cls = op_map.get(type(op))
        if not new_op_cls:
            return node

        new_node = ast.Compare(
            left=ast.Name(id=right, ctx=ast.Load()),
            ops=[new_op_cls()],
            comparators=[ast.Name(id=left, ctx=ast.Load())],
        )

        return ast.copy_location(new_node, node)

    def conditional_check(self, node):
        if isinstance(node, ast.Compare):
            return self.change_exp(node)

        if isinstance(node, ast.BoolOp):
            return ast.copy_location(
                ast.BoolOp(
                    op=node.op,
                    values=[self.conditional_check(v) for v in node.values],
                ),
                node,
            )

        return node  
    def get_refactored_code(self, code):
        tree = ast.parse(code)

        for node in ast.walk(tree):
            if isinstance(node, (ast.If, ast.While)):
                node.test = self.conditional_check(node.test)

        ast.fix_missing_locations(tree)
        return ast.unparse(tree)


if __name__ == "__main__":
    src = """
def func(a, b, c):
    if a>b and b >= c or a==c:
        return a
    elif c < b:
        return b
    elif a == c:
        return c
    elif func2(): return None

while(a<b and var1!=var2):
    a += 1
"""
    print(RelationRefactors().get_refactored_code(src))


def func(a, b, c):
    if b < a and c <= b or c == a:
        return a
    elif b > c:
        return b
    elif c == a:
        return c
    elif func2():
        return None
while b > a and var2 != var1:
    a += 1


elseif not wokring