In [1]:
import os
import csv
import codecs
from utils.file_utils import print_lines, load_lines, load_conversations, extract_sentence_pairs

### 查看原始文本数据

In [2]:
corpus_name = "cornell movie-dialogs corpus"
corpus_path = os.path.join("data", corpus_name)
movie_lines_path = os.path.join(corpus_path, "movie_lines.txt") # 原始文件路径
print_lines(movie_lines_path)


b'L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!\n'
b'L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!\n'
b'L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.\n'
b'L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?\n'
b"L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go.\n"
b'L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow\n'
b"L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie.\n"
b'L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No\n'
b'L870 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I\'m kidding.  You know how sometimes you just become this "persona"?  And you don\'t know how to quit?\n'
b'L869 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Like my fear of wearing pastels?\n'


### 查看使用 file_utils 中的函数处理好文本后的结果

In [3]:
# Define path to new file
datafile_path = os.path.join(corpus_path, "formatted_movie_lines.txt")

delimiter = '\t'
# Unescape the delimiter
delimiter = str(codecs.decode(delimiter, "unicode_escape"))

# Initialize lines dict, conversations list, and field ids
lines = {}
conversations = []
MOVIE_LINES_FIELDS = ["lineID", "characterID",
                    "movieID", "character", "text"]
MOVIE_CONVERSATIONS_FIELDS = ["character1ID", "character2ID",
                            "movieID", "utteranceIDs"]

# Load lines and process conversations
print("\nProcessing corpus...")
lines = load_lines(movie_lines_path, MOVIE_LINES_FIELDS)
print("\nLoading conversations...")
conversations = load_conversations(os.path.join(corpus_path, "movie_conversations.txt"), lines, MOVIE_CONVERSATIONS_FIELDS)

# Write new csv file
print("\nWriting newly formatted file...")
with open(datafile_path, 'w', encoding='utf-8') as outputfile:
    writer = csv.writer(outputfile, delimiter=delimiter,
                        lineterminator='\n')
    for pair in extract_sentence_pairs(conversations):
        writer.writerow(pair)

# Print a sample of lines
print("\nSample lines from file:")
print_lines(datafile_path, n=10)



Processing corpus...

Loading conversations...

Writing newly formatted file...

Sample lines from file:
b"Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\tWell, I thought we'd start with pronunciation, if that's okay with you.\r\n"
b"Well, I thought we'd start with pronunciation, if that's okay with you.\tNot the hacking and gagging and spitting part.  Please.\r\n"
b"Not the hacking and gagging and spitting part.  Please.\tOkay... then how 'bout we try out some French cuisine.  Saturday?  Night?\r\n"
b"You're asking me out.  That's so cute. What's your name again?\tForget it.\r\n"
b"No, no, it's my fault -- we didn't have a proper introduction ---\tCameron.\r\n"
b"Cameron.\tThe thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\r\n"
b"The thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister. 

### TODO Vocab 类测试

在运行一下代码之前，请在这里测试你编写好的 Vocab 类。

In [4]:
# limsum，待测试句子
from utils.vocab import Vocab
limsum = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum."
vocab = Vocab("Lorem ipsum")
vocab.add_sentence(limsum)
vocab.word2count

{'Lorem': 1,
 'ipsum': 1,
 'dolor': 2,
 'sit': 1,
 'amet,': 1,
 'consectetur': 1,
 'adipiscing': 1,
 'elit,': 1,
 'sed': 1,
 'do': 1,
 'eiusmod': 1,
 'tempor': 1,
 'incididunt': 1,
 'ut': 2,
 'labore': 1,
 'et': 1,
 'dolore': 2,
 'magna': 1,
 'aliqua.': 1,
 'Ut': 1,
 'enim': 1,
 'ad': 1,
 'minim': 1,
 'veniam,': 1,
 'quis': 1,
 'nostrud': 1,
 'exercitation': 1,
 'ullamco': 1,
 'laboris': 1,
 'nisi': 1,
 'aliquip': 1,
 'ex': 1,
 'ea': 1,
 'commodo': 1,
 'consequat.': 1,
 'Duis': 1,
 'aute': 1,
 'irure': 1,
 'in': 3,
 'reprehenderit': 1,
 'voluptate': 1,
 'velit': 1,
 'esse': 1,
 'cillum': 1,
 'eu': 1,
 'fugiat': 1,
 'nulla': 1,
 'pariatur.': 1,
 'Excepteur': 1,
 'sint': 1,
 'occaecat': 1,
 'cupidatat': 1,
 'non': 1,
 'proident,': 1,
 'sunt': 1,
 'culpa': 1,
 'qui': 1,
 'officia': 1,
 'deserunt': 1,
 'mollit': 1,
 'anim': 1,
 'id': 1,
 'est': 1,
 'laborum.': 1}

### 生成单词词典

这里可能会运行十几秒，耐心等待。

In [5]:
from utils.data_utils import load_prepare_data

# [TODO][Optional] 如果实现了硬盘保存，文件会保存的位置
save_dir = os.path.join("data", corpus_name)
# 载入/处理 词典 和 对话文本对
vocab, pairs = load_prepare_data(corpus_name, datafile_path, save_dir)
print("Counted words:", vocab.num_words)


已获取大小为18004的vocab,长度为64271的pairs
Counted words: 18007


In [6]:
# 检验，打印前10个对话文本对
print("\npairs:")
for pair in pairs[:10]:
    print(pair)


pairs:
['there .', 'where ?']
['you have my word . as a gentleman', 'you re sweet .']
['hi .', 'looks like things worked out tonight huh ?']
['you know chastity ?', 'i believe we share an art instructor']
['have fun tonight ?', 'tons']
['well no . . .', 'then that s all you had to say .']
['then that s all you had to say .', 'but']
['but', 'you always been this selfish ?']
['do you listen to this crap ?', 'what crap ?']
['what good stuff ?', 'the real you .']


In [7]:
# 检验，查看 vocab 的属性/成员变量
print(vocab.__dict__.keys())
# 查看 word2count 的前10项
print("vocab.word2count:")
for i, (key, value) in enumerate(vocab.word2count.items()):
    print(key, value)
    if i >= 10-1:
        break
# 查看 index2word 的前10项
print("vocab.index2word:")
for i, (key, value) in enumerate(vocab.index2word.items()):
    print(key, value)
    if i >= 10-1:
        break

dict_keys(['name', 'trimmed', 'word2index', 'word2count', 'index2word', 'num_words'])
vocab.word2count:
there 2013
. 104124
where 2475
? 43942
you 29248
have 3023
my 3148
word 125
as 558
a 8579
vocab.index2word:
3 there
4 .
5 where
6 ?
7 you
8 have
9 my
10 word
11 as
12 a


### TODO 测试 trim_rare_words

In [8]:
from utils.preprocessing import trim_rare_words, MIN_COUNT
pairs = trim_rare_words(vocab, pairs, MIN_COUNT)

Trimmed from 64271 pairs to 53125, 0.8266 of total


In [9]:
from utils.preprocessing import indices_from_sentence, zero_padding, binary_matrix
# Example for validation
small_batch_size = 5
# [random.choice(pairs) for _ in range(small_batch_size)]
chosen_pairs = [pairs[_] for _ in range(small_batch_size)]
input_batch, output_batch = [], []
for pair in chosen_pairs:
    input_batch.append(pair[0])
    output_batch.append(pair[1])


In [10]:
print(input_batch)
print(output_batch)

['there .', 'you have my word . as a gentleman', 'hi .', 'have fun tonight ?', 'well no . . .']
['where ?', 'you re sweet .', 'looks like things worked out tonight huh ?', 'tons', 'then that s all you had to say .']


#### 检查 input_batch 的 indices_batch 和 padded_list 结果

In [11]:
indices_batch = [indices_from_sentence(
    vocab, sentence) for sentence in input_batch]
print(f"indices_batch: {indices_batch}")
padded_list = zero_padding(indices_batch)
print(f"padded_list: {padded_list}")


indices_batch: [[3, 4, 2], [7, 8, 9, 10, 4, 11, 12, 13, 2], [16, 4, 2], [8, 31, 22, 6, 2], [33, 34, 4, 4, 4, 2]]
padded_list: [(3, 7, 16, 8, 33), (4, 8, 4, 31, 34), (2, 9, 2, 22, 4), (0, 10, 0, 6, 4), (0, 4, 0, 2, 4), (0, 11, 0, 0, 2), (0, 12, 0, 0, 0), (0, 13, 0, 0, 0), (0, 2, 0, 0, 0)]


#### 检查 output_batch 的 indices_batch 和 padded_list 结果

In [12]:
indices_batch = [indices_from_sentence(
    vocab, sentence) for sentence in output_batch]
print(f"indices_batch: {indices_batch}")
padded_list = zero_padding(indices_batch)
print(f"padded_list: {padded_list}")
[print(_) for _ in padded_list]


indices_batch: [[5, 6, 2], [7, 14, 15, 4, 2], [17, 18, 19, 20, 21, 22, 23, 6, 2], [32, 2], [35, 36, 37, 38, 7, 39, 40, 41, 4, 2]]
padded_list: [(5, 7, 17, 32, 35), (6, 14, 18, 2, 36), (2, 15, 19, 0, 37), (0, 4, 20, 0, 38), (0, 2, 21, 0, 7), (0, 0, 22, 0, 39), (0, 0, 23, 0, 40), (0, 0, 6, 0, 41), (0, 0, 2, 0, 4), (0, 0, 0, 0, 2)]
(5, 7, 17, 32, 35)
(6, 14, 18, 2, 36)
(2, 15, 19, 0, 37)
(0, 4, 20, 0, 38)
(0, 2, 21, 0, 7)
(0, 0, 22, 0, 39)
(0, 0, 23, 0, 40)
(0, 0, 6, 0, 41)
(0, 0, 2, 0, 4)
(0, 0, 0, 0, 2)


[None, None, None, None, None, None, None, None, None, None]

#### 检查 binary_matrix 函数的输出 mask

In [13]:
mask = binary_matrix(padded_list)
[print(_) for _ in mask]
print(type(mask))


[1, 1, 1, 1, 1]
[1, 1, 1, 1, 1]
[1, 1, 1, 0, 1]
[0, 1, 1, 0, 1]
[0, 1, 1, 0, 1]
[0, 0, 1, 0, 1]
[0, 0, 1, 0, 1]
[0, 0, 1, 0, 1]
[0, 0, 1, 0, 1]
[0, 0, 0, 0, 1]
<class 'list'>


#### 检查 binary_matrix 函数的输出 mask

In [14]:
from utils.preprocessing import input_variable, output_variable
input_, lengths = input_variable(input_batch, vocab)
target, mask, max_target_len = output_variable(output_batch, vocab)
max_target_len


10

In [15]:
print(f"input_:\n{input_}")
print(f"target:\n{target}")
print(f"lengths:\n{lengths}")
print(f"mask:\n{mask}")
print(f"max_target_len: \n{max_target_len}")


input_:
tensor([[ 3,  7, 16,  8, 33],
        [ 4,  8,  4, 31, 34],
        [ 2,  9,  2, 22,  4],
        [ 0, 10,  0,  6,  4],
        [ 0,  4,  0,  2,  4],
        [ 0, 11,  0,  0,  2],
        [ 0, 12,  0,  0,  0],
        [ 0, 13,  0,  0,  0],
        [ 0,  2,  0,  0,  0]])
target:
tensor([[ 5,  7, 17, 32, 35],
        [ 6, 14, 18,  2, 36],
        [ 2, 15, 19,  0, 37],
        [ 0,  4, 20,  0, 38],
        [ 0,  2, 21,  0,  7],
        [ 0,  0, 22,  0, 39],
        [ 0,  0, 23,  0, 40],
        [ 0,  0,  6,  0, 41],
        [ 0,  0,  2,  0,  4],
        [ 0,  0,  0,  0,  2]])
lengths:
tensor([3, 9, 3, 5, 6])
mask:
tensor([[ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True, False,  True],
        [False,  True,  True, False,  True],
        [False,  True,  True, False,  True],
        [False, False,  True, False,  True],
        [False, False,  True, False,  True],
        [False, False,  True, False,  True],
        [Fals

In [20]:
import itertools
l = [[1,3,2,4,5,6],[2,4,5],[1,2]]
print(*l)
print(l)
list(itertools.zip_longest(*l,fillvalue=0))

[1, 3, 2, 4, 5, 6] [2, 4, 5] [1, 2]
[[1, 3, 2, 4, 5, 6], [2, 4, 5], [1, 2]]


[(1, 2, 1), (3, 4, 2), (2, 5, 0), (4, 0, 0), (5, 0, 0), (6, 0, 0)]