<a href="https://colab.research.google.com/github/akashe/NLP/blob/main/Converstational_ChatBot_with_Transformers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In this notebook, we will train a converstational agent(Chatbot) using [Cornell Movie-Dialogs Corpus](https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html) using the Transformer mentioned in the [Attention is all you Need](https://arxiv.org/pdf/1706.03762.pdf) paper.

This notebook dervies heavily from:

1. [Pytorch Chatbot Tutorial](https://pytorch.org/tutorials/beginner/chatbot_tutorial.html) but instead of RNN Seq2Seq model we will use Transformers.
2. [Attention is all you Need notebook](https://github.com/akashe/NLP/blob/main/Attention_is_all_you_need.ipynb) we wrote previosuly.


The Pytorch tutorial uses MAX_LENGHT of 10 and removes any conversation pairs with <unk> token and achieves a train NLLloss of 2.4606 after 4000 train steps.

Objective:
To get NLLloss less that 2.4606 in 4000 iterations using Transformers.

Things left:

1. reducing qa pairs with max_length and unk token
1. restricting to 4000 iterations with random batches
2. final qa interface.

### Importing libs and setting device

In [3]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import torch
from torch.jit import script, trace
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import csv
import random
import re
import os
import unicodedata
import codecs
from io import open
import itertools
import math
import spacy
from torchtext.data import Field,Dataset, Example,BucketIterator
import time

USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")

In [4]:
SEED = 1007

random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

### Downloading data

In [5]:
!wget -c http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip

--2021-02-07 20:53:48--  http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip
Resolving www.cs.cornell.edu (www.cs.cornell.edu)... 132.236.207.36
Connecting to www.cs.cornell.edu (www.cs.cornell.edu)|132.236.207.36|:80... connected.
HTTP request sent, awaiting response... 416 Requested Range Not Satisfiable

    The file is already fully retrieved; nothing to do.



In [6]:
!unzip -n /content/cornell_movie_dialogs_corpus.zip

Archive:  /content/cornell_movie_dialogs_corpus.zip


In [7]:
!head -10 /content/cornell\ movie-dialogs\ corpus/movie_lines.txt

L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!
L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!
L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.
L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?
L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go.
L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow
L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie.
L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No
L870 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I'm kidding.  You know how sometimes you just become this "persona"?  And you don't know how to quit?
L869 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Like my fear of wearing pastels?


In [8]:
!head -10 /content/cornell\ movie-dialogs\ corpus/movie_conversations.txt

u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L194', 'L195', 'L196', 'L197']
u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L198', 'L199']
u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L200', 'L201', 'L202', 'L203']
u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L204', 'L205', 'L206']
u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L207', 'L208']
u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L271', 'L272', 'L273', 'L274', 'L275']
u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L276', 'L277']
u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L280', 'L281']
u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L363', 'L364']
u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L365', 'L366']


### Processing data

In the pytorch tutorial there are seperate function for:

1. Removing unicode character.
2. Giving spaces between fullstop and question mark. i.e. Akash? -> Akash ?
3. removing redundant spaces

Luckily spacy takes care of all of the above while tokenizing. We do have to take care of keeping the text in lower case.

Also, since we will be using bucketiterator, we need not worry about setting a max_len for the sentences. Add it to the fact that Transformers dont have an upper limit of src_len or trg_len

### Loading qa-pairs

In [9]:
def loadLines(fileName, fields):
    lines = {}
    with open(fileName, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(" +++$+++ ")
            # Extract fields
            lineObj = {}
            for i, field in enumerate(fields):
                lineObj[field] = values[i]
            lines[lineObj['lineID']] = lineObj
    return lines


# Groups fields of lines from `loadLines` into conversations based on *movie_conversations.txt*
def loadConversations(fileName, lines, fields):
    conversations = []
    with open(fileName, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(" +++$+++ ")
            # Extract fields
            convObj = {}
            for i, field in enumerate(fields):
                convObj[field] = values[i]
            # Convert string to list (convObj["utteranceIDs"] == "['L598485', 'L598486', ...]")
            utterance_id_pattern = re.compile('L[0-9]+')
            lineIds = utterance_id_pattern.findall(convObj["utteranceIDs"])
            # Reassemble lines
            convObj["lines"] = []
            for lineId in lineIds:
                convObj["lines"].append(lines[lineId])
            conversations.append(convObj)
    return conversations


# Extracts pairs of sentences from conversations
def extractSentencePairs(conversations):
    qa_pairs = []
    for conversation in conversations:
        # Iterate over all the lines of the conversation
        for i in range(len(conversation["lines"]) - 1):  # We ignore the last line (no answer for it)
            inputLine = conversation["lines"][i]["text"].strip()
            targetLine = conversation["lines"][i+1]["text"].strip()
            # Filter wrong samples (if one of the lists is empty)
            if inputLine and targetLine:
                qa_pairs.append([inputLine, targetLine])
    return qa_pairs

In [10]:
corpus = "/content/cornell movie-dialogs corpus/"

lines = {}
conversations = []
MOVIE_LINES_FIELDS = ["lineID", "characterID", "movieID", "character", "text"]
MOVIE_CONVERSATIONS_FIELDS = ["character1ID", "character2ID", "movieID", "utteranceIDs"]

# Load lines and process conversations
print("\nProcessing corpus...")
lines = loadLines(os.path.join(corpus, "movie_lines.txt"), MOVIE_LINES_FIELDS)
print("\nLoading conversations...")
conversations = loadConversations(os.path.join(corpus, "movie_conversations.txt"),
                                  lines, MOVIE_CONVERSATIONS_FIELDS)

qa_pairs = extractSentencePairs(conversations)
print(f'Total number of qa pairs = {len(qa_pairs)}')


Processing corpus...

Loading conversations...
Total number of qa pairs = 221282


In [11]:
MAX_LENGTH = 10
# Reducing number of examples by setting max_len = 10
length_mask = [len(i[0].split(" "))<MAX_LENGTH and len(i[1].split(" "))<MAX_LENGTH for i in qa_pairs]

qa_pairs = [pair for i,pair in enumerate(qa_pairs) if length_mask[i]]

### Setting up preprocessing necesseties

In [12]:
!python -m spacy download en

[38;5;2m✔ Download and installation successful[0m
You can now load the model via spacy.load('en_core_web_sm')
[38;5;2m✔ Linking successful[0m
/usr/local/lib/python3.6/dist-packages/en_core_web_sm -->
/usr/local/lib/python3.6/dist-packages/spacy/data/en
You can now load the model via spacy.load('en')


In [13]:
spacy_en = spacy.load('en')

In [14]:
print([i.text.lower() for i in spacy_en.tokenizer('What!!! No?? \u1F600')])

['what', '!', '!', '!', 'no', '?', '?', 'ὠ0']


In [15]:
def tokenize_en(text):
    """
    Tokenizes English text from a string into a list of strings
    """
    return [tok.text for tok in spacy_en.tokenizer(text)]

### Setting up Fields, Vocabs, Datasets and Iterators

In [16]:
# Defining a single field to keep the vocab same for encoder and decoder

SRC_TRG = Field(tokenize = tokenize_en, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True, 
            batch_first = True)

In [17]:
fields = [('src',SRC_TRG),('trg',SRC_TRG)]

chat_examples = [Example.fromlist([pair[0],pair[1]],fields) for pair in qa_pairs]
chat_dataset = Dataset(chat_examples,fields)

train_data, valid_data = chat_dataset.split(split_ratio=[0.85,0.15],random_state = random.seed(SEED))

In [18]:
print(vars(train_data[0]))

{'src': ['he', 'tapped', 'that', '.'], 'trg': ['naw', '!']}


In [19]:
SRC_TRG.build_vocab(train_data, max_size = 7823,min_freq = 3)
# TRG.build_vocab(train_data, max_size = 5000,min_freq = 3)

In [20]:
# There is no need for the this step. Ideally model should be able to handle unk tokens.
# Since we are trying to compare results with the pytorch tutorial,
# we are removing pairs with unk token in them
pruned_Examples = []
for i in train_data.examples:
  flag_ = True
  for j in i.src:
    if j not in SRC_TRG.vocab.stoi:
      flag_ = False
  for j in i.trg:
    if j not in SRC_TRG.vocab.stoi:
      flag_ = False
  if flag_:
    pruned_Examples.append(i)
  
print(f"Total number of train_examples after pruning pairs with unk tokens {len(pruned_Examples)}")
train_data.examples = pruned_Examples

Total number of train_examples after pruning pairs with unk tokens 57320


In [21]:
BATCH_SIZE = 64

train_iterator, valid_iterator = BucketIterator.splits(
    (train_data, valid_data), 
     batch_size = BATCH_SIZE,
     sort_key = lambda x: len(x.src),
     device = device)

### Model

In [22]:
class PositionalEncodingComponent(nn.Module):
  '''
  Class to encode positional information to tokens.
  

  '''
  def __init__(self,hid_dim,device,dropout=0.2,max_len=5000):
    super().__init__()

    assert hid_dim%2==0 # If not, it will result error in allocation to positional_encodings[:,1::2] later

    self.dropout = nn.Dropout(dropout)

    self.positional_encodings = torch.zeros(max_len,hid_dim)

    pos = torch.arange(0,max_len).unsqueeze(1) # pos : [max_len,1]
    div_term  = torch.exp(-torch.arange(0,hid_dim,2)*math.log(10000.0)/hid_dim) # Calculating value of 1/(10000^(2i/hid_dim)) in log space and then exponentiating it
    # div_term: [hid_dim//2]

    self.positional_encodings[:,0::2] = torch.sin(pos*div_term) # pos*div_term [max_len,hid_dim//2]
    self.positional_encodings[:,1::2] = torch.cos(pos*div_term) 

    self.positional_encodings = self.positional_encodings.unsqueeze(0) # To account for batch_size in inputs
    # positional_encodings : [1,max_len,hid_dim]
    
    self.device = device

  def forward(self,x):
    x = x + self.positional_encodings[:,:x.size(1)].detach().to(self.device)
    return self.dropout(x)


In [23]:
class FeedForwardComponent(nn.Module):
  '''
  Class for pointwise feed forward connections
  '''
  def __init__(self,hid_dim,pf_dim,dropout):
    super().__init__()

    self.dropout = nn.Dropout(dropout)

    self.fc1 = nn.Linear(hid_dim,pf_dim)
    self.fc2 = nn.Linear(pf_dim,hid_dim)

  def forward(self,x):

    # x : [batch_size,seq_len,hid_dim]
    x = self.dropout(torch.relu(self.fc1(x)))

    # x : [batch_size,seq_len,pf_dim]
    x = self.fc2(x)

    # x : [batch_size,seq_len,hid_dim]
    return x

In [24]:
class MultiHeadedAttentionComponent(nn.Module):
  '''
  Multiheaded attention Component. This implementation also supports mask. 
  The reason for mask that in Decoder, we don't want attention mechanism to get
  important information from future tokens.
  '''
  def __init__(self,hid_dim, n_heads, dropout, device):
    super().__init__()

    assert hid_dim % n_heads == 0 # Since we split hid_dims into n_heads

    self.hid_dim = hid_dim
    self.n_heads = n_heads # no of heads in 'multiheaded' attention
    self.head_dim = hid_dim//n_heads # dims of each head

    # Transformation from source vector to query vector
    self.fc_q = nn.Linear(hid_dim,hid_dim)

    # Transformation from source vector to key vector
    self.fc_k = nn.Linear(hid_dim,hid_dim)

    # Transformation from source vector to value vector
    self.fc_v = nn.Linear(hid_dim,hid_dim)

    self.fc_o = nn.Linear(hid_dim,hid_dim)

    self.dropout = nn.Dropout(dropout)

    # Used in self attention for smoother gradients
    self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)

  def forward(self,query,key,value,mask=None):

    #query : [batch_size, query_len, hid_dim]
    #key : [batch_size, key_len, hid_dim]
    #value : [batch_size, value_len, hid_dim]

    batch_size = query.shape[0]

    # Transforming quey,key,values
    Q = self.fc_q(query)
    K = self.fc_k(key)
    V = self.fc_v(value)

    #Q : [batch_size, query_len, hid_dim]
    #K : [batch_size, key_len, hid_dim]
    #V : [batch_size, value_len,hid_dim]

    # Changing shapes to acocmadate n_heads information
    Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
    K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
    V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)

    #Q : [batch_size, n_heads, query_len, head_dim]
    #K : [batch_size, n_heads, key_len, head_dim]
    #V : [batch_size, n_heads, value_len, head_dim]

    # Calculating alpha
    score = torch.matmul(Q,K.permute(0,1,3,2))/self.scale
    # score : [batch_size, n_heads, query_len, key_len]

    if mask is not None:
      score = score.masked_fill(mask==0,-1e10)

    alpha = torch.softmax(score,dim=-1)
    # alpha : [batch_size, n_heads, query_len, key_len]

    # Get the final self-attention  vector
    x = torch.matmul(self.dropout(alpha),V)
    # x : [batch_size, n_heads, query_len, head_dim]

    # Reshaping self attention vector to concatenate
    x = x.permute(0,2,1,3).contiguous()
    # x : [batch_size, query_len, n_heads, head_dim]

    x = x.view(batch_size,-1,self.hid_dim)
    # x: [batch_size, query_len, hid_dim]

    # Transforming concatenated outputs 
    x = self.fc_o(x)
    #x : [batch_size, query_len, hid_dim] 

    return x, alpha

In [25]:
class EncoderLayer(nn.Module):
  '''
  Operations of a single layer in an Encoder. An Encoder employs multiple such layers. Each layer contains:
  1) multihead attention, folllowed by
  2) LayerNorm of addition of multihead attention output and input to the layer, followed by
  3) FeedForward connections, followed by
  4) LayerNorm of addition of FeedForward outputs and output of previous layerNorm.
  '''
  def __init__(self, hid_dim,n_heads,pf_dim,dropout,device):
    super().__init__()
    
    self.self_attn_layer_norm = nn. LayerNorm(hid_dim) #Layer norm after self-attention
    self.ff_layer_norm = nn.LayerNorm(hid_dim) # Layer norm after FeedForward component

    self.self_attention = MultiHeadedAttentionComponent(hid_dim,n_heads,dropout,device)
    self.feed_forward = FeedForwardComponent(hid_dim,pf_dim,dropout)

    self.dropout = nn.Dropout(dropout)
    
  def forward(self,src,src_mask):
    
    # src : [batch_size, src_len, hid_dim]
    # src_mask : [batch_size, 1, 1, src_len]

    # get self-attention
    _src, _ = self.self_attention(src,src,src,src_mask)

    # LayerNorm after dropout
    src = self.self_attn_layer_norm(src + self.dropout(_src))
    # src : [batch_size, src_len, hid_dim]

    # FeedForward
    _src = self.feed_forward(src)

    # layerNorm after dropout
    src = self.ff_layer_norm(src + self.dropout(_src))
    # src: [batch_size, src_len, hid_dim]

    return src
    

In [26]:
class DecoderLayer(nn.Module):
  '''
  Operations of a single layer in an Decoder. An Decoder employs multiple such layers. Each layer contains:
  1) masked decoder self attention, followed by
  2) LayerNorm of addition of previous attention output and input to the layer,, followed by
  3) encoder self attention, followed by
  4) LayerNorm of addition of result of encoder self attention and its input, followed by
  5) FeedForward connections, followed by
  6) LayerNorm of addition of Feedforward results and its input.
  '''
  def __init__(self,hid_dim,n_heads,pf_dim,dropout,device):
    super().__init__()

    self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
    self.enc_attn_layer_norm = nn.LayerNorm(hid_dim)
    self.ff_layer_norm = nn.LayerNorm(hid_dim)

    # decoder self attention
    self.self_attention = MultiHeadedAttentionComponent(hid_dim,n_heads,dropout,device)

    # encoder attention
    self.encoder_attention = MultiHeadedAttentionComponent(hid_dim,n_heads,dropout,device)

    # FeedForward
    self.feed_forward = FeedForwardComponent(hid_dim,pf_dim,dropout)

    self.dropout = nn.Dropout(dropout)

  def forward(self,trg, enc_src,trg_mask,src_mask):

    #trg : [batch_size, trg_len, hid_dim]
    #enc_src : [batch_size, src_len, hid_dim]
    #trg_mask : [batch_size, 1, trg_len, trg_len]
    #src_mask : [batch_size, 1, 1, src_len]

    '''
    Decoder self-attention
    trg_mask is to force decoder to look only into past tokens and not get information from future tokens.
    Since we apply mask before doing softmax, the final self attention vector gets no information from future tokens.
    '''
    _trg, _ = self.self_attention(trg,trg,trg,trg_mask)

    # LayerNorm and dropout with resdiual connection
    trg = self.self_attn_layer_norm(trg + self.dropout(_trg))
    # trg : [batch_size, trg_len, hid_dim]

    '''
    Encoder attention:
    Query: trg
    key: enc_src
    Value : enc_src
    Why? 
    the idea here is to extract information from encoder outputs. So we use decoder self-attention as a query to find important values from enc_src
    and that is why we use src_mask, to avoid getting information from enc_src positions where it is equal to pad-id
    After we get necessary infromation from encoder outputs we add them back to decoder self-attention.
    '''
    _trg, encoder_attn_alpha = self.encoder_attention(trg,enc_src,enc_src,src_mask)

    # LayerNorm , residual connection and dropout
    trg = self.enc_attn_layer_norm(trg + self.dropout(_trg))
    # trg : [ batch_size, trg_len, hid_dim]

    # Feed Forward
    _trg = self.feed_forward(trg)

    # LayerNorm, residual connection and dropout
    trg = self.ff_layer_norm(trg + self.dropout(_trg))

    return trg, encoder_attn_alpha
    

In [27]:
class Encoder(nn.Module):
  '''
  An encoder, creates token embeddings and position embeddings and passes them through multiple encoder layers
  This is a bidirectional encoder. So we will reverse the input, add appropriate padding and positional embeddings and
  concat it with the forward input.
  '''
  def __init__(self,input_dim,hid_dim,n_layers,n_heads,pf_dim,dropout,device,max_length = 5000):
    super().__init__()
    self.device = device

    self.tok_embedding = nn.Embedding(input_dim,hid_dim)
    self.pos_embedding = PositionalEncodingComponent(hid_dim,device,dropout,max_length)

    # encoder layers
    self.layers = nn.ModuleList([EncoderLayer(2*hid_dim,n_heads,pf_dim,dropout,device) for _ in range(n_layers)])

    self.dropout = nn.Dropout(dropout)

    self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)

    self.hid_dim = hid_dim

  def reverse_src(self,src,src_mask):
    # src : [batch_size, src_len]
    # src_mask : [batch_size,1,1,src_len]

    total_non_padded_values = src_mask.sum(dim=-1).squeeze(1)
    # total_non_padded_values : [batch_size,1]

    reversed_src = torch.ones(src.shape,dtype=torch.int64)
    # reversed_src : [batch_size,src_len]

    # flip the src :[10,9,8] -> [8,9,10]
    src = torch.flip(src,dims=[-1])
    # src : [batch_size, src_len]
    
    for pos,(i,j) in enumerate(zip(src,total_non_padded_values)):
      j = int(j.item())
      reversed_src[pos] = torch.cat([i[-j:],i[:-j]],dim=-1)

    return reversed_src
  
  def forward(self,src,src_mask):

    # src : [batch_size, src_len]
    # src_mask : [batch_size,1,1,src_len]

    batch_size = src.shape[0]
    src_len = src.shape[1]

    # reversed source
    reversed_src = self.reverse_src(src,src_mask).to(device)
    # reversed_src : [batch_size,src_len]

    tok_embeddings = self.tok_embedding(src)*self.scale
    reversed_tok_embeddings = self.tok_embedding(reversed_src)*self.scale

    # token plus position embeddings
    src  = self.pos_embedding(tok_embeddings)
    reversed_src = self.pos_embedding(reversed_tok_embeddings)

    concatenated_src = torch.ones([batch_size,src_len,self.hid_dim*2]).to(device)
    # Interleaving concatenated source such that f1,b1,f2,b2....
    concatenated_src[:,:,0::2] = src
    concatenated_src[:,:,1::2] = reversed_src

    src = concatenated_src

    for layer in self.layers:
      src = layer(src,src_mask)
    # src : [batch_size, src_len, 2*hid_dim]

    return src

In [28]:
class Decoder(nn.Module):
  '''
  An decoder, creates token embeddings and position embeddings and passes them through multiple decoder layers
  '''
  def __init__(self,output_dim,hid_dim,n_layers,n_heads,pf_dim,dropout,device,max_length= 5000):
    super().__init__()

    self.device = device

    self.tok_embedding = nn.Embedding(output_dim,hid_dim)
    self.pos_embedding = PositionalEncodingComponent(hid_dim,device,dropout,max_length)

    # decoder layers
    self.layers = nn.ModuleList([DecoderLayer(hid_dim,n_heads,pf_dim,dropout,device) for _ in range(n_layers)])

    # convert decoder outputs to real outputs
    self.fc_out = nn.Linear(hid_dim,output_dim)

    self.dropout = nn.Dropout(dropout)

    self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)

  def forward(self, trg, enc_src,trg_mask,src_mask):
    
    #trg : [batch_size, trg_len]
    #enc_src : [batch_size, src_len, hid_dim]
    #trg_mask : [batch_size, 1, trg_len, trg_len]
    #src_mask : [batch_size, 1, 1, src_len]

    batch_size = trg.shape[0]
    trg_len = trg.shape[1]

    tok_embeddings = self.tok_embedding(trg)*self.scale

    # token plus pos embeddings
    trg = self.pos_embedding(tok_embeddings)
    # trg : [batch_size, trg_len, hid_dim]

    # Pass trg thorugh decoder layers
    for layer in self.layers:
      trg, encoder_attention = layer(trg,enc_src,trg_mask,src_mask)
    
    # trg : [batch_size,trg_len,hid_dim]
    # encoder_attention :  [batch_size, n_head,trg_len, src_len]

    # Convert to outputs
    output = self.fc_out(trg)
    # output : [batch_size, trg_len, output_dim]
    
    return output, encoder_attention

In [29]:
class Seq2Seq(nn.Module):
  def __init__(self, encoder, decoder, src_pad_idx, trg_pad_idx, device):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder
    self.src_pad_idx = src_pad_idx
    self.trg_pad_idx = trg_pad_idx
    self.device = device

  def make_src_mask(self,src):
    # src : [batch_size, src_len]

    # Masking pad values
    src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
    # src_mask : [batch_size,1,1,src_len]

    return src_mask

  def make_trg_mask(self,trg):
    # trg : [batch_size, trg_len]

    # Masking pad values
    trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)
    # trg_pad_mask : [batch_size,1,1, trg_len]

    # Masking future values
    trg_len = trg.shape[1]
    trg_sub_mask = torch.tril(torch.ones((trg_len,trg_len),device= self.device)).bool()
    # trg_sub_mask : [trg_len, trg_len]

    # combine both masks
    trg_mask = trg_pad_mask & trg_sub_mask
    # trg_mask = [batch_size,1,trg_len,trg_len]

    return trg_mask

  def forward(self,src,trg):

    # src : [batch_size, src_len]
    # trg : [batch_size, trg_len]

    src_mask = self.make_src_mask(src)
    trg_mask = self.make_trg_mask(trg)

    # src_mask : [ batch_size, 1,1,src_len]
    # trg_mask : [batch_size, 1, trg_len, trg_len]

    enc_src = self.encoder(src,src_mask)
    #enc_src : [batch_size, src_len, hid_dim]

    output, encoder_decoder_attention = self.decoder(trg,enc_src,trg_mask,src_mask)
    # output : [batch_size, trg_len, output_dim]
    # encoder_decoder_attention : [batch_size, n_heads, trg_len, src_len]

    return output, encoder_decoder_attention

### Initializing Network

In [54]:
INPUT_DIM = len(SRC_TRG.vocab)
OUTPUT_DIM = len(SRC_TRG.vocab)
HID_DIM = 256
ENC_LAYERS = 1
DEC_LAYERS = 1
ENC_HEADS = 8
DEC_HEADS = 8
ENC_PF_DIM = 512
DEC_PF_DIM = 512
ENC_DROPOUT = 0.1
DEC_DROPOUT = 0.1

enc = Encoder(INPUT_DIM, 
              HID_DIM, 
              ENC_LAYERS, 
              ENC_HEADS, 
              ENC_PF_DIM, 
              ENC_DROPOUT, 
              device)

dec = Decoder(OUTPUT_DIM, 
              2*HID_DIM, 
              DEC_LAYERS, 
              DEC_HEADS, 
              DEC_PF_DIM, 
              DEC_DROPOUT, 
              device)

SRC_PAD_IDX = SRC_TRG.vocab.stoi[SRC_TRG.pad_token]
TRG_PAD_IDX = SRC_TRG.vocab.stoi[SRC_TRG.pad_token]

model = Seq2Seq(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device).to(device)

### Initialize weights and total model params

In [55]:
def initialize_weights(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        nn.init.xavier_uniform_(m.weight.data)

model.apply(initialize_weights);

In [56]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 14,234,003 trainable parameters


### Learning rate, criterion and optimizer


In [57]:
LEARNING_RATE = 0.0001
decoder_learning_ratio = 5

encoder_optimizer = torch.optim.Adam(model.encoder.parameters(), lr = LEARNING_RATE)
decoder_optimizer = torch.optim.Adam(model.decoder.parameters(),lr= decoder_learning_ratio*LEARNING_RATE)

optimizer = (encoder_optimizer,decoder_optimizer)

criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX,size_average= True) # size average = True calculates mean loss only of non pad idx values



### Train and Eval Loop

In [58]:
def train(model, iterator, optimizer, criterion, clip,iters):
    
    model.train()
    
    epoch_loss = 0
    
    for i, batch in enumerate(iterator):
        

        src = batch.src
        trg = batch.trg
        
        optimizer[0].zero_grad()
        optimizer[1].zero_grad()
        
        output, _ = model(src, trg[:,:-1])
                
        #output = [batch size, trg len - 1, output dim]
        #trg = [batch size, trg len]
            
        output_dim = output.shape[-1]
            
        output = output.contiguous().view(-1, output_dim)
        trg = trg[:,1:].contiguous().view(-1)
                
        #output = [batch size * trg len - 1, output dim]
        #trg = [batch size * trg len - 1]
            
        loss = criterion(output, trg)
        
        loss.backward()

        if iters <= 4000:
          print(f'Train loss at {iters} iteration is {loss.item()} ')
          
        if iters >=4000:
          break

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer[0].step()
        optimizer[1].step()
        
        epoch_loss += loss.item()
        
        iters += 1

    return epoch_loss / len(iterator), iters

In [59]:
def evaluate(model, iterator, criterion):
    
    model.eval()
    
    epoch_loss = 0
    
    with torch.no_grad():
    
        for i, batch in enumerate(iterator):

            src = batch.src
            trg = batch.trg

            output, _ = model(src, trg[:,:-1])
            
            #output = [batch size, trg len - 1, output dim]
            #trg = [batch size, trg len]
            
            output_dim = output.shape[-1]
            
            output = output.contiguous().view(-1, output_dim)
            trg = trg[:,1:].contiguous().view(-1)
            
            #output = [batch size * trg len - 1, output dim]
            #trg = [batch size * trg len - 1]
            
            loss = criterion(output, trg)

            epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

### Runner Loop

In [60]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [61]:
N_EPOCHS = 5
CLIP = 1

best_valid_loss = float('inf')

iters = 0

for epoch in range(N_EPOCHS):
    
    start_time = time.time()
    
    train_loss,iters = train(model, train_iterator, optimizer, criterion, CLIP ,iters)
    # valid_loss = evaluate(model, valid_iterator, criterion)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    # if valid_loss < best_valid_loss:
    #     best_valid_loss = valid_loss
    #     torch.save(model.state_dict(), 'chat_bot_model.pt')
    
    # print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    # print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    # print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

Train loss at 0 iteration is 9.00206470489502 
Train loss at 1 iteration is 8.232087135314941 
Train loss at 2 iteration is 7.64331579208374 
Train loss at 3 iteration is 7.442013263702393 
Train loss at 4 iteration is 7.192834854125977 
Train loss at 5 iteration is 6.757341384887695 
Train loss at 6 iteration is 6.532083988189697 
Train loss at 7 iteration is 6.337754726409912 
Train loss at 8 iteration is 6.048821926116943 
Train loss at 9 iteration is 5.9464240074157715 
Train loss at 10 iteration is 5.775899410247803 
Train loss at 11 iteration is 5.612894535064697 
Train loss at 12 iteration is 5.5215582847595215 
Train loss at 13 iteration is 5.067715167999268 
Train loss at 14 iteration is 5.517705917358398 
Train loss at 15 iteration is 5.209707260131836 
Train loss at 16 iteration is 5.2359209060668945 
Train loss at 17 iteration is 5.240084648132324 
Train loss at 18 iteration is 5.263883590698242 
Train loss at 19 iteration is 4.915689945220947 
Train loss at 20 iteration is