# Библиотеки

In [1]:
import numpy as np
import tensorflow as tf

# 1. Подготовка датасета

In [2]:
with open('movie_conversations.txt', 'r') as file:
    lines1 = file.read().splitlines()

In [3]:
num = []
for line1 in lines1:
    num.append(line1[line1.index('[') + 1 : line1.index(']')].replace("'", ""))

In [4]:
input_, output_ = [], []
for num_ in num:
    for i in range(len(num_.split(', ')) - 1):
        input_.append(num_.split(', ')[i])
        output_.append(num_.split(', ')[i + 1])

In [10]:
with open('movie_lines.txt', 'r', encoding='ISO-8859-1') as file:
    lines2 = file.read().splitlines()

In [19]:
lines2[-1]

"L666256 +++$+++ u9034 +++$+++ m616 +++$+++ VEREKER +++$+++ Colonel Durnford... William Vereker. I hear you 've been seeking Officers?"

In [16]:
lines2[0][:lines2[0].index('+') - 1]

'L1045'

In [21]:
name, line = [], []
for line2 in lines2:
  try:
    name.append(line2[:line2.index('+') - 1])
    for _ in range(4):
        line2 = line2[line2.index('$')+1:]  
    line.append(line2[4:]) 
  except:
    print(name)

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [22]:
line_name = {}
for name, line in zip(name, line):
    line_name[name] = line

In [None]:
line_name

In [24]:
input_texts, target_texts = [], []
for input_1, output_1 in zip(input_, output_):
    input_texts.append(line_name[input_1])
    target_texts.append(line_name[output_1])

# 2. Признаки

## 2.1 Подготовка словарей

In [25]:
def prepare_vocab(texts):
    vocab = sorted(set(''.join(texts)))
    vocab.append('<START>')
    vocab.append('<END>')
    vocab_size = len(vocab)
    char2idx = {u:i for i, u in enumerate(vocab)}
    idx2char = np.array(vocab)
    return vocab_size, char2idx, idx2char

INPUT_VOCAB_SIZE, input_char2idx, input_idx2char = prepare_vocab(input_texts)
TARGET_VOCAB_SIZE, target_char2idx, target_idx2char = prepare_vocab(target_texts)

In [None]:
input_char2idx

## 2.2 Токенизация

In [27]:
input_texts_as_int = [[input_char2idx[c] for c in text] for text in input_texts]
target_texts_as_int = [[target_char2idx[c] for c in text] for text in target_texts]

encoder_input_seqs = [np.array(text) for text in input_texts_as_int]
decoder_input_seqs = []
decoder_target_seqs = []
for target_text in target_texts_as_int:
    decoder_input_seqs.append(np.array([target_char2idx['<START>']] + target_text))
    decoder_target_seqs.append(np.array(target_text + [target_char2idx['<END>']]))

## 2.3 Паддинг

In [28]:
max([len(seq) for seq in encoder_input_seqs]), np.mean([len(seq) for seq in encoder_input_seqs]), np.median([len(seq) for seq in encoder_input_seqs])

(1903, 53.6749828532236, 34.0)

In [29]:
max([len(seq) for seq in decoder_input_seqs]), np.mean([len(seq) for seq in decoder_input_seqs]), np.median([len(seq) for seq in decoder_input_seqs])

(3047, 56.47703234423507, 36.0)

In [30]:
max_enc_seq_length = 64
max_dec_seq_length = 64

encoder_input_seqs = tf.keras.preprocessing.sequence.pad_sequences(
    encoder_input_seqs,
    value=input_char2idx[' '],
    padding='post',
    maxlen=max_enc_seq_length)

decoder_input_seqs = tf.keras.preprocessing.sequence.pad_sequences(
    decoder_input_seqs,
    value=target_char2idx[' '],
    padding='post',
    maxlen=max_dec_seq_length)

decoder_target_seqs = tf.keras.preprocessing.sequence.pad_sequences(
    decoder_target_seqs,
    value=target_char2idx[' '],
    padding='post',
    maxlen=max_dec_seq_length)

# 3. Модель

## 3.1 Обучение

In [31]:
H_SIZE = 256 # Размерность скрытого состояния LSTM
EMB_SIZE = 256 # размерность эмбеддингов (и для входных и для выходных цепочек)

class Encoder(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.embed = tf.keras.layers.Embedding(INPUT_VOCAB_SIZE, EMB_SIZE)
        self.lstm = tf.keras.layers.LSTM(H_SIZE, return_sequences=False, return_state=True)
        
    def call(self, x):
        out = self.embed(x)
        _, h, c = self.lstm(out)
        state = (h, c)
        return state

class Decoder(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.embed = tf.keras.layers.Embedding(TARGET_VOCAB_SIZE, EMB_SIZE)
        self.lstm = tf.keras.layers.LSTM(H_SIZE, return_sequences=True, return_state=True)
        self.fc = tf.keras.layers.Dense(TARGET_VOCAB_SIZE, activation='softmax')
        
    def call(self, x, init_state):
        out = self.embed(x)
        out, h, c = self.lstm(out, initial_state=init_state)
        out = self.fc(out)
        state = (h, c)
        return out, state

encoder_model = Encoder()
decoder_model = Decoder()

encoder_inputs = tf.keras.layers.Input(shape=(None,))
decoder_inputs = tf.keras.layers.Input(shape=(None,))

In [32]:
type(encoder_inputs)

keras.engine.keras_tensor.KerasTensor

In [33]:
enc_state = encoder_model(encoder_inputs)
decoder_outputs, _ = decoder_model(decoder_inputs, enc_state)

seq2seq = tf.keras.Model([encoder_inputs, decoder_inputs], decoder_outputs)

In [34]:
BATCH_SIZE = 64
EPOCHS = 100

loss = tf.losses.SparseCategoricalCrossentropy()
seq2seq.compile(optimizer='rmsprop', loss=loss, metrics=['accuracy'])
seq2seq.fit([encoder_input_seqs, decoder_input_seqs], decoder_target_seqs,
          batch_size=BATCH_SIZE,
          epochs=EPOCHS)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78

<keras.callbacks.History at 0x7f73bc35ef10>

## 3.2 Инференс

In [35]:
def seq2seq_inference(input_seq):
    state = encoder_model(input_seq)

    target_seq = np.array([[target_char2idx['<START>']]])

    decoded_sentence = ''
    while True:
        output_tokens, state = decoder_model(target_seq, state)

        sampled_token_index = np.argmax(output_tokens[0, -1, :])
        sampled_char = target_idx2char[sampled_token_index]
        decoded_sentence += sampled_char

        if (sampled_char == '<END>' or
           len(decoded_sentence) > max_dec_seq_length):
            break

        target_seq = np.array([[sampled_token_index]])

    return decoded_sentence

# 4. Результат

In [36]:
for seq_index in range(0, 20):
    input_seq = encoder_input_seqs[seq_index: seq_index + 1]
    decoded_sentence = seq2seq_inference(input_seq)
    print('-')
    print('Input sentence:', input_texts[seq_index])
    print('Result sentence:', decoded_sentence)
    print('Target sentence:', target_texts[seq_index])

-
Input sentence: Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.
Result sentence: What do you mean?<END>
Target sentence: Well, I thought we'd start with pronunciation, if that's okay with you.
-
Input sentence: Well, I thought we'd start with pronunciation, if that's okay with you.
Result sentence: I don't know what you mean.<END>
Target sentence: Not the hacking and gagging and spitting part.  Please.
-
Input sentence: Not the hacking and gagging and spitting part.  Please.
Result sentence: What do you mean?<END>
Target sentence: Okay... then how 'bout we try out some French cuisine.  Saturday?  Night?
-
Input sentence: You're asking me out.  That's so cute. What's your name again?
Result sentence: Just a little thing.<END>
Target sentence: Forget it.
-
Input sentence: No, no, it's my fault -- we didn't have a proper introduction ---
Result sentence: You don't have to take me to the bathroom.<END>
