In [1]:
import ipywidgets as widgets

import torch

from nnsum.data import Seq2SeqDataLoader
from nnsum.datasets import CopyDataset
import nnsum.embedding_context as ec
import nnsum.seq2seq as s2s

In [50]:
def make_dataset_widget():
    widgets.Dropdown(
        options=['1', '2', '3'],
        value='2',
        description='Number:',
        disabled=False,
    )

    vocab_size_slider = widgets.IntSlider(
        value=25,
        min=1,
        max=50000,
        step=1,
        description='Vocab Size:',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='d'
    )
    
    max_length_slider = widgets.IntSlider(
        value=10,
        min=1,
        max=500,
        step=1,
        description='Max Length:',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='d'
    )

    train_size_slider = widgets.IntSlider(
        value=2048,
        min=1,
        max=20000,
        step=1,
        description='Training Size:',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='d'
    )
    
    test_size_slider = widgets.IntSlider(
        value=100,
        min=1,
        max=20000,
        step=1,
        description='Test Size:',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='d'
    )
    
    dataset_box = widgets.Box([vocab_size_slider, max_length_slider, 
                               train_size_slider, test_size_slider])

    batch_size_slider = widgets.IntSlider(
        value=32,
        min=1,
        max=2048,
        step=1,
        description='Batch Size:',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='d'
    )
    
    learning_rate_text = widgets.BoundedFloatText(
        value=0.1,
        min=0,
        max=100.0,
        step=0.05,
        description='Learning Rate:',
        disabled=False
    )
    
    training_epochs_slider = widgets.IntSlider(
        value=50,
        min=1,
        max=300,
        step=1,
        description='Training Epochs:',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='d'
    )
       
    trainer_box = widgets.Box([batch_size_slider, learning_rate_text, training_epochs_slider])
    
    embedding_dim_slider = widgets.IntSlider(
        value=100,
        min=1,
        max=1024,
        step=1,
        description='Embedding Dim:',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='d'
    )
    
    hidden_dim_slider = widgets.IntSlider(
        value=100,
        min=1,
        max=1024,
        step=1,
        description='Hidden Dim:',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='d'
    )
    
    
    rnn_cell_dropdown = widgets.Dropdown(
        options=['rnn', 'gru', 'lstm'],
        value='gru',
        description='RNN Cell:',
        disabled=False,
    )
    
    rnn_layers_slider = widgets.IntSlider(
        value=1,
        min=1,
        max=3,
        step=1,
        description='RNN Layers:',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='d'
    )
    
    attention_dropdown = widgets.Dropdown(
        options=['none', 'dot'],
        value='dot',
        description='Attention:',
        disabled=False,
    )
    
    model_box = widgets.Box([embedding_dim_slider, hidden_dim_slider, rnn_cell_dropdown, rnn_layers_slider,
                             attention_dropdown])

    start_train = widgets.Button(
        description='Train Model!',
        disabled=False,
        button_style='', # 'success', 'info', 'warning', 'danger' or ''
        tooltip='Click me',
        icon='check'
    )
    
    controls = widgets.VBox([dataset_box, trainer_box, model_box, start_train])

    
    def on_button_clicked(b):
        torch.manual_seed(7492023)
        test_seed = torch.LongTensor(1).random_(0,2**31).item()
        train_seeds = torch.LongTensor(training_epochs_slider.value).random_(0,2**31).tolist()
        test_dataset = CopyDataset(vocab_size_slider.value, max_length_slider.value,
                                   test_size_slider.value, random_seed=test_seed)
        train_dataset = CopyDataset(vocab_size_slider.value, max_length_slider.value,
                                    train_size_slider.value, random_seed=train_seeds[0])

        src_vocab = ec.Vocab.from_word_list(train_dataset.word_list(), pad="<PAD>",
                                            unk="<UNK>", start="<START>") 
        src_vocabs = {"tokens": src_vocab}                           

        tgt_vocab = ec.Vocab.from_word_list(train_dataset.word_list(), pad="<PAD>",
                                            unk="<UNK>", start="<START>",
                                            stop="<STOP>")
        tgt_vocabs = {"tokens": tgt_vocab}

        train_dataloader = Seq2SeqDataLoader(train_dataset, src_vocabs, tgt_vocabs,
                                             batch_size=batch_size_slider.value,
                                             include_original_data=True) 
        test_dataloader = Seq2SeqDataLoader(test_dataset, src_vocabs, tgt_vocabs,
                                            batch_size=2,
                                            include_original_data=True)

        src_emb_ctx = ec.EmbeddingContext(src_vocab, embedding_size=embedding_dim_slider.value,
                                          name="tokens")
        enc = s2s.RNNEncoder(src_emb_ctx, hidden_dim=hidden_dim_slider.value, num_layers=rnn_layers_slider.value,
                             rnn_cell=rnn_cell_dropdown.value)                        

        tgt_emb_ctx = ec.EmbeddingContext(tgt_vocab, embedding_size=embedding_dim_slider.value,
                                          name="tokens")
        dec = s2s.RNNDecoder(tgt_emb_ctx, hidden_dim=hidden_dim_slider.value, num_layers=rnn_layers_slider.value,
                             attention=attention_dropdown.value, rnn_cell=rnn_cell_dropdown.value)

        model = s2s.EncoderDecoderBase(enc, dec)
        model.initialize_parameters()                                  

        loss_func = s2s.CrossEntropyLoss(tgt_vocab.pad_index)

        model.train()
        optim = torch.optim.SGD(model.parameters(), lr=learning_rate_text.value, weight_decay=.0001)


        for epoch in range(training_epochs_slider.value):
            train_dataset.seed(train_seeds[epoch])
            total_xent = 0
            total_tokens = 0
            for step, batch in enumerate(train_dataloader, 1):
                optim.zero_grad()
                model_state = model(batch)
                loss = loss_func(model_state,
                                 batch["target_output_features"]["tokens"],
                                 batch["target_lengths"])
                loss.backward()
                #print(epoch, torch.exp(loss).item())
                num_tokens = batch["target_lengths"].sum().item()
                total_xent += loss.item() * num_tokens
                total_tokens += num_tokens
                optim.step()
                print("\rEpoch={} Step={}/{}  Avg. X-Entropy={:0.5f}".format(
                    epoch, step, len(train_dataloader), total_xent / total_tokens),
                    end="", flush=True)
                
            print()

        
    
    start_train.on_click(on_button_clicked)
    
    

    return controls



In [52]:
make_dataset_widget()

VBox(children=(Box(children=(IntSlider(value=25, continuous_update=False, description='Vocab Size:', max=50000…