In [15]:
import os, sys
import json
import re
from multiprocessing import Pool
import time
import Levenshtein
from tqdm import tqdm

In [2]:
input_folder = '{path_to_libgen}/libgen/VulLib'
train_path = os.path.join(input_folder, 'train.json')
valid_path = os.path.join(input_folder, 'valid.json')
test_path = os.path.join(input_folder, 'test.json')
maven_path = os.path.join(input_folder, 'maven_corpus_new.json')

In [3]:
with open(train_path, 'r') as f:
    train = json.load(f)
with open(valid_path, 'r') as f:
    valid = json.load(f)
with open(test_path, 'r') as f:
    test = json.load(f) 
vulns = train + valid + test

In [5]:
with open(maven_path, 'r') as f:
    maven_corpus = json.load(f)
lib_names = set([lib['name'] for lib in maven_corpus])

In [6]:
artifacts = {item.split(':')[-1]: set() for item in lib_names}
for item in lib_names:
    components = item.split(':')
    artifacts[components[-1]].add(item)

In [7]:
with open('response/1_res_gpt-4-1106-preview.json', 'r') as f:
    res = json.load(f)

In [9]:
maven_regex = 'maven:[-.\w]+:[-.\w]+'
for item in res:
    item['top_k'] = re.findall(maven_regex, item['top_res'][0])
    item['rerank_k'] = re.findall(maven_regex, item['rerank_res'][0])

In [55]:
weights = (1, 4, 4)

def cloest_artifact(artifact_id):
    global artifacts, weights
    if artifact_id in artifacts:
        return artifact_id

    distances = [(Levenshtein.distance(artifact_id, item,\
                    weights = weights), item) for item in artifacts]
    return min(distances)[1]

def cloest_group(group_id, groups):
    if len(groups) == 0:
        return group_id
    if len(groups) == 1:
        return next(iter(groups))
    
    global weights
    distances = [(Levenshtein.distance(group_id, item.split(':')[-2],\
                    weights = weights), item) for item in groups]
    return min(distances)[1]


def closest_lib(label):
    global lib_names
    if label in lib_names:
        return label
    if len(label.split(':')) > 1:
        group_id, artifact_id = label.split(':')[-2], label.split(':')[-1]
    else:
        group_id, artifact_id = "", label.split(':')[-1]
    if artifact_id in artifacts:
        return cloest_group(group_id, artifacts[artifact_id])
    else:
        advanced_artifact_id = cloest_artifact(artifact_id)
        return cloest_group(group_id, artifacts[advanced_artifact_id])

In [37]:
target = [lib for vuln in res for lib in vuln['rerank_k']]

In [52]:
with Pool(processes=32) as pool:
    result = list(tqdm(pool.imap(closest_lib, target)))

3623it [00:03, 915.13it/s] 


In [39]:
def precision(vuln, pred, k):
    labels = vuln['labels']
    if len(labels) == 0:
        return None
    inter = set(labels) & set(pred[:k])
    return len(inter) / min(k, len(labels))

def recall(vuln, pred, k):
    labels = vuln['labels']
    if len(labels) == 0:
        return None
    inter = set(labels) & set(pred[:k])
    return len(inter) / len(labels)

def f1_score(p, r):
    return 2*p*r/(p+r)

In [56]:
idx = 0
for vuln in res:
    vuln['rerank_k_post'] = []
    for lib in vuln['rerank_k']:
        vuln['rerank_k_post'].append('maven:'+result[idx])
        idx = idx + 1

In [57]:
k = 1
p = [precision(vuln, response['rerank_k_post'], k) for (vuln, response) in zip(vulns, res) if precision(vuln, response['top_k'], 1) != None]
r = [recall(vuln, response['rerank_k_post'], k) for (vuln, response) in zip(vulns, res) if recall(vuln, response['top_k'], 1) != None]
sum(p) / len(p), sum(r) / len(r), f1_score(sum(p) / len(p), sum(r) / len(r))

(0.754392255288634, 0.5723564886436876, 0.6508863178076171)

In [59]:
res[0]['rerank_k']

['maven:org.jgroups:jgroups']

##### Lack/More Token

In [77]:
def is_sub_sequence(src, dst):
    for 

def is_lack_more_token(raw_name, target_name):
    prefix, raw_group, raw_artifact = raw_name.split(':')
    prefix, target_group, target_artifact = target_name.split(':')
    raw_group_tokens = re.split('-|\.', raw_group)
    raw_artifact_tokens = re.split('-|\.', raw_artifact)
    print(raw_group_tokens, raw_artifact_tokens)

In [78]:
is_lack_more_token(vulns[0]['labels'][0], res[0]['rerank_k'][0])

['jgroups'] ['jgroups', 'all']


In [74]:
re.split('-|\.', 'maven:org.jgroups:jgroups-wow.as-wow')

['maven:org', 'jgroups:jgroups', 'wow', 'as', 'wow']