In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import os
from ADC.ste import ste_round, ste_floor # Import from your ste.py
from ADC.quantizers import ADCQuantizer, ADCQuantizerAshift

In [2]:
%load_ext autoreload
%autoreload 2


In [3]:
def quantize_uniform_core(x, n_bits, round_fn, data_min_val, data_max_val):
    """
    Core uniform quantization logic. Returns dequantized values and integer levels.
    """
    data_min = torch.tensor(data_min_val, device=x.device, dtype=x.dtype)
    data_max = torch.tensor(data_max_val, device=x.device, dtype=x.dtype)
    num_levels = 2**n_bits

    if num_levels <= 1 or (data_max - data_min).abs() < 1e-9:
        levels_int = torch.zeros_like(x, dtype=torch.float32) # float for consistency
        dequant_val = torch.full_like(x, data_min)
        return dequant_val, levels_int

    scale = (data_max - data_min) / (num_levels - 1)
    
    x_clamped_input = torch.clamp(x, data_min, data_max)
    x_transformed = (x_clamped_input - data_min) / scale
    
    quantized_levels_float = round_fn(x_transformed)
    quantized_levels_int = torch.clamp(quantized_levels_float, 0, num_levels - 1)
    
    x_dequant = quantized_levels_int * scale + data_min
    # Shift levels for symmetric to be centered around 0 if desired for interpretation
    # For this analysis, levels 0 to N-1 are fine.
    return x_dequant, quantized_levels_int.to(torch.float32)

In [4]:
def quantize_elementwise(data, n_bits, quant_type, round_fn):
    if quant_type == 'affine':
        min_val, max_val = data.min().item(), data.max().item()
        if max_val <= min_val + 1e-9: max_val = min_val + 1.0
    elif quant_type == 'symmetric':
        abs_max = data.abs().max().item()
        min_val, max_val = -abs_max, abs_max
        if max_val <= min_val + 1e-9 : 
            min_val = -1.0 # Default small range if data is all zero
            max_val = 1.0
    else:
        raise ValueError("quant_type must be 'affine' or 'symmetric'")
    
    return quantize_uniform_core(data, n_bits, round_fn, min_val, max_val)

In [5]:
def generate_normal_data(num_samples, mean=0.0, std=0.1):
    return torch.randn(num_samples) * std + mean, f"Normal (μ={mean:.2f}, σ={std:.2f})"

def generate_exponential_data(num_samples, rate=1.0):
    return torch.distributions.Exponential(rate).sample((num_samples,)), f"Exponential (rate={rate:.1f})"

In [18]:
def run_product_quantization_comparison(bw=4, bx=4, ba=8, k_adc=4, num_samples=10000):
    print(f"\n--- Product Quantization Comparison ---")
    print(f"Params: w_bits={bw}, x_bits={bx}, final_product_bits={ba}, adc_k={k_adc}, Samples={num_samples}")

    w_orig, w_dist_name = generate_normal_data(num_samples, std=0.05)
    x_orig, x_dist_name = generate_exponential_data(num_samples, rate=0.5) # Mean = 2

    # --- Method 1: Standard Quantization of Product ---
    w_std_dequant_for_prod, _ = quantize_elementwise(w_orig, n_bits=bw, quant_type='symmetric', round_fn=torch.round)
    x_std_dequant_for_prod, _ = quantize_elementwise(x_orig, n_bits=bx, quant_type='affine', round_fn=torch.round)
    product_for_std_quant = w_std_dequant_for_prod * x_std_dequant_for_prod
    # Quantize the product itself to 'ba' bits using affine quantization and standard round
    _, product_std_quant_levels = quantize_elementwise(product_for_std_quant, n_bits=ba, quant_type='affine', round_fn=torch.round)

    # --- Method 2: ADC-Style (STE-Floor for w,x pre-quant) ---
    w_floor_inter_dequant, w_floor_inter_levels = quantize_elementwise(w_orig, n_bits=bw, quant_type='symmetric', round_fn=ste_floor)
    x_floor_inter_dequant, x_floor_inter_levels = quantize_elementwise(x_orig, n_bits=bx, quant_type='affine', round_fn=ste_floor)
    product_input_to_adc_floor = w_floor_inter_dequant * x_floor_inter_dequant
    adc_module_floor = ADCQuantizer(M=1, bx=bx, bw=bw, ba=ba, k=k_adc) # M=1 for element-wise product
    adc_floor_output_levels = adc_module_floor(product_input_to_adc_floor)

    # --- Method 3: ADC-Style (STE-Round for w,x pre-quant) ---
    w_round_inter_dequant, w_round_inter_levels = quantize_elementwise(w_orig, n_bits=bw, quant_type='symmetric', round_fn=ste_round)
    x_round_inter_dequant, x_round_inter_levels = quantize_elementwise(x_orig, n_bits=bx, quant_type='affine', round_fn=ste_round)
    product_input_to_adc_round = w_round_inter_dequant * x_round_inter_dequant
    adc_module_round = ADCQuantizer(M=1, bx=bx, bw=bw, ba=ba, k=k_adc) 
    adc_round_output_levels = adc_module_round(product_input_to_adc_round)
    
    print("\n  --- Product Bin Counts (all target 'ba' bits) ---")
    print(f"    Standard Product Quant (affine, round) bins: {len(torch.unique(product_std_quant_levels))}")
    print(f"    ADC Output (pre w,x quant w/ STE-Floor) bins: {len(torch.unique(adc_floor_output_levels))}")
    print(f"    ADC Output (pre w,x quant w/ STE-Round) bins: {len(torch.unique(adc_round_output_levels))}")
    print(f"  --- Intermediate Quantized W/X Bin Counts ---")
    print(f"    Intermediate W (symm, STE-Floor, {bw}-bit) bins: {len(torch.unique(w_floor_inter_levels))}")
    print(f"    Intermediate X (aff, STE-Floor, {bx}-bit) bins: {len(torch.unique(x_floor_inter_levels))}")

    # --- Visualization ---
    fig, axes = plt.subplots(3, 3, figsize=(20, 16)) # Now 3 rows
    title_str = (f"Product Quantization Bin Comparison ({ba}-bit output)\n"
                 f"Inputs: w ({bw}b-symm), x ({bx}b-aff). ADC M=1, k={k_adc}")
    fig.suptitle(title_str, fontsize=15)

    # Row 0: Originals 
    axes[0, 0].hist(w_orig.cpu().numpy(), bins=100, color='gray', alpha=0.7, density=True)
    axes[0, 0].set_title(f"Original W ({w_dist_name})")
    axes[0, 0].grid(True)

    axes[0, 1].hist(x_orig.cpu().numpy(), bins=100, color='skyblue', alpha=0.7, density=True)
    axes[0, 1].set_title(f"Original X ({x_dist_name})")
    axes[0, 1].grid(True)
    axes[0, 2].axis('off') # Keep one empty for layout or future use

    # Helper for plotting levels (bins)
    def plot_levels(ax, levels_data, title, color, is_dequant_plot=False):
        unique_vals, counts = torch.unique(levels_data, return_counts=True)
        bar_width = 0.8 
        if len(unique_vals) > 1:
            min_diff = (torch.sort(unique_vals).values[1:] - torch.sort(unique_vals).values[:-1]).min().item()
            # For dequantized values, which are continuous-like, or many bins, use hist
            if is_dequant_plot or len(unique_vals) > 2 * (2**max(bw,bx,ba)): # Heuristic for hist
                ax.hist(levels_data.cpu().numpy(), bins=50, color=color, alpha=0.8, density=True)
            else: # For few discrete levels, use bar
                bar_width = min_diff * 0.8 if min_diff > 1e-6 else 0.8 # Ensure positive width for bars
                if bar_width < 1e-5: bar_width = 0.05 * (unique_vals.abs().mean().item() if len(unique_vals)>0 else 1.0)
                if bar_width < 1e-5: bar_width = 0.05
                ax.bar(unique_vals.cpu().numpy(), (counts.cpu().numpy() / num_samples), 
                       width=bar_width, color=color, alpha=0.85)
        elif len(unique_vals) == 1: 
             bar_width = 0.05 * abs(unique_vals.item()) if abs(unique_vals.item()) > 0 else 0.05
             if bar_width == 0: bar_width = 0.05
             ax.bar(unique_vals.cpu().numpy(), (counts.cpu().numpy() / num_samples), 
                    width=bar_width, color=color, alpha=0.85)
        else: # No data or no unique values
            ax.text(0.5, 0.5, "No data/bins", ha="center", va="center")


        ax.set_title(f"{title}\nUnique Bins/Values: {len(unique_vals)}")
        ax.set_ylabel("Normalized Freq. / Density")
        ax.grid(True)

    # Row 1: Intermediately Quantized W, X (using STE-Floor for ADC path example) and a Product Input
    # Plotting dequantized values here to see their distribution before product
    plot_levels(axes[1, 0], w_floor_inter_dequant, f"Intermed. Quant W ({bw}b, STE-Floor)\n(Dequantized)", 'darkblue', is_dequant_plot=True)
    plot_levels(axes[1, 1], x_floor_inter_dequant, f"Intermed. Quant X ({bx}b, STE-Floor)\n(Dequantized)", 'darkgreen', is_dequant_plot=True)
    axes[1, 2].axis('off')


    # Row 2: Final Product Quantized Levels (Bins) from the three methods
    plot_levels(axes[2, 0], product_std_quant_levels, f"Std. Product Quant Levels ({ba}-bit)\n(Affine, torch.round)", 'purple')
    plot_levels(axes[2, 1], adc_floor_output_levels, f"ADC Output Levels ({ba}-bit)\n(Pre w,x STE-Floor)", 'teal')
    plot_levels(axes[2, 2], adc_round_output_levels, f"ADC Output Levels ({ba}-bit)\n(Pre w,x STE-Round)", 'orangered')
    
    for ax_row in axes:
        for ax in ax_row:
            ax.set_xlabel("Value / Level")

    plt.tight_layout(rect=[0, 0, 1, 0.92]) 
    output_dir = os.path.join("ADC", "analysis_results")
    os.makedirs(output_dir, exist_ok=True)
    plot_filename = os.path.join(output_dir, f"product_quant_compare_w{bw}x{bx}_out{ba}_k{k_adc}.png")
    plt.savefig(plot_filename)
    print(f"\nPlot saved to {plot_filename}")
    plt.close()

In [19]:
run_product_quantization_comparison(bw=8, bx=8, ba=8, k_adc=4)


--- Product Quantization Comparison ---
Params: w_bits=8, x_bits=8, final_product_bits=8, adc_k=4, Samples=10000

  --- Product Bin Counts (all target 'ba' bits) ---
    Standard Product Quant (affine, round) bins: 144
    ADC Output (pre w,x quant w/ STE-Floor) bins: 2
    ADC Output (pre w,x quant w/ STE-Round) bins: 2
  --- Intermediate Quantized W/X Bin Counts ---
    Intermediate W (symm, STE-Floor, 8-bit) bins: 186
    Intermediate X (aff, STE-Floor, 8-bit) bins: 153

Plot saved to ADC/analysis_results/product_quant_compare_w8x8_out8_k4.png
