In [2]:
import random
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
embedding_size = 4
source_seq_len = 6
target_seq_len = 7
batch_size = 2
source_vocab_size = 9
target_vocab_size = 10
hidden_size = 5

## Embedding the SOURCE:

In [4]:
torch.manual_seed(1)

source_vector = torch.tensor([[2, 2], [1, 5], [3, 6], [4, 7], [0, 8], [0, 4]])
print("source vector:")
print(source_vector)

embed = nn.Embedding(source_vocab_size, embedding_size)
source_embedded = embed(source_vector)
print()

print("source_embedded:", source_embedded.shape)
print(source_embedded)

source vector:
tensor([[2, 2],
        [1, 5],
        [3, 6],
        [4, 7],
        [0, 8],
        [0, 4]])

source_embedded: torch.Size([6, 2, 4])
tensor([[[-0.7121,  0.3037, -0.7773, -0.2515],
         [-0.7121,  0.3037, -0.7773, -0.2515]],

        [[-0.1002, -0.6092, -0.9798, -1.6091],
         [ 0.1991,  0.0457,  0.1530, -0.4757]],

        [[-0.2223,  1.6871,  0.2284,  0.4676],
         [-1.8821, -0.7765,  2.0242, -0.0865]],

        [[-0.6970, -1.1608,  0.6995,  0.1991],
         [ 2.3571, -1.0373,  1.5748, -0.6298]],

        [[-1.5256, -0.7502, -0.6540, -1.6095],
         [ 2.4070,  0.2786,  0.2468,  1.1843]],

        [[-1.5256, -0.7502, -0.6540, -1.6095],
         [-0.6970, -1.1608,  0.6995,  0.1991]]], grad_fn=<EmbeddingBackward>)


## ENCODING using embedded SOURCE:

In [5]:
torch.manual_seed(10)

source_rnn = nn.GRU(embedding_size, hidden_size, bidirectional=True)

encoder_output, encoder_hidden = source_rnn(source_embedded)

print("encoder_output:", encoder_output.shape)
print(encoder_output)
print()
print("encdoer_hidden directly from rnn:", encoder_hidden.shape)
print(encoder_hidden)

fc = nn.Linear(hidden_size * 2, hidden_size)
encoder_hidden = torch.cat((encoder_hidden[0:1], encoder_hidden[1:2]), -1)
print()
print("encoder_hidden concatenated:", encoder_hidden.shape)
print(encoder_hidden)
encoder_hidden = fc(encoder_hidden)
print()
print("encoder_hidden after passing through fc layer:",encoder_hidden.shape)
print(encoder_hidden)

encoder_output: torch.Size([6, 2, 10])
tensor([[[-0.0719,  0.0092,  0.0766,  0.2258, -0.3897,  0.4973,  0.2241,
          -0.0364,  0.2452,  0.5391],
         [-0.0719,  0.0092,  0.0766,  0.2258, -0.3897,  0.1393, -0.0579,
           0.0384,  0.3339,  0.5365]],

        [[-0.0766, -0.0131, -0.2374,  0.0349, -0.6761,  0.4118,  0.2044,
          -0.0644,  0.1900,  0.5684],
         [-0.0207,  0.0344, -0.1456,  0.0855, -0.4137, -0.2120, -0.3217,
           0.1128,  0.3892,  0.5949]],

        [[ 0.2076,  0.0988, -0.2090,  0.2032, -0.4063,  0.2258, -0.0907,
           0.1444,  0.1525,  0.4167],
         [-0.6163,  0.0439, -0.1630,  0.2272, -0.0916, -0.4553, -0.6790,
           0.2554,  0.4721,  0.5932]],

        [[-0.3797,  0.1084, -0.1694,  0.1712, -0.2130,  0.0815, -0.4713,
           0.0978,  0.5312,  0.6088],
         [-0.0351,  0.2360, -0.6361, -0.0924, -0.0800, -0.4727, -0.1312,
           0.2407,  0.5484,  0.5065]],

        [[-0.4578,  0.0858, -0.1673,  0.1922, -0.6676,  0.4791,  

## The TARGET

In [6]:
torch.manual_seed(1)

target_vector = torch.tensor([[2, 2], [1, 5], [3, 6], [4, 7], [0, 8], [0, 9],[0, 4]])
print("target vector:")
print(target_vector)

target vector:
tensor([[2, 2],
        [1, 5],
        [3, 6],
        [4, 7],
        [0, 8],
        [0, 9],
        [0, 4]])


## DECODING:
The decoder's 1st hidden states are the encoder's final hidden states

In [12]:
torch.manual_seed(10)

decoder_hidden = encoder_hidden
outputs = torch.zeros(target_seq_len, batch_size, target_vocab_size)

embed = nn.Embedding(target_vocab_size, embedding_size)
energy = nn.Linear(hidden_size * 3, 1)
relu = nn.ReLU()
softmax = nn.Softmax(dim = 0)
rnn = nn.GRU(hidden_size * 2 + embedding_size, hidden_size)

x = target_vector[0]

for i in range(1, target_seq_len):
    x = x.unsqueeze(0)
    x_embed = embed(x)
    
    print("x_embed:", x_embed.shape)
    print(x_embed)
    
    seq_len = encoder_output.shape[0]
    
    # Alignment
    h_reshaped = decoder_hidden.repeat(seq_len, 1, 1)
    print()
    print("h_reshaped (decoder_hidden repeated n times):", h_reshaped.shape)
    print(h_reshaped)
    
    h_reshaped_encoder_out_cat = torch.cat((h_reshaped, encoder_output), 2)
    print()
    print("h_reshaped_encoder_out_cat:", h_reshaped_encoder_out_cat.shape)
    print(h_reshaped_encoder_out_cat)
    
    e = relu(energy(h_reshaped_encoder_out_cat))
    print()
    print("e:", e.shape)
    print(e)
    
    # Weighing 
    attention = softmax(e)
    print()
    print("attention:", attention.shape)
    print(attention)
    
    context_vector = torch.einsum("snk,snl->knl", attention, encoder_output)
    print()
    print("context_vector:", context_vector.shape)
    print(context_vector)
    
    rnn_input = torch.cat((context_vector, x_embed), dim=2)
    print()
    print("rnn_input:", rnn_input.shape)
    print(rnn_input)
    
    out, decoder_hidden = rnn(rnn_input, decoder_hidden)
    print()
    print("decoder rnn out:", out.shape)
    print("decoder hidden:", decoder_hidden.shape)
    
    break

x_embed: torch.Size([1, 2, 4])
tensor([[[-0.7465,  1.0051, -0.2568,  0.4765],
         [-0.7465,  1.0051, -0.2568,  0.4765]]], grad_fn=<EmbeddingBackward>)

h_reshaped (decoder_hidden repeated n times): torch.Size([6, 2, 5])
tensor([[[ 0.1603, -0.0372,  0.0328, -0.3453,  0.0616],
         [-0.0238, -0.2260, -0.2417,  0.0431,  0.0713]],

        [[ 0.1603, -0.0372,  0.0328, -0.3453,  0.0616],
         [-0.0238, -0.2260, -0.2417,  0.0431,  0.0713]],

        [[ 0.1603, -0.0372,  0.0328, -0.3453,  0.0616],
         [-0.0238, -0.2260, -0.2417,  0.0431,  0.0713]],

        [[ 0.1603, -0.0372,  0.0328, -0.3453,  0.0616],
         [-0.0238, -0.2260, -0.2417,  0.0431,  0.0713]],

        [[ 0.1603, -0.0372,  0.0328, -0.3453,  0.0616],
         [-0.0238, -0.2260, -0.2417,  0.0431,  0.0713]],

        [[ 0.1603, -0.0372,  0.0328, -0.3453,  0.0616],
         [-0.0238, -0.2260, -0.2417,  0.0431,  0.0713]]],
       grad_fn=<RepeatBackward>)

h_reshaped_encoder_out_cat: torch.Size([6, 2, 15])
tensor