In [71]:
from keras.models import load_model
import keras
import tensorflow as tf
import numpy as np


In [72]:
%run ../../layers/dimension_utils

In [73]:
def np2tf(x):
    if isinstance(x, np.ndarray):
        x = tf.convert_to_tensor(x, dtype=tf.float32)
        return x, False
    return x, True


def match_tensor(x1:tf.Tensor or np.ndarray, x2:tf.Tensor or np.ndarray):
    
    x1, f1 = np2tf(x1)
    x2, f2 = np2tf(x2)

    # no need to transpose if all var are tensor, we assume tensor are computed by gragh.
    if f1 and f2:
        return x1, x2
    
    # ensure tensor is set to x1, weights set to x2
    if f2:
        x1, x2 = x2, x1

    if x1.shape.ndims != x2.shape.ndims:
        while x2.shape.ndims < x1.shape.ndims:
            x2 = tf.expand_dims(x2, axis=0)
    
    new_shape = shape_NCD_to_NDC_format([i for i in range(len(x2.shape))])
    x2 = tf.transpose(x2, new_shape)
    return (x2, x1) if f2 else (x1, x2)

In [79]:
@keras.saving.register_keras_serializable()
class TFAdd(keras.layers.Layer):
    def __init__(self,tensor_grap,  node_weights, node_inputs, node_attribute, *args, **kwargs):
        super().__init__()
        self.tensor_grap = tensor_grap
        self.node_weights = node_weights
        self.node_inputs = node_inputs
        self.node_attribute = node_attribute
        self.first_operand = tensor_grap[node_inputs[0]] if node_inputs[0] in tensor_grap else node_weights[node_inputs[0]]
        self.second_operand = tensor_grap[node_inputs[1]] if node_inputs[1] in tensor_grap else node_weights[node_inputs[1]]
        self.first_operand, self.second_operand = match_tensor(self.first_operand, self.second_operand)


    def call(self, *args, **kwargs):
        return keras.ops.add(args[0], args[1])
    
    def get_config(self):
        config = super().get_config()
        config.update({
            "tensor_grap":self.tensor_grap,
            'node_weights':self.node_weights,
            'node_inputs':self.node_inputs,
            "first_operand": self.first_operand,
             "second_operand": self.second_operand,
            'node_attribute':self.node_attribute
        })
        return config


class TFSub(keras.layers.Layer):
    def __init__(self,tensor_grap,  node_weights, node_inputs, node_attribute, *args, **kwargs):
        super().__init__()
        self.tensor_grap = tensor_grap
        self.node_weights = node_weights
        self.node_inputs = node_inputs
        self.node_attribute = node_attribute
        self.first_operand = tensor_grap[node_inputs[0]] if node_inputs[0] in tensor_grap else node_weights[node_inputs[0]]
        self.second_operand = tensor_grap[node_inputs[1]] if node_inputs[1] in tensor_grap else node_weights[node_inputs[1]]
        self.first_operand, self.second_operand = match_tensor(self.first_operand, self.second_operand)


    def call(self, *args, **kwargs):
        return keras.ops.subtract(args[0], args[1])
        return self.first_operand - self.second_operand
    
    def get_config(self):
        config = super().get_config()
        config.update({
            "tensor_grap":self.tensor_grap,
            'node_weights':self.node_weights,
            'node_inputs':self.node_inputs,
            "first_operand": self.first_operand,
             "second_operand": self.second_operand,
            'node_attribute':self.node_attribute
        })
        return config


class TFEqual(keras.layers.Layer):
    def __init__(self,tensor_grap,  node_weights, node_inputs, node_attribute, *args, **kwargs):
        super().__init__()
        self.tensor_grap = tensor_grap
        self.node_weights = node_weights
        self.node_inputs = node_inputs
        self.node_attribute = node_attribute
        self.first_operand = tensor_grap[node_inputs[0]] if node_inputs[0] in tensor_grap else node_weights[node_inputs[0]]
        self.second_operand = tensor_grap[node_inputs[1]] if node_inputs[1] in tensor_grap else node_weights[node_inputs[1]]
        self.first_operand, self.second_operand = match_tensor(self.first_operand, self.second_operand)


    def call(self, *args, **kwargs):
        return keras.ops.equal(args[0], args[1])
        return self.first_operand * self.second_operand
    
    def get_config(self):
        config = super().get_config()
        config.update({
            "tensor_grap":self.tensor_grap,
            'node_weights':self.node_weights,
            'node_inputs':self.node_inputs,
            "first_operand": self.first_operand,
             "second_operand": self.second_operand,
            'node_attribute':self.node_attribute
        })
        return config
class TFLog(keras.layers.Layer):
    def __init__(self, *args, **kwargs):
        super().__init__()
    def call(self, inputs, *args, **kwargs):
        return keras.ops.log(inputs)

class TFGreater(keras.layers.Layer):
    def __init__(self, *args, **kwargs):
        super().__init__()


    def call(self, *args, **kwargs):
        return keras.ops.greater(args[0], args[1])


class TFWhere(keras.layers.Layer):
    def __init__(self,tensor_grap,  node_weights, node_inputs, node_attribute, *args, **kwargs):
        super().__init__()
        self.tensor_grap = tensor_grap
        self.node_weights = node_weights
        self.node_inputs = node_inputs
        self.node_attribute = node_attribute
        self.true_value = tensor_grap[node_inputs[1]] if node_inputs[1] in tensor_grap else node_weights[node_inputs[1]]
        self.false_value = tensor_grap[node_inputs[2]] if node_inputs[2] in tensor_grap else node_weights[node_inputs[2]]
        self.true_value, self.false_value = match_tensor(self.true_value, self.false_value)


    def call(self, *args, **kwargs):
        return keras.ops.where(args[0], args[1], args[2])
        return self.first_operand * self.second_operand
    
    def get_config(self):
        config = super().get_config()
        config.update({
            "tensor_grap":self.tensor_grap,
            'node_weights':self.node_weights,
            'node_inputs':self.node_inputs,
            "true_value": self.true_value,
             "false_value": self.false_value,
            'node_attribute':self.node_attribute
        })
        return config


class TFReduceMin(keras.layers.Layer):
    def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs):
        super().__init__()
        self.tensor_grap = tensor_grap
        self.node_weights = node_weights
        self.node_inputs = node_inputs
        self.node_attribute = node_attribute


        self.keep_dims = node_attribute.get("keepdims", 1) == 1
        # change, no shape for dict
        input_shape_len = len(tensor_grap[node_inputs[0]]['config']['shape'])
        self.axes = [channel_to_last_dimension(i) if i >=0 else channel_to_last_dimension(input_shape_len + i) for i in node_attribute.get("axes", [-1])]

    def call(self, inputs, *args, **kwargs):
        return keras.ops.min(inputs, axis=self.axes, keepdims=self.keep_dims)

    def get_config(self):
        config = super().get_config()
        config.update({
            "tensor_grap":self.tensor_grap,
            'node_weights':self.node_weights,
            'node_inputs':self.node_inputs,
            'node_attribute':self.node_attribute
        })
        return config

    
from keras import backend as K
K.clear_session()
keras.saving.get_custom_objects().clear()
custom_objects = {"TFAdd": TFAdd, "TFLog":TFLog, "TFGreater": TFGreater, "TFWhere":TFWhere, "TFSub":TFSub, "TFEqual":TFEqual, "TFReduceMin": TFReduceMin}
with keras.saving.custom_object_scope(custom_objects):
    model = load_model("add.keras")
    model.summary()



In [80]:
import torch
input1 = torch.tensor([10, 40, 50], dtype = torch.float32).reshape(1,-1,1)
input2 = torch.tensor([13, 4, 7], dtype = torch.float32).reshape(1,-1,1)
# input3 = torch.tensor([3, 14, 7], dtype = torch.float32).reshape(1,-1,1)
model.predict((input1))

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step


array([[10.]], dtype=float32)