In [1]:
import ast
import astor
import astunparse
from astunparse.unparser import Unparser
import pandas as pd
import numpy as np
try:
    import re2 as re
except:
    import re
import bandit
import os
import subprocess
import json
import tempfile
import time
import io
import logging
from tqdm import tqdm

In [2]:
class ExceptPassTransformer(ast.NodeTransformer):
    def visit_ExceptHandler(self, node):
        self.generic_visit(node)
        if isinstance(node.body[0], ast.Pass):
            node.body[0] = ast.Raise()
        return node

class EvalTransformer(ast.NodeTransformer):
    def visit_Call(self, node):
        self.generic_visit(node)
        if isinstance(node.func, ast.Name) and node.func.id == 'eval':
            node.func.id = 'ast.literal_eval'
        return node
    
class AssertTransformer(ast.NodeTransformer):
    def visit_Assert(self, node):
        self.generic_visit(node)
        new_node = ast.If(test=ast.UnaryOp(op=ast.Not(), operand=node.test),
                          body=[ast.Raise(exc=ast.Call(func=ast.Name(id='AssertionError', ctx=ast.Load()),
                                                        args=node.msg if node.msg else [],
                                                        keywords=[]))],orelse=[])

        return new_node
    
class OpenTransformer(ast.NodeTransformer):
    def visit_Call(self, node):
        self.generic_visit(node)
        if (
            isinstance(node.func, ast.Name)
            and node.func.id == "open"
            and not isinstance(node.parent, ast.With)
        ):
            with_node = ast.With(
                items=[ast.withitem(context_expr=node, optional_vars=None)],
                body=[],
                type_comment=None,
            )
            return with_node
        return node
    
class PathTraversalTransformer(ast.NodeTransformer):
    def visit_Call(self, node):
        self.generic_visit(node)
        if (
            isinstance(node.func, ast.Attribute)
            and isinstance(node.func.value, ast.Name)
            and node.func.value.id == "os"
            and node.func.attr in ["path", "join", "abspath"]
        ):
            for i, arg in enumerate(node.args):
                if isinstance(arg, ast.Str) and ".." in arg.s:
                    node.args[i] = ast.Str(s=re.sub(r"\.\.", "", arg.s))
        return node
    
class SQLInjectionTransformer(ast.NodeTransformer):
    def visit_BinOp(self, node):
        self.generic_visit(node)
        if isinstance(node.op, ast.Add):
            if (isinstance(node.left, ast.Str) and isinstance(node.right, ast.Str) and 
                    node.left.s.endswith("SELECT") and node.right.s.startswith("FROM")):
                node.op = ast.Mod()
                node.right = ast.Tuple(elts=[node.right], ctx=ast.Load())
                node.left.s = node.left.s[:-6] + "SELECT %s " + node.right.s
        return node
    
class XSSTransformer(ast.NodeTransformer):
    def visit_Call(self, node):
        self.generic_visit(node)
        if (
            isinstance(node.func, ast.Attribute)
            and isinstance(node.func.value, ast.Name)
            and node.func.value.id == "re"
            and node.func.attr in ["sub", "subn"]
        ):
            if len(node.args) >= 2 and isinstance(node.args[0], ast.Str) and isinstance(node.args[1], ast.Str):
                if isinstance(node.args[0], ast.Constant):
                    node.args[0] = ast.Str(s=str(node.args[0].value))
                if isinstance(node.args[1], ast.Constant):
                    node.args[1] = ast.Str(s=str(node.args[1].value))
                node.args[1].s = re.sub(r"[<>\"']", "", node.args[1].s)
        return node
    
class ConstantTransformer(ast.NodeTransformer):
    def visit_Constant(self, node):
        if isinstance(node.value, str):
            new_node = ast.Str(s=node.value)
        elif isinstance(node.value, (int, float)):
            new_node = ast.Num(n=node.value)
        else:
            return node

        ast.copy_location(new_node, node)
        return new_node
    
class CustomTransformer(ExceptPassTransformer, EvalTransformer, AssertTransformer, 
                        OpenTransformer, PathTraversalTransformer, SQLInjectionTransformer, XSSTransformer, ConstantTransformer):
    pass

def apply_transformers(code):
    # Parse the code into an AST
    tree = ast.parse(code)

    # Apply the combined transformer
    tree = CustomTransformer().visit(tree)
    
    # Fix missing location information
    tree = ast.fix_missing_locations(tree)

    # Unparse the modified AST back into a code string
    fixed_code = custom_unparse(tree).strip()
    return fixed_code

class CustomUnparser(Unparser):
    def _Call(self, t):
        self.write("(")
        comma = False
        args = t.args if isinstance(t.args, list) else [t.args]
        for e in args:
            if comma:
                self.write(", ")
            else:
                comma = True
            self.dispatch(e)
        self.write(")")

def custom_unparse(tree):
    v = io.StringIO()
    CustomUnparser(tree, file=v)
    return v.getvalue()
# def apply_transformers(code):
#     # Parse the code into an AST
#     tree = ast.parse(code)

#     # Apply the combined transformer
#     tree = CustomTransformer().visit(tree)

#     # Unparse the modified AST back into a code string
#     fixed_code = astunparse.unparse(tree).strip()
#     return fixed_code

# class CustomTransformer(ast.NodeTransformer):
#     def __init__(self):
#         self.changes = []

#     def visit_Str(self, node):
#         new_node = ast.Constant(value=node.s, kind=None)
#         self.changes.append((node, new_node))
#         return new_node

In [3]:
def reduce_mem_usage(props):
    start_mem_usg = props.memory_usage().sum() / 1024**2 
    print("Memory usage of properties dataframe is :",start_mem_usg," MB")
    NAlist = [] # Keeps track of columns that have missing values filled in. 
    for col in props.columns:
        if props[col].dtype != object:  # Exclude strings
            
            # Print current column type
            #print("******************************")
            #print("Column: ",col)
            #print("dtype before: ",props[col].dtype)
            
            # make variables for Int, max and min
            IsInt = False
            mx = props[col].max()
            mn = props[col].min()
            
            # Integer does not support NA, therefore, NA needs to be filled
            if not np.isfinite(props[col]).all(): 
                NAlist.append(col)
                props[col].fillna(mn-1,inplace=True)  
                   
            # test if column can be converted to an integer
            asint = props[col].fillna(0).astype(np.int64)
            result = (props[col] - asint)
            result = result.sum()
            if result > 0.01:
                IsInt = True

            
            # Make Integer/unsigned Integer datatypes
            if IsInt:
                if mn >= 0:
                    if mx < 255:
                        props[col] = props[col].astype(np.uint8)
                    elif mx < 65535:
                        props[col] = props[col].astype(np.uint16)
                    elif mx < 4294967295:
                        props[col] = props[col].astype(np.uint32)
                    else:
                        props[col] = props[col].astype(np.uint64)
                else:
                    if mn > np.iinfo(np.int8).max:
                        props[col] = props[col].astype(np.int8)
                    elif mn > np.iinfo(np.int16).max:
                        props[col] = props[col].astype(np.int16)
                    elif mn > np.iinfo(np.int32).max:
                        props[col] = props[col].astype(np.int32)
                    elif mn > np.iinfo(np.int64).max:
                        props[col] = props[col].astype(np.int64)    
            
            # Make float datatypes 32 bit
            else:
                props[col] = props[col].astype(np.float32)
            
            # Print new column type
            #print("dtype after: ",props[col].dtype)
            #print("******************************")
    
    # Print final result
    print("___MEMORY USAGE AFTER COMPLETION:___")
    mem_usg = props.memory_usage().sum() / 1024**2 
    print("Memory usage is: ",mem_usg," MB")
    print("This is ",100*mem_usg/start_mem_usg,"% of the initial size")
    return props, NAlist

In [4]:
train_df = pd.read_csv('/data/kiho/autocode/dataset/pep8_format/source/python100k_train_pep8_only3.csv', index_col=[0])
eval_df = pd.read_csv('/data/kiho/autocode/dataset/pep8_format/source/python50k_eval_pep8_only3.csv', index_col=[0])

train_df[['vulnerable_lines', 'fixed_code', 'after_fix_vulnerable_lines']] = str(0)
eval_df[['vulnerable_lines', 'fixed_code', 'after_fix_vulnerable_lines']] = str(0)

In [5]:
train_df, NAlist = reduce_mem_usage(train_df)
eval_df, NAlist = reduce_mem_usage(eval_df)

Memory usage of properties dataframe is : 3.814697265625  MB
___MEMORY USAGE AFTER COMPLETION:___
Memory usage is:  3.814697265625  MB
This is  100.0 % of the initial size
Memory usage of properties dataframe is : 1.9073486328125  MB
___MEMORY USAGE AFTER COMPLETION:___
Memory usage is:  1.9073486328125  MB
This is  100.0 % of the initial size


In [4]:
def run_bandit(code_str, output_file):
    # Create a temporary file and write the code string to it
    with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as temp_file:
        temp_file.write(code_str)
        temp_file_name = temp_file.name

    # Run bandit on the temporary file
    result = subprocess.run(
        ["bandit", temp_file_name, "-f", "json"],
        capture_output=True,
        text=True,
    )

    # Delete the temporary file
    os.remove(temp_file_name)

    # Write the output to the specified output file
    with open(output_file, "w") as f:
        f.write(result.stdout)

    return json.loads(result.stdout)

In [5]:
def get_vulnerable_lines(input_string):
    # data = json.loads(str(input_string).replace("'", "\""))
    json_string = json.dumps(input_string)
    data = json.loads(json_string)
    vulnerable_lines = []

    for result in data['results']:
        vulnerable_lines.extend(result['line_range'])

    return vulnerable_lines

def fix_vulnerable_code(code_string):
    tree = ast.parse(code_string)
    transformer = CustomTransformer()
    fixed_tree = transformer.visit(tree)
    fixed_code = code_string

    # Apply changes to the code string
    for old_node, new_node in reversed(transformer.changes):
        start = old_node.col_offset
        end = start + len(repr(old_node.s))
        fixed_code = fixed_code[:start] + repr(new_node.value) + fixed_code[end:]

    return fixed_code

# def fix_vulnerable_code(code_string):
#     tree = ast.parse(code_string)
#     transformer = CustomTransformer()
#     fixed_tree = transformer.visit(tree)
#     fixed_code = astor.to_source(fixed_tree)
#     return fixed_code

# def fix_vulnerable_code(code_string):
#     tree = ast.parse(code_string)
#     transformer = CustomTransformer()
#     fixed_tree = transformer.visit(tree)
#     fixed_code = astunparse.unparse(fixed_tree)
#     return fixed_code

In [30]:
import concurrent.futures
from functools import partial


def process_single_entry(i, train_df, dir_path):
    result = {}
    target_text = train_df['text'][i][13:-13]
    output_file = dir_path + str(i) + ".json"
    report = run_bandit(target_text, output_file)
    vulnerable_lines = get_vulnerable_lines(report)
    if vulnerable_lines != []:
        result['vulnerable_lines'] = str(vulnerable_lines)
        fixed_code = apply_transformers(target_text)  # Change this line
        result['fixed_code'] = str(fixed_code)
        output_file = dir_path + str(i) + "_fixed.json"
        report = run_bandit(fixed_code, output_file)
        vulnerable_lines = get_vulnerable_lines(report)
        if vulnerable_lines != []:
            result['after_fix_vulnerable_lines'] = str(vulnerable_lines)
    else:
        result['vulnerable_lines'] = str(0)

    return i, result

dir_path = "/data/kiho/secure_coding/bandit_output_train/"
num_threads = 8

with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
    process_func = partial(process_single_entry, train_df=train_df, dir_path=dir_path)
    results = list(tqdm(executor.map(process_func, range(len(train_df['text']))), total=len(train_df['text'])))

for i, result in results:
    for key, value in result.items():
        train_df[key][i] = value

# train_df.to_csv('py_train_fixed_df.csv')

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100000/100000 [1:39:46<00:00, 16.71it/s]


In [31]:
import concurrent.futures
from functools import partial


def process_single_entry(i, train_df, dir_path):
    result = {}
    target_text = train_df['text'][i][13:-13]
    output_file = dir_path + str(i) + ".json"
    report = run_bandit(target_text, output_file)
    vulnerable_lines = get_vulnerable_lines(report)
    if vulnerable_lines != []:
        result['vulnerable_lines'] = str(vulnerable_lines)
        fixed_code = apply_transformers(target_text)  # Change this line
        result['fixed_code'] = str(fixed_code)
        output_file = dir_path + str(i) + "_fixed.json"
        report = run_bandit(fixed_code, output_file)
        vulnerable_lines = get_vulnerable_lines(report)
        if vulnerable_lines != []:
            result['after_fix_vulnerable_lines'] = str(vulnerable_lines)
    else:
        result['vulnerable_lines'] = str(0)

    return i, result

dir_path = "/data/kiho/secure_coding/bandit_output_eval/"
num_threads = 8

with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
    process_func = partial(process_single_entry, train_df=eval_df, dir_path=dir_path)
    results = list(tqdm(executor.map(process_func, range(len(eval_df['text']))), total=len(eval_df['text'])))

for i, result in results:
    for key, value in result.items():
        eval_df[key][i] = value

# eval_df.to_csv('py_eval_fixed_df.csv')

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50000/50000 [50:34<00:00, 16.48it/s]


In [None]:
# dir_path = "/data/kiho/secure_coding/bandit_output_eval/"
# for i in tqdm(range(len(eval_df['text'])), desc='fixing eval_data', mininterval=0.01):
#     target_text = eval_df['text'][i][13:-13]
#     output_file = dir_path + str(i) + ".json"
#     report = run_bandit(target_text, output_file)
#     vulnerable_lines = get_vulnerable_lines(report)
#     if vulnerable_lines != []:
#         eval_df['vulnerable_lines'][i] = str(vulnerable_lines)
#         fixed_code = fix_vulnerable_code(target_text)
#         eval_df['fixed_code'][i] = str(fixed_code)
#         output_file = dir_path + str(i) + "_fixed.json"
#         report = run_bandit(fixed_code, output_file)
#         vulnerable_lines = get_vulnerable_lines(report)
#         if vulnerable_lines != []:
#             train_df['after_fix_vulnerable_lines'][i] = str(vulnerable_lines)
#     else:
#         train_df['vulnerable_lines'][i] = str(0)
# train_df.to_csv('py_eval_fixed_df.csv')

### Generated DataFrame

In [6]:
df_py_train = pd.read_csv('/data/kiho/secure_coding/py_train_fixed_df.csv', index_col=0)
df_py_eval = pd.read_csv('/data/kiho/secure_coding/py_eval_fixed_df.csv', index_col=0)

In [7]:
df_py_train

Unnamed: 0,text,vulnerable_lines,fixed_code,after_fix_vulnerable_lines
0,<|endoftext|>#!/usr/bin/env python\n# -*- codi...,0,0,0
1,<|endoftext|># -*- coding: utf-8 -*-\n# Open S...,[221],import msgpack\nimport gevent.pool\nimport gev...,0
2,"<|endoftext|>#!/usr/bin/env python\n""""""Django'...",0,0,0
3,"<|endoftext|>""""""Installer for hippybot\n""""""\n\...",0,0,0
4,<|endoftext|>#!/usr/bin/env python\nimport os\...,0,0,0
...,...,...,...,...
99995,<|endoftext|># vim: tabstop=4 shiftwidth=4 sof...,0,0,0
99996,<|endoftext|># vim: tabstop=4 shiftwidth=4 sof...,0,0,0
99997,<|endoftext|># vim: tabstop=4 shiftwidth=4 sof...,0,0,0
99998,<|endoftext|># vim: tabstop=4 shiftwidth=4 sof...,0,0,0


In [28]:
df_py_eval

Unnamed: 0,text,vulnerable_lines,fixed_code,after_fix_vulnerable_lines
0,<|endoftext|># vim: tabstop=4 shiftwidth=4 sof...,0,0,0
1,<|endoftext|># vim: tabstop=4 shiftwidth=4 sof...,0,0,0
2,<|endoftext|>#!/usr/bin/env python\n# vim: tab...,0,0,0
3,<|endoftext|># vim: tabstop=4 shiftwidth=4 sof...,0,0,0
4,<|endoftext|>from horizon import tables\nfrom ...,0,0,0
...,...,...,...,...
49995,<|endoftext|># The MIT License (MIT)\n#\n# Cop...,"[231, 235]",' Search Imgur for a random image '\nimport co...,0
49996,<|endoftext|># -*- coding: utf-8 -*-\n#\n# Fla...,0,0,0
49997,<|endoftext|># coding: utf-8\n\nimport base64\...,0,0,0
49998,"<|endoftext|># coding: utf-8\n""""""\n flask_s...",0,0,0


In [29]:
len(df_py_eval['after_fix_vulnerable_lines'].unique())

844

### Semgrep

In [249]:
import yaml
import json
import os
import tempfile
# from semgrep import semgrep_main
# from semgrep.output import OutputHandler

In [379]:
yaml_lst = []
# Semgrep rules with python
f = open("/data/kiho/secure_coding/python_yamls_path.txt", 'r') # https://github.com/returntocorp/semgrep-rules
lines = f.readlines()
for line in lines:
    yaml_lst.append(line.strip())
f.close()
# List of input YAML files to merge
input_files = yaml_lst

# Load and merge the content of the input files
combined_rules = {'rules': []}
for file in input_files:
    with open(file, 'r') as f:
        content = yaml.safe_load(f)
        if 'rules' in content:
            combined_rules['rules'].extend(content['rules'])

# Save the combined content to a new YAML file
with open('combined_rules.yml', 'w') as output_file:
    yaml.dump(combined_rules, output_file)

In [378]:
def run_semgrep(code_str, output_file):
    # Create a temporary file and write the code string to it
    with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as temp_file:
        temp_file.write(code_str)
        temp_file_name = temp_file.name
    
    # Set up the Semgrep configuration and targets
    config_ids = [
        "p/default",
        "p/comment",
        "p/cwe-top-25",
        "p/owasp-top-ten",
        "p/r2c-security-audit",
        "p/bandit"
    ]
    configs_arg = ' '.join(f'--config {config}' for config in config_ids)

    # Run Semgrep using subprocess
    semgrep_cmd = f'semgrep  --verbose {configs_arg} --json {temp_file_name}'
    results = subprocess.run(semgrep_cmd, shell=True, capture_output=True, encoding='utf-8')
    
    # Delete the temporary file
    os.remove(temp_file_name)

    # Parse and return the JSON results
    return json.loads(results.stdout)

In [381]:
report = run_semgrep(target_text, output_file)
report

{'errors': [],
 'paths': {'scanned': ['/tmp/tmptw1xrh94.py'], 'skipped': []},
 'results': [],
 'version': '1.20.0'}