In [70]:
import numpy as np
import torch
import transformers
import faiss
from transformers import PreTrainedModel, AutoConfig
from dense_retriever_utils import get_model, get_tokenizer

In [84]:
import argparse, shlex, json

def config(in_program_call=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str)
    parser.add_argument('--batch_size', type=int, default=48)
    parser.add_argument('--source_file', default='docprompting_data/conala/conala_nl.txt')
    parser.add_argument('--target_file', default='docprompting_data/conala/python_manual_firstpara.tok.txt')
    parser.add_argument('--source_embed_save_file', default='docprompting_data/conala/.tmp/src_embedding')
    parser.add_argument('--target_embed_save_file', default='docprompting_data/conala/.tmp/tgt_embedding')
    parser.add_argument('--save_file', default='[REPLACE]data/conala/simcse.[MODEL].[SOURCE].[TARGET].[POOLER].t[TOPK].json')
    parser.add_argument('--top_k', type=int, default=200)
    parser.add_argument('--cpu', action='store_true')
    parser.add_argument('--pooler', choices=('cls', 'cls_before_pooler'), default='cls')
    parser.add_argument('--log_level', default='verbose')
    parser.add_argument('--nl_cm_folder', default='docprompting_data/conala/nl.cm')
    parser.add_argument('--sim_func', default='cls_distance.cosine', choices=('cls_distance.cosine', 'cls_distance.l2', 'bertscore'))
    parser.add_argument('--num_layers', type=int, default=12)
    parser.add_argument('--origin_mode', action='store_true')
    parser.add_argument('--oracle_eval_file', default='docprompting_data/conala/cmd_dev.oracle_man.full.json')
    parser.add_argument('--eval_hit', action='store_true')
    parser.add_argument('--normalize_embed', action='store_true')



    args = parser.parse_args() if in_program_call is None else parser.parse_args(shlex.split(in_program_call))

    args.source_idx_file = args.source_file.replace(".txt", ".id")
    args.target_idx_file = args.target_file.replace(".txt", ".id")

    print(json.dumps(vars(args), indent=2))
    return args

in_program_call = ("--model_name 'neulab/docprompting-codet5-python-doc-retriever'"
                  " --save_file docprompting_data/conala/retrieval_results_test.json")
args = config(in_program_call)

{
  "model_name": "neulab/docprompting-codet5-python-doc-retriever",
  "batch_size": 48,
  "source_file": "docprompting_data/conala/conala_nl.txt",
  "target_file": "docprompting_data/conala/python_manual_firstpara.tok.txt",
  "source_embed_save_file": "docprompting_data/conala/.tmp/src_embedding",
  "target_embed_save_file": "docprompting_data/conala/.tmp/tgt_embedding",
  "save_file": "docprompting_data/conala/retrieval_results_test.json",
  "top_k": 200,
  "cpu": false,
  "pooler": "cls",
  "log_level": "verbose",
  "nl_cm_folder": "docprompting_data/conala/nl.cm",
  "sim_func": "cls_distance.cosine",
  "num_layers": 12,
  "origin_mode": false,
  "oracle_eval_file": "docprompting_data/conala/cmd_dev.oracle_man.full.json",
  "eval_hit": false,
  "normalize_embed": false,
  "source_idx_file": "docprompting_data/conala/conala_nl.id",
  "target_idx_file": "docprompting_data/conala/python_manual_firstpara.tok.id"
}


In [17]:
class RetrievalModel(PreTrainedModel):
    def __init__(self, config, model_name, tokenizer, model_args, batch_size=64):
        super().__init__(config)
        self.model_args = model_args
        self.batch_size = batch_size
        self.model_name = model_name
        self.tokenizer = get_tokenizer(model_name, use_fast=True) if tokenizer is None else tokenizer
        self.model = get_model(self.model_name)

In [14]:
class CodeT5Retriever:
    def __init__(self, args):
        self.args = args
        self.model_name = self.args.model_name
        self.tokenizer = transformers.RobertaTokenizer.from_pretrained(self.model_name)

        config = AutoConfig.from_pretrained(self.model_name)
        class Dummy():
            def __init__(self, sim_func):
                self.sim_func = sim_func
        model_arg = Dummy(args.sim_func)

        self.model = RetrievalModel(model_name=self.model_name,
                                    tokenizer=self.tokenizer,
                                    config=config,
                                    model_args=model_arg)
        self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        self.model.eval()
        self.model.to(self.device)

In [18]:
searcher = CodeT5Retriever(args)

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

In [67]:
normaliza_embed = True
# text_file = args.source_file
# save_file = args.source_embed_save_file
text_file = args.target_file
save_file = args.target_embed_save_file
with open(text_file, 'r') as f:
    dataset = []
    for line in f:
        dataset.append(line.strip())
print(f'number of sentences: {len(dataset)}')

number of sentences: 30755


In [68]:
import os

# for each sentence, get n(token)*768 embedding, then do a mean pooling
bs = 128
with torch.no_grad():
    all_embeddings = []
    for i in range(0, len(dataset), bs):
        batch = dataset[i: i+bs]
        
        # pad batch
        sent_features = searcher.tokenizer(batch, add_special_tokens=True, 
                                           max_length=searcher.tokenizer.model_max_length, truncation=True)
        arr = sent_features['input_ids']
        lens = torch.LongTensor([len(a) for a in arr])
        max_len = lens.max().item()
        padded = torch.ones(len(arr), max_len, dtype=torch.long) * searcher.tokenizer.pad_token_id
        mask = torch.zeros(len(arr), max_len, dtype=torch.long)
        for i, a in enumerate(arr):
            padded[i, : lens[i]] = torch.tensor(a, dtype=torch.long)
            mask[i, : lens[i]] = 1
        padded_batch = {'input_ids': padded, 'attention_mask': mask, 'lengths': lens}
        for key in padded_batch:
            if isinstance(padded_batch[key], torch.Tensor):
                padded_batch[j] = padded_batch[j].to(searcher.device)
        
        # get embedding
        input_ids = padded_batch['input_ids']
        attention_mask = padded_batch['attention_mask']
        lengths = padded_batch['lengths']
        output = searcher.model.model(input_ids, attention_mask=attention_mask, output_hidden_states=False)
        emb = output['last_hidden_state']
        emb.masked_fill_(~attention_mask.bool().unsqueeze(-1), 0)
        # base = torch.arange(max_len, dtype=torch.long).expand(len(lengths), max_len).to(lengths.device)
        # pad_mask = base < lengths.unsqueeze(1) # pad token set to false
        # emb = (emb*pad_mask.unsqueeze(-1)).sum(dim=1) / pad_mask.sum(-1).unsqueeze(-1)
        emb = emb.sum(dim=1) / lengths.unsqueeze(-1)
        if args.normalize_embed:
            emb = emb / emb.norm(dim=1, keepdim=True)
            
        all_embeddings.append(emb)
        
    all_embeddings = np.concatenate(all_embeddings, axis=0)
    print(f"done embedding: {all_embeddings.shape}")
    
    if not os.path.exists(os.path.dirname(save_file)):
            os.makedirs(os.path.dirname(save_file))
    np.save(save_file, all_embeddings)

done embedding: (30755, 768)


In [73]:
# retrieve
source_id_map, target_id_map = {}, {}
with open(args.source_idx_file, 'r') as f:
    for idx, line in enumerate(f):
        source_id_map[idx] = line.strip()
with open(args.target_idx_file, 'r') as f:
    for idx, line in enumerate(f):
        target_id_map[idx] = line.strip()
source_embed = np.load(args.source_embed_save_file + '.npy')
target_embed = np.load(args.target_embed_save_file + '.npy')
assert len(source_id_map) == source_embed.shape[0]
assert len(target_id_map) == target_embed.shape[0]
indexer = faiss.IndexFlatIP(target_embed.shape[1])
indexer.add(target_embed)
print(source_embed.shape, target_embed.shape)
D, I = indexer.search(source_embed, args.top_k)

(2879, 768) (30755, 768)


In [82]:
results = dict()
for source_idx, (dist, retrieved_index) in enumerate(zip(D, I)):
    source_id = source_id_map[source_idx]
    results[source_id] = {}
    retrieved_target_id = [target_id_map[x] for x in retrieved_index]
    results[source_id]['retrieved'] = retrieved_target_id
    results[source_id]['score'] = dist.tolist()
    
with open(args.save_file, 'w+') as f:
    json.dump(results, f, indent=2)

In [86]:
# evaluate
d = json.load(open(args.oracle_eval_file, 'r'))
gold = [item['oracle_man'] for item in d]
r_d = json.load(open(args.save_file, 'r'))
pred = [r_d[x['question_id']]['retrieved'] for x in d]
top_k = [1, 3, 5, 8, 10, 12, 15, 20, 30, 50, 100, 200]

def calc_recall(src, pred, top_k, print_result=True):
    recall_n = {x: 0 for x in top_k}
    precision_n = {x: 0 for x in top_k}

    for s, p in zip(src, pred):
        # cmd_name = s['cmd_name']
        oracle_man = s
        pred_man = p

        for tk in recall_n.keys():
            cur_result_vids = pred_man[:tk]
            cur_hit = sum([x in cur_result_vids for x in oracle_man])
            # recall_n[tk] += cur_hit / (len(oracle_man) + 1e-10)
            recall_n[tk] += cur_hit / (len(oracle_man)) if len(oracle_man) else 1
            precision_n[tk] += cur_hit / tk
    recall_n = {k: v / len(pred) for k, v in recall_n.items()}
    precision_n = {k: v / len(pred) for k, v in precision_n.items()}

    if print_result:
        for k in sorted(recall_n.keys()):
            print(f"{recall_n[k] :.3f}", end="\t")
        print()
        for k in sorted(precision_n.keys()):
            print(f"{precision_n[k] :.3f}", end="\t")
        print()
        for k in sorted(recall_n.keys()):
            print(f"{2 * precision_n[k] * recall_n[k] / (precision_n[k] + recall_n[k] + 1e-10) :.3f}", end="\t")
        print()

    return {'recall': recall_n, 'precision': precision_n}

metrics = calc_recall(gold, pred, top_k)
print(metrics)

0.165	0.326	0.423	0.506	0.558	0.579	0.604	0.667	0.737	0.812	0.879	0.949	
0.254	0.182	0.140	0.105	0.095	0.083	0.069	0.059	0.044	0.029	0.016	0.009	
0.200	0.234	0.211	0.174	0.162	0.144	0.124	0.109	0.084	0.056	0.032	0.017	
{'recall': {1: 0.1654228855721393, 3: 0.32645107794361544, 5: 0.4234660033167495, 8: 0.5059701492537313, 10: 0.5582089552238806, 12: 0.5785240464344941, 15: 0.6042288557213931, 20: 0.6669983416252073, 30: 0.7372305140961858, 50: 0.8124378109452735, 100: 0.8793532338308457, 200: 0.949419568822554}, 'precision': {1: 0.2537313432835821, 3: 0.18242122719734652, 5: 0.14029850746268627, 8: 0.10509950248756218, 10: 0.0950248756218905, 12: 0.0825041459369819, 15: 0.0693200663349916, 20: 0.05920398009950251, 30: 0.044278606965174064, 50: 0.029154228855721342, 100: 0.016218905472636776, 200: 0.008631840796019878}}


In [3]:
import transformers
import numpy as np
import torch

In [4]:
class model_config():
    model_name = 'Salesforce/codet5-base'
    batch_size = 128
    top_k = 200
    sim_func = 'cls_distance.cosine'
    num_layers = 12
    normalize_embed = True
args = model_config()

In [5]:
tokenizer = transformers.RobertaTokenizer.from_pretrained(args.model_name)
model = transformers.T5EncoderModel.from_pretrained(args.model_name)
model.eval()

tokenizer_config.json:   0%|          | 0.00/1.48k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/703k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/294k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/12.5k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.57k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/892M [00:00<?, ?B/s]

T5EncoderModel(
  (shared): Embedding(32100, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32100, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=768, out_features=3072, bias=False)
              (wo): Linear(in_features=3072, out_features=768, bias=False)
              (dropout): Dropout(p=0.1, 

In [13]:
from dataset_utils import tldr_config
import json

tldr_args = tldr_config()
tldr_doc_whole = json.load(open(tldr_args.doc_whole, 'r'))
test_txt = tldr_doc_whole[next(iter(tldr_doc_whole))]

In [14]:
tokenizer.model_max_length

512

In [31]:
test_text = "notify-send - a program to send desktop notifications\nnotify-send [OPTIONS] {summary} [body]\nWith notify-send you can send desktop notifications to the user via a notification daemon from the command line. These notifications can be used to inform the user about an event or display some form of information without getting in the user\u00e2\u20ac\u2122s way.\n-?, --help Show help and exit.\n-u, --urgency=LEVEL Specifies the urgency level (low, normal, critical).\n-t, --expire-time=TIME The duration, in milliseconds, for the notification to appear on screen.\n(Ubuntu's Notify OSD and GNOME Shell both ignore this parameter.)\n-i, --icon=ICON[,ICON...] Specifies an icon filename or stock icon to display.\n-c, --category=TYPE[,TYPE...] Specifies the notification category.\n-h, --hint=TYPE:NAME:VALUE Specifies basic extra data to pass. Valid types are INT, DOUBLE, STRING and BYTE.\nThe Desktop Notification Spec on http://www.galago-project.org/specs/notification/.\nAndre Filipe de Assuncao e Brito <decko@noisemakers.org> Original author\nPaul van Tilburg <paulvt@debian.org> Original author\nRiccardo Setti <giskard@debian.org> Original author"
# tokenizer(batch, add_special_tokens=True, max_length=self.tokenizer.model_max_length, truncation=True)

In [35]:
test1_id = tokenizer(test_text, add_special_tokens=True, max_length=tokenizer.model_max_length, truncation=True)['input_ids']

In [37]:
print(test1_id)

[1, 12336, 17, 4661, 300, 279, 5402, 358, 1366, 21304, 9208, 203, 12336, 17, 4661, 306, 12422, 65, 288, 7687, 97, 306, 3432, 65, 203, 1190, 5066, 17, 4661, 1846, 848, 1366, 21304, 9208, 358, 326, 729, 3970, 279, 3851, 8131, 628, 326, 1296, 980, 18, 8646, 9208, 848, 506, 1399, 358, 13235, 326, 729, 2973, 392, 871, 578, 2562, 2690, 646, 434, 1779, 2887, 8742, 316, 326, 729, 132, 100, 163, 229, 110, 163, 231, 100, 87, 4031, 18, 203, 17, 35, 16, 1493, 5201, 9674, 2809, 471, 2427, 18, 203, 17, 89, 16, 1493, 295, 75, 2075, 33, 10398, 4185, 5032, 326, 8896, 75, 2075, 1801, 261, 821, 16, 2212, 16, 11239, 2934, 203, 17, 88, 16, 1493, 14070, 17, 957, 33, 4684, 1021, 3734, 16, 316, 10993, 16, 364, 326, 3851, 358, 9788, 603, 5518, 18, 203, 12, 57, 70, 25348, 1807, 10918, 5932, 40, 471, 611, 3417, 958, 19433, 3937, 2305, 333, 1569, 12998, 203, 17, 77, 16, 1493, 3950, 33, 21745, 63, 16, 21745, 2777, 65, 4185, 5032, 392, 4126, 1544, 578, 12480, 4126, 358, 2562, 18, 203, 17, 71, 16, 1493, 4743, 33, 23

In [38]:
test_text2 = "notify-send - a program to send desktop notifications\nnotify-send [OPTIONS] {summary} [body]"
test2_id = tokenizer(test_text2, add_special_tokens=True, max_length=tokenizer.model_max_length, truncation=True)['input_ids']

In [39]:
print(test2_id)

[1, 12336, 17, 4661, 300, 279, 5402, 358, 1366, 21304, 9208, 203, 12336, 17, 4661, 306, 12422, 65, 288, 7687, 97, 306, 3432, 65, 2]


In [15]:
with torch.no_grad():
    sent_features = tokenizer(test_txt, add_special_tokens=True, truncation=True)
    arr = sent_features['input_ids']
    
    # pad batch
    lens = torch.LongTensor([len(a) for a in arr])
    max_len = lens.max().item()
    padded = torch.ones(len(arr), max_len, dtype=torch.long) * self.tokenizer.pad_token_id
    mask = torch.zeros(len(arr), max_len, dtype=torch.long)
    for i, a in enumerate(arr):
        padded[i, : lens[i]] = torch.tensor(a, dtype=torch.long)
        mask[i, : lens[i]] = 1
        
    # no [cls] so get sentence embedding by pooling hidden state
    input_ids, attention_mask, lengths = padded.to(self.device), mask.to(self.device), lens.to(self.device)
    output = self.model.model(input_ids, attention_mask=attention_mask, output_hidden_states=True, max_length=max_len)

{'input_ids': [1, 4400, 12096, 17, 21675, 12, 21, 13, 10792, 2177, 18034, 10402, 4269, 12096, 17, 21675, 12, 21, 13, 203, 12336, 17, 4661, 300, 279, 5402, 358, 1366, 21304, 9208, 203, 12336, 17, 4661, 306, 12422, 65, 288, 7687, 97, 306, 3432, 65, 203, 1190, 5066, 17, 4661, 1846, 848, 1366, 21304, 9208, 358, 326, 729, 3970, 279, 3851, 8131, 628, 326, 1296, 980, 18, 8646, 9208, 848, 506, 1399, 358, 13235, 326, 729, 2973, 392, 871, 578, 2562, 2690, 646, 434, 1779, 2887, 8742, 316, 326, 729, 132, 100, 163, 229, 110, 163, 231, 100, 87, 4031, 18, 203, 17, 35, 16, 1493, 5201, 9674, 2809, 471, 2427, 18, 203, 17, 89, 16, 1493, 295, 75, 2075, 33, 10398, 4185, 5032, 326, 8896, 75, 2075, 1801, 261, 821, 16, 2212, 16, 11239, 2934, 203, 17, 88, 16, 1493, 14070, 17, 957, 33, 4684, 1021, 3734, 16, 316, 10993, 16, 364, 326, 3851, 358, 9788, 603, 5518, 18, 203, 12, 57, 70, 25348, 1807, 10918, 5932, 40, 471, 611, 3417, 958, 19433, 3937, 2305, 333, 1569, 12998, 203, 17, 77, 16, 1493, 3950, 33, 21745, 63, 

In [16]:
from dense_retriever import dense_retriever_config
model_name = 'Salesforce/codet5-base'
in_program_call = f"--dataset tldr --model_name {model_name}"
args = dense_retriever_config(in_program_call)

{
  "dataset": "tldr",
  "model_name": "Salesforce/codet5-base",
  "batch_size": 128,
  "top_k": 200,
  "sim_func": "cls_distance.cosine",
  "normalize_embed": false,
  "save_file": "docprompting_data/tldr/ret_results_codet5-base.json",
  "conala_qs_embed_save_file": "docprompting_data/conala/.tmp/src_embedding_codet5-base",
  "conala_doc_firstpara_save_file": "docprompting_data/conala/.tmp/tgt_embedding_codet5-base",
  "tldr_qs_embed_save_file": "docprompting_data/tldr/.tmp/src_embedding_codet5-base",
  "tldr_doc_whole_embed_save_file": "docprompting_data/tldr/.tmp/whole_embedding_codet5-base",
  "tldr_doc_line_embed_save_file": "docprompting_data/tldr/.tmp/line_embedding_codet5-base"
}


In [17]:
tldr_args = tldr_config()
tldr_key_list_whole = list(json.load(open(tldr_args.doc_whole, 'r')).keys())
source_embed = np.load(args.tldr_qs_embed_save_file + '.npy')
result_whole_save_file = args.save_file.replace('.json', '_whole.json')
doc_whole_embed = np.load(args.tldr_doc_whole_embed_save_file + '.npy')

In [22]:
import faiss
target_embed = doc_whole_embed
save_file = result_whole_save_file
target_key_list = tldr_key_list_whole

indexer = faiss.IndexFlatIP(target_embed.shape[1])
indexer.add(target_embed)
print(source_embed.shape, target_embed.shape)
D, I = indexer.search(source_embed, args.top_k)

(1845, 768) (1927, 768)


In [29]:
results = list()
dist, retrieved_index = D[0], I[0]
retrieved_target_key = [target_key_list[x] for x in retrieved_index]
for key, distance in zip(retrieved_target_key, dist):
    results.append(dict(lib_key=key, score=distance))
results

[{'lib_key': 'wp', 'score': 2.4240243},
 {'lib_key': 'p5', 'score': 2.2991867},
 {'lib_key': 'qpdf', 'score': 2.1964781},
 {'lib_key': 'asciiart', 'score': 2.1012466},
 {'lib_key': 'po4a-gettextize', 'score': 2.068434},
 {'lib_key': 'wpa_passphrase', 'score': 2.0590746},
 {'lib_key': 'wordgrinder', 'score': 2.0583205},
 {'lib_key': 'steghide', 'score': 2.0401578},
 {'lib_key': 'gofmt', 'score': 2.0368047},
 {'lib_key': 'lorem', 'score': 2.0340328},
 {'lib_key': 'steam', 'score': 2.0106273},
 {'lib_key': 'sphinx-build', 'score': 2.009172},
 {'lib_key': 'dexdump', 'score': 2.006337},
 {'lib_key': 'solo', 'score': 2.0050251},
 {'lib_key': 'banner', 'score': 1.9967657},
 {'lib_key': 'debman', 'score': 1.990351},
 {'lib_key': 'guetzli', 'score': 1.9849534},
 {'lib_key': 'ngrok', 'score': 1.982582},
 {'lib_key': 'assimp', 'score': 1.9684796},
 {'lib_key': 'ember', 'score': 1.9613961},
 {'lib_key': 'pdftotext', 'score': 1.9578973},
 {'lib_key': 'alex', 'score': 1.9551215},
 {'lib_key': 'tesse

In [140]:
import numpy as np
auth_src_embed_file = 'docprompting_data/conala/.tmp/src_embedding.npy'
src_embed_file = 'docprompting_data/conala/.tmp/src_embedding_docprompting-codet5-python-doc-retriever.npy'
auth_src_embed = np.load(auth_src_embed_file)
src_embed = np.load(src_embed_file)

In [143]:
def check_consistency_with_auth_src(embed):
    assert embed.shape == auth_src_embed.shape
    np.allclose(embed[0], auth_src_embed[0])
    
def check_consistency_with_src(embed):
    assert embed.shape == src_embed.shape
    np.allclose(embed[0], src_embed[0])

In [148]:
from dense_retriever import dense_retriever_config
model_name = 'neulab/docprompting-codet5-python-doc-retriever'
in_program_call = f"--dataset conala \
                    --model_name {model_name} \
                    --sim_func cls_distance.cosine \
                    --normalize_embed"
ret_config = dense_retriever_config(in_program_call)

{
  "dataset": "conala",
  "model_name": "neulab/docprompting-codet5-python-doc-retriever",
  "batch_size": 128,
  "top_k": 200,
  "sim_func": "cls_distance.cosine",
  "normalize_embed": true,
  "save_file": "docprompting_data/conala/ret_results_docprompting-codet5-python-doc-retriever.json",
  "conala_qs_embed_save_file": "docprompting_data/conala/.tmp/src_embedding_docprompting-codet5-python-doc-retriever",
  "conala_doc_firstpara_save_file": "docprompting_data/conala/.tmp/tgt_embedding_docprompting-codet5-python-doc-retriever",
  "tldr_qs_embed_save_file": "docprompting_data/tldr/.tmp/src_embedding_docprompting-codet5-python-doc-retriever",
  "tldr_doc_whole_embed_save_file": "docprompting_data/tldr/.tmp/whole_embedding_docprompting-codet5-python-doc-retriever",
  "tldr_doc_line_embed_save_file": "docprompting_data/tldr/.tmp/line_embedding_docprompting-codet5-python-doc-retriever"
}


In [128]:
from dense_encoder import CodeT5Retriever, denseEncoder

auth_encoder = CodeT5Retriever(args)
encoder = denseEncoder(args)

In [96]:
from dataset_utils import conala_config
conala_args = conala_config()
with open(conala_args.qs_file, "r") as f:
    dataset = []
    for line in f:
        dataset.append(line.strip())
print(dataset[0])

Create list `instancelist` containing 29 objects of type MyClass


In [130]:
with torch.no_grad():
    batch = dataset[0:args.batch_size]
    auth_sent_features = auth_encoder.tokenizer(batch, add_special_tokens=True,
                                                   max_length=auth_encoder.tokenizer.model_max_length, truncation=True)
    sent_features = encoder.tokenizer(batch, add_special_tokens=True, max_length=encoder.tokenizer.model_max_length, truncation=True)
    assert auth_sent_features['input_ids'] == sent_features['input_ids']

In [131]:
def tmp(sent_features, encoder):
    arr = sent_features['input_ids']
    # pad batch
    lens = torch.LongTensor([len(a) for a in arr])
    max_len = lens.max().item()
    padded = torch.ones(len(arr), max_len, dtype=torch.long) * encoder.tokenizer.pad_token_id
    mask = torch.zeros(len(arr), max_len, dtype=torch.long)
    for i, a in enumerate(arr):
        padded[i, : lens[i]] = torch.tensor(a, dtype=torch.long)
        mask[i, : lens[i]] = 1
        
    return padded, mask, lens

auth_input_ids, auth_attention_mask, auth_lengths = tmp(auth_sent_features, auth_encoder)
input_ids, attention_mask, lengths = tmp(sent_features, encoder)
# print(auth_input_ids)
# print(input_ids)
# print(auth_attention_mask)
# print(attention_mask)
# print(auth_lengths)
# print(lengths)

In [132]:
with torch.no_grad():
    auth_output = auth_encoder.model.model(input_ids, attention_mask=attention_mask, output_hidden_states=True)
    auth_emb = auth_output['last_hidden_state']
    output = encoder.model(input_ids, attention_mask=attention_mask, output_hidden_states=True)
    embed = output['last_hidden_state']

In [133]:
def tmp2(emb):
    emb.masked_fill_(~attention_mask.bool().unsqueeze(-1), 0)
    emb = emb.sum(dim=1) / lengths.unsqueeze(-1)
    emb = emb / emb.norm(dim=1, keepdim=True)
    emb = emb.cpu()
    return emb
auth_emb = tmp2(auth_emb)
emb = tmp2(embed)

In [134]:
auth_emb = auth_emb.detach().numpy()
emb = emb.detach().numpy()
print(np.allclose(auth_emb, emb))

True


In [None]:
auth_src_embed_file = 'docprompting_data/conala/.tmp/src_embedding.npy'
src_embed_file = 'docprompting_data/conala/.tmp/src_embedding_docprompting-codet5-python-doc-retriever.npy'
auth_src_embed = np.load(auth_src_embed_file)
src_embed = np.load(src_embed_file)

In [120]:
src_embed = src_embed[0:args.batch_size]
auth_src_embed = auth_src_embed[0:args.batch_size]

In [121]:
print(np.allclose(emb, auth_src_embed))

False


In [139]:
len(encoder.model.encoder.block)

12

In [136]:
print(src_embed[0])

[-1.51683511e-02 -6.61402717e-02 -5.66248782e-02  2.51631625e-02
 -4.06907052e-02 -2.15809513e-02 -6.44470900e-02 -3.19580697e-02
  3.80077437e-02 -2.18240060e-02  2.73901541e-02  5.24688326e-02
 -1.23803802e-02 -8.38407036e-03  4.24143206e-03 -1.99274137e-03
  4.68592420e-02  4.56570536e-02 -8.80548637e-03 -5.32950237e-02
 -1.20829698e-02 -2.99830791e-02  3.33988778e-02 -3.56720239e-02
  8.65500513e-03 -7.42584514e-03 -2.69307569e-02 -5.50832897e-02
 -5.25289476e-02  1.27503369e-02 -9.04892758e-03  8.46601836e-03
 -2.69015692e-02  2.20235735e-02 -1.11881895e-02  1.41660534e-02
 -2.83475574e-02  2.53256317e-02 -2.60235500e-02 -3.99435908e-02
 -2.30046585e-02  6.42091557e-02 -1.00606289e-02  5.75726619e-03
  7.37272725e-02 -3.19232941e-02  3.95331867e-02 -5.09065688e-02
  2.80672275e-02  5.77943847e-02 -2.93761343e-02  1.89341400e-02
  4.92721051e-03 -5.61733767e-02  2.16703769e-02 -4.99217510e-02
 -4.16233689e-02  6.39647171e-02 -2.95666065e-02 -9.43135750e-03
  3.32091888e-03  1.05230

In [137]:
print(auth_src_embed[0])

[-4.80976552e-02 -2.09725618e-01 -1.79553062e-01  7.97904208e-02
 -1.29027039e-01 -6.84315115e-02 -2.04356685e-01 -1.01336539e-01
  1.20519578e-01 -6.92022145e-02  8.68520364e-02  1.66374564e-01
 -3.92572172e-02 -2.65852306e-02  1.34492498e-02 -6.31882716e-03
  1.48586988e-01  1.44774944e-01 -2.79215090e-02 -1.68994352e-01
 -3.83141525e-02 -9.50739980e-02  1.05905227e-01 -1.13113195e-01
  2.74443459e-02 -2.35467739e-02 -8.53953212e-02 -1.74664810e-01
 -1.66565180e-01  4.04303223e-02 -2.86934432e-02  2.68450826e-02
 -8.53027701e-02  6.98350295e-02 -3.54768746e-02  4.49194461e-02
 -8.98878872e-02  8.03055987e-02 -8.25186446e-02 -1.26658008e-01
 -7.29459748e-02  2.03602210e-01 -3.19014676e-02  1.82558410e-02
  2.33783424e-01 -1.01226270e-01  1.25356644e-01 -1.61420748e-01
  8.89989808e-02  1.83261469e-01 -9.31494236e-02  6.00386783e-02
  1.56237995e-02 -1.78121388e-01  6.87150732e-02 -1.58297971e-01
 -1.31984442e-01  2.02827126e-01 -9.37533975e-02 -2.99060978e-02
  1.05303740e-02  3.33678

In [135]:
print(auth_emb[0])

[-1.24917468e-02 -7.73063898e-02 -5.53829446e-02  3.31767686e-02
 -3.81548591e-02 -1.72845758e-02 -3.36980224e-02 -1.79621112e-02
  3.13402936e-02 -4.63802330e-02 -2.34607905e-02  1.51392343e-02
  1.55846868e-02 -2.04078690e-03 -4.84605767e-02 -3.56928930e-02
  1.56200025e-02  8.02709013e-02  3.92882191e-02  2.55777650e-02
  2.59470213e-02 -5.27198724e-02  5.19483723e-02 -5.48982136e-02
 -5.29148765e-02 -2.20311363e-03  2.66518164e-02 -3.50047126e-02
 -5.46136759e-02 -2.41253246e-02 -1.06619939e-03 -4.07711826e-02
 -1.74031220e-02 -1.44696329e-02  2.79256310e-02 -2.13246848e-02
  1.92914177e-02 -2.82315928e-02  2.86931638e-02 -1.00778369e-02
  4.82489262e-03  1.87020004e-03 -2.43233163e-02  2.72888057e-02
  4.30642143e-02 -4.64540720e-02  3.73025984e-02 -6.98341131e-02
  8.19585919e-02  3.58498618e-02 -4.97059152e-02 -2.94583524e-03
 -2.73831449e-02 -4.93291281e-02  3.86470668e-02 -2.35444400e-02
 -1.92547981e-02  6.45612553e-02 -3.92575786e-02  1.10852933e-02
  3.35079990e-02  2.52354