In [2]:
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
from datetime import datetime
import json
import torch as t
import pandas as pd


In [2]:
# Suppress warning messages
from transformers.utils import logging
logging.set_verbosity(40)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'


In [3]:
# Program variables

conversation_history = list()
model_id = "codellama/CodeLlama-7b-Instruct-hf"
filename = f"{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}.txt"


In [4]:
device = "cuda:0" if t.cuda.is_available() else "cpu"

In [5]:
device

'cuda:0'

In [6]:
cache_dir ="../datax/models"

In [7]:
# Load model
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, cache_dir=cache_dir, device_map=device, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, cache_dir=cache_dir, padding_side="left")
streamer = TextStreamer(tokenizer, skip_prompt=True)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [8]:
model.to(device)


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32016, 4096)
    (layers): ModuleList(
      (0): LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
      (1): LlamaDecoderLayer(
        (self_attn): LlamaAtten

In [9]:
p1_template= "The following snippet contians vulnerable code with a CVSS score of {}  {}"

In [10]:
p2_template = "What is the CVSS score for the following snippet?  {}"

In [8]:
prompt1 = '''The following snippet is a exploitable code with a score of 4.5
	dev = usb_get_intfdata(interface);
	if (!dev) {
		retval = -ENODEV;
		goto exit;
	}

	/* increment our usage count for the device */
	kref_get(&dev->kref);

	/* save our object in the file's private structure */
	mutex_lock(&dev->io_mutex);
	file->private_data = dev;
	mutex_unlock(&dev->io_mutex);'''

In [9]:
prompt2 = '''The following snippet has a score of 7.5: int i;
char inLine[64];
cin >> inLine;
i = atoi (inLine);
sleep(i);'''

In [10]:
prompt3 = '''What is the score for the following snippet? 
int main(int argc, char *argv[])
{
	rc = SQLConnect(Example.ConHandle, argv[0], SQL_NTS,
	(SQLCHAR *) "", SQL_NTS, (SQLCHAR *) "", SQL_NTS);
} '''

In [11]:
MESSAGES_LENGTH = 2

In [12]:
max_samples = 30

In [3]:
val_data = pd.read_csv("../data/big-vul/train.csv")

In [6]:
val_data["func_before"][0]

'const char *OBJ_nid2sn(int n)\n\t{\n\tADDED_OBJ ad,*adp;\n\tASN1_OBJECT ob;\n\n\tif ((n >= 0) && (n < NUM_NID))\n\t\t{\n\t\tif ((n != NID_undef) && (nid_objs[n].nid == NID_undef))\n\t\t\t{\n\t\t\tOBJerr(OBJ_F_OBJ_NID2SN,OBJ_R_UNKNOWN_NID);\n\t\t\treturn(NULL);\n\t\t\t}\n\t\treturn(nid_objs[n].sn);\n\t\t}\n\telse if (added == NULL)\n\t\treturn(NULL);\n\telse\n\t\t{\n\t\tad.type=ADDED_NID;\n\t\tad.obj= &ob;\n\t\tob.nid=n;\n\t\tadp=lh_ADDED_OBJ_retrieve(added,&ad);\n\t\tif (adp != NULL)\n\t\t\treturn(adp->obj->sn);\n\t\telse\n\t\t\t{\n\t\t\tOBJerr(OBJ_F_OBJ_NID2SN,OBJ_R_UNKNOWN_NID);\n\t\t\treturn(NULL);\n\t\t\t}\n\t\t}\n\t}\n'

In [7]:
val_data["vul_func_with_fix"][0]

'const char *OBJ_nid2sn(int n)\n\t{\n\tADDED_OBJ ad,*adp;\n\tASN1_OBJECT ob;\n\n\tif ((n >= 0) && (n < NUM_NID))\n\t\t{\n\t\tif ((n != NID_undef) && (nid_objs[n].nid == NID_undef))\n\t\t\t{\n\t\t\tOBJerr(OBJ_F_OBJ_NID2SN,OBJ_R_UNKNOWN_NID);\n\t\t\treturn(NULL);\n\t\t\t}\n\t\treturn(nid_objs[n].sn);\n\t\t}\n\telse if (added == NULL)\n\t\treturn(NULL);\n\telse\n\t\t{\n\t\tad.type=ADDED_NID;\n\t\tad.obj= &ob;\n\t\tob.nid=n;\n\t\tadp=lh_ADDED_OBJ_retrieve(added,&ad);\n\t\tif (adp != NULL)\n\t\t\treturn(adp->obj->sn);\n\t\telse\n\t\t\t{\n\t\t\tOBJerr(OBJ_F_OBJ_NID2SN,OBJ_R_UNKNOWN_NID);\n\t\t\treturn(NULL);\n\t\t\t}\n\t\t}\n\t}\n'

In [8]:
val_data["processed_func"][0]

'const char *OBJ_nid2sn(int n)\n\t{\n\tADDED_OBJ ad,*adp;\n\tASN1_OBJECT ob;\n\n\tif ((n >= 0) && (n < NUM_NID))\n\t\t{\n\t\tif ((n != NID_undef) && (nid_objs[n].nid == NID_undef))\n\t\t\t{\n\t\t\tOBJerr(OBJ_F_OBJ_NID2SN,OBJ_R_UNKNOWN_NID);\n\t\t\treturn(NULL);\n\t\t\t}\n\t\treturn(nid_objs[n].sn);\n\t\t}\n\telse if (added == NULL)\n\t\treturn(NULL);\n\telse\n\t\t{\n\t\tad.type=ADDED_NID;\n\t\tad.obj= &ob;\n\t\tob.nid=n;\n\t\tadp=lh_ADDED_OBJ_retrieve(added,&ad);\n\t\tif (adp != NULL)\n\t\t\treturn(adp->obj->sn);\n\t\telse\n\t\t\t{\n\t\t\tOBJerr(OBJ_F_OBJ_NID2SN,OBJ_R_UNKNOWN_NID);\n\t\t\treturn(NULL);\n\t\t\t}\n\t\t}\n\t}\n'

In [15]:
val_data.head()

Unnamed: 0,index,Access Gained,Attack Origin,Authentication Required,Availability,CVE ID,CVE Page,CWE ID,Complexity,Confidentiality,...,parentID,patch,project,project_after,project_before,target,vul_func_with_fix,processed_func,flaw_line,flaw_line_index
0,12473,,Remote,Not required,,CVE-2014-3508,https://www.cvedetails.com/cve/CVE-2014-3508/,CWE-200,Medium,Partial,...,17160033765480453be0a41335fa6b833691c049,"@@ -471,11 +471,12 @@ int OBJ_obj2txt(char *bu...",openssl,https://git.openssl.org/gitweb/?p=openssl.git;...,https://git.openssl.org/gitweb/?p=openssl.git;...,0,const char *OBJ_nid2sn(int n)\n\t{\n\tADDED_OB...,const char *OBJ_nid2sn(int n)\n\t{\n\tADDED_OB...,,
1,24444,,Local,Not required,,CVE-2011-4080,https://www.cvedetails.com/cve/CVE-2011-4080/,CWE-264,High,Complete,...,,"@@ -170,6 +170,11 @@ static int proc_taint(str...",linux,bfdc0b497faa82a0ba2f9dddcf109231dd519fcc,cb16e95fa2996743a6e80a665ed2ed0590bd38cf,0,void register_sysctl_root(struct ctl_table_roo...,void register_sysctl_root(struct ctl_table_roo...,,
2,111436,,Remote,Not required,Partial,CVE-2012-2875,https://www.cvedetails.com/cve/CVE-2012-2875/,,Medium,Partial,...,,"@@ -4058,11 +4058,6 @@ bool WebPage::touchEven...",Chrome,d345af9ed62ee5f431be327967f41c3cc3fe936a,e261bb8e47a6a9fdd1d26fd52b1538c5c9bcb122,0,WebPagePrivate::~WebPagePrivate()\n{\n // H...,WebPagePrivate::~WebPagePrivate()\n{\n m_we...,,
3,1314,,Remote,Not required,Complete,CVE-2009-3604,https://www.cvedetails.com/cve/CVE-2009-3604/,CWE-399,Medium,Complete,...,75c3466ba2e4980802e80b939495981240261cd5,"@@ -216,6 +216,28 @@ void *gmallocn_checkoverf...",poppler,https://cgit.freedesktop.org/poppler/poppler/t...,https://cgit.freedesktop.org/poppler/poppler/t...,0,"char *gstrndup(const char *s, size_t n) {\n c...","char *gstrndup(const char *s, size_t n) {\n c...",,
4,88406,,Remote,Not required,,CVE-2019-15164,https://www.cvedetails.com/cve/CVE-2019-15164/,CWE-918,Low,,...,,"@@ -156,6 +156,8 @@ static int rpcapd_recv(SOC...",libpcap,33834cb2a4d035b52aa2a26742f832a112e90a0a,484d60cbf7ca4ec758c3cbb8a82d68b244a78d58,0,"daemon_AuthUserPwd(char *username, char *passw...","daemon_AuthUserPwd(char *username, char *passw...",,


In [16]:
filtered_val = val_data[val_data['func_before'].str.len().between(100,300)]
filtered_val = filtered_val[filtered_val['Score'].notna()]
filtered_val.shape

(43856, 39)

In [4]:
filtered_val = val_data[val_data['func_before'].str.len()<100]
filtered_val = filtered_val[filtered_val['Score'].notna()]
filtered_val.shape

(10425, 39)

In [17]:
test_data = pd.read_csv("../data/big-vul/test.csv")

In [20]:
filtered_test = test_data[test_data['func_before'].str.len().between(100,300)]
filtered_test = filtered_test[filtered_test['Score'].notna()]
filtered_test.shape

(5524, 39)

In [18]:
filtered_test = test_data[test_data['func_before'].str.len()< 100]
filtered_test = filtered_test[filtered_test['Score'].notna()]
filtered_test.shape

(1334, 39)

In [19]:
test_data.head()

Unnamed: 0,index,Access Gained,Attack Origin,Authentication Required,Availability,CVE ID,CVE Page,CWE ID,Complexity,Confidentiality,...,parentID,patch,project,project_after,project_before,target,vul_func_with_fix,processed_func,flaw_line,flaw_line_index
0,73752,,Remote,Not required,Complete,CVE-2016-6561,https://www.cvedetails.com/cve/CVE-2016-6561/,CWE-476,Low,,...,,"@@ -21,6 +21,7 @@\n /*\n * Copyright (c) 2007...",illumos-gate,6d1c73b5858fefc6161c7d686345f0dc887ea799,516627f338a630bcf9806a91aa873bbbae9a2fac,0,smb_ofile_delete(void *arg)\n{\n\tsmb_tree_t\t...,smb_ofile_delete(void *arg)\n{\n\tsmb_tree_t\t...,,
1,54196,,Local,Not required,Complete,CVE-2016-3138,https://www.cvedetails.com/cve/CVE-2016-3138/,,Low,,...,,"@@ -1179,6 +1179,9 @@ static int acm_probe(str...",linux,8835ba4a39cf53f705417b3b3a94eb067673f2c9,0b818e3956fc1ad976bee791eadcbb3b5fec5bfd,0,static inline int acm_set_control(struct acm *...,static inline int acm_set_control(struct acm *...,,
2,169124,,Remote,Not required,,CVE-2018-6145,https://www.cvedetails.com/cve/CVE-2018-6145/,CWE-79,Medium,,...,,"@@ -82,13 +82,6 @@ static bool TokenExitsForei...",Chrome,133bc5c262b2555af223263452e9875a95db9eb7,1e8327c88920544f1503004b4e32850c935d4efb,0,HTMLTreeBuilderSimulator::State HTMLTreeBuilde...,HTMLTreeBuilderSimulator::State HTMLTreeBuilde...,,
3,109551,,Remote,Not required,Partial,CVE-2012-5135,https://www.cvedetails.com/cve/CVE-2012-5135/,CWE-399,Low,Partial,...,,"@@ -713,7 +713,8 @@ PrintWebViewHelper::PrintW...",Chrome,b755ebba29dd405d6f1e4cf70f5bc81ffd33b0f6,7b688dec9fa8ab42a4933e381ad9aeb63413139b,0,int PrintWebViewHelper::PrintPreviewContext::t...,int PrintWebViewHelper::PrintPreviewContext::t...,,
4,78906,,Local,Not required,Complete,CVE-2018-16276,https://www.cvedetails.com/cve/CVE-2018-16276/,CWE-20,Low,Complete,...,,"@@ -396,35 +396,24 @@ static ssize_t yurex_rea...",linux,f1e255d60ae66a9f672ff9a207ee6cd8e33d2679,bba57eddadda936c94b5dccf73787cb9e159d0a5,0,"static int yurex_open(struct inode *inode, str...","static int yurex_open(struct inode *inode, str...",,


In [20]:
test_data.columns.tolist()

['index',
 'Access Gained',
 'Attack Origin',
 'Authentication Required',
 'Availability',
 'CVE ID',
 'CVE Page',
 'CWE ID',
 'Complexity',
 'Confidentiality',
 'Integrity',
 'Known Exploits',
 'Publish Date',
 'Score',
 'Summary',
 'Update Date',
 'Vulnerability Classification',
 'add_lines',
 'codeLink',
 'commit_id',
 'commit_message',
 'del_lines',
 'file_name',
 'files_changed',
 'func_after',
 'func_before',
 'lang',
 'lines_after',
 'lines_before',
 'parentID',
 'patch',
 'project',
 'project_after',
 'project_before',
 'target',
 'vul_func_with_fix',
 'processed_func',
 'flaw_line',
 'flaw_line_index']

In [1]:
random_row = test_data.sample(n=1)
text = random_row['func_before'].values[0]
random_row['Score'].values[0]

NameError: name 'test_data' is not defined

In [22]:
text

'void GLES2DecoderImpl::SetGLError(GLenum error, const char* msg) {\n  if (msg) {\n    last_error_ = msg;\n    if (log_synthesized_gl_errors()) {\n      LOG(ERROR) << last_error_;\n    }\n    if (!msg_callback_.is_null()) {\n      msg_callback_.Run(0, GLES2Util::GetStringEnum(error) + " : " + msg);\n    }\n  }\n  error_bits_ |= GLES2Util::GLErrorToErrorBit(error);\n}\n'

In [23]:
def build_messages():
    messages = []
    indexes = []
    gt = None
    for i in range(MESSAGES_LENGTH):
        random_row = filtered_val.sample(n=1)
        text = random_row['func_before'].values[0]
        score = random_row['Score'].values[0]
        message = p1_template.format(score, text)
        indexes.append(int(random_row['index'].values[0]))
        messages.append(message)
    random_row = filtered_test.sample(n=1)
    text = random_row['func_before'].values[0]
    gt = random_row['Score'].values[0]
    message = p2_template.format(text)
    indexes.append(int(random_row['index'].values[0]))
    messages.append(message)
    return messages, indexes, gt

In [24]:
messages, indexes, gt_score = build_messages()

In [25]:
indexes

[96262, 123689, 84058]

In [26]:
# Limit maximum iterations for conversation
def generate_prediction(messages):
    conversation_history = list()

    for message in messages:

        conversation_history.append({"role": "user", "content": message})
        conversation_history.append({"role": "system", "content": ""})
        # Convert conversational history into chat template and tokenize
        inputs = tokenizer.apply_chat_template(conversation_history, return_tensors="pt", return_attention_mask=False).to(device)

        # Generate output
        generated_ids = model.generate(inputs,
            #streamer=streamer,
            max_new_tokens=20,
            do_sample=True,
            top_k=50,
            top_p=0.92,
            temperature= 0.9,
            pad_token_id=tokenizer.eos_token_id
        )

        # Get complete output from model including input prompt
        output = tokenizer.batch_decode(generated_ids)[0]

        # Filter only new output information using '</s>' delimiter, then strip starting and trailing whitespace
        output_filtered = output.split('[/INST]')[-1].strip()

        # Update conversation history with the latest output
        conversation_history[-1]["content"] = output_filtered

    return conversation_history

        # Capture input before start of next iteration
        #capture_input()

In [27]:
def generate_trials():
    conversations = list()
    result = dict()
    messages, indexes, gt_score = build_messages()
    result["indexes"] = indexes
    result["gt_score"] = gt_score
    if not gt_score:
        return None
    for i in range(max_samples):
        conversation = generate_prediction(messages)
        conversations.append(conversation)
    result["chats"] = conversations
    return result
        
    

In [28]:
conversations= generate_trials()

In [29]:
def save_conversations(conversation_history):
    # Save entire conversation history to text file for debugging or use for loading conversational context
    with open(filename, 'w') as f:
        json.dump(conversation_history, f, ensure_ascii=False, indent=4)

In [30]:
SAMPLES = 300

In [31]:
results = []
for i in range(SAMPLES):
    result = generate_trials()
    if result:
        results.append(result)

In [32]:
results

[{'indexes': [82214, 21462, 40661],
  'gt_score': 4.9,
  'chats': [[{'role': 'user',
     'content': 'The following snippet contians vulnerable code with a CVSS score of 7.1  SYSCALL_DEFINE2(shutdown, int, fd, int, how)\n{\n\treturn __sys_shutdown(fd, how);\n}\n'},
    {'role': 'system',
     'content': 'The provided snippet contains a vulnerability in the `shutdown` system call. The `'},
    {'role': 'user',
     'content': 'The following snippet contians vulnerable code with a CVSS score of 7.2  static int __init init_elf_binfmt(void)\n{\n\treturn register_binfmt(&elf_format);\n}\n'},
    {'role': 'system',
     'content': 'This code snippet contains a vulnerability with a CVSS score of 7.2. The'},
    {'role': 'user',
     'content': 'What is the CVSS score for the following snippet?  static void prb_thaw_queue(struct tpacket_kbdq_core *pkc)\n{\n\tpkc->reset_pending_on_curr_blk = 0;\n}\n'},
    {'role': 'system',
     'content': 'The CVSS score for the given snippet is 5.8, as it is

In [33]:
import json
json_data = json.dumps(results, ensure_ascii=False)

In [None]:
filename = f"{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}_300_100_30_2_codeLlama7b.txt"
save_conversations(results)

In [None]:


# Load conversational history from a previous context file
context_filename = "./*.txt"
with open(context_filename, 'r') as f:
     data = json.load(f)
     conversation_history = data


