In [1]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider

In [2]:

def soft_quantize(x, k):
    # Have to be same as QKeras
    levels = np.array([-1.0, -0.5, 0.0, 0.5])
    
    x_reshaped = x[:, np.newaxis]      
    levels_reshaped = levels[np.newaxis, :]
    dist = np.square(x_reshaped - levels_reshaped)
    exp_term = np.exp(-k * dist)
    weights = exp_term / np.sum(exp_term, axis=1, keepdims=True)
    
    return np.sum(weights * levels, axis=1)


In [None]:
fig, ax = plt.subplots(figsize=(10, 7))
plt.subplots_adjust(bottom=0.25) 

x_input = np.linspace(-1.5, 1.5, 500)
initial_k = 1.0

identity_line, = ax.plot(x_input, x_input, 'k--', alpha=0.5, label='Identity (y=x)')
soft_quant_line, = ax.plot(x_input, soft_quantize(x_input, initial_k), 
                           'b-',
                           linewidth=2.5, 
                           label='Soft Quantization (Trainable)')


ax.grid(True, linestyle=':')
ax.set_xlim(-1.5, 1.5)
ax.set_ylim(-1.5, 1.5)
ax.axhline(0, color='black', linewidth=0.5)
ax.axvline(0, color='black', linewidth=0.5)

ax_slider = plt.axes([0.2, 0.1, 0.65, 0.03])
k_slider = Slider(
    ax=ax_slider,
    label='Annealing k',
    valmin=0,    # 10^0 = 1
    valmax=3,    # 10^3 = 1000
    valinit=0,   # Start at 10^0 = 1
    color='#007ACC'
)

def update(val):
    k = 10**k_slider.val
    new_y = soft_quantize(x_input, k)
    soft_quant_line.set_ydata(new_y)
    fig.canvas.draw_idle()

k_slider.on_changed(update)

plt.show()