In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle

import numpy as np
from collections import Counter
import os
from argparse import Namespace

In [2]:
class RNNModule(nn.Module):
    def __init__(self, n_vocab, seq_size, embedding_size, lstm_size):
        super(RNNModule, self).__init__()
        self.seq_size = seq_size
        self.lstm_size = lstm_size
        self.embedding = nn.Embedding(n_vocab, embedding_size)
        self.lstm = nn.LSTM(embedding_size,
                            lstm_size,
                            batch_first=True)
        self.dense = nn.Linear(lstm_size, n_vocab)
        
    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.dense(output)

        return logits, state
    
    def zero_state(self, batch_size):
        return (torch.zeros(1, batch_size, self.lstm_size),
                torch.zeros(1, batch_size, self.lstm_size))

In [3]:
def predict(device, net, words, n_vocab, vocab_to_int, int_to_vocab, top_k=5):
    net.eval()

    state_h, state_c = net.zero_state(1)
    state_h = state_h.to(device)
    state_c = state_c.to(device)
    for w in words:
        ix = torch.tensor([[vocab_to_int[w]]]).to(device)
        output, (state_h, state_c) = net(ix, (state_h, state_c))
    
    _, top_ix = torch.topk(output[0], k=top_k)
    choices = top_ix.tolist()
    choice = np.random.choice(choices[0])

    words.append(int_to_vocab[choice])
    for _ in range(100):
        ix = torch.tensor([[choice]]).to(device)
        output, (state_h, state_c) = net(ix, (state_h, state_c))

        _, top_ix = torch.topk(output[0], k=top_k)
        choices = top_ix.tolist()
        choice = np.random.choice(choices[0])
        words.append(int_to_vocab[choice])
        metal_text = (' '.join(words))

    return metal_text

In [4]:
# def main():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
with open('dictionaries' + '.pkl', 'rb') as f:
    diction =  pickle.load(f)
seq_size = diction[0]
n_vocab = diction[1]
vocab_to_int = diction[2]
embedding_size = diction[3]
lstm_size = diction[4]
int_to_vocab = diction[5]

net = RNNModule(n_vocab, seq_size,
            embedding_size, lstm_size)

net.load_state_dict(torch.load('LSTM_gener'))

net = net.to(device)

while True:
    try:
        # Note: Python 2.x users should use raw_input, the equivalent of 3.x's input
        initial_words = input("Please the word: ")
    except ValueError:
        print("Sorry, METAL SONGS DO NOT USE THIS POPPY WORDS!")
        continue
    else:
        break
        
initial_words = initial_words.split(' ')

METAL_text = predict(device, net, initial_words, n_vocab,
                    vocab_to_int, int_to_vocab, top_k=5)

print((METAL_text.strip().lower()
    .replace(' \n ', '\n')
    .replace(' ! ', '! ')
    .replace(' ? ', '? ')
    .replace(' , ', ', ')))

Please the word: dragon from the sky
dragon from the sky is no longer to the world to the world is
i can never be a man ,
i'm my head to be, i'm my head
i am
and i'm the last man
i don't believe, i'm a man and
you have been, i'm no one and you are
the night is my mind and you will see you to me to be the one
i'm my mind
and we got a new
the one can you see the end, i'm not, the land
the last
