In [None]:
aa = torch.tensor([2,2,4])

torch.pow(aa, exponent=2)

In [3]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '2,3'
import pickle
import torch
import numpy as np
import glob
import json
from argparse import ArgumentParser
from itertools import chain
from tqdm import tqdm

import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict

def pickle_load(path):
    with open(path, 'rb') as f:
        obj = pickle.load(f)
    return obj

def convert_id(look_up_list):
    converted_look_up = {}
    for idx, ids in enumerate(look_up_list):
        converted_look_up[ids] = idx
    return converted_look_up

def load_trec(file_path):
    qid2dids = defaultdict(list)
    with open(file_path, "r", encoding="utf-8") as fi:
        for line in tqdm(fi):
            qid, did, _ = line.strip().split()
            qid2dids[qid].append(did)
    return qid2dids



class build_dataset(Dataset):
    def __init__(
        self, 
        qid2dids,
        qid2index,
        docid2index,
        q_vectors,
        d_vectors,
    ):
        self.qid2dids = qid2dids
        self.qid2index = qid2index
        self.docid2index = docid2index
        self.q_vectors = q_vectors
        self.d_vectors = d_vectors
    
        self.qids = list(qid2dids.keys())

    def __len__(self):
        return len(self.qids)

    def __getitem__(self, index):
        qid = self.qids[index]
        dids = self.qid2dids[qid]
        return get_item(
            qid,
            dids,
            qid2index=self.qid2index,
            docid2index=self.docid2index,
            q_vectors=self.q_vectors,
            d_vectors=self.d_vectors,
        )
    
    
def get_item(
    qid,
    dids,
    qid2index,
    docid2index,
    q_vectors,
    d_vectors,
):
    q_index = qid2index[qid]
    doc_indexs = [docid2index[did] for did in dids]
    
    q_vecs = q_vectors[q_index]
    doc_vecs = [d_vectors[index] for index in doc_indexs]
    
    return q_vecs, doc_vecs, qid, dids
            

def batchify_fct(batch):
    
    q_vecs = [ex[0] for ex in batch]
    doc_vecs = [ex[1] for ex in batch]
    qids = [ex[2] for ex in batch]
    dids = [ex[3] for ex in batch]
    
    q_tensor = torch.tensor(q_vecs)
    doc_tensor = torch.tensor(doc_vecs)

    return q_tensor, doc_tensor, qids, dids
    
    
if __name__ == "__main__":
    
    DATA_DIR = "/data/private/sunsi/experiments/cocondenser/results/inference.iter-1.self-neg-cocondenser-20k"
    output_path = os.path.join(DATA_DIR, "dev-mmr.jsonl")
    
    trec_path = os.path.join(DATA_DIR, "dev.rank.tsv")
    query_reps_path = os.path.join(DATA_DIR, "query/qry.pt")
    passage_reps_path = os.path.join(DATA_DIR, "corpus/*.pt")
    ## load query2vec
    q_reps, q_lookup = pickle_load(query_reps_path)
    
    ## load passage2vec
    index_files = glob.glob(passage_reps_path)
    print(f'Pattern match found {len(index_files)} files; loading them into index.')
    
    ## load retrieval results
    qid2dids = load_trec(trec_path)
    
    p_reps_0, p_lookup_0 = pickle_load(index_files[0])
    shards = chain([(p_reps_0, p_lookup_0)], map(pickle_load, index_files[1:]))
    if len(index_files) > 1:
        shards = tqdm(shards, desc='Loading shards into index', total=len(index_files))    
    
    p_reps = []
    look_up = []
    for _p_reps, p_lookup in shards:
        p_reps.append(_p_reps)
        look_up += p_lookup
      
    p_reps = np.concatenate(p_reps, axis=0)
    
    ## id2index
    qid2index = convert_id(q_lookup)
    docid2index = convert_id(look_up)
    
    ## dataset
    encode_dataset = build_dataset(
        qid2dids=qid2dids,
        qid2index=qid2index,
        docid2index=docid2index,
        q_vectors=q_reps,
        d_vectors=p_reps,
    )
    sampler = torch.utils.data.sampler.SequentialSampler(encode_dataset)
    
    
    
    batch_size = 128
    encode_data_loader = DataLoader(
        encode_dataset,
        batch_size=batch_size,
        sampler=sampler,
        num_workers=16,
        collate_fn=batchify_fct,
        pin_memory=True,
    )
    
    
    with open(output_path, "w", encoding="utf-8") as fw:
    
        batch_iterator = tqdm(encode_data_loader, desc="Iteration")
        for batch in tqdm(batch_iterator):
            q_tensor, doc_tensor, qids, dids = batch

            q_tensor = q_tensor.cuda() # bz * 768
            doc_tensor = doc_tensor.cuda() # bz * topk * 768

            qd_score = torch.sum(q_tensor.unsqueeze(1) * doc_tensor, dim=-1)
            norm_qd_score = F.normalize(qd_score, p=1, dim=-1)

            cos_dd_score = torch.cosine_similarity(doc_tensor.unsqueeze(1), doc_tensor.unsqueeze(2), dim=-1)

            norm_qd_score = norm_qd_score.cpu().detach().tolist()
            cos_dd_score = cos_dd_score.cpu().detach().tolist()

            for batch_id, (qid, sub_dids) in enumerate(zip(qids, dids)):
                save_item = {
                    "qid":qid,
                    "dids":sub_dids,
                    "qd_score":norm_qd_score,
                    "dd_score":cos_dd_score,
                }
                fw.write(json.dumps(save_item)+"\n")


Pattern match found 10 files; loading them into index.


Loading shards into index: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:17<00:00,  1.77s/it]


Iteration:   0%|                                                                                                                                       | 0/7 [00:00<?, ?it/s]
  0%|                                                                                                                                                  | 0/7 [00:00<?, ?it/s][A
Iteration:  14%|██████████████████▏                                                                                                            | 1/7 [01:08<06:53, 68.96s/it][A
Iteration:  29%|████████████████████████████████████▎                                                                                          | 2/7 [02:13<05:30, 66.11s/it][A
Iteration:  43%|██████████████████████████████████████████████████████▍                                                                        | 3/7 [03:18<04:22, 65.57s/it][A
Iteration:  57%|████████████████████████████████████████████████████████████████████████▌                             