In [5]:
# model and auxiliary stuff
import torch
import torch.nn as nn
import numpy as np
import json
import random

with open('dino_first_letter_freq_dist.csv') as f:
    first_letter, frequency = zip(*[
        line.split(',') for line in f.read().split()
    ])
frequency = [float(f) for f in frequency]

with open('dino_model_vocab.json') as f:
    vocab = json.load(f)
token_to_index = {v:k for k,v in enumerate(vocab)}
index_to_token = {k:v for k,v in enumerate(vocab)}

class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        
        self.rnn = nn.RNN(len(vocab), hidden_size)
        self.linear1 = nn.Linear(hidden_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(0.1)
        # note we are going to output logits and use CrossEntropyLoss,
        # so no need to define softmax here
    
    def forward(self, input, hidden):
        # input is a single one-hot encoded token, in the form of a 2d
        # array with first dimension 1 and second dimension the vocab size
        hidden, _ = self.rnn(input, hidden) # the two outputs are equal the way we have things set up; not sure of different in general though
        x = self.dropout(hidden)
        x = self.linear1(x)
        x = self.dropout(x)
        logits = self.linear2(x)
        return logits, hidden
    
    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size)
rnn = torch.load('dino_rnn.pth')

def letterTensor(letter):
    onehot = np.zeros((1, len(vocab)))
    index = token_to_index[letter]
    onehot[0, index] = 1
    return torch.from_numpy(np.expand_dims(onehot, axis=1)).float() # numpy default is float64 but torch wants float32

def generate(maxlen=30):
    # first generate first letter
    x = random.choices(first_letter, weights=frequency, k=1)[0]
    with torch.no_grad():
        hidden = rnn.initHidden()
        generated = ''

        while x != '\n':
            # append this letter to the generated string
            generated += x
            # convert x to an input tensor
            letter_tensor = letterTensor(x)
            # generate logits and hidden
            logits, hidden = rnn(letter_tensor[[0]], hidden)
            # from the logits, compute probabilities
            probs = nn.Softmax(dim=2)(logits).flatten()
            # pull next letter from this probability distribution
            x = random.choices(vocab, weights=probs, k=1)[0]
    
    return generated

In [19]:
generate()

'bukpnoleon'

In [28]:
import ipywidgets as widgets
from IPython.display import display
button = widgets.Button(
    description='Generate',
    disabled=False,
    button_style='', # 'success', 'info', 'warning', 'danger' or ''

)
out = widgets.Output()
def onclick(change):
    out.clear_output()
    generated = generate()
    with out:
        display(generated)

button.on_click(onclick)

In [29]:
widgets.VBox([
    button,
    out
])

VBox(children=(Button(description='Generate', style=ButtonStyle()), Output()))