In [1]:
import argparse
import random

import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn

import lavis.tasks as tasks
from lavis.common.config import Config
from lavis.common.dist_utils import get_rank, init_distributed_mode
from lavis.common.logger import setup_logger
from lavis.common.optims import (
    LinearWarmupCosineLRScheduler,
    LinearWarmupStepLRScheduler,
)
from lavis.common.utils import now

# imports modules for registration
from lavis.datasets.builders import *
from lavis.models import *
from lavis.processors import *
from lavis.runners.runner_base import RunnerBase
from lavis.tasks import *
from layers.nbitlineardynamic import NBitLinearDynamic

  from .autonotebook import tqdm as notebook_tqdm


In [23]:
def parse_args():
    parser = argparse.ArgumentParser(description="Training")

    parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
    parser.add_argument(
        "--options",
        nargs="+",
        help="override some settings in the used config, the key-value pair "
        "in xxx=yyy format will be merged into config file (deprecate), "
        "change to --cfg-options instead.",
    )
    
    parser.add_argument('--visual-encoder-block-modules', 
                        required=False,
                        nargs="*",
                        choices= ['qkv', 'proj', 'fc1', 'fc2'],
                        default=None,                         
                        help='modules of visual-encoder blocks to quantize')
    
    parser.add_argument('--visual-encoder-block-indices',
                         required=False,
                         nargs='*',
                         type=int,
                        #  choices= [i for i in range(39)],   # NOTE: can enforce hard-coded number of possible blocks for ViT
                         default=None,      
                         help = 'indices of visual-encoder blocks to quantize')
    
    parser.add_argument('--visual-encoder-block-weight-bits',
                        required=False,
                        type=int,
                        choices=[1,2,4,6,8],
                        default=None,
                        help = 'weight bits for visual-encoder blocks')
    
    # parser.add_argument('visual-encoder-block')    
    
    # TODO: options to quantize qformer bert
    
    # TODO: options to quantize qformer cls (for LLM head?)

    CLI_INPUT = f'''
                --cfg-path /nfshomes/vla/scratch/LAVIS/ret_flickr_eval.yaml \
                --visual-encoder-block-modules qkv proj fc1 fc2
                --visual-encoder-block-indices {' '.join([str(i) for i in range(39)])}
                --visual-encoder-block-weight-bits 8
                '''
                
    
    args = parser.parse_args(CLI_INPUT.split())
    # if 'LOCAL_RANK' not in os.environ:
    #     os.environ['LOCAL_RANK'] = str(args.local_rank)
    
    args_dict = vars(args)
    
    # ensure these are all defined
    if args_dict['visual_encoder_block_modules'] != None:
       
       if args_dict['visual_encoder_block_indices'] == None or \
          args_dict['visual_encoder_block_weight_bits'] == None:
            
            parser.error('--visual-encoder-block-modules, --visual-encoder-block-indices, --visual-encoder-block-weight-bits, must be given together')
    

    return args

args = vars(parse_args())
args

{'cfg_path': '/nfshomes/vla/scratch/LAVIS/ret_flickr_eval.yaml',
 'options': None,
 'visual_encoder_block_modules': ['qkv', 'proj', 'fc1', 'fc2'],
 'visual_encoder_block_indices': [0,
  1,
  2,
  3,
  4,
  5,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  13,
  14,
  15,
  16,
  17,
  18,
  19,
  20,
  21,
  22,
  23,
  24,
  25,
  26,
  27,
  28,
  29,
  30,
  31,
  32,
  33,
  34,
  35,
  36,
  37,
  38],
 'visual_encoder_block_weight_bits': 8}

In [6]:
cfg = Config(parse_args())
cfg

<lavis.common.config.Config at 0x7f02683e9660>

In [7]:
task = tasks.setup_task(cfg)
task

<lavis.tasks.retrieval.RetrievalTask at 0x7f0170d03430>

In [8]:
model = task.build_model(cfg)



Position interpolate from 16x16 to 26x26


In [12]:
# model_parts = {name:m.__class__.__name__ for name,m in model.named_children()}
# print(model_parts)

'''
Takes in nn.Linear and returns equivalent NBitLinearDynamic replacement
'''
def quantize_layer(module:nn.Linear, weight_bits = 32, activation_bits=32):
    
    print('weight_bits: ', weight_bits)
    
    with torch.no_grad():
        
        bias = True if module.bias != None else False
        
        Q_layer = NBitLinearDynamic(module.in_features, 
                    module.out_features, 
                    bias=bias,
                    weight_bits = weight_bits,
                    activation_bits = activation_bits)

        # copy over weights
        Q_layer.weight.copy_(module.weight)
        if bias:
            Q_layer.bias.copy_(module.bias)

    return Q_layer


def quantize_visual_encoder_block(module_parent):
    for name, module in module_parent.named_children():
        if name in args['visual_encoder_block_modules']:
            print('parent: ', module_parent)
            print('child: ', name)
            
            # TODO: could customize weight_bits/activation_bits per block 
            setattr(module_parent, name, quantize_layer(module, weight_bits = args['visual_encoder_block_weight_bits']))
            
        else:
            quantize_visual_encoder_block(module)
            

def quantize_visual_encoder_blocks(blocks):
    for name, module in blocks.named_children():
        # print(name)
        if int(name) in args['visual_encoder_block_indices']:
            # print('here')
            quantize_visual_encoder_block(module)
         


def quantize(model):
    # Visual encoder blocks
    if args['visual_encoder_block_modules']:
        quantize_visual_encoder_blocks(model.visual_encoder.blocks)

# def apply_quant_to_selected_modules(model: nn.Module, target_modules: List[str], bits: int = 4, apply=None):
    
#     for name, module in model.named_children():
#         if (apply is None):
#             if name in target_modules:
#                 print(f"Applying GPTQ to {name} module")
#                 apply_quant_to_selected-modules(module, target_modules, bits, True)
#             else:
#                 apply_quant_to_selected-modules(module, target_modules, bits, False)
#         else:
#             if isinstance(module, nn.Linear):
#                 print(f"Found a layer to quantize {name}")
#                 gptq_quantize_layer(module, bits)
#             elif isinstance(module, nn.Module):
#                 apply_quant_to_selected-modules(module, target_modules, bits, apply)
#     return model

# quantized_model = apply_quant_to_selected-modules(model, target_modules, bits=8)

In [13]:
quantize(model)

parent:  Attention(
  (qkv): NBitLinearDynamic(in_features=1408, out_features=4224, bias=False)
  (attn_drop): Dropout(p=0.0, inplace=False)
  (proj): NBitLinearDynamic(in_features=1408, out_features=1408, bias=True)
  (proj_drop): Dropout(p=0.0, inplace=False)
)
child:  qkv
weight_bits:  8
parent:  Attention(
  (qkv): NBitLinearDynamic(in_features=1408, out_features=4224, bias=False)
  (attn_drop): Dropout(p=0.0, inplace=False)
  (proj): NBitLinearDynamic(in_features=1408, out_features=1408, bias=True)
  (proj_drop): Dropout(p=0.0, inplace=False)
)
child:  proj
weight_bits:  8
parent:  Mlp(
  (fc1): NBitLinearDynamic(in_features=1408, out_features=6144, bias=True)
  (act): GELU(approximate='none')
  (fc2): NBitLinearDynamic(in_features=6144, out_features=1408, bias=True)
  (drop): Dropout(p=0.0, inplace=False)
)
child:  fc1
weight_bits:  8
parent:  Mlp(
  (fc1): NBitLinearDynamic(in_features=1408, out_features=6144, bias=True)
  (act): GELU(approximate='none')
  (fc2): NBitLinearDynam

In [14]:
model.visual_encoder.blocks

ModuleList(
  (0-38): 39 x Block(
    (norm1): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
    (attn): Attention(
      (qkv): NBitLinearDynamic(in_features=1408, out_features=4224, bias=False)
      (attn_drop): Dropout(p=0.0, inplace=False)
      (proj): NBitLinearDynamic(in_features=1408, out_features=1408, bias=True)
      (proj_drop): Dropout(p=0.0, inplace=False)
    )
    (drop_path): Identity()
    (norm2): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
    (mlp): Mlp(
      (fc1): NBitLinearDynamic(in_features=1408, out_features=6144, bias=True)
      (act): GELU(approximate='none')
      (fc2): NBitLinearDynamic(in_features=6144, out_features=1408, bias=True)
      (drop): Dropout(p=0.0, inplace=False)
    )
  )
)

In [21]:
for name, module in model.Qformer.named_children():
    print(name)

bert
cls


In [22]:
model.Qformer.cls

BertOnlyMLMHead(
  (predictions): BertLMPredictionHead(
    (transform): BertPredictionHeadTransform(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (transform_act_fn): GELUActivation()
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    )
    (decoder): Linear(in_features=768, out_features=30523, bias=True)
  )
)

In [None]:
# TODO:
# arg --> list of blocks to quantize for ViT/Q-Former

In [45]:
weight_bits = 8
activation_bits = 32

In [46]:
from layers.nbitlineardynamic import *
Q_layer = NBitLinearDynamic(model.vision_proj.in_features, 
                            model.vision_proj.out_features, 
                            bias=True,
                            weight_bits = 8,
                            activation_bits = 32)

with torch.no_grad():
    Q_layer.weight.copy_(model.vision_proj.weight)
    Q_layer.bias.copy_(model.vision_proj.bias)
    

Q_layer

NBitLinearDynamic(in_features=768, out_features=256, bias=True)

In [47]:
model.vision_proj = Q_layer

In [48]:
model.vision_proj

NBitLinearDynamic(in_features=768, out_features=256, bias=True)

In [49]:
model

Blip2Qformer(
  (visual_encoder): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 1408, kernel_size=(14, 14), stride=(14, 14))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0-38): 39 x Block(
        (norm1): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=1408, out_features=4224, bias=False)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=1408, out_features=1408, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=1408, out_features=6144, bias=True)
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=6144, out_features=1408, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
    )


In [30]:
[module for module in model.modules()]

[Blip2Qformer(
   (visual_encoder): VisionTransformer(
     (patch_embed): PatchEmbed(
       (proj): Conv2d(3, 1408, kernel_size=(14, 14), stride=(14, 14))
     )
     (pos_drop): Dropout(p=0.0, inplace=False)
     (blocks): ModuleList(
       (0-38): 39 x Block(
         (norm1): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
         (attn): Attention(
           (qkv): Linear(in_features=1408, out_features=4224, bias=False)
           (attn_drop): Dropout(p=0.0, inplace=False)
           (proj): Linear(in_features=1408, out_features=1408, bias=True)
           (proj_drop): Dropout(p=0.0, inplace=False)
         )
         (drop_path): Identity()
         (norm2): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
         (mlp): Mlp(
           (fc1): Linear(in_features=1408, out_features=6144, bias=True)
           (act): GELU(approximate='none')
           (fc2): Linear(in_features=6144, out_features=1408, bias=True)
           (drop): Dropout(p=0.0, inplace=False)
  

In [36]:
model.visual_encoder

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 1408, kernel_size=(14, 14), stride=(14, 14))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0-38): 39 x Block(
      (norm1): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=1408, out_features=4224, bias=False)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=1408, out_features=1408, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=1408, out_features=6144, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=6144, out_features=1408, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
)