## Static Quantization with focus on placement of Stubs

This notebook provides final codes for static quantization. Although dynamic quantization was also tried but we conclude this as working methology rather.


In [None]:
from dquartic.model.unet1d import UNet1d
from dquartic.utils.data_loader import DIAMSDataset
import torch
import torch.quantization
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.ao.quantization as tq

In [None]:
import gc
gc.collect(    
)
#restarts kernal and clears gpu memory. Please run this cell if you find ooM errors and gpu usage is high even after restarting kernal.

14076

In [None]:
#Load your files, in this case i am loading npy files but process would be same for parquet files too
ms1_file = "npy/ms1_data_int32.npy"
ms2_file = "npy/ms2_data_cat_int32.npy" 

In [None]:
dataset = DIAMSDataset(ms1_file=ms1_file, ms2_file=ms2_file, normalize="minmax")
data_loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1)
data_iter = iter(data_loader)
ms2_1, ms1_1, ms2_2,ms1_2 = next(data_iter)

Info: Loaded 520 MS2 slice samples and 520 MS1 slice samples from NPY files.


Below is model configuration used for Original model checkpoint

In [None]:
model_init={
              "dim": 4,
      "channels": 1,
      "dim_mults": [
        1,
        2,
        2,
        2,
        4,
        4,
        4
      ],
      "conditional": True,
      "init_cond_channels": 1,
      "attn_cond_channels": 1,
      "tfer_dim_mult": 620,
      "downsample_dim": 40000,
      "simple": True
}

After in depth experimentation, it was found that in static quantization in pytorch, only specifc layers are supported like Conv1d,etc so it doesnt matter how you place Stubs, like here we are passing them through simple forward process in Unet itself. Now what will happen is that only supported layers will be effectively quantised and dequantised.

In [None]:

# WRAPPER FOR QUANTIZATION (QuantStub/DeQuantStub)
# only gonna quantise supported layers so doesnt matter 
class QuantizedUNet1dWrapper(nn.Module):
    """
    Wrap the UNet1d so that we can do static quantization (Eager Mode).
    This inserts:
      - QuantStub at input
      - DeQuantStub at output
    Then we let PyTorch’s built-in prepare() and convert() do the rest.
    """
    def __init__(self, unet: nn.Module):
        super().__init__()
        self.quant = tq.QuantStub()
        self.model = unet
        self.dequant = tq.DeQuantStub()

    def forward(self, x, time, init_cond=None, attn_cond=None):
        # Quantize the input
        x = self.quant(x)
        # Forward pass through your original U-Net
        out = self.model(x, time, init_cond, attn_cond)
        # Dequantize the output
        out = self.dequant(out)
        return out


In [None]:
 # Build original UNet
unet = UNet1d(
        **model_init
    ).to('cuda')

model = QuantizedUNet1dWrapper(unet)

# Force per-tensor affine for both activation & weight but not much significant if we want very compressed size and weights
per_tensor_qconfig = tq.QConfig(
    activation=tq.HistogramObserver.with_args(
        qscheme=torch.per_tensor_affine,
        reduce_range=False
    ),
    weight=tq.HistogramObserver.with_args(
        qscheme=torch.per_tensor_affine,
        dtype=torch.qint8,
        reduce_range=False
    ),
)

model.qconfig = per_tensor_qconfig

In [None]:
tq.prepare(model, inplace=True)
#Calibrate for static quantization
# You need some representative data to run through the model in eval mode.
model.eval()

QuantizedUNet1dWrapper(
  (quant): QuantStub(
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (model): UNet1d(
    (init_conv): Conv1d(
      2, 4, kernel_size=(7,), stride=(1,), padding=(3,)
      (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
    )
    (time_mlp): Sequential(
      (0): SinusoidalPosEmb()
      (1): Linear(
        in_features=4, out_features=16, bias=True
        (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
      )
      (2): GELU(approximate='none')
      (3): Linear(
        in_features=16, out_features=16, bias=True
        (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
      )
    )
    (init_cond_proj): ConditionalScaleShift(
      (to_scale_shift): Sequential(
        (0): SiLU()
        (1): Linear(
          in_features=16, out_features=2, bias=True
          (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
        )
      )
 

In [None]:
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"Parameter: {name}, Value: {param.data}")

Parameter: model.init_conv.weight, Value: tensor([[[-0.1598,  0.2116, -0.2439,  0.0442,  0.0973, -0.1735,  0.1961],
         [-0.0621,  0.1701, -0.0554,  0.0978,  0.0060, -0.0044,  0.1072]],

        [[ 0.2516,  0.1461, -0.2599,  0.2404, -0.2043,  0.2586, -0.2442],
         [ 0.2524,  0.0733, -0.0006,  0.1177, -0.0262,  0.0623, -0.0380]],

        [[-0.0697,  0.0896,  0.1742,  0.2116, -0.0023, -0.2453, -0.0146],
         [-0.1664, -0.1529,  0.1832, -0.0101,  0.2246, -0.1536, -0.2020]],

        [[ 0.1580,  0.0400, -0.2502,  0.2476, -0.1447, -0.1175, -0.1621],
         [ 0.0909,  0.0851, -0.2309,  0.0562,  0.0414,  0.0204, -0.2634]]],
       device='cuda:0')
Parameter: model.init_conv.bias, Value: tensor([-0.1324,  0.1871,  0.2254, -0.1771], device='cuda:0')
Parameter: model.time_mlp.1.weight, Value: tensor([[ 0.0625, -0.2258, -0.1206, -0.1757],
        [ 0.4835, -0.2103, -0.2763, -0.2715],
        [-0.4072,  0.2352, -0.1956, -0.4070],
        [-0.1788,  0.3326, -0.2065, -0.3320],
     

Just a brief comparision of normal UNET and Quantised Unet parameters

In [None]:
from torchinfo import summary
summary(model)#its the quantised model

Layer (type:depth-idx)                                       Param #
QuantizedUNet1dWrapper                                       --
├─Quantize: 1-1                                              --
├─UNet1d: 1-2                                                --
│    └─Conv1d: 2-1                                           --
│    └─Sequential: 2-2                                       --
│    │    └─SinusoidalPosEmb: 3-1                            --
│    │    └─Linear: 3-2                                      --
│    │    └─GELU: 3-3                                        --
│    │    └─Linear: 3-4                                      --
│    └─ConditionalScaleShift: 2-3                            --
│    │    └─Sequential: 3-5                                  --
│    └─ModuleList: 2-4                                       --
│    │    └─Identity: 3-6                                    --
│    │    └─Sequential: 3-7                                  --
│    └─ModuleList: 2-5             

In [None]:
m = UNet1d(**model_init)
summary(m)

Layer (type:depth-idx)                             Param #
UNet1d                                             --
├─Conv1d: 1-1                                      60
├─Sequential: 1-2                                  --
│    └─SinusoidalPosEmb: 2-1                       --
│    └─Linear: 2-2                                 80
│    └─GELU: 2-3                                   --
│    └─Linear: 2-4                                 272
├─ConditionalScaleShift: 1-3                       --
│    └─Sequential: 2-5                             --
│    │    └─SiLU: 3-1                              --
│    │    └─Linear: 3-2                            34
├─ModuleList: 1-4                                  --
│    └─Identity: 2-6                               --
│    └─Sequential: 2-7                             --
│    │    └─Conv1d: 3-3                            64
│    │    └─GELU: 3-4                              --
│    │    └─Conv1d: 3-5                            72
├─ModuleList: 1-5     

Below code prevents any deserialisation error in pytorch

In [None]:
import numpy as np
torch.serialization.add_safe_globals([np.dtype])
from torch.serialization import add_safe_globals
add_safe_globals([np.core.multiarray.scalar])

In [None]:
path='dquartic/model/best_model.pth'
checkpoint = torch.load(path, map_location="cuda",weights_only=False)

In [None]:
if "model_state_dict" in checkpoint:
    state_dict = checkpoint["model_state_dict"]
    # Load the state dict into the model
    model.load_state_dict(state_dict,strict=False)
else:
    print("Error: Checkpoint doesn't contain 'model_state_dict'")

# # Assign a quantization config
# model.qconfig = tq.get_default_qconfig("fbgemm")
# tq.prepare(model, inplace=True)

In [None]:
# Calibrate with a small sample of data from your dataloader
torch.manual_seed(42)
device='cuda'
# model_prepared.eval()
with torch.no_grad():
    calibration_samples = 0
    for i, (ms2_1, ms1_1, ms2_2, ms1_2) in enumerate(data_loader):
        if i >= 2:  # Limit to a small number of batches for calibration
            break
            
        # Move tensors to the right device
        ms2_1 = ms2_1.to(device)
        ms1_1 = ms1_1.to(device)
        ms2_2 = ms2_2.to(device)
        
        # Create mixture conditioning (as done in training)
        mixture_weights = (0.5, 0.5)
        ms2_cond = (ms2_1 * mixture_weights[0]) + (ms2_2 * mixture_weights[1])
        
        # Create random timesteps for batch
        batch_size = ms2_1.shape[0]
        timestep = torch.randint(0, 1000, (batch_size,), device=device)
        
        # Add noise to ms2_1 to simulate diffusion process
        noise = torch.randn_like(ms2_1, device=device)
        alpha = 0.7  # Using a fixed alpha for simplicity
        noisy_ms2 = math.sqrt(alpha) * ms2_1 + math.sqrt(1-alpha) * noise
        
        # Pass through model for calibration
        _ = model(noisy_ms2, timestep, ms2_cond, ms1_1)
        
        calibration_samples += len(ms2_1)
        print(f"Calibrated with {calibration_samples} samples so far")

# Convert the calibrated model to quantized form


Calibrated with 1 samples so far
Calibrated with 2 samples so far


In [None]:

model_int8 = tq.convert(model, inplace=False)
    
#  Save the quantized weights with all checkpoint components to match our original saved checkpoint while loading it with model_interface.py
quantized_checkpoint = {
        'epoch': checkpoint['epoch'],
        'model_state_dict': model_int8.state_dict(),
        'optimizer_state_dict': checkpoint['optimizer_state_dict'],
        'scheduler_state_dict': checkpoint['scheduler_state_dict'],
        'best_loss': checkpoint['best_loss']
    }
torch.save(quantized_checkpoint, "my_unet_checkpoint_int8.pth")
print("Successfully saved quantized model")


Successfully saved quantized model


We saw a decrease in size from original 14gb to 10gb but its only because supported layers,weights were quantised but not everything else. It would be safe to say that we can try
Pytorch Fx graph, pytorch 2 to expand quantization to more layers for deeper quantisation on our custom modules(BIGGEST HINDERANCE) 