In [1]:
from argparse import Namespace
from pathlib import Path
from functools import partial
from collections import (
    OrderedDict,
    Counter,
    defaultdict
)

# torch 
import torch 
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.nn.utils.rnn import pack_padded_sequence,pad_packed_sequence
from torch.utils.data import default_collate
# other 
from torchtext import vocab
try:
    from torchdata import datapipes as dp
    from torchdata.dataloader2 import DataLoader2
except:
    !pip install torchdata
    from torchdata import datapipes as dp
    from torchdata.dataloader2 import DataLoader2

# manipulation 
import numpy as np

#other
from tqdm import tqdm

In [2]:
args = Namespace(
    # data
    data_base_path = "../../data/mt/",
    dataset = ["train_1"],
    
    # vocab
    mask_tkn = "<MASK>",
    ukn_tkn = "<UKN>",
    beg_tkn = "<BEG>",
    end_tkn = "<END>",
    
    # Training
    batch_size = 2,
    embedding_size = 5,
    rnn_hidden_size = 4,
    
    
    # runtime
    cuda = torch.cuda.is_available(),
    device = "cuda" if torch.cuda.is_available() else "cpu",
    
)
for k,v in args._get_kwargs():
    if "base" in k:
        Path(v).mkdir(parents=True,exist_ok=True)

In [3]:
def build_pipe_dict(args=args):
    pipe_dict = {}
    for fname in args.dataset:
        pipe = dp.iter.FileOpener([args.data_base_path+f"{fname}.csv"])
        pipe_dict[fname] = pipe.parse_csv(skip_lines=1)
    return pipe_dict          

In [4]:
pipe_dict = build_pipe_dict(args)
pipe_dict

{'train_1': CSVParserIterDataPipe}

In [5]:
def create_vocab(counter,max_seq_length,args):
    sort_tuples = sorted(counter.items(),key=lambda kf :(-kf[1],kf[0]))
    vocabulary = vocab.vocab(ordered_dict=OrderedDict(sort_tuples),
                        specials=[args.mask_tkn,
                                  args.ukn_tkn,
                                  args.beg_tkn,
                                  args.end_tkn])
    vocabulary.max_seq_length = max_seq_length
    vocabulary.set_default_index(vocabulary[args.ukn_tkn])
    return vocabulary

In [6]:
def build_vocab_dict(train_pipe,args=args):
    counter_dict = defaultdict(Counter)
    max_seq_len_dict = defaultdict(int)
    for row in train_pipe:
        for key,sent in zip(["source","target"],row):
            max_seq_len_dict[key] = max(max_seq_len_dict[key],len((sent_list:=sent.split(" "))))
            counter_dict[key].update(sent_list)
    
    return {k:create_vocab(counter_dict[k],max_seq_len_dict[k],args)
            for k in counter_dict.keys()}    

In [7]:
vocab_dict = build_vocab_dict(pipe_dict["train_1"],args)
vocab_dict

{'source': Vocab(), 'target': Vocab()}

In [8]:
def vectorize(indices,seq_length,mask_idx):
    vector = np.full(shape=seq_length,
                     fill_value=mask_idx,
                     dtype=np.int64)
    vector[:len(indices)] = indices
    return vector

def get_source_indices(source_text,source_vocab,args=args):
    indices = [source_vocab[args.beg_tkn]]
    indices.extend(source_vocab.lookup_indices(source_text.split(" ")))
    indices.append(source_vocab[args.end_tkn])
    
    return indices

def get_target_indices(target_text,target_vocab,args=args):
    indices = target_vocab.lookup_indices(target_text.split(" "))
    target_x = [target_vocab[args.beg_tkn]] + indices
    target_y = indices + [target_vocab[args.end_tkn]]
    return target_x ,target_y

def create_dataset(vocab_dict,args,row):
    source_indices = get_source_indices(source_text=row[0],
                                        source_vocab=vocab_dict["source"],
                                        args=args)
    source_vector = vectorize(source_indices,
                              seq_length = vocab_dict["source"].max_seq_length + 2,
                              mask_idx= vocab_dict["source"][args.mask_tkn])
    
    target_x_indices ,target_y_indices = get_target_indices(target_text=row[1],
                                                            target_vocab=vocab_dict["target"],
                                                            args=args)
    
    target_x_vector = vectorize(target_x_indices,
                              seq_length = vocab_dict["target"].max_seq_length + 1,
                              mask_idx= vocab_dict["target"][args.mask_tkn])
    
    target_y_vector = vectorize(target_y_indices,
                              seq_length = vocab_dict["target"].max_seq_length + 1,
                              mask_idx= vocab_dict["target"][args.mask_tkn])
    
    return {"source_vector":source_vector,
            "target_x_vector":target_x_vector,
            "target_y_vector":target_y_vector,
            "source_length":len(source_indices)}
    
    
def collate_fn(args,batch):
    data_dict = default_collate(batch)
    lengths = data_dict["source_length"].numpy()
    sorted_lengths = lengths.argsort()[::-1].tolist()
    
    return {k:v[sorted_lengths].to(args.device)
            for k,v in data_dict.items()}
        
        

In [9]:
def build_dataset_pipe(pipe_dict,vocab_dict,args):
    dataset_pipe = {}
    fn = partial(create_dataset,vocab_dict,args)
    for dataset,pipe in pipe_dict.items():
        if dataset == "train":
            pipe = pipe.shuffle()
        
        pipe = pipe.map(fn)
        pipe = pipe.batch(args.batch_size,drop_last=True)
        pipe = pipe.collate(partial(collate_fn,args))
        
        dataset_pipe[dataset] = pipe
        
    return dataset_pipe

In [10]:
dataset_dict = build_dataset_pipe(pipe_dict,vocab_dict,args)
dataset_dict

{'train_1': CollatorIterDataPipe}

In [58]:
args.source_vocab_size = len(vocab_dict["source"])
args.target_vocab_size = len(vocab_dict["target"])
args.padding_idx = vocab_dict["source"][args.mask_tkn]
args.target_embedding_size = 5

In [12]:
vocab_dict["source"].get_itos()

['<MASK>',
 '<UKN>',
 '<BEG>',
 '<END>',
 '.',
 'always',
 'am',
 'complaining',
 'exhausted',
 'her',
 'i',
 'is',
 'job',
 'of',
 'she']

In [13]:
vocab_dict["target"].get_itos()

['<MASK>',
 '<UKN>',
 '<BEG>',
 '<END>',
 '.',
 'de',
 'elle',
 'je',
 'plaint',
 'se',
 'son',
 'suis',
 'toujours',
 'travail',
 'vannÃ©']

In [14]:
dataloader = DataLoader2(dataset_dict["train_1"])

In [15]:
sample = next(iter(dataset_dict["train_1"]))
sample

{'source_vector': tensor([[ 2, 14, 11,  5,  7, 13,  9, 12,  4,  3],
         [ 2, 10,  6,  8,  4,  3,  0,  0,  0,  0]]),
 'target_x_vector': tensor([[ 2,  6,  9,  8, 12,  5, 10, 13,  4],
         [ 2,  7, 11, 14,  4,  0,  0,  0,  0]]),
 'target_y_vector': tensor([[ 6,  9,  8, 12,  5, 10, 13,  4,  3],
         [ 7, 11, 14,  4,  3,  0,  0,  0,  0]]),
 'source_length': tensor([10,  6])}

# step 1. take in input [ B , Seq ]

In [16]:
source_input = sample["source_vector"]
source_input.shape

torch.Size([2, 10])

1. we have two input with each input have fixed seq length of 10.
2. Each element hold the word index in the vocab

# Step 2. Convert input vector into its embedding [ B , Seq , Emb]

we creating the embedding nothing but the lookup where index are word ,then each each word are represent as the vector

In [17]:
source_embedding = nn.Embedding(num_embeddings=args.source_vocab_size,
                                embedding_dim=args.embedding_size,padding_idx=0)

In [18]:
torch.all(source_embedding(source_input) == source_embedding.weight[source_input])

tensor(True)

In [19]:
source_emb_out = source_embedding(source_input)
source_emb_out.shape

#? each word is represent as the 5 dim vector

torch.Size([2, 10, 5])

# Step 3 Encode the emb using bi_gru

1. output shape ==> [B,Seq,D*hid]
2. Hidden shape ==> [D*layers,B,hid]

In [20]:
encoder_bi_gru = nn.GRU(input_size=args.embedding_size,
                        hidden_size = args.rnn_hidden_size,
                        bias=False,
                        batch_first=True,
                        num_layers=1,
                        bidirectional=True)
encoder_bi_gru

GRU(5, 4, bias=False, batch_first=True, bidirectional=True)

1. Bi-GRU will have the four weights.
2. Two for the forward and Two for the backward.
3. Each GRU Cell has three gate reset,update,new gate
4. Each Gate have two weighs one for input and another for hidden.
5. Shape of each gate weights for input [hidden size ,input size] . Three Weight matrix concat to form Weight matrix for input_hidden ,Shape [3*hidden,input_size]
6. shape of each gate weights for hidden [hidden_size,hidden_size].Same will happen to these matrix will form the hidden_hidden weight matrix of shape [3*hidden,hidden] 


In [21]:
W_hh_f = encoder_bi_gru.weight_hh_l0
W_hh_b = encoder_bi_gru.weight_hh_l0_reverse

W_hh_f.shape,W_hh_b.shape

#? [3(#gate)*4(hidden_size),4(hidden_size)]

(torch.Size([12, 4]), torch.Size([12, 4]))

In [22]:
W_ih_f = encoder_bi_gru.weight_ih_l0
W_ih_b = encoder_bi_gru.weight_ih_l0_reverse

W_ih_f.shape,W_ih_b.shape

#? [3(#gate)*4(hidden_size),5(input_size)]

(torch.Size([12, 5]), torch.Size([12, 5]))

In [23]:
W_ir_f,W_iz_f,W_in_f = W_ih_f.split(args.rnn_hidden_size)
W_hr_f,W_hz_f,W_hn_f = W_hh_f.split(args.rnn_hidden_size)
W_ir_b,W_iz_b,W_in_b = W_ih_b.split(args.rnn_hidden_size)
W_hr_b,W_hz_b,W_hn_b = W_hh_b.split(args.rnn_hidden_size)

I have split the each weight into its gate weights matrix.

In [24]:
def compute_hidden(xt,h,W_ir,W_iz,W_in,W_hr,W_hz,W_hn):
    rt = torch.sigmoid(xt@W_ir.T + h@W_hr.T)
    zt = torch.sigmoid(xt@W_iz.T + h@W_hz.T)
    nt = torch.tanh(xt@W_in.T + rt*(h@W_hn.T))
    h = (1-zt)*nt + zt*h
    return h

In [25]:
h = torch.zeros(size=(2,args.batch_size,args.rnn_hidden_size))

forward = []
backward = []
hf = h[0].unsqueeze(0)
hb = h[1].unsqueeze(0)
#? we have to iterate over sequence so permute to [seq,batch,emb]
source_seq_input = source_emb_out.permute(1,0,2)
for t in range(source_seq_input.size(0)):
    xt_f = source_seq_input[t]
    xt_b = source_seq_input[-1-t]
    hf = compute_hidden(xt_f,hf,W_ir_f,W_iz_f,W_in_f,W_hr_f,W_hz_f,W_hn_f)
    hb = compute_hidden(xt_b,hb,W_ir_b,W_iz_b,W_in_b,W_hr_b,W_hz_b,W_hn_b)
    forward.append(hf)
    backward.append(hb)    

In [26]:
manual_h = torch.cat((hf,hb),dim=0)

In [27]:
manual_out = torch.cat((torch.stack(forward).squeeze(),
                        torch.flip(torch.stack(backward).squeeze(),dims=[0])), # flip across the seq dim so that we can concat the forward and backward
                       dim=2).permute(1,0,2)

#? concat across the embedding 
#? then we permute to batch,seq,emb

In [28]:
source_emb_out

tensor([[[ 0.9074, -2.1085, -0.3591,  1.1981, -0.3069],
         [-0.8585, -0.1450,  1.7102,  1.5498,  0.8394],
         [ 0.6814,  1.2009,  2.8156, -0.8746,  0.4696],
         [-0.8103,  0.8221, -0.8722, -1.5284, -0.3509],
         [-0.6266, -0.0082, -1.1014,  0.9016,  0.5329],
         [-1.1077, -0.7007,  2.5961,  0.6664,  0.8542],
         [ 0.5068, -1.8874, -0.2909, -0.1333,  1.1500],
         [ 0.1781,  0.2571,  1.3376,  0.9025,  0.5470],
         [-0.6205, -1.5771,  0.1747, -0.3322, -0.2763],
         [ 0.2494, -0.2989, -0.3181,  1.2363,  0.7480]],

        [[ 0.9074, -2.1085, -0.3591,  1.1981, -0.3069],
         [ 0.7964, -0.1401,  0.6662, -0.4248,  0.1476],
         [-0.3629, -0.9660,  0.8563,  0.7797, -1.0582],
         [-1.1474,  1.1108,  0.7414, -0.3431, -0.3324],
         [-0.6205, -1.5771,  0.1747, -0.3322, -0.2763],
         [ 0.2494, -0.2989, -0.3181,  1.2363,  0.7480],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000, 

In [29]:
encoder_out , encoder_h = encoder_bi_gru(source_emb_out,h)

In [30]:
torch.all(torch.isclose(encoder_h ,manual_h))

tensor(True)

In [31]:
torch.all(torch.isclose(encoder_out ,manual_out,atol=1e-7))

tensor(True)

## Step 3A Using packed sequence

In [32]:
source_emb_packed = pack_padded_sequence(source_emb_out,
                                         sample["source_length"].numpy().tolist(),batch_first=True)
source_emb_packed

PackedSequence(data=tensor([[ 0.9074, -2.1085, -0.3591,  1.1981, -0.3069],
        [ 0.9074, -2.1085, -0.3591,  1.1981, -0.3069],
        [-0.8585, -0.1450,  1.7102,  1.5498,  0.8394],
        [ 0.7964, -0.1401,  0.6662, -0.4248,  0.1476],
        [ 0.6814,  1.2009,  2.8156, -0.8746,  0.4696],
        [-0.3629, -0.9660,  0.8563,  0.7797, -1.0582],
        [-0.8103,  0.8221, -0.8722, -1.5284, -0.3509],
        [-1.1474,  1.1108,  0.7414, -0.3431, -0.3324],
        [-0.6266, -0.0082, -1.1014,  0.9016,  0.5329],
        [-0.6205, -1.5771,  0.1747, -0.3322, -0.2763],
        [-1.1077, -0.7007,  2.5961,  0.6664,  0.8542],
        [ 0.2494, -0.2989, -0.3181,  1.2363,  0.7480],
        [ 0.5068, -1.8874, -0.2909, -0.1333,  1.1500],
        [ 0.1781,  0.2571,  1.3376,  0.9025,  0.5470],
        [-0.6205, -1.5771,  0.1747, -0.3322, -0.2763],
        [ 0.2494, -0.2989, -0.3181,  1.2363,  0.7480]],
       grad_fn=<PackPaddedSequenceBackward0>), batch_sizes=tensor([2, 2, 2, 2, 2, 2, 1, 1, 1, 1]), 

In [33]:
source_emb_out

tensor([[[ 0.9074, -2.1085, -0.3591,  1.1981, -0.3069],
         [-0.8585, -0.1450,  1.7102,  1.5498,  0.8394],
         [ 0.6814,  1.2009,  2.8156, -0.8746,  0.4696],
         [-0.8103,  0.8221, -0.8722, -1.5284, -0.3509],
         [-0.6266, -0.0082, -1.1014,  0.9016,  0.5329],
         [-1.1077, -0.7007,  2.5961,  0.6664,  0.8542],
         [ 0.5068, -1.8874, -0.2909, -0.1333,  1.1500],
         [ 0.1781,  0.2571,  1.3376,  0.9025,  0.5470],
         [-0.6205, -1.5771,  0.1747, -0.3322, -0.2763],
         [ 0.2494, -0.2989, -0.3181,  1.2363,  0.7480]],

        [[ 0.9074, -2.1085, -0.3591,  1.1981, -0.3069],
         [ 0.7964, -0.1401,  0.6662, -0.4248,  0.1476],
         [-0.3629, -0.9660,  0.8563,  0.7797, -1.0582],
         [-1.1474,  1.1108,  0.7414, -0.3431, -0.3324],
         [-0.6205, -1.5771,  0.1747, -0.3322, -0.2763],
         [ 0.2494, -0.2989, -0.3181,  1.2363,  0.7480],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000, 

In [34]:
source_emb_packed.data.shape

torch.Size([16, 5])

In [35]:
pack_encoder_out,pack_encoder_h = encoder_bi_gru(source_emb_packed,h)

In [36]:
padded_encoder_out,_ = pad_packed_sequence(pack_encoder_out,batch_first=True)

In [37]:
torch.isclose(padded_encoder_out ,encoder_out)

tensor([[[ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True]],

        [[ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  

1. False in above output is happen only in forward sequence only not on the backward sequence.
2. All false in the padded seq position.

> Why this is happened only the padded position only?
>>  When we pass the shorter seq to gru , it learn something in the non-padded position ,this learned info is passed in padded position also.but in the packed sequence there is no padded sequence , there it wont pass the info, it remain zero . this is what we needed.**So alway use the packed sequence when training rnn.**

In [39]:
torch.isclose(pack_encoder_h ,encoder_h)

tensor([[[ True,  True,  True,  True],
         [False, False, False, False]],

        [[ True,  True,  True,  True],
         [ True,  True,  True,  True]]])

In [40]:
padded_encoder_out[1]

tensor([[-0.1182,  0.5731,  0.2595, -0.5787, -0.4998, -0.0218,  0.4067, -0.2413],
        [-0.0100,  0.2469, -0.0475, -0.3790, -0.3101,  0.1612,  0.1219, -0.2342],
        [ 0.0364,  0.1461, -0.2013, -0.2400, -0.3631,  0.3954,  0.4890, -0.2404],
        [ 0.2429, -0.1025, -0.5825,  0.1807, -0.2435,  0.3893,  0.1379,  0.0524],
        [ 0.2124, -0.0534, -0.1428, -0.2280, -0.4336,  0.0330,  0.4035,  0.2857],
        [ 0.1100,  0.3438,  0.1894, -0.4629, -0.3680, -0.2646,  0.2554, -0.0948],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],
       grad_fn=<SelectBackward0>)

In [41]:
encoder_out[1]

tensor([[-0.1182,  0.5731,  0.2595, -0.5787, -0.4998, -0.0218,  0.4067, -0.2413],
        [-0.0100,  0.2469, -0.0475, -0.3790, -0.3101,  0.1612,  0.1219, -0.2342],
        [ 0.0364,  0.1461, -0.2013, -0.2400, -0.3631,  0.3954,  0.4890, -0.2404],
        [ 0.2429, -0.1025, -0.5825,  0.1807, -0.2435,  0.3893,  0.1379,  0.0524],
        [ 0.2124, -0.0534, -0.1428, -0.2280, -0.4336,  0.0330,  0.4035,  0.2857],
        [ 0.1100,  0.3438,  0.1894, -0.4629, -0.3680, -0.2646,  0.2554, -0.0948],
        [ 0.0340,  0.1789,  0.0060, -0.1277,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0082,  0.0880, -0.0293, -0.0322,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.0006,  0.0419, -0.0276, -0.0049,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.0030,  0.0195, -0.0190,  0.0017,  0.0000,  0.0000,  0.0000,  0.0000]],
       grad_fn=<SelectBackward0>)

In [42]:
source_emb_out

tensor([[[ 0.9074, -2.1085, -0.3591,  1.1981, -0.3069],
         [-0.8585, -0.1450,  1.7102,  1.5498,  0.8394],
         [ 0.6814,  1.2009,  2.8156, -0.8746,  0.4696],
         [-0.8103,  0.8221, -0.8722, -1.5284, -0.3509],
         [-0.6266, -0.0082, -1.1014,  0.9016,  0.5329],
         [-1.1077, -0.7007,  2.5961,  0.6664,  0.8542],
         [ 0.5068, -1.8874, -0.2909, -0.1333,  1.1500],
         [ 0.1781,  0.2571,  1.3376,  0.9025,  0.5470],
         [-0.6205, -1.5771,  0.1747, -0.3322, -0.2763],
         [ 0.2494, -0.2989, -0.3181,  1.2363,  0.7480]],

        [[ 0.9074, -2.1085, -0.3591,  1.1981, -0.3069],
         [ 0.7964, -0.1401,  0.6662, -0.4248,  0.1476],
         [-0.3629, -0.9660,  0.8563,  0.7797, -1.0582],
         [-1.1474,  1.1108,  0.7414, -0.3431, -0.3324],
         [-0.6205, -1.5771,  0.1747, -0.3322, -0.2763],
         [ 0.2494, -0.2989, -0.3181,  1.2363,  0.7480],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000, 

# Step 4 Use the output of encoder as encoder state

In [44]:
encoder_state = encoder_out

# Step 5 Use the last hidden output as the initial hidden state for the decoder

In [66]:
args.decoding_size = 2 * args.rnn_hidden_size

In [51]:
initial_hidden = encoder_h.permute(1,0,2).contiguous().view(args.batch_size,-1)
initial_hidden.shape 

torch.Size([2, 8])

Note:
* 2*hidden_size of encoder = hidden_size of decoder
* 8 will be the hidden_size of the decoder

# step 5A pass the initial hidden state to linear

In [67]:
hidden_fc = nn.Linear(args.decoding_size,args.decoding_size)
initial_hidden =  hidden_fc(initial_hidden)

# Step 6 Define the target embedding

In [59]:
target_embedding = nn.Embedding(embedding_dim=args.target_embedding_size,
                                num_embeddings=args.target_vocab_size,
                                padding_idx=args.padding_idx)
target_embedding

Embedding(15, 5, padding_idx=0)

# Step 7 Define the decoder cell

In [62]:
decoder_cell = nn.GRUCell(input_size=args.target_embedding_size + args.decoding_size, # for context vector
                          hidden_size=args.decoding_size,
                          bias=False,)

# Step 7A Define the classifier

In [70]:
classifier = nn.Linear(2*args.decoding_size,args.target_vocab_size)

# Step 8 Initial the context vector

In [61]:
context_vector = torch.zeros(size=(args.batch_size,args.decoding_size))

# Step 9 Get the input of the target sequence [batch,max_seq_length_target_vocab]

In [54]:
target_input = sample["target_x_vector"]
target_input

tensor([[ 2,  6,  9,  8, 12,  5, 10, 13,  4],
        [ 2,  7, 11, 14,  4,  0,  0,  0,  0]])

In [55]:
target_input.shape

torch.Size([2, 9])

# Step 10 Permute to [Seq,Batch]

So we can loop through data directly

In [56]:
target_input = target_input.permute(1,0)
target_input.shape

torch.Size([9, 2])

# Step 11 loop through seq

In [63]:
context_vector.shape

torch.Size([2, 8])

In [65]:
target_embedding(target_input[0]).shape

torch.Size([2, 5])

In [68]:
h_t = initial_hidden

In [69]:
def attention_mechanism(encoder_state,query_vector):
    batch_size,seq_size,emb_size = encoder_state.size()
    
    vector_score = torch.sum(encoder_state * query_vector.view(batch_size,1,emb_size),
                             dim=2)
    
    vector_prob = F.softmax(vector_score,dim=1)
    
    weight_vector = encoder_state * vector_prob.view(batch_size,seq_size,1)
    
    context_vector = torch.sum(weight_vector,dim=1)
    
    return context_vector,vector_prob,vector_score

In [71]:
output_vector = []
for xt in target_input:
    # step 11a find the embedding of input
    xt_vector = target_embedding(xt)
    # step 11b combine the embed vector and context vector which the input to decoder cell
    cell_input = torch.cat((xt_vector,context_vector),dim=1) 
    # shape [batch ,target_embedding_size + decoding_size]
    h_t = decoder_cell(cell_input,h_t)
    
    # attention mechanism
    context_vector,p_attn,_ = attention_mechanism(encoder_state=encoder_state,
                                                  query_vector=h_t)
    
    pred_vector = torch.cat((context_vector,h_t),dim=1)
    score_out_index = classifier(pred_vector)
    output_vector.append(score_out_index)

In [73]:
decode_out = torch.stack(output_vector).permute(1,0,2)
decode_out.shape

torch.Size([2, 9, 15])