In [1]:
import pandas as pd
import numpy as np

# 读取txt文件为DataFrame
file_path = '../cornell/formatted_movie_lines.txt'
data = pd.read_csv(file_path, sep='\t', header=None, names=['speaker1', 'speaker2'])

# 可选：查看前几行数据
print(data.head())

                                            speaker1  \
0  Can we make this quick?  Roxanne Korrine and A...   
1  Well, I thought we'd start with pronunciation,...   
2  Not the hacking and gagging and spitting part....   
3  You're asking me out.  That's so cute. What's ...   
4  No, no, it's my fault -- we didn't have a prop...   

                                            speaker2  
0  Well, I thought we'd start with pronunciation,...  
1  Not the hacking and gagging and spitting part....  
2  Okay... then how 'bout we try out some French ...  
3                                         Forget it.  
4                                           Cameron.  


In [4]:
import unicodedata
import re
import csv
import codecs
import os
# 定义新文件的路径
corpus='../cornell/'
datafile = os.path.join(corpus, "formatted_movie_lines.txt")
# 默认词向量
PAD_token = 0  # Used for padding short sentences
SOS_token = 1  # Start-of-sentence token
EOS_token = 2  # End-of-sentence token

class Voc:
    def __init__(self, name):
        self.name = name
        self.trimmed = False
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3  # Count SOS, EOS, PAD
    # 添加句子中的所有单词到词汇表
    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)
    # 向词汇表中添加单词
    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words += 1
        else:
            self.word2count[word] += 1

    # 删除低于特定计数阈值的单词
    def trim(self, min_count):
        if self.trimmed:
            return
        self.trimmed = True

        keep_words = []

        for k, v in self.word2count.items():
            if v >= min_count:
                keep_words.append(k)

        print('keep_words {} / {} = {:.4f}'.format(
            len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
        ))

        # 重初始化字典
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3 # Count default tokens

        for word in keep_words:
            self.addWord(word)
MAX_LENGTH = 10  # Maximum sentence length to consider
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s
# 初始化Voc对象 和 格式化pairs对话存放到list中
def readVocs(datafile, corpus_name):
    print("Reading lines...")
    # Read the file and split into lines
    lines = open(datafile, encoding='utf-8').read().strip().split('\n')
    # Split every line into pairs and normalize
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
    voc = Voc(corpus_name)
    return voc, pairs

# 如果对 'p' 中的两个句子都低于 MAX_LENGTH 阈值，则返回True
def filterPair(p):
    # Input sequences need to preserve the last word for EOS token
    return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH

# 过滤满足条件的 pairs 对话
def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

# 使用上面定义的函数，返回一个填充的voc对象和对列表
def loadPrepareData(corpus, corpus_name, datafile, save_dir):
    print("Start preparing training data ...")
    voc, pairs = readVocs(datafile, corpus_name)
    print("Read {!s} sentence pairs".format(len(pairs)))
    pairs = filterPairs(pairs)
    print("Trimmed to {!s} sentence pairs".format(len(pairs)))
    print("Counting words...")
    for pair in pairs:
        voc.addSentence(pair[0])
        voc.addSentence(pair[1])
    print("Counted words:", voc.num_words)
    return voc, pairs

# 加载/组装voc和对
save_dir = os.path.join("data", "save")
voc, pairs = loadPrepareData(corpus, '../cornell/formatted_movie_lines.txt', datafile, save_dir)
# 打印一些对进行验证
print("\npairs:")
for pair in pairs[:10]:
    print(pair)

Start preparing training data ...
Reading lines...
Read 221282 sentence pairs
Trimmed to 63446 sentence pairs
Counting words...
Counted words: 17774

pairs:
['there .', 'where ?']
['you have my word . as a gentleman', 'you re sweet .']
['hi .', 'looks like things worked out tonight huh ?']
['you know chastity ?', 'i believe we share an art instructor']
['have fun tonight ?', 'tons']
['well no . . .', 'then that s all you had to say .']
['then that s all you had to say .', 'but']
['but', 'you always been this selfish ?']
['do you listen to this crap ?', 'what crap ?']
['what good stuff ?', ' the real you . ']


In [6]:
len(pairs)

63446

In [2]:
input_sentences = data['speaker1'].tolist()
output_sentences = data['speaker2'].tolist()

In [3]:
import nltk
from nltk.tokenize import word_tokenize
nltk.download('punkt')

# 分词函数
def tokenize_sentences(sentences):
    tokenized = [word_tokenize(sentence.lower()) for sentence in sentences]
    return tokenized

input_tokenized = tokenize_sentences(input_sentences)
output_tokenized = tokenize_sentences(output_sentences)


[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Justi\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [5]:
from tensorflow.keras.preprocessing.text import Tokenizer

tokenizer = Tokenizer()
tokenizer.fit_on_texts(input_tokenized + output_tokenized)

input_sequences = tokenizer.texts_to_sequences(input_tokenized)
output_sequences = tokenizer.texts_to_sequences(output_tokenized)

vocab_size = len(tokenizer.word_index) + 1


In [6]:
from tensorflow.keras.preprocessing.text import Tokenizer

tokenizer = Tokenizer()
tokenizer.fit_on_texts(input_tokenized + output_tokenized)

input_sequences = tokenizer.texts_to_sequences(input_tokenized)
output_sequences = tokenizer.texts_to_sequences(output_tokenized)

vocab_size = len(tokenizer.word_index) + 1


In [7]:
from tensorflow.keras.preprocessing.sequence import pad_sequences

max_sequence_length = max(max(len(seq) for seq in input_sequences), max(len(seq) for seq in output_sequences))

input_padded = pad_sequences(input_sequences, maxlen=max_sequence_length, padding='pre')
output_padded = pad_sequences(output_sequences, maxlen=max_sequence_length, padding='pre')

In [8]:
X = np.array(input_padded)
y = np.array(output_padded)

In [9]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding, LSTM, Dense

embedding_dim = 128
hidden_units = 256

model = Sequential([
    Embedding(vocab_size, embedding_dim, input_length=max_sequence_length),
    LSTM(hidden_units, return_sequences=True),
    Dense(vocab_size, activation='softmax')
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.summary()


Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding (Embedding)        (None, 684, 128)          7525504   
_________________________________________________________________
lstm (LSTM)                  (None, 684, 256)          394240    
_________________________________________________________________
dense (Dense)                (None, 684, 58793)        15109801  
Total params: 23,029,545
Trainable params: 23,029,545
Non-trainable params: 0
_________________________________________________________________


In [13]:
import tensorflow as tf
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

Num GPUs Available:  0


In [15]:
model.fit(X, y, batch_size=16, epochs=5, validation_split=0.2)


Epoch 1/5
  171/11065 [..............................] - ETA: 14:28:49 - loss: 1.0604 - accuracy: 0.9797

KeyboardInterrupt: 