Skip to content

Commit

Permalink
Standalone: Added support for "tensorflow.function" JIT
Browse files Browse the repository at this point in the history
* This adds support to preserve the source code of
  decorated functions and provide it at runtime to
  tensorflow so it can do its tracing.

* The code generation now has  a way of providing
  module level init codes.

* This mechanism should be possible to generalize
  to other JIT making modules as well.

* Without this, some codes using tensorflow.function
  could totally miss out on the specialization it does,
  that compilation with Nuitka does not currently
  replace.
  • Loading branch information
kayhayen committed Apr 30, 2024
1 parent 0a9cedb commit 74d167b
Show file tree
Hide file tree
Showing 25 changed files with 1,121 additions and 111 deletions.
4 changes: 3 additions & 1 deletion nuitka/HardImportRegistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
"pkg_resources",
"importlib_metadata",
"importlib_resources",
"tensorflow",
)
)

Expand Down Expand Up @@ -109,7 +110,7 @@ def isHardModule(module_name):


# These modules can cause issues if imported during compile time.
hard_modules_trust_with_side_effects = set(["site"])
hard_modules_trust_with_side_effects = set(["site", "tensorflow"])
if not isWin32Windows():
# Crashing on anything but Windows.
hard_modules_trust_with_side_effects.add("ctypes.wintypes")
Expand Down Expand Up @@ -264,6 +265,7 @@ def isHardModuleWithoutSideEffect(module_name):
"ctypes.wintypes": {},
"ctypes.macholib": {},
"builtins": module_builtins_trust,
"tensorflow": {"function": trust_node},
}


Expand Down
3 changes: 3 additions & 0 deletions nuitka/build/include/nuitka/helper/import_hard.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ extern PyObject *IMPORT_HARD_SYS(void);
/* C helper for hard import of module "sysconfig" import. */
extern PyObject *IMPORT_HARD_SYSCONFIG(void);

/* C helper for hard import of module "tensorflow" import. */
extern PyObject *IMPORT_HARD_TENSORFLOW(void);

/* C helper for hard import of module "types" import. */
extern PyObject *IMPORT_HARD_TYPES(void);

Expand Down
25 changes: 25 additions & 0 deletions nuitka/build/include/nuitka/jit_sources.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright 2024, Kay Hayen, mailto:kay.hayen@gmail.com find license text at end of file

#ifndef __NUITKA_JIT_SOURCES_H__
#define __NUITKA_JIT_SOURCES_H__

// Helpers for making source available at run-time for JIT systems
// outside of Nuitka that want it.

extern void SET_UNCOMPILED_FUNCTION_SOURCE_DICT(PyObject *name, PyObject *source);

#endif
// Part of "Nuitka", an optimizing Python compiler that is compatible and
// integrates with CPython, but also works on its own.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
2 changes: 2 additions & 0 deletions nuitka/build/include/nuitka/prelude.h
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,8 @@ extern PyObject *Nuitka_dunder_compiled_value;
#include "nuitka/filesystem_paths.h"
#include "nuitka/safe_string_ops.h"

#include "nuitka/jit_sources.h"

#if _NUITKA_EXPERIMENTAL_WRITEABLE_CONSTANTS
#include "nuitka_data_decoder.h"
#else
Expand Down
2 changes: 2 additions & 0 deletions nuitka/build/static_src/CompiledCodeHelpers.c
Original file line number Diff line number Diff line change
Expand Up @@ -2109,6 +2109,8 @@ PyObject *MAKE_UNION_TYPE(PyObject *args) {
#include "HelpersDumpBacktraces.c"
#endif

#include "HelpersJitSources.c"

// Part of "Nuitka", an optimizing Python compiler that is compatible and
// integrates with CPython, but also works on its own.
//
Expand Down
15 changes: 15 additions & 0 deletions nuitka/build/static_src/HelpersImportHard.c
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,21 @@ PyObject *IMPORT_HARD_SYSCONFIG(void) {
return module_import_hard_sysconfig;
}

/* C helper for hard import of module "tensorflow" import. */
PyObject *IMPORT_HARD_TENSORFLOW(void) {
static PyObject *module_import_hard_tensorflow = NULL;

if (module_import_hard_tensorflow == NULL) {
module_import_hard_tensorflow = PyImport_ImportModule("tensorflow");

if (unlikely(module_import_hard_tensorflow == NULL)) {
return NULL;
}
}

return module_import_hard_tensorflow;
}

/* C helper for hard import of module "types" import. */
PyObject *IMPORT_HARD_TYPES(void) {
static PyObject *module_import_hard_types = NULL;
Expand Down
46 changes: 46 additions & 0 deletions nuitka/build/static_src/HelpersJitSources.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright 2024, Kay Hayen, mailto:kay.hayen@gmail.com find license text at end of file

// This file is included from another C file, help IDEs to still parse it on
// its own.
#ifdef __IDE_ONLY__
#include "nuitka/prelude.h"
#endif

#ifdef _NUITKA_STANDALONE

static char const *uncompiled_sources_dict_attribute_name = "_uncompiled_function_sources_dict";

void SET_UNCOMPILED_FUNCTION_SOURCE_DICT(PyObject *name, PyObject *source) {
PyObject *uncompiled_function_sources_dict =
PyObject_GetAttrString((PyObject *)builtin_module, uncompiled_sources_dict_attribute_name);

if (uncompiled_function_sources_dict == NULL) {
PyThreadState *tstate = PyThreadState_GET();

DROP_ERROR_OCCURRED(tstate);

uncompiled_function_sources_dict = MAKE_DICT_EMPTY();

PyObject_SetAttrString((PyObject *)builtin_module, uncompiled_sources_dict_attribute_name,
uncompiled_function_sources_dict);
}

bool res = DICT_SET_ITEM(uncompiled_function_sources_dict, name, source);
assert(res == false);
}

#endif
// Part of "Nuitka", an optimizing Python compiler that is compatible and
// integrates with CPython, but also works on its own.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
4 changes: 4 additions & 0 deletions nuitka/code_generation/CodeGeneration.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@
generateSubscriptCheckCode,
generateSubscriptLookupCode,
)
from .TensorflowCodes import generateTensorflowFunctionCallCode
from .TryCodes import generateTryCode
from .TupleCodes import generateBuiltinTupleCode, generateTupleCreationCode
from .VariableCodes import (
Expand Down Expand Up @@ -943,6 +944,9 @@ def generateHelpersCode():
# TODO: Should have all of these generically or not. This one is required for now.
"EXPRESSION_DICT_OPERATION_FROMKEYS_REF": generateDictOperationFromkeysRefCode,
"EXPRESSION_TYPE_OPERATION_PREPARE": generateTypeOperationPrepareCode,
# PyPI module "tensorflow" specific stuff
"EXPRESSION_TENSORFLOW_FUNCTION_REF": generateImportModuleNameHardCode,
"EXPRESSION_TENSORFLOW_FUNCTION_CALL": generateTensorflowFunctionCallCode,
}
)

Expand Down
16 changes: 16 additions & 0 deletions nuitka/code_generation/Contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,10 @@ def getInplaceLeftName(self):
def getConstantCode(self, constant, deep_check=False):
pass

@abstractmethod
def addModuleInitCode(self, code):
pass

@abstractmethod
def getModuleCodeName(self):
pass
Expand Down Expand Up @@ -527,6 +531,9 @@ def __init__(self, parent):
def getConstantCode(self, constant, deep_check=False):
return self.parent.getConstantCode(constant, deep_check=deep_check)

def addModuleInitCode(self, code):
self.parent.addModuleInitCode(code)

def getModuleCodeName(self):
return self.parent.getModuleCodeName()

Expand Down Expand Up @@ -764,6 +771,7 @@ class PythonModuleContext(
"variable_storage",
"function_table_entries",
"constant_accessor",
"module_init_codes",
# FrameDeclarationsMixin
"frame_variables_stack",
"frame_type_descriptions",
Expand Down Expand Up @@ -820,6 +828,8 @@ def __init__(self, module, data_filename):
top_level_name="mod_consts", data_filename=data_filename
)

self.module_init_codes = []

def __repr__(self):
return "<PythonModuleContext instance for module %s>" % self.name

Expand Down Expand Up @@ -879,6 +889,12 @@ def getConstantCode(self, constant, deep_check=False):
def getConstantsCount(self):
return self.constant_accessor.getConstantsCount()

def getModuleInitCodes(self):
return self.module_init_codes

def addModuleInitCode(self, code):
self.module_init_codes.append(code)

def addFunctionCreationInfo(self, creation_info):
self.function_table_entries.append(creation_info)

Expand Down
2 changes: 2 additions & 0 deletions nuitka/code_generation/ImportCodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ def getImportHardModuleGetterCode(module_name, context):
def getImportModuleNameHardCode(
to_name, module_name, import_name, needs_check, emit, context
):
module_name = ModuleName(module_name)

if module_name == "sys":
emit("""%s = Nuitka_SysGetObject("%s");""" % (to_name, import_name))
needs_release = False
Expand Down
3 changes: 2 additions & 1 deletion nuitka/code_generation/ModuleCodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def getModuleCode(
"module_functions_code": function_body_codes,
"module_function_table_entries": indented(function_table_entries_decl),
"temps_decl": indented(local_var_inits),
"module_code": indented(module_codes.codes),
"module_init_codes": indented(context.getModuleInitCodes()),
"module_codes": indented(module_codes.codes),
"module_exit": module_exit,
"module_code_objects_decl": indented(module_code_objects_decl, 0),
"module_code_objects_init": indented(module_code_objects_init, 1),
Expand Down
72 changes: 72 additions & 0 deletions nuitka/code_generation/TensorflowCodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2024, Kay Hayen, mailto:kay.hayen@gmail.com find license text at end of file


""" Code generation for tensorflow module specific stuff. """

from nuitka.Options import isStandaloneMode

from .BuiltinCodes import getBuiltinCallViaSpecCode
from .ImportCodes import getImportModuleNameHardCode


def generateTensorflowFunctionCallCode(to_name, expression, emit, context):
"""This is for tensorflow.function calls."""

# TODO: Have global cached forms of hard attribute lookup results too.
tensorflow_function_name = context.allocateTempName(
"tensorflow_function", unique=True
)

getImportModuleNameHardCode(
to_name=tensorflow_function_name,
module_name="tensorflow",
import_name="function",
needs_check=False,
emit=emit,
context=context,
)

# Include source code of "tensorflow.function" decorated functions.
if expression.subnode_func is not None and isStandaloneMode():
func_value = expression.subnode_func

if func_value.isExpressionFunctionCreation():
function_ref = func_value.subnode_function_ref

function_super_qualified_name = function_ref.getFunctionSuperQualifiedName()
function_source_code = function_ref.getFunctionSourceCode()

context.addModuleInitCode(
"""\
SET_UNCOMPILED_FUNCTION_SOURCE_DICT(%s, %s);
"""
% (
context.getConstantCode(function_super_qualified_name),
context.getConstantCode(function_source_code),
)
)

getBuiltinCallViaSpecCode(
spec=expression.spec,
called_name=tensorflow_function_name,
to_name=to_name,
expression=expression,
emit=emit,
context=context,
)


# Part of "Nuitka", an optimizing Python compiler that is compatible and
# integrates with CPython, but also works on its own.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
5 changes: 4 additions & 1 deletion nuitka/code_generation/templates/CodeTemplatesModules.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,8 +451,11 @@
// Temp variables if any
%(temps_decl)s
// Module init code if any
%(module_init_codes)s
// Module code.
%(module_code)s
%(module_codes)s
// Report to PGO about leaving the module without error.
PGO_onModuleExit("%(module_identifier)s", false);
Expand Down
36 changes: 19 additions & 17 deletions nuitka/nodes/AttributeNodesGenerated.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,33 @@


# We are not avoiding these in generated code at all
# pylint: disable=I0021,too-many-lines
# pylint: disable=I0021,line-too-long
# pylint: disable=I0021,too-many-instance-attributes
# pylint: disable=I0021,too-many-return-statements
# pylint: disable=I0021,line-too-long,too-many-instance-attributes,too-many-lines
# pylint: disable=I0021,too-many-arguments,too-many-return-statements,too-many-statements


"""Specialized attribute nodes
WARNING, this code is GENERATED. Modify the template AttributeNodeFixed.py.j2 instead!
spell-checker: ignore __prepare__ append args buffering capitalize casefold center chars
spell-checker: ignore clear closefd copy count decode default delete dist
spell-checker: ignore __prepare__ append args autograph buffering capitalize casefold
spell-checker: ignore center chars clear closefd copy count decode default delete dist
spell-checker: ignore distribution_name encode encoding end endswith errors exit_code
spell-checker: ignore expandtabs extend file fillchar find format format_map formatmap
spell-checker: ignore fromkeys get group handle has_key haskey index insert isalnum
spell-checker: ignore isalpha isascii isdecimal isdigit isidentifier islower isnumeric
spell-checker: ignore isprintable isspace istitle isupper item items iterable iteritems
spell-checker: ignore iterkeys itervalues join keepends key keys kwargs ljust lower lstrip
spell-checker: ignore maketrans maxsplit mode name new newline old opener p package
spell-checker: ignore expandtabs experimental_attributes experimental_autograph_options
spell-checker: ignore experimental_compile experimental_follow_type_hints
spell-checker: ignore experimental_implements experimental_relax_shapes extend file
spell-checker: ignore fillchar find format format_map formatmap fromkeys func get group
spell-checker: ignore handle has_key haskey index input_signature insert isalnum isalpha
spell-checker: ignore isascii isdecimal isdigit isidentifier islower isnumeric isprintable
spell-checker: ignore isspace istitle isupper item items iterable iteritems iterkeys
spell-checker: ignore itervalues jit_compile join keepends key keys kwargs ljust lower
spell-checker: ignore lstrip maketrans maxsplit mode name new newline old opener p package
spell-checker: ignore package_or_requirement pairs partition path pop popitem prefix
spell-checker: ignore prepare remove replace resource resource_name reverse rfind rindex
spell-checker: ignore rjust rpartition rsplit rstrip s sep setdefault sort split
spell-checker: ignore splitlines start startswith stop strip sub suffix swapcase table
spell-checker: ignore tabsize title translate update upper use_errno use_last_error value
spell-checker: ignore values viewitems viewkeys viewvalues width winmode zfill
spell-checker: ignore prepare reduce_retracing remove replace resource resource_name
spell-checker: ignore reverse rfind rindex rjust rpartition rsplit rstrip s sep setdefault
spell-checker: ignore sort split splitlines start startswith stop strip sub suffix
spell-checker: ignore swapcase table tabsize title translate update upper use_errno
spell-checker: ignore use_last_error value values viewitems viewkeys viewvalues width
spell-checker: ignore winmode zfill
"""


Expand Down
Loading

0 comments on commit 74d167b

Please sign in to comment.