## Load dataset

In [2]:
import json
data = None
with open("MSR_data_cleaned.json", "r") as file:
    data = json.load(file)

In [8]:
print(list(data.values())[0])

{'': '0', 'Access Gained': 'None', 'Attack Origin': 'Remote', 'Authentication Required': 'Single system', 'Availability': 'Partial', 'CVE ID': 'CVE-2015-8467', 'CVE Page': 'https://www.cvedetails.com/cve/CVE-2015-8467/', 'CWE ID': 'CWE-264', 'Complexity': 'Medium', 'Confidentiality': 'Partial', 'Integrity': 'Partial', 'Known Exploits': '', 'Publish Date': '2015-12-29', 'Score': '6.0', 'Summary': 'The samldb_check_user_account_control_acl function in dsdb/samdb/ldb_modules/samldb.c in Samba 4.x before 4.1.22, 4.2.x before 4.2.7, and 4.3.x before 4.3.3 does not properly check for administrative privileges during creation of machine accounts, which allows remote authenticated users to bypass intended access restrictions by leveraging the existence of a domain with both a Samba DC and a Windows DC, a similar issue to CVE-2015-2535.', 'Update Date': '2016-12-30', 'Vulnerability Classification': 'Bypass', 'add_lines': '0', 'codeLink': 'https://git.samba.org/?p=samba.git;a=commit;h=b000da128b

In [None]:
func_before
vul_func_with_fix

In [4]:
import os
import csv
import sys
from tree_sitter import Language, Parser
from subprocess import check_output
from tree_sitter_cpp import language

CPP_LANGUAGE = Language(language())

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

def ast_to_graph(ast):
    nodes = []
    edges = []
    node_features = []

    def traverse_tree(node, parent_id=None):
        node_id = len(nodes)
        nodes.append(node)
        node_features.append(node.type)  # Encode node type as a feature

        if parent_id is not None:
            edges.append((parent_id, node_id))  # Parent-child edge

        for child in node.children:
            traverse_tree(child, node_id)

    traverse_tree(ast.root_node)
    return nodes, edges, node_features

def parse_cpp(code):
    # Initialize parser and set language to C++
    parser = Parser(CPP_LANGUAGE)
    #print(dir(Language))
    #CPP_LANGUAGE = Language('build/my-languages.so', 'cpp')
    # Parse the code and generate the AST
    tree = parser.parse(bytes(code, 'utf8'))
    return tree


In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer

# Load your CSV dataset
# The dataset should have two columns: "vulnerable_code" and "fixed_code"
dataset = load_dataset("json", data_files="MSR_data_cleaned.json")

# Split the dataset into training and validation sets
dataset = dataset["train"].train_test_split(test_size=0.2)

# Initialize the tokenizer for Llama
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")  # Replace with your desired Llama model

# Tokenize the data
def tokenize_function(example):
    # Prepare input as "Fix the vulnerability: [vulnerable_code]" and target as "fixed_code"
    example["input_text"] = f"Fix the vulnerability: {example['func_before']}"
    example["target_text"] = example["vul_func_with_fix"]
    
    # Tokenize input and target
    input_encodings = tokenizer(example["input_text"], truncation=True, padding="max_length", max_length=512)
    target_encodings = tokenizer(example["target_text"], truncation=True, padding="max_length", max_length=512)

    # Return the encodings as input_ids and labels
    return {
        "input_ids": input_encodings["input_ids"],
        "attention_mask": input_encodings["attention_mask"],
        "labels": target_encodings["input_ids"]
    }

# Apply tokenization to the dataset
tokenized_dataset = dataset.map(tokenize_function, batched=True)

# Set the dataset format for PyTorch
tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

In [None]:
import torch
from transformers import AutoModelForCausalLM, Trainer, TrainingArguments

# Load the pre-trained Llama model for fine-tuning
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")

# Define training arguments
training_args = TrainingArguments(
    output_dir="./llama-vuln-fixer",
    evaluation_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    weight_decay=0.01,
    save_total_limit=2,
    logging_dir="./logs",
    logging_steps=10,
)

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
)

# Train the model
trainer.train()