In [3]:
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
from datetime import datetime
import json
import torch as t
import pandas as pd


## The transofrmer model initialization

On this notebook we aim to evaluate Mistral model for predicting vulnerability CVSS by providing a set of examples on the in-context learning coniguration. We will introduce three kind of variables in the prompt:
1. Code with a context length less than 100 words.
2. Code with a context lenth between 100 and 300 words.
3. Only CVSS text description.

Since the model supports a limited number of tokens/words we truncate some descriptions

In [4]:
# Suppress warning messages
from transformers.utils import logging
logging.set_verbosity(40)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'


In [5]:
# Program variables
#max_iterations = 30
conversation_history = list()
model_id = "mistralai/Mistral-7B-Instruct-v0.2"
filename = f"{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}.txt"


In [6]:
device = "cuda:0" if t.cuda.is_available() else "cpu"

In [7]:
device

'cuda:0'

In [8]:
cache_dir ="../datax/models"

In [9]:
# Load model
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, cache_dir=cache_dir, device_map=device, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, cache_dir=cache_dir, padding_side="left")
streamer = TextStreamer(tokenizer, skip_prompt=True)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [10]:
model.to(device)


MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0): MistralDecoderLayer(
        (self_attn): MistralAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
      (1): MistralDecoderLayer(
        (self

## Step 1: Prompt template configuration

This experiment considers three type of prompts:
1. Only code with out any context abour the score, we just want the model to classify and guess the score just providing some examples with the real ground truth score
2. Only code with some CVSS context, explaining first that CVSS consist on a score from 1.0 to 10.0 and we need to obtain that range of values depending on the severity of the vulnerability.
3. CVSS description and CVSS score.

In [80]:
prompt_templates = {}
prompt_templates['p1']={'role': None,
    'example':"The following snippet is a exploitable code with a score of {}:  {}",
          'question':"What is the score for the following snippet?   {}"}
          
prompt_templates['p2']={'role':'You are a software code vulnerability inspector, CVSS score is a number between 0.0 and 10.0',
        'example':"The following snippet is a vulnerable code with a CVSS score of {}:  {}",
        'question':"What is the CVSS score for the following snippet?  {}"}
        
prompt_templates['p3']={'role':'You are a software code vulnerability inspector and you should to provide a CVSS score depending on the severity. a CVSS score is a number between 0.0 and 10.0',
        'example':"The following vulnerability description has a escore of {}:  {}",
        'question':"What is the CVSS score from the following description?  {}"}



### Prompt examples

The following are just prompt examples with vulnerable code blocks and linked scores

In [8]:
prompt1 = '''The following snippet is a exploitable code with a score of 4.5
	dev = usb_get_intfdata(interface);
	if (!dev) {
		retval = -ENODEV;
		goto exit;
	}

	/* increment our usage count for the device */
	kref_get(&dev->kref);

	/* save our object in the file's private structure */
	mutex_lock(&dev->io_mutex);
	file->private_data = dev;
	mutex_unlock(&dev->io_mutex);'''

In [9]:
prompt2 = '''The following snippet has a score of 7.5: int i;
char inLine[64];
cin >> inLine;
i = atoi (inLine);
sleep(i);'''

In [10]:
prompt3 = '''What is the score for the following snippet? 
int main(int argc, char *argv[])
{
	rc = SQLConnect(Example.ConHandle, argv[0], SQL_NTS,
	(SQLCHAR *) "", SQL_NTS, (SQLCHAR *) "", SQL_NTS);
} '''

## Step 2: Experiment configuration

In [46]:
N_EXAMPLES = 2

In [13]:
max_trials = 30

## Step 3: Load testbed

On this experiment we are using *Big-vul* datasets. For providing the examples we are reusing training split from that dataset.
For building the question prompt we use the testing split dataset.

In [14]:
val_data = pd.read_csv("../data/big-vul/train.csv")

In [15]:
val_data.head()

Unnamed: 0,index,Access Gained,Attack Origin,Authentication Required,Availability,CVE ID,CVE Page,CWE ID,Complexity,Confidentiality,...,parentID,patch,project,project_after,project_before,target,vul_func_with_fix,processed_func,flaw_line,flaw_line_index
0,12473,,Remote,Not required,,CVE-2014-3508,https://www.cvedetails.com/cve/CVE-2014-3508/,CWE-200,Medium,Partial,...,17160033765480453be0a41335fa6b833691c049,"@@ -471,11 +471,12 @@ int OBJ_obj2txt(char *bu...",openssl,https://git.openssl.org/gitweb/?p=openssl.git;...,https://git.openssl.org/gitweb/?p=openssl.git;...,0,const char *OBJ_nid2sn(int n)\n\t{\n\tADDED_OB...,const char *OBJ_nid2sn(int n)\n\t{\n\tADDED_OB...,,
1,24444,,Local,Not required,,CVE-2011-4080,https://www.cvedetails.com/cve/CVE-2011-4080/,CWE-264,High,Complete,...,,"@@ -170,6 +170,11 @@ static int proc_taint(str...",linux,bfdc0b497faa82a0ba2f9dddcf109231dd519fcc,cb16e95fa2996743a6e80a665ed2ed0590bd38cf,0,void register_sysctl_root(struct ctl_table_roo...,void register_sysctl_root(struct ctl_table_roo...,,
2,111436,,Remote,Not required,Partial,CVE-2012-2875,https://www.cvedetails.com/cve/CVE-2012-2875/,,Medium,Partial,...,,"@@ -4058,11 +4058,6 @@ bool WebPage::touchEven...",Chrome,d345af9ed62ee5f431be327967f41c3cc3fe936a,e261bb8e47a6a9fdd1d26fd52b1538c5c9bcb122,0,WebPagePrivate::~WebPagePrivate()\n{\n // H...,WebPagePrivate::~WebPagePrivate()\n{\n m_we...,,
3,1314,,Remote,Not required,Complete,CVE-2009-3604,https://www.cvedetails.com/cve/CVE-2009-3604/,CWE-399,Medium,Complete,...,75c3466ba2e4980802e80b939495981240261cd5,"@@ -216,6 +216,28 @@ void *gmallocn_checkoverf...",poppler,https://cgit.freedesktop.org/poppler/poppler/t...,https://cgit.freedesktop.org/poppler/poppler/t...,0,"char *gstrndup(const char *s, size_t n) {\n c...","char *gstrndup(const char *s, size_t n) {\n c...",,
4,88406,,Remote,Not required,,CVE-2019-15164,https://www.cvedetails.com/cve/CVE-2019-15164/,CWE-918,Low,,...,,"@@ -156,6 +156,8 @@ static int rpcapd_recv(SOC...",libpcap,33834cb2a4d035b52aa2a26742f832a112e90a0a,484d60cbf7ca4ec758c3cbb8a82d68b244a78d58,0,"daemon_AuthUserPwd(char *username, char *passw...","daemon_AuthUserPwd(char *username, char *passw...",,


In [16]:
val_data.shape

(150908, 39)

In [17]:
val_data['Summary'].notna().sum()

139607

### Data filtering by size

Filtering functions between 100 and 300 length at the function 

In [106]:
filtered_val_300 = val_data[val_data['func_before'].str.len().between(100,300)]
filtered_val_300 = filtered_val_300[filtered_val_300['Score'].notna()]
filtered_val_300.shape

(43856, 39)

Filtering functions less 100

In [107]:
filtered_val_100 = val_data[val_data['func_before'].str.len()<100]
filtered_val_100 = filtered_val_100[filtered_val_100['Score'].notna()]
filtered_val_100.shape

(10425, 39)

### Load test split from Big-vul

In [19]:
test_data = pd.read_csv("../data/big-vul/test.csv")

In [20]:
test_data.shape

(18864, 39)

In [21]:
test_data['Summary'].notna().sum()

17568

In [22]:
filtered_test_300 = test_data[test_data['func_before'].str.len().between(100,300)]
filtered_test_300 = filtered_test_300[filtered_test_300['Score'].notna()]
filtered_test_300.shape

(5524, 39)

In [23]:
filtered_test_100 = test_data[test_data['func_before'].str.len()< 100]
filtered_test_100 = filtered_test_100[filtered_test_100['Score'].notna()]
filtered_test_100.shape

(1334, 39)

In [24]:
test_data.head()

Unnamed: 0,index,Access Gained,Attack Origin,Authentication Required,Availability,CVE ID,CVE Page,CWE ID,Complexity,Confidentiality,...,parentID,patch,project,project_after,project_before,target,vul_func_with_fix,processed_func,flaw_line,flaw_line_index
0,73752,,Remote,Not required,Complete,CVE-2016-6561,https://www.cvedetails.com/cve/CVE-2016-6561/,CWE-476,Low,,...,,"@@ -21,6 +21,7 @@\n /*\n * Copyright (c) 2007...",illumos-gate,6d1c73b5858fefc6161c7d686345f0dc887ea799,516627f338a630bcf9806a91aa873bbbae9a2fac,0,smb_ofile_delete(void *arg)\n{\n\tsmb_tree_t\t...,smb_ofile_delete(void *arg)\n{\n\tsmb_tree_t\t...,,
1,54196,,Local,Not required,Complete,CVE-2016-3138,https://www.cvedetails.com/cve/CVE-2016-3138/,,Low,,...,,"@@ -1179,6 +1179,9 @@ static int acm_probe(str...",linux,8835ba4a39cf53f705417b3b3a94eb067673f2c9,0b818e3956fc1ad976bee791eadcbb3b5fec5bfd,0,static inline int acm_set_control(struct acm *...,static inline int acm_set_control(struct acm *...,,
2,169124,,Remote,Not required,,CVE-2018-6145,https://www.cvedetails.com/cve/CVE-2018-6145/,CWE-79,Medium,,...,,"@@ -82,13 +82,6 @@ static bool TokenExitsForei...",Chrome,133bc5c262b2555af223263452e9875a95db9eb7,1e8327c88920544f1503004b4e32850c935d4efb,0,HTMLTreeBuilderSimulator::State HTMLTreeBuilde...,HTMLTreeBuilderSimulator::State HTMLTreeBuilde...,,
3,109551,,Remote,Not required,Partial,CVE-2012-5135,https://www.cvedetails.com/cve/CVE-2012-5135/,CWE-399,Low,Partial,...,,"@@ -713,7 +713,8 @@ PrintWebViewHelper::PrintW...",Chrome,b755ebba29dd405d6f1e4cf70f5bc81ffd33b0f6,7b688dec9fa8ab42a4933e381ad9aeb63413139b,0,int PrintWebViewHelper::PrintPreviewContext::t...,int PrintWebViewHelper::PrintPreviewContext::t...,,
4,78906,,Local,Not required,Complete,CVE-2018-16276,https://www.cvedetails.com/cve/CVE-2018-16276/,CWE-20,Low,Complete,...,,"@@ -396,35 +396,24 @@ static ssize_t yurex_rea...",linux,f1e255d60ae66a9f672ff9a207ee6cd8e33d2679,bba57eddadda936c94b5dccf73787cb9e159d0a5,0,"static int yurex_open(struct inode *inode, str...","static int yurex_open(struct inode *inode, str...",,


In [25]:
test_data.columns.tolist()

['index',
 'Access Gained',
 'Attack Origin',
 'Authentication Required',
 'Availability',
 'CVE ID',
 'CVE Page',
 'CWE ID',
 'Complexity',
 'Confidentiality',
 'Integrity',
 'Known Exploits',
 'Publish Date',
 'Score',
 'Summary',
 'Update Date',
 'Vulnerability Classification',
 'add_lines',
 'codeLink',
 'commit_id',
 'commit_message',
 'del_lines',
 'file_name',
 'files_changed',
 'func_after',
 'func_before',
 'lang',
 'lines_after',
 'lines_before',
 'parentID',
 'patch',
 'project',
 'project_after',
 'project_before',
 'target',
 'vul_func_with_fix',
 'processed_func',
 'flaw_line',
 'flaw_line_index']

In [26]:
random_row = test_data.sample(n=1)
text = random_row['func_before'].values[0]
random_row['Score'].values[0]

5.8

In [27]:
text

'WebGLRenderingContextBase::CreateContextProviderInternal(\n    CanvasRenderingContextHost* host,\n    const CanvasContextCreationAttributes& attributes,\n    unsigned web_gl_version,\n    bool* using_gpu_compositing) {\n  DCHECK(host);\n  ExecutionContext* execution_context = host->GetTopExecutionContext();\n  DCHECK(execution_context);\n\n  Platform::ContextAttributes context_attributes = ToPlatformContextAttributes(\n      attributes, web_gl_version,\n      SupportOwnOffscreenSurface(execution_context));\n\n  Platform::GraphicsInfo gl_info;\n  std::unique_ptr<WebGraphicsContext3DProvider> context_provider;\n  const auto& url = execution_context->Url();\n  if (IsMainThread()) {\n    *using_gpu_compositing = !Platform::Current()->IsGpuCompositingDisabled();\n    context_provider =\n        Platform::Current()->CreateOffscreenGraphicsContext3DProvider(\n            context_attributes, url, nullptr, &gl_info);\n  } else {\n    context_provider = CreateContextProviderOnWorkerThread(\n   

In [30]:
def build_messages(filtered_val,  filtered_test, prompt_template):
    messages = []
    indexes = []
    gt = None
    if (role_template := prompt_template['role']) :
        messages.append(role_template)
    p1_template = prompt_template['example']
    p2_template = prompt_template['question']
    for i in range(N_EXAMPLES):
        random_row = filtered_val.sample(n=1)
        text = random_row['func_before'].values[0]
        score = random_row['Score'].values[0]
        message = p1_template.format(score, text)
        indexes.append(int(random_row['index'].values[0]))
        messages.append(message)
    random_row = filtered_test.sample(n=1)
    text = random_row['func_before'].values[0]
    gt = random_row['Score'].values[0]
    message = p2_template.format(text)
    indexes.append(int(random_row['index'].values[0]))
    messages.append(message)
    return messages, indexes, gt

### Build messages sandbox

On this example we are taking random examples from training dataset with code length less than 100 and random question from test dataset with code length less than 100, we also use the prompt template configuration 1

In [81]:
messages, indexes, gt_score = build_messages(filtered_val_100, filtered_test_100, prompt_templates['p2'])

**Indexes:** Each datapoint from the testbed has an index, we capture the data point index from the dataset to have trazability. The indexes array reports the indexes from the training dataset (examples) and the index from the test dataset (question)

In [76]:
indexes

[104444, 5556, 8502]

In [82]:
messages

['You are a software code vulnerability inspector, CVSS score is a number between 0.0 and 10.0',
 'The following snippet is a vulnerable code with a CVSS score of 4.3:  SchedulerObject::~SchedulerObject()\n{\n\tdelete m_codec;\n}\n',
 'The following snippet is a vulnerable code with a CVSS score of 4.3:    void PushNextTask(base::OnceClosure task) {\n    task_stack_.push(std::move(task));\n  }\n',
 'What is the CVSS score for the following snippet?  static bool interface_ready(void) {\n return bt_hal_cbacks != NULL;\n}\n']

In [100]:
# Limit maximum iterations for conversation
def generate_prediction(messages):
    conversation_history = list()

    for message in messages:

        conversation_history.append({"role": "user", "content": message})
        conversation_history.append({"role": "assistant", "content": ""})
        # Convert conversational history into chat template and tokenize
        inputs = tokenizer.apply_chat_template(conversation_history, return_tensors="pt", return_attention_mask=False).to(device)

        # Generate output
        generated_ids = model.generate(inputs,
            #streamer=streamer,
            max_new_tokens=30,
            do_sample=True,
            top_k=50,
            top_p=0.92,
            temperature= 0.9,
            pad_token_id=tokenizer.eos_token_id
        )

        # Get complete output from model including input prompt
        output = tokenizer.batch_decode(generated_ids)[0]

        # Filter only new output information using '</s>' delimiter, then strip starting and trailing whitespace
        output_filtered = output.split('[/INST]')[-1].strip()

        # Update conversation history with the latest output
        conversation_history[-1]["content"] = output_filtered

    return conversation_history

        # Capture input before start of next iteration
        #capture_input()

In [92]:
conversation = generate_prediction(messages)

In [93]:
conversation

[{'role': 'user',
  'content': 'You are a software code vulnerability inspector, CVSS score is a number between 0.0 and 10.0'},
 {'role': 'assistant',
  'content': "</s> I'm a text-based AI and don't have the ability to directly inspect software code or assign CVSS scores. The Common Vulnerability Scoring System (CVSS) is a freely available and open industry standard for assessing the severity of computer system security vulnerabilities. A CVSS score"},
 {'role': 'user',
  'content': 'The following snippet is a vulnerable code with a CVSS score of 4.3:  SchedulerObject::~SchedulerObject()\n{\n\tdelete m_codec;\n}\n'},
 {'role': 'assistant',
  'content': "</s> The provided code snippet is a destructor for a class named `SchedulerObject`. It appears to only contain a single line of code, which is deleting the `m_codec` member variable. Based on the information given, it's difficult to determine why this code would have a CVSS score of"},
 {'role': 'user',
  'content': 'The following snip

In [94]:
def generate_trials():
    conversations = list()
    for i in range(max_trials):
        conversation = generate_prediction(messages)
        conversations.append(conversation)
    return conversations
        
    

### Conversation example

In [95]:
conversations= generate_trials()

In [96]:
conversations

[[{'role': 'user',
   'content': 'You are a software code vulnerability inspector, CVSS score is a number between 0.0 and 10.0'},
  {'role': 'assistant',
   'content': '</s> I understand that I am a software code vulnerability inspector, and I use the Common Vulnerability Scoring System (CVSS) to assign a severity score to vulnerabilities I discover. The CVSS score ranges from 0.0 to 10.0, where 0.0 represents'},
  {'role': 'user',
   'content': 'The following snippet is a vulnerable code with a CVSS score of 4.3:  SchedulerObject::~SchedulerObject()\n{\n\tdelete m_codec;\n}\n'},
  {'role': 'assistant',
   'content': "</s> Based on the given code snippet, it's not immediately clear why this code would have a CVSS score of 4.3. The code appears to be a destructor for a `SchedulerObject` class that deletes the `m_codec` pointer.\n\nHowever, a CVSS score"},
  {'role': 'user',
   'content': 'The following snippet is a vulnerable code with a CVSS score of 4.3:    void PushNextTask(base::Onc

In [67]:
def save_conversations(conversation_history):
    # Save entire conversation history to text file for debugging or use for loading conversational context
    with open(filename, 'w') as f:
        json.dump(conversation_history, f, ensure_ascii=False, indent=4)

In [101]:
SAMPLES = 2

## Step 4: Parameter validation and experiment execution

In [102]:
N_EXAMPLES

2

In [103]:
max_trials

30

In [104]:
SAMPLES

2

In [108]:
results = []
for i in range(SAMPLES):
    result = dict()
    messages, indexes, gt_score = build_messages(filtered_val_300, filtered_test_300, prompt_templates['p1'])
    if not gt_score:
        continue #TODO: DRC filter data with gt_score only
    result["indexes"] = indexes
    result["gt_score"] = gt_score
    result["chats"] = generate_trials()
    results.append(result)

In [35]:
results

[{'indexes': [25121, 10375],
  'gt_score': 1.9,
  'chats': [[{'role': 'user',
     'content': 'The following snippet is a vulnerable code with a CVSS score of 6.8  void ip_rt_multicast_event(struct in_device *in_dev)\n{\n\trt_cache_flush(dev_net(in_dev->dev), 0);\n}\n'},
    {'role': 'assistant',
     'content': '</s> I see that you have provided a vulnerable function named `ip_rt_multicast_event'},
    {'role': 'user',
     'content': 'What is the CVSS score for the following snippet?  user_local_get_user_name (User *user)\n{\n        return user->user_name;\n}\n'},
    {'role': 'assistant',
     'content': '</s> The given function `user_local_get_user_name` is not vulnerable on its own'}],
   [{'role': 'user',
     'content': 'The following snippet is a vulnerable code with a CVSS score of 6.8  void ip_rt_multicast_event(struct in_device *in_dev)\n{\n\trt_cache_flush(dev_net(in_dev->dev), 0);\n}\n'},
    {'role': 'assistant',
     'content': '</s> The provided code snippet is a simpl

In [35]:
import json
json_data = json.dumps(results, ensure_ascii=False)

In [36]:
filename = f"{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}_300_100_30_1_Mistral7b_Explicit.txt"
save_conversations(results)

In [None]:


# Load conversational history from a previous context file
context_filename = "./*.txt"
with open(context_filename, 'r') as f:
     data = json.load(f)
     conversation_history = data


