<a href="https://colab.research.google.com/github/ZHOU-py/NLP/blob/master/%E8%81%8A%E5%A4%A9%E6%9C%BA%E5%99%A8%E4%BA%BA/Cornell_Movies_Dialogs_Corpus_S2S_%E8%81%8A%E5%A4%A9%E6%9C%BA%E5%99%A8%E4%BA%BA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### 下载数据文件

In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import torch
from torch.jit import script, trace
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import csv
import random
import re
import os
import unicodedata
import codecs
from io import open
import itertools
import math

In [2]:
USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")

#### 加载和预处理数据
- 220,579 conversational exchanges between 10,292 pairs of movie characters
- involves 9,035 characters from 617 movies
- in total 304,713 utterances

In [3]:
corpus_name = "cornell movie-dialogs corpus" 
corpus = os.path.join("data", corpus_name)
def printLines(file, n=10):
  with open(file, 'rb') as datafile:
    lines = datafile.readlines() 
  for line in lines[:n]:
    print(line)
printLines(os.path.join(corpus, "movie_lines.txt"))

b'L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!\n'
b'L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!\n'
b'L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.\n'
b'L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?\n'
b"L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go.\n"
b'L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow\n'
b"L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie.\n"
b'L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No\n'
b'L870 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I\'m kidding.  You know how sometimes you just become this "persona"?  And you don\'t know how to quit?\n'
b'L869 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Like my fear of wearing pastels?\n'


##### 创建格式化数据
解析原始数据文件 movie_lines.txt
* `loadLines`：将文件的每一行拆分为字段（lineID, characterID, movieID, character, text)组合的字典
* `loadConversations`: 根据`movie_conversations.txt`将`loadLines`中的每一行数据进行归类
* `extractSentencePairs`: 从对话中提取句子对


In [14]:
# 将文件的每一行拆分为字段字典
def loadLines(fileName, fields):
  lines = {}
  with open(fileName, 'r', encoding='iso-8859-1') as f:
    for line in f:
      values = line.split(" +++$+++ ")
      lineObj = {}
      for i, field in enumerate(fields):
        lineObj[field] = values[i]
      lines[lineObj['lineID']] = lineObj
  return lines

In [15]:
#  将 'loadLines'中的行字段分组为基于 *movie_conversations.txt* 的对话
def loadConversations(fileName, lines, fields):
  conversations = []
  with open(fileName, 'r', encoding='iso-8859-1') as f:
    for line in f:
      values = line.split(" +++$+++ ")
      # Extract fields
      convObj = {}
      for i, field in enumerate(fields):
        convObj[field] = values[i]
      # Convert string to list (convObj["utteranceIDs"] == "['L598485','L598486', ...]")
      lineIds = eval(convObj["utteranceIDs"])
      # Reassemble lines
      convObj["lines"] = []
#      print(lineIds)
      for lineId in lineIds:
       # print(lines[lineId])
        if lines[lineId]:
          convObj["lines"].append(lines[lineId])
        else:
          continue
      conversations.append(convObj)

  return conversations

In [16]:
# 从对话中提取一对句子
def extractSentencePairs(conversations):
  qa_pairs = []
  for conversation in conversations:
    for i in range(len(conversation["lines"])-1):
      inputLine = conversation["lines"][i]["text"].strip()
      targetLine = conversation["lines"][i+1]["text"].strip()
      if inputLine and targetLine:
        qa_pairs.append([inputLine,targetLine])
  return qa_pairs

In [17]:
# 定义新文件的路径
datafile = os.path.join(corpus, "formatted_movie_lines.txt") 
delimiter = '\t'
delimiter = str(codecs.decode(delimiter, "unicode_escape"))
# 初始化行dict，对话列表和字段ID
lines = {}
conversations = []
MOVIE_LINES_FIELDS = ["lineID", "characterID", "movieID", "character", "text"] 
MOVIE_CONVERSATIONS_FIELDS = ["character1ID", "character2ID", "movieID", "utteranceIDs"]

# 加载行和进程对话
print("\n Processing corpus...")
lines = loadLines(os.path.join(corpus,"movie_lines.txt"), MOVIE_LINES_FIELDS)
print("\n Loading conversations...")
conversations = loadConversations(os.path.join(corpus, "movie_conversations.txt"),\
                                  lines, MOVIE_CONVERSATIONS_FIELDS)

# 写入新的csv文件
print("\nWriting newly formatted file...")
with open(datafile, 'w', encoding='utf-8') as outputfile:
  writer = csv.writer(outputfile, delimiter=delimiter, lineterminator='\n') 
  for pair in extractSentencePairs(conversations):
    writer.writerow(pair)

# 打印一个样本的行
print("\nSample lines from file:") 
printLines(datafile)


 Processing corpus...

 Loading conversations...

Writing newly formatted file...

Sample lines from file:
b"Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\tWell, I thought we'd start with pronunciation, if that's okay with you.\n"
b"Well, I thought we'd start with pronunciation, if that's okay with you.\tNot the hacking and gagging and spitting part.  Please.\n"
b"Not the hacking and gagging and spitting part.  Please.\tOkay... then how 'bout we try out some French cuisine.  Saturday?  Night?\n"
b"You're asking me out.  That's so cute. What's your name again?\tForget it.\n"
b"No, no, it's my fault -- we didn't have a proper introduction ---\tCameron.\n"
b"Cameron.\tThe thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\n"
b"The thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't d

In [None]:
for i,fields in enumerate(MOVIE_CONVERSATIONS_FIELDS):
  print(i,fields)


#### 加载和清洗数据
创建词汇表并将查询/响应句子加载到内存  
通过数据集中的单词创建一个索引。  
创建`Voc`类，以存储从单词到索引的映射，索引到单词的反向映射，每个单词的计数和总单词量。  
这个类提供：  
* 词汇表中添加单词的方法（`addWord`)  
* 添加所有单词到句子中的方法（`addSentence`)
* 清洗不常见的单词方法（`trim`)  

In [22]:
# 默认词向量
PAD_token = 0 # Used for padding short sentences 
SOS_token = 1 # Start-of-sentence token 
EOS_token = 2 # End-of-sentence token

class Voc:
  def __init__(self, name):
    self.name = name
    self.trimmed = False
    self.word2index = {}
    self.word2count = {}
    self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"} 
    self.num_words = 3 # Count SOS, EOS, PAD

  def addSentence(self, sentence):
    for word in sentence.split(' '):
      self.addWord(word)
  
  def addWord(self, word):
    if word not in self.word2index:
      self.word2index[word] = self.num_words 
      self.word2count[word] = 1 
      self.index2word[self.num_words] = word 
      self.num_words += 1
    else:
      self.word2count[word] += 1
  # 删除低于特定计数阈值的单词 
  def trim(self, min_count):
    if self.trimmed: return
    self.trimmed = True
    
    keep_words = []
    
    for k, v in self.word2count.items(): 
      if v >= min_count:
        keep_words.append(k)
    
    print('keep_words {} / {} = {:.4f}'.format(
      len(keep_words), len(self.word2index), len(keep_words) /
    len(self.word2index) ))
    
    # 重初始化字典
    self.word2index = {}
    self.word2count = {}
    self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}

    self.num_words = 3 # Count default tokens 
    for word in keep_words:
        self.addWord(word)

# 小写并删除非字母字符
def normalizeString(s):
  s = s.lower()
  s = re.sub(r"([.!?])", r" \1", s)
  s = re.sub(r"[^a-zA-Z.!?]+", r" ", s) 
  return s
# 使用字符串句子，返回单词索引的句子

def indexesFromSentence(voc, sentence):
  return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]

组装词汇表和查询/响应语句对。

`unicodeToAscii`将unicode字符串转换为ASCII。将所有字母转换为小写字母并清洗掉除基本标点之外的所有非字母字符（`normalizaString`). 最后为了收敛，过滤掉长度大于MAX_LENGTH的句子（`filterParis`)。

In [19]:
MAX_LENGTH = 10 # Maximum sentence length to consider

# 将Unicode字符串转换为纯ASCII，引用
# https://stackoverflow.com/a/518232/2809427 
def unicodeToAscii(s):
  return ''.join(
  c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn'
  )

# 初始化Voc对象 和 格式化pairs对话存放到list中 
def readVocs(datafile, corpus_name):
  print("Reading lines...")
  # Read the file and split into lines
  lines = open(datafile, encoding='utf-8').read().strip().split('\n') 
  # Split every line into pairs and normalize
  pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines] 
  voc = Voc(corpus_name)
  return voc, pairs

# 如果对 'p' 中的两个句子都低于 MAX_LENGTH 阈值，则返回True 
def filterPair(p):
  # Input sequences need to preserve the last word for EOS token
  return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH

# 过滤满足条件的 pairs 对话 
def filterPairs(pairs):
  return [pair for pair in pairs if filterPair(pair)]

# 使用上面定义的函数，返回一个填充的voc对象和对列表
def loadPrepareData(corpus, corpus_name, datafile, save_dir):
  print("Start preparing training data ...")
  voc, pairs = readVocs(datafile, corpus_name) 
  print("Read {!s} sentence pairs".format(len(pairs))) 
  pairs = filterPairs(pairs)
  print("Trimmed to {!s} sentence pairs".format(len(pairs))) 
  print("Counting words...")
  for pair in pairs:
    voc.addSentence(pair[0])
    voc.addSentence(pair[1]) 
  print("Counted words:", voc.num_words) 
  return voc, pairs


In [23]:
# 加载/组装voc和对
save_dir = os.path.join("data", "save")
voc, pairs = loadPrepareData(corpus, corpus_name, datafile, save_dir) 
# 打印一些对进行验证
print("\npairs:")
for pair in pairs[:10]:
  print(pair)

Start preparing training data ...
Reading lines...
Read 221282 sentence pairs
Trimmed to 63436 sentence pairs
Counting words...
Counted words: 17755

pairs:
['there .', 'where ?']
['you have my word . as a gentleman', 'you re sweet .']
['hi .', 'looks like things worked out tonight huh ?']
['you know chastity ?', 'i believe we share an art instructor']
['have fun tonight ?', 'tons']
['well no . . .', 'then that s all you had to say .']
['then that s all you had to say .', 'but']
['but', 'you always been this selfish ?']
['do you listen to this crap ?', 'what crap ?']
['what good stuff ?', ' the real you . ']


另一个加快收敛的方式是 去除词汇表中很少使用的单词。减少特征空间也会降低模型学习目标函数的难度。  
* 使用`voc.trim`函数去除MIN_COUNT阈值以下的单词。  
* 如果句子中包含词频过小的单词，整个句子也被过滤掉。  

In [24]:
MIN_COUNT = 3 # 修剪的最小字数阈值

def trimRareWords(voc, pairs, MIN_COUNT): 
  # 修剪来自voc的MIN_COUNT下使用的单词 
  voc.trim(MIN_COUNT)

  # Filter out pairs with trimmed words 
  keep_pairs = []
  for pair in pairs:
    input_sentence = pair[0] 
    output_sentence = pair[1] 
    keep_input = True 
    keep_output = True
    # 检查输入句子
    for word in input_sentence.split(' '):
      if word not in voc.word2index: 
        keep_input = False
        break

    # 检查输出句子
    for word in output_sentence.split(' '):
      if word not in voc.word2index: 
        keep_output = False
        break
    
    # 只保留输入或输出句子中不包含修剪单词的对 
    if keep_input and keep_output:
      keep_pairs.append(pair)
  print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(pairs), \
                                                              len(keep_pairs), \
                                                              len(keep_pairs) / len(pairs)))
  return keep_pairs 

# 修剪voc和对
pairs = trimRareWords(voc, pairs, MIN_COUNT)

keep_words 7700 / 17752 = 0.4338
Trimmed from 63436 pairs to 52460, 0.8270 of total


#### 为模型准备数据
