In [1]:
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
from datetime import datetime
import json
import torch as t
import pandas as pd


In [2]:
# Suppress warning messages
from transformers.utils import logging
logging.set_verbosity(40)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'


In [3]:
# Program variables
#max_iterations = 10
conversation_history = list()
model_id = "codellama/CodeLlama-7b-Instruct-hf"
filename = f"{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}.txt"


In [4]:
device = "cuda:0" if t.cuda.is_available() else "cpu"

In [5]:
device

'cuda:0'

In [6]:
cache_dir ="../datax/models"

In [7]:
# Load model
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, cache_dir=cache_dir, device_map=device, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, cache_dir=cache_dir, padding_side="left")
streamer = TextStreamer(tokenizer, skip_prompt=True)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

## Prompt template configuration

In [8]:
prompt_templates = {}
prompt_templates['p1']={'role': None,
    'example':"The following snippet is a confirmed vulnerable code with a severity score of {}:  {}",
          'question':"What is the vulnerability severity score for the following snippet?   {}",
        "field":"func_before"}
          
prompt_templates['p2']={'role':'You are a software code vulnerability inspector, CVSS score is a number between 0.0 and 10.0',
        'example':"The following snippet is a vulnerable code with a CVSS score of {}:  {}",
        'question':"What is the CVSS score for the following snippet?  {}",
        "field":"func_before"}
        
prompt_templates['p3']={'role':'You are a software code vulnerability inspector and you should to provide a CVSS score depending on the severity. a CVSS score is a number between 0.0 and 10.0',
        'example':"The following vulnerability description has a escore of {}:  {}",
        'question':"What is the CVSS score from the following description?  {}",
        "field":"func_before"}
  

### Examples

In [9]:
prompt1 = '''The following snippet is a exploitable code with a score of 4.5
	dev = usb_get_intfdata(interface);
	if (!dev) {
		retval = -ENODEV;
		goto exit;
	}

	/* increment our usage count for the device */
	kref_get(&dev->kref);

	/* save our object in the file's private structure */
	mutex_lock(&dev->io_mutex);
	file->private_data = dev;
	mutex_unlock(&dev->io_mutex);'''

In [10]:
prompt2 = '''The following snippet has a score of 7.5: int i;
char inLine[64];
cin >> inLine;
i = atoi (inLine);
sleep(i);'''

In [11]:
prompt3 = '''What is the score for the following snippet? 
int main(int argc, char *argv[])
{
	rc = SQLConnect(Example.ConHandle, argv[0], SQL_NTS,
	(SQLCHAR *) "", SQL_NTS, (SQLCHAR *) "", SQL_NTS);
} '''

In [12]:
model.to(device)


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32016, 4096)
    (layers): ModuleList(
      (0): LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
      (1): LlamaDecoderLayer(
        (self_attn): LlamaAtten

In [13]:
N_EXAMPLES= 3

In [14]:
max_trials = 30

# Data loading and filtering

In [15]:
val_data = pd.read_csv("../datax/big-vul/train.csv")

In [16]:
val_data.head()

Unnamed: 0,index,Access Gained,Attack Origin,Authentication Required,Availability,CVE ID,CVE Page,CWE ID,Complexity,Confidentiality,...,parentID,patch,project,project_after,project_before,target,vul_func_with_fix,processed_func,flaw_line,flaw_line_index
0,12473,,Remote,Not required,,CVE-2014-3508,https://www.cvedetails.com/cve/CVE-2014-3508/,CWE-200,Medium,Partial,...,17160033765480453be0a41335fa6b833691c049,"@@ -471,11 +471,12 @@ int OBJ_obj2txt(char *bu...",openssl,https://git.openssl.org/gitweb/?p=openssl.git;...,https://git.openssl.org/gitweb/?p=openssl.git;...,0,const char *OBJ_nid2sn(int n)\n\t{\n\tADDED_OB...,const char *OBJ_nid2sn(int n)\n\t{\n\tADDED_OB...,,
1,24444,,Local,Not required,,CVE-2011-4080,https://www.cvedetails.com/cve/CVE-2011-4080/,CWE-264,High,Complete,...,,"@@ -170,6 +170,11 @@ static int proc_taint(str...",linux,bfdc0b497faa82a0ba2f9dddcf109231dd519fcc,cb16e95fa2996743a6e80a665ed2ed0590bd38cf,0,void register_sysctl_root(struct ctl_table_roo...,void register_sysctl_root(struct ctl_table_roo...,,
2,111436,,Remote,Not required,Partial,CVE-2012-2875,https://www.cvedetails.com/cve/CVE-2012-2875/,,Medium,Partial,...,,"@@ -4058,11 +4058,6 @@ bool WebPage::touchEven...",Chrome,d345af9ed62ee5f431be327967f41c3cc3fe936a,e261bb8e47a6a9fdd1d26fd52b1538c5c9bcb122,0,WebPagePrivate::~WebPagePrivate()\n{\n // H...,WebPagePrivate::~WebPagePrivate()\n{\n m_we...,,
3,1314,,Remote,Not required,Complete,CVE-2009-3604,https://www.cvedetails.com/cve/CVE-2009-3604/,CWE-399,Medium,Complete,...,75c3466ba2e4980802e80b939495981240261cd5,"@@ -216,6 +216,28 @@ void *gmallocn_checkoverf...",poppler,https://cgit.freedesktop.org/poppler/poppler/t...,https://cgit.freedesktop.org/poppler/poppler/t...,0,"char *gstrndup(const char *s, size_t n) {\n c...","char *gstrndup(const char *s, size_t n) {\n c...",,
4,88406,,Remote,Not required,,CVE-2019-15164,https://www.cvedetails.com/cve/CVE-2019-15164/,CWE-918,Low,,...,,"@@ -156,6 +156,8 @@ static int rpcapd_recv(SOC...",libpcap,33834cb2a4d035b52aa2a26742f832a112e90a0a,484d60cbf7ca4ec758c3cbb8a82d68b244a78d58,0,"daemon_AuthUserPwd(char *username, char *passw...","daemon_AuthUserPwd(char *username, char *passw...",,


## Filtering data by length

100 to 300 code length

In [None]:
filtered_val_300 = val_data[val_data['func_before'].str.len().between(100,300)]
filtered_val_300 = filtered_val_300[filtered_val_300['Score'].notna()]
filtered_val_300.shape

less than 100 words

In [None]:
filtered_val_100 = val_data[val_data['func_before'].str.len()<100]
filtered_val_100 = filtered_val_100[filtered_val_100['Score'].notna()]
filtered_val_100.shape

### Load testing data

In [None]:
test_data = pd.read_csv("../datax/big-vul/test.csv")

In [None]:
filtered_test_300 = test_data[test_data['func_before'].str.len().between(100,300)]
filtered_test_300 = filtered_test_300[filtered_test_300['Score'].notna()]
filtered_test_300.shape

In [None]:
filtered_test_100 = test_data[test_data['func_before'].str.len()< 100]
filtered_test_100 = filtered_test_100[filtered_test_100['Score'].notna()]
filtered_test_100.shape

In [None]:
test_data.head()

In [None]:
def build_messages(filtered_val,  filtered_test, prompt_template):
    messages = []
    indexes = []
    gt = None
    if (role_template := prompt_template['role']) :
        messages.append(role_template)
    p1_template = prompt_template['example']
    p2_template = prompt_template['question']
    for i in range(N_EXAMPLES):
        random_row = filtered_val.sample(n=1)
        text = random_row['func_before'].values[0]
        score = random_row['Score'].values[0]
        message = p1_template.format(score, text)
        indexes.append(int(random_row['index'].values[0]))
        messages.append(message)
    random_row = filtered_test.sample(n=1)
    text = random_row['func_before'].values[0]
    gt = random_row['Score'].values[0]
    message = p2_template.format(text)
    indexes.append(int(random_row['index'].values[0]))
    messages.append(message)
    return messages, indexes, gt

In [None]:
messages, indexes, gt_score = build_messages(filtered_val_100, filtered_test_100, prompt_templates['p1'])

In [None]:
indexes

In [None]:
# Limit maximum iterations for conversation
def generate_prediction(messages):
    conversation_history = list()

    for message in messages:

        conversation_history.append({"role": "user", "content": message})
        conversation_history.append({"role": "system", "content": ""})
        # Convert conversational history into chat template and tokenize
        inputs = tokenizer.apply_chat_template(conversation_history, return_tensors="pt", return_attention_mask=False).to(device)

        # Generate output
        generated_ids = model.generate(inputs,
            #streamer=streamer,
            max_new_tokens=20,
            do_sample=True,
            top_k=50,
            top_p=0.92,
            temperature= 0.9,
            pad_token_id=tokenizer.eos_token_id
        )

        # Get complete output from model including input prompt
        output = tokenizer.batch_decode(generated_ids)[0]

        # Filter only new output information using '</s>' delimiter, then strip starting and trailing whitespace
        output_filtered = output.split('[/INST]')[-1].strip()

        # Update conversation history with the latest output
        conversation_history[-1]["content"] = output_filtered

    return conversation_history

        # Capture input before start of next iteration
        #capture_input()

In [None]:
def generate_trials():
    conversations = list()
    for i in range(max_trials):
        conversation = generate_prediction(messages)
        conversations.append(conversation)
    return conversations
        
        
    

In [None]:
conversations= generate_trials()

In [None]:
def save_conversations(conversation_history):
    # Save entire conversation history to text file for debugging or use for loading conversational context
    with open(filename, 'w') as f:
        json.dump(conversation_history, f, ensure_ascii=False, indent=4)

In [None]:
N_EXAMPLES= 3

In [None]:
SAMPLES = 300

In [None]:
import logging
import time
logging.basicConfig(filename='my_script.log', level=logging.INFO, format='%(asctime)s %(message)s')


In [None]:
results = []
for i in range(SAMPLES):
    print("Executing...")
    logging.info(f"Logging message {i}")
    result = dict()
    messages, indexes, gt_score = build_messages(filtered_val_100, filtered_test_100, prompt_templates['p1'])
    if not gt_score:
        continue #TODO: DRC filter data with gt_score only
    result["indexes"] = indexes
    result["gt_score"] = gt_score
    result["chats"] = generate_trials()
    results.append(result)

In [None]:
import json
json_data = json.dumps(results, ensure_ascii=False)

In [None]:
def save_conversations(conversation_history):
    # Save entire conversation history to text file for debugging or use for loading conversational context
    with open(filename, 'w') as f:
        json.dump(conversation_history, f, ensure_ascii=False, indent=4)

In [None]:
filename = f"{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}_100_30_3_codeLlama7b_P1.txt"
save_conversations(results)