In [1]:
import torch as t
import csv
import ast
import re

# Data Preprocessing

## 1. Visualise the first few lines of the tsv files

In [2]:
movie_lines_filename = "cornell_movie_corpus/movie_lines.tsv"
lines_to_visualise = 10

# Open the TSV file in read mode
with open(movie_lines_filename, "r", encoding="utf-8") as file:
    # Iterate over each line in the file
    for i, line in enumerate(file):
        # Remove any leading/trailing whitespace and split the line by tabs
        row = line.strip().split("\t")
        
        # Process the row as desired
        print(row)
        
        # Check if the desired number of lines has been reached
        if i + 1 >= lines_to_visualise:
            break

    print("\n")
    print(type(row))

['L1045', 'u0', 'm0', 'BIANCA', 'They do not!']
['L1044', 'u2', 'm0', 'CAMERON', 'They do to!']
['L985', 'u0', 'm0', 'BIANCA', 'I hope so.']
['L984', 'u2', 'm0', 'CAMERON', 'She okay?']
['L925', 'u0', 'm0', 'BIANCA', "Let's go."]
['L924', 'u2', 'm0', 'CAMERON', 'Wow']
['L872', 'u0', 'm0', 'BIANCA', "Okay -- you're gonna need to learn how to lie."]
['L871', 'u2', 'm0', 'CAMERON', 'No']
['"L870', 'u0', 'm0', 'BIANCA', 'I\'m kidding.  You know how sometimes you just become this ""persona""?  And you don\'t know how to quit?"']
['L869', 'u0', 'm0', 'BIANCA', 'Like my fear of wearing pastels?']


<class 'list'>


In [3]:
movie_conv_filename = "cornell_movie_corpus/movie_conversations.tsv"
lines_to_visualise = 10

# Open the TSV file in read mode
with open(movie_conv_filename, "r", encoding="utf-8") as file:
    # Iterate over each line in the file
    for i, line in enumerate(file):
        # Remove any leading/trailing whitespace and split the line by tabs
        row = line.strip().split("\t")
        
        # Process the row as desired
        print(row)
        
        # Check if the desired number of lines has been reached
        if i + 1 >= lines_to_visualise:
            break
     

    print("\n")
    print(type(row))

['u0', 'u2', 'm0', "['L194' 'L195' 'L196' 'L197']"]
['u0', 'u2', 'm0', "['L198' 'L199']"]
['u0', 'u2', 'm0', "['L200' 'L201' 'L202' 'L203']"]
['u0', 'u2', 'm0', "['L204' 'L205' 'L206']"]
['u0', 'u2', 'm0', "['L207' 'L208']"]
['u0', 'u2', 'm0', "['L271' 'L272' 'L273' 'L274' 'L275']"]
['u0', 'u2', 'm0', "['L276' 'L277']"]
['u0', 'u2', 'm0', "['L280' 'L281']"]
['u0', 'u2', 'm0', "['L363' 'L364']"]
['u0', 'u2', 'm0', "['L365' 'L366']"]


<class 'list'>


## 2. Create relevant dictionaries based on the tsv file datas 

In [4]:
movie_lines = {}
movie_lines_fields = ['line ID','user ID','movie ID', 'char name', 'text']

# Open the TSV file in read mode
with open(movie_lines_filename, "r", encoding="utf-8") as file:
    for i, line in enumerate(file):
        
        # Remove any leading/trailing whitespace and split the line by tabs
        line_parts = line.strip().split("\t")
        
        if(len(line_parts) > 4):             # Handling the bad datas where there are empty texts
            
            # Extract the individual parts
            line_id = line_parts[0]
            user_id = line_parts[1]
            movie_id = line_parts[2]
            character_name = line_parts[3]
            text = line_parts[4]

        # Create a dictionary for the line
        line_dict = {
            'lineID': line_id,
            'userID': user_id,
            'movieID': movie_id,
            'charName': character_name,
            'text': text
        }

        # Add the line dictionary to the result dictionary
        movie_lines[line_id] = line_dict


In [5]:
# movie_lines is a dictionary of dictionaries
movie_lines

{'L1045': {'lineID': 'L1045',
  'userID': 'u0',
  'movieID': 'm0',
  'charName': 'BIANCA',
  'text': 'They do not!'},
 'L1044': {'lineID': 'L1044',
  'userID': 'u2',
  'movieID': 'm0',
  'charName': 'CAMERON',
  'text': 'They do to!'},
 'L985': {'lineID': 'L985',
  'userID': 'u0',
  'movieID': 'm0',
  'charName': 'BIANCA',
  'text': 'I hope so.'},
 'L984': {'lineID': 'L984',
  'userID': 'u2',
  'movieID': 'm0',
  'charName': 'CAMERON',
  'text': 'She okay?'},
 'L925': {'lineID': 'L925',
  'userID': 'u0',
  'movieID': 'm0',
  'charName': 'BIANCA',
  'text': "Let's go."},
 'L924': {'lineID': 'L924',
  'userID': 'u2',
  'movieID': 'm0',
  'charName': 'CAMERON',
  'text': 'Wow'},
 'L872': {'lineID': 'L872',
  'userID': 'u0',
  'movieID': 'm0',
  'charName': 'BIANCA',
  'text': "Okay -- you're gonna need to learn how to lie."},
 'L871': {'lineID': 'L871',
  'userID': 'u2',
  'movieID': 'm0',
  'charName': 'CAMERON',
  'text': 'No'},
 '"L870': {'lineID': '"L870',
  'userID': 'u0',
  'movieID

In [6]:
# First element of the movie_lines dictionary
list(movie_lines.items())[0]

('L1045',
 {'lineID': 'L1045',
  'userID': 'u0',
  'movieID': 'm0',
  'charName': 'BIANCA',
  'text': 'They do not!'})

In [7]:
# Demonstrating the regex function used
import re

line_numbers_str = "['L194' 'L195' 'L196' 'L197']"

# Extract the line numbers using regular expressions
line_numbers_list = re.findall(r"'(\w+)'", line_numbers_str)

print(line_numbers_list)
print(type(line_numbers_list))

['L194', 'L195', 'L196', 'L197']
<class 'list'>


In [8]:
movie_conv_fields = ['charID1','charID2','movieID', 'lineIDs']
conversations = []

# Open the TSV file in read mode
with open(movie_conv_filename, "r", encoding="utf-8") as file:
    
    # Iterate over each line in the file
    for i, line in enumerate(file):
        conv_dict = {}                            # Declare an empty dictionary
        row = line.strip().split("\t")            #  ['u0', 'u2', 'm0', "['L194' 'L195' 'L196' 'L197']"]
        for j, conv_field in enumerate(movie_conv_fields):
            if(conv_field == 'lineIDs'):
                row[j] = re.findall(r"'(\w+)'", row[j])   # matches any alphanumeric characters (\w+) enclosed in single quotes
            conv_dict[conv_field] = row[j] 
        
        conversations.append(conv_dict)

In [9]:
# conversations is a list of dictionaries
# First element of the conversations list
conversations[0]

{'charID1': 'u0',
 'charID2': 'u2',
 'movieID': 'm0',
 'lineIDs': ['L194', 'L195', 'L196', 'L197']}

In [10]:
c = [
    {
        'charID1': 'u0',
        'charID2': 'u2',
        'movieID': 'm0',
        'lineIDs': ['L194', 'L195', 'L196', 'L197']
    },
    # other dictionaries
    ]

m_lines = {
    'L194': {'lineID': 'L194', 'userID': 'u0', 'movieID': 'm0', 'charName': 'BIANCA', 'text': 'They do not!'},
    'L195': {'lineID': 'L195', 'userID': 'u2', 'movieID': 'm0', 'charName': 'CAMERON', 'text': 'They do to!'},
    'L500': {'lineID': 'L500', 'userID': 'u2', 'movieID': 'm0', 'charName': 'CAMERON', 'text': 'They do to!'},
    # other dictionaries
}

for item in c:
    line_ids = item['lineIDs']
    lines = []
    for line_id in line_ids:
        line = m_lines.get(line_id)
        if line:
            lines.append(line)
    item['lines'] = lines

print(c)

[{'charID1': 'u0', 'charID2': 'u2', 'movieID': 'm0', 'lineIDs': ['L194', 'L195', 'L196', 'L197'], 'lines': [{'lineID': 'L194', 'userID': 'u0', 'movieID': 'm0', 'charName': 'BIANCA', 'text': 'They do not!'}, {'lineID': 'L195', 'userID': 'u2', 'movieID': 'm0', 'charName': 'CAMERON', 'text': 'They do to!'}]}]


## 3. Merging the processed conversations list and movie_lines dictionary 
- Add a new key 'lines' in each dictionary of the conversations list of dictionaries

In [11]:
for conv in conversations:
    line_ids = conv['lineIDs']
    lines = []
    for line_id in line_ids:
        line = movie_lines.get(line_id)
        if line:
            lines.append(line)
    conv['lines'] = lines

In [12]:
conversations[0]

{'charID1': 'u0',
 'charID2': 'u2',
 'movieID': 'm0',
 'lineIDs': ['L194', 'L195', 'L196', 'L197'],
 'lines': [{'lineID': 'L194',
   'userID': 'u0',
   'movieID': 'm0',
   'charName': 'BIANCA',
   'text': 'Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.'},
  {'lineID': 'L195',
   'userID': 'u2',
   'movieID': 'm0',
   'charName': 'CAMERON',
   'text': "Well I thought we'd start with pronunciation if that's okay with you."},
  {'lineID': 'L196',
   'userID': 'u0',
   'movieID': 'm0',
   'charName': 'BIANCA',
   'text': 'Not the hacking and gagging and spitting part.  Please.'},
  {'lineID': 'L197',
   'userID': 'u2',
   'movieID': 'm0',
   'charName': 'CAMERON',
   'text': "Okay... then how 'bout we try out some French cuisine.  Saturday?  Night?"}]}

## 4. Extract q/a pairs

In [13]:
qa_pairs = []

for conv in conversations:
    lines = conv['lines']
    for i in range(len(lines) - 1):
        q = lines[i]['text'].strip()       # remove trailing and leading whitespace characters 
        a = lines[i+1]['text'].strip()
        if len(q) > 0 and len(a) > 0:       # filter empty lists
            qa_pairs.append([q, a])

        
len(qa_pairs)

217150

In [14]:
qa_pairs

[['Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.',
  "Well I thought we'd start with pronunciation if that's okay with you."],
 ["Well I thought we'd start with pronunciation if that's okay with you.",
  'Not the hacking and gagging and spitting part.  Please.'],
 ['Not the hacking and gagging and spitting part.  Please.',
  "Okay... then how 'bout we try out some French cuisine.  Saturday?  Night?"],
 ["You're asking me out.  That's so cute. What's your name again?",
  'Forget it.'],
 ["No no it's my fault -- we didn't have a proper introduction ---",
  'Cameron.'],
 ['Cameron.',
  "The thing is Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does."],
 ["The thing is Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.",
  'Seems like she could get a date easy enough...'],
 ['Why?',


## 5. Writing a new tsv file using the qa_pairs

In [15]:
filename = "cornell_movie_corpus/formatted_qa_pairs.tsv"

print("Writting a newly formatted file .....")
with open(filename, "w", encoding="utf-8") as file:
    for pair in qa_pairs:
        q, a = pair
        file.write(f"{q}\t{a}\n")     # format the QnA with a TAB character b/w and a newline b/w every Qna pair 

print("File written successfully.")

Writting a newly formatted file .....
File written successfully.


In [16]:
with open(filename, "rb") as file:
    lines = file.readlines()
for line in lines[:5]:
    print(line)

b"Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\tWell I thought we'd start with pronunciation if that's okay with you.\r\n"
b"Well I thought we'd start with pronunciation if that's okay with you.\tNot the hacking and gagging and spitting part.  Please.\r\n"
b"Not the hacking and gagging and spitting part.  Please.\tOkay... then how 'bout we try out some French cuisine.  Saturday?  Night?\r\n"
b"You're asking me out.  That's so cute. What's your name again?\tForget it.\r\n"
b"No no it's my fault -- we didn't have a proper introduction ---\tCameron.\r\n"


## 6. Processing the words
- Defining a WordIndexer class
- Define helper functions to preprocess the text such as - convert unicode to ASCII, normalise the strings
- Read from the formatted qa pairs tsv file and store the preprocessed qa pairs in pair 

In [17]:
# Defining tokens
PAD_token = 0             # for padding short sentences
SOS_token = 1             # start of sentence token
EOS_token = 2             # End of sentence

class WordIndexer:
    def __init__(self, corpus_name):
        """
        Initializes a WordIndexer object.

        Args:
        - corpus_name: A string representing the name of the corpus or dataset.
                                                                                        """
        self.corpus_name = corpus_name
        self.word_to_index = {'PAD' : PAD_token, 'SOS': SOS_token, 'EOS': EOS_token}
        self.word_counts = {}           # Stores the count of each word in the corpus
        self.index_to_word = {PAD_token:'PAD', SOS_token:'SOS', EOS_token:'EOS'}
        self.num_words = 3             # to include the 3 special tokens: PAD, SOS, EOS

    def add_word(self, word):
        """
        Adds a word to the vocabulary/WordIndexer object

        Args:
        - word: A string representing the word to be added.
                                                                    """
        if word not in self.word_to_index:
            index = len(self.word_to_index) + 1
            self.word_to_index[word] = index
            self.index_to_word[index] = word
            self.word_counts[word] = 1
            self.num_words += 1
        else:
            self.word_counts[word] += 1

    def add_sentence(self, sentence):
        """
            Adds all words in a sentence to the vocabulary/WordIndexer object

        Args:
        - sentence: A string representing the sentence.
                                                                """
        words = sentence.split()
        for word in words:
            self.add_word(word)
    
    
    def trim_less_freq_words(self, threshold):
        """
            Remove word below a certain count threshold

        Args:
        - threshold: int : the minimum count for a word to be retained.
                                                                          """
        words_to_remove = []
        for word, count in self.word_counts.items():
            if count < threshold:                        # remove if count < threshold
                words_to_remove.append(word)

        for word in words_to_remove:
            index = self.word_to_index[word]
            del self.word_to_index[word]
            del self.index_to_word[index]
            del self.word_counts[word]
            self.num_words -= 1

    # Retrieve the index of a word from the vocabulary
    def get_word_index(self, word):  
        return self.word_to_index.get(word)

    # Retruns the word corresponding to an index from the vocabulary
    def get_index_word(self, index):
        return self.index_to_word.get(index)


### A sample Run to test/show the working of the WordIndexer class

In [18]:
W = WordIndexer("sample")
W.add_sentence("My name is Donal. The name is Sherlock Holmes")
W.word_to_index

{'PAD': 0,
 'SOS': 1,
 'EOS': 2,
 'My': 4,
 'name': 5,
 'is': 6,
 'Donal.': 7,
 'The': 8,
 'Sherlock': 9,
 'Holmes': 10}

In [19]:
W.word_counts

{'My': 1,
 'name': 2,
 'is': 2,
 'Donal.': 1,
 'The': 1,
 'Sherlock': 1,
 'Holmes': 1}

In [20]:
W.index_to_word

{0: 'PAD',
 1: 'SOS',
 2: 'EOS',
 4: 'My',
 5: 'name',
 6: 'is',
 7: 'Donal.',
 8: 'The',
 9: 'Sherlock',
 10: 'Holmes'}

In [21]:
W.num_words

10

In [22]:
W.trim_less_freq_words(2)

In [23]:
W.word_to_index

{'PAD': 0, 'SOS': 1, 'EOS': 2, 'name': 5, 'is': 6}

In [24]:
!pip install unidecode





In [25]:
from unidecode import unidecode

# Thanks to stackoverflow "https://stackoverflow.com/a/518232/2809427"
def unicode_to_ascii(text):
    """
    Convert Unicode text to ASCII by transliterating non-ASCII characters to their closest ASCII equivalents.
    
    Args:
        text (str): The Unicode text to convert.
    Returns:
        str: The converted ASCII text.
                                                                                                                """
    return unidecode(text)

In [26]:
unicode_to_ascii('北亰')

'Bei Jing '

In [27]:
unicode_to_ascii('François')

'Francois'

In [28]:
unicode_to_ascii('kožušček')

'kozuscek'

In [29]:
def normalize_string(text):
    """
    Normalize a string by converting it to lowercase, adding space before punctuation marks,
    removing non-letter characters, and removing sequences of whitespace.
    
    Args:
        text (str): The input string to normalize.
    
    Returns:
        str: The normalized string.
    """
    # Convert the text to lowercase
    normalized_text = unicode_to_ascii(text.lower())         # no need to strip()

    # Add space before punctuation marks
    normalized_text = re.sub(r"([.,!?])", r" \1", normalized_text)

    # Remove non-letter characters
    normalized_text = re.sub(r"[^a-zA-Z.,!? ]", "", normalized_text)

    # Remove sequences of whitespace
    normalized_text = re.sub(r"\s+", " ", normalized_text)

    return normalized_text.strip()

In [30]:
normalize_string("    AbC123aa!s's    dd?    ")

'abcaa !ss dd ?'

In [31]:
def read_tsv_file(file_path):
    data = []
    
    print("Reading and processing file ...")
    with open(file_path, 'r', encoding='utf-8') as tsv_file:
        reader = csv.reader(tsv_file, delimiter='\t')
        for row in reader:
            if len(row) == 2:
                q = normalize_string(row[0])
                a = normalize_string(row[1])
                data.append([q, a])
    print("Done reading.")
    return data

In [32]:
data_filename = "cornell_movie_corpus/formatted_qa_pairs.tsv"
pairs = read_tsv_file(data_filename)

Reading and processing file ...
Done reading.


In [33]:
# Visualise the first 2 pairs
pairs[0:2]

[['can we make this quick ? roxanne korrine and andrew barrett are having an incredibly horrendous public break up on the quad . again .',
  'well i thought wed start with pronunciation if thats okay with you .'],
 ['well i thought wed start with pronunciation if thats okay with you .',
  'not the hacking and gagging and spitting part . please .']]

In [34]:
len(pairs)

217150

## Filtering the text 

In [35]:
def filter_qa_pairs(data, q_threshold=12, a_threshold=12):
    """
    Filters the question-answer pairs beyond a threshold length of words,
    Args:
    data = (the list of lists containing the normalized question-answer pairs) 
    q_threshold, a_threshold = (the maximum number of words allowed for the question and answer).
                                                                                                        """
    filtered_data = []
    for pair in data:
        q = pair[0]
        a = pair[1]
        if len(q.split()) <= q_threshold and len(a.split()) <= a_threshold:
            filtered_data.append(pair)
    
    return filtered_data


In [36]:
print(f"There are {len(pairs)} pairs/conversations in the dataset")
filtered_pairs = filter_qa_pairs(pairs)
print(f"After filtering (by threshold) , there are {len(filtered_pairs)} pairs/conversations")

There are 217150 pairs/conversations in the dataset
After filtering (by threshold) , there are 100847 pairs/conversations


### Instantiate an object of the class WordIndexer

In [37]:
voc = WordIndexer("cornell_movie_corpus")

In [38]:
for pair in filtered_pairs:
    q, a = pair[0], pair[1]
    voc.add_sentence(q)
    voc.add_sentence(a)

print(f"Count of words in the voc = {voc.num_words}")
for pair in filtered_pairs[:5]:
    print(pair)
    

Count of words in the voc = 28151
['no no its my fault we didnt have a proper introduction', 'cameron .']
['gosh if only we could find kat a boyfriend . . .', 'let me see what i can do .']
['cesc ma tete . this is my head', 'right . see ? youre ready for the quiz .']
['thats because its such a nice one .', 'forget french .']
['how is our little find the wench a date plan progressing ?', 'well theres someone i think might be']


## Remove those qa pairs if any word of  'q' or 'a' occurs less than a threshold value

In [39]:
def filter_by_word_frequency(voc, qa_pairs, threshold):

    # Remove words below the threshold from the class instance
    voc.trim_less_freq_words(threshold)

    # Filter QA pairs based on word frequency
    keep_pairs = []
    for pair in qa_pairs:
        q, a = pair[0], pair[1]
        q_words = q.split()
        a_words = a.split()

        # Check if any word in 'q' or 'a' is below the threshold
        if any(voc.word_counts.get(word, 0) < threshold for word in q_words + a_words):
            continue                                    # Skip this QA pair if any word is below the threshold

        # Append the QA pair to the filtered list
        keep_pairs.append(pair)

    print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(qa_pairs), len(keep_pairs), len(keep_pairs) / len(qa_pairs)))
    
    return keep_pairs



In [40]:
threshold = 3    # trial and error
pairs = filter_by_word_frequency(voc, filtered_pairs, threshold)

Trimmed from 100847 pairs to 83359, 0.8266 of total


In [41]:
pairs[0:5]

[['no no its my fault we didnt have a proper introduction', 'cameron .'],
 ['gosh if only we could find kat a boyfriend . . .',
  'let me see what i can do .'],
 ['thats because its such a nice one .', 'forget french .'],
 ['there .', 'where ?'],
 ['you have my word . as a gentleman', 'youre sweet .']]

# Data Preparation
- voc = The WordIndexer Object instantiated, it contains the words of the dataset and corresponding indexes  
- pairs = The question-answer pairs after all data preprocessing, in the form of [['q1', 'a1'], [  ], ......,[  ]]

In [50]:
def sentence2indexes(voc, sentence):
    """ Given a sentence as input and returns a list of indexes 
        corresponding to the words in the sentence,
        followed by the EOS token.                               
    Args: 
        voc : The WordIndexer class instantiated.
        sentence : string : The sentence
    returns:
        list: the list of indices
                                                                    """
    words = sentence.split()
    indexes = [voc.get_word_index(word) for word in words]
    indexes.append(EOS_token)
    
    return indexes

In [49]:
pairs[0][0]

'no no its my fault we didnt have a proper introduction'

In [52]:
sentence2indexes(voc, pairs[0][0])

[4, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 2]

In [57]:
# TEST ON A BATCH_SIZE = 6
batch_size = 8
inp = []
op = []
for pair in pairs[:batch_size]:
    inp.append(pair[0])
    op.append(pair[1])
    
indexes = [sentence2indexes(voc, sentence) for sentence in inp]
indexes

[[4, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 2],
 [16, 17, 18, 8, 19, 20, 21, 11, 22, 15, 15, 15, 2],
 [43, 44, 5, 45, 11, 46, 47, 15, 2],
 [63, 15, 2],
 [65, 10, 6, 66, 15, 67, 11, 68, 2],
 [70, 15, 2],
 [10, 85, 76, 37, 2],
 [27, 87, 37, 2]]

In [78]:
def zero_pad_rows(index_list, pad_token=PAD_token):
    """
    Zero-pads the rows in the index_list so that all rows have the same length.
    Transposes the resulting list of lists.

    Args:
        index_list (list[list[int]]): List of index lists to be zero-padded and transposed.
        pad_token (int, optional): The padding token to use. Defaults to PAD_token = 0

    Returns:
        list[tuples]: Transposed and zero-padded list of tuples.
                    Shape = ()

    """
    max_length = max(len(row) for row in index_list)
    
    # Zero pad the rows
    padded_list = [row + [pad_token] * (max_length - len(row)) for row in index_list]
    
    # Transpose the list[list]
    transposed = list(zip(*padded_list))
    
    return transposed

### Testing the function - zero_pad_rows( ) and sentence2index()

In [79]:
padded_list = zero_pad_rows(indexes)
padded_list

[(4, 16, 43, 63, 65, 70, 10, 27),
 (4, 17, 44, 15, 10, 15, 85, 87),
 (5, 18, 5, 2, 6, 2, 76, 37),
 (6, 8, 45, 0, 66, 0, 37, 2),
 (7, 19, 11, 0, 15, 0, 2, 0),
 (8, 20, 46, 0, 67, 0, 0, 0),
 (9, 21, 47, 0, 11, 0, 0, 0),
 (10, 11, 15, 0, 68, 0, 0, 0),
 (11, 22, 2, 0, 2, 0, 0, 0),
 (12, 15, 0, 0, 0, 0, 0, 0),
 (13, 15, 0, 0, 0, 0, 0, 0),
 (2, 15, 0, 0, 0, 0, 0, 0),
 (0, 2, 0, 0, 0, 0, 0, 0)]

In [80]:
rows = len(padded_list)
columns = len(padded_list[0])

print("Shape is (max_length, batch_size) = ", (rows, columns))

Shape is (max_length, batch_size) =  (13, 8)


In [84]:
# This will later help us save space and time during training as it can be stored in 1 bit also
def binaryMatrix(padded_list):
    """ Given a padded matrix, converts it into a binary matrix
    by replacing non-zero elements with 1, else 0
    Args:
        padded_list (list of tuples): A list of tuples representing a matrix.
    Returns:
        list[list]: The binary matrix .
                                                                        """
    # convert non-zero elements to 1 and 0s to 0
    binary_matrix = [[1 if element > 0 else 0 for element in row] for row in padded_list]
    
    return binary_matrix

In [86]:
## Test the function
binaryMatrix(padded_list)

[[1, 1, 1, 1, 1, 1, 1, 1],
 [1, 1, 1, 1, 1, 1, 1, 1],
 [1, 1, 1, 1, 1, 1, 1, 1],
 [1, 1, 1, 0, 1, 0, 1, 1],
 [1, 1, 1, 0, 1, 0, 1, 0],
 [1, 1, 1, 0, 1, 0, 0, 0],
 [1, 1, 1, 0, 1, 0, 0, 0],
 [1, 1, 1, 0, 1, 0, 0, 0],
 [1, 1, 1, 0, 1, 0, 0, 0],
 [1, 1, 0, 0, 0, 0, 0, 0],
 [1, 1, 0, 0, 0, 0, 0, 0],
 [1, 1, 0, 0, 0, 0, 0, 0],
 [0, 1, 0, 0, 0, 0, 0, 0]]

In [89]:
def generateInputTensor(word_indexer, sentence_list):
    # Convert sentences to indexes
    indexes_batch = [sentence2indexes(voc, sentence) for sentence in sentence_list]    
    
    # Get the lengths of each sentences + 1 (EOS_token)
    lengths = torch.tensor([len(index) for index in indexes_batch])
    
    # Zero-pad the index list and transpose it, so as to be able to pass as batches
    padded_batches = zero_pad_rows(indexes_batch)
        
    # Convert the transposed matrix to a LongTensor
    input_tensor = torch.LongTensor(padded_batches)
    
    return input_tensor, lengths

In [90]:
def generateOutputTensor(word_indexer, sentence_list):
    # Convert sentences to indexes
    indexes_batch = [sentence2indexes(voc, sentence) for sentence in sentence_list]    
    
    # Get the maximum target/output length in the batch
    max_target_len = max([len(index) for index in indexes_batch])
    
    # Zero-pad the index list and transpose it, so as to be able to pass as batches
    padded_batches = zero_pad_rows(indexes_batch)
    
    # Get the binary mask
    binary_mask = binaryMatrix(padded_batches)
    binary_mask = torch.ByteTensor(binary_mask)
        
    # Convert the transposed matrix to a LongTensor
    output_tensor = torch.LongTensor(padded_batches)
    
    return output_tensor, binary_mask, max_target_len

In [91]:
def batch2Train(word_indexer, qa_batches):
    """
    Convert question-answer batches into input and output tensors for training.

    Arguments:
        word_indexer (WordIndexer): An instance of the WordIndexer class.
        qa_batches (list of lists): A list of question-answer batches, where each batch is a list of two elements: question and answer.

    Returns:
        input_tensor (torch.LongTensor): The input tensor containing the indexes of the questions after padding.
        input_lengths (torch.Tensor): The tensor containing the lengths of each input sequence.
        output_tensor (torch.LongTensor): The output tensor containing the indexes of the answers after padding.
        output_mask (torch.ByteTensor): The binary mask indicating the positions with non-zero elements in the output tensor.
        max_target_len (int): The maximum length of the target/output sequence.

    """
    # Sort the batches in descending order of question length
    sorted_batches = sorted(qa_batches, key=lambda x: len(x[0]), reverse=True)
    
    question_batch, answer_batch = [], []
    for qa_batch in qa_batches:
        question_batch.append(qa_batch[0])
        answer_batch.append(qa_batch[1])
        
    # generate input tensor and input lengths
    input_tensor, input_lengths = generateInputTensor(word_indexer, question_batch)
    
    # get output tensor, binary mask, and max target length
    output_tensor, output_mask, max_target_len = generateOutputTensor(word_indexer, answer_batch)
    
    return input_tensor, input_lengths, output_tensor, output_mask, max_target_len

# Building the Model

The model is built using these 2 blocks, which is discussed briefly below
### Encoder
-  aedewfwefw

### Decoder
- wefewfwefw

In [42]:
import torch 
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

In [43]:
# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117

In [44]:
CUDA = torch.cuda.is_available()
device = torch.device("cuda" if CUDA else "cpu")

  device = torch.device("cuda" if CUDA else "cpu")


In [45]:
CUDA

True

In [46]:
device

device(type='cuda')

## Encoder
- edkjewbfowen

## Decoder
- wefncnhwoihfno

# Training the Model