In [1]:
%load_ext autoreload
%autoreload 2

In [46]:
import inspect
import ast
import numpy as np
import typing

import typist
from typist import type_check

Python built-in types are [int, float, str, bool, bytes, list, tuple, dict, set, frozenset, None]

We aim to keep track of all variables in the function that are derived from the function argument. That means we need to know what each operators will return. This must be simple as all Uniary operators on a Ciphertext, binary operaotrs on two Ciphertexts should result in a Ciphertext. 

In [47]:
class Ciphertext:
    def __init__(self, val):
        self.val = val

    def __add__(self, o):
        new_instance = Ciphertext(self.val + o.val)
        return new_instance

def print_ast(tree, indent=2):
    print(ast.dump(tree, indent=indent))

def fun(a: Ciphertext, b: Ciphertext, c:int = 0) -> Ciphertext:
    d = a + b
    c = c + 5
    return d

In [48]:
source_code = inspect.getsource(fun)
tree = ast.parse(source_code)

In [57]:
body = tree.body[0].body

In [58]:
assign1, assign2, ret = body

1. c = a ? b 일 때, c가 무엇이 될지 알려면, a와 b의 type을 미리 알아야함. 
2. a와 b의 type을 추적하기 위해서는 우변에서 type의 변화가 일어나지 않는 것이 좋음 (SSO) 
3. 따라서 복잡한 expression을 SSO 형태로 먼저 분해해내는 것이 중요함. 
4. c = fun(a,b)일 경우, fun()의 return type이 explicit하게 주어져야함
5. ALL ciphertext represent a single data type (Float in CKKS or Int in BF/GV). All non-Ciphertext variables that interacts with Ciphertext are automatically converted to either Float or Int. 
6. If one or more of operands are Ciphertext and the operation is valid, resulting variable is always a Ciphertext. 

In [79]:
assign1.value.left

<ast.Name at 0x7f71082a5750>

In [75]:
print(ast.dump(tree, indent=2))

Module(
  body=[
    FunctionDef(
      name='fun',
      args=arguments(
        posonlyargs=[],
        args=[
          arg(
            arg='a',
            annotation=Name(id='Ciphertext', ctx=Load())),
          arg(
            arg='b',
            annotation=Name(id='Ciphertext', ctx=Load())),
          arg(
            arg='c',
            annotation=Name(id='int', ctx=Load()))],
        kwonlyargs=[],
        kw_defaults=[],
        defaults=[
          Constant(value=0)]),
      body=[
        Assign(
          targets=[
            Name(id='d', ctx=Store())],
          value=BinOp(
            left=Name(id='a', ctx=Load()),
            op=Add(),
            right=Name(id='b', ctx=Load()))),
        Assign(
          targets=[
            Name(id='c', ctx=Store())],
          value=BinOp(
            left=Name(id='c', ctx=Load()),
            op=Add(),
            right=Constant(value=5))),
        Return(
          value=Name(id='d', ctx=Load()))],
      decorator_list=[],


In [77]:
annotated_tree = typist.AnnotateVariableTypes().visit(tree)

Visiting ASSIGN
NODE <ast.BinOp object at 0x7f71082a6470>
Inferred type:  None
Visiting ASSIGN
NODE <ast.BinOp object at 0x7f71082a75b0>
Inferred type:  None


In [36]:
print_ast(annotated_tree)

Module(
  body=[
    FunctionDef(
      name='fun',
      args=arguments(
        posonlyargs=[],
        args=[
          arg(
            arg='a',
            annotation=Name(id='Ciphertext', ctx=Load())),
          arg(
            arg='b',
            annotation=Name(id='Ciphertext', ctx=Load())),
          arg(
            arg='c',
            annotation=Name(id='int', ctx=Load()))],
        kwonlyargs=[],
        kw_defaults=[],
        defaults=[
          Constant(value=0)]),
      body=[
        Assign(
          targets=[
            Name(id='d', ctx=Store())],
          value=BinOp(
            left=Name(id='a', ctx=Load()),
            op=Add(),
            right=Name(id='b', ctx=Load()))),
        Assign(
          targets=[
            Name(id='c', ctx=Store())],
          value=BinOp(
            left=Name(id='c', ctx=Load()),
            op=Add(),
            right=Constant(value=5))),
        Return(
          value=Name(id='d', ctx=Load()))],
      decorator_list=[],


In [37]:
annotated_source_code = ast.unparse(annotated_tree)
print(annotated_source_code)

def fun(a: Ciphertext, b: Ciphertext, c: int=0) -> Ciphertext:
    d = a + b
    c = c + 5
    return d


In [38]:
annotated_tree = typist.AnnotateVariableTypes().visit(tree)
annotated_source_code = ast.unparse(annotated_tree)
print(annotated_source_code)

def fun(a: Ciphertext, b: Ciphertext, c: int=0) -> Ciphertext:
    d = a + b
    c = c + 5
    return d


In [27]:
annotated_tree.body

[]

In [26]:

# Todo:
# Add a decorator for an explicit type check. 
@type_check
def example_function(a: int, b: int, c: float, ext = None):
    d = a + b
    e = a + c
    
    return d

source_code = inspect.getsource(example_function)
tree = ast.parse(source_code)
annotated_tree = typist.AnnotateVariableTypes().visit(tree)
annotated_source_code = ast.unparse(annotated_tree)
print(annotated_source_code)




In [28]:
example_function(1,5,3)

TypeError: Argument 'c' should be of type <class 'float'>, but got <class 'int'>

In [16]:
function_def = tree.body[0]
arg1 = function_def.args.args[0]

In [18]:
arg1.annotation.id

'int'

In [9]:
print(ast.dump(annotated_tree, indent=2))

Module(
  body=[
    FunctionDef(
      name='example_function',
      args=arguments(
        posonlyargs=[],
        args=[],
        kwonlyargs=[],
        kw_defaults=[],
        defaults=[]),
      body=[
        AnnAssign(
          target=Name(id='a', ctx=Store()),
          annotation=Name(id='int', ctx=Load()),
          value=Constant(value=1),
          simple=1),
        AnnAssign(
          target=Name(id='b', ctx=Store()),
          annotation=Name(id='float', ctx=Load()),
          value=Constant(value=2.0),
          simple=1),
        AnnAssign(
          target=Name(id='c', ctx=Store()),
          annotation=Name(id='str', ctx=Load()),
          value=Constant(value='hello'),
          simple=1),
        Assign(
          targets=[
            Name(id='d', ctx=Store())],
          value=BinOp(
            left=Name(id='a', ctx=Load()),
            op=Add(),
            right=Name(id='b', ctx=Load())))],
      decorator_list=[])],
  type_ignores=[])


In [71]:
import ast
import inspect
from functools import wraps
from typing import Any, Callable

class DtypeCheckVisitor(ast.NodeVisitor):
    def __init__(self, arg_types):
        self.arg_types = arg_types
        self.errors = []

    def visit_Assign(self, node):
        for target in node.targets:
            if isinstance(target, ast.Name) and target.id in self.arg_types:
                if not self._is_assigning_correct_type(node.value, self.arg_types[target.id]):
                    self.errors.append(f"Data type of '{target.id}' changed before return")

    def _is_assigning_correct_type(self, value_node, expected_type):
        if isinstance(value_node, ast.Num):
            return isinstance(value_node.n, expected_type)
        elif isinstance(value_node, ast.Str):
            return expected_type == str
        elif isinstance(value_node, ast.Call):
            return self._is_assigning_correct_type(value_node.func, expected_type)
        return True

def enforce_dtype_unchanged(func: Callable[..., Any]) -> Callable[..., Any]:
    @wraps(func)
    def wrapper(*args, **kwargs):
        # Get the source code of the function
        source_code = inspect.getsource(func)
        tree = ast.parse(source_code)

        # Extract argument types from function annotations
        arg_types = {arg.arg: arg.annotation for arg in tree.body[0].args.args if arg.annotation}

        # Visit the AST and find assignments that change the data type of the input arguments
        visitor = DtypeCheckVisitor(arg_types)
        visitor.visit(tree)

        if visitor.errors:
            raise TypeError(f"Dtype change(s) detected before return: {', '.join(visitor.errors)}")

        return func(*args, **kwargs)

    return wrapper

In [72]:
@enforce_dtype_unchanged
def valid_function(a: int, b: float):
    c = a * b
    a = int(c)
    return c

result = valid_function(2, 3.0)  # No error

@enforce_dtype_unchanged
def invalid_function(a: int, b: float):
    c = a * b
    a = c
    return c

try:
    result = invalid_function(2, 3.0)
except TypeError as e:
    print(e)  # Output: Dtype change(s) detected before return: Data type of 'a' changed before return

In [74]:
invalid_function(3.0, 4)

12.0