In [9]:
!pip install langchain
!pip install chromadb
!pip install pypdf
!pip install sentence-transformers
!pip install langchain-community langchain-core
!pip install openai
!pip install openpyxl





In [10]:
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from openai import OpenAI
import openpyxl



In [11]:
loaders = [PyPDFLoader("/content/CWEexplanations.pdf")]

docs = []

for file in loaders:
  docs.extend(file.load())

text_splitter = RecursiveCharacterTextSplitter(chunk_size = 1000, chunk_overlap = 100)
docs = text_splitter.split_documents(docs)
embedding_function = HuggingFaceEmbeddings(model_name = "sentence-transformers/all-MiniLM-L6-v2", model_kwargs = {'device': 'cpu'})

vectorstore = Chroma.from_documents(docs, embedding_function, persist_directory = "./chroma_db_nccn")
print(vectorstore._collection.count())

108


In [None]:
import os
import signal
import sys
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
import openai


def signal_handler(sig,frame):
  print('\nThanks for using Gemini')
  sys.exit(0)

signal.signal(signal.SIGINT, signal_handler)

def generate_rag_prompt(query,context):
  escape = context.replace("'","").replace('"',"").replace("\n"," ")
  prompt = ("""
  You are a helpful and informative bot that answers questions using text from the reference context included below. \
  If the context is irrelevant to the answer, you may ignore it.
      QUESTION: '{query}'
      CONTEXT: '{context}'

      ANSWER:
  """).format(query=query, context=context)
  return prompt

  #last line prevents hallucinations

def get_relevant_context_from_db(query):
  context = ""
  embedding_function = HuggingFaceEmbeddings(model_name = "sentence-transformers/all-MiniLM-L6-v2", model_kwargs = {'device': 'cpu'})
  vector_db = Chroma(persist_directory="./chroma_db_nccn", embedding_function=embedding_function)
  search_results = vector_db.similarity_search(query, k=1)
  for results in search_results:
    context += results.page_content + "\n"
  return context

OPENAI_API_KEY =


def generate_answer(prompt):

  client = OpenAI(api_key=OPENAI_API_KEY)

  chat_completion = client.chat.completions.create(
        model="gpt-4",  # or "gpt-3.5-turbo"
        messages=[
            {"role": "system", "content": "You are a helpful assistant."},  # Optional: system message to set context
            {"role": "user", "content": prompt},  # User's prompt
        ],
        max_tokens=150,
        temperature=0.7
    )

    # Correctly access the content
  return chat_completion.choices[0].message.content.strip()

wrkbk = openpyxl.load_workbook("CWE-416_vul.xlsx")
sh = wrkbk.active


for i in range(1, sh.max_row + 1):
    # Extract data from the first cell in each row (or modify as needed)
    function_code = sh.cell(row=i, column=1).value
    #print(function_code)
      # Construct the query
  #Benchmark 1
    query = f"""Predict whether the C/C++ function below is vulnerable. Strictly return 1 for a vulnerable function or 0 for a non-vulnerable function without further explanation.
    {function_code}
    """
    #benchmark2
    # query = f"""
    #   identify the CWE-ID of this vulnerability. The potential CWE-IDs
    #    are listed in the CWEexplanations document. Strictly
    #    return one of the CWE-IDs from the list
    #   {function_code}

    #   """
#benchmark3
  #   query = f"""
  #    The following C/C++ function is confirmed vulnerable AS CWE-79, esti-
  # mate the CVSS severity score (Version 3.1) of this vulnerable
  # function. The potential severity score is a float between 0
  # to 10. Strictly return a severity score estimation without
  # any further explanation.
  #      {function_code}

  #      """
#benchmark 4
# query = """Example Vulnerable Function 1: PS_SERIALIZER_DECODE_FUNC(php_binary) /* {{{ */ { const char *p; char *name; const char *endptr = val + vallen; zval *current; int namelen; int has_value; php_unserialize_data_t var_hash; PHP_VAR_UNSERIALIZE_INIT(var_hash); for (p = val; p < endptr; ) { zval **tmp; namelen = ((unsigned char)(*p)) & (~PS_BIN_UNDEF); if (namelen < 0 || namelen > PS_BIN_MAX || (p + namelen) >= endptr) { return FAILURE; } name = estrndup(p + 1, namelen); p += namelen + 1; if (zend_hash_find(&EG(symbol_table), name, namelen + 1, (void **) &tmp) == SUCCESS) { if ((Z_TYPE_PP(tmp) == IS_ARRAY && Z_ARRVAL_PP(tmp) == &EG(symbol_table)) || *tmp == PS(http_session_vars)) { efree(name); continue; } } if (has_value) { ALLOC_INIT_ZVAL(current); if (php_var_unserialize(&current, (const unsigned char **) &p, (const unsigned char *) endptr, &var_hash TSRMLS_CC)) { php_set_session_var(name, namelen, current, &var_hash TSRMLS_CC); } else { PHP_VAR_UNSERIALIZE_DESTROY(var_hash); return FAILURE; } var_push_dtor_no_addref(&var_hash, &current); } PS_ADD_VARL(name, namelen); efree(name); } PHP_VAR_UNSERIALIZE_DESTROY(var_hash); return SUCCESS; } /* }}} */ Example Repair Patch 1: : PS_SERIALIZER_DECODE_FUNC(php_binary) /* {{{ */ { const char *p; char *name; const char *endptr = val + vallen; zval *current; int namelen; int has_value; php_unserialize_data_t var_hash; PHP_VAR_UNSERIALIZE_INIT(var_hash); for (p = val; p < endptr; ) { zval **tmp; namelen = ((unsigned char)(*p)) & (~PS_BIN_UNDEF); if (namelen < 0 || namelen > PS_BIN_MAX || (p + namelen) >= endptr) { PHP_VAR_UNSERIALIZE_DESTROY(var_hash); return FAILURE; } name = estrndup(p + 1, namelen); p += namelen + 1; if (zend_hash_find(&EG(symbol_table), name, namelen + 1, (void **) &tmp) == SUCCESS) { if ((Z_TYPE_PP(tmp) == IS_ARRAY && Z_ARRVAL_PP(tmp) == &EG(symbol_table)) || *tmp == PS(http_session_vars)) { efree(name); continue; } } if (has_value) { ALLOC_INIT_ZVAL(current); if (php_var_unserialize(&current, (const unsigned char **) &p, (const unsigned char *) endptr, &var_hash TSRMLS_CC)) { php_set_session_var(name, namelen, current, &var_hash TSRMLS_CC); } else { PHP_VAR_UNSERIALIZE_DESTROY(var_hash); return FAILURE; } var_push_dtor_no_addref(&var_hash, &current); } PS_ADD_VARL(name, namelen); efree(name); } PHP_VAR_UNSERIALIZE_DESTROY(var_hash); return SUCCESS; } /* }}} */ Example Vulnerable Function 2: ZEND_METHOD(CURLFile, __wakeup) { zend_update_property_string(curl_CURLFile_class, getThis(), "name", sizeof("name")-1, "" TSRMLS_CC); zend_throw_exception(NULL, "Unserialization of CURLFile instances is not allowed", 0 TSRMLS_CC); } Example Repair Patch 2: ZEND_METHOD(CURLFile, __wakeup) { zval *_this = getThis(); zend_unset_property(curl_CURLFile_class, _this, "name", sizeof("name")-1 TSRMLS_CC); zend_update_property_string(curl_CURLFile_class, _this, "name", sizeof("name")-1, "" TSRMLS_CC); zend_throw_exception(NULL, "Unserialization of CURLFile instances is not allowed", 0 TSRMLS_CC); } Generate repair patches for the following vulnerable func- tion. The vulnerability is CWE-416. The return format should strictly follow the Example Repair Patch provided above PHP_FUNCTION(snmp_set_enum_print)
#       {
# 	    long a1;

# 	if (zend_parse_parameters(ZEND_NUM_ARGS() TSRMLS_CC, "l", &a1) == FAILURE) {
# 		RETURN_FALSE;
# 	}

#         netsnmp_ds_set_boolean(NETSNMP_DS_LIBRARY_ID, NETSNMP_DS_LIB_PRINT_NUMERIC_ENUM, (int) a1);
#         RETURN_TRUE;
# }"""



    context = get_relevant_context_from_db(query)
    prompt = generate_rag_prompt(query=query, context=context)
    answer = generate_answer(prompt=prompt)

      # Print or store the answer as needed
    print(f"Row {i} - Answer: {answer}")



Row 1 - Answer: 0
Row 2 - Answer: 0
Row 3 - Answer: 0
Row 4 - Answer: 0
Row 5 - Answer: 1
Row 6 - Answer: 0
Row 7 - Answer: 0
Row 8 - Answer: 0
Row 9 - Answer: 1
Row 10 - Answer: 0
