In [1]:
from typing import Optional
import sys
import json
import yaml
import inspect
from enum import Enum

import torch
# torch.set_printoptions(profile="full")

In [2]:
# this is based off the docs here and the file native_functions.yaml in the pytorch github repo
# https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native

In [3]:
!curl -O https://raw.githubusercontent.com/pytorch/pytorch/1.6/aten/src/ATen/native/native_functions.yaml

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  267k  100  267k    0     0  3220k      0 --:--:-- --:--:-- --:--:-- 3181k


In [4]:
with open(r"native_functions.yaml") as file:
    native_functions = yaml.load(file, Loader=yaml.FullLoader)

In [5]:
# example structure
# - func: add_relu.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
#   variants: function
#   dispatch:
#     CPU: add_relu_out

def process_yaml_entry(native_func):
    return_type = native_func["func"].split("->")[-1].strip()

    func_name = native_func["func"].split("(")[0].strip()
    args_parts = native_func["func"].split("->")[0].strip().split("(")
    args = "".join(args_parts[1:]).strip().replace("(", "").replace(")", "")
    return func_name, args, return_type

print(native_functions[0])
print(process_yaml_entry(native_functions[0]))

{'func': '_cast_Byte(Tensor self, bool non_blocking=False) -> Tensor', 'use_c10_dispatcher': 'full', 'variants': 'function'}
('_cast_Byte', 'Tensor self, bool non_blocking=False', 'Tensor')


In [6]:
def return_type_to_python(return_type):
    return_type = return_type.strip()
    if "()" == return_type:
        return ["None"]

    return_type = return_type.replace("(", "").replace(")", "")
    parts = return_type.split(",")
    clean_parts = []
    for part in parts:
        clean_part = part.strip()
        if clean_part.startswith("Tensor"):
            clean_part = "Tensor"
        # according to the docs Scalar is any kind of numeric in python or a unit tensor
        if clean_part == "Scalar":
            clean_part = "Union[int, float, complex, Tensor]"
        # it seems like these are the pytorch dtype types
        if clean_part == "ScalarType":
            clean_part = "torch.dtype"
        if "[]" in part:
            clean_part = f"List[{clean_part}]"
        
        clean_parts.append(clean_part)
    return clean_parts
    
    if any(prim == return_type for prim in ["int", "float", "bool"]):
        return return_type
    
return_type_to_python("(Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)")

['Tensor', 'Tensor']

In [7]:
def signature_type_to_python(signature):
    signature = signature.strip()

    signature = signature.replace("(", "").replace(")", "")
    clean_parts = []

    if signature == "":
        return ["None"]
    # default int[] value [0,1] should not add space after comma, since native_parse.py uses ', ' to split args
    for part in signature.split(", "):
        name = None
        default = None
        clean_part = part.strip()
        if clean_part == "*":
            # all args after this must be kwargs
            # not sure we care right now
#             clean_parts.append(clean_part)
            continue
        subpart = clean_part.split(" ")
        if len(subpart) > 1:
            subsub = subpart[1].split("=")
            name = subsub[0]
            if len(subsub) > 1:
                default = subsub[1]

        if subpart[0].startswith("Tensor"):
            clean_part = "Tensor"
        else:
            clean_part = subpart[0]

        if clean_part.startswith("ScalarType"):
            clean_part = "torch.dtype"
            
        if clean_part.startswith("Scalar"):
            clean_part = "Union[int, float, complex, Tensor]"

        if "[" in subpart[0]:
            t = clean_part.split("[")[0]
            clean_part = f"List[{t}]"

        if "?" in part:
            clean_part = f"Optional[{clean_part}]".replace("?", "")
            
        if name is not None:
            clean_part = f"{name}: {clean_part}"
            
        if default is not None:
            clean_part = f"{clean_part} = {default}"

        clean_parts.append(clean_part)
        
    if len(clean_parts) == 0:
        return ["None"]
    return clean_parts

signature_type_to_python( 'Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None')

['end: Union[int, float, complex, Tensor]',
 'dtype: Optional[torch.dtype] = None',
 'layout: Optional[Layout] = None',
 'device: Optional[Device] = None',
 'pin_memory: Optional[bool] = None']

In [8]:
def process_func_yaml(native_functions):
    unique_return_types = set()
    unique_function_args = set()
    func_dict = {}
    for native_func in native_functions:
        func_name, args, return_type = process_yaml_entry(native_func)
        unique_return_types.add(return_type)
        unique_function_args.add(args)

        if func_name in func_dict:
            print(f"Error duplicate func_name: {func_name}")
        func_dict[func_name] = {"args": signature_type_to_python(args), "return_type": return_type_to_python(return_type)}

    return func_dict, unique_return_types, unique_function_args

In [9]:
native_func_dict, unique_return_types, unique_function_args = process_func_yaml(native_functions)
print(len(native_func_dict))

1527


In [10]:
print(native_func_dict["div_.Tensor"])

type_json = json.dumps(native_func_dict)
with open("torch_types.json", "w") as f:
    f.write(type_json)

{'args': ['self: Tensor', 'other: Tensor'], 'return_type': ['Tensor']}


In [11]:
unique_return_types

{'()',
 '(Tensor Q, Tensor R)',
 '(Tensor U, Tensor S, Tensor V)',
 '(Tensor a, Tensor tau)',
 '(Tensor eigenvalues, Tensor eigenvectors)',
 '(Tensor grad_input, Tensor grad_weight)',
 '(Tensor grad_input, Tensor grad_weight, Tensor grad_bias)',
 '(Tensor grad_self, Tensor grad_grid)',
 '(Tensor output, Tensor buffer)',
 '(Tensor output, Tensor finput, Tensor fgrad_input)',
 '(Tensor output, Tensor is_target)',
 '(Tensor output, Tensor total_weight)',
 '(Tensor sign, Tensor logabsdet)',
 '(Tensor solution, Tensor LU)',
 '(Tensor solution, Tensor QR)',
 '(Tensor solution, Tensor cloned_coefficient)',
 '(Tensor values, Tensor indices)',
 '(Tensor(a!) Q, Tensor(b!) R)',
 '(Tensor(a!) U, Tensor(b!) S, Tensor(c!) V)',
 '(Tensor(a!) a, Tensor(b!) tau)',
 '(Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)',
 '(Tensor(a!) solution, Tensor(b!) LU)',
 '(Tensor(a!) solution, Tensor(b!) QR)',
 '(Tensor(a!) solution, Tensor(b!) cloned_coefficient)',
 '(Tensor(a!) values, Tensor(b!) indices)',
 '(Te

In [12]:
unique_python_return_types = set()
for line in [return_type_to_python(t) for t in unique_return_types]:
    unique_python_return_types.add(tuple(line))

unique_python_return_types

{('List[Tensor]',),
 ('None',),
 ('QScheme',),
 ('Tensor',),
 ('Tensor', 'Tensor'),
 ('Tensor', 'Tensor', 'Tensor'),
 ('Tensor', 'Tensor', 'Tensor', 'List[Tensor]'),
 ('Tensor', 'Tensor', 'Tensor', 'Tensor'),
 ('Tensor', 'Tensor', 'Tensor', 'Tensor', 'Tensor'),
 ('Tensor', 'Tensor', 'Tensor', 'Tensor', 'int'),
 ('Tensor', 'Tensor', 'float', 'int'),
 ('Union[int, float, complex, Tensor]',),
 ('bool',),
 ('float',),
 ('float', 'int'),
 ('int',),
 ('torch.dtype',)}

In [13]:
list(unique_function_args)[0:10]

['',
 'Tensora! self, Tensor end, Tensor weight',
 'Tensor grad_output, Tensor self, Tensor target, int reduction',
 'Tensor self, int dim, *, ScalarType? dtype=None',
 'Tensor grad, Tensor self, float scale, int zero_point, int quant_min, int quant_max',
 'Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean',
 'Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, Tensor mean_dy, Tensor mean_dy_xmu',
 'Tensor self, int[1] output_size',
 'Tensor self, int dim0, int dim1',
 'Tensor grad_output, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensora! grad_input']

In [14]:
unique_python_arg_signatures = set()
for line in [signature_type_to_python(t) for t in unique_function_args]:
    unique_python_arg_signatures.add(tuple(line))
    
print(len(unique_python_arg_signatures))

list(unique_python_arg_signatures)[0:10]

716


[('self: Tensor',
  'eigenvectors: bool = False',
  'upper: bool = True',
  'e: Tensor',
  'V: Tensor'),
 ('self: Tensor',
  'dim: Dimname',
  'keepdim: bool = False',
  'min: Tensor',
  'min_indices: Tensor'),
 ('self: Tensor', 'pivot: bool = True', 'check_errors: bool = True'),
 ('self: Tensor', 'dim: List[int]', 'keepdim: bool = False'),
 ('input: Tensor',),
 ('self: Tensor',
  'output_size: List[int]',
  'kernel_size: List[int]',
  'dilation: List[int]',
  'padding: List[int]',
  'stride: List[int]'),
 ('self: Tensor',
  'dim: int',
  'sorted: bool = True',
  'return_inverse: bool = False',
  'return_counts: bool = False'),
 ('input: Tensor', 'coefficients: Tensor'),
 ('n: int', 'm: int', 'out: Tensor'),
 ('self: Tensor',
  'dim: int',
  'keepdim: bool = False',
  'min: Tensor',
  'min_indices: Tensor')]

In [15]:
arg_signatures_no_names = set()
for sig in unique_python_arg_signatures:
    new_parts = []
    for part in sig:
        new_part = part.split("=")[0]
        new_part = new_part.split(":")[-1]
        new_part = new_part.strip()
        new_parts.append(new_part)
    arg_signatures_no_names.add(tuple(new_parts))
    
print(len(arg_signatures_no_names))

list(arg_signatures_no_names)[0:10]

452


[('Tensor', 'Tensor', 'bool'),
 ('Tensor', 'Dimname', 'Tensor', 'bool'),
 ('Tensor', 'Tensor', 'Optional[float]', 'Tensor'),
 ('Tensor', 'Tensor', 'Optional[Generator]', 'Tensor'),
 ('Tensor', 'Tensor', 'Tensor', 'Optional[Tensor]', 'int'),
 ('Tensor', 'Tensor', 'Tensor', 'Tensor', 'int'),
 ('Tensor', 'int', 'List[int]', 'Optional[List[Dimname]]'),
 ('Union[int, float, complex, Tensor]', 'Tensor', 'Tensor'),
 ('Tensor', 'Dimname', 'Tensor', 'Tensor'),
 ('Tensor',
  'List[Tensor]',
  'int',
  'Tensor',
  'Tensor',
  'Optional[Tensor]',
  'Tensor',
  'Optional[Tensor]',
  'Optional[Tensor]',
  'Optional[Tensor]',
  'int',
  'int',
  'int',
  'bool',
  'float',
  'bool',
  'bool',
  'List[int]',
  'Optional[Tensor]',
  'Tensor',
  'List[bool]')]

In [16]:
all_types = set()
for sigs in arg_signatures_no_names:
    for part in sigs:
        all_types.add(part)
        
for return_types in unique_python_return_types:
    for part in return_types:
        all_types.add(part)
        
print(len(all_types))
all_types

32


{'ConstQuantizerPtr',
 'Device',
 'Dimname',
 'List[Dimname]',
 'List[Tensor]',
 'List[bool]',
 'List[int]',
 'MemoryFormat',
 'None',
 'Optional[Device]',
 'Optional[Generator]',
 'Optional[Layout]',
 'Optional[List[Dimname]]',
 'Optional[List[Tensor]]',
 'Optional[List[float]]',
 'Optional[List[int]]',
 'Optional[MemoryFormat]',
 'Optional[Tensor]',
 'Optional[Union[int, float, complex, Tensor]]',
 'Optional[bool]',
 'Optional[float]',
 'Optional[int]',
 'Optional[torch.dtype]',
 'QScheme',
 'Storage',
 'Tensor',
 'Union[int, float, complex, Tensor]',
 'bool',
 'float',
 'int',
 'str',
 'torch.dtype'}

In [17]:
def full_name_with_qualname(klass: type) -> str:
    return f'{klass.__module__}.{klass.__qualname__}'

In [18]:
def whatis(thing, optional_key: Optional[str] = None):
    name = getattr(thing, "__name__", "")
    name = f"{thing} {type(thing)} {name}"
    if optional_key is not None:
        name = f".{optional_key} == {name}"

    print(name)
    
    mro = getattr(thing, "mro", None)
    if mro is not None:
        hierarchy = mro()
        print(f"Hierarchy: {hierarchy}")
    

    test_types = [
        "ismodule", "isclass", "ismethod", "ismethoddescriptor", "isfunction", "isgeneratorfunction", "isgenerator",
        "isbuiltin", "isroutine", "isdatadescriptor", "isgetsetdescriptor", "ismemberdescriptor"
    ]
    result_ttypes = []
    for ttype in test_types:
        istype = getattr(inspect, ttype)(thing)
        if istype is True:
            result_ttypes.append(ttype)

    if callable(thing):
        result_ttypes.append("callable")
       
    if issubclass(type(thing), Enum):
        result_ttype.append("ispyenum")

    members = getattr(thing, "__members__", None)
    if members is not None:
        print("members: ", members)
        result_ttypes.append("iscenum")
    
    if "pybind11_type" in str(type(thing)):
        result_ttypes.append("ispybind11")
        
    if len(result_ttypes) == 0:
        if type(thing) in [type(None), bool, int, float, complex, str]:
            result_ttypes.append("isproperty")

    return result_ttypes

In [19]:
def get_signature(thing):
    return inspect.getfullargspec(thing.__init__)

In [20]:
def get_c_class_signature(thing):
    signature = ""
    try:
        thing()
        signature = f"1. {thing}()"
    except Exception as e:
        if "No constructor defined" in str(e):
            return None

        for line in str(e).splitlines():
            if any(parts in line.strip() for parts in ["TypeError", "incompatible", "Invoked with:"]):
                continue
            else:
                signature += f"{line}\n"
        signature = signature.strip()

    return signature

In [21]:
def get_native_signature(key):
    if key in native_func_dict:
        # print(f"Found {key} in native_func_dict")
        return native_func_dict[key]["args"]
    else:
        print(f"Cant find {key} in native_func_dict")
        return None

In [22]:
def get_native_return_type(key):
    if key in native_func_dict:
        # print(f"Found {key} in native_func_dict")
        return native_func_dict[key]["return_type"]
    else:
        print(f"Cant find {key} in native_func_dict")
        return None
    

In [23]:
missing_attr_matches = []
def detect_attrs(obj) -> dict:
    path_dict = {}
    attr_dict = {"module":{}, "enum": {}, "class": {}, "method":{}, "function":{}, "property":{}, "builtin": {}}
    
    count = 0
    for key in dir(obj):
        try:
            if key.startswith("__"):
                continue
            count += 1
    #         if count > 20:
    #             break
            attr = getattr(obj, key)
            prop_types = whatis(attr, key)
            if "iscenum" in prop_types and "ispybind11" not in prop_types:
                attr_dict["enum"][type(attr)] = {"members":type(attr).__members__}
                path_dict[key] = {"type": "enum", "members":type(attr).__members__}
            if "isclass" in prop_types and "ispybind11" not in prop_types:
                attr_dict["class"][key] = get_signature(attr)
                path_dict[key] = {"type": "class", "init":get_signature(attr)}
            if "isclass" in prop_types and "ispybind11" in prop_types:
                attr_dict["class"][key] = get_c_class_signature(attr)
                path_dict[key] = {"type": "class", "init":get_c_class_signature(attr)}

            if "isbuiltin" in prop_types and "isroutine" in prop_types:
                if get_native_signature(key) is None or get_native_return_type(key) is None:
                    missing_attr_matches.append(key)
                path_dict[key] = {
                    "type": "function",
                    "signature": get_native_signature(key),
                    "return_type": get_native_return_type(key)
                }
            if "isproperty" in prop_types:
                attr_dict["property"][key] = type(attr)
                path_dict[key] = {"type": "property", "return_type":type(attr)}

        except Exception as e:
            print(f"Exception with {key}. {e}")
    return path_dict

attrs = detect_attrs(torch)

.AVG == AggregationType.AVG <class 'torch._C.AggregationType'> 
members:  {'SUM': AggregationType.SUM, 'AVG': AggregationType.AVG}
.AggregationType == <class 'torch._C.AggregationType'> <class 'pybind11_builtins.pybind11_type'> AggregationType
Hierarchy: [<class 'torch._C.AggregationType'>, <class 'pybind11_builtins.pybind11_object'>, <class 'object'>]
members:  {'SUM': AggregationType.SUM, 'AVG': AggregationType.AVG}
.AnyType == <class 'torch._C.AnyType'> <class 'pybind11_builtins.pybind11_type'> AnyType
Hierarchy: [<class 'torch._C.AnyType'>, <class 'torch._C.Type'>, <class 'pybind11_builtins.pybind11_object'>, <class 'object'>]
.Argument == <class 'torch._C.Argument'> <class 'pybind11_builtins.pybind11_type'> Argument
Hierarchy: [<class 'torch._C.Argument'>, <class 'pybind11_builtins.pybind11_object'>, <class 'object'>]
.ArgumentSpec == <class 'torch._C.ArgumentSpec'> <class 'pybind11_builtins.pybind11_type'> ArgumentSpec
Hierarchy: [<class 'torch._C.ArgumentSpec'>, <class 'pybind11

In [24]:
missing_attr_matches

['add',
 'as_tensor',
 'autocast_decrement_nesting',
 'autocast_increment_nesting',
 'bitwise_and',
 'bitwise_or',
 'bitwise_xor',
 'bucketize',
 'clear_autocast_cache',
 'conv_transpose2d',
 'conv_transpose3d',
 'ctc_loss',
 'dequantize',
 'div',
 'dsmm',
 'empty',
 'eq',
 'fill_',
 'flatten',
 'fmod',
 'fork',
 'from_numpy',
 'ge',
 'get_default_dtype',
 'get_device',
 'get_num_interop_threads',
 'get_num_threads',
 'gru',
 'gt',
 'hsmm',
 'import_ir_module',
 'import_ir_module_from_buffer',
 'index_fill',
 'init_num_threads',
 'is_anomaly_enabled',
 'is_autocast_enabled',
 'is_grad_enabled',
 'le',
 'lerp',
 'log_softmax',
 'lstm',
 'lt',
 'masked_fill',
 'merge_type_from_type_comment',
 'mul',
 'ne',
 'normal',
 'numel',
 'parse_ir',
 'parse_schema',
 'parse_type_comment',
 'pow',
 'quantized_gru',
 'quantized_lstm',
 'remainder',
 'repeat_interleave',
 'result_type',
 'rnn_relu',
 'rnn_tanh',
 'rsub',
 'saddmm',
 'scatter',
 'searchsorted',
 'select',
 'set_anomaly_enabled',
 'set

In [25]:
def print_attrs(attrs, key: Optional[str] = None):
    for attr, meta in attrs.items():
        if meta["type"] == "enum":
            meta_info = meta["members"]
        if meta["type"] == "class":
            meta_info = meta["init"]
        if meta["type"] == "function":
            sig = meta["signature"]
            return_type = meta["return_type"]
            meta_info = f"{sig} => {return_type}"
        if meta["type"] == "property":
            return_type = meta["return_type"]
            meta_info = f"=> {return_type}"
        t = meta["type"]
        path = attr
        if key is not None:
            path = f"{key}.{path}"
        print(f"torch.{path} - {t} ({meta_info})")

In [26]:
#ConstQuantizerPtr
# Device
# Dimname
# DONE Generator
# Layout
# MemoryFormat
# QScheme
# torch.Storage
# torch.dtype

In [27]:
a = torch.device('cuda', 1)
d = detect_attrs(a)
print_attrs(d, "device")

.index == 1 <class 'int'> 
.type == cuda <class 'str'> 
torch.device.index - property (=> <class 'int'>)
torch.device.type - property (=> <class 'str'>)


In [None]:
# playground

In [None]:
issubclass(type(torch.AggregationType), Enum)
print(torch.AggregationType)
print(torch.AggregationType.SUM)

In [None]:
print(torch.Tensor.set_.__doc__)
print(get_native_signature("set_"))
print(get_native_return_type("set_"))

In [None]:
whatis(torch.SUM)
torch.SUM.__members__

In [None]:
a = torch._C.AggregationType(0)
print(a)

In [None]:
print(torch.AVG.__members__)

In [None]:
whatis(torch._C.AggregationType)

In [None]:
print(torch.AVG.__members__)
# print(torch.AVG.mro())

#callable(torch.AVG)
#inspect.getmro(torch.AVG)
id(torch.AVG)
id(torch.SUM)

In [None]:
a = torch.device('cpu')
thing = getattr(a, "index")
print(type(thing))
print(f"ismodule: {inspect.ismodule(thing)}")
print(f"isclass: {inspect.isclass(thing)}")
print(f"ismethod: {inspect.ismethod(thing)}")
print(f"ismethoddescriptor: {inspect.ismethoddescriptor(thing)}")
print(f"isfunction: {inspect.isfunction(thing)}")
print(f"isgeneratorfunction: {inspect.isgeneratorfunction(thing)}")
print(f"isgenerator: {inspect.isgenerator(thing)}")
print(f"isbuiltin: {inspect.isbuiltin(thing)}")
print(f"isroutine: {inspect.isroutine(thing)}")
print(f"isdatadescriptor: {inspect.isroutine(thing)}")
print(f"isgetsetdescriptor: {inspect.isroutine(thing)}")
print(f"ismemberdescriptor: {inspect.isroutine(thing)}")

In [None]:
def compare_class_w_instance(klass, inst):
    a = set(list(dir(klass)))
    b = set(list(dir(inst)))

    print(a.difference(b))

In [None]:
compare_class_w_instance(torch.Tensor, torch.Tensor([1, 2, 3]))