In [1]:
# model and auxiliary stuff
import torch
import torch.nn as nn
import numpy as np
import json
from thefuzz import process

class MyModel(nn.Module):
    def __init__(self, hidden_layer_size, vocab_size):
        super(MyModel, self).__init__()
        self.hidden_layer_size = hidden_layer_size
        self.vocab_size = vocab_size
        
        self.gru = nn.GRU(input_size=vocab_size, hidden_size=hidden_layer_size)
        self.h2logits = nn.Linear(hidden_layer_size, vocab_size)

    def forward(self, x, h):
        h, _ = self.gru(x, h)        
        logits = self.h2logits(h)
        return logits, h
    
    def initHidden(self):
        return torch.zeros(1, self.hidden_layer_size)

model = torch.load('dino_rnn.pth')
with open('dino_model_vocab.json') as f:
    vocab = json.load(f)
chars_to_idx = {v:k for k,v in enumerate(vocab)}
idx_to_chars = {k:v for k,v in enumerate(vocab)}
n = 20

def sample():
    vocab_size = model.vocab_size
    hidden_layer_size = model.hidden_layer_size
    h_prev = model.initHidden()
    x = torch.zeros(1, vocab_size)
    
    indices = []
    idx = -1
    counter = 0
    while (counter <= n and idx != chars_to_idx['\n']):
        logits, h_prev = model.forward(x, h_prev)
        probs = nn.Softmax(dim=1)(logits)
        
        # Sample the index of the character using generated probs distribution
        idx = np.random.choice(vocab_size, p=probs.ravel().detach().numpy())

        # Add the char to the sequence
        indices.append(idx)
        
        # Update a_prev and x
        x = torch.zeros(1, vocab_size)
        x[0, idx] = 1
        
        counter += 1 
        
    sequence = "".join([idx_to_chars[idx] for idx in indices if idx != 0])
    return sequence

In [73]:
import ipywidgets as widgets
from IPython.display import display
button = widgets.Button(
    description='Generate',
    disabled=False,
    button_style='', # 'success', 'info', 'warning', 'danger' or ''

)
generated_output = widgets.Output()
top5_closest_output = widgets.Output()

with open('dino_names_scraped_from_wikipedia.txt') as f:
    real_dino_names = f.read().split('\n') # lowercase

def onclick(change):
    generated_output.clear_output()
    top5_closest_output.clear_output()
    generated = sample()
    while generated in real_dino_names:
        generated = sample()
        
    top5 = [x[0].capitalize() for x in process.extract(generated, real_dino_names, limit=5)]    
        
    with generated_output:
        display(widgets.HTML(f'<h2>{generated.capitalize()}</h2>'))
    with top5_closest_output:
        display(widgets.HTML('<p>Top 5 closest real dinosaur names:</p>'))
        for x in top5:
            display(widgets.HTML(f'<li><a href="https://en.wikipedia.org/wiki/{x}">{x}</a></li>'))

button.on_click(onclick)

In [74]:
widgets.VBox([
    widgets.Label('Click to generate a unique dinosaur name!'),
    button,
    generated_output,
    top5_closest_output
])

VBox(children=(Label(value='Click to generate a unique dinosaur name!'), Button(description='Generate', style=…