In [1]:
import numpy as np
from scipy.stats import norm

def compute_mse(grid):
    """
    Compute the Mean Squared Error (MSE) for a given scaling parameter 'a' and number of quantization levels 'N'.
    """
    q = [-np.inf] + [(grid[i] + grid[i+1]) / 2 for i in range(len(grid) - 1)] + [np.inf]  # Quantization boundaries

    MSE = 0.0
    for i in range(len(grid)):
        left = q[i]
        right = q[i + 1]
        center = grid[i]

        # Probability of the interval
        P_i = norm.cdf(right) - norm.cdf(left)

        # First and second moments over the interval
        M1_i = norm.expect(lambda t: t, loc=0, scale=1, lb=left, ub=right)
        M2_i = norm.expect(lambda t: t**2, loc=0, scale=1, lb=left, ub=right)

        # MSE for the i-th interval
        E_i = M2_i - 2 * center * M1_i + center**2 * P_i
        MSE += E_i

    # Total MSE
    return MSE


def get_uniform_grid(a, N):
    return np.linspace(-a, a, N)  # Quantization centers

def get_fp4_grid(a:float=1):
    zeros = [+0, -0]
    normal = [sign * (1+m)/2 * 2**(e-1) for sign in [1, -1] for e in range(1,4) for m in range(1,3)]
    subnormal = [sign * (0+m) * 2**(-1) for m in range(1,2) for sign in [1, -1]]
    return a * np.array(sorted(zeros + normal + subnormal))


In [2]:
GRID_MSES = {}

In [3]:
from scipy.optimize import minimize
from tqdm.auto import tqdm

for bits in [1,2,3,4,5,6,7,8]:
    # Number of quantization levels
    N = 2**bits  # You can change this value as needed

    # Objective function for minimization
    def objective(a):
        return compute_mse(get_uniform_grid(a[0], N))

    # Initial guess for 'a'
    a0 = [2.0]

    # Bounds for 'a' to ensure it's positive
    bounds = [(0.1, 10.0)]

    # Minimize the MSE
    result = minimize(objective, a0, bounds=bounds, method='L-BFGS-B')

    # Optimal scaling parameter and corresponding MSE
    optimal_a = result.x[0]
    minimum_mse = result.fun

    GRID_MSES[bits] = minimum_mse
    print(f"{bits}: optimal scaling parameter (a): {optimal_a}, Minimum MSE: {minimum_mse}")


1: optimal scaling parameter (a): 0.7978845587140913, Minimum MSE: 0.3633802276324186
2: optimal scaling parameter (a): 1.493534520977036, Minimum MSE: 0.11884605038769407
3: optimal scaling parameter (a): 2.0510679063024964, Minimum MSE: 0.03743965939152373
4: optimal scaling parameter (a): 2.5139324513630887, Minimum MSE: 0.011542884500323213
5: optimal scaling parameter (a): 2.9160897658147453, Minimum MSE: 0.003495211376111403
6: optimal scaling parameter (a): 3.276597435983721, Minimum MSE: 0.0010400475795804263
7: optimal scaling parameter (a): 3.6010436416224247, Minimum MSE: 0.00030436603842457166
8: optimal scaling parameter (a): 3.884997364699907, Minimum MSE: 8.782117814336654e-05


In [8]:
from scipy.optimize import minimize

# Number of quantization levels
N = 16  # You can change this value as needed

# Objective function for minimization
def objective(a):
    return compute_mse(get_fp4_grid(a[0]))

# Initial guess for 'a'
a0 = [1.0]

# Bounds for 'a' to ensure it's positive
bounds = [(0.1, 10.0)]

# Minimize the MSE
result = minimize(objective, a0, bounds=bounds, method='L-BFGS-B')

# Optimal scaling parameter and corresponding MSE
optimal_a = result.x[0]
minimum_mse = result.fun

GRID_MSES["fp4"] = minimum_mse

print(f"Optimal scaling parameter (a): {optimal_a}")
print(f"Minimum MSE: {minimum_mse}")


Optimal scaling parameter (a): 0.487079483934662
Minimum MSE: 0.012684904138719949


In [9]:
GRID_MSES

{1: np.float64(0.3633802276324186),
 2: np.float64(0.11884605038769407),
 3: np.float64(0.03743965939152373),
 4: np.float64(0.011542884500323213),
 8: np.float64(8.782117814336654e-05),
 'fp4': np.float64(0.012684904138719949)}

In [None]:
from scipy.optimize import minimize

# Number of quantization levels
N = 3  # You can change this value as needed

# Objective function for minimization
def objective(a):
    return compute_mse(get_uniform_grid(a[0], N))

# Initial guess for 'a'
a0 = [2.0]

# Bounds for 'a' to ensure it's positive
bounds = [(0.1, 10.0)]

# Minimize the MSE
result = minimize(objective, a0, bounds=bounds, method='L-BFGS-B')

# Optimal scaling parameter and corresponding MSE
optimal_a = result.x[0]
minimum_mse = result.fun

print(f"Optimal scaling parameter (a): {optimal_a}")
print(f"Minimum MSE: {minimum_mse}")


Optimal scaling parameter (a): 1.2240089519030855
Minimum MSE: 0.19017403925019966


In [3]:
get_uniform_grid(optimal_a, N)

array([-1.22400895,  0.        ,  1.22400895])