In [1]:
from blip_quantizer import QuantConfig, ModelPart, LayerGroup, LayerType
from quant_functions import uniform_quantization
from utils import save_quant_configs
import itertools
import os

# Ensure the configs directory exists
os.makedirs("./configs", exist_ok=True)

bit_widths = [16, 8, 6, 5, 4, 3]
model_parts = [ModelPart.VIT, ModelPart.QFORMER, ModelPart.LLM]
layer_types = [LayerType.MLP, LayerType.ATTENTION, LayerType.BOTH]
layer_groups = [
    [],
    [LayerGroup.FIRST],
    [LayerGroup.MIDDLE],
    [LayerGroup.LAST],
    [LayerGroup.FIRST, LayerGroup.MIDDLE],
    [LayerGroup.FIRST, LayerGroup.LAST],
    [LayerGroup.MIDDLE, LayerGroup.LAST],
    [LayerGroup.FIRST, LayerGroup.MIDDLE, LayerGroup.LAST]
]

def generate_configs():
    config_index = 0
    for bit_width in bit_widths:
        for layer_type in layer_types:
            for main_part in model_parts:
                for main_groups in layer_groups:
                    other_parts = [part for part in model_parts if part != main_part]
                    for other_quant in itertools.product([True, False], repeat=2):
                        configs = []
                        
                        # Main part configuration
                        for group in main_groups:
                            configs.append(QuantConfig(main_part, group, layer_type, uniform_quantization, bit_width))
                        
                        # Other parts configuration
                        for part, should_quant in zip(other_parts, other_quant):
                            if should_quant:
                                configs.append(QuantConfig(part, LayerGroup.ALL, layer_type, uniform_quantization, bit_width))
                        
                        # Save the configuration
                        if configs:
                            save_quant_configs(configs, f"./configs/{config_index}.json")
                            config_index += 1
    return config_index

if __name__ == "__main__":
    config_index = generate_configs()
    print(f"Generated {config_index} configuration files.")

Generated 1674 configuration files.
