In [22]:
import numpy as np
import torch
import transformers
import faiss
from transformers import PreTrainedModel, AutoConfig
from dense_retriever_utils import get_model, get_tokenizer

In [23]:
import argparse, shlex, json

def config(in_program_call=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str)
    parser.add_argument('--batch_size', type=int, default=48)
    parser.add_argument('--source_file', default='docprompting_data/conala/conala_nl.txt')
    parser.add_argument('--target_file', default='docprompting_data/conala/python_manual_firstpara.tok.txt')
    parser.add_argument('--source_embed_save_file', default='docprompting_data/conala/.tmp/src_embedding')
    parser.add_argument('--target_embed_save_file', default='docprompting_data/conala/.tmp/tgt_embedding')
    parser.add_argument('--save_file', default='[REPLACE]data/conala/simcse.[MODEL].[SOURCE].[TARGET].[POOLER].t[TOPK].json')
    parser.add_argument('--top_k', type=int, default=200)
    parser.add_argument('--cpu', action='store_true')
    parser.add_argument('--pooler', choices=('cls', 'cls_before_pooler'), default='cls')
    parser.add_argument('--log_level', default='verbose')
    parser.add_argument('--nl_cm_folder', default='docprompting_data/conala/nl.cm')
    parser.add_argument('--sim_func', default='cls_distance.cosine', choices=('cls_distance.cosine', 'cls_distance.l2', 'bertscore'))
    parser.add_argument('--num_layers', type=int, default=12)
    parser.add_argument('--origin_mode', action='store_true')
    parser.add_argument('--oracle_eval_file', default='docprompting_data/conala/cmd_dev.oracle_man.full.json')
    parser.add_argument('--eval_hit', action='store_true')
    parser.add_argument('--normalize_embed', action='store_true')



    args = parser.parse_args() if in_program_call is None else parser.parse_args(shlex.split(in_program_call))

    args.source_idx_file = args.source_file.replace(".txt", ".id")
    args.target_idx_file = args.target_file.replace(".txt", ".id")

    print(json.dumps(vars(args), indent=2))
    return args

in_program_call = ("--model_name 'neulab/docprompting-codet5-python-doc-retriever'"
                  " --save_file docprompting_data/conala/retrieval_results_test.json")
args = config(in_program_call)

{
  "model_name": "neulab/docprompting-codet5-python-doc-retriever",
  "batch_size": 48,
  "source_file": "docprompting_data/conala/conala_nl.txt",
  "target_file": "docprompting_data/conala/python_manual_firstpara.tok.txt",
  "source_embed_save_file": "docprompting_data/conala/.tmp/src_embedding",
  "target_embed_save_file": "docprompting_data/conala/.tmp/tgt_embedding",
  "save_file": "docprompting_data/conala/retrieval_results_test.json",
  "top_k": 200,
  "cpu": false,
  "pooler": "cls",
  "log_level": "verbose",
  "nl_cm_folder": "docprompting_data/conala/nl.cm",
  "sim_func": "cls_distance.cosine",
  "num_layers": 12,
  "origin_mode": false,
  "oracle_eval_file": "docprompting_data/conala/cmd_dev.oracle_man.full.json",
  "eval_hit": false,
  "normalize_embed": false,
  "source_idx_file": "docprompting_data/conala/conala_nl.id",
  "target_idx_file": "docprompting_data/conala/python_manual_firstpara.tok.id"
}


In [24]:
class RetrievalModel(PreTrainedModel):
    def __init__(self, config, model_name, tokenizer, model_args, batch_size=64):
        super().__init__(config)
        self.model_args = model_args
        self.batch_size = batch_size
        self.model_name = model_name
        self.tokenizer = get_tokenizer(model_name, use_fast=True) if tokenizer is None else tokenizer
        self.model = get_model(self.model_name)

In [25]:
class CodeT5Retriever:
    def __init__(self, args):
        self.args = args
        self.model_name = self.args.model_name
        self.tokenizer = transformers.RobertaTokenizer.from_pretrained(self.model_name)

        config = AutoConfig.from_pretrained(self.model_name)
        class Dummy():
            def __init__(self, sim_func):
                self.sim_func = sim_func
        model_arg = Dummy(args.sim_func)

        self.model = RetrievalModel(model_name=self.model_name,
                                    tokenizer=self.tokenizer,
                                    config=config,
                                    model_args=model_arg)
        self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        self.model.eval()
        self.model.to(self.device)

In [26]:
searcher = CodeT5Retriever(args)

In [56]:
normaliza_embed = True
text_file = args.source_file
save_file = args.source_embed_save_file
# text_file = args.target_file
# save_file = args.target_embed_save_file
with open(text_file, 'r') as f:
    dataset = []
    for line in f:
        dataset.append(line.strip())
print(f'number of sentences: {len(dataset)}')
dataset = dataset[0:128]

number of sentences: 2879


In [61]:
import os

# for each sentence, get n(token)*768 embedding, then do a mean pooling
bs = 128
with torch.no_grad():
    all_embeddings = []
    for i in range(0, len(dataset), bs):
        batch = dataset[i: i+bs]
        
        # pad batch
        sent_features = searcher.tokenizer(batch, add_special_tokens=True, 
                                           max_length=searcher.tokenizer.model_max_length, truncation=True)
        arr = sent_features['input_ids']
        print(arr)
        lens = torch.LongTensor([len(a) for a in arr])
        max_len = lens.max().item()
        padded = torch.ones(len(arr), max_len, dtype=torch.long) * searcher.tokenizer.pad_token_id
        mask = torch.zeros(len(arr), max_len, dtype=torch.long)
        for i, a in enumerate(arr):
            padded[i, : lens[i]] = torch.tensor(a, dtype=torch.long)
            mask[i, : lens[i]] = 1
        padded_batch = {'input_ids': padded, 'attention_mask': mask, 'lengths': lens}
        for k in padded_batch:
            if isinstance(padded_batch[k], torch.Tensor):
                padded_batch[k] = padded_batch[k].to(searcher.device)
        
        # get embedding
        input_ids = padded_batch['input_ids']
        attention_mask = padded_batch['attention_mask']
        lengths = padded_batch['lengths']
        output = searcher.model.model(input_ids, attention_mask=attention_mask, output_hidden_states=False)
        emb = output['last_hidden_state']
        emb.masked_fill_(~attention_mask.bool().unsqueeze(-1), 0)
        # base = torch.arange(max_len, dtype=torch.long).expand(len(lengths), max_len).to(lengths.device)
        # pad_mask = base < lengths.unsqueeze(1) # pad token set to false
        # emb = (emb*pad_mask.unsqueeze(-1)).sum(dim=1) / pad_mask.sum(-1).unsqueeze(-1)
        emb = emb.sum(dim=1) / lengths.unsqueeze(-1)
        if args.normalize_embed:
            emb = emb / emb.norm(dim=1, keepdim=True)
            
        all_embeddings.append(emb)
        
    all_embeddings = np.concatenate(all_embeddings, axis=0)
    print(f"done embedding: {all_embeddings.shape}")
    
#     if not os.path.exists(os.path.dirname(save_file)):
#             os.makedirs(os.path.dirname(save_file))
#     np.save(save_file, all_embeddings)

[[1, 1788, 666, 1375, 267, 541, 2183, 376, 68, 4191, 12899, 2184, 434, 618, 8005, 797, 2], [1, 399, 6159, 326, 1686, 434, 279, 24817, 1296, 315, 2219, 79, 2292, 1188, 271, 2163, 16, 271, 2499, 1713, 1842, 18, 5830, 405, 1842, 22, 18, 5830, 6, 2], [1, 29480, 2529, 364, 3419, 1375, 7407, 68, 4752, 3974, 1865, 19440, 279, 5823, 1046, 11416, 1520, 2], [1, 29480, 26328, 1620, 358, 2623, 296, 3789, 461, 11, 2], [1, 968, 12037, 3756, 858, 733, 598, 2142, 1879, 326, 2595, 598, 1967, 1057, 1257, 3470, 225, 12170, 1375, 2180, 68, 2], [1, 7074, 3207, 1375, 24339, 68, 487, 14476, 585, 296, 3459, 18, 6446, 11, 2], [1, 7074, 13892, 2667, 358, 1316, 585, 1375, 3459, 18, 6446, 68, 622, 279, 7861, 434, 1375, 19249, 16361, 68, 2], [1, 1623, 364, 3936, 1936, 296, 4709, 14361, 1188, 11, 316, 533, 1375, 3813, 1585, 68, 6508, 394, 980, 3351, 2337, 82, 11, 2], [1, 1374, 4412, 5600, 316, 10681, 316, 279, 533, 1375, 2503, 353, 1300, 404, 471, 333, 353, 1300, 11201, 68, 2], [1, 6164, 310, 585, 1375, 768, 68, 62

done embedding: (128, 768)


In [89]:
fuck_emb = auth_output[0]
torch.equal(fuck_emb, output[0])

True

In [92]:
fuck_emb = auth_output[0]
fuck_emb.masked_fill_(~attention_mask.bool().unsqueeze(-1), 0)
fuck_emb = fuck_emb.sum(dim=1) / lengths.unsqueeze(-1)
if args.normalize_embed:
    fuck_emb = fuck_emb / fuck_emb.norm(dim=1, keepdim=True)
torch.equal(fuck_emb, emb)

True

In [94]:
auth_emb = auth_output[0]
auth_emb.masked_fill_(~attention_mask.bool().unsqueeze(-1), 0)
auth_emb = auth_emb.sum(dim=1) / lengths.unsqueeze(-1)
if args.normalize_embed:
    auth_emb = auth_emb / auth_emb.norm(dim=1, keepdim=True)
torch.equal(auth_emb, emb)

True

In [96]:
print(args.normalize_embed)

False


In [77]:
torch.equal(emb, auth_emb)

False

In [78]:
emb

tensor([[-0.0481, -0.2097, -0.1796,  ..., -0.0620,  0.1195,  0.0246],
        [ 0.0107, -0.1357,  0.1582,  ..., -0.0061,  0.0038, -0.1792],
        [-0.0133, -0.0564,  0.2823,  ...,  0.0524, -0.1482, -0.0964],
        ...,
        [-0.0522, -0.1945, -0.2370,  ..., -0.0774,  0.2625,  0.0512],
        [-0.1862, -0.0600, -0.1827,  ..., -0.2017,  0.2697, -0.0348],
        [-0.1862, -0.0600, -0.1827,  ..., -0.2017,  0.2697, -0.0348]])

In [79]:
auth_emb

tensor([[-0.0152, -0.0661, -0.0566,  ..., -0.0195,  0.0377,  0.0077],
        [ 0.0034, -0.0432,  0.0504,  ..., -0.0020,  0.0012, -0.0571],
        [-0.0036, -0.0153,  0.0767,  ...,  0.0142, -0.0403, -0.0262],
        ...,
        [-0.0135, -0.0504, -0.0614,  ..., -0.0201,  0.0680,  0.0133],
        [-0.0450, -0.0145, -0.0441,  ..., -0.0487,  0.0651, -0.0084],
        [-0.0450, -0.0145, -0.0441,  ..., -0.0487,  0.0651, -0.0084]])

In [76]:
auth_emb.shape

torch.Size([128, 768])

In [70]:
auth_output[0].shape

torch.Size([128, 45, 768])

In [71]:
output[0].shape

torch.Size([128, 45, 768])

In [72]:
torch.equal(auth_output[0], output[0])

True

In [65]:
torch.equal(input_ids, auth_input_ids)

True

In [66]:
torch.equal(attention_mask, auth_attention_mask)

True

In [35]:
auth_src_embed.shape

(2879, 768)

In [40]:
auth_emb[0]

array([-1.51683511e-02, -6.61402717e-02, -5.66248782e-02,  2.51631625e-02,
       -4.06907052e-02, -2.15809513e-02, -6.44470900e-02, -3.19580697e-02,
        3.80077437e-02, -2.18240060e-02,  2.73901541e-02,  5.24688326e-02,
       -1.23803802e-02, -8.38407036e-03,  4.24143206e-03, -1.99274137e-03,
        4.68592420e-02,  4.56570536e-02, -8.80548637e-03, -5.32950237e-02,
       -1.20829698e-02, -2.99830791e-02,  3.33988778e-02, -3.56720239e-02,
        8.65500513e-03, -7.42584514e-03, -2.69307569e-02, -5.50832897e-02,
       -5.25289476e-02,  1.27503369e-02, -9.04892758e-03,  8.46601836e-03,
       -2.69015692e-02,  2.20235735e-02, -1.11881895e-02,  1.41660534e-02,
       -2.83475574e-02,  2.53256317e-02, -2.60235500e-02, -3.99435908e-02,
       -2.30046585e-02,  6.42091557e-02, -1.00606289e-02,  5.75726619e-03,
        7.37272725e-02, -3.19232941e-02,  3.95331867e-02, -5.09065688e-02,
        2.80672275e-02,  5.77943847e-02, -2.93761343e-02,  1.89341400e-02,
        4.92721051e-03, -

In [36]:
auth_src_embed[0]

array([-4.80976552e-02, -2.09725618e-01, -1.79553062e-01,  7.97904208e-02,
       -1.29027039e-01, -6.84315115e-02, -2.04356685e-01, -1.01336539e-01,
        1.20519578e-01, -6.92022145e-02,  8.68520364e-02,  1.66374564e-01,
       -3.92572172e-02, -2.65852306e-02,  1.34492498e-02, -6.31882716e-03,
        1.48586988e-01,  1.44774944e-01, -2.79215090e-02, -1.68994352e-01,
       -3.83141525e-02, -9.50739980e-02,  1.05905227e-01, -1.13113195e-01,
        2.74443459e-02, -2.35467739e-02, -8.53953212e-02, -1.74664810e-01,
       -1.66565180e-01,  4.04303223e-02, -2.86934432e-02,  2.68450826e-02,
       -8.53027701e-02,  6.98350295e-02, -3.54768746e-02,  4.49194461e-02,
       -8.98878872e-02,  8.03055987e-02, -8.25186446e-02, -1.26658008e-01,
       -7.29459748e-02,  2.03602210e-01, -3.19014676e-02,  1.82558410e-02,
        2.33783424e-01, -1.01226270e-01,  1.25356644e-01, -1.61420748e-01,
        8.89989808e-02,  1.83261469e-01, -9.31494236e-02,  6.00386783e-02,
        1.56237995e-02, -

In [37]:
all_embeddings[0]

array([-4.80976552e-02, -2.09725618e-01, -1.79553062e-01,  7.97904208e-02,
       -1.29027039e-01, -6.84315115e-02, -2.04356685e-01, -1.01336539e-01,
        1.20519578e-01, -6.92022145e-02,  8.68520364e-02,  1.66374564e-01,
       -3.92572172e-02, -2.65852306e-02,  1.34492498e-02, -6.31882716e-03,
        1.48586988e-01,  1.44774944e-01, -2.79215090e-02, -1.68994352e-01,
       -3.83141525e-02, -9.50739980e-02,  1.05905227e-01, -1.13113195e-01,
        2.74443459e-02, -2.35467739e-02, -8.53953212e-02, -1.74664810e-01,
       -1.66565180e-01,  4.04303223e-02, -2.86934432e-02,  2.68450826e-02,
       -8.53027701e-02,  6.98350295e-02, -3.54768746e-02,  4.49194461e-02,
       -8.98878872e-02,  8.03055987e-02, -8.25186446e-02, -1.26658008e-01,
       -7.29459748e-02,  2.03602210e-01, -3.19014676e-02,  1.82558410e-02,
        2.33783424e-01, -1.01226270e-01,  1.25356644e-01, -1.61420748e-01,
        8.89989808e-02,  1.83261469e-01, -9.31494236e-02,  6.00386783e-02,
        1.56237995e-02, -

In [73]:
# retrieve
source_id_map, target_id_map = {}, {}
with open(args.source_idx_file, 'r') as f:
    for idx, line in enumerate(f):
        source_id_map[idx] = line.strip()
with open(args.target_idx_file, 'r') as f:
    for idx, line in enumerate(f):
        target_id_map[idx] = line.strip()
source_embed = np.load(args.source_embed_save_file + '.npy')
target_embed = np.load(args.target_embed_save_file + '.npy')
assert len(source_id_map) == source_embed.shape[0]
assert len(target_id_map) == target_embed.shape[0]
indexer = faiss.IndexFlatIP(target_embed.shape[1])
indexer.add(target_embed)
print(source_embed.shape, target_embed.shape)
D, I = indexer.search(source_embed, args.top_k)

(2879, 768) (30755, 768)


In [82]:
results = dict()
for source_idx, (dist, retrieved_index) in enumerate(zip(D, I)):
    source_id = source_id_map[source_idx]
    results[source_id] = {}
    retrieved_target_id = [target_id_map[x] for x in retrieved_index]
    results[source_id]['retrieved'] = retrieved_target_id
    results[source_id]['score'] = dist.tolist()
    
with open(args.save_file, 'w+') as f:
    json.dump(results, f, indent=2)

In [86]:
# evaluate
d = json.load(open(args.oracle_eval_file, 'r'))
gold = [item['oracle_man'] for item in d]
r_d = json.load(open(args.save_file, 'r'))
pred = [r_d[x['question_id']]['retrieved'] for x in d]
top_k = [1, 3, 5, 8, 10, 12, 15, 20, 30, 50, 100, 200]

def calc_recall(src, pred, top_k, print_result=True):
    recall_n = {x: 0 for x in top_k}
    precision_n = {x: 0 for x in top_k}

    for s, p in zip(src, pred):
        # cmd_name = s['cmd_name']
        oracle_man = s
        pred_man = p

        for tk in recall_n.keys():
            cur_result_vids = pred_man[:tk]
            cur_hit = sum([x in cur_result_vids for x in oracle_man])
            # recall_n[tk] += cur_hit / (len(oracle_man) + 1e-10)
            recall_n[tk] += cur_hit / (len(oracle_man)) if len(oracle_man) else 1
            precision_n[tk] += cur_hit / tk
    recall_n = {k: v / len(pred) for k, v in recall_n.items()}
    precision_n = {k: v / len(pred) for k, v in precision_n.items()}

    if print_result:
        for k in sorted(recall_n.keys()):
            print(f"{recall_n[k] :.3f}", end="\t")
        print()
        for k in sorted(precision_n.keys()):
            print(f"{precision_n[k] :.3f}", end="\t")
        print()
        for k in sorted(recall_n.keys()):
            print(f"{2 * precision_n[k] * recall_n[k] / (precision_n[k] + recall_n[k] + 1e-10) :.3f}", end="\t")
        print()

    return {'recall': recall_n, 'precision': precision_n}

metrics = calc_recall(gold, pred, top_k)
print(metrics)

0.165	0.326	0.423	0.506	0.558	0.579	0.604	0.667	0.737	0.812	0.879	0.949	
0.254	0.182	0.140	0.105	0.095	0.083	0.069	0.059	0.044	0.029	0.016	0.009	
0.200	0.234	0.211	0.174	0.162	0.144	0.124	0.109	0.084	0.056	0.032	0.017	
{'recall': {1: 0.1654228855721393, 3: 0.32645107794361544, 5: 0.4234660033167495, 8: 0.5059701492537313, 10: 0.5582089552238806, 12: 0.5785240464344941, 15: 0.6042288557213931, 20: 0.6669983416252073, 30: 0.7372305140961858, 50: 0.8124378109452735, 100: 0.8793532338308457, 200: 0.949419568822554}, 'precision': {1: 0.2537313432835821, 3: 0.18242122719734652, 5: 0.14029850746268627, 8: 0.10509950248756218, 10: 0.0950248756218905, 12: 0.0825041459369819, 15: 0.0693200663349916, 20: 0.05920398009950251, 30: 0.044278606965174064, 50: 0.029154228855721342, 100: 0.016218905472636776, 200: 0.008631840796019878}}


In [16]:
from dense_retriever import dense_retriever_config
model_name = 'Salesforce/codet5-base'
in_program_call = f"--dataset tldr --model_name {model_name}"
args = dense_retriever_config(in_program_call)

{
  "dataset": "tldr",
  "model_name": "Salesforce/codet5-base",
  "batch_size": 128,
  "top_k": 200,
  "sim_func": "cls_distance.cosine",
  "normalize_embed": false,
  "save_file": "docprompting_data/tldr/ret_results_codet5-base.json",
  "conala_qs_embed_save_file": "docprompting_data/conala/.tmp/src_embedding_codet5-base",
  "conala_doc_firstpara_save_file": "docprompting_data/conala/.tmp/tgt_embedding_codet5-base",
  "tldr_qs_embed_save_file": "docprompting_data/tldr/.tmp/src_embedding_codet5-base",
  "tldr_doc_whole_embed_save_file": "docprompting_data/tldr/.tmp/whole_embedding_codet5-base",
  "tldr_doc_line_embed_save_file": "docprompting_data/tldr/.tmp/line_embedding_codet5-base"
}


In [17]:
tldr_args = tldr_config()
tldr_key_list_whole = list(json.load(open(tldr_args.doc_whole, 'r')).keys())
source_embed = np.load(args.tldr_qs_embed_save_file + '.npy')
result_whole_save_file = args.save_file.replace('.json', '_whole.json')
doc_whole_embed = np.load(args.tldr_doc_whole_embed_save_file + '.npy')

In [22]:
import faiss
target_embed = doc_whole_embed
save_file = result_whole_save_file
target_key_list = tldr_key_list_whole

indexer = faiss.IndexFlatIP(target_embed.shape[1])
indexer.add(target_embed)
print(source_embed.shape, target_embed.shape)
D, I = indexer.search(source_embed, args.top_k)

(1845, 768) (1927, 768)


In [29]:
results = list()
dist, retrieved_index = D[0], I[0]
retrieved_target_key = [target_key_list[x] for x in retrieved_index]
for key, distance in zip(retrieved_target_key, dist):
    results.append(dict(lib_key=key, score=distance))
results

[{'lib_key': 'wp', 'score': 2.4240243},
 {'lib_key': 'p5', 'score': 2.2991867},
 {'lib_key': 'qpdf', 'score': 2.1964781},
 {'lib_key': 'asciiart', 'score': 2.1012466},
 {'lib_key': 'po4a-gettextize', 'score': 2.068434},
 {'lib_key': 'wpa_passphrase', 'score': 2.0590746},
 {'lib_key': 'wordgrinder', 'score': 2.0583205},
 {'lib_key': 'steghide', 'score': 2.0401578},
 {'lib_key': 'gofmt', 'score': 2.0368047},
 {'lib_key': 'lorem', 'score': 2.0340328},
 {'lib_key': 'steam', 'score': 2.0106273},
 {'lib_key': 'sphinx-build', 'score': 2.009172},
 {'lib_key': 'dexdump', 'score': 2.006337},
 {'lib_key': 'solo', 'score': 2.0050251},
 {'lib_key': 'banner', 'score': 1.9967657},
 {'lib_key': 'debman', 'score': 1.990351},
 {'lib_key': 'guetzli', 'score': 1.9849534},
 {'lib_key': 'ngrok', 'score': 1.982582},
 {'lib_key': 'assimp', 'score': 1.9684796},
 {'lib_key': 'ember', 'score': 1.9613961},
 {'lib_key': 'pdftotext', 'score': 1.9578973},
 {'lib_key': 'alex', 'score': 1.9551215},
 {'lib_key': 'tesse

In [31]:
import numpy as np
auth_src_embed_file = 'docprompting_data/conala/.tmp/src_embedding.npy'
src_embed_file = 'docprompting_data/conala/.tmp/src_embedding_docprompting-codet5-python-doc-retriever.npy'
auth_src_embed = np.load(auth_src_embed_file)
src_embed = np.load(src_embed_file)

(2879, 768)

In [19]:
def check_consistency_with_auth_src(embed):
    assert embed[0].shape == auth_src_embed[0].shape
    return np.allclose(embed[0], auth_src_embed[0])
    
def check_consistency_with_src(embed):
    assert embed[0].shape == src_embed[0].shape
    return np.allclose(embed[0], src_embed[0])

In [4]:
from dense_retriever import dense_retriever_config
model_name = 'neulab/docprompting-codet5-python-doc-retriever'
in_program_call = f"--dataset conala \
                    --model_name {model_name} \
                    --sim_func cls_distance.cosine \
                    --normalize_embed"
ret_args = dense_retriever_config(in_program_call)

{
  "dataset": "conala",
  "model_name": "neulab/docprompting-codet5-python-doc-retriever",
  "batch_size": 128,
  "top_k": 200,
  "sim_func": "cls_distance.cosine",
  "normalize_embed": true,
  "save_file": "docprompting_data/conala/ret_results_docprompting-codet5-python-doc-retriever.json",
  "conala_qs_embed_save_file": "docprompting_data/conala/.tmp/src_embedding_docprompting-codet5-python-doc-retriever",
  "conala_doc_firstpara_embed_save_file": "docprompting_data/conala/.tmp/tgt_embedding_docprompting-codet5-python-doc-retriever"
}


In [5]:
from dense_encoder import CodeT5Retriever, denseEncoder

auth_encoder = CodeT5Retriever(ret_args)
encoder = denseEncoder(ret_args)

In [42]:
from dataset_utils import conala_config
conala_args = conala_config()
with open(conala_args.qs_file, "r") as f:
    dataset = []
    for line in f:
        dataset.append(line.strip())
print(len(dataset))

2879


In [60]:
import torch
with torch.no_grad():
    batch = dataset[0:ret_args.batch_size]
    auth_sent_features = auth_encoder.tokenizer(batch, add_special_tokens=True,
                                                   max_length=auth_encoder.tokenizer.model_max_length, truncation=True)
    # sent_features = encoder.tokenizer(batch, add_special_tokens=True, max_length=encoder.tokenizer.model_max_length, truncation=True)
    # assert auth_sent_features['input_ids'] == sent_features['input_ids']
    print(auth_sent_features['input_ids'])

[[1, 1788, 666, 1375, 267, 541, 2183, 376, 68, 4191, 12899, 2184, 434, 618, 8005, 797, 2], [1, 399, 6159, 326, 1686, 434, 279, 24817, 1296, 315, 2219, 79, 2292, 1188, 271, 2163, 16, 271, 2499, 1713, 1842, 18, 5830, 405, 1842, 22, 18, 5830, 6, 2], [1, 29480, 2529, 364, 3419, 1375, 7407, 68, 4752, 3974, 1865, 19440, 279, 5823, 1046, 11416, 1520, 2], [1, 29480, 26328, 1620, 358, 2623, 296, 3789, 461, 11, 2], [1, 968, 12037, 3756, 858, 733, 598, 2142, 1879, 326, 2595, 598, 1967, 1057, 1257, 3470, 225, 12170, 1375, 2180, 68, 2], [1, 7074, 3207, 1375, 24339, 68, 487, 14476, 585, 296, 3459, 18, 6446, 11, 2], [1, 7074, 13892, 2667, 358, 1316, 585, 1375, 3459, 18, 6446, 68, 622, 279, 7861, 434, 1375, 19249, 16361, 68, 2], [1, 1623, 364, 3936, 1936, 296, 4709, 14361, 1188, 11, 316, 533, 1375, 3813, 1585, 68, 6508, 394, 980, 3351, 2337, 82, 11, 2], [1, 1374, 4412, 5600, 316, 10681, 316, 279, 533, 1375, 2503, 353, 1300, 404, 471, 333, 353, 1300, 11201, 68, 2], [1, 6164, 310, 585, 1375, 768, 68, 62

In [43]:
def tmp(sent_features, encoder):
    arr = sent_features['input_ids']
    # pad batch
    lens = torch.LongTensor([len(a) for a in arr])
    max_len = lens.max().item()
    padded = torch.ones(len(arr), max_len, dtype=torch.long) * encoder.tokenizer.pad_token_id
    mask = torch.zeros(len(arr), max_len, dtype=torch.long)
    for i, a in enumerate(arr):
        padded[i, : lens[i]] = torch.tensor(a, dtype=torch.long)
        mask[i, : lens[i]] = 1
    padded_batch = {'input_ids': padded, 'attention_mask': mask, 'lengths': lens}
    for key in padded_batch:
        if isinstance(padded_batch[key], torch.Tensor):
            padded_batch[key] = padded_batch[key].to(encoder.device)
        
    return padded_batch

auth_padded_batch = tmp(auth_sent_features, auth_encoder)
auth_input_ids, auth_attention_mask, auth_lengths = auth_padded_batch['input_ids'], auth_padded_batch['attention_mask'], auth_padded_batch['lengths']
input_ids, attention_mask, lengths = tmp(sent_features, encoder)
# print(auth_input_ids)
# print(input_ids)
# print(auth_attention_mask)
# print(attention_mask)
# print(auth_lengths)
# print(lengths)

In [73]:
with torch.no_grad():
    auth_output = auth_encoder.model.model(auth_input_ids, attention_mask=auth_attention_mask, output_hidden_states=False)
    auth_emb = auth_output['last_hidden_state']
#     output = encoder.model(input_ids, attention_mask=attention_mask, output_hidden_states=False)
#     embed = output['last_hidden_state']

In [50]:
# auth_emb = auth_emb.detach().numpy()
# emb = emb.detach().numpy()
# print(np.allclose(auth_emb, emb))

AttributeError: 'numpy.ndarray' object has no attribute 'detach'

In [51]:
print(check_consistency_with_auth_src(emb))

False


In [52]:
print(check_consistency_with_src(emb))

False


In [55]:
src_embed[0]

array([-1.51683511e-02, -6.61402717e-02, -5.66248782e-02,  2.51631625e-02,
       -4.06907052e-02, -2.15809513e-02, -6.44470900e-02, -3.19580697e-02,
        3.80077437e-02, -2.18240060e-02,  2.73901541e-02,  5.24688326e-02,
       -1.23803802e-02, -8.38407036e-03,  4.24143206e-03, -1.99274137e-03,
        4.68592420e-02,  4.56570536e-02, -8.80548637e-03, -5.32950237e-02,
       -1.20829698e-02, -2.99830791e-02,  3.33988778e-02, -3.56720239e-02,
        8.65500513e-03, -7.42584514e-03, -2.69307569e-02, -5.50832897e-02,
       -5.25289476e-02,  1.27503369e-02, -9.04892758e-03,  8.46601836e-03,
       -2.69015692e-02,  2.20235735e-02, -1.11881895e-02,  1.41660534e-02,
       -2.83475574e-02,  2.53256317e-02, -2.60235500e-02, -3.99435908e-02,
       -2.30046585e-02,  6.42091557e-02, -1.00606289e-02,  5.75726619e-03,
        7.37272725e-02, -3.19232941e-02,  3.95331867e-02, -5.09065688e-02,
        2.80672275e-02,  5.77943847e-02, -2.93761343e-02,  1.89341400e-02,
        4.92721051e-03, -

In [53]:
auth_emb[0]

array([-1.51683511e-02, -6.61402717e-02, -5.66248782e-02,  2.51631625e-02,
       -4.06907052e-02, -2.15809513e-02, -6.44470900e-02, -3.19580697e-02,
        3.80077437e-02, -2.18240060e-02,  2.73901541e-02,  5.24688326e-02,
       -1.23803802e-02, -8.38407036e-03,  4.24143206e-03, -1.99274137e-03,
        4.68592420e-02,  4.56570536e-02, -8.80548637e-03, -5.32950237e-02,
       -1.20829698e-02, -2.99830791e-02,  3.33988778e-02, -3.56720239e-02,
        8.65500513e-03, -7.42584514e-03, -2.69307569e-02, -5.50832897e-02,
       -5.25289476e-02,  1.27503369e-02, -9.04892758e-03,  8.46601836e-03,
       -2.69015692e-02,  2.20235735e-02, -1.11881895e-02,  1.41660534e-02,
       -2.83475574e-02,  2.53256317e-02, -2.60235500e-02, -3.99435908e-02,
       -2.30046585e-02,  6.42091557e-02, -1.00606289e-02,  5.75726619e-03,
        7.37272725e-02, -3.19232941e-02,  3.95331867e-02, -5.09065688e-02,
        2.80672275e-02,  5.77943847e-02, -2.93761343e-02,  1.89341400e-02,
        4.92721051e-03, -