In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import pickle

import os
import csv
import codecs
import numpy  as np
import pandas as pd

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
corpus_name = "cornell movie-dialogs corpus"

datafile = os.path.join(corpus_name, "movie_lines.txt")

with open(datafile, 'rb') as f:
    lines = f.readlines()
    
for line in lines[:10]:
    print(line)

b'L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!\n'
b'L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!\n'
b'L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.\n'
b'L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?\n'
b"L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go.\n"
b'L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow\n'
b"L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie.\n"
b'L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No\n'
b'L870 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I\'m kidding.  You know how sometimes you just become this "persona"?  And you don\'t know how to quit?\n'
b'L869 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Like my fear of wearing pastels?\n'


In [3]:
datafile = os.path.join(corpus_name, "formatted_movie_lines.txt")

<h3>Preprocess data - lower</h3>

In [4]:
import re
import unicodedata

def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    s = re.sub(r"\s+", r" ", s).strip()
    return s

In [5]:
lines = open(datafile, encoding='utf-8').read().strip().split('\n')

pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]

In [6]:
pairs[456]

['what s a synonym for throbbing ?',
 'sarah lawrence is on the other side of the country .']

<h3>Split Data</h3>

In [7]:
pairs = [[pair[0].split(), pair[1].split()] for pair in pairs]

<h3>Выкинем слишком длинные предложения</h3>

In [8]:
min_length = 16
pairs = [pair for pair in pairs if len(pair[0]) <= min_length and len(pair[1]) <= min_length]

In [9]:
len(pairs)

126591

<h3>Train Test Split</h3>

In [10]:
import random

test_size = 0.1

random.shuffle(pairs)
idx = int(len(pairs) * test_size)

train_pairs, test_pairs = pairs[idx:], pairs[:idx]

In [11]:
len(train_pairs), len(test_pairs)

(113932, 12659)

<h3>Count words</h3>

In [12]:
from collections import Counter

word_count = Counter()

for pair in train_pairs:
    for word in pair[0]:
        word_count[word] += 1
    for word in pair[1]:
        word_count[word] += 1

<h3>Word to Id</h3>

In [13]:
min_freq = 10


pad_idx = 0
unk_idx = 1
sos_idx = 2
eos_idx = 3

word2id = {
    "<pad>": pad_idx,
    "<unk>": unk_idx,
    "<sos>": sos_idx,
    "<eos>": eos_idx,
}

i = 4
for word, count in word_count.items():
    if count >= min_freq:
        word2id[word] = i
        i += 1

In [14]:
len(word2id)

5784

<h3>Tokenize</h3>

In [15]:
train_data, test_data = [], []
    
for pair in train_pairs:
    train_data.append([
        [word2id.get(word, unk_idx) for word in pair[0]],
        [word2id.get(word, unk_idx) for word in pair[1]],
    ])
    
for pair in test_pairs:
    test_data.append([
        [word2id.get(word, unk_idx) for word in pair[0]],
        [word2id.get(word, unk_idx) for word in pair[1]],
    ])

In [16]:
train_data[6454]

[[16, 526, 167, 30, 81, 825, 9, 9, 9],
 [1655, 9, 9, 3888, 37, 23, 123, 46, 1156, 20]]

<h3>Get Batch</h3>

In [17]:
def padding(sequences, pad_idx, max_length=None):
    '''
    Inputs:
        sequences: list of list of tokens
    '''
    if max_length is None:
        max_length = max(map(len, sequences))
    
    return [seq + [pad_idx]*(max_length - len(seq)) for seq in sequences]

def get_batch(batch_size, train):
    if train:
        data = train_data
    else:
        data = test_data
        
    rand_ids = np.random.randint(0, len(data), batch_size)
    
    source = [data[idx][0] for idx in rand_ids]
    target = [data[idx][1] for idx in rand_ids]
    
    target_in  = [[sos_idx] + sequence for sequence in target]
    target_out = [sequence + [eos_idx] for sequence in target]
    
    source     = padding(source, pad_idx)
    target_in  = padding(target_in, pad_idx)
    target_out = padding(target_out, pad_idx)
    
    source     = torch.LongTensor(source).to(device)
    target_in  = torch.LongTensor(target_in).to(device)
    target_out = torch.LongTensor(target_out).to(device)
    
    return source, target_in, target_out

In [18]:
source, target_in, target_out = get_batch(32, train=True)

In [19]:
source.size(), target_in.size(), target_out.size()

(torch.Size([32, 15]), torch.Size([32, 17]), torch.Size([32, 17]))

<h3>Neural Network</h3>

In [20]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, emb_size, hidden_size, padding_idx):
        super(Encoder, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx=padding_idx)
        self.lstm      = nn.LSTM(emb_size, hidden_size, batch_first=True)
        
    def forward(self, batch_words):
        '''
        Inputs:
            batch_words: (batch x source_len)
        '''
        
        #(batch x source_len) -> (batch x source_len x emb_size)
        embedded = self.embedding(batch_words)
        
        output, hidden = self.lstm(embedded)
        return output
    
class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        
        self.linear = nn.Linear(hidden_size, hidden_size, bias=False)
        
    def forward(self, hidden, encoder_output, encoder_mask):
        '''
        Inputs:
            hidden: (batch x hidden_size)
            encoder_output: (batch x source_len x hidden_size)
        '''
        hidden = self.linear(hidden)
        hidden = hidden.unsqueeze(2)
        alphas = encoder_output.matmul(hidden)
        
        if encoder_mask is not None:
            alphas[encoder_mask] = -1e16
            
        scores = F.softmax(alphas, dim=1)
        c = (scores * encoder_output).sum(dim=1)
        return c
    
class Decoder(nn.Module):
    def __init__(self, vocab_size, emb_size, hidden_size, padding_idx):
        super(Decoder, self).__init__()
        
        self.vocab_size = vocab_size
        
        self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx=padding_idx)
        self.attn      = Attention(hidden_size)
        self.gru_cell  = nn.GRUCell(emb_size + hidden_size, hidden_size)
        self.linear    = nn.Linear(hidden_size, vocab_size)
  

    def forward(self, batch_trans_in, encoder_output, hidden, encoder_mask=None):
        '''
        Inputs:
            batch_trans_in: (batch x target_len)
            encoder_output: (batch x source_len x hidden_size)
            hidden: (batch x hidden_size)
        '''
        embedded  = self.embedding(batch_trans_in)
        timesteps = embedded.size(1)
        
        output = []
        
        for t in range(timesteps):
            x = embedded[:, t]
            c = self.attn(hidden, encoder_output, encoder_mask)
            inp = torch.cat([x, c], dim=1)
            hidden = self.gru_cell(inp, hidden)
            output.append(hidden)
        
        output = torch.stack(output, dim=1)
        logits = self.linear(output)

        return logits.view(-1, self.vocab_size), hidden

In [21]:
id2word = {value: key for key, value in word2id.items()}

In [22]:
def _print(train=True):
    source, _, _ = get_batch(1, train=True)
    encoder_output = encoder(source)
    
    hidden = torch.zeros(1, hidden_size).to(device)
    
    generated = [sos_idx]
    
    for i in range(min_length):
        
        generated_in = torch.LongTensor([[generated[-1]]]).to(device)
        
        logit, hidden = decoder(generated_in, encoder_output, hidden)
        next_idx = logit.max(1)[1][0].item()
        generated.append(next_idx)
        
        if next_idx == eos_idx:
            break
            
    source = [id2word.get(idx, "unk") for idx in source[0].tolist()]
    target = [id2word.get(idx, "unk") for idx in generated if idx not in [sos_idx, eos_idx]]
    
    source = ' '.join(source)
    target = ' '.join(target)
    
    print(source)
    print(target)

In [23]:
emb_size    = 256
hidden_size = 512

encoder = Encoder(len(word2id), emb_size, hidden_size, pad_idx).to(device)
decoder = Decoder(len(word2id), emb_size, hidden_size, pad_idx).to(device)

criterion = nn.CrossEntropyLoss()

encoder_optimizer = optim.Adam(encoder.parameters())
decoder_optimizer = optim.Adam(decoder.parameters())

In [24]:
_print(train=False)

to hide it from the i .r .s .
rifle units choice shell honor worry glass paint print clarice party farewell clarice pie budget pedro


In [25]:
losses = []
batch_size = 128

for epoch in range(15):
    for batch_idx in range(len(train_data) // batch_size):
        
        source, target_in, target_out = get_batch(batch_size, train=True)
        encoder_output = encoder(source)
        encoder_mask = source == pad_idx
        hidden = torch.zeros(batch_size, hidden_size).to(device)
        logit, hidden = decoder(target_in, encoder_output, hidden, encoder_mask)

        target_out = target_out.view(-1)
        decoder_mask = target_out != pad_idx
        decoder_mask = decoder_mask

        loss = criterion(logit[decoder_mask], target_out[decoder_mask])
        
        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()
        loss.backward()
        decoder_optimizer.step()
        encoder_optimizer.step()
        
        if batch_idx % 1000 == 0:
            print("Epoch: %s. Loss: %s" % (epoch, loss.item()))
            _print(True)
            print('------')
            _print(False)
            print('-------')
            print()

Epoch: 0. Loss: 8.689029693603516
aren t you ever going to get married ?
i accident t ordinary lie fooled pick pick tv nashville invitation used world slaughtered crystal nerve
------
did you see a girl come by here ?
i accident t ordinary fooled woulda plank singing paint girlfriends rifle chauncey sheldrake collect bored jump
-------

Epoch: 1. Loss: 3.575674295425415
that s none of your business .
i m not .
------
the windows don t open .
i m not .
-------

Epoch: 2. Loss: 3.3263769149780273
that s off now .
i m not .
------
he was a <unk> .
he s a <unk> .
-------

Epoch: 3. Loss: 2.8957908153533936
right now ?
i m not sure .
------
i took care of everybody .
you don t have to be <unk> .
-------

Epoch: 4. Loss: 2.8700754642486572
where the hell are you from <unk> ? ? ?
i m not . . .
------
some people would say you re paranoid .
i m not .
-------

Epoch: 5. Loss: 2.616819143295288
then we must reach her before she feels that pain .
i know .
------
. . <unk> .
i m sorry . . .
------

In [42]:
a = "Hello, my name is Zuzee"

In [43]:
def prepr(a):
    a = unicodeToAscii(a)
    a = normalizeString(a)
    a = a.split()
    a = [word2id.get(word, unk_idx) for word in a]
    a = torch.LongTensor(a).to(device)
    
    return a

In [44]:
#a = [word2id.get(word, unk_idx) for word in a]
#a = torch.LongTensor(a).to(device)

In [45]:
#a = a + [pad_idx]*(min_length - len(a))

In [46]:
#a = torch.LongTensor(a).to(device)

In [47]:
def eval(source1):
    source1 = prepr(source1)
    seq = source1.unsqueeze(0)
    encoder_output = encoder(seq)
    
    hidden = torch.zeros(1, hidden_size).to(device)
    
    generated = [sos_idx]
    
    for i in range(min_length):
        
        generated_in = torch.LongTensor([[generated[-1]]]).to(device)
        
        logit, hidden = decoder(generated_in, encoder_output, hidden)
        next_idx = logit.max(1)[1][0].item()
        generated.append(next_idx)
        
        if next_idx == eos_idx:
            break
            
    seq = [id2word.get(idx, "unk") for idx in seq[0].tolist()]
    target = [id2word.get(idx, "unk") for idx in generated if idx not in [sos_idx, eos_idx]]
    
    seq = ' '.join(seq)
    target = ' '.join(target)
    
    print(seq)
    print(target)
    return target

In [None]:
v = eval(a)

In [35]:
with open('id2word.pkl', 'wb') as f:
        pickle.dump(id2word, f, pickle.HIGHEST_PROTOCOL)

In [36]:
with open('word2id.pkl', 'wb') as f:
        pickle.dump(word2id, f, pickle.HIGHEST_PROTOCOL)

In [40]:
torch.save(encoder.state_dict(), 'encoder.pt')

In [41]:
torch.save(decoder.state_dict(), 'decoder.pt')