In [1]:
from fastNLP.embeddings import StaticEmbedding
from fastNLP import Vocabulary, BucketSampler, DataSetIter, SequentialSampler
from fastNLP import DataSet
from fastNLP.io.data_bundle import DataBundle
from nltk.tokenize import word_tokenize
from tqdm import tqdm

In [2]:
def load_data(path, keys):
    metadata = {}
    with open(path, 'r') as f:
        for line in f:
            line = line.split('+++$+++')
            line = [x.strip() for x in line]
            metadata[line[0]] = dict(zip(keys, line[1:]))
    return metadata

In [3]:
movie_titles = load_data('./data/cornell movie-dialogs corpus/movie_titles_metadata.txt', 
                         ['movie title', 'movie year', 'IMDB rating', 'no. IMDB votes', 'genres'])
movie_characters = load_data('./data/cornell movie-dialogs corpus/movie_characters_metadata.txt',
                            ['character name', 'movieID', 'movie title', 'gender', 'position in credits']) 
movie_lines = load_data('./data/cornell movie-dialogs corpus/movie_lines.txt',
                       ['characterID', 'movieID', 'character name', 'utterance'])

In [4]:
movie_lines['L12455']

{'characterID': 'u152',
 'movieID': 'm10',
 'character name': 'ROLFE',
 'utterance': 'Get some sleep.'}

In [5]:
input_sent = []
output_sent = []
speakers = []
movies = []
with open('./data/cornell movie-dialogs corpus/movie_conversations.txt', 'r') as f:
    cnt = 0
    for line in f:
        line = line.split('+++$+++')
        line = [x.strip() for x in line]
        line[3] = line[3][1:-1].split(', ')
        raw_words = [movie_lines[x[1:-1]]['utterance'] for x in line[3]]
        words = [word_tokenize(x) for x in raw_words]
        w = False
        for x in words:
            if len(x) == 0: w = True
        if w: continue
        input_sent += words[:-1]
        output_sent += words[1:]
        speakers += [line[1], line[0]] * ((len(line[3]) - 1) // 2)
        if len(line[3]) % 2 == 0: speakers += [line[1]]
        movies += [line[2]] * (len(line[3]) - 1)
        cnt += 1
        if cnt == 2000: break
dataset = DataSet({'input_sent':input_sent, 'speakers':speakers, 'movies':movies, 'output_sent':output_sent})
dataset.apply_field(lambda x: len(x), field_name='input_sent', new_field_name='seq_len')
dataset.apply_field(lambda x: x + ['<EOS>'], field_name='output_sent', new_field_name='output_sent')
print(dataset)

+----------------+----------+--------+-----------------+---------+
| input_sent     | speakers | movies | output_sent     | seq_len |
+----------------+----------+--------+-----------------+---------+
| ['Can', 'we... | u2       | m0     | ['Well', ','... | 25      |
| ['Well', ',... | u0       | m0     | ['Not', 'the... | 17      |
| ['Not', 'th... | u2       | m0     | ['Okay', '..... | 11      |
| ['You', "'r... | u2       | m0     | ['Forget', '... | 17      |
| ['No', ',',... | u2       | m0     | ['Cameron', ... | 18      |
| ['Cameron',... | u0       | m0     | ['The', 'thi... | 2       |
| ['The', 'th... | u2       | m0     | ['Seems', 'l... | 30      |
| ['Why', '?'... | u2       | m0     | ['Unsolved',... | 2       |
| ['Unsolved'... | u0       | m0     | ['That', "'s... | 28      |
| ['Gosh', ',... | u2       | m0     | ['Let', 'me'... | 11      |
| ["C'esc", '... | u2       | m0     | ['Right', '.... | 8       |
| ['Right', '... | u0       | m0     | ['I', 'do', ... | 11   

In [6]:
dataset.set_input('input_sent', 'speakers', 'seq_len')
dataset.set_target('output_sent')
train, test = dataset.split(0.005)
train, dev = train.split(0.005)
train = train[:20]
data_bundle = DataBundle()
data_bundle.set_dataset(train, 'train')
data_bundle.set_dataset(dev, 'dev')
data_bundle.set_dataset(test, 'test')
print(data_bundle)

In total 3 datasets:
	train has 20 instances.
	dev has 24 instances.
	test has 24 instances.



In [7]:
vocab = Vocabulary(min_freq=1)
vocab.from_dataset(train, field_name=['input_sent', 'output_sent'], no_create_entry_dataset=[dev, test])
vocab.index_dataset(train, field_name=['input_sent', 'output_sent'])
vocab.index_dataset(dev, field_name=['input_sent', 'output_sent'])
vocab.index_dataset(test, field_name=['input_sent', 'output_sent'])
speaker_vocab = Vocabulary()
speaker_vocab.from_dataset(train, field_name='speakers', no_create_entry_dataset=[dev, test])
speaker_vocab.index_dataset(train, field_name='speakers')
speaker_vocab.index_dataset(dev, field_name='speakers')
speaker_vocab.index_dataset(test, field_name='speakers')
print(len(vocab))
print(len(speaker_vocab))

726
51


In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time

In [9]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding, hidden_dim=64, dropout=0.5):
        super().__init__()

        self.embedding_dim = 100
        self.vocab_size = vocab_size
        self.embedding = embedding
        self.gru = nn.GRU(self.embedding_dim, hidden_dim, num_layers=1, dropout=0)
        self.dropout = nn.Dropout(dropout)

    def forward(self, words, input_lengths):
        # (input) words : (batch_size, seq_len)
        words = words.permute(1, 0)
        # words : (seq_len, batch_size)
        
        sorted_seq_lens = list(sorted([(idx, seq_len.item()) for
                                       idx, seq_len in zip(range(input_lengths.size(0)), list(input_lengths))],
                                      key=lambda x: x[1], reverse=True))
        idx = torch.LongTensor([x[0] for x in sorted_seq_lens])
        seq_len = torch.LongTensor([x[1] for x in sorted_seq_lens])
        new_words = torch.index_select(words, 1, idx)
        
        embedded = self.dropout(self.embedding(new_words))
        # embedded : (seq_len, batch_size, embedding_dim)

        packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, seq_len)
        output, hidden = self.gru(packed)
        # output: (seq_len, batch_size, hidden_dim)
        # hidden: (num_layers, batch_size, hidden_dim)

        hidden = self.dropout(hidden[-1, :, :])
        # hidden: (batch_size, hidden_dim)
        
        return hidden

In [10]:
class Decoder(nn.Module):
    def __init__(self, embedding, speakervocab_size, speaker_dim, output_dim, hidden_dim=64, dropout=0.5):
        super().__init__()

        self.embedding_dim = 100
        self.embedding = embedding
        self.hidden_dim = hidden_dim
        self.speaker_dim = speaker_dim
        self.speaker = nn.Embedding(speakervocab_size, speaker_dim)
        self.gru = nn.GRU(self.embedding_dim + speaker_dim, hidden_dim, num_layers=1, dropout=0)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, words, lst_hidden, speaker):
        # words : (1, batch_size)
        # lst_hidden: (batch_size, hidden_dim)

        embedded = self.dropout(self.embedding(words))
        # embedded : (1, batch_size, embedding_dim)
        speaker_embedded = self.speaker(speaker).view(1, -1, self.speaker_dim)
        # speaker_embedded : (1, batch_size, speaker_dim)
        embedded = torch.cat((embedded, speaker_embedded), dim=2)
        # embedded : (1, batch_size, embedding_dim + speaker_dim)
        lst_hidden = lst_hidden.view(1, -1, self.hidden_dim)
        # lst_hidden : (1, batch_size, hidden_dim)
        output, hidden = self.gru(embedded, lst_hidden)
        # output: (1, batch_size, hidden_dim)
        # hidden: (1, batch_size, hidden_dim)

        hidden = self.dropout(hidden[-1, :, :])
        # hidden: (batch_size, hidden_dim)

        pred = self.fc(hidden)
        # result: (batch_size, output_dim)
        return hidden, pred

In [11]:
hidden_size = 64
embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-100d', lower=True)
encoder = Encoder(len(vocab), embed, hidden_size, 0)
decoder = Decoder(embed, len(speaker_vocab), 100, len(vocab), hidden_size, 0)
print(encoder)
print(decoder)

All word in the vocab have been lowered. There are 726 words, 654 unique lowered words.
Found 636 out of 654 words in the pre-training embedding.
Encoder(
  (embedding): StaticEmbedding(
    (dropout_layer): Dropout(p=0)
    (embedding): Embedding(642, 100, padding_idx=0)
  )
  (gru): GRU(100, 64)
  (dropout): Dropout(p=0)
)
Decoder(
  (embedding): StaticEmbedding(
    (dropout_layer): Dropout(p=0)
    (embedding): Embedding(642, 100, padding_idx=0)
  )
  (speaker): Embedding(51, 100)
  (gru): GRU(200, 64)
  (fc): Linear(in_features=64, out_features=726, bias=True)
  (dropout): Dropout(p=0)
)


In [12]:
def adjust_learning_rate(optimizer):
    """Sets the learning rate to the initial LR decayed by 10 every X epochs"""
    for param_group in optimizer.param_groups:
        param_group['lr'] *= 0.1

In [13]:
def Test(dev):
    batch = DataSetIter(batch_size=32, dataset=dev, sampler=SequentialSampler())
    PAD_token = 0
    EOS_token = vocab.to_index('<EOS>')
    per_list = []
    for batch_x, batch_y in batch:
        batch_size = batch_x['input_sent'].size(0)
        #print(batch_x)
        decoder_hidden = encoder(batch_x['input_sent'], batch_x['seq_len'])
        decoder_input = torch.LongTensor([[EOS_token for _ in range(batch_size)]])
        max_target_len = max([len(x) for x in batch_y['output_sent']])
        per = torch.zeros((batch_size))
        for t in range(max_target_len):
            decoder_hidden, decoder_output = decoder(
                decoder_input, decoder_hidden, batch_x['speakers']
            )
            #print(decoder_output.size())
            _, topi = decoder_output.topk(1) # [64, 1]
            pred = F.softmax(decoder_output)
            #decoder_input = torch.LongTensor([[topi[i][0] for i in range(batch_size)]])
            decoder_input = batch_y['output_sent'][:, t].view(1, -1)
            pred[:, PAD_token] = torch.ones((batch_size))
            pred = torch.gather(pred, 1, batch_y['output_sent'][:, t:t+1]).squeeze()
            per -= torch.log(pred)
            #print(decoder_input)
            #print(batch_y['output_sent'][:, t])
        #print(per, batch_x['seq_len'].float())
        per /= batch_x['seq_len'].float()
        per_list.append(sum(per))
    #print(per_list)
    return sum(per_list)/dev.get_length()

In [14]:
def Train(epoch, batch_size, decoder, encoder, data_bundle):
    encoder_optimizer = torch.optim.RMSprop(encoder.parameters(), lr=0.0001)
    decoder_optimizer = torch.optim.RMSprop(decoder.parameters(), lr=0.0001 * 5)
    sampler = BucketSampler(batch_size=batch_size, seq_len_field_name='seq_len')
    batch = DataSetIter(batch_size=batch_size, dataset=data_bundle.get_dataset('train'), sampler=sampler)
    
    PAD_token = 0
    EOS_token = vocab.to_index('<EOS>')
    
    start_time = time.time()
    print("-"*5+"start training"+"-"*5)
    best_res = 100 
    for i in range(epoch):
        if i % 200 == 199:
            adjust_learning_rate(encoder_optimizer)
            adjust_learning_rate(decoder_optimizer)
        loss_list = []
        for batch_x, batch_y in batch: # tqdm(batch):
            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()
            decoder_hidden = encoder(batch_x['input_sent'], batch_x['seq_len'])
            decoder_input = torch.LongTensor([[EOS_token for _ in range(batch_x['input_sent'].size(0))]])
            loss = 0
            max_target_len = max([len(x) for x in batch_y['output_sent']])
            pred_sent = [[] for _ in range(batch_x['input_sent'].size(0))]
            ground_truth = [[] for _ in range(batch_x['input_sent'].size(0))]
            
            use_teacher_forcing = True
            for t in range(max_target_len):
                decoder_hidden, decoder_output = decoder(
                    decoder_input, decoder_hidden, batch_x['speakers']
                )
                _, topi = decoder_output.topk(1) # [64, 1]
                if use_teacher_forcing:
                    decoder_input = batch_y['output_sent'][:, t].view(1, -1) # Next input is current target                        
                else:
                    decoder_input = torch.LongTensor([[topi[i][0] for i in range(batch_x['input_sent'].size(0))]])
                pred = F.softmax(decoder_output)
                for j in range(batch_x['input_sent'].size(0)):
                    pred_sent[j].append((topi[j][0].item()))#, pred[j, topi[j][0]].item()))
                    ground_truth[j].append((batch_y['output_sent'][j, t].item()))#, pred[j, batch_y['output_sent'][i, t].item()].item()))
                #print(batch_y['output_sent'][:, t])
                #print(loss)
                loss += F.cross_entropy(decoder_output, batch_y['output_sent'][:, t], ignore_index=PAD_token)
            
            #print('mine ', pred_sent)
            #print('truth ', ground_truth)
            loss.backward()
            encoder_optimizer.step()
            decoder_optimizer.step()
            loss_list.append(loss.item())

        res = Test(data_bundle.get_dataset('dev'))

        print('Epoch {:d} Avg Loss: {:.2f}'.format(i, sum(loss_list) / len(loss_list)), end=" ")
        print("Perplexity on Development set: {:.4f}".format(res), end=" ")
        print('{:d}ms'.format(round((time.time()-start_time)*1000)))
        
        if res < best_res:
            print('Saving')
            torch.save(embed.state_dict(), './trained_model/embed.pkl')
            torch.save(encoder.state_dict(), './trained_model/encoder.pkl')
            torch.save(decoder.state_dict(), './trained_model/decoder.pkl')
            print('Model Saved')
            # best_res = res
        
        loss_list.clear()

In [15]:
Train(300, 4, decoder, encoder, data_bundle)

-----start training-----




Epoch 0 Avg Loss: 156.99 Perplexity on Development set: 9.4556 517ms
Saving
Model Saved
Epoch 1 Avg Loss: 142.83 Perplexity on Development set: 9.1644 990ms
Saving
Model Saved
Epoch 2 Avg Loss: 131.02 Perplexity on Development set: 8.8928 1433ms
Saving
Model Saved
Epoch 3 Avg Loss: 121.90 Perplexity on Development set: 8.7000 1862ms
Saving
Model Saved
Epoch 4 Avg Loss: 115.59 Perplexity on Development set: 8.5817 2305ms
Saving
Model Saved
Epoch 5 Avg Loss: 111.16 Perplexity on Development set: 8.5077 2764ms
Saving
Model Saved
Epoch 6 Avg Loss: 107.86 Perplexity on Development set: 8.4690 3222ms
Saving
Model Saved
Epoch 7 Avg Loss: 105.20 Perplexity on Development set: 8.4454 3678ms
Saving
Model Saved
Epoch 8 Avg Loss: 102.93 Perplexity on Development set: 8.4339 4099ms
Saving
Model Saved
Epoch 9 Avg Loss: 100.90 Perplexity on Development set: 8.4239 4511ms
Saving
Model Saved
Epoch 10 Avg Loss: 99.07 Perplexity on Development set: 8.4211 4921ms
Saving
Model Saved
Epoch 11 Avg Loss: 97.3

Epoch 92 Avg Loss: 24.86 Perplexity on Development set: 9.3729 38254ms
Saving
Model Saved
Epoch 93 Avg Loss: 24.39 Perplexity on Development set: 9.3692 38706ms
Saving
Model Saved
Epoch 94 Avg Loss: 23.91 Perplexity on Development set: 9.3865 39134ms
Saving
Model Saved
Epoch 95 Avg Loss: 23.43 Perplexity on Development set: 9.4030 39570ms
Saving
Model Saved
Epoch 96 Avg Loss: 22.97 Perplexity on Development set: 9.4008 39986ms
Saving
Model Saved
Epoch 97 Avg Loss: 22.50 Perplexity on Development set: 9.4406 40412ms
Saving
Model Saved
Epoch 98 Avg Loss: 22.07 Perplexity on Development set: 9.4319 40859ms
Saving
Model Saved
Epoch 99 Avg Loss: 21.64 Perplexity on Development set: 9.4548 41291ms
Saving
Model Saved
Epoch 100 Avg Loss: 21.19 Perplexity on Development set: 9.4507 41720ms
Saving
Model Saved
Epoch 101 Avg Loss: 20.78 Perplexity on Development set: 9.4770 42190ms
Saving
Model Saved
Epoch 102 Avg Loss: 20.37 Perplexity on Development set: 9.5009 42627ms
Saving
Model Saved
Epoch 1

Epoch 183 Avg Loss: 4.00 Perplexity on Development set: 10.4308 75521ms
Saving
Model Saved
Epoch 184 Avg Loss: 3.91 Perplexity on Development set: 10.5058 75898ms
Saving
Model Saved
Epoch 185 Avg Loss: 3.87 Perplexity on Development set: 10.4891 76280ms
Saving
Model Saved
Epoch 186 Avg Loss: 3.77 Perplexity on Development set: 10.5130 76661ms
Saving
Model Saved
Epoch 187 Avg Loss: 3.69 Perplexity on Development set: 10.5232 77048ms
Saving
Model Saved
Epoch 188 Avg Loss: 3.62 Perplexity on Development set: 10.5357 77422ms
Saving
Model Saved
Epoch 189 Avg Loss: 3.56 Perplexity on Development set: 10.5249 77813ms
Saving
Model Saved
Epoch 190 Avg Loss: 3.48 Perplexity on Development set: 10.5772 78219ms
Saving
Model Saved
Epoch 191 Avg Loss: 3.43 Perplexity on Development set: 10.5380 78618ms
Saving
Model Saved
Epoch 192 Avg Loss: 3.34 Perplexity on Development set: 10.5530 78995ms
Saving
Model Saved
Epoch 193 Avg Loss: 3.28 Perplexity on Development set: 10.5934 79428ms
Saving
Model Saved

Epoch 273 Avg Loss: 2.39 Perplexity on Development set: 10.7808 112930ms
Saving
Model Saved
Epoch 274 Avg Loss: 2.37 Perplexity on Development set: 10.7827 113308ms
Saving
Model Saved
Epoch 275 Avg Loss: 2.40 Perplexity on Development set: 10.7862 113757ms
Saving
Model Saved
Epoch 276 Avg Loss: 2.36 Perplexity on Development set: 10.7879 114140ms
Saving
Model Saved
Epoch 277 Avg Loss: 2.35 Perplexity on Development set: 10.7903 114585ms
Saving
Model Saved
Epoch 278 Avg Loss: 2.38 Perplexity on Development set: 10.7927 114973ms
Saving
Model Saved
Epoch 279 Avg Loss: 2.33 Perplexity on Development set: 10.7968 115363ms
Saving
Model Saved
Epoch 280 Avg Loss: 2.34 Perplexity on Development set: 10.7972 115801ms
Saving
Model Saved
Epoch 281 Avg Loss: 2.32 Perplexity on Development set: 10.8038 116211ms
Saving
Model Saved
Epoch 282 Avg Loss: 2.32 Perplexity on Development set: 10.8011 116616ms
Saving
Model Saved
Epoch 283 Avg Loss: 2.31 Perplexity on Development set: 10.8107 117029ms
Saving


In [16]:
def Pred(test):
    batch = DataSetIter(batch_size=1, dataset=test, sampler=SequentialSampler())
    EOS_token = vocab.to_index('<EOS>')
    per_list = []
    for batch_x, batch_y in batch:
        decoder_hidden = encoder(batch_x['input_sent'], batch_x['seq_len'])
        decoder_input = torch.LongTensor([[EOS_token]])
        per = 0
        pred = []
        for t in range(20):
            decoder_hidden, decoder_output = decoder(
                decoder_input, decoder_hidden, batch_x['speakers']
            )
            _, topi = decoder_output.topk(1) # [64, 1]
            decoder_input = torch.LongTensor([[topi[0][0]]])
            pred.append(decoder_input[0, 0].item())
            if decoder_input[0, 0].item() == vocab.to_index('<EOS>'): break
        print(" ".join([vocab.to_word(x.item()) for x in list(batch_x['input_sent'][0, :])]))
        print(movie_characters[speaker_vocab.to_word(batch_x['speakers'].item())]['character name'] + " says: " + " ".join([vocab.to_word(x) for x in pred]))
        print(" ".join([vocab.to_word(x.item()) for x in list(batch_y['output_sent'][0, :])]))
        print()

In [17]:
embed.load_state_dict(torch.load('./trained_model/embed.pkl'))
encoder.load_state_dict(torch.load('./trained_model/encoder.pkl'))
decoder.load_state_dict(torch.load('./trained_model/decoder.pkl'))
Pred(train[:10])

Elaine , I 'm going back there . Just hold onto that stick and try to control this hunk of tin as best you can .
TED says: Ted , please be careful . <EOS>
Ted , please be careful . <EOS>

You 've got Penthouse , Playboy , Hustler , etc . Nobody even considers them pornography anymore . Then , there 's mainstream hardcore . Triple X . The difference is penetration . That 's hardcore . That whole industry 's up in the valley . Writers , directors , porn stars . They 're celebrities , or they think they are . They pump out 150 videos a week . A week . They 've even got a porno Academy Awards . America loves pornography . Anybody tells you they never use pornography , they 're lying . Somebody 's buying those videos . Somebody 's out there spending 900 million dollars a year on phone sex . Know what else ? It 's only gon na get worse . More and more you 'll see perverse hardcore coming into the mainstream , because that 's evolution . Desensitization . Oh my God , Elvis Presley 's wiggling