In [1]:
import os

os.environ['KERAS_BACKEND'] = "torch"

## Masked Average

In [78]:
import torch
import keras
from layer import MaskedAverage

In [59]:
x_tensor = torch.tensor([
    [100, 1, 10],
    [110, 10, 30],
    [90, 33, 99]
])

In [60]:
x_tensor.shape

torch.Size([3, 3])

In [68]:
masked_impl_avg = MaskedAverage()(input_data=x_tensor, mask=[1, 1, 1])
keras_avg = keras.layers.Average()(inputs=[x_tensor, torch.tensor([1, 1, 1])])

print(f"MaskedAverage: {masked_impl_avg} \nkeras.layers.Average: {keras_avg}")

MaskedAverage: tensor([100.0000,  14.6667,  46.3333]) 
keras.layers.Average: tensor([[50.5000,  1.0000,  5.5000],
        [55.5000,  5.5000, 15.5000],
        [45.5000, 17.0000, 50.0000]])


Average works on wrong axis.

In [71]:
import numpy as np

x = np.array([
    [1., 2., 3.],
    [4., 5., 6.],
    [7., 8., 9.],
    [7., 8., 9.],
    [11., 8., 9.]]
)

In [72]:
MaskedAverage()(x)

tensor([6.0000, 6.2000, 7.2000])

In [73]:
inputs = np.random.random([10, 3]).astype(np.float32)
# We create input and simulate empty rows as could be yielded by our Embedding model on unknown words (Not that any of those exist in corpus)
inputs[3, :] = 0.
inputs[5, :] = 0.

In [74]:
inputs  # 0.4803591

array([[0.720102  , 0.85836864, 0.3445978 ],
       [0.29986498, 0.5420876 , 0.88653666],
       [0.8114074 , 0.52032554, 0.1960794 ],
       [0.        , 0.        , 0.        ],
       [0.3971557 , 0.19923395, 0.11562309],
       [0.        , 0.        , 0.        ],
       [0.22157855, 0.1699425 , 0.39186078],
       [0.80099577, 0.2734881 , 0.918462  ],
       [0.16538976, 0.5264178 , 0.91051924],
       [0.8423247 , 0.5678992 , 0.27526826]], dtype=float32)

In [75]:
x = keras.layers.Masking(mask_value=0.0)(inputs)
res = MaskedAverage()(x)

In [77]:
print(res) # The result is correct wrt to the zero values to skip for the average

tensor([0.5324, 0.4572, 0.5049])


## WeightedSumLayer

In [None]:
# Class Definition
class WeightedSumLayer(keras.layers.Layer):
    def __init__(self, **kwargs):
        super(WeightedSumLayer, self).__init__(**kwargs)
        self.supports_masking = True

    def call(self, inputs):
        x, w = inputs
        w = keras.ops.expand_dims(w, axis=-1)

        weighted_val = w * x
        print(weighted_val)
        return keras.ops.sum(x * w, axis=1)

    def get_output_shape_for(self, input_shape):
        return input_shape[0][0], input_shape[0][-1]

