In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.utils.data as data
#import timm

# Hyperparameters
BATCH_SIZE = 128
NUM_BITS = 8
FIXED_T = 100.5  # Fixed temperature for soft rounding
LR = 0.001  # Learning rate
NUM_ITERATIONS = 100  # Per-layer optimization iterations

# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

#train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
test_loader = data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Load Pretrained ResNet18 from timm
import resnet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = resnet.resnet18(pretrained=False, device=device)
model.to(device)

state_dict = torch.load('/content/resnet18.pt', map_location=torch.device('cpu'))
model.load_state_dict(state_dict, strict=False)
model.eval()

print("\nAccuracy BEFORE Quantization:")

# Function to evaluate model accuracy
def evaluate(model, test_loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f"Test Accuracy: {100 * correct / total:.2f}%")

# ====== Hook for capturing activations ======
temp_activations = {}

def activation_hook(layer_name):
    def hook(module, input, output):
        temp_activations[layer_name] = input[0].detach()
    return hook

# Register hooks for Conv2D layers
for name, layer in model.named_modules():
    if isinstance(layer, nn.Conv2d):
        layer.register_forward_hook(activation_hook(name))

# ====== Differentiable Min-Max Quantization Module ======
class MinMaxQuantization(nn.Module):
    def __init__(self, weight, num_levels=2**NUM_BITS, fixed_T=FIXED_T):
        super().__init__()
        self.num_levels = num_levels
        self.fixed_T = fixed_T

        w_min_init = weight.min().detach()
        w_max_init = weight.max().detach()
        range_padding = 0.05 * (w_max_init - w_min_init)
        self.w_min = nn.Parameter(w_min_init - range_padding)
        self.w_max = nn.Parameter(w_max_init + range_padding)

    def forward(self, w):
        EPSILON = 1e-6
        w_min_clamped = self.w_min.clamp(max=self.w_max.item() - EPSILON)
        w_max_clamped = self.w_max.clamp(min=w_min_clamped.item() + EPSILON)
        w_normalized = (w - w_min_clamped) / (w_max_clamped - w_min_clamped + EPSILON)

        q_levels = torch.linspace(0, 1, self.num_levels, device=w.device)
        distances = -torch.abs(w_normalized.unsqueeze(-1) - q_levels)
        soft_weights = torch.softmax(distances * self.fixed_T, dim=-1)
        w_quantized = (soft_weights * q_levels).sum(dim=-1)
        w_dequantized = w_quantized * (w_max_clamped - w_min_clamped) + w_min_clamped
        return w_dequantized

# ====== Per-Layer Optimization with Activation Usage ======
def optimize_per_layer(model, test_loader, num_iterations=NUM_ITERATIONS, lr=LR):
    model.eval()
    updated_state_dict = model.state_dict()
    quantization_layers = {}

    print("\nStarting per-layer quantization optimization...")

    # Get one batch for activations and loss computation
    data_iterator = iter(test_loader)
    images, labels = next(data_iterator)
    images, labels = images.to(device), labels.to(device)

    # Capture activations
    with torch.no_grad():
        model(images)

    # Initial accuracy reference
    with torch.no_grad():
        outputs = model(images)
        initial_loss = nn.CrossEntropyLoss()(outputs, labels).item()
    print(f"Initial Classification Loss Before Optimization: {initial_loss:.6f}")

    # Optimize layer by layer
    for name, param in model.named_parameters():
        if "conv" in name and "weight" in name:
            print(f"\nOptimizing {name}...")
            layer_name = name.replace(".weight", "")

            if layer_name not in temp_activations:
                print(f"Skipping {layer_name}: No activation found.")
                continue

            original_weight = param.clone().detach()
            quant_layer = MinMaxQuantization(original_weight).to(device)
            optimizer = optim.Adam(quant_layer.parameters(), lr=lr)
            mse_loss_fn = nn.MSELoss()
            activation_input = temp_activations[layer_name]  # Correct input for this layer

            prev_class_loss = float('inf')
            for iteration in range(num_iterations):
                optimizer.zero_grad()
                quantized_weight = quant_layer(original_weight)

                quantized_output = nn.functional.conv2d(
                    activation_input, quantized_weight, stride=param.shape[2], padding=param.shape[3]
                )
                original_output = nn.functional.conv2d(
                    activation_input, original_weight, stride=param.shape[2], padding=param.shape[3]
                )

                recon_loss = mse_loss_fn(quantized_output, original_output)
                class_loss = nn.CrossEntropyLoss()(model(images), labels)

                if class_loss > prev_class_loss:
                    print(f"Early stop at iter {iteration}: class_loss increased.")
                    break
                prev_class_loss = class_loss

                total_loss = 0.1 * recon_loss + 0.9 * class_loss
                total_loss.backward()
                optimizer.step()

                if iteration % 10 == 0:
                    print(f"Iter {iteration}: recon_loss={recon_loss.item():.8f}, "
                          f"class_loss={class_loss.item():.4f}, "
                          f"w_min={quant_layer.w_min.item():.4f}, "
                          f"w_max={quant_layer.w_max.item():.4f}")

            updated_state_dict[name] = quant_layer(original_weight).detach()

    model.load_state_dict(updated_state_dict)
    print("\nPer-layer optimization complete.")

# ====== Run Optimization and Evaluation ======
#evaluate(model, test_loader)  # Before quantization
optimize_per_layer(model, test_loader)
#evaluate(model, test_loader)  # After quantization


Accuracy BEFORE Quantization:

Starting per-layer quantization optimization...
Initial Classification Loss Before Optimization: 0.352729

Optimizing conv1.weight...
Iter 0: recon_loss=0.00000000, class_loss=0.3527, w_min=-0.1654, w_max=0.1546
Iter 10: recon_loss=0.00000000, class_loss=0.3527, w_min=-0.1656, w_max=0.1551
Iter 20: recon_loss=0.00000000, class_loss=0.3527, w_min=-0.1654, w_max=0.1554
Iter 30: recon_loss=0.00000000, class_loss=0.3527, w_min=-0.1653, w_max=0.1551
Iter 40: recon_loss=0.00000000, class_loss=0.3527, w_min=-0.1650, w_max=0.1549
Iter 50: recon_loss=0.00000000, class_loss=0.3527, w_min=-0.1650, w_max=0.1549
Iter 60: recon_loss=0.00000000, class_loss=0.3527, w_min=-0.1651, w_max=0.1549
Iter 70: recon_loss=0.00000000, class_loss=0.3527, w_min=-0.1651, w_max=0.1550
Iter 80: recon_loss=0.00000000, class_loss=0.3527, w_min=-0.1651, w_max=0.1550
Iter 90: recon_loss=0.00000000, class_loss=0.3527, w_min=-0.1650, w_max=0.1550

Optimizing layer1.0.conv1.weight...
Iter 0: 

KeyboardInterrupt: 

In [None]:
torch.cuda.empty_cache()

In [None]:
class MinMaxQuantization(nn.Module):
    def __init__(self, weight, num_levels=2**NUM_BITS, fixed_T=FIXED_T, entropy_budget=None):
        super().__init__()
        self.num_levels = num_levels
        self.fixed_T = fixed_T
        self.entropy_budget = entropy_budget

        w_min_init = weight.min().detach()
        w_max_init = weight.max().detach()
        range_padding = 0.05 * (w_max_init - w_min_init)
        self.w_min = nn.Parameter(w_min_init - range_padding)
        self.w_max = nn.Parameter(w_max_init + range_padding)

    def forward(self, w):
        EPSILON = 1e-6
        w_min_clamped = self.w_min.clamp(max=self.w_max.item() - EPSILON)
        w_max_clamped = self.w_max.clamp(min=w_min_clamped.item() + EPSILON)
        w_normalized = (w - w_min_clamped) / (w_max_clamped - w_min_clamped + EPSILON)

        q_levels = torch.linspace(0, 1, self.num_levels, device=w.device)
        distances = -torch.abs(w_normalized.unsqueeze(-1) - q_levels)
        soft_weights = torch.softmax(distances * self.fixed_T, dim=-1)
        w_quantized = (soft_weights * q_levels).sum(dim=-1)
        w_dequantized = w_quantized * (w_max_clamped - w_min_clamped) + w_min_clamped

        # Entropy (differentiable)
        prob = soft_weights + EPSILON
        entropy = -(prob * prob.log()).sum()

        # Budget penalty
        if self.entropy_budget is not None:
            #budget_penalty = torch.relu(entropy - self.entropy_budget)
            budget_penalty = (entropy / self.entropy_budget)
        else:
            budget_penalty = torch.tensor(0.0, device=w.device)

        return w_dequantized, entropy, budget_penalty


In [None]:
def optimize_per_layer(model, test_loader, num_iterations=100, lr=1e-3, CR_target=10):
    model.eval()
    updated_state_dict = model.state_dict()
    print("\nStarting per-layer quantization optimization...")

    # Use a smaller batch for activation capture + class loss
    data_iterator = iter(test_loader)
    images, labels = next(data_iterator)
    images, labels = images[:2].to(device), labels[:2].to(device)

    # Run one forward pass to trigger activation hooks
    with torch.no_grad():
        model(images)

    # Freeze all weights to save memory
    for p in model.parameters():
        p.requires_grad = False

    total_params = sum(p.numel() for n, p in model.named_parameters() if "conv" in n and "weight" in n)

    for name, param in model.named_parameters():
        if "conv" in name and "weight" in name:
            print(f"\n🔧 Optimizing {name}...")
            layer_name = name.replace(".weight", "")

            if layer_name not in temp_activations or temp_activations[layer_name] is None:
                print(f"⏭️ Skipping {layer_name}: No activation found.")
                continue

            original_weight = param.detach().clone()
            num_weights = original_weight.numel()
            entropy_budget = (32 * num_weights) / CR_target

            # Create quantization module
            quant_layer = MinMaxQuantization(original_weight, entropy_budget=entropy_budget).to(device)
            optimizer = torch.optim.Adam(quant_layer.parameters(), lr=lr)
            mse_loss_fn = nn.MSELoss()

            activation_input = temp_activations[layer_name].detach().clone().to(device)
            del temp_activations[layer_name]  # Release after use

            prev_entropy = None

            for iteration in range(num_iterations):
                optimizer.zero_grad()
                quantized_weight, entropy, budget_penalty = quant_layer(original_weight)

                # Forward through current conv layer
                quantized_output = nn.functional.conv2d(
                    activation_input, quantized_weight,
                    stride=param.shape[2], padding=param.shape[3]
                )

                original_output = nn.functional.conv2d(
                    activation_input, original_weight,
                    stride=param.shape[2], padding=param.shape[3]
                )

                recon_loss = mse_loss_fn(quantized_output, original_output)

                with torch.no_grad():
                    class_loss = nn.CrossEntropyLoss()(model(images), labels)

                total_loss = (
                    0.1 * recon_loss +
                    0.9 * class_loss +
                    0.001 * budget_penalty  # soft constraint
                )

                total_loss.backward()
                optimizer.step()

                if iteration % 10 == 0:
                    delta_entropy = (
                        entropy.item() - prev_entropy if prev_entropy is not None else 0.0
                    )
                    print(f"Iter {iteration:03d}: recon={recon_loss.item():.6f}, "
                          f"class={class_loss.item():.4f}, "
                          f"entropy={entropy.item():.2f}, "
                          f"Δentropy={delta_entropy:.2f}, "
                          f"budget={entropy_budget:.2f}, "
                          f"penalty={budget_penalty.item():.4f}, "
                          f"w_min={quant_layer.w_min.item():.4f}, "
                          f"w_max={quant_layer.w_max.item():.4f}")
                    prev_entropy = entropy.item()

            # Overwrite quantized weights in-place
            with torch.no_grad():
                quantized_w, _, _ = quant_layer(original_weight)
                param.copy_(quantized_w.to(param.device))

            # Cleanup to prevent memory buildup
            del quant_layer, optimizer, activation_input
            del quantized_weight, entropy, budget_penalty
            del quantized_output, original_output
            torch.cuda.empty_cache()

    print("\n✅ Per-layer optimization complete.")


In [None]:
# Load Pretrained ResNet18 from timm
import resnet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = resnet.resnet18(pretrained=False, device=device)
model.to(device)

state_dict = torch.load('/content/resnet18.pt', map_location=torch.device('cpu'))
model.load_state_dict(state_dict, strict=False)
model.eval()
evaluate(model, test_loader)  # Before
optimize_per_layer(model, test_loader)
evaluate(model, test_loader)  # After


Test Accuracy: 86.53%

Starting per-layer quantization optimization...

🔧 Optimizing conv1.weight...
Iter 000: recon=0.000000, class=0.0103, entropy=4539.62, Δentropy=0.00, budget=5529.60, penalty=0.8210, w_min=-0.1649, w_max=0.1551
Iter 010: recon=0.000000, class=0.0103, entropy=4538.21, Δentropy=-1.41, budget=5529.60, penalty=0.8207, w_min=-0.1631, w_max=0.1570
Iter 020: recon=0.000000, class=0.0103, entropy=4538.19, Δentropy=-0.01, budget=5529.60, penalty=0.8207, w_min=-0.1630, w_max=0.1567
Iter 030: recon=0.000000, class=0.0103, entropy=4539.19, Δentropy=1.00, budget=5529.60, penalty=0.8209, w_min=-0.1630, w_max=0.1566
Iter 040: recon=0.000000, class=0.0103, entropy=4538.41, Δentropy=-0.78, budget=5529.60, penalty=0.8207, w_min=-0.1630, w_max=0.1566
Iter 050: recon=0.000000, class=0.0103, entropy=4538.26, Δentropy=-0.14, budget=5529.60, penalty=0.8207, w_min=-0.1629, w_max=0.1568
Iter 060: recon=0.000000, class=0.0103, entropy=4540.37, Δentropy=2.11, budget=5529.60, penalty=0.8211,

In [None]:
pip install constriction




In [None]:
import torch
import numpy as np
from constriction.stream import stack
from collections import Counter

def compress_tensor_with_constriction(tensor):
    # Step 1: Flatten and convert to numpy
    flat = tensor.detach().cpu().view(-1).numpy()

    # Step 2: Quantize values to integer symbols
    unique_vals, inverse = np.unique(flat, return_inverse=True)
    symbols = inverse.astype(np.uint32)  # Shape: [N]
    num_symbols = len(unique_vals)

    # Step 3: Frequency histogram
    counts = np.bincount(symbols, minlength=num_symbols)
    total = np.sum(counts)
    probs = counts / total

    # Step 4: Build CDF for range encoding
    cdf = np.zeros(num_symbols + 1, dtype=np.uint32)
    precision = 16  # Range encoder precision (bits)
    cdf[1:] = np.cumsum((probs * (1 << precision)).astype(np.uint32))
    cdf[-1] = 1 << precision  # Ensure total probability = 2^precision

    # Step 5: Range encode using constriction
    encoder = stack.AnsCoder()
    encoder.encode(symbols.tolist(), cdf.tolist())

    compressed_bytes = encoder.get_compressed().nbytes
    original_bits = flat.nbytes * 8  # float32 input
    compressed_bits = compressed_bytes * 8
    cr = original_bits / compressed_bits

    return cr, original_bits, compressed_bits, unique_vals


ModuleNotFoundError: No module named 'constriction.stream'

In [None]:
total_orig_bits = 0
total_comp_bits = 0

for name, param in model.named_parameters():
    if "conv" in name and "weight" in name:
        ratio, orig_bits, comp_bits, _ = compress_tensor_with_constriction(param)
        total_orig_bits += orig_bits
        total_comp_bits += comp_bits
        print(f"{name}: CR = {ratio:.2f}×, Orig = {orig_bits}, Comp = {comp_bits}")

final_cr = total_orig_bits / total_comp_bits
print(f"\n🎯 Final Model Compression Ratio (with constriction): {final_cr:.2f}×")


In [None]:
pip install range-coder




In [None]:
from range_coder import RangeEncoder, RangeDecoder
import torch
import numpy as np
from collections import Counter
import io

def compress_tensor(tensor):
    flat = tensor.cpu().view(-1).detach().numpy()

    # Step 1: Create integer symbols
    unique_vals, inverse = np.unique(flat, return_inverse=True)
    symbol_sequence = inverse  # indices into unique_vals
    counts = Counter(symbol_sequence)
    total = sum(counts.values())

    # Step 2: Normalize frequencies
    freqs = {sym: max(1, int((cnt / total) * 10000)) for sym, cnt in counts.items()}
    sym_list = list(freqs.keys())
    freq_list = [freqs[sym] for sym in sym_list]

    # Step 3: Range encode
    encoder = RangeEncoder()
    compressed = io.BytesIO()
    encoder.encode(freq_list, symbol_sequence.tolist(), compressed)

    compressed_bytes = compressed.getbuffer().nbytes
    original_bits = flat.nbytes * 8  # float32
    compressed_bits = compressed_bytes * 8

    ratio = original_bits / compressed_bits
    return ratio, original_bits, compressed_bits


In [None]:
total_original_bits, total_compressed_bits = 0, 0

for name, param in model.named_parameters():
    if "conv" in name and "weight" in name:
        ratio, orig_bits, comp_bits = compress_tensor(param)
        total_original_bits += orig_bits
        total_compressed_bits += comp_bits
        print(f"{name}: CR = {ratio:.2f}, Orig = {orig_bits}, Compressed = {comp_bits}")

overall_ratio = total_original_bits / total_compressed_bits
print(f"\n🌟 Final Compression Ratio (Range Encoded): {overall_ratio:.2f}×")


TypeError: function missing required argument 'filepath' (pos 1)

In [None]:
pip install zstandard




In [None]:
import torch
import zstandard as zstd
import numpy as np
import io

def compress_tensor_with_zstd(tensor, dtype=torch.float16, compression_level=3):
    # Convert to lower precision and serialize
    tensor = tensor.detach().cpu().to(dtype)
    byte_buffer = io.BytesIO()
    np.save(byte_buffer, tensor.numpy(), allow_pickle=False)
    raw_bytes = byte_buffer.getvalue()

    # Compress with Zstandard
    compressor = zstd.ZstdCompressor(level=compression_level)
    compressed = compressor.compress(raw_bytes)

    # Sizes
    original_bits = len(raw_bytes) * 8
    compressed_bits = len(compressed) * 8
    compression_ratio = original_bits / compressed_bits

    return compression_ratio, original_bits, compressed_bits, compressed


In [None]:
total_orig_bits = 0
total_comp_bits = 0

for name, param in model.named_parameters():
    if "conv" in name and "weight" in name:
        cr, ob, cb, _ = compress_tensor_with_zstd(param, dtype=torch.float16)
        total_orig_bits += ob
        total_comp_bits += cb
        print(f"{name}: Zstd CR = {cr:.2f}×, Orig = {ob}, Comp = {cb}")

final_cr = total_orig_bits / total_comp_bits
print(f"\n📦 Final Model Compression Ratio with Zstd: {final_cr:.2f}×")


conv1.weight: Zstd CR = 1.05×, Orig = 28672, Comp = 27192
layer1.0.conv1.weight: Zstd CR = 1.07×, Orig = 590848, Comp = 551912
layer1.0.conv2.weight: Zstd CR = 1.09×, Orig = 590848, Comp = 544280
layer1.1.conv1.weight: Zstd CR = 1.08×, Orig = 590848, Comp = 545144
layer1.1.conv2.weight: Zstd CR = 1.08×, Orig = 590848, Comp = 544992
layer2.0.conv1.weight: Zstd CR = 1.09×, Orig = 1180672, Comp = 1083992
layer2.0.conv2.weight: Zstd CR = 1.09×, Orig = 2360320, Comp = 2166728
layer2.1.conv1.weight: Zstd CR = 1.09×, Orig = 2360320, Comp = 2169616
layer2.1.conv2.weight: Zstd CR = 1.09×, Orig = 2360320, Comp = 2171408
layer3.0.conv1.weight: Zstd CR = 1.09×, Orig = 4719616, Comp = 4335304
layer3.0.conv2.weight: Zstd CR = 1.09×, Orig = 9438208, Comp = 8687568
layer3.1.conv1.weight: Zstd CR = 1.08×, Orig = 9438208, Comp = 8711256
layer3.1.conv2.weight: Zstd CR = 1.08×, Orig = 9438208, Comp = 8704256
layer4.0.conv1.weight: Zstd CR = 1.09×, Orig = 18875392, Comp = 17341664
layer4.0.conv2.weight: Zs

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.utils.data as data
import timm
import numpy as np

# Hyperparameters
BATCH_SIZE = 128
NUM_BITS = 4
FIXED_T = 100.5
LR = 0.001
NUM_ITERATIONS = 100
CR_target = 10  # Global compression ratio target

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
test_loader = data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Model
# Load Pretrained ResNet18 from timm
import resnet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = resnet.resnet18(pretrained=False, device=device)
model.to(device)

state_dict = torch.load('/content/resnet18.pt', map_location=torch.device('cpu'))
model.load_state_dict(state_dict, strict=False)
model.eval()

# Evaluation function
def evaluate(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f"🎯 Accuracy: {100 * correct / total:.2f}%")

# Activation storage
temp_activations = {}
def activation_hook(layer_name):
    def hook(module, input, output):
        temp_activations[layer_name] = input[0].detach().clone()
    return hook

# Register hooks
for name, layer in model.named_modules():
    if isinstance(layer, nn.Conv2d):
        layer.register_forward_hook(activation_hook(name))
layer_soft_probs = {}
# Quantization Module
class MinMaxQuantization(nn.Module):
    def __init__(self, weight, num_levels=2**NUM_BITS, fixed_T=FIXED_T, entropy_budget=None):
        super().__init__()
        self.num_levels = num_levels
        self.fixed_T = fixed_T
        self.entropy_budget = entropy_budget

        w_min_init = weight.min().detach()
        w_max_init = weight.max().detach()
        pad = 0.05 * (w_max_init - w_min_init)
        self.w_min = nn.Parameter(w_min_init - pad)
        self.w_max = nn.Parameter(w_max_init + pad)

    def forward(self, w):
        EPS = 1e-6
        w_min = self.w_min.clamp(max=self.w_max.item() - EPS)
        w_max = self.w_max.clamp(min=w_min.item() + EPS)
        w_norm = (w - w_min) / (w_max - w_min + EPS)

        q_levels = torch.linspace(0, 1, self.num_levels, device=w.device)
        dists = -torch.abs(w_norm.unsqueeze(-1) - q_levels)
        soft_probs = torch.softmax(dists * self.fixed_T, dim=-1)
        w_q = (soft_probs * q_levels).sum(dim=-1)
        w_deq = w_q * (w_max - w_min) + w_min

        #entropy = -(soft_probs + EPS).mul((soft_probs + EPS).log()).sum()
        # Entropy across quantization bins
        bin_mass = soft_probs.sum(dim=0)            # [num_bins]
        bin_probs = bin_mass / bin_mass.sum()
        entropy = -(bin_probs * (bin_probs + EPS).log()).sum()

        budget_penalty = (entropy / (self.entropy_budget + EPS)) ** 2
        return w_deq, entropy, budget_penalty, soft_probs

# Optimization
def optimize_per_layer(model, test_loader):
    model.eval()
    for p in model.parameters():
        p.requires_grad = False

    total_params = sum(p.numel() for n, p in model.named_parameters() if "conv" in n and "weight" in n)

    # Get batch
    data_iterator = iter(test_loader)
    images, labels = next(data_iterator)
    images, labels = images[:2].to(device), labels[:2].to(device)

    with torch.no_grad():
        model(images)

    for name, param in model.named_parameters():
        if "conv" in name and "weight" in name:
            print(f"\n🔧 Optimizing {name}...")
            layer_name = name.replace(".weight", "")
            if layer_name not in temp_activations:
                print(f"⏭️ Skipping {layer_name} (no activation).")
                continue

            original_weight = param.detach().clone()
            num_weights = original_weight.numel()
            entropy_budget = (32 * num_weights) / CR_target

            quant_layer = MinMaxQuantization(original_weight, entropy_budget=entropy_budget).to(device)
            optimizer = optim.Adam(quant_layer.parameters(), lr=LR)
            mse_loss_fn = nn.MSELoss()
            activation_input = temp_activations[layer_name].detach().clone().to(device)
            del temp_activations[layer_name]

            prev_entropy = None
            original_param_data = param.data.clone()
            prev_class_loss = float('inf')
            for iteration in range(NUM_ITERATIONS):
                optimizer.zero_grad()
                q_weight, entropy, penalty, soft_probs  = quant_layer(original_weight)

                q_out = nn.functional.conv2d(activation_input, q_weight,
                                             stride=param.shape[2], padding=param.shape[3])
                o_out = nn.functional.conv2d(activation_input, original_weight,
                                             stride=param.shape[2], padding=param.shape[3])
                recon_loss = mse_loss_fn(q_out, o_out)

                with torch.no_grad():
                    param.data = q_weight.detach()
                    outputs = model(images)
                    class_loss = nn.CrossEntropyLoss()(outputs, labels)
                    param.data = original_param_data  # restore
                    #class_loss = nn.CrossEntropyLoss()(model(images), labels)

                if 0.2 < class_loss:
                  break
                prev_class_loss = class_loss

                total_loss = 0.1 * recon_loss + 0.9 * class_loss + 0.1 * entropy
                total_loss.backward()
                optimizer.step()

                if iteration % 10 == 0:
                    delta_entropy = entropy.item() - prev_entropy if prev_entropy is not None else 0.0
                    print(f"Iter {iteration:03d}: recon={recon_loss.item():.6f}, "
                          f"class={class_loss.item():.4f}, entropy={entropy.item():.2f}, "
                          f"Δentropy={delta_entropy:.2f}, total_loss={total_loss:.4f}, "
                          f"penalty={penalty.item():.4f}, "
                          f"w_min={quant_layer.w_min.item():.4f}, w_max={quant_layer.w_max.item():.4f}")
                    prev_entropy = entropy.item()

            with torch.no_grad():
                layer_soft_probs[name] = soft_probs.detach().cpu()
                final_weight, _, _,_ = quant_layer(original_weight)
                param.copy_(final_weight.to(param.device))

            del quant_layer, optimizer, activation_input
            del q_weight, entropy, penalty, q_out, o_out
            torch.cuda.empty_cache()

            # 🧠; Refresh activations after every residual block
            if 1:#:any(name.startswith(f"layer{l}.0.conv1") for l in [1, 2, 3, 4]):
                data_iterator = iter(test_loader)
                images, labels = next(data_iterator)
                images, labels = images[:2].to(device), labels[:2].to(device)
                with torch.no_grad():
                    model(images)
                print(f"🔁 Refreshed activations after {name}")

    print("\n✅ Per-layer optimization complete.")


In [None]:
evaluate(model, test_loader)        # Accuracy before quantization
optimize_per_layer(model, test_loader)
evaluate(model, test_loader)        # Accuracy after quantization


🎯 Accuracy: 86.53%

🔧 Optimizing conv1.weight...
Iter 000: recon=0.000457, class=0.0098, entropy=5.08, Δentropy=0.00, total_loss=0.5169, penalty=0.0000, w_min=-0.1669, w_max=0.1531
Iter 010: recon=0.000352, class=0.0099, entropy=5.03, Δentropy=-0.05, total_loss=0.5120, penalty=0.0000, w_min=-0.1766, w_max=0.1618
Iter 020: recon=0.000508, class=0.0098, entropy=4.99, Δentropy=-0.04, total_loss=0.5074, penalty=0.0000, w_min=-0.1860, w_max=0.1718
Iter 030: recon=0.000556, class=0.0098, entropy=4.95, Δentropy=-0.04, total_loss=0.5034, penalty=0.0000, w_min=-0.1954, w_max=0.1803
Iter 040: recon=0.000454, class=0.0099, entropy=4.91, Δentropy=-0.04, total_loss=0.4997, penalty=0.0000, w_min=-0.2051, w_max=0.1870
Iter 050: recon=0.000334, class=0.0100, entropy=4.87, Δentropy=-0.03, total_loss=0.4964, penalty=0.0000, w_min=-0.2144, w_max=0.1934
Iter 060: recon=0.000285, class=0.0102, entropy=4.84, Δentropy=-0.03, total_loss=0.4933, penalty=0.0000, w_min=-0.2234, w_max=0.2002
Iter 070: recon=0.000

In [None]:
def quantize_and_compress_with_zstd(weight_tensor, soft_probs, num_bits=4, zstd_level=5):
    assert soft_probs.shape[-1] == 2 ** num_bits, "Soft probs must match number of bins"
    quantized = soft_probs.argmax(dim=-1).cpu().numpy().astype(np.uint8)

    flat = quantized.flatten()
    if len(flat) % 2 != 0:
        flat = np.append(flat, 0)

    packed = np.bitwise_or(flat[0::2] << 4, flat[1::2])
    packed_bytes = packed.tobytes()

    compressor = zstd.ZstdCompressor(level=zstd_level)
    compressed = compressor.compress(packed_bytes)

    original_bits = quantized.size * num_bits
    compressed_bits = len(compressed) * 8
    cr = original_bits / compressed_bits

    return cr, compressed
print("\n📦 Final Compression Ratios (4-bit + Zstd):")
for name, param in model.named_parameters():
    if name in layer_soft_probs:
        soft_probs = layer_soft_probs[name]
        cr, compressed = quantize_and_compress_with_zstd(param.data, soft_probs)
        print(f"{name}: CR = {cr:.2f}×, Size = {len(compressed)} bytes")





📦 Final Compression Ratios (4-bit + Zstd):
conv1.weight: CR = 1.71×, Size = 505 bytes
layer1.0.conv1.weight: CR = 3.97×, Size = 4641 bytes
layer1.0.conv2.weight: CR = 3.08×, Size = 5994 bytes
layer1.1.conv1.weight: CR = 3.37×, Size = 5467 bytes
layer1.1.conv2.weight: CR = 3.99×, Size = 4616 bytes
layer2.0.conv1.weight: CR = 3.17×, Size = 11611 bytes
layer2.0.conv2.weight: CR = 3.31×, Size = 22268 bytes
layer2.1.conv1.weight: CR = 3.29×, Size = 22394 bytes
layer2.1.conv2.weight: CR = 4.41×, Size = 16732 bytes
layer3.0.conv1.weight: CR = 6.80×, Size = 21688 bytes
layer3.0.conv2.weight: CR = 9.07×, Size = 32507 bytes
layer3.1.conv1.weight: CR = 22.01×, Size = 13401 bytes
layer3.1.conv2.weight: CR = 11.64×, Size = 25344 bytes
layer4.0.conv1.weight: CR = 2.67×, Size = 220622 bytes
layer4.0.conv2.weight: CR = 7.90×, Size = 149364 bytes
layer4.1.conv1.weight: CR = 7.64×, Size = 154500 bytes
layer4.1.conv2.weight: CR = 8.43×, Size = 139983 bytes
