In [222]:
import random
import itertools
import torch
import numpy as np
from process import IndexMapping

In [223]:
with open('intermedium/processed_movie_lines.txt', 'r') as f:
    read_in = f.read().strip().split('\n')

In [224]:
pairs = [[s for s in l.split('\t')] for l in read_in]

In [225]:
pairs[:5]

[['not the hacking and gagging and spitting part. please.',
  'okay... then how bout we try out some french cuisine. saturday? night?'],
 ['you re asking me out. that s so cute. what s your name again?',
  'forget it.'],
 ['no no it s my fault we didn t have a proper introduction', 'cameron.'],
 ['gosh if only we could find kat a boyfriend...',
  'let me see what i can do.'],
 ['that s because it s such a nice one.', 'forget french.']]

In [226]:
index_mapping = IndexMapping()

In [227]:
for p in pairs:
    index_mapping.add_sentence(p[0])
    index_mapping.add_sentence(p[1])

In [228]:
index_mapping.n_words

28410

In [229]:
index_mapping.word2index

{'not': 3,
 'the': 4,
 'hacking': 5,
 'and': 6,
 'gagging': 7,
 'spitting': 8,
 'part.': 9,
 'please.': 10,
 'okay...': 11,
 'then': 12,
 'how': 13,
 'bout': 14,
 'we': 15,
 'try': 16,
 'out': 17,
 'some': 18,
 'french': 19,
 'cuisine.': 20,
 'saturday?': 21,
 'night?': 22,
 'you': 23,
 're': 24,
 'asking': 25,
 'me': 26,
 'out.': 27,
 'that': 28,
 's': 29,
 'so': 30,
 'cute.': 31,
 'what': 32,
 'your': 33,
 'name': 34,
 'again?': 35,
 'forget': 36,
 'it.': 37,
 'no': 38,
 'it': 39,
 'my': 40,
 'fault': 41,
 'didn': 42,
 't': 43,
 'have': 44,
 'a': 45,
 'proper': 46,
 'introduction': 47,
 'cameron.': 48,
 'gosh': 49,
 'if': 50,
 'only': 51,
 'could': 52,
 'find': 53,
 'kat': 54,
 'boyfriend...': 55,
 'let': 56,
 'see': 57,
 'i': 58,
 'can': 59,
 'do.': 60,
 'because': 61,
 'such': 62,
 'nice': 63,
 'one.': 64,
 'french.': 65,
 'there.': 66,
 'where?': 67,
 'word.': 68,
 'as': 69,
 'gentleman': 70,
 'sweet.': 71,
 'sure': 72,
 'have.': 73,
 'really': 74,
 'wanna': 75,
 'go': 76,
 'but':

In [230]:
PAD = 0
SOS = 1
EOS = 2

In [231]:
def word_indexing(mapping, sent):
    """
    Represent each word in the given sentence by an integer defined in the created indexing mapping

    Args:
        mapping (dict): the established mapping between words and integers
        sent (str): The sentence for indexing
    Return:
        (list<int>): A list of integers for word representation
    """
    return [mapping.word2index[word] for word in sent.split(' ')] + [EOS]

In [232]:
# Indexing one sentence
rand_idx = random.randint(0, len(index_mapping.word2index))
sample_sent = pairs[rand_idx][0]
print('Given sentence: ')
print(f'=> {sample_sent}')
print('Index representation: ')
print(f'=> {word_indexing(index_mapping, sample_sent)}')

Given sentence: 
=> no that could not be done mr. ruby. there are a good many things involved in that.
Index representation: 
=> [38, 28, 52, 3, 89, 603, 2190, 14066, 328, 302, 45, 131, 420, 96, 3768, 182, 777, 2]


In [233]:
# Use first 32 pairs for showing how to prepare data for training
first_batch = pairs[:32]
input_batch, output_batch = zip(*first_batch)

In [234]:
indexes_batch = [word_indexing(index_mapping, input_sent) for input_sent in input_batch]


In [235]:
for line in indexes_batch:
    print(line)

[3, 4, 5, 6, 7, 6, 8, 9, 10, 2]
[23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 29, 33, 34, 35, 2]
[38, 38, 39, 29, 40, 41, 15, 42, 43, 44, 45, 46, 47, 2]
[49, 50, 51, 15, 52, 53, 54, 45, 55, 2]
[28, 29, 61, 39, 29, 62, 45, 63, 64, 2]
[66, 2]
[23, 44, 40, 68, 69, 45, 70, 2]
[72, 73, 2]
[58, 74, 74, 74, 75, 76, 77, 58, 59, 78, 3, 79, 40, 80, 81, 2]
[93, 2]
[44, 100, 101, 2]
[58, 103, 91, 23, 104, 105, 4, 106, 77, 23, 107, 108, 88, 89, 109, 110, 2]
[58, 111, 2]
[117, 118, 2]
[12, 28, 29, 119, 23, 120, 88, 121, 2]
[77, 2]
[125, 23, 126, 88, 123, 127, 2]
[58, 128, 23, 129, 130, 88, 4, 131, 132, 133, 2]
[32, 131, 134, 2]
[38, 2]
[141, 2]
[85, 143, 2]
[146, 125, 147, 2]
[115, 23, 149, 33, 150, 2]
[151, 2]
[155, 115, 156, 157, 156, 158, 159, 160, 2]
[161, 2]
[163, 2]
[169, 156, 170, 171, 172, 2]
[58, 44, 88, 89, 181, 182, 183, 184, 2]
[23, 153, 23, 24, 4, 51, 187, 105, 4, 188, 2]
[39, 29, 178, 2]


In [236]:
max_lengths = max([len(indexes) for indexes in indexes_batch])
aligned_batch = [line + [PAD] * (max_lengths - len(line)) for line in indexes_batch]

In [237]:
for line in aligned_batch:
    print(f'{line} => length:{len(line)}')

[3, 4, 5, 6, 7, 6, 8, 9, 10, 2, 0, 0, 0, 0, 0, 0, 0] => length:17
[23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 29, 33, 34, 35, 2, 0, 0] => length:17
[38, 38, 39, 29, 40, 41, 15, 42, 43, 44, 45, 46, 47, 2, 0, 0, 0] => length:17
[49, 50, 51, 15, 52, 53, 54, 45, 55, 2, 0, 0, 0, 0, 0, 0, 0] => length:17
[28, 29, 61, 39, 29, 62, 45, 63, 64, 2, 0, 0, 0, 0, 0, 0, 0] => length:17
[66, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] => length:17
[23, 44, 40, 68, 69, 45, 70, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0] => length:17
[72, 73, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] => length:17
[58, 74, 74, 74, 75, 76, 77, 58, 59, 78, 3, 79, 40, 80, 81, 2, 0] => length:17
[93, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] => length:17
[44, 100, 101, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] => length:17
[58, 103, 91, 23, 104, 105, 4, 106, 77, 23, 107, 108, 88, 89, 109, 110, 2] => length:17
[58, 111, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] => length:17
[117, 118, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [239]:
def aligned_index_batch(input_batch, mapping):
    """
    Get an aligned batch of sentence that each sentence is represented by indexing integers with same length

    Args:
        input_batch:
        mapping
    Return:
        batch tensor:
        lengths tensor:
    """
    indexes_batch = [word_indexing(mapping, input_sent) for input_sent in input_batch]
    lengths = [len(indexes) for indexes in indexes_batch]

    # Padding the line with PAD token if the length of a sample is less than the maximum length
    aligned_batch = [line + [PAD] * (max(lengths) - len(line)) for line in indexes_batch]

    return torch.LongTensor(aligned_batch), torch.tensor(lengths)

In [240]:
input_index_batch, _ = aligned_index_batch(input_batch, index_mapping)
input_index_batch = torch.transpose(input_index_batch, dim0=0, dim1=1)
input_index_batch
# input_index_batch.shape

tensor([[  3,  23,  38,  49,  28,  66,  23,  72,  58,  93,  44,  58,  58, 117,
          12,  77, 125,  58,  32,  38, 141,  85, 146, 115, 151, 155, 161, 163,
         169,  58,  23,  39],
        [  4,  24,  38,  50,  29,   2,  44,  73,  74,   2, 100, 103, 111, 118,
          28,   2,  23, 128, 131,   2,   2, 143, 125,  23,   2, 115,   2,   2,
         156,  44, 153,  29],
        [  5,  25,  39,  51,  61,   0,  40,   2,  74,   0, 101,  91,   2,   2,
          29,   0, 126,  23, 134,   0,   0,   2, 147, 149,   0, 156,   0,   0,
         170,  88,  23, 178],
        [  6,  26,  29,  15,  39,   0,  68,   0,  74,   0,   2,  23,   0,   0,
         119,   0,  88, 129,   2,   0,   0,   0,   2,  33,   0, 157,   0,   0,
         171,  89,  24,   2],
        [  7,  27,  40,  52,  29,   0,  69,   0,  75,   0,   0, 104,   0,   0,
          23,   0, 123, 130,   0,   0,   0,   0,   0, 150,   0, 156,   0,   0,
         172, 181,   4,   0],
        [  6,  28,  41,  53,  62,   0,  45,   0,  76,   0,  

In [241]:
output_index_batch, _ = aligned_index_batch(output_batch, index_mapping)
output_index_batch = torch.transpose(output_index_batch, dim0=0, dim1=1)
output_index_batch.shape

torch.Size([16, 32])

In [242]:
output_index_batch.shape

torch.Size([16, 32])

In [243]:
def get_binary_matrix(batch):
    """
    Get a binary matrix that all entries placed 1 represent to the non-zero entries in the
    given batch matrix.

    Args:

    Return:

    """
    assert isinstance(batch, torch.Tensor), "The given batch should be a tensor"
        
    nonzero_indexes = torch.nonzero(batch)
    
    mask = np.zeros_like(batch)
    for p in nonzero_indexes:
        row, col = p[0], p[1]
        mask[row][col] = 1
    
    return torch.BoolTensor(mask)

In [244]:
mask = get_binary_matrix(output_index_batch)
mask

tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True],
        [ True,  True, False,  True,  True, False,  True,  True,  True,  True,
         False,  True,  True,  True, False,  True,  True,  True,  True,  True,
          True,  True,  True, False,  True, False, False,  True,  True,  True,
          True, False],
        [ True, False, False,  True, False, False,  True,  True,  True,  True,
         False, False,  True,  True, False,  True, False,  True,  True,  True,
          True,  True,  True, False,  True, False, False,  

In [245]:
def split_pairs(pairs, training_ratio=0.9, shuffle=True):
    """
    Split the whole pairs dataset into training set and testing set

    Args:
        pairs (list<list<str>>): The processed pairs
        training_ratio (int): the propotion for dataset splition, default as 0.9 for training
    """
    if shuffle:
        random.shuffle(pairs)
    
    split_index = int(len(pairs) * training_ratio)
    return pairs[:split_index], pairs[split_index:]

In [246]:
train_pairs, test_pairs = split_pairs(pairs)

In [247]:
print(f'Total samples in training set: {len(train_pairs)}')
print(f'Total samples in test set: {len((test_pairs))}')

Total samples in training set: 105405
Total samples in test set: 11712


In [248]:
def get_batch(pairs, batch_size, shuffle=False, drop_last=False):
    """
    Obtain a mini-batch from the given dataset

    Args:
        pairs (list<list<str>>): the given pairs set for creating mini-batch
        batch_size (int): the desired batch size in each batch
        shuffle (bool): 
        drop_last (bool):
    Return: 
        A batch of pairs
    """
    if shuffle:
        random.shuffle(pairs)
    
    # Two points for slicing samplgs to form a batch
    start_index = 0 
    end_index = batch_size

    while end_index < len(pairs):
        batch = pairs[start_index : end_index]
        temp = end_index
        end_index += batch_size
        start_index = temp
        yield batch

    if not drop_last:
        # Return all remaining sample to form a batch if not drop the imcompleted batch
        batch = pairs[start_index:]
        yield batch

In [249]:
for i, batch in enumerate(get_batch(train_pairs, batch_size=64, drop_last=False)):
    print(f'Batch index {i}: batch size: {len(batch)}')

Batch index 0: batch size: 64
Batch index 1: batch size: 64
Batch index 2: batch size: 64
Batch index 3: batch size: 64
Batch index 4: batch size: 64
Batch index 5: batch size: 64
Batch index 6: batch size: 64
Batch index 7: batch size: 64
Batch index 8: batch size: 64
Batch index 9: batch size: 64
Batch index 10: batch size: 64
Batch index 11: batch size: 64
Batch index 12: batch size: 64
Batch index 13: batch size: 64
Batch index 14: batch size: 64
Batch index 15: batch size: 64
Batch index 16: batch size: 64
Batch index 17: batch size: 64
Batch index 18: batch size: 64
Batch index 19: batch size: 64
Batch index 20: batch size: 64
Batch index 21: batch size: 64
Batch index 22: batch size: 64
Batch index 23: batch size: 64
Batch index 24: batch size: 64
Batch index 25: batch size: 64
Batch index 26: batch size: 64
Batch index 27: batch size: 64
Batch index 28: batch size: 64
Batch index 29: batch size: 64
Batch index 30: batch size: 64
Batch index 31: batch size: 64
Batch index 32: ba