In [7]:
config = {
    'MAX_LEN' : 200, #max length of a sequence 
    'bs' : 16, #batch size
    'substitue' : 'X', #substitute tag
    'name' : 'bert-base-uncased', #model name
    'do_lower_case' : True, #lower case
    'fold_num' : 5, #number of fold to cross validation
    'hidden_size' : 768, #input size of classifier
    'dropout' : 0.1, #dropout rate
    'decay' : 1e-5, #weight decay
    'lr' : 4e-5, #learning rate
    'mu' : 5, #parameter in computing weights
    'FULL_FINETUNING' : True, #if finetune bert or not
    'max_grad_norm' : 1.0, #grad clipping
    'epochs' : 50, #training epochs
    'period' : 10, #period of display training results
}

In [8]:
import pandas as pd
import numpy as np
from tqdm import tqdm, trange
import pickle as pk

#add pname at the beginning of the sentence
def get_groups(data, if_add=False, cve_cpe_pnames=None):
    def add(pname):
        if len(pname)==0:
            return [], []
        add_sent = ['Product','name','is']
        add_label = ['O','O','O']
        #only consider pname without '_'
        pname = [i for i in pname if '_' not in i]

        for i in range(len(pname)):
            spl_name = pname[i].split()
            add_sent.extend(spl_name)
            add_label.extend(['pn']*len(spl_name))

            if i!=len(pname)-1:
                add_sent.append(',')
            else:
                add_sent.append('.')
            add_label.append('O')
        return add_sent, add_label

    def agg_func(s):
        if if_add:
            add_sent, add_label = add(cve_cpe_pnames[s.name[0]])
        else:
            add_sent, add_label = [], []
        new_sent = add_sent + s["token"].values.tolist()
        new_tag = add_label + s["label"].values.tolist()
        return [(w, t) for w, t in zip(new_sent, new_tag)]


    grouped = data.groupby(['sent_ind','cve_sent_ind']).apply(agg_func)
    words = [w for w in grouped]
    return words

def read_data(path):
    data = pd.read_csv(path, encoding="latin1").fillna(method="ffill")
    count_label = data.groupby('label')['sent_ind'].count()
    # sentence id
    i=0
    cve_sent_ind = 0
    sent_ind = data.loc[0]['sent_ind']
    while i<len(data):
        d = data.loc[i]
        if d['sent_ind']==sent_ind:
            data.loc[i, 'cve_sent_ind']=cve_sent_ind
            if d['token']=='.':
                cve_sent_ind=cve_sent_ind+1
            i+=1
        else:
            sent_ind = d['sent_ind']
            cve_sent_ind = 0
            
    with open('data/cpe.pkl','rb') as f:
        cve_cpe_pnames,cve_cpe_vendors = pk.load(f)
    words = add_pname(data, cve_cpe_pnames)
    sentences = [" ".join([s[0] for s in sent]) for sent in words]
    labels = [[s[1] for s in sent] for sent in words]
    substitue = config['substitue']
    tags_vals = list(set(data["label"].values)) + [substitue]
    tag2idx = {t: i for i, t in enumerate(tags_vals)}
    return words, sentences, labels, tags_vals, tag2idx


In [9]:
from pytorch_pretrained_bert import BertTokenizer, BertConfig, BertModel
from keras.preprocessing.sequence import pad_sequences
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler



def vectorization(config, sentences, labels, tags_vals, tag2idx):
    #use bert tokenization and substitute label
    #vectorize and pad dataset
    tokenizer = BertTokenizer.from_pretrained(config['name'], do_lower_case=config['do_lower_case'])

    mytexts = []
    mylabels = []
    for sent, tags in zip(sentences,labels):
        BERT_texts = []
        BERT_labels = np.array([])
        for word, tag in zip(sent.split(),tags):
            sub_words = tokenizer.tokenize(word)
            n_underscore = sub_words.count('_') 
            for i in range(n_underscore):
                sub_words.remove('_')
            tags = np.array([tag for x in sub_words])
            tags[1:] = config['substitue']
            BERT_texts += sub_words
            BERT_labels = np.append(BERT_labels,tags)
        mytexts.append(BERT_texts)
        mylabels.append(BERT_labels)

    l = 0
    for w in mytexts:
        if len(w)>l:
            l = len(w)
    print('The longest sentence has {} tokens.'.format(l))

    MAX_LEN = config['MAX_LEN']
    #padding data
    input_ids = pad_sequences([tokenizer.convert_tokens_to_ids(txt) for txt in mytexts],
                              maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")
    tags = pad_sequences([[tag2idx.get(l) for l in lab] for lab in mylabels],
                         maxlen=MAX_LEN, value=tag2idx["O"], padding="post",
                         dtype="long", truncating="post")
    attention_masks = np.array([[float(i>0) for i in ii] for ii in input_ids])
    data_fold = (input_ids, tags, attention_masks)
    return data_fold

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


Using TensorFlow backend.


In [10]:
def myDataLoader(data_fold, train_index, test_index):
    bs = config['bs']
    input_ids, tags, attention_masks = data_fold
    
    tr_inputs, val_inputs = input_ids[train_index], input_ids[test_index]
    tr_tags, val_tags = tags[train_index], tags[test_index]
    tr_masks, val_masks = attention_masks[train_index], attention_masks[test_index]
    
    tr_inputs = torch.tensor(tr_inputs)
    val_inputs = torch.tensor(val_inputs)
    tr_tags = torch.tensor(tr_tags)
    val_tags = torch.tensor(val_tags)
    tr_masks = torch.tensor(tr_masks)
    val_masks = torch.tensor(val_masks)
    
    train_data = TensorDataset(tr_inputs, tr_masks, tr_tags)
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=bs, drop_last=False)

    valid_data = TensorDataset(val_inputs, val_masks, val_tags)
    valid_sampler = SequentialSampler(valid_data)
    valid_dataloader = DataLoader(valid_data, sampler=valid_sampler, batch_size=bs, drop_last=False)
    dataloader = (train_dataloader, valid_dataloader)
    
    count = np.unique(tr_tags, return_counts=True)[1]
    return dataloader, count

In [11]:
from pytorch_pretrained_bert import BertTokenizer, BertConfig, BertModel
from pytorch_pretrained_bert import BertForTokenClassification, BertAdam

def BuildModel(config, weight=None):
    # change the forward method: do not consider 'X' when computing loss
    def new_forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, weight=weight):
        sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        if labels is not None:
            if weight is not None:
                weight = weight.to(torch.float).to(config['device'])
            loss_fct = nn.CrossEntropyLoss(weight=weight, ignore_index=self.num_labels-1)
            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = (attention_mask.view(-1) == 1) #* (labels.view(-1) != self.num_labels-1)
                active_logits = logits.view(-1, self.num_labels)[active_loss]
                active_labels = labels.view(-1)[active_loss]
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            return loss
        else:
            return logits
    BertForTokenClassification.forward = new_forward
    model = BertForTokenClassification.from_pretrained(config['name'], num_labels=len(tag2idx))
    model.to(config['device'])
    
    return model

In [12]:
from seqeval.metrics import precision_score, recall_score
from sklearn.metrics import confusion_matrix ,f1_score, accuracy_score, classification_report
def test(config, model, dataloader, validation = False):
    #dataloader is only validation data or test data
    model.eval()
    eval_loss, eval_accuracy = 0, 0
    nb_eval_steps, nb_eval_examples = 0, 0
    predictions , true_labels = [], []
    for batch in dataloader:
        batch = tuple(t.to(config['device']) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        
        with torch.no_grad():
            tmp_eval_loss = model(b_input_ids, token_type_ids=None,
                                  attention_mask=b_input_mask, labels=b_labels)
            logits = model(b_input_ids, token_type_ids=None,
                           attention_mask=b_input_mask)
        
        active = ((b_input_mask.view(-1) == 1) * (b_labels.view(-1) != len(tag2idx)-1))
        active_logits = logits.view(-1, len(tag2idx))[active].cpu().numpy()
        active_labels = b_labels.view(-1)[active].cpu().numpy()
        pred_labels = np.argmax(active_logits, axis=1)
#         label_ids = b_labels.to('cpu').numpy()
#         logits = logits.detach().cpu().numpy()
        predictions.append(pred_labels)
        true_labels.append(active_labels)
        
#         tmp_eval_accuracy = np.sum(pred_labels == active_labels) / len(active_labels)
        
        eval_loss += tmp_eval_loss.mean().item()
#         eval_accuracy += tmp_eval_accuracy
        
        nb_eval_examples += b_input_ids.size(0)
        nb_eval_steps += 1
    eval_loss = eval_loss/nb_eval_steps
#     eval_accuracy = eval_accuracy/nb_eval_steps
    predictions = np.concatenate(predictions)
    true_labels = np.concatenate(true_labels)
    
    eval_accuracy = accuracy_score(true_labels, predictions, normalize=True, sample_weight=None)
    f1 = f1_score(true_labels, predictions, average='macro')
    if validation==True:
        return eval_loss, eval_accuracy, f1
    else:
        print("Test loss: {}".format(eval_loss))
        print("Test Accuracy: {}".format(eval_accuracy))
        print("micro F1-Score: {}".format(f1_score(true_labels, predictions, average='micro',)))
        print("macro F1-Score: {}".format(f1))
        print("weighted F1-Score: {}".format(f1_score(true_labels, predictions, average='weighted')))
        
        pred_tags = [tags_vals[p] for p in predictions]
        valid_tags = [tags_vals[l] for l in true_labels]
        counts = [valid_tags.count(tag) for tag in tags_vals]
        cfs_mat = confusion_matrix(valid_tags, pred_tags,tags_vals)
        cfs_with_index = pd.DataFrame(cfs_mat, index = tags_vals,
                      columns = tags_vals)
        cfs_mat_norm = cfs_mat/cfs_mat.sum(axis=1, keepdims = True)
        cfs_with_index_norm = pd.DataFrame(cfs_mat_norm, index = tags_vals,
                      columns = tags_vals)
        print('')
        print('test counts:')
        print(pd.DataFrame(tags_vals,counts))
        print('')
        print(classification_report(valid_tags, pred_tags))
        print('')
        print('Confusion matrix:')
        print(cfs_with_index)
        sn.heatmap(cfs_with_index_norm)
        print('')
        return predictions, true_labels, eval_loss, eval_accuracy, f1

In [13]:
from torch.optim import Adam
import matplotlib.pyplot as plt
import seaborn as sn
from copy import deepcopy

def train(config, model, dataloader, if_plot=True, fold_id=None):
    #the dataloader is the combination of training data and validation data
    epochs = config['epochs']
    max_grad_norm = config['max_grad_norm']
    period = config['period']
    FULL_FINETUNING = config['FULL_FINETUNING']
    
    if FULL_FINETUNING:
        
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'gamma', 'beta']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
             'weight_decay_rate': config['decay']},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
             'weight_decay_rate': 0.0}
        ]
    else:
        param_optimizer = list(model.classifier.named_parameters()) 
        optimizer_grouped_parameters = [{"params": [p for n, p in param_optimizer]}]
    
    optimizer = Adam(optimizer_grouped_parameters, lr=config['lr'])
    
    tr_loss_list = []
    eval_loss_list = []
    eval_acc_list = []
    f1_list = []
    max_acc = 0
    max_f1 = 0
    
    train_dataloader, valid_dataloader = dataloader
    eval_loss, eval_accuracy, f1 = test(config, model, dataloader=valid_dataloader, validation=True)
    # print train loss per epoch
    print('Epoch: {}'.format(0))
    # VALIDATION on validation set
    print("Validation loss: {}".format(eval_loss))
    print("Validation Accuracy: {}".format(eval_accuracy))
    print("F1-Score: {}".format(f1))
    print('')
    
    for epoch in range(1, epochs+1):
        # TRAIN loop
        model.train()
        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0

        for step, batch in enumerate(train_dataloader):
            # add batch to gpu
            batch = tuple(t.to(config['device']) for t in batch)
            b_input_ids, b_input_mask, b_labels = batch
            # forward pass
            loss = model(b_input_ids, token_type_ids=None,
                         attention_mask=b_input_mask, labels=b_labels)
            # backward pass
            loss.backward()
            # track train loss
            tr_loss += loss.item()
            nb_tr_examples += b_input_ids.size(0)
            nb_tr_steps += 1
            # gradient clipping
            torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=max_grad_norm)
            # update parameters
            optimizer.step()
            model.zero_grad()
        tr_loss = tr_loss/nb_tr_steps
        eval_loss, eval_accuracy, f1 = test(config, model, valid_dataloader, validation = True)
        
        if f1>max_f1:
            max_acc = eval_accuracy
            max_f1 = f1
            best_model = deepcopy(model)
            
        tr_loss_list.append(tr_loss)
        eval_loss_list.append(eval_loss)
        eval_acc_list.append(eval_accuracy)
        f1_list.append(f1)
        
        if epoch % period == 0:
            # print train loss per epoch
            print('Epoch: {}'.format(epoch))
            print("Train loss: {}".format(tr_loss))
            # VALIDATION on validation set
            print("Validation loss: {}".format(eval_loss))
            print("Validation Accuracy: {}".format(eval_accuracy))
            print("F1-Score: {}".format(f1))
            print('')
            
    
    print('The best result: ')
    print('Validation Accuracy: {}, F1-Score: {}'.format(max_acc, max_f1))
    
    if if_plot:
#     pk.dump((tr_loss_list, eval_loss_list, eval_acc_list, f1_list), open("results/train_result.pkl",'wb'))
    
        ax1=plt.subplot(1, 3, 1)
        ax2=plt.subplot(1, 3, 2)
        ax3=plt.subplot(1, 3, 3)

        ax1.plot(tr_loss_list)
        ax1.plot(eval_loss_list)

        ax2.plot(eval_acc_list)

        ax3.plot(f1_list)
        plt.show()
        plt.savefig('results/train_img{}.png'.format(fold_id))
    
    return best_model, max_acc, max_f1

In [14]:
with open('data/cve_desc.pickle', 'rb') as f:
    cve_desc = pk.load(f)


In [15]:
data = pd.read_csv("data/dataset.csv", encoding="latin1").fillna(method="ffill")
count_label = data.groupby('label')['sent_ind'].count()
# # sentence id
# i=0
# cve_sent_ind = 0
# sent_ind = data.loc[0]['sent_ind']
# while i<len(data):
#     d = data.loc[i]
#     if d['sent_ind']==sent_ind:
#         data.loc[i, 'cve_sent_ind']=cve_sent_ind
#         if d['token']=='.':
#             cve_sent_ind=cve_sent_ind+1
#         i+=1
#     else:
#         sent_ind = d['sent_ind']
#         cve_sent_ind = 0
        
# words = get_groups(data, if_add=False, cve_cpe_pnames = cve_cpe_pnames)
# sentences = [" ".join([s[0] for s in sent]) for sent in words]
# labels = [[s[1] for s in sent] for sent in words]
# substitue = config['substitue']
# tags_vals = list(set(data["label"].values)) + [substitue]
# tag2idx = {t: i for i, t in enumerate(tags_vals)}

# grouped = data.groupby(['sent_ind','cve_sent_ind'])
# # grouped = [g for g in grouped]
# words_wo_tag = [[w[0] for w in s] for s in words]
# words_group = list(zip(grouped.groups.keys(),words_wo_tag))

In [16]:
cve_desc = {cve_id:{'description' : cve_info['description'].lower(),
                   'cpes' : [cpe.lower() for cpe in cve_info['cpes'] if cpe[8]=='a']} for (cve_id,cve_info) in cve_desc.items()}

cve_desc = {cve_id:cve_desc[cve_id] for cve_id in cve_desc if cve_desc[cve_id]['cpes']}


In [19]:
cve_desc

{'CVE-2005-3331': {'description': 'viewpatch in mgdiff 1.0 allows local users to overwrite arbitrary files via a symlink attack on temporary files.',
  'cpes': ['cpe:2.3:a:rogers_software_source:mgdiff_patch_viewer:1.0:*:*:*:*:*:*:*']},
 'CVE-2005-3330': {'description': 'the _httpsrequest function in snoopy 1.2, as used in products such as (1) magpierss, (2) wordpress, (3) ampache, and (4) jinzora, allows remote attackers to execute arbitrary commands via shell metacharacters in an https url to an ssl protected web page, which is not properly handled by the fetch function.',
  'cpes': ['cpe:2.3:a:snoopy:snoopy:1.2:*:*:*:*:*:*:*']},
 'CVE-2005-3333': {'description': 'sql injection vulnerability in ebaseweb 3.0 allows remote attackers to execute arbitrary sql commands via unknown attack vectors.',
  'cpes': ['cpe:2.3:a:ebase:ebaseweb:3.0:*:*:*:*:*:*:*']},
 'CVE-2005-3332': {'description': 'php remote file include vulnerability in admin/define.inc.php in belchior foundry vcard 2.9 allows 

In [20]:
descriptions = {i:cve_desc[i]['description'] for i in cve_desc}

In [121]:
import re
cve_wo_js = {}
cve_w_js = set()
js_dict = {}
# pop_package = ['Angular', 'React', 'Vue', 'Ember', 'Meteor', 'Mithril', 'Node', 'Polymer', 'Aurelia', 'Backbone']
for cve_id in cve_desc:
    for cpe in cve_desc[cve_id]['cpes']:
        js_name_list = re.findall('[\w\-\.]*\.js(?!p|on)', cpe)
        if js_name_list:
            for js_name in js_name_list:
                cve_w_js.add(cve_id)
                js_name = js_name[:-3]
                if js_name not in js_dict:
                    js_dict[js_name] = []
                js_dict[js_name].append(cve_id)
        else:
            cve_wo_js[cve_id] = cve_desc[cve_id]
            
    desc = cve_desc[cve_id]['description']
    js_name_list = re.findall('[\w\-\.]*\.js(?!p|on)', desc)
    if js_name_list:
        for js_name in js_name_list:
            cve_w_js.add(cve_id)
            js_name = js_name[:-3]
            if js_name not in js_dict:
                js_dict[js_name] = []
            js_dict[js_name].append(cve_id)
    else:
        cve_wo_js[cve_id] = cve_desc[cve_id]

In [122]:
re.findall('[\w\-\.]*\.js(?!p|on)', cve_desc['CVE-2017-6818']['description'])

['tags-box.js']

In [123]:

for cve_id in cve_w_js:
    if 'js' in cve_desc[cve_id]['description']:
        print(cve_desc[cve_id]['description'],'\n',cve_desc[cve_id]['cpes'],'\n')

the microsoft azure active directory passport (aka passport-azure-ad) library 1.x before 1.4.6 and 2.x before 2.0.1 for node.js does not recognize the validateissuer setting, which allows remote attackers to bypass authentication via a crafted token. 
 ['cpe:2.3:a:microsoft:azure_active_directory_passport:1.0.0:*:*:*:*:*:*:*', 'cpe:2.3:a:microsoft:azure_active_directory_passport:1.1.0:*:*:*:*:*:*:*', 'cpe:2.3:a:microsoft:azure_active_directory_passport:1.1.1:*:*:*:*:*:*:*', 'cpe:2.3:a:microsoft:azure_active_directory_passport:1.2.0:*:*:*:*:*:*:*', 'cpe:2.3:a:microsoft:azure_active_directory_passport:1.3.0:*:*:*:*:*:*:*', 'cpe:2.3:a:microsoft:azure_active_directory_passport:1.3.1:*:*:*:*:*:*:*', 'cpe:2.3:a:microsoft:azure_active_directory_passport:1.3.2:*:*:*:*:*:*:*', 'cpe:2.3:a:microsoft:azure_active_directory_passport:1.3.3:*:*:*:*:*:*:*', 'cpe:2.3:a:microsoft:azure_active_directory_passport:1.3.4:*:*:*:*:*:*:*', 'cpe:2.3:a:microsoft:azure_active_directory_passport:1.3.5:*:*:*:*:*:*:

In [124]:
len(cve_desc)

93833

In [125]:
len(cve_wo_js)

93703

In [126]:
cve_desc['CVE-2010-3355']

{'description': 'ember 0.5.7 places a zero-length directory name in the ld_library_path, which allows local users to gain privileges via a trojan horse shared library in the current working directory.',
 'cpes': ['cpe:2.3:a:erik_hjortsberg:ember:0.5.7:*:*:*:*:*:*:*']}

In [127]:
js_dict_count = {k:len(v) for (k,v) in js_dict.items()}
js_dict_count

{'nsloginmanagerprompter': 1,
 'node': 1388,
 'mapbox': 2,
 'login': 2,
 'globals': 2,
 'swfobject': 1,
 'httplisteneredit': 2,
 '': 11,
 'serc': 4,
 'datachannel': 1,
 'rtcmulticonnection': 1,
 'xterm': 5,
 'share': 2,
 'viewcode': 1,
 'spaw_script': 1,
 'horizon.instances': 1,
 'clickstream': 1,
 'functions': 6,
 'normalization': 2,
 'printer': 1,
 'index': 9,
 'switch': 1,
 'regress-410192': 1,
 'javascript-static': 1,
 'sshprofiles': 2,
 'encryptionprofiles': 2,
 'sqlite': 3,
 'mssql': 3,
 'flaudio': 2,
 'flvideo': 2,
 'build': 2,
 'buildutil': 2,
 'ta_loaded': 1,
 'sftp': 1,
 'controlsocket': 1,
 'relatedobjectlookups': 1,
 'applications': 1,
 'configuration': 1,
 'custommbeans': 1,
 'resources': 2,
 'registration': 1,
 'webservicesgeneral': 1,
 'auditmoduleedit': 1,
 'jdbcresourceedit': 1,
 'next': 23,
 'asm': 4,
 'webchannel': 1,
 'og': 1,
 'nw': 4,
 'ember': 128,
 'handlebars': 2,
 'mustache': 2,
 'xss': 1,
 'comment2': 1,
 'public': 4,
 'plugin': 1,
 'textangular-sanitize': 1,

In [6]:
js_dict.keys(ym)

NameError: name 'js_dict' is not defined

In [19]:
js_dict['']

['CVE-2009-2664',
 'CVE-2008-0124',
 'CVE-2018-6389',
 'CVE-2017-14749',
 'CVE-2017-17092',
 'CVE-2006-5031',
 'CVE-2009-2419',
 'CVE-2010-3155',
 'CVE-2017-16139',
 'CVE-2005-4855',
 'CVE-2007-1044']

In [28]:
import re
word = 'jquery'
a = [k for (k,v) in descriptions.items() if re.search('(?<!\w)'+word+'(?!\.js|\w)',v)]
b = [k for (k,v) in descriptions.items() if re.search('(?<!\w)'+word+'\.js(?!on|p)',v)]

In [29]:
print(len(a))
print(len(b))

32
1


In [34]:
# i=-1
i+=1
cve_desc[a[i]]

{'description': 'multiple cross-site request forgery (csrf) vulnerabilities in the crossslide jquery (crossslide-jquery-plugin-for-wordpress) plugin 2.0.5 for wordpress allow remote attackers to hijack the authentication of administrators for requests that (1) change plugin settings or conduct cross-site scripting (xss) attacks via the (2) csj_width, (3) csj_height, (4) csj_sleep, (5) csj_fade, or (6) upload_image parameter in the thisismyurl_csj.php page to wp-admin/options-general.php.',
 'cpes': ['cpe:2.3:a:crossslide_jquery_project:crossslide_jquery:2.0.5:*:*:*:*:wordpress:*:*']}

In [36]:
a[i]

'CVE-2015-2089'

In [134]:
x = [k for (k,v) in descriptions.items() if re.search('(?<!\w)java\s*script(?!\w)',v)]
print(len(x))

2054


In [23]:
any_js = set()
for cve_id in descriptions:
    desc = descriptions[cve_id]
    if re.search('js', desc):
        any_js.add(cve_id)
        print(cve_id)
        print(cve_desc[cve_id])
        print()

CVE-2015-6055
{'description': 'the microsoft (1) vbscript 5.7 and 5.8 and (2) jscript 5.7 and 5.8 engines, as used in internet explorer 8 through 11 and other products, allow remote attackers to execute arbitrary code or cause a denial of service (memory corruption) via crafted filter arguments, aka "scripting engine memory corruption vulnerability."', 'cpes': ['cpe:2.3:a:microsoft:jscript:5.6:*:*:*:*:*:*:*', 'cpe:2.3:a:microsoft:jscript:5.7:*:*:*:*:*:*:*', 'cpe:2.3:a:microsoft:jscript:5.8:*:*:*:*:*:*:*', 'cpe:2.3:a:microsoft:vbscript:5.6:*:*:*:*:*:*:*', 'cpe:2.3:a:microsoft:vbscript:5.7:*:*:*:*:*:*:*', 'cpe:2.3:a:microsoft:vbscript:5.8:*:*:*:*:*:*:*', 'cpe:2.3:a:microsoft:internet_explorer:8:*:*:*:*:*:*:*', 'cpe:2.3:a:microsoft:internet_explorer:9:*:*:*:*:*:*:*', 'cpe:2.3:a:microsoft:internet_explorer:10:*:*:*:*:*:*:*', 'cpe:2.3:a:microsoft:internet_explorer:11:-:*:*:*:*:*:*']}

CVE-2011-0663
{'description': 'multiple integer overflows in the microsoft (1) jscript 5.6 through 5.8 and 

{'description': 'multiple cross-site scripting (xss) vulnerabilities in the help jsp scripts in sun java web console 3.0.2 through 3.0.5, and sun java web console in solaris 10, allow remote attackers to inject arbitrary web script or html via unspecified vectors.', 'cpes': ['cpe:2.3:a:sun:java_web_console:3.0.2:*:*:*:*:*:*:*', 'cpe:2.3:a:sun:java_web_console:3.0.2:*:linux:*:*:*:*:*', 'cpe:2.3:a:sun:java_web_console:3.0.2:*:solaris8_sparc:*:*:*:*:*', 'cpe:2.3:a:sun:java_web_console:3.0.2:*:solaris8_x86:*:*:*:*:*', 'cpe:2.3:a:sun:java_web_console:3.0.2:*:solaris9_sparc:*:*:*:*:*', 'cpe:2.3:a:sun:java_web_console:3.0.2:*:windows:*:*:*:*:*', 'cpe:2.3:a:sun:java_web_console:3.0.3:*:*:*:*:*:*:*', 'cpe:2.3:a:sun:java_web_console:3.0.3:*:linux:*:*:*:*:*', 'cpe:2.3:a:sun:java_web_console:3.0.3:*:solaris9_sparc:*:*:*:*:*', 'cpe:2.3:a:sun:java_web_console:3.0.3:*:solaris9_x86:*:*:*:*:*', 'cpe:2.3:a:sun:java_web_console:3.0.3:*:windows:*:*:*:*:*', 'cpe:2.3:a:sun:java_web_console:3.0.4:*:*:*:*:*:*


CVE-2018-19775
{'description': 'cross site scripting exists in infovista vistaportal se version 5.1 (build 51029). the page "variables.jsp" has reflected xss via the connpoolname and groupid parameters.', 'cpes': ['cpe:2.3:a:infovista:vistaportal:5.1:*:*:*:standard:*:*:*']}

CVE-2018-19774
{'description': 'cross site scripting exists in infovista vistaportal se version 5.1 (build 51029). the page "presentspace.jsp" has reflected xss via the groupid and connpoolname parameters.', 'cpes': ['cpe:2.3:a:infovista:vistaportal:5.1:*:*:*:standard:*:*:*']}

CVE-2018-19771
{'description': 'cross site scripting exists in infovista vistaportal se version 5.1 (build 51029). the page "editcurrentpool.jsp" has reflected xss via the propname parameter.', 'cpes': ['cpe:2.3:a:infovista:vistaportal:5.1:*:*:*:standard:*:*:*']}

CVE-2018-19770
{'description': 'cross site scripting exists in infovista vistaportal se version 5.1 (build 51029). the page "users.jsp" has reflected xss via the connpoolname para

CVE-2010-2433
{'description': 'multiple cross-site scripting (xss) vulnerabilities in content/internalerror.jsp in ibm websphere ilog jrules 6.7 allow remote attackers to inject arbitrary web script or html via an rts url to (1) explore/explore.jsp, (2) compose/compose.jsp, or (3) home.jsp in faces/.', 'cpes': ['cpe:2.3:a:ibm:websphere_ilog_jrules:6.7:*:*:*:*:*:*:*']}

CVE-2005-3966
{'description': 'cross-site scripting (xss) vulnerability in search.jsp in java search engine (jse) 0.9.34 allows remote attackers to inject arbitrary web script or html via the q parameter.', 'cpes': ['cpe:2.3:a:java_search_engine:java_search_engine:0.9.34:*:*:*:*:*:*:*']}

CVE-2012-4937
{'description': 'session fixation vulnerability in the web interface in pattern insight 2.3 allows remote attackers to hijack web sessions via a jsession_id cookie.', 'cpes': ['cpe:2.3:a:patterninsight:pattern_insight:2.3:*:*:*:*:*:*:*']}

CVE-2013-2041
{'description': 'multiple cross-site scripting (xss) vulnerabilities i

In [268]:
len(any_js)

1728

In [132]:
any_js = set()
for cve_id in cve_desc:
    cpes = cve_desc[cve_id]['cpes']
    for cpe in cpes:
        if re.search('java[\s*]script', cpe):
            any_js.add(cve_id)
            print(cve_id)
            print(cve_desc[cve_id])
            print()

In [35]:
# i=-1
# i+=1
cve_desc[a[i]]

{'description': 'multiple cross-site request forgery (csrf) vulnerabilities in the crossslide jquery (crossslide-jquery-plugin-for-wordpress) plugin 2.0.5 for wordpress allow remote attackers to hijack the authentication of administrators for requests that (1) change plugin settings or conduct cross-site scripting (xss) attacks via the (2) csj_width, (3) csj_height, (4) csj_sleep, (5) csj_fade, or (6) upload_image parameter in the thisismyurl_csj.php page to wp-admin/options-general.php.',
 'cpes': ['cpe:2.3:a:crossslide_jquery_project:crossslide_jquery:2.0.5:*:*:*:*:wordpress:*:*']}

In [172]:
i=0
cve_desc[x[i]]

{'description': 'multiple cross-site scripting (xss) vulnerabilities in mantis before 0.19.3 allow remote attackers to inject arbitrary web script or html via (1) unknown vectors involving javascript and (2) mantis/view_all_set.php.',
 'cpes': ['cpe:2.3:a:mantis:mantis:0.19.0:*:*:*:*:*:*:*',
  'cpe:2.3:a:mantis:mantis:0.19.0_rc1:*:*:*:*:*:*:*',
  'cpe:2.3:a:mantis:mantis:0.19.0a1:*:*:*:*:*:*:*',
  'cpe:2.3:a:mantis:mantis:0.19.0a2:*:*:*:*:*:*:*',
  'cpe:2.3:a:mantis:mantis:0.19.1:*:*:*:*:*:*:*',
  'cpe:2.3:a:mantis:mantis:0.19.2:*:*:*:*:*:*:*',
  'cpe:2.3:a:mantis:mantis:0.19.3:*:*:*:*:*:*:*']}

In [173]:
x[0]

'CVE-2005-3337'

In [163]:
i #inject

0

In [141]:
i #execute

4

In [144]:
i #embed

6

In [149]:
i #jscript.dll

10

In [157]:
i #svg file with js

14

In [139]:
i #bypass javascript api execution restrictions

3

In [156]:
i #impact via the array index of the arguments array in a javascript function

13

In [166]:
i #can result in javascript execution

21

In [330]:
# cve_wo_js_desc = {}
# cve_w_js_desc = []
pop_package = ['angular', 'react', 'vue', 'ember', 'meteor', 'mithril', 'node', 'polymer', 'aurelia', 'backbone'] + ['Angular', 'React', 'Vue', 'Ember', 'Meteor', 'Mithril', 'Node', 'Polymer', 'Aurelia', 'Backbone']
for cve_id in cve_wo_js:
    for desc in cve_wo_js[cve_id]['description']:
        for p in pop_package:
            if p in cpe:
                print(desc)
#                 cve_w_js_desc.append(cve_id)
#                 break
#     else:
#         cve_wo_js_desc[cve_id] = cve_wo_js[cve_id]

In [324]:
len(cve_wo_js_desc)

118311

In [41]:
oo = []
for cve_id in cve_desc:
    for cpe in cve_desc[cve_id]['cpes']:
        if cpe[3]!=':' or cpe[3]!=':' or cpe[3]!=':' or cpe[9]!=':' :
            oo.append(cve_id)
            break

In [42]:
oo

[]

In [18]:
import copy
cve_desc_ori = copy.deepcopy(cve_desc)
for cve_id in cve_desc_green:
    cpes = []
    for cpe in cve_desc[cve_id]['cpes']:
        if cpe[8]=='a':
            cpe_ = cpe[10:]
            p1 = cpe_.find(':')
            p2 = cpe_[p1+1:].find(':')
            new_cpe = cpe_[:p1] + ' ' + cpe_[p1+1:p1+p2+1]
            if new_cpe not in cpes:
                cpes.append(new_cpe)
    
    cve_desc_ori[cve_id]['cpes'] = cpes

NameError: name 'cve_desc_green' is not defined

In [22]:
import copy
cve_desc_ori = copy.deepcopy(cve_desc)
for cve_id in cve_desc_green:
    cpes = []
    for cpe in cve_desc[cve_id]['cpes']:
        if cpe[8]=='a':
            cpe_ = cpe[10:]
            p1 = cpe_.find(':')
            p2 = cpe_[p1+1:].find(':')
            new_cpe = cpe_[p1+1:p1+p2+1]
            if new_cpe not in cpes:
                cpes.append(new_cpe)
    
    cve_desc_ori[cve_id]['cpes'] = cpes

In [23]:
cve_desc_ori

{'CVE-2005-3331': {'description': 'viewpatch in mgdiff 1.0 allows local users to overwrite arbitrary files via a symlink attack on temporary files.',
  'cpes': ['mgdiff_patch_viewer']},
 'CVE-2005-3330': {'description': 'the _httpsrequest function in snoopy 1.2, as used in products such as (1) magpierss, (2) wordpress, (3) ampache, and (4) jinzora, allows remote attackers to execute arbitrary commands via shell metacharacters in an https url to an ssl protected web page, which is not properly handled by the fetch function.',
  'cpes': ['snoopy']},
 'CVE-2005-3333': {'description': 'sql injection vulnerability in ebaseweb 3.0 allows remote attackers to execute arbitrary sql commands via unknown attack vectors.',
  'cpes': ['ebaseweb']},
 'CVE-2005-3332': {'description': 'php remote file include vulnerability in admin/define.inc.php in belchior foundry vcard 2.9 allows remote attackers to execute arbitrary php code via the match parameter.',
  'cpes': ['vcard']},
 'CVE-2005-3335': {'desc

In [279]:
for cve_id in cve_desc_ori:
    for cpe in cve_desc_ori[cve_id]['cpes']:
        if re.search('[0-9][^\d{1}][0-9]',cpe):
            print(cve_id,cpe)

CVE-2012-2023 adobe illustrator_cs5.5
CVE-2012-2026 adobe illustrator_cs5.5
CVE-2012-2024 adobe illustrator_cs5.5
CVE-2012-2025 adobe illustrator_cs5.5
CVE-2009-0761 team5.team_board 1.0
CVE-2009-0761 team5.team_board 1.0.1
CVE-2009-0761 team5.team_board 1.0.2
CVE-2009-0761 team5.team_board 1.0.3
CVE-2009-0761 team5.team_board 1.0.4
CVE-2009-0761 team5.team_board 1.0.5
CVE-2001-0955 xfree86_project x11r6
CVE-2012-2028 adobe photoshop_cs5.5
CVE-2012-2027 adobe photoshop_cs5.5
CVE-2014-0176 redhat cloudforms_3.0_management_engine
CVE-2008-6761 china-on-site flexcustomer0.0.6
CVE-2010-2050 m0r0n com_mscomment
CVE-2004-0419 x.org x11r6
CVE-2012-3533 ovirt-engine-sdk 3.1.0.5
CVE-2009-1381 squirrelmail squirrelmail1.4.19-1
CVE-2014-4624 avamar_virtual_edition 6.0
CVE-2014-4624 avamar_virtual_edition 6.0.402
CVE-2014-4624 avamar_virtual_edition 7.0
CVE-2014-4624 avamar_virtual_edition 7.0.2-43
CVE-2004-0337 software602 602pro_lan_suite
CVE-2004-0335 software602 602pro_lan_suite
CVE-2010-5214 

CVE-2018-14403 techsmith mp4v2
CVE-2017-16250 mitel st14.2
CVE-2017-16251 mitel st14.2
CVE-2016-2275 advantech vesp211-232_firmware
CVE-2004-1513 soft3304 04webserver
CVE-2014-3692 redhat cloudforms_3.1_management_engine
CVE-2013-6717 ibm db2_purescale_feature_9.8
CVE-2017-6753 cisco webex_meetings_server_2.0
CVE-2017-6753 cisco webex_meetings_server_2.0_mr8_patch
CVE-2017-6753 cisco webex_meetings_server_2.0_mr9_patch
CVE-2017-6753 cisco webex_meetings_server_2.5
CVE-2017-6753 cisco webex_meetings_server_2.5_mr2_patch
CVE-2017-6753 cisco webex_meetings_server_2.5_mr5_patch
CVE-2017-6753 cisco webex_meetings_server_2.5_mr6_patch
CVE-2017-6753 cisco webex_meetings_server_2.6
CVE-2017-6753 cisco webex_meetings_server_2.6_mr1_patch
CVE-2017-6753 cisco webex_meetings_server_2.6_mr2_patch
CVE-2017-6753 cisco webex_meetings_server_2.6_mr3_patch
CVE-2017-6753 cisco webex_meetings_server_2.7
CVE-2017-6753 cisco webex_meetings_server_2.7_mr1_patch
CVE-2017-6753 cisco webex_meetings_server_2.7_m

In [24]:
cve_desc_ori['CVE-2007-3520']

{'description': 'sql injection vulnerability in process.php in easybe 1-2-3 music store allows remote attackers to execute arbitrary sql commands via the categoryid parameter.',
 'cpes': ['1-2-3_music_store']}

In [94]:
cve_desc

{'CVE-2017-2185': {'description': 'HOME SPOT CUBE2 firmware V101 and earlier allows authenticated attackers to execute arbitrary OS commands via WebUI.',
  'cpes': ['cpe:2.3:o:kddi:home_spot_cube_2_firmware:v100:*:*:*:*:*:*:*',
   'cpe:2.3:o:kddi:home_spot_cube_2_firmware:v101:*:*:*:*:*:*:*',
   'cpe:2.3:h:kddi:home_spot_cube_2:-:*:*:*:*:*:*:*']},
 'CVE-2005-3331': {'description': 'viewpatch in mgdiff 1.0 allows local users to overwrite arbitrary files via a symlink attack on temporary files.',
  'cpes': ['cpe:2.3:a:rogers_software_source:mgdiff_patch_viewer:1.0:*:*:*:*:*:*:*']},
 'CVE-2005-3330': {'description': 'The _httpsrequest function in Snoopy 1.2, as used in products such as (1) MagpieRSS, (2) WordPress, (3) Ampache, and (4) Jinzora, allows remote attackers to execute arbitrary commands via shell metacharacters in an HTTPS URL to an SSL protected web page, which is not properly handled by the fetch function.',
  'cpes': ['cpe:2.3:a:snoopy:snoopy:1.2:*:*:*:*:*:*:*']},
 'CVE-2005

In [25]:
def string_form(s):
    s = re.sub('[_\-/]', ' ', s) 
    s = re.sub('\.(?![\d])|(?<![\d])\.', ' ', s) #dot not between digits
    s = re.sub(r'[\\]', '', s)
    return s

In [27]:
import copy
import re
cve_desc_green = copy.deepcopy(cve_desc)
for cve_id in cve_desc_green:
    cpes = []
    for cpe in cve_desc_green[cve_id]['cpes']:
        if cpe[8]=='a':
            cpe_ = cpe[10:]
            p1 = cpe_.find(':')
            p2 = cpe_[p1+1:].find(':')
            new_cpe = string_form( cpe_[p1+1:p1+p2+1])
#             new_cpe = re.sub('[_.\-/]', ' ', cpe_[:p1] + ' ' + cpe_[p1+1:p1+p2+1])
#             new_cpe = re.sub(r'[\\]', '', new_cpe)
#             new_cpe = (cpe_[:p1] + ' ' + cpe_[p1+1:p1+p2+1]).replace('_',' ').replace('.',' ').replace('-',' ')
            if new_cpe not in cpes:
                cpes.append(new_cpe)
    cve_desc_green[cve_id]['cpes'] = cpes

In [47]:
cve_desc_green

{'CVE-2005-3331': {'description': 'viewpatch in mgdiff 1.0 allows local users to overwrite arbitrary files via a symlink attack on temporary files.',
  'cpes': ['mgdiff patch viewer']},
 'CVE-2005-3330': {'description': 'the _httpsrequest function in snoopy 1.2, as used in products such as (1) magpierss, (2) wordpress, (3) ampache, and (4) jinzora, allows remote attackers to execute arbitrary commands via shell metacharacters in an https url to an ssl protected web page, which is not properly handled by the fetch function.',
  'cpes': ['snoopy']},
 'CVE-2005-3333': {'description': 'sql injection vulnerability in ebaseweb 3.0 allows remote attackers to execute arbitrary sql commands via unknown attack vectors.',
  'cpes': ['ebaseweb']},
 'CVE-2005-3332': {'description': 'php remote file include vulnerability in admin/define.inc.php in belchior foundry vcard 2.9 allows remote attackers to execute arbitrary php code via the match parameter.',
  'cpes': ['vcard']},
 'CVE-2005-3335': {'desc

In [48]:
cpes_2_cve = {}
for cve_id in cve_desc_green:
    for cpe in cve_desc_green[cve_id]['cpes']:
        if cpe not in cpes_2_cve:
            cpes_2_cve[cpe] = []
        cpes_2_cve[cpe].append(cve_id)

In [44]:
import re
for cve_id in cve_desc_green:
    for cpe in cve_desc_green[cve_id]['cpes']:
        if re.search('\s[0-9]\s',cpe):
            print(cve_id,cpe)

CVE-2000-0303 quake 3 arena
CVE-2009-4250 utf 8 cutenews
CVE-2006-3325 quake 3 engine
CVE-2006-3400 quake 3 engine
CVE-2006-0970 1 2 all
CVE-2007-0757 call of duty 2 dreamstats system
CVE-2015-7876 drupal 7 driver for sql server and sql azure
CVE-2013-6448 jboss seam 2 framework
CVE-2013-6447 jboss seam 2 framework
CVE-2016-5563 hospitality opera 5 property services
CVE-2016-5564 hospitality opera 5 property services
CVE-2016-5565 hospitality opera 5 property services
CVE-2018-2436 r 3 enterprise retail
CVE-2009-4777 jp1 automatic job management system 2 view
CVE-2009-4777 job management partner 1 automatic job management system 2 view
CVE-2009-4777 job management partner 1 integrated management view
CVE-2009-4777 job management partner 1 integrated manager console view
CVE-2009-4777 job management partner 1 integrated manager view
CVE-2009-4777 job management partner 1 performance management snmp system observer
CVE-2009-4777 job management partner 1 snmp system observer
CVE-2006-0320

In [272]:
x = re.sub('/',' ',cve_desc['CVE-2012-5321']['cpes'][0])
re.sub(r'\\','',x)

'cpe:2.3:a:tiki:tikiwiki_cms groupware:8.3:*:*:*:*:*:*:*'

In [256]:
re.sub('\.(?![\d])|(?<![\d])\.', ' ', 'dg.fdgh5_er.5gt-4.6-4.6b-4.y7-v.4-v.d')

'dg fdgh5_er 5gt-4.6-4.6b-4 y7-v 4-v d'

In [115]:
description = {'CVE-2017-2185':'HOME SPOT CUBE2 firmware V101 and earlier allows authenticated attackers to execute arbitrary OS commands via WebUI.'}

In [32]:
import pickle as pk
with open('data/npmPackageNames_2019-05-09_16:11:48.pickle','rb') as f:
    npm = pk.load(f)
npm = [string_form(i) for i in npm]

In [1]:
import pickle as pk
with open('data/npmPackageNames_2019-05-09_16:11:48.pickle','rb') as f:
    npm = pk.load(f)
# npm = [string_form(i) for i in npm]

In [2]:
npm

['0',
 '0-',
 '0----',
 '0-1-project',
 '0-100',
 '0-24',
 '0-60',
 '0-9',
 '0-_-0',
 '0.',
 '0.0',
 '0.0.1',
 '0.0.168',
 '0.0.250',
 '0.1.0',
 '0.1.unity-settings-daemon',
 '0.1f',
 '0.2.18',
 '0.css',
 '0.js',
 '0.workspace',
 '00',
 '00-components',
 '00-test',
 '00.demo',
 '000-webpack',
 '0003-lion-lib',
 '001',
 '001-nodelist',
 '001-npm',
 '001_skt',
 '001_test',
 '001senge',
 '002',
 '002-bao',
 '002-globals',
 '002-npm',
 '002twoweek',
 '003',
 '003-3',
 '003-bingpic',
 '003-npm',
 '003threeweek',
 '004-bingpic',
 '004-module',
 '004-week',
 '005',
 '005-gxp',
 '005-http',
 '005-http-antao',
 '005-http-jin',
 '005-http-open',
 '005-https-cxb',
 '005http',
 '007',
 '008-mysql',
 '008-somepackage',
 '009',
 '01',
 '01-03',
 '01-bibliotheque',
 '01-calc',
 '01-cute',
 '01-dj',
 '01-hhtclac',
 '01-numacert',
 '01-plugin',
 '01-practicemyself',
 '01-simple',
 '01-szhm-wb',
 '01-szhmqd25calc',
 '01-szhmqd27cale',
 '01-upload',
 '01.22pagination',
 '01.test-cnpm',
 '010-static-http-

In [4]:
import nltk
def tokenize_only(text):
    # first tokenize by sentence, then by word to ensure that punctuation is caught as it's own token
    tokens = [word for sent in nltk.sent_tokenize(text) for word in nltk.word_tokenize(sent)]
    filtered_tokens = []
    for token in tokens:
        if re.search('[a-zA-Z0-9]', token):
            filtered_tokens.append(token)
    return filtered_tokens

In [49]:
all_cpes = [cpe for cpe in cpes_2_cve]
text = all_cpes + npm

In [50]:
cpes_2_cve

{'mgdiff patch viewer': ['CVE-2005-3331'],
 'snoopy': ['CVE-2005-3330',
  'CVE-2014-5009',
  'CVE-2014-5008',
  'CVE-2008-7313',
  'CVE-2008-4796',
  'CVE-2009-0502'],
 'ebaseweb': ['CVE-2005-3333'],
 'vcard': ['CVE-2005-3332',
  'CVE-2004-1828',
  'CVE-2006-1230',
  'CVE-2009-3779',
  'CVE-2006-2810',
  'CVE-2006-3474'],
 'mantis': ['CVE-2005-3335',
  'CVE-2005-3339',
  'CVE-2005-3338',
  'CVE-2005-3337',
  'CVE-2005-3336',
  'CVE-2008-4689',
  'CVE-2008-4688',
  'CVE-2002-1113',
  'CVE-2002-1110',
  'CVE-2002-1114',
  'CVE-2005-4521',
  'CVE-2005-4520',
  'CVE-2005-4523',
  'CVE-2008-3332',
  'CVE-2008-3333',
  'CVE-2002-1112',
  'CVE-2008-2276',
  'CVE-2006-0146',
  'CVE-2006-0147',
  'CVE-2005-2556',
  'CVE-2005-2557',
  'CVE-2005-4238',
  'CVE-2006-1577',
  'CVE-2006-6515',
  'CVE-2008-4687',
  'CVE-2003-0499',
  'CVE-2007-6611',
  'CVE-2008-0404',
  'CVE-2006-6574',
  'CVE-2004-2666',
  'CVE-2005-4524',
  'CVE-2002-1111',
  'CVE-2002-1116',
  'CVE-2002-1115',
  'CVE-2004-1734',
 

In [51]:
def tokenized_vocab(text_list):
    totalvocab_tokenized = []
    for i in text_list:

        allwords_tokenized = tokenize_only(i)
        totalvocab_tokenized.extend(allwords_tokenized)
    return totalvocab_tokenized
totalvocab = set(tokenized_vocab(text))

In [52]:
with open('totalvocab.pkl','wb') as f:
    pk.dump(totalvocab, f)

#     totalvocab = pickle.load(f,'rb'))

In [53]:
len(totalvocab)

329075

In [54]:
# # tf-idf vectorizer
from sklearn.feature_extraction.text import TfidfVectorizer

tfidf_vectorizer = TfidfVectorizer(#min_df =10**-3 ,
                                   analyzer = 'word', max_features=len(set(totalvocab)), 
#                                    stop_words=my_stop_words, 
                                   tokenizer=tokenize_only, 
#                                     ngram_range=(1,3)
                                    )



In [72]:
tfidf_vectorizer

TfidfVectorizer(analyzer='word', binary=False, decode_error='strict',
        dtype=<class 'numpy.float64'>, encoding='utf-8', input='content',
        lowercase=True, max_df=1.0, max_features=329075, min_df=1,
        ngram_range=(1, 1), norm='l2', preprocessor=None, smooth_idf=True,
        stop_words=None, strip_accents=None, sublinear_tf=False,
        token_pattern='(?u)\\b\\w\\w+\\b',
        tokenizer=<function tokenize_only at 0x7fd670595488>, use_idf=True,
        vocabulary=None)

In [55]:
tfidf_matrix = tfidf_vectorizer.fit_transform(text)
all_cpes_vec = tfidf_matrix[:len(all_cpes)]
npm_vec = tfidf_matrix[len(all_cpes):]
grade = all_cpes_vec.dot(npm_vec.transpose())

most_arg = np.argmax(grade, axis=1)

most_grade = np.max(grade, axis=1)
all_match = list(zip(np.array(all_cpes),np.array(npm)[most_arg].reshape((len(all_cpes))), most_grade.toarray().reshape(-1)))

In [175]:
cve_desc_green[cpes_2_cve['video'][2]]

{'description': 'sql injection vulnerability in default.asp in efestech video 5.0 allows remote attackers to execute arbitrary sql commands via the catid parameter.',
 'cpes': ['video']}

In [176]:
cve_desc[cpes_2_cve['video'][2]]

{'description': 'sql injection vulnerability in default.asp in efestech video 5.0 allows remote attackers to execute arbitrary sql commands via the catid parameter.',
 'cpes': ['cpe:2.3:a:efestech:video:5.0:*:*:*:*:*:*:*']}

In [56]:
with open( "tfidf_matrix.pkl", "wb" ) as f:
    pk.dump(tfidf_matrix, f)
    # tfidf_matrix = pickle.load(f)
with open( "tfidf_vectorizer.pkl", "wb" ) as f:
    pk.dump(tfidf_vectorizer, f)
    # tfidf_vectorizer = pickle.load(f)

In [57]:
all_cpes_vec = tfidf_matrix[:len(all_cpes)]
npm_vec = tfidf_matrix[len(all_cpes):]

In [58]:
grade = all_cpes_vec.dot(npm_vec.transpose())

most_arg = np.argmax(grade, axis=1)

most_grade = np.max(grade, axis=1)

In [59]:
all_match = list(zip(np.array(all_cpes),np.array(npm)[most_arg].reshape((len(all_cpes))), most_grade.toarray().reshape(-1)))

In [161]:
all_match

[('mgdiff patch viewer', 'patch', 0.46975437103935863),
 ('snoopy', 'snoopy', 1.0),
 ('ebaseweb', '0', 0.0),
 ('vcard', 'vcard', 1.0),
 ('mantis', 'mantis', 1.0),
 ('flyspray', '0', 0.0),
 ('python', 'python', 1.0),
 ('postgresql', 'postgresql', 1.0),
 ('openssh', 'openssh rsa dsa parse', 0.6189807681860662),
 ('unified threat management software', 'threat', 0.5886744885675435),
 ('tomee', '0', 0.0),
 ('simplesamlphp', '0', 0.0),
 ('saml2', 'saml2', 1.0),
 ('gstreamer', 'gstreamer', 1.0),
 ('libav', '0', 0.0),
 ('jscript', 'jscript', 1.0),
 ('vbscript', 'vbscript to typescript', 0.8258526671011653),
 ('internet explorer', 'internet explorer version', 0.862816514766701),
 ('firefox', 'firefox', 1.0),
 ('backupninja', '0', 0.0),
 ('pillow', 'pillow', 1.0),
 ('remote device access virtual customer access system',
  'remote access',
  0.7083000871554723),
 ('connections business directory plugin', 'connections', 0.6321427907792957),
 ('samba', 'samba', 1.0),
 ('infosphere information serve

In [168]:
match_ranges = {round(10*k)/10:len([i for i in all_match if k+0.1>i[2]>=k]) for k in np.arange(1,-0.1,-0.1)}

In [169]:
print(match_ranges)
print(np.array(list(match_ranges.values()))/sum(match_ranges.values()))

{1.0: 3556, 0.9: 642, 0.8: 2311, 0.7: 4682, 0.6: 4445, 0.5: 4352, 0.4: 2337, 0.3: 341, 0.2: 16, 0.1: 0, 0.0: 0}
[0.1567763  0.02830438 0.10188696 0.20641919 0.19597037 0.19187021
 0.10303324 0.01503395 0.00070541 0.         0.        ]


In [216]:
has_list = []
for cve_id in cve_desc_green:
    cpes = cve_desc_green[cve_id]['cpes']
    alltext = ' '.join(cpes) +cve_desc_green[cve_id]['description']
    for cpe in cpes:
        if re.search('jquery', alltext):
#         if cpe=='fuse':
            has_list.append(cve_id)


In [327]:
# i=-1
i+=1
print(i)
cve_desc[has_list[i]]

1


{'description': 'mtappjquery 1.8.1 and earlier allows remote php code execution via unspecified vectors.',
 'cpes': ['cpe:2.3:a:bit-part:mtappjquery:*:*:*:*:*:movabletype:*:*']}

In [278]:
len(has_list)

46

In [210]:
i=-1
i+=1
cve_desc[a[i]]

{'description': 'cross-site scripting (xss) vulnerability in the datatables plugin 1.10.8 and earlier for jquery allows remote attackers to inject arbitrary web script or html via the scripts parameter to media/unit_testing/templates/6776.php.',
 'cpes': ['cpe:2.3:a:sprymedia:datatables:*:*:*:*:*:jquery:*:*']}

In [162]:
match_100 = [i for i in all_match if 1<=i[2]]
len(match_100)

3556

In [163]:
match_100

[('snoopy', 'snoopy', 1.0),
 ('vcard', 'vcard', 1.0),
 ('mantis', 'mantis', 1.0),
 ('python', 'python', 1.0),
 ('postgresql', 'postgresql', 1.0),
 ('saml2', 'saml2', 1.0),
 ('gstreamer', 'gstreamer', 1.0),
 ('jscript', 'jscript', 1.0),
 ('firefox', 'firefox', 1.0),
 ('pillow', 'pillow', 1.0),
 ('samba', 'samba', 1.0),
 ('vlc media player', 'vlc media player', 1.0),
 ('openrefine', 'openrefine', 1.0),
 ('popcorn', 'popcorn', 1.0),
 ('freetype', 'freetype', 1.0),
 ('safari', 'safari', 1.0),
 ('cups', 'cups', 1.0),
 ('ie', 'ie', 1.0),
 ('opencv', 'opencv', 1.0),
 ('qpdf', 'qpdf', 1.0),
 ('office', 'office', 1.0),
 ('server', 'Server', 1.0),
 ('visio', 'visio', 1.0),
 ('works', 'works', 1.0),
 ('ftp', 'ftp', 1.0),
 ('sendmail', 'sendmail', 1.0),
 ('mesa', 'mesa', 1.0),
 ('mutt', 'mutt', 1.0),
 ('ssh', 'ssh', 1.0),
 ('global', 'global', 1.0),
 ('binutils', 'binutils', 1.0),
 ('word', 'word', 1.0),
 ('excel', 'excel', 1.0),
 ('zero cms', 'cms zero', 1.0),
 ('endless', 'endless', 1.0),
 ('unz

In [63]:
match_90 = [i for i in all_match if 1>i[2]>=0.9]
len(match_90)

642

In [64]:
match_90

[('wireshark', 'node wireshark', 0.9401619617812456),
 ('jrun', 'jrun js', 0.9424776543929972),
 ('report viewer', 'report viewer', 0.9999999999999999),
 ('poppler', 'poppler simple', 0.9054465550396035),
 ('typo3', 'generator typo3', 0.9185011275897756),
 ('exchange server', 'exchange test server', 0.9018773456407552),
 ('teamspeak3', 'teamspeak3 client', 0.9211198351487024),
 ('libshout', 'libshout js', 0.9424776543929972),
 ('application server', 'react application server', 0.934701271966239),
 ('hylafax', 'hylafax client', 0.9211198351487024),
 ('skeleton theme', 'react theme skeleton', 0.9410919140464948),
 ('virtuemart', 'virtuemart api', 0.9233583537857097),
 ('business manager', 'business manager', 0.9999999999999998),
 ('plesk', '@plesk plesk ext sdk', 0.9175482888617823),
 ('apple tv', 'apple tv', 0.9999999999999998),
 ('libpng', 'node libpng', 0.9424344106729406),
 ('cloud backup', 'cloud backup server', 0.9007764213332878),
 ('weblogic server', 'weblogic', 0.905029383988992

In [165]:
match_80 = [i for i in all_match if 0.8>i[2]>=0.7]
len(match_80)

4682

In [166]:
match_80

[('remote device access virtual customer access system',
  'remote access',
  0.7083000871554723),
 (' net framework', 'net', 0.7433681328714203),
 ('forefront client security', 'forefront', 0.774027282093347),
 ('office powerpoint viewer', 'powerpoint', 0.7030576224952497),
 ('excel viewer', 'excel', 0.7478987568469974),
 ('lonely maple', 'lonely', 0.7662519265813593),
 ('orange cutout', 'cutout', 0.787237978060705),
 ('adaptive server enterprise', 'adaptive', 0.7055013040681335),
 ('icloud for windows', 'icloud', 0.748833266672302),
 ('insite', 'insite infrajs', 0.714557201677125),
 ('command antivirus', 'antivirus', 0.7858552123523799),
 ('scan engine', 'scan', 0.7919962665133217),
 ('panda antivirus', 'antivirus', 0.7194740512688714),
 ('anti malware', 'malware', 0.7521857296892208),
 ('endpoint protection', 'protection', 0.7250184903761094),
 ('rumpus ftp server', 'rumpus', 0.7814633091203284),
 ('orange web server', 'orange', 0.7686286187613062),
 ('dive assistant', 'dive', 0.776

In [67]:
with open('data/npm_package_name_with_matching_cveid.pickle','rb') as f:
    npm_match_ground_truth = pk.load(f)

In [68]:
npm_match_ground_truth

[('datatables', 'CVE-2015-6584'),
 ('grunt-webdriver-qunit', 'CVE-2016-10606'),
 ('panels.js', 'CVE-2015-8861'),
 ('sfml', 'CVE-2016-10654'),
 ('ember', 'CVE-2014-0014'),
 ('looppake', 'CVE-2017-16169'),
 ('kibana', 'CVE-2017-8452'),
 ('backbone', 'CVE-2016-10537'),
 ('ldapauth-fork', 'CVE-2015-7294'),
 ('node', 'CVE-2016-5325'),
 ('windows-selenium-chromedriver', 'CVE-2016-10687'),
 ('cobalt-cli', 'CVE-2016-10597'),
 ('zeroclipboard', 'CVE-2014-1869'),
 ('electron', 'CVE-2017-12581'),
 ('keycloak-auth-utils', 'CVE-2017-7474'),
 ('pandora-doomsday', 'CVE-2017-16127'),
 ('cue-sdk-node', 'CVE-2016-10590'),
 ('gaoxuyan', 'CVE-2017-16153'),
 ('node', 'CVE-2017-14919'),
 ('bkjs-wand', 'CVE-2016-10571'),
 ('mediaelement', 'CVE-2016-4567'),
 ('string', 'CVE-2017-16116'),
 ('dylmomo', 'CVE-2017-16163'),
 ('assembly', 'CVE-2015-8861'),
 ('galenframework-cli', 'CVE-2016-10560'),
 ('mapbox.js', 'CVE-2017-1000043'),
 ('nodemailer.js', 'CVE-2017-16072'),
 ('lz4', 'CVE-2014-4611'),
 ('kibana', 'CVE-

In [185]:
a = [i[1] for i in npm_match_ground_truth]
b = [string_form(i[0]) for i in npm_match_ground_truth]

In [257]:
jq_in_a = []
for i in a:
    if 'jquery' in i[0]:
        jq_in_a.append(i)

In [146]:
b_unique = list(set(b))

In [79]:
no_index = []
for i in range(len(b)):
    if b[i] not in npm:
        no_index.append(i)
print(no_index)

[53, 54, 94, 132, 212, 238, 277, 290, 332, 357, 387, 490, 531, 541, 547, 548, 614, 678, 687, 711, 792]


In [170]:
cve_desc_green[a[b.index('cordova android')]]

ValueError: 'video' is not in list

In [82]:
cpe_list_ = []
for i in range(len(npm_match_ground_truth)):
    cve_id = npm_match_ground_truth[i][1]
    if cve_id in cve_desc_green:
        cpe_list_.extend(cve_desc_green[cve_id]['cpes'])
cpe_list_ = list(set(cpe_list_))

In [71]:
d_npm_cve = {}
for i in a:
    if i in cve_desc_green:
        for cpe in cve_desc_green[i]['cpes']:
            if cpe not in d_cve_npm:
                d_cve_npm[cpe] = []
            d_cve_npm[cpe].append(i)
        

In [147]:
all_cpes_ = [cpe for cpe in cpe_list_]
text_ = all_cpes_ + b_unique

In [126]:

tfidf_matrix_ = tfidf_vectorizer.transform(text_)
all_cpes_vec_ = tfidf_matrix_[:len(all_cpes_)]
npm_vec_ = tfidf_matrix_[len(all_cpes_):]
grade_ = all_cpes_vec_.dot(npm_vec_.transpose())

most_arg_ = np.argmax(grade_, axis=0)

most_grade_ = np.max(grade_, axis=0)


In [127]:
npm_vec_.shape[0]

641

In [148]:
all_match_ = list(zip(np.array(all_cpes_)[most_arg_].reshape((npm_vec_.shape[0])),np.array(b_unique), most_grade_.toarray().reshape(-1)))

In [129]:
len(all_match_)

641

In [131]:
match_90_ = [i for i in all_match_ if 1>i[2]>=0.9]
len(match_90_)

81

In [132]:
match_80_ = [i for i in all_match_ if 0.9>i[2]>=0.8]
len(match_80_)

10

In [133]:
match_ranges = {round(10*k)/10:len([i for i in all_match_ if k+0.1>i[2]>=k]) for k in np.arange(1,-0.1,-0.1)}

In [149]:
print(match_ranges)
print(np.array(list(match_ranges.values()))/sum(match_ranges.values()))

{1.0: 456, 0.9: 81, 0.8: 10, 0.7: 10, 0.6: 10, 0.5: 14, 0.4: 10, 0.3: 10, 0.2: 6, 0.1: 0, 0.0: 0}
[0.75123558 0.13344316 0.01647446 0.01647446 0.01647446 0.02306425
 0.01647446 0.01647446 0.00988468 0.         0.        ]


In [164]:
[i for i in all_match_ if 0.8>i[2]>=0.7]

[('adm zip', 'adm zip mit', 0.7876956047126276),
 ('flowplayer html5', 'flowplayer', 0.7940680475299016),
 ('adm zip', 'adm zip with enc', 0.7096629221767243),
 ('serve', 'serve index', 0.727090996370414),
 ('query mysql', 'mysql', 0.7205384828635941),
 ('ember js', 'ember', 0.7844933642504155),
 ('adm zip', 'adm zip iconv', 0.7895573702298826),
 ('auth0', 'auth0 lock', 0.7148464281458531),
 ('node js', 'node', 0.7072518992211563),
 ('ejs', 'ejs co', 0.7340395644158756)]

In [354]:
import json

with open('data/parsed_mvn_cent_idx_solr_06032019.json') as f:
    mvn = json.load(f)

In [356]:
mvn_all = {}
for i in range(len(mvn)):
    artifact_u = string_form(mvn[i].get('artifact_u',''))
    if artifact_u not in mvn_all:
        mvn_all[artifact_u] = (i, string_form(mvn[i].get('group_u','')))
mvn_all


{'yom': (0, 'yom'),
 'ymsg test': (2, 'ymsg'),
 'ymsg support': (3, 'ymsg'),
 'ymsg network': (4, 'ymsg'),
 'yan': (6, 'yan'),
 'jfunutil': (13, 'yan'),
 'xxl': (19, 'xxl'),
 'xtiff jai': (20, 'xtiff jai'),
 'xstream': (21, 'xstream'),
 'xsdlib': (48, 'xsdlib'),
 'xsddoc': (49, 'xsddoc'),
 'maven xsddoc plugin': (54, 'xsddoc'),
 'xpp3 xpath': (58, 'xpp3'),
 'xpp3 min': (62, 'xpp3'),
 'xpp3': (67, 'xpp3'),
 'xom': (81, 'xom'),
 'xmlwriter': (93, 'xmlwriter'),
 'xmlunit': (111, 'xmlunit'),
 'xmlrpc helma': (118, 'xmlrpc helma'),
 'xmlrpc server': (119, 'xmlrpc'),
 'xmlrpc common': (124, 'xmlrpc'),
 'xmlrpc client': (128, 'xmlrpc'),
 'xmlrpc': (132, 'xmlrpc'),
 'xmlpull': (150, 'xmlpull'),
 'xfc': (154, 'xmlmind'),
 'aptconvert': (155, 'xmlmind'),
 'xmlenc': (156, 'xmlenc'),
 'xmldb xupdate': (159, 'xmldb'),
 'xmldb common': (160, 'xmldb'),
 'xmldb api sdk': (161, 'xmldb'),
 'xmldb api': (162, 'xmldb'),
 'xmlpublic': (163, 'xmlbeans'),
 'xmlbeans xmlpublic': (167, 'xmlbeans'),
 'xmlbeans'

In [359]:
mvn_conca_w_id = {(mvn_all[a_u][1] + 2 * (' ' + a_u)): mvn_all[a_u][0] for a_u in mvn_all}

In [361]:
mvn_conca_w_lst = [i for i in mvn_conca_w_id]

In [368]:
import copy
cve_desc_green_ = copy.deepcopy(cve_desc)
for cve_id in cve_desc_green_:
    cpes = []
    for cpe in cve_desc_green_[cve_id]['cpes']:
        if cpe[8]=='a':
            cpe_ = cpe[10:]
            p1 = cpe_.find(':')
            p2 = cpe_[p1+1:].find(':')
            new_cpe = string_form(cpe_[:p1] + 2 * (' ' + cpe_[p1+1:p1+p2+1]))
#             new_cpe = re.sub('[_.\-/]', ' ', cpe_[:p1] + ' ' + cpe_[p1+1:p1+p2+1])
#             new_cpe = re.sub(r'[\\]', '', new_cpe)
#             new_cpe = (cpe_[:p1] + ' ' + cpe_[p1+1:p1+p2+1]).replace('_',' ').replace('.',' ').replace('-',' ')
            if new_cpe not in cpes:
                cpes.append(new_cpe)
    cve_desc_green_[cve_id]['cpes'] = cpes

In [370]:
cpes_2_cve_ = {}
for cve_id in cve_desc_green_:
    for cpe in cve_desc_green_[cve_id]['cpes']:
        if cpe not in cpes_2_cve_:
            cpes_2_cve_[cpe] = []
        cpes_2_cve_[cpe].append(cve_id)

In [413]:
cve_id = cpes_2_cve_['google api search api search']
print(cve_id)

['CVE-2005-3869']


In [414]:
cve_desc[cve_id[0]]

{'description': 'Cross-site scripting (XSS) vulnerability in index.php in Google API Search 1.3.1 and earlier allows remote attackers to inject arbitrary web script or HTML via hex-encoded values in the REQ parameter.',
 'cpes': ['cpe:2.3:a:google:api_search:*:*:*:*:*:*:*:*']}

In [374]:
all_cpes_double = [cpe for cpe in cpes_2_cve_]
text = all_cpes_double + mvn_conca_w_lst

In [377]:
all_cpes_vec_mvn_double.shape

(33831, 94890)

In [378]:
mvn_double_vec.shape

(941456, 314990)

In [420]:
a = np.array([1,2,3,5,4])
print(np.argsort(a)[::-1][:3])
print(a)

[3 4 2]
[1 2 3 5 4]


In [421]:
a[np.argsort(a)[::-1][:3]]

array([5, 4, 3])

In [437]:
# tfidf_matrix_mvn_double = tfidf_vectorizer.fit_transform(text)
# all_cpes_vec_mvn_double = tfidf_matrix_mvn_double[:len(all_cpes_double)]
# mvn_double_vec = tfidf_matrix_mvn_double[len(all_cpes_double):]
# grade_ = all_cpes_vec_mvn_double.dot(mvn_double_vec.transpose())

# most_arg_ = np.argmax(grade_, axis=1)
n=5
most_n_arg_ = np.argsort(grade_.toarray(), axis=1)[:,-1:-n-1:-1]

most_grade_ = np.sort(grade_, axis=1)[:,-1:-n-1:-1]
all_match_ = list(zip(np.array(all_cpes_),np.array(mvn_conca_w_lst)[most_arg_].reshape((len(all_cpes_))), most_grade_.toarray().reshape(-1)))

MemoryError: 

In [443]:
grade_.sort(axis = 0)

AttributeError: sort not found

In [429]:
np.array(mvn_conca_w_lst)[most_n_arg_]

NameError: name 'most_n_arg_' is not defined

In [444]:
match_90_ = [i for i in all_match_ if 0.5>i[2]>=0.4]
len(match_90_)

5762

In [445]:
match_90_

[('rogers software source mgdiff patch viewer mgdiff patch viewer',
  'io fabric8 patch patch core patch core',
  0.414927835229414),
 ('sophos unified threat management software unified threat management software',
  'org webjars npm unified unified',
  0.4934875058057485),
 ('zahmit design connections business directory plugin connections business directory plugin',
  'directory maven directory plugin maven directory plugin',
  0.448523464770778),
 ('cisco adaptive security appliance software adaptive security appliance software',
  'me adaptive adaptive nibble common adaptive nibble common',
  0.4209054607190967),
 ('cisco adaptive security virtual appliance adaptive security virtual appliance',
  'me adaptive adaptive nibble common adaptive nibble common',
  0.40078021650940243),
 ('thedaylightstudio fuel cms fuel cms',
  'com buabook sf et fuel sdk sf et fuel sdk',
  0.4730020288511579),
 ('microsoft digital image suite digital image suite',
  'it tidalwave image image core image 

In [405]:
import random
for x in range(36): 
    print(random.choice(match_90_)) 

('perforce perforce client perforce client', 'org jvnet hudson plugins perforce perforce', 0.8547382874388395)
('redhat sendmail sendmail', 'org apache geronimo samples sendmail sendmail', 0.855462214770309)
('google talk talk', 'com pddstudio talk talk', 0.8572247069568956)
('easypage easypage easypage', 'com github ethendev easypage easypage', 0.8792666393943822)
('signal messenger messenger', 'messenger messenger messenger', 0.8721138098986067)
('barracuda load balancer load balancer', 'org apache stratos load balancer load balancer parent load balancer parent', 0.8744097015738728)
('google api search api search', 'org sakaiproject search search api search api', 0.8879348786037732)
('uochm signup signup', 'org sakaiproject signup signup signup', 0.8660247326465632)
('apache commons jelly commons jelly', 'org jvnet hudson commons jelly commons jelly', 0.8969044946002815)
('openstack grizzly grizzly', 'grizzly grizzly grizzly', 0.8845809684364562)
('zabbix zabbix zabbix', 'com github 

In [62]:
from pytorch_transformers.tokenization_bert import BertTokenizer
from pytorch_transformers.modeling_bert import (
        BertModel,
        BertForNextSentencePrediction,
        BertForMaskedLM,
        BertForMultipleChoice,
        BertForPreTraining,
        BertForQuestionAnswering,
        BertForSequenceClassification,
        BertForTokenClassification,
        )
tokenizer = BertTokenizer.from_pretrained(config['name'], do_lower_case=config['do_lower_case'])

100%|██████████| 231508/231508 [00:00<00:00, 811841.27B/s]


In [132]:
MAX_LEN = config['MAX_LEN']
    #padding data
sentence_pairs = ['[CLS] ' + n + ' [SEP] ' + description[cve_id] + ' [SEP]' for n in npm for cve_id in description]
tokenized_text = [tokenizer.tokenize(sentence_pair) for sentence_pair in sentence_pairs] 
input_ids = pad_sequences([tokenizer.convert_tokens_to_ids(txt) for txt in tokenized_text],
                              maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")

KeyboardInterrupt: 

In [125]:
[tokenizer.convert_tokens_to_ids(txt) for txt in sentence_pairs]

['[CLS] 0 [SEP] HOME SPOT CUBE2 firmware V101 and earlier allows authenticated attackers to execute arbitrary OS commands via WebUI. [SEP]',
 '[CLS] 0  [SEP] HOME SPOT CUBE2 firmware V101 and earlier allows authenticated attackers to execute arbitrary OS commands via WebUI. [SEP]',
 '[CLS] 0     [SEP] HOME SPOT CUBE2 firmware V101 and earlier allows authenticated attackers to execute arbitrary OS commands via WebUI. [SEP]',
 '[CLS] 0 1 project [SEP] HOME SPOT CUBE2 firmware V101 and earlier allows authenticated attackers to execute arbitrary OS commands via WebUI. [SEP]',
 '[CLS] 0 100 [SEP] HOME SPOT CUBE2 firmware V101 and earlier allows authenticated attackers to execute arbitrary OS commands via WebUI. [SEP]',
 '[CLS] 0 24 [SEP] HOME SPOT CUBE2 firmware V101 and earlier allows authenticated attackers to execute arbitrary OS commands via WebUI. [SEP]',
 '[CLS] 0 60 [SEP] HOME SPOT CUBE2 firmware V101 and earlier allows authenticated attackers to execute arbitrary OS commands via Web

In [130]:
tokenizer.convert_tokens_to_ids(tokenized_text)

[101,
 1014,
 102,
 2188,
 3962,
 14291,
 2475,
 3813,
 8059,
 1058,
 10790,
 2487,
 1998,
 3041,
 4473,
 14469,
 4383,
 17857,
 2000,
 15389,
 15275,
 9808,
 10954,
 3081,
 4773,
 10179,
 1012,
 102]

In [129]:
tokenized_text

['[CLS]',
 '0',
 '[SEP]',
 'home',
 'spot',
 'cube',
 '##2',
 'firm',
 '##ware',
 'v',
 '##10',
 '##1',
 'and',
 'earlier',
 'allows',
 'authentic',
 '##ated',
 'attackers',
 'to',
 'execute',
 'arbitrary',
 'os',
 'commands',
 'via',
 'web',
 '##ui',
 '.',
 '[SEP]']

In [99]:
import torch
# from pytorch_pretrained_bert import bertForNextSentencePrediction
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = tokenizer.tokenize(text)
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
tokens_tensor = torch.tensor([indexed_tokens,indexed_tokens])
segments_tensors = torch.tensor([segments_ids,segments_ids])
# Load bertForNextSentencePrediction


In [100]:
tokens_tensor

tensor([[  101,  2040,  2001,  3958, 27227,  1029,   102,  3958, 27227,  2001,
          1037, 13997, 11510,   102],
        [  101,  2040,  2001,  3958, 27227,  1029,   102,  3958, 27227,  2001,
          1037, 13997, 11510,   102]])

In [70]:
model = torch.hub.load('huggingface/pytorch-transformers', 'bertForNextSentencePrediction', 'bert-base-uncased')
model.eval()


In [101]:
# Predict the next sentence classification logits
with torch.no_grad():
    next_sent_classif_logits = model(tokens_tensor, segments_tensors)

In [105]:
torch.softmax(next_sent_classif_logits,1)

tensor([[1.0000e+00, 2.8751e-06],
        [1.0000e+00, 2.8751e-06]])

In [103]:
next_sent_classif_logits

tensor([[ 6.3714, -6.3880],
        [ 6.3714, -6.3880]])

In [1]:
config

Available objects for config:
     AliasManager
     DisplayFormatter
     HistoryManager
     IPCompleter
     IPKernelApp
     LoggingMagics
     MagicsManager
     OSMagics
     PrefilterManager
     ScriptMagics
     StoreMagics
     ZMQInteractiveShell
