In [1]:
import json
import os
import sys
sys.path.insert(0, "../../src/chung/")

import matplotlib.pyplot as plt

from ipywidgets import Layout, IntSlider, Image, Text, HBox, VBox, interact
from IPython.display import display

# Training Visualizations

This notebook aims to show how a predicted example changed throughout the training progress.

### Music Sequence Modeling Visualizations

We can see how the model learned to output a music piece throughout the training progress. Here's the code for visualizing:

In [2]:
def visualize_music_training(dataset_name, model_name, data_type):
    music_dataset_configs = {'all': {'hbox': True, 'width': '45%'},
                             'nott': {'hbox': False},
                             'jsb': {'hbox': True, 'width': '45%'},
                             'muse': {'hbox': False},
                             'piano': {'hbox': False}}

    pred_dir = f'../../src/chung/exp/music/{dataset_name}/{model_name}/img_{data_type}'
    with open(f'../../src/chung/exp/music/{dataset_name}/{model_name}/img_{data_type}/gold.png', 'rb') as f:
        img_gold = f.read()
    wig_img_gold = Image(value=img_gold, format='png', **music_dataset_configs[dataset_name])

    def visualizer(img_idx):
        with open(os.path.join(pred_dir, f'{img_idx}.png'), 'rb') as f:
            img_pred = f.read()
        wig_img_pred = Image(value=img_pred, format='png', **music_dataset_configs[dataset_name])
        if music_dataset_configs[dataset_name]['hbox']:
            box = HBox([wig_img_pred, wig_img_gold])
        else:
            box = VBox([wig_img_pred, wig_img_gold])
        display(box)

    return visualizer

* The `dataset_name` can be one of `['all', 'nott', 'jsb', 'muse', 'piano']`.
* The `model_name` can be one of `['lstm', 'gru', 'tanh']`.
* The `data_type` can be one of `['train', 'valid']`, which is an example of which particular dataset split.

We use `interact` widget to visualize the change in the example as shown below. The slider can be moved to see results in latter epochs. The first visualized image is the predicted music sequence while the second visualized image is the ground truth music sequence.

First, let's check the difference between GRU, LSTM, Tanh on the JSB validation set.

#### GRU on JSB Validation Set

In [3]:
visual_func = visualize_music_training('jsb', 'gru', 'valid')
_ = interact(visual_func, img_idx=IntSlider(description='Epoch:',
                            value=70, min=0, max=499, step=1,
                            layout=Layout(width='90%')))

interactive(children=(IntSlider(value=70, description='Epoch:', layout=Layout(width='90%'), max=499), Output()…

#### LSTM on JSB Validation Set

In [4]:
visual_func = visualize_music_training('jsb', 'lstm', 'valid')
_ = interact(visual_func, img_idx=IntSlider(description='Epoch:',
                            value=70, min=0, max=499, step=1,
                            layout=Layout(width='90%')))

interactive(children=(IntSlider(value=70, description='Epoch:', layout=Layout(width='90%'), max=499), Output()…

#### Tanh RNN on JSB Validation Set

In [5]:
visual_func = visualize_music_training('jsb', 'tanh', 'valid')
_ = interact(visual_func, img_idx=IntSlider(description='Epoch:',
                            value=70, min=0, max=499, step=1,
                            layout=Layout(width='90%')))

interactive(children=(IntSlider(value=70, description='Epoch:', layout=Layout(width='90%'), max=499), Output()…

It seems that GRU has never been able to really learn the notes at the top while LSTM generalized somewhat well for the top notes. The top notes on the GRU predictions are quite dense when it should be sparser. On the other hand, Tanh have trouble generalizing the basic structure of the music (like chords) for most of the epochs.

The notes that are singled out and not really connected with others are the hardest to learn for these models. It should be noted that LSTM seem to have generalized the structure of the music quite well, especially for the notes at the bottom.

We can also see the difference for long music pieces by comparing LSTM with Tanh in the Nottingham dataset.

#### LSTM on Nottingham Validation Set

In [6]:
visual_func = visualize_music_training('nott', 'lstm', 'valid')
_ = interact(visual_func, img_idx=IntSlider(description='Epoch:',
                            value=221, min=0, max=499, step=1,
                            layout=Layout(width='90%')))

interactive(children=(IntSlider(value=221, description='Epoch:', layout=Layout(width='90%'), max=499), Output(…

#### Tanh RNN on Nottingham Validation Set

In [7]:
visual_func = visualize_music_training('nott', 'tanh', 'valid')
_ = interact(visual_func, img_idx=IntSlider(description='Epoch:',
                            value=221, min=0, max=499, step=1,
                            layout=Layout(width='90%')))

interactive(children=(IntSlider(value=221, description='Epoch:', layout=Layout(width='90%'), max=499), Output(…

It seems that LSTM has generalized the unique structure of the top notes quite well, whereas Tanh RNN is constantly confused about what the top notes should be. Checking around epoch 221 show how constant LSTM is predicting the top notes while Tanh RNN constantly changed.

### Machine Translation Visualizations

Let's also do some visualizations on the machine translation task. Here's the code for visualizing:

In [8]:
def visualize_translation_training(dataset_name, model_name, data_type):
    pred_dir = f'../../src/chung/exp/translation/{dataset_name}/{model_name}/text_{data_type}'
    with open(f'../../src/chung/exp/translation/{dataset_name}/{model_name}/text_{data_type}/gold.json', 'r') as f:
        text_gold = json.load(f)
    wig_text_src = Text(value=text_gold['src_text'], description="Source Text:", layout=Layout(width="95%"))
    wig_text_gold = Text(value=text_gold['tgt_text'], description="Gold Text:", layout=Layout(width="95%"))

    def visualizer(text_idx):
        with open(os.path.join(pred_dir, f'{text_idx}.json'), 'r') as f:
            text_pred = json.load(f)
        wig_text_pred = Text(value=text_pred['tgt_text'], description="Pred Text:", layout=Layout(width="95%"))
        wig_text_pred_nll = Text(value=str(text_pred['nll']), description="Pred NLL:", layout=Layout(width="95%"))
        box = VBox([wig_text_src, wig_text_gold, wig_text_pred, wig_text_pred_nll])
        display(box)

    return visualizer

* The `dataset_name` can be one of `['en_de', 'de_it', 'it_en']`.
* The `model_name` can be one of `['lstm', 'gru', 'tanh']`.
* The `data_type` can be one of `['train', 'valid']`, which is an example of which particular dataset split.

The visualization will also output the negative log-likelihood of the predicted text at the epoch of training. The negative log-likelihood should also get lower here as the model is learning to increase the probability of the gold (target) text in the predicted text.

Let's look at how GRU, LSTM, Tanh RNN trained the IWSLT2017 English to German dataset:

#### GRU on IWSLT2017 English to German Validation Set

In [9]:
visual_func = visualize_translation_training('en_de', 'gru', 'valid')
_ = interact(visual_func, text_idx=IntSlider(description='Epoch:',
                            value=0, min=0, max=99, step=1,
                            layout=Layout(width='90%')))

interactive(children=(IntSlider(value=0, description='Epoch:', layout=Layout(width='90%'), max=99), Output()),…

#### LSTM on IWSLT2017 English to German Validation Set

In [10]:
visual_func = visualize_translation_training('en_de', 'lstm', 'valid')
_ = interact(visual_func, text_idx=IntSlider(description='Epoch:',
                            value=0, min=0, max=99, step=1,
                            layout=Layout(width='90%')))

interactive(children=(IntSlider(value=0, description='Epoch:', layout=Layout(width='90%'), max=99), Output()),…

#### Tanh RNN on IWSLT2017 English to German Validation Set

In [11]:
visual_func = visualize_translation_training('en_de', 'tanh', 'valid')
_ = interact(visual_func, text_idx=IntSlider(description='Epoch:',
                            value=0, min=0, max=99, step=1,
                            layout=Layout(width='90%')))

interactive(children=(IntSlider(value=0, description='Epoch:', layout=Layout(width='90%'), max=99), Output()),…

All models have learned to output "Und der", "Und das", "Und die" for translating "And the". It seems that for latter epochs the Tanh RNN model have completely overfit and forgot about "Und der". It seems that LSTM and GRU are outputting similar gibberishes throughout training.