In [1]:
import tensorflow as tf
import plotly.graph_objects as go
from ipywidgets import widgets
import numpy as np
import math

In [2]:
gammaSlide = widgets.FloatSlider(
    value = 0.2,
    min = 0.01,
    max = 3,
    step = 0.01,
    description = "gamma"
)

nSlide = widgets.IntSlider(
    value = 1000,
    min = 5,
    max = 1000,
    step = 5,
    description = "n"
)

valBoost = widgets.Checkbox(description = 'higher values')
matchCase = widgets.Dropdown(options=[('Only Random', 1), ('Single Match', 2), ('10 Repeat Matches', 3)], value=1, description='How many matches')
#singleBest = widgets.Checkbox(description = 'single highest value')
addZeros = widgets.Checkbox(description = 'add 25% zeros')
zscaleSwitch = widgets.Checkbox(description = 'z-normalize')
mscaleSwitch = widgets.Checkbox(description = 'minmax-normalize')
squareSwitch = widgets.Checkbox(description = 'square softmax')
sqrtSwitch = widgets.Checkbox(description = 'sqrt softmax')
doubleSwitch = widgets.Checkbox(description = 'double softmax')
smnormSwitch = widgets.Checkbox(description = 'minmax normalize softmax')
valdivSwitch = widgets.Checkbox(description = 'divide values by softmax')
valmulSwitch = widgets.Checkbox(description = 'multiply values with softmax')

container = widgets.HBox(children=[nSlide, gammaSlide, valBoost, matchCase, addZeros])
container2 = widgets.HBox(children=[zscaleSwitch, mscaleSwitch, squareSwitch, sqrtSwitch, doubleSwitch, smnormSwitch])
container3 = widgets.HBox(children=[valdivSwitch, valmulSwitch])

In [3]:
realisticValues = True
if realisticValues:
    values = np.random.choice([-118.4823226928711, -108.17951202392578, -105.08866882324219, -97.87670135498047, -94.78585815429688, -91.69501495361328, -86.54360961914062, -83.45276641845703, -80.36192321777344, -77.27108001708984, -76.24079895019531, -73.14995574951172, -70.05911254882812, -64.90770721435547, -63.87742614746094, -61.816864013671875, -58.72602081298828, -55.63517761230469, -54.604896545410156, -51.51405334472656, -49.4534912109375, -48.42321014404297, -43.27180480957031, -40.18096160888672, -37.090118408203125, -32.968994140625, -29.878150939941406, -26.787307739257812, -15.454216003417969],
                              size=nSlide.max,
                              p=[0.5565744122640054, 0.00017362004218323987, 0.3369836411338032, 0.00016075929831781472, 0.000109316322856114, 0.08850120890992336, 0.00019934152991409024, 0.00012217706672153916, 2.5721487730850353e-05, 0.014095375276505994, 0.0002122022737795154, 9.002520705797624e-05, 1.9291115798137764e-05, 0.00023792376151036577, 0.0014468336848603323, 7.716446319255105e-05, 1.9291115798137764e-05, 6.430371932712588e-06, 0.0002636452492412161, 5.787334739441329e-05, 0.00013503781058696435, 1.9291115798137764e-05, 0.00022506301764494059, 3.215185966356294e-05, 1.2860743865425176e-05, 0.00010288595092340141, 7.073409125983846e-05, 1.2860743865425176e-05, 1.2860743865425176e-05])
else:
    values = np.random.gumbel(1, 0.5, nSlide.max)
    buf = values[0]
    mv = max(values)
    mi = np.argmax(values)
    values[0] = mv
    values[mi] = buf
    
#print(values)

In [4]:
trace1 = go.Bar(x=list(range(nSlide.value)), y=values[0:nSlide.value], name='values')
trace2 = go.Bar(x=list(range(nSlide.value)), y=tf.nn.softmax(gammaSlide.value * values[0:nSlide.value]).numpy(), name='softmax')
fig = go.FigureWidget(data=[trace1, trace2], 
                      layout=go.Layout(
                          title={'text': "softmax var: "+str(np.var(trace2.y))+"   softmax stdev: "+str(np.std(trace2.y))+"   softmax mean: "+str(np.mean(trace2.y))+"   softmax median: "+str(np.median(trace2.y))}
                      ))

In [5]:
def response(change):
    size = nSlide.value
    gamma = gammaSlide.value
    vals = np.array(values[0:size])
    if valBoost.value:
        vals *= 10
        
    if matchCase.value == 2:
        if realisticValues:
            vals[0] = 33.4895253
        else:
            vals /= 3
            vals[0] *= 3
    elif matchCase.value == 3:
        if realisticValues:
            for i in range(10):
                vals[i] = 33.4895253
        else:
            vals /= 3
            for i in range(10):
                vals[i] *= 3
                
    if addZeros.value:
        nzeros = len(vals)//4
        for i in range(nzeros):
            vals[len(vals)-1-i] = 0
        
    if zscaleSwitch.value:
        mean = np.mean(vals)
        std = np.std(vals)
        vals -= mean
        vals /= std
        
    if mscaleSwitch.value:
        maxv = np.max(vals)
        minv = np.min(vals)
        vals -= minv
        vals /= (maxv-minv)
        
    smruns = 2 if doubleSwitch.value else 1
    softmax = vals
    for i in range(smruns):
        if doubleSwitch.value and i == 1:
            smmax = np.max(softmax)
            smmin = np.min(softmax)
            softmax -= smmin
            softmax /= (smmax-smmin)
            
        softmax = tf.nn.softmax(gamma*softmax).numpy()

        if squareSwitch.value:
            softmax *= softmax

        if sqrtSwitch.value:
            softmax = np.sqrt(softmax)
    
    if smnormSwitch.value:
        smmax = np.max(softmax)
        smmin = np.min(softmax)
        softmax -= smmin
        softmax /= (smmax-smmin)
        
    if valdivSwitch.value:
        vals /= softmax
        
    if valmulSwitch.value:
        vals *= softmax
    
    with fig.batch_update():
        fig.data[0].x = list(range(size))
        fig.data[1].x = list(range(size))
        fig.data[0].y = vals
        fig.data[1].y = softmax
        fig.layout.title.text = "softmax var: "+str(np.var(softmax))+"   softmax stdev: "+str(np.std(softmax))+"   softmax mean: "+str(np.mean(softmax))+"   softmax median: "+str(np.median(softmax))+"<br>value var: "+str(np.var(vals))+"   value stdev: "+str(np.std(vals))+"   value mean: "+str(np.mean(vals))+"   value median: "+str(np.median(vals))

In [6]:
gammaSlide.observe(response, names='value')
nSlide.observe(response, names='value')
valBoost.observe(response, names='value')
matchCase.observe(response, names='value')
addZeros.observe(response, names='value')
zscaleSwitch.observe(response, names='value')
mscaleSwitch.observe(response, names='value')
squareSwitch.observe(response, names='value')
sqrtSwitch.observe(response, names='value')
doubleSwitch.observe(response, names='value')
smnormSwitch.observe(response, names='value')
valdivSwitch.observe(response, names='value')
valmulSwitch.observe(response, names='value')

In [7]:
widgets.VBox([container, container2, container3, fig])

VBox(children=(HBox(children=(IntSlider(value=1000, description='n', max=1000, min=5, step=5), FloatSlider(val…