In [19]:
import os
import sys
import time
import torch
import numpy as np
import json
import torch.nn as nn

from utils.timingtools import now
from exp_vgg import validate_vgg19_cifar10

In [20]:
class ScalarQuantizer:
    def quantize_tensor(self, tensor, n_bits, symmetric=True):
        """Quantize tensor to specified bit precision"""
        if symmetric:
            max_val = torch.max(torch.abs(tensor))
            scale = max_val / (2**(n_bits-1) - 1) if max_val > 0 else 1.0
            quantized = torch.clamp(torch.round(tensor / scale), 
                                  -2**(n_bits-1), 2**(n_bits-1)-1)
            return quantized * scale
        else:
            min_val = torch.min(tensor)
            max_val = torch.max(tensor)
            scale = (max_val - min_val) / (2**n_bits - 1) if (max_val - min_val) > 0 else 1.0
            quantized = torch.clamp(torch.round((tensor - min_val) / scale), 
                                  0, 2**n_bits-1)
            return quantized * scale + min_val

    def quantize_kernels(self, kernels, n_bits=4):
        """Quantize all kernels to specified bit precision"""
        quantized_kernels = torch.zeros_like(kernels)
        n = kernels.shape[0]
        
        with torch.no_grad():
            for i in range(n):
                quantized_kernels[i] = self.quantize_tensor(kernels[i], n_bits)
                
        # Calculate MSE
        mse = torch.mean((kernels - quantized_kernels) **2).item()
        return quantized_kernels, mse

In [21]:
def compress_torch_weights(model_path, compressed_model_path, n_bits=4):
    """
    Apply scalar quantization to VGG model weights
    Save compressed weights to disk
    """
    # Load model weights
    checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
    weights = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint

    # Initialize quantizer
    quantizer = ScalarQuantizer()
    
    # Concatenate all weight kernels
    from weights import concat_weights  # Reuse weight concatenation function
    K = concat_weights(weights)
    N = K.shape[0]
    print(f"Total kernels: {N}, Quantization bits: {n_bits}")

    # Apply quantization
    quantized_K, mse = quantizer.quantize_kernels(K, n_bits)
    print(f"Quantization MSE: {mse:.6f}")

    # Reconstruct compressed weights dictionary
    from weights import reassign_weights  # Reuse weight reassignment function
    compressed_dict = reassign_weights(quantized_K, weights)
    compressed_weights = dict(
        state_dict=compressed_dict
    )
    
    # Save compressed model
    os.makedirs(os.path.dirname(compressed_model_path), exist_ok=True)
    torch.save(compressed_weights, compressed_model_path)
    
    # Validate quantized model
    top1_avg = validate_vgg19_cifar10(compressed_model_path)

    # Save results
    result_dict = dict(
        timestamp=now(),
        precision=top1_avg,
        quantization_bits=n_bits,
        quantization_mse=mse,
        model_name='vgg19',
        task='classification',
        dataset='cifar10',
        compression_type_base='scalar_quantization',
        compressed_model_name=compressed_model_path
    )
    
    save_dir = './results'
    os.makedirs(save_dir, exist_ok=True)
    json_file_path = os.path.join(save_dir, f'result_{n_bits}bits_scalar.json')
    with open(json_file_path, 'w') as f:
        json.dump(result_dict, f, indent=4)
    
    return top1_avg, mse

In [22]:
def simulate_activation_quantization(model, n_bits=4):
    """Add hooks to simulate activation quantization in VGG model"""
    quantizer = ScalarQuantizer()
    hooks = []
    
    def quant_hook(module, input, output):
        if isinstance(output, torch.Tensor):
            return quantizer.quantize_tensor(output, n_bits)
        return output
    
    # Add hooks to convolutional and linear layers
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            hooks.append(module.register_forward_hook(quant_hook))
    
    return model, hooks

In [23]:
def main():
    model_path = 'vgg19.pth'  
    
    # set quantization bits directly in the notebook
    n_bits = 3  
    
    print(f"Starting main quantization process...")
    print(f"Using {n_bits}-bit quantization")
    
    compressed_model_path = f'./compressed_models/vgg19_per_input_channel_{n_bits}bit.pth'

    t0 = time.time()
    
    print("\n" + "="*50)
    print("Starting weight compression")
    print("="*50)
    
    acc, mse = compress_torch_weights(
        model_path, compressed_model_path, 
        n_bits=n_bits  
    )
    
    t1 = time.time()
    
    print("\n" + "="*50)
    print("Quantization Summary")
    print("="*50)
    print(f'Processing time: {t1 - t0:.2f} seconds')
    print(f'Quantization bits: {n_bits}')
    print(f'Top-1 Accuracy: {acc:.4f}')
    print(f'Quantization MSE: {mse:.6f}')

print("Main function defined!")

Main function defined!


In [24]:
if __name__ == '__main__':
    main()

Starting main quantization process...
Using 3-bit quantization

Starting weight compression
Total kernels: 2224320, Quantization bits: 3
Quantization MSE: 0.000009
reassign features.0.weight torch.Size([64, 3, 3, 3])
reassign features.2.weight torch.Size([64, 64, 3, 3])
reassign features.5.weight torch.Size([128, 64, 3, 3])
reassign features.7.weight torch.Size([128, 128, 3, 3])
reassign features.10.weight torch.Size([256, 128, 3, 3])
reassign features.12.weight torch.Size([256, 256, 3, 3])
reassign features.14.weight torch.Size([256, 256, 3, 3])
reassign features.16.weight torch.Size([256, 256, 3, 3])
reassign features.19.weight torch.Size([512, 256, 3, 3])
reassign features.21.weight torch.Size([512, 512, 3, 3])
reassign features.23.weight torch.Size([512, 512, 3, 3])
reassign features.25.weight torch.Size([512, 512, 3, 3])
reassign features.28.weight torch.Size([512, 512, 3, 3])
reassign features.30.weight torch.Size([512, 512, 3, 3])
reassign features.32.weight torch.Size([512, 512



Test: [0/79]	Time 13.280 (13.280)	Loss 0.5412 (0.5412)	Prec@1 89.844 (89.844)
Test: [10/79]	Time 0.435 (1.577)	Loss 0.7209 (0.6308)	Prec@1 89.844 (89.844)
Test: [20/79]	Time 0.451 (1.042)	Loss 0.6206 (0.7004)	Prec@1 89.062 (88.876)
Test: [30/79]	Time 0.466 (0.854)	Loss 0.8140 (0.7192)	Prec@1 89.844 (89.138)
Test: [40/79]	Time 0.477 (0.759)	Loss 0.7260 (0.6935)	Prec@1 85.156 (89.310)
Test: [50/79]	Time 0.436 (0.704)	Loss 0.4572 (0.6753)	Prec@1 89.062 (89.369)
Test: [60/79]	Time 0.476 (0.668)	Loss 0.9168 (0.6834)	Prec@1 85.938 (89.344)
Test: [70/79]	Time 0.519 (0.646)	Loss 0.5485 (0.6865)	Prec@1 92.188 (89.283)
 * Prec@1 89.340

Quantization Summary
Processing time: 152.01 seconds
Quantization bits: 3
Top-1 Accuracy: 89.3400
Quantization MSE: 0.000009
