In [1]:
import torch
from transformers import Blip2Processor, Blip2ForConditionalGeneration
from datasets import COCODataset
from datasets import COCODataset
from tqdm import tqdm
from PIL import Image
from torch.utils.data import DataLoader

# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load BLIP-2 model and processor
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b")

# Ensure the model is on the correct device
model = model.to(device)

# Load COCO dataset
coco_dataset = COCODataset(ann_file='./data/coco/annotations/captions_val2017.json',
                           img_dir='./data/coco/val2017')



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

loading annotations into memory...
Done (t=0.04s)
creating index...
index created!


In [26]:
import torch
from typing import List, Tuple, Callable, Union
from enum import Enum, auto
from torch import nn, Tensor

class ModelPart(Enum):
    VIT = auto()
    QFORMER = auto()
    LLM = auto()

class LayerGroup(Enum):
    FIRST = auto()
    MIDDLE = auto()
    LAST = auto()
    ALL = auto()

class LayerType(Enum):
    MLP = auto()
    ATTENTION = auto()
    BOTH = auto()

class QuantConfig:
    def __init__(self, 
                 model_part: ModelPart,
                 layer_group: LayerGroup,
                 layer_type: LayerType,
                 quant_function: Callable,
                 num_bits: int):
        self.model_part = model_part
        self.layer_group = layer_group
        self.layer_type = layer_type
        self.quant_function = quant_function
        self.num_bits = num_bits

class BlipQuantizer:
    def __init__(self, model: nn.Module):
        self.model = model
        self.num_bits = 0

    def apply_quantization(self, configs: List[QuantConfig]):
        for config in configs:
            self._quantize_part(config)

    def _quantize_part(self, config: QuantConfig):
        if config.model_part == ModelPart.VIT:
            layers = self.model.vision_model.encoder.layers
        elif config.model_part == ModelPart.QFORMER:
            layers = self.model.qformer.encoder.layer
        else:  # LLM
            layers = self.model.language_model.model.decoder.layers

        total_layers = len(layers)
        start, end = self._get_layer_range(config.layer_group, total_layers)

        self.num_bits = config.num_bits
        print(f"running {self.num_bits} quant")
        bit_quant_function = config.quant_function(config.num_bits)

        for layer in layers[start:end]:
            if config.layer_type in [LayerType.MLP, LayerType.BOTH]:
                self._quantize_mlp(layer, bit_quant_function)
            if config.layer_type in [LayerType.ATTENTION, LayerType.BOTH]:
                self._quantize_attention(layer, bit_quant_function)

    def _get_layer_range(self, group: LayerGroup, total_layers: int):
        if group == LayerGroup.FIRST:
            return 0, total_layers // 3
        elif group == LayerGroup.LAST:
            return 2 * total_layers // 3, total_layers
        elif group == LayerGroup.MIDDLE:
            return total_layers // 3, 2 * total_layers // 3
        else:  # ALL
            return 0, total_layers

    def _quantize_mlp(self, layer: nn.Module, quant_function: Callable):
        if hasattr(layer, 'mlp'):
            self._quantize_linear(layer.mlp.fc1, quant_function)
            self._quantize_linear(layer.mlp.fc2, quant_function)
        elif hasattr(layer, 'fc1') and hasattr(layer, 'fc2'):
            self._quantize_linear(layer.fc1, quant_function)
            self._quantize_linear(layer.fc2, quant_function)

    def _quantize_attention(self, layer: nn.Module, quant_function: Callable):
        if hasattr(layer, 'self_attn'):
            if hasattr(layer.self_attn, 'qkv'):
                self._quantize_linear(layer.self_attn.qkv, quant_function)
            if hasattr(layer.self_attn, 'projection'):
                self._quantize_linear(layer.self_attn.projection, quant_function)
        elif hasattr(layer, 'attention'):
            if hasattr(layer.attention, 'attention'):
                self._quantize_linear(layer.attention.attention.query, quant_function)
                self._quantize_linear(layer.attention.attention.key, quant_function)
                self._quantize_linear(layer.attention.attention.value, quant_function)
            if hasattr(layer.attention, 'output'):
                self._quantize_linear(layer.attention.output.dense, quant_function)
        elif hasattr(layer, 'k_proj'):
            self._quantize_linear(layer.k_proj, quant_function)
            self._quantize_linear(layer.v_proj, quant_function)
            self._quantize_linear(layer.q_proj, quant_function)
            self._quantize_linear(layer.out_proj, quant_function)

    def _quantize_linear(self, module: nn.Module, quant_function: Callable):
        if hasattr(module, 'weight') and isinstance(module.weight, torch.Tensor):
            module.weight.data = quant_function(module.weight.data)
            module.quantized = True  # Add this line
            module.num_bits = self.num_bits  # Assuming num_bits is the first default argument
        if hasattr(module, 'bias') and isinstance(module.bias, torch.Tensor):
            module.bias.data = quant_function(module.bias.data)

    def count_quantized_layers(self):
        count = 0
        for name, module in self.model.named_modules():
            if hasattr(module, 'quantized') and module.quantized:
                count += 1
        return count

def print_model_structure(model, indent=0):
    for name, module in model.named_children():
        print('  ' * indent + name + ': ' + module.__class__.__name__, end='')
        if hasattr(module, 'quantized'):
            print(f" (Quantized: {module.num_bits} bits)", end='')
        print()
        if len(list(module.children())) > 0:
            print_model_structure(module, indent + 1)

def quant_function(x: Tensor, num_bits):
    min_val = x.min()
    max_val = x.max()
    
    alpha = max_val - min_val
    x = (x - min_val) / alpha
    
    scale = (2**num_bits - 1)
    result = (scale * x).round()
    result /= scale
    
    result = alpha * result + min_val
    
    return result

In [28]:
from quant_functions import uniform_quantization

quantizer = BlipQuantizer(model)

configs = [
    QuantConfig(ModelPart.VIT, LayerGroup.ALL, LayerType.BOTH, 
                uniform_quantization, num_bits=4),
    QuantConfig(ModelPart.QFORMER, LayerGroup.ALL, LayerType.BOTH, 
                uniform_quantization, num_bits=8),
    QuantConfig(ModelPart.LLM, LayerGroup.ALL, LayerType.BOTH, 
                uniform_quantization, num_bits=2)
]

print("Quantizing model...")
quantizer.apply_quantization(configs)
print_model_structure(model)

Quantizing model...
running 4 quant
running 8 quant
running 2 quant
vision_model: Blip2VisionModel
  embeddings: Blip2VisionEmbeddings
    patch_embedding: Conv2d
  encoder: Blip2Encoder
    layers: ModuleList
      0: Blip2EncoderLayer
        self_attn: Blip2Attention
          dropout: Dropout
          qkv: Linear (Quantized: 4 bits)
          projection: Linear (Quantized: 4 bits)
        layer_norm1: LayerNorm
        mlp: Blip2MLP
          activation_fn: GELUActivation
          fc1: Linear (Quantized: 4 bits)
          fc2: Linear (Quantized: 4 bits)
        layer_norm2: LayerNorm
      1: Blip2EncoderLayer
        self_attn: Blip2Attention
          dropout: Dropout
          qkv: Linear (Quantized: 4 bits)
          projection: Linear (Quantized: 4 bits)
        layer_norm1: LayerNorm
        mlp: Blip2MLP
          activation_fn: GELUActivation
          fc1: Linear (Quantized: 4 bits)
          fc2: Linear (Quantized: 4 bits)
        layer_norm2: LayerNorm
      2: Blip2En

In [None]:

quantizer = BlipQuantizer(model)

configs = [
    QuantConfig(ModelPart.VIT, LayerGroup.ALL, LayerType.BOTH, 
                lambda b: lambda x: quant_function(x, num_bits=b), num_bits=4),
    QuantConfig(ModelPart.QFORMER, LayerGroup.ALL, LayerType.BOTH, 
                lambda b: lambda x: quant_function(x, num_bits=b), num_bits=8),
    QuantConfig(ModelPart.LLM, LayerGroup.ALL, LayerType.BOTH, 
                lambda b: lambda x: quant_function(x, num_bits=b), num_bits=2)
]

quantizer.apply_quantization(configs)


In [7]:
import torch
from transformers import Blip2ForConditionalGeneration
import gc

def test_blip_quantizer(model: nn.Module):
    # Store original parameters
    original_params = {}
    sample_params = [
        'vision_model.encoder.layers.0.self_attn.qkv.weight',
        'qformer.encoder.layer.0.attention.attention.query.weight',
        'language_model.model.decoder.layers.0.self_attn.k_proj.weight'
    ]
    for name in sample_params:
        param = model
        for attr in name.split('.'):
            param = getattr(param, attr)
        original_params[name] = param.detach().clone()

    quantizer = BlipQuantizer(model)
    configs = [
        QuantConfig(ModelPart.VIT, LayerGroup.ALL, LayerType.BOTH, 
                    lambda b: lambda x: quant_function(x, num_bits=b), num_bits=4),
        QuantConfig(ModelPart.QFORMER, LayerGroup.ALL, LayerType.BOTH, 
                    lambda b: lambda x: quant_function(x, num_bits=b), num_bits=8),
        QuantConfig(ModelPart.LLM, LayerGroup.ALL, LayerType.BOTH, 
                    lambda b: lambda x: quant_function(x, num_bits=b), num_bits=2)
    ]
    print("Testing BlipQuantizer...")
    
    # Test 1: Apply quantization
    print("Applying quantization...")
    quantizer.apply_quantization(configs)
    print("Quantization applied.")

    # Test 2: Count quantized layers
    quantized_layers = quantizer.count_quantized_layers()
    print(f"Number of quantized layers: {quantized_layers}")

    # Test 3: Check quantization bits for a sample of layers
    print("\nSampling quantized layers:")
    sampled_modules = [
        ('vision_model.encoder.layers.0', model.vision_model.encoder.layers[0]),
        ('qformer.encoder.layer.0', model.qformer.encoder.layer[0]),
        ('language_model.model.decoder.layers.0', model.language_model.model.decoder.layers[0])
    ]
    
    for name, module in sampled_modules:
        print(f"\nChecking {name}:")
        for sub_name, sub_module in module.named_modules():
            if hasattr(sub_module, 'quantized'):
                print(f"  {sub_name}: Quantized to {sub_module.num_bits} bits")

    # Test 4: Verify quantization effects for a sample parameter
    print("\nVerifying quantization effects:")
    sample_params = [
        ('vision_model.encoder.layers.0.self_attn.qkv', 'weight'),
        ('qformer.encoder.layer.0.attention.attention.query', 'weight'),
        ('language_model.model.decoder.layers.0.self_attn.k_proj', 'weight')
    ]
    
    for module_name, param_name in sample_params:
        module = model
        for attr in module_name.split('.'):
            module = getattr(module, attr)
        
        param = getattr(module, param_name)
        original_param = original_params[f"{module_name}.{param_name}"]
        
        if hasattr(module, 'quantized'):
            diff = torch.abs(param - original_param).mean().item()
            print(f"{module_name}.{param_name}: Mean absolute difference after quantization: {diff:.6f}")
            print(f"Quantized to {module.num_bits} bits")
        else:
            print(f"{module_name}.{param_name}: Not quantized")
    
        # Print a small sample of the parameter values before and after quantization
        print("Sample values:")
        print("Original:", original_param.flatten()[:5].tolist())
        print("Quantized:", param.flatten()[:5].tolist())
        print()

    # Clean up
    del model, quantizer, original_params
    gc.collect()
    torch.cuda.empty_cache()


print("Loading BLIP-2 model...")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b")
print("BLIP-2 model loaded.")

# Run the tests
test_blip_quantizer(model)

Loading BLIP-2 model...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

BLIP-2 model loaded.
Testing BlipQuantizer...
Applying quantization...
running 4 quant
running 8 quant
running 2 quant
Quantization applied.
Number of quantized layers: 268

Sampling quantized layers:

Checking vision_model.encoder.layers.0:
  self_attn.qkv: Quantized to 4 bits
  self_attn.projection: Quantized to 4 bits
  mlp.fc1: Quantized to 4 bits
  mlp.fc2: Quantized to 4 bits

Checking qformer.encoder.layer.0:
  attention.attention.query: Quantized to 8 bits
  attention.attention.key: Quantized to 8 bits
  attention.attention.value: Quantized to 8 bits
  attention.output.dense: Quantized to 8 bits

Checking language_model.model.decoder.layers.0:
  fc1: Quantized to 2 bits
  fc2: Quantized to 2 bits

Verifying quantization effects:
vision_model.encoder.layers.0.self_attn.qkv.weight: Mean absolute difference after quantization: 0.030251
Quantized to 4 bits
Sample values:
Original: [-2.3543834686279297e-05, 4.231929779052734e-06, 0.00011557340621948242, 0.00534820556640625, -0.03228

In [11]:
print_model_structure(model)

vision_model: Blip2VisionModel
  embeddings: Blip2VisionEmbeddings
    patch_embedding: Conv2d
  encoder: Blip2Encoder
    layers: ModuleList
      0: Blip2EncoderLayer
        self_attn: Blip2Attention
          dropout: Dropout
          qkv: Linear (Quantized: 4 bits)
          projection: Linear (Quantized: 4 bits)
        layer_norm1: LayerNorm
        mlp: Blip2MLP
          activation_fn: GELUActivation
          fc1: Linear (Quantized: 4 bits)
          fc2: Linear (Quantized: 4 bits)
        layer_norm2: LayerNorm
      1: Blip2EncoderLayer
        self_attn: Blip2Attention
          dropout: Dropout
          qkv: Linear (Quantized: 4 bits)
          projection: Linear (Quantized: 4 bits)
        layer_norm1: LayerNorm
        mlp: Blip2MLP
          activation_fn: GELUActivation
          fc1: Linear (Quantized: 4 bits)
          fc2: Linear (Quantized: 4 bits)
        layer_norm2: LayerNorm
      2: Blip2EncoderLayer
        self_attn: Blip2Attention
          dropout: Drop

In [7]:
model

Blip2ForConditionalGeneration(
  (vision_model): Blip2VisionModel(
    (embeddings): Blip2VisionEmbeddings(
      (patch_embedding): Conv2d(3, 1408, kernel_size=(14, 14), stride=(14, 14))
    )
    (encoder): Blip2Encoder(
      (layers): ModuleList(
        (0-38): 39 x Blip2EncoderLayer(
          (self_attn): Blip2Attention(
            (dropout): Dropout(p=0.0, inplace=False)
            (qkv): Linear(in_features=1408, out_features=4224, bias=True)
            (projection): Linear(in_features=1408, out_features=1408, bias=True)
          )
          (layer_norm1): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
          (mlp): Blip2MLP(
            (activation_fn): GELUActivation()
            (fc1): Linear(in_features=1408, out_features=6144, bias=True)
            (fc2): Linear(in_features=6144, out_features=1408, bias=True)
          )
          (layer_norm2): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
        )
      )
    )
    (post_layernorm): LayerNorm((