In [1]:
import tensorflow as tf
import torch
from torch import nn
from random import randint
import numpy as np

In [2]:
'''
Declare model parameters
'''
vocab_size = 83
embedding_dim = 256
rnn_units = 1024
batch_size = 4
seq_length = 100

In [3]:
'''
Create an embedding layer in keras
'''
k_emb = tf.keras.layers.Embedding(vocab_size, embedding_dim, batch_input_shape=[batch_size, None])

In [4]:
'''
Create input for keras model
'''
k_input = np.array([[randint(0,vocab_size-1) for j in range(seq_length)] for i in range(batch_size)])

In [5]:
'''
Get output of embedding layer
'''
k_test = k_emb(k_input)

In [6]:
'''
Create an embedding layer in PyTorch
'''
p_emb = torch.nn.Embedding(batch_size*seq_length, embedding_dim)

In [7]:
'''
Use the same input we used for keras but make it into a PyTorch transformer
'''
p_input = torch.tensor(k_input)

In [8]:
'''
Get output from embedding layer
'''
p_test = p_emb(p_input)

In [9]:
'''
During forward propagation, the Pytorch LSTM layer expects a 3 dimensional tensor as input (similar to TF).
But the default ordering of the dimensions changes. The input tensor should be of shape (timesteps, batch, input_features).
If we want to get the same order of dimensions as TF, we should set batch_first=True at layer initiation.
'''
p_lstm = torch.nn.LSTM(embedding_dim, batch_size*embedding_dim, batch_first=True)

In [10]:
'''
Get output of PyTorch LSTM
'''
p_lstm_test = p_lstm(p_test)

In [11]:
'''
The output of the Pytorch LSTM layer is a tuple with two elements.
The first element of the tuple is LSTM’s output corresponding to all timesteps (hᵗ : ∀t = 1,2…T)
with shape (timesteps, batch, output_features).
The second element of the tuple is another tuple with two elements.
The first element of this second tuple is the output corresponding to the last timestep (hᵀ).
It has the shape (1, batch, output_features).
The second element of this second tuple is the cell state corresponding to the last timestep (cᵀ).
It also has the shape (1, batch, output_features).
If we had initiated the LSTM as a block of stacked layers by setting num_layers=k,
then hᵀ and cᵀ would have the shape (k, batch, output_features).
Here, both hᵀ and cᵀ would have the last states of all the k layers in the stack.
Further at initiation, had we set batch_first=True,
then the timesteps and batch dimensions would swap in the output (similar to the input).
'''
#(seq_length * batch_size * hidden_size)

p_lstm_test[0].shape

torch.Size([4, 100, 1024])

In [12]:
'''
TF LSTM layer expects a 3 dimensional tensor as input during forward propagation.
This input should be of the shape (batch, timesteps, input_features)
'''
k_lstm = tf.keras.layers.LSTM(
        rnn_units,
        return_sequences=True,
        recurrent_initializer='glorot_uniform',
        recurrent_activation='sigmoid',
        stateful=True,
    )

In [13]:
'''
Get output from keras LSTM
'''
k_lstm_test = k_lstm(k_test)

In [14]:
'''
Make a keras dense layer
'''
k_dense = tf.keras.layers.Dense(units=vocab_size)

In [15]:
'''
Get output of keras dense layer
'''
k_out = k_dense(k_lstm_test)

In [16]:
'''
Create a PyTorch linear (dense) layer
'''
p_dense = torch.nn.Linear(rnn_units, vocab_size)

In [17]:
'''
Output for PyTorch LSTM is a tuple (tensor, (tensor, tensor)) so take the first element of the tuple and pass into dense layer
'''
p_out = p_dense(p_lstm_test[0])

In [18]:
'''
Create some random labels to calculate loss function. We're using sparse_categorical_crossentropy in keras
and CrossEntropyLoss in PyTorch
'''

labels = [[randint(0, vocab_size-1) for j in range(seq_length)] for i in range(batch_size)]
#print(labels)
loss = tf.keras.losses.sparse_categorical_crossentropy(labels, k_out, from_logits=False)
print(loss)

tf.Tensor(
[[13.765961   4.1679363  3.5024328 14.174576   3.3851292  3.3538055
   3.146059  14.424006  14.415008  14.439827  14.352431  14.454153
  14.51372    3.8280933 14.532884  14.395902  14.468557   4.229684
  14.568549   6.37278    5.405327   3.3354204 14.299489  14.591692
   4.0014505  4.0593195  3.7824407  2.9956534  4.916977  14.780802
   3.3925154  2.8981307  5.1328225  3.1150978  5.724106   4.9540133
  14.750332   3.706162   3.3901477  3.316679   3.1679006 14.707722
  14.5920315 14.462907   3.9426234 14.373063  14.308451  14.228557
  14.253304  14.494127   2.3596013 14.146403  14.26072    4.4960294
  14.46563   14.403555  14.448816   2.8844917 14.497851   3.6454391
   4.5588217  2.6133156 14.214947  14.333407   2.7033124 14.382575
  14.392125  14.413561  14.449843  14.576517  14.523473   3.932469
   3.298026  14.470236  14.435741  14.564459  14.494001   5.1289854
  14.500172  14.348515  14.478828  14.441821  14.537241  14.597513
  14.715878  14.528492  14.674841  14.527171  

In [19]:
'''
We have to change the shape of the output from the dense layer
'''

x=torch.Tensor(p_out).permute((0,2,1))  # shape of preds must be (N, C, H, W) instead of (N, H, W, C)
y=torch.Tensor(labels).long() #  shape of labels must be (N, H, W) and type must be long integer
print(x)
print(y)
losses = nn.CrossEntropyLoss(reduction="none")(x, y)
losses

tensor([[[-0.0354,  0.0064, -0.0296,  ..., -0.0619,  0.0011,  0.1172],
         [-0.0968, -0.0472,  0.0209,  ..., -0.0288, -0.0691, -0.0899],
         [-0.0085, -0.0450,  0.0250,  ...,  0.0522, -0.0569, -0.0256],
         ...,
         [ 0.0208, -0.0147,  0.0275,  ..., -0.0376, -0.0263, -0.0321],
         [ 0.0405,  0.0866,  0.0260,  ..., -0.0834, -0.0638, -0.0335],
         [-0.0341,  0.0209, -0.0576,  ..., -0.0285,  0.0321, -0.0439]],

        [[ 0.0194,  0.0387,  0.0793,  ..., -0.0215,  0.0194,  0.0167],
         [-0.0249,  0.0300, -0.0852,  ..., -0.0272, -0.0147,  0.0240],
         [-0.0398,  0.0186,  0.0375,  ..., -0.0213, -0.0470, -0.0081],
         ...,
         [ 0.0354, -0.0178, -0.0602,  ...,  0.0076, -0.0272, -0.0084],
         [-0.0127, -0.0078, -0.0472,  ..., -0.0145,  0.0718, -0.0316],
         [ 0.0278, -0.0214, -0.0678,  ..., -0.0419, -0.0032, -0.0162]],

        [[ 0.0057, -0.0062,  0.0109,  ...,  0.0142,  0.0307, -0.0475],
         [-0.0299, -0.0326, -0.0675,  ..., -0

tensor([[4.4556, 4.4446, 4.3917, 4.4269, 4.4474, 4.4787, 4.3162, 4.4109, 4.4482,
         4.4008, 4.3853, 4.3863, 4.3461, 4.3837, 4.3364, 4.3408, 4.4233, 4.3968,
         4.3867, 4.3710, 4.4553, 4.3982, 4.3778, 4.4252, 4.4640, 4.4664, 4.4299,
         4.4188, 4.3705, 4.4602, 4.4568, 4.3845, 4.4133, 4.4565, 4.4071, 4.3972,
         4.4566, 4.3954, 4.4144, 4.3814, 4.2896, 4.4201, 4.3011, 4.5056, 4.4360,
         4.4831, 4.4163, 4.4677, 4.5258, 4.4243, 4.4262, 4.5003, 4.5333, 4.4032,
         4.3575, 4.4255, 4.4375, 4.3780, 4.4160, 4.4436, 4.4497, 4.4436, 4.5374,
         4.4290, 4.4244, 4.4642, 4.4132, 4.4152, 4.3564, 4.3603, 4.4069, 4.3782,
         4.4222, 4.4271, 4.4165, 4.4453, 4.4194, 4.3374, 4.4369, 4.4376, 4.4419,
         4.4414, 4.4036, 4.3221, 4.4511, 4.4091, 4.4005, 4.4421, 4.4221, 4.4580,
         4.5162, 4.4325, 4.4830, 4.4016, 4.2722, 4.4357, 4.4149, 4.3777, 4.3639,
         4.4577],
        [4.4158, 4.4058, 4.4664, 4.4990, 4.4742, 4.4212, 4.4077, 4.4694, 4.3952,
         4