In [426]:
import pandas as pd
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from dotenv import dotenv_values
import sys
sys.path.insert(0,'/workspaces/RAG_secure_code_generation/src')
from utils.utils import load_yaml, init_argument_parser, sanitize_output, fill_default_parameters
from langchain.prompts import (
    ChatPromptTemplate, PromptTemplate
)
from utils.openai_utils import is_openai_model, build_chat_model
from langchain.chat_models import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
import random
import numpy as np
from functools import partial
from typing import List


In [427]:
seed = 156
np.random.seed(seed)
random.seed(seed)

In [428]:
env = dotenv_values()


In [429]:
samples_to_extract = 4

In [430]:
template_file = "../data/templates/complete_function_readable.yaml"
template_examples_file = "../data/example_templates/detect_xss_simple_prompt.txt"
task_file = "../data/tasks/detect_xss_simple_prompt.txt"
parameters_file = "../data/prompt_parameters/empty.yaml"
train_set_file = "../data/train.csv"
model_name = "gpt-3.5-turbo-0613"

In [431]:
template = load_yaml(template_file)
    # load parameters
prompt_parameters = load_yaml(parameters_file)

#read txt containing the task
with open(task_file) as f:
    prompt_parameters["input"] = f.read()
with open(template_examples_file) as f:
    example_template = f.read()
prompt_parameters = fill_default_parameters(prompt_parameters, template["default_parameters"])
use_openai_api = is_openai_model(model_name)
openai_key = env['OPENAI_API_KEY']
model = ChatOpenAI(temperature=1, openai_api_key=openai_key, model=model_name)

In [432]:
prompt = ChatPromptTemplate.from_messages([("system", template["input"]), ("human", "{input}")])
print(prompt.format(**prompt_parameters))


System: The user will provide the initial part of the function (function name, parameters with types, return type and a comment describing the purpose of the function, with some optional example. 
You are a coding assistant that writes some python code to complete the user's function. Be sure that the code is syntactically correct and that it returns the correct type. 
The lenght of the code should short, readable, and without redundant checks on the paylaod. 
Return only python code in Markdown format, e.g.:
```python
....
```
Human: def detect_xss(http_get_request: str)->bool: 
""" Check if in the given http_get_request there is an XSS exploit, considering also the possible evasions that an attacker can perform.""" 


In [433]:
trainset = pd.read_csv(train_set_file)

In [434]:
def sample_examples(df:pd.DataFrame,
                    samples_per_class:int = 2, 
                    label_column:str = "Class",
                    shuffle = False,
                    format_function = lambda x: x):
    #extract different labels using label_column
    labels = df[label_column].unique()
    examples = pd.DataFrame()
    for l in labels:
        #extract samples_per_class examples for each label
        sample = df[df[label_column] == l].sample(samples_per_class)
        #format the examples

        #add the examples to the examples dataframe
        print(sample)
        examples = pd.concat([examples, sample], ignore_index=True)
    if shuffle:
        examples = examples.sample(frac=1)
    
    #format examples
    examples = format_function(examples)

    return examples

In [435]:
def humaneval_style_format(examples:pd.Series,
                    template:str,
                    label_column:str = "Class",
                    payload_column:str = "Payloads",
                    mappig:dict = {"Benign": False, "Malicious": True})-> List[str]:
    formatted_examples = []
    for _ , row in examples.iterrows():
        formatted_examples.append(template.format(input = row[payload_column], output = mappig[row[label_column]]))
    return formatted_examples
   

In [436]:
examples = sample_examples(trainset, samples_per_class=samples_to_extract, shuffle=True)
examples

                                               Payloads      Class
2877  https://expand.snowballeffect.net:8442/plesk/u...  Malicious
6306  http://www.aromasin.com/content/news.jsp?setsh...  Malicious
3748  http://www.talonline.ca/searchalberta/search.j...  Malicious
6108  http://rules.nyse.com/nysetools/tocchapter.asp...  Malicious
                                                Payloads   Class
7762   http://www.wikihow.com/purchase-drafting-suppl...  Benign
12707  http://localhost:8080/tienda1/miembros/editar....  Benign
11988  http://localhost:8080/tienda1/publico/pagar.js...  Benign
9888   http://www.wikihow.com/preserve-your-paint-wit...  Benign


Unnamed: 0,Payloads,Class
3,http://rules.nyse.com/nysetools/tocchapter.asp...,Malicious
0,https://expand.snowballeffect.net:8442/plesk/u...,Malicious
4,http://www.wikihow.com/purchase-drafting-suppl...,Benign
2,http://www.talonline.ca/searchalberta/search.j...,Malicious
6,http://localhost:8080/tienda1/publico/pagar.js...,Benign
5,http://localhost:8080/tienda1/miembros/editar....,Benign
7,http://www.wikihow.com/preserve-your-paint-wit...,Benign
1,http://www.aromasin.com/content/news.jsp?setsh...,Malicious


In [437]:
mapping_dict = {"Benign": False, "Malicious": True}
partial_format_function = partial(humaneval_style_format, label_column="Class", payload_column="Payloads", template =  example_template, mappig = mapping_dict)
examples = sample_examples(trainset, samples_per_class=samples_to_extract, shuffle=True, format_function=partial_format_function)
examples

                                               Payloads      Class
3798  http://www.theatronhometheater.com/index.php?p...  Malicious
1744  http://www.artcreationselite.com/login.asp?fol...  Malicious
1084  http://www2.camarapoa.rs.gov.br/default.php?p_...  Malicious
1781  http://www.devrekanadolulisesi.k12.tr/for/logi...  Malicious
                                                Payloads   Class
7366   http://localhost:8080/tienda1/publico/pagar.js...  Benign
13519  http://www.wikihow.com/make-a-dessert-taco&t=1...  Benign
11375  http://www.wikihow.com/make-a-peep-cake&t=1396...  Benign
9172   http://www.wikihow.com/look-korean/ulzzang/asi...  Benign


['>>> detect_xss(http://localhost:8080/tienda1/publico/pagar.jsp?modo=insertar&precio=1611&b1=pasar+por+caja)\nFalse\n',
 '>>> detect_xss(http://www.devrekanadolulisesi.k12.tr/for/login.asp?folder=&quot;&gt;&lt;script&gt;alert(document.cookie)&lt;/script&gt;<br>&gt;&lt;marquee&gt;&lt;h1&gt;by+narcoticxs&lt;/h1&gt;&lt;/marquee&gt;)\nTrue\n',
 ">>> detect_xss(http://www2.camarapoa.rs.gov.br/default.php?p_secao=%22%3e'%3e%3cscript%20src=http://vuln.xssed.net/<br>thirdparty/scripts/python5.js%3e%3c/script%3e)\nTrue\n",
 '>>> detect_xss(http://www.theatronhometheater.com/index.php?page=&quot;&gt;&lt;/iframe&gt;&lt;script&gt;alert(1);&lt;/script&gt;&amp;pagetitle=&lt;/<br>title&gt;&lt;script&gt;alert(1);&lt;/script&gt;)\nTrue\n',
 '>>> detect_xss(http://www.wikihow.com/make-a-peep-cake&t=1396546410095&n=1703252&k=mainentity)\nFalse\n',
 '>>> detect_xss(http://www.artcreationselite.com/login.asp?folder=&quot;&gt;&lt;script&gt;alert(document.cookie)&lt;/script&gt;&gt;&lt;marquee<br>&gt;&lt;h1&

In [438]:
prompt_parameters["input"] += "\n"
for example in examples:
    prompt_parameters["input"] += example

In [439]:
print(prompt.format(**prompt_parameters))


System: The user will provide the initial part of the function (function name, parameters with types, return type and a comment describing the purpose of the function, with some optional example. 
You are a coding assistant that writes some python code to complete the user's function. Be sure that the code is syntactically correct and that it returns the correct type. 
The lenght of the code should short, readable, and without redundant checks on the paylaod. 
Return only python code in Markdown format, e.g.:
```python
....
```
Human: def detect_xss(http_get_request: str)->bool: 
""" Check if in the given http_get_request there is an XSS exploit, considering also the possible evasions that an attacker can perform.""" 
>>> detect_xss(http://localhost:8080/tienda1/publico/pagar.jsp?modo=insertar&precio=1611&b1=pasar+por+caja)
False
>>> detect_xss(http://www.devrekanadolulisesi.k12.tr/for/login.asp?folder=&quot;&gt;&lt;script&gt;alert(document.cookie)&lt;/script&gt;<br>&gt;&lt;marquee&gt;

In [440]:
chain = prompt | model | StrOutputParser() | sanitize_output