In [None]:
from torch import Tensor

import sys
sys.path.append('./../..')

from image_gen.noise import BaseNoiseSchedule
from image_gen.utils import get_class_source

In [None]:
class Test:
    def __mult__(self, other):
        return NotImplemented

class OriginalClassName(Test):
    VAR = 1
    def __init__(self, a: int, b: bool = True):
        self._a = a
        self._b = b
        self.var = OriginalClassName.VAR  # Self-reference example

    def __call__(self):
        return self._a
        
    def method(self):
        def nested_func():
            print("Inside nested")
        return nested_func

    @classmethod
    def class_method(cls):
        print(f"Class method of {OriginalClassName.__name__}")
        
    @staticmethod
    def static_method():
        print("Static method of OriginalClassName")

    @property
    def a(self):
        return self.a
    
    @a.setter
    def a(self, value):
        self._a = value

def sample_function():
    print("I'm a function")

In [None]:
class ExponentialNoiseSchedule(BaseNoiseSchedule):
    def __init__(self, *args, beta_min: float = 0.001, beta_max: float = 50.0, e: float = 2.0, **kwargs):
        self.beta_min = beta_min
        self.beta_max = beta_max
        self.e = e

    def __call__(self, t: Tensor, *args, **kwargs) -> Tensor:
        return self.beta_min + t ** self.e * (self.beta_max - self.beta_min)

    def integral_beta(self, t: Tensor, *args, **kwargs) -> Tensor:
        integral_beta_min = self.beta_min * t
        integral_t = (self.beta_max - self.beta_min) * (t ** (self.e + 1)) / (self.e + 1)
        return integral_beta_min + integral_t

    def config(self) -> dict:
        return {
            "beta_min": self.beta_min,
            "beta_max": self.beta_max,
            "e": self.e
        }

In [None]:
print(get_class_source(OriginalClassName))

In [None]:
print(get_class_source(ExponentialNoiseSchedule))

In [None]:
import ast

class ClassRenamer(ast.NodeTransformer):
    def __init__(self, old_name, new_name):
        self.old_name = old_name
        self.new_name = new_name
        self.in_class = False

    def visit_ClassDef(self, node):
        # Rename the class definition
        if node.name == self.old_name:
            node.name = self.new_name
            self.in_class = True
        else:
            self.in_class = False

        # Process base classes
        node.bases = [self.visit(base) for base in node.bases]
        node.keywords = [self.visit(kw) for kw in node.keywords]
        
        # Process body while tracking class membership
        original_in_class = self.in_class
        self.generic_visit(node)
        self.in_class = original_in_class
        return node

    def visit_Name(self, node):
        # Handle references in type annotations and other contexts
        if node.id == self.old_name and not self.in_class:
            return ast.Name(id=self.new_name, ctx=node.ctx)
        return node

    def visit_Attribute(self, node):
        # Handle class attribute accesses
        if isinstance(node.value, ast.Name) and node.value.id == self.old_name:
            node.value = ast.Name(id=self.new_name, ctx=ast.Load())
        self.generic_visit(node)
        return node

    def visit_Call(self, node):
        # Handle class instantiations
        if isinstance(node.func, ast.Name) and node.func.id == self.old_name:
            node.func = ast.Name(id=self.new_name, ctx=ast.Load())
        self.generic_visit(node)
        return node

    def visit_FunctionDef(self, node):
        # Handle type annotations in arguments
        for arg in node.args.args:
            if arg.annotation:
                arg.annotation = self.visit(arg.annotation)
        if node.returns:
            node.returns = self.visit(node.returns)
        self.generic_visit(node)
        return node

def rename_class(source_code, old_name, new_name):
    tree = ast.parse(source_code)
    transformer = ClassRenamer(old_name, new_name)
    new_tree = transformer.visit(tree)
    ast.fix_missing_locations(new_tree)
    return ast.unparse(new_tree)

In [None]:
print(rename_class(get_class_source(OriginalClassName), "OriginalClassName", "NewClassName"))

In [None]:
print(rename_class(get_class_source(ExponentialNoiseSchedule), "ExponentialNoiseSchedule", "NewClassName2"))

In [None]:
from image_gen.utils import CustomClassWrapper

In [None]:
exp = CustomClassWrapper(get_class_source(ExponentialNoiseSchedule), "ExponentialNoiseSchedule")

In [None]:
exp.loaded

The code isn't run until needed, which gives the user time to verify the content

In [None]:
import torch
import matplotlib.pyplot as plt

x = torch.linspace(0.0001, 0.9999, 100)

plt.figure(figsize=(5, 5))

plt.plot(x, exp(x), label='Exponential Schedule', color='blue', linewidth=2)

plt.legend()
plt.show()

In [None]:
exp.loaded

In [None]:
print(exp._code)