In [3]:
import tensorflow as tf
import tensorflow.keras.layers as tfl
import numpy as np

In [14]:
class PositionalEncoding(tfl.Layer):
    def __init__(self, d_model: int, seq_length: int):
        super(PositionalEncoding, self).__init__()
        self.pos_encoding = self.positional_encoding(seq_length, d_model)

    def get_angles(self, pos, i, d_model):
        angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model))
        return pos * angle_rates

    def positional_encoding(self, seq_length, d_model):
        angle_rads = self.get_angles(np.arange(seq_length)[:, np.newaxis], # => (seq_len, 1)
                                     np.arange(d_model)[np.newaxis, :], # => (1, d_model)
                                     d_model)

        angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
        angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])

        pos_encoding = angle_rads[np.newaxis, ...]
        return tf.cast(pos_encoding, dtype=tf.float32)

    def call(self, x):
        return x + self.pos_encoding[:, :tf.shape(x)[1], :]

In [20]:
x = tf.random.normal(shape=(1,6,10))
PositionalEncoding(10, 6)(x)

tf.Tensor(
[[[ 0.0000000e+00  1.0000000e+00  0.0000000e+00  1.0000000e+00
    0.0000000e+00  1.0000000e+00  0.0000000e+00  1.0000000e+00
    0.0000000e+00  1.0000000e+00]
  [ 8.4147096e-01  5.4030228e-01  1.5782665e-01  9.8746681e-01
    2.5116222e-02  9.9968451e-01  3.9810613e-03  9.9999207e-01
    6.3095731e-04  9.9999982e-01]
  [ 9.0929741e-01 -4.1614684e-01  3.1169716e-01  9.5018148e-01
    5.0216600e-02  9.9873835e-01  7.9620592e-03  9.9996829e-01
    1.2619144e-03  9.9999923e-01]
  [ 1.4112000e-01 -9.8999250e-01  4.5775455e-01  8.8907862e-01
    7.5285293e-02  9.9716204e-01  1.1942931e-02  9.9992865e-01
    1.8928709e-03  9.9999821e-01]
  [-7.5680250e-01 -6.5364361e-01  5.9233773e-01  8.0568975e-01
    1.0030649e-01  9.9495661e-01  1.5923614e-02  9.9987322e-01
    2.5238267e-03  9.9999684e-01]
  [-9.5892429e-01  2.8366220e-01  7.1207315e-01  7.0210528e-01
    1.2526439e-01  9.9212337e-01  1.9904044e-02  9.9980187e-01
    3.1547814e-03  9.9999505e-01]]], shape=(1, 6, 10), dtype=fl

<tf.Tensor: shape=(1, 6, 10), dtype=float32, numpy=
array([[[-0.7859173 ,  0.7773997 , -0.08671875,  0.8779846 ,
          1.6954538 ,  1.8347679 , -0.5001034 ,  0.98377967,
          0.4602798 ,  1.3874539 ],
        [ 1.0675675 ,  2.6301348 ,  0.22089678,  0.80823755,
          1.6050917 ,  0.6280722 , -0.11870308,  0.58685625,
          1.8923658 ,  1.3858421 ],
        [ 0.99622196, -0.5422464 ,  0.88039744,  0.523929  ,
          1.2406073 ,  0.8052842 ,  0.57554185,  1.5938379 ,
          1.0584487 ,  0.5241351 ],
        [-0.7065241 , -1.3726014 ,  1.4321014 , -0.73010564,
          0.95234376,  1.7666137 ,  0.3044888 ,  0.52966166,
         -0.856288  ,  0.88565606],
        [-0.8899487 , -1.5368311 , -0.04313767,  3.1681867 ,
         -0.3792765 , -0.33048046,  0.23004808,  0.41639006,
          0.17384364,  1.1872885 ],
        [-1.4351509 ,  1.0429194 ,  2.131488  , -0.27410465,
          1.8655901 ,  0.4377225 ,  0.01971366,  0.8009309 ,
         -1.131422  ,  1.8008431 ]]]

In [21]:
import torch
import torch.nn as nn

In [50]:
x = torch.arange(0, 6).unsqueeze(0).unsqueeze(0) #shape(batch, seq, emd)

comp = torch.view_as_complex(x.float().reshape(1,1,-1,2))
real = torch.view_as_real(comp)

real.reshape(*x.shape[:2], -1)

tensor([[[0., 1., 2., 3., 4., 5.]]])