In [1]:
from elasticsearch import Elasticsearch

es = Elasticsearch("http://localhost:9200")
es.info().body

{'name': '3587fef945a1',
 'cluster_name': 'docker-cluster',
 'cluster_uuid': 'Cv6QqXF_R_2HGTAJdqcWeA',
 'version': {'number': '8.7.0',
  'build_flavor': 'default',
  'build_type': 'docker',
  'build_hash': '09520b59b6bc1057340b55750186466ea715e30e',
  'build_date': '2023-03-27T16:31:09.816451435Z',
  'build_snapshot': False,
  'lucene_version': '9.5.0',
  'minimum_wire_compatibility_version': '7.17.0',
  'minimum_index_compatibility_version': '7.0.0'},
 'tagline': 'You Know, for Search'}

In [3]:
es.indices.create(index="test")

ObjectApiResponse({'acknowledged': True, 'shards_acknowledged': True, 'index': 'test'})

In [5]:
import datasets, evaluate
tldr = datasets.load_dataset("neulab/tldr")
tldr_metric = evaluate.load("neulab/tldr_eval")

In [10]:
tldr['test'][0]

{'question_id': '21',
 'nl': 'delete a shared memory segment by id',
 'cmd': 'ipcrm --shmem-id {{shmem_id}}',
 'oracle_man': ['ipcrm_4', 'ipcrm_12'],
 'cmd_name': 'ipcrm',
 'tldr_cmd_name': 'ipcrm',
 'manual_exist': True,
 'matching_info': {'token': ['--shmem-id', '|main|'],
  'oracle_man': [['ipcrm_12'],
   ['ipcrm_4', 'ipcrm_5', 'ipcrm_6', 'ipcrm_7', 'ipcrm_8']]}}

In [11]:
import json
import tqdm

doc_path = "docprompting_data/tldr/manual_all_raw.json"
tldr_doc_raw = json.load(open(doc_path, 'r'))

In [32]:
processed_docs = list()
for lib_key, lib_man in tldr_doc_raw.items():
    cmd_name = '_'.join(lib_key.split("_")[:-1]) if lib_key[-1].isdigit() else lib_key
    processed_docs.append(dict(_index="test", manual=lib_man, lib_key=lib_key, cmd_name=cmd_name))

In [33]:
from elasticsearch.helpers import bulk

bulk(es, processed_docs, index="test")

(1927, [])

In [49]:
question_file = 'docprompting_data/tldr/cmd_dev.seed.json'
questions = json.load(open(question_file, 'r'))
res_list = list()
for question in questions:
    query = {'query':
                 {'match': 
                      {'manual': question['nl']}},
             'size': 35}
    res = es.search(index="test", body=query)['hits']['hits']
    _res = list()
    for item in res:
        _res.append({'lib_key': item['_source']['lib_key'], 'score': item['_score']})
    res_list.append(_res)

In [53]:
save_file = 'docprompting_data/tldr/test_dev_whole_retrieve_result.json'
with open(save_file, 'w+') as f:
    json.dump(res_list, f, indent=2)

In [56]:
auth_cmd_name_list = [question['cmd_name'] for question in questions]
retrieved_result = list()
for res in res_list:
    retrieved_result.append([item['lib_key'] for item in res])

In [61]:
def calc_hit(src, pred, top_k=None):
    top_k = TOP_K if top_k is None else top_k
    hit_n = {x: 0 for x in top_k}
    assert len(src) == len(pred), (len(src), len(pred))

    for s, p in zip(src, pred):
        cmd_name = s
        pred_man = p

        for tk in hit_n.keys():
            cur_result_vids = pred_man[:tk]
            cur_hit = any([cmd_name in x for x in cur_result_vids])
            hit_n[tk] += cur_hit

    hit_n = {k: v / len(pred) for k, v in hit_n.items()}
    for k in sorted(hit_n.keys()):
        print(f"{hit_n[k] :.3f}", end="\t")
    print()
    return hit_n

In [63]:
# how to reduce the influence of retrieval error?
calc_hit(auth_cmd_name_list, retrieved_result, top_k=[1, 3, 5, 10, 15, 20, 30])

0.328	0.455	0.517	0.598	0.640	0.661	0.703	


{1: 0.32791327913279134,
 3: 0.45474254742547426,
 5: 0.5170731707317073,
 10: 0.5978319783197832,
 15: 0.6395663956639567,
 20: 0.6612466124661247,
 30: 0.7029810298102981}

In [None]:
# line-level retrieve (retrieve can be improved)
top_k_retrieve = 5
for r1_res, question in zip(res_list, questions):
    retrieved_cmd_list = [item['lib_key'] for item in r1_res][:top_k_retrieve]
    for cmd in retrieved_cmd_list:
        query = {'query':
                    {'bool':
                        {'must':
                            [{'term': {'cmd_name.keyword': cmd}},
                             {'match': {'manual': question['nl']}}]
                        }
                    },
                'size': 30
                }
        res = es.search(index='test', body=query)['hits']['hits']