**Importing Required Libraries**

In [1]:
import random
import numpy as np

import plotly.graph_objs as go
import plotly.subplots as sp

from colors import colors
from utils import apply_repetition_penalty

import ipywidgets as widgets

# Loading Words & Creating (Sampled) Vocabulary

In [2]:
with open('words.txt') as file:
    words = file.read().splitlines()

In [3]:
vocab_size = 15

random.seed(0)
vocab = sorted(random.sample(words, vocab_size))

# Creating Random Logits Distribution

In [4]:
np.random.seed(1618)
logits = np.random.normal(0, 0.8, size=vocab_size)

# Defining Parameters for Inferencing

In [5]:
repetition_penalty = 10

# Plotting

In [6]:
frequencies = np.zeros(shape=vocab_size, dtype=np.int16)

In [8]:
def update_graph():

    with output_widget:
        output_widget.clear_output()
        fig = sp.make_subplots(rows=5, cols=1, shared_xaxes=True)
    
        # Adding a trace for logits
        fig.add_trace(
            go.Bar(
                x=vocab,
                y=logits,
                name='Original Logits Distribution',
                marker={
                    'color': colors['logits']
                }
            ), row=1, col=1
        )
        fig.update_yaxes(title_text='Logits', row=1, col=1)
    
         # Adding trace for frequencies
        fig.add_trace(
            go.Bar(
                x=vocab,
                y=frequencies,
                name='Frequency',
                marker={
                    'color': colors['logits']
                }
            ), row=2, col=1
        )
        fig.update_yaxes(title_text='Frequency', row=2, col=1)
    
        processed_logits = apply_repetition_penalty(logits, frequencies, repetition_penalty) 
        # Adding trace for logits after repetition penalty
        fig.add_trace(
            go.Bar(
                x=vocab,
                y=processed_logits,
                name='Logits Distribution After Applying Repetition Penalty (RP)',
                marker={
                    'color': colors['logits']
                }
            ), row=3, col=1
        )
        fig.update_yaxes(title_text='Logits', row=3, col=1)
    
        pmf = np.exp(processed_logits)
        pmf = pmf / pmf.sum()
        # Adding PMF trace
        fig.add_trace(
            go.Bar(
                x=vocab,
                y=pmf,
                name='Probability Distribution',
                marker={
                    'color': colors['pmf']
                },
            ), row=4, col=1
        )
        fig.update_yaxes(title_text='Mass', row=4, col=1)
        
        # Adding PMF trace
        fig.add_trace(
            go.Scatter(
                x=vocab,
                y=pmf,
                name='Temperature-Scaled Approx. Probability Density Function',
                marker={
                    'color': colors['pmf']
                },
                fill='tozeroy',
                fillcolor=colors['pmf']
            ), row=5, col=1
        )
        fig.update_yaxes(title_text='Mass', row=5, col=1)
        
        fig.update_layout(
            showlegend=True,
            yaxis={'autorange': True},
            legend=dict(
                x=0.5,
                y=1.0,
                xanchor='center',
                yanchor='bottom',
                orientation='h'
            ),
            width=1000,
            height=1000
        )
    
        fig.show()

def update_frequencies_slider(change):
     frequencies_slider.value = frequencies[vocab.index(change['new'])]

def update_frequencies(change):
    frequencies[vocab.index(vocab_dropdown.value)] = change['new']
    update_graph()

output_widget = widgets.Output()

vocab_dropdown = widgets.Dropdown(
    options=vocab,
    layout=widgets.Layout(width='500px')
)

frequencies_slider = widgets.IntSlider(
    min=0,
    max=10,
    value=frequencies[0],
    step=1,
    layout=widgets.Layout(width='500px')
)

vocab_dropdown.observe(update_frequencies_slider, names='value')
frequencies_slider.observe(update_frequencies, names='value')
display(widgets.VBox(
    [
        widgets.Label('Select word to change it\'s frequency'), vocab_dropdown,
        widgets.Label('Change frequency'), frequencies_slider,
        output_widget
    ]
))
update_graph()

VBox(children=(Label(value="Select word to change it's frequency"), Dropdown(layout=Layout(width='500px'), opt…