In [50]:
import re
import pickle
import random

import numpy as np
import pandas as pd
import torch
import torch.nn as nn

In [9]:
with open("plots_text.pickle", "rb") as pickle_in:
    movie_plots = pickle.load(pickle_in)
    
display(len(movie_plots))

500

In [3]:
movie_plots[0]

'barry is a private with the 101st airborne division of the united states army, stationed at fort campbell, kentucky. calpernia works as a showgirl at a transgender revue in nashville, tennessee when the two met in 1999. barry\'s roommate justin fisher  brings barry to the club where she performs. when barry and calpernia begin seeing each other regularly, fisher begins spreading rumors on base about their relationship, which appeared to be a violation of the military\'s "don\'t ask, don\'t tell" policy about discussing the sexual orientation of military personnel. barry faces increasing harassment and pressure, which explode into violence over fourth of july weekend. while calpernia performs in a pageant in nashville, barry is beaten to death in his sleep with a baseball bat by calvin glover, who had been goaded by fisher into committing the crime. the film ends with a discussion of the aftermath.'

In [11]:
movie_plots = pd.DataFrame(movie_plots, columns=["plot_summary"])
display(movie_plots)

Unnamed: 0,plot_summary
0,barry is a private with the 101st airborne div...
1,chinese exorcist one-eyebrow priest leads a p...
2,while playing baseball on a busy street in gre...
3,thadeous and fabious ([[danny mcbride are son...
4,"{{plot}} jung su-ji is a quiet, mysterious bea..."
...,...
495,the film opens with harlow as a struggling ext...
496,{{plot}} edgar and alan frog interrupt a half-...
497,{{copyedit}} after smuggling a strong box of t...
498,when the existence of a strain of plague is r...


In [12]:
p = re.compile("[^a-z' ]")
movie_plots['plot_summary'] = movie_plots['plot_summary'].apply(lambda x: p.sub("", x))

In [17]:
def create_seq(df, seq_len):
    lst = []
    for plot in df['plot_summary']:
        plot_toks = plot.split()
        if len(plot_toks) > seq_len:
            for i in range(seq_len, len(plot_toks)):
                lst.append(plot_toks[i-seq_len: i+1])
        else:
            lst.append(plot_toks)
    return lst

In [18]:
seqs = create_seq(movie_plots, 5)
display(seqs[:20])

[['barry', 'is', 'a', 'private', 'with', 'the'],
 ['is', 'a', 'private', 'with', 'the', 'st'],
 ['a', 'private', 'with', 'the', 'st', 'airborne'],
 ['private', 'with', 'the', 'st', 'airborne', 'division'],
 ['with', 'the', 'st', 'airborne', 'division', 'of'],
 ['the', 'st', 'airborne', 'division', 'of', 'the'],
 ['st', 'airborne', 'division', 'of', 'the', 'united'],
 ['airborne', 'division', 'of', 'the', 'united', 'states'],
 ['division', 'of', 'the', 'united', 'states', 'army'],
 ['of', 'the', 'united', 'states', 'army', 'stationed'],
 ['the', 'united', 'states', 'army', 'stationed', 'at'],
 ['united', 'states', 'army', 'stationed', 'at', 'fort'],
 ['states', 'army', 'stationed', 'at', 'fort', 'campbell'],
 ['army', 'stationed', 'at', 'fort', 'campbell', 'kentucky'],
 ['stationed', 'at', 'fort', 'campbell', 'kentucky', 'calpernia'],
 ['at', 'fort', 'campbell', 'kentucky', 'calpernia', 'works'],
 ['fort', 'campbell', 'kentucky', 'calpernia', 'works', 'as'],
 ['campbell', 'kentucky', 'c

In [19]:
x = []
y = []

for s in seqs:
    x.append(" ".join(s[:-1]))
    y.append(" ".join(s[1:]))

In [51]:
def create_dictionary(corpus):
    int2tok = {}
    
    for i, word in enumerate(set(" ".join(corpus).split())):
        int2tok[i] = word
        
    tok2int = {tok: i for i, tok in int2tok.items()}
    return int2tok, tok2int

int2tok, tok2int = create_dictionary(movie_plots['plot_summary'])
vocab_size = len(int2tok)

In [42]:
def vectorize(seq):
    return [tok2int[word] for word in seq.split()]

x_vector = np.array([vectorize(i) for i in x])
y_vector = np.array([vectorize(i) for i in y])

In [104]:
def make_batches(x_arr, y_arr, batch_size):
    place = 0
        
    while place <= x_arr.shape[0] - batch_size:
        x = x_arr[place:place+batch_size, :]
        y = y_arr[place:place+batch_size, :]
        place += batch_size
        yield x, y

In [105]:
class WordLSTM(nn.Module):
    def __init__(self, n_hidden=64, n_layers=4, drop_prob=0.3, lr=0.001):
        super().__init__()
        
        self.drop_prob = drop_prob
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.lr = lr
        
        self.emb_layer = nn.Embedding(vocab_size, 42)
        self.lstm = nn.LSTM(42, n_hidden, n_layers, dropout=drop_prob, batch_first=True)
        self.dropout = nn.Dropout(drop_prob)
        self.fc = nn.Linear(n_hidden, vocab_size)
        
    def forward(self, x, hidden):
        embedded = self.emb_layer(x)
        lstm_output, hidden = self.lstm(embedded, hidden)
        out = self.dropout(lstm_output)
        out = out.reshape(-1, self.n_hidden)
        out = self.fc(out)
        return out, hidden
    
    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        hidden = (weight.new(self.n_layers, batch_size, self.n_hidden).zero_(),
                  weight.new(self.n_layers, batch_size, self.n_hidden).zero_())
        return hidden

In [106]:
del net
net = WordLSTM()
print(net)

WordLSTM(
  (emb_layer): Embedding(16592, 42)
  (lstm): LSTM(42, 64, num_layers=4, batch_first=True, dropout=0.3)
  (dropout): Dropout(p=0.3, inplace=False)
  (fc): Linear(in_features=64, out_features=16592, bias=True)
)


In [71]:
len(x_vector)

152644

In [107]:
def train(net, epochs, batch_size=32, lr=0.001, clip=1, print_every=256):
    opt = torch.optim.Adam(net.parameters(), lr=lr)
    loss_func = nn.CrossEntropyLoss()
    
    counter = 0
    net.train()
    
    for e in range(epochs):
        h = net.init_hidden(batch_size)
        
        for x, y in make_batches(x_vector, y_vector, batch_size):
            counter += 1
            inputs, targets = torch.from_numpy(x), torch.from_numpy(y)
            h = tuple([each.data for each in h])
            net.zero_grad()
            output, h = net(inputs, h)
            loss = loss_func(output, targets.view(-1))
            loss.backward()
            nn.utils.clip_grad_norm(net.parameters(), clip)
            opt.step()
            
            if counter % print_every == 0:
                print(f'Epoch: {e+1}/{epochs}',
                     f'Step: {counter}...')

In [108]:
train(net, epochs=1)

  nn.utils.clip_grad_norm(net.parameters(), clip)


Epoch: 1/1 Step: 256...
Epoch: 1/1 Step: 512...
Epoch: 1/1 Step: 768...
Epoch: 1/1 Step: 1024...
Epoch: 1/1 Step: 1280...
Epoch: 1/1 Step: 1536...
Epoch: 1/1 Step: 1792...
Epoch: 1/1 Step: 2048...
Epoch: 1/1 Step: 2304...
Epoch: 1/1 Step: 2560...
Epoch: 1/1 Step: 2816...
Epoch: 1/1 Step: 3072...
Epoch: 1/1 Step: 3328...
Epoch: 1/1 Step: 3584...
Epoch: 1/1 Step: 3840...
Epoch: 1/1 Step: 4096...
Epoch: 1/1 Step: 4352...
Epoch: 1/1 Step: 4608...


In [109]:
def predict(net, tok, h=None):
    x = np.array([[tok2int[tok]]])
    inputs = torch.from_numpy(x)
    
    h = tuple([each.data for each in h])
    out, h = net(inputs, h)
    p = nn.functional.softmax(out, dim=1).data
    
    p = p.numpy()
    p = p.reshape(p.shape[1],)
    top5idx = p.argsort()[-5:][::-1]
    sampled_tok_idx = top5idx[random.sample([0,1,2,3,4],1)[0]]
    
    return int2tok[sampled_tok_idx], h

def sample(net, size, prime):
    net.eval()
    h = net.init_hidden(1)
    toks = prime.split()
    
    for tok in toks:
        nextTok, h = predict(net, tok, h=h)
        
    toks.append(nextTok)
    
    for i in range(size-1):
        nextTok, h = predict(net, toks[-1], h)
        toks.append(nextTok)
        
    return ' '.join(toks)

In [120]:
sample(net, 12, 'star wars')

'star wars the the is and and and and a the and is and'