In [1]:
import os
import datetime
import gc
import torch.quantization
from ptflops import get_model_complexity_info


def timestamp():
    print(datetime.datetime.now().strftime("%b %d %Y, %H:%M:%S"))

# Quantization of ColBERT Model

In [2]:
from transformers import pipeline, AutoConfig
from colbert.modeling.colbert import colbert_score
from colbert.modeling.checkpoint import Checkpoint
from colbert.infra import Run, RunConfig, ColBERTConfig
from colbert import Trainer, Indexer, Searcher
from transformers import AutoTokenizer
from colbert.data import Queries
import torch
from tqdm.auto import tqdm
import re


In [3]:
def filter_layers(name, prune_type, ignore_bias=True):
    if name.startswith('model.bert.embeddings') \
        or 'LayerNorm' in name: 
            return True
    if ignore_bias and name.endswith('bias'):
        return True
    if prune_type == "dense":
        if "attention" in name:
            return True
    elif "attention" in prune_type:
        if "attention" not in name:
            return True
        if "no_dense" in prune_type and "dense" in name:
            return True
    return False

In [4]:
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

In [5]:
def quantization_data_new(config, quant_type, quant_Int):
    use_iter = "v2.0"
    
    use_full_data = False
    nbits = 2
    k = 1000
    maxsteps = 10000

    base_path = fr"experiments/"
    experiment = fr""




    #use_iter_str=f"{use_iter:,}".replace(',','.')
    index_name = f""


    checkpoint = fr"experiments/model_dump/colbert{use_iter}" 
    retrieval_name = f"{index_name}.ranking={k}.tsv"




    if not os.path.exists(checkpoint):
        #anil checkpoint = fr"{base_path}/checkpoints/colbert"
        print(f"Couldn't find checkpoint. Using default checkpoint: {checkpoint}")
        checkpoint = fr"experiments/model_dump/colbertv2.0"

    config = ColBERTConfig(
        bsize = 64,
        root=base_path,
        experiment=experiment,
        triples=r"../../kngo/data/triples.train.small.id.json",
        collection= r"../../kngo/data/collection.tsv",

        checkpoint=checkpoint,
        nbits=nbits,
        overwrite='resume',
        index_name=index_name,
        index_path=fr"./indexes",
        rank = 0,
        nranks = 1,
        amp = True,
        gpus = 1,
    )

    print("index_name=",index_name)
        
    for q_type in quant_type:
        print(f"pruning model on prune type {q_type} to: {quant_Int}")
        with Run().context(RunConfig(nranks=config.nranks, experiment=config.experiment)):
            model = Checkpoint(config.checkpoint, colbert_config=config)
        model_state_dict = model.state_dict()
        quantized_model = torch.quantization.quantize_dynamic(model,q_type , dtype=quant_Int)
        quantized_state_dict = quantized_model.state_dict()       
        print_size_of_model(model)
        print_size_of_model(quantized_model)
        
        if do_retrieval:
            timestamp()
            gc.collect()
            config.set("queries", r"../../kngo/data/queries.dev.tsv")
            
  
            with Run().context(RunConfig(nranks=config.nranks, experiment=config.experiment, name='retrieval', overwrite = True)):
                
                config.checkpoint = model
                model.to('cpu')
                searcher = Searcher(index=config.index_name, config=config, checkpoint=model)
                queries = Queries(config.queries)
                count = 0
                while(count !=2):
                    print(f"Base model #", count)
                    ranking = searcher.search_all(queries, k=k)
                    count = count + 1
            timestamp()

            del searcher, queries, ranking
            gc.collect()
            
            with Run().context(RunConfig(nranks=config.nranks, experiment=config.experiment, name='retrieval', overwrite = True)):
                
                config.checkpoint = quantized_model
                quantized_model.to('cpu')
                searcher = Searcher(index=config.index_name, config=config, checkpoint=quantized_model)
                queries = Queries(config.queries)
                count = 0
                while(count !=2):
                    print(f"Quantized model #", count)
                    ranking = searcher.search_all(queries, k=k)
                    count = count + 1
                #ranking.save(f"msmarco.{use_iter}.nbits={config.nbits}.prune={prune_amount}.prune_type={prune_type}.ranking={k}.tsv")
                #ranking.save(retrieval_name)
            timestamp()

            del searcher, queries, ranking
            gc.collect()
             

        if do_eval:
            #!python -m utility.evaluate.msmarco_passages \
            #     --ranking "experiments/msmarco_{maxsteps_str}/retrieval/msmarco.{use_iter}.nbits={config.nbits}.prune={prune_amount}.prune_type={prune_type}.ranking={k}.tsv" \
            #     --qrels "../data/qrels.dev.tsv" > "experiments/msmarco_{maxsteps_str}/retrieval/msmarco.{use_iter}.nbits={config.nbits}.prune={prune_amount}.prune_type={prune_type}.ranking={k}.tsv.log"
            !python -m utility.evaluate.msmarco_passages \
                --ranking "experiments/{experiment}/none/retrieval/{retrieval_name}" \
                --qrels "../../kngo/data/qrels.dev.tsv" #> "experiments/{experiment}/retrieval/{retrieval_name}.log"
        del model,quantized_model
        gc.collect()

In [6]:
def quantization_data(config, quant_type, quant_Int):

    for q_type in quant_type:
        print(f"pruning model on prune type {q_type} to: {quant_Int}")
        with Run().context(RunConfig(nranks=config.nranks, experiment=config.experiment)):
            model = Checkpoint(config.checkpoint, colbert_config=config)

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model.to(device)
        quantized_model = torch.quantization.quantize_dynamic(model,q_type , dtype=quant_Int)
        quantized_model.save(f"{checkpoint}.quant={quant_Int}.quant_type={q_type}")
        del model,quantized_model
        gc.collect()

In [7]:

quantization_Int = [torch.qint8]
quantization_Type = [{torch.nn.Linear}]



In [8]:

#if not os.path.exists(checkpoint):
#    checkpoint = fr"{base_path}/checkpoints/colbert"
base_path = fr"experiments/"
checkpoint = fr"experiments/model_dump/colbertv2.0"

do_retrieval = True
do_eval = True
    
config = ColBERTConfig(
        bsize = 64,
        root=base_path,
    
        triples=r"../kngo/data/triples.train.small.id.json",
        collection= r"../kngo/data/collection.tsv",
        
        checkpoint = checkpoint,
        overwrite='resume',
    
        ncells= 10,
    
        rank = 0,
        nranks = 1,
        amp = True,
        gpus = 1,
    )

for q_Int in quantization_Int:
    quantization_data_new(config, quantization_Type, q_Int )
print("quantization experiment complete")



index_name= 
pruning model on prune type {<class 'torch.nn.modules.linear.Linear'>} to: torch.qint8
Size (MB): 438.393806
Size (MB): 181.584042
Apr 07 2023, 23:16:20
[Apr 07, 23:16:20] #> Loading collection...
0M 1M 2M 3M 4M 5M 6M 7M 8M 
[Apr 07, 23:16:43] #> Loading codec...
[Apr 07, 23:16:43] Loading decompress_residuals_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
[Apr 07, 23:17:36] Loading packbits_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
[Apr 07, 23:18:30] #> Loading IVF...
[Apr 07, 23:18:32] #> Loading doclens...


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 354/354 [00:00<00:00, 1198.63it/s]


[Apr 07, 23:18:33] #> Loading codes and residuals...


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 354/354 [01:31<00:00,  3.88it/s]


[Apr 07, 23:20:04] #> Loading the queries from ../kngo/data/queries.dev.tsv ...


FileNotFoundError: [Errno 2] No such file or directory: '../kngo/data/queries.dev.tsv'

In [None]:
for p_amount in prune_amount:
    gc.collect()
    prune_experiment(prune_type, p_amount, maxsteps = 10000,  k = 1000, \
                     do_train = False, do_index = False, do_retrieval = False, do_eval = True, nbits = 2, \
                     use_full_data = False)
print("!!!!all done!!!!!")